""" Official evaluation script for v1.1 of the SQuAD dataset. """
from __future__ import print_function
from collections import Counter
import string
from zhon import hanzi as zh
import re
import argparse
import json
import sys


def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation + zh.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth, tokenizer):
    prediction_tokens = tokenizer.tokenize(normalize_answer(prediction))
    ground_truth_tokens = tokenizer.tokenize(normalize_answer(ground_truth))
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth, tokenizer):
    return (''.join(tokenizer.tokenize(normalize_answer(prediction))) ==
            ''.join(tokenizer.tokenize(normalize_answer(ground_truth))))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths, tokenizer):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth, tokenizer)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def evaluate(dataset, predictions, tokenizer):
    acc = f1 = exact_match = total = 0
    for article in dataset:
        for paragraph in article['paragraphs']:
            for qa in paragraph['qas']:
                total += 1
                if str(qa['id']) not in predictions and int(qa['id']) not in predictions:
                    message = 'Unanswered question ' + str(qa['id']) + \
                              ' will receive score 0.'
                    print(message, file=sys.stderr)
                    continue
                ground_truths = list(map(lambda x: x['text'], qa['answers']))
                try:
                    prediction = predictions[str(qa['id'])]
                except KeyError:
                    prediction = predictions[int(qa['id'])]
                if ground_truths[0].lower() in prediction:
                    acc += 1
                exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths, tokenizer)
                f1 += metric_max_over_ground_truths(
                    f1_score, prediction, ground_truths, tokenizer)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total
    acc = 100.0 * acc / total
    return {'exact_match': exact_match, 'f1': f1, 'acc': acc}


if __name__ == '__main__':
    import tokenization
    tokenizer = tokenization.FullTokenizer(vocab_file='chinese_L-12_H-768_A-12/vocab.txt')
    expected_version = '1.1'
    parser = argparse.ArgumentParser(
        description='Evaluation for SQuAD ' + expected_version)
    parser.add_argument('dataset_file', help='Dataset file')
    parser.add_argument('prediction_file', help='Prediction File')
    args = parser.parse_args()
    with open(args.dataset_file) as dataset_file:
        dataset_json = json.load(dataset_file)
        if (dataset_json['version'] != expected_version):
            print('Evaluation expects v-' + expected_version +
                  ', but got dataset with v-' + dataset_json['version'],
                  file=sys.stderr)
        dataset = dataset_json['data']
    with open(args.prediction_file) as prediction_file:
        predictions = json.load(prediction_file)
    print(json.dumps(evaluate(dataset, predictions, tokenizer)))