import argparse import json from os import path, listdir from random import randint import networkx as nx import re from nltk.stem import PorterStemmer from utils import get_pbar def _get_args(): parser =argparse.ArgumentParser() parser.add_argument("data_dir") parser.add_argument("fold_path") return parser.parse_args() def _tokenize(raw): tokens = re.findall(r"[\w]+", raw) return tokens stem = True stemmer = PorterStemmer() def _normalize(word): word = word.lower() if stem: word = stemmer.stem(word) return word def load_all(data_dir): annos_dir = path.join(data_dir, 'annotations') images_dir = path.join(data_dir, 'images') questions_dir = path.join(data_dir, 'questions') anno_dict = {} questions_dict = {} choicess_dict = {} answers_dict = {} image_ids = sorted([path.splitext(name)[0] for name in listdir(images_dir) if name.endswith(".png")], key=lambda x: int(x)) pbar = get_pbar(len(image_ids)).start() for i, image_id in enumerate(image_ids): json_name = "%s.png.json" % image_id anno_path = path.join(annos_dir, json_name) ques_path = path.join(questions_dir, json_name) if path.exists(anno_path) and path.exists(ques_path): anno = json.load(open(anno_path, "r")) ques = json.load(open(ques_path, "r")) questions = [] choicess = [] answers = [] for question, d in ques['questions'].items(): if not d['abcLabel']: choices = d['answerTexts'] answer = d['correctAnswer'] questions.append(question) choicess.append(choices) answers.append(answer) questions_dict[image_id] = questions choicess_dict[image_id] = choicess answers_dict[image_id] = answers anno_dict[image_id] = anno pbar.update(i) pbar.finish() return anno_dict, questions_dict, choicess_dict, answers_dict def _get_val(anno, key): first = key[0] if first == 'T': val = anno['text'][key]['value'] val = _normalize(val) return val elif first == 'O': d = anno['objects'][key] if 'text' in d and len(d['text']) > 0: key = d['text'][0] return _get_val(anno, key) return None else: raise Exception(key) def create_graph(anno): graph = nx.Graph() try: d = anno['relationships']['interObject']['linkage'] except: return graph for dd in d.values(): if dd['category'] == 'objectToObject': dest = _get_val(anno, dd['destination'][0]) orig = _get_val(anno, dd['origin'][0]) if dest and orig: graph.add_edge(dest, orig) return graph def find_node(graph, text): words = _tokenize(text) words = [_normalize(word) for word in words] for word in words: if word in graph.nodes(): return word return None def guess(graph, question, choices): MAX = 9999 SUBMAX = 999 ques_node = find_node(graph, question) dists = [] for choice in choices: choice_node = find_node(graph, choice) if ques_node is None and choice_node is None: dist = MAX elif ques_node is None and choice_node is not None: dist = SUBMAX elif ques_node is not None and choice_node is None: dist = MAX else: if nx.has_path(graph, ques_node, choice_node): pl = len(nx.shortest_path(graph, ques_node, choice_node)) dist = pl else: dist = MAX dists.append(dist) answer, dist = min(enumerate(dists), key=lambda x: x[1]) max_dist = max(dists) if dist == MAX: return None if dist == max_dist: return None return answer def evaluate(anno_dict, questions_dict, choicess_dict, answers_dict): total = 0 correct = 0 incorrect = 0 guessed = 0 pbar = get_pbar(len(anno_dict)).start() for i, (image_id, anno) in enumerate(anno_dict.items()): graph = create_graph(anno) questions = questions_dict[image_id] choicess =choicess_dict[image_id] answers = answers_dict[image_id] for question, choices, answer in zip(questions, choicess, answers): total += 1 a = guess(graph, question, choices) if a is None: guessed += 1 elif answer == a: correct += 1 else: incorrect += 1 pbar.update(i) pbar.finish() print("expected accuracy: (0.25 * %d + %d)/%d = %.4f" % (guessed, correct, total, (0.25*guessed + correct)/total)) print("precision: %d/%d = %.4f" % (correct, correct + incorrect, correct/(correct + incorrect))) def select(fold_path, *all_): fold = json.load(open(fold_path, 'r')) test_ids = fold['test'] new_all = [] for each in all_: new_each = {id_: each[id_] for id_ in test_ids if id_ in each} new_all.append(new_each) return new_all def main(): args = _get_args() all_ = load_all(args.data_dir) selected = select(args.fold_path, *all_) evaluate(*selected) if __name__ == "__main__": main()