"""Conll training algorithm""" import os import time import argparse import socket from datetime import datetime import numpy as np import torch import torch.nn as nn from torch.autograd import Variable from torch.optim import RMSprop from torch.utils.data import DataLoader from tensorboardX import SummaryWriter from neuralcoref.train.model import Model from neuralcoref.train.dataset import ( NCDataset, NCBatchSampler, load_embeddings_from_file, padder_collate, SIZE_PAIR_IN, SIZE_SINGLE_IN, ) from neuralcoref.train.utils import SIZE_EMBEDDING from neuralcoref.train.evaluator import ConllEvaluator PACKAGE_DIRECTORY = os.path.dirname(os.path.abspath(__file__)) STAGES = ["allpairs", "toppairs", "ranking"] def clipped_sigmoid(inputs): epsilon = 1.0e-7 return torch.sigmoid(inputs).clamp(epsilon, 1.0 - epsilon) def get_all_pairs_loss(n): def all_pair_loss(scores, targets): """ All pairs and single mentions probabilistic loss """ labels = targets[0] weights = targets[4].data if len(targets) == 5 else None loss_op = nn.BCEWithLogitsLoss(weight=weights, reduction="sum") loss = loss_op(scores, labels) return loss / n return all_pair_loss def get_top_pair_loss(n): def top_pair_loss(scores, targets, debug=False): """ Top pairs (best true and best mistaken) and single mention probabilistic loss """ true_ants = targets[2] false_ants = targets[3] if len(targets) == 5 else None s_scores = clipped_sigmoid(scores) true_pairs = torch.gather(s_scores, 1, true_ants) top_true, top_true_arg = torch.log(true_pairs).max( dim=1 ) # max(log(p)), p=sigmoid(s) if debug: print("true_pairs", true_pairs.data) print("top_true", top_true.data) print("top_true_arg", top_true_arg.data) out_score = torch.sum(top_true).neg() if ( false_ants is not None ): # We have no false antecedents when there are no pairs false_pairs = torch.gather(s_scores, 1, false_ants) top_false, _ = torch.log(1 - false_pairs).min( dim=1 ) # min(log(1-p)), p=sigmoid(s) out_score = out_score + torch.sum(top_false).neg() return out_score / n return top_pair_loss def get_ranking_loss(n): def ranking_loss(scores, targets): """ Slack-rescaled max margin loss """ costs = targets[1] true_ants = targets[2] weights = targets[4] if len(targets) == 5 else None true_ant_score = torch.gather(scores, 1, true_ants) top_true, _ = true_ant_score.max(dim=1) tmp_loss = scores.add(1).add( top_true.unsqueeze(1).neg() ) # 1 + scores - top_true if weights is not None: tmp_loss = tmp_loss.mul(weights) tmp_loss = tmp_loss.mul(costs) loss, _ = tmp_loss.max(dim=1) out_score = torch.sum(loss) return out_score / n return ranking_loss def decrease_lr(optim_func, factor=0.1, min_lrs=0, eps=0, verbose=True): for i, param_group in enumerate(optim_func.param_groups): old_lr = float(param_group["lr"]) new_lr = max(old_lr * factor, min_lrs) if old_lr - new_lr > eps: param_group["lr"] = new_lr if verbose: print(f"Reducing learning rate" " of group {i} to {new_lr:.4e}.") return new_lr def load_model(model, path): print("⛄️ Reloading model from", path) model.load_state_dict( torch.load(path) if args.cuda else torch.load(path, map_location=lambda storage, loc: storage) ) def run_model(args): print( "Training for", args.all_pairs_epoch, args.top_pairs_epoch, args.ranking_epoch, "epochs", ) # Tensorboard server writer = SummaryWriter() # Load datasets and embeddings embed_path = args.weights if args.weights is not None else args.train tensor_embeddings, voc = load_embeddings_from_file(embed_path + "tuned_word") dataset = NCDataset(args.train, args) eval_dataset = NCDataset(args.eval, args) print("Vocabulary:", len(voc)) # Construct model print("🏝 Build model") model = Model( len(voc), SIZE_EMBEDDING, args.h1, args.h2, args.h3, SIZE_PAIR_IN, SIZE_SINGLE_IN, ) model.load_embeddings(tensor_embeddings) if args.cuda: model.cuda() if args.weights is not None: print("🏝 Loading pre-trained weights") model.load_weights(args.weights) if args.checkpoint_file is not None: print("⛄️ Loading model from", args.checkpoint_file) model.load_state_dict( torch.load(args.checkpoint_file) if args.cuda else torch.load( args.checkpoint_file, map_location=lambda storage, loc: storage ) ) print("🏝 Loading conll evaluator") eval_evaluator = ConllEvaluator( model, eval_dataset, args.eval, args.evalkey, embed_path, args ) train_evaluator = ConllEvaluator( model, dataset, args.train, args.trainkey, embed_path, args ) print("🏝 Testing evaluator and getting first eval score") eval_evaluator.test_model() start_time = time.time() eval_evaluator.build_test_file() score, f1_conll, ident = eval_evaluator.get_score() elapsed = time.time() - start_time print(f"|| s/evaluation {elapsed:5.2f}") writer.add_scalar("eval/" + "F1_conll", f1_conll, 0) # Preparing dataloader print("🏝 Preparing dataloader") print( "Dataloader parameters: batchsize", args.batchsize, "numworkers", args.numworkers, ) batch_sampler = NCBatchSampler( dataset.mentions_pair_length, shuffle=True, batchsize=args.batchsize ) dataloader = DataLoader( dataset, collate_fn=padder_collate, batch_sampler=batch_sampler, num_workers=args.numworkers, pin_memory=args.cuda, ) mentions_idx, n_pairs = batch_sampler.get_batch_info() print("🏝 Start training") g_step = 0 start_from = ( args.startstep if args.startstep is not None and args.startstage is not None else 0 ) def run_epochs( start_epoch, end_epoch, loss_func, optim_func, save_name, lr, g_step, debug=None ): best_model_path = args.save_path + "best_model" + save_name start_time_all = time.time() best_f1_conll = 0 lower_eval = 0 for epoch in range(start_epoch, end_epoch): """ Run an epoch """ print(f"🚘 {save_name} Epoch {epoch:d}") model.train() start_time_log = time.time() start_time_epoch = time.time() epoch_loss = 0 for batch_i, (m_idx, n_pairs_l, batch) in enumerate( zip(mentions_idx, n_pairs, dataloader) ): if debug is not None and (debug == -1 or debug in m_idx): l = list(dataset.flat_m_loc[m][2:] for m in m_idx) print( "🏔 Batch", batch_i, "m_idx:", "|".join(str(i) for i in m_idx), "mentions:", "|".join(dataset.docs[d]["mentions"][i] for u, i, d in l), ) print("Batch n_pairs:", "|".join(str(p) for p in n_pairs_l)) inputs, targets = batch inputs = tuple(Variable(inp, requires_grad=False) for inp in inputs) targets = tuple(Variable(tar, requires_grad=False) for tar in targets) if args.cuda: inputs = tuple(i.cuda() for i in inputs) targets = tuple(t.cuda() for t in targets) scores = model(inputs) if debug is not None and (debug == -1 or debug in m_idx): print( "Scores:\n" + "\n".join( "|".join(str(s) for s in s_l) for s_l in scores.data.cpu().numpy() ) ) print( "Labels:\n" + "\n".join( "|".join(str(s) for s in s_l) for s_l in targets[0].data.cpu().numpy() ) ) loss = loss_func(scores, targets) if debug is not None and (debug == -1 or debug in m_idx): print("Loss", loss.item()) # Zero gradients, perform a backward pass, and update the weights. optim_func.zero_grad() loss.backward() epoch_loss += loss.item() optim_func.step() writer.add_scalar("train/" + save_name + "_loss", loss.item(), g_step) writer.add_scalar("meta/" + "lr", lr, g_step) writer.add_scalar("meta/" + "stage", STAGES.index(save_name), g_step) g_step += 1 if batch_i % args.log_interval == 0 and batch_i > 0: elapsed = time.time() - start_time_log lr = optim_func.param_groups[0]["lr"] ea = elapsed * 1000 / args.log_interval li = loss.item() print( f"| epoch {epoch:3d} | {batch_i:5d}/{len(dataloader):5d} batches | " f"lr {lr:.2e} | ms/batch {ea:5.2f} | " f"loss {li:.2e}" ) start_time_log = time.time() elapsed_all = time.time() - start_time_all elapsed_epoch = time.time() - start_time_epoch ep = elapsed_epoch / 60 ea = ( elapsed_all / 3600 * float(end_epoch - epoch) / float(epoch - start_epoch + 1) ) print( f"|| min/epoch {ep:5.2f} | est. remaining time (h) {ea:5.2f} | loss {epoch_loss:.2e}" ) writer.add_scalar("epoch/" + "loss", epoch_loss, g_step) if epoch % args.conll_train_interval == 0: start_time = time.time() train_evaluator.build_test_file() score, f1_conll, ident = train_evaluator.get_score() elapsed = time.time() - start_time ep = elapsed_epoch / 60 print(f"|| min/train evaluation {ep:5.2f} | F1_conll {f1_conll:5.2f}") writer.add_scalar("epoch/" + "F1_conll", f1_conll, g_step) if epoch % args.conll_eval_interval == 0: start_time = time.time() eval_evaluator.build_test_file() score, f1_conll, ident = eval_evaluator.get_score() elapsed = time.time() - start_time ep = elapsed_epoch / 60 print(f"|| min/evaluation {ep:5.2f}") writer.add_scalar("eval/" + "F1_conll", f1_conll, g_step) g_step += 1 save_path = args.save_path + save_name + "_" + str(epoch) torch.save(model.state_dict(), save_path) if f1_conll > best_f1_conll: best_f1_conll = f1_conll torch.save(model.state_dict(), best_model_path) lower_eval = 0 elif args.on_eval_decrease != "nothing": print("Evaluation metric decreases") lower_eval += 1 if lower_eval >= args.patience: if ( args.on_eval_decrease == "divide_lr" or args.on_eval_decrease == "divide_then_next" ): print("reload best model and decrease lr") load_model(model, best_model_path) lr = decrease_lr(optim_func) if args.on_eval_decrease == "next_stage" or lr <= args.min_lr: print("Switch to next stage") break # Save last step start_time = time.time() eval_evaluator.build_test_file() score, f1_conll, ident = eval_evaluator.get_score() elapsed = time.time() - start_time ep = elapsed / 60 print(f"|| min/evaluation {ep:5.2f}") writer.add_scalar("eval/" + "F1_conll", f1_conll, g_step) g_step += 1 save_path = args.save_path + save_name + "_" + str(epoch) torch.save(model.state_dict(), save_path) load_model(model, best_model_path) return g_step if args.startstage is None or args.startstage == "allpairs": optimizer = RMSprop( model.parameters(), lr=args.all_pairs_lr, weight_decay=args.all_pairs_l2 ) loss_func = get_all_pairs_loss(batch_sampler.pairs_per_batch) g_step = run_epochs( start_from, args.all_pairs_epoch, loss_func, optimizer, "allpairs", args.all_pairs_lr, g_step, ) start_from = 0 if args.startstage is None or args.startstage in ["allpairs", "toppairs"]: optimizer = RMSprop( model.parameters(), lr=args.top_pairs_lr, weight_decay=args.top_pairs_l2 ) loss_func = get_top_pair_loss(10 * batch_sampler.mentions_per_batch) g_step = run_epochs( start_from, args.top_pairs_epoch, loss_func, optimizer, "toppairs", args.top_pairs_lr, g_step, ) start_from = 0 if args.startstage is None or args.startstage in [ "ranking", "allpairs", "toppairs", ]: optimizer = RMSprop( model.parameters(), lr=args.ranking_lr, weight_decay=args.ranking_l2 ) loss_func = get_ranking_loss(batch_sampler.mentions_per_batch) g_step = run_epochs( start_from, args.ranking_epoch, loss_func, optimizer, "ranking", args.ranking_lr, g_step, ) if __name__ == "__main__": DIR_PATH = os.path.dirname(os.path.realpath(__file__)) parser = argparse.ArgumentParser( description="Training the neural coreference model" ) parser.add_argument( "--train", type=str, default=DIR_PATH + "/data/", help="Path to the train dataset", ) parser.add_argument( "--eval", type=str, default=DIR_PATH + "/data/", help="Path to the eval dataset" ) parser.add_argument( "--evalkey", type=str, help="Path to an optional key file for scoring" ) parser.add_argument( "--weights", type=str, help="Path to pre-trained weights (if you only want to test the scoring for e.g.)", ) parser.add_argument( "--batchsize", type=int, default=20000, help="Size of a batch in total number of pairs", ) parser.add_argument( "--numworkers", type=int, default=8, help="Number of workers for loading batches", ) parser.add_argument( "--startstage", type=str, help='Start from a specific stage ("allpairs", "toppairs", "ranking")', ) parser.add_argument("--startstep", type=int, help="Start from a specific step") parser.add_argument( "--checkpoint_file", type=str, help="Start from a previously saved checkpoint file", ) parser.add_argument( "--log_interval", type=int, default=10, help="test every X mini-batches" ) parser.add_argument( "--conll_eval_interval", type=int, default=10, help="evaluate eval F1 conll every X epochs", ) parser.add_argument( "--conll_train_interval", type=int, default=20, help="evaluate train F1 conll every X epochs", ) parser.add_argument("--seed", type=int, default=1111, help="random seed") parser.add_argument("--costfn", type=float, default=0.8, help="cost of false new") parser.add_argument("--costfl", type=float, default=0.4, help="cost of false link") parser.add_argument("--costwl", type=float, default=1.0, help="cost of wrong link") parser.add_argument( "--h1", type=int, default=1000, help="number of hidden unit on layer 1" ) parser.add_argument( "--h2", type=int, default=500, help="number of hidden unit on layer 2" ) parser.add_argument( "--h3", type=int, default=500, help="number of hidden unit on layer 3" ) parser.add_argument( "--all_pairs_epoch", type=int, default=200, help="number of epochs for all-pairs pre-training", ) parser.add_argument( "--top_pairs_epoch", type=int, default=200, help="number of epochs for top-pairs pre-training", ) parser.add_argument( "--ranking_epoch", type=int, default=200, help="number of epochs for ranking training", ) parser.add_argument( "--all_pairs_lr", type=float, default=2e-4, help="all pairs pre-training learning rate", ) parser.add_argument( "--top_pairs_lr", type=float, default=2e-4, help="top pairs pre-training learning rate", ) parser.add_argument( "--ranking_lr", type=float, default=2e-6, help="ranking training learning rate" ) parser.add_argument( "--all_pairs_l2", type=float, default=1e-6, help="all pairs pre-training l2 regularization", ) parser.add_argument( "--top_pairs_l2", type=float, default=1e-5, help="top pairs pre-training l2 regularization", ) parser.add_argument( "--ranking_l2", type=float, default=1e-5, help="ranking training l2 regularization", ) parser.add_argument( "--patience", type=int, default=3, help="patience (epochs) before considering evaluation has decreased", ) parser.add_argument("--min_lr", type=float, default=2e-8, help="min learning rate") parser.add_argument( "--on_eval_decrease", type=str, default="nothing", help='What to do when evaluation decreases ("nothing", "divide_lr", "next_stage", "divide_then_next")', ) parser.add_argument( "--lazy", type=int, default=1, choices=(0, 1), help="Use lazy loading (1, default) or not (0) while loading the npy files", ) args = parser.parse_args() args.costs = {"FN": args.costfn, "FL": args.costfl, "WL": args.costwl} args.lazy = bool(args.lazy) current_time = datetime.now().strftime("%b%d_%H-%M-%S") args.save_path = os.path.join( PACKAGE_DIRECTORY, "checkpoints", current_time + "_" + socket.gethostname() + "_", ) np.random.seed(args.seed) torch.manual_seed(args.seed) args.cuda = torch.cuda.is_available() if args.cuda: torch.cuda.manual_seed(args.seed) args.evalkey = args.evalkey if args.evalkey is not None else args.eval + "/key.txt" args.trainkey = args.train + "/key.txt" args.train = args.train + "/numpy/" args.eval = args.eval + "/numpy/" print(args) run_model(args)