import argparse import time import math import numpy as np import torch import torch.nn as nn from torch.autograd import Variable import data import model from utils import batchify, get_batch, repackage_hidden parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model') parser.add_argument('--data', type=str, default='data/penn', help='location of the data corpus') parser.add_argument('--model', type=str, default='LSTM', help='type of recurrent net (LSTM, QRNN)') parser.add_argument('--save', type=str,default='best.pt', help='model to use the pointer over') parser.add_argument('--cuda', action='store_false', help='use CUDA') parser.add_argument('--bptt', type=int, default=5000, help='sequence length') parser.add_argument('--window', type=int, default=3785, help='pointer window length') parser.add_argument('--theta', type=float, default=0.6625523432485668, help='mix between uniform distribution and pointer softmax distribution over previous words') parser.add_argument('--lambdasm', type=float, default=0.12785920428335693, help='linear mix between only pointer (1) and only vocab (0) distribution') args = parser.parse_args() ############################################################################### # Load data ############################################################################### corpus = data.Corpus(args.data) eval_batch_size = 1 test_batch_size = 1 #train_data = batchify(corpus.train, args.batch_size) val_data = batchify(corpus.valid, test_batch_size, args) test_data = batchify(corpus.test, test_batch_size, args) ############################################################################### # Build the model ############################################################################### ntokens = len(corpus.dictionary) criterion = nn.CrossEntropyLoss() def one_hot(idx, size, cuda=True): a = np.zeros((1, size), np.float32) a[0][idx] = 1 v = Variable(torch.from_numpy(a)) if cuda: v = v.cuda() return v def evaluate(data_source, batch_size=10, window=args.window): # Turn on evaluation mode which disables dropout. if args.model == 'QRNN': model.reset() model.eval() total_loss = 0 ntokens = len(corpus.dictionary) hidden = model.init_hidden(batch_size) next_word_history = None pointer_history = None for i in range(0, data_source.size(0) - 1, args.bptt): if i > 0: print(i, len(data_source), math.exp(total_loss / i)) data, targets = get_batch(data_source, i, evaluation=True, args=args) output, hidden, rnn_outs, _ = model(data, hidden, return_h=True) rnn_out = rnn_outs[-1].squeeze() output_flat = output.view(-1, ntokens) ### # Fill pointer history start_idx = len(next_word_history) if next_word_history is not None else 0 next_word_history = torch.cat([one_hot(t.data[0], ntokens) for t in targets]) if next_word_history is None else torch.cat([next_word_history, torch.cat([one_hot(t.data[0], ntokens) for t in targets])]) #print(next_word_history) pointer_history = Variable(rnn_out.data) if pointer_history is None else torch.cat([pointer_history, Variable(rnn_out.data)], dim=0) #print(pointer_history) ### # Built-in cross entropy # total_loss += len(data) * criterion(output_flat, targets).data[0] ### # Manual cross entropy # softmax_output_flat = torch.nn.functional.softmax(output_flat) # soft = torch.gather(softmax_output_flat, dim=1, index=targets.view(-1, 1)) # entropy = -torch.log(soft) # total_loss += len(data) * entropy.mean().data[0] ### # Pointer manual cross entropy loss = 0 softmax_output_flat = torch.nn.functional.softmax(output_flat) for idx, vocab_loss in enumerate(softmax_output_flat): p = vocab_loss if start_idx + idx > window: valid_next_word = next_word_history[start_idx + idx - window:start_idx + idx] valid_pointer_history = pointer_history[start_idx + idx - window:start_idx + idx] logits = torch.mv(valid_pointer_history, rnn_out[idx]) theta = args.theta ptr_attn = torch.nn.functional.softmax(theta * logits).view(-1, 1) ptr_dist = (ptr_attn.expand_as(valid_next_word) * valid_next_word).sum(0).squeeze() lambdah = args.lambdasm p = lambdah * ptr_dist + (1 - lambdah) * vocab_loss ### target_loss = p[targets[idx].data] loss += (-torch.log(target_loss)).data[0] total_loss += loss / batch_size ### hidden = repackage_hidden(hidden) next_word_history = next_word_history[-window:] pointer_history = pointer_history[-window:] return total_loss / len(data_source) # Load the best saved model. with open(args.save, 'rb') as f: if not args.cuda: model = torch.load(f, map_location=lambda storage, loc: storage) else: model = torch.load(f) print(model) # Run on val data. val_loss = evaluate(val_data, test_batch_size) print('=' * 89) print('| End of pointer | val loss {:5.2f} | val ppl {:8.2f}'.format( val_loss, math.exp(val_loss))) print('=' * 89) # Run on test data. test_loss = evaluate(test_data, test_batch_size) print('=' * 89) print('| End of pointer | test loss {:5.2f} | test ppl {:8.2f}'.format( test_loss, math.exp(test_loss))) print('=' * 89)