Advanced Sarimax Usage

If you are not a magician who can easily infer correct SARIMAX orders from looking on PACF (partial autocorrelation function) and ACF (autocorrelation function), you want to rather leverage AutoSarima which finds them for you - set init_with_autoarima to True.

If you want to further configure the search space of AutoARIMA, then you can provide all parameters of pmdarima.arima.AutoARIMA as autoarima_dict arguments.

When you run cross-validation with enabled AutoARIMA (init_with_autoarima), it’s often advisable to find the correct order only during the first fit call and reuse this model on all other splits in order to simulate the out-of-sample performance.

The signature of SarimaxWrapper contains parameters of pmdarima.arima.ARIMA, not AutoARIMA class.

For more parameters check pmdarima docs

[1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('seaborn')
plt.rcParams['figure.figsize'] = [12, 6]
[2]:
from hcrystalball.utils import get_sales_data

df = get_sales_data(n_dates=100,
                    n_assortments=1,
                    n_states=1,
                    n_stores=1)
X, y = df[["Open"]], df['Sales']
[3]:
X
[3]:
Open
Date
2015-04-23 True
2015-04-24 True
2015-04-25 True
2015-04-26 False
2015-04-27 True
... ...
2015-07-27 True
2015-07-28 True
2015-07-29 True
2015-07-30 True
2015-07-31 True

100 rows × 1 columns

[4]:
from hcrystalball.wrappers import SarimaxWrapper
[5]:
SarimaxWrapper?
[6]:
model = SarimaxWrapper(
    autoarima_dict={'d':1, 'm':7, 'max_p':2, 'max_q':2},
    init_with_autoarima=True
)
[7]:
preds = (model.fit(X[:-10], y[:-10])
         .predict(X[-10:])
         .merge(y, left_index=True, right_index=True, how='outer')
         .tail(50)
)
preds.plot(title=f"MAE:{(preds['Sales']-preds['sarimax']).abs().mean().round(3)}");
[7]:
<AxesSubplot:title={'center':'MAE:3151.698'}, xlabel='Date'>
../../../_images/examples_tutorial_wrappers_07_advanced_sarimax_7_1.png

And now access the models parameters

[8]:
model
[8]:
SarimaxWrapper(always_search_model=False,
               autoarima_dict={'d': 1, 'm': 7, 'max_p': 2, 'max_q': 2},
               clip_predictions_lower=None, clip_predictions_upper=None,
               conf_int=False, hcb_verbose=True, init_with_autoarima=False,
               maxiter=50, method='lbfgs', name='sarimax', order=(0, 1, 2),
               out_of_sample_size=0, scoring='mse', scoring_args={},
               seasonal_order=(0, 0, 2, 7), start_params=None,
               suppress_warnings=True, trend=None, with_intercept=False)

You might also directly pass the orders if you know, what are you doing

[9]:
model = SarimaxWrapper(order=(1, 1, 2), seasonal_order=(1, 0, 2, 7))
[10]:
preds = (model.fit(X[:-10], y[:-10])
         .predict(X[-10:])
         .merge(y, left_index=True, right_index=True, how='outer')
         .tail(50)
)
preds.plot(title=f"MAE:{(preds['Sales']-preds['sarimax']).abs().mean().round(3)}");
[10]:
<AxesSubplot:title={'center':'MAE:2154.027'}, xlabel='Date'>
../../../_images/examples_tutorial_wrappers_07_advanced_sarimax_12_1.png
[11]:
model
[11]:
SarimaxWrapper(always_search_model=False, autoarima_dict=None,
               clip_predictions_lower=None, clip_predictions_upper=None,
               conf_int=False, hcb_verbose=True, init_with_autoarima=False,
               maxiter=50, method='lbfgs', name='sarimax', order=(1, 1, 2),
               out_of_sample_size=0, scoring='mse', scoring_args=None,
               seasonal_order=(1, 0, 2, 7), start_params=None,
               suppress_warnings=False, trend=None, with_intercept=True)
[12]:
model
[12]:
SarimaxWrapper(always_search_model=False, autoarima_dict=None,
               clip_predictions_lower=None, clip_predictions_upper=None,
               conf_int=False, hcb_verbose=True, init_with_autoarima=False,
               maxiter=50, method='lbfgs', name='sarimax', order=(1, 1, 2),
               out_of_sample_size=0, scoring='mse', scoring_args=None,
               seasonal_order=(1, 0, 2, 7), start_params=None,
               suppress_warnings=False, trend=None, with_intercept=True)