from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path
import numpy as np
from skimage.io import imread
import scipy.misc as sm
from datasets.imagenet.map import class2num

from util import log

__IMAGENET_IMG_PATH__ = '/YOUR_IMAGENET_PATH/ILSVRC/Data/CLS-LOC'
__IMAGENET_LIST_PATH__ = './datasets/imagenet'

rs = np.random.RandomState(123)


class Dataset(object):

    def __init__(self, ids, name='default',
                 max_examples=None, is_train=True):
        self._ids = list(ids)
        self.name = name
        self.is_train = is_train

        if max_examples is not None:
            self._ids = self._ids[:max_examples]

        file = os.path.join(__IMAGENET_IMG_PATH__, self._ids[0])

        try:
            imread(file)
        except:
            raise IOError('Dataset not found. Please make sure the dataset was downloaded.')
        log.info("Reading Done: %s", file)

    def load_image(self, id):
        img = imread(
            os.path.join(__IMAGENET_IMG_PATH__, id)) / 255. * 2 - 1
        img = sm.imresize(img, [128, 128])

        y = np.random.randint(img.shape[0]-114)
        x = np.random.randint(img.shape[1]-114)
        img = img[y:y+112, x:x+112, :3]
        # assert img.shape[-1] == 3, '{} dimension mismatch {}'.format(id, img.shape[-1])

        l = np.zeros(1000)
        l[class2num[id.split('/')[-2]]] = 1
        return img, l

    def get_data(self, id1, id2):
        # preprocessing and data augmentation
        img_x, l_x = self.load_image(id1)
        img_y, l_y = self.load_image(id2)
        return img_x, img_y, l_x, l_y

    @property
    def ids(self):
        return self._ids

    def __len__(self):
        return len(self.ids)

    def __size__(self):
        return 114, 114

    def __repr__(self):
        return 'Dataset (%s, %d examples)' % (
            self.name,
            len(self)
        )


def create_default_splits(is_train=True, ratio=0.8):
    ids = all_ids()

    num_trains = int(len(ids) * ratio)

    dataset_train = Dataset(ids[:num_trains], name='train', is_train=False)
    dataset_test = Dataset(ids[num_trains:], name='test', is_train=False)
    return dataset_train, dataset_test


def all_ids():
    id_filename = 'train_list.txt'

    id_txt = os.path.join(__IMAGENET_LIST_PATH__, id_filename)
    try:
        with open(id_txt, 'r') as fp:
            _ids = [s.strip() for s in fp.readlines() if s]
    except:
        raise IOError('Dataset not found. Please make sure the dataset was downloaded.')
    rs.shuffle(_ids)
    return _ids