# coding:utf-8 import torch import torch.nn as nn from torch.autograd import Variable import torch.optim as optim import os import time import sys import datetime import ctypes import json import numpy as np from tqdm import tqdm use_gpu=False device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): use_gpu=True class MyDataParallel(nn.DataParallel): def _getattr__(self, name): return getattr(self.module, name) def to_var(x): return Variable(torch.from_numpy(x).to(device)) class Config(object): def __init__(self): base_file = os.path.abspath( os.path.join(os.path.dirname(__file__), "release/Base.so") ) self.lib = ctypes.cdll.LoadLibrary(base_file) """argtypes""" """'sample""" self.lib.sampling.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_int64, ctypes.c_int64, ctypes.c_int64, ] """'valid""" self.lib.getValidHeadBatch.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] self.lib.getValidTailBatch.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] self.lib.validHead.argtypes = [ctypes.c_void_p] self.lib.validTail.argtypes = [ctypes.c_void_p] """test link prediction""" self.lib.getHeadBatch.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] self.lib.getTailBatch.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] self.lib.testHead.argtypes = [ctypes.c_void_p] self.lib.testTail.argtypes = [ctypes.c_void_p] """test triple classification""" self.lib.getValidBatch.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] self.lib.getTestBatch.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] self.lib.getBestThreshold.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] self.lib.test_triple_classification.argtypes = [ ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, ] """restype""" self.lib.getValidHit10.restype = ctypes.c_float # for triple classification self.lib.test_triple_classification.restype = ctypes.c_float """set essential parameters""" self.in_path = "./" self.batch_size = 100 self.bern = 0 self.work_threads = 8 self.hidden_size = 100 self.negative_ent = 1 self.negative_rel = 0 self.ent_size = self.hidden_size self.rel_size = self.hidden_size self.margin = 1.0 self.valid_steps = 5 self.save_steps = 5 self.opt_method = "SGD" self.optimizer = None self.lr_decay = 0 self.weight_decay = 0 self.lmbda = 0.0 self.lmbda_two = 0.0 self.alpah = 0.001 self.early_stopping_patience = 10 self.nbatches = 100 self.p_norm = 1 self.test_link = True self.test_triple = False self.model = None self.trainModel = None self.testModel = None self.pretrain_model = None self.ent_dropout = 0 self.rel_dropout = 0 self.use_init_embeddings = False self.test_file_path = None def init(self): self.lib.setInPath( ctypes.create_string_buffer(self.in_path.encode(), len(self.in_path) * 2) ) self.lib.setTestFilePath( ctypes.create_string_buffer(self.test_file_path.encode(), len(self.test_file_path) * 2) ) self.lib.setBern(self.bern) self.lib.setWorkThreads(self.work_threads) self.lib.randReset() self.lib.importTrainFiles() self.lib.importTestFiles() self.lib.importTypeFiles() self.relTotal = self.lib.getRelationTotal() self.entTotal = self.lib.getEntityTotal() self.trainTotal = self.lib.getTrainTotal() self.testTotal = self.lib.getTestTotal() self.validTotal = self.lib.getValidTotal() self.batch_size = int(self.trainTotal / self.nbatches) self.batch_seq_size = self.batch_size * ( 1 + self.negative_ent + self.negative_rel ) self.batch_h = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_t = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_r = np.zeros(self.batch_seq_size, dtype=np.int64) self.batch_y = np.zeros(self.batch_seq_size, dtype=np.float32) self.batch_h_addr = self.batch_h.__array_interface__["data"][0] self.batch_t_addr = self.batch_t.__array_interface__["data"][0] self.batch_r_addr = self.batch_r.__array_interface__["data"][0] self.batch_y_addr = self.batch_y.__array_interface__["data"][0] self.valid_h = np.zeros(self.entTotal, dtype=np.int64) self.valid_t = np.zeros(self.entTotal, dtype=np.int64) self.valid_r = np.zeros(self.entTotal, dtype=np.int64) self.valid_h_addr = self.valid_h.__array_interface__["data"][0] self.valid_t_addr = self.valid_t.__array_interface__["data"][0] self.valid_r_addr = self.valid_r.__array_interface__["data"][0] self.test_h = np.zeros(self.entTotal, dtype=np.int64) self.test_t = np.zeros(self.entTotal, dtype=np.int64) self.test_r = np.zeros(self.entTotal, dtype=np.int64) self.test_h_addr = self.test_h.__array_interface__["data"][0] self.test_t_addr = self.test_t.__array_interface__["data"][0] self.test_r_addr = self.test_r.__array_interface__["data"][0] self.valid_pos_h = np.zeros(self.validTotal, dtype=np.int64) self.valid_pos_t = np.zeros(self.validTotal, dtype=np.int64) self.valid_pos_r = np.zeros(self.validTotal, dtype=np.int64) self.valid_pos_h_addr = self.valid_pos_h.__array_interface__["data"][0] self.valid_pos_t_addr = self.valid_pos_t.__array_interface__["data"][0] self.valid_pos_r_addr = self.valid_pos_r.__array_interface__["data"][0] self.valid_neg_h = np.zeros(self.validTotal, dtype=np.int64) self.valid_neg_t = np.zeros(self.validTotal, dtype=np.int64) self.valid_neg_r = np.zeros(self.validTotal, dtype=np.int64) self.valid_neg_h_addr = self.valid_neg_h.__array_interface__["data"][0] self.valid_neg_t_addr = self.valid_neg_t.__array_interface__["data"][0] self.valid_neg_r_addr = self.valid_neg_r.__array_interface__["data"][0] self.test_pos_h = np.zeros(self.testTotal, dtype=np.int64) self.test_pos_t = np.zeros(self.testTotal, dtype=np.int64) self.test_pos_r = np.zeros(self.testTotal, dtype=np.int64) self.test_pos_h_addr = self.test_pos_h.__array_interface__["data"][0] self.test_pos_t_addr = self.test_pos_t.__array_interface__["data"][0] self.test_pos_r_addr = self.test_pos_r.__array_interface__["data"][0] self.test_neg_h = np.zeros(self.testTotal, dtype=np.int64) self.test_neg_t = np.zeros(self.testTotal, dtype=np.int64) self.test_neg_r = np.zeros(self.testTotal, dtype=np.int64) self.test_neg_h_addr = self.test_neg_h.__array_interface__["data"][0] self.test_neg_t_addr = self.test_neg_t.__array_interface__["data"][0] self.test_neg_r_addr = self.test_neg_r.__array_interface__["data"][0] self.relThresh = np.zeros(self.relTotal, dtype=np.float32) self.relThresh_addr = self.relThresh.__array_interface__["data"][0] def set_test_link(self, test_link): self.test_link = test_link def set_test_triple(self, test_triple): self.test_triple = test_triple def set_margin(self, margin): self.margin = margin def set_in_path(self, in_path): self.in_path = in_path def set_test_file_path(self, test_file_path): self.test_file_path = test_file_path def set_nbatches(self, nbatches): self.nbatches = nbatches def set_p_norm(self, p_norm): self.p_norm = p_norm def set_valid_steps(self, valid_steps): self.valid_steps = valid_steps def set_save_steps(self, save_steps): self.save_steps = save_steps def set_checkpoint_dir(self, checkpoint_dir): self.checkpoint_dir = checkpoint_dir def set_result_dir(self, result_dir): self.result_dir = result_dir def set_alpha(self, alpha): self.alpha = alpha def set_lmbda(self, lmbda): self.lmbda = lmbda def set_lmbda_two(self, lmbda_two): self.lmbda_two = lmbda_two def set_lr_decay(self, lr_decay): self.lr_decay = lr_decay def set_weight_decay(self, weight_decay): self.weight_decay = weight_decay def set_opt_method(self, opt_method): self.opt_method = opt_method def set_bern(self, bern): self.bern = bern def set_init_embeddings(self, entity_embs, rel_embs): self.use_init_embeddings = True self.init_ent_embs = torch.from_numpy(entity_embs).to(device) self.init_rel_embs = torch.from_numpy(rel_embs).to(device) def set_config_CNN(self, num_of_filters, drop_prob, kernel_size=1): self.out_channels = num_of_filters self.convkb_drop_prob = drop_prob self.kernel_size = kernel_size def set_dimension(self, dim): self.hidden_size = dim self.ent_size = dim self.rel_size = dim def set_ent_dimension(self, dim): self.ent_size = dim def set_rel_dimension(self, dim): self.rel_size = dim def set_train_times(self, train_times): self.train_times = train_times def set_work_threads(self, work_threads): self.work_threads = work_threads def set_ent_neg_rate(self, rate): self.negative_ent = rate def set_rel_neg_rate(self, rate): self.negative_rel = rate def set_ent_dropout(self, ent_dropout): self.ent_dropout = ent_dropout def set_rel_dropout(self, rel_dropout): self.rel_dropout = rel_dropout def set_early_stopping_patience(self, early_stopping_patience): self.early_stopping_patience = early_stopping_patience def set_pretrain_model(self, pretrain_model): self.pretrain_model = pretrain_model def get_parameters(self, param_dict, mode="numpy"): for param in param_dict: param_dict[param] = param_dict[param].cpu() res = {} for param in param_dict: if mode == "numpy": res[param] = param_dict[param].numpy() elif mode == "list": res[param] = param_dict[param].numpy().tolist() else: res[param] = param_dict[param] return res def save_embedding_matrix(self, best_model): path = os.path.join(self.result_dir, self.model.__name__ + ".json") f = open(path, "w") f.write(json.dumps(self.get_parameters(best_model, "list"))) f.close() def set_train_model(self, model): print("Initializing training model...") self.model = model self.trainModel = self.model(config=self) #self.trainModel = nn.DataParallel(self.trainModel, device_ids=[2,3,4]) self.trainModel.to(device) if self.optimizer != None: pass elif self.opt_method == "Adagrad" or self.opt_method == "adagrad": self.optimizer = optim.Adagrad( self.trainModel.parameters(), lr=self.alpha, lr_decay=self.lr_decay, weight_decay=self.weight_decay, ) elif self.opt_method == "Adadelta" or self.opt_method == "adadelta": self.optimizer = optim.Adadelta( self.trainModel.parameters(), lr=self.alpha, weight_decay=self.weight_decay, ) elif self.opt_method == "Adam" or self.opt_method == "adam": self.optimizer = optim.Adam( self.trainModel.parameters(), lr=self.alpha, weight_decay=self.weight_decay, ) else: self.optimizer = optim.SGD( self.trainModel.parameters(), lr=self.alpha, weight_decay=self.weight_decay, ) print("Finish initializing") def set_test_model(self, model, path=None): print("Initializing test model...") self.model = model self.testModel = self.model(config=self) if path == None: path = os.path.join(self.result_dir, self.model.__name__ + ".ckpt") self.testModel.load_state_dict(torch.load(path)) self.testModel.to(device) self.testModel.eval() print("Finish initializing") def sampling(self): self.lib.sampling( self.batch_h_addr, self.batch_t_addr, self.batch_r_addr, self.batch_y_addr, self.batch_size, self.negative_ent, self.negative_rel, ) def save_checkpoint(self, model, epoch): path = os.path.join( self.checkpoint_dir, self.model.__name__ + "-" + str(epoch) + ".ckpt" ) torch.save(model, path) def save_best_checkpoint(self, best_model): path = os.path.join(self.result_dir, self.model.__name__ + ".ckpt") torch.save(best_model, path) def train_one_step(self): self.trainModel.train() self.trainModel.batch_h = to_var(self.batch_h) self.trainModel.batch_t = to_var(self.batch_t) self.trainModel.batch_r = to_var(self.batch_r) self.trainModel.batch_y = to_var(self.batch_y) self.optimizer.zero_grad() loss = self.trainModel() loss.backward() torch.nn.utils.clip_grad_norm_(self.trainModel.parameters(), 0.5) self.optimizer.step() return loss.item() def test_one_step(self, model, test_h, test_t, test_r): model.eval() with torch.no_grad(): model.batch_h = to_var(test_h) model.batch_t = to_var(test_t) model.batch_r = to_var(test_r) return model.predict() def valid(self, model): self.lib.validInit() for i in range(self.validTotal): sys.stdout.write("%d\r" % (i)) sys.stdout.flush() self.lib.getValidHeadBatch( self.valid_h_addr, self.valid_t_addr, self.valid_r_addr ) res = self.test_one_step(model, self.valid_h, self.valid_t, self.valid_r) self.lib.validHead(res.__array_interface__["data"][0]) self.lib.getValidTailBatch( self.valid_h_addr, self.valid_t_addr, self.valid_r_addr ) res = self.test_one_step(model, self.valid_h, self.valid_t, self.valid_r) self.lib.validTail(res.__array_interface__["data"][0]) return self.lib.getValidHit10() def training_model(self): if not os.path.exists(self.checkpoint_dir): os.mkdir(self.checkpoint_dir) best_epoch = 0 best_hit10 = 0.0 best_model = None bad_counts = 0 training_range = tqdm(range(self.train_times)) for epoch in training_range: res = 0.0 for batch in range(self.nbatches): self.sampling() loss = self.train_one_step() res += loss training_range.set_description("Epoch %d | loss: %f" % (epoch, res)) # print("Epoch %d | loss: %f" % (epoch, res)) if (epoch + 1) % self.save_steps == 0: training_range.set_description("Epoch %d has finished, saving..." % (epoch)) self.save_checkpoint(self.trainModel.state_dict(), epoch) if (epoch + 1) % self.valid_steps == 0: training_range.set_description("Epoch %d has finished | loss: %f, validating..." % (epoch, res)) hit10 = self.valid(self.trainModel) if hit10 > best_hit10: best_hit10 = hit10 best_epoch = epoch best_model = self.trainModel.state_dict() print("Best model | hit@10 of valid set is %f" % (best_hit10)) bad_counts = 0 else: print("Hit@10 of valid set is %f | bad count is %d" % (hit10, bad_counts)) bad_counts += 1 if bad_counts == self.early_stopping_patience: print("Early stopping at epoch %d" % (epoch)) break if best_model == None: best_model = self.trainModel.state_dict() best_epoch = self.train_times - 1 best_hit10 = self.valid(self.trainModel) print("Best epoch is %d | hit@10 of valid set is %f" % (best_epoch, best_hit10)) print("Store checkpoint of best result at epoch %d..." % (best_epoch)) if not os.path.isdir(self.result_dir): os.mkdir(self.result_dir) self.save_best_checkpoint(best_model) self.save_embedding_matrix(best_model) print("Finish storing") print("Testing...") self.set_test_model(self.model) self.test() print("Finish test") return best_model def valid_triple_classification(self, model): self.lib.getValidBatch( self.valid_pos_h_addr, self.valid_pos_t_addr, self.valid_pos_r_addr, self.valid_neg_h_addr, self.valid_neg_t_addr, self.valid_neg_r_addr, ) res_pos = self.test_one_step( model, self.valid_pos_h, self.valid_pos_t, self.valid_pos_r ) res_neg = self.test_one_step( model, self.valid_neg_h, self.valid_neg_t, self.valid_neg_r ) self.lib.getBestThreshold( self.relThresh_addr, res_pos.__array_interface__["data"][0], res_neg.__array_interface__["data"][0], ) return self.lib.test_triple_classification( self.relThresh_addr, res_pos.__array_interface__["data"][0], res_neg.__array_interface__["data"][0], ) def training_triple_classification(self): if not os.path.exists(self.checkpoint_dir): os.mkdir(self.checkpoint_dir) best_epoch = 0 best_acc = 0.0 best_model = None bad_counts = 0 training_range = tqdm(range(self.train_times)) for epoch in training_range: res = 0.0 for batch in range(self.nbatches): self.sampling() loss = self.train_one_step() res += loss training_range.set_description("Epoch %d | loss: %f" % (epoch, res)) if (epoch + 1) % self.save_steps == 0: training_range.set_description("Epoch %d has finished, saving..." % (epoch)) self.save_checkpoint(self.trainModel.state_dict(), epoch) if (epoch + 1) % self.valid_steps == 0: training_range.set_description("Epoch %d has finished | loss: %f, validating..." % (epoch, res)) acc = self.valid_triple_classification(self.trainModel) if acc > best_acc: best_acc = acc best_epoch = epoch best_model = self.trainModel.state_dict() print("Best model | Acc of valid set is %f" % (best_acc)) bad_counts = 0 else: print("Acc of valid set is %f | bad count is %d" % (acc, bad_counts)) bad_counts += 1 if bad_counts == self.early_stopping_patience: print("Early stopping at epoch %d" % (epoch)) break if best_model == None: best_model = self.trainModel.state_dict() best_epoch = self.train_times - 1 best_acc = self.valid_triple_classification(self.trainModel) print("Best epoch is %d | Acc of valid set is %f" % (best_epoch, best_acc)) print("Store checkpoint of best result at epoch %d..." % (best_epoch)) if not os.path.isdir(self.result_dir): os.mkdir(self.result_dir) self.save_best_checkpoint(best_model) self.save_embedding_matrix(best_model) print("Finish storing") print("Testing...") self.set_test_model(self.model) self.test() print("Finish test") return best_model def link_prediction(self): print("The total of test triple is %d" % (self.testTotal)) for i in range(self.testTotal): sys.stdout.write("%d\r" % (i)) sys.stdout.flush() self.lib.getHeadBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr) res = self.test_one_step( self.testModel, self.test_h, self.test_t, self.test_r ) self.lib.testHead(res.__array_interface__["data"][0]) self.lib.getTailBatch(self.test_h_addr, self.test_t_addr, self.test_r_addr) res = self.test_one_step( self.testModel, self.test_h, self.test_t, self.test_r ) self.lib.testTail(res.__array_interface__["data"][0]) self.lib.test_link_prediction() def triple_classification(self): self.lib.getValidBatch( self.valid_pos_h_addr, self.valid_pos_t_addr, self.valid_pos_r_addr, self.valid_neg_h_addr, self.valid_neg_t_addr, self.valid_neg_r_addr, ) res_pos = self.test_one_step( self.testModel, self.valid_pos_h, self.valid_pos_t, self.valid_pos_r ) res_neg = self.test_one_step( self.testModel, self.valid_neg_h, self.valid_neg_t, self.valid_neg_r ) self.lib.getBestThreshold( self.relThresh_addr, res_pos.__array_interface__["data"][0], res_neg.__array_interface__["data"][0], ) self.lib.getTestBatch( self.test_pos_h_addr, self.test_pos_t_addr, self.test_pos_r_addr, self.test_neg_h_addr, self.test_neg_t_addr, self.test_neg_r_addr, ) res_pos = self.test_one_step( self.testModel, self.test_pos_h, self.test_pos_t, self.test_pos_r ) res_neg = self.test_one_step( self.testModel, self.test_neg_h, self.test_neg_t, self.test_neg_r ) self.lib.test_triple_classification( self.relThresh_addr, res_pos.__array_interface__["data"][0], res_neg.__array_interface__["data"][0], ) def test(self): if self.test_link: self.link_prediction() if self.test_triple: self.triple_classification()