#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Provides access to the CIFAR-10 dataset, including simple data augmentation.

Author: Jan Schlüter
"""
import os
import sys

import numpy as np


def download_dataset(path, source='https://www.cs.toronto.edu/~kriz/'
                                  'cifar-10-python.tar.gz'):
    """
    Downloads and extracts the dataset, if needed.
    """
    files = ['data_batch_%d' % (i + 1) for i in range(5)] + ['test_batch']
    for fn in files:
        if not os.path.exists(os.path.join(path, 'cifar-10-batches-py', fn)):
            break  # at least one file is missing
    else:
        return  # dataset is already complete

    print("Downloading and extracting %s into %s..." % (source, path))
    if sys.version_info[0] == 2:
        from urllib import urlopen
    else:
        from urllib.request import urlopen
    import tarfile
    if not os.path.exists(path):
        os.makedirs(path)
    u = urlopen(source)
    with tarfile.open(fileobj=u, mode='r|gz') as f:
        f.extractall(path=path)
    u.close()


def load_dataset(path):
    download_dataset(path)

    # training data
    data = [np.load(os.path.join(path, 'cifar-10-batches-py',
                                 'data_batch_%d' % (i + 1))) for i in range(5)]
    X_train = np.vstack([d['data'] for d in data])
    y_train = np.hstack([np.asarray(d['labels'], np.int8) for d in data])

    # test data
    data = np.load(os.path.join(path, 'cifar-10-batches-py', 'test_batch'))
    X_test = data['data']
    y_test = np.asarray(data['labels'], np.int8)

    # reshape
    X_train = X_train.reshape(-1, 3, 32, 32)
    X_test = X_test.reshape(-1, 3, 32, 32)

    # normalize
    try:
        mean_std = np.load(os.path.join(path, 'cifar-10-mean_std.npz'))
        mean = mean_std['mean']
        std = mean_std['std']
    except IOError:
        mean = X_train.mean(axis=(0, 2, 3), keepdims=True).astype(np.float32)
        std = X_train.std(axis=(0, 2, 3), keepdims=True).astype(np.float32)
        np.savez(os.path.join(path, 'cifar-10-mean_std.npz'),
                 mean=mean, std=std)
    X_train = (X_train - mean) / std
    X_test = (X_test - mean) / std

    return X_train, y_train, X_test, y_test


def iterate_minibatches(inputs, targets, batchsize, shuffle=False):
    """
    Generates one epoch of batches of inputs and targets, optionally shuffled.
    """
    assert len(inputs) == len(targets)
    if shuffle:
        indices = np.arange(len(inputs))
        np.random.shuffle(indices)
    for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
        if shuffle:
            excerpt = indices[start_idx:start_idx + batchsize]
        else:
            excerpt = slice(start_idx, start_idx + batchsize)
        yield inputs[excerpt], targets[excerpt]


def augment_minibatches(minibatches, flip=0.5, trans=4):
    """
    Randomly augments images by horizontal flipping with a probability of
    `flip` and random translation of up to `trans` pixels in both directions.
    """
    for inputs, targets in minibatches:
        batchsize, c, h, w = inputs.shape
        if flip:
            coins = np.random.rand(batchsize) < flip
            inputs = [inp[:, :, ::-1] if coin else inp
                      for inp, coin in zip(inputs, coins)]
            if not trans:
                inputs = np.asarray(inputs)
        outputs = inputs
        if trans:
            outputs = np.empty((batchsize, c, h, w), inputs[0].dtype)
            shifts = np.random.randint(-trans, trans, (batchsize, 2))
            for outp, inp, (x, y) in zip(outputs, inputs, shifts):
                if x > 0:
                    outp[:, :x] = 0
                    outp = outp[:, x:]
                    inp = inp[:, :-x]
                elif x < 0:
                    outp[:, x:] = 0
                    outp = outp[:, :x]
                    inp = inp[:, -x:]
                if y > 0:
                    outp[:, :, :y] = 0
                    outp = outp[:, :, y:]
                    inp = inp[:, :, :-y]
                elif y < 0:
                    outp[:, :, y:] = 0
                    outp = outp[:, :, :y]
                    inp = inp[:, :, -y:]
                outp[:] = inp
        yield outputs, targets