from __future__ import print_function
from __future__ import absolute_import
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

import xml.dom.minidom as minidom

import os
import pdb
# import PIL
import argparse
import numpy as np
import scipy.sparse
import subprocess
import math
import glob
import uuid
import scipy.io as sio
import xml.etree.ElementTree as ET
import pickle
from utils.imdb import imdb
from utils.imdb import ROOT_DIR
from utils import  ds_utils
from utils.voc_eval import voc_eval_LRP

# TODO: make fast_rcnn irrelevant
# >>>> obsolete, because it depends on sth outside of this project
from utils.config import cfg

try:
    xrange          # Python 2
except NameError:
    xrange = range  # Python 3

# <<<< obsolete
class results_struct:
    pass


class pascal_voc(imdb):
    def __init__(self, image_set, year, lrp_resultfile, devkit_path) :
        imdb.__init__(self, 'voc_' + year + '_' + image_set)
        self._year = year
        self._image_set = image_set
        self._devkit_path = self._get_default_path() if devkit_path is None \
            else devkit_path
        self._data_path = os.path.join(self._devkit_path, 'VOC' + self._year)
        self._classes = ('__background__',  # always index 0
                         'aeroplane', 'bicycle', 'bird', 'boat',
                         'bottle', 'bus', 'car', 'cat', 'chair',
                         'cow', 'diningtable', 'dog', 'horse',
                         'motorbike', 'person', 'pottedplant',
                         'sheep', 'sofa', 'train', 'tvmonitor')
        # pass value to lrp_resultsfile with user arguments
        self.lrp_resultsfile = lrp_resultfile
        self._class_to_ind = dict(zip(self.classes, xrange(self.num_classes)))
        self._image_ext = '.jpg'
        self._image_index = self._load_image_set_index()
        # Default to roidb handler
        # self._roidb_handler = self.selective_search_roidb
        self._roidb_handler = self.gt_roidb
        self._salt = str(uuid.uuid4())
        self._comp_id = 'comp4'
        #continue from here 
        self.results_fid = open(self.lrp_resultsfile, 'wb')
        # PASCAL specific config options
        self.config = {'cleanup': True,
                       'use_salt': False,
                       'use_diff': False,
                       'matlab_eval': False,
                       'rpn_file': None,
                       'min_size': 2,
                       'use_comp': False
                       }

        assert os.path.exists(self._devkit_path), \
            'VOCdevkit path does not exist: {}'.format(self._devkit_path)
        assert os.path.exists(self._data_path), \
            'Path does not exist: {}'.format(self._data_path)

    def image_path_at(self, i):
        """
        Return the absolute path to image i in the image sequence.
        """
        return self.image_path_from_index(self._image_index[i])

    def image_id_at(self, i):
        """
        Return the absolute path to image i in the image sequence.
        """
        return i

    def image_path_from_index(self, index):
        """
        Construct an image path from the image's "index" identifier.
        """
        image_path = os.path.join(self._data_path, 'JPEGImages',
                                  index + self._image_ext)
        assert os.path.exists(image_path), \
            'Path does not exist: {}'.format(image_path)
        return image_path

    def _load_image_set_index(self):
        """
        Load the indexes listed in this dataset's image set file.
        """
        # Example path to image set file:
        # self._devkit_path + /VOCdevkit2007/VOC2007/ImageSets/Main/val.txt
        image_set_file = os.path.join(self._data_path, 'ImageSets', 'Main',
                                      self._image_set + '.txt')
        assert os.path.exists(image_set_file), \
            'Path does not exist: {}'.format(image_set_file)
        with open(image_set_file) as f:
            image_index = [x.strip() for x in f.readlines()]
        return image_index

    def _get_default_path(self):
        """
        Return the default path where PASCAL VOC is expected to be installed.
        """
        return os.path.join(cfg.DATA_DIR, 'VOCdevkit' + self._year)

    def gt_roidb(self):
        """
        Return the database of ground-truth regions of interest.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path, self.name + '_gt_roidb.pkl')
        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = pickle.load(fid)
            print('{} gt roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        gt_roidb = [self._load_pascal_annotation(index)
                    for index in self.image_index]
        with open(cache_file, 'wb') as fid:
            pickle.dump(gt_roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote gt roidb to {}'.format(cache_file))

        return gt_roidb

    def selective_search_roidb(self):
        """
        Return the database of selective search regions of interest.
        Ground-truth ROIs are also included.

        This function loads/saves from/to a cache file to speed up future calls.
        """
        cache_file = os.path.join(self.cache_path,
                                  self.name + '_selective_search_roidb.pkl')

        if os.path.exists(cache_file):
            with open(cache_file, 'rb') as fid:
                roidb = pickle.load(fid)
            print('{} ss roidb loaded from {}'.format(self.name, cache_file))
            return roidb

        if int(self._year) == 2007 or self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            ss_roidb = self._load_selective_search_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, ss_roidb)
        else:
            roidb = self._load_selective_search_roidb(None)
        with open(cache_file, 'wb') as fid:
            pickle.dump(roidb, fid, pickle.HIGHEST_PROTOCOL)
        print('wrote ss roidb to {}'.format(cache_file))

        return roidb

    def rpn_roidb(self):
        if int(self._year) == 2007 or self._image_set != 'test':
            gt_roidb = self.gt_roidb()
            rpn_roidb = self._load_rpn_roidb(gt_roidb)
            roidb = imdb.merge_roidbs(gt_roidb, rpn_roidb)
        else:
            roidb = self._load_rpn_roidb(None)

        return roidb

    def _load_rpn_roidb(self, gt_roidb):
        filename = self.config['rpn_file']
        print('loading {}'.format(filename))
        assert os.path.exists(filename), \
            'rpn data not found at: {}'.format(filename)
        with open(filename, 'rb') as f:
            box_list = pickle.load(f)
        return self.create_roidb_from_box_list(box_list, gt_roidb)

    def _load_selective_search_roidb(self, gt_roidb):
        filename = os.path.abspath(os.path.join(cfg.DATA_DIR,
                                                'selective_search_data',
                                                self.name + '.mat'))
        assert os.path.exists(filename), \
            'Selective search data not found at: {}'.format(filename)
        raw_data = sio.loadmat(filename)['boxes'].ravel()

        box_list = []
        for i in xrange(raw_data.shape[0]):
            boxes = raw_data[i][:, (1, 0, 3, 2)] - 1
            keep = ds_utils.unique_boxes(boxes)
            boxes = boxes[keep, :]
            keep = ds_utils.filter_small_boxes(boxes, self.config['min_size'])
            boxes = boxes[keep, :]
            box_list.append(boxes)

        return self.create_roidb_from_box_list(box_list, gt_roidb)

    def _load_pascal_annotation(self, index):
        """
        Load image and bounding boxes info from XML file in the PASCAL VOC
        format.
        """
        filename = os.path.join(self._data_path, 'Annotations', index + '.xml')
        tree = ET.parse(filename)
        objs = tree.findall('object')
        # if not self.config['use_diff']:
        #     # Exclude the samples labeled as difficult
        #     non_diff_objs = [
        #         obj for obj in objs if int(obj.find('difficult').text) == 0]
        #     # if len(non_diff_objs) != len(objs):
        #     #     print 'Removed {} difficult objects'.format(
        #     #         len(objs) - len(non_diff_objs))
        #     objs = non_diff_objs
        num_objs = len(objs)

        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        gt_classes = np.zeros((num_objs), dtype=np.int32)
        overlaps = np.zeros((num_objs, self.num_classes), dtype=np.float32)
        # "Seg" area for pascal is just the box area
        seg_areas = np.zeros((num_objs), dtype=np.float32)
        ishards = np.zeros((num_objs), dtype=np.int32)

        # Load object bounding boxes into a data frame.
        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            # Make pixel indexes 0-based
            x1 = float(bbox.find('xmin').text) - 1
            y1 = float(bbox.find('ymin').text) - 1
            x2 = float(bbox.find('xmax').text) - 1
            y2 = float(bbox.find('ymax').text) - 1

            diffc = obj.find('difficult')
            difficult = 0 if diffc == None else int(diffc.text)
            ishards[ix] = difficult

            cls = self._class_to_ind[obj.find('name').text.lower().strip()]
            boxes[ix, :] = [x1, y1, x2, y2]
            gt_classes[ix] = cls
            overlaps[ix, cls] = 1.0
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

        overlaps = scipy.sparse.csr_matrix(overlaps)

        return {'boxes': boxes,
                'gt_classes': gt_classes,
                'gt_ishard': ishards,
                'gt_overlaps': overlaps,
                'flipped': False,
                'seg_areas': seg_areas}

    def _get_comp_id(self):
        comp_id = (self._comp_id + '_' + self._salt if self.config['use_salt']
                   else self._comp_id)
        return comp_id

    def _get_voc_results_file_template(self):
        # VOCdevkit/results/VOC2007/Main/<comp_id>_det_test_aeroplane.txt
        if self.config['use_comp'] == True:
            filename = 'det_' + self._image_set + '_{:s}.txt'
        else:
            filename = 'det_' + self._image_set + '_{:s}.txt'
        filedir = os.path.join(self._devkit_path, 'results', 'VOC' + self._year, 'Main')
        if not os.path.exists(filedir):
            os.makedirs(filedir)
        path = os.path.join(filedir, filename)
        return path

    def _write_voc_results_file(self, all_boxes):
        for cls_ind, cls in enumerate(self.classes):
            if cls == '__background__':
                continue 
            filename = self._get_voc_results_file_template().format(cls)
            print('Writing {} VOC results file...'.format(cls))
            with open(filename, 'wt') as f:
                for im_ind, index in enumerate(self.image_index):
                    dets = all_boxes[cls_ind][im_ind]
                    if dets == []:
                        continue
                    # the VOCdevkit expects 1-based indices
                    for k in xrange(dets.shape[0]):
                        f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
                                format(index, dets[k, -1],
                                       dets[k, 0] + 1, dets[k, 1] + 1,
                                       dets[k, 2] + 1, dets[k, 3] + 1))

    def _do_python_eval(self, output_dir='output', ovthresh=0.5):
        annopath = os.path.join(
            self._devkit_path,
            'VOC' + self._year,
            'Annotations',
            '{:s}.xml')
        imagesetfile = os.path.join(
            self._devkit_path,
            'VOC' + self._year,
            'ImageSets',
            'Main',
            self._image_set + '.txt')
        
        resultsfile = os.path.join(".", "lrp_results.txt")
        cachedir = os.path.join(self._devkit_path, 'annotations_cache')
        aps = []
        # The PASCAL VOC metric changed in 2010
        use_07_metric = True if int(self._year) < 2010 else False
        #print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
        if not os.path.isdir(output_dir):
            os.mkdir(output_dir)
        threshold=np.zeros(21)
        oLRP=np.zeros(21)
        oLRPLoc=np.zeros(21)
        oLRPFP=np.zeros(21)
        oLRPFN=np.zeros(21)
        ap=np.zeros(21)
        
        classpointer=1
        self.resultsLRP = results_struct()
        for i, cls in enumerate(self._classes):
            if cls == '__background__':
                continue
            filename = self._get_voc_results_file_template().format(cls)
            print("Evaluating class {}, IoU Threshold: {}".format(cls, ovthresh))
            self.resultsLRP = voc_eval_LRP(filename, annopath, imagesetfile, self._image_set, cls, cachedir, ovthresh)   
            
            # write results to textfile
            self.write_results()
            
            # append results to return
            oLRP[classpointer] = self.resultsLRP.olrp
            threshold[classpointer] = self.resultsLRP.th
            oLRPLoc[classpointer] = self.resultsLRP.olrploc
            oLRPFP[classpointer] = self.resultsLRP.olrpfp
            oLRPFN[classpointer] = self.resultsLRP.olrpfn
            ap[classpointer] = self.resultsLRP.ap
            classpointer = classpointer + 1

        return oLRP, threshold,oLRPLoc,oLRPFP,oLRPFN, ap

    def evaluate_detections(self, use_all_boxes_pickle, all_boxes_path, output_dir, tau=0.5):
        
        if use_all_boxes_pickle == True:
            all_boxes_log = []
            with(open(all_boxes_path, "rb")) as openfile:
                while True:
                    try:
                        all_boxes_log = pickle.load(openfile)
                    except EOFError:
                        break
            all_boxes = all_boxes_log
            self._write_voc_results_file(all_boxes_log)
        
        oLRP, threshold,oLRPLoc,oLRPFP,oLRPFN, ap = self._do_python_eval(output_dir, tau)
        
        # pack to write summary
        self.resultsLRP.olrp = oLRP
        self.resultsLRP.th = threshold
        self.resultsLRP.olrploc = oLRPLoc
        self.resultsLRP.olrpfp = oLRPFP
        self.resultsLRP.olrpfn = oLRPFN
        self.resultsLRP.ap = ap
        pdb.set_trace() 
        # write summary
        self.write_summary()

        #cleanup class wise text files.
        if self.config['cleanup']:
            for cls in self._classes:
                if cls == '__background__':
                    continue
                filename = self._get_voc_results_file_template().format(cls)
                os.remove(filename)
        
        # close text file
        self.results_fid.close()

        return oLRP, threshold,oLRPLoc,oLRPFP,oLRPFN, ap

    def competition_mode(self, on):
        if on == True:
            self.config['use_salt'] = False
            self.config['cleanup'] = False
            self.config['use_comp'] = True
        else:
            self.config['use_salt'] = True
            self.config['cleanup'] = True
            self.config['use_comp'] = False
    
    def write_summary(self): 
        pdb.set_trace()
        self.results_fid.write("--------------------------Overall Results------------------------------\n")
        self.results_fid.write("moLRP: {},\n" \
                                "moLRP.Loc: {}\n" \
                                "moLRP.FP: {}\n" \
                                "moLRP.FN: {}\n" \
                                "mAP: {}\n\n".format(self.resultsLRP.olrp[1:].mean(),\
                                                     self.resultsLRP.olrploc[1:].mean(),\
                                                     self.resultsLRP.olrpfp[1:].mean(),\
                                                     self.resultsLRP.olrpfn[1:].mean(),\
                                                     self.resultsLRP.ap[1:].mean()))

    def write_results(self): 
        self.results_fid.write("--------------------------Classwise Results: {}--------------------------\n".format(self.resultsLRP.cls))
        self.results_fid.write("oLRP: {}\n" \
                                "o.Threshold: {}\n" \
                                "oLRP.Loc: {}\n" \
                                "oLRP.FP: {}\n" \
                                "oLRP.FN: {}\n" \
                                "AP: {}\n\n".format(self.resultsLRP.olrp,\
                                                   self.resultsLRP.th,\
                                                   self.resultsLRP.olrploc,\
                                                   self.resultsLRP.olrpfp,\
                                                   self.resultsLRP.olrpfn,\
                                                   self.resultsLRP.ap)) 


def parse_args():
    """
    Parser input arguments.
    """
    parser = argparse.ArgumentParser(description='Evaluate PASCAL-VOC detection under LRP and AP metrics.')
    parser.add_argument('--boxes_path', dest='input_path',
                        help = 'path to all boxes pickle file',
                        default = "./results/det/detections_voc_test_base.pkl", type = str)
    parser.add_argument('--use_pickle', dest='use_pickle',
                        help = 'use all boxes pickle or not.',
                        action ='store_true')
    parser.add_argument('--tau', dest = 'iouTau',
                        help= 'iou threshold for detections.',
                        default = 0.5, type = float)
    parser.add_argument('--save_results', dest = 'resultsfile',
                        help = 'where to store lrp results text file.',
                        default = "./results/eval/lrp_results.txt",
                        type = str)
    parser.add_argument('--set', dest = 'set',
                        help = 'train, val, test etc.', 
                        default = 'test', type=str)
    parser.add_argument('--year', dest= 'year',
                        help = '2007,2012 etc.',
                        default = '2007', type = str)
    parser.add_argument('--comp', dest = 'comp',
                        help='competition mode or not.',
                        action = 'store_true')
    parser.add_argument('--devkit_path', dest ='kitpath',
                        help = 'to specify devkit path.',
                        default='./VOCdevkit', type = str)

    args = parser.parse_args()
    return args

if __name__ == '__main__':
    
    # get arguments
    args = parse_args()

    # create pascal voc object.
    d = pascal_voc(args.set, args.year, args.resultsfile, args.kitpath)

    if args.comp == True:
        d.competition_mode(True)
    else:
        d.competition_mode(False)

    # get results
    oLRP, threshold, oLRPLoc, oLRPFP, oLRPFN, ap = d.evaluate_detections(args.use_pickle, args.input_path, ".", args.iouTau)