# coding: utf-8 from __future__ import division, print_function from collections import OrderedDict from time import time import models import data import theano try: import cPickle except ImportError: import _pickle as cPickle import sys import os.path try: input = raw_input except NameError: pass import theano.tensor as T import numpy as np MAX_EPOCHS = 50 MINIBATCH_SIZE = 128 L2_REG = 0.0 CLIPPING_THRESHOLD = 2.0 PATIENCE_EPOCHS = 1 """ Bi-directional RNN with attention For a sequence of N words, the model makes N punctuation decisions (no punctuation before the first word, but there's a decision after the last word or before </S>) """ def get_minibatch(file_name, batch_size, shuffle, with_pauses=False): dataset = data.load(file_name) if shuffle: np.random.shuffle(dataset) X_batch = [] Y_batch = [] if with_pauses: P_batch = [] if len(dataset) < batch_size: print("WARNING: Not enough samples in '%s'. Reduce mini-batch size to %d or use a dataset with at least %d words." % ( file_name, len(dataset), MINIBATCH_SIZE * data.MAX_SEQUENCE_LEN)) for subsequence in dataset: X_batch.append(subsequence[0]) Y_batch.append(subsequence[1]) if with_pauses: P_batch.append(subsequence[2]) if len(X_batch) == batch_size: # Transpose, because the model assumes the first axis is time X = np.array(X_batch, dtype=np.int32).T Y = np.array(Y_batch, dtype=np.int32).T if with_pauses: P = np.array(P_batch, dtype=theano.config.floatX).T if with_pauses: yield X, Y, P else: yield X, Y X_batch = [] Y_batch = [] if with_pauses: P_batch = [] if __name__ == "__main__": if len(sys.argv) > 1: model_name = sys.argv[1] else: sys.exit("'Model name' argument missing!") if len(sys.argv) > 2: num_hidden = int(sys.argv[2]) else: sys.exit("'Hidden layer size' argument missing!") if len(sys.argv) > 3: learning_rate = float(sys.argv[3]) else: sys.exit("'Learning rate' argument missing!") model_file_name = "Model_%s_h%d_lr%s.pcl" % (model_name, num_hidden, learning_rate) print(num_hidden, learning_rate, model_file_name) word_vocabulary = data.read_vocabulary(data.WORD_VOCAB_FILE) punctuation_vocabulary = data.iterable_to_dict(data.PUNCTUATION_VOCABULARY) x = T.imatrix('x') y = T.imatrix('y') lr = T.scalar('lr') continue_with_previous = False if os.path.isfile(model_file_name): while True: resp = input("Found an existing model with the name %s. Do you want to:\n[c]ontinue training the existing model?\n[r]eplace the existing model and train a new one?\n[e]xit?\n>" % model_file_name) resp = resp.lower().strip() if resp not in ('c', 'r', 'e'): continue if resp == 'e': sys.exit() elif resp == 'c': continue_with_previous = True break if continue_with_previous: print("Loading previous model state") net, state = models.load(model_file_name, MINIBATCH_SIZE, x) gsums, learning_rate, validation_ppl_history, starting_epoch, rng = state best_ppl = min(validation_ppl_history) else: rng = np.random rng.seed(1) print("Building model...") net = models.GRU( rng=rng, x=x, minibatch_size=MINIBATCH_SIZE, n_hidden=num_hidden, x_vocabulary=word_vocabulary, y_vocabulary=punctuation_vocabulary ) starting_epoch = 0 best_ppl = np.inf validation_ppl_history = [] gsums = [theano.shared(np.zeros_like(param.get_value(borrow=True))) for param in net.params] cost = net.cost(y) + L2_REG * net.L2_sqr gparams = T.grad(cost, net.params) updates = OrderedDict() # Compute norm of gradients norm = T.sqrt(T.sum( [T.sum(gparam ** 2) for gparam in gparams] )) # Adagrad: "Adaptive subgradient methods for online learning and stochastic optimization" (2011) for gparam, param, gsum in zip(gparams, net.params, gsums): gparam = T.switch( T.ge(norm, CLIPPING_THRESHOLD), gparam / norm * CLIPPING_THRESHOLD, gparam ) # Clipping of gradients updates[gsum] = gsum + (gparam ** 2) updates[param] = param - lr * (gparam / (T.sqrt(updates[gsum] + 1e-6))) train_model = theano.function( inputs=[x, y, lr], outputs=cost, updates=updates ) validate_model = theano.function( inputs=[x, y], outputs=net.cost(y) ) print("Training...") for epoch in range(starting_epoch, MAX_EPOCHS): t0 = time() total_neg_log_likelihood = 0 total_num_output_samples = 0 iteration = 0 for X, Y in get_minibatch(data.TRAIN_FILE, MINIBATCH_SIZE, shuffle=True): total_neg_log_likelihood += train_model(X, Y, learning_rate) total_num_output_samples += np.prod(Y.shape) iteration += 1 if iteration % 100 == 0: sys.stdout.write("PPL: %.4f; Speed: %.2f sps\n" % (np.exp(total_neg_log_likelihood / total_num_output_samples), total_num_output_samples / max(time() - t0, 1e-100))) sys.stdout.flush() print("Total number of training labels: %d" % total_num_output_samples) total_neg_log_likelihood = 0 total_num_output_samples = 0 for X, Y in get_minibatch(data.DEV_FILE, MINIBATCH_SIZE, shuffle=False): total_neg_log_likelihood += validate_model(X, Y) total_num_output_samples += np.prod(Y.shape) print("Total number of validation labels: %d" % total_num_output_samples) ppl = np.exp(total_neg_log_likelihood / total_num_output_samples) validation_ppl_history.append(ppl) print("Validation perplexity is %s" % np.round(ppl, 4)) if ppl <= best_ppl: best_ppl = ppl net.save(model_file_name, gsums=gsums, learning_rate=learning_rate, validation_ppl_history=validation_ppl_history, best_validation_ppl=best_ppl, epoch=epoch, random_state=rng.get_state()) elif best_ppl not in validation_ppl_history[-PATIENCE_EPOCHS:]: print("Finished!") print("Best validation perplexity was %s" % best_ppl) break