Handling Level Changes for Prophet Predictions

36 Views Asked by At

I have a dataset like so:

enter image description here

It is seasonal data, but there is a level shift after some point

I want Prophet to adapt to the data after the level shift faster. How can I do this?

I've read through the docs, there are some options:

  • delete older data

But is there any way to force prophet to adapt to the level shifted data "faster"?

Here is a repro:

import pandas as pd
from prophet import Prophet
from random import randint
from datetime import datetime
import matplotlib.pyplot as plt


def get_dataset():
    d = {}
    total = 100
    level_shift_point = 10
    values = [20+i%3 for i in range(level_shift_point)]
    for i in range(level_shift_point, total):
        values.append(100 + i%3)

    d["y"] = values 
    d["ds"] = [datetime.utcfromtimestamp(3600*i).strftime('%Y-%m-%d %H:%M:%S') for i in range(total)]
    return pd.DataFrame.from_dict(d)


df = get_dataset()


m = Prophet(changepoint_prior_scale=0.0001)
m.fit(df)

future = m.make_future_dataframe(periods=100, freq="h", include_history=False)
forecast = m.predict(future)
m.plot(forecast)
plt.show()

enter image description here

As you can see, prediction don't make any sense at all.

I want the predictions to align with the data after the level shift. How can I do this?

1

There are 1 best solutions below

0
nz_21 On

Use holiday removal

import pandas as pd
from prophet import Prophet
from random import randint
from datetime import datetime
import matplotlib.pyplot as plt

def to_str_date(i):
    return datetime.utcfromtimestamp(3600*i).strftime('%Y-%m-%d %H:%M:%S')



def get_holiday_df():
    d = {
        "holiday":["one"], 
        "ds": [to_str_date(1)], 
        "upper_window":[1], 
        "lower_window":[0]
    }
    return pd.DataFrame.from_dict(d)



def get_dataset():
    d = {}
    total = 100
    level_shift_point = 24
    values = [20+i%3 for i in range(level_shift_point)]
    for i in range(level_shift_point, total):
        values.append(100 + i%3)

    d["y"] = values 
    d["ds"] = [datetime.utcfromtimestamp(3600*i).strftime('%Y-%m-%d %H:%M:%S') for i in range(total)]
    return pd.DataFrame.from_dict(d)


df = get_dataset()

hol = get_holiday_df()


m = Prophet(holidays=hol)
m.fit(df)

future = m.make_future_dataframe(periods=100, freq="h", include_history=False)
forecast = m.predict(future)
m.plot(forecast)
plt.show()