""" Usage: qa_to_oie --in=INPUT_FILE --out=OUTPUT_FILE --conll=CONLL_FILE [--dist=DIST_FILE] [--oieinput=OIE_INPUT] [-v] """ from docopt import docopt import re import itertools from oie_readers.extraction import Extraction, escape_special_chars, normalize_element from collections import defaultdict import logging import operator from fuzzywuzzy import process from fuzzywuzzy.utils import full_process import itertools from fuzzywuzzy.string_processing import StringProcessor from fuzzywuzzy.utils import asciidammit from operator import itemgetter import nltk import json import pdb from oie_readers.extraction import QUESTION_TRG_INDEX from oie_readers.extraction import QUESTION_PP_INDEX from oie_readers.extraction import QUESTION_OBJ2_INDEX ## CONSTANTS QUESTION_TRG_INDEX = 3 # index of the predicate within the question QUESTION_MODALITY_INDEX = 1 # index of the modality within the question PASS_ALL = lambda x: x MASK_ALL = lambda x: "_" get_default_mask = lambda : [PASS_ALL] * 8 # QA-SRL vocabulary for "AUX" placement, which modifies the predicates QA_SRL_AUX_MODIFIERS = [ # "are", "are n't", "can", "ca n't", "could", "could n't", # "did", "did n't", # "do", # "does", "does n't", "do n't", "had", "had n't", # "has", "has n't", # "have", "have n't", # "is", "is n't", "may", "may not", "might", "might not", "must", "must n't", "should", "should n't", # "was", "was n't", # "were", "were n't", "will", "wo n't", "would", "would n't", ] class Qa2OIE: # Static variables extractions_counter = 0 def __init__(self, qaFile, dist_file = ""): """ Loads qa file and converts it into open IE If a distribtion file is given, it is used to determine the hopefully correct order of arguments. Otherwise, these are oredered accroding to their linearization """ # This next lines ensures that the json is loaded with numerical # indexes for loc self.question_dist = dict([(q, dict([(int(loc), cnt) for (loc, cnt) in dist.iteritems()])) for (q, dist) in json.load(open(dist_file)).iteritems()]) \ if dist_file\ else {} self.dic = self.loadFile(self.getExtractions(qaFile)) def loadFile(self, lines): sent = '' d = {} indsForQuestions = defaultdict(lambda: set()) for line in lines.split('\n'): line = line.strip() if not line: continue data = line.split('\t') if len(data) == 1: if sent: for ex in d[sent]: ex.indsForQuestions = dict(indsForQuestions) sent = line d[sent] = [] indsForQuestions = defaultdict(lambda: set()) else: pred = self.preproc(data[0]) pred_indices = map(int, eval(data[1])) head_pred_index = int(data[2]) cur = Extraction((pred, [pred_indices]), head_pred_index, sent, confidence = 1.0) for q, a in zip(data[3::2], data[4::2]): preproc_arg = self.preproc(a) if not preproc_arg: logging.warn("Argument reduced to None: {}".format(a)) indices = fuzzy_match_phrase(preproc_arg.split(" "), sent.split(" ")) cur.addArg((preproc_arg, indices), q) indsForQuestions[q] = indsForQuestions[q].union(flatten(indices)) if sent: if cur.noPronounArgs(): cur.resolveAmbiguity() d[sent].append(cur) return d def preproc(self, s): """ Returns a unified preproc of a string: - Removes duplicates spaces, to allow for space delimited words. """ return " ".join([w for w in s.split(" ") if w]) def getExtractions(self, qa_srl_path, mask = get_default_mask()): """ Parse a QA-SRL file (with raw sentences) at qa_srl_path. Returns output which can in turn serve as input for load_file. """ lc = 0 sentQAs = [] curAnswers = [] curSent = "" ret = '' for line in open(qa_srl_path, 'r'): if line.startswith('#'): continue line = line.strip() info = line.strip().split("\t") if lc == 0: # Read sentence ID. sent_id = int(info[0].split("_")[1]) ptb_id = [] lc += 1 elif lc == 1: if curSent: ret += self.printSent(curSent, sentQAs) # Write sentence. curSent = line lc += 1 sentQAs = [] elif lc == 2: if curAnswers: sentQAs.append(((surfacePred, predIndex, augmented_pred_indices), curAnswers)) curAnswers = [] # Update line counter. if line.strip() == "": lc = 0 # new line for new sent else: # reading predicate and qa pairs predIndex, basePred, count = info surfacePred = basePred lc += int(count) elif lc > 2: question = encodeQuestion("\t".join(info[:-1]), mask) curSurfacePred = augment_pred_with_question(basePred, question) if len(curSurfacePred) > len(surfacePred): surfacePred = curSurfacePred answers = self.consolidate_answers(info[-1].split("###")) curAnswers.append(zip([question]*len(answers), answers)) lc -= 1 if (lc == 2): # Reached the end of this predicate's questions # TODO: make sure that base pred is in the indices returned # by fuzzy matching augmented_pred_indices = fuzzy_match_phrase(surfacePred.split(" "), curSent.split(" ")) # pdb.set_trace() if not augmented_pred_indices: augmented_pred_indices = [predIndex] else: augmented_pred_indices = augmented_pred_indices[0] # pdb.set_trace() sentQAs.append(((surfacePred, predIndex, augmented_pred_indices), curAnswers)) curAnswers = [] # Flush if sentQAs: ret += self.printSent(curSent, sentQAs) return ret def printSent(self, sent, sentQAs): ret = sent + "\n" for (pred, head_pred_index, pred_indices), predQAs in sentQAs: for element in itertools.product(*predQAs): self.encodeExtraction(element) ret += "\t".join([pred, str(pred_indices), str(head_pred_index)] + ["\t".join(x) for x in element]) + "\n" ret += "\n" return ret def encodeExtraction(self, element): questions = map(operator.itemgetter(0),element) extractionSet = set(questions) encoding = repr(extractionSet) (count, _, extractions) = extractionsDic.get(encoding, (0, extractionSet, [])) extractions.append(Qa2OIE.extractions_counter) Qa2OIE.extractions_counter += 1 extractionsDic[encoding] = (count+1, extractionSet, extractions) def consolidate_answers(self, answers): """ For a given list of answers, returns only minimal answers - e.g., ones which do not contain any other answer in the set. This deals with certain QA-SRL anntoations which include a longer span than that is needed. """ ret = [] for i, first_answer in enumerate(answers): includeFlag = True for j, second_answer in enumerate(answers): if (i != j) and (is_str_subset(second_answer, first_answer)) : includeFlag = False continue if includeFlag: ret.append(first_answer) return ret def createOIEInput(self, fn): with open(fn, 'w') as fout: for sent in self.dic: fout.write(sent + '\n') def writeOIE(self, fn): with open(fn, 'w') as fout: for sent, extractions in self.dic.iteritems(): for ex in extractions: fout.write('{}\t{}\n'.format(escape_special_chars(sent), ex.__str__())) def writeConllFile(self, fn): """ Write a conll representation of all of the extractions to file """ running_index = 0 # Running index enumerates the predicates in the dataset # Add a header file identifying each column header = '\t'.join(["word_id", "word", "pred", "pred_id", "head_pred_id", "sent_id", "run_id", "label"]) with open(fn, 'w') as fout: fout.write(header + '\n') for sent_index, extractions in enumerate(self.dic.itervalues()): for ex in extractions: fout.write(ex.conll(external_feats = [sent_index, running_index]) + '\n') running_index += 1 # MORE HELPER def augment_pred_with_question(pred, question): """ Decide what elements from the question to incorporate in the given corresponding predicate """ # Parse question wh, aux, sbj, trg, obj1, pp, obj2 = map(normalize_element, question.split(' ')[:-1]) # Last split is the question mark # Add auxiliary to the predicate if aux in QA_SRL_AUX_MODIFIERS: return " ".join([aux, pred]) # Non modified predicates return pred def is_str_subset(s1, s2): """ returns true iff the words in string s1 are contained in string s2 in the same order by which they appear in s2 """ all_indices = [find_all_indices(s2.split(" "), x) for x in s1.split()] if not all(all_indices): return False for combination in itertools.product(*all_indices): if strictly_increasing(combination): return True return False def find_all_indices(ls, elem): return [i for i,x in enumerate(ls) if x == elem] def strictly_increasing(L): return all(x<y for x, y in zip(L, L[1:])) def is_consecutive(ls): """ Returns true iff ls represents a list of consecutive numbers. """ return all((y - x == 1) for x, y in zip(ls, ls[1:])) questionsDic = {} extractionsDic = {} def encodeQuestion(question, mask): info = [mask[i](x).replace(" ","_") for i,x in enumerate(question.split("\t"))] encoding = "\t".join(info) # get the encoding of a question, and the count of times it appeared (val, count) = questionsDic.get(encoding, (len(questionsDic), 0)) questionsDic[encoding] = (val, count+1) ret = " ".join(info) return ret def all_index(s, ss, matchCase = True, ignoreSpaces = True): ''' Find all occurrences of substring ss in s ''' if not matchCase: s = s.lower() ss = ss.lower() if ignoreSpaces: s = s.replace(' ', '') ss = ss.replace(' ','') return [m.start() for m in re.finditer(re.escape(ss), s)] def fuzzy_match_phrase(phrase, sentence): """ Fuzzy find the indexes of all word in phrase against a given sentence (both are lists of words), returns a list of indexes in the length of phrase which match the best return from fuzzy. """ logging.debug("Fuzzy searching \"{}\" in \"{}\"".format(" ".join(phrase), " ".join(sentence))) limit = min((len(phrase) / 2) + 1, 3) possible_indices = [fuzzy_match_word(w, sentence, limit) \ + (fuzzy_match_word("not", sentence, limit) \ if w == "n't" \ else []) for w in phrase] indices = find_consecutive_combinations(*possible_indices) if not indices: logging.debug("\t".join(map(str, ["*** {}".format(len(indices)), " ".join(phrase), " ".join(sentence), possible_indices, indices]))) return indices def find_consecutive_combinations(*lists): """ Given a list of lists of integers, find only the consecutive options from the Cartesian product. """ ret = [] desired_length = len(lists) # this is the length of a valid walk logging.debug("desired length: {}".format(desired_length)) for first_item in lists[0]: logging.debug("starting with {}".format(first_item)) cur_walk = [first_item] cur_item = first_item for ls_ind, ls in enumerate(lists[1:]): logging.debug("ls = {}".format(ls)) for cur_candidate in ls: if cur_candidate - cur_item == 1: logging.debug("Found match: {}".format(cur_candidate)) # This is a valid option from this list, # add it and break out of this list cur_walk.append(cur_candidate) cur_item = cur_candidate break if len(cur_walk) != ls_ind + 2: # Didn't find a valid candidate - # break out of this first item break if len(cur_walk) == desired_length: ret.append(cur_walk) return ret def fuzzy_match_word(word, words, limit): """ Fuzzy find the indexes of word in words, returns a list of indexes which match the best return from fuzzy. limit controls the number of choices to allow. """ # Try finding exact matches exact_matches = set([i for (i, w) in enumerate(words) if w == word]) if exact_matches: logging.debug("Found exact match for {}".format(word)) # Else, return fuzzy matching logging.debug("No exact match for: {}".format(word)) # Allow some variance which extractOne misses # For example: "Armstrong World Industries Inc" in "Armstrong World Industries Inc. agreed in principle to sell its carpet operations to Shaw Industries Inc ." best_matches = [w for (w, s) in process.extract(word, words, processor = semi_process, limit = limit) if (s > 70)] logging.debug("Best matches = {}".format(best_matches)) return list(exact_matches.union([i for (i, w) in enumerate(words) if w in best_matches])) # Flatten a list of lists flatten = lambda l: [item for sublist in l for item in sublist] def semi_process(s, force_ascii=False): """ Variation on Fuzzywuzzy's full_process: Process string by XX removing all but letters and numbers --> These are kept to keep consecutive spans -- trim whitespace XX force to lower case --> These are kept since annotators marked verbatim spans, so case is a good signal if force_ascii == True, force convert to ascii """ if s is None: return "" if force_ascii: s = asciidammit(s) # Remove leading and trailing whitespaces. string_out = StringProcessor.strip(s) return string_out ## MAIN if __name__ == '__main__': args = docopt(__doc__) if args['-v']: logging.basicConfig(level = logging.DEBUG) else: logging.basicConfig(level = logging.INFO) logging.debug(args) inp = args['--in'] out = args['--out'] dist_file = args['--dist'] if args['--dist']\ else '' q = Qa2OIE(args['--in'], dist_file = dist_file) q.writeOIE(args['--out']) q.writeConllFile(args['--conll']) if args['--oieinput']: q.createOIEInput(args['--oieinput'])