import os import pickle import sys from concurrent.futures import ProcessPoolExecutor from jack.util.preprocessing import sort_by_tfidf def extract_support(triviaqa_question, docs, corpus, max_num_support, max_tokens): answers = [] supports = [] paragraph_tokens = [] separator = '$|$' for doc in docs: doc_tokens = corpus.get_document(doc.doc_id) doc_tokens_flat = [t for p in doc_tokens for s in p for t in s] doc_paragraph_tokens = [[t for s in p for t in s] for p in doc_tokens] # merge many small paragraphs if max_tokens > 0: new_paragraph_tokens = [[]] for s in doc_paragraph_tokens: if len(new_paragraph_tokens[-1]) + len(s) >= max_tokens and len(new_paragraph_tokens[-1]) > 0: # start new paragraph if len(s) >= max_tokens: while s: new_paragraph_tokens.append(s[:max_tokens]) s = s[max_tokens:] else: new_paragraph_tokens.append(s) else: # merge with recent paragraph if len(new_paragraph_tokens[-1]) > 0: new_paragraph_tokens[-1].append(separator) new_paragraph_tokens[-1].extend(s) else: new_paragraph_tokens = doc_paragraph_tokens paragraph_tokens.extend(new_paragraph_tokens) p_idx_flat = [i for i, p in enumerate(new_paragraph_tokens) for t in p if t != separator] assert len(doc_tokens_flat) == len(p_idx_flat) doc_idx_offset = len(supports) supports.extend(" ".join(s) for s in new_paragraph_tokens) support_offsets = [0] for s in supports[doc_idx_offset:]: support_offsets.append(support_offsets[-1] + len(s) + 1) if doc.answer_spans is not None: for flat_s, flat_e in doc.answer_spans: p_idx = p_idx_flat[flat_s] s = flat_s - sum(1 for p2 in new_paragraph_tokens[:p_idx] for t in p2 if t != separator) p = new_paragraph_tokens[p_idx] k = 0 char_s = 0 while s > k: if p[k] == separator: s += 1 char_s += len(p[k]) + 1 k += 1 char_e = char_s + sum(len(t) + 1 for t in doc_tokens_flat[flat_s:flat_e + 1]) - 1 answers.append({ "text": " ".join(doc_tokens_flat[flat_s:flat_e + 1]), "span": [char_s, char_e], "doc_idx": p_idx + doc_idx_offset }) del doc_tokens, p_idx_flat, doc_tokens_flat if max_num_support > 0 and len(supports) > max_num_support: sorted_supports = sort_by_tfidf(" ".join(triviaqa_question.question), [' '.join(p) for p in paragraph_tokens]) sorted_supports = [i for i, _ in sorted_supports] sorted_supports_rev = {v: k for k, v in enumerate(sorted_supports)} if answers: min_answer_rev = min(sorted_supports_rev[a['doc_idx']] for a in answers) if min_answer_rev >= max_num_support: min_answer = sorted_supports[min_answer_rev] # force at least one answer by swapping best paragraph with answer to be the n-th paragraph old_nth_best = sorted_supports[max_num_support - 1] sorted_supports[min_answer_rev] = sorted_supports[max_num_support - 1] sorted_supports[max_num_support - 1] = min_answer sorted_supports_rev[old_nth_best] = min_answer_rev sorted_supports_rev[min_answer] = max_num_support - 1 sorted_supports_rev = {v: k for k, v in enumerate(sorted_supports)} supports = [supports[i] for i in sorted_supports[:max_num_support]] is_an_answer = len(answers) > 0 answers = [a for a in answers if sorted_supports_rev[a['doc_idx']] < max_num_support] for a in answers: a['doc_idx'] = sorted_supports_rev[a['doc_idx']] assert not is_an_answer or len(answers) > 0 return supports, answers def convert_triviaqa(triviaqa_question, corpus, max_num_support, max_tokens, is_web): question = " ".join(triviaqa_question.question) if is_web: for doc in triviaqa_question.web_docs: supports, answers = extract_support(triviaqa_question, [doc], corpus, max_num_support, max_tokens) filename = corpus.file_id_map[doc.doc_id] question_id = triviaqa_question.question_id + '--' + filename[4:] + ".txt" yield {"questions": [{"answers": answers, "question": {"text": question, "id": question_id}}], "support": supports} for doc in triviaqa_question.entity_docs: supports, answers = extract_support(triviaqa_question, [doc], corpus, max_num_support, max_tokens) question_id = triviaqa_question.question_id + '--' + doc.title.replace(' ', '_') + ".txt" yield {"questions": [{"answers": answers, "question": {"text": question, "id": question_id}}], "support": supports} else: question_id = triviaqa_question.question_id supports, answers = extract_support(triviaqa_question, triviaqa_question.entity_docs, corpus, max_num_support, max_tokens) yield {"questions": [{"answers": answers, "question": {"text": question, "id": question_id}}], "support": supports} def process(x, verbose=False): dataset, filemap, max_num_support, max_tokens, is_web = x instances = [] corpus = TriviaQaEvidenceCorpusTxt(filemap) for i, q in enumerate(dataset): if verbose and i % 1000 == 0: print("%d/%d done" % (i, len(dataset))) instances.extend(x for x in convert_triviaqa(q, corpus, max_num_support, max_tokens, is_web)) return instances def convert_dataset(path, filemap, name, num_processes, max_num_support, max_tokens, is_web=True): with open(path, 'rb') as f: dataset = pickle.load(f) if num_processes == 1: instances = process((dataset, filemap, max_num_support, max_tokens, is_web), True) else: chunk_size = 1000 executor = ProcessPoolExecutor(num_processes) instances = [] i = 0 for processed in executor.map( process, [(dataset[i * chunk_size:(i + 1) * chunk_size], filemap, max_num_support, max_tokens, is_web) for i in range(len(dataset) // chunk_size + 1)]): instances.extend(processed) i += chunk_size print("%d/%d done" % (min(len(dataset), i), len(dataset))) return {"meta": {"source": name}, 'instances': instances} if __name__ == '__main__': from docqa.triviaqa.evidence_corpus import TriviaQaEvidenceCorpusTxt import json dataset = sys.argv[1] if len(sys.argv) > 2: num_processes = int(sys.argv[2]) else: num_processes = 1 if len(sys.argv) > 3: max_paragraphs = int(sys.argv[3]) else: max_paragraphs = -1 if len(sys.argv) > 4: max_tokens = int(sys.argv[4]) else: max_tokens = -1 triviaqa_prepro = os.environ['TRIVIAQA_HOME'] + '/preprocessed' is_web = dataset.startswith('web') dataset, split = dataset.split('-') ds = os.path.join(triviaqa_prepro, 'triviaqa/', dataset) with open(ds + "/file_map.json") as f: filemap = json.load(f) fn = '%s-%s.json' % (dataset, split) print("Converting %s..." % fn) new_ds = convert_dataset(os.path.join(ds, split + '.pkl'), filemap, fn, num_processes, max_paragraphs, max_tokens, is_web) with open('data/triviaqa/%s' % fn, 'w') as f: json.dump(new_ds, f)