import torchvision.transforms as transforms
import torchvision
import torch
import numpy as np
import os
import codecs
from torch.distributions.categorical import Categorical
import torch.utils.data as data
from PIL import Image
import errno


def _reduce_class(set, classes, train, preserve_label_space=True):
    if classes is None:
        return

    new_class_idx = {}
    for c in classes:
        new_class_idx[c] = new_class_idx.__len__()

    new_data = []
    new_labels = []
    if train:
        all_data = set.train_data
        labels = set.train_labels
    else:
        all_data = set.test_data
        labels = set.test_labels

    for data, label in zip(all_data, labels):
        if type(label) == int:
            label_val = label
        else:
            label_val = label.item()
        if label_val in classes:
            new_data.append(data)
            if preserve_label_space:
                new_labels += [label_val]
            else:
                new_labels += [new_class_idx[label_val]]
    if type(new_data[0]) == np.ndarray:
        new_data = np.array(new_data)
    elif type(new_data[0]) == torch.Tensor:
        new_data = torch.stack(new_data)
    else:
        assert False, "Reduce class not supported"
    if train:
        set.train_data = new_data
        set.train_labels = new_labels
    else:
        set.test_data = new_data
        set.test_labels = new_labels


class Permutation(torch.utils.data.Dataset):
    """
    A dataset wrapper that permute the position of features
    """
    def __init__(self, dataset, permute_idx, target_offset):
        super(Permutation,self).__init__()
        self.dataset = dataset
        self.permute_idx = permute_idx
        self.target_offset = target_offset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, target = self.dataset[index]
        target = target + self.target_offset
        shape = img.size()
        img = img.view(-1)[self.permute_idx].view(shape)
        return img, target


class DatasetsLoaders:
    def __init__(self, dataset, batch_size=4, num_workers=4, pin_memory=True, **kwargs):
        self.dataset_name = dataset
        self.valid_loader = None
        self.num_workers = num_workers
        if self.num_workers is None:
            self.num_workers = 4

        self.random_erasing = kwargs.get("random_erasing", False)
        self.reduce_classes = kwargs.get("reduce_classes", None)
        self.permute = kwargs.get("permute", False)
        self.target_offset = kwargs.get("target_offset", 0)

        pin_memory = pin_memory if torch.cuda.is_available() else False
        self.batch_size = batch_size
        cifar10_mean = (0.5, 0.5, 0.5)
        cifar10_std = (0.5, 0.5, 0.5)
        cifar100_mean = (0.5070, 0.4865, 0.4409)
        cifar100_std = (0.2673, 0.2564, 0.2761)
        mnist_mean = [33.318421449829934]
        mnist_std = [78.56749083061408]
        fashionmnist_mean = [73.14654541015625]
        fashionmnist_std = [89.8732681274414]

        if dataset == "CIFAR10":
            # CIFAR10:
            #   type               : uint8
            #   shape              : train_set.train_data.shape (50000, 32, 32, 3)
            #   test data shape    : (10000, 32, 32, 3)
            #   number of channels : 3
            #   Mean per channel   : train_set.train_data[:,:,:,0].mean() 125.306918046875
            #                        train_set.train_data[:,:,:,1].mean() 122.95039414062499
            #                        train_set.train_data[:,:,:,2].mean() 113.86538318359375
            #   Std per channel   :  train_set.train_data[:, :, :, 0].std() 62.993219278136884
            #                        train_set.train_data[:, :, :, 1].std() 62.088707640014213
            #                        train_set.train_data[:, :, :, 2].std() 66.704899640630913
            self.mean = cifar10_mean
            self.std = cifar10_std

            transform_train = transforms.Compose([
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])

            transform_test = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            ])

            self.train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                          download=True, transform=transform_train)
            self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size,
                                                            shuffle=True, num_workers=self.num_workers,
                                                            pin_memory=pin_memory)

            self.test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                                         download=True, transform=transform_test)
            self.test_loader = torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size,
                                                           shuffle=False, num_workers=self.num_workers,
                                                           pin_memory=pin_memory)
        if dataset == "CIFAR100":
            # CIFAR100:
            #   type               : uint8
            #   shape              : train_set.train_data.shape (50000, 32, 32, 3)
            #   test data shape    : (10000, 32, 32, 3)
            #   number of channels : 3
            #   Mean per channel   : train_set.train_data[:,:,:,0].mean() 129.304165605/255=0.5070
            #                        train_set.train_data[:,:,:,1].mean() 124.069962695/255=0.4865
            #                        train_set.train_data[:,:,:,2].mean() 112.434050059/255=0.4409
            #   Std per channel   :  train_set.train_data[:, :, :, 0].std() 68.1702428992/255=0.2673
            #                        train_set.train_data[:, :, :, 1].std() 65.3918080439/255=0.2564
            #                        train_set.train_data[:, :, :, 2].std() 70.418370188/255=0.2761

            self.mean = cifar100_mean
            self.std = cifar100_std
            transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize(self.mean, self.std)])

            self.train_set = torchvision.datasets.CIFAR100(root='./data', train=True,
                                                           download=True, transform=transform)
            _reduce_class(self.train_set, self.reduce_classes, train=True,
                          preserve_label_space=kwargs.get("preserve_label_space"))
            self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size,
                                                            shuffle=True, num_workers=self.num_workers,
                                                            pin_memory=pin_memory)

            self.test_set = torchvision.datasets.CIFAR100(root='./data', train=False,
                                                          download=True, transform=transform)
            _reduce_class(self.test_set, self.reduce_classes, train=False,
                          preserve_label_space=kwargs.get("preserve_label_space"))
            self.test_loader = torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size,
                                                           shuffle=False, num_workers=self.num_workers,
                                                           pin_memory=pin_memory)
        if dataset == "MNIST":
            # MNIST:
            #   type               : torch.ByteTensor
            #   shape              : train_set.train_data.shape torch.Size([60000, 28, 28])
            #   test data shape    : [10000, 28, 28]
            #   number of channels : 1
            #   Mean per channel   : 33.318421449829934
            #   Std per channel    : 78.56749083061408

            # Transforms
            self.mean = mnist_mean
            self.std = mnist_std
            if kwargs.get("pad_to_32", False):
                transform = transforms.Compose(
                    [transforms.Pad(2, fill=0, padding_mode='constant'),
                     transforms.ToTensor(),
                     transforms.Normalize(mean=(0.1000,), std=(0.2752,))])
            else:
                transform = transforms.Compose(
                    [transforms.ToTensor()])

            # Create train set
            self.train_set = torchvision.datasets.MNIST(root='./data', train=True,
                                                        download=True, transform=transform)
            if kwargs.get("permutation", False):
                # Permute if permutation is provided
                self.train_set = Permutation(torchvision.datasets.MNIST(root='./data', train=True,
                                                                        download=True, transform=transform),
                                             kwargs.get("permutation", False), self.target_offset)
            # Reduce classes if necessary
            _reduce_class(self.train_set, self.reduce_classes, train=True,
                          preserve_label_space=kwargs.get("preserve_label_space"))
            # Remap labels
            if kwargs.get("labels_remapping", False):
                labels_remapping = kwargs.get("labels_remapping", False)
                for lbl_idx in range(len(self.train_set.train_labels)):
                    self.train_set.train_labels[lbl_idx] = labels_remapping[self.train_set.train_labels[lbl_idx]]

            self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size,
                                                            shuffle=True, num_workers=self.num_workers,
                                                            pin_memory=pin_memory)

            # Create test set
            self.test_set = torchvision.datasets.MNIST(root='./data', train=False,
                                                       download=True, transform=transform)
            if kwargs.get("permutation", False):
                # Permute if permutation is provided
                self.test_set = Permutation(torchvision.datasets.MNIST(root='./data', train=False,
                                                                        download=True, transform=transform),
                                             kwargs.get("permutation", False), self.target_offset)
            # Reduce classes if necessary
            _reduce_class(self.test_set, self.reduce_classes, train=False,
                          preserve_label_space=kwargs.get("preserve_label_space"))
            # Remap labels
            if kwargs.get("labels_remapping", False):
                labels_remapping = kwargs.get("labels_remapping", False)
                for lbl_idx in range(len(self.test_set.test_labels)):
                    self.test_set.test_labels[lbl_idx] = labels_remapping[self.test_set.test_labels[lbl_idx]]

            self.test_loader = torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size,
                                                           shuffle=False, num_workers=self.num_workers,
                                                           pin_memory=pin_memory)
        if dataset == "FashionMNIST":
            # MNIST:
            #   type               : torch.ByteTensor
            #   shape              : train_set.train_data.shape torch.Size([60000, 28, 28])
            #   test data shape    : [10000, 28, 28]
            #   number of channels : 1
            #   Mean per channel   : fm.train_data.type(torch.FloatTensor).mean() is 72.94035223214286
            #   Std per channel    : fm.train_data.type(torch.FloatTensor).std() is 90.0211833054075
            self.mean = fashionmnist_mean
            self.std = fashionmnist_std
            # transform = transforms.Compose(
            #     [transforms.ToTensor(),
            #      transforms.Normalize(self.mean, self.std)])
            # transform = transforms.Compose(
            #     [transforms.ToTensor()])
            transform = transforms.Compose(
                [transforms.Pad(2),
                 transforms.ToTensor(),
                 transforms.Normalize((72.94035223214286 / 255,), (90.0211833054075 / 255,))])



            self.train_set = torchvision.datasets.FashionMNIST(root='./data/fmnist', train=True,
                                                        download=True, transform=transform)
            self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size,
                                                            shuffle=True, num_workers=self.num_workers,
                                                            pin_memory=pin_memory)

            self.test_set = torchvision.datasets.FashionMNIST(root='./data/fmnist', train=False,
                                                       download=True, transform=transform)
            self.test_loader = torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size,
                                                           shuffle=False, num_workers=self.num_workers,
                                                           pin_memory=pin_memory)
        if dataset == "SVHN":
            # SVHN:
            #   type               : numpy.ndarray
            #   shape              : self.train_set.data.shape is (73257, 3, 32, 32)
            #   test data shape    : self.test_set.data.shape is (26032, 3, 32, 32)
            #   number of channels : 3
            #   Mean per channel   : sv.data.mean(axis=0).mean(axis=1).mean(axis=1) is array([111.60893668, 113.16127466, 120.56512767])
            #   Std per channel    : np.transpose(sv.data, (1, 0, 2, 3)).reshape(3,-1).std(axis=1) is array([50.49768174, 51.2589843 , 50.24421614])
            self.mean = mnist_mean
            self.std = mnist_std
            # transform = transforms.Compose(
            #     [transforms.ToTensor(),
            #      transforms.Normalize(self.mean, self.std)])
            transform = transforms.Compose(
                [transforms.ToTensor(),
                 transforms.Normalize((111.60893668/255, 113.16127466/255, 120.56512767/255), (50.49768174/255, 51.2589843/255, 50.24421614/255))])



            self.train_set = torchvision.datasets.SVHN(root='./data', split="train",
                                                               download=True, transform=transform)
            self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size,
                                                            shuffle=True, num_workers=self.num_workers,
                                                            pin_memory=pin_memory)

            self.test_set = torchvision.datasets.SVHN(root='./data', split="test",
                                                              download=True, transform=transform)
            self.test_loader = torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size,
                                                           shuffle=False, num_workers=self.num_workers,
                                                           pin_memory=pin_memory)
        if dataset == "NOTMNIST":
            # MNIST:
            #   type               : torch.ByteTensor
            #   shape              : train_set.train_data.shape torch.Size([60000, 28, 28])
            #   test data shape    : [10000, 28, 28]
            #   number of channels : 1
            #   Mean per channel   : nm.train_data.type(torch.FloatTensor).mean() is 106.51712372448979
            #   Std per channel    : nm.train_data.type(torch.FloatTensor).std() is 115.76734631096612
            self.mean = mnist_mean
            self.std = mnist_std
            transform = transforms.Compose(
                [transforms.Pad(2),
                 transforms.ToTensor(),
                 transforms.Normalize((106.51712372448979 / 255,), (115.76734631096612 / 255,))])

            self.train_set = NOTMNIST(root='./data/notmnist', train=True, download=True, transform=transform)

            self.train_loader = torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size,
                                                            shuffle=True, num_workers=self.num_workers,
                                                            pin_memory=pin_memory)

            self.test_set = NOTMNIST(root='./data/notmnist', train=False, download=True, transform=transform)
            self.test_loader = torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size,
                                                           shuffle=False, num_workers=self.num_workers,
                                                           pin_memory=pin_memory)
        if dataset == "CONTPERMUTEDPADDEDMNIST":
            transform = transforms.Compose(
                [transforms.Pad(2, fill=0, padding_mode='constant'),
                 transforms.ToTensor(),
                 transforms.Normalize(mean=(0.1000,), std=(0.2752,))])

            # Original MNIST
            tasks_datasets = [torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)]
            tasks_samples_indices = [torch.tensor(range(len(tasks_datasets[0])), dtype=torch.int32)]
            total_len = len(tasks_datasets[0])
            test_loaders = [torch.utils.data.DataLoader(torchvision.datasets.MNIST(root='./data', train=False,
                                                                                   download=True, transform=transform),
                                                        batch_size=self.batch_size, shuffle=False,
                                                        num_workers=self.num_workers, pin_memory=pin_memory)]
            self.num_of_permutations = len(kwargs.get("all_permutation"))
            all_permutation = kwargs.get("all_permutation", None)
            for p_idx in range(self.num_of_permutations):
                # Create permuation
                permutation = all_permutation[p_idx]

                # Add train set:
                tasks_datasets.append(Permutation(torchvision.datasets.MNIST(root='./data', train=True,
                                                                             download=True, transform=transform),
                                                  permutation, target_offset=0))

                tasks_samples_indices.append(torch.tensor(range(total_len,
                                                                total_len + len(tasks_datasets[-1])
                                                                ), dtype=torch.int32))
                total_len += len(tasks_datasets[-1])
                # Add test set:
                test_set = Permutation(torchvision.datasets.MNIST(root='./data', train=False,
                                                                  download=True, transform=transform),
                                       permutation, self.target_offset)
                test_loaders.append(torch.utils.data.DataLoader(test_set, batch_size=self.batch_size,
                                                                shuffle=False, num_workers=self.num_workers,
                                                                pin_memory=pin_memory))
            self.test_loader = test_loaders
            # Concat datasets
            total_iters = kwargs.get("total_iters", None)

            assert total_iters is not None
            beta = kwargs.get("contpermuted_beta", 3)
            all_datasets = torch.utils.data.ConcatDataset(tasks_datasets)

            # Create probabilities of tasks over iterations
            self.tasks_probs_over_iterations = [_create_task_probs(total_iters, self.num_of_permutations+1, task_id,
                                                                    beta=beta) for task_id in
                                                 range(self.num_of_permutations+1)]
            normalize_probs = torch.zeros_like(self.tasks_probs_over_iterations[0])
            for probs in self.tasks_probs_over_iterations:
                normalize_probs.add_(probs)
            for probs in self.tasks_probs_over_iterations:
                probs.div_(normalize_probs)
            self.tasks_probs_over_iterations = torch.cat(self.tasks_probs_over_iterations).view(-1, self.tasks_probs_over_iterations[0].shape[0])
            tasks_probs_over_iterations_lst = []
            for col in range(self.tasks_probs_over_iterations.shape[1]):
                tasks_probs_over_iterations_lst.append(self.tasks_probs_over_iterations[:, col])
            self.tasks_probs_over_iterations = tasks_probs_over_iterations_lst

            train_sampler = ContinuousMultinomialSampler(data_source=all_datasets, samples_in_batch=self.batch_size,
                                                         tasks_samples_indices=tasks_samples_indices,
                                                         tasks_probs_over_iterations=
                                                             self.tasks_probs_over_iterations,
                                                         num_of_batches=kwargs.get("iterations_per_virtual_epc", 1))
            self.train_loader = torch.utils.data.DataLoader(all_datasets, batch_size=self.batch_size,
                                                            num_workers=self.num_workers, sampler=train_sampler, pin_memory=pin_memory)


class ContinuousMultinomialSampler(torch.utils.data.Sampler):
    r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
    If with replacement, then user can specify ``num_samples`` to draw.
    self.tasks_probs_over_iterations is the probabilities of tasks over iterations.
    self.samples_distribution_over_time is the actual distribution of samples over iterations
                                            (the result of sampling from self.tasks_probs_over_iterations).
    Arguments:
        data_source (Dataset): dataset to sample from
        num_samples (int): number of samples to draw, default=len(dataset)
        replacement (bool): samples are drawn with replacement if ``True``, default=False
    """

    def __init__(self, data_source, samples_in_batch=128, num_of_batches=69, tasks_samples_indices=None,
                 tasks_probs_over_iterations=None):
        self.data_source = data_source
        assert tasks_samples_indices is not None, "Must provide tasks_samples_indices - a list of tensors," \
                                                  "each item in the list corrosponds to a task, each item of the " \
                                                  "tensor corrosponds to index of sample of this task"
        self.tasks_samples_indices = tasks_samples_indices
        self.num_of_tasks = len(self.tasks_samples_indices)
        assert tasks_probs_over_iterations is not None, "Must provide tasks_probs_over_iterations - a list of " \
                                                         "probs per iteration"
        assert all([isinstance(probs, torch.Tensor) and len(probs) == self.num_of_tasks for
                    probs in tasks_probs_over_iterations]), "All probs must be tensors of len" \
                                                              + str(self.num_of_tasks) + ", first tensor type is " \
                                                              + str(type(tasks_probs_over_iterations[0])) + ", and " \
                                                              " len is " + str(len(tasks_probs_over_iterations[0]))
        self.tasks_probs_over_iterations = tasks_probs_over_iterations
        self.current_iteration = 0

        self.samples_in_batch = samples_in_batch
        self.num_of_batches = num_of_batches

        # Create the samples_distribution_over_time
        self.samples_distribution_over_time = [[] for _ in range(self.num_of_tasks)]
        self.iter_indices_per_iteration = []

        if not isinstance(self.samples_in_batch, int) or self.samples_in_batch <= 0:
            raise ValueError("num_samples should be a positive integeral "
                             "value, but got num_samples={}".format(self.samples_in_batch))

    def generate_iters_indices(self, num_of_iters):
        from_iter = len(self.iter_indices_per_iteration)
        for iter_num in range(from_iter, from_iter+num_of_iters):

            # Get random number of samples per task (according to iteration distribution)
            tsks = Categorical(probs=self.tasks_probs_over_iterations[iter_num]).sample(torch.Size([self.samples_in_batch]))
            # Generate samples indices for iter_num
            iter_indices = torch.zeros(0, dtype=torch.int32)
            for task_idx in range(self.num_of_tasks):
                if self.tasks_probs_over_iterations[iter_num][task_idx] > 0:
                    num_samples_from_task = (tsks == task_idx).sum().item()
                    self.samples_distribution_over_time[task_idx].append(num_samples_from_task)
                    # Randomize indices for each task (to allow creation of random task batch)
                    tasks_inner_permute = np.random.permutation(len(self.tasks_samples_indices[task_idx]))
                    rand_indices_of_task = tasks_inner_permute[:num_samples_from_task]
                    iter_indices = torch.cat([iter_indices, self.tasks_samples_indices[task_idx][rand_indices_of_task]])
                else:
                    self.samples_distribution_over_time[task_idx].append(0)
            self.iter_indices_per_iteration.append(iter_indices.tolist())

    def __iter__(self):
        self.generate_iters_indices(self.num_of_batches)
        self.current_iteration += self.num_of_batches
        return iter([item for sublist in self.iter_indices_per_iteration[self.current_iteration - self.num_of_batches:self.current_iteration] for item in sublist])

    def __len__(self):
        return len(self.samples_in_batch)


def _get_linear_line(start, end, direction="up"):
    if direction == "up":
        return torch.FloatTensor([(i - start)/(end-start) for i in range(start, end)])
    return torch.FloatTensor([1 - ((i - start) / (end - start)) for i in range(start, end)])


def _create_task_probs(iters, tasks, task_id, beta=3):
    if beta <= 1:
        peak_start = int((task_id/tasks)*iters)
        peak_end = int(((task_id + 1) / tasks)*iters)
        start = peak_start
        end = peak_end
    else:
        start = max(int(((beta*task_id - 1)*iters)/(beta*tasks)), 0)
        peak_start = int(((beta*task_id + 1)*iters)/(beta*tasks))
        peak_end = int(((beta * task_id + (beta - 1)) * iters) / (beta * tasks))
        end = min(int(((beta * task_id + (beta + 1)) * iters) / (beta * tasks)), iters)

    probs = torch.zeros(iters, dtype=torch.float)
    if task_id == 0:
        probs[start:peak_start].add_(1)
    else:
        probs[start:peak_start] = _get_linear_line(start, peak_start, direction="up")
    probs[peak_start:peak_end].add_(1)
    if task_id == tasks - 1:
        probs[peak_end:end].add_(1)
    else:
        probs[peak_end:end] = _get_linear_line(peak_end, end, direction="down")
    return probs


###
# NotMNIST
###
class NOTMNIST(data.Dataset):
    """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where ``processed/training.pt``
            and  ``processed/test.pt`` exist.
        train (bool, optional): If True, creates dataset from ``training.pt``,
            otherwise from ``test.pt``.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
    urls = [
        'https://github.com/davidflanagan/notMNIST-to-MNIST/raw/master/t10k-images-idx3-ubyte.gz',
        'https://github.com/davidflanagan/notMNIST-to-MNIST/raw/master/t10k-labels-idx1-ubyte.gz',
        'https://github.com/davidflanagan/notMNIST-to-MNIST/raw/master/train-images-idx3-ubyte.gz',
        'https://github.com/davidflanagan/notMNIST-to-MNIST/raw/master/train-labels-idx1-ubyte.gz',
    ]
    raw_folder = 'raw'
    processed_folder = 'processed'
    training_file = 'training.pt'
    test_file = 'test.pt'

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found.' +
                               ' You can use download=True to download it')

        if self.train:
            self.train_data, self.train_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.training_file))
        else:
            self.test_data, self.test_labels = torch.load(
                os.path.join(self.root, self.processed_folder, self.test_file))

    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        if self.train:
            img, target = self.train_data[index], self.train_labels[index]
        else:
            img, target = self.test_data[index], self.test_labels[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder, self.training_file)) and \
            os.path.exists(os.path.join(self.root, self.processed_folder, self.test_file))

    def download(self):
        """Download the MNIST data if it doesn't exist in processed_folder already."""
        from six.moves import urllib
        import gzip

        if self._check_exists():
            return

        # download files
        try:
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for url in self.urls:
            print('Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            with open(file_path.replace('.gz', ''), 'wb') as out_f, \
                    gzip.GzipFile(file_path) as zip_f:
                out_f.write(zip_f.read())
            os.unlink(file_path)

        # process and save as torch files
        print('Processing...')

        training_set = (
            self.read_image_file(os.path.join(self.root, self.raw_folder, 'train-images-idx3-ubyte')),
            self.read_label_file(os.path.join(self.root, self.raw_folder, 'train-labels-idx1-ubyte'))
        )
        test_set = (
            self.read_image_file(os.path.join(self.root, self.raw_folder, 't10k-images-idx3-ubyte')),
            self.read_label_file(os.path.join(self.root, self.raw_folder, 't10k-labels-idx1-ubyte'))
        )
        with open(os.path.join(self.root, self.processed_folder, self.training_file), 'wb') as f:
            torch.save(training_set, f)
        with open(os.path.join(self.root, self.processed_folder, self.test_file), 'wb') as f:
            torch.save(test_set, f)

        print('Done!')

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = 'train' if self.train is True else 'test'
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

    @staticmethod
    def get_int(b):
        return int(codecs.encode(b, 'hex'), 16)

    def read_label_file(self, path):
        with open(path, 'rb') as f:
            data = f.read()
            assert self.get_int(data[:4]) == 2049
            length = self.get_int(data[4:8])
            parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
            return torch.from_numpy(parsed).view(length).long()

    def read_image_file(self, path):
        with open(path, 'rb') as f:
            data = f.read()
            assert self.get_int(data[:4]) == 2051
            length = self.get_int(data[4:8])
            num_rows = self.get_int(data[8:12])
            num_cols = self.get_int(data[12:16])
            images = []
            parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
            return torch.from_numpy(parsed).view(length, num_rows, num_cols)


###########################################################################
# Callable datasets
###########################################################################


def ds_mnist(**kwargs):
    """
    MNIST dataset.
    :param batch_size: batch size
           num_workers: num of workers
           pad_to_32: If true, will pad digits to size 32x32 and normalize to zero mean and unit variance.
    :return: Tuple with two lists.
             First list of the tuple is a list of 1 train loaders.
             Second list of the tuple is a list of 1 test loaders.
    """
    dataset = [DatasetsLoaders("MNIST", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1),
                               pad_to_32=kwargs.get("pad_to_32", False))]
    test_loaders = [ds.test_loader for ds in dataset]
    train_loaders = [ds.train_loader for ds in dataset]
    return train_loaders, test_loaders


def ds_split_mnist(**kwargs):
    """
    Split MNIST dataset. Consists of 5 tasks: digits 0 & 1, 2 & 3, 4 & 5, 6 & 7, and 8 & 9.
    :param batch_size: batch size
           num_workers: num of workers
           pad_to_32: If true, will pad digits to size 32x32 and normalize to zero mean and unit variance.
           separate_labels_space: If true, each task will have its own label space (e.g. 01, 23 etc.).
                                  If false, all tasks will have label space of 0,1 only.
    :return: Tuple with two lists.
             First list of the tuple is a list of 5 train loaders, each loader is a task.
             Second list of the tuple is a list of 5 test loaders, each loader is a task.
    """
    classes_lst = [
        [0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]
    ]
    dataset = [DatasetsLoaders("MNIST", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1),
                               reduce_classes=cl, pad_to_32=kwargs.get("pad_to_32", False),
                               preserve_label_space=kwargs.get("separate_labels_space")) for cl in classes_lst]
    test_loaders = [ds.test_loader for ds in dataset]
    train_loaders = [ds.train_loader for ds in dataset]
    return train_loaders, test_loaders


def ds_padded_split_mnist(**kwargs):
    """
    Split MNIST dataset, padded to 32x32 pixels.
    """
    return ds_split_mnist(pad_to_32=True, **kwargs)


def ds_split_mnist_offline(**kwargs):
    """
    Split MNIST dataset. Offline means that all tasks are mixed together.
    """
    if kwargs.get("separate_labels_space"):
        return ds_mnist(**kwargs)
    else:
        return ds_mnist(labels_remapping={l: l % 2 for l in range(10)}, **kwargs)


def ds_padded_split_mnist_offline(**kwargs):
    """
    Split MNIST dataset. Padded to 32x32. Offline means that all tasks are mixed together.
    """
    return ds_split_mnist_offline(pad_to_32=True, **kwargs)


def ds_permuted_mnist(**kwargs):
    """
    Permuted MNIST dataset.
    First task is the MNIST datasets (with 10 possible labels).
    Other tasks are permutations (pixel-wise) of the MNIST datasets (with 10 possible labels).
    :param batch_size: batch size
           num_workers: num of workers
           pad_to_32: If true, will pad digits to size 32x32 and normalize to zero mean and unit variance.
           permutations: A list of permutations. Each permutation should be a list containing new pixel position.
           separate_labels_space: True for seperated labels space - task i labels will be (10*i) to (10*i + 9).
                                  False for unified labels space - all tasks will have labels of 0 to 9.
    :return: Tuple with two lists.
             First list of the tuple is a list of train loaders, each loader is a task.
             Second list of the tuple is a list of test loaders, each loader is a task.
    """
    # First task
    dataset = [DatasetsLoaders("MNIST", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1), pad_to_32=kwargs.get("pad_to_32", False))]
    target_offset = 0
    permutations = kwargs.get("permutations", [])
    for pidx in range(len(permutations)):
        if kwargs.get("separate_labels_space"):
            target_offset = (pidx + 1) * 10
        dataset.append(DatasetsLoaders("MNIST", batch_size=kwargs.get("batch_size", 128),
                                       num_workers=kwargs.get("num_workers", 1),
                                       permutation=permutations[pidx], target_offset=target_offset,
                                       pad_to_32=kwargs.get("pad_to_32", False)))
    # For offline permuted we take the datasets and mix them.
    if kwargs.get("offline", False):
        train_sets = []
        test_sets = []
        for ds in dataset:
            train_sets.append(ds.train_set)
            test_sets.append(ds.test_set)
        train_set = torch.utils.data.ConcatDataset(train_sets)
        test_set = torch.utils.data.ConcatDataset(test_sets)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=kwargs.get("batch_size", 128), shuffle=True,
                                                   num_workers=kwargs.get("num_workers", 1), pin_memory=True)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=kwargs.get("batch_size", 128), shuffle=False,
                                                  num_workers=kwargs.get("num_workers", 1), pin_memory=True)
        return [train_loader], [test_loader]
    test_loaders = [ds.test_loader for ds in dataset]
    train_loaders = [ds.train_loader for ds in dataset]
    return train_loaders, test_loaders


def ds_padded_permuted_mnist(**kwargs):
    """
    Permuted MNIST dataset, padded to 32x32.
    """
    return ds_permuted_mnist(pad_to_32=True, **kwargs)


def ds_permuted_mnist_offline(**kwargs):
    """
    Permuted MNIST dataset. Offline means that all tasks are mixed together.
    """
    return ds_permuted_mnist(offline=True, **kwargs)


def ds_padded_permuted_mnist_offline(**kwargs):
    """
    Permuted MNIST dataset, padded to 32x32. Offline means that all tasks are mixed together.
    """
    return ds_permuted_mnist(pad_to_32=True, offline=True, **kwargs)


def ds_padded_cont_permuted_mnist(**kwargs):
    """
    Continuous Permuted MNIST dataset, padded to 32x32.
    Notice that this dataloader is aware to the epoch number, therefore if the training is loaded from a checkpoint
        adjustments might be needed. 
    Access dataset.tasks_probs_over_iterations to see the tasks probabilities for each iteration.
    :param num_epochs: Number of epochs for the training (since it builds distribution over iterations,
                            it needs this information in advance)
    :param iterations_per_virtual_epc: In continuous task-agnostic learning, the notion of epoch does not exists,
                                        since we cannot define 'passing over the whole dataset'. Therefore,
                                        we define "iterations_per_virtual_epc" -
                                        how many iterations consist a single epoch.
    :param contpermuted_beta: The proportion in which the tasks overlap. 4 means that 1/4 of a task duration will
                                consist of data from previous/next task. Larger values means less overlapping.
    :param permutations: The permutations which will be used (first task is always the original MNIST).
    :param batch_size: Batch size.
    :param num_workers: Num workers.
    :return: A tuple of (train_loaders, test_loaders). train_loaders is a list of 1 data loader - it loads the
                permuted MNIST dataset continuously as described in the paper. test_loaders is a list of 1+permutations
                data loaders, one for each dataset.

    """
    dataset = [DatasetsLoaders("CONTPERMUTEDPADDEDMNIST", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1),
                               total_iters=(kwargs.get("num_epochs")*kwargs.get("iterations_per_virtual_epc")),
                               contpermuted_beta=kwargs.get("contpermuted_beta"),
                               iterations_per_virtual_epc=kwargs.get("iterations_per_virtual_epc"),
                               all_permutation=kwargs.get("permutations", []))]
    test_loaders = [tloader for ds in dataset for tloader in ds.test_loader]
    train_loaders = [ds.train_loader for ds in dataset]

    return train_loaders, test_loaders


def ds_visionmix(**kwargs):
    """
    Vision mix dataset. Consists of: MNIST, notMNIST, FashionMNIST, SVHN and CIFAR10.
    """
    dataset = [DatasetsLoaders("MNIST", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1), pad_to_32=True),
               DatasetsLoaders("NOTMNIST", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1)),
               DatasetsLoaders("FashionMNIST", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1)),
               DatasetsLoaders("SVHN", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1)),
               DatasetsLoaders("CIFAR10", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1))]
    test_loaders = [ds.test_loader for ds in dataset]
    train_loaders = [ds.train_loader for ds in dataset]
    return train_loaders, test_loaders


def ds_cifar10and100(**kwargs):
    """
    CIFAR10 and CIFAR100 dataset. Consists of 6 tasks:
        1) CIFAR10
        2-6) Subsets of 10 classes from CIFAR100.
    """
    classes_lst = [[j for j in range(i * 10, (i + 1) * 10)] for i in range(0, 5)]
    dataset = [DatasetsLoaders("CIFAR100", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1),
                               reduce_classes=cl, preserve_label_space=False) for cl in classes_lst]
    dataset = [DatasetsLoaders("CIFAR10", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1), preserve_label_space=False)] + dataset

    test_loaders = [ds.test_loader for ds in dataset]
    train_loaders = [ds.train_loader for ds in dataset]
    return train_loaders, test_loaders


def ds_cifar10(**kwargs):
    """
    CIFAR10 dataset. No tasks.
    """
    dataset = [DatasetsLoaders("CIFAR10", batch_size=kwargs.get("batch_size", 128),
                               num_workers=kwargs.get("num_workers", 1))]

    test_loaders = [ds.test_loader for ds in dataset]
    train_loaders = [ds.train_loader for ds in dataset]
    return train_loaders, test_loaders