import re from glob import glob from os import path import json import pprint from collections import namedtuple from typing import Dict, Set from enum import Enum import click from functional import seq from functional.pipeline import Sequence import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from qanta import qlogging from qanta.datasets.quiz_bowl import QuestionDatabase from qanta.util.io import safe_path log = qlogging.get(__name__) class Answer(Enum): correct = 1 unanswered_wrong = 2 unanswered_hopeless_1 = 3 unanswered_hopeless_classifier = 4 unanswered_hopeless_dan = 5 wrong_hopeless_1 = 6 wrong_hopeless_classifier = 7 wrong_hopeless_dan = 8 wrong_early = 9 wrong_late = 10 ANSWER_PLOT_ORDER = ['correct', 'wrong_late', 'wrong_early', 'unanswered_wrong', 'wrong_hopeless_1', 'unanswered_hopeless_1', 'wrong_hopeless_classifier', 'unanswered_hopeless_classifier', 'wrong_hopeless_dan', 'unanswered_hopeless_dan'] Prediction = namedtuple('Prediction', ['score', 'question', 'sentence', 'token']) Meta = namedtuple('Meta', ['question', 'sentence', 'token', 'guess']) Line = namedtuple('Line', ['question', 'sentence', 'token', 'buzz', 'guess', 'answer', 'all_guesses']) ScoredGuess = namedtuple('ScoredGuess', ['score', 'guess']) SUMMARY_REGEX = re.compile(r'test\.json') ANSWER_REGEX = re.compile(r'test\.([-+\a-z]+)\.json') def load_predictions(pred_file: str) -> Sequence: def parse_line(line: str) -> Prediction: try: tokens = line.split() score = float(tokens[0]) if len(tokens) < 2: question, sentence, token = None, None, None else: question, sentence, token = [int(x) for x in tokens[1].split('_')] return Prediction(score, question, sentence, token) except Exception: log.info("Error parsing line: {0}".format(line)) raise return seq.open(pred_file).map(parse_line) def load_meta(meta_file: str) -> Sequence: def parse_line(line: str) -> Meta: tokens = line.split() question = int(tokens[0]) sentence = int(tokens[1]) token = int(tokens[2]) guess = ' '.join(tokens[3:]) return Meta(question, sentence, token, guess) return seq.open(meta_file).map(parse_line) def load_data(pred_file: str, meta_file: str, q_db: QuestionDatabase) -> Sequence: preds = load_predictions(pred_file) metas = load_meta(meta_file) answers = q_db.all_answers() def create_line(group): question = group[0] elements = group[1] st_groups = seq(elements).group_by(lambda x: (x[0].sentence, x[0].token)).sorted() st_lines = [] for st, v in st_groups: scored_guesses = seq(v)\ .map(lambda x: ScoredGuess(x[0].score, x[1].guess)).sorted(reverse=True).list() st_lines.append(Line( question, st[0], st[1], scored_guesses[0].score > 0, scored_guesses[0].guess, answers[question], scored_guesses )) return question, st_lines def fix_missing_label(pm): prediction = pm[0] meta = pm[1] if prediction.question is None or prediction.token is None or prediction.sentence is None: log.info("WARNING: Prediction malformed, fixing with meta line: {0}".format(prediction)) prediction = Prediction(prediction.score, meta.question, meta.sentence, meta.token) assert meta.question == prediction.question assert meta.sentence == prediction.sentence assert meta.token == prediction.token return prediction, meta return preds\ .zip(metas)\ .map(fix_missing_label)\ .group_by(lambda x: x[0].question)\ .map(create_line) def load_audit(audit_file: str, meta_file: str): audit_data = {} with open(audit_file) as audit_f, open(meta_file) as meta_f: for a_line, m_line in zip(audit_f, meta_f): qid, evidence = a_line.split('\t') a_qnum, a_sentence, a_token = [int(t) for t in qid.split('_')] s_m_qnum, s_m_sentence, s_m_token, guess = m_line.split() m_qnum = int(s_m_qnum) m_sentence = int(s_m_sentence) m_token = int(s_m_token) if a_qnum != m_qnum or a_sentence != m_sentence or a_token != m_token: raise ValueError('Error occurred in audit and meta file alignment') audit_data[(a_qnum, a_sentence, a_token, guess)] = evidence.strip() return audit_data def compute_answers(data: Sequence, dan_answers: Set[str]): questions = {} for q, lines in data: lines = seq(lines) answer = lines.first().answer buzz = lines.find(lambda b: b.buzz) if buzz is None: if lines.exists(lambda g: g.guess == answer): questions[q] = Answer.unanswered_wrong elif answer not in dan_answers: questions[q] = Answer.unanswered_hopeless_dan else: questions[q] = Answer.unanswered_hopeless_1 if not lines.flat_map(lambda g: g.all_guesses).exists(lambda g: g.guess == answer): questions[q] = Answer.unanswered_hopeless_classifier elif buzz.guess == buzz.answer: questions[q] = Answer.correct else: correct_buzz = lines.find(lambda g: g.guess == answer) if correct_buzz is None: questions[q] = Answer.wrong_hopeless_1 if answer not in dan_answers: questions[q] = Answer.wrong_hopeless_dan else: if not lines.flat_map(lambda g: g.all_guesses).exists(lambda g: g.guess == answer): questions[q] = Answer.wrong_hopeless_classifier elif (correct_buzz.sentence, correct_buzz.token) < (buzz.sentence, buzz.token): questions[q] = Answer.wrong_late elif (buzz.sentence, buzz.token) < (correct_buzz.sentence, correct_buzz.token): questions[q] = Answer.wrong_early else: raise ValueError('Unexpected for buzz and correct buzz to be the same') if q not in questions: raise ValueError('Expected an answer type for question') return questions def compute_statistics(questions: Dict[int, Answer]) -> Sequence: n_questions = len(questions) empty_set = [(a, 0) for a in Answer] results = seq(questions.values())\ .map(lambda x: (x, 1)) results = (results + seq(empty_set)).reduce_by_key(lambda x, y: x + y)\ .map(lambda kv: (str(kv[0]), kv[1] / n_questions if kv[1] > 0 else 0)) return results def parse_data(stats_dir): def parse_file(file): experiment = None base_file = path.basename(file) m = SUMMARY_REGEX.match(base_file) if m: experiment = 'all features' m = ANSWER_REGEX.match(base_file) if m: experiment = m.group(1) if experiment is None: raise ValueError('Incorrect file name argument: {}'.format(base_file)) with open(file) as f: data = json.load(f) return seq(data.items()).map(lambda kv: { 'experiment': experiment, 'result': kv[0].replace('Answer.', ''), 'score': kv[1] }) rows = seq(glob(path.join(stats_dir, 'test*.json')))\ .sorted().flat_map(parse_file).to_pandas() return rows @click.group() def cli(): pass def plot_summary(summary_only, stats_dir, output): import seaborn as sns rows = parse_data(stats_dir) g = sns.factorplot(y='result', x='score', col='experiment', data=rows, kind='bar', ci=None, order=ANSWER_PLOT_ORDER, size=4, col_wrap=4, sharex=False) for ax in g.axes.flat: for label in ax.get_xticklabels(): label.set_rotation(30) plt.subplots_adjust(top=0.93) g.fig.suptitle('Feature Ablation Study') g.savefig(output, format='png', dpi=200) @cli.command() @click.option('--summary-only', is_flag=False) @click.argument('stats_dir') @click.argument('output') def plot(summary_only, stats_dir, output): plot_summary(summary_only, stats_dir, output) @cli.command() @click.option('--min-count', default=1) @click.argument('pred_file') @click.argument('meta_file') @click.argument('output') def generate(min_count, pred_file, meta_file, output): database = QuestionDatabase() data = load_data(pred_file, meta_file, database) dan_answers = set(database.page_by_count(min_count, True)) answers = compute_answers(data, dan_answers) stats = compute_statistics(answers).cache() stats.to_json(safe_path(output), root_array=False) pprint.pprint(stats) if __name__ == '__main__': cli()