from dataclasses import dataclass
import collections
import enum
import itertools

from torch import optim
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch.utils.data as data
import torch

from typing import List

import mnist_model
import emnist_model
import vgg_model
from active_learning_data import ActiveLearningData
from torch_utils import get_balanced_sample_indices
from train_model import train_model
from transformed_dataset import TransformedDataset
import subrange_dataset


@dataclass
class ExperimentData:
    active_learning_data: ActiveLearningData
    train_dataset: Dataset
    available_dataset: Dataset
    validation_dataset: Dataset
    test_dataset: Dataset
    initial_samples: List[int]


@dataclass
class DataSource:
    train_dataset: Dataset
    validation_dataset: Dataset = None
    test_dataset: Dataset = None
    shared_transform: object = None
    train_transform: object = None
    scoring_transform: object = None


def get_CINIC10(root="./"):
    cinic_directory = root + "data/CINIC-10"
    cinic_mean = [0.47889522, 0.47227842, 0.43047404]
    cinic_std = [0.24205776, 0.23828046, 0.25874835]

    train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip()])
    shared_transform = transforms.Compose([transforms.ToTensor(),
                                           transforms.Normalize(mean=cinic_mean,
                                                                std=cinic_std)])

    train_dataset = datasets.ImageFolder(cinic_directory + '/train')
    validation_dataset = datasets.ImageFolder(cinic_directory + '/valid')

    # Concatenate train and validation set to have more samples.
    merged_train_dataset = torch.utils.data.ConcatDataset([train_dataset, validation_dataset])

    test_dataset = datasets.ImageFolder(cinic_directory + '/test')

    return DataSource(
        train_dataset=merged_train_dataset,
        test_dataset=test_dataset,
        shared_transform=shared_transform,
        train_transform=train_transform,
    )


def get_MNIST():
    # num_classes=10, input_size=28
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST("data", train=False, transform=transform)

    return DataSource(train_dataset=train_dataset, test_dataset=test_dataset)


def get_RepeatedMNIST():
    # num_classes = 10, input_size = 28
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    org_train_dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
    train_dataset = data.ConcatDataset([org_train_dataset] * 3)

    test_dataset = datasets.MNIST("data", train=False, transform=transform)
    return DataSource(train_dataset=train_dataset, test_dataset=test_dataset)


class DatasetEnum(enum.Enum):
    mnist = "mnist"
    emnist = "emnist"
    emnist_bymerge = "emnist_bymerge"
    repeated_mnist_w_noise = "repeated_mnist_w_noise"
    repeated_mnist_w_noise2 = "repeated_mnist_w_noise2"
    repeated_mnist_w_noise5 = "repeated_mnist_w_noise5"
    mnist_w_noise = "mnist_w_noise"
    cinic10 = "cinic10"

    def get_data_source(self):
        if self == DatasetEnum.mnist:
            return get_MNIST()
        elif self in (
                DatasetEnum.repeated_mnist_w_noise2,
                DatasetEnum.repeated_mnist_w_noise5,
                DatasetEnum.repeated_mnist_w_noise,
                DatasetEnum.mnist_w_noise,
        ):
            # num_classes=10, input_size=28
            num_repetitions = {
                DatasetEnum.mnist_w_noise: 1,
                DatasetEnum.repeated_mnist_w_noise: 3,
                DatasetEnum.repeated_mnist_w_noise2: 2,
                DatasetEnum.repeated_mnist_w_noise5: 5,
            }[self]

            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
            org_train_dataset = datasets.MNIST("data", train=True, download=True, transform=transform)

            def apply_noise(idx, sample):
                data, target = sample
                return data + dataset_noise[idx], target

            dataset_noise = torch.empty(
                (len(org_train_dataset) * num_repetitions, 28, 28), dtype=torch.float32
            ).normal_(0.0, 0.1)

            train_dataset = TransformedDataset(
                data.ConcatDataset([org_train_dataset] * num_repetitions), transformer=apply_noise
            )

            test_dataset = datasets.MNIST("data", train=False, transform=transform)

            return DataSource(train_dataset=train_dataset, test_dataset=test_dataset)
        elif self in (DatasetEnum.emnist, DatasetEnum.emnist_bymerge):
            # num_classes=47, input_size=28,
            split = "balanced" if self == DatasetEnum.emnist else "bymerge"
            transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
            train_dataset = datasets.EMNIST("emnist_data", split=split, train=True, download=True,
                                            transform=transform)

            test_dataset = datasets.EMNIST("emnist_data", split=split, train=False, transform=transform)

            """
                Table II contains a summary of the EMNIST datasets and
                indicates which classes contain a validation subset in the
                training set. In these datasets, the last portion of the training
                set, equal in size to the testing set, is set aside as a validation
                set. Additionally, this subset is also balanced such that it
                contains an equal number of samples for each task. If the
                validation set is not to be used, then the training set can be
                used as one contiguous set.
            """
            if self == DatasetEnum.emnist:
                # Balanced contains a test set
                split_index = len(train_dataset) - len(test_dataset)
                train_dataset, validation_dataset = subrange_dataset.dataset_subset_split(train_dataset,
                                                                                          split_index)
            else:
                validation_dataset = None
            return DataSource(
                train_dataset=train_dataset, test_dataset=test_dataset, validation_dataset=validation_dataset
            )
        elif self == DatasetEnum.cinic10:
            return get_CINIC10()
        else:
            raise NotImplementedError(f"Unknown dataset {self}!")

    @property
    def num_classes(self):
        if self in (
                DatasetEnum.mnist,
                DatasetEnum.repeated_mnist_w_noise,
                DatasetEnum.repeated_mnist_w_noise2,
                DatasetEnum.repeated_mnist_w_noise5,
                DatasetEnum.mnist_w_noise,
        ):
            return 10
        elif self in (DatasetEnum.emnist, DatasetEnum.emnist_bymerge):
            return 47
        elif self == DatasetEnum.cinic10:
            return 10
        else:
            raise NotImplementedError(f"Unknown dataset {self}!")

    def create_bayesian_model(self, device):
        num_classes = self.num_classes
        if self in (
                DatasetEnum.mnist,
                DatasetEnum.repeated_mnist_w_noise,
                DatasetEnum.repeated_mnist_w_noise2,
                DatasetEnum.repeated_mnist_w_noise5,
                DatasetEnum.mnist_w_noise,
        ):
            return mnist_model.BayesianNet(num_classes=num_classes).to(device)
        elif self in (DatasetEnum.emnist, DatasetEnum.emnist_bymerge):
            return emnist_model.BayesianNet(num_classes=num_classes).to(device)
        elif self == DatasetEnum.cinic10:
            return vgg_model.vgg16_cinic10_bn(pretrained=True, num_classes=num_classes).to(device)
        else:
            raise NotImplementedError(f"Unknown dataset {self}!")

    def create_optimizer(self, model):
        if self == DatasetEnum.cinic10:
            optimizer = optim.Adam(model.parameters(), lr=1e-4)
        else:
            optimizer = optim.Adam(model.parameters())
        return optimizer

    def create_train_model_extra_args(self, optimizer):
        return {}

    def train_model(
            self,
            train_loader,
            test_loader,
            validation_loader,
            num_inference_samples,
            max_epochs,
            early_stopping_patience,
            desc,
            log_interval,
            device,
            epoch_results_store=None,
    ):
        model = self.create_bayesian_model(device)
        optimizer = self.create_optimizer(model)
        num_epochs, test_metrics = train_model(
            model,
            optimizer,
            max_epochs,
            early_stopping_patience,
            num_inference_samples,
            test_loader,
            train_loader,
            validation_loader,
            log_interval,
            desc,
            device,
            epoch_results_store=epoch_results_store,
            **self.create_train_model_extra_args(optimizer),
        )
        return model, num_epochs, test_metrics


def get_experiment_data(
        data_source,
        num_classes,
        initial_samples,
        reduced_dataset,
        samples_per_class,
        validation_set_size,
        balanced_test_set,
        balanced_validation_set,
):
    train_dataset, test_dataset, validation_dataset = (
        data_source.train_dataset,
        data_source.test_dataset,
        data_source.validation_dataset,
    )

    active_learning_data = ActiveLearningData(train_dataset)
    if initial_samples is None:
        initial_samples = list(
            itertools.chain.from_iterable(
                get_balanced_sample_indices(
                    get_targets(train_dataset), num_classes=num_classes, n_per_digit=samples_per_class
                ).values()
            )
        )

    # Split off the validation dataset after acquiring the initial samples.
    active_learning_data.acquire(initial_samples)

    if validation_dataset is None:
        print("Acquiring validation set from training set.")
        if not validation_set_size:
            validation_set_size = len(test_dataset)

        if not balanced_validation_set:
            validation_dataset = active_learning_data.extract_dataset(validation_set_size)
        else:
            print("Using a balanced validation set")
            validation_dataset = active_learning_data.extract_dataset_from_indices(
                balance_dataset_by_repeating(
                    active_learning_data.available_dataset, num_classes, validation_set_size, upsample=False
                )
            )
    else:
        if validation_set_size == 0:
            print("Using provided validation set.")
            validation_set_size = len(validation_dataset)
        if validation_set_size < len(validation_dataset):
            print("Shrinking provided validation set.")
            if not balanced_validation_set:
                validation_dataset = data.Subset(
                    validation_dataset, torch.randperm(len(validation_dataset))[:validation_set_size].tolist()
                )
            else:
                print("Using a balanced validation set")
                validation_dataset = data.Subset(
                    validation_dataset,
                    balance_dataset_by_repeating(validation_dataset, num_classes, validation_set_size, upsample=False),
                )

    if balanced_test_set:
        print("Using a balanced test set")
        print("Distribution of original test set classes:")
        classes = get_target_bins(test_dataset)
        print(classes)

        test_dataset = data.Subset(
            test_dataset, balance_dataset_by_repeating(test_dataset, num_classes, len(test_dataset))
        )

    if reduced_dataset:
        # Let's assume we won't use more than 1000 elements for our validation set.
        active_learning_data.extract_dataset(len(train_dataset) - max(len(train_dataset) // 20, 5000))
        test_dataset = subrange_dataset.SubrangeDataset(test_dataset, 0, max(len(test_dataset) // 10, 5000))
        if validation_dataset:
            validation_dataset = subrange_dataset.SubrangeDataset(validation_dataset, 0,
                                                                  len(validation_dataset) // 10)
        print("USING REDUCED DATASET!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")

    show_class_frequencies = True
    if show_class_frequencies:
        print("Distribution of training set classes:")
        classes = get_target_bins(train_dataset)
        print(classes)

        print("Distribution of validation set classes:")
        classes = get_target_bins(validation_dataset)
        print(classes)

        print("Distribution of test set classes:")
        classes = get_target_bins(test_dataset)
        print(classes)

        print("Distribution of pool classes:")
        classes = get_target_bins(active_learning_data.available_dataset)
        print(classes)

        print("Distribution of active set classes:")
        classes = get_target_bins(active_learning_data.active_dataset)
        print(classes)

    print(f"Dataset info:")
    print(f"\t{len(active_learning_data.active_dataset)} active samples")
    print(f"\t{len(active_learning_data.available_dataset)} available samples")
    print(f"\t{len(validation_dataset)} validation samples")
    print(f"\t{len(test_dataset)} test samples")

    if data_source.shared_transform is not None or data_source.train_transform is not None:
        train_dataset = TransformedDataset(
            active_learning_data.active_dataset,
            vision_transformer=compose_transformers([data_source.train_transform, data_source.shared_transform]),
        )
    else:
        train_dataset = active_learning_data.active_dataset

    if data_source.shared_transform is not None or data_source.scoring_transform is not None:
        available_dataset = TransformedDataset(
            active_learning_data.available_dataset,
            vision_transformer=compose_transformers([data_source.scoring_transform, data_source.shared_transform]),
        )
    else:
        available_dataset = active_learning_data.available_dataset

    if data_source.shared_transform is not None:
        test_dataset = TransformedDataset(test_dataset, vision_transformer=data_source.shared_transform)
        validation_dataset = TransformedDataset(validation_dataset, vision_transformer=data_source.shared_transform)

    return ExperimentData(
        active_learning_data=active_learning_data,
        train_dataset=train_dataset,
        available_dataset=available_dataset,
        validation_dataset=validation_dataset,
        test_dataset=test_dataset,
        initial_samples=initial_samples,
    )


def compose_transformers(iterable):
    iterable = list(filter(None, iterable))
    if len(iterable) == 0:
        return None
    if len(iterable) == 1:
        return iterable[0]
    return transforms.Compose(iterable)


# TODO: move to utils?
def get_target_bins(dataset):
    classes = collections.Counter(int(target) for target in get_targets(dataset))
    return classes


# TODO: move to utils?
def balance_dataset_by_repeating(dataset, num_classes, target_size, upsample=True):
    balanced_samples_indices = get_balanced_sample_indices(get_targets(dataset), num_classes, len(dataset)).values()

    if upsample:
        num_samples_per_class = max(
            max(len(samples_per_class) for samples_per_class in balanced_samples_indices),
            target_size // num_classes
        )
    else:
        num_samples_per_class = min(
            max(len(samples_per_class) for samples_per_class in balanced_samples_indices),
            target_size // num_classes
        )

    def sample_indices(indices, total_length):
        return (torch.randperm(total_length) % len(indices)).tolist()

    balanced_samples_indices = list(
        itertools.chain.from_iterable(
            [
                [samples_per_class[i] for i in sample_indices(samples_per_class, num_samples_per_class)]
                for samples_per_class in balanced_samples_indices
            ]
        )
    )

    print(
        f"Resampled dataset ({len(dataset)} samples) to a balanced set of {len(balanced_samples_indices)} samples!")

    return balanced_samples_indices


# TODO: move to utils?
def get_targets(dataset):
    """Get the targets of a dataset without any target target transforms(!)."""
    if isinstance(dataset, TransformedDataset):
        return get_targets(dataset.dataset)
    if isinstance(dataset, data.Subset):
        targets = get_targets(dataset.dataset)
        return torch.as_tensor(targets)[dataset.indices]
    if isinstance(dataset, data.ConcatDataset):
        return torch.cat([get_targets(sub_dataset) for sub_dataset in dataset.datasets])

    if isinstance(
            dataset, (datasets.MNIST, datasets.ImageFolder,)
    ):
        return torch.as_tensor(dataset.targets)
    if isinstance(dataset, datasets.SVHN):
        return dataset.labels

    raise NotImplementedError(f"Unknown dataset {dataset}!")