# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import time import h5py import torch import random import argparse from torch import nn from utils import get_logger from functools import partial from utils import AverageMeter from nli.models import PpoModel from utils import EarlyStopping from utils import get_lr_scheduler from tensorboardX import SummaryWriter from torch.utils.data import DataLoader from nli.data_preprocessing import NliDataset def make_path_preparations(args): seed = hash(str(args)) % 1000_000 torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) # logger path args_hash = str(hash(str(args))) if not os.path.exists(args.logs_path): os.makedirs(args.logs_path) logger = get_logger(f"{args.logs_path}/l{args_hash}.log") print(f"{args.logs_path}/l{args_hash}.log") logger.info(f"args: {str(args)}") logger.info(f"args hash: {args_hash}") logger.info(f"random seed: {seed}") # model path args.model_dir = f"{args.model_dir}/m{args_hash}" if not os.path.exists(args.model_dir): os.makedirs(args.model_dir) logger.info(f"checkpoint's dir is: {args.model_dir}") # tensorboard path tensorboard_path = f"{args.tensorboard_path}/t{args_hash}" if not os.path.exists(tensorboard_path): os.makedirs(tensorboard_path) summary_writer = dict() summary_writer["train"] = SummaryWriter(log_dir=os.path.join(tensorboard_path, 'log' + args_hash, 'train')) summary_writer["valid"] = SummaryWriter(log_dir=os.path.join(tensorboard_path, 'log' + args_hash, 'valid')) return logger, summary_writer def get_data(args): if args.nli == "snli": train_data = NliDataset.load_data(f"data/nli/snli_1.0/train_lower={args.lower}.pckl") valid_data = NliDataset.load_data(f"data/nli/snli_1.0/valid_lower={args.lower}.pckl") test_data = NliDataset.load_data(f"data/nli/snli_1.0/test_lower={args.lower}.pckl") elif args.nli == "multi_nli": train_data = NliDataset.load_data(f"data/nli/multinli_1.0/train_lower={args.lower}.pckl") train_data.extend(NliDataset.load_data(f"data/nli/snli_1.0/train_lower={args.lower}.pckl")) valid_data = NliDataset.load_data(f"data/nli/multinli_1.0/valid_matched_lower={args.lower}.pckl") test_data = None else: raise ValueError print(f"train len: {len(train_data)}") print(f"valid len: {len(valid_data)}") train_dataset = NliDataset(train_data, max_len=args.max_len) valid_dataset = NliDataset(valid_data) test_dataset = None if test_data is None else NliDataset(test_data) print(f"train len: {len(train_dataset.data)}") print(f"valid len: {len(valid_dataset.data)}") train_data = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True, collate_fn=NliDataset.collate_fn, pin_memory=True) valid_data = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False, collate_fn=NliDataset.collate_fn, pin_memory=True) test_data = None if test_dataset is None else \ DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=False, collate_fn=NliDataset.collate_fn, pin_memory=True) with h5py.File(f"data/nli/glove_lower={args.lower}.h5", 'r') as f: glove = f["glove"][...] args.vocab_size = glove.shape[0] args.label_size = NliDataset.label_size return train_data, valid_data, test_data, glove def prepare_optimisers(args, logger, policy_parameters, environment_parameters): if args.env_optimizer == "adam": env_opt_class = torch.optim.Adam elif args.env_optimizer == "amsgrad": env_opt_class = partial(torch.optim.Adam, amsgrad=True) elif args.env_optimizer == "adadelta": env_opt_class = torch.optim.Adadelta else: env_opt_class = torch.optim.SGD if args.pol_optimizer == "adam": pol_opt_class = torch.optim.Adam elif args.pol_optimizer == "amsgrad": pol_opt_class = partial(torch.optim.Adam, amsgrad=True) elif args.pol_optimizer == "adadelta": pol_opt_class = torch.optim.Adadelta else: pol_opt_class = torch.optim.SGD optimizer = {"policy": pol_opt_class(params=policy_parameters, lr=args.pol_lr, weight_decay=args.l2_weight), "env": env_opt_class(params=environment_parameters, lr=args.env_lr, weight_decay=args.l2_weight)} lr_scheduler = {"policy": get_lr_scheduler(logger, optimizer["policy"], patience=args.lr_scheduler_patience, threshold=args.lr_scheduler_threshold), "env": get_lr_scheduler(logger, optimizer["env"], patience=args.lr_scheduler_patience, threshold=args.lr_scheduler_threshold)} es = EarlyStopping(mode="max", patience=args.es_patience, threshold=args.es_threshold) return optimizer, lr_scheduler, es def perform_env_optimizer_step(optimizer, model, args): if args.clip_grad_norm > 0: nn.utils.clip_grad_norm_(parameters=model.get_environment_parameters(), max_norm=args.clip_grad_norm, norm_type=float("inf")) optimizer["env"].step() optimizer["env"].zero_grad() def perform_policy_optimizer_step(optimizer, model, args): if args.clip_grad_norm > 0: nn.utils.clip_grad_norm_(parameters=model.get_policy_parameters(), max_norm=args.clip_grad_norm, norm_type=float("inf")) optimizer["policy"].step() optimizer["policy"].zero_grad() def test(test_data, model, device, logger): if test_data is None: return loading_time_meter = AverageMeter() batch_time_meter = AverageMeter() ce_loss_meter = AverageMeter() accuracy_meter = AverageMeter() entropy_meter = AverageMeter() n_entropy_meter = AverageMeter() model.eval() start = time.time() with torch.no_grad(): for labels, premises, p_mask, hypotheses, h_mask in test_data: labels = labels.to(device=device, non_blocking=True) premises = premises.to(device=device, non_blocking=True) p_mask = p_mask.to(device=device, non_blocking=True) hypotheses = hypotheses.to(device=device, non_blocking=True) h_mask = h_mask.to(device=device, non_blocking=True) loading_time_meter.update(time.time() - start) pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \ model(premises, p_mask, hypotheses, h_mask, labels) entropy = entropy.mean() normalized_entropy = normalized_entropy.mean() accuracy = (labels == pred_labels).to(dtype=torch.float32).mean() n = p_mask.shape[0] accuracy_meter.update(accuracy.item(), n) ce_loss_meter.update(ce_loss.item(), n) entropy_meter.update(entropy.item(), n) n_entropy_meter.update(normalized_entropy.item(), n) batch_time_meter.update(time.time() - start) start = time.time() logger.info(f"Test: ce_loss: {ce_loss_meter.avg:.4f} accuracy: {accuracy_meter.avg:.4f} " f"entropy: {entropy_meter.avg:.4f} n_entropy: {n_entropy_meter.avg:.4f} " f"loading_time: {loading_time_meter.avg:.4f} batch_time: {batch_time_meter.avg:.4f}") logger.info("done") return accuracy_meter.avg def validate(valid_data, model, epoch, device, logger, summary_writer): loading_time_meter = AverageMeter() batch_time_meter = AverageMeter() ce_loss_meter = AverageMeter() accuracy_meter = AverageMeter() entropy_meter = AverageMeter() n_entropy_meter = AverageMeter() model.eval() start = time.time() with torch.no_grad(): for labels, premises, p_mask, hypotheses, h_mask in valid_data: labels = labels.to(device=device, non_blocking=True) premises = premises.to(device=device, non_blocking=True) p_mask = p_mask.to(device=device, non_blocking=True) hypotheses = hypotheses.to(device=device, non_blocking=True) h_mask = h_mask.to(device=device, non_blocking=True) loading_time_meter.update(time.time() - start) pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \ model(premises, p_mask, hypotheses, h_mask, labels) entropy = entropy.mean() normalized_entropy = normalized_entropy.mean() accuracy = (labels == pred_labels).to(dtype=torch.float32).mean() n = p_mask.shape[0] accuracy_meter.update(accuracy.item(), n) ce_loss_meter.update(ce_loss.item(), n) entropy_meter.update(entropy.item(), n) n_entropy_meter.update(normalized_entropy.item(), n) batch_time_meter.update(time.time() - start) start = time.time() logger.info(f"Valid: epoch: {epoch} ce_loss: {ce_loss_meter.avg:.4f} accuracy: {accuracy_meter.avg:.4f} " f"entropy: {entropy_meter.avg:.4f} n_entropy: {n_entropy_meter.avg:.4f} " f"loading_time: {loading_time_meter.avg:.4f} batch_time: {batch_time_meter.avg:.4f}") summary_writer["valid"].add_scalar(tag="ce", scalar_value=ce_loss_meter.avg, global_step=global_step) summary_writer["valid"].add_scalar(tag="accuracy", scalar_value=accuracy_meter.avg, global_step=global_step) summary_writer["valid"].add_scalar(tag="n_entropy", scalar_value=n_entropy_meter.avg, global_step=global_step) model.train() return accuracy_meter.avg def train(train_data, valid_data, model, optimizer, lr_scheduler, es, epoch, args, logger, summary_writer): loading_time_meter = AverageMeter() batch_time_meter = AverageMeter() ce_loss_meter = AverageMeter() accuracy_meter = AverageMeter() entropy_meter = AverageMeter() n_entropy_meter = AverageMeter() prob_ratio_meter = AverageMeter() device = args.gpu_id model.train() start = time.time() for batch_idx, (labels, premises, p_mask, hypotheses, h_mask) in enumerate(train_data): labels = labels.to(device=device, non_blocking=True) premises = premises.to(device=device, non_blocking=True) p_mask = p_mask.to(device=device, non_blocking=True) hypotheses = hypotheses.to(device=device, non_blocking=True) h_mask = h_mask.to(device=device, non_blocking=True) loading_time_meter.update(time.time() - start) pred_labels, ce_loss, rewards, actions, actions_log_prob, entropy, normalized_entropy = \ model(premises, p_mask, hypotheses, h_mask, labels) ce_loss.backward() perform_env_optimizer_step(optimizer, model, args) for k in range(args.ppo_updates): if k == 0: new_normalized_entropy, new_actions_log_prob = normalized_entropy, actions_log_prob else: new_normalized_entropy, new_actions_log_prob = \ model.evaluate_actions(premises, p_mask, actions["p_actions"], hypotheses, h_mask, actions["h_actions"]) prob_ratio = (new_actions_log_prob - actions_log_prob.detach()).exp() clamped_prob_ratio = prob_ratio.clamp(1.0 - args.epsilon, 1.0 + args.epsilon) ppo_loss = torch.max(prob_ratio * rewards, clamped_prob_ratio * rewards).mean() loss = ppo_loss - args.entropy_weight * new_normalized_entropy.mean() loss.backward() perform_policy_optimizer_step(optimizer, model, args) entropy = entropy.mean() normalized_entropy = normalized_entropy.mean() n = p_mask.shape[0] accuracy = (labels == pred_labels).to(dtype=torch.float32).mean() accuracy_meter.update(accuracy.item(), n) ce_loss_meter.update(ce_loss.item(), n) entropy_meter.update(entropy.item(), n) n_entropy_meter.update(normalized_entropy.item(), n) prob_ratio_meter.update((1.0-prob_ratio.detach()).abs().mean().item(), n) batch_time_meter.update(time.time() - start) global global_step summary_writer["train"].add_scalar(tag="ce", scalar_value=ce_loss.item(), global_step=global_step) summary_writer["train"].add_scalar(tag="accuracy", scalar_value=accuracy.item(), global_step=global_step) summary_writer["train"].add_scalar(tag="n_entropy", scalar_value=normalized_entropy.item(), global_step=global_step) summary_writer["train"].add_scalar(tag="prob_ratio", scalar_value=prob_ratio_meter.value, global_step=global_step) global_step += 1 if (batch_idx + 1) % (len(train_data) // 3) == 0: logger.info(f"Train: epoch: {epoch} batch_idx: {batch_idx + 1} ce_loss: {ce_loss_meter.avg:.4f} " f"accuracy: {accuracy_meter.avg:.4f} entropy: {entropy_meter.avg:.4f} " f"n_entropy: {n_entropy_meter.avg:.4f} loading_time: {loading_time_meter.avg:.4f} " f"batch_time: {batch_time_meter.avg:.4f}") val_accuracy = validate(valid_data, model, epoch, device, logger, summary_writer) lr_scheduler["env"].step(val_accuracy) lr_scheduler["policy"].step(val_accuracy) es.step(val_accuracy) global best_model_path if es.is_converged: return if es.is_improved(): logger.info("saving model...") best_model_path = f"{args.model_dir}/{epoch}-{batch_idx}.mdl" torch.save({"epoch": epoch, "batch_idx": batch_idx, "state_dict": model.state_dict()}, best_model_path) model.train() start = time.time() def main(args): logger, summary_writer = make_path_preparations(args) train_data, valid_data, test_data, vectors = get_data(args) model = PpoModel(vocab_size=args.vocab_size, word_dim=args.word_dim, hidden_dim=args.hidden_dim, mlp_hidden_dim=args.mlp_hidden_dim, label_dim=args.label_size, dropout_prob=args.dropout_prob, parser_leaf_transformation=args.parser_leaf_transformation, parser_trans_hidden_dim=args.parser_trans_hidden_dim, tree_leaf_transformation=args.tree_leaf_transformation, tree_trans_hidden_dim=args.tree_trans_hidden_dim, baseline_type=args.baseline_type, var_normalization=args.var_normalization, use_batchnorm=args.use_batchnorm).cuda(args.gpu_id) dtype = model.embd_parser.weight.data.dtype device = model.embd_parser.weight.data.device model.embd_parser.weight.data = torch.tensor(vectors, dtype=dtype, device=device) model.embd_tree.weight.data = torch.tensor(vectors, dtype=dtype, device=device) if args.freeze_embeddings: model.embd_parser.weight.requires_grad = False model.embd_tree.weight.requires_grad = False logger.info("Embeddings is frozen!") optimizer, lr_scheduler, es = prepare_optimisers(args, logger, policy_parameters=model.get_policy_parameters(), environment_parameters=model.get_environment_parameters()) validate(valid_data, model, 0, args.gpu_id, logger, summary_writer) for epoch in range(args.max_epoch): train(train_data, valid_data, model, optimizer, lr_scheduler, es, epoch, args, logger, summary_writer) if es.is_converged: break print(best_model_path) checkpoint = torch.load(best_model_path) model.load_state_dict(checkpoint["state_dict"]) test(test_data, model, args.gpu_id, logger) if __name__ == "__main__": # SNLI args = {"nli": "snli", "freeze-embeddings": "True", "use-batchnorm": "True", "dropout-prob": 0.1, "lower": "True", "mlp-hidden-dim": 1024, "word-dim": 300, "hidden-dim": 300, "parser-leaf-transformation": "lstm_transformation", "parser-trans-hidden_dim": 300, "tree-leaf-transformation": "lstm_transformation", "tree-trans-hidden_dim": 300, "baseline-type": "self_critical", "var-normalization": "True", "entropy-weight": 0.0, "clip-grad-norm": 0.0, "env-optimizer": "adadelta", "pol-optimizer": "adadelta", "env-lr": 1.0, "pol-lr": 1.0, "ppo-updates": 1, "epsilon": 0.2, "lr-scheduler-patience": 8, "lr-scheduler-threshold": 0.005, "l2-weight": 0.0, "batch-size": 64, "max-len": 120, "max-epoch": 150, "es-patience": 20, "es-threshold": 0.005, "gpu-id": 0, "model-dir": "data/snli/ppo/models/exp0", "logs-path": "data/snli/ppo/logs/exp0", "tensorboard-path": "data/snli/ppo/tensorboard/exp0" } # MultiNLI # args = {"nli": "multi_nli", # "freeze-embeddings": "True", # "use-batchnorm": "True", # "dropout-prob": 0.1, # "lower": "True", # "mlp-hidden-dim": 1024, # "word-dim": 300, # "hidden-dim": 300, # "parser-leaf-transformation": "lstm_transformation", # "parser-trans-hidden_dim": 300, # "tree-leaf-transformation": "lstm_transformation", # "tree-trans-hidden_dim": 300, # "baseline-type": "self_critical", # "var-normalization": "True", # "entropy-weight": 0.0, # "clip-grad-norm": 0.0, # "env-optimizer": "adadelta", # "pol-optimizer": "adadelta", # "env-lr": 1.0, # "pol-lr": 1.0, # "ppo-updates": 1, # "epsilon": 0.2, # "lr-scheduler-patience": 8, # "lr-scheduler-threshold": 0.005, # "l2-weight": 0.0, # "batch-size": 64, # "max-len": 120, # "max-epoch": 150, # "es-patience": 20, # "es-threshold": 0.005, # "gpu-id": 0, # "model-dir": "data/multi_nli/ppo/models/exp0", # "logs-path": "data/multi_nli/ppo/logs/exp0", # "tensorboard-path": "data/multi_nli/ppo/tensorboard/exp0" # } parser = argparse.ArgumentParser() parser.add_argument("--nli", default=args["nli"], choices=["multi_nli", "snli"]) parser.add_argument("--freeze-embeddings", default=args["freeze-embeddings"], type=lambda val: True if val == "True" else False) parser.add_argument("--use-batchnorm", default=args["use-batchnorm"], type=lambda val: True if val == "True" else False) parser.add_argument("--dropout-prob", default=args["dropout-prob"], type=float) parser.add_argument("--lower", default=args["lower"], type=lambda val: True if val == "True" else False) parser.add_argument("--mlp-hidden-dim", default=args["mlp-hidden-dim"], type=int) parser.add_argument("--word-dim", required=False, default=args["word-dim"], type=int) parser.add_argument("--hidden-dim", required=False, default=args["hidden-dim"], type=int) parser.add_argument("--parser-leaf-transformation", required=False, default=args["parser-leaf-transformation"], choices=["no_transformation", "lstm_transformation", "bi_lstm_transformation", "conv_transformation"]) parser.add_argument("--parser-trans-hidden_dim", required=False, default=args["parser-trans-hidden_dim"], type=int) parser.add_argument("--tree-leaf-transformation", required=False, default=args["tree-leaf-transformation"], choices=["no_transformation", "lstm_transformation", "bi_lstm_transformation", "conv_transformation"]) parser.add_argument("--tree-trans-hidden_dim", required=False, default=args["tree-trans-hidden_dim"], type=int) parser.add_argument("--baseline-type", default=args["baseline-type"], choices=["no_baseline", "ema", "self_critical"]) parser.add_argument("--var-normalization", default=args["var-normalization"], type=lambda string: True if string == "True" else False) parser.add_argument("--entropy-weight", default=args["entropy-weight"], type=float) parser.add_argument("--clip-grad-norm", default=args["clip-grad-norm"], type=float, help="If the value is less or equal to zero clipping is not performed.") parser.add_argument("--env-optimizer", required=False, default=args["env-optimizer"], choices=["adam", "amsgrad", "sgd", "adadelta"]) parser.add_argument("--pol-optimizer", required=False, default=args["pol-optimizer"], choices=["adam", "amsgrad", "sgd", "adadelta"]) parser.add_argument("--env-lr", required=False, default=args["env-lr"], type=float) parser.add_argument("--pol-lr", required=False, default=args["pol-lr"], type=float) parser.add_argument("--ppo-updates", required=False, default=args["ppo-updates"], type=int) parser.add_argument("--epsilon", required=False, default=args["epsilon"], type=float) parser.add_argument("--lr-scheduler-patience", required=False, default=args["lr-scheduler-patience"], type=int) parser.add_argument("--lr-scheduler-threshold", required=False, default=args["lr-scheduler-threshold"], type=float) parser.add_argument("--l2-weight", required=False, default=args["l2-weight"], type=float) parser.add_argument("--batch-size", required=False, default=args["batch-size"], type=int) parser.add_argument("--max-len", default=args["max-len"], type=int) parser.add_argument("--max-epoch", required=False, default=args["max-epoch"], type=int) parser.add_argument("--es-patience", required=False, default=args["es-patience"], type=int) parser.add_argument("--es-threshold", required=False, default=args["es-threshold"], type=float) parser.add_argument("--gpu-id", required=False, default=args["gpu-id"], type=int) parser.add_argument("--model-dir", required=False, default=args["model-dir"], type=str) parser.add_argument("--logs-path", required=False, default=args["logs-path"], type=str) parser.add_argument("--tensorboard-path", required=False, default=args["tensorboard-path"], type=str) global_step = 0 best_model_path = None args = parser.parse_args() with torch.cuda.device(args.gpu_id): main(args)