"""
All rights reserved.
--Yang Song, Apr 7th, 2020
"""
from sklearn.model_selection import StratifiedKFold, LeaveOneOut

from FAE.DataContainer.DataContainer import DataContainer


class BaseCrossValidation(object):
    def __init__(self, n_split='', description=''):
        self._description = description
        if n_split == 'all':
            self._cv = LeaveOneOut()
            self._name = 'LeaveOneOut'
        else:
            self._cv = StratifiedKFold(int(n_split), shuffle=False)
            self._name = '{}-Fold'.format(int(n_split))
        pass

    def GetName(self):
        return self._name

    def Generate(self, data_container):
        array, label = data_container.GetArray(), data_container.GetLabel()
        feature_name, case_name = data_container.GetFeatureName(), data_container.GetCaseName()
        for train_index, val_index in self._cv.split(array, label):
            train_array, train_label = array[train_index, :], label[train_index]
            val_array, val_label = array[val_index, :], label[val_index]

            sub_train_container = DataContainer(array=train_array, label=train_label, feature_name=feature_name,
                                                case_name=[case_name[index] for index in train_index])
            sub_val_container = DataContainer(array=val_array, label=val_label, feature_name=feature_name,
                                              case_name=[case_name[index] for index in val_index])
            yield (sub_train_container, sub_val_container)

    def GetDescription(self):
        return self._description


CrossValidation5Fold = BaseCrossValidation(n_split='5',
                                           description="To determine the hyper-parameter (e.g. the number of "
                                                       "features) of model, we applied cross validation with 5-fold "
                                                       "on the training data set. The hyper-parameters were set "
                                                       "according to the model performance on the validation data set. "
                                           )
CrossValidation10Fold = BaseCrossValidation(n_split='10',
                                            description="To determine the hyper-parameter (e.g. the number of "
                                                        "features) of model, we applied cross validation with 10-fold "
                                                        "on the training data set. The hyper-parameters were set "
                                                        "according to the model performance on the validation data set. "
                                            )
CrossValidationLOO = BaseCrossValidation(n_split='all',
                                         description="To determine the hyper-parameter (e.g. the number of features) "
                                                     "of model, we applied cross validation with leave-one-out on the "
                                                     "training data set. The hyper-parameters were set according to "
                                                     "the model performance on the validation data set. ")

if __name__ == '__main__':
    import numpy as np

    data = np.random.random((100, 10))
    label = np.concatenate((np.ones((60,)), np.zeros((40,))), axis=0)

    cv = LeaveOneOut()
    for train, val in cv.split(data, label):
        print(train)
        print(val)
        print('')