""" Script to pretrain the LSTM -> linear model for definition to attributes """ from data.dictionary_dataset import DictionaryChallengeDataset from lib.bucket_iterator import DictionaryChallengeIter from config import ModelConfig from lib.att_prediction import DictionaryModel from torch import optim import os import torch from lib.misc import CosineRankingLoss, optimize, cosine_ranking_loss, get_cosine_ranking import numpy as np import time from torch.nn.utils.rnn import pad_packed_sequence # Recommended hyperparameters args = ModelConfig(margin=0.1, lr=1e-4, batch_size=64, eps=1e-8, save_dir='def2atts_pretrain', dropout=0.2) train_data, val_data = DictionaryChallengeDataset.splits() train_iter = DictionaryChallengeIter(train_data, batch_size=args.batch_size) val_iter = DictionaryChallengeIter(val_data, batch_size=args.batch_size * 10, sort=False, shuffle=False) m = DictionaryModel(train_data.fields['text'].vocab, 300) optimizer = optim.Adam(m.parameters(), lr=args.lr, eps=args.eps, betas=(args.beta1, args.beta2)) crit = CosineRankingLoss(size_average=True, margin=args.margin) if torch.cuda.is_available(): m.cuda() crit.cuda() train_data.embeds = train_data.embeds.cuda() val_data.embeds = val_data.embeds.cuda() @optimize def train_batch(word_inds, defns, optimizers=None): pred_embs = m(defns) return crit(pred_embs, train_data.embeds[word_inds]) def deploy(word_inds, defns): pred_embs = m(defns) cost, correct_contrib, inc_contrib = cosine_ranking_loss(pred_embs, val_data.embeds[word_inds], margin=args.margin) cost = cost.data.cpu().numpy() correct_contrib = correct_contrib.data.cpu().numpy() rank, ranking = get_cosine_ranking(pred_embs, val_data.embeds, word_inds) return cost, correct_contrib, rank, ranking def log_val(word_inds, defns, cost, rank, ranking, num_ex=10): print("mean rank {:.1f}------------------------------------------".format(np.mean(rank))) engl_defns, ls = pad_packed_sequence(defns, batch_first=True, padding_value=0) spacing = np.linspace(len(ls) // num_ex, len(ls), endpoint=False, num=num_ex, dtype=np.int64) engl_defns = [' '.join([val_data.fields['text'].vocab.itos[x] for x in d[1:(l - 1)]]) for d, l in zip(engl_defns.cpu().data.numpy()[spacing], [ls[s] for s in spacing])] top_scorers = [[val_data.fields['label'].vocab.itos[int(x)] for x in t] for t in ranking.data.cpu().numpy()[spacing, :3]] words = [val_data.fields['label'].vocab.itos[int(wi)] for wi in word_inds.cpu().numpy()[spacing]] for w, (word, rank_, top3, l, defn) in enumerate( zip(words, rank, top_scorers, cost, engl_defns)): print("w{:2d}/{:2d}, R{:5d} {:>30} ({:.3f}){:>13}: {}".format( w, 64, rank_, ' '.join(top3), l, word, defn)) print("------------------------------------------------------------") last_best_epoch = 1 prev_best = 0.0 for epoch in range(1, 101): val_l = [] val_l_correct = [] train_l = [] m.eval() for val_b, (word_inds, defns) in enumerate(val_iter): cost, correct_contrib, rank, ranking = deploy(word_inds, defns) val_l.append(np.mean(cost)) val_l_correct.append(np.mean(correct_contrib)) if val_b == 0: log_val(word_inds, defns, cost, rank, ranking, num_ex=10) print("--- \n E{:2d} (VAL) Cost {:.3f} Correct score {:.3f} \n --- \n".format( epoch, np.mean(val_l), np.mean(val_l_correct), ), flush=True) if np.mean(val_l_correct) > prev_best: prev_best = np.mean(val_l_correct) last_best_epoch = epoch else: if last_best_epoch < (epoch - 3): print("Early stopping at epoch {}".format(epoch)) break m.train() start_epoch = time.time() for b, (word_inds, defns) in enumerate(train_iter): start = time.time() l = train_batch(word_inds, defns, optimizers=[optimizer]) train_l.append(l) dur = time.time() - start if b % 1000 == 0 and b >= 100: print("e{:2d}b{:5d} Cost {:.3f} , {:.3f} s/batch".format( epoch, b, np.mean(train_l), dur, ), flush=True) dur_epoch = time.time() - start_epoch print("Duration of epoch was {:.3f}/batch, overall loss was {:.3f}".format( dur_epoch / b, np.mean(train_l), )) torch.save({ 'args': args.args, 'epoch': epoch, 'm_state_dict': m.state_dict(), 'optimizer': optimizer.state_dict(), }, os.path.join(args.save_dir, 'ckpt_{}.tar'.format(epoch)))