# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

"""Defines an explainable lightgbm model."""

from .explainable_model import BaseExplainableModel, _get_initializer_args, _clean_doc
from .tree_model_utils import _explain_local_tree_surrogate, _expected_values_tree_surrogate
from ...common.constants import ShapValuesOutput, LightGBMSerializationConstants, \
    ExplainableModelType, Extension
import json
import warnings
import logging
import inspect

with warnings.catch_warnings():
    warnings.filterwarnings('ignore', 'Starting from version 2.2.1', UserWarning)
    import shap
    try:
        from lightgbm import LGBMRegressor, LGBMClassifier, Booster
        import lightgbm
        from packaging import version
        if (version.parse(lightgbm.__version__) <= version.parse('2.2.1')):
            print("Using older than supported version of lightgbm, please upgrade to version greater than 2.2.1")
    except ImportError:
        print("Could not import lightgbm, required if using LGBMExplainableModel")

DEFAULT_RANDOM_STATE = 123
_N_FEATURES = '_n_features'
_N_CLASSES = '_n_classes'


class LGBMExplainableModel(BaseExplainableModel):
    available_explanations = [Extension.GLOBAL, Extension.LOCAL]
    explainer_type = Extension.GLASSBOX

    """LightGBM (fast, high performance framework based on decision tree) explainable model.

    Please see documentation for more details: https://github.com/Microsoft/LightGBM

    Additional arguments to LightGBMClassifier and LightGBMRegressor can be passed through kwargs.

    :param multiclass: Set to true to generate a multiclass model.
    :type multiclass: bool
    :param random_state: Int to seed the model.
    :type random_state: int
    :param shap_values_output: The type of the output from explain_local when using TreeExplainer.
        Currently only types 'default', 'probability' and 'teacher_probability' are supported.  If
        'probability' is specified, then we approximately scale the raw log-odds values from the
        TreeExplainer to probabilities.
    :type shap_values_output: interpret_community.common.constants.ShapValuesOutput
    :param classification: Indicates if this is a classification or regression explanation.
    :type classification: bool
    """

    def __init__(self, multiclass=False, random_state=DEFAULT_RANDOM_STATE,
                 shap_values_output=ShapValuesOutput.DEFAULT, classification=True, **kwargs):
        """Initialize the LightGBM Model.

        Additional arguments to LightGBMClassifier and LightGBMRegressor can be passed through kwargs.

        :param multiclass: Set to true to generate a multiclass model.
        :type multiclass: bool
        :param random_state: Int to seed the model.
        :type random_state: int
        :param shap_values_output: The type of the output from explain_local when using TreeExplainer.
            Currently only types 'default', 'probability' and 'teacher_probability' are supported.  If
            'probability' is specified, then we approximately scale the raw log-odds values from the
            TreeExplainer to probabilities.
        :type shap_values_output: interpret_community.common.constants.ShapValuesOutput
        :param classification: Indicates if this is a classification or regression explanation.
        :type classification: bool
        """
        self.multiclass = multiclass
        initializer_args = _get_initializer_args(kwargs)
        if self.multiclass:
            initializer = LGBMClassifier
        else:
            initializer = LGBMRegressor
        self._lgbm = initializer(random_state=random_state, **initializer_args)
        super(LGBMExplainableModel, self).__init__(**kwargs)
        self._logger.debug('Initializing LGBMExplainableModel')
        self._method = 'lightgbm'
        self._tree_explainer = None
        self._shap_values_output = shap_values_output
        self._classification = classification

    try:
        __init__.__doc__ = (__init__.__doc__ +
                            '\nIf multiclass=True, uses the parameters for LGBMClassifier:\n' +
                            _clean_doc(LGBMClassifier.__init__.__doc__) +
                            '\nOtherwise, if multiclass=False, uses the parameters for LGBMRegressor:\n' +
                            _clean_doc(LGBMRegressor.__init__.__doc__))
    except Exception:
        pass

    def fit(self, dataset, labels, **kwargs):
        """Call lightgbm fit to fit the explainable model.

        :param dataset: The dataset to train the model on.
        :type dataset: numpy or scipy array
        :param labels: The labels to train the model on.
        :type labels: numpy or scipy array
        """
        self._lgbm.fit(dataset, labels, **kwargs)

    try:
        fit.__doc__ = (fit.__doc__ +
                       '\nIf multiclass=True, uses the parameters for LGBMClassifier:\n' +
                       _clean_doc(LGBMClassifier.fit.__doc__) +
                       '\nOtherwise, if multiclass=False, uses the parameters for LGBMRegressor:\n' +
                       _clean_doc(LGBMRegressor.fit.__doc__))
    except Exception:
        pass

    def predict(self, dataset, **kwargs):
        """Call lightgbm predict to predict labels using the explainable model.

        :param dataset: The dataset to predict on.
        :type dataset: numpy or scipy array
        :return: The predictions of the model.
        :rtype: list
        """
        return self._lgbm.predict(dataset, **kwargs)

    try:
        predict.__doc__ = (predict.__doc__ +
                           '\nIf multiclass=True, uses the parameters for LGBMClassifier:\n' +
                           _clean_doc(LGBMClassifier.predict.__doc__) +
                           '\nOtherwise, if multiclass=False, uses the parameters for LGBMRegressor:\n' +
                           _clean_doc(LGBMRegressor.predict.__doc__))
    except Exception:
        pass

    def predict_proba(self, dataset, **kwargs):
        """Call lightgbm predict_proba to predict probabilities using the explainable model.

        :param dataset: The dataset to predict probabilities on.
        :type dataset: numpy or scipy array
        :return: The predictions of the model.
        :rtype: list
        """
        if self.multiclass:
            return self._lgbm.predict_proba(dataset, **kwargs)
        else:
            raise Exception("predict_proba not supported for regression or binary classification dataset")

    try:
        predict_proba.__doc__ = (predict_proba.__doc__ +
                                 '\nIf multiclass=True, uses the parameters for LGBMClassifier:\n' +
                                 _clean_doc(LGBMClassifier.predict_proba.__doc__) +
                                 '\nOtherwise predict_proba is not supported for ' +
                                 'regression or binary classification.\n')
    except Exception:
        pass

    def explain_global(self, **kwargs):
        """Call lightgbm feature importances to get the global feature importances from the explainable model.

        :return: The global explanation of feature importances.
        :rtype: numpy.ndarray
        """
        return self._lgbm.feature_importances_

    def explain_local(self, evaluation_examples, probabilities=None, **kwargs):
        """Use TreeExplainer to get the local feature importances from the trained explainable model.

        :param evaluation_examples: The evaluation examples to compute local feature importances for.
        :type evaluation_examples: numpy or scipy array
        :param probabilities: If output_type is probability, can specify the teacher model's
            probability for scaling the shap values.
        :type probabilities: numpy.ndarray
        :return: The local explanation of feature importances.
        :rtype: Union[list, numpy.ndarray]
        """
        if self._tree_explainer is None:
            self._tree_explainer = shap.TreeExplainer(self._lgbm)
        return _explain_local_tree_surrogate(self._lgbm, evaluation_examples, self._tree_explainer,
                                             self._shap_values_output, self._classification,
                                             probabilities, self.multiclass)

    @property
    def expected_values(self):
        """Use TreeExplainer to get the expected values.

        :return: The expected values of the LightGBM tree model.
        :rtype: list
        """
        if self._tree_explainer is None:
            self._tree_explainer = shap.TreeExplainer(self._lgbm)
        return _expected_values_tree_surrogate(self._lgbm, self._tree_explainer, self._shap_values_output,
                                               self._classification, self.multiclass)

    @property
    def model(self):
        """Retrieve the underlying model.

        :return: The lightgbm model, either classifier or regressor.
        :rtype: Union[LGBMClassifier, LGBMRegressor]
        """
        return self._lgbm

    @staticmethod
    def explainable_model_type(self):
        """Retrieve the model type.

        :return: Tree explainable model type.
        :rtype: ExplainableModelType
        """
        return ExplainableModelType.TREE_EXPLAINABLE_MODEL_TYPE

    def _save(self):
        """Return a string dictionary representation of the LGBMExplainableModel.

        :return: A serialized dictionary representation of the LGBMExplainableModel.
        :rtype: dict
        """
        properties = {}
        # Save all of the properties
        for key, value in self.__dict__.items():
            if key in LightGBMSerializationConstants.nonify_properties:
                properties[key] = None
            elif key in LightGBMSerializationConstants.save_properties:
                # Save booster model to string representation
                # This is not recommended but can be necessary to get around pickle being not secure
                # See here for more info:
                # https://github.com/Microsoft/LightGBM/issues/1942
                # https://github.com/Microsoft/LightGBM/issues/1217
                properties[key] = value.booster_.model_to_string()
            else:
                properties[key] = json.dumps(value)
        # Need to add _n_features
        properties[_N_FEATURES] = self._lgbm._n_features
        # And if classification case need to add _n_classes
        if self.multiclass:
            properties[_N_CLASSES] = self._lgbm._n_classes
        return properties

    @staticmethod
    def _load(properties):
        """Load a LGBMExplainableModel from the given properties.

        :param properties: A serialized dictionary representation of the LGBMExplainableModel.
        :type properties: dict
        :return: The deserialized LGBMExplainableModel.
        :rtype: interpret_community.mimic.models.LGBMExplainableModel
        """
        # create the LGBMExplainableModel without any properties using the __new__ function, similar to pickle
        lightgbm = LGBMExplainableModel.__new__(LGBMExplainableModel)
        # Get _n_features
        _n_features = properties.pop(_N_FEATURES)
        # If classification case get _n_classes
        if json.loads(properties[LightGBMSerializationConstants.MULTICLASS]):
            _n_classes = properties.pop(_N_CLASSES)
        # load all of the properties
        for key, value in properties.items():
            # Regenerate the properties on the fly
            if key in LightGBMSerializationConstants.nonify_properties:
                if key == LightGBMSerializationConstants.LOGGER:
                    parent = logging.getLogger(__name__)
                    lightgbm_identity = json.loads(properties[LightGBMSerializationConstants.IDENTITY])
                    lightgbm.__dict__[key] = parent.getChild(lightgbm_identity)
                elif key == LightGBMSerializationConstants.TREE_EXPLAINER:
                    lightgbm.__dict__[key] = None
                else:
                    raise Exception("Unknown nonify key on deserialize in LightGBMExplainableModel: {}".format(key))
            elif key in LightGBMSerializationConstants.save_properties:
                # Load the booster from file and re-create the LGBMClassifier or LGBMRegressor
                # This is not recommended but can be necessary to get around pickle being not secure
                # See here for more info:
                # https://github.com/Microsoft/LightGBM/issues/1942
                # https://github.com/Microsoft/LightGBM/issues/1217
                booster_args = {LightGBMSerializationConstants.MODEL_STR: value}
                is_multiclass = json.loads(properties[LightGBMSerializationConstants.MULTICLASS])
                if is_multiclass:
                    objective = LightGBMSerializationConstants.MULTICLASS
                else:
                    objective = LightGBMSerializationConstants.REGRESSION
                if LightGBMSerializationConstants.MODEL_STR in inspect.getargspec(Booster).args:
                    extras = {LightGBMSerializationConstants.OBJECTIVE: objective}
                    lgbm_booster = Booster(**booster_args, params=extras)
                else:
                    # For backwards compatibility with older versions of lightgbm
                    booster_args[LightGBMSerializationConstants.OBJECTIVE] = objective
                    lgbm_booster = Booster(params=booster_args)
                if is_multiclass:
                    new_lgbm = LGBMClassifier()
                    new_lgbm._Booster = lgbm_booster
                    new_lgbm._n_classes = _n_classes
                else:
                    new_lgbm = LGBMRegressor()
                    new_lgbm._Booster = lgbm_booster
                new_lgbm._n_features = _n_features
                lightgbm.__dict__[key] = new_lgbm
            elif key in LightGBMSerializationConstants.enum_properties:
                # NOTE: If more enums added in future, will need to handle this differently
                lightgbm.__dict__[key] = ShapValuesOutput(json.loads(value))
            else:
                lightgbm.__dict__[key] = json.loads(value)
        return lightgbm