# -*- coding: utf-8 -*-

"""Package where the MLChallenge class is defined."""

import logging
import os
from copy import deepcopy
from urllib.parse import urljoin

import pandas as pd
from sklearn.metrics import make_scorer
from sklearn.model_selection import KFold, StratifiedKFold, cross_val_score
from sklearn.preprocessing import OneHotEncoder

from btb_benchmark.challenges.challenge import Challenge

BASE_DATASET_URL = 'https://atm-data.s3.amazonaws.com/'
BUCKET_NAME = 'atm-data'
LOGGER = logging.getLogger(__name__)

# Available datasets sorted by execution time, slowest first
DATASETS_BY_TIME = [
    'BNG(cmc)_1.csv',
    'eye_movements_1.csv',
    'wall-robot-navigation_1.csv',
    'analcatdata_germangss_2.csv',
    'scene_1.csv',
    'arcene_1.csv',
    'mv_1.csv',
    'bank32nh_1.csv',
    'BNG(cmc,nominal,55296)_1.csv',
    'AP_Endometrium_Prostate_1.csv',
    'house_16H_1.csv',
    'waveform-5000_1.csv',
    'autoUniv-au4-2500_1.csv',
    'BNG(breast-w)_1.csv',
    'wall-robot-navigation_2.csv',
    'madelon_1.csv',
    'Click_prediction_small_1.csv',
    'wall-robot-navigation_3.csv',
    'nursery_1.csv',
    'kdd_JapaneseVowels_1.csv',
    'house_8L_1.csv',
    'MagicTelescope_1.csv',
    'vehicle_1.csv',
    'musk_1.csv',
    'BNG(vote)_1.csv',
    'PopularKids_1.csv',
    'abalone_2.csv',
    'rmftsa_sleepdata_2.csv',
    'yeast_ml8_1.csv',
    'leukemia_1.csv',
    'electricity_1.csv',
    'CreditCardSubset_1.csv',
    'cmc_1.csv',
    'car_1.csv',
    'autoUniv-au7-700_1.csv',
    'grub-damage_1.csv',
    'analcatdata_authorship_1.csv',
    'skin-segmentation_1.csv',
    'splice_1.csv',
    'cpu_small_1.csv',
    'tumors_C_1.csv',
    'cpu_act_1.csv',
    'ringnorm_1.csv',
    'bank-marketing_2.csv',
    'mfeat-factors_1.csv',
    'pol_1.csv',
    '2dplanes_1.csv',
    'Amazon_employee_access_1.csv',
    'cal_housing_1.csv',
    'houses_1.csv',
    'BNG(tic-tac-toe)_1.csv',
    'mozilla4_1.csv',
    'cardiotocography_1.csv',
    'meta_ensembles.arff_1.csv',
    'balance-scale_1.csv',
    'fried_1.csv',
    'tae_1.csv',
    'spambase_1.csv',
    'ozone-level-8hr_1.csv',
    'PhishingWebsites_1.csv',
    'elevators_1.csv',
    'bank8FM_1.csv',
    'mammography_1.csv',
    'hayes-roth_2.csv',
    'sylva_prior_1.csv',
    'meta_batchincremental.arff_1.csv',
    'waveform-5000_2.csv',
    'teachingAssistant_1.csv',
    'ailerons_1.csv',
    'robot-failures-lp1_1.csv',
    'desharnais_1.csv',
    'PieChart4_1.csv',
    'eeg-eye-state_1.csv',
    'Engine1_1.csv',
    'puma32H_1.csv',
    'lymph_1.csv',
    'wind_1.csv',
    'mfeat-fourier_1.csv',
    'steel-plates-fault_1.csv',
    'mfeat-karhunen_1.csv',
    'meta_instanceincremental.arff_1.csv',
    'page-blocks_1.csv',
    'twonorm_1.csv',
    'mc1_1.csv',
    'wine_1.csv',
    'hill-valley_1.csv',
    'robot-failures-lp4_1.csv',
    'kc1_1.csv',
    'segment_1.csv',
    'spectrometer_1.csv',
    'seismic-bumps_1.csv',
    'pc4_1.csv',
    'pc3_1.csv',
    'ada_agnostic_1.csv',
    'splice_2.csv',
    'bank-marketing_1.csv',
    'seeds_1.csv',
    'delta_ailerons_1.csv',
    'white-clover_1.csv',
    'boston_1.csv',
    'letter_1.csv',
    'prnn_viruses_1.csv',
    'oil_spill_1.csv',
    'credit-g_1.csv',
    'diabetes_1.csv',
    'boston_corrected_1.csv',
    'PieChart3_1.csv',
    'optdigits_1.csv',
    'PizzaCutter3_1.csv',
    'qsar-biodeg_1.csv',
    'abalone_1.csv',
    'cmc_2.csv',
    'iris_1.csv',
    'vinnie_1.csv',
    'pc1_1.csv',
    'balloon_1.csv',
    'tic-tac-toe_1.csv',
    'fri_c4_1000_100_1.csv',
    'robot-failures-lp3_1.csv',
    'autoUniv-au1-1000_1.csv',
    'lsvt_1.csv',
    'hill-valley_2.csv',
    'banana_1.csv',
    'vertebra-column_2.csv',
    'quake_1.csv',
    'rmftsa_sleepdata_1.csv',
    'kr-vs-kp_1.csv',
    'mfeat-zernike_1.csv',
    'wisconsin_1.csv',
    'CastMetal1_1.csv',
    'car_2.csv',
    'pendigits_1.csv',
    'PizzaCutter1_1.csv',
    'kc2_1.csv',
    'delta_elevators_1.csv',
    'mw1_1.csv',
    'analcatdata_asbestos_1.csv',
    'mu284_1.csv',
    'plasma_retinol_1.csv',
    'quake_2.csv',
    'socmob_1.csv',
    'PieChart1_1.csv',
    'housing_1.csv',
    'pasture_1.csv',
    'sensory_1.csv',
    'wine_2.csv',
    'chscase_geyser1_1.csv',
    'monks-problems-2_1.csv',
    'mfeat-pixel_1.csv',
    'analcatdata_dmft_1.csv',
    'autoPrice_1.csv',
    'SPECTF_1.csv',
    'breast-tissue_1.csv',
    'phoneme_1.csv',
    'vehicle_2.csv',
    'visualizing_soil_1.csv',
    'rabe_266_1.csv',
    'blood-transfusion-service-center_1.csv',
    'thoracic-surgery_1.csv',
    'ionosphere_1.csv',
    'pm10_1.csv',
    'sa-heart_1.csv',
    'analcatdata_authorship_2.csv',
    'jEdit_4.2_4.3_1.csv',
    'Australian_1.csv',
    'ecoli_1.csv',
    'ilpd_1.csv',
    'climate-model-simulation-crashes_1.csv',
    'pc1_req_1.csv',
    'sonar_1.csv',
    'mc2_1.csv',
    'auto_price_1.csv',
    'arsenic-female-bladder_1.csv',
    'kdd_synthetic_control_1.csv',
    'analcatdata_supreme_1.csv',
    'prnn_fglass_1.csv',
    'kin8nm_1.csv',
    'no2_1.csv',
    'vowel_1.csv',
    'pc2_1.csv',
    'fri_c4_1000_50_1.csv',
    'stock_1.csv',
    'tae_2.csv',
    'transplant_1.csv',
    'chscase_funds_1.csv',
    'SPECT_1.csv',
    'visualizing_galaxy_1.csv',
    'analcatdata_vineyard_1.csv',
    'newton_hema_1.csv',
    'puma8NH_1.csv',
    'fri_c4_500_100_1.csv',
    'wilt_1.csv',
    'analcatdata_germangss_1.csv',
    'dresses-sales_1.csv',
    'haberman_1.csv',
    'fri_c1_1000_50_1.csv',
    'fri_c2_1000_50_1.csv',
    'white-clover_2.csv',
    'ar4_1.csv',
    'chscase_vine2_1.csv',
    'balance-scale_2.csv',
    'CostaMadre1_1.csv',
    'diggle_table_a2_1.csv',
    'arsenic-male-bladder_1.csv',
    'sleuth_case2002_1.csv',
    'grub-damage_2.csv',
    'glass_1.csv',
    'rmftsa_ctoarrivals_1.csv',
    'jEdit_4.0_4.2_1.csv',
    'analcatdata_apnea1_1.csv',
    'prnn_cushings_1.csv',
    'pwLinear_1.csv',
    'analcatdata_apnea2_1.csv',
    'analcatdata_boxing2_1.csv',
    'visualizing_livestock_1.csv',
    'PieChart2_1.csv',
    'monks-problems-3_1.csv',
    'wholesale-customers_1.csv',
    'analcatdata_wildcat_1.csv',
    'fri_c3_1000_50_1.csv',
    'flags_1.csv',
    'analcatdata_cyyoung9302_1.csv',
    'fri_c3_1000_5_1.csv',
    'monks-problems-1_1.csv',
    'bodyfat_1.csv',
    'fri_c1_1000_25_1.csv',
    'fri_c3_1000_25_1.csv',
    'veteran_1.csv',
    'baskball_1.csv',
    'backache_1.csv',
    'kc3_1.csv',
    'kc1-binary_1.csv',
    'heart-statlog_1.csv',
    'fri_c0_1000_50_1.csv',
    'visualizing_ethanol_1.csv',
    'parkinsons_1.csv',
    'lowbwt_1.csv',
    'analcatdata_michiganacc_1.csv',
    'sleuth_ex1605_1.csv',
    'fruitfly_1.csv',
    'machine_cpu_1.csv',
    'tecator_1.csv',
    'triazines_1.csv',
    'pyrim_1.csv',
    'fri_c2_1000_10_1.csv',
    'lymph_2.csv',
    'fri_c3_500_50_1.csv',
    'fri_c0_1000_25_1.csv',
    'fri_c2_500_50_1.csv',
    'cloud_1.csv',
    'fri_c2_500_25_1.csv',
    'pollution_1_train.csv',
    'analcatdata_seropositive_1.csv',
    'pasture_2.csv',
    'analcatdata_lawsuit_1.csv',
    'dbworld-bodies_1.csv',
    'fri_c4_250_100_1.csv',
    'sleuth_ex2016_1.csv',
    'vineyard_1.csv',
    'analcatdata_apnea3_1.csv',
    'mfeat-morphological_1.csv',
    'analcatdata_boxing1_1.csv',
    'MegaWatt1_1.csv',
    'cm1_req_1.csv',
    'iris_2.csv',
    'fri_c2_1000_25_1.csv',
    'servo_1.csv',
    'fri_c3_1000_10_1.csv',
    'space_ga_1.csv',
    'nursery_2.csv',
    'fri_c4_1000_25_1.csv',
    'pollution_1.csv',
    'fri_c0_100_10_1.csv',
    'cpu_1.csv',
    'strikes_1.csv',
    'rabe_148_1.csv',
    'fri_c1_500_50_1.csv',
    'rmftsa_ladata_1.csv',
    'dbworld-subjects-stemmed_1.csv',
    'blogger_1.csv',
    'molecular-biology_promoters_1.csv',
    'kidney_1.csv',
    'fri_c1_500_25_1.csv',
    'fri_c2_250_50_1.csv',
    'diggle_table_a1_1.csv',
    'fri_c1_1000_10_1.csv',
    'sleuth_ex1221_1.csv',
    'analcatdata_vehicle_1.csv',
    'fl2000_1.csv',
    'kc1-top5_1.csv',
    'qualitative-bankruptcy_1.csv',
    'hutsof99_logis_1.csv',
    'sleuth_case1202_1.csv',
    'ar5_1.csv',
    'datatrieve_1.csv',
    'chscase_census4_1.csv',
    'analcatdata_cyyoung8092_1.csv',
    'acute-inflammations_1.csv',
    'fri_c3_250_25_1.csv',
    'chscase_census5_1.csv',
    'analcatdata_election2000_1.csv',
    'fri_c4_500_50_1.csv',
    'fri_c0_500_10_1.csv',
    'fri_c0_1000_10_1.csv',
    'hutsof99_child_witness_1.csv',
    'chscase_vine1_1.csv',
    'mbagrade_1.csv',
    'fri_c4_100_25_1.csv',
    'fri_c1_250_25_1.csv',
    'fri_c2_250_25_1.csv',
    'acute-inflammations_2.csv',
    'fri_c3_500_5_1.csv',
    'sleuth_case1102_1.csv',
    'fri_c2_100_25_1.csv',
    'analcatdata_gviolence_1.csv',
    'banknote-authentication_1.csv',
    'humandevel_1.csv',
    'fri_c0_500_50_1.csv',
    'chscase_health_1.csv',
    'pollution_1_test.csv',
    'fri_c0_250_50_1.csv',
    'fertility_1.csv',
    'chscase_census2_1.csv',
    'pollen_1.csv',
    'ar1_1.csv',
    'fri_c1_250_50_1.csv',
    'dbworld-subjects_1.csv',
    'arsenic-female-lung_1.csv',
    'fri_c0_100_25_1.csv',
    'schlvote_1.csv',
    'ar3_1.csv',
    'disclosure_x_bias_1.csv',
    'vertebra-column_1.csv',
    'fri_c4_500_25_1.csv',
    'molecular-biology_promoters_2.csv',
    'hayes-roth_1.csv',
    'fri_c3_100_50_1.csv',
    'disclosure_x_noise_1.csv',
    'wind_correlations_1.csv',
    'fri_c3_100_10_1.csv',
    'visualizing_environmental_1.csv',
    'fri_c2_100_5_1.csv',
    'fri_c4_250_10_1.csv',
    'fri_c4_100_100_1.csv',
    'fri_c0_250_10_1.csv',
    'wdbc_1.csv',
    'fri_c1_100_50_1.csv',
    'fri_c0_100_5_1.csv',
    'confidence_1.csv',
    'analcatdata_bankruptcy_1.csv',
    'fri_c3_100_5_1.csv',
    'fri_c0_250_25_1.csv',
    'elusage_1.csv',
    'fri_c4_100_50_1.csv',
    'sleuth_case1201_1.csv',
    'fri_c1_100_25_1.csv',
    'chatfield_4_1.csv',
    'analcatdata_chlamydia_1.csv',
    'fri_c2_500_5_1.csv',
    'chscase_census6_1.csv',
    'fri_c4_250_25_1.csv',
    'aids_1.csv',
    'bolts_1.csv',
    'fri_c0_500_25_1.csv',
    'fri_c4_1000_10_1.csv',
    'fri_c2_100_50_1.csv',
    'rabe_97_1.csv',
    'fri_c3_250_50_1.csv',
    'analcatdata_olympic2000_1.csv',
    'planning-relax_1.csv',
    'collins_1.csv',
    'rabe_265_1.csv',
    'sleuth_ex2015_1.csv',
    'fri_c1_250_10_1.csv',
    'fri_c0_1000_5_1.csv',
    'lupus_1.csv',
    'visualizing_hamster_1.csv',
    'sleuth_ex1714_1.csv',
    'chscase_adopt_1.csv',
    'fri_c1_500_5_1.csv',
    'fri_c2_1000_5_1.csv',
    'fri_c3_500_25_1.csv',
    'fri_c4_250_50_1.csv',
    'rabe_131_1.csv',
    'fri_c3_250_10_1.csv',
    'zoo_1.csv',
    'ar6_1.csv',
    'witmer_census_1980_1.csv',
    'fri_c0_100_50_1.csv',
    'analcatdata_challenger_1.csv',
    'fri_c4_500_10_1.csv',
    'fri_c1_100_5_1.csv',
    'fri_c2_500_10_1.csv',
    'fri_c3_100_25_1.csv',
    'fri_c2_250_10_1.csv',
    'fri_c3_500_10_1.csv',
    'rabe_166_1.csv',
    'fri_c3_250_5_1.csv',
    'fri_c0_500_5_1.csv',
    'analcatdata_creditscore_1.csv',
    'chscase_census3_1.csv',
    'fri_c1_1000_5_1.csv',
    'fri_c2_100_10_1.csv',
    'fri_c1_500_10_1.csv',
    'prnn_synth_1.csv',
    'fri_c0_250_5_1.csv',
    'badges2_1.csv',
    'fri_c2_250_5_1.csv',
    'visualizing_slope_1.csv',
    'disclosure_x_tampered_1.csv',
    'fri_c1_100_10_1.csv',
    'analcatdata_neavote_1.csv',
    'disclosure_z_1.csv',
    'diabetes_numeric_1.csv',
    'analcatdata_challenger_2.csv',
    'fri_c1_250_5_1.csv',
    'fri_c4_100_10_1.csv',
    'rabe_176_1.csv',
    'arsenic-male-lung_1.csv',
    'sleuth_ex2016_2.csv',
    'sleuth_ex2015_2.csv',
    'analcatdata_japansolvent_1.csv'
]


class MLChallenge(Challenge):
    """Machine Learning Challenge class.

    The MLChallenge class rerpresents a single ``machine learning challenge`` that can be used for
    benchmark.

    Args:
        model (class):
            Class of a machine learning estimator.

        dataset (str):
            Name or path to a dataset. If it's a name it will try to read it from
            https://btb-data.s3.amazonaws.com/

        target_column (str):
            Name of the target column in the dataset.

        encode (bool):
            Either or not to encode the dataset using ``sklearn.preprocessing.OneHotEncoder``.

        model_defaults (dict):
            Dictionary with default keyword args for the model instantiation.

        make_binary (bool):
            Either or not to make the target column binary.

        tunable_hyperparameters (dict):
            Dictionary representing the tunable hyperparameters for the challenge.

        metric (callable):
            Metric function. If ``None``, then the estimator's metric function
            will be used in case there is otherwise the default that ``cross_val_score`` function
            offers will be used.
    """

    _data = None

    @classmethod
    def get_dataset_url(cls, name):
        if not name.endswith('.csv'):
            name = name + '.csv'

        return urljoin(BASE_DATASET_URL, name)

    @classmethod
    def get_available_dataset_names(cls):
        return DATASETS_BY_TIME.copy()

    @classmethod
    def get_all_challenges(cls, challenges=None):
        """Return a list containing the instance of the datasets available."""
        datasets = challenges or cls.get_available_dataset_names()
        loaded_challenges = []
        for dataset in datasets:
            try:
                loaded_challenges.append(cls(dataset))
                LOGGER.info('Dataset %s loaded', dataset)
            except Exception as ex:
                LOGGER.warn('Dataset: %s could not be loaded. Error: %s', dataset, ex)

        LOGGER.info('%s / %s datasets loaded.', len(loaded_challenges), len(datasets))

        return loaded_challenges

    def load_data(self):
        """Load ``X`` and ``y`` over which to perform fit and evaluate."""
        if os.path.isdir(self.dataset):
            X = pd.read_csv(self.dataset)

        else:
            url = self.get_dataset_url(self.dataset)
            X = pd.read_csv(url)

        y = X.pop(self.target_column)

        if self.make_binary:
            y = y.iloc[0] == y

        if self.encode:
            ohe = OneHotEncoder(categories='auto')
            X = ohe.fit_transform(X)

        return X, y

    @property
    def data(self):
        if self._data is None:
            self._data = self.load_data()

        return self._data

    def __init__(self, dataset, model=None, target_column=None,
                 encode=None, tunable_hyperparameters=None, metric=None,
                 model_defaults=None, make_binary=None, stratified=None,
                 cv_splits=5, cv_random_state=42, cv_shuffle=True, metric_args={}):

        self.model = model or self.MODEL
        self.dataset = dataset or self.DATASET
        self.target_column = target_column or self.TARGET_COLUMN
        self.model_defaults = model_defaults or self.MODEL_DEFAULTS
        self.make_binary = make_binary or self.MAKE_BINARY
        self.tunable_hyperparameters = tunable_hyperparameters or self.TUNABLE_HYPERPARAMETERS

        if metric:
            self.metric = metric
            self.metric_args = metric_args

        else:
            # Allow to either write a metric method or assign a METRIC function
            self.metric = getattr(self, 'metric', self.__class__.METRIC)
            self.metric_args = getattr(self, 'metric_args', self.__class__.METRIC_ARGS)

        self.stratified = self.STRATIFIED if stratified is None else stratified
        # self.X, self.y = self.load_data()

        self.encode = self.ENCODE if encode is None else encode
        self.scorer = make_scorer(self.metric, **self.metric_args)

        if self.stratified:
            self.cv = StratifiedKFold(
                shuffle=cv_shuffle,
                n_splits=cv_splits,
                random_state=cv_random_state
            )
        else:
            self.cv = KFold(
                shuffle=cv_shuffle,
                n_splits=cv_splits,
                random_state=cv_random_state
            )

    def get_tunable_hyperparameters(self):
        return deepcopy(self.tunable_hyperparameters)

    def evaluate(self, **hyperparams):
        """Apply cross validation to hyperparameter combination.

        Args:
            hyperparams (dict):
                A combination of ``self.tunable_hyperparams``.

        Returns:
            score (float):
                Returns the ``mean`` cross validated score.
        """
        hyperparams.update((self.model_defaults or {}))
        model = self.model(**hyperparams)
        X, y = self.data

        return cross_val_score(model, X, y, cv=self.cv, scoring=self.scorer).mean()

    def __repr__(self):
        return "{}('{}')".format(self.__class__.__name__, self.dataset)