# -*- coding: utf-8 -*-


import os
import sys

import unittest
from sklearn.utils.testing import assert_equal
# noinspection PyProtectedMember
from sklearn.utils.testing import assert_allclose
from sklearn.utils.testing import assert_raises
from sklearn.metrics import precision_score
from sklearn.utils import check_random_state

import numpy as np

# temporary solution for relative imports in case combo is not installed
# if combo is installed, no need to use the following line
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from utils.utility import standardizer
from utils.utility import get_label_n
from utils.utility import precision_n_scores
from utils.utility import argmaxn
from utils.utility import invert_order
from utils.utility import check_detector
from utils.utility import score_to_label
from utils.utility import list_diff
from utils.utility import score_to_proba


class TestScaler(unittest.TestCase):

    def setUp(self):
        random_state = check_random_state(42)
        self.X_train = random_state.rand(500, 5)
        self.X_test = random_state.rand(100, 5)
        self.X_test_diff = random_state.rand(100, 10)
        self.scores1 = [0.1, 0.3, 0.5, 0.7, 0.2, 0.1]
        self.scores2 = np.array([0.1, 0.3, 0.5, 0.7, 0.2, 0.1])

    def test_normalization(self):

        # test when X_t is presented and no scalar
        norm_X_train, norm_X_test = standardizer(self.X_train, self.X_test)
        assert_allclose(norm_X_train.mean(), 0, atol=0.05)
        assert_allclose(norm_X_train.std(), 1, atol=0.05)

        assert_allclose(norm_X_test.mean(), 0, atol=0.05)
        assert_allclose(norm_X_test.std(), 1, atol=0.05)

        # test when X_t is not presented and no scalar
        norm_X_train = standardizer(self.X_train)
        assert_allclose(norm_X_train.mean(), 0, atol=0.05)
        assert_allclose(norm_X_train.std(), 1, atol=0.05)

        # test when X_t is presented and the scalar is kept
        norm_X_train, norm_X_test, scalar = standardizer(self.X_train,
                                                         self.X_test,
                                                         keep_scalar=True)

        assert_allclose(norm_X_train.mean(), 0, atol=0.05)
        assert_allclose(norm_X_train.std(), 1, atol=0.05)

        assert_allclose(norm_X_test.mean(), 0, atol=0.05)
        assert_allclose(norm_X_test.std(), 1, atol=0.05)

        if not hasattr(scalar, 'fit') or not hasattr(scalar, 'transform'):
            raise AttributeError("%s is not a detector instance." % (scalar))

        # test when X_t is not presented and the scalar is kept
        norm_X_train, scalar = standardizer(self.X_train, keep_scalar=True)

        assert_allclose(norm_X_train.mean(), 0, atol=0.05)
        assert_allclose(norm_X_train.std(), 1, atol=0.05)

        if not hasattr(scalar, 'fit') or not hasattr(scalar, 'transform'):
            raise AttributeError("%s is not a detector instance." % (scalar))

        # test shape difference
        with assert_raises(ValueError):
            standardizer(self.X_train, self.X_test_diff)

    def test_invert_order(self):
        target = np.array([-0.1, -0.3, -0.5, -0.7, -0.2, -0.1]).ravel()
        scores1 = invert_order(self.scores1)
        assert_allclose(scores1, target)

        scores2 = invert_order(self.scores2)
        assert_allclose(scores2, target)

        target = np.array([0.6, 0.4, 0.2, 0, 0.5, 0.6]).ravel()
        scores2 = invert_order(self.scores2, method='subtraction')
        assert_allclose(scores2, target)

    def tearDown(self):
        pass


class TestMetrics(unittest.TestCase):

    def setUp(self):
        self.y = [0, 0, 1, 1, 1, 0, 0, 0, 1, 0]
        self.labels_ = [0.1, 0.2, 0.2, 0.8, 0.2, 0.5, 0.7, 0.9, 1, 0.3]
        self.labels_short_ = [0.1, 0.2, 0.2, 0.8, 0.2, 0.5, 0.7, 0.9, 1]
        self.manual_labels = [0, 0, 0, 1, 0, 0, 1, 1, 1, 0]
        self.outliers_fraction = 0.3
        self.value_lists = [0.1, 0.3, 0.2, -2, 1.5, 0, 1, -1, -0.5, 11]

    def test_precision_n_scores(self):
        assert_equal(precision_score(self.y, self.manual_labels),
                     precision_n_scores(self.y, self.labels_))

    def test_get_label_n(self):
        assert_allclose(self.manual_labels,
                        get_label_n(self.y, self.labels_))

    def test_get_label_n_equal_3(self):
        manual_labels = [0, 0, 0, 1, 0, 0, 0, 1, 1, 0]
        assert_allclose(manual_labels,
                        get_label_n(self.y, self.labels_, n=3))

    def test_inconsistent_length(self):
        with assert_raises(ValueError):
            get_label_n(self.y, self.labels_short_)

    def test_score_to_label(self):
        manual_scores = [0.1, 0.4, 0.2, 0.3, 0.5, 0.9, 0.7, 1, 0.8, 0.6]
        labels = score_to_label(manual_scores, outliers_fraction=0.1)
        assert_allclose(labels, [0, 0, 0, 0, 0, 0, 0, 1, 0, 0])

        labels = score_to_label(manual_scores, outliers_fraction=0.3)
        assert_allclose(labels, [0, 0, 0, 0, 0, 1, 0, 1, 1, 0])

    def test_score_to_proba(self):
        manual_scores = np.array([[1, 2, 1], [1, 8, 1]])
        proba = score_to_proba(manual_scores)
        assert_allclose(proba, np.array([[0.25, 0.5, 0.25], [0.1, 0.8, 0.1]]))

    def test_list_diff(self):
        list1 = [1, 2, 5, 6, 7]
        list2 = [2, 3, 4]
        diff = list_diff(list1, list2)
        assert_allclose(diff, [1, 5, 6, 7])

    def test_argmaxn(self):
        ind = argmaxn(self.value_lists, 3)
        assert_equal(len(ind), 3)

        ind = argmaxn(self.value_lists, 3)
        assert_equal(np.sum(ind), np.sum([4, 6, 9]))

        ind = argmaxn(self.value_lists, 3, order='asc')
        assert_equal(np.sum(ind), np.sum([3, 7, 8]))

        with assert_raises(ValueError):
            argmaxn(self.value_lists, -1)
        with assert_raises(ValueError):
            argmaxn(self.value_lists, 20)

    def tearDown(self):
        pass


class TestCheckDetector(unittest.TestCase):

    def setUp(self):
        class DummyNegativeModel():
            def fit_negative(self):
                return

            def decision_function_negative(self):
                return

        class DummyPostiveModel():
            def fit(self):
                return

            def decision_function(self):
                return

        self.detector_positive = DummyPostiveModel()
        self.detector_negative = DummyNegativeModel()

    def test_check_detector_positive(self):
        check_detector(self.detector_positive)

    def test_check_detector_negative(self):
        with assert_raises(AttributeError):
            check_detector(self.detector_negative)


if __name__ == '__main__':
    unittest.main()