"""Pascal VOC dataset class
"""

import os
import xml.etree.ElementTree as ET
import numpy as np
import cv2
import pickle
import copy

import config as cfg


class pascal_voc:
    def __init__(self, image_set, batch_size=cfg.BATCH_SIZE, rebuild=False):
        self.name = 'voc_2007'
        self.devkit_path = cfg.PASCAL_PATH
        self.data_path = os.path.join(self.devkit_path, 'VOC2007')
        self.cache_path = cfg.CACHE_PATH
        self.batch_size = batch_size
        self.image_size = cfg.IMAGE_SIZE
        self.cell_size = cfg.S
        self.classes = ('aeroplane', 'bicycle', 'bird', 'boat',
                        'bottle', 'bus', 'car', 'cat', 'chair',
                        'cow', 'diningtable', 'dog', 'horse',
                        'motorbike', 'person', 'pottedplant',
                        'sheep', 'sofa', 'train', 'tvmonitor')
        self.num_class = len(self.classes)
        self.class_to_ind = dict(
            list(zip(self.classes, list(range(self.num_class)))))
        self.flipped = cfg.FLIPPED
        self.image_set = image_set
        self.rebuild = rebuild
        self.cursor = 0
        self.gt_labels = None
        assert os.path.exists(self.devkit_path), \
            'VOCdevkit path does not exist: {}'.format(self.devkit_path)
        assert os.path.exists(self.data_path), \
            'Path does not exist: {}'.format(self.data_path)
        self.prepare()

    def get(self):
        images = np.zeros(
            (self.batch_size, self.image_size, self.image_size, 3))
        labels = np.zeros(
            (self.batch_size, self.cell_size, self.cell_size, 25))
        count = 0
        while count < self.batch_size:
            imname = self.gt_labels[self.cursor]['imname']
            flipped = self.gt_labels[self.cursor]['flipped']
            images[count, :, :, :] = self.image_read(imname, flipped)
            labels[count, :, :, :] = self.gt_labels[self.cursor]['label']
            count += 1
            self.cursor += 1
            if self.cursor >= len(self.gt_labels):
                np.random.shuffle(self.gt_labels)
                self.cursor = 0
        return images, labels

    def image_read(self, imname, flipped=False):
        image = cv2.imread(imname)
        image = cv2.resize(image, (self.image_size, self.image_size))
        image = image.astype(np.float32)
        image = (image / 255.0) * 2.0 - 1.0
        if flipped:
            image = image[:, ::-1, :]
        return image

    def prepare(self):
        gt_labels = self.load_labels()
        # TODO: consider adding flipped data into the saved cache's
        if self.flipped:
            print 'Appending horizontally-flipped training examples ...'
            gt_labels_cp = copy.deepcopy(gt_labels)
            for idx in range(len(gt_labels_cp)):
                gt_labels_cp[idx]['flipped'] = True
                gt_labels_cp[idx]['label'] = gt_labels_cp[idx]['label'][:, ::-1, :]
                for i in xrange(self.cell_size):
                    for j in xrange(self.cell_size):
                        if gt_labels_cp[idx]['label'][i, j, 0] == 1:
                            gt_labels_cp[idx]['label'][i, j, 1] = self.image_size - \
                                1 - gt_labels_cp[idx]['label'][i, j, 1]
            gt_labels += gt_labels_cp
        np.random.shuffle(gt_labels)
        self.gt_labels = gt_labels
        return gt_labels

    def load_labels(self):
        cache_file = os.path.join(
            self.cache_path, 'pascal_' + self.image_set + '_gt_labels.pkl')

        if os.path.isfile(cache_file) and not self.rebuild:
            print 'Loading gt_labels from: ' + cache_file
            with open(cache_file, 'rb') as f:
                gt_labels = pickle.load(f)
            print '{} gt_labels loaded from {}'.format(self.name, cache_file)
            return gt_labels

        print 'Processing gt_labels from: ' + self.data_path

        if not os.path.exists(self.cache_path):
            os.makedirs(self.cache_path)

        txtname = os.path.join(self.data_path, 'ImageSets', 'Main',
                               self.image_set + '.txt')
        assert os.path.exists(txtname), \
            'Path does not exist: {}'.format(txtname)
        with open(txtname, 'r') as f:
            self.image_index = [x.strip() for x in f.readlines()]

        gt_labels = []
        for index in self.image_index:
            label, num = self.load_pascal_annotation(index)
            if num == 0:
                continue
            imname = os.path.join(
                self.data_path, 'JPEGImages', index + '.jpg')
            gt_labels.append(
                {'imname': imname, 'label': label, 'flipped': False})
        print 'Saving gt_labels to: ' + cache_file
        with open(cache_file, 'wb') as f:
            pickle.dump(gt_labels, f)
        return gt_labels

    def load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """

        imname = os.path.join(self.data_path, 'JPEGImages', index + '.jpg')
        im = cv2.imread(imname)
        h_ratio = 1.0 * self.image_size / im.shape[0]
        w_ratio = 1.0 * self.image_size / im.shape[1]
        # im = cv2.resize(im, [self.image_size, self.image_size])

        label = np.zeros((self.cell_size, self.cell_size, 25))
        filename = os.path.join(
            self.data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')

        for obj in objs:
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            x1 = max(min((float(bbox.find('xmin').text) - 1)
                         * w_ratio, self.image_size - 1), 0)
            y1 = max(min((float(bbox.find('ymin').text) - 1)
                         * h_ratio, self.image_size - 1), 0)
            x2 = max(min((float(bbox.find('xmax').text) - 1)
                         * w_ratio, self.image_size - 1), 0)
            y2 = max(min((float(bbox.find('ymax').text) - 1)
                         * h_ratio, self.image_size - 1), 0)
            cls_ind = self.class_to_ind[obj.find(
                'name').text.lower().strip()]
            boxes = [(x2 + x1) / 2.0, (y2 + y1) / 2.0, x2 - x1, y2 - y1]
            x_ind = int(boxes[0] * self.cell_size / self.image_size)
            y_ind = int(boxes[1] * self.cell_size / self.image_size)
            if label[y_ind, x_ind, 0] == 1:
                continue
            label[y_ind, x_ind, 0] = 1
            label[y_ind, x_ind, 1:5] = boxes
            label[y_ind, x_ind, 5 + cls_ind] = 1

        return label, len(objs)