import tensorflow as tf
import numpy as np
from utils import data_handler as dh
import src.OPTICS as OPTICS
import scipy.misc
import matplotlib.pyplot as plt
from sklearn.metrics import average_precision_score as aps
FLAGS = tf.app.flags.FLAGS
from time import time
import config
import os
import pickle
import csv


class Statistic():
    def __init__(self, logits, total_loss, loss_list, input_ph, ground_truths_ph, multi_loss_class, processed_ground_truths):
        self.labelsID_outpot = logits[0]
        self.InstanceID_outpot = logits[1]
        self.Disparity_outpot = logits[2]
        self.total_loss = total_loss
        self.loss_list = loss_list
        self.input_ph = input_ph
        self.ground_truths_ph = ground_truths_ph
        self.multi_loss_class = multi_loss_class
        self.processed_ground_truths = processed_ground_truths
        self.epoch_num = 0
        # lists for saving statistic over time
        self.eval_keys = ['total_loss', 'labelsID_acc', 'InstanceID_per_pixel_rms', 'InstanceID_total_rms',
                          'Label_ap_acc', 'Disparity_per_pixel_rms', 'Disparity_total_rms']
        self.statistics = dict((k, {'val': [], 'train': []}) for k in self.eval_keys)
        self.statistics['loss_lists'] = {'val': [[] for i in range(len(loss_list))], 'train': [[] for i in range(len(loss_list))]}
        self.statistics['sigmas_list'] = [[] for i in range(len(loss_list))]
        self.statistics['weights_list'] = [[] for i in range(len(loss_list))]
        self.statistics['label_scores'] = {}
        self.statistics['label_scores']['acc'] = {'val': [], 'train': []}
        self.statistics['label_scores']['cl_acc_mean'] = {'val': [], 'train': []}
        self.statistics['label_scores']['iu_mean'] = {'val': [], 'train': []}
        self.statistics['label_scores']['cl_acc'] = {'val': [], 'train': []}
        self.statistics['label_scores']['iu'] = {'val': [], 'train': []}
        self.statistics['label_scores']['iu_no_void_mean'] = {'val': [], 'train': []}
        # self.label_scores['label_ap_score'] = {'val': [], 'train': []}
        # self.label_scores['instance_ap_score'] = {'val': [], 'train': []}
        # make plot dirs
        self.make_dirs()

    def make_dirs(self):
        if not os.path.exists(FLAGS.stat_dump_path):
            os.makedirs(FLAGS.stat_dump_path)
        if not os.path.exists(FLAGS.stat_dump_path):
            os.makedirs(FLAGS.stat_csv_path)
        if not os.path.exists(FLAGS.plots_path):
            os.makedirs(FLAGS.plots_path)
        if not os.path.exists(FLAGS.example_preds):
            os.makedirs(FLAGS.example_preds)
        if not os.path.exists(os.path.join(FLAGS.example_preds, 'label')):
            os.makedirs(os.path.join(FLAGS.example_preds, 'label'))
        if not os.path.exists(os.path.join(FLAGS.example_preds, 'disp')):
            os.makedirs(os.path.join(FLAGS.example_preds, 'disp'))
        if not os.path.exists(os.path.join(FLAGS.example_preds, 'instance', "y_reg")):
            os.makedirs(os.path.join(FLAGS.example_preds, 'instance', "y_reg"))
        if not os.path.exists(os.path.join(FLAGS.example_preds, 'instance', "x_reg")):
            os.makedirs(os.path.join(FLAGS.example_preds, 'instance', "x_reg"))

    def arange_eval_lists(self, val_eval_dict, train_eval_dict):
        [self.statistics[key]['val'].append(val_eval_dict[key]) for key in self.eval_keys if
         key not in ['Label_ap_acc', 'loss_lists']]   # Those are updated differently
        [self.statistics[key]['train'].append(train_eval_dict[key]) for key in self.eval_keys if
         key not in ['Label_ap_acc', 'loss_lists']]
        if (self.epoch_num + 1) % FLAGS.calc_ap_epoch_num == 0:
            self.statistics['Label_ap_acc']['val'].append(val_eval_dict['Label_ap'])
            self.statistics['Label_ap_acc']['train'].append(train_eval_dict['Label_ap'])
        for loss_num in range(len(val_eval_dict['loss_list'])):
            self.statistics['loss_lists']['val'][loss_num].append(val_eval_dict['loss_list'][loss_num])
            self.statistics['loss_lists']['train'][loss_num].append(train_eval_dict['loss_list'][loss_num])

    def handle_statistic(self, epoch, logits, sess, train_input_imgs=None, train_gts=None,
                        val_input_imgs=None, val_gts=None, verbose=1):
        self.epoch_num = epoch
        if epoch % FLAGS.val_epoch == 0:
            start = time()
            val_eval_dict = self.run_evaluation('val', logits, sess, input_imgs=val_input_imgs, gts=val_gts)
            middle = time()
            print('time for val eval: %.2f' %(middle-start))
            train_eval_dict = self.run_evaluation('train', logits, sess, input_imgs=train_input_imgs, gts=train_gts)
            end = time()
            print('time for train eval: %.2f' % (end - middle))
            self.set_sigmas_and_wights()
            self.arange_eval_lists(val_eval_dict, train_eval_dict)
            self.save_plots()
            # self.save_cvss()
            if verbose:
                self.print_end_epoch_stats(train_eval_dict, val_eval_dict)

        if epoch % FLAGS.example_epoch == 0:
            start = time()
            self.calc_and_save_examples(logits, sess)
            end = time()
            print('calc example time: %.2f' % (end - start))

    # def save_csv(self):
    #     with open(os.path.join(FLAGS.stat_csv_path, 'stat.csv'), 'wb') as myfile:
    #         wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
    #         wr.writerow(mylist)
#

    def print_end_epoch_stats(self, train_eval_dict, val_eval_dict):
        print('train - label accuracy: %.2f' % (train_eval_dict['labelsID_acc'])
              + ', label IoU: %.2f' % (self.statistics['label_scores']['iu_mean']['train'][-1])
              + ', best IoU: %.2f' % (np.max(self.statistics['label_scores']['iu_mean']['train'])))
        print('val   - label accuracy: %.2f' % (val_eval_dict['labelsID_acc'])
              + ', label IoU: %.2f' % (self.statistics['label_scores']['iu_mean']['val'][-1])
              + ', best IoU: %.2f' % (np.max(self.statistics['label_scores']['iu_mean']['val'])))
        print('train - Instance RMS per pixel: %.4f' % (train_eval_dict['InstanceID_per_pixel_rms'])
              + ', Instance total RMS: %.4f' % (train_eval_dict['InstanceID_total_rms']))
        print('val   - Instance RMS per pixel: %.4f' % (val_eval_dict['InstanceID_per_pixel_rms'])
              + ', Instance total RMS: %.4f' % (val_eval_dict['InstanceID_total_rms']))
        print('train - disp RMS per pixel: %.4f' % (train_eval_dict['Disparity_per_pixel_rms'])
              + ', disp total RMS: %.4f' % (train_eval_dict['Disparity_total_rms']))
        print('val   - disp RMS per pixel: %.4f' % (val_eval_dict['Disparity_per_pixel_rms'])
              + ', disp total RMS: %.4f' % (val_eval_dict['Disparity_total_rms']))

    def run_evaluation(self, set_name, logits, sess, input_imgs=None, gts=None):
        labelsID_acc_sum = 0
        InstanceID_per_pixel_rms_sum = 0
        InstanceID_total_rms_sum = 0
        disp_per_pixel_rms_sum = 0
        disp_total_rms_sum = 0
        total_loss_sum = 0
        Label_ap = 0    # average precision
        loss_list_sum = [0]*len(self.loss_list)
        n = len(config.colors[config.working_dataset])
        label_id_hist = np.zeros((n, n))  # n = 34 (num of classes)
        num_per_set = {'train': FLAGS.num_of_train_imgs,
                       'val': FLAGS.num_of_val_imgs}
        set_images_number = num_per_set[set_name]
        input_len = 0
        for ind in range(set_images_number):
            if input_imgs is not None:
                image, gt = input_imgs[ind], gts[ind]
            else:
                image, gt = dh.get_data(ind, set_name)
            processed_gts = [None] * 3
            if FLAGS.need_resize:
                processed_gts[0] = scipy.misc.imresize(gt[0].squeeze(), (FLAGS.output_height, FLAGS.output_width))
                processed_gts[1] = scipy.misc.imresize(gt[1].squeeze(), (FLAGS.output_height, FLAGS.output_width))
                processed_gts[2] = scipy.misc.imresize(gt[2].squeeze(), (FLAGS.output_height, FLAGS.output_width))
            else:
                processed_gts[0] = gt[0].squeeze()
                processed_gts[1] = gt[1].squeeze()
                processed_gts[2] = gt[2].squeeze()
            full_feed_dict = self.get_feed_dict(image, gt)
            run_list = self.get_run_list(logits)
            pred_list = sess.run(run_list, feed_dict=full_feed_dict)
            labelsID_acc_sum += self.calc_labelsID_acc(pred_list[0], pred_list[4])
            label_id_hist += self.fast_hist(pred_list[0], pred_list[4], n)
            per_pixel_rms, total_rms = self.calc_InstanceID_rms(pred_list[1], processed_gts[1])
            InstanceID_per_pixel_rms_sum += per_pixel_rms
            InstanceID_total_rms_sum += total_rms
            per_pixel_rms_disp, total_rms_disp = self.calc_Disparity_rms(pred_list[2], processed_gts[2])
            disp_per_pixel_rms_sum += per_pixel_rms_disp
            disp_total_rms_sum += total_rms_disp

            total_loss_sum += pred_list[3]
            for loss_num in range(len(loss_list_sum)):
                loss_list_sum[loss_num] += pred_list[5 + loss_num]
            input_len = input_len + 1
            #if self.epoch_num + 1 == FLAGS.num_of_epchs:
            #    self.Instance_img(set_name, ind, pred_list[1], processed_gts[1])
            if (self.epoch_num+1) % FLAGS.calc_ap_epoch_num == 0:
                Label_ap += self.calc_LabelId_ap(pred_list[0], pred_list[4])

        self.calc_and_set_label_scores(label_id_hist, set_name)
        for loss_num in range(len(loss_list_sum)):
            loss_list_sum[loss_num] = loss_list_sum[loss_num] / input_len
        return_dict = {'labelsID_acc': labelsID_acc_sum / input_len,
                       'InstanceID_per_pixel_rms': InstanceID_per_pixel_rms_sum / input_len,
                       'InstanceID_total_rms': InstanceID_total_rms_sum / input_len,
                       'Disparity_per_pixel_rms': disp_per_pixel_rms_sum / input_len,
                       'Disparity_total_rms': disp_total_rms_sum / input_len,
                       'total_loss': total_loss_sum / input_len,
                       'loss_list': loss_list_sum}
        if (self.epoch_num + 1) % FLAGS.calc_ap_epoch_num == 0:
            return_dict['Label_ap'] = Label_ap / input_len
        return return_dict

    def set_sigmas_and_wights(self):
        if FLAGS.use_multi_loss:
            for i, sigma_sq_tn in zip(range(len(self.multi_loss_class._sigmas_sq)), self.multi_loss_class._sigmas_sq):
                sigma_sq = sigma_sq_tn.eval()
                wight = 1 / (2 * sigma_sq)
                self.statistics['sigmas_list'][i].append(sigma_sq)
                self.statistics['weights_list'][i].append(wight)

    def get_run_list(self, logits):
        run_list = [logits[0], logits[1], logits[2], self.total_loss, self.processed_ground_truths[0]]
        for loss in self.loss_list:
            run_list.append(loss)
        return run_list

    def calc_and_save_examples(self, logits, sess):
        example_inputs, example_ground_truths_many = dh.get_all_data('example')
        for ind, example_input, example_ground_truths in zip(range(len(example_inputs)), example_inputs, example_ground_truths_many):
            example_preds = (sess.run([logits[0], logits[1], logits[2]], feed_dict={self.input_ph: example_input}))
            example_labelsID = self.calc_labelsID_rgb_img(example_preds[0])
            mask = example_ground_truths[1][:, :, :, 2].squeeze(0)
            if FLAGS.need_resize:
                mask = scipy.misc.imresize(mask, (FLAGS.output_height, FLAGS.output_width))
            mask = np.expand_dims(mask, 2)
            example_InstanceID = self.calc_InstanceID_example(example_preds[1].squeeze(0), mask)
            example_Disparity = example_preds[2]
            scipy.misc.imsave(os.path.join(FLAGS.example_preds, 'label', "example_%08d_epoch_%08d.png" % (ind, self.epoch_num)), example_labelsID)
            scipy.misc.imsave(os.path.join(FLAGS.example_preds, 'disp',  "example_%08d_epoch_%08d.png" % (ind, self.epoch_num)), example_Disparity.squeeze())
            if (self.epoch_num + 1) % FLAGS.example_OPTICS_epoch == 0:
                scipy.misc.imsave(os.path.join(FLAGS.example_preds, 'instance', "example_%08d_epoch_%08d.png" % (ind, self.epoch_num)), example_InstanceID[0])
            scipy.misc.imsave(os.path.join(FLAGS.example_preds, 'instance', "y_reg", "example_%08d_epoch_%08d.png" % (ind, self.epoch_num)), example_InstanceID[1])
            scipy.misc.imsave(os.path.join(FLAGS.example_preds, 'instance', "x_reg", "example_%08d_epoch_%08d.png" % (ind, self.epoch_num)), example_InstanceID[2])
            self.save_latest_example(ind, example_labelsID, example_Disparity, example_InstanceID)
        return None

    def save_latest_example(self, ind, example_labelsID, example_Disparity, example_InstanceID):
        scipy.misc.imsave(os.path.join(FLAGS.example_preds, "example_%08d_latest_label.png" % (ind)), example_labelsID)
        scipy.misc.imsave(os.path.join(FLAGS.example_preds, "example_%08d_latest_disp.png" % (ind)), example_Disparity.squeeze())
        if (self.epoch_num + 1) % FLAGS.example_OPTICS_epoch == 0:
            scipy.misc.imsave(os.path.join(FLAGS.example_preds, "example_%08d_latest_instance.png" % (ind)), example_InstanceID[0])
        plt.clf()
        plt.pcolormesh(example_InstanceID[3], cmap='jet')
        plt.gca().invert_yaxis()
        plt.savefig(os.path.join(FLAGS.example_preds, 'instance', "y_reg", "example_%08d_latest.png" % (ind)))
        plt.clf()
        plt.pcolormesh(example_InstanceID[4], cmap='jet')
        plt.gca().invert_yaxis()
        plt.savefig(os.path.join(FLAGS.example_preds, 'instance', "x_reg", "example_%08d_latest.png" % (ind)))
        np.save(os.path.join(FLAGS.example_preds, 'instance', "y_reg", "example_%08d_latest" % (ind)), example_InstanceID[3])
        np.save(os.path.join(FLAGS.example_preds, 'instance', "x_reg", "example_%08d_latest" % (ind)), example_InstanceID[4])

    def calc_InstanceID_example(self, xy_image, mask):
        raw_image = np.concatenate([xy_image, mask], axis=2)
        cmap = plt.get_cmap('jet')
        y_image = np.delete(cmap(xy_image[:, :, 0]), 3, 2)
        x_image = np.delete(cmap(xy_image[:, :, 1]), 3, 2)

        opt = None
        if (self.epoch_num + 1) % FLAGS.example_OPTICS_epoch == 0:
            opt = OPTICS.calc_clusters_img(raw_image)
        return [opt, y_image, x_image, xy_image[:, :, 0].squeeze(), xy_image[:, :, 1].squeeze()]

    def calc_labelsID_rgb_img(self, label_pred):
        labeled_img = label_pred.squeeze().argmax(axis=2)
        #conc_labeled_img = np.concatenate([np.expand_dims(labeled_img, 2), np.expand_dims(labeled_img, 2)], 2)
        #conc_labeled_img = np.concatenate([conc_labeled_img, np.expand_dims(labeled_img, 2)], 2)
        size = [label_pred.shape[1], label_pred.shape[2]]
        size.append(3)
        rgb_img = np.zeros(size)
        for ind, color in zip(range(len(config.colors[config.working_dataset])), config.colors[config.working_dataset]):
            rgb_img[labeled_img == ind] = color
        return rgb_img

    def calc_labelsID_acc(self, pred, GT):
        gt_labeled_img = GT.squeeze().argmax(axis=2)
        labeled_img = pred.squeeze().argmax(axis=2)
        return np.sum(gt_labeled_img == labeled_img) / (gt_labeled_img.size)

    def calc_InstanceID_rms(self, pred, GT):
        mask = np.expand_dims(GT[:, :, 2], axis=-1)
        num_of_valid_pixels = np.sum(mask)
        mask = np.concatenate([mask, mask], axis=2)
        r_sq_matrix = np.sum(np.square(pred.squeeze()*mask - GT[:, :, 0:2]*mask), axis=-1)
        if num_of_valid_pixels > 0:
            per_pixel_rms = np.sum(np.sqrt(r_sq_matrix))/num_of_valid_pixels
            total_rms = np.sqrt(np.sum(r_sq_matrix)/num_of_valid_pixels)
            return per_pixel_rms, total_rms
        else:
            return 0, 0

    def calc_LabelId_ap(self, pred, GT):
        '''
        Calculating Average Precision (without invalid classes (0-3))
        '''
        return aps(GT[:, :, :, 4:].reshape(-1), pred[:,:,:,4:].reshape(-1))

    def calc_Disparity_rms(self, pred, GT):
        mask = GT[:, :, 1]
        num_of_valid_pixels = np.sum(mask)
        r_sq_matrix = np.sum(np.square(pred.squeeze() * mask - GT[:, :, 0:1].squeeze() * mask), axis=-1)
        if num_of_valid_pixels > 0:
            per_pixel_rms = np.sum(np.sqrt(r_sq_matrix))/num_of_valid_pixels
            total_rms = np.sqrt(np.sum(r_sq_matrix)/num_of_valid_pixels)
            return per_pixel_rms, total_rms
        else:
            return 0, 0

    def get_feed_dict(self, input, outputs):
        feed_dict = {self.input_ph: input}
        feed_dict[self.ground_truths_ph[0]] = outputs[0]
        feed_dict[self.ground_truths_ph[1]] = outputs[1]
        feed_dict[self.ground_truths_ph[2]] = outputs[2]
        return feed_dict

    def save_dict(self, dict_to_save, name):
        f = open(os.path.join(FLAGS.stat_dump_path, name + '.pkl'), "wb")
        pickle.dump(dict_to_save, f)
        f.close()

    def save_plots(self):
        self.save_single_plots([self.statistics['total_loss']['val'], self.statistics['total_loss']['train']],
                               ['val', 'train'], os.path.join(FLAGS.plots_path, 'total_loss'), title='total_loss',
                               ylabel='Total Loss', xlabel='Epoch')

        self.save_single_plots([self.statistics['labelsID_acc']['val'], self.statistics['labelsID_acc']['train']],
                               ['val', 'train'], os.path.join(FLAGS.plots_path, 'labelsID_acc'), title='Labels ID Accuracy',
                               ylabel='Accuracy', xlabel='Epoch')

        self.save_single_plots([self.statistics['InstanceID_per_pixel_rms']['val'], self.statistics['InstanceID_per_pixel_rms']['train']],
                               ['val', 'train'],  os.path.join(FLAGS.plots_path, 'InstanceID_per_pixel_rms'), title='Instance ID per pixel RMS',
                               ylabel='RMS', xlabel='epoch')

        self.save_single_plots([self.statistics['InstanceID_total_rms']['val'], self.statistics['InstanceID_total_rms']['train']],
                               ['val', 'train'], os.path.join(FLAGS.plots_path, 'InstanceID_total_rms'),
                               title='Instance total RMS',
                               ylabel='RMS', xlabel='epoch')

        self.save_single_plots([self.statistics['Disparity_per_pixel_rms']['val'], self.statistics['Disparity_per_pixel_rms']['train']],
                               ['val', 'train'], os.path.join(FLAGS.plots_path, 'Disparity_per_pixel_rms'),
                               title='Disparity ID per pixel RMS',
                               ylabel='RMS', xlabel='epoch')

        self.save_single_plots([self.statistics['Disparity_total_rms']['val'], self.statistics['Disparity_total_rms']['train']],
                               ['val', 'train'], os.path.join(FLAGS.plots_path, 'Disparity_total_rms'),
                               title='Disparity total RMS',
                               ylabel='RMS', xlabel='epoch')

        self.save_dict(self.statistics, 'Statistics_dictionary')

        #----- ap -----
        if (self.epoch_num + 1) % FLAGS.calc_ap_epoch_num == 0:
            # self.save_single_plots([self.Instance_ap_acc_eval['val'],
            #                        self.Instance_ap_acc_eval['train']],
            #                       ['val', 'train'],  FLAGS.plots_path + '/InstanceID_ap',
            #                       title='Instance ID Average Precision', ylabel='Average Percision', xlabel='Epoch')
            self.save_single_plots([self.Label_ap_acc_eval['val'],
                               self.Label_ap_acc_eval['train']],
                              ['val', 'train'], os.path.join(FLAGS.plots_path, 'labelsID_ap'), title='Labels ID Average Precision',
                              ylabel='Average Precision', xlabel='Epoch')
        #--------------



        for loss_num in range(len(self.loss_list)):
            plot_path = os.path.join(FLAGS.plots_path, 'loss_' + str(loss_num)) #+ '_' + str(self.epoch_num)
            self.save_single_plots([self.statistics['loss_lists']['val'][loss_num], self.statistics['loss_lists']['train'][loss_num]],
                                   ['val', 'train'], plot_path,
                                   title='Loss num: ' + str(loss_num), ylabel='loss', xlabel='epoch')

        legend = []
        for i in range(len(self.statistics['sigmas_list'])):
            legend.append('sigma ' + str(i))
        self.save_single_plots(self.statistics['sigmas_list'], legend, os.path.join(FLAGS.plots_path, 'sigmas'), title='Sigmas Sq',
                               ylabel='sigma Sq value', xlabel='epoch')
        legend = []
        list_to_plot = []
        if FLAGS.use_multi_loss:
            for i in range(len(self.statistics['sigmas_list'])):
                legend.append('TU ' + str(i))
                list_to_plot.append(np.array(self.statistics['weights_list'][i])*np.array(self.statistics['loss_lists']['train'][i]+np.log(np.array(self.statistics['sigmas_list'][i]))))
            self.save_single_plots(self.statistics['sigmas_list'], legend, os.path.join(FLAGS.plots_path, 'Task_uncertainty'), title='Task uncertainty',
                                   ylabel='TU', xlabel='epoch')
        legend = []
        for i in range(len(self.statistics['weights_list'])):
            legend.append('weight ' + str(i))
        self.save_single_plots(self.statistics['weights_list'], legend, os.path.join(FLAGS.plots_path, 'weights'), title='Wights',
                               ylabel='wight value', xlabel='epoch')
        scores_name_list = list(self.statistics['label_scores'].keys())
        for name in scores_name_list:
            if name in ['iu', 'cl_acc']:
                continue
            plot_path = os.path.join(FLAGS.plots_path, 'label_acc_' + name)  # + str(self.epoch_num)
            self.save_single_plots([self.statistics['label_scores'][name]['val'], self.statistics['label_scores'][name]['train']],
                                   ['val', 'train'], plot_path, title=name, ylabel=name, xlabel='epoch')
        self.save_dict(self.statistics['label_scores'], 'label_scores')

    def save_single_plots(self, results, legend, plot_path, title='model accuracy', ylabel='accuracy', xlabel='epoch'):
        plt.clf()
        epochs = range(len(results[0]))
        for result in results:
            plt.plot(epochs, result)
        plt.title(title)
        plt.ylabel(ylabel)
        plt.xlabel(xlabel)
        plt.legend(legend, loc='upper right')
        plt.savefig(plot_path + '.png')

    def calc_and_set_label_scores(self, hist_2d, set_name):
        acc, cl_acc_mean, iu_mean, cl_acc, iu = self.get_label_score(hist_2d)
        iu_no_void_mean = self.calc_iu_no_void(iu)
        self.statistics['label_scores']['acc'][set_name].append(acc)
        self.statistics['label_scores']['cl_acc_mean'][set_name].append(cl_acc_mean)
        self.statistics['label_scores']['iu_mean'][set_name].append(iu_mean)
        self.statistics['label_scores']['iu_no_void_mean'][set_name].append(iu_no_void_mean)
        self.statistics['label_scores']['cl_acc'][set_name].append(cl_acc)
        self.statistics['label_scores']['iu'][set_name].append(iu)

    def calc_iu_no_void(self, iu):  #TODO: need to fix
        index_list = []
        for i in range(28):
            if i not in [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 12, 29, 30]:   #TODO: move to config
                index_list.append(i)
        return np.nanmean(iu[index_list])

    def get_label_score(self, hist):
        # Mean pixel accuracy
        acc = np.diag(hist).sum() / (hist.sum() + 1e-12)
        # Per class accuracy
        cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12)
        # Per class IoU
        iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12)
        return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu

    def fast_hist(self, a, b, n):
        n = len(config.colors[config.working_dataset])
        a = a.squeeze().argmax(axis=2).flatten()
        b = b.squeeze().argmax(axis=2).flatten()
        k = np.where((a >= 0) & (a < n))[0]
        bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2)
        if len(bc) != n ** 2:
            # ignore this example if dimension mismatch
            return 0
        return bc.reshape(n, n)