## Copyright (C) 2019, Huan Zhang <huan@huan-zhang.com>
##                     Hongge Chen <chenhg@mit.edu>
##                     Chaowei Xiao <xiaocw@umich.edu>
## 
## This program is licenced under the BSD 2-Clause License,
## contained in the LICENCE file in this directory.
##
import multiprocessing
import torch
from torch.utils import data
from functools import partial
import torchvision.transforms as transforms
import torchvision.datasets as datasets

# compute image statistics (by Andreas https://discuss.pytorch.org/t/computing-the-mean-and-std-of-dataset/34949/4)
def get_stats(loader):
    mean = 0.0
    for images, _ in loader:
        batch_samples = images.size(0) 
        reshaped_img = images.view(batch_samples, images.size(1), -1)
        mean += reshaped_img.mean(2).sum(0)
    w = images.size(2)
    h = images.size(3)
    mean = mean / len(loader.dataset)

    var = 0.0
    for images, _ in loader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        var += ((images - mean.unsqueeze(1))**2).sum([0,2])
    std = torch.sqrt(var / (len(loader.dataset)*w*h))
    return mean, std

# load MNIST of Fashion-MNIST
def mnist_loaders(dataset, batch_size, shuffle_train = True, shuffle_test = False, normalize_input = False, num_examples = None, test_batch_size=None): 
    mnist_train = dataset("./data", train=True, download=True, transform=transforms.ToTensor())
    mnist_test = dataset("./data", train=False, download=True, transform=transforms.ToTensor())
    if num_examples:
        indices = list(range(num_examples))
        mnist_train = data.Subset(mnist_train, indices)
        mnist_test = data.Subset(mnist_test, indices)
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),2))
    if test_batch_size:
        batch_size = test_batch_size
    test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),2))
    std = [1.0]
    mean = [0.0]
    train_loader.std = std
    test_loader.std = std
    train_loader.mean = mean
    test_loader.mean = mean
    return train_loader, test_loader

def cifar_loaders(batch_size, shuffle_train = True, shuffle_test = False, train_random_transform = False, normalize_input = False, num_examples = None, test_batch_size=None): 
    if normalize_input:
        std = [0.2023, 0.1994, 0.2010]
        mean = [0.4914, 0.4822, 0.4465]
        normalize = transforms.Normalize(mean = mean, std = std)
    else:
        std = [1.0, 1.0, 1.0]
        mean = [0, 0, 0]
        normalize = transforms.Normalize(mean = mean, std = std)
    if train_random_transform:
        if normalize_input:
            train = datasets.CIFAR10('./data', train=True, download=True, 
                transform=transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize,
                ]))
        else:
            train = datasets.CIFAR10('./data', train=True, download=True, 
                transform=transforms.Compose([
                    transforms.RandomHorizontalFlip(),
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                ]))
    else:
        train = datasets.CIFAR10('./data', train=True, download=True, 
            transform=transforms.Compose([transforms.ToTensor(),normalize]))
    test = datasets.CIFAR10('./data', train=False, 
        transform=transforms.Compose([transforms.ToTensor(), normalize]))
    
    if num_examples:
        indices = list(range(num_examples))
        train = data.Subset(train, indices)
        test = data.Subset(test, indices)

    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
        shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
    if test_batch_size:
        batch_size = test_batch_size
    test_loader = torch.utils.data.DataLoader(test, batch_size=max(batch_size, 1),
        shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
    train_loader.std = std
    test_loader.std = std
    train_loader.mean = mean
    test_loader.mean = mean
    return train_loader, test_loader

def svhn_loaders(batch_size, shuffle_train = True, shuffle_test = False, train_random_transform = False, normalize_input = False, num_examples = None, test_batch_size=None): 
    if normalize_input:
        mean = [0.43768206, 0.44376972, 0.47280434] 
        std = [0.19803014, 0.20101564, 0.19703615]
        normalize = transforms.Normalize(mean = mean, std = std)
    else:
        std = [1.0, 1.0, 1.0]
        mean = [0, 0, 0]
        normalize = transforms.Normalize(mean = mean, std = std)
    if train_random_transform:
        if normalize_input:
            train = datasets.SVHN('./data', split='train', download=True, 
                transform=transforms.Compose([
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                    normalize,
                ]))
        else:
            train = datasets.SVHN('./data', split='train', download=True, 
                transform=transforms.Compose([
                    transforms.RandomCrop(32, 4),
                    transforms.ToTensor(),
                ]))
    else:
        train = datasets.SVHN('./data', split='train', download=True, 
            transform=transforms.Compose([transforms.ToTensor(),normalize]))
    test = datasets.SVHN('./data', split='test', download=True,
        transform=transforms.Compose([transforms.ToTensor(), normalize]))
    
    if num_examples:
        indices = list(range(num_examples))
        train = data.Subset(train, indices)
        test = data.Subset(test, indices)

    train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size,
        shuffle=shuffle_train, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
    if test_batch_size:
        batch_size = test_batch_size
    test_loader = torch.utils.data.DataLoader(test, batch_size=max(batch_size, 1),
        shuffle=shuffle_test, pin_memory=True, num_workers=min(multiprocessing.cpu_count(),6))
    train_loader.std = std
    test_loader.std = std
    train_loader.mean = mean
    test_loader.mean = mean
    mean, std = get_stats(train_loader)
    print('dataset mean = ', mean.numpy(), 'std = ', std.numpy())
    return train_loader, test_loader

# when new loaders is added, they must be registered here
loaders = {
        "mnist": partial(mnist_loaders, datasets.MNIST),
        "fashion-mnist": partial(mnist_loaders, datasets.FashionMNIST),
        "cifar": cifar_loaders,
        "svhn": svhn_loaders,
        }