"""ILSVRC 2017 Classicifation Dataset.
"""

import os
from scipy import misc
import math
import numpy as np
import random
import pickle
import xml.etree.ElementTree as ET
from tqdm import trange, tqdm
from multiprocessing import Process, Array, Queue

import config as cfg


class ilsvrc_cls:

    def __init__(self, image_set, rebuild=False,
                 multithread=False, batch_size=cfg.BATCH_SIZE,
                 image_size = cfg.IMAGE_SIZE, random_noise=False):
        self.name = 'ilsvrc_2017_cls'
        self.devkit_path = cfg.ILSVRC_PATH
        self.data_path = self.devkit_path
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = batch_size
        self.image_size = image_size
        self.image_set = image_set
        self.rebuild = rebuild
        self.multithread = multithread
        self.random_noise = random_noise
        self.load_classes()
        self.cursor = 0
        self.epoch = 1
        self.gt_labels = None
        assert os.path.exists(self.devkit_path), \
            'ILSVRC path does not exist: {}'.format(self.devkit_path)
        assert os.path.exists(self.data_path), \
            'Path does not exist: {}'.format(self.data_path)
        self.prepare()

        if self.multithread:
            self.prepare_multithread()
            self.get = self._get_multithread
        else:
            self.get = self._get

    def prepare(self):
        """Create a list of ground truth that includes input path and label.
        """
        # TODO: may still need to implement test
        cache_file = os.path.join(
            self.cache_path, 'ilsvrc_cls_' + self.image_set + '_gt_labels.pkl')
        if os.path.isfile(cache_file) and not self.rebuild:
            print('Loading gt_labels from: ' + cache_file)
            with open(cache_file, 'rb') as f:
                gt_labels = pickle.load(f)
            print('{} {} dataset gt_labels loaded from {}'.
                  format(self.name, self.image_set, cache_file))
        else:
            if (self.image_set == "train"):
                imgset_fname = "train_cls.txt"
            else:
                imgset_fname = self.image_set + ".txt"
            imgset_file = os.path.join(
                self.data_path, 'ImageSets', 'CLS-LOC', imgset_fname)
            anno_dir = os.path.join(
                self.data_path, 'Annotations', 'CLS-LOC', self.image_set)
            print('Processing gt_labels using ' + imgset_file)
            gt_labels = []
            with open(imgset_file, 'r') as f:
                for line in tqdm(f.readlines()):
                    img_path = line.strip().split()[0]
                    if (self.image_set == "train"):
                        label = self.class_to_ind[img_path.split("/")[0]]
                    else:
                        anno_file = os.path.join(anno_dir, img_path + '.xml')
                        tree = ET.parse(anno_file)
                        label = tree.find('object').find('name').text
                        label = self.class_to_ind[label]
                    imname = os.path.join(
                        self.data_path, 'Data', 'CLS-LOC', self.image_set, img_path + ".JPEG")
                    gt_labels.append(
                        {'imname': imname, 'label': label})
            print('Saving gt_labels to: ' + cache_file)
            with open(cache_file, 'wb') as f:
                pickle.dump(gt_labels, f)
        random.shuffle(gt_labels)
        self.gt_labels = gt_labels
        self.image_num = len(gt_labels)
        self.total_batch = int(math.ceil(self.image_num / float(self.batch_size)))


    def _get(self):
        """Get shuffled images and labels according to batchsize.

        Return: 
            images: 4D numpy array
            labels: 1D numpy array
        """
        images = np.zeros(
            (self.batch_size, self.image_size, self.image_size, 3))
        labels = np.zeros(self.batch_size)
        count = 0
        while count < self.batch_size:
            imname = self.gt_labels[self.cursor]['imname']
            images[count, :, :, :] = self.image_read(
                imname)
            labels[count] = self.gt_labels[self.cursor]['label']
            count += 1
            self.cursor += 1
            if self.cursor >= len(self.gt_labels):
                random.shuffle(self.gt_labels)
                self.cursor = 0
                self.epoch += 1
        return images, labels

    def prepare_multithread(self):
        """Preperation for mutithread processing."""

        self.reset = False
        # num_batch_left should always be -1 until the last batch block of the epoch
        self.num_batch_left = -1
        self.num_child = 10
        self.child_processes = [None] * self.num_child
        self.batch_cursor_read = 0
        self.batch_cursor_fetched = 0
        # TODO: add this to cfg file
        self.prefetch_size = 5  # in terms of batch
        # TODO: may not need readed_batch after validating everything
        self.read_batch_array_size = self.total_batch + self.prefetch_size * self.batch_size
        self.readed_batch = Array('i', self.read_batch_array_size)
        for i in range(self.read_batch_array_size):
            self.readed_batch[i] = 0
        self.prefetched_images = np.zeros((self.batch_size * self.prefetch_size
                                           * self.num_child,
                                           self.image_size, self.image_size, 3))
        self.prefetched_labels = np.zeros(
            (self.batch_size * self.prefetch_size * self.num_child))
        self.queue_in = []
        self.queue_out = []
        for i in range(self.num_child):
            self.queue_in.append(Queue())
            self.queue_out.append(Queue())
            self.start_process(i)
            self.start_prefetch(i)

        # fetch the first one
        desc = 'receive the first half: ' + \
            str(self.num_child * self.prefetch_size / 2) + ' batches'
        for i in trange(self.num_child / 2, desc=desc):
            #     print "collecting", i
            self.collect_prefetch(i)

    def start_process(self, n):
        """Start multiprocessing prcess n."""
        self.child_processes[n] = Process(target=self.prefetch,
                                          args=(self.readed_batch,
                                                self.queue_in[n],
                                                self.queue_out[n]))
        self.child_processes[n].start()

    def start_prefetch(self, n):
        """Start prefetching in process n."""
        self.queue_in[n].put([self.cursor + self.batch_size * n * self.prefetch_size,
                              self.batch_cursor_fetched + self.prefetch_size * n])

        # maintain cusor and batch_cursor_fetched here
        # so it is easier to syncronize between threads
        if n == self.num_child - 1:
            batch_block = self.prefetch_size * self.num_child
            self.cursor += self.batch_size * batch_block
            self.batch_cursor_fetched += batch_block
            if self.total_batch <= self.batch_cursor_fetched + batch_block:
                self.reset = True
                self.num_batch_left = self.total_batch - self.batch_cursor_fetched
        # print "batch_cursor_fetched:", self.batch_cursor_fetched

    def start_prefetch_list(self, L):
        """Start multiple multiprocessing prefetches."""
        for p in L:
            self.start_prefetch(p)

    def collect_prefetch(self, n):
        """Collect prefetched data, join the processes.
        Join is not inculded because it seems faster to have
        Queue.get() perform in clusters.
        """
        images, labels = self.queue_out[n].get()
        fetch_size = self.batch_size * self.prefetch_size
        self.prefetched_images[n * fetch_size:(n + 1) * fetch_size] = images
        self.prefetched_labels[n * fetch_size:(n + 1) * fetch_size] = labels

    def collect_prefetch_list(self, L):
        """Collect and join a list of prefetcging processes."""
        for p in L:
            self.collect_prefetch(p)

    def close_all_processes(self):
        """Empty and close all queues, then terminate all child processes."""
        for i in range(self.num_child):
            self.queue_in[i].cancel_join_thread()
            self.queue_out[i].cancel_join_thread()
        for i in range(self.num_child):
            self.child_processes[i].terminate()

    def load_classes(self):
        """Use the folder name to get labels."""
        # TODO: double check if the classes are all the same as for train, test, val
        img_folder = os.path.join(
            self.data_path, 'Data', 'CLS-LOC', 'train')
        print('Loading class info from ' + img_folder)
        self.classes = [item for item in os.listdir(img_folder)
                        if os.path.isdir(os.path.join(img_folder, item))]
        self.num_class = len(self.classes)
        assert (self.num_class == 1000), "number of classes is not 1000!"
        self.class_to_ind = dict(
            list(zip(self.classes, list(range(self.num_class)))))

    def _get_multithread(self):
        """Get in multithread mode.
        Besides getting images and labels, 
        the function also manages start and end of child processes for prefetching data.

        Return: 
            images: 4D numpy array
            labels: 1D numpy array
        """

        # print "num_batch_left:", self.num_batch_left

        if self.reset:
            print "one epoch is about to finish! reseting..."
            self.collect_prefetch_list(
                range(self.num_child / 2, self.num_child))
            self.reset = False

        elif self.num_batch_left == -1:
            # run the child process
            batch_block = self.prefetch_size * self.num_child
            checker = (self.batch_cursor_read % batch_block) - 4
            # print "checker:", checker
            if checker % 5 == 0:
                # print "about to start prefetch", checker / 5
                self.start_prefetch(int(checker / 5))
                if checker / 5 == self.num_child / 2 - 1:
                    self.collect_prefetch_list(
                        range(self.num_child / 2, self.num_child))

                elif checker / 5 == self.num_child - 1:
                    self.collect_prefetch_list(range(self.num_child / 2))

        assert (self.readed_batch[self.batch_cursor_read] == 1), \
            "batch not prefetched!"
        start_index = (self.batch_cursor_read
                       % (self.prefetch_size * self.num_child)) \
            * self.batch_size
        self.batch_cursor_read += 1
        # print "batch_cursor_read:", self.batch_cursor_read

        if self.num_batch_left == self.total_batch - self.batch_cursor_read:
            # fetch and receive the last few batches of the epoch
            L = range(int(math.ceil(self.num_batch_left /
                                    float(self.prefetch_size))))
            self.start_prefetch_list(L)
            self.collect_prefetch_list(L)

        # reset after one epoch
        if self.batch_cursor_read == self.total_batch:
            self.num_batch_left = -1
            self.epoch += 1
            self.cursor = 0
            self.batch_cursor_read = 0
            self.batch_cursor_fetched = 0
            random.shuffle(self.gt_labels)
            for i in range(self.read_batch_array_size):
                self.readed_batch[i] = 0
            print "######### reset, epoch", self.epoch, "start!########"
            # prefill the fetch task for the new epoch
            for i in range(self.num_child):
                self.start_prefetch(i)
            for i in range(self.num_child / 2):
                self.collect_prefetch(i)

        return (self.prefetched_images[start_index:start_index + self.batch_size],
                self.prefetched_labels[start_index:start_index + self.batch_size])

    def prefetch(self, readed_batch, q_in, q_out):
        """Prefetch data when task coming in from q_in 
        and sent out the images and labels from q_out.
        Uses in multithread processing.
        q_in send in [cursor, batch_cursor_fetched].
        """
        fetch_size = self.batch_size * self.prefetch_size

        while True:
            cursor, batch_cursor_fetched = q_in.get()
            images = np.zeros(
                (fetch_size, self.image_size, self.image_size, 3))
            labels = np.zeros(fetch_size)
            count = 0
            while count < fetch_size:
                imname = self.gt_labels[cursor]['imname']
                images[count, :, :, :] = self.image_read(
                    imname)
                labels[count] = self.gt_labels[cursor]['label']
                count += 1
                cursor += 1
                # to simplify the multithread reading
                # the last batch will padded with the images
                # from the beginning of the same list
                if cursor >= len(self.gt_labels):
                    cursor = 0
            for i in range(batch_cursor_fetched, batch_cursor_fetched + self.prefetch_size):
                readed_batch[i] = 1

            q_out.put([images, labels])

    def add_4_side_contrast(self, x):
        """This is deprecated, please use add_4_side_contrast_mtr."""
        x_with_contrast = np.zeros([299, 299, 3])
        for r in range(299):
            for c in range(299):
                for i in range(3):
                    x_with_contrast[r,c,i] = x[r,c,i]
                    if r != 0:
                        x_with_contrast[r,c, 3+i] = abs(x[r,c,i] - x[r-1,c,i])
                    if r != 299 - 1:
                        x_with_contrast[r,c, 3*2+i] = abs(x[r,c,i] - x[r+1,c,i])
                    if c != 0:
                        x_with_contrast[r,c, 3*3+i] = abs(x[r,c,i] - x[r,c-1,i])
                    if c != 299 - 1:
                        x_with_contrast[r,c, 3*4+i] = abs(x[r,c,i] - x[r,c+1,i])
        return x_with_contrast

    def add_4_side_contrast_mtr(self, x):
        """Add the contrast of the 4 sides of each pixel value for each rgb channel.
        Thus, for each rgb channel, 4 new channels is created s.t. for pixle (r,c):
            1) absolute difference between (r,c) and (r-1,c)
            2) absolute difference between (r,c) and (r+1,c)
            3) absolute difference between (r,c) and (r,c-1)
            4) absolute difference between (r,c) and (r,c+1)
        """
        x_with_contrast = np.zeros([299, 299, 3])
        x_with_contrast[:,:,0:3] = x
        x_with_contrast[1:,:,3:6] = np.abs(x[1:,:,:] - x[:-1,:,:])
        x_with_contrast[:-1,:,6:9] = np.abs(x[:-1,:,:] - x[1:,:,:])
        x_with_contrast[:,1:,9:12] = np.abs(x[:,1:,:] - x[:,:-1,:])
        x_with_contrast[:,:-1,12:] = np.abs(x[:,:-1,:] - x[:,1:,:])
        return x_with_contrast

    def image_read(self, imname):
        image = misc.imread(imname, mode='RGB').astype(np.float)
        r,c,ch = image.shape
        if r < 299 or c < 299:
            # TODO: check too small images
            # print "##too small!!"
            image = misc.imresize(image, (299, 299, 3))
        elif r > 299 or c > 299:
            image = image[(r-299)/2 : (r-299)/2 + 299, (c-299)/2 : (c-299)/2 + 299, :]
        # print r, c, image.shape
        assert image.shape == (299, 299, 3)
        image = (image / 255.0) * 2.0 - 1.0
        if self.random_noise:
            add_noise = bool(random.getrandbits(1))
            if add_noise:
                eps = random.choice([4.0, 8.0, 12.0, 16.0]) / 255.0 * 2.0
                noise_image = image + eps * np.random.choice([-1, 1], (299,299,3))
                image = np.clip(noise_image, -1.0, 1.0)
        return image


def save_synset_to_ilsvrcid_map(meta_file):
    """Create a map from synset to ilsvrcid and save it as a pickle file.
    """

    from scipy.io import loadmat
    meta = loadmat(meta_file)

    D = {}
    for item in meta['synsets']:
        D[str(item[0][1][0])] = item[0][0][0,0]
    
    pickle_file = os.path.join(os.path.dirname(__file__), 'syn2ilsid_map.pickle')
    with open(pickle_file, 'wb') as f:
        pickle.dump(D, f)


def save_ilsvrcid_to_synset_map(meta_file):
    """Create a map from ilsvrcid to synset and save it as a pickle file.
    """

    from scipy.io import loadmat
    meta = loadmat(meta_file)

    D = {}
    for item in meta['synsets']:
        D[item[0][0][0,0]] = str(item[0][1][0]) 
    
    pickle_file = os.path.join(os.path.dirname(__file__), 'ilsid2syn_map.pickle')
    with open(pickle_file, 'wb') as f:
        pickle.dump(D, f)