import scipy
from glob import glob
import numpy as np
from keras.datasets import mnist
from skimage.transform import resize as imresize
import pickle
import os
import urllib
import gzip

class DataLoader():
    """Loads images from MNIST (domain A) and MNIST-M (domain B)"""
    def __init__(self, img_res=(128, 128)):
        self.img_res = img_res

        self.mnistm_url = 'https://github.com/VanushVaswani/keras_mnistm/releases/download/1.0/keras_mnistm.pkl.gz'

        self.setup_mnist(img_res)
        self.setup_mnistm(img_res)

    def normalize(self, images):
        return images.astype(np.float32) / 127.5 - 1.

    def setup_mnist(self, img_res):

        print ("Setting up MNIST...")

        if not os.path.exists('datasets/mnist_x.npy'):
            # Load the dataset
            (mnist_X, mnist_y), (_, _) = mnist.load_data()

            # Normalize and rescale images
            mnist_X = self.normalize(mnist_X)
            mnist_X = np.array([imresize(x, img_res) for x in mnist_X])
            mnist_X = np.expand_dims(mnist_X, axis=-1)
            mnist_X = np.repeat(mnist_X, 3, axis=-1)

            self.mnist_X, self.mnist_y = mnist_X, mnist_y

            # Save formatted images
            np.save('datasets/mnist_x.npy', self.mnist_X)
            np.save('datasets/mnist_y.npy', self.mnist_y)
        else:
            self.mnist_X = np.load('datasets/mnist_x.npy')
            self.mnist_y = np.load('datasets/mnist_y.npy')

        print ("+ Done.")

    def setup_mnistm(self, img_res):

        print ("Setting up MNIST-M...")

        if not os.path.exists('datasets/mnistm_x.npy'):

            # Download the MNIST-M pkl file
            filepath = 'datasets/keras_mnistm.pkl.gz'
            if not os.path.exists(filepath.replace('.gz', '')):
                print('+ Downloading ' + self.mnistm_url)
                data = urllib.request.urlopen(self.mnistm_url)
                with open(filepath, 'wb') as f:
                    f.write(data.read())
                with open(filepath.replace('.gz', ''), 'wb') as out_f, \
                        gzip.GzipFile(filepath) as zip_f:
                    out_f.write(zip_f.read())
                os.unlink(filepath)

            # load MNIST-M images from pkl file
            with open('datasets/keras_mnistm.pkl', "rb") as f:
                data = pickle.load(f, encoding='bytes')

            # Normalize and rescale images
            mnistm_X = np.array(data[b'train'])
            mnistm_X = self.normalize(mnistm_X)
            mnistm_X = np.array([imresize(x, img_res) for x in mnistm_X])

            self.mnistm_X, self.mnistm_y = mnistm_X, self.mnist_y.copy()

            # Save formatted images
            np.save('datasets/mnistm_x.npy', self.mnistm_X)
            np.save('datasets/mnistm_y.npy', self.mnistm_y)
        else:
            self.mnistm_X = np.load('datasets/mnistm_x.npy')
            self.mnistm_y = np.load('datasets/mnistm_y.npy')

        print ("+ Done.")


    def load_data(self, domain, batch_size=1):

        X = self.mnist_X if domain == 'A' else self.mnistm_X
        y = self.mnist_y if domain == 'A' else self.mnistm_y

        idx = np.random.choice(list(range(len(X))), size=batch_size)

        return X[idx], y[idx]