import os import numpy as np from pandas import DataFrame, Series from autogluon import try_import_lightgbm from ...constants import BINARY, MULTICLASS, REGRESSION # Mapping to specialized LightGBM metrics that are much faster than the standard metric computation _ag_to_lgbm_metric_dict = { BINARY: dict( accuracy='binary_error', log_loss='binary_logloss', roc_auc='auc', ), MULTICLASS: dict( accuracy='multi_error', log_loss='multi_logloss', ), REGRESSION: dict( mean_absolute_error='l1', mean_squared_error='l2', root_mean_squared_error='rmse', ), } def convert_ag_metric_to_lgbm(ag_metric_name, problem_type): return _ag_to_lgbm_metric_dict.get(problem_type, dict()).get(ag_metric_name, None) def func_generator(metric, is_higher_better, needs_pred_proba, problem_type): if needs_pred_proba: if problem_type == MULTICLASS: def function_template(y_hat, data): y_true = data.get_label() y_hat = y_hat.reshape(len(np.unique(y_true)), -1).T return metric.name, metric(y_true, y_hat), is_higher_better else: def function_template(y_hat, data): y_true = data.get_label() return metric.name, metric(y_true, y_hat), is_higher_better else: if problem_type == MULTICLASS: def function_template(y_hat, data): y_true = data.get_label() y_hat = y_hat.reshape(len(np.unique(y_true)), -1) y_hat = y_hat.argmax(axis=0) return metric.name, metric(y_true, y_hat), is_higher_better else: def function_template(y_hat, data): y_true = data.get_label() y_hat = np.round(y_hat) return metric.name, metric(y_true, y_hat), is_higher_better return function_template def construct_dataset(x: DataFrame, y: Series, location=None, reference=None, params=None, save=False, weight=None): try_import_lightgbm() import lightgbm as lgb dataset = lgb.Dataset(data=x, label=y, reference=reference, free_raw_data=True, params=params, weight=weight) if save: assert location is not None saving_path = f'{location}.bin' if os.path.exists(saving_path): os.remove(saving_path) os.makedirs(os.path.dirname(saving_path), exist_ok=True) dataset.save_binary(saving_path) # dataset_binary = lgb.Dataset(location + '.bin', reference=reference, free_raw_data=False)# .construct() return dataset