import os import re import cPickle import copy import numpy import torch import nltk from nltk.corpus import ptb from nltk import Tree word_tags = ['CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', 'MD', 'NN', 'NNS', 'NNP', 'NNPS', 'PDT', 'POS', 'PRP', 'PRP$', 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP$', 'WRB'] currency_tags_words = ['#', '$', 'C$', 'A$'] ellipsis = ['*', '*?*', '0', '*T*', '*ICH*', '*U*', '*RNR*', '*EXP*', '*PPA*', '*NOT*'] punctuation_tags = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``'] punctuation_words = ['.', ',', ':', '-LRB-', '-RRB-', '\'\'', '``', '--', ';', '-', '?', '!', '...', '-LCB-', '-RCB-'] file_ids = ptb.fileids() train_file_ids = [] valid_file_ids = [] test_file_ids = [] rest_file_ids = [] for id in file_ids: if 'WSJ/00/WSJ_0000.MRG' <= id <= 'WSJ/24/WSJ_2499.MRG': train_file_ids.append(id) elif 'WSJ/22/WSJ_2200.MRG' <= id <= 'WSJ/22/WSJ_2299.MRG': valid_file_ids.append(id) elif 'WSJ/23/WSJ_2300.MRG' <= id <= 'WSJ/23/WSJ_2399.MRG': test_file_ids.append(id) elif 'WSJ/00/WSJ_0000.MRG' <= id <= 'WSJ/01/WSJ_0199.MRG' or 'WSJ/24/WSJ_2400.MRG' <= id <= 'WSJ/24/WSJ_2499.MRG': rest_file_ids.append(id) data_path = '/misc/vlgscratch4/BowmanGroup/pmh330/datasets/' train_files = data_path + 'all_nli/all_nli_train.jsonl' valid_files = data_path + 'all_nli/all_nli_valid.jsonl' test_files_snli = data_path + 'snli_1.0/snli_1.0_test.jsonl' test_files_mnli_match = data_path + 'multinli_1.0/multinli_1.0_dev_matched.jsonl' class Dictionary(object): def __init__(self): self.word2idx = {'<unk>': 0} self.idx2word = ['<unk>'] self.word2frq = {} def add_word(self, word): if word not in self.word2idx: self.idx2word.append(word) self.word2idx[word] = len(self.idx2word) - 1 if word not in self.word2frq: self.word2frq[word] = 1 else: self.word2frq[word] += 1 return self.word2idx[word] def __len__(self): return len(self.idx2word) def __getitem__(self, item): if self.word2idx.has_key(item): return self.word2idx[item] else: return self.word2idx['<unk>'] def rebuild_by_freq(self, thd=3): self.word2idx = {'<unk>': 0} self.idx2word = ['<unk>'] for k, v in self.word2frq.iteritems(): if v >= thd and (not k in self.idx2word): self.idx2word.append(k) self.word2idx[k] = len(self.idx2word) - 1 print 'Number of words:', len(self.idx2word) return len(self.idx2word) class Corpus(object): def __init__(self, path): dict_file_name = os.path.join(path, 'dict_nli.pkl') if os.path.exists(dict_file_name): self.dictionary = cPickle.load(open(dict_file_name, 'rb')) print("loading: ", dict_file_name) else: self.dictionary = Dictionary() self.add_words(train_files) self.add_words(valid_files) self.add_words(test_files_snli) self.add_words(test_files_mnli_match) self.dictionary.rebuild_by_freq() cPickle.dump(self.dictionary, open(dict_file_name, 'wb')) self.train, self.train_sens, self.train_trees = self.tokenize(train_files) self.valid, self.valid_sens, self.valid_trees = self.tokenize(valid_files) self.test_snli, self.test_snli_sens, self.test_snli_trees = self.tokenize(test_files_snli) self.test_mnli, self.test_mnli_sens, self.test_mnli_trees = self.tokenize(test_files_mnli_match) self.test, self.test_sens, self.test_trees = self.tokenize(test_file_ids) def filter_words(self, tree): words = [] for w, tag in tree.pos(): if tag in word_tags: w = w.lower() w = re.sub('[0-9]+', 'N', w) # if tag == 'CD': # w = 'N' words.append(w) return words def add_words(self, file_name): # Add words to the dictionary f_in = open(file_name, 'r') for line in f_in: if line.strip() == '': continue data = eval(line) sen_tree = Tree.fromstring(data['sentence1_parse']) words = self.filter_words(sen_tree) words = ['<s>'] + words + ['</s>'] for word in words: self.dictionary.add_word(word) sen_tree = Tree.fromstring(data['sentence2_parse']) words = self.filter_words(sen_tree) words = ['<s>'] + words + ['</s>'] for word in words: self.dictionary.add_word(word) f_in.close() def tokenize(self, file_name): def tree2list(tree): if isinstance(tree, nltk.Tree): if tree.label() in word_tags: return tree.leaves()[0] else: root = [] for child in tree: c = tree2list(child) if c != []: root.append(c) if len(root) > 1: return root elif len(root) == 1: return root[0] return [] sens_idx = [] sens = [] sentences = [] trees = [] f_in = open(file_name, 'r') for line in f_in: if line.strip() == '': continue data = eval(line) sentences = [] sentences.append(Tree.fromstring(data['sentence1_parse'])) sentences.append(Tree.fromstring(data['sentence2_parse'])) for sen_tree in sentences: words = self.filter_words(sen_tree) if not words: continue words = ['<s>'] + words + ['</s>'] # if len(words) > 50: # continue sens.append(words) idx = [] for word in words: idx.append(self.dictionary[word]) sens_idx.append(torch.LongTensor(idx)) trees.append(tree2list(sen_tree)) f_in.close() return sens_idx, sens, trees