import numpy as np import pandas as pd from sklearn import clone from sklearn.base import BaseEstimator from sklearn.utils.validation import ( check_is_fitted, check_X_y, check_array, ) from sklego.common import as_list, expanding_list def constant_shrinkage(group_sizes: list, alpha: float) -> np.ndarray: r""" The augmented prediction for each level is the weighted average between its prediction and the augmented prediction for its parent. Let $\hat{y}_i$ be the prediction at level $i$, with $i=0$ being the root, than the augmented prediction $\hat{y}_i^* = \alpha \hat{y}_i + (1 - \alpha) \hat{y}_{i-1}^*$, with $\hat{y}_0^* = \hat{y}_0$. """ return np.array( [alpha ** (len(group_sizes) - 1)] + [ alpha ** (len(group_sizes) - 1 - i) * (1 - alpha) for i in range(1, len(group_sizes) - 1) ] + [(1 - alpha)] ) def relative_shrinkage(group_sizes: list) -> np.ndarray: """Weigh each group according to it's size""" return np.array(group_sizes) def min_n_obs_shrinkage(group_sizes: list, min_n_obs) -> np.ndarray: """Use only the smallest group with a certain amount of observations""" if min_n_obs > max(group_sizes): raise ValueError( f"There is no group with size greater than or equal to {min_n_obs}" ) res = np.zeros(len(group_sizes)) res[np.argmin(np.array(group_sizes) >= min_n_obs) - 1] = 1 return res class GroupedEstimator(BaseEstimator): """ Construct an estimator per data group. Splits data by values of a single column and fits one estimator per such column. :param estimator: the model/pipeline to be applied per group :param groups: the column(s) of the matrix/dataframe to select as a grouping parameter set :param value_columns: Columns to use in the prediction. If None (default), use all non-grouping columns :param shrinkage: How to perform shrinkage. None: No shrinkage (default) {"constant", "min_n_obs", "relative"} or a callable * constant: shrunk prediction for a level is weighted average of its prediction and its parents prediction * min_n_obs: shrunk prediction is the prediction for the smallest group with at least n observations in it * relative: each group-level is weight according to its size * function: a function that takes a list of group lengths and returns an array of the same size with the weights for each group :param use_global_model: With shrinkage: whether to have a model over the entire input as first group Without shrinkage: whether or not to fall back to a general model in case the group parameter is not found during `.predict()` :param **shrinkage_kwargs: keyword arguments to the shrinkage function """ def __init__( self, estimator, groups, value_columns=None, shrinkage=None, use_global_model=True, **shrinkage_kwargs, ): self.estimator = estimator self.groups = groups self.value_columns = value_columns self.shrinkage = shrinkage self.use_global_model = use_global_model self.shrinkage_kwargs = shrinkage_kwargs def __set_shrinkage_function(self): if isinstance(self.shrinkage, str): # Predefined shrinkage functions shrink_options = { "constant": constant_shrinkage, "relative": relative_shrinkage, "min_n_obs": min_n_obs_shrinkage, } try: self.shrinkage_function_ = shrink_options[self.shrinkage] except KeyError: raise ValueError( f"The specified shrinkage function {self.shrinkage} is not valid, " f"choose from {list(shrink_options.keys())} or supply a callable." ) elif callable(self.shrinkage): self.__check_shrinkage_func() self.shrinkage_function_ = self.shrinkage else: raise ValueError( f"Invalid shrinkage specified. Should be either None (no shrinkage), str or callable." ) def __check_shrinkage_func(self): """Validate the shrinkage function if a function is specified""" group_lengths = [10, 5, 2] expected_shape = np.array(group_lengths).shape try: result = self.shrinkage(group_lengths) except Exception as e: raise ValueError( f"Caught an exception while checking the shrinkage function: {str(e)}" ) from e else: if not isinstance(result, np.ndarray): raise ValueError( f"shrinkage_function({group_lengths}) should return an np.ndarray" ) if result.shape != expected_shape: raise ValueError( f"shrinkage_function({group_lengths}).shape should be {expected_shape}" ) @staticmethod def __check_cols_exist(X, cols): """Check whether the specified grouping columns are in X""" if X.shape[1] == 0: raise ValueError( f"0 feature(s) (shape=({X.shape[0]}, 0)) while a minimum of 1 is required." ) # X has been converted to a DataFrame x_cols = set(X.columns) diff = set(as_list(cols)) - x_cols if len(diff) > 0: raise ValueError(f"{diff} not in columns of X {x_cols}") @staticmethod def __check_missing_and_inf(X): """Check that all elements of X are non-missing and finite, needed because check_array cannot handle strings""" if np.any(pd.isnull(X)): raise ValueError("X has NaN values") try: if np.any(np.isinf(X)): raise ValueError("X has infinite values") except TypeError: # if X cannot be converted to numeric, checking infinites does not make sense pass def __validate(self, X, y=None): """Validate the input, used in both fit and predict""" if ( self.shrinkage and len(as_list(self.groups)) == 1 and not self.use_global_model ): raise ValueError( "Cannot do shrinkage with a single group if use_global_model is False" ) self.__check_cols_exist(X, self.value_colnames_) self.__check_cols_exist(X, self.group_colnames_) # Split the model data from the grouping columns, this part is checked `regularly` X_data = X.loc[:, self.value_colnames_] # y can be None because __validate used in predict, X can have no columns if the estimator only uses y if X_data.shape[1] > 0 and y is not None: check_X_y(X_data, y, multi_output=True) elif y is not None: check_array(y, ensure_2d=False) elif X_data.shape[1] > 0: check_array(X_data) self.__check_missing_and_inf(X) def __fit_grouped_estimator(self, X, y, value_columns, group_columns): # Reset indices such that they are the same in X and y X, y = X.reset_index(drop=True), y.reset_index(drop=True) group_indices = X.groupby(group_columns).indices grouped_estimations = { group: clone(self.estimator).fit( X.loc[indices, value_columns], y.loc[indices] ) for group, indices in group_indices.items() } return grouped_estimations def __get_shrinkage_factor(self, X): """Get for all complete groups an array of shrinkages""" counts = X.groupby(self.group_colnames_).size() # Groups that are split on all most_granular_groups = [ grp for grp in self.groups_ if len(as_list(grp)) == len(self.group_colnames_) ] # For each hierarchy level in each most granular group, get the number of observations hierarchical_counts = { granular_group: [ counts[tuple(subgroup)].sum() for subgroup in expanding_list(granular_group, tuple) ] for granular_group in most_granular_groups } # For each hierarchy level in each most granular group, get the shrinkage factor shrinkage_factors = { group: self.shrinkage_function_(counts, **self.shrinkage_kwargs) for group, counts in hierarchical_counts.items() } # Make sure that the factors sum to one shrinkage_factors = { group: value / value.sum() for group, value in shrinkage_factors.items() } return shrinkage_factors def __prepare_input_data(self, X, y=None): if isinstance(X, np.ndarray): X = pd.DataFrame(X, columns=[str(_) for _ in range(X.shape[1])]) if self.shrinkage is not None and self.use_global_model: global_col = "a-column-that-is-constant-for-all-data" X = X.assign(**{global_col: "global"}) self.groups = [global_col] + as_list(self.groups) if y is not None: if isinstance(y, np.ndarray): pred_col = ( "the-column-that-i-want-to-predict-but-dont-have-the-name-for" ) cols = ( pred_col if y.ndim == 1 else ["_".join([pred_col, i]) for i in range(y.shape[1])] ) y = ( pd.Series(y, name=cols) if y.ndim == 1 else pd.DataFrame(y, columns=cols) ) return X, y return X def fit(self, X, y=None): """ Fit the model using X, y as training data. Will also learn the groups that exist within the dataset. :param X: array-like, shape=(n_columns, n_samples,) training data. :param y: array-like, shape=(n_samples,) training data. :return: Returns an instance of self. """ X, y = self.__prepare_input_data(X, y) if self.shrinkage is not None: self.__set_shrinkage_function() self.group_colnames_ = [str(_) for _ in as_list(self.groups)] if self.value_columns is not None: self.value_colnames_ = [str(_) for _ in as_list(self.value_columns)] else: self.value_colnames_ = [ _ for _ in X.columns if _ not in self.group_colnames_ ] self.__validate(X, y) # List of all hierarchical subsets of columns self.group_colnames_hierarchical_ = expanding_list(self.group_colnames_, list) self.fallback_ = None if self.shrinkage is None and self.use_global_model: subset_x = X[self.value_colnames_] self.fallback_ = clone(self.estimator).fit(subset_x, y) if self.shrinkage is not None: self.estimators_ = {} for level_colnames in self.group_colnames_hierarchical_: self.estimators_.update( self.__fit_grouped_estimator( X, y, self.value_colnames_, level_colnames ) ) else: self.estimators_ = self.__fit_grouped_estimator( X, y, self.value_colnames_, self.group_colnames_ ) self.groups_ = as_list(self.estimators_.keys()) if self.shrinkage is not None: self.shrinkage_factors_ = self.__get_shrinkage_factor(X) return self def __predict_group(self, X, group_colnames): """Make predictions for all groups""" try: return ( X.groupby(group_colnames, as_index=False) .apply( lambda d: pd.DataFrame( self.estimators_.get(d.name, self.fallback_).predict( d[self.value_colnames_] ), index=d.index, ) ) .values.squeeze() ) except AttributeError: # Handle new groups culprits = set(X[self.group_colnames_].agg(func=tuple, axis=1)) - set( self.estimators_.keys() ) if self.shrinkage is not None and self.use_global_model: # Remove the global group from the culprits because the user did not specify culprits = {culprit[1:] for culprit in culprits} raise ValueError( f"found a group(s) {culprits} in `.predict` that was not in `.fit`" ) def __predict_shrinkage_groups(self, X): """Make predictions for all shrinkage groups""" # DataFrame with predictions for each hierarchy level, per row. Missing groups errors are thrown here. hierarchical_predictions = pd.concat( [ pd.Series(self.__predict_group(X, level_columns)) for level_columns in self.group_colnames_hierarchical_ ], axis=1, ) # This is a Series with values the tuples of hierarchical grouping prediction_groups = X[self.group_colnames_].agg(func=tuple, axis=1) # This is a Series of arrays shrinkage_factors = prediction_groups.map(self.shrinkage_factors_) # Convert the Series of arrays it to a DataFrame shrinkage_factors = pd.DataFrame.from_dict(shrinkage_factors.to_dict()).T return (hierarchical_predictions * shrinkage_factors).sum(axis=1) def predict(self, X): """ Predict on new data. :param X: array-like, shape=(n_columns, n_samples,) training data. :return: array, shape=(n_samples,) the predicted data """ X = self.__prepare_input_data(X) self.__validate(X) check_is_fitted( self, [ "estimators_", "groups_", "group_colnames_", "value_colnames_", "fallback_", ], ) if self.shrinkage is None: return self.__predict_group(X, group_colnames=self.group_colnames_) else: return self.__predict_shrinkage_groups(X)