import torch
import numpy as np
import re
from collections import Counter
import string
import pickle
import random
from torch.autograd import Variable
import copy
import ujson as json
import traceback

IGNORE_INDEX = -100

RE_D = re.compile('\d')
def has_digit(string):
    return RE_D.search(string)

def prepro(token):
    return token if not has_digit(token) else 'N'

class DataIterator(object):
    def __init__(self, buckets, bsz, para_limit, ques_limit, char_limit, shuffle, sent_limit):
        self.buckets = buckets
        self.bsz = bsz
        if para_limit is not None and ques_limit is not None:
            self.para_limit = para_limit
            self.ques_limit = ques_limit
        else:
            para_limit, ques_limit = 0, 0
            for bucket in buckets:
                for dp in bucket:
                    para_limit = max(para_limit, dp['context_idxs'].size(0))
                    ques_limit = max(ques_limit, dp['ques_idxs'].size(0))
            self.para_limit, self.ques_limit = para_limit, ques_limit
        self.char_limit = char_limit
        self.sent_limit = sent_limit

        self.num_buckets = len(self.buckets)
        self.bkt_pool = [i for i in range(self.num_buckets) if len(self.buckets[i]) > 0]
        if shuffle:
            for i in range(self.num_buckets):
                random.shuffle(self.buckets[i])
        self.bkt_ptrs = [0 for i in range(self.num_buckets)]
        self.shuffle = shuffle

    def __iter__(self):
        context_idxs = torch.LongTensor(self.bsz, self.para_limit).cuda()
        ques_idxs = torch.LongTensor(self.bsz, self.ques_limit).cuda()
        context_char_idxs = torch.LongTensor(self.bsz, self.para_limit, self.char_limit).cuda()
        ques_char_idxs = torch.LongTensor(self.bsz, self.ques_limit, self.char_limit).cuda()
        y1 = torch.LongTensor(self.bsz).cuda()
        y2 = torch.LongTensor(self.bsz).cuda()
        q_type = torch.LongTensor(self.bsz).cuda()
        start_mapping = torch.Tensor(self.bsz, self.para_limit, self.sent_limit).cuda()
        end_mapping = torch.Tensor(self.bsz, self.para_limit, self.sent_limit).cuda()
        all_mapping = torch.Tensor(self.bsz, self.para_limit, self.sent_limit).cuda()
        is_support = torch.LongTensor(self.bsz, self.sent_limit).cuda()

        while True:
            if len(self.bkt_pool) == 0: break
            bkt_id = random.choice(self.bkt_pool) if self.shuffle else self.bkt_pool[0]
            start_id = self.bkt_ptrs[bkt_id]
            cur_bucket = self.buckets[bkt_id]
            cur_bsz = min(self.bsz, len(cur_bucket) - start_id)

            ids = []

            cur_batch = cur_bucket[start_id: start_id + cur_bsz]
            cur_batch.sort(key=lambda x: (x['context_idxs'] > 0).long().sum(), reverse=True)

            max_sent_cnt = 0
            for mapping in [start_mapping, end_mapping, all_mapping]:
                mapping.zero_()
            is_support.fill_(IGNORE_INDEX)

            for i in range(len(cur_batch)):
                context_idxs[i].copy_(cur_batch[i]['context_idxs'])
                ques_idxs[i].copy_(cur_batch[i]['ques_idxs'])
                context_char_idxs[i].copy_(cur_batch[i]['context_char_idxs'])
                ques_char_idxs[i].copy_(cur_batch[i]['ques_char_idxs'])
                if cur_batch[i]['y1'] >= 0:
                    y1[i] = cur_batch[i]['y1']
                    y2[i] = cur_batch[i]['y2']
                    q_type[i] = 0
                elif cur_batch[i]['y1'] == -1:
                    y1[i] = IGNORE_INDEX
                    y2[i] = IGNORE_INDEX
                    q_type[i] = 1
                elif cur_batch[i]['y1'] == -2:
                    y1[i] = IGNORE_INDEX
                    y2[i] = IGNORE_INDEX
                    q_type[i] = 2
                elif cur_batch[i]['y1'] == -3:
                    y1[i] = IGNORE_INDEX
                    y2[i] = IGNORE_INDEX
                    q_type[i] = 3
                else:
                    assert False
                ids.append(cur_batch[i]['id'])

                for j, cur_sp_dp in enumerate(cur_batch[i]['start_end_facts']):
                    if j >= self.sent_limit: break
                    if len(cur_sp_dp) == 3:
                        start, end, is_sp_flag = tuple(cur_sp_dp)
                    else:
                        start, end, is_sp_flag, is_gold = tuple(cur_sp_dp)
                    if start < end:
                        start_mapping[i, start, j] = 1
                        end_mapping[i, end-1, j] = 1
                        all_mapping[i, start:end, j] = 1
                        is_support[i, j] = int(is_sp_flag)

                max_sent_cnt = max(max_sent_cnt, len(cur_batch[i]['start_end_facts']))

            input_lengths = (context_idxs[:cur_bsz] > 0).long().sum(dim=1)
            max_c_len = int(input_lengths.max())
            max_q_len = int((ques_idxs[:cur_bsz] > 0).long().sum(dim=1).max())

            self.bkt_ptrs[bkt_id] += cur_bsz
            if self.bkt_ptrs[bkt_id] >= len(cur_bucket):
                self.bkt_pool.remove(bkt_id)

            yield {'context_idxs': context_idxs[:cur_bsz, :max_c_len].contiguous(),
                'ques_idxs': ques_idxs[:cur_bsz, :max_q_len].contiguous(),
                'context_char_idxs': context_char_idxs[:cur_bsz, :max_c_len].contiguous(),
                'ques_char_idxs': ques_char_idxs[:cur_bsz, :max_q_len].contiguous(),
                'context_lens': input_lengths,
                'y1': y1[:cur_bsz],
                'y2': y2[:cur_bsz],
                'ids': ids,
                'q_type': q_type[:cur_bsz],
                'is_support': is_support[:cur_bsz, :max_sent_cnt].contiguous(),
                'start_mapping': start_mapping[:cur_bsz, :max_c_len, :max_sent_cnt],
                'end_mapping': end_mapping[:cur_bsz, :max_c_len, :max_sent_cnt],
                'all_mapping': all_mapping[:cur_bsz, :max_c_len, :max_sent_cnt]}

def get_buckets(record_file):
    # datapoints = pickle.load(open(record_file, 'rb'))
    datapoints = torch.load(record_file)
    return [datapoints]

def convert_tokens(eval_file, qa_id, pp1, pp2, p_type):
    answer_dict = {}
    for qid, p1, p2, type in zip(qa_id, pp1, pp2, p_type):
        if type == 0:
            context = eval_file[str(qid)]["context"]
            spans = eval_file[str(qid)]["spans"]
            start_idx = spans[p1][0]
            end_idx = spans[p2][1]
            answer_dict[str(qid)] = context[start_idx: end_idx]
        elif type == 1:
            answer_dict[str(qid)] = 'yes'
        elif type == 2:
            answer_dict[str(qid)] = 'no'
        elif type == 3:
            answer_dict[str(qid)] = 'noanswer'
        else:
            assert False
    return answer_dict

def evaluate(eval_file, answer_dict):
    f1 = exact_match = total = 0
    for key, value in answer_dict.items():
        total += 1
        ground_truths = eval_file[key]["answer"]
        prediction = value
        assert len(ground_truths) == 1
        cur_EM = exact_match_score(prediction, ground_truths[0])
        cur_f1, _, _ = f1_score(prediction, ground_truths[0])
        exact_match += cur_EM
        f1 += cur_f1

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}

# def evaluate(eval_file, answer_dict, full_stats=False):
#     if full_stats:
#         with open('qaid2type.json', 'r') as f:
#             qaid2type = json.load(f)
#         f1_b = exact_match_b = total_b = 0
#         f1_4 = exact_match_4 = total_4 = 0

#         qaid2perf = {}

#     f1 = exact_match = total = 0
#     for key, value in answer_dict.items():
#         total += 1
#         ground_truths = eval_file[key]["answer"]
#         prediction = value
#         cur_EM = metric_max_over_ground_truths(
#             exact_match_score, prediction, ground_truths)
#         # cur_f1 = metric_max_over_ground_truths(f1_score,
#                                             # prediction, ground_truths)
#         assert len(ground_truths) == 1
#         cur_f1, cur_prec, cur_recall = f1_score(prediction, ground_truths[0])
#         exact_match += cur_EM
#         f1 += cur_f1
#         if full_stats and key in qaid2type:
#             if qaid2type[key] == '4':
#                 f1_4 += cur_f1
#                 exact_match_4 += cur_EM
#                 total_4 += 1
#             elif qaid2type[key] == 'b':
#                 f1_b += cur_f1
#                 exact_match_b += cur_EM
#                 total_b += 1
#             else:
#                 assert False

#         if full_stats:
#             qaid2perf[key] = {'em': cur_EM, 'f1': cur_f1, 'pred': prediction,
#                     'prec': cur_prec, 'recall': cur_recall}

#     exact_match = 100.0 * exact_match / total
#     f1 = 100.0 * f1 / total

#     ret = {'exact_match': exact_match, 'f1': f1}
#     if full_stats:
#         if total_b > 0:
#             exact_match_b = 100.0 * exact_match_b / total_b
#             exact_match_4 = 100.0 * exact_match_4 / total_4
#             f1_b = 100.0 * f1_b / total_b
#             f1_4 = 100.0 * f1_4 / total_4
#             ret.update({'exact_match_b': exact_match_b, 'f1_b': f1_b,
#                 'exact_match_4': exact_match_4, 'f1_4': f1_4,
#                 'total_b': total_b, 'total_4': total_4, 'total': total})

#         ret['qaid2perf'] = qaid2perf

#     return ret

def normalize_answer(s):

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    normalized_prediction = normalize_answer(prediction)
    normalized_ground_truth = normalize_answer(ground_truth)

    ZERO_METRIC = (0, 0, 0)

    if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC
    if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
        return ZERO_METRIC

    prediction_tokens = normalized_prediction.split()
    ground_truth_tokens = normalized_ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return ZERO_METRIC
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1, precision, recall


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)