import time from tqdm import tqdm import torch from torch import nn, optim from torch.autograd import Variable from torch.nn import functional as F from logger import get_logger from utils import (batch_generator, encode_text, generate_seed, ID2CHAR, main, make_dirs, VOCAB_SIZE) logger = get_logger(__name__) class Model(nn.Module): """ build character embeddings LSTM text generation model. """ def __init__(self, vocab_size=VOCAB_SIZE, embedding_size=32, rnn_size=128, num_layers=2, drop_rate=0.0): super(Model, self).__init__() self.args = {"vocab_size": vocab_size, "embedding_size": embedding_size, "rnn_size": rnn_size, "num_layers": num_layers, "drop_rate": drop_rate} self.encoder = nn.Embedding(vocab_size, embedding_size) self.dropout = nn.Dropout(drop_rate) self.rnn = nn.LSTM(embedding_size, rnn_size, num_layers, dropout=drop_rate) self.decoder = nn.Linear(rnn_size, vocab_size) def forward(self, inputs, state): # input shape: [seq_len, batch_size] embed_seq = self.dropout(self.encoder(inputs)) # shape: [seq_len, batch_size, embedding_size] rnn_out, state = self.rnn(embed_seq, state) # rnn_out shape: [seq_len, batch_size, rnn_size] # hidden shape: [2, num_layers, batch_size, rnn_size] rnn_out = self.dropout(rnn_out) # shape: [seq_len, batch_size, rnn_size] logits = self.decoder(rnn_out.view(-1, rnn_out.size(2))) # output shape: [seq_len * batch_size, vocab_size] return logits, state def predict(self, input, hidden): # input shape: [seq_len, batch_size] logits, hidden = self.forward(input, hidden) # logits shape: [seq_len * batch_size, vocab_size] # hidden shape: [2, num_layers, batch_size, rnn_size] probs = F.softmax(logits) # shape: [seq_len * batch_size, vocab_size] probs = probs.view(input.size(0), input.size(1), probs.size(1)) # output shape: [seq_len, batch_size, vocab_size] return probs, hidden def init_state(self, batch_size=1): """ initialises rnn states. """ return (Variable(torch.zeros(self.args["num_layers"], batch_size, self.args["rnn_size"])), Variable(torch.zeros(self.args["num_layers"], batch_size, self.args["rnn_size"]))) def save(self, checkpoint_path="model.ckpt"): """ saves model and args to checkpoint_path. """ checkpoint = {"args": self.args, "state_dict": self.state_dict()} with open(checkpoint_path, "wb") as f: torch.save(checkpoint, f) logger.info("model saved: %s.", checkpoint_path) @classmethod def load(cls, checkpoint_path): """ loads model from checkpoint_path. """ with open(checkpoint_path, "rb") as f: checkpoint = torch.load(f) model = cls(**checkpoint["args"]) model.load_state_dict(checkpoint["state_dict"]) logger.info("model loaded: %s.", checkpoint_path) return model def sample_from_probs(probs, top_n=10): """ truncated weighted random choice. """ _, indices = torch.sort(probs) # set probabilities after top_n to 0 probs[indices.data[:-top_n]] = 0 sampled_index = torch.multinomial(probs, 1) return sampled_index def generate_text(model, seed, length=512, top_n=10): """ generates text of specified length from trained model with given seed character sequence. """ logger.info("generating %s characters from top %s choices.", length, top_n) logger.info('generating with seed: "%s".', seed) generated = seed encoded = encode_text(seed) encoded = Variable(torch.from_numpy(encoded), volatile=True) model.eval() x = encoded[:-1].unsqueeze(1) # input shape: [seq_len, 1] state = model.init_state() # get rnn state due to seed sequence _, state = model.predict(x, state) next_index = encoded[-1:] for i in range(length): x = next_index.unsqueeze(1) # input shape: [1, 1] probs, state = model.predict(x, state) # output shape: [1, 1, vocab_size] next_index = sample_from_probs(probs.squeeze(), top_n) # append to sequence generated += ID2CHAR[next_index.data[0]] logger.info("generated text: \n%s\n", generated) return generated def train_main(args): """ trains model specfied in args. main method for train subcommand. """ # load text with open(args.text_path) as f: text = f.read() logger.info("corpus length: %s.", len(text)) # load or build model if args.restore: logger.info("restoring model.") load_path = args.checkpoint_path if args.restore is True else args.restore model = Model.load(load_path) else: model = Model(vocab_size=VOCAB_SIZE, embedding_size=args.embedding_size, rnn_size=args.rnn_size, num_layers=args.num_layers, drop_rate=args.drop_rate) # make checkpoint directory make_dirs(args.checkpoint_path) model.save(args.checkpoint_path) # loss function and optimizer criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.learning_rate) # training start num_batches = (len(text) - 1) // (args.batch_size * args.seq_len) data_iter = batch_generator(encode_text(text), args.batch_size, args.seq_len) state = model.init_state(args.batch_size) logger.info("start of training.") time_train = time.time() for i in range(args.num_epochs): epoch_losses = torch.Tensor(num_batches) time_epoch = time.time() # training epoch for j in tqdm(range(num_batches), desc="epoch {}/{}".format(i + 1, args.num_epochs)): # prepare inputs x, y = next(data_iter) x = Variable(torch.from_numpy(x)).t() y = Variable(torch.from_numpy(y)).t().contiguous() # reset state variables to remove their history state = tuple([Variable(var.data) for var in state]) # prepare model model.train() model.zero_grad() # calculate loss logits, state = model.forward(x, state) loss = criterion(logits, y.view(-1)) epoch_losses[j] = loss.data[0] # calculate gradients loss.backward() # clip gradient norm nn.utils.clip_grad_norm(model.parameters(), args.clip_norm) # apply gradient update optimizer.step() # logs duration_epoch = time.time() - time_epoch logger.info("epoch: %s, duration: %ds, loss: %.6g.", i + 1, duration_epoch, epoch_losses.mean()) # checkpoint model.save(args.checkpoint_path) # generate text seed = generate_seed(text) generate_text(model, seed) # training end duration_train = time.time() - time_train logger.info("end of training, duration: %ds.", duration_train) # generate text seed = generate_seed(text) generate_text(model, seed, 1024, 3) return model def generate_main(args): """ generates text from trained model specified in args. main method for generate subcommand. """ # load model inference_model = Model.load(args.checkpoint_path) # create seed if not specified if args.seed is None: with open(args.text_path) as f: text = f.read() seed = generate_seed(text) logger.info("seed sequence generated from %s.", args.text_path) else: seed = args.seed return generate_text(inference_model, seed, args.length, args.top_n) if __name__ == "__main__": main("PyTorch", train_main, generate_main)