import torch.nn as nn from modules.model import Model from modules.flows.mog_flow import MogFlow_batch import torch from tools.utils import * from tools.dico_builder import build_dictionary import torch.nn.functional as F from evaluation.word_translation import * from torch.nn import CosineEmbeddingLoss import codecs import scipy class E2E(Model): def __init__(self, args, src_dict, tgt_dict, src_embedding, tgt_embedding, device): super(E2E, self).__init__(args) self.args = args self.src_dict = src_dict self.tgt_dict = tgt_dict # src_flow: assume tgt embeddings are transformed from the src mog space self.register_buffer('src_embedding', src_embedding) self.register_buffer('tgt_embedding', tgt_embedding) if args.init_var: # initialize with gaussian variance self.register_buffer("s2t_s_var", src_dict.var) self.register_buffer("s2t_t_var", tgt_dict.var) self.register_buffer("t2s_s_var", src_dict.var) self.register_buffer("t2s_t_var", tgt_dict.var) else: self.s2t_s_var = args.s_var self.s2t_t_var = args.s2t_t_var self.t2s_t_var = args.t_var self.t2s_s_var = args.t2s_s_var self.register_buffer('src_freqs', torch.tensor(src_dict.freqs, dtype=torch.float)) self.register_buffer('tgt_freqs', torch.tensor(tgt_dict.freqs, dtype=torch.float)) # backward: t2s self.src_flow = MogFlow_batch(args, self.t2s_s_var) # backward: s2t self.tgt_flow = MogFlow_batch(args, self.s2t_t_var) self.s2t_valid_dico = None self.t2s_valid_dico = None self.device = device # use dict pairs from train data (supervise) or identical words (supervise_id) as supervisions self.supervise = args.supervise_id if self.supervise: self.load_training_dico() if args.sup_obj == 'mse': self.sup_loss_func = nn.MSELoss() elif args.sup_obj == 'cosine': self.sup_loss_func = CosineEmbeddingLoss() optim_fn, optim_params= get_optimizer(args.flow_opt_params) self.flow_optimizer = optim_fn(list(self.src_flow.parameters()) + list(self.tgt_flow.parameters()), **optim_params) self.flow_scheduler = torch.optim.lr_scheduler.ExponentialLR(self.flow_optimizer, gamma=args.lr_decay) self.best_valid_metric = 1e-12 self.sup_sw = args.sup_s_weight self.sup_tw = args.sup_t_weight self.mse_loss = nn.MSELoss() self.cos_loss = CosineEmbeddingLoss() # Evaluation on trained model if args.load_from_pretrain_s2t != "" or args.load_from_pretrain_t2s != "": self.load_from_pretrain() def orthogonalize(self): """ Orthogonalize the mapping. """ W1 = self.src_flow.W W2 = self.tgt_flow.W beta = 0.01 with torch.no_grad(): for _ in range(self.args.ortho_steps): W1.copy_((1 + beta) * W1 - beta * W1.mm(W1.transpose(0, 1).mm(W1))) W2.copy_((1 + beta) * W2 - beta * W2.mm(W2.transpose(0, 1).mm(W2))) def sup_step(self, src_emb, tgt_emb): src_to_tgt, tgt_to_src, _, _ = self.run_flow(src_emb, tgt_emb, 'both', False) if self.args.sup_obj == "mse": s2t_sim = (src_to_tgt * tgt_emb).sum(dim=1) s2t_sup_loss = self.sup_loss_func(s2t_sim, torch.ones_like(s2t_sim)) t2s_sim = (tgt_to_src * src_emb).sum(dim=1) t2s_sup_loss = self.sup_loss_func(t2s_sim, torch.ones_like(t2s_sim)) loss = s2t_sup_loss + t2s_sup_loss elif self.args.sup_obj == "cosine": target = torch.ones(src_emb.size(0)).to(self.device) s2t_sup_loss = self.sup_loss_func(src_to_tgt, tgt_emb, target) t2s_sup_loss = self.sup_loss_func(tgt_to_src, src_emb, target) loss = s2t_sup_loss + t2s_sup_loss else: raise NotImplementedError # check NaN if (loss != loss).data.any(): print("NaN detected (supervised loss)") exit() return s2t_sup_loss, t2s_sup_loss, loss def flow_step(self, base_src_ids, base_tgt_ids, src_ids, tgt_ids, training_stats, src_emb_in_dict=None, tgt_emb_in_dict=None): src_emb = self.src_embedding[src_ids] tgt_emb = self.tgt_embedding[tgt_ids] base_src_emb = self.src_embedding[base_src_ids] base_tgt_emb = self.tgt_embedding[base_tgt_ids] base_src_var = base_tgt_var = None if self.args.init_var: train_src_var = self.s2t_s_var[src_ids] base_src_var = self.t2s_s_var[base_src_ids] train_tgt_var = self.t2s_t_var[tgt_ids] base_tgt_var = self.s2t_t_var[base_tgt_ids] src_std = torch.sqrt(train_src_var).unsqueeze(1) tgt_std = torch.sqrt(train_tgt_var).unsqueeze(1) else: src_std = math.sqrt(self.s2t_s_var) tgt_std = math.sqrt(self.t2s_t_var) src_emb = src_emb + torch.randn_like(src_emb) * src_std tgt_emb = tgt_emb + torch.randn_like(tgt_emb) * tgt_std if self.args.cofreq: # ids of words are their frequency ranks train_src_freq = src_emb.new_tensor(src_ids) + 1. train_tgt_freq = tgt_emb.new_tensor(tgt_ids) + 1. base_src_freq = src_emb.new_tensor(base_src_ids) + 1. base_tgt_freq = tgt_emb.new_tensor(base_tgt_ids) + 1. else: train_src_freq = train_tgt_freq = None src_freq_normalized = self.src_freqs[base_src_ids] src_freq_normalized = src_freq_normalized / src_freq_normalized.sum() tgt_freq_normalized = self.tgt_freqs[base_tgt_ids] tgt_freq_normalized = tgt_freq_normalized / tgt_freq_normalized.sum() base_src_freq = torch.log(src_freq_normalized) base_tgt_freq = torch.log(tgt_freq_normalized) src_to_tgt, src_ll = self.tgt_flow.backward(src_emb, x=base_tgt_emb, x_freqs=base_tgt_freq, require_log_probs=True, var=base_tgt_var, y_freqs=train_src_freq) tgt_to_src, tgt_ll = self.src_flow.backward(tgt_emb, x=base_src_emb, x_freqs=base_src_freq, require_log_probs=True, var=base_src_var, y_freqs=train_tgt_freq) # the log density of observing src embeddings (transformm to target space) src_nll, tgt_nll = -src_ll.mean(), -tgt_ll.mean() loss = src_nll + tgt_nll if self.args.back_translate_src_w > 0 and self.args.back_translate_tgt_w > 0: target = torch.ones(src_emb.size(0)).to(self.device) tgt_to_src_to_tgt, src_to_tgt_to_src, _, _ = self.run_flow(tgt_to_src, src_to_tgt, 'both', False) src_bt_loss = self.cos_loss(src_emb, src_to_tgt_to_src, target) tgt_bt_loss = self.cos_loss(tgt_emb, tgt_to_src_to_tgt, target) bt_w_src = self.args.back_translate_src_w bt_w_tgt = self.args.back_translate_src_w loss = loss + bt_w_src * src_bt_loss + bt_w_tgt * tgt_bt_loss training_stats["BT_S2T"].append(src_bt_loss.item()) training_stats["BT_T2S"].append(tgt_bt_loss.item()) if self.supervise: assert src_emb_in_dict is not None, tgt_emb_in_dict is not None s2t_sup_loss, t2s_sup_loss, sup_loss = self.sup_step(src_emb_in_dict, tgt_emb_in_dict) loss = loss + self.sup_sw * s2t_sup_loss + self.sup_tw * t2s_sup_loss training_stats["Sup_S2T"].append(s2t_sup_loss.item()) training_stats["Sup_T2S"].append(t2s_sup_loss.item()) else: sup_loss = torch.tensor(0.0) loss.backward() self.flow_optimizer.step() self.flow_scheduler.step() self.flow_optimizer.zero_grad() loss, src_nll, tgt_nll, sup_loss = loss.item(), src_nll.item(), tgt_nll.item(), sup_loss.item() if self.args.cuda: torch.cuda.empty_cache() training_stats["S2T_nll"].append(src_nll) training_stats["T2S_nll"].append(tgt_nll) def load_training_dico(self): """ Load training dictionary. """ word2id1 = self.src_dict.word2id word2id2 = self.tgt_dict.word2id valid_dico_size = 1000 if self.args.supervise_id > 0: id_dict_1, id_dict_2 = load_identical_char_dico(word2id1, word2id2) print("Idenditical dictionary pairs = %d, %d" % (id_dict_1.size(0), id_dict_2.size(0))) dict = id_dict_1[:self.args.supervise_id, :] else: dict = torch.tensor(0) if self.args.valid_option == "train": dict_s2t = load_dictionary(self.args.sup_dict_path, word2id1, word2id2) t2s_dict_path = os.path.join(os.path.dirname(self.args.sup_dict_path), self.tgt_dict.lang + "-" + self.src_dict.lang + ".0-5000.txt") dict_t2s = load_dictionary(t2s_dict_path, word2id2, word2id1, reverse=True) ids_s2t = list(np.random.permutation(range(dict_s2t.size(0)))) ids_t2s = list(np.random.permutation(range(dict_t2s.size(0)))) self.s2t_valid_dico = dict_s2t[ids_s2t[0: valid_dico_size], :] self.t2s_valid_dico = dict_t2s[ids_t2s[0: valid_dico_size], :] self.t2s_valid_dico = torch.cat([dict_t2s[:, 1].unsqueeze(1), dict_t2s[:, 0].unsqueeze(1)], dim=1) print("Loading validation dictionary: %d %d" % (self.s2t_valid_dico.size(0), self.t2s_valid_dico.size(0))) for w1, w2 in self.s2t_valid_dico[:100]: print(self.src_dict.id2word[w1.item()], self.tgt_dict.id2word[w2.item()]) print("-" * 30) for w1, w2 in self.t2s_valid_dico[:100]: print(self.tgt_dict.id2word[w1.item()], self.src_dict.id2word[w2.item()]) print("Pruning dictionary pairs = %d" % dict.size(0)) # toch.LongTensor: [len(pairs), 2] self.dict = dict def run_flow(self, src_emb=None, tgt_emb=None, side="both", require_logll=True): if side == "src": # from src to tgt assert src_emb is not None src_to_tgt, src_log_ll = self.tgt_flow.backward(src_emb, require_log_probs=require_logll) return src_to_tgt, src_log_ll elif side == "tgt": assert tgt_emb is not None tgt_to_src, tgt_log_ll = self.src_flow.backward(tgt_emb, require_log_probs=require_logll) return tgt_to_src, tgt_log_ll elif side == "both": assert tgt_emb is not None and src_emb is not None src_to_tgt, src_log_ll = self.tgt_flow.backward(src_emb, require_log_probs=require_logll) tgt_to_src, tgt_log_ll = self.src_flow.backward(tgt_emb, require_log_probs=require_logll) return src_to_tgt, tgt_to_src, src_log_ll, tgt_log_ll def map_embs(self, src_emb, tgt_emb, s2t=True, t2s=True): src2tgt_emb = tgt2src_emb = None with torch.no_grad(): if s2t: src_to_tgt_list = [] for i, j in get_batches(src_emb.size(0), self.args.dico_batch_size): src_emb_batch = src_emb[i:j, :]#.to(self.device) src_to_tgt, _ = self.run_flow(src_emb=src_emb_batch, side="src", require_logll=False) src_to_tgt_list.append(src_to_tgt.cpu()) # reside on cpu src2tgt_emb = torch.cat(src_to_tgt_list, dim=0) if t2s: tgt_to_src_list = [] for i, j in get_batches(tgt_emb.size(0), self.args.dico_batch_size): tgt_emb_batch = tgt_emb[i:j, :]#.to(self.device) tgt_to_src, _ = self.run_flow(tgt_emb=tgt_emb_batch, side="tgt", require_logll=False) tgt_to_src_list.append(tgt_to_src.cpu()) tgt2src_emb = torch.cat(tgt_to_src_list, dim=0) return src2tgt_emb, tgt2src_emb def build_dictionary(self, src_emb, tgt_emb, s2t=True, t2s=True): # Build dictionary with current trained mappings to augment the original dictionary src_to_tgt_emb, tgt_to_src_emb = self.map_embs(src_emb, tgt_emb, s2t=s2t, t2s=t2s) # torch.longTensor topk = 50000 if s2t: self.build_s2t_dict = torch.cat([self.dict_s2t, build_dictionary(src_to_tgt_emb.cuda()[:topk], tgt_emb[:topk], self.args)], dim=0) s2t = self.build_s2t_dict for i in range(300, 320): print(self.src_dict.id2word[s2t[i, 0].item()], self.tgt_dict.id2word[s2t[i, 1].item()]) if t2s: self.build_t2s_dict = torch.cat([self.dict_t2s, build_dictionary(tgt_to_src_emb.cuda()[:topk], src_emb[:topk], self.args)], dim=0) t2s = self.build_t2s_dict print("---" * 20) for i in range(300, 320): print(self.src_dict.id2word[t2s[i, 1].item()], self.tgt_dict.id2word[t2s[i, 0].item()]) def procrustes(self, src_emb, tgt_emb, s2t=True, t2s=True): """ Find the best orthogonal matrix mapping using the Orthogonal Procrustes problem https://en.wikipedia.org/wiki/Orthogonal_Procrustes_problem """ if s2t: A = src_emb[self.build_s2t_dict[:, 0]] B = tgt_emb[self.build_s2t_dict[:, 1]] W = self.tgt_flow.W M = B.transpose(0, 1).mm(A).cpu().numpy() U, S, V_t = scipy.linalg.svd(M, full_matrices=True) with torch.no_grad(): W.copy_(torch.from_numpy((U.dot(V_t)).transpose()).type_as(W)) if t2s: A = tgt_emb[self.build_t2s_dict[:, 0]] B = src_emb[self.build_t2s_dict[:, 1]] W2 = self.src_flow.W M = B.transpose(0, 1).mm(A).cpu().numpy() U, S, V_t = scipy.linalg.svd(M, full_matrices=True) with torch.no_grad(): W2.copy_(torch.from_numpy((U.dot(V_t)).transpose()).type_as(W2)) def load_best_from_both_sides(self): self.load_best_s2t() self.load_best_t2s() def load_best_s2t(self): print("Load src to tgt mapping to %s" % self.s2t_save_to) to_reload = torch.from_numpy(torch.load(self.s2t_save_to)) with torch.no_grad(): W1 = self.tgt_flow.W W1.copy_(to_reload.type_as(W1)) def load_best_t2s(self): print("Load src to tgt mapping to %s" % self.t2s_save_to) to_reload = torch.from_numpy(torch.load(self.t2s_save_to)) with torch.no_grad(): W1 = self.src_flow.W W1.copy_(to_reload.type_as(W1)) def save_best_s2t(self): print("Save src to tgt mapping to %s" % self.s2t_save_to) with torch.no_grad(): torch.save(self.tgt_flow.W.cpu().numpy(), self.s2t_save_to) def save_best_t2s(self): print("Save tgt to src mapping to %s" % self.t2s_save_to) with torch.no_grad(): torch.save(self.src_flow.W.cpu().numpy(), self.t2s_save_to) def export_embeddings(self, src_emb, tgt_emb, exp_path): self.load_best_from_both_sides() mapped_src_emb, mapped_tgt_emb = self.map_embs(src_emb, tgt_emb) src_path = exp_path + self.src_dict.lang + "2" + self.tgt_dict.lang + "_emb.vec" tgt_path = exp_path + self.tgt_dict.lang + "2" + self.src_dict.lang + "_emb.vec" mapped_src_emb = mapped_src_emb.cpu().numpy() mapped_tgt_emb = mapped_tgt_emb.cpu().numpy() print(f'Writing source embeddings to {src_path}') with io.open(src_path, 'w', encoding='utf-8') as f: f.write(u"%i %i\n" % mapped_src_emb.shape) for i in range(len(self.src_dict)): f.write(u"%s %s\n" % (self.src_dict[i], " ".join('%.5f' % x for x in mapped_src_emb[i]))) print(f'Writing target embeddings to {tgt_path}') with io.open(tgt_path, 'w', encoding='utf-8') as f: f.write(u"%i %i\n" % mapped_tgt_emb.shape) for i in range(len(self.tgt_dict)): f.write(u"%s %s\n" % (self.tgt_dict[i], " ".join('%.5f' % x for x in mapped_tgt_emb[i]))) def load_from_pretrain(self): # load src to tgt W for tgt flow if self.args.load_from_pretrain_s2t is not None: print("Loading from pretrained model %s!" % self.args.load_from_pretrain_s2t) with torch.no_grad(): s2t = torch.from_numpy(torch.load(self.args.load_from_pretrain_s2t)) W1 = self.tgt_flow.W W1.copy_(s2t.type_as(W1)) if self.args.load_from_pretrain_t2s is not None: print("Loading from pretrained model %s!" % self.args.load_from_pretrain_t2s) with torch.no_grad(): t2s = torch.from_numpy(torch.load(self.args.load_from_pretrain_t2s)) W2 = self.src_flow.W W2.copy_(t2s.type_as(W2)) def write_topK(self, dico, topk, fname, id2word_1, id2word_2): dico = dico.cpu().numpy() topk = topk.cpu().numpy() assert dico.shape[0] == topk.shape[0] with codecs.open("../analysis/" + fname, "w", "utf-8") as fout: d = dict() for t, (w1, w2) in enumerate(dico): word_1 = id2word_1[w1] top_10 = [id2word_2[i] for i in topk[t, :]] if word_1 not in d: d[word_1] = [] d[word_1].append(top_10) if id2word_2[w2] in top_10: score = top_10.index(id2word_2[w2]) else: score = -1 d[word_1].append((id2word_2[w2], score)) for kword, ll in d.items(): best_score = -1 fout.write(kword + ": " + " ".join(["Top 10:"] + ll[0]) + "\n") groud_words = [] for word_2, s in ll[1:]: if s > best_score: best_score = s groud_words.append(word_2) fout.write("Ground Truth words: " + " ".join(groud_words) + "\n") fout.write("Best match: " + str(best_score) + "\n") fout.write("-" * 50 + "\n") def check_word_translation(self, full_src_emb, full_tgt_emb, topK=True, density=False): src_to_tgt_emb, tgt_to_src_emb = self.map_embs(full_src_emb, full_tgt_emb) s2t_path = self.src_dict.lang + "-" + self.tgt_dict.lang + ".topK" t2s_path = self.tgt_dict.lang + "-" + self.src_dict.lang + ".topK" if density: print("<%s> TO <%s> Evaluation!" % (self.src_dict.lang, self.tgt_dict.lang)) for method in ['density']: s2t_dico, s2t_top_k = get_word_translation_accuracy( self.src_dict.lang, self.src_dict.word2id, src_to_tgt_emb, # query self.tgt_dict.lang, self.tgt_dict.word2id, full_tgt_emb.cpu(), method=method, dico_eval=self.args.dico_eval, get_scores=topK, var=self.args.s2t_t_var ) self.write_topK(s2t_dico, s2t_top_k, s2t_path, self.src_dict.id2word, self.tgt_dict.id2word) print("<%s> TO <%s> Evaluation!" % (self.tgt_dict.lang, self.src_dict.lang)) tgt_to_src_path = os.path.join(os.path.dirname(self.args.dico_eval), self.args.tgt_lang + "-" + self.args.src_lang + ".5000-6500.txt") for method in ['density']: t2s_dico, t2s_top_k = get_word_translation_accuracy( self.tgt_dict.lang, self.tgt_dict.word2id, tgt_to_src_emb, # query self.src_dict.lang, self.src_dict.word2id, full_src_emb.cpu(), method=method, dico_eval=tgt_to_src_path, get_scores=topK, var=self.args.t2s_s_var ) self.write_topK(t2s_dico, t2s_top_k, t2s_path, self.tgt_dict.id2word, self.src_dict.id2word) return if topK: print("<%s> TO <%s> Evaluation!" % (self.src_dict.lang, self.tgt_dict.lang)) for method in ['nn', 'csls_knn_10']: s2t_dico, s2t_top_k = get_word_translation_accuracy( self.src_dict.lang, self.src_dict.word2id, src_to_tgt_emb, # query self.tgt_dict.lang, self.tgt_dict.word2id, full_tgt_emb.cpu(), method=method, dico_eval=self.args.dico_eval, get_scores=topK ) self.write_topK(s2t_dico, s2t_top_k, s2t_path, self.src_dict.id2word, self.tgt_dict.id2word) print("<%s> TO <%s> Evaluation!" % (self.tgt_dict.lang, self.src_dict.lang)) tgt_to_src_path = os.path.join(os.path.dirname(self.args.dico_eval), self.args.tgt_lang + "-" + self.args.src_lang + ".5000-6500.txt") for method in ['nn', 'csls_knn_10']: t2s_dico, t2s_top_k = get_word_translation_accuracy( self.tgt_dict.lang, self.tgt_dict.word2id, tgt_to_src_emb, # query self.src_dict.lang, self.src_dict.word2id, full_src_emb.cpu(), method=method, dico_eval=tgt_to_src_path, get_scores=topK ) self.write_topK(t2s_dico, t2s_top_k, t2s_path, self.tgt_dict.id2word, self.src_dict.id2word)