"""
Download all datasets and compute preprocessing structures (whitening matrices, etc).
"""
import torch
import torchvision
from nice.utils import rescale


def zca_matrix(data_tensor):
    """
    Helper function: compute ZCA whitening matrix across a dataset ~ (N, C, H, W).
    """
    # 1. flatten dataset:
    X = data_tensor.view(data_tensor.shape[0], -1)
    
    # 2. zero-center the matrix:
    X = rescale(X, -1., 1.)
    
    # 3. compute covariances:
    cov = torch.t(X) @ X

    # 4. compute ZCA(X) == U @ (diag(1/S)) @ torch.t(V) where U, S, V = SVD(cov):
    U, S, V = torch.svd(cov)
    return (U @ torch.diag(torch.reciprocal(S)) @ torch.t(V))


def main():
    ### download training datasets:
    print("Downloading CIFAR10...")
    cifar10 = torchvision.datasets.CIFAR10(root="./datasets/cifar", train=True,
                                           transform=torchvision.transforms.ToTensor(), download=True)
    print("Downloading SVHN...")
    svhn = torchvision.datasets.SVHN(root="./datasets/svhn", split='train',
                                     transform=torchvision.transforms.ToTensor(), download=True)
    print("Downloading MNIST...")
    mnist = torchvision.datasets.MNIST(root="./datasets/mnist", train=True,
                                       transform=torchvision.transforms.ToTensor(), download=True)

    ### save ZCA whitening matrices:
    print("Computing CIFAR10 ZCA matrix...")
    torch.save(zca_matrix(torch.cat([x for (x,_) in cifar10], dim=0)), "./datasets/cifar/zca_matrix.pt")
    print("Computing SVHN ZCA matrix...")
    torch.save(zca_matrix(torch.cat([x for (x,_) in svhn], dim=0)), "./datasets/svhn/zca_matrix.pt")

    print("...All done.")

if __name__ == '__main__':
    main()