"""ILSVRC 2017 Classicifation Dataset.
"""

import os
import cv2
import math
import numpy as np
import random
import pickle
import xml.etree.ElementTree as ET
from tqdm import trange, tqdm
from multiprocessing import Process, Array, Queue

import config as cfg


class ilsvrc_cls:

    def __init__(self, image_set, rebuild=False, data_aug=False,
                 multithread=False, batch_size=cfg.BATCH_SIZE,
                 image_size = cfg.IMAGE_SIZE, RGB=False):
        self.name = 'ilsvrc_2017_cls'
        self.devkit_path = cfg.ILSVRC_PATH
        self.data_path = self.devkit_path
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = batch_size
        self.image_size = image_size
        self.image_set = image_set
        self.rebuild = rebuild
        self.multithread = multithread
        self.data_aug = data_aug
        self.RGB = RGB
        self.load_classes()
        self.cursor = 0
        self.epoch = 1
        self.gt_labels = None
        assert os.path.exists(self.devkit_path), \
            'ILSVRC path does not exist: {}'.format(self.devkit_path)
        assert os.path.exists(self.data_path), \
            'Path does not exist: {}'.format(self.data_path)
        self.prepare()

        if self.multithread:
            self.prepare_multithread()
            self.get = self._get_multithread
        else:
            self.get = self._get

    def prepare(self):
        """Create a list of ground truth that includes input path and label.
        """
        # TODO: may still need to implement test
        cache_file = os.path.join(
            self.cache_path, 'ilsvrc_cls_' + self.image_set + '_gt_labels.pkl')
        if os.path.isfile(cache_file) and not self.rebuild:
            print('Loading gt_labels from: ' + cache_file)
            with open(cache_file, 'rb') as f:
                gt_labels = pickle.load(f)
            print('{} {} dataset gt_labels loaded from {}'.
                  format(self.name, self.image_set, cache_file))
        else:
            if (self.image_set == "train"):
                imgset_fname = "train_cls.txt"
            else:
                imgset_fname = self.image_set + ".txt"
            imgset_file = os.path.join(
                self.data_path, 'ImageSets', 'CLS-LOC', imgset_fname)
            anno_dir = os.path.join(
                self.data_path, 'Annotations', 'CLS-LOC', self.image_set)
            print('Processing gt_labels using ' + imgset_file)
            gt_labels = []
            with open(imgset_file, 'r') as f:
                for line in tqdm(f.readlines()):
                    img_path = line.strip().split()[0]
                    if (self.image_set == "train"):
                        label = self.class_to_ind[img_path.split("/")[0]]
                    else:
                        anno_file = os.path.join(anno_dir, img_path + '.xml')
                        tree = ET.parse(anno_file)
                        label = tree.find('object').find('name').text
                        label = self.class_to_ind[label]
                    imname = os.path.join(
                        self.data_path, 'Data', 'CLS-LOC', self.image_set, img_path + ".JPEG")
                    gt_labels.append(
                        {'imname': imname, 'label': label})
            print('Saving gt_labels to: ' + cache_file)
            with open(cache_file, 'wb') as f:
                pickle.dump(gt_labels, f)
        random.shuffle(gt_labels)
        self.gt_labels = gt_labels
        self.image_num = len(gt_labels)
        self.total_batch = int(math.ceil(self.image_num / float(self.batch_size)))


    def _get(self):
        """Get shuffled images and labels according to batchsize.

        Return: 
            images: 4D numpy array
            labels: 1D numpy array
        """
        images = np.zeros(
            (self.batch_size, self.image_size, self.image_size, 3))
        labels = np.zeros(self.batch_size)
        count = 0
        while count < self.batch_size:
            imname = self.gt_labels[self.cursor]['imname']
            images[count, :, :, :] = self.image_read(
                imname, data_aug=self.data_aug)
            labels[count] = self.gt_labels[self.cursor]['label']
            count += 1
            self.cursor += 1
            if self.cursor >= len(self.gt_labels):
                random.shuffle(self.gt_labels)
                self.cursor = 0
                self.epoch += 1
        return images, labels

    def prepare_multithread(self):
        """Preperation for mutithread processing."""

        self.reset = False
        # num_batch_left should always be -1 until the last batch block of the epoch
        self.num_batch_left = -1
        self.num_child = 10
        self.child_processes = [None] * self.num_child
        self.batch_cursor_read = 0
        self.batch_cursor_fetched = 0
        # TODO: add this to cfg file
        self.prefetch_size = 5  # in terms of batch
        # TODO: may not need readed_batch after validating everything
        self.read_batch_array_size = self.total_batch + self.prefetch_size * self.batch_size
        self.readed_batch = Array('i', self.read_batch_array_size)
        for i in range(self.read_batch_array_size):
            self.readed_batch[i] = 0
        self.prefetched_images = np.zeros((self.batch_size * self.prefetch_size
                                           * self.num_child,
                                           self.image_size, self.image_size, 3))
        self.prefetched_labels = np.zeros(
            (self.batch_size * self.prefetch_size * self.num_child))
        self.queue_in = []
        self.queue_out = []
        for i in range(self.num_child):
            self.queue_in.append(Queue())
            self.queue_out.append(Queue())
            self.start_process(i)
            self.start_prefetch(i)

        # fetch the first one
        desc = 'receive the first half: ' + \
            str(self.num_child * self.prefetch_size / 2) + ' batches'
        for i in trange(self.num_child / 2, desc=desc):
            #     print "collecting", i
            self.collect_prefetch(i)

    def start_process(self, n):
        """Start multiprocessing prcess n."""
        self.child_processes[n] = Process(target=self.prefetch,
                                          args=(self.readed_batch,
                                                self.queue_in[n],
                                                self.queue_out[n]))
        self.child_processes[n].start()

    def start_prefetch(self, n):
        """Start prefetching in process n."""
        self.queue_in[n].put([self.cursor + self.batch_size * n * self.prefetch_size,
                              self.batch_cursor_fetched + self.prefetch_size * n])

        # maintain cusor and batch_cursor_fetched here
        # so it is easier to syncronize between threads
        if n == self.num_child - 1:
            batch_block = self.prefetch_size * self.num_child
            self.cursor += self.batch_size * batch_block
            self.batch_cursor_fetched += batch_block
            if self.total_batch <= self.batch_cursor_fetched + batch_block:
                self.reset = True
                self.num_batch_left = self.total_batch - self.batch_cursor_fetched
        # print "batch_cursor_fetched:", self.batch_cursor_fetched

    def start_prefetch_list(self, L):
        """Start multiple multiprocessing prefetches."""
        for p in L:
            self.start_prefetch(p)

    def collect_prefetch(self, n):
        """Collect prefetched data, join the processes.
        Join is not inculded because it seems faster to have
        Queue.get() perform in clusters.
        """
        images, labels = self.queue_out[n].get()
        fetch_size = self.batch_size * self.prefetch_size
        self.prefetched_images[n * fetch_size:(n + 1) * fetch_size] = images
        self.prefetched_labels[n * fetch_size:(n + 1) * fetch_size] = labels

    def collect_prefetch_list(self, L):
        """Collect and join a list of prefetcging processes."""
        for p in L:
            self.collect_prefetch(p)

    def close_all_processes(self):
        """Empty and close all queues, then terminate all child processes."""
        for i in range(self.num_child):
            self.queue_in[i].cancel_join_thread()
            self.queue_out[i].cancel_join_thread()
        for i in range(self.num_child):
            self.child_processes[i].terminate()

    def load_classes(self):
        """Use the folder name to get labels."""
        # TODO: double check if the classes are all the same as for train, test, val
        img_folder = os.path.join(
            self.data_path, 'Data', 'CLS-LOC', 'train')
        print('Loading class info from ' + img_folder)
        self.classes = [item for item in os.listdir(img_folder)
                        if os.path.isdir(os.path.join(img_folder, item))]
        self.num_class = len(self.classes)
        assert (self.num_class == 1000), "number of classes is not 1000!"
        self.class_to_ind = dict(
            list(zip(self.classes, list(range(self.num_class)))))

    def _get_multithread(self):
        """Get in multithread mode.
        Besides getting images and labels, 
        the function also manages start and end of child processes for prefetching data.

        Return: 
            images: 4D numpy array
            labels: 1D numpy array
        """

        # print "num_batch_left:", self.num_batch_left

        if self.reset:
            print "one epoch is about to finish! reseting..."
            self.collect_prefetch_list(
                range(self.num_child / 2, self.num_child))
            self.reset = False

        elif self.num_batch_left == -1:
            # run the child process
            batch_block = self.prefetch_size * self.num_child
            checker = (self.batch_cursor_read % batch_block) - 4
            # print "checker:", checker
            if checker % 5 == 0:
                # print "about to start prefetch", checker / 5
                self.start_prefetch(int(checker / 5))
                if checker / 5 == self.num_child / 2 - 1:
                    self.collect_prefetch_list(
                        range(self.num_child / 2, self.num_child))

                elif checker / 5 == self.num_child - 1:
                    self.collect_prefetch_list(range(self.num_child / 2))

        assert (self.readed_batch[self.batch_cursor_read] == 1), \
            "batch not prefetched!"
        start_index = (self.batch_cursor_read
                       % (self.prefetch_size * self.num_child)) \
            * self.batch_size
        self.batch_cursor_read += 1
        # print "batch_cursor_read:", self.batch_cursor_read

        if self.num_batch_left == self.total_batch - self.batch_cursor_read:
            # fetch and receive the last few batches of the epoch
            L = range(int(math.ceil(self.num_batch_left /
                                    float(self.prefetch_size))))
            self.start_prefetch_list(L)
            self.collect_prefetch_list(L)

        # reset after one epoch
        if self.batch_cursor_read == self.total_batch:
            self.num_batch_left = -1
            self.epoch += 1
            self.cursor = 0
            self.batch_cursor_read = 0
            self.batch_cursor_fetched = 0
            random.shuffle(self.gt_labels)
            for i in range(self.read_batch_array_size):
                self.readed_batch[i] = 0
            print "######### reset, epoch", self.epoch, "start!########"
            # prefill the fetch task for the new epoch
            for i in range(self.num_child):
                self.start_prefetch(i)
            for i in range(self.num_child / 2):
                self.collect_prefetch(i)

        return (self.prefetched_images[start_index:start_index + self.batch_size],
                self.prefetched_labels[start_index:start_index + self.batch_size])

    def prefetch(self, readed_batch, q_in, q_out):
        """Prefetch data when task coming in from q_in 
        and sent out the images and labels from q_out.
        Uses in multithread processing.
        q_in send in [cursor, batch_cursor_fetched].
        """
        fetch_size = self.batch_size * self.prefetch_size

        while True:
            cursor, batch_cursor_fetched = q_in.get()
            images = np.zeros(
                (fetch_size, self.image_size, self.image_size, 3))
            labels = np.zeros(fetch_size)
            count = 0
            while count < fetch_size:
                imname = self.gt_labels[cursor]['imname']
                images[count, :, :, :] = self.image_read(
                    imname, data_aug=self.data_aug)
                labels[count] = self.gt_labels[cursor]['label']
                count += 1
                cursor += 1
                # to simplify the multithread reading
                # the last batch will padded with the images
                # from the beginning of the same list
                if cursor >= len(self.gt_labels):
                    cursor = 0
            for i in range(batch_cursor_fetched, batch_cursor_fetched + self.prefetch_size):
                readed_batch[i] = 1

            q_out.put([images, labels])

    def image_read(self, imname, data_aug=False):
        image = cv2.imread(imname)
        if self.RGB:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        #####################
        # Data Augmentation #
        #####################
        if data_aug:
            flip = bool(random.getrandbits(1))
            rotate_deg = random.randint(0, 359)
            # 75% chance to do random crop
            # another 25% change in maintaining input at self.image_size
            # this help simplify the input processing for test, val
            # TODO: can make multiscale test input later
            random_crop_chance = random.randint(0, 3)
            too_small = False
            color_pert = bool(random.getrandbits(1))
            exposure_shift = bool(random.getrandbits(1))

            if flip:
                image = image[:, ::-1, :]
            # assume color image
            rows, cols, _ = image.shape
            M = cv2.getRotationMatrix2D((cols / 2, rows / 2), rotate_deg, 1)
            image = cv2.warpAffine(image, M, (cols, rows))

            # color perturbation
            if color_pert:
                hue_shift_sign = bool(random.getrandbits(1))
                hue_shift = random.randint(0, 10)
                saturation_shift_sign = bool(random.getrandbits(1))
                saturation_shift = random.randint(0, 10)
                hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
                # TODO: currently not sure what cv2 does to values
                # that are larger than the maximum.
                # It seems it does not cut at the max
                # nor normalize the whole by multiplying a factor.
                # need to expore this in more detail
                if hue_shift_sign:
                    hsv[:, :, 0] += hue_shift
                else:
                    hsv[:, :, 0] -= hue_shift
                if saturation_shift_sign:
                    hsv[:, :, 1] += saturation_shift
                else:
                    hsv[:, :, 1] -= saturation_shift
                image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)

            if exposure_shift:
                brighter = bool(random.getrandbits(1))
                if brighter:
                    gamma = random.uniform(1, 2)
                else:
                    gamma = random.uniform(0.5, 1)
                image = ((image / 255.0) ** (1.0 / gamma)) * 255

            # random crop
            if random_crop_chance > 0:
                # current random crop upbound is (1.3 x self.image_size)
                short_side_len = random.randint(
                    self.image_size, cfg.RAND_CROP_UPBOUND)
                short_side = min([cols, rows])
                if short_side == cols:
                    scaled_cols = short_side_len
                    factor = float(short_side_len) / cols
                    scaled_rows = int(rows * factor)
                else:
                    scaled_rows = short_side_len
                    factor = float(short_side_len) / rows
                    scaled_cols = int(cols * factor)
                # print "scaled_cols and rows:", scaled_cols, scaled_rows
                if scaled_cols < self.image_size or scaled_rows < self.image_size:
                    too_small = True
                    print "Image is too small,", imname
                else:
                    image = cv2.resize(image, (scaled_cols, scaled_rows))
                    col_offset = random.randint(
                        0, scaled_cols - self.image_size)
                    row_offset = random.randint(
                        0, scaled_rows - self.image_size)
                    # print "col_offset and row_offset:", col_offset, row_offset
                    image = image[row_offset:self.image_size + row_offset,
                                  col_offset:self.image_size + col_offset]
                # print "image shape is", image.shape

            if random_crop_chance == 0 or too_small:
                image = cv2.resize(image, (self.image_size, self.image_size))

        else:
            image = cv2.resize(image, (self.image_size, self.image_size))

        image = image.astype(np.float32)
        image = (image / 255.0) * 2.0 - 1.0

        return image


def save_synset_to_ilsvrcid_map(meta_file):
    """Create a mape from synset to ilsvrcid and save it as a pickle file.
    """

    from scipy.io import loadmat
    meta = loadmat(meta_file)

    D = {}
    for item in meta['synsets']:
        D[str(item[0][1][0])] = item[0][0][0,0]
    
    pickle_file = os.path.join(os.path.dirname(__file__), 'syn2ilsid_map.pickle')
    with open(pickle_file, 'wb') as f:
        pickle.dump(D, f)


def save_ilsvrcid_to_synset_map(meta_file):
    """Create a mape from ilsvrcid to synset and save it as a pickle file.
    """

    from scipy.io import loadmat
    meta = loadmat(meta_file)

    D = {}
    for item in meta['synsets']:
        D[item[0][0][0,0]] = str(item[0][1][0]) 
    
    pickle_file = os.path.join(os.path.dirname(__file__), 'ilsid2syn_map.pickle')
    with open(pickle_file, 'wb') as f:
        pickle.dump(D, f)