import xml.etree.ElementTree as ET import torch import pickle import glob import argparse import numpy as np from pytorch_pretrained_bert import BertTokenizer, BertModel from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import euclidean_distances from nltk.stem import WordNetLemmatizer from tqdm import tqdm, trange from copy import deepcopy import warnings warnings.filterwarnings('ignore') class BERT: def __init__(self, device_number='cuda:2', use_cuda = True): self.device_number = device_number self.use_cuda = use_cuda self.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') self.model = BertModel.from_pretrained('bert-large-uncased') self.model.eval() if use_cuda: self.model.to(device_number) class Word_Sense_Model: def __init__(self, device_number = 'cuda:2', use_cuda=True): self.device_number = device_number self.use_cuda = use_cuda self.sense_number_map = {'N':1, 'V':2, 'J':3, 'R':4} self.Bert_Model = BERT(device_number, use_cuda) self.lemmatizer = WordNetLemmatizer() def open_xml_file(self, file_name): tree = ET.parse(file_name) root = tree.getroot() return root, tree def wngt_sent_sense_collect(self, xml_struct): _sent =[] _sent1 = [] _senses = [] temp_list_pos = [] _back_sent = [] _back_sent1 = "" _back_senses = [] for idx,j in enumerate(xml_struct.iter('word')): _temp_dict = j.attrib if 'lemma' in _temp_dict: _word = _temp_dict['lemma'].lower() else: _word = _temp_dict['surface_form'].lower() _back_sent.extend([_word]) _back_sent1 += _word + " " if 'wn30_key' in _temp_dict: _back_senses.extend( [_temp_dict['wn30_key']]*len([_word])) else: _back_senses.extend( [0]*len([_word])) _temp_dict = xml_struct.attrib if 'wn30_key' in _temp_dict: _senses1 = _temp_dict['wn30_key'].split(';') for i in _senses1: _word = [str(i.split('%')[0]), 'is'] _temp_sent = [] _temp_sent1 = "" _temp_senses = [] _temp_sent.extend(_word) _temp_sent.extend(_back_sent) _temp_sent1 += ' '.join(_word) + " " + _back_sent1 _temp_senses.extend([i,0]) _temp_senses.extend(_back_senses) _sent.append(_temp_sent) _sent1.append(_temp_sent1) _senses.append(_temp_senses) return _sent, _sent1, _senses, temp_list_pos def semcor_sent_sense_collect(self, xml_struct): _sent =[] _sent1 = "" _senses = [] temp_list_pos = [] for idx,j in enumerate(xml_struct.iter('word')): _temp_dict = j.attrib flag = 0 if 'lemma' not in _temp_dict: words = _temp_dict['surface_form'].lower() _sent1 += words + " " words = words.split('_') words1 = words[0:1] words2 = words[1:] else: _pos = _temp_dict['pos'].lower()[0] if _pos not in ['a', 'v', 'n']: _pos = 'n' w2 = _temp_dict['lemma'].lower().split('_') words = _temp_dict['surface_form'].lower() _sent1 += words + " " words = words.split('_') l = self.lemmatizer.lemmatize(words[0],pos=_pos) if str(l).startswith(w2[0]) or str(w2[0]).startswith(l): words1 = words[0:1] words2 = words[1:] else: flag = 1 _sent.extend(words) if 'wn30_key' in _temp_dict: if not flag: _senses.extend([_temp_dict['wn30_key']]*len(words1)) _senses.extend([0]*len(words2)) else: _senses.extend([0]*len(words)) else: _senses.extend([0]*len(words)) return _sent, _sent1, _senses, temp_list_pos def semeval_sent_sense_collect(self, xml_struct): _sent =[] _sent1 = "" _senses = [] pos = [] for idx,j in enumerate(xml_struct.iter('word')): _temp_dict = j.attrib if 'lemma' in _temp_dict: words = _temp_dict['lemma'].lower() else: words = _temp_dict['surface_form'].lower() if '*' not in words: _sent1 += words + " " _sent.extend([words]) if 'pos' in _temp_dict: pos.extend([_temp_dict['pos']]*len([words])) else: pos.extend([0]*len([words])) if 'wn30_key' in _temp_dict: _senses.extend([_temp_dict['wn30_key']]*len([words])) else: _senses.extend([0]*len([words])) return _sent, _sent1, _senses, pos def apply_bert_tokenizer(self, word): return self.Bert_Model.tokenizer.tokenize(word) def collect_bert_tokens(self, _sent, lemma=False ): _bert_tokens = ['[CLS]',] if lemma: for idx, j in enumerate(_sent): _sent[idx] = self.lemmatizer.lemmatize(_sent[idx]) _tokens = self.apply_bert_tokenizer(_sent[idx]) _bert_tokens.extend(_tokens) else: for idx, j in enumerate(_sent): _tokens = self.apply_bert_tokenizer(_sent[idx]) _bert_tokens.extend(_tokens) _bert_tokens.append('[SEP]') return _bert_tokens def get_bert_embeddings(self, tokens): _ib = self.Bert_Model.tokenizer.convert_tokens_to_ids(tokens) _st = [0]*len(_ib) if self.use_cuda: _t1,_t2 = torch.tensor([_ib]).to(self.device_number), torch.tensor([_st]).to(self.device_number) else: _t1,_t2 = torch.tensor([_ib]), torch.tensor([_st]) with torch.no_grad(): _encoded_layers, _ = self.Bert_Model.model(_t1,_t2, output_all_encoded_layers=True) _e1 = _encoded_layers[-4:] _e2 = torch.cat((_e1[0],_e1[1], _e1[2], _e1[3]),2) if self.use_cuda: _final_layer = _e2[0].cpu().numpy() else: _final_layer = _e2[0].numpy() return _final_layer def create_word_sense_maps(self, _word_sense_emb): _sense_emb = {} _sentence_maps = {} _sense_word_map ={} _word_sense_map ={} for i in _word_sense_emb: if i not in _word_sense_map: _word_sense_map[i] = [] for j in _word_sense_emb[i]: if j not in _sense_word_map: _sense_word_map[j] = [] _sense_word_map[j].append(i) _word_sense_map[i].append(j) if j not in _sense_emb: _sense_emb[j] =[] _sentence_maps[j] = [] _sense_emb[j].extend(_word_sense_emb[i][j]['embs']) _sentence_maps[j].extend(_word_sense_emb[i][j]['sentences']) return _sense_emb, _sentence_maps, _sense_word_map, _word_sense_map def train(self, train_file, training_data_type): print("Training Embeddings!!") _word_sense_emb = {} _train_root, _train_tree = self.open_xml_file(train_file) for i in tqdm(_train_root.iter('sentence')): if training_data_type == "SE": all_sent, all_sent1, all_senses, _ = self.semeval_sent_sense_collect(i) all_sent, all_sent1, all_senses = [all_sent], [all_sent1], [all_senses] elif training_data_type == "SEM": all_sent, all_sent1, all_senses, _ = self.semcor_sent_sense_collect(i) all_sent, all_sent1, all_senses = [all_sent], [all_sent1], [all_senses] elif training_data_type == "WNGT": all_sent, all_sent1, all_senses, _ = self.wngt_sent_sense_collect(i) else: print("Argument train_type not specified properly!!") quit() for sent, sent1, senses in zip(all_sent, all_sent1, all_senses): try: bert_tokens = self.collect_bert_tokens(sent) final_layer = self.get_bert_embeddings(bert_tokens) count = 1 for idx, j in enumerate(zip(senses, sent)): sense = j[0] word = j[1] if sense != 0: embedding = np.mean(final_layer[count: count+len(self.apply_bert_tokenizer(word)) ],0) if word not in _word_sense_emb: _word_sense_emb[word]={} for s in sense.split(';'): if s not in _word_sense_emb[word]: _word_sense_emb[word][s]={} _word_sense_emb[word][s]['embs'] = [] _word_sense_emb[word][s]['sentences'] = [] _word_sense_emb[word][s]['embs'].append(embedding) _word_sense_emb[word][s]['sentences'].append(sent1) count += len(self.apply_bert_tokenizer(word)) except Exception as e: print(e) return _word_sense_emb def load_embeddings(self, pickle_file_name, train_file, training_data_type): try: with open(pickle_file_name, 'rb') as h: _x = pickle.load(h) print("EMBEDDINGS FOUND!") return _x except: print("Embedding File Not Found!! \n") word_sense_emb = self.train(train_file, training_data_type) with open(pickle_file_name, 'wb') as h: pickle.dump(word_sense_emb, h) print("Embeddings Saved to " + pickle_file_name) return word_sense_emb def test(self, train_file, test_file, emb_pickle_file, training_data_type, save_to, k=1, use_euclidean = False, reduced_search = True): word_sense_emb = self.load_embeddings(emb_pickle_file, train_file, training_data_type) print("Testing!") sense_emb, sentence_maps, sense_word_map, word_sense_map = self.create_word_sense_maps(word_sense_emb) _test_root, _test_tree = self.open_xml_file(test_file) _correct, _wrong= [], [] open(save_to, "w").close() for i in tqdm(_test_root.iter('sentence')): sent, sent1, senses, pos = self.semeval_sent_sense_collect(i) bert_tokens = self.collect_bert_tokens(sent) final_layer = self.get_bert_embeddings(bert_tokens) count, tag, nn_sentences = 1, [], [] for idx, j in enumerate(zip(senses, sent, pos)): word = j[1] pos_tag = j[2][0] if j[0] != 0: _temp_tag = 0 max_score = -99 nearest_sent = 'NONE' embedding = np.mean(final_layer[count:count+len(self.apply_bert_tokenizer(word))],0) min_span = 10000 if word in word_sense_map: concat_senses = [] concat_sentences = [] index_maps = {} _reduced_sense_map = [] if reduced_search: for sense_id in word_sense_map[word]: if self.sense_number_map[pos_tag] == int(sense_id.split('%')[1][0]): _reduced_sense_map.append(sense_id) if len(_reduced_sense_map) == 0 : _reduced_sense_map = list(word_sense_map[word]) for sense_id in _reduced_sense_map: index_maps[sense_id] = {} index_maps[sense_id]['start'] = len(concat_senses) concat_senses.extend(sense_emb[sense_id]) concat_sentences.extend(sentence_maps[sense_id]) index_maps[sense_id]['end'] = len(concat_senses) - 1 index_maps[sense_id]['count'] = 0 if min_span > (index_maps[sense_id]['end']-index_maps[sense_id]['start']+1): min_span = (index_maps[sense_id]['end']-index_maps[sense_id]['start']+1) min_nearest = min(min_span, k) concat_senses = np.array(concat_senses) if use_euclidean: simis = euclidean_distances(embedding.reshape(1,-1), concat_senses)[0] nearest_indexes = simis.argsort()[:min_nearest] else: simis = cosine_similarity(embedding.reshape(1,-1), concat_senses)[0] nearest_indexes = simis.argsort()[-min_nearest:][::-1] for idx1 in nearest_indexes: for sense_id in _reduced_sense_map: if index_maps[sense_id]['start']<= idx1 and index_maps[sense_id]['end']>=idx1: index_maps[sense_id]['count'] += 1 score = index_maps[sense_id]['count'] if score > max_score: max_score = score _temp_tag = sense_id nearest_sent = concat_sentences[idx1] tag.append(_temp_tag) nn_sentences.append(nearest_sent) count += len(self.apply_bert_tokenizer(word)) _counter = 0 for j in i.iter('word'): temp_dict = j.attrib try: if 'wn30_key' in temp_dict: if tag[_counter] == 0: pass else: j.attrib['WSD'] = str(tag[_counter]) if j.attrib['WSD'] in str(temp_dict['wn30_key']).split(';') : _correct.append([temp_dict['wn30_key'], j.attrib['WSD'], (sent1), nn_sentences[_counter]]) else: _wrong.append([temp_dict['wn30_key'], j.attrib['WSD'], (sent1), nn_sentences[_counter]]) _counter += 1 except Exception as e: print(e) with open(save_to, "w") as f: _test_tree.write(f, encoding="unicode") print("OUTPUT STORED TO FILE: " + str(save_to)) return _correct, _wrong if __name__=='__main__': parser = argparse.ArgumentParser(description='WSD using BERT') parser.add_argument('--use_cuda', type=bool, default=True, help='Use GPU?') parser.add_argument('--device', type=str, default='cuda:2', help='GPU Device to Use?') parser.add_argument('--train_corpus', type=str, required=True, help='Training Corpus') parser.add_argument('--train_type', type=str, required=True, help='SEM/WNGT/SE') parser.add_argument('--trained_pickle',type=str,help='Pickle file of Trained Bert Embeddings/Save Embeddings to this file') parser.add_argument('--test_corpus', type=str, required=True, help='Testing Corpus') parser.add_argument('--start_k', type=int, default=1, help='Start value of Nearest Neighbour') parser.add_argument('--end_k', type=int, default=1, help='End value of Nearest Neighbour') parser.add_argument('--save_xml_to', type=str, help='Save the final output to?') parser.add_argument('--use_euclidean', type=int, default=0, help='Use Euclidean Distance to Find NNs?') parser.add_argument('--reduced_search', type=int, default=0, help='Apply Reduced POS Search?') args = parser.parse_args() print("Training Corpus is: " + args.train_corpus) print("Testing Corpus is: " + args.test_corpus) print("Nearest Neighbour start: " + str(args.start_k)) print("Nearest Neighbour end: " + str(args.end_k)) if args.reduced_search: print("Using Reduced POS Search!") else: print("Using the Search without POS!") if args.use_euclidean: print("Using Euclidean Distance!") else: print("Using Cosine Similarity!") print("Loading WSD Model!") WSD = Word_Sense_Model(device_number = args.device, use_cuda = args.use_cuda) print("Loaded WSD Model!") for nn in range(args.start_k, args.end_k+1): correct, wrong = WSD.test(train_file=args.train_corpus, test_file = args.test_corpus, training_data_type = args.train_type, emb_pickle_file = args.trained_pickle, save_to = args.save_xml_to[:-4] + "_" + str(nn)+args.save_xml_to[-4:], k=nn, use_euclidean = args.use_euclidean, reduced_search = args.reduced_search)