import pickle import nltk class Trainer: def __init__(self): self.uni_dist = nltk.FreqDist() self.backward_bi_dist = nltk.FreqDist() self.forward_bi_dist = nltk.FreqDist() self.trigram_dist = nltk.FreqDist() self.word_casing_lookup = {} def __function_one(self, sentence, word, word_idx, word_lower): try: if (word_lower in self.word_casing_lookup and len(self.word_casing_lookup[word_lower]) >= 2): # Only if there are multiple options prev_word = sentence[word_idx - 1] self.backward_bi_dist[prev_word + "_" + word] += 1 next_word = sentence[word_idx + 1].lower() self.forward_bi_dist[word + "_" + next_word] += 1 except IndexError: pass def __function_two(self, sentence, word, word_idx): try: if word_idx - 1 < 0: return prev_word = sentence[word_idx - 1] cur_word = sentence[word_idx] cur_word_lower = word.lower() next_word_lower = sentence[word_idx + 1].lower() if (cur_word_lower in self.word_casing_lookup and len(self.word_casing_lookup[cur_word_lower]) >= 2): # Only if there are multiple options self.trigram_dist[prev_word + "_" + cur_word + "_" + next_word_lower] += 1 except IndexError: pass def train(self, corpus): for sentence in corpus: if not self.check_sentence_sanity(sentence): continue for word_idx, word in enumerate(sentence): self.uni_dist[word] += 1 word_lower = word.lower() if word_lower not in self.word_casing_lookup: self.word_casing_lookup[word_lower] = set() self.word_casing_lookup[word_lower].add(word) self.__function_one(sentence, word, word_idx, word_lower) self.__function_two(sentence, word, word_idx) def save_to_file(self, file_path): pickle_dict = { "uni_dist": self.uni_dist, "backward_bi_dist": self.backward_bi_dist, "forward_bi_dist": self.forward_bi_dist, "trigram_dist": self.trigram_dist, "word_casing_lookup": self.word_casing_lookup, } with open(file_path, "wb") as fp: pickle.dump(pickle_dict, fp) print("Model saved to " + file_path) @staticmethod def get_casing(word): """ Returns the casing of a word """ if len(word) == 0: return "other" elif word.isdigit(): # Is a digit return "numeric" elif word.islower(): # All lower case return "allLower" elif word.isupper(): # All upper case return "allUpper" # is a title, initial char upper, then all lower elif word[0].isupper(): return "initialUpper" return "other" def check_sentence_sanity(self, sentence): """ Checks the sanity of the sentence. If the sentence is for example all uppercase, it is rejected """ case_dist = nltk.FreqDist() for token in sentence: case_dist[self.get_casing(token)] += 1 if case_dist.most_common(1)[0][0] != "allLower": return False return True if __name__ == "__main__": corpus = (nltk.corpus.brown.sents() + nltk.corpus.reuters.sents() + nltk.corpus.semcor.sents() + nltk.corpus.conll2000.sents() + nltk.corpus.state_union.sents()) trainer = Trainer() trainer.train(corpus) trainer.save_to_file("data/english.dist")