import numpy as np import multiprocessing as actual_processing import multiprocessing.dummy as dummy_processing from sklearn.base import BaseEstimator from sklearn.utils.validation import check_array, column_or_1d as c1d from sklearn.model_selection import ParameterGrid import tbats.error as error class Estimator(BaseEstimator): """Base estimator for BATS and TBATS models Methods ------- fit(y) Fit to y and select best performing model based on AIC criterion. """ def __init__(self, context, use_box_cox=None, box_cox_bounds=(0, 1), use_trend=None, use_damped_trend=None, seasonal_periods=None, use_arma_errors=True, n_jobs=None): """ Class constructor Parameters ---------- context: abstract.ContextInterface For advanced users only. Provide this to override default behaviors use_box_cox: bool or None, optional (default=None) If Box-Cox transformation of original series should be applied. When None both cases shall be considered and better is selected by AIC. box_cox_bounds: tuple, shape=(2,), optional (default=(0, 1)) Minimal and maximal Box-Cox parameter values. use_trend: bool or None, optional (default=None) Indicates whether to include a trend or not. When None both cases shall be considered and better is selected by AIC. use_damped_trend: bool or None, optional (default=None) Indicates whether to include a damping parameter in the trend or not. Applies only when trend is used. When None both cases shall be considered and better is selected by AIC. seasonal_periods: iterable or array-like, optional (default=None) Length of each of the periods (amount of observations in each period). BATS accepts only int values here. When None or empty array, non-seasonal model shall be fitted. use_arma_errors: bool, optional (default=True) When True BATS will try to improve the model by modelling residuals with ARMA. Best model will be selected by AIC. If False, ARMA residuals modeling will not be considered. show_warnings: bool, optional (default=True) If warnings should be shown or not. Also see Model.warnings variable that contains all model related warnings. n_jobs: int, optional (default=None) How many jobs to run in parallel when fitting BATS model. When not provided BATS shall try to utilize all available cpu cores. """ self.context = context self.n_jobs = n_jobs self.seasonal_periods = self._normalize_seasonal_periods(seasonal_periods) self.use_box_cox = use_box_cox self.box_cox_bounds = box_cox_bounds self.use_arma_errors = use_arma_errors self.use_trend = use_trend if use_trend is False: if use_damped_trend is True: self.context.get_exception_handler().warn( "When use_damped_trend can be used only with use_trend. Setting damped trend to False.", error.InputArgsWarning ) use_damped_trend = False self.use_damped_trend = use_damped_trend def _normalize_seasonal_periods(self, seasonal_periods): # abstract method raise NotImplementedError() def _do_fit(self, y): # abstract method raise NotImplementedError() def fit(self, y): """Fit model to observations ``y``. :param y: array-like or iterable, shape=(n_samples,) :return: abstract.Model, Fitted model """ y = self._validate(y) if y is False: # Input data is not valid and no exception was raised yet. # This can happen only when one overrides default exception handler (see tbats.error.ExceptionHandler) return None if np.allclose(y, y[0]): return self.context.create_constant_model(y[0]).fit(y) best_model = self._do_fit(y) for warning in best_model.warnings: self.context.get_exception_handler().warn(warning, error.ModelWarning) return best_model def _validate(self, y): """Validates input time series. Also adjusts box_cox if necessary.""" try: y = c1d(check_array(y, ensure_2d=False, force_all_finite=True, ensure_min_samples=1, copy=True, dtype=np.float64)) # type: np.ndarray except Exception as validation_exception: self.context.get_exception_handler().exception( "y series is invalid", error.InputArgsException, previous_exception=validation_exception ) return False if np.any(y <= 0): if self.use_box_cox is True: self.context.get_exception_handler().warn( "Box-Cox transformation (use_box_cox) was forced to True " "but there are negative values in input series. " "Setting use_box_cox to False.", error.InputArgsWarning ) self.use_box_cox = False return y def _case_fit(self, components_combination): """Internal method used by parallel computation.""" case = self.context.create_case_from_dictionary(**components_combination) return case.fit(self._y) def _choose_model_from_possible_component_settings(self, y, components_grid): """Fits all models in a grid and returns best one by AIC Returns ------- abstract.Model Best model by AIC """ self._y = y # note n_jobs = None means to use cpu_count() pool = self._prepare_pool(self.n_jobs) models = pool.map(self._case_fit, components_grid) pool.close() self._y = None # clean-up if len(models) == 0: return None best_model = models[0] for model in models: if model.aic < best_model.aic: best_model = model return best_model def _prepare_pool(self, n_jobs=None): if n_jobs == 1: return dummy_processing.Pool(processes=n_jobs) return actual_processing.Pool(processes=n_jobs) def _prepare_components_grid(self, seasonal_harmonics=None): """Provides a grid of all allowed model component combinations. Parameters ---------- seasonal_harmonics: array-like or None When provided all component combinations shall contain those harmonics """ allowed_combinations = [] use_box_cox = self.use_box_cox base_combination = { 'use_box_cox': self.__prepare_component_boolean_combinations(use_box_cox), 'box_cox_bounds': [self.box_cox_bounds], 'use_arma_errors': [self.use_arma_errors], 'seasonal_periods': [self.seasonal_periods], } if seasonal_harmonics is not None: base_combination['seasonal_harmonics'] = [seasonal_harmonics] if self.use_trend is not True: # False or None allowed_combinations.append({ **base_combination, **{ 'use_trend': [False], 'use_damped_trend': [False], # Damped trend must be False when trend is False } }) if self.use_trend is not False: # True or None allowed_combinations.append({ **base_combination, **{ 'use_trend': [True], 'use_damped_trend': self.__prepare_component_boolean_combinations(self.use_damped_trend), } }) return ParameterGrid(allowed_combinations) def _prepare_non_seasonal_components_grid(self): """Provides a grid of all allowed non-season model component combinations.""" allowed_combinations = [] use_box_cox = self.use_box_cox base_combination = { 'use_box_cox': self.__prepare_component_boolean_combinations(use_box_cox), 'box_cox_bounds': [self.box_cox_bounds], 'use_arma_errors': [self.use_arma_errors], 'seasonal_periods': [[]], } if self.use_trend is not True: # False or None allowed_combinations.append({ **base_combination, **{ 'use_trend': [False], 'use_damped_trend': [False], # Damped trend must be False when trend is False } }) if self.use_trend is not False: # True or None allowed_combinations.append({ **base_combination, **{ 'use_trend': [True], 'use_damped_trend': self.__prepare_component_boolean_combinations(self.use_damped_trend), } }) return ParameterGrid(allowed_combinations) @staticmethod def __prepare_component_boolean_combinations(param): combinations = [param] if param is None: combinations = [False, True] return combinations def _normalize_seasonal_periods_to_type(self, seasonal_periods, dtype): """Validates seasonal periods and normalizes them Normalization ensures periods are of proper type, unique and sorted. """ if seasonal_periods is not None: try: seasonal_periods = c1d(check_array(seasonal_periods, ensure_2d=False, force_all_finite=True, ensure_min_samples=0, copy=True, dtype=dtype)) except Exception as validation_exception: self.context.get_exception_handler().exception("seasonal_periods definition is invalid", error.InputArgsException, previous_exception=validation_exception) seasonal_periods = np.unique(seasonal_periods) if len(seasonal_periods[np.where(seasonal_periods <= 1)]) > 0: self.context.get_exception_handler().warn( "All seasonal periods should be values greater than 1. " "Ignoring all seasonal period values that do not meet this condition.", error.InputArgsWarning ) seasonal_periods = seasonal_periods[np.where(seasonal_periods > 1)] seasonal_periods.sort() if len(seasonal_periods) == 0: seasonal_periods = None return seasonal_periods