import os import gc import sys import glob import time import math import numpy as np import torch import torch.nn as nn import logging import argparse import torch.backends.cudnn as cudnn import data import model from utils import batchify, get_batch, repackage_hidden, create_exp_dir, save_checkpoint, parse_arch parser = argparse.ArgumentParser(description='PyTorch PennTreeBank/WikiText2 Language Model') parser.add_argument('--data', type=str, default='../data/penn/', help='location of the data corpus') parser.add_argument('--emsize', type=int, default=850, help='size of word embeddings') parser.add_argument('--nhid', type=int, default=850, help='number of hidden units per layer') parser.add_argument('--nhidlast', type=int, default=850, help='number of hidden units for the last rnn layer') parser.add_argument('--lr', type=float, default=20, help='initial learning rate') parser.add_argument('--clip', type=float, default=0.25, help='gradient clipping') parser.add_argument('--epochs', type=int, default=8000, help='upper epoch limit') parser.add_argument('--batch_size', type=int, default=64, metavar='N', help='batch size') parser.add_argument('--bptt', type=int, default=35, help='sequence length') parser.add_argument('--dropout', type=float, default=0.75, help='dropout applied to layers (0 = no dropout)') parser.add_argument('--dropouth', type=float, default=0.25, help='dropout for hidden nodes in rnn layers (0 = no dropout)') parser.add_argument('--dropoutx', type=float, default=0.75, help='dropout for input nodes rnn layers (0 = no dropout)') parser.add_argument('--dropouti', type=float, default=0.2, help='dropout for input embedding layers (0 = no dropout)') parser.add_argument('--dropoute', type=float, default=0, help='dropout to remove words from embedding layer (0 = no dropout)') parser.add_argument('--seed', type=int, default=1267, help='random seed') parser.add_argument('--nonmono', type=int, default=5, help='random seed') parser.add_argument('--cuda', action='store_false', help='use CUDA') parser.add_argument('--log-interval', type=int, default=200, metavar='N', help='report interval') parser.add_argument('--save', type=str, default='EXP', help='path to save the final model') parser.add_argument('--alpha', type=float, default=0, help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)') parser.add_argument('--beta', type=float, default=1e-3, help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)') parser.add_argument('--wdecay', type=float, default=8e-7, help='weight decay applied to all weights') parser.add_argument('--continue_train', action='store_true', help='continue train from a checkpoint') parser.add_argument('--small_batch_size', type=int, default=-1, help='the batch size for computation. batch_size should be divisible by small_batch_size.\ In our implementation, we compute gradients with small_batch_size multiple times, and accumulate the gradients\ until batch_size is reached. An update step is then performed.') parser.add_argument('--max_seq_len_delta', type=int, default=20, help='max sequence length') parser.add_argument('--single_gpu', default=True, action='store_false', help='use single GPU') parser.add_argument('--gpu', type=int, default=0, help='GPU device to use') parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use') args = parser.parse_args() if args.nhidlast < 0: args.nhidlast = args.emsize if args.small_batch_size < 0: args.small_batch_size = args.batch_size """if not args.continue_train: args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))""" log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) # Set the random seed manually for reproducibility. np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) torch.backends.cudnn.deterministic = True cudnn.benckmark = True if torch.cuda.is_available(): if not args.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") else: torch.cuda.set_device(args.gpu) cudnn.benchmark = True cudnn.enabled=True torch.cuda.manual_seed_all(args.seed) corpus = data.Corpus(args.data) eval_batch_size = 10 test_batch_size = 1 train_data = batchify(corpus.train, args.batch_size, args) val_data = batchify(corpus.valid, eval_batch_size, args) test_data = batchify(corpus.test, test_batch_size, args) ntokens = len(corpus.dictionary) try: genotype = eval("genotypes.%s" % args.arch) except: genotype = parse_arch(args.arch) if os.path.exists(os.path.join(args.save, 'model.pt')): print("Found model.pt in {}, automatically continue training.".format(args.save)) args.continue_train = True if args.continue_train: model = torch.load(os.path.join(args.save, 'model.pt')) else: model = model.RNNModel(ntokens, args.emsize, args.nhid, args.nhidlast, args.dropout, args.dropouth, args.dropoutx, args.dropouti, args.dropoute, cell_cls=model.DARTSCell, genotype=genotype) if args.cuda: if args.single_gpu: parallel_model = model.cuda() else: parallel_model = nn.DataParallel(model, dim=1).cuda() else: parallel_model = model total_params = sum(x.data.nelement() for x in model.parameters()) logging.info('Args: {}'.format(args)) logging.info('Model total parameters: {}'.format(total_params)) logging.info('Genotype: {}'.format(genotype)) def evaluate(data_source, batch_size=10): # Turn on evaluation mode which disables dropout. model.eval() total_loss = 0 ntokens = len(corpus.dictionary) hidden = model.init_hidden(batch_size) for i in range(0, data_source.size(0) - 1, args.bptt): data, targets = get_batch(data_source, i, args, evaluation=True) targets = targets.view(-1) log_prob, hidden = parallel_model(data, hidden) loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data total_loss += loss * len(data) hidden = repackage_hidden(hidden) return total_loss[0] / len(data_source) def train(): assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size' # Turn on training mode which enables dropout. total_loss = 0 start_time = time.time() ntokens = len(corpus.dictionary) hidden = [model.init_hidden(args.small_batch_size) for _ in range(args.batch_size // args.small_batch_size)] batch, i = 0, 0 while i < train_data.size(0) - 1 - 1: bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2. # Prevent excessively small or negative sequence lengths seq_len = max(5, int(np.random.normal(bptt, 5))) # There's a very small chance that it could select a very long sequence length resulting in OOM seq_len = min(seq_len, args.bptt + args.max_seq_len_delta) lr2 = optimizer.param_groups[0]['lr'] optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt model.train() data, targets = get_batch(train_data, i, args, seq_len=seq_len) optimizer.zero_grad() start, end, s_id = 0, args.small_batch_size, 0 while start < args.batch_size: cur_data, cur_targets = data[:, start: end], targets[:, start: end].contiguous().view(-1) # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden[s_id] = repackage_hidden(hidden[s_id]) log_prob, hidden[s_id], rnn_hs, dropped_rnn_hs = parallel_model(cur_data, hidden[s_id], return_h=True) raw_loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), cur_targets) loss = raw_loss # Activiation Regularization if args.alpha > 0: loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:]) # Temporal Activation Regularization (slowness) loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:]) loss *= args.small_batch_size / args.batch_size total_loss += raw_loss.data * args.small_batch_size / args.batch_size loss.backward() s_id += 1 start = end end = start + args.small_batch_size gc.collect() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs. torch.nn.utils.clip_grad_norm(model.parameters(), args.clip) optimizer.step() # total_loss += raw_loss.data optimizer.param_groups[0]['lr'] = lr2 if np.isnan(total_loss[0]): raise OverflowError('NAN loss') if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss[0] / args.log_interval elapsed = time.time() - start_time logging.info('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) total_loss = 0 start_time = time.time() batch += 1 i += seq_len # Loop over epochs. lr = args.lr best_val_loss = [] stored_loss = 100000000 # At any point you can hit Ctrl + C to break out of training early. try: if args.continue_train: optimizer_state = torch.load(os.path.join(args.save, 'optimizer.pt')) if 't0' in optimizer_state['param_groups'][0]: optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) else: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) optimizer.load_state_dict(optimizer_state) epoch = torch.load(os.path.join(args.save, 'misc.pt'))['epoch'] else: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) epoch = 1 while epoch < args.epochs + 1: epoch_start_time = time.time() try: train() except: logging.info('rolling back to the previous best model ...') model = torch.load(os.path.join(args.save, 'model.pt')) parallel_model = model.cuda() optimizer_state = torch.load(os.path.join(args.save, 'optimizer.pt')) if 't0' in optimizer_state['param_groups'][0]: optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) else: optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.wdecay) optimizer.load_state_dict(optimizer_state) epoch = torch.load(os.path.join(args.save, 'misc.pt'))['epoch'] continue if 't0' in optimizer.param_groups[0]: tmp = {} for prm in model.parameters(): tmp[prm] = prm.data.clone() prm.data = optimizer.state[prm]['ax'].clone() val_loss2 = evaluate(val_data) logging.info('-' * 89) logging.info('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2))) logging.info('-' * 89) if val_loss2 < stored_loss: save_checkpoint(model, optimizer, epoch, args.save) logging.info('Saving Averaged!') stored_loss = val_loss2 for prm in model.parameters(): prm.data = tmp[prm].clone() else: val_loss = evaluate(val_data, eval_batch_size) logging.info('-' * 89) logging.info('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ' 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss))) logging.info('-' * 89) if val_loss < stored_loss: save_checkpoint(model, optimizer, epoch, args.save) logging.info('Saving Normal!') stored_loss = val_loss if 't0' not in optimizer.param_groups[0] and (len(best_val_loss)>args.nonmono and val_loss > min(best_val_loss[:-args.nonmono])): logging.info('Switching!') optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay) best_val_loss.append(val_loss) epoch += 1 except KeyboardInterrupt: logging.info('-' * 89) logging.info('Exiting from training early') # Load the best saved model. model = torch.load(os.path.join(args.save, 'model.pt')) parallel_model = model.cuda() # Run on test data. test_loss = evaluate(test_data, test_batch_size) logging.info('=' * 89) logging.info('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( test_loss, math.exp(test_loss))) logging.info('=' * 89)