# Copyright 2017 Max Planck Society
# Distributed under the BSD-3 Software license,
# (See accompanying file ./LICENSE.txt or copy at
# https://opensource.org/licenses/BSD-3-Clause)
"""This class helps to handle the data.

"""

import os
import random
import logging
import tensorflow as tf
import numpy as np
from six.moves import cPickle
import utils
import PIL
from utils import ArraySaver
from PIL import Image
import sys

datashapes = {}
datashapes['mnist'] = [28, 28, 1]
datashapes['cifar10'] = [32, 32, 3]
datashapes['celebA'] = [64, 64, 3]
datashapes['grassli'] = [64, 64, 3]
datashapes['dsprites'] = [64, 64, 1]

def _data_dir(opts):
    if opts['data_dir'].startswith("/"):
        return opts['data_dir']
    else:
        return os.path.join('./', opts['data_dir'])

def load_cifar_batch(fpath, label_key='labels'):
    """Internal utility for parsing CIFAR data.

    # Arguments
        fpath: path the file to parse.
        label_key: key for label data in the retrieve
            dictionary.

    # Returns
        A tuple `(data, labels)`.
    """
    f = utils.o_gfile(fpath, 'rb')
    if sys.version_info < (3,):
        d = cPickle.load(f)
    else:
        d = cPickle.load(f, encoding='bytes')
        # decode utf8
        d_decoded = {}
        for k, v in d.items():
            d_decoded[k.decode('utf8')] = v
        d = d_decoded
    f.close()
    data = d['data']
    labels = d[label_key]

    data = data.reshape(data.shape[0], 3, 32, 32)
    return data, labels

def transform_mnist(pic, mode='n'):
    """Take an MNIST picture normalized into [0, 1] and transform
        it according to the mode:
        n   -   noise
        i   -   colour invert
        s*  -   shift
    """
    pic = np.copy(pic)
    if mode == 'n':
        noise = np.random.randn(28, 28, 1)
        return np.clip(pic + 0.25 * noise, 0, 1)
    elif mode == 'i':
        return 1. - pic
    pixels = 3 + np.random.randint(5)
    if mode == 'sl':
        pic[:, :-pixels] = pic[:, pixels:] + 0.0
        pic[:, -pixels:] = 0.
    elif mode == 'sr':
        pic[:, pixels:] = pic[:, :-pixels] + 0.0
        pic[:, :pixels] = 0.
    elif mode == 'sd':
        pic[pixels:, :] = pic[:-pixels, :] + 0.0
        pic[:pixels, :] = 0.
    elif mode == 'su':
        pic[:-pixels, :] = pic[pixels:, :] + 0.0
        pic[-pixels:, :] = 0.
    return pic


class Data(object):
    """
    If the dataset can be quickly loaded to memory self.X will contain np.ndarray
    Otherwise we will be reading files as we train. In this case self.X is a structure:
        self.X.paths        list of paths to the files containing pictures
        self.X.dict_loaded  dictionary of (key, val), where key is the index of the
                            already loaded datapoint and val is the corresponding index
                            in self.X.loaded
        self.X.loaded       list containing already loaded pictures
    """
    def __init__(self, opts, X, paths=None, dict_loaded=None, loaded=None):
        """
        X is either np.ndarray or paths
        """
        data_dir = _data_dir(opts)
        self.X = None
        self.normalize = opts['input_normalize_sym']
        self.paths = None
        self.dict_loaded = None
        self.loaded = None
        if isinstance(X, np.ndarray):
            self.X = X
            self.shape = X.shape
        else:
            assert isinstance(data_dir, str), 'Data directory not provided'
            assert paths is not None and len(paths) > 0, 'No paths provided for the data'
            self.data_dir = data_dir
            self.paths = paths[:]
            self.dict_loaded = {} if dict_loaded is None else dict_loaded
            self.loaded = [] if loaded is None else loaded
            self.crop_style = opts['celebA_crop']
            self.dataset_name = opts['dataset']
            self.shape = (len(self.paths), None, None, None)

    def __len__(self):
        if isinstance(self.X, np.ndarray):
            return len(self.X)
        else:
            # Our dataset was too large to fit in the memory
            return len(self.paths)

    def drop_loaded(self):
        if not isinstance(self.X, np.ndarray):
            self.dict_loaded = {}
            self.loaded = []

    def __getitem__(self, key):
        if isinstance(self.X, np.ndarray):
            return self.X[key]
        else:
            # Our dataset was too large to fit in the memory
            if isinstance(key, int):
                keys = [key]
            elif isinstance(key, list):
                keys = key
            elif isinstance(key, np.ndarray):
                keys = list(key)
            elif isinstance(key, slice):
                start = key.start
                stop = key.stop
                step = key.step
                start = start if start is not None else 0
                if start < 0:
                    start += len(self.paths)
                stop = stop if stop is not None else len(self.paths) - 1
                if stop < 0:
                    stop += len(self.paths)
                step = step if step is not None else 1
                keys = range(start, stop, step)
            else:
                print type(key)
                raise Exception('This type of indexing yet not supported for the dataset')
            res = []
            new_keys = []
            new_points = []
            for key in keys:
                if key in self.dict_loaded:
                    idx = self.dict_loaded[key]
                    res.append(self.loaded[idx])
                else:
                    if self.dataset_name == 'celebA':
                        point = self._read_celeba_image(self.data_dir, self.paths[key])
                    else:
                        raise Exception('Disc read for this dataset not implemented yet...')
                    if self.normalize:
                        point = (point - 0.5) * 2.
                    res.append(point)
                    new_points.append(point)
                    new_keys.append(key)
            n = len(self.loaded)
            cnt = 0
            for key in new_keys:
                self.dict_loaded[key] = n + cnt
                cnt += 1
            self.loaded.extend(new_points)
            return np.array(res)

    def _read_celeba_image(self, data_dir, filename):
        width = 178
        height = 218
        new_width = 140
        new_height = 140
        im = Image.open(utils.o_gfile((data_dir, filename), 'rb'))
        if self.crop_style == 'closecrop':
            # This method was used in DCGAN, pytorch-gan-collection, AVB, ...
            left = (width - new_width) / 2
            top = (height - new_height) / 2
            right = (width + new_width) / 2
            bottom = (height + new_height)/2
            im = im.crop((left, top, right, bottom))
            im = im.resize((64, 64), PIL.Image.ANTIALIAS)
        elif self.crop_style == 'resizecrop':
            # This method was used in ALI, AGE, ...
            im = im.resize((64, 78), PIL.Image.ANTIALIAS)
            im = im.crop((0, 7, 64, 64 + 7))
        else:
            raise Exception('Unknown crop style specified')
        return np.array(im).reshape(64, 64, 3) / 255.

class DataHandler(object):
    """A class storing and manipulating the dataset.

    In this code we asume a data point is a 3-dimensional array, for
    instance a 28*28 grayscale picture would correspond to (28,28,1),
    a 16*16 picture of 3 channels corresponds to (16,16,3) and a 2d point
    corresponds to (2,1,1). The shape is contained in self.data_shape
    """


    def __init__(self, opts):
        self.data_shape = None
        self.num_points = None
        self.data = None
        self.test_data = None
        self.labels = None
        self.test_labels = None
        self._load_data(opts)

    def _load_data(self, opts):
        """Load a dataset and fill all the necessary variables.

        """
        if opts['dataset'] == 'mnist':
            self._load_mnist(opts)
        elif opts['dataset'] == 'dsprites':
            self._load_dsprites(opts)
        elif opts['dataset'] == 'mnist_mod':
            self._load_mnist(opts, modified=True)
        elif opts['dataset'] == 'zalando':
            self._load_mnist(opts, zalando=True)
        elif opts['dataset'] == 'mnist3':
            self._load_mnist3(opts)
        elif opts['dataset'] == 'gmm':
            self._load_gmm(opts)
        elif opts['dataset'] == 'circle_gmm':
            self._load_mog(opts)
        elif opts['dataset'] == 'guitars':
            self._load_guitars(opts)
        elif opts['dataset'] == 'cifar10':
            self._load_cifar(opts)
        elif opts['dataset'] == 'celebA':
            self._load_celebA(opts)
        elif opts['dataset'] == 'grassli':
            self._load_grassli(opts)
        else:
            raise ValueError('Unknown %s' % opts['dataset'])

        sym_applicable = ['mnist',
                          'dsprites',
                          'mnist3',
                          'guitars',
                          'cifar10',
                          'celebA',
                          'grassli']

        if opts['input_normalize_sym'] and opts['dataset'] not in sym_applicable:
            raise Exception('Can not normalyze this dataset')

        if opts['input_normalize_sym'] and opts['dataset'] in sym_applicable:
            # Normalize data to [-1, 1]
            if isinstance(self.data.X, np.ndarray):
                self.data.X = (self.data.X - 0.5) * 2.
                self.test_data.X = (self.test_data.X - 0.5) * 2.
            # Else we will normalyze while reading from disk


    def _load_mog(self, opts):
        """Sample data from the mixture of Gaussians on circle.

        """

        # Only use this setting in dimension 2
        assert opts['toy_dataset_dim'] == 2

        # First we choose parameters of gmm and thus seed
        radius = opts['gmm_max_val']
        modes_num = opts["gmm_modes_num"]
        np.random.seed(opts["random_seed"])

        thetas = np.linspace(0, 2 * np.pi, modes_num)
        mixture_means = np.stack((radius * np.sin(thetas), radius * np.cos(thetas)), axis=1)
        mixture_variance = 0.01

        # Now we sample points, for that we unseed
        np.random.seed()
        num = opts['toy_dataset_size']
        X = np.zeros((num, opts['toy_dataset_dim'], 1, 1))
        for idx in xrange(num):
            comp_id = np.random.randint(modes_num)
            mean = mixture_means[comp_id]
            cov = mixture_variance * np.identity(opts["toy_dataset_dim"])
            X[idx, :, 0, 0] = np.random.multivariate_normal(mean, cov, 1)

        self.data_shape = (opts['toy_dataset_dim'], 1, 1)
        self.data = Data(opts, X)
        self.num_points = len(X)

    def _load_gmm(self, opts):
        """Sample data from the mixture of Gaussians.

        """

        logging.debug('Loading GMM dataset...')
        # First we choose parameters of gmm and thus seed
        modes_num = opts["gmm_modes_num"]
        np.random.seed(opts["random_seed"])
        max_val = opts['gmm_max_val']
        mixture_means = np.random.uniform(
            low=-max_val, high=max_val,
            size=(modes_num, opts['toy_dataset_dim']))

        def variance_factor(num, dim):
            if num == 1: return 3 ** (2. / dim)
            if num == 2: return 3 ** (2. / dim)
            if num == 3: return 8 ** (2. / dim)
            if num == 4: return 20 ** (2. / dim)
            if num == 5: return 10 ** (2. / dim)
            return num ** 2.0 * 3

        mixture_variance = \
                max_val / variance_factor(modes_num, opts['toy_dataset_dim'])

        # Now we sample points, for that we unseed
        np.random.seed()
        num = opts['toy_dataset_size']
        X = np.zeros((num, opts['toy_dataset_dim'], 1, 1))
        for idx in xrange(num):
            comp_id = np.random.randint(modes_num)
            mean = mixture_means[comp_id]
            cov = mixture_variance * np.identity(opts["toy_dataset_dim"])
            X[idx, :, 0, 0] = np.random.multivariate_normal(mean, cov, 1)

        self.data_shape = (opts['toy_dataset_dim'], 1, 1)
        self.data = Data(opts, X)
        self.num_points = len(X)

        logging.debug('Loading GMM dataset done!')

    def _load_guitars(self, opts):
        """Load data from Thomann files.

        """
        logging.debug('Loading Guitars dataset')
        data_dir = os.path.join('./', 'thomann')
        X = None
        files = utils.listdir(data_dir)
        pics = []
        for f in sorted(files):
            if '.jpg' in f and f[0] != '.':
                im = Image.open(utils.o_gfile((data_dir, f), 'rb'))
                res = np.array(im.getdata()).reshape(128, 128, 3)
                pics.append(res)
        X = np.array(pics)

        seed = 123
        np.random.seed(seed)
        np.random.shuffle(X)
        np.random.seed()

        self.data_shape = (128, 128, 3)
        self.data = Data(opts, X/255.)
        self.num_points = len(X)

        logging.debug('Loading Done.')

    def _load_dsprites(self, opts):
        """Load data from dsprites dataset

        """
        logging.debug('Loading dsprites')
        data_dir = _data_dir(opts)
        data_file = os.path.join(data_dir, 'dsprites.npz')
        X = np.load(data_file)['imgs']
        X = X[:, :, :, None]

        seed = 123
        np.random.seed(seed)
        np.random.shuffle(X)
        np.random.seed()

        self.data_shape = (64, 64, 1)
        test_size = 10000

        self.data = Data(opts, X[:-test_size])
        self.test_data = Data(opts, X[-test_size:])
        self.num_points = len(self.data)

        logging.debug('Loading Done.')

    def _load_mnist(self, opts, zalando=False, modified=False):
        """Load data from MNIST or ZALANDO files.

        """
        if zalando:
            logging.debug('Loading Fashion MNIST')
        elif modified:
            logging.debug('Loading modified MNIST')
        else:
            logging.debug('Loading MNIST')
        data_dir = _data_dir(opts)
        # pylint: disable=invalid-name
        # Let us use all the bad variable names!
        tr_X = None
        tr_Y = None
        te_X = None
        te_Y = None

        with utils.o_gfile((data_dir, 'train-images-idx3-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            tr_X = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)

        with utils.o_gfile((data_dir, 'train-labels-idx1-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            tr_Y = loaded[8:].reshape((60000)).astype(np.int)

        with utils.o_gfile((data_dir, 't10k-images-idx3-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            te_X = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)

        with utils.o_gfile((data_dir, 't10k-labels-idx1-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            te_Y = loaded[8:].reshape((10000)).astype(np.int)

        tr_Y = np.asarray(tr_Y)
        te_Y = np.asarray(te_Y)

        X = np.concatenate((tr_X, te_X), axis=0)
        y = np.concatenate((tr_Y, te_Y), axis=0)
        X = X / 255.

        seed = 123
        np.random.seed(seed)
        np.random.shuffle(X)
        np.random.seed(seed)
        np.random.shuffle(y)
        np.random.seed()

        self.data_shape = (28, 28, 1)
        test_size = 10000

        if modified:
            self.original_mnist = X
            n = opts['toy_dataset_size']
            n += test_size
            points = []
            labels = []
            for _ in xrange(n):
                idx = np.random.randint(len(X))
                point = X[idx]
                modes = ['n', 'i', 'sl', 'sr', 'su', 'sd']
                mode = modes[np.random.randint(len(modes))]
                point = transform_mnist(point, mode)
                points.append(point)
                labels.append(y[idx])
            X = np.array(points)
            y = np.array(y)
        self.data = Data(opts, X[:-test_size])
        self.test_data = Data(opts, X[-test_size:])
        self.labels = y[:-test_size]
        self.test_labels = y[-test_size:]
        self.num_points = len(self.data)

        logging.debug('Loading Done.')

    def _load_mnist3(self, opts):
        """Load data from MNIST files.

        """
        logging.debug('Loading 3-digit MNIST')
        data_dir = _data_dir(opts)
        # pylint: disable=invalid-name
        # Let us use all the bad variable names!
        tr_X = None
        tr_Y = None
        te_X = None
        te_Y = None

        with utils.o_gfile((data_dir, 'train-images-idx3-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            tr_X = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)

        with utils.o_gfile((data_dir, 'train-labels-idx1-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            tr_Y = loaded[8:].reshape((60000)).astype(np.int)

        with utils.o_gfile((data_dir, 't10k-images-idx3-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            te_X = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)

        with utils.o_gfile((data_dir, 't10k-labels-idx1-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            te_Y = loaded[8:].reshape((10000)).astype(np.int)

        tr_Y = np.asarray(tr_Y)
        te_Y = np.asarray(te_Y)

        X = np.concatenate((tr_X, te_X), axis=0)
        y = np.concatenate((tr_Y, te_Y), axis=0)

        num = opts['mnist3_dataset_size']
        ids = np.random.choice(len(X), (num, 3), replace=True)
        if opts['mnist3_to_channels']:
            # Concatenate 3 digits ito 3 channels
            X3 = np.zeros((num, 28, 28, 3))
            y3 = np.zeros(num)
            for idx, _id in enumerate(ids):
                X3[idx, :, :, 0] = np.squeeze(X[_id[0]], axis=2)
                X3[idx, :, :, 1] = np.squeeze(X[_id[1]], axis=2)
                X3[idx, :, :, 2] = np.squeeze(X[_id[2]], axis=2)
                y3[idx] = y[_id[0]] * 100 + y[_id[1]] * 10 + y[_id[2]]
            self.data_shape = (28, 28, 3)
        else:
            # Concatenate 3 digits in width
            X3 = np.zeros((num, 28, 3 * 28, 1))
            y3 = np.zeros(num)
            for idx, _id in enumerate(ids):
                X3[idx, :, 0:28, 0] = np.squeeze(X[_id[0]], axis=2)
                X3[idx, :, 28:56, 0] = np.squeeze(X[_id[1]], axis=2)
                X3[idx, :, 56:84, 0] = np.squeeze(X[_id[2]], axis=2)
                y3[idx] = y[_id[0]] * 100 + y[_id[1]] * 10 + y[_id[2]]
            self.data_shape = (28, 28 * 3, 1)

        self.data = Data(opts, X3/255.)
        y3 = y3.astype(int)
        self.labels = y3
        self.num_points = num

        logging.debug('Training set JS=%.4f' % utils.js_div_uniform(y3))
        logging.debug('Loading Done.')

    def _load_cifar(self, opts):
        """Load CIFAR10

        """
        logging.debug('Loading CIFAR10 dataset')

        num_train_samples = 50000
        data_dir = _data_dir(opts)
        x_train = np.zeros((num_train_samples, 3, 32, 32), dtype='uint8')
        y_train = np.zeros((num_train_samples,), dtype='uint8')

        for i in range(1, 6):
            fpath = os.path.join(data_dir, 'data_batch_' + str(i))
            data, labels = load_cifar_batch(fpath)
            x_train[(i - 1) * 10000: i * 10000, :, :, :] = data
            y_train[(i - 1) * 10000: i * 10000] = labels

        fpath = os.path.join(data_dir, 'test_batch')
        x_test, y_test = load_cifar_batch(fpath)

        y_train = np.reshape(y_train, (len(y_train), 1))
        y_test = np.reshape(y_test, (len(y_test), 1))
        x_train = x_train.transpose(0, 2, 3, 1)
        x_test = x_test.transpose(0, 2, 3, 1)

        X = np.vstack([x_train, x_test])
        X = X/255.
        y = np.vstack([y_train, y_test])

        seed = 123
        np.random.seed(seed)
        np.random.shuffle(X)
        np.random.seed(seed)
        np.random.shuffle(y)
        np.random.seed()

        self.data_shape = (32, 32, 3)

        self.data = Data(opts, X[:-1000])
        self.test_data = Data(opts, X[-1000:])
        self.labels = y[:-1000]
        self.test_labels = y[-1000:]
        self.num_points = len(self.data)

        logging.debug('Loading Done.')

    def _load_celebA(self, opts):
        """Load CelebA
        """
        logging.debug('Loading CelebA dataset')

        num_samples = 202599

        datapoint_ids = range(1, num_samples + 1)
        paths = ['%.6d.jpg' % i for i in xrange(1, num_samples + 1)]
        seed = 123
        random.seed(seed)
        random.shuffle(paths)
        random.shuffle(datapoint_ids)
        random.seed()

        saver = ArraySaver('disk', workdir=opts['work_dir'])
        saver.save('shuffled_training_ids', datapoint_ids)

        self.data_shape = (64, 64, 3)
        test_size = 20000
        self.data = Data(opts, None, paths[:-test_size])
        self.test_data = Data(opts, None, paths[-test_size:])
        self.num_points = num_samples - test_size
        self.labels = np.array(self.num_points * [0])
        self.test_labels = np.array(test_size * [0])

        logging.debug('Loading Done.')

    def _load_grassli(self, opts):
        """Load grassli

        """
        logging.debug('Loading grassli dataset')

        data_dir = _data_dir(opts)
        X = np.load(utils.o_gfile((data_dir, 'grassli.npy'), 'rb')) / 255.

        seed = 123
        np.random.seed(seed)
        np.random.shuffle(X)
        np.random.seed(seed)
        np.random.seed()

        self.data_shape = (64, 64, 3)
        test_size = 5000

        self.data = Data(opts, X[:-test_size])
        self.test_data = Data(opts, X[-test_size:])
        self.num_points = len(self.data)

        logging.debug('Loading Done.')