import logging
import os
import sklearn
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import DecisionTreeRegressor

from supervised.algorithms.algorithm import BaseAlgorithm
from supervised.algorithms.sklearn import SklearnAlgorithm
from supervised.algorithms.registry import AlgorithmsRegistry
from supervised.algorithms.registry import (
    BINARY_CLASSIFICATION,
    MULTICLASS_CLASSIFICATION,
    REGRESSION,
)
from supervised.utils.config import LOG_LEVEL

logger = logging.getLogger(__name__)
logger.setLevel(LOG_LEVEL)


from dtreeviz.trees import dtreeviz


class DecisionTreeAlgorithm(SklearnAlgorithm):

    algorithm_name = "Decision Tree"
    algorithm_short_name = "Decision Tree"

    def __init__(self, params):
        super(DecisionTreeAlgorithm, self).__init__(params)
        logger.debug("DecisionTreeAlgorithm.__init__")
        self.library_version = sklearn.__version__
        self.max_iters = additional.get("max_steps", 1)
        self.model = DecisionTreeClassifier(
            criterion=params.get("criterion", "gini"),
            max_depth=params.get("max_depth", 3),
            random_state=params.get("seed", 1),
        )

    def file_extension(self):
        return "decision_tree"

    def interpret(
        self,
        X_train,
        y_train,
        X_validation,
        y_validation,
        model_file_path,
        learner_name,
        target_name=None,
        class_names=None,
        metric_name=None,
        ml_task=None,
        explain_level=2,
    ):
        super(DecisionTreeAlgorithm, self).interpret(
            X_train,
            y_train,
            X_validation,
            y_validation,
            model_file_path,
            learner_name,
            target_name,
            class_names,
            metric_name,
            ml_task,
            explain_level,
        )
        if explain_level == 0:
            return
        try:
            if len(class_names) > 10:
                # dtreeviz does not support more than 10 classes
                return
            viz = dtreeviz(
                self.model,
                X_train,
                y_train,
                target_name="target",
                feature_names=X_train.columns,
                class_names=class_names,
            )
            tree_file_plot = os.path.join(model_file_path, learner_name + "_tree.svg")
            viz.save(tree_file_plot)
        except Exception as e:
            logger.info(f"Problem when visualizing decision tree. {str(e)}")


class DecisionTreeRegressorAlgorithm(SklearnAlgorithm):

    algorithm_name = "Decision Tree"
    algorithm_short_name = "Decision Tree"

    def __init__(self, params):
        super(DecisionTreeRegressorAlgorithm, self).__init__(params)
        logger.debug("DecisionTreeRegressorAlgorithm.__init__")
        self.library_version = sklearn.__version__
        self.max_iters = additional.get("max_steps", 1)
        self.model = DecisionTreeRegressor(
            criterion=params.get("criterion", "mse"),
            max_depth=params.get("max_depth", 3),
            random_state=params.get("seed", 1),
        )

    def file_extension(self):
        return "decision_tree"

    def interpret(
        self,
        X_train,
        y_train,
        X_validation,
        y_validation,
        model_file_path,
        learner_name,
        target_name=None,
        class_names=None,
        metric_name=None,
        ml_task=None,
        explain_level=2,
    ):
        super(DecisionTreeRegressorAlgorithm, self).interpret(
            X_train,
            y_train,
            X_validation,
            y_validation,
            model_file_path,
            learner_name,
            target_name,
            class_names,
            metric_name,
            ml_task,
            explain_level,
        )
        if explain_level == 0:
            return
        try:

            viz = dtreeviz(
                self.model,
                X_train,
                y_train,
                target_name="target",
                feature_names=X_train.columns,
            )
            tree_file_plot = os.path.join(model_file_path, learner_name + "_tree.svg")
            viz.save(tree_file_plot)
        except Exception as e:
            logger.info(f"Problem when visuzalizin decision tree regressor. {str(e)}")


dt_params = {"criterion": ["gini", "entropy"], "max_depth": [1, 2, 3, 4]}

classification_default_params = {"criterion": "gini", "max_depth": 3}

additional = {
    "trees_in_step": 1,
    "train_cant_improve_limit": 0,
    "max_steps": 1,
    "max_rows_limit": None,
    "max_cols_limit": None,
}
required_preprocessing = [
    "missing_values_inputation",
    "convert_categorical",
    "target_as_integer",
]

AlgorithmsRegistry.add(
    BINARY_CLASSIFICATION,
    DecisionTreeAlgorithm,
    dt_params,
    required_preprocessing,
    additional,
    classification_default_params,
)

AlgorithmsRegistry.add(
    MULTICLASS_CLASSIFICATION,
    DecisionTreeAlgorithm,
    dt_params,
    required_preprocessing,
    additional,
    classification_default_params,
)

dt_regression_params = {
    "criterion": [
        "mse",
        "friedman_mse",
    ],  # remove "mae" because it slows down a lot https://github.com/scikit-learn/scikit-learn/issues/9626
    "max_depth": [1, 2, 3, 4],
}
regression_required_preprocessing = ["missing_values_inputation", "convert_categorical"]

regression_default_params = {"criterion": "mse", "max_depth": 3}

AlgorithmsRegistry.add(
    REGRESSION,
    DecisionTreeRegressorAlgorithm,
    dt_regression_params,
    regression_required_preprocessing,
    additional,
    regression_default_params,
)