#!/usr/bin/env python3 import datetime import gc import logging import pickle import os import sys import time, json import torch import data_utils import models from data_utils import to_torch from eval_metric import mrr from model_utils import get_gold_pred_str, get_eval_string, get_output_index from tensorboardX import SummaryWriter from torch import optim from tqdm import tqdm import numpy as np sys.path.insert(0, './resources') import config_parser, constant, eval_metric class TensorboardWriter: """ Wraps a pair of ``SummaryWriter`` instances but is a no-op if they're ``None``. Allows Tensorboard logging without always checking for Nones first. """ def __init__(self, train_log: SummaryWriter = None, validation_log: SummaryWriter = None) -> None: self._train_log = train_log self._validation_log = validation_log def add_train_scalar(self, name: str, value: float, global_step: int) -> None: if self._train_log is not None: self._train_log.add_scalar(name, value, global_step) def add_validation_scalar(self, name: str, value: float, global_step: int) -> None: if self._validation_log is not None: self._validation_log.add_scalar(name, value, global_step) def get_data_gen(dataname, mode, args, vocab_set, goal): dataset = data_utils.TypeDataset(constant.FILE_ROOT + dataname, lstm_type=args.lstm_type, goal=goal, vocab=vocab_set) if mode == 'train': data_gen = dataset.get_batch(args.batch_size, args.num_epoch, forever=False, eval_data=False, simple_mention=not args.enhanced_mention) elif mode == 'dev': data_gen = dataset.get_batch(args.eval_batch_size, 1, forever=True, eval_data=True, simple_mention=not args.enhanced_mention) else: data_gen = dataset.get_batch(args.eval_batch_size, 1, forever=False, eval_data=True, simple_mention=not args.enhanced_mention) return data_gen def get_joint_datasets(args): vocab = data_utils.get_vocab() train_gen_list = [] valid_gen_list = [] if args.mode == 'train': if not args.remove_open and not args.only_crowd: train_gen_list.append( #("open", get_data_gen('train/open*.json', 'train', args, vocab, "open"))) ("open", get_data_gen('distant_supervision/headword_train.json', 'train', args, vocab, "open"))) valid_gen_list.append(("open", get_data_gen('distant_supervision/headword_dev.json', 'dev', args, vocab, "open"))) if not args.remove_el and not args.only_crowd: valid_gen_list.append( ("wiki", get_data_gen('distant_supervision/el_dev.json', 'dev', args, vocab, "wiki" if args.multitask else "open"))) train_gen_list.append( ("wiki", get_data_gen('distant_supervision/el_train.json', 'train', args, vocab, "wiki" if args.multitask else "open"))) #get_data_gen('train/el_train.json', 'train', args, vocab, "wiki" if args.multitask else "open"))) if args.add_crowd or args.only_crowd: train_gen_list.append( ("open", get_data_gen('crowd/train_m.json', 'train', args, vocab, "open"))) crowd_dev_gen = get_data_gen('crowd/dev.json', 'dev', args, vocab, "open") return train_gen_list, valid_gen_list, crowd_dev_gen def get_datasets(data_lists, args): data_gen_list = [] vocab_set = data_utils.get_vocab() for dataname, mode, goal in data_lists: data_gen_list.append(get_data_gen(dataname, mode, args, vocab_set, goal)) return data_gen_list def _train(args): if args.data_setup == 'joint': train_gen_list, val_gen_list, crowd_dev_gen = get_joint_datasets(args) else: train_fname = args.train_data dev_fname = args.dev_data data_gens = get_datasets([(train_fname, 'train', args.goal), (dev_fname, 'dev', args.goal)], args) train_gen_list = [(args.goal, data_gens[0])] val_gen_list = [(args.goal, data_gens[1])] train_log = SummaryWriter(os.path.join(constant.EXP_ROOT, args.model_id, "log", "train")) validation_log = SummaryWriter(os.path.join(constant.EXP_ROOT, args.model_id, "log", "validation")) tensorboard = TensorboardWriter(train_log, validation_log) model = models.Model(args, constant.ANSWER_NUM_DICT[args.goal]) model.cuda() total_loss = 0 batch_num = 0 start_time = time.time() init_time = time.time() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) if args.load: load_model(args.reload_model_name, constant.EXP_ROOT, args.model_id, model, optimizer) for idx, m in enumerate(model.modules()): logging.info(str(idx) + '->' + str(m)) best_eval_ma_f1=0 while True: batch_num += 1 # single batch composed of all train signal passed by. for (type_name, data_gen) in train_gen_list: try: batch = next(data_gen) batch, _ = to_torch(batch) except StopIteration: logging.info(type_name + " finished at " + str(batch_num)) torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, '{0:s}/{1:s}.pt'.format(constant.EXP_ROOT, args.model_id)) return optimizer.zero_grad() loss, output_logits = model(batch, type_name) loss.backward() total_loss += loss.data.cpu()[0] optimizer.step() # if batch_num % args.log_period == 0 and batch_num > 0: # gc.collect() # cur_loss = float(1.0 * loss.data.cpu().clone()[0]) # elapsed = time.time() - start_time # train_loss_str = ('|loss {0:3f} | at {1:d}step | @ {2:.2f} ms/batch'.format(cur_loss, batch_num, # elapsed * 1000 / args.log_period)) # start_time = time.time() # print(train_loss_str) # logging.info(train_loss_str) # tensorboard.add_train_scalar('train_loss_' + type_name, cur_loss, batch_num) # # if batch_num % args.eval_period == 0 and batch_num > 0: # output_index = get_output_index(output_logits) # gold_pred_train = get_gold_pred_str(output_index, batch['y'].data.cpu().clone(), args.goal) # accuracy = sum([set(y) == set(yp) for y, yp in gold_pred_train]) * 1.0 / len(gold_pred_train) # train_acc_str = '{1:s} Train accuracy: {0:.1f}%'.format(accuracy * 100, type_name) # print(train_acc_str) # logging.info(train_acc_str) # tensorboard.add_train_scalar('train_acc_' + type_name, accuracy, batch_num) # for (val_type_name, val_data_gen) in val_gen_list: # if val_type_name == type_name: # eval_batch, _ = to_torch(next(val_data_gen)) # evaluate_batch(batch_num, eval_batch, model, tensorboard, val_type_name, args.goal) if batch_num % args.eval_period == 0 and batch_num > 0: # Evaluate Loss on the Turk Dev dataset. print('---- eval at step {0:d} ---'.format(batch_num)) feed_dict = next(crowd_dev_gen) eval_batch, _ = to_torch(feed_dict) crowd_eval_loss, crowd_eval_ma_f1 = evaluate_batch(batch_num, eval_batch, model, tensorboard, "open", "open") if batch_num % args.save_period == 0 and batch_num > 0 and crowd_eval_ma_f1 > best_eval_ma_f1: best_eval_ma_f1 = crowd_eval_ma_f1 save_fname = '{0:s}/{1:s}_best.pt'.format(constant.EXP_ROOT, args.model_id) torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, save_fname) print( 'Total {0:.2f} minutes have passed, saving at {1:s} '.format((time.time() - init_time) / 60, save_fname)) # Training finished! torch.save({'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, '{0:s}/{1:s}.pt'.format(constant.EXP_ROOT, args.model_id)) def evaluate_batch(batch_num, eval_batch, model, tensorboard, val_type_name, goal): model.eval() loss, output_logits = model(eval_batch, val_type_name) output_index = get_output_index(output_logits) eval_loss = loss.data.cpu().clone()[0] eval_loss_str = 'Eval loss: {0:.7f} at step {1:d}'.format(eval_loss, batch_num) gold_pred = get_gold_pred_str(output_index, eval_batch['y'].data.cpu().clone(), goal) eval_accu = sum([set(y) == set(yp) for y, yp in gold_pred]) * 1.0 / len(gold_pred) tensorboard.add_validation_scalar('eval_acc_' + val_type_name, eval_accu, batch_num) tensorboard.add_validation_scalar('eval_loss_' + val_type_name, eval_loss, batch_num) eval_str, ma_f1, f1 = get_eval_string(gold_pred) print(val_type_name + ":" +eval_loss_str) print(gold_pred[:3]) print(val_type_name+":"+ eval_str) logging.info(val_type_name + ":" + eval_loss_str) logging.info(val_type_name +":" + eval_str) model.train() tensorboard.add_validation_scalar('ma_f1' + val_type_name, ma_f1, batch_num) tensorboard.add_validation_scalar('f1' + val_type_name, f1, batch_num) return eval_loss, ma_f1 def load_model(reload_model_name, save_dir, model_id, model, optimizer=None): if reload_model_name: model_file_name = '{0:s}/{1:s}.pt'.format(save_dir, reload_model_name) else: model_file_name = '{0:s}/{1:s}.pt'.format(save_dir, model_id) checkpoint = torch.load(model_file_name) model.load_state_dict(checkpoint['state_dict']) if optimizer: optimizer.load_state_dict(checkpoint['optimizer']) else: total_params = 0 # Log params for k in checkpoint['state_dict']: elem = checkpoint['state_dict'][k] param_s = 1 for size_dim in elem.size(): param_s = size_dim * param_s print(k, elem.size()) total_params += param_s param_str = ('Number of total parameters..{0:d}'.format(total_params)) logging.info(param_str) print(param_str) logging.info("Loading old file from {0:s}".format(model_file_name)) print('Loading model from ... {0:s}'.format(model_file_name)) def _test(args): assert args.load test_fname = args.eval_data data_gens = get_datasets([(test_fname, 'test', args.goal)], args) model = models.Model(args, constant.ANSWER_NUM_DICT[args.goal]) model.cuda() model.eval() # load_model(args.reload_model_name, constant.EXP_ROOT, args.model_id, model) saved_path = constant.EXP_ROOT model.load_state_dict(torch.load(saved_path + '/' + args.model_id + '_best.pt')["state_dict"]) data_gens = get_datasets([(test_fname, 'test', args.goal)], args)#, eval_epoch=1) for name, dataset in [(test_fname, data_gens[0])]: print('Processing... ' + name) batch = next(dataset) eval_batch, annot_ids = to_torch(batch) loss, output_logits = model(eval_batch, args.goal) threshes = np.arange(0,1,0.005) p_and_r = [] for thresh in tqdm(threshes): total_gold_pred = [] total_annot_ids = [] total_probs = [] total_ys = [] print('thresh {}'.format(thresh)) output_index = get_output_index(output_logits, thresh) output_prob = model.sigmoid_fn(output_logits).data.cpu().clone().numpy() y = eval_batch['y'].data.cpu().clone().numpy() gold_pred = get_gold_pred_str(output_index, y, args.goal) total_probs.extend(output_prob) total_ys.extend(y) total_gold_pred.extend(gold_pred) total_annot_ids.extend(annot_ids) # mrr_val = mrr(total_probs, total_ys) # print('mrr_value: ', mrr_val) # pickle.dump({'gold_id_array': total_ys, 'pred_dist': total_probs}, # open('./{0:s}.p'.format(args.reload_model_name), "wb")) # with open('./{0:s}.json'.format(args.reload_model_name), 'w') as f_out: # output_dict = {} # for a_id, (gold, pred) in zip(total_annot_ids, total_gold_pred): # output_dict[a_id] = {"gold": gold, "pred": pred} # json.dump(output_dict, f_out) eval_str, p, r = get_eval_string(total_gold_pred) p_and_r.append([p, r]) print(eval_str) np.save(saved_path + '/baseline_pr_dev', p_and_r) # for name, dataset in [(test_fname, data_gens[0])]: # print('Processing... ' + name) # total_gold_pred = [] # total_annot_ids = [] # total_probs = [] # total_ys = [] # for batch_num, batch in enumerate(dataset): # eval_batch, annot_ids = to_torch(batch) # loss, output_logits = model(eval_batch, args.goal) # output_index = get_output_index(output_logits) # output_prob = model.sigmoid_fn(output_logits).data.cpu().clone().numpy() # y = eval_batch['y'].data.cpu().clone().numpy() # gold_pred = get_gold_pred_str(output_index, y, args.goal) # total_probs.extend(output_prob) # total_ys.extend(y) # total_gold_pred.extend(gold_pred) # total_annot_ids.extend(annot_ids) # mrr_val = mrr(total_probs, total_ys) # print('mrr_value: ', mrr_val) # pickle.dump({'gold_id_array': total_ys, 'pred_dist': total_probs}, # open('./{0:s}.p'.format(args.reload_model_name), "wb")) # with open('./{0:s}.json'.format(args.reload_model_name), 'w') as f_out: # output_dict = {} # for a_id, (gold, pred) in zip(total_annot_ids, total_gold_pred): # output_dict[a_id] = {"gold": gold, "pred": pred} # json.dump(output_dict, f_out) # eval_str = get_eval_string(total_gold_pred) # print(eval_str) # logging.info('processing: ' + name) # logging.info(eval_str) if __name__ == '__main__': config = config_parser.parser.parse_args() torch.cuda.manual_seed(config.seed) logging.basicConfig( filename=constant.EXP_ROOT +"/"+ config.model_id + datetime.datetime.now().strftime("_%m-%d_%H") + config.mode + '.txt', level=logging.INFO, format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s', datefmt='%m-%d %H:%M') logging.info(config) logger = logging.getLogger() logger.setLevel(logging.INFO) if config.mode == 'train': _train(config) elif config.mode == 'test': _test(config) else: raise ValueError("invalid value for 'mode': {}".format(config.mode))