from gensim.models import Word2Vec from gensim.models.word2vec import LineSentence from gensim.models.phrases import Phrases, Phraser import gensim from mat2vec.training.helpers.utils import EpochSaver, compute_epoch_accuracies, \ keep_simple_formula, load_obj, COMMON_TERMS, EXCLUDE_PUNCT, INCLUDE_PHRASES import logging import os import argparse import regex import pickle from tqdm import tqdm logging.basicConfig(format="%(asctime)s : %(levelname)s : %(message)s", level=logging.INFO) def exclude_words(phrasegrams, words): """Given a list of words, excludes those from the keys of the phrase dictionary.""" new_phrasergrams = {} words_re_list = [] for word in words: we = regex.escape(word) words_re_list.append("^" + we + "$|^" + we + "_|_" + we + "$|_" + we + "_") word_reg = regex.compile(r""+"|".join(words_re_list)) for gram in tqdm(phrasegrams): valid = True for sub_gram in gram: if word_reg.search(sub_gram.decode("unicode_escape", "ignore")) is not None: valid = False break if not valid: continue if valid: new_phrasergrams[gram] = phrasegrams[gram] return new_phrasergrams # Generating word grams. def wordgrams(sent, depth, pc, th, ct, et, ip, d=0): if depth == 0: return sent, None else: """Builds word grams according to the specification.""" phrases = Phrases( sent, common_terms=ct, min_count=pc, threshold=th) grams = Phraser(phrases) grams.phrasegrams = exclude_words(grams.phrasegrams, et) d += 1 if d < depth: return wordgrams(grams[sent], depth, pc, th, ct, et, ip, d) else: return grams[sent], grams if __name__ == "__main__": parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--corpus", required=True, help="The path to the corpus to train on.") parser.add_argument("--model_name", required=True, help="Name for saving the model (in the models folder).") parser.add_argument("--epochs", default=30, help="Number of epochs.") parser.add_argument("--size", default=200, help="Size of the embedding.") parser.add_argument("--window", default=8, help="Context window size.") parser.add_argument("--min_count", default=5, help="Minimum number of occurrences for word.") parser.add_argument("--workers", default=16, help="Number of workers.") parser.add_argument("--alpha", default=0.01, help="Learning rate.") parser.add_argument("--batch", default=10000, help="Minibatch size.") parser.add_argument("--negative", default=15, help="Number of negative samples.") parser.add_argument("--subsample", default=0.0001, help="Subsampling rate.") parser.add_argument("--phrase_depth", default=2, help="The number of passes to perform for phrase generation.") parser.add_argument("--phrase_count", default=10, help="Minimum number of occurrences for phrase to be considered.") parser.add_argument("--phrase_threshold", default=15.0, help="Phrase importance threshold.") parser.add_argument("-include_extra_phrases", action="store_true", help="If true, will look for all_ents.p and add extra phrases.") parser.add_argument("-sg", action="store_true", help="If set, will train a skip-gram, otherwise a CBOW.") parser.add_argument("-hs", action="store_true", help="If set, hierarchical softmax will be used.") parser.add_argument("-keep_formula", action="store_true", help="If set, keeps simple chemical formula independent on count.") parser.add_argument("-notmp", action="store_true", help="If set, will not store the progress in tmp folder.") args = parser.parse_args() all_formula = [] if args.keep_formula: try: all_formula = load_obj(args.corpus + "_formula") # list of formula is supplied def keep_formula_list(word, count, min_count): if word in all_formula: return gensim.utils.RULE_KEEP else: return gensim.utils.RULE_DEFAULT trim_rule_formula = keep_formula_list logging.info("Using a supplied list of formula to keep simple formula.") except: # no list is supplied, use the simple formula rule trim_rule_formula = keep_simple_formula logging.info("Using a function to keep material mentions.") else: logging.info("Basic min_count trim rule for formula.") trim_rule_formula = None # The trim rule for extra phrases to always keep them, similar to the formulae. if args.include_extra_phrases: INCLUDE_PHRASES_SET = set(INCLUDE_PHRASES) try: with open("all_ents.p", "rb") as f: INCLUDE_PHRASES += list(set(pickle.load(f))) INCLUDE_PHRASES_SET = set([ip.replace("_", "$@$@$") for ip in INCLUDE_PHRASES]) logging.info("Included the supplied {} additional phrases.".format(len(INCLUDE_PHRASES))) except: logging.info("No specific phrases supplied, using the defaults.") def keep_extra_phrases(word, count, min_count): if word in INCLUDE_PHRASES_SET or trim_rule_formula is not None and \ trim_rule_formula(word, 1, 2) == gensim.utils.RULE_KEEP: return gensim.utils.RULE_KEEP else: return gensim.utils.RULE_DEFAULT trim_rule = keep_extra_phrases logging.info("Keeping the extra phrases independent on their count.") else: trim_rule = trim_rule_formula logging.info("Not including extra phrases, option not specified.") # Excluding all formula from the phrases. formula_counts = [0] * len(all_formula) for i, formula in enumerate(all_formula): for writing in all_formula[formula]: formula_counts[i] += all_formula[formula][writing] formula_strings = [formula for i, formula in enumerate(all_formula) if formula_counts[i] > int(args.phrase_count)] # Loading text and generating the phrases. sentences = LineSentence(args.corpus) # Pre-process everything to force the supplied phrases before it even goes to the phraser. processed_sentences = sentences if args.include_extra_phrases: phrases_by_length = dict() for phrase in INCLUDE_PHRASES: phrase_split = phrase.split("_") if len(phrase_split) not in phrases_by_length: phrases_by_length[len(phrase_split)] = [phrase] else: phrases_by_length[len(phrase_split)].append(phrase) max_len = max(phrases_by_length.keys()) processed_sentences = [] for sentence in tqdm(sentences): for cl in reversed(range(2, max_len + 1)): repl_phrases = set(phrases_by_length[cl]) si = 0 while si <= len(sentence) - cl: if "_".join(sentence[si:cl + si]) in repl_phrases: sentence[si] = "$@$@$".join(sentence[si:cl + si]) del(sentence[si + 1:cl + si]) else: si += 1 processed_sentences.append(sentence) # Process sentences to force the extra phrases. sentences, phraser = wordgrams(processed_sentences, depth=int(args.phrase_depth), pc=int(args.phrase_count), th=float(args.phrase_threshold), ct=COMMON_TERMS, et=EXCLUDE_PUNCT + formula_strings, ip=INCLUDE_PHRASES) phraser.save(os.path.join("models", args.model_name + "_phraser.pkl")) if not args.notmp: callbacks = [EpochSaver(path_prefix=args.model_name)] else: callbacks = [] my_model = Word2Vec( sentences, size=int(args.size), window=int(args.window), min_count=int(args.min_count), sg=bool(args.sg), hs=bool(args.hs), trim_rule=trim_rule, workers=int(args.workers), alpha=float(args.alpha), sample=float(args.subsample), negative=int(args.negative), compute_loss=True, sorted_vocab=True, batch_words=int(args.batch), iter=int(args.epochs), callbacks=callbacks) my_model.save(os.path.join("models", args.model_name)) analogy_file = os.path.join("data", "analogies.txt") # Save the accuracies in the tmp folder. compute_epoch_accuracies("tmp", args.model_name, analogy_file)