# copypasta from main.py of pytorch word_language_model code # coding: utf-8 import argparse import time import math import os import torch import torch.nn as nn import torch.onnx import datetime import shutil import pickle import data from relational_rnn_models import RelationalMemory # is it faster? torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='PyTorch Wikitext-2 RNN/LSTM Language Model') # hyperparams for text data parser.add_argument('--data', type=str, default='./data/wikitext-2', help='location of the data corpus') parser.add_argument('--emsize', type=int, default=192, help='size of word embeddings') # NEW!: hyperparams for relational memory core (RMC) parser.add_argument('--memslots', type=int, default=1, help='number of memory slots of the relational memory core') parser.add_argument('--headsize', type=int, default=192, help='size of the each head for multihead attention') parser.add_argument('--numheads', type=int, default=4, help='total number of heads for multihead attention') parser.add_argument('--numblocks', type=int, default=1, help='Number of times to compute attention per time step') parser.add_argument('--forgetbias', type=float, default=1., help='Bias to use for the forget gate, assuming we are using some form of gating') parser.add_argument('--inputbias', type=float, default=0., help='Bias to use for the input gate, assuming we are using some form of gating') parser.add_argument('--gatestyle', type=str, default='unit', help='Whether to use per-element gating (\'unit\'), per-memory slot gating (\'memory\'), or no gating at all (None).') parser.add_argument('--attmlplayers', type=int, default=3, help='Number of layers to use in the post-attention MLP') parser.add_argument('--keysize', type=int, default=64, help='Size of vector to use for key & query vectors in the attention' 'computation. Defaults to None, in which case we use `head_size`') # parameters for adaptive softmax parser.add_argument('--adaptivesoftmax', action='store_true', help='use adaptive softmax during hidden state to output logits.' 'it uses less memory by approximating softmax of large vocabulary.') parser.add_argument('--cutoffs', nargs="*", type=int, default=[10000, 50000, 100000], help='cutoff values for adaptive softmax. list of integers.' 'optimal values are based on word frequencey and vocabulary size of the dataset.') # other hyperparams for general RNN mechanics parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate') parser.add_argument('--clip', type=float, default=0.1, help='gradient clipping') parser.add_argument('--epochs', type=int, default=100, help='upper epoch limit') parser.add_argument('--batch_size', type=int, default=64, metavar='N', help='batch size') parser.add_argument('--bptt', type=int, default=100, help='sequence length') # dropout of RMC is hard-bound to 0.5 at the embedding layer # parser.add_argument('--dropout', type=float, default=0.2, # help='dropout applied to layers (0 = no dropout)') # embed weight tying is set always to true # parser.add_argument('--tied', action='store_true', # help='tie the word embedding and softmax weights') parser.add_argument('--seed', type=int, default=1111, help='random seed') parser.add_argument('--cuda', action='store_true', help='use CUDA') parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='report interval') parser.add_argument('--onnx-export', type=str, default='', help='path to export the final model in onnx format') parser.add_argument('--resume', type=int, default=None, help='if specified with the 1-indexed global epoch, loads the checkpoint and resumes training') # experiment name for this run parser.add_argument('--name', type=str, default=None, help='name for this experiment. generates folder with the name if specified.') args = parser.parse_args() # Set the random seed manually for reproducibility. torch.manual_seed(args.seed) if torch.cuda.is_available(): if not args.cuda: print("WARNING: You have a CUDA device, so you should probably run with --cuda") device = torch.device("cuda" if args.cuda else "cpu") ############################################################################### # Load data ############################################################################### corpus_name = os.path.basename(os.path.normpath(args.data)) corpus_filename = './data/corpus-' + str(corpus_name) + str('.pkl') if os.path.isfile(corpus_filename): print("loading pre-built " + str(corpus_name) + " corpus file...") loadfile = open(corpus_filename, 'rb') corpus = pickle.load(loadfile) loadfile.close() else: print("building " + str(corpus_name) + " corpus...") corpus = data.Corpus(args.data) # save the corpus for later savefile = open(corpus_filename, 'wb') pickle.dump(corpus, savefile) savefile.close() print("corpus saved to pickle") # Starting from sequential data, batchify arranges the dataset into columns. # For instance, with the alphabet as the sequence and batch size 4, we'd get # ┌ a g m s ┐ # │ b h n t │ # │ c i o u │ # │ d j p v │ # │ e k q w │ # └ f l r x ┘. # These columns are treated as independent by the model, which means that the # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient # batch processing. def batchify(data, bsz): # Work out how cleanly we can divide the dataset into bsz parts. nbatch = data.size(0) // bsz # Trim off any extra elements that wouldn't cleanly fit (remainders). data = data.narrow(0, 0, nbatch * bsz) # Evenly divide the data across the bsz batches. data = data.view(bsz, -1).t().contiguous() return data.to(device) eval_batch_size = 32 train_data = batchify(corpus.train, args.batch_size) val_data = batchify(corpus.valid, eval_batch_size) test_data = batchify(corpus.test, eval_batch_size) # create folder for current experiments # name: args.name + current time # includes: entire scripts for faithful reproduction, train & test logs folder_name = str(datetime.datetime.now())[:-7] if args.name is not None: folder_name = str(args.name) + ' ' + folder_name os.mkdir(folder_name) for file in os.listdir(os.getcwd()): if file.endswith(".py"): shutil.copy2(file, os.path.join(os.getcwd(), folder_name)) logger_train = open(os.path.join(os.getcwd(), folder_name, 'train_log.txt'), 'w+') logger_test = open(os.path.join(os.getcwd(), folder_name, 'test_log.txt'), 'w+') # save args to logger logger_train.write(str(args) + '\n') # define saved model file location savepath = os.path.join(os.getcwd(), folder_name) ############################################################################### # Build the model ############################################################################### ntokens = len(corpus.dictionary) print("vocabulary size (ntokens): " + str(ntokens)) if args.adaptivesoftmax: print("Adaptive Softmax is on: the performance depends on cutoff values. check if the cutoff is properly set") print("Cutoffs: " + str(args.cutoffs)) if args.cutoffs[-1] > ntokens: raise ValueError("the last element of cutoff list must be lower than vocab size of the dataset") model = RelationalMemory(mem_slots=args.memslots, head_size=args.headsize, input_size=args.emsize, num_tokens=ntokens, num_heads=args.numheads, num_blocks=args.numblocks, forget_bias=args.forgetbias, input_bias=args.inputbias, attention_mlp_layers=args.attmlplayers, key_size=args.keysize, use_adaptive_softmax=args.adaptivesoftmax, cutoffs=args.cutoffs).to(device) total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) model = nn.DataParallel(model) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5) ############################################################################### # Load the model checkpoint if specified and restore the global & best epoch ############################################################################### if args.resume is not None: print("--resume detected. loading checkpoint...") global_epoch = args.resume if args.resume is not None else 0 best_epoch = args.resume if args.resume is not None else 0 if args.resume is not None: loadpath = os.path.join(os.getcwd(), "model_{}.pt".format(args.resume)) if not os.path.isfile(loadpath): raise FileNotFoundError( "model_{}.pt not found. place the model checkpoint file to the current working directory.".format( args.resume)) checkpoint = torch.load(loadpath) model.load_state_dict(checkpoint["state_dict"]) optimizer.load_state_dict(checkpoint["optimizer"]) scheduler.load_state_dict(checkpoint["scheduler"]) global_epoch = checkpoint["global_epoch"] best_epoch = checkpoint["best_epoch"] print("model built, total trainable params: " + str(total_params)) ############################################################################### # Training code ############################################################################### # get_batch subdivides the source data into chunks of length args.bptt. # If source is equal to the example output of the batchify function, with # a bptt-limit of 2, we'd get the following two Variables for i = 0: # ┌ a g m s ┐ ┌ b h n t ┐ # └ b h n t ┘ └ c i o u ┘ # Note that despite the name of the function, the subdivison of data is not # done along the batch dimension (i.e. dimension 1), since that was handled # by the batchify function. The chunks are along dimension 0, corresponding # to the seq_len dimension in the LSTM. def get_batch(source, i): seq_len = min(args.bptt, len(source) - 1 - i) data = source[i:i + seq_len] target = source[i + 1:i + 1 + seq_len].view(-1) return data, target def evaluate(data_source): # Turn on evaluation mode which disables dropout. model.eval() total_loss = 0. ntokens = len(corpus.dictionary) memory = model.module.initial_state(eval_batch_size, trainable=False).to(device) with torch.no_grad(): for i in range(0, data_source.size(0) - 1, args.bptt): data, targets = get_batch(data_source, i) data = torch.t(data) loss, memory = model(data, memory, targets) loss = torch.mean(loss) # data has shape [T * B, N] total_loss += args.bptt * loss.item() return total_loss / len(data_source) def train(): # Turn on training mode which enables dropout. model.train() total_loss = 0. forward_elapsed_time = 0. start_time = time.time() ntokens = len(corpus.dictionary) # in RMC, "hidden state" is called "memory" instead. so use the name "memory" memory = model.module.initial_state(args.batch_size, trainable=True).to(device) for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)): data, targets = get_batch(train_data, i) # transpose the data to [batch, seq] data = torch.t(data) # synchronize cuda for a proper speed benchmark torch.cuda.synchronize() forward_start_time = time.time() model.zero_grad() # the forward pass of RMC just returns loss and does not return logits (DataParallel code optimization) loss, memory = model(data, memory, targets) loss = torch.mean(loss) total_loss += loss.item() # synchronize cuda for a proper speed benchmark torch.cuda.synchronize() forward_elapsed = time.time() - forward_start_time forward_elapsed_time += forward_elapsed loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss / args.log_interval elapsed = time.time() - start_time printlog = '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | forward ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}'.format( epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'], elapsed * 1000 / args.log_interval, forward_elapsed_time * 1000 / args.log_interval, cur_loss, math.exp(cur_loss)) # print and save the log print(printlog) logger_train.write(printlog + '\n') logger_train.flush() total_loss = 0. # reset timer start_time = time.time() forward_start_time = time.time() forward_elapsed_time = 0. def export_onnx(path, batch_size, seq_len): print('The model is also exported in ONNX format at {}'. format(os.path.realpath(args.onnx_export))) model.eval() dummy_input = torch.LongTensor(seq_len * batch_size).zero_().view(-1, batch_size).to(device) hidden = model.init_hidden(batch_size) torch.onnx.export(model, (dummy_input, hidden), path) # Loop over epochs. best_val_loss = None # At any point you can hit Ctrl + C to break out of training early. try: print("training started...") if global_epoch > args.epochs: raise ValueError("global_epoch is higher than args.epochs when resuming training.") for epoch in range(global_epoch + 1, args.epochs + 1): global_epoch += 1 epoch_start_time = time.time() train() val_loss = evaluate(val_data) print('-' * 89) testlog = '| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | valid ppl {:8.2f}'.format(epoch, ( time.time() - epoch_start_time), val_loss, math.exp(val_loss)) print(testlog) logger_test.write(testlog + '\n') logger_test.flush() print('-' * 89) scheduler.step(val_loss) # Save the model if the validation loss is the best we've seen so far. # model_{} contains state_dict and other states, model_dump_{} contains all the dependencies for generate_rmc.py if not best_val_loss or val_loss < best_val_loss: try: os.remove(os.path.join(savepath, "model_{}.pt".format(best_epoch))) os.remove(os.path.join(savepath, "model_dump_{}.pt").format(best_epoch)) except FileNotFoundError: pass best_epoch = global_epoch torch.save(model, os.path.join(savepath, "model_dump_{}.pt".format(global_epoch))) with open(os.path.join(savepath, "model_{}.pt".format(global_epoch)), 'wb') as f: optimizer_state = optimizer.state_dict() scheduler_state = scheduler.state_dict() torch.save({"state_dict": model.state_dict(), "optimizer": optimizer_state, "scheduler": scheduler_state, "global_epoch": global_epoch, "best_epoch": best_epoch}, f) best_val_loss = val_loss else: pass except KeyboardInterrupt: print('-' * 89) print('Exiting from training early: loading checkpoint from the best epoch {}...'.format(best_epoch)) # Load the best saved model. with open(os.path.join(savepath, "model_{}.pt".format(best_epoch)), 'rb') as f: checkpoint = torch.load(f) model.load_state_dict(checkpoint["state_dict"]) # Run on test data. test_loss = evaluate(test_data) print('=' * 89) testlog = '| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format( test_loss, math.exp(test_loss)) print(testlog) logger_test.write(testlog + '\n') logger_test.flush() print('=' * 89) if len(args.onnx_export) > 0: # Export the model in ONNX format. export_onnx(args.onnx_export, batch_size=1, seq_len=args.bptt)