# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# Tests for kernel, tree and deep explainers.

import pytest
import logging

from lightgbm import LGBMClassifier, LGBMRegressor

from interpret_community.shap.kernel_explainer import KernelExplainer
from interpret_community.shap.tree_explainer import TreeExplainer
from interpret_community.shap.deep_explainer import DeepExplainer
from interpret_community.shap.linear_explainer import LinearExplainer
from common_tabular_tests import VerifyTabularTests
from common_utils import create_keras_multiclass_classifier, create_keras_regressor, \
    create_sklearn_linear_regressor, create_sklearn_logistic_regressor
from constants import owner_email_tools_and_ux

test_logger = logging.getLogger(__name__)
test_logger.setLevel(logging.INFO)


@pytest.mark.owner(email=owner_email_tools_and_ux)
@pytest.mark.usefixtures("clean_dir")
class TestKernelExplainer(object):
    def setup_class(self):
        def create_explainer(model, x_train, **kwargs):
            return KernelExplainer(model, x_train, **kwargs)

        self._verify_tabular = VerifyTabularTests(test_logger, create_explainer)

    def test_kernel_explainer_raw_transformations_list_classification(self):
        self._verify_tabular.verify_explain_model_transformations_list_classification()

    def test_kernel_explainer_raw_transformations_column_transformer_classification(self):
        self._verify_tabular.verify_explain_model_transformations_column_transformer_classification()

    def test_kernel_explainer_raw_transformations_list_regression(self):
        self._verify_tabular.verify_explain_model_transformations_list_regression()

    def test_kernel_explainer_raw_transformations_column_transformer_regression(self):
        self._verify_tabular.verify_explain_model_transformations_list_regression()


@pytest.mark.owner(email=owner_email_tools_and_ux)
@pytest.mark.usefixtures("clean_dir")
class TestDeepExplainer(object):
    def setup_class(self):
        def create_explainer(model, x_train, **kwargs):
            return DeepExplainer(model, x_train, **kwargs)

        self._verify_tabular = VerifyTabularTests(test_logger, create_explainer)

    def _get_create_model(self, classification):
        if classification:
            train_fn = create_keras_multiclass_classifier
        else:
            train_fn = create_keras_regressor

        def create_model(x, y):
            return train_fn(x, y)
        return create_model

    def test_deep_explainer_raw_transformations_list_classification(self):
        self._verify_tabular.verify_explain_model_transformations_list_classification(self._get_create_model(
            classification=True))

    def test_deep_explainer_raw_transformations_column_transformer_classification(self):
        self._verify_tabular.verify_explain_model_transformations_column_transformer_classification(
            self._get_create_model(classification=True))

    def test_deep_explainer_raw_transformations_list_regression(self):
        self._verify_tabular.verify_explain_model_transformations_list_regression(self._get_create_model(
            classification=False))

    def test_deep_explainer_raw_transformations_column_transformer_regression(self):
        self._verify_tabular.verify_explain_model_transformations_column_transformer_regression(
            self._get_create_model(classification=False))


@pytest.mark.owner(email=owner_email_tools_and_ux)
@pytest.mark.usefixtures("clean_dir")
class TestTreeExplainer(object):
    def setup_class(self):
        def create_explainer(model, x_train, **kwargs):
            return TreeExplainer(model, **kwargs)

        self._verify_tabular = VerifyTabularTests(test_logger, create_explainer)

    def _get_create_model(self, classification):
        if classification:
            model = LGBMClassifier()
        else:
            model = LGBMRegressor()

        def create_model(x, y):
            return model.fit(x, y)

        return create_model

    def test_tree_explainer_raw_transformations_list_classification(self):
        self._verify_tabular.verify_explain_model_transformations_list_classification(self._get_create_model(
            classification=True))

    def test_tree_explainer_raw_transformations_column_transformer_classification(self):
        self._verify_tabular.verify_explain_model_transformations_column_transformer_classification(
            self._get_create_model(classification=True))

    def test_tree_explainer_raw_transformations_list_regression(self):
        self._verify_tabular.verify_explain_model_transformations_list_regression(self._get_create_model(
            classification=False))

    def test_tree_explainer_raw_transformations_column_transformer_regression(self):
        self._verify_tabular.verify_explain_model_transformations_list_regression(self._get_create_model(
            classification=False))


@pytest.mark.owner(email=owner_email_tools_and_ux)
@pytest.mark.usefixtures("clean_dir")
class TestLinearExplainer(object):
    def setup_class(self):
        def create_explainer(model, x_train, **kwargs):
            return LinearExplainer(model, x_train, **kwargs)

        self._verify_tabular = VerifyTabularTests(test_logger, create_explainer)

    def _get_create_model(self, classification):
        if classification:
            train_fn = create_sklearn_logistic_regressor
        else:
            train_fn = create_sklearn_linear_regressor

        def create_model(x, y):
            return train_fn(x, y)

        return create_model

    def test_linear_explainer_raw_transformations_list_classification(self):
        self._verify_tabular.verify_explain_model_transformations_list_classification(self._get_create_model(
            classification=True))

    def test_linear_explainer_raw_transformations_column_transformer_classification(self):
        self._verify_tabular.verify_explain_model_transformations_column_transformer_classification(
            self._get_create_model(classification=True))

    def test_linear_explainer_raw_transformations_list_regression(self):
        self._verify_tabular.verify_explain_model_transformations_list_regression(self._get_create_model(
            classification=False))

    def test_linear_explainer_raw_transformations_column_transformer_regression(self):
        self._verify_tabular.verify_explain_model_transformations_list_regression(self._get_create_model(
            classification=False))