# --------------------------------------------------------
# Copyright (c) 2016 by Contributors
# Copyright (c) 2017 Microsoft
# Copyright (c) 2017 ShanghaiTech PLUS Group
# Licensed under The Apache-2.0 License [see LICENSE for details]
# Written by Zheng Zhang
# Written by Songyang Zhang 
# E-main: sy.zhangbuaa#gmail.com
# --------------------------------------------------------

import cPickle
import os
import cv2
import numpy as np
import itertools

from imdb import IMDB
from PIL import Image
from tqdm import tqdm
from config.dff_config import config
import sys
sys.path.insert(0, './data/cityscapes/cityscapesscripts/helpers')
from labels import id2label, trainId2label
# import pdb; pdb.set_trace()
class CityScape_Video(IMDB):
    def __init__(self, image_set, root_path, dataset_path, result_path=None):
        """
        fill basic information to initialize imdb
        :param image_set: leftImg8bit_train, etc
        :param root_path: 'selective_search_data' and 'cache'
        :param dataset_path: data and results
        :return: imdb object
        """
        image_set_main_folder, image_set_sub_folder= image_set.split('_', 1)
        super(CityScape_Video, self).__init__('cityscape_video', image_set, root_path, dataset_path, result_path)  # set self.name

        self.image_set_main_folder = image_set_main_folder
        self.image_set_sub_folder = image_set_sub_folder
        self.root_path = root_path
        self.data_path = dataset_path
        self.num_classes = 19
        self.image_set_index = self.load_image_set_index()
        self.num_images = len(self.image_set_index)
        print('======== Total Number of images: {}======'.format(self.num_images))
        self.config = {'comp_id': 'comp4',
                       'use_diff': False,
                       'min_size': 2}
        self.global_config = config


    def load_image_set_index(self):
        """
        find out which indexes correspond to given image set
        :return: the indexes of given image set
        """

        #Collection all subfolders
        image_set_main_folder_path = os.path.join(self.data_path, self.image_set_main_folder, self.image_set_sub_folder)
        image_name_set = [filename for parent, dirname, filename in os.walk(image_set_main_folder_path)]
        image_name_set = list(itertools.chain.from_iterable(image_name_set))
        index_set = ['' for x in range(len(image_name_set))]
        valid_index_count = 0
        for i, image_name in enumerate(image_name_set):
            splited_name_set = image_name.split('_')
            ext_split = splited_name_set[len(splited_name_set) - 1].split('.')
            ext = ext_split[len(ext_split)-1]
            if splited_name_set[len(splited_name_set) - 1] != 'flip.png' and ext == 'png':
                index_set[valid_index_count] = "_".join(splited_name_set[:len(splited_name_set)-1])
                valid_index_count += 1

        return index_set[:valid_index_count]

    def image_path_from_index(self, index):
        """
        find the image path from given index
        :param index: the given index
        :return: the image path
        """
        index_folder = index.split('_')[0]
        image_file = os.path.join(self.data_path, self.image_set_main_folder, self.image_set_sub_folder, index_folder, index + '_' + self.image_set_main_folder + '.png')
        assert os.path.exists(image_file), 'Path does not exist: {}'.format(image_file)
        return image_file

    def annotation_path_from_index(self, index):
        """
        find the gt path from given index
        :param index: the given index
        :return: the image path
        """
        index_folder = index.split('_')[0]
        image_file = os.path.join(self.data_path, 'gtFine', self.image_set_sub_folder, index_folder, index + '_gtFine_labelTrainIds.png')
        assert os.path.exists(image_file), 'Path does not exist: {}'.format(image_file)
        return image_file

    def load_segdb_from_index(self, index):
        """
        load segdb from given index
        :param index: given index
        :return: segdb
        """
        seg_rec = dict()
        seg_rec['image'] = self.image_path_from_index(index)
        size = cv2.imread(seg_rec['image']).shape
        seg_rec['height'] = size[0]
        seg_rec['width'] = size[1]

        seg_rec['seg_cls_path'] = self.annotation_path_from_index(index)
        seg_rec['flipped'] = False
        # add for video
        seg_rec['split'] = self.image_set_sub_folder # train or val
        seg_rec['city'] = index.split('_')[0]
        seg_rec['seg_id'] = index.split('_')[1]
        seg_rec['frame_id'] = int(index.strip().split('_')[2])
    
        return seg_rec

    def gt_segdb(self):
        """
        return ground truth image regions database
        :return: imdb[image_index]['', 'flipped']
        """
        print("======== Starting to get gt_segdb ========")
        cache_file = os.path.join(self.cache_path, self.name + '_gt_segdb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                segdb = cPickle.load(fid)
            print '========= {} gt segdb loaded from {}'.format(self.name, cache_file)
            return segdb
        print("======== Starting to create gt_segdb ======")
        gt_segdb = []
        for index in tqdm(self.image_set_index):
            gt_segdb.append(self.load_segdb_from_index(index))

        # gt_segdb = [self.load_segdb_from_index(index) for index in self.image_set_index]
        with open(cache_file, 'wb') as fid:
            cPickle.dump(gt_segdb, fid, cPickle.HIGHEST_PROTOCOL)
        print '========= Wrote gt segdb to {}'.format(cache_file)

        return gt_segdb

    def getpallete(self, num_cls):
        """
        this function is to get the colormap for visualizing the segmentation mask
        :param num_cls: the number of visulized class
        :return: the pallete
        """
        n = num_cls
        pallete_raw = np.zeros((n, 3)).astype('uint8')
        pallete = np.zeros((n, 3)).astype('uint8')

        pallete_raw[6, :] =  [111,  74,   0]
        pallete_raw[7, :] =  [ 81,   0,  81]
        pallete_raw[8, :] =  [128,  64, 128]
        pallete_raw[9, :] =  [244,  35, 232]
        pallete_raw[10, :] =  [250, 170, 160]
        pallete_raw[11, :] = [230, 150, 140]
        pallete_raw[12, :] = [ 70,  70,  70]
        pallete_raw[13, :] = [102, 102, 156]
        pallete_raw[14, :] = [190, 153, 153]
        pallete_raw[15, :] = [180, 165, 180]
        pallete_raw[16, :] = [150, 100, 100]
        pallete_raw[17, :] = [150, 120,  90]
        pallete_raw[18, :] = [153, 153, 153]
        pallete_raw[19, :] = [153, 153, 153]
        pallete_raw[20, :] = [250, 170,  30]
        pallete_raw[21, :] = [220, 220,   0]
        pallete_raw[22, :] = [107, 142,  35]
        pallete_raw[23, :] = [152, 251, 152]
        pallete_raw[24, :] = [ 70, 130, 180]
        pallete_raw[25, :] = [220,  20,  60]
        pallete_raw[26, :] = [255,   0,   0]
        pallete_raw[27, :] = [  0,   0, 142]
        pallete_raw[28, :] = [  0,   0,  70]
        pallete_raw[29, :] = [  0,  60, 100]
        pallete_raw[30, :] = [  0,   0,  90]
        pallete_raw[31, :] = [  0,   0, 110]
        pallete_raw[32, :] = [  0,  80, 100]
        pallete_raw[33, :] = [  0,   0, 230]
        pallete_raw[34, :] = [119,  11,  32]

        train2regular = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]

        for i in range(len(train2regular)):
            pallete[i, :] = pallete_raw[train2regular[i]+1, :]

        pallete = pallete.reshape(-1)

        return pallete

    def evaluate_segmentations(self, pred_segmentations = None):
        """
        top level evaluations
        :param pred_segmentations: the pred segmentation result
        :return: the evaluation results
        """
        if not (pred_segmentations is None):
            self.write_segmentation_result(pred_segmentations)

        info = self._py_evaluate_segmentation()
        return info


    def get_confusion_matrix(self, gt_label, pred_label, class_num):
        """
        Calcute the confusion matrix by given label and pred
        :param gt_label: the ground truth label
        :param pred_label: the pred label
        :param class_num: the nunber of class
        :return: the confusion matrix
        """
        index = (gt_label * class_num + pred_label).astype('int32')
        label_count = np.bincount(index)
        confusion_matrix = np.zeros((class_num, class_num))

        for i_label in range(class_num):
            for i_pred_label in range(class_num):
                cur_index = i_label * class_num + i_pred_label
                if cur_index < len(label_count):
                    confusion_matrix[i_label, i_pred_label] = label_count[cur_index]

        return confusion_matrix

    def _py_evaluate_segmentation(self):
        """
        This function is a wrapper to calculte the metrics for given pred_segmentation results
        :return: the evaluation metrics
        """
        dff_test_offset = self.global_config.TEST.OFFSET
        dff_test_epoch = self.global_config.TEST.test_epoch
        res_file_folder = os.path.join(self.result_path, 'results_offset_{}_epoch_{}'.format(dff_test_offset,dff_test_epoch))

        confusion_matrix = np.zeros((self.num_classes,self.num_classes))
        for i, index in enumerate(self.image_set_index):
            seg_gt_info = self.load_segdb_from_index(index)

            seg_gt = np.array(Image.open(seg_gt_info['seg_cls_path'])).astype('float32')

            seg_pathes = os.path.split(seg_gt_info['seg_cls_path'])
            res_image_name = seg_pathes[1][:-len('_gtFine_labelTrainIds.png')]
            res_subfolder_name = os.path.split(seg_pathes[0])[-1]
            res_save_folder = os.path.join(res_file_folder, res_subfolder_name)
            res_save_path = os.path.join(res_save_folder, res_image_name + '.png')

            seg_pred = np.array(Image.open(res_save_path)).astype('float32')
            #seg_pred = np.squeeze(pred_segmentations[i])

            seg_pred = cv2.resize(seg_pred, (seg_gt.shape[1], seg_gt.shape[0]), interpolation=cv2.INTER_NEAREST)
            ignore_index = seg_gt != 255
            seg_gt = seg_gt[ignore_index]
            seg_pred = seg_pred[ignore_index]

            confusion_matrix += self.get_confusion_matrix(seg_gt, seg_pred, self.num_classes)

        pos = confusion_matrix.sum(1)
        res = confusion_matrix.sum(0)
        tp = np.diag(confusion_matrix)

        IU_array = (tp / np.maximum(1.0, pos + res - tp))
        mean_IU = IU_array.mean()

        return {'meanIU':mean_IU, 'IU_array':IU_array}

    def write_segmentation_result(self, segmentation_results):
        """
        Write the segmentation result to result_file_folder
        :param segmentation_results: the prediction result
        :param result_file_folder: the saving folder
        :return: [None]
        """
        dff_test_offset = self.global_config.TEST.OFFSET
        dff_test_epoch = self.global_config.TEST.test_epoch
        res_file_folder = os.path.join(self.result_path, 'results_offset_{}_epoch_{}'.format(dff_test_offset,dff_test_epoch))

        if not os.path.exists(res_file_folder):
            os.mkdir(res_file_folder)

        pallete = self.getpallete(256)
        for i, index in enumerate(self.image_set_index):
            seg_gt_info = self.load_segdb_from_index(index)

            seg_pathes = os.path.split(seg_gt_info['seg_cls_path'])
            res_image_name = seg_pathes[1][:-len('_gtFine_labelTrainIds.png')]
            res_subfolder_name = os.path.split(seg_pathes[0])[-1]
            res_save_folder = os.path.join(res_file_folder, res_subfolder_name)
            res_save_path = os.path.join(res_save_folder, res_image_name + '.png')

            if not os.path.exists(res_save_folder):
                os.makedirs(res_save_folder)

            segmentation_result = np.uint8(np.squeeze(np.copy(segmentation_results[i])))
            segmentation_result = Image.fromarray(segmentation_result)
            segmentation_result.putpalette(pallete)
            segmentation_result.save(res_save_path)