"""Conll parser""" import re import argparse import time import os import io import pickle import spacy import numpy as np from tqdm import tqdm from neuralcoref.train.compat import unicode_ from neuralcoref.train.document import ( Mention, Document, Speaker, EmbeddingExtractor, MISSING_WORD, extract_mentions_spans, ) from neuralcoref.train.utils import parallel_process PACKAGE_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) REMOVED_CHAR = ["/", "%", "*"] NORMALIZE_DICT = { "/.": ".", "/?": "?", "-LRB-": "(", "-RRB-": ")", "-LCB-": "{", "-RCB-": "}", "-LSB-": "[", "-RSB-": "]", } CONLL_GENRES = {"bc": 0, "bn": 1, "mz": 2, "nw": 3, "pt": 4, "tc": 5, "wb": 6} FEATURES_NAMES = [ "mentions_features", # 0 "mentions_labels", # 1 "mentions_pairs_length", # 2 "mentions_pairs_start_index", # 3 "mentions_spans", # 4 "mentions_words", # 5 "pairs_ant_index", # 6 "pairs_features", # 7 "pairs_labels", # 8 "locations", # 9 "conll_tokens", # 10 "spacy_lookup", # 11 "doc", # 12 ] MISSED_MENTIONS_FILE = os.path.join( PACKAGE_DIRECTORY, "test_mentions_identification.txt" ) SENTENCES_PATH = os.path.join(PACKAGE_DIRECTORY, "test_sentences.txt") ################### ### UTILITIES ##### def clean_token(token): cleaned_token = token if cleaned_token in NORMALIZE_DICT: cleaned_token = NORMALIZE_DICT[cleaned_token] if cleaned_token not in REMOVED_CHAR: for char in REMOVED_CHAR: cleaned_token = cleaned_token.replace(char, "") if len(cleaned_token) == 0: cleaned_token = "," return cleaned_token def mention_words_idx(embed_extractor, mention, debug=False): # index of the word in the tuned embeddings no need for normalizing, # it is already performed in set_mentions_features() # We take them in the tuned vocabulary which is a smaller voc tailored from conll words = [] for _, w in sorted(mention.words_embeddings_.items()): if w not in embed_extractor.tun_idx: if debug: print( "No matching tokens in tuned voc for word ", w, "surrounding or inside mention", mention, ) words.append(MISSING_WORD) else: words.append(w) return [embed_extractor.tun_idx[w] for w in words] def check_numpy_array(feature, array, n_mentions_list, compressed=True): for n_mentions in n_mentions_list: if feature == FEATURES_NAMES[0]: assert array.shape[0] == len(n_mentions) if compressed: assert np.array_equiv( array[:, 3], np.array([len(n_mentions)] * len(n_mentions)) ) assert np.max(array[:, 2]) == len(n_mentions) - 1 assert np.min(array[:, 2]) == 0 elif feature == FEATURES_NAMES[1]: assert array.shape[0] == len(n_mentions) elif feature == FEATURES_NAMES[2]: assert array.shape[0] == len(n_mentions) assert np.array_equiv(array[:, 0], np.array(list(range(len(n_mentions))))) elif feature == FEATURES_NAMES[3]: assert array.shape[0] == len(n_mentions) assert np.array_equiv( array[:, 0], np.array([p * (p - 1) / 2 for p in range(len(n_mentions))]) ) elif feature == FEATURES_NAMES[4]: assert array.shape[0] == len(n_mentions) elif feature == FEATURES_NAMES[5]: assert array.shape[0] == len(n_mentions) elif feature == FEATURES_NAMES[6]: assert array.shape[0] == len(n_mentions) * (len(n_mentions) - 1) / 2 assert np.max(array) == len(n_mentions) - 2 elif feature == FEATURES_NAMES[7]: if compressed: assert array.shape[0] == len(n_mentions) * (len(n_mentions) - 1) / 2 assert np.max(array[:, 7]) == len(n_mentions) - 2 assert np.min(array[:, 7]) == 0 elif feature == FEATURES_NAMES[8]: assert array.shape[0] == len(n_mentions) * (len(n_mentions) - 1) / 2 ############################################################################################### ### PARALLEL FCT (has to be at top-level of the module to be pickled for multiprocessing) ##### def load_file(full_name, debug=False): """ load a *._conll file Input: full_name: path to the file Output: list of tuples for each conll doc in the file, where the tuple contains: (utts_text ([str]): list of the utterances in the document utts_tokens ([[str]]): list of the tokens (conll words) in the document utts_corefs: list of coref objects (dicts) with the following properties: coref['label']: id of the coreference cluster, coref['start']: start index (index of first token in the utterance), coref['end': end index (index of last token in the utterance). utts_speakers ([str]): list of the speaker associated to each utterances in the document name (str): name of the document part (str): part of the document ) """ docs = [] with io.open(full_name, "rt", encoding="utf-8", errors="strict") as f: lines = list(f) # .readlines() utts_text = [] utts_tokens = [] utts_corefs = [] utts_speakers = [] tokens = [] corefs = [] index = 0 speaker = "" name = "" part = "" for li, line in enumerate(lines): cols = line.split() if debug: print("line", li, "cols:", cols) # End of utterance if len(cols) == 0: if tokens: if debug: print("End of utterance") utts_text.append("".join(t + " " for t in tokens)) utts_tokens.append(tokens) utts_speakers.append(speaker) utts_corefs.append(corefs) tokens = [] corefs = [] index = 0 speaker = "" continue # End of doc elif len(cols) == 2: if debug: print("End of doc") if cols[0] == "#end": if debug: print("Saving doc") docs.append( (utts_text, utts_tokens, utts_corefs, utts_speakers, name, part) ) utts_text = [] utts_tokens = [] utts_corefs = [] utts_speakers = [] else: raise ValueError("Error on end line " + line) # New doc elif len(cols) == 5: if debug: print("New doc") if cols[0] == "#begin": name = re.match(r"\((.*)\);", cols[2]).group(1) try: part = cols[4] except ValueError: print("Error parsing document part " + line) if debug: print("New doc", name, part, name[:2]) tokens = [] corefs = [] index = 0 else: raise ValueError("Error on begin line " + line) # Inside utterance elif len(cols) > 7: if debug: print("Inside utterance") assert cols[0] == name and int(cols[1]) == int(part), ( "Doc name or part error " + line ) assert int(cols[2]) == index, "Index error on " + line if speaker: assert cols[9] == speaker, "Speaker changed in " + line + speaker else: speaker = cols[9] if debug: print("speaker", speaker) if cols[-1] != "-": coref_expr = cols[-1].split("|") if debug: print("coref_expr", coref_expr) if not coref_expr: raise ValueError("Coref expression empty " + line) for tok in coref_expr: if debug: print("coref tok", tok) try: match = re.match(r"^(\(?)(\d+)(\)?)$", tok) except: print("error getting coreferences for line " + line) assert match is not None, ( "Error parsing coref " + tok + " in " + line ) num = match.group(2) assert num is not "", ( "Error parsing coref " + tok + " in " + line ) if match.group(1) == "(": if debug: print("New coref", num) corefs.append({"label": num, "start": index, "end": None}) if match.group(3) == ")": j = None for i in range(len(corefs) - 1, -1, -1): if debug: print("i", i) if ( corefs[i]["label"] == num and corefs[i]["end"] is None ): j = i break assert j is not None, "coref closing error " + line if debug: print("End coref", num) corefs[j]["end"] = index tokens.append(clean_token(cols[3])) index += 1 else: raise ValueError("Line not standard " + line) return docs def set_feats(doc): doc.set_mentions_features() def get_feats(doc, i): return doc.get_feature_array(doc_id=i) def gather_feats(gathering_array, array, feat_name, pairs_ant_index, pairs_start_index): if gathering_array is None: gathering_array = array else: if feat_name == FEATURES_NAMES[6]: array = [a + pairs_ant_index for a in array] elif feat_name == FEATURES_NAMES[3]: array = [a + pairs_start_index for a in array] gathering_array += array return feat_name, gathering_array def read_file(full_name): doc = "" with io.open(full_name, "rt", encoding="utf-8", errors="strict") as f: doc = f.read() return doc ################### ### ConllDoc ##### class ConllDoc(Document): def __init__(self, name, part, *args, **kwargs): self.name = name self.part = part self.feature_matrix = {} self.conll_tokens = [] self.conll_lookup = [] self.gold_corefs = [] self.missed_gold = [] super(ConllDoc, self).__init__(*args, **kwargs) def get_conll_spacy_lookup(self, conll_tokens, spacy_tokens, debug=False): """ Compute a look up table between spacy tokens (from spacy tokenizer) and conll pre-tokenized tokens Output: list[conll_index] => list of associated spacy tokens (assume spacy tokenizer has a finer granularity) """ lookup = [] c_iter = (t for t in conll_tokens) s_iter = enumerate(t for t in spacy_tokens) i, s_tok = next(s_iter) for c_tok in c_iter: # if debug: print("conll", c_tok, "spacy", s_tok, "index", i) c_lookup = [] while i is not None and len(c_tok) and c_tok.startswith(s_tok.text): c_lookup.append(i) c_tok = c_tok[len(s_tok) :] i, s_tok = next(s_iter, (None, None)) if debug and len(c_tok): print("eating token: conll", c_tok, "spacy", s_tok, "index", i) assert len(c_lookup), "Unmatched conll and spacy tokens" lookup.append(c_lookup) return lookup def add_conll_utterance( self, parsed, tokens, corefs, speaker_id, use_gold_mentions, debug=False ): conll_lookup = self.get_conll_spacy_lookup(tokens, parsed) self.conll_tokens.append(tokens) self.conll_lookup.append(conll_lookup) # Convert conll tokens coref index in spacy tokens indexes identified_gold = [False] * len(corefs) for coref in corefs: missing_values = [key for key in ['label', 'start', 'end', ] if coref.get(key, None) is None] if missing_values: found_values = {key: coref[key] for key in ['label', 'start', 'end'] if coref.get(key, None) is not None} raise Exception(f"Coref {self.name} with fields {found_values} has empty values for the keys {missing_values}.") coref["start"] = conll_lookup[coref["start"]][0] coref["end"] = conll_lookup[coref["end"]][-1] if speaker_id not in self.speakers: speaker_name = speaker_id.split("_") if debug: print("New speaker: ", speaker_id, "name: ", speaker_name) self.speakers[speaker_id] = Speaker(speaker_id, speaker_name) if use_gold_mentions: for coref in corefs: # print("coref['label']", coref['label']) # print("coref text",parsed[coref['start']:coref['end']+1]) mention = Mention( parsed[coref["start"] : coref["end"] + 1], len(self.mentions), len(self.utterances), self.n_sents, speaker=self.speakers[speaker_id], gold_label=coref["label"], ) self.mentions.append(mention) # print("mention: ", mention, "label", mention.gold_label) else: mentions_spans = extract_mentions_spans( doc=parsed, blacklist=self.blacklist ) self._process_mentions( mentions_spans, len(self.utterances), self.n_sents, self.speakers[speaker_id], ) # Assign a gold label to mentions which have one if debug: print("Check corefs", corefs) for i, coref in enumerate(corefs): for m in self.mentions: if m.utterance_index != len(self.utterances): continue # if debug: print("Checking mention", m, m.utterance_index, m.start, m.end) if coref["start"] == m.start and coref["end"] == m.end - 1: m.gold_label = coref["label"] identified_gold[i] = True # if debug: print("Gold mention found:", m, coref['label']) for found, coref in zip(identified_gold, corefs): if not found: self.missed_gold.append( [ self.name, self.part, str(len(self.utterances)), parsed.text, parsed[coref["start"] : coref["end"] + 1].text, ] ) if debug: print( "❄️ gold mention not in predicted mentions", coref, parsed[coref["start"] : coref["end"] + 1], ) self.utterances.append(parsed) self.gold_corefs.append(corefs) self.utterances_speaker.append(self.speakers[speaker_id]) self.n_sents += len(list(parsed.sents)) def get_single_mention_features_conll(self, mention, compressed=True): """ Compressed or not single mention features""" if not compressed: _, features = self.get_single_mention_features(mention) return features[np.newaxis, :] feat_l = [ mention.features_["01_MentionType"], mention.features_["02_MentionLength"], mention.index, len(self.mentions), mention.features_["04_IsMentionNested"], self.genre_, ] return feat_l def get_pair_mentions_features_conll(self, m1, m2, compressed=True): """ Compressed or not single mention features""" if not compressed: _, features = self.get_pair_mentions_features(m1, m2) return features[np.newaxis, :] features_, _ = self.get_pair_mentions_features(m1, m2) feat_l = [ features_["00_SameSpeaker"], features_["01_AntMatchMentionSpeaker"], features_["02_MentionMatchSpeaker"], features_["03_HeadsAgree"], features_["04_ExactStringMatch"], features_["05_RelaxedStringMatch"], features_["06_SentenceDistance"], features_["07_MentionDistance"], features_["08_Overlapping"], ] return feat_l def get_feature_array(self, doc_id, feature=None, compressed=True, debug=False): """ Prepare feature array: mentions_spans: (N, S) mentions_words: (N, W) mentions_features: (N, Fs) mentions_labels: (N, 1) mentions_pairs_start_index: (N, 1) index of beggining of pair list in pair_labels mentions_pairs_length: (N, 1) number of pairs (i.e. nb of antecedents) for each mention pairs_features: (P, Fp) pairs_labels: (P, 1) pairs_ant_idx: (P, 1) => indexes of antecedents mention for each pair (mention index in doc) """ if not self.mentions: if debug: print("No mention in this doc !") return {} if debug: print("🛎 features matrices") mentions_spans = [] mentions_words = [] mentions_features = [] pairs_ant_idx = [] pairs_features = [] pairs_labels = [] mentions_labels = [] mentions_pairs_start = [] mentions_pairs_length = [] mentions_location = [] n_mentions = 0 total_pairs = 0 if debug: print("mentions", self.mentions, str([m.gold_label for m in self.mentions])) for mention_idx, antecedents_idx in list( self.get_candidate_pairs(max_distance=None, max_distance_with_match=None) ): n_mentions += 1 mention = self.mentions[mention_idx] mentions_spans.append(mention.spans_embeddings) w_idx = mention_words_idx(self.embed_extractor, mention) if w_idx is None: print("error in", self.name, self.part, mention.utterance_index) mentions_words.append(w_idx) mentions_features.append( self.get_single_mention_features_conll(mention, compressed) ) mentions_location.append( [ mention.start, mention.end, mention.utterance_index, mention_idx, doc_id, ] ) ants = [self.mentions[ant_idx] for ant_idx in antecedents_idx] no_antecedent = ( not any(ant.gold_label == mention.gold_label for ant in ants) or mention.gold_label is None ) if antecedents_idx: pairs_ant_idx += [idx for idx in antecedents_idx] pairs_features += [ self.get_pair_mentions_features_conll(ant, mention, compressed) for ant in ants ] ant_labels = ( [0 for ant in ants] if no_antecedent else [ 1 if ant.gold_label == mention.gold_label else 0 for ant in ants ] ) pairs_labels += ant_labels mentions_labels.append(1 if no_antecedent else 0) mentions_pairs_start.append(total_pairs) total_pairs += len(ants) mentions_pairs_length.append(len(ants)) out_dict = { FEATURES_NAMES[0]: mentions_features, FEATURES_NAMES[1]: mentions_labels, FEATURES_NAMES[2]: mentions_pairs_length, FEATURES_NAMES[3]: mentions_pairs_start, FEATURES_NAMES[4]: mentions_spans, FEATURES_NAMES[5]: mentions_words, FEATURES_NAMES[6]: pairs_ant_idx if pairs_ant_idx else None, FEATURES_NAMES[7]: pairs_features if pairs_features else None, FEATURES_NAMES[8]: pairs_labels if pairs_labels else None, FEATURES_NAMES[9]: [mentions_location], FEATURES_NAMES[10]: [self.conll_tokens], FEATURES_NAMES[11]: [self.conll_lookup], FEATURES_NAMES[12]: [ { "name": self.name, "part": self.part, "utterances": list(str(u) for u in self.utterances), "mentions": list(str(m) for m in self.mentions), } ], } if debug: print("🚘 Summary") for k, v in out_dict.items(): print(k, len(v)) return n_mentions, total_pairs, out_dict ################### ### ConllCorpus ##### class ConllCorpus(object): def __init__( self, n_jobs=4, embed_path=PACKAGE_DIRECTORY + "/weights/", gold_mentions=False, blacklist=False, ): self.n_jobs = n_jobs self.features = {} self.utts_text = [] self.utts_tokens = [] self.utts_corefs = [] self.utts_speakers = [] self.utts_doc_idx = [] self.docs_names = [] self.docs = [] if embed_path is not None: self.embed_extractor = EmbeddingExtractor(embed_path) self.trainable_embed = [] self.trainable_voc = [] self.gold_mentions = gold_mentions self.blacklist = blacklist def check_words_in_embeddings_voc(self, embedding, tuned=True, debug=False): print("🌋 Checking if words are in embedding voc") if tuned: embed_voc = embedding.tun_idx else: embed_voc = embedding.stat_idx missing_words = [] missing_words_sents = [] missing_words_doc = [] for doc in self.docs: # if debug: print("Checking doc", doc.name, doc.part) for sent in doc.utterances: # if debug: print(sent.text) for word in sent: w = embedding.normalize_word(word) # if debug: print(w) if w not in embed_voc: missing_words.append(w) missing_words_sents.append(sent.text) missing_words_doc.append(doc.name + doc.part) if debug: out_str = ( "No matching tokens in tuned voc for " + w + " in sentence " + sent.text + " in doc " + doc.name + doc.part ) print(out_str) return missing_words, missing_words_sents, missing_words_doc def test_sentences_words(self, save_file, debug=False): print("🌋 Saving sentence list") with io.open(save_file, "w", encoding="utf-8") as f: if debug: print("Sentences saved in", save_file) for doc in self.docs: out_str = "#begin document (" + doc.name + "); part " + doc.part + "\n" f.write(out_str) for sent in doc.utterances: f.write(sent.text + "\n") out_str = "#end document\n\n" f.write(out_str) def save_sentences(self, save_file, debug=False): print("🌋 Saving sentence list") with io.open(save_file, "w", encoding="utf-8") as f: if debug: print("Sentences saved in", save_file) for doc in self.docs: out_str = "#begin document (" + doc.name + "); part " + doc.part + "\n" f.write(out_str) for sent in doc.utterances: f.write(sent.text + "\n") out_str = "#end document\n\n" f.write(out_str) def build_key_file(self, data_path, key_file, debug=False): print("🌋 Building key file from corpus") print("Saving in", key_file) # Create a pool of processes. By default, one is created for each CPU in your machine. with io.open(key_file, "w", encoding="utf-8") as kf: if debug: print("Key file saved in", key_file) for dirpath, _, filenames in os.walk(data_path): print("In", dirpath) file_list = [ os.path.join(dirpath, f) for f in filenames if f.endswith(".v4_auto_conll") or f.endswith(".v4_gold_conll") ] cleaned_file_list = [] for f in file_list: fn = f.split(".") if fn[1] == "v4_auto_conll": gold = fn[0] + "." + "v4_gold_conll" if gold not in file_list: cleaned_file_list.append(f) else: cleaned_file_list.append(f) # self.load_file(file_list[0]) doc_list = parallel_process(cleaned_file_list, read_file) for doc in doc_list: kf.write(doc) def list_undetected_mentions(self, data_path, save_file, debug=True): self.read_corpus(data_path) print("🌋 Listing undetected mentions") with io.open(save_file, "w", encoding="utf-8") as out_file: for doc in tqdm(self.docs): for name, part, utt_i, utt, coref in doc.missed_gold: out_str = name + "\t" + part + "\t" + utt_i + '\t"' + utt + '"\n' out_str += coref + "\n" out_file.write(out_str) if debug: print(out_str) def read_corpus(self, data_path, model=None, debug=False): print("🌋 Reading files") for dirpath, _, filenames in os.walk(data_path): print("In", dirpath, os.path.abspath(dirpath)) file_list = [ os.path.join(dirpath, f) for f in filenames if f.endswith(".v4_auto_conll") or f.endswith(".v4_gold_conll") ] cleaned_file_list = [] for f in file_list: fn = f.split(".") if fn[1] == "v4_auto_conll": gold = fn[0] + "." + "v4_gold_conll" if gold not in file_list: cleaned_file_list.append(f) else: cleaned_file_list.append(f) doc_list = parallel_process(cleaned_file_list, load_file) for docs in doc_list: # executor.map(self.load_file, cleaned_file_list): for ( utts_text, utt_tokens, utts_corefs, utts_speakers, name, part, ) in docs: if debug: print("Imported", name) print("utts_text", utts_text) print("utt_tokens", utt_tokens) print("utts_corefs", utts_corefs) print("utts_speakers", utts_speakers) print("name, part", name, part) self.utts_text += utts_text self.utts_tokens += utt_tokens self.utts_corefs += utts_corefs self.utts_speakers += utts_speakers self.utts_doc_idx += [len(self.docs_names)] * len(utts_text) self.docs_names.append((name, part)) print("utts_text size", len(self.utts_text)) print("utts_tokens size", len(self.utts_tokens)) print("utts_corefs size", len(self.utts_corefs)) print("utts_speakers size", len(self.utts_speakers)) print("utts_doc_idx size", len(self.utts_doc_idx)) print("🌋 Building docs") for name, part in self.docs_names: self.docs.append( ConllDoc( name=name, part=part, nlp=None, blacklist=self.blacklist, consider_speakers=True, embedding_extractor=self.embed_extractor, conll=CONLL_GENRES[name[:2]], ) ) print("🌋 Loading spacy model") if model is None: model_options = ["en_core_web_lg", "en_core_web_md", "en_core_web_sm", "en"] for model_option in model_options: if not model: try: spacy.info(model_option) model = model_option print("Loading model", model_option) except: print("Could not detect model", model_option) if not model: print("Could not detect any suitable English model") return else: spacy.info(model) print("Loading model", model) nlp = spacy.load(model) print( "🌋 Parsing utterances and filling docs with use_gold_mentions=" + (str(bool(self.gold_mentions))) ) doc_iter = (s for s in self.utts_text) for utt_tuple in tqdm( zip( nlp.pipe(doc_iter), self.utts_tokens, self.utts_corefs, self.utts_speakers, self.utts_doc_idx, ) ): spacy_tokens, conll_tokens, corefs, speaker, doc_id = utt_tuple if debug: print(unicode_(self.docs_names[doc_id]), "-", spacy_tokens) doc = spacy_tokens if debug: out_str = ( "utterance " + unicode_(doc) + " corefs " + unicode_(corefs) + " speaker " + unicode_(speaker) + "doc_id" + unicode_(doc_id) ) print(out_str.encode("utf-8")) self.docs[doc_id].add_conll_utterance( doc, conll_tokens, corefs, speaker, use_gold_mentions=self.gold_mentions ) def build_and_gather_multiple_arrays(self, save_path): print(f"🌋 Extracting mentions features with {self.n_jobs} job(s)") parallel_process(self.docs, set_feats, n_jobs=self.n_jobs) print(f"🌋 Building and gathering array with {self.n_jobs} job(s)") arr = [{"doc": doc, "i": i} for i, doc in enumerate(self.docs)] arrays_dicts = parallel_process( arr, get_feats, use_kwargs=True, n_jobs=self.n_jobs ) gathering_dict = dict((feat, None) for feat in FEATURES_NAMES) n_mentions_list = [] pairs_ant_index = 0 pairs_start_index = 0 for npaidx in tqdm(range(len(arrays_dicts))): try: n, p, arrays_dict = arrays_dicts[npaidx] except: # empty array dict, cannot extract the dict values for this doc continue for f in FEATURES_NAMES: if gathering_dict[f] is None: gathering_dict[f] = arrays_dict[f] else: if f == FEATURES_NAMES[6]: array = [a + pairs_ant_index for a in arrays_dict[f]] elif f == FEATURES_NAMES[3]: array = [a + pairs_start_index for a in arrays_dict[f]] else: array = arrays_dict[f] gathering_dict[f] += array pairs_ant_index += n pairs_start_index += p n_mentions_list.append(n) for feature in FEATURES_NAMES[:9]: feature_data = gathering_dict[feature] if not feature_data: print("No data for", feature) continue print("Building numpy array for", feature, "length", len(feature_data)) if feature != "mentions_spans": array = np.array(feature_data) if array.ndim == 1: array = np.expand_dims(array, axis=1) else: array = np.stack(feature_data) # check_numpy_array(feature, array, n_mentions_list) print("Saving numpy", feature, "size", array.shape) np.save(save_path + feature, array) for feature in FEATURES_NAMES[9:]: feature_data = gathering_dict[feature] if feature_data: print("Saving pickle", feature, "size", len(feature_data)) with open(save_path + feature + ".bin", "wb") as fp: pickle.dump(feature_data, fp) def save_vocabulary(self, save_path, debug=False): def _vocabulary_to_file(path, vocabulary): print("🌋 Saving vocabulary") with io.open(path, "w", encoding="utf-8") as f: if debug: print(f"voc saved in {path}, length: {len(vocabulary)}") for w in tunable_voc: f.write(w + "\n") print("🌋 Building tunable vocabulary matrix from static vocabulary") tunable_voc = self.embed_extractor.tun_voc _vocabulary_to_file( path=save_path + "tuned_word_vocabulary.txt", vocabulary=tunable_voc ) static_voc = self.embed_extractor.stat_voc _vocabulary_to_file( path=save_path + "static_word_vocabulary.txt", vocabulary=static_voc ) tuned_word_embeddings = np.vstack( [self.embed_extractor.get_stat_word(w)[1] for w in tunable_voc] ) print("Saving tunable voc, size:", tuned_word_embeddings.shape) np.save(save_path + "tuned_word_embeddings", tuned_word_embeddings) static_word_embeddings = np.vstack( [self.embed_extractor.static_embeddings[w] for w in static_voc] ) print("Saving static voc, size:", static_word_embeddings.shape) np.save(save_path + "static_word_embeddings", static_word_embeddings) if __name__ == "__main__": DIR_PATH = os.path.dirname(os.path.realpath(__file__)) parser = argparse.ArgumentParser( description="Training the neural coreference model" ) parser.add_argument( "--function", type=str, default="all", help='Function ("all", "key", "parse", "find_undetected")', ) parser.add_argument( "--path", type=str, default=DIR_PATH + "/data/", help="Path to the dataset" ) parser.add_argument( "--key", type=str, help="Path to an optional key file for scoring" ) parser.add_argument( "--n_jobs", type=int, default=1, help="Number of parallel jobs (default 1)" ) parser.add_argument( "--gold_mentions", type=int, default=0, help="Use gold mentions (1) or not (0, default)", ) parser.add_argument( "--blacklist", type=int, default=0, help="Use blacklist (1) or not (0, default)" ) parser.add_argument("--spacy_model", type=str, default=None, help="model name") args = parser.parse_args() if args.key is None: args.key = args.path + "/key.txt" CORPUS = ConllCorpus( n_jobs=args.n_jobs, gold_mentions=args.gold_mentions, blacklist=args.blacklist ) if args.function == "parse" or args.function == "all": SAVE_DIR = args.path + "/numpy/" if not os.path.exists(SAVE_DIR): os.makedirs(SAVE_DIR) else: if os.listdir(SAVE_DIR): print("There are already data in", SAVE_DIR) print("Erasing") for file in os.listdir(SAVE_DIR): print(file) os.remove(SAVE_DIR + file) start_time = time.time() CORPUS.read_corpus(args.path, model=args.spacy_model) print("=> read_corpus time elapsed", time.time() - start_time) if not CORPUS.docs: print("Could not parse any valid docs") else: start_time2 = time.time() CORPUS.build_and_gather_multiple_arrays(SAVE_DIR) print( "=> build_and_gather_multiple_arrays time elapsed", time.time() - start_time2, ) start_time2 = time.time() CORPUS.save_vocabulary(SAVE_DIR) print("=> save_vocabulary time elapsed", time.time() - start_time2) print("=> total time elapsed", time.time() - start_time) if args.function == "key" or args.function == "all": CORPUS.build_key_file(args.path, args.key) if args.function == "find_undetected": CORPUS.list_undetected_mentions( args.path, args.path + "/undetected_mentions.txt" )