from typing import Callable

from redshells.factory.singleton import Singleton


class _PredictionModelFactory(metaclass=Singleton):
    def __init__(self):
        self._models = dict()
        try:
            import sklearn.ensemble
            self._models['RandomForestClassifier'] = sklearn.ensemble.RandomForestClassifier
        except ImportError:
            pass

        try:
            import xgboost
            self._models['XGBClassifier'] = xgboost.XGBClassifier
        except ImportError:
            pass

        try:
            import lightgbm
            self._models['LGBMClassifier'] = lightgbm.LGBMClassifier
        except ImportError:
            pass

        try:
            import catboost
            self._models['CatBoostClassifier'] = catboost.CatBoostClassifier
        except ImportError:
            pass

    def get(self, key: str):
        if key in self._models:
            return self._models[key]
        raise RuntimeError(
            f'"{key}" is not registered. Please class "register_prediction_model" beforehand. The keys are {list(self._models.keys())}'
        )

    def register(self, key, class_name):
        self._models[key] = class_name


def get_prediction_model_type(key):
    return _PredictionModelFactory().get(key)


def create_prediction_model(key: str, **kwargs):
    return _PredictionModelFactory().get(key)(**kwargs)


def register_prediction_model(key: str, class_name: Callable) -> None:
    _PredictionModelFactory().register(key, class_name)