# -*- coding: utf-8 -*-
# File: mnist.py

import gzip
import numpy
import os

from ...utils import logger
from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow

__all__ = ['Mnist', 'FashionMnist']

def maybe_download(url, work_directory):
    """Download the data from Yann's website, unless it's already here."""
    filename = url.split('/')[-1]
    filepath = os.path.join(work_directory, filename)
    if not os.path.exists(filepath):
        logger.info("Downloading to {}...".format(filepath))
        download(url, work_directory)
    return filepath

def _read32(bytestream):
    dt = numpy.dtype(numpy.uint32).newbyteorder('>')
    return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]

def extract_images(filename):
    """Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
    with gzip.open(filename) as bytestream:
        magic = _read32(bytestream)
        if magic != 2051:
            raise ValueError(
                'Invalid magic number %d in MNIST image file: %s' %
                (magic, filename))
        num_images = _read32(bytestream)
        rows = _read32(bytestream)
        cols = _read32(bytestream)
        buf = bytestream.read(rows * cols * num_images)
        data = numpy.frombuffer(buf, dtype=numpy.uint8)
        data = data.reshape(num_images, rows, cols, 1)
        data = data.astype('float32') / 255.0
        return data

def extract_labels(filename):
    """Extract the labels into a 1D uint8 numpy array [index]."""
    with gzip.open(filename) as bytestream:
        magic = _read32(bytestream)
        if magic != 2049:
            raise ValueError(
                'Invalid magic number %d in MNIST label file: %s' %
                (magic, filename))
        num_items = _read32(bytestream)
        buf = bytestream.read(num_items)
        labels = numpy.frombuffer(buf, dtype=numpy.uint8)
        return labels

class Mnist(RNGDataFlow):
    Produces [image, label] in MNIST dataset,
    image is 28x28 in the range [0,1], label is an int.

    _DIR_NAME = 'mnist_data'
    _SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'

    def __init__(self, train_or_test, shuffle=True, dir=None):
            train_or_test (str): either 'train' or 'test'
            shuffle (bool): shuffle the dataset
        if dir is None:
            dir = get_dataset_path(self._DIR_NAME)
        assert train_or_test in ['train', 'test']
        self.train_or_test = train_or_test
        self.shuffle = shuffle

        def get_images_and_labels(image_file, label_file):
            f = maybe_download(self._SOURCE_URL + image_file, dir)
            images = extract_images(f)
            f = maybe_download(self._SOURCE_URL + label_file, dir)
            labels = extract_labels(f)
            assert images.shape[0] == labels.shape[0]
            return images, labels

        if self.train_or_test == 'train':
            self.images, self.labels = get_images_and_labels(
    def __len__(self):
        return self.images.shape[0]

    def __iter__(self):
        idxs = list(range(self.__len__()))
        if self.shuffle:
        for k in idxs:
            img = self.images[k].reshape((28, 28))
            label = self.labels[k]
            yield [img, label]

class FashionMnist(Mnist):
    Same API as :class:`Mnist`, but more fashion.

    _DIR_NAME = 'fashion_mnist_data'
    _SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'

    def get_label_names(self):
            [str]: the name of each class
        # copied from https://github.com/zalandoresearch/fashion-mnist
        return ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
                'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

if __name__ == '__main__':
    ds = Mnist('train')
    for _ in ds:
        from IPython import embed