import nltk import os import numpy as np import string import re from collections import Counter import nltk # --------------------2d spans ------------------- # read : span for each token -> char level def get_2d_spans(text, tokenss): spanss = [] cur_idx = 0 for tokens in tokenss: spans = [] for token in tokens: if text.find(token, cur_idx) < 0: print(tokens) print("{} {} {}".format(token, cur_idx, text)) raise Exception() cur_idx = text.find(token, cur_idx) spans.append((cur_idx, cur_idx + len(token))) cur_idx += len(token) spanss.append(spans) return spanss # read def get_word_span(context, wordss, start, stop): spanss = get_2d_spans(context, wordss) # [[(start,end),...],...] -> char level idxs = [] for sent_idx, spans in enumerate(spanss): for word_idx, span in enumerate(spans): if not (stop <= span[0] or start >= span[1]): idxs.append((sent_idx, word_idx)) assert len(idxs) > 0, "{} {} {} {}".format(context, spanss, start, stop) return idxs[0], (idxs[-1][0], idxs[-1][1] + 1) # (sent_start, token_start) --> (sent_stop, token_stop+1) def get_word_idx(context, wordss, idx): spanss = get_2d_spans(context, wordss) # [[(start,end),...],...] -> char level return spanss[idx[0]][idx[1]][0] # ----------------- 1d span----------------------- def get_1d_spans(text, token_seq): spans = [] curIdx = 0 for token in token_seq: token = token.replace('\xa0',' ') findRes = text.find(token,curIdx) if findRes < 0: raise RuntimeError('{} {} {}'.format(token,curIdx,text)) curIdx = findRes spans.append((curIdx, curIdx+len(token))) curIdx += len(token) return spans def get_word_idxs_1d(context, token_seq, char_start_idx, char_end_idx): """ 0 based :param context: :param token_seq: :param char_start_idx: :param char_end_idx: :return: 0-based token index sequence in the tokenized context. """ spans = get_1d_spans(context,token_seq) idxs = [] for wordIdx, span in enumerate(spans): if not (char_end_idx <= span[0] or char_start_idx >= span[1]): idxs.append(wordIdx) assert len(idxs) > 0, "{} {} {} {}".format(context, token_seq, char_start_idx, char_end_idx) return idxs def get_start_and_end_char_idx_for_word_idx_1d(context, token_seq, word_idx_seq): ''' 0 based :param context: :param token_seq: :param word_idx_seq: :return: ''' spans = get_1d_spans(context, token_seq) correct_spans = [span for idx,span in enumerate(spans) if idx in word_idx_seq] return correct_spans[0][0],correct_spans[-1][-1] # ----------------- for node target idx ----------------------- def calculate_idx_seq_f1_score(input_idx_seq, label_idx_seq, recall_factor=1.): assert len(input_idx_seq) > 0 and len(label_idx_seq)>0 # recall recall_counter = sum(1 for label_idx in label_idx_seq if label_idx in input_idx_seq) precision_counter = sum(1 for input_idx in input_idx_seq if input_idx in label_idx_seq) recall = 1.0*recall_counter/ len(label_idx_seq) precision = 1.0*precision_counter / len(input_idx_seq) recall = recall/recall_factor if recall + precision <= 0.: return 0. else: return 2.*recall*precision / (recall + precision) def get_best_node_idx(node_and_leaf_pair, answer_token_idx_seq, recall_factor=1.): """ all index in this function is 1 bases :param node_and_leaves_pair: :param answer_token_idx_seq: :return: """ f1_scores = [] for node_idx, leaf_idx_seq in node_and_leaf_pair: f1_scores.append(calculate_idx_seq_f1_score(leaf_idx_seq,answer_token_idx_seq, recall_factor)) max_idx = np.argmax(f1_scores) return node_and_leaf_pair[max_idx][0] # ------------------ calculate text f1------------------- def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r'\b(a|an|the)\b', ' ', text) def white_space_fix(text): return ' '.join(text.split()) def remove_punc(text): exclude = set(string.punctuation) return ''.join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() def tokenize(text): return ' '.join(nltk.word_tokenize(text)) return white_space_fix(remove_articles(remove_punc(lower(tokenize(s))))) def f1_score(prediction, ground_truth): prediction_tokens = normalize_answer(prediction).split() ground_truth_tokens = normalize_answer(ground_truth).split() common = Counter(prediction_tokens) & Counter(ground_truth_tokens) num_same = sum(common.values()) if num_same == 0: return 0 precision = 1.0 * num_same / len(prediction_tokens) recall = 1.0 * num_same / len(ground_truth_tokens) f1 = (2 * precision * recall) / (precision + recall) return f1 def exact_match_score(prediction, ground_truth): return (normalize_answer(prediction) == normalize_answer(ground_truth)) def check_rebuild_quality(prediction,ground_truth): em = exact_match_score(prediction,ground_truth) f1 = f1_score(prediction, ground_truth) return em,f1 def dynamic_length(lengthList, ratio, add=None, security = True, fileName=None): ratio = float(ratio) if add is not None: ratio += add ratio = ratio if ratio < 1 else 1 if security: ratio = ratio if ratio < 0.99 else 0.99 def calculate_dynamic_len(pdf ,ratio_ = ratio): cdf = [] previous = 0 # accumulate for len ,freq in pdf: previous += freq cdf.append((len, previous)) # calculate for len ,accu in cdf: if 1.0 * accu/ previous >= ratio_: # satisfy the condition return len, cdf[-1][0] # max return cdf[-1][0], cdf[-1][0] pdf = dict(nltk.FreqDist(lengthList)) pdf = sorted(pdf.items(), key=lambda d: d[0]) if fileName is not None: with open(fileName, 'w') as f: for len, freq in pdf: f.write('%d\t%d' % (len, freq)) f.write(os.linesep) return calculate_dynamic_len(pdf, ratio) def dynamic_keep(collect,ratio,fileName=None): pdf = dict(nltk.FreqDist(collect)) pdf = sorted(pdf.items(), key=lambda d: d[1],reverse=True) cdf = [] previous = 0 # accumulate for token, freq in pdf: previous += freq cdf.append((token, previous)) # calculate for idx, (token, accu) in enumerate(cdf): keepAnchor = idx if 1.0 * accu / previous >= ratio: # satisfy the condition break tokenList=[] for idx, (token, freq) in enumerate(pdf): if idx > keepAnchor: break tokenList.append(token) if fileName is not None: with open(fileName, 'w') as f: for idx, (token, freq) in enumerate(pdf): f.write('%d\t%d' % (token, freq)) f.write(os.linesep) if idx == keepAnchor: print(os.linesep*20) return tokenList def gene_question_explicit_class_tag(question_token): classes = ['what', 'how', 'who', 'when', 'which', 'where', 'why', 'whom', 'whose', ['am', 'is', 'are', 'was', 'were']] question_token = [token.lower() for token in question_token] for idx_c, cls in enumerate(classes): if not isinstance(cls, list): if cls in question_token: return idx_c else: for ccls in cls: if ccls in question_token: return idx_c return len(classes) def gene_token_freq_info(context_token, question_token): def look_up_dict(t_dict, t): try: return t_dict[t] except KeyError: return 0 context_token_dict = dict(nltk.FreqDist(context_token)) question_token_dict = dict(nltk.FreqDist(question_token)) # context tokens in context and question dicts context_tf = [] for token in context_token: context_tf.append((look_up_dict(context_token_dict, token), look_up_dict(question_token_dict, token))) # question tokens in context and question dicts question_tf = [] for token in context_token: question_tf.append((look_up_dict(context_token_dict, token), look_up_dict(question_token_dict, token))) return {'context':context_tf, 'question':question_tf}