# -*- coding: utf-8 -*-
"""
Created on Sun Mar 19 16:47:13 2017
@author: Chin-Wei
"""

import helpers
import os
import cPickle as pickle
import numpy as np
from sklearn.preprocessing import OneHotEncoder
floatX = 'float32'

import scipy.io
from torch.utils.data import Dataset as Dataset



def load_bmnist_image(root='dataset'):
    helpers.create(root, 'bmnist')
    droot = root+'/'+'bmnist'
    
    if not os.path.exists('{}/binarized_mnist_train.amat'.format(droot)):
        from downloader import download_bmnist
        download_bmnist(droot)
    
    # Larochelle 2011
    path_tr = '{}/binarized_mnist_train.amat'.format(droot)
    path_va = '{}/binarized_mnist_valid.amat'.format(droot)
    path_te = '{}/binarized_mnist_test.amat'.format(droot)
    train_x = np.loadtxt(path_tr).astype(floatX).reshape(50000,784)
    valid_x = np.loadtxt(path_va).astype(floatX).reshape(10000,784)
    test_x = np.loadtxt(path_te).astype(floatX).reshape(10000,784)
    
    
    return train_x, valid_x, test_x


def load_mnist_image(root='dataset',n_validation=1345, state=123):
    helpers.create(root, 'bmnist')
    droot = root+'/'+'bmnist'
    
    if not os.path.exists('{}/train-images-idx3-ubyte'.format(droot)):
        from downloader import download_bmnist
        download_bmnist(droot)
    
    path_tr = '{}/train-images-idx3-ubyte'.format(droot)
    path_te = '{}/t10k-images-idx3-ubyte'.format(droot)
    train_x = np.loadtxt(path_tr).astype(floatX)
    test_x = np.loadtxt(path_te).astype(floatX)
    
    return train_x[:50000], train_x[50000:], test_x
    

def load_cifar10_image(root='dataset',labels=False):
    helpers.create(root, 'cifar10')
    droot = root+'/'+'cifar10'
    
    if not os.path.exists('{}/cifar10.pkl'.format(droot)):
        from downloader import download_cifar10
        download_cifar10(droot)
    
    f = lambda d:d.astype(floatX)
    filename = '{}/cifar10.pkl'.format(droot)
    tr_x, tr_y, te_x, te_y = pickle.load(open(filename,'r'))
    if tr_x.max() == 255:
        tr_x = tr_x / 256.
        te_x = te_x / 256.
        
    if labels:
        enc = OneHotEncoder(10)
        tr_y = enc.fit_transform(tr_y).toarray().reshape(50000,10).astype(int)
        te_y = enc.fit_transform(te_y).toarray().reshape(10000,10).astype(int)
        
        return (f(d) for d in [tr_x, tr_y, te_x, te_y])   
    else:
        return (f(d) for d in [tr_x, te_x])
    
    
def load_omniglot_image(root='dataset',n_validation=1345, state=123):
    helpers.create(root, 'omniglot')
    droot = root+'/'+'omniglot'
    
    if not os.path.exists('{}/omniglot.amat'.format(droot)):
        from downloader import download_omniglot
        download_omniglot(droot)
    
    
    def reshape_data(data):
        return data.reshape((-1, 28, 28)).reshape((-1, 28*28), order='fortran')
    path = '{}/omniglot.amat'.format(droot)
    omni_raw = scipy.io.loadmat(path)

    train_data = reshape_data(omni_raw['data'].T.astype(floatX))
    test_data = reshape_data(omni_raw['testdata'].T.astype(floatX))

    n = train_data.shape[0]
    
    ind_va = np.random.RandomState(
        state).choice(n, n_validation, replace=False)
    
    ind_tr = np.delete(np.arange(n), ind_va)
    
    return train_data[ind_tr], train_data[ind_va], test_data



def load_caltech101_image(root='dataset'):
    # binary
    # tr: 4100 x 28 x 28
    # va: 2264 x 28 x 28
    # te: 2307 x 28 x 28
    helpers.create(root, 'caltech101')
    droot = root+'/'+'caltech101'
    fn = 'caltech101_silhouettes_28_split1.mat'
    
    if not os.path.exists('{}/{}'.format(droot, fn)):
        from downloader import download_caltech101
        download_caltech101(droot)
    
    ds = scipy.io.loadmat('{}/{}'.format(droot, fn))
    ds = [ds['train_data'], ds['val_data'], ds['test_data']]
    
    return [d.astype(floatX) for d in ds]



class DatasetWrapper(Dataset):

    def __init__(self, dataset, transform=None):

        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, ind):
        
        sample = self.dataset[ind]
        if self.transform:
            sample = self.transform(sample)
        
        return sample

class InputOnly(Dataset):

    def __init__(self, dataset):

        self.dataset = dataset
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, ind):
        return self.dataset[ind][0]