# Copyright (c) 2018, salesforce.com, inc. # All rights reserved. # Licensed under the BSD 3-Clause license. # For full license text, see the LICENSE file in the repo root # or https://opensource.org/licenses/BSD-3-Clause import torch import numpy as np from torchtext import data from torchtext import datasets from torch.nn import functional as F from torch.autograd import Variable import revtok import logging import random import string import traceback import math import uuid import argparse import os import copy import time from tqdm import tqdm, trange from model import Transformer, FastTransformer, INF, TINY, softmax from utils import NormalField, NormalTranslationDataset, TripleTranslationDataset, ParallelDataset from utils import Metrics, Best, computeGLEU, computeBLEU, Cache, Batch, masked_sort, unsorted, computeGroupBLEU from time import gmtime, strftime import sys from traceback import extract_tb from code import interact def interactive_exception(e_class, e_value, tb): sys.__excepthook__(e_class, e_value, tb) tb_stack = extract_tb(tb) locals_stack = [] while tb is not None: locals_stack.append(tb.tb_frame.f_locals) tb = tb.tb_next while len(tb_stack) > 0: frame = tb_stack.pop() ls = locals_stack.pop() print('\nInterpreter at file "{}", line {}, in {}:'.format( frame.filename, frame.lineno, frame.name)) print(' {}'.format(frame.line.strip())) interact(local=ls) #sys.excepthook = interactive_exception # check dirs for d in ['models', 'runs', 'logs']: if not os.path.exists('./{}'.format(d)): os.mkdir('./{}'.format(d)) # params parser = argparse.ArgumentParser(description='Train a Transformer model.') # data parser.add_argument('--data_prefix', type=str, default='../data/') parser.add_argument('--dataset', type=str, default='iwslt', help='"flickr" or "iwslt"') parser.add_argument('--language', type=str, default='ende', help='a combination of two language markers to show the language pair.') parser.add_argument('--load_vocab', action='store_true', help='load a pre-computed vocabulary') parser.add_argument('--load_dataset', action='store_true', help='load a pre-processed dataset') parser.add_argument('--use_revtok', action='store_true', help='use reversible tokenization') parser.add_argument('--level', type=str, default='subword', help='for BPE, we must preprocess the dataset') parser.add_argument('--good_course', action='store_true', help='use beam-search output for distillation') parser.add_argument('--test_set', type=str, default=None, help='which test set to use') parser.add_argument('--max_len', type=int, default=None, help='limit the train set sentences to this many tokens') parser.add_argument('--remove_eos', action='store_true', help='possibly remove <eos> tokens for FastTransformer') # model basic parser.add_argument('--prefix', type=str, default='', help='prefix to denote the model, nothing or [time]') parser.add_argument('--params', type=str, default='james-iwslt', help='pamarater sets: james-iwslt, t2t-base, etc') parser.add_argument('--fast', dest='model', action='store_const', const=FastTransformer, default=Transformer, help='use a single self-attn stack') # model variants parser.add_argument('--local', dest='windows', action='store_const', const=[1, 3, 5, 7, -1], default=None, help='use local attention') parser.add_argument('--causal', action='store_true', help='use causal attention') parser.add_argument('--positional_attention', action='store_true', help='incorporate positional information in key/value') parser.add_argument('--no_source', action='store_true') parser.add_argument('--use_mask', action='store_true', help='use src/trg mask during attention') parser.add_argument('--diag', action='store_true', help='ignore diagonal attention when doing self-attention.') parser.add_argument('--convblock', action='store_true', help='use ConvBlock instead of ResNet') parser.add_argument('--cosine_output', action='store_true', help='use cosine similarity as output layer') parser.add_argument('--noisy', action='store_true', help='inject noise in the attention mechanism: Beta-Gumbel softmax') parser.add_argument('--noise_samples', type=int, default=0, help='only useful for noisy parallel decoding') parser.add_argument('--critic', action='store_true', help='use critic') parser.add_argument('--kernel_sizes', type=str, default='2,3,4,5', help='kernel sizes of convnet critic') parser.add_argument('--kernel_num', type=int, default=128, help='number of each kind of kernel') parser.add_argument('--use_wo', action='store_true', help='use output weight matrix in multihead attention') parser.add_argument('--share_embeddings', action='store_true', help='share embeddings between encoder and decoder') parser.add_argument('--use_alignment', action='store_true', help='use the aligned fake data to initialize') parser.add_argument('--hard_inputs', action='store_true', help='use hard selection as inputs, instead of soft-attention over embeddings.') parser.add_argument('--preordering', action='store_true', help='use the ground-truth reordering information') parser.add_argument('--use_posterior_order', action='store_true', help='directly use the groud-truth alignment for reordering.') parser.add_argument('--train_decoder_with_order', action='store_true', help='when training the decoder, use the ground-truth') parser.add_argument('--postordering', action='store_true', help='just have a try...') parser.add_argument('--fertility_only', action='store_true') parser.add_argument('--highway', action='store_true', help='usually false') parser.add_argument('--mix_of_experts', action='store_true') parser.add_argument('--orderless', action='store_true', help='for the inputs, remove the order information') parser.add_argument('--cheating', action='store_true', help='disable decoding, always use real fertility') # running parser.add_argument('--mode', type=str, default='train', help='train, test or build') parser.add_argument('--gpu', type=int, default=0, help='GPU to use or -1 for CPU') parser.add_argument('--seed', type=int, default=19920206, help='seed for randomness') parser.add_argument('--eval-every', type=int, default=1000, help='run dev every') parser.add_argument('--maximum_steps', type=int, default=1000000, help='maximum steps you take to train a model') parser.add_argument('--disable_lr_schedule', action='store_true', help='disable the transformer learning rate') parser.add_argument('--batchsize', type=int, default=2048, help='# of tokens processed per batch') parser.add_argument('--hidden_size', type=int, default=None, help='input the hidden size') parser.add_argument('--length_ratio', type=int, default=2, help='maximum lengths of decoding') parser.add_argument('--optimizer', type=str, default='Adam') parser.add_argument('--beam_size', type=int, default=1, help='beam-size used in Beamsearch, default using greedy decoding') parser.add_argument('--alpha', type=float, default=0.6, help='length normalization weights') parser.add_argument('--temperature', type=float, default=1, help='smoothing temperature for noisy decoding') parser.add_argument('--multi_run', type=int, default=1, help='we can run the code multiple times to get the best') parser.add_argument('--load_from', type=str, default=None, help='load from checkpoint') parser.add_argument('--resume', action='store_true', help='when loading from the saved model, it resumes from that.') parser.add_argument('--teacher', type=str, default=None, help='load a pre-trained auto-regressive model.') parser.add_argument('--share_encoder', action='store_true', help='use teacher-encoder to initialize student') parser.add_argument('--finetune_encoder', action='store_true', help='if further train the encoder') parser.add_argument('--seq_dist', action='store_true', help='knowledge distillation at sequence level') parser.add_argument('--word_dist', action='store_true', help='knowledge distillation at word level') parser.add_argument('--greedy_fertility', action='store_true', help='using the fertility generated by autoregressive model (only for seq_dist)') parser.add_argument('--fertility_mode', type=str, default='argmax', help='mean, argmax or reinforce') parser.add_argument('--finetuning_truth', action='store_true', help='use ground-truth for finetuning') parser.add_argument('--trainable_teacher', action='store_true', help='have a trainable teacher') parser.add_argument('--only_update_errors', action='store_true', help='have a trainable teacher') parser.add_argument('--teacher_use_real', action='store_true', help='teacher also trained with MLE on real data') parser.add_argument('--max_cache', type=int, default=0, help='save most recent max_cache decoded translations') parser.add_argument('--replay_every', type=int, default=1000, help='every 1k updates, train the teacher again') parser.add_argument('--replay_times', type=int, default=250, help='train the teacher again for 250k steps') parser.add_argument('--margin', type=float, default=1.5, help='margin to make sure teacher will give higher score to real data') parser.add_argument('--real_data', action='store_true', help='only used in the reverse kl setting') parser.add_argument('--beta1', type=float, default=0.5, help='balancing MLE and KL loss.') parser.add_argument('--beta2', type=float, default=0.01, help='balancing the GAN loss.') parser.add_argument('--critic_only', type=int, default=0, help='pre-training the critic model.') parser.add_argument('--st', action='store_true', help='straight through estimator') parser.add_argument('--entropy', action='store_true') parser.add_argument('--no_bpe', action='store_true', help='output files without BPE') parser.add_argument('--no_write', action='store_true', help='do not write the decoding into the decoding files.') parser.add_argument('--output_fer', action='store_true', help='decoding and output fertilities') # debugging parser.add_argument('--check', action='store_true', help='on training, only used to check on the test set.') parser.add_argument('--debug', action='store_true', help='debug mode: no saving or tensorboard') parser.add_argument('--tensorboard', action='store_true', help='use TensorBoard') # old params parser.add_argument('--old', action='store_true', help='this is used for solving conflicts of new codes') parser.add_argument('--hyperopt', action='store_true', help='use HyperOpt') parser.add_argument('--scst', action='store_true', help='use HyperOpt') parser.add_argument('--serve', type=int, default=None, help='serve at port') parser.add_argument('--attention_discrimination', action='store_true') # ---------------------------------------------------------------------------------------------------------------- # args = parser.parse_args() if args.prefix == '[time]': args.prefix = strftime("%m.%d_%H.%M.", gmtime()) args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')] # get the langauage pairs: args.src = args.language[:2] # source language args.trg = args.language[2:] # target language # logger settings logger = logging.getLogger() logger.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s %(levelname)s: - %(message)s', datefmt='%Y-%m-%d %H:%M:%S') fh = logging.FileHandler('./logs/log-{}.txt'.format(args.prefix)) fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) ch = logging.StreamHandler() ch.setLevel(logging.DEBUG) ch.setFormatter(formatter) logger.addHandler(ch) logger.addHandler(fh) # setup random seeds random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) # setup data-field DataField = data.ReversibleField if args.use_revtok else NormalField tokenizer = revtok.tokenize if args.use_revtok else lambda x: x.replace('@@ ', '').split() TRG = DataField(init_token='<init>', eos_token='<eos>', batch_first=True) SRC = DataField(batch_first=True) if not args.share_embeddings else TRG ALIGN = data.Field(sequential=True, preprocessing=data.Pipeline(lambda tok: int(tok.split('-')[0])), use_vocab=False, pad_token=0, batch_first=True) FER = data.Field(sequential=True, preprocessing=data.Pipeline(lambda tok: int(tok)), use_vocab=False, pad_token=0, batch_first=True) align_dict, align_table = None, None # setup many datasets (need to manaually setup) data_prefix = args.data_prefix if args.dataset == 'iwslt': if args.test_set is None: args.test_set = 'IWSLT16.TED.tst2013' if args.dist_set is None: args.dist_set = '.dec.b1' elif args.greedy_fertility: logger.info('use the fertility predicted by autoregressive model (instead of fast-align)') train_data, dev_data = ParallelDataset.splits( path=data_prefix + 'iwslt/en-de/', train='train.en-de.bpe.new', validation='IWSLT16.TED.tst2013.en-de.bpe.new.dev', exts=('.src.b1', '.trg.b1', '.dec.b1', '.fer', '.fer'), fields=[('src', SRC), ('trg', TRG), ('dec', TRG), ('fer', FER), ('fer_dec', FER)], load_dataset=args.load_dataset, prefix='ts') elif (args.mode == 'test') or (args.mode == 'test_noisy'): train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'iwslt/en-de/', train='train.tags.en-de{}'.format( '.bpe' if not args.use_revtok else ''), validation='{}.en-de{}'.format( args.test_set, '.bpe' if not args.use_revtok else ''), exts=('.en', '.de'), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='normal') else: train_data, dev_data = ParallelDataset.splits( path=data_prefix + 'iwslt/en-de/', train='train.tags.en-de.bpe', validation='train.tags.en-de.bpe.dev', exts=('.en2', '.de2', '.decoded2', '.aligned', '.decode.aligned', '.fer', '.decode.fer'), fields=[('src', SRC), ('trg', TRG), ('dec', TRG), ('align', ALIGN), ('align_dec', ALIGN), ('fer', FER), ('fer_dec', FER)], load_dataset=args.load_dataset, prefix='ts') decoding_path = data_prefix + 'iwslt/en-de/{}.en-de.bpe.new' if args.use_alignment and (args.model is FastTransformer): align_dict = {l.split()[0]: l.split()[1] for l in open(data_prefix + 'iwslt/en-de/train.tags.en-de.dict')} elif args.dataset == 'wmt16-ende': if args.test_set is None: args.test_set = 'newstest2013' if (args.mode == 'test') or (args.mode == 'test_noisy'): train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'wmt16-ende/', train='newstest2013.tok.bpe.32000', validation='{}.tok.bpe.32000'.format(args.test_set), exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='real') decoding_path = data_prefix + 'wmt16-ende/test.{}.{}'.format(args.prefix, args.test_set) elif not args.seq_dist: train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'wmt16-ende/', train='train.tok.clean.bpe.32000', validation='{}.tok.bpe.32000'.format(args.test_set), exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='real') decoding_path = data_prefix + 'wmt16-ende/{}.tok.bpe.decode' else: train_data, dev_data = ParallelDataset.splits( path=data_prefix + 'wmt16-ende/', train='train.tok.bpe.decode', validation='newstest2013.tok.bpe.decode.dev', exts=('.src.b1', '.trg.b1', '.dec.b1', '.real.aligned', '.fake.aligned', '.real.fer', '.fake.fer'), fields=[('src', SRC), ('trg', TRG), ('dec', TRG), ('align', ALIGN), ('align_dec', ALIGN), ('fer', FER), ('fer_dec', FER)], load_dataset=args.load_dataset, prefix='ts') decoding_path = data_prefix + 'wmt16-ende/{}.tok.bpe.na' if args.use_alignment and (args.model is FastTransformer): align_table = {l.split()[0]: l.split()[1] for l in open(data_prefix + 'wmt16-ende/train.tok.bpe.decode.full.fastlign2.dict')} elif args.dataset == 'wmt16-deen': if args.test_set is None: args.test_set = 'newstest2013' if (args.mode == 'test') or (args.mode == 'test_noisy'): train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'wmt16-ende/', train='newstest2013.tok.bpe.32000', validation='{}.tok.bpe.32000'.format(args.test_set), exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='real') decoding_path = data_prefix + 'wmt16-ende/test.{}.{}'.format(args.prefix, args.test_set) elif not args.seq_dist: train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'wmt16-deen/', train='train.tok.clean.bpe.32000', validation='{}.tok.bpe.32000'.format(args.test_set), exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='real') decoding_path = data_prefix + 'wmt16-deen/{}.tok.bpe.decode' else: train_data, dev_data = ParallelDataset.splits( path=data_prefix + 'wmt16-deen/', train='train.tok.bpe.decode', validation='{}.tok.bpe.decode.dev'.format(args.test_set), exts=('.src.b1', '.trg.b1', '.dec.b1', '.real.aligned', '.fake.aligned', '.real.fer', '.fake.fer'), fields=[('src', SRC), ('trg', TRG), ('dec', TRG), ('align', ALIGN), ('align_dec', ALIGN), ('fer', FER), ('fer_dec', FER)], load_dataset=args.load_dataset, prefix='ts') decoding_path = data_prefix + 'wmt16-deen/{}.tok.bpe.na' if args.use_alignment and (args.model is FastTransformer): align_table = {l.split()[0]: l.split()[1] for l in open(data_prefix + 'wmt16-deen/train.tok.bpe.decode.full.fastlign2.dict')} elif args.dataset == 'wmt16-enro': if args.test_set is None: args.test_set = 'dev' if (args.mode == 'test') or (args.mode == 'test_noisy'): train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'wmt16-enro/', train='dev.bpe', validation='{}.bpe'.format(args.test_set), exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='real') decoding_path = data_prefix + 'wmt16-enro/{}.bpe.decode' elif not args.seq_dist: train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'wmt16-enro/', train='corpus.bpe', validation='{}.bpe'.format(args.test_set), exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='real') decoding_path = data_prefix + 'wmt16-enro/{}.bpe.decode' else: train_data, dev_data = ParallelDataset.splits( path=data_prefix + 'wmt16-enro/', train='train.bpe.decode', validation='dev.bpe.decode.dev', exts=('.src.b1', '.trg.b1', '.dec.b1', '.real.aligned', '.fake.aligned', '.real.fer', '.fake.fer'), fields=[('src', SRC), ('trg', TRG), ('dec', TRG), ('align', ALIGN), ('align_dec', ALIGN), ('fer', FER), ('fer_dec', FER)], load_dataset=args.load_dataset, prefix='ts') decoding_path = data_prefix + 'wmt16-enro/{}.tok.bpe.na' if args.use_alignment and (args.model is FastTransformer): align_table = {l.split()[0]: l.split()[1] for l in open(data_prefix + 'wmt16-enro/train.bpe.decode.full.fastlign2.dict')} elif args.dataset == 'wmt16-roen': if args.test_set is None: args.test_set = 'dev' if (args.mode == 'test') or (args.mode == 'test_noisy'): train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'wmt16-roen/', train='dev.bpe', validation='{}.bpe'.format(args.test_set), exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='real') decoding_path = data_prefix + 'wmt16-roen/{}.bpe.decode' elif not args.seq_dist: train_data, dev_data = NormalTranslationDataset.splits( path=data_prefix + 'wmt16-roen/', train='corpus.bpe', validation='{}.bpe'.format(args.test_set), exts=('.{}'.format(args.src), '.{}'.format(args.trg)), fields=(SRC, TRG), load_dataset=args.load_dataset, prefix='real') decoding_path = data_prefix + 'wmt16-roen/{}.bpe.decode' else: train_data, dev_data = ParallelDataset.splits( path=data_prefix + 'wmt16-roen/', train='train.bpe.decode', validation='dev.bpe.decode.dev', exts=('.src.b1', '.trg.b1', '.dec.b1', '.real.aligned', '.fake.aligned', '.real.fer', '.fake.fer'), fields=[('src', SRC), ('trg', TRG), ('dec', TRG), ('align', ALIGN), ('align_dec', ALIGN), ('fer', FER), ('fer_dec', FER)], load_dataset=args.load_dataset, prefix='ts') decoding_path = data_prefix + 'wmt16-roen/{}.tok.bpe.na' if args.use_alignment and (args.model is FastTransformer): align_table = {l.split()[0]: l.split()[1] for l in open(data_prefix + 'wmt16-roen/train.bpe.decode.full.fastlign2.dict')} else: raise NotImplementedError # build word-level vocabularies if args.load_vocab and os.path.exists(data_prefix + '{}/vocab{}_{}.pt'.format( args.dataset, 'shared' if args.share_embeddings else '', '{}-{}'.format(args.src, args.trg))): logger.info('load saved vocabulary.') src_vocab, trg_vocab = torch.load(data_prefix + '{}/vocab{}_{}.pt'.format( args.dataset, 'shared' if args.share_embeddings else '', '{}-{}'.format(args.src, args.trg))) SRC.vocab = src_vocab TRG.vocab = trg_vocab else: logger.info('save the vocabulary') if not args.share_embeddings: SRC.build_vocab(train_data, dev_data, max_size=50000) TRG.build_vocab(train_data, dev_data, max_size=50000) torch.save([SRC.vocab, TRG.vocab], data_prefix + '{}/vocab{}_{}.pt'.format( args.dataset, 'shared' if args.share_embeddings else '', '{}-{}'.format(args.src, args.trg))) args.__dict__.update({'trg_vocab': len(TRG.vocab), 'src_vocab': len(SRC.vocab)}) # build alignments --- if align_dict is not None: align_table = [TRG.vocab.stoi['<init>'] for _ in range(len(SRC.vocab.itos))] for src in align_dict: align_table[SRC.vocab.stoi[src]] = TRG.vocab.stoi[align_dict[src]] align_table[0] = 0 # --<unk> align_table[1] = 1 # --<pad> def dyn_batch_with_padding(new, i, sofar): prev_max_len = sofar / (i - 1) if i > 1 else 0 if args.seq_dist: return max(len(new.src), len(new.trg), len(new.dec), prev_max_len) * i else: return max(len(new.src), len(new.trg), prev_max_len) * i def dyn_batch_without_padding(new, i, sofar): if args.seq_dist: return sofar + max(len(new.src), len(new.trg), len(new.dec)) else: return sofar + max(len(new.src), len(new.trg)) # build the dataset iterators # work around torchtext making it hard to share vocabs without sharing other field properties if args.share_embeddings: SRC = copy.deepcopy(SRC) SRC.init_token = None SRC.eos_token = None train_data.fields['src'] = SRC dev_data.fields['src'] = SRC if (args.model is FastTransformer) and (args.remove_eos): TRG.eos_token = None if args.max_len is not None: train_data.examples = [ex for ex in train_data.examples if len(ex.trg) <= args.max_len] if args.batchsize == 1: # speed-test: one sentence per batch. batch_size_fn = lambda new, count, sofar: count else: batch_size_fn = dyn_batch_without_padding if args.model is Transformer else dyn_batch_with_padding train_real, dev_real = data.BucketIterator.splits( (train_data, dev_data), batch_sizes=(args.batchsize, args.batchsize), device=args.gpu, batch_size_fn=batch_size_fn, repeat=None if args.mode == 'train' else False) logger.info("build the dataset. done!") # model hyper-params: hparams = None if args.dataset == 'iwslt': if args.params == 'james-iwslt': hparams = {'d_model': 278, 'd_hidden': 507, 'n_layers': 5, 'n_heads': 2, 'drop_ratio': 0.079, 'warmup': 746} # ~32 elif args.params == 'james-iwslt2': hparams = {'d_model': 278, 'd_hidden': 2048, 'n_layers': 5, 'n_heads': 2, 'drop_ratio': 0.079, 'warmup': 746} # ~32 teacher_hparams = {'d_model': 278, 'd_hidden': 507, 'n_layers': 5, 'n_heads': 2, 'drop_ratio': 0.079, 'warmup': 746} elif args.dataset == 'wmt16-ende': logger.info('use default parameters of t2t-base') hparams = {'d_model': 512, 'd_hidden': 512, 'n_layers': 6, 'n_heads': 8, 'drop_ratio': 0.1, 'warmup': 16000} # ~32 teacher_hparams = hparams elif args.dataset == 'wmt16-deen': logger.info('use default parameters of t2t-base') hparams = {'d_model': 512, 'd_hidden': 512, 'n_layers': 6, 'n_heads': 8, 'drop_ratio': 0.1, 'warmup': 16000} # ~32 teacher_hparams = hparams elif args.dataset == 'wmt16-enro': logger.info('use default parameters of t2t-base') hparams = {'d_model': 512, 'd_hidden': 512, 'n_layers': 6, 'n_heads': 8, 'drop_ratio': 0.1, 'warmup': 16000} # ~32 teacher_hparams = hparams elif args.dataset == 'wmt16-roen': logger.info('use default parameters of t2t-base') hparams = {'d_model': 512, 'd_hidden': 512, 'n_layers': 6, 'n_heads': 8, 'drop_ratio': 0.1, 'warmup': 16000} # ~32 teacher_hparams = hparams if hparams is None: logger.info('use default parameters of t2t-base') hparams = {'d_model': 512, 'd_hidden': 512, 'n_layers': 6, 'n_heads': 8, 'drop_ratio': 0.1, 'warmup': 16000} # ~32 if args.teacher is not None: teacher_args = copy.deepcopy(args) teacher_args.__dict__.update(teacher_hparams) args.__dict__.update(hparams) if args.hidden_size is not None: args.d_hidden = args.hidden_size # show the arg: logger.info(args) hp_str = (f"{args.dataset}_{args.level}_{'fast_' if args.model is FastTransformer else ''}" f"{args.d_model}_{args.d_hidden}_{args.n_layers}_{args.n_heads}_" f"{args.drop_ratio:.3f}_{args.warmup}_" f"{args.xe_until if hasattr(args, 'xe_until') else ''}_" f"{f'{args.xe_ratio:.3f}' if hasattr(args, 'xe_ratio') else ''}_" f"{args.xe_every if hasattr(args, 'xe_every') else ''}") logger.info(f'Starting with HPARAMS: {hp_str}') model_name = './models/' + args.prefix + hp_str # build the model model = args.model(SRC, TRG, args) if args.load_from is not None: with torch.cuda.device(args.gpu): # very important. model.load_state_dict(torch.load('./models/' + args.load_from + '.pt', map_location=lambda storage, loc: storage.cuda())) # load the pretrained models. if args.critic: model.install_critic() # logger.info(str(model)) # if using a teacher if args.teacher is not None: teacher_model = Transformer(SRC, TRG, teacher_args) with torch.cuda.device(args.gpu): teacher_model.load_state_dict(torch.load('./models/' + args.teacher + '.pt', map_location=lambda storage, loc: storage.cuda())) for params in teacher_model.parameters(): if args.trainable_teacher: params.requires_grad = True else: params.requires_grad = False if (args.share_encoder) and (args.load_from is None): model.encoder = copy.deepcopy(teacher_model.encoder) for params in model.encoder.parameters(): if args.finetune_encoder: params.requires_grad = True else: params.requires_grad = False else: teacher_model = None # use cuda if args.gpu > -1: model.cuda(args.gpu) if align_table is not None: align_table = torch.LongTensor(align_table).cuda(args.gpu) align_table = Variable(align_table) model.alignment = align_table if args.teacher is not None: teacher_model.cuda(args.gpu) def register_nan_checks(m): def check_grad(module, grad_input, grad_output): if any(np.any(np.isnan(gi.data.cpu().numpy())) for gi in grad_input if gi is not None): print('NaN gradient in ' + type(module).__name__) 1/0 m.apply(lambda module: module.register_backward_hook(check_grad)) def get_learning_rate(i, lr0=0.1): if not args.disable_lr_schedule: return lr0 * 10 / math.sqrt(args.d_model) * min( 1 / math.sqrt(i), i / (args.warmup * math.sqrt(args.warmup))) return 0.00002 def export(x): try: with torch.cuda.device(args.gpu): return x.data.cpu().float().mean() except Exception: return 0 def devol(batch): new_batch = copy.copy(batch) new_batch.src = Variable(batch.src.data, volatile=True) return new_batch # register_nan_checks(model) # register_nan_checks(teacher_model) def valid_model(model, dev, dev_metrics=None, distillation=False, print_out=False, teacher_model=None): print_seqs = ['[sources]', '[targets]', '[decoded]', '[fertili]', '[origind]'] trg_outputs, dec_outputs = [], [] outputs = {} model.eval() if teacher_model is not None: teacher_model.eval() for j, dev_batch in enumerate(dev): # decode from the model (whatever Transformer or FastTransformer) torch.cuda.nvtx.range_push('quick_prepare') inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(dev_batch, distillation) torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push('prepare_initial') decoder_inputs, input_reorder, reordering_cost = inputs, None, None if type(model) is FastTransformer: # batch_align = dev_batch.align_dec if distillation else dev_batch.align batch_align = None batch_fer = dev_batch.fer_dec if distillation else dev_batch.fer # if args.postordering: # # targets_sorted = targets.gather(1, align_index) # batch_align_sorted, align_index = masked_sort(batch_align, target_masks) # change the target indexxx, batch x max_trg decoder_inputs, input_reorder, decoder_masks, reordering_cost = model.prepare_initial(encoding, sources, source_masks, input_masks, batch_align, batch_fer, decoding=(not args.cheating), mode='argmax') else: decoder_masks = input_masks torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push('model') decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True) torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push('batched_cost') loss = 0 if args.postordering: if args.cheating: decoding1 = unsorted(decoding, align_index) else: positions = model.predict_offset(out, decoder_masks, None) shifted_index = positions.sort(1)[1] decoding1 = unsorted(decoding, shifted_index) else: decoding1 = decoding # loss = model.batched_cost(targets, target_masks, probs) torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push('output_decoding') dev_outputs = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding1), ('src', input_reorder)]] if args.postordering: dev_outputs += [model.output_decoding(('trg', decoding))] torch.cuda.nvtx.range_pop() torch.cuda.nvtx.range_push('computeGLEU') gleu = computeGLEU(dev_outputs[2], dev_outputs[1], corpus=False, tokenizer=tokenizer) torch.cuda.nvtx.range_pop() if print_out: for k, d in enumerate(dev_outputs): logger.info("{}: {}".format(print_seqs[k], d[0])) logger.info('------------------------------------------------------------------') if teacher_model is not None: # teacher is Transformer, student is FastTransformer inputs_student, _, targets_student, _, _, _, encoding_teacher, _ = teacher_model.quick_prepare(dev_batch, False, decoding, decoding, input_masks, target_masks, source_masks) teacher_real_loss = teacher_model.cost(targets, target_masks, out=teacher_model(encoding_teacher, source_masks, inputs, input_masks)) teacher_fake_out = teacher_model(encoding_teacher, source_masks, inputs_student, input_masks) teacher_fake_loss = teacher_model.cost(targets_student, target_masks, out=teacher_fake_out) teacher_alter_loss = teacher_model.cost(targets, target_masks, out=teacher_fake_out) trg_outputs += dev_outputs[1] dec_outputs += dev_outputs[2] if dev_metrics is not None: values = [loss, gleu] if teacher_model is not None: values += [teacher_real_loss, teacher_fake_loss, teacher_real_loss - teacher_fake_loss, teacher_alter_loss, teacher_alter_loss - teacher_fake_loss] if reordering_cost is not None: values += [reordering_cost] dev_metrics.accumulate(batch_size, *values) corpus_gleu = computeGLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) outputs['corpus_gleu'] = corpus_gleu outputs['corpus_bleu'] = corpus_bleu if dev_metrics is not None: logger.info(dev_metrics) logger.info("The dev-set corpus GLEU = {}".format(corpus_gleu)) logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu)) return outputs def train_model(model, train, dev, teacher_model=None): if args.tensorboard and (not args.debug): from tensorboardX import SummaryWriter writer = SummaryWriter('./runs/{}'.format(args.prefix+hp_str)) # optimizer if args.optimizer == 'Adam': opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9) if args.trainable_teacher: opt_teacher = torch.optim.Adam([p for p in teacher_model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9) elif args.optimizer == 'RMSprop': opt = torch.optim.RMSprop([p for p in model.parameters() if p.requires_grad], eps=1e-9) if args.trainable_teacher: opt_teacher = torch.optim.RMSprop([p for p in teacher_model.parameters() if p.requires_grad], eps=1e-9) else: raise NotImplementedError # if resume training if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load('./models/' + args.load_from + '.pt.states', map_location=lambda storage, loc: storage.cuda()) opt.load_state_dict(opt_states) else: offset = 0 # metrics best = Best(max, 'corpus_bleu', 'corpus_gleu', 'gleu', 'loss', 'i', model=model, opt=opt, path=model_name, gpu=args.gpu) train_metrics = Metrics('train', 'loss', 'real', 'fake') dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'reordering_loss', 'corpus_gleu') progressbar = tqdm(total=args.eval_every, desc='start training.') # cache if args.max_cache > 0: caches = Cache(args.max_cache, args.gpu) for iters, batch in enumerate(train): iters += offset if iters > args.maximum_steps: logger.info('reach the maximum updating steps.') break if iters % args.eval_every == 0: progressbar.close() dev_metrics.reset() if args.seq_dist: outputs_course = valid_model(model, dev, dev_metrics, distillation=True, teacher_model=None)#teacher_model=teacher_model) if args.trainable_teacher: outputs_teacher = valid_model(teacher_model, dev, None) outputs_data = valid_model(model, dev, None if args.seq_dist else dev_metrics, teacher_model=None, print_out=True) if args.tensorboard and (not args.debug): writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu, iters) writer.add_scalar('dev/Loss', dev_metrics.loss, iters) writer.add_scalar('dev/GLEU_corpus_', outputs_data['corpus_gleu'], iters) writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'], iters) if args.seq_dist: writer.add_scalar('dev/GLEU_corpus_dis', outputs_course['corpus_gleu'], iters) writer.add_scalar('dev/BLEU_corpus_dis', outputs_course['corpus_bleu'], iters) if args.trainable_teacher: writer.add_scalar('dev/GLEU_corpus_teacher', outputs_teacher['corpus_gleu'], iters) writer.add_scalar('dev/BLEU_corpus_teacher', outputs_teacher['corpus_bleu'], iters) if args.teacher is not None: writer.add_scalar('dev/Teacher_real_loss', dev_metrics.real_loss, iters) writer.add_scalar('dev/Teacher_fake_loss', dev_metrics.fake_loss, iters) writer.add_scalar('dev/Teacher_alter_loss', dev_metrics.alter_loss, iters) writer.add_scalar('dev/Teacher_distance', dev_metrics.distance, iters) writer.add_scalar('dev/Teacher_distance2', dev_metrics.distance2, iters) if args.preordering: writer.add_scalar('dev/Reordering_loss', dev_metrics.reordering_loss, iters) if not args.debug: best.accumulate(outputs_data['corpus_bleu'], outputs_data['corpus_gleu'], dev_metrics.gleu, dev_metrics.loss, iters) logger.info('the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}'.format( best.i, best.gleu, best.corpus_gleu, best.corpus_bleu)) logger.info('model:' + args.prefix + hp_str) # ---set-up a new progressor--- progressbar = tqdm(total=args.eval_every, desc='start training.') # --- training --- # # try: model.train() opt.param_groups[0]['lr'] = get_learning_rate(iters + 1) opt.zero_grad() # prepare the data inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(batch, args.seq_dist) input_reorder, reordering_cost, decoder_inputs = None, None, inputs batch_align = None # batch.align_dec if args.seq_dist else batch.align batch_fer = batch.fer_dec if args.seq_dist else batch.fer # batch_align_sorted, align_index = masked_sort(batch_align, target_masks) # change the target indexxx, batch x max_trg # print(batch_fer.size(), input_masks.size(), source_masks.size(), sources.size()) # Prepare_Initial if type(model) is FastTransformer: inputs, input_reorder, input_masks, reordering_cost = model.prepare_initial(encoding, sources, source_masks, input_masks, batch_align, batch_fer) # Maximum Likelihood Training feedback = {} if not args.word_dist: loss = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks, positions= None, feedback=feedback)) # train the reordering also using MLE?? if args.preordering: loss += reordering_cost else: # only used for FastTransformer: word-level adjustment if not args.preordering: decoding, out, probs = model(encoding, source_masks, inputs, input_masks, return_probs=True, decoding=True) loss_student = model.batched_cost(targets, target_masks, probs) # student-loss (MLE) decoder_masks = input_masks else: # Note that MLE and decoding has different translations. We need to run the same code twice if args.finetuning_truth: decoding, out, probs = model(encoding, source_masks, inputs, input_masks, decoding=True, return_probs=True, feedback=feedback) loss_student = model.cost(targets, target_masks, out=out) decoder_masks = input_masks else: if args.fertility_mode != 'reinforce': loss_student = model.cost(targets, target_masks, out=model(encoding, source_masks, inputs, input_masks, positions=None, feedback=feedback)) decoder_inputs, _, decoder_masks, _ = model.prepare_initial(encoding, sources, source_masks, input_masks, batch_align, batch_fer, decoding=True, mode=args.fertility_mode) decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True) # decode again else: # truth decoding, out, probs = model(encoding, source_masks, inputs, input_masks, decoding=True, return_probs=True, feedback=feedback) loss_student = model.cost(targets, target_masks, out=out) decoder_masks = input_masks # baseline decoder_inputs_b, _, decoder_masks_b, _ = model.prepare_initial(encoding, sources, source_masks, input_masks, batch_align, batch_fer, decoding=True, mode='mean') decoding_b, out_b, probs_b = model(encoding, source_masks, decoder_inputs_b, decoder_masks_b, decoding=True, return_probs=True) # decode again # reinforce decoder_inputs_r, _, decoder_masks_r, _ = model.prepare_initial(encoding, sources, source_masks, input_masks, batch_align, batch_fer, decoding=True, mode='reinforce') decoding_r, out_r, probs_r = model(encoding, source_masks, decoder_inputs_r, decoder_masks_r, decoding=True, return_probs=True) # decode again # train the reordering also using MLE?? if args.preordering: loss_student += reordering_cost # teacher tries translation + look-at student's output teacher_model.eval() if args.fertility_mode != 'reinforce': inputs_student_index, _, targets_student_soft, _, _, _, encoding_teacher, _ = model.quick_prepare(batch, False, decoding, probs, decoder_masks, decoder_masks, source_masks) out_teacher, probs_teacher = teacher_model(encoding_teacher, source_masks, inputs_student_index.detach(), decoder_masks, return_probs=True) loss_teacher = teacher_model.batched_cost(targets_student_soft, decoder_masks, probs_teacher.detach()) loss = (1 - args.beta1) * loss_teacher + args.beta1 * loss_student # final results else: inputs_student_index, _, targets_student_soft, _, _, _, encoding_teacher, _ = model.quick_prepare(batch, False, decoding, probs, decoder_masks, decoder_masks, source_masks) out_teacher, probs_teacher = teacher_model(encoding_teacher, source_masks, inputs_student_index.detach(), decoder_masks, return_probs=True) loss_teacher = teacher_model.batched_cost(targets_student_soft, decoder_masks, probs_teacher.detach()) inputs_student_index, _ = model.prepare_inputs(batch, decoding_b, False, decoder_masks_b) targets_student_soft, _ = model.prepare_targets(batch, probs_b, False, decoder_masks_b) out_teacher, probs_teacher = teacher_model(encoding_teacher, source_masks, inputs_student_index.detach(), decoder_masks_b, return_probs=True) _, loss_1= teacher_model.batched_cost(targets_student_soft, decoder_masks_b, probs_teacher.detach(), True) inputs_student_index, _ = model.prepare_inputs(batch, decoding_r, False, decoder_masks_r) targets_student_soft, _ = model.prepare_targets(batch, probs_r, False, decoder_masks_r) out_teacher, probs_teacher = teacher_model(encoding_teacher, source_masks, inputs_student_index.detach(), decoder_masks_r, return_probs=True) _, loss_2= teacher_model.batched_cost(targets_student_soft, decoder_masks_r, probs_teacher.detach(), True) rewards = -(loss_2 - loss_1).data # if rewards.size(0) != 1: rewards = rewards - rewards.mean() # ) / (rewards.std() + TINY) rewards = rewards.expand_as(source_masks) rewards = rewards * source_masks # print(model.predictor.saved_fertilities) # print(batch.src.size()) model.predictor.saved_fertilities.reinforce(0.1 * rewards.contiguous().view(-1, 1)) loss = (1 - args.beta1) * loss_teacher + args.beta1 * loss_student #+ 0 * model.predictor.saved_fertilities.float().sum() # detect reinforce # loss = 0 * model.predictor.saved_fertilities.float().sum() # detect reinforce # accmulate the training metrics train_metrics.accumulate(batch_size, loss, print_iter=None) train_metrics.reset() # train the student if args.preordering and args.fertility_mode == 'reinforce': torch.autograd.backward((loss, model.predictor.saved_fertilities), (torch.ones(1).cuda(loss.get_device()), None)) else: loss.backward() # torch.nn.utils.clip_grad_norm(model.parameters(), 1) opt.step() info = 'training step={}, loss={:.3f}, lr={:.5f}'.format(iters, export(loss), opt.param_groups[0]['lr']) if args.word_dist: info += '| NA:{:.3f}, AR:{:.3f}'.format(export(loss_student), export(loss_teacher)) if args.trainable_teacher and (args.max_cache <= 0): loss_alter, loss_worse = export(loss_alter), export(loss_worse) info += '| AL:{:.3f}, WO:{:.3f}'.format(loss_alter, loss_worse) if args.preordering: info += '| RE:{:.3f}'.format(export(reordering_cost)) if args.fertility_mode == 'reinforce': info += '| RL: {:.3f}'.format(export(rewards.mean())) if args.max_cache > 0: info += '| caches={}'.format(len(caches.cache)) if args.tensorboard and (not args.debug): writer.add_scalar('train/Loss', export(loss), iters) progressbar.update(1) progressbar.set_description(info) # continue-training the teacher model if args.trainable_teacher: if args.max_cache > 0: caches.add([batch.src, batch.trg, batch.dec, decoding]) # experience-reply # trainable teacher: used old experience to train if (iters+1) % args.replay_every == 0: # ---set-up a new progressor: teacher training--- # progressbar_teacher = tqdm(total=args.replay_times, desc='start training the teacher.') for j in range(args.replay_times): opt_teacher.param_groups[0]['lr'] = get_learning_rate(iters + 1) opt_teacher.zero_grad() src, trg, dec, decoding = caches.sample() batch = Batch(src, trg, dec) inputs, input_masks, targets, target_masks, sources, source_masks, encoding_teacher, batch_size = teacher_model.quick_prepare(batch, (not args.teacher_use_real)) inputs_students, _ = teacher_model.prepare_inputs(batch, decoding, masks=input_masks) loss_alter = teacher_model.cost(targets, target_masks, out=teacher_model(encoding_teacher, source_masks, inputs_students, input_masks)) loss_worse = teacher_model.cost(targets, target_masks, out=teacher_model(encoding_teacher, source_masks, inputs, input_masks)) loss2 = loss_alter + loss_worse loss2.backward() opt_teacher.step() info = 'teacher step={}, loss={:.3f}, alter={:.3f}, worse={:.3f}'.format(j, export(loss2), export(loss_alter), export(loss_worse)) progressbar_teacher.update(1) progressbar_teacher.set_description(info) progressbar_teacher.close() # except Exception as e: # logger.warn('caught an exception: {}'.format(e)) def decode_model(model, train_real, dev_real, evaluate=True, decoding_path=None, names=['en', 'de', 'decode']): if train_real is None: logger.info('decoding from the devlopment set. beamsize={}, alpha={}'.format(args.beam_size, args.alpha)) dev = dev_real else: logger.info('decoding from the training set. beamsize={}, alpha={}'.format(args.beam_size, args.alpha)) dev = train_real dev.train = False # make the Iterator create Variables with volatile=True so no graph is built progressbar = tqdm(total=sum([1 for _ in dev]), desc='start decoding') model.eval() if decoding_path is not None: decoding_path = decoding_path.format(args.test_set if train_real is None else 'train') handle_dec = open(decoding_path + '.{}'.format(names[2]), 'w') handle_src = open(decoding_path + '.{}'.format(names[0]), 'w') handle_trg = open(decoding_path + '.{}'.format(names[1]), 'w') if args.output_fer: handle_fer = open(decoding_path + '.{}'.format('fer'), 'w') corpus_size = 0 src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] decoded_words, target_words, decoded_info = 0, 0, 0 attentions = None #{'source': None, 'target': None} pad_id = model.decoder.field.vocab.stoi['<pad>'] eos_id = model.decoder.field.vocab.stoi['<eos>'] curr_time = 0 for iters, dev_batch in enumerate(dev): start_t = time.time() inputs, input_masks, targets, target_masks, sources, source_masks, encoding, batch_size = model.quick_prepare(dev_batch) if args.model is FastTransformer: decoder_inputs, input_reorder, decoder_masks, _ = model.prepare_initial(encoding, sources, source_masks, input_masks, None, None, decoding=True, mode=args.fertility_mode) else: decoder_inputs, decoder_masks = inputs, input_masks decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=args.beam_size, alpha=args.alpha, decoding=True, feedback=attentions) used_t = time.time() - start_t curr_time += used_t real_mask = 1 - ((decoding.data == eos_id) + (decoding.data == pad_id)).float() outputs = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)]] def DHondt(approx, mask): L = mask.size(1) w = torch.arange(1, 2 * L, 2) if approx.is_cuda: w = w.cuda(approx.get_device()) w = 1 / w # 1, 1/2, 1/3, ... approx = approx[:, :, None] @ w[None, :] # B x Ts x Tt approx = approx.view(approx.size(0), -1) # B x (Ts x Tt) appinx = approx.topk(L, 1)[1] # B x Tt (index) fertility = approx.new(*approx.size()).fill_(0).scatter_(1, appinx, mask) fertility = fertility.contiguous().view(mask.size(0), -1, mask.size(1)).sum(2).long() return fertility def cutoff(s, t): for i in range(len(s), 0, -1): if s[i-1] != t: return s[:i] raise IndexError if args.output_fer: source_attention = attentions['source'].data.mean(1).transpose(2, 1) # B x Ts x Tt source_attention *= real_mask[:, None, :] approx_fertility = source_attention.sum(2) # B x Ts fertility = DHondt(approx_fertility, real_mask) corpus_size += batch_size src_outputs += outputs[0] trg_outputs += outputs[1] dec_outputs += outputs[2] timings += [used_t] if decoding_path is not None: for s, t, d in zip(outputs[0], outputs[1], outputs[2]): if args.no_bpe: s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') print(s, file=handle_src, flush=True) print(t, file=handle_trg, flush=True) print(d, file=handle_dec, flush=True) if args.output_fer: with torch.cuda.device_of(fertility): fertility = fertility.tolist() for f in fertility: f = ' '.join([str(fi) for fi in cutoff(f, 0)]) print(f, file=handle_fer, flush=True) progressbar.update(1) progressbar.set_description('finishing sentences={}/batches={}, speed={} sec/batch'.format(corpus_size, iters, curr_time / (1 + iters))) if evaluate: corpus_gleu = computeGLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) logger.info("The dev-set corpus GLEU = {}".format(corpus_gleu)) logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu)) computeGroupBLEU(dec_outputs, trg_outputs, tokenizer=tokenizer) torch.save([src_outputs, trg_outputs, dec_outputs, timings], './space/data.pt') def noisy_decode_model(model, dev_real, samples=1, alpha=1, tau=1, teacher_model=None, evaluate=True, decoding_path=None, names=['en', 'de', 'decode'], saveall=False): assert type(model) is FastTransformer, 'only works for fastTransformer' logger.info('decoding from the devlopment set. beamsize={}, alpha={}, tau={}'.format(args.beam_size, args.alpha, args.temperature)) dev = dev_real progressbar = tqdm(total=sum([1 for _ in dev]), desc='start decoding') model.eval() teacher_model.eval() if decoding_path is not None: decoding_path = decoding_path.format(args.test_set if train_real is None else 'train') handle_dec = open(decoding_path + '.{}'.format(names[2]), 'w') handle_src = open(decoding_path + '.{}'.format(names[0]), 'w') handle_trg = open(decoding_path + '.{}'.format(names[1]), 'w') # if saveall: # handle_fer = open(decoding_path + '.{}'.format(names[3]), 'w') corpus_size = 0 src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] all_dec_outputs = [] decoded_words, target_words, decoded_info = 0, 0, 0 attentions = None #{'source': None, 'target': None} pad_id = model.decoder.field.vocab.stoi['<pad>'] eos_id = model.decoder.field.vocab.stoi['<eos>'] curr_time = 0 for iters, dev_batch in enumerate(dev): start_t = time.time() inputs, input_masks, targets, target_masks, sources, source_masks0, encoding0, batch_size = model.quick_prepare(dev_batch) if teacher_model is not None: encoding_teacher = teacher_model.encoding(sources, source_masks0) batch_size, src_len, hsize = encoding0[0].size() if samples > 1: source_masks = source_masks0[:, None, :].expand(batch_size, samples, src_len).contiguous().view(batch_size * samples, src_len) encoding = [None for _ in encoding0] for i in range(len(encoding)): encoding[i] = encoding0[i][:, None, :].expand( batch_size, samples, src_len, hsize).contiguous().view(batch_size * samples, src_len, hsize) if teacher_model is not None: for i in range(len(encoding)): encoding_teacher[i] = encoding_teacher[i][:, None, :].expand( batch_size, samples, src_len, hsize).contiguous().view(batch_size * samples, src_len, hsize) def parallel(): decoder_inputs, input_reorder, decoder_masks, logits_fer = model.prepare_initial(encoding0, sources, source_masks0, input_masks, None, None, decoding=True, mode=args.fertility_mode, N=samples, tau=tau) if teacher_model is not None: decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, feedback=attentions) student_inputs, _ = teacher_model.prepare_inputs(dev_batch, decoding, decoder_masks) student_targets, _ = teacher_model.prepare_targets(dev_batch, decoding, decoder_masks) out, probs = teacher_model(encoding_teacher, source_masks, student_inputs, decoder_masks, return_probs=True, decoding=False) _, teacher_loss = model.batched_cost(student_targets, decoder_masks, probs, batched=True) # student-loss (MLE) # reranking the translation teacher_loss = teacher_loss.view(batch_size, samples) decoding = decoding.view(batch_size, samples, -1) lp = decoder_masks.sum(1).view(batch_size, samples) ** (1 - alpha) teacher_loss = teacher_loss * Variable(lp) return decoding, teacher_loss, input_reorder if args.multi_run > 1: decodings, teacher_losses, _ = zip(*[parallel() for _ in range(args.multi_run)]) maxl = max([d.size(2) for d in decodings]) decoding = Variable(sources.data.new(batch_size, samples * args.multi_run, maxl).fill_(1).long()) for i, d in enumerate(decodings): decoding[:, i * samples: (i+1) * samples, :d.size(2)] = d teacher_loss = torch.cat(teacher_losses, 1) else: decoding, teacher_loss, input_reorder = parallel() all_dec_outputs += [(decoding.view(batch_size * samples, -1), input_reorder)] selected_idx = (-teacher_loss).topk(1, 1)[1] # batch x 1 decoding = decoding.gather(1, selected_idx[:, :, None].expand(batch_size, 1, decoding.size(-1)))[:, 0, :] used_t = time.time() - start_t curr_time += used_t real_mask = 1 - ((decoding.data == eos_id) + (decoding.data == pad_id)).float() outputs = [model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)]] corpus_size += batch_size src_outputs += outputs[0] trg_outputs += outputs[1] dec_outputs += outputs[2] timings += [used_t] if decoding_path is not None: for s, t, d in zip(outputs[0], outputs[1], outputs[2]): if args.no_bpe: s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') print(s, file=handle_src, flush=True) print(t, file=handle_trg, flush=True) print(d, file=handle_dec, flush=True) # if saveall: # for d, f in all_dec_outputs: # ds = model.output_decoding(('trg', d)) # fs = model.output_decoding(('src', f)) # for dd, ff in zip(ds, fs): # print(dd, file=handle_fer, flush=True) # print(ff, file=handle_fer, flush=True) progressbar.update(1) progressbar.set_description('finishing sentences={}/batches={} speed={} sec/batch'.format(corpus_size, iters, curr_time / (1 + iters))) if evaluate: corpus_gleu = computeGLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) logger.info("The dev-set corpus GLEU = {}".format(corpus_gleu)) logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu)) computeGroupBLEU(dec_outputs, trg_outputs, tokenizer=tokenizer) torch.save([src_outputs, trg_outputs, dec_outputs, timings], './space/data.pt') def self_improving_model(model, train, dev): if args.tensorboard and (not args.debug): from tensorboardX import SummaryWriter writer = SummaryWriter('./runs/self-{}'.format(args.prefix+hp_str)) # optimizer if args.optimizer == 'Adam': opt = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9) if args.trainable_teacher: opt_teacher = torch.optim.Adam([p for p in teacher_model.parameters() if p.requires_grad], betas=(0.9, 0.98), eps=1e-9) elif args.optimizer == 'RMSprop': opt = torch.optim.RMSprop([p for p in model.parameters() if p.requires_grad], eps=1e-9) if args.trainable_teacher: opt_teacher = torch.optim.RMSprop([p for p in teacher_model.parameters() if p.requires_grad], eps=1e-9) else: raise NotImplementedError # if resume training -- if (args.load_from is not None) and (args.resume): with torch.cuda.device(args.gpu): # very important. offset, opt_states = torch.load('./models/' + args.load_from + '.pt.states', map_location=lambda storage, loc: storage.cuda()) opt.load_state_dict(opt_states) else: offset = 0 # metrics best = Best(max, 'corpus_bleu', 'corpus_gleu', 'gleu', 'loss', 'i', model=model, opt=opt, path=model_name, gpu=args.gpu) train_metrics = Metrics('train', 'loss', 'real', 'fake') dev_metrics = Metrics('dev', 'loss', 'gleu', 'real_loss', 'fake_loss', 'distance', 'alter_loss', 'distance2', 'reordering_loss', 'corpus_gleu') progressbar = tqdm(total=args.eval_every, desc='start training.') # cache samples = 100 tau = 1 caches = Cache(args.max_cache, ['src', 'trg', 'dec', 'fer']) best_model = copy.deepcopy(model) # used for decoding best_score = 0 # start loop iters = offset train = iter(train) counters = 0 while iters <= args.maximum_steps: iters += 1 counters += 1 batch = devol(next(train)) # prepare inputs model.eval() inputs, input_masks, \ targets, target_masks, \ sources, source_masks0, encoding, batch_size = model.quick_prepare(batch) _, src_len, hsize = encoding[0].size() trg_len = targets.size(1) # prepare parallel -- noisy sampling decoder_inputs, input_reorder, decoder_masks, _, pred_fer \ = model.prepare_initial(encoding, sources, source_masks0, input_masks, None, None, decoding=True, mode='reinforce', N=samples, tau=tau, return_samples=True) # repeating for decoding source_masks = source_masks0[:, None, :].expand(batch_size, samples, src_len).contiguous().view(batch_size * samples, src_len) for i in range(len(encoding)): encoding[i] = encoding[i][:, None, :].expand( batch_size, samples, src_len, hsize).contiguous().view(batch_size * samples, src_len, hsize) # run decoding decoding, _, probs = best_model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True) # compute GLEU score to select the best translation trg_output = best_model.output_decoding(('trg', targets[:, None, :].expand(batch_size, samples, trg_len).contiguous().view(batch_size * samples, trg_len))) dec_output = best_model.output_decoding(('trg', decoding)) bleu_score = computeBLEU(dec_output, trg_output, corpus=False, tokenizer=tokenizer).contiguous().view(batch_size, samples).cuda(args.gpu) best_index = bleu_score.max(1)[1] def index_gather(data, index, samples): batch_size = index.size(0) data = data.contiguous().view(batch_size, samples, -1) # batch x samples x dim index = index[:, None, None].expand(batch_size, 1, data.size(2)) return data.gather(1, index)[:, 0, :] best_decoding, best_decoder_masks, best_fertilities = [index_gather(x, best_index, samples) for x in [decoding, decoder_masks, pred_fer]] caches.add([sources, targets, best_decoding, best_fertilities], [source_masks0, target_masks, best_decoder_masks, source_masks0], ['src', 'trg', 'dec', 'fer']) progressbar.update(1) progressbar.set_description('caching sentences={}/batches={}'.format(len(caches.cache), iters)) if counters == args.eval_every: logger.info('build a new dataset from the caches') print(len(caches.cache)) cache_data = ParallelDataset(examples=caches.cache, fields=[('src', SRC), ('trg', TRG), ('dec', TRG), ('fer', FER)]) cache_iter = data.BucketIterator(cache_data, batch_sizes=2048, device=args.gpu, batch_size_fn=batch_size_fn) print('done') import sys;sys.exit(1) if False: # iters % args.eval_every == 0: progressbar.close() dev_metrics.reset() outputs_data = valid_model(model, dev, None if args.seq_dist else dev_metrics, teacher_model=None, print_out=True) if args.tensorboard and (not args.debug): writer.add_scalar('dev/GLEU_sentence_', dev_metrics.gleu, iters) writer.add_scalar('dev/Loss', dev_metrics.loss, iters) writer.add_scalar('dev/GLEU_corpus_', outputs_data['corpus_gleu'], iters) writer.add_scalar('dev/BLEU_corpus_', outputs_data['corpus_bleu'], iters) if not args.debug: best.accumulate(outputs_data['corpus_bleu'], outputs_data['corpus_gleu'], dev_metrics.gleu, dev_metrics.loss, iters) logger.info('the best model is achieved at {}, average greedy GLEU={}, corpus GLEU={}, corpus BLEU={}'.format( best.i, best.gleu, best.corpus_gleu, best.corpus_bleu)) logger.info('model:' + args.prefix + hp_str) # ---set-up a new progressor--- progressbar = tqdm(total=args.eval_every, desc='start training.') if args.mode == 'train': logger.info('starting training') train_model(model, train_real, dev_real, teacher_model) elif args.mode == 'self': logger.info('starting self-training') self_improving_model(model, train_real, dev_real) elif args.mode == 'test': logger.info('starting decoding from the pre-trained model, test...') names = ['dev.src.b{}={}.{}'.format(args.beam_size, args.load_from, args,fertility_mode), 'dev.trg.b{}={}.{}'.format(args.beam_size, args.load_from, args,fertility_mode), 'dev.dec.b{}={}.{}'.format(args.beam_size, args.load_from, args,fertility_mode)] decode_model(model, None, dev_real, evaluate=True, decoding_path=decoding_path if not args.no_write else None, names=names) elif args.mode == 'test_noisy': logger.info('starting decoding from the pre-trained model, test...') names = ['dev.src.b{}={}.noise{}'.format(args.beam_size, args.load_from, args.beam_size), 'dev.trg.b{}={}.noise{}'.format(args.beam_size, args.load_from, args.beam_size), 'dev.dec.b{}={}.noise{}'.format(args.beam_size, args.load_from, args.beam_size), 'dev.fer.b{}={}.noise{}'.format(args.beam_size, args.load_from, args.beam_size)] noisy_decode_model(model, dev_real, samples=args.beam_size, alpha=args.alpha, tau=args.temperature, teacher_model=teacher_model, evaluate=True, decoding_path=decoding_path if not args.no_write else None, names=names, saveall=True) else: logger.info('starting decoding from the pre-trained model, build the course dataset...') names = ['src.b{}'.format(args.beam_size), 'trg.b{}'.format(args.beam_size), 'dec.b{}'.format(args.beam_size)] decode_model(model, train_real, dev_real, decoding_path=decoding_path if not args.no_write else None, names=names) logger.info("done.")