""" This script contains the functions related to Gromove-Wasserstein Learning """ import copy from dev.util import logger import matplotlib.pyplot as plt import numpy as np import pickle from preprocess.DataIO import IndexSampler, cost_sampler1, cost_sampler2 from sklearn.manifold import TSNE import torch import torch.nn as nn from torch.utils.data import DataLoader class GromovWassersteinEmbedding(nn.Module): """ Learning embeddings from Cosine similarity """ def __init__(self, num1: int, num2: int, dim: int, cost_type: str = 'cosine', loss_type: str = 'L2'): super(GromovWassersteinEmbedding, self).__init__() self.num1 = num1 self.num2 = num2 self.dim = dim self.cost_type = cost_type self.loss_type = loss_type emb1 = nn.Embedding(self.num1, self.dim) emb1.weight = nn.Parameter( torch.FloatTensor(self.num1, self.dim).uniform_(-1 / self.dim, 1 / self.dim)) emb2 = nn.Embedding(self.num2, self.dim) emb2.weight = nn.Parameter( torch.FloatTensor(self.num2, self.dim).uniform_(-1 / self.dim, 1 / self.dim)) self.emb_model = nn.ModuleList([emb1, emb2]) def orthogonal(self, index, idx): embs = self.emb_model[idx](index) orth = torch.matmul(torch.t(embs), embs) orth -= torch.eye(embs.size(1)) return (orth**2).sum() def self_cost_mat(self, index, idx): embs = self.emb_model[idx](index) # (batch_size, dim) if self.cost_type == 'cosine': # cosine similarity energy = torch.sqrt(torch.sum(embs ** 2, dim=1, keepdim=True)) # (batch_size, 1) cost = 1-torch.exp(-5*(1-torch.matmul(embs, torch.t(embs)) / (torch.matmul(energy, torch.t(energy)) + 1e-5))) else: # Euclidean distance embs = torch.matmul(embs, torch.t(embs)) # (batch_size, batch_size) embs_diag = torch.diag(embs).view(-1, 1).repeat(1, embs.size(0)) # (batch_size, batch_size) cost = 1-torch.exp(-(embs_diag + torch.t(embs_diag) - 2 * embs)/embs.size(1)) return cost def mutual_cost_mat(self, index1, index2): embs1 = self.emb_model[0](index1) # (batch_size1, dim) embs2 = self.emb_model[1](index2) # (batch_size2, dim) if self.cost_type == 'cosine': # cosine similarity energy1 = torch.sqrt(torch.sum(embs1 ** 2, dim=1, keepdim=True)) # (batch_size1, 1) energy2 = torch.sqrt(torch.sum(embs2 ** 2, dim=1, keepdim=True)) # (batch_size2, 1) cost = 1-torch.exp(-(1-torch.matmul(embs1, torch.t(embs2))/(torch.matmul(energy1, torch.t(energy2))+1e-5))) else: # Euclidean distance embs = torch.matmul(embs1, torch.t(embs2)) # (batch_size1, batch_size2) # (batch_size1, batch_size2) embs_diag1 = torch.diag(torch.matmul(embs1, torch.t(embs1))).view(-1, 1).repeat(1, embs2.size(0)) # (batch_size2, batch_size1) embs_diag2 = torch.diag(torch.matmul(embs2, torch.t(embs2))).view(-1, 1).repeat(1, embs1.size(0)) cost = 1-torch.exp(-(embs_diag1 + torch.t(embs_diag2) - 2 * embs)/embs1.size(1)) return cost def tensor_times_mat(self, cost_s, cost_t, trans, mu_s, mu_t): if self.loss_type == 'L2': # f1(a) = a^2, f2(b) = b^2, h1(a) = a, h2(b) = 2b # cost_st = f1(cost_s)*mu_s*1_nt^T + 1_ns*mu_t^T*f2(cost_t)^T # cost = cost_st - h1(cost_s)*trans*h2(cost_t)^T f1_st = torch.matmul(cost_s ** 2, mu_s).repeat(1, trans.size(1)) f2_st = torch.matmul(torch.t(mu_t), torch.t(cost_t ** 2)).repeat(trans.size(0), 1) cost_st = f1_st + f2_st cost = cost_st - 2 * torch.matmul(torch.matmul(cost_s, trans), torch.t(cost_t)) else: # f1(a) = a*log(a) - a, f2(b) = b, h1(a) = a, h2(b) = log(b) # cost_st = f1(cost_s)*mu_s*1_nt^T + 1_ns*mu_t^T*f2(cost_t)^T # cost = cost_st - h1(cost_s)*trans*h2(cost_t)^T f1_st = torch.matmul(cost_s * torch.log(cost_s + 1e-5) - cost_s, mu_s).repeat(1, trans.size(1)) f2_st = torch.matmul(torch.t(mu_t), torch.t(cost_t)).repeat(trans.size(0), 1) cost_st = f1_st + f2_st cost = cost_st - torch.matmul(torch.matmul(cost_s, trans), torch.t(torch.log(cost_t + 1e-5))) return cost def similarity(self, cost_pred, cost_truth, mask=None): if mask is None: if self.loss_type == 'L2': loss = ((cost_pred - cost_truth) ** 2) * torch.exp(-cost_truth) else: loss = cost_pred * torch.log(cost_pred / (cost_truth + 1e-5)) else: if self.loss_type == 'L2': # print(mask.size()) # print(cost_truth.size()) # print(cost_pred.size()) loss = mask.data * ((cost_pred - cost_truth) ** 2) * torch.exp(-cost_truth) else: loss = mask.data * (cost_pred * torch.log(cost_pred / (cost_truth + 1e-5))) loss = loss.sum() return loss def forward(self, index1, index2, trans, mu_s, mu_t, cost1, cost2, prior=None, mask1=None, mask2=None, mask12=None): cost_s = self.self_cost_mat(index1, 0) cost_t = self.self_cost_mat(index2, 1) cost_st = self.mutual_cost_mat(index1, index2) cost = self.tensor_times_mat(cost_s, cost_t, trans, mu_s, mu_t) d_gw = (cost * trans).sum() d_w = (cost_st * trans).sum() regularizer = self.similarity(cost_s, cost1, mask1) + self.similarity(cost_t, cost2, mask2) regularizer += self.orthogonal(index1, 0) + self.orthogonal(index2, 1) if prior is not None: regularizer += self.similarity(cost_st, prior, mask12) return d_gw, d_w, regularizer def plot_and_save(self, index1: torch.Tensor, index2: torch.Tensor, output_name: str = None): """ Plot and save cost matrix Args: index1: a (batch_size, 1) Long/CudaLong Tensor indicating the indices of entities index2: a (batch_size, 1) Long/CudaLong Tensor indicating the indices of entities output_name: a string indicating the output image's name Returns: save cost matrix as a .png file """ cost_s = self.self_cost_mat(index1, 0).data.cpu().numpy() cost_t = self.self_cost_mat(index2, 0).data.cpu().numpy() cost_st = self.mutual_cost_mat(index1, index2).data.cpu().numpy() pc_kwargs = {'rasterized': True, 'cmap': 'viridis'} fig, axs = plt.subplots(1, 3, figsize=(5, 5), constrained_layout=True) im = axs[0, 0].pcolormesh(cost_s, **pc_kwargs) fig.colorbar(im, ax=axs[0, 0]) axs[0, 0].set_title('source cost') axs[0, 0].set_aspect('equal') im = axs[0, 1].pcolormesh(cost_t, **pc_kwargs) fig.colorbar(im, ax=axs[0, 1]) axs[0, 1].set_title('target cost') axs[0, 1].set_aspect('equal') im = axs[0, 2].pcolormesh(cost_st, **pc_kwargs) fig.colorbar(im, ax=axs[0, 2]) axs[0, 2].set_title('mutual cost') axs[0, 2].set_aspect('equal') if output_name is None: plt.savefig('result.png') else: plt.savefig(output_name) plt.close("all") class GromovWassersteinLearning(object): """ Learning Gromov-Wasserstein distance in a nonparametric way. """ def __init__(self, hyperpara_dict): """ Initialize configurations Args: hyperpara_dict: a dictionary containing the configurations of model dict = {'src_number': the number of entities in the source domain, 'tar_number': the number of entities in the target domain, 'dimension': the proposed dimension of entities' embeddings, 'loss_type': 'KL' or 'L2' } """ self.src_num = hyperpara_dict['src_number'] self.tar_num = hyperpara_dict['tar_number'] self.dim = hyperpara_dict['dimension'] self.loss_type = hyperpara_dict['loss_type'] self.cost_type = hyperpara_dict['cost_type'] self.ot_method = hyperpara_dict['ot_method'] self.gwl_model = GromovWassersteinEmbedding(self.src_num, self.tar_num, self.dim, self.loss_type) self.d_gw = [] self.trans = np.zeros((self.src_num, self.tar_num)) self.Prec = [] self.Recall = [] self.F1 = [] self.NC1 = [] self.NC2 = [] self.EC1 = [] self.EC2 = [] def plot_result(self, index_s, index_t, epoch, prefix): # tsne embs_s = self.gwl_model.emb_model[0](index_s) embs_t = self.gwl_model.emb_model[1](index_t) embs = np.concatenate((embs_s.cpu().data.numpy(), embs_t.cpu().data.numpy()), axis=0) embs = TSNE(n_components=2).fit_transform(embs) plt.figure(figsize=(5, 5)) plt.scatter(embs[:embs_s.size(0), 0], embs[:embs_s.size(0), 1], marker='.', s=0.5, c='b', edgecolors='b', label='graph 1') plt.scatter(embs[-embs_t.size(0):, 0], embs[-embs_t.size(0):, 1], marker='o', s=8, c='', edgecolors='r', label='graph 2') leg = plt.legend(loc='upper left', ncol=1, shadow=True, fancybox=True) leg.get_frame().set_alpha(0.5) plt.title('T-SNE of node embeddings') plt.savefig('{}/emb_epoch{}_{}_{}.pdf'.format(prefix, epoch, self.ot_method, self.cost_type)) plt.close("all") trans_b = np.zeros(self.trans.shape) for i in range(trans_b.shape[0]): idx = np.argmax(self.trans[i, :]) trans_b[i, idx] = 1 plt.imshow(trans_b) plt.savefig('{}/trans_epoch{}_{}_{}.png'.format(prefix, epoch, self.ot_method, self.cost_type)) plt.close('all') def evaluation_matching(self, trans: np.ndarray, cost_s: np.ndarray, cost_t: np.ndarray, index_s: np.ndarray, index_t: np.ndarray, mask_s: np.ndarray, mask_t: np.ndarray): """ Evaluate graph matching result Args: trans: (ns, nt) ndarray cost_s: (ns, ns) ndarray of source cost cost_t: (nt, nt) ndarray of target cost index_s: (ns, ) ndarray of source index index_t: (nt, ) ndarray of target index Returns: nc1: node correctness based on trans: #correctly-matched nodes/#nodes * 100% ec1: edge correctness based on trans: #correctly-matched edges/#edges * 100% nc2: node correctness based on cost_st ec2: edge correctness based on cost_st """ nc1 = 0 nc2 = 0 ec1 = 0 ec2 = 0 num_edges = np.sum(mask_s) cost_s += np.eye(trans.shape[0]) cost_s = 1 / cost_s cost_s -= 1 cost_s[cost_s < 1] = 0 cost_t += np.eye(trans.shape[1]) cost_t = 1 / cost_t cost_t -= 1 cost_t[cost_t < 1] = 0 # edge correctness cost_st = self.gwl_model.mutual_cost_mat(index_s, index_t) cost_st = cost_st.cpu().data.numpy() pair1 = [] pair2 = [] for i in range(trans.shape[0]): j1 = np.argmax(trans[i, :]) j2 = np.argmin(cost_st[i, :]) pair1.append(j1) pair2.append(j2) if index_s[i] == index_t[j1]: nc1 += 1 if index_s[i] == index_t[j2]: nc2 += 1 nc1 = nc1 / trans.shape[0] * 100. nc2 = nc2 / trans.shape[0] * 100. idx = np.transpose(np.nonzero(cost_s)) for n in range(idx.shape[0]): rs = idx[n, 0] cs = idx[n, 1] rt1 = pair1[rs] rt2 = pair2[rs] ct1 = pair1[cs] ct2 = pair2[cs] if mask_t[rt1, ct1] > 0 or mask_t[ct1, rt1] > 0: ec1 += 1 if mask_t[rt2, ct2] > 0 or mask_t[ct2, rt2] > 0: ec2 += 1 ec1 = ec1 / num_edges * 100. ec2 = ec2 / num_edges * 100. return nc1, ec1, nc2, ec2 def evaluation_recommendation(self, database): index_s = torch.LongTensor(list(range(self.src_num))) index_t = torch.LongTensor(list(range(self.tar_num))) cost_st = self.gwl_model.mutual_cost_mat(index_s, index_t) cost_st = cost_st.cpu().data.numpy() prec = np.zeros((3,)) recall = np.zeros((3,)) f1 = np.zeros((3,)) tops = [1, 3, 5] num = 0 for n in range(len(database['mutual_interactions'])): pair = database['mutual_interactions'][n] source_list = pair[0] target_list = pair[1] prec_n = np.zeros((3,)) recall_n = np.zeros((3,)) for i in range(len(source_list)): s = source_list[i] if i == 0: items = cost_st[s, :] else: items += cost_st[s, :] idx = np.argsort(items) for i in range(len(tops)): top = tops[i] top_items = idx[:(top*len(target_list))] for recommend_item in top_items: if recommend_item in target_list: prec_n[i] += 1/top recall_n[i] += 1/len(target_list) prec_n *= 100 recall_n *= 100 f1_n = (2*prec_n*recall_n)/(prec_n+recall_n+1e-8) prec += prec_n recall += recall_n f1 += f1_n num += 1 # for n in range(len(database['mutual_interactions'])): # pair = database['mutual_interactions'][n] # source_list = pair[0] # target_list = pair[1] # for s in source_list: # prec_s = np.zeros((3,)) # recall_s = np.zeros((3,)) # # items = cost_st[s, :] # idx = np.argsort(items) # from small to large # for i in range(len(tops)): # top = tops[i] # top_items = idx[:top] # for recommend_item in top_items: # if recommend_item in target_list: # prec_s[i] += 1/top # recall_s[i] += 1/len(target_list) # prec_s *= 100 # recall_s *= 100 # f1_s = (2*prec_s*recall_s)/(prec_s+recall_s+1e-8) # # prec += prec_s # recall += recall_s # f1 += f1_s # num += 1 prec /= num recall /= num f1 /= num return prec, recall, f1 def regularized_gromov_wasserstein_discrepancy(self, cost_s, cost_t, cost_mutual, mu_s, mu_t, hyperpara_dict): """ Learning optimal transport from source to target domain Args: cost_s: (Ns, Ns) matrix representing the relationships among source entities cost_t: (Nt, Nt) matrix representing the relationships among target entities cost_mutual: (Ns, Nt) matrix representing the prior of proposed optimal transport mu_s: (Ns, 1) vector representing marginal probability of source entities mu_t: (Nt, 1) vector representing marginal probability of target entities hyperpara_dict: a dictionary of hyperparameters dict = {epochs: the number of epochs, batch_size: batch size, use_cuda: use cuda or not, strategy: hard or soft, beta: the weight of proximal term outer_iter: the outer iteration of ipot inner_iter: the inner iteration of sinkhorn prior: True or False } Returns: """ ns = mu_s.size(0) nt = mu_t.size(0) trans = torch.matmul(mu_s, torch.t(mu_t)) a = mu_s.sum().repeat(ns, 1) a /= a.sum() b = 0 beta = hyperpara_dict['beta'] if self.loss_type == 'L2': # f1(a) = a^2, f2(b) = b^2, h1(a) = a, h2(b) = 2b # cost_st = f1(cost_s)*mu_s*1_nt^T + 1_ns*mu_t^T*f2(cost_t)^T # cost = cost_st - h1(cost_s)*trans*h2(cost_t)^T f1_st = torch.matmul(cost_s ** 2, mu_s).repeat(1, nt) f2_st = torch.matmul(torch.t(mu_t), torch.t(cost_t ** 2)).repeat(ns, 1) cost_st = f1_st + f2_st for t in range(hyperpara_dict['outer_iteration']): cost = cost_st - 2 * torch.matmul(torch.matmul(cost_s, trans), torch.t(cost_t)) + 0.1*cost_mutual if self.ot_method == 'proximal': kernel = torch.exp(-cost / beta) * trans else: kernel = torch.exp(-cost / beta) for l in range(hyperpara_dict['inner_iteration']): b = mu_t / torch.matmul(torch.t(kernel), a) a = mu_s / torch.matmul(kernel, b) # print((b**2).sum()) # print((a**2).sum()) # print((b**2).sum()*(a**2).sum()) trans = torch.matmul(torch.matmul(torch.diag(a[:, 0]), kernel), torch.diag(b[:, 0])) if t % 100 == 0: print('sinkhorn iter {}/{}'.format(t, hyperpara_dict['outer_iteration'])) cost = cost_st - 2 * torch.matmul(torch.matmul(cost_s, trans), torch.t(cost_t)) else: # f1(a) = a*log(a) - a, f2(b) = b, h1(a) = a, h2(b) = log(b) # cost_st = f1(cost_s)*mu_s*1_nt^T + 1_ns*mu_t^T*f2(cost_t)^T # cost = cost_st - h1(cost_s)*trans*h2(cost_t)^T f1_st = torch.matmul(cost_s * torch.log(cost_s + 1e-5) - cost_s, mu_s).repeat(1, nt) f2_st = torch.matmul(torch.t(mu_t), torch.t(cost_t)).repeat(ns, 1) cost_st = f1_st + f2_st for t in range(hyperpara_dict['outer_iteration']): cost = cost_st - torch.matmul(torch.matmul(cost_s, trans), torch.t(torch.log(cost_t + 1e-5))) if self.ot_method == 'proximal': kernel = torch.exp(-cost / beta) * trans else: kernel = torch.exp(-cost / beta) for l in range(hyperpara_dict['inner_iteration']): b = mu_t / torch.matmul(torch.t(kernel), a) a = mu_s / torch.matmul(kernel, b) trans = torch.matmul(torch.matmul(torch.diag(a[:, 0]), kernel), torch.diag(b[:, 0])) cost = cost_st - torch.matmul(torch.matmul(cost_s, trans), torch.t(torch.log(cost_t + 1e-5))) d_gw = (cost * trans).sum() return trans, d_gw, cost def train_without_prior(self, database, optimizer, hyperpara_dict, scheduler=None): """ Regularized Gromov-Wasserstein Embedding Args: database: proposed database optimizer: the pytorch optimizer hyperpara_dict: a dictionary of hyperparameters dict = {epochs: the number of epochs, batch_size: batch size, use_cuda: use cuda or not, strategy: hard or soft, beta: the weight of proximal term outer_iter: the outer iteration of ipot inner_iter: the inner iteration of sinkhorn prior: True or False } scheduler: scheduler of learning rate. Returns: d_gw, trans """ device = torch.device('cuda:0' if hyperpara_dict['use_cuda'] else 'cpu') if hyperpara_dict['use_cuda']: torch.cuda.manual_seed(1) kwargs = {'num_workers': 1, 'pin_memory': True} if hyperpara_dict['use_cuda'] else {} self.gwl_model.to(device) self.gwl_model.train() num_src_node = len(database['src_interactions']) num_tar_node = len(database['tar_interactions']) src_loader = DataLoader(IndexSampler(num_src_node), batch_size=hyperpara_dict['batch_size'], shuffle=True, **kwargs) tar_loader = DataLoader(IndexSampler(num_tar_node), batch_size=hyperpara_dict['batch_size'], shuffle=True, **kwargs) for epoch in range(hyperpara_dict['epochs']): gw = 0 trans_tmp = np.zeros(self.trans.shape) if scheduler is not None: scheduler.step() for src_idx, indices1 in enumerate(src_loader): for tar_idx, indices2 in enumerate(tar_loader): # Estimate Gromov-Wasserstein discrepancy give current costs cost_s, cost_t, mu_s, mu_t, index_s, index_t, mask_s, mask_t = \ cost_sampler2(database, indices1, indices2, device) if hyperpara_dict['display']: self.plot_result(index_s, index_t, epoch, prefix=hyperpara_dict['prefix']) if hyperpara_dict['strategy'] == 'hard': z = np.random.rand() if z < epoch/hyperpara_dict['epochs']: # cost1 = mask_s.data * self.gwl_model.self_cost_mat(index_s, 0).data # cost2 = mask_t.data * self.gwl_model.self_cost_mat(index_t, 1).data cost1 = self.gwl_model.self_cost_mat(index_s, 0).data cost2 = self.gwl_model.self_cost_mat(index_t, 1).data cost12 = self.gwl_model.mutual_cost_mat(index_s, index_t).data else: cost1 = cost_s cost2 = cost_t cost12 = 0 else: # cost_s_emb = mask_s.data * self.gwl_model.self_cost_mat(index_s, 0).data # cost_t_emb = mask_t.data * self.gwl_model.self_cost_mat(index_t, 1).data cost_s_emb = self.gwl_model.self_cost_mat(index_s, 0).data cost_t_emb = self.gwl_model.self_cost_mat(index_t, 1).data cost_st_12 = self.gwl_model.mutual_cost_mat(index_s, index_t).data # alpha = max([(hyperpara_dict['epochs'] - epoch) / hyperpara_dict['epochs'], 0.5]) alpha = (hyperpara_dict['epochs'] - epoch) / hyperpara_dict['epochs'] cost1 = alpha * cost_s + (1-alpha) * cost_s_emb cost2 = alpha * cost_t + (1-alpha) * cost_t_emb cost12 = (1-alpha) * cost_st_12 trans, d_gw, cost_12 = self.regularized_gromov_wasserstein_discrepancy(cost1, cost2, cost12, mu_s, mu_t, hyperpara_dict) # estimate optimal transport trans_np = trans.cpu().data.numpy() index_s_np = index_s.cpu().data.numpy() index_t_np = index_t.cpu().data.numpy() patch = self.trans[index_s_np, :] patch = patch[:, index_t_np] energy = np.sum(patch) + 1 for row in range(trans_np.shape[0]): for col in range(trans_np.shape[1]): trans_tmp[index_s_np[row], index_t_np[col]] += (energy * trans_np[row, col]) gw += d_gw if epoch == 0: sgd_iter = hyperpara_dict['sgd_iteration'] else: sgd_iter = 100 # inner iteration based on SGD for num in range(sgd_iter): # zero the parameter gradients optimizer.zero_grad() # Update source and target embeddings alternatively loss_gw, loss_w, regularizer = self.gwl_model(index_s, index_t, trans, mu_s, mu_t, cost_s, cost_t, prior=cost_12, mask1=mask_s, mask2=mask_t, mask12=None) loss = 1e3 * loss_gw + 1e3 * loss_w + regularizer loss.backward() optimizer.step() if num % 10 == 0: print('inner {}/{}: loss={:.6f}.'.format(num, sgd_iter, loss.data)) nc1, ec1, nc2, ec2 = self.evaluation_matching(trans_np, cost_s.cpu().data.numpy(), cost_t.cpu().data.numpy(), index_s, index_t, mask_s.cpu().data.numpy(), mask_t.cpu().data.numpy()) self.NC1.append(nc1) self.NC2.append(nc2) self.EC1.append(ec1) self.EC2.append(ec2) logger.info('Train Epoch: {}'.format(epoch)) logger.info('- node correctness: {:.4f}%, {:.4f}%'.format(nc1, nc2)) logger.info('- edge correctness: {:.4f}%, {:.4f}%'.format(ec1, ec2)) if src_idx % 100 == 1: logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]'.format( epoch, src_idx * hyperpara_dict['batch_size'], len(src_loader.dataset), 100. * src_idx / len(src_loader))) logger.info('- GW distance = {:.4f}.'.format(gw/len(src_loader))) trans_tmp /= np.max(trans_tmp) self.trans = trans_tmp self.d_gw.append(gw/len(src_loader)) def train_with_prior(self, database, optimizer, hyperpara_dict, scheduler=None): """ Regularized Gromov-Wasserstein Embedding Args: database: proposed database optimizer: the pytorch optimizer hyperpara_dict: a dictionary of hyperparameters dict = {epochs: the number of epochs, batch_size: batch size, use_cuda: use cuda or not, strategy: hard or soft, beta: the weight of proximal term outer_iter: the outer iteration of ipot inner_iter: the inner iteration of sinkhorn prior: True or False } scheduler: scheduler of learning rate. Returns: d_gw, trans """ device = torch.device('cuda:0' if hyperpara_dict['use_cuda'] else 'cpu') if hyperpara_dict['use_cuda']: torch.cuda.manual_seed(1) kwargs = {'num_workers': 1, 'pin_memory': True} if hyperpara_dict['use_cuda'] else {} self.gwl_model.to(device) self.gwl_model.train() num_interaction = len(database['mutual_interactions']) train_base = copy.deepcopy(database) test_base = copy.deepcopy(database) train_base['mutual_interactions'] = train_base['mutual_interactions'][:int(0.75*num_interaction)] test_base['mutual_interactions'] = test_base['mutual_interactions'][int(0.75*num_interaction):] num_interaction_train = len(train_base['mutual_interactions']) dataloader = DataLoader(IndexSampler(num_interaction_train), batch_size=hyperpara_dict['batch_size'], shuffle=True, **kwargs) for epoch in range(hyperpara_dict['epochs']): gw = 0 trans_tmp = np.zeros(self.trans.shape) if scheduler is not None: scheduler.step() for batch_idx, indices in enumerate(dataloader): # Estimate Gromov-Wasserstein discrepancy give current costs cost_s, cost_t, mu_s, mu_t, index_s, index_t, prior, mask_s, mask_t, mask_st = \ cost_sampler1(train_base, indices, device) self.plot_result(index_s, index_t, epoch, prefix=hyperpara_dict['prefix']) if hyperpara_dict['strategy'] == 'hard': z = np.random.rand() if z < epoch / hyperpara_dict['epochs']: cost1 = mask_s.data * self.gwl_model.self_cost_mat(index_s, 0).data cost2 = mask_t.data * self.gwl_model.self_cost_mat(index_t, 1).data cost12 = mask_st.data * self.gwl_model.mutual_cost_mat(index_s, index_t).data else: cost1 = cost_s cost2 = cost_t cost12 = prior else: cost_s_emb = mask_s.data * self.gwl_model.self_cost_mat(index_s, 0).data cost_t_emb = mask_t.data * self.gwl_model.self_cost_mat(index_t, 1).data cost_st_12 = mask_st.data * self.gwl_model.mutual_cost_mat(index_s, index_t).data alpha = max([(hyperpara_dict['epochs'] - epoch) / hyperpara_dict['epochs'], 0.7]) cost1 = alpha * cost_s + (1 - alpha) * cost_s_emb cost2 = alpha * cost_t + (1 - alpha) * cost_t_emb cost12 = alpha * prior + (1 - alpha) * cost_st_12 trans, d_gw, cost_12 = self.regularized_gromov_wasserstein_discrepancy(cost1, cost2, cost12, mu_s, mu_t, hyperpara_dict) # estimate optimal transport trans_np = trans.cpu().data.numpy() index_s_np = index_s.cpu().data.numpy() index_t_np = index_t.cpu().data.numpy() patch = self.trans[index_s_np, :] patch = patch[:, index_t_np] energy = np.sum(patch) + 1 for row in range(trans_np.shape[0]): for col in range(trans_np.shape[1]): trans_tmp[index_s_np[row], index_t_np[col]] += (energy * trans_np[row, col]) gw += d_gw # inner iteration based on SGD if epoch == 0: sgd_iter = hyperpara_dict['sgd_iteration'] else: sgd_iter = 20 for num in range(sgd_iter): # zero the parameter gradients optimizer.zero_grad() # Update source and target embeddings alternatively loss_gw, loss_w, regularizer = self.gwl_model(index_s, index_t, trans, mu_s, mu_t, cost_s, cost_t, prior, # +0.5*cost_12, mask_s, mask_t, mask_st) loss = loss_gw + loss_w + regularizer loss.backward() optimizer.step() if num % 10 == 0: print('inner {}/{}: loss={:.6f}.'.format(num, sgd_iter, loss.data)) prec, recall, f1 = self.evaluation_recommendation(test_base) self.Prec.append(prec) self.Recall.append(recall) self.F1.append(f1) logger.info('Train Epoch: {}'.format(epoch)) logger.info('- OT method={}, Distance={}'.format(self.ot_method, self.cost_type)) tops = [1, 3, 5] for top in range(3): logger.info('- Top-{}, precision={:.4f}%, recall={:.4f}%, f1={:.4f}%'.format( tops[top], prec[top], recall[top], f1[top])) if batch_idx % 100 == 1: logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]'.format( epoch, batch_idx * hyperpara_dict['batch_size'], len(dataloader.dataset), 100. * batch_idx / len(dataloader))) logger.info('- GW distance = {:.4f}.'.format(gw / len(dataloader))) trans_tmp /= np.max(trans_tmp) self.trans = trans_tmp self.d_gw.append(gw / len(dataloader)) def obtain_embedding(self, hyperpara_dict, index, idx): device = torch.device('cuda:0' if hyperpara_dict['use_cuda'] else 'cpu') self.gwl_model.to(device) self.gwl_model.eval() return self.gwl_model.emb_model[idx](index) def save_model(self, full_path, mode: str = 'entire'): """ Save trained model :param full_path: the path of directory :param mode: 'parameter' for saving only parameters of the model, 'entire' for saving entire model """ if mode == 'entire': torch.save(self.gwl_model, full_path) logger.info('The entire model is saved in {}.'.format(full_path)) elif mode == 'parameter': torch.save(self.gwl_model.state_dict(), full_path) logger.info('The parameters of the model is saved in {}.'.format(full_path)) else: logger.warning("'{}' is a undefined mode, we use 'entire' mode instead.".format(mode)) torch.save(self.gwl_model, full_path) logger.info('The entire model is saved in {}.'.format(full_path)) def load_model(self, full_path, mode: str = 'entire'): """ Load pre-trained model :param full_path: the path of directory :param mode: 'parameter' for saving only parameters of the model, 'entire' for saving entire model """ if mode == 'entire': self.gwl_model = torch.load(full_path) elif mode == 'parameter': self.gwl_model.load_state_dict(torch.load(full_path)) else: logger.warning("'{}' is a undefined mode, we use 'entire' mode instead.".format(mode)) self.gwl_model = torch.load(full_path) def save_matching(self, full_path): with open(full_path, 'wb') as f: # Python 3: open(..., 'wb') pickle.dump([self.NC1, self.EC1, self.NC2, self.EC2, self.d_gw], f) def save_recommend(self, full_path): with open(full_path, 'wb') as f: # Python 3: open(..., 'wb') pickle.dump([self.Prec, self.Recall, self.F1, self.d_gw, self.trans], f)