import math import os import sys import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import tqdm from numpy import random from torch.nn.parameter import Parameter from utils import * def get_batches(pairs, neighbors, batch_size): n_batches = (len(pairs) + (batch_size - 1)) // batch_size for idx in range(n_batches): x, y, t, neigh = [], [], [], [] for i in range(batch_size): index = idx * batch_size + i if index >= len(pairs): break x.append(pairs[index][0]) y.append(pairs[index][1]) t.append(pairs[index][2]) neigh.append(neighbors[pairs[index][0]]) yield torch.tensor(x), torch.tensor(y), torch.tensor(t), torch.tensor(neigh) class GATNEModel(nn.Module): def __init__( self, num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a ): super(GATNEModel, self).__init__() self.num_nodes = num_nodes self.embedding_size = embedding_size self.embedding_u_size = embedding_u_size self.edge_type_count = edge_type_count self.dim_a = dim_a self.node_embeddings = Parameter(torch.FloatTensor(num_nodes, embedding_size)) self.node_type_embeddings = Parameter( torch.FloatTensor(num_nodes, edge_type_count, embedding_u_size) ) self.trans_weights = Parameter( torch.FloatTensor(edge_type_count, embedding_u_size, embedding_size) ) self.trans_weights_s1 = Parameter( torch.FloatTensor(edge_type_count, embedding_u_size, dim_a) ) self.trans_weights_s2 = Parameter(torch.FloatTensor(edge_type_count, dim_a, 1)) self.reset_parameters() def reset_parameters(self): self.node_embeddings.data.uniform_(-1.0, 1.0) self.node_type_embeddings.data.uniform_(-1.0, 1.0) self.trans_weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights_s1.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) self.trans_weights_s2.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) def forward(self, train_inputs, train_types, node_neigh): node_embed = self.node_embeddings[train_inputs] node_embed_neighbors = self.node_type_embeddings[node_neigh] node_embed_tmp = torch.cat( [ node_embed_neighbors[:, i, :, i, :].unsqueeze(1) for i in range(self.edge_type_count) ], dim=1, ) node_type_embed = torch.sum(node_embed_tmp, dim=2) trans_w = self.trans_weights[train_types] trans_w_s1 = self.trans_weights_s1[train_types] trans_w_s2 = self.trans_weights_s2[train_types] attention = F.softmax( torch.matmul( torch.tanh(torch.matmul(node_type_embed, trans_w_s1)), trans_w_s2 ).squeeze(2), dim=1, ).unsqueeze(1) node_type_embed = torch.matmul(attention, node_type_embed) node_embed = node_embed + torch.matmul(node_type_embed, trans_w).squeeze(1) last_node_embed = F.normalize(node_embed, dim=1) return last_node_embed class NSLoss(nn.Module): def __init__(self, num_nodes, num_sampled, embedding_size): super(NSLoss, self).__init__() self.num_nodes = num_nodes self.num_sampled = num_sampled self.embedding_size = embedding_size self.weights = Parameter(torch.FloatTensor(num_nodes, embedding_size)) self.sample_weights = F.normalize( torch.Tensor( [ (math.log(k + 2) - math.log(k + 1)) / math.log(num_nodes + 1) for k in range(num_nodes) ] ), dim=0, ) self.reset_parameters() def reset_parameters(self): self.weights.data.normal_(std=1.0 / math.sqrt(self.embedding_size)) def forward(self, input, embs, label): n = input.shape[0] log_target = torch.log( torch.sigmoid(torch.sum(torch.mul(embs, self.weights[label]), 1)) ) negs = torch.multinomial( self.sample_weights, self.num_sampled * n, replacement=True ).view(n, self.num_sampled) noise = torch.neg(self.weights[negs]) sum_log_sampled = torch.sum( torch.log(torch.sigmoid(torch.bmm(noise, embs.unsqueeze(2)))), 1 ).squeeze() loss = log_target + sum_log_sampled return -loss.sum() / n def train_model(network_data): all_walks = generate_walks(network_data, args.num_walks, args.walk_length, args.schema, file_name) vocab, index2word = generate_vocab(all_walks) train_pairs = generate_pairs(all_walks, vocab, args.window_size) edge_types = list(network_data.keys()) num_nodes = len(index2word) edge_type_count = len(edge_types) epochs = args.epoch batch_size = args.batch_size embedding_size = args.dimensions embedding_u_size = args.edge_dim u_num = edge_type_count num_sampled = args.negative_samples dim_a = args.att_dim att_head = 1 neighbor_samples = args.neighbor_samples device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") neighbors = [[[] for __ in range(edge_type_count)] for _ in range(num_nodes)] for r in range(edge_type_count): g = network_data[edge_types[r]] for (x, y) in g: ix = vocab[x].index iy = vocab[y].index neighbors[ix][r].append(iy) neighbors[iy][r].append(ix) for i in range(num_nodes): if len(neighbors[i][r]) == 0: neighbors[i][r] = [i] * neighbor_samples elif len(neighbors[i][r]) < neighbor_samples: neighbors[i][r].extend( list( np.random.choice( neighbors[i][r], size=neighbor_samples - len(neighbors[i][r]), ) ) ) elif len(neighbors[i][r]) > neighbor_samples: neighbors[i][r] = list( np.random.choice(neighbors[i][r], size=neighbor_samples) ) model = GATNEModel( num_nodes, embedding_size, embedding_u_size, edge_type_count, dim_a ) nsloss = NSLoss(num_nodes, num_sampled, embedding_size) model.to(device) nsloss.to(device) optimizer = torch.optim.Adam( [{"params": model.parameters()}, {"params": nsloss.parameters()}], lr=1e-4 ) best_score = 0 patience = 0 for epoch in range(epochs): random.shuffle(train_pairs) batches = get_batches(train_pairs, neighbors, batch_size) data_iter = tqdm.tqdm( batches, desc="epoch %d" % (epoch), total=(len(train_pairs) + (batch_size - 1)) // batch_size, bar_format="{l_bar}{r_bar}", ) avg_loss = 0.0 for i, data in enumerate(data_iter): optimizer.zero_grad() embs = model(data[0].to(device), data[2].to(device), data[3].to(device),) loss = nsloss(data[0].to(device), embs, data[1].to(device)) loss.backward() optimizer.step() avg_loss += loss.item() if i % 5000 == 0: post_fix = { "epoch": epoch, "iter": i, "avg_loss": avg_loss / (i + 1), "loss": loss.item(), } data_iter.write(str(post_fix)) final_model = dict(zip(edge_types, [dict() for _ in range(edge_type_count)])) for i in range(num_nodes): train_inputs = torch.tensor([i for _ in range(edge_type_count)]).to(device) train_types = torch.tensor(list(range(edge_type_count))).to(device) node_neigh = torch.tensor( [neighbors[i] for _ in range(edge_type_count)] ).to(device) node_emb = model(train_inputs, train_types, node_neigh) for j in range(edge_type_count): final_model[edge_types[j]][index2word[i]] = ( node_emb[j].cpu().detach().numpy() ) valid_aucs, valid_f1s, valid_prs = [], [], [] test_aucs, test_f1s, test_prs = [], [], [] for i in range(edge_type_count): if args.eval_type == "all" or edge_types[i] in args.eval_type.split(","): tmp_auc, tmp_f1, tmp_pr = evaluate( final_model[edge_types[i]], valid_true_data_by_edge[edge_types[i]], valid_false_data_by_edge[edge_types[i]], ) valid_aucs.append(tmp_auc) valid_f1s.append(tmp_f1) valid_prs.append(tmp_pr) tmp_auc, tmp_f1, tmp_pr = evaluate( final_model[edge_types[i]], testing_true_data_by_edge[edge_types[i]], testing_false_data_by_edge[edge_types[i]], ) test_aucs.append(tmp_auc) test_f1s.append(tmp_f1) test_prs.append(tmp_pr) print("valid auc:", np.mean(valid_aucs)) print("valid pr:", np.mean(valid_prs)) print("valid f1:", np.mean(valid_f1s)) average_auc = np.mean(test_aucs) average_f1 = np.mean(test_f1s) average_pr = np.mean(test_prs) cur_score = np.mean(valid_aucs) if cur_score > best_score: best_score = cur_score patience = 0 else: patience += 1 if patience > args.patience: print("Early Stopping") break return average_auc, average_f1, average_pr if __name__ == "__main__": args = parse_args() file_name = args.input print(args) training_data_by_type = load_training_data(file_name + "/train.txt") valid_true_data_by_edge, valid_false_data_by_edge = load_testing_data( file_name + "/valid.txt" ) testing_true_data_by_edge, testing_false_data_by_edge = load_testing_data( file_name + "/test.txt" ) average_auc, average_f1, average_pr = train_model(training_data_by_type) print("Overall ROC-AUC:", average_auc) print("Overall PR-AUC", average_pr) print("Overall F1:", average_f1)