Source code for hcrystalball.wrappers._prophet

import itertools

# redirect prophets and pystans output to the console
import logging
import sys

from hcrystalball.wrappers._base import TSModelWrapper
from hcrystalball.wrappers._base import tsmodel_wrapper_constructor_factory

sys_out = logging.StreamHandler(sys.__stdout__)
sys_out.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
logging.getLogger("prophet").addHandler(sys_out)
logging.getLogger("pystan").addHandler(sys_out)
logger = logging.getLogger("fbprophet.plot")
logger.setLevel(logging.CRITICAL)

import pandas as pd
from prophet import Prophet

from hcrystalball.utils import check_fit_before_predict
from hcrystalball.utils import check_X_y
from hcrystalball.utils import deep_dict_update
from hcrystalball.utils import enforce_y_type
from hcrystalball.utils import set_verbosity

pd.plotting.register_matplotlib_converters()


[docs]class ProphetWrapper(TSModelWrapper): """Wrapper for prophet.Prophet model https://facebook.github.io/prophet/docs/quick_start.html#python-api Bring prophet to sklearn time-series compatible interface and puts fit parameters to initialization stage. Parameters ---------- name : str Name of the model instance, used also as column name for returned prediction. conf_int : bool Whether confidence intervals should be also outputed. full_prophet_output: bool Whether the `predict` method should output the full prophet.Prophet dataframe. extra_seasonalities : list of dicts Dictionary will be passed to prophet.Prophet add_regressor method. extra_regressors : list or list of dicts Dictionary will be passed to prophet.Prophet add_seasonality method. extra_holidays : dict of dict Dict with name of the holiday and values as another dict with required 'lower_window' key and 'upper_window' key and optional 'prior_scale' key i.e.{'holiday_name': {'lower_window':1, 'upper_window:1, 'prior_scale: 10}}. fit_params : dict Parameters passed to `fit` prophet.Prophet model. clip_predictions_lower : float Minimal value allowed for predictions - predictions will be clipped to this value. clip_predictions_upper : float Maximum value allowed for predictions - predictions will be clipped to this value. hcb_verbose : bool Whtether to keep (True) or suppress (False) messages to stdout and stderr from the wrapper and 3rd party libraries during fit and predict """ @tsmodel_wrapper_constructor_factory(Prophet) def __init__( self, name="prophet", conf_int=False, full_prophet_output=False, extra_seasonalities=None, extra_regressors=None, extra_holidays=None, fit_params=None, clip_predictions_lower=None, clip_predictions_upper=None, hcb_verbose=True, ): """This constructor will be modified at runtime to accept all parameters of the Prophet class on top of the ones defined here!""" pass @staticmethod def _transform_data_to_tsmodel_input_format(X, y=None): """Trasnform data into `Prophet.model` required format Parameters ---------- X : pandas.DataFrame Input features with required 'date' column. y : array_like, (1d) Target vector Returns ------- pandas.DataFrame """ if y is not None: X = X.assign(y=y) return X.assign(ds=lambda x: x.index).reset_index(drop=True) def _set_model_extra_params(self, model): """Add `extra_seasonalities` and `extra_regressors` to `Prophet.model` Parameters ---------- model : Prophet.model model to be extended with extra seasonalities and regressors Returns ------- Prophet.model model with extra seasonalities and extra regressors """ if self.extra_seasonalities is not None: for s in self.extra_seasonalities: model.add_seasonality(**s) if self.extra_regressors is not None: for r in self.extra_regressors: if isinstance(r, str): model.add_regressor(r) else: model.add_regressor(**r) return model def _adjust_holidays(self, X): """Add `holidays` to `Prophet.model` Doing that in required form and drop the 'holiday' column from X Parameters ---------- X : pandas.DataFrame Input features with 'holiday' column. Returns ------- pandas.DataFrame Input features without 'holiday' column """ holiday_cols = [col for col in X.filter(like="_holiday_").select_dtypes(include="object").columns] unique_holiday_dict = {col: X.loc[X[col] != "", col].unique() for col in holiday_cols} extra_holidays = { col: { holiday: { "lower_window": self._get_holiday_windows(X, f"_before{col}"), "upper_window": self._get_holiday_windows(X, f"_after{col}"), "prior_scale": self.holidays_prior_scale, } for holiday in holidays } for col, holidays in unique_holiday_dict.items() } if self.extra_holidays: extra_holidays = {k: deep_dict_update(v, self.extra_holidays) for k, v in extra_holidays.items()} unique_holiday = set(itertools.chain.from_iterable(unique_holiday_dict.values())) all_extra_holidays = set(itertools.chain.from_iterable(extra_holidays.values())) if len(unique_holiday) > 0: missing_holidays = all_extra_holidays.difference(unique_holiday) if missing_holidays: logging.warning( f"""Following holidays weren't found in data; thus not being used {missing_holidays}. Available holidays for this data: {unique_holiday}""" ) holidays = [] for col in holiday_cols: # assign country code/country code column to the holiday names # to ensure single occurence of a holiday per country # (e.g. `BE` and `DE` both have Christmas Day -> Christmas Day_DE, Christmas Day_BE) inter = X.loc[X[col] != "", [col]].assign( **{"holiday": lambda df: df[col] + f"_{col.split('_')[2]}"} ) if not inter.empty: # translate original holiday name to extra information on the holiday affect # given the extra_holidays parameter holidays.append( inter.merge( inter[col].map(extra_holidays[col]).apply(pd.Series), left_index=True, right_index=True, ).loc[ :, ["holiday", "lower_window", "upper_window", "prior_scale"], ] ) self.model.holidays = pd.concat(holidays).assign(ds=lambda x: x.index).reset_index(drop=True) return X.drop(columns=holiday_cols, errors="ignore") def _get_holiday_windows(self, X, col_like): """Get information about window for holidays for particular country. Parameters ---------- X : pandas.DataFrame Input features with 'col_like' column. col_like: str col name pattern (i.e. `_before_holiday_DE`) Returns ------- int number of days around holidays (whether before or after depends on `col_like`) """ window = X.filter(like=f"{col_like}") window = 0 if window.empty else window.columns[0].split("_")[1] return int(window)
[docs] @enforce_y_type @check_X_y @set_verbosity def fit(self, X, y): """Transform input data to `Prophet.model` required format and fit the model. Parameters ---------- X : pandas.DataFrame Input features. y : array_like, (1d) Target vector. Returns ------- self """ # TODO Add regressors which are not in self.extra_regressors but are in X? self.model = self._init_tsmodel(Prophet) if X.filter(like="_holiday_").shape[1] > 0: X = self._adjust_holidays(X) df = self._transform_data_to_tsmodel_input_format(X, y) self.model.fit(df, **self.fit_params) if self.fit_params else self.model.fit(df) self.fitted = True return self
[docs] @check_fit_before_predict @set_verbosity def predict(self, X): """Adjust holidays, transform data to required format and provide predictions. Parameters ---------- X : pandas.DataFrame Input features. Returns ------- pandas.DataFrame with pandas.DatetimeIndex Prediction is stored in column with name being the `name` of the wrapper. If `conf_int` attribute is set to True, the returned DataFrame will have three columns, with the second and third (named 'name'_lower and 'name'_upper). If `full_prophet_output` is set to True, then full Prophet.model.predict output is returned. """ if X.filter(like="_holiday_").shape[1] > 0: X = self._adjust_holidays(X) df = self._transform_data_to_tsmodel_input_format(X) preds = ( self.model.predict(df) .rename( columns={ "yhat": self.name, "yhat_lower": f"{self.name}_lower", "yhat_upper": f"{self.name}_upper", } ) .drop(columns="ds", errors="ignore") ) if not self.full_prophet_output: if self.conf_int: preds = preds[[self.name, f"{self.name}_lower", f"{self.name}_upper"]] else: preds = preds[[self.name]] preds.index = X.index return self._clip_predictions(preds)
__all__ = ["ProphetWrapper"]