import lightgbm as lgb from attrdict import AttrDict from sklearn.externals import joblib from .base import BaseTransformer from .utils import get_logger logger = get_logger() class LightGBM(BaseTransformer): def __init__(self, model_config, training_config): self.model_config = AttrDict(model_config) self.training_config = AttrDict(training_config) self.evaluation_function = None def fit(self, X, y, X_valid, y_valid, feature_names, categorical_features, **kwargs): train = lgb.Dataset(X, label=y, feature_name=feature_names, categorical_feature=categorical_features ) valid = lgb.Dataset(X_valid, label=y_valid, feature_name=feature_names, categorical_feature=categorical_features ) evaluation_results = {} self.estimator = lgb.train(self.model_config, train, valid_sets=[train, valid], valid_names=['train', 'valid'], evals_result=evaluation_results, num_boost_round=self.training_config.number_boosting_rounds, early_stopping_rounds=self.training_config.early_stopping_rounds, verbose_eval=self.model_config.verbose, feval=self.evaluation_function) return self def transform(self, X, y=None, **kwargs): prediction = self.estimator.predict(X) return {'prediction': prediction} def load(self, filepath): self.estimator = joblib.load(filepath) return self def save(self, filepath): joblib.dump(self.estimator, filepath)