import argparse import numpy as np import copy import torch from scipy.spatial.distance import cosine from scipy.spatial import KDTree from allennlp.commands.elmo import ElmoEmbedder parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( '--elmo_weights_path', type=str, default='models/$l_weights.hdf5', help="Path to elmo weights files - use $l as a placeholder for language.") parser.add_argument( '--elmo_options_path', type=str, default='models/options262.json', help="Path to elmo options file. n_characters in the file should be 262") parser.add_argument( '--align_path', type=str, default='models/align/$l_best_mapping.pth', help="Path to the aligning matrix saved in a pyTorch format. Use $l as a placeholder for language.") parser.add_argument( '-l1', '--language1', type=str, default='en', help="language of sentence 1") parser.add_argument( '-s1', '--sent1', type=str, default= 'A house cat is valued by humans for companionship and for its ability to hunt rodents.', help="sentence in language 1") parser.add_argument( '-w1', '--word1', type=str, default='cat', help= "Examined word from the sentence of language 1 (first occurrence will be used)" ) parser.add_argument( '-l2', '--language2', type=str, default='es', help="language of sentence 2") parser.add_argument( '-s2', '--sent2', type=str, default= 'el gato doméstico está incluido en la lista 100 de las especies exóticas invasoras más dañinas del mundo.', help="sentence in language 2") parser.add_argument( '-w2', '--word2', type=str, default='gato', help= "Examined word from the sentence of language 2 (first occurrence will be used)" ) parser.add_argument( '--layer', type=int, default=1, help="Layer of Elmo to compute for") parser.add_argument( '-c', '--cuda_device', type=int, default=-1, help="Cuda device") args = parser.parse_args() def parse_config(args): ''' Replaces $l for the two languages. Prints the args ''' new_args = copy.deepcopy(args) for k in vars(args): val = getattr(args, k) if type(val) is str and "$l" in val: new_val = val.replace("$l", args.language1) new_k = "{}_{}".format(k, "l1") setattr(new_args, new_k, new_val) new_val = val.replace("$l", args.language2) new_k = "{}_{}".format(k, "l2") setattr(new_args, new_k, new_val) print('-' * 30) for k in vars(new_args): print("{}: {}".format(k, getattr(new_args, k))) print('-' * 30) return new_args def get_sent_embeds(sent, elmo_options_file, elmo_weights_file, layer, cuda_device): ''' Get the embeddings of the sentence words. sent - list of tokens elmo_options_file - json for model. n_characters should be 262 elmo_weights_file - saved model layer - what layer of ELMo to output cuda_device - cuda device returns a numpy array with the embeddings per token for the selected layer ''' elmo = ElmoEmbedder(elmo_options_file, elmo_weights_file, cuda_device) s_embeds = elmo.embed_sentence(sent) layer_embeds = s_embeds[layer,:,:] return layer_embeds def analyze_sents(embeds_l1, embeds_l2, sent1, sent2, w1_ind, w2_ind, k=5): kdt = KDTree(embeds_l1) emb2 = embeds_l2[w2_ind] top_k_inds = kdt.query(emb2, k)[1] top_k_words = [sent1[i] for i in top_k_inds] print('Nearest {} neighbors for {} in "{}":\n{}'.format(k, sent2[w2_ind], ' '.join(sent1), ' ,'.join(top_k_words))) emb1 = embeds_l1[w1_ind] dist = cosine(emb1, emb2) print("Cosine distance between {} and {}: {}".format(sent1[w1_ind], sent2[w2_ind],dist)) if __name__ == '__main__': args = parse_config(args) # Language 1 sent1_tokens = args.sent1.strip().split() w1_ind = sent1_tokens.index(args.word1) s1_embeds = get_sent_embeds(sent1_tokens, args.elmo_options_path, args.elmo_weights_path_l1, args.layer, args.cuda_device) align1 = torch.load(args.align_path_l1) s1_embeds_aligned = np.matmul(s1_embeds, align1.transpose()) # Language 2 sent2_tokens = args.sent2.strip().split() w2_ind = sent2_tokens.index(args.word2) s2_embeds = get_sent_embeds(sent2_tokens, args.elmo_options_path, args.elmo_weights_path_l2, args.layer, args.cuda_device) align2 = torch.load(args.align_path_l2) s2_embeds_aligned = np.matmul(s2_embeds, align2.transpose()) # Analyse print("--- Before alignment:") analyze_sents(s1_embeds, s2_embeds, sent1_tokens, sent2_tokens, w1_ind, w2_ind) print("\n--- After alignment:") analyze_sents(s1_embeds_aligned, s2_embeds_aligned, sent1_tokens, sent2_tokens, w1_ind, w2_ind)