import unittest
import sys, platform

from skmultilearn.adapt import MLkNN
from skmultilearn.cluster import LabelCooccurrenceGraphBuilder
from skmultilearn.embedding import CLEMS, SKLearnEmbedder, EmbeddingClassifier
from skmultilearn.tests.classifier_basetest import ClassifierBaseTest
from sklearn.linear_model import LinearRegression
from sklearn.manifold import SpectralEmbedding
from copy import copy
import sklearn.metrics as metrics

if not (sys.version_info[0] == 2 or platform.architecture()[0]=='32bit'):
    from skmultilearn.embedding import OpenNetworkEmbedder


class EmbeddingTest(ClassifierBaseTest):
    TEST_NEIGHBORS = 3

    def classifiers(self):
        graph_builder = LabelCooccurrenceGraphBuilder(weighted=True, include_self_edges=False)

        param_dicts = {
            'GraphFactorization': dict(epoch=1),
            'GraRep': dict(Kstep=2),
            'HOPE': dict(),
            'LaplacianEigenmaps': dict(),
            'LINE': dict(epoch=1, order=1),
            'LLE': dict(),
        }

        if not (sys.version_info[0] == 2 or platform.architecture()[0] == '32bit'):
            for embedding in OpenNetworkEmbedder._EMBEDDINGS:
                if embedding == 'LLE':
                    dimension = 3
                else:
                    dimension = 4

                yield EmbeddingClassifier(
                        OpenNetworkEmbedder(copy(graph_builder), embedding, dimension,
                                        'add', True, param_dicts[embedding]),
                        LinearRegression(),
                        MLkNN(k=2)
                    )

        yield EmbeddingClassifier(
            SKLearnEmbedder(SpectralEmbedding(n_components=2)),
            LinearRegression(),
            MLkNN(k=2)
        )

        EmbeddingClassifier(
            CLEMS(metrics.accuracy_score, True),
            LinearRegression(),
            MLkNN(k=2),
            True
        )

    def test_if_embedding_classification_works_on_sparse_input(self):
        for classifier in self.classifiers():
            self.assertClassifierWorksWithSparsity(classifier, 'sparse')
            self.assertClassifierPredictsProbabilities(classifier, 'sparse')

    def test_if_embedding_classification_works_on_dense_input(self):
        for classifier in self.classifiers():
            self.assertClassifierWorksWithSparsity(classifier, 'dense')
            self.assertClassifierPredictsProbabilities(classifier, 'dense')


    def test_if_embedding_works_with_cross_validation(self):
        for classifier in self.classifiers():
            self.assertClassifierWorksWithCV(classifier)

if __name__ == '__main__':
    unittest.main()