import numpy as np
import torch
import torchnet as tnt
from torchvision.datasets.mnist import MNIST

import config


def augmentation(x, max_shift=2):
    _, _, height, width = x.size()

    h_shift, w_shift = np.random.randint(-max_shift, max_shift + 1, size=2)
    source_height_slice = slice(max(0, h_shift), h_shift + height)
    source_width_slice = slice(max(0, w_shift), w_shift + width)
    target_height_slice = slice(max(0, -h_shift), -h_shift + height)
    target_width_slice = slice(max(0, -w_shift), -w_shift + width)

    shifted_image = torch.zeros(*x.size())
    shifted_image[:, :, source_height_slice, source_width_slice] = x[:, :, target_height_slice, target_width_slice]
    return shifted_image.float()


def get_iterator(mode):
    dataset = MNIST(root='./data', train=mode, download=True)
    data = getattr(dataset, 'train_data' if mode else 'test_data')
    labels = getattr(dataset, 'train_labels' if mode else 'test_labels')
    tensor_dataset = tnt.dataset.TensorDataset([data, labels])

    return tensor_dataset.parallel(batch_size=config.BATCH_SIZE, num_workers=4, shuffle=mode)


if __name__ == "__main__":
    t = torch.rand(1, 1, 28, 28)
    print(t)
    y = augmentation(t)
    print(y)