from __future__ import absolute_import, print_function, division, unicode_literals
import unittest
from xcessiv import models
from sklearn.ensemble import RandomForestClassifier


class TestReturnTrainDataFromJSON(unittest.TestCase):
    def setUp(self):
        self.extraction = models.Extraction()
        self.extraction.main_dataset['source'] = ''.join([
            "from sklearn.datasets import load_digits\n",
            "\n",
            "\n",
            "def extract_main_dataset():\n",
            "    X, y = load_digits(return_X_y=True)\n",
            "    return X, y"
        ])

    def test_main_is_train(self):
        X, y = self.extraction.return_train_dataset()
        assert X.shape == (1797, 64)
        assert y.shape == (1797,)

    def test_split_main_for_test(self):
        self.extraction.test_dataset['method'] = 'split_from_main'
        self.extraction.test_dataset['split_ratio'] = 0.1
        self.extraction.test_dataset['split_seed'] = 8
        X, y = self.extraction.return_train_dataset()
        assert X.shape == (1617, 64)
        assert y.shape == (1617,)


class TestReturnTestDataFromJSON(unittest.TestCase):
    def setUp(self):
        self.extraction = models.Extraction()
        self.extraction.main_dataset['source'] = ''.join([
            "from sklearn.datasets import load_digits\n",
            "\n",
            "\n",
            "def extract_main_dataset():\n",
            "    X, y = load_digits(return_X_y=True)\n",
            "    return X, y"
        ])
        self.extraction.test_dataset['method'] = 'split_from_main'
        self.extraction.test_dataset['split_ratio'] = 0.1
        self.extraction.test_dataset['split_seed'] = 8

    def test_split_main_for_test(self):
        X, y = self.extraction.return_test_dataset()
        assert X.shape == (180, 64)
        assert y.shape == (180,)

    def test_test_dataset_from_source(self):
        self.extraction.test_dataset["method"] = "source"
        self.extraction.test_dataset["source"] = ''.join([
            "from sklearn.datasets import load_digits\n",
            "def extract_test_dataset():\n",
            "    X, y = load_digits(return_X_y=True)\n",
            "    return X, y"
        ])
        X, y = self.extraction.return_test_dataset()
        assert X.shape == (1797, 64)
        assert y.shape == (1797,)


class TestReturnEstimator(unittest.TestCase):
    def setUp(self):
        self.base_learner_origin = models.BaseLearnerOrigin(
            source=''.join([
                "from sklearn.ensemble import RandomForestClassifier\n",
                "base_learner = RandomForestClassifier(random_state=8)"
            ])
        )

    def test_return_estimator_from_json(self):
        est = self.base_learner_origin.return_estimator()
        assert isinstance(est, RandomForestClassifier)