import numpy as np
from scipy.spatial.distance import cosine as cosine


class Evaluator(object):

    def __init__(self,
                 query_dict_fn,
                 gallery_dict_fn,
                 topks=[3, 5, 10, 20, 30, 50],
                 extract_feature=False):
        """Create the empty array to count

        Args:
            query_dict_fn(dict): the mapping of the index to the id of each
                query_embed.
            gallery_dict_fn(dict): the mapping of the index to the id of each
                gallery_embed.
            tops_type(int): default retrieve top3, top5.
            extract_feature(bool): whether to save extracted garment feature
                or not.
        """

        self.topks = topks
        """ recall@k = ture_positive/k"""
        self.recall = dict()

        for k in topks:
            self.recall[k] = []

        self.query_dict, self.query_id2idx = self.get_id_dict(query_dict_fn)
        self.gallery_dict, self.gallery_id2idx = self.get_id_dict(
            gallery_dict_fn)

        self.extract_feature = extract_feature

    def load_dict(self, fn):
        dic = dict()
        rf = open(fn).readlines()
        for i, line in enumerate(rf):
            dic[i] = int(line.strip('\n'))
        return dic

    def inverse_dict(self, idx2id):
        """ invert "idx2id" dict to "id2idx" dict """
        id2idx = dict()
        for k, v in idx2id.items():  # k:idx v:id
            if v not in id2idx:
                id2idx[v] = [k]
            else:
                id2idx[v].append(k)
        return id2idx

    def single_query(self, query_id, query_feat, gallery_embeds, query_idx):
        query_dist = []
        for j, feat in enumerate(gallery_embeds):
            cosine_dist = cosine(
                feat.reshape(1, -1), query_feat.reshape(1, -1))
            query_dist.append(cosine_dist)
        query_dist = np.array(query_dist)

        order = np.argsort(query_dist)
        single_recall = dict()

        print(self.query_id2idx[query_id])
        for k in self.topks:
            retrieved_idxes = order[:k]
            tp = 0
            relevant_num = len(self.gallery_id2idx[query_id])
            for idx in retrieved_idxes:
                retrieved_id = self.gallery_dict[idx]
                if query_id == retrieved_id:
                    tp += 1

            single_recall[k] = float(tp) / relevant_num
        return single_recall

    def show_results(self):
        print('--------------- Retrieval Evaluation ------------')
        for k in self.topks:
            recall = 100 * float(sum(self.recall[k])) / len(self.recall[k])
            print('Recall@%d = %.2f' % (k, recall))

    def evaluate(self, query_embeds, gallery_embeds):
        for i, query_feat in enumerate(query_embeds):
            query_id = self.query_dict[i]
            single_recall = self.single_query(query_id, query_feat,
                                              gallery_embeds, i)

            for k in self.topks:
                self.recall[k].append(single_recall[k])
            self.show_results()

        self.show_results()

    def show_retrieved_images(self, query_feat, gallery_embeds):
        query_dist = []

        for i, feat in enumerate(gallery_embeds):
            cosine_dist = cosine(
                feat.reshape(1, -1), query_feat.reshape(1, -1))
            query_dist.append(cosine_dist)

        query_dist = np.array(query_dist)
        order = np.argsort(query_dist)

        for k in self.topks:
            retrieved_idxes = order[:k]
            for idx in retrieved_idxes:
                retrieved_id = self.gallery_dict[idx]
                print('retrieved id', retrieved_id)

    def get_id_dict(self, id_file):
        ids = []
        id_fn = open(id_file).readlines()
        id2idx, idx2id = {}, {}
        for idx, line in enumerate(id_fn):
            img_id = int(line.strip('\n'))
            ids.append(img_id)
            idx2id[idx] = img_id

            if img_id not in id2idx:
                id2idx[img_id] = [idx]
            else:
                id2idx[img_id].append(idx)
        return idx2id, id2idx