get_sklearn_wrapper

hcrystalball.wrappers.get_sklearn_wrapper(model_cls, **model_params)[source]

Factory function returning the model specific SklearnWrapper with provided model_cls parameters.

This function is required for sklearn compatibility since our SklearnWrapper need to have all parameters of model_cls set already during SklearnWrapper definition time. This factory function is not needed in case of other wrappers since the regressor is already part of the wrapper.

Parameters
  • model_cls (class of sklearn compatible regressor) – i.e. LinearRegressor, GradientBoostingRegressor

  • model_paramsmodel_cls specific parameters (e.g. max_depth) and/or SklearnWrapper specific parameters (e.g. clip_predictions_lower)

Example

>>> from hcrystalball.wrappers._sklearn import _get_sklearn_wrapper
>>> from sklearn.ensemble import RandomForestRegressor
>>> est = get_sklearn_wrapper(RandomForestRegressor, max_depth=6, clip_predictions_lower=0.)
>>> est
SklearnWrapper(bootstrap=True, ccp_alpha=0.0, clip_predictions_lower=0.0,
           clip_predictions_upper=None, criterion='mse', fit_params=None,
           lags=3, max_depth=6, max_features='auto', max_leaf_nodes=None,
           max_samples=None, min_impurity_decrease=0.0,
           min_impurity_split=None, min_samples_leaf=1, min_samples_split=2,
           min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=None,
           name='sklearn', oob_score=False, optimize_for_horizon=False,
           random_state=None, verbose=0, warm_start=False)
Returns

Return type

SklearnWrapper