from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import numpy as np import random import torch import torch.nn as nn import torch.autograd as autograd from . import model class Seq2Seq(model.Model): def __init__(self, freq_dim, vocab_size, config): super().__init__(freq_dim, config) # For decoding decoder_cfg = config["decoder"] rnn_dim = self.encoder_dim embed_dim = decoder_cfg["embedding_dim"] self.embedding = nn.Embedding(vocab_size, embed_dim) self.dec_rnn = nn.GRUCell(input_size=embed_dim, hidden_size=rnn_dim) self.attend = NNAttention(rnn_dim, log_t=decoder_cfg.get("log_t", False)) self.sample_prob = decoder_cfg.get("sample_prob", 0) self.scheduled_sampling = (self.sample_prob != 0) # *NB* we predict vocab_size - 1 classes since we # never need to predict the start of sequence token. self.fc = model.LinearND(rnn_dim, vocab_size - 1) def set_eval(self): """ Set the model to evaluation mode. """ self.eval() self.volatile = True self.scheduled_sampling = False def set_train(self): """ Set the model to training mode. """ self.train() self.volatile = False self.scheduled_sampling = (self.sample_prob != 0) def loss(self, batch): x, y = self.collate(*batch) if self.is_cuda: x = x.cuda() y = y.cuda() out, alis = self.forward_impl(x, y) batch_size, _, out_dim = out.size() out = out.view((-1, out_dim)) y = y[:,1:].contiguous().view(-1) loss = nn.functional.cross_entropy(out, y, size_average=False) loss = loss / batch_size return loss def forward_impl(self, x, y): x = self.encode(x) out, alis = self.decode(x, y) return out, alis def forward(self, batch): x, y = self.collate(*batch) if self.is_cuda: x = x.cuda() y = y.cuda() return self.forward_impl(x, y)[0] def decode(self, x, y): """ x should be shape (batch, time, hidden dimension) y should be shape (batch, label sequence length) """ inputs = self.embedding(y[:, :-1]) out = []; aligns = [] hx = torch.zeros((x.shape[0], x.shape[2]), requires_grad=False) if self.is_cuda: hx.cuda() ax = None; sx = None; for t in range(y.size()[1] - 1): sample = (out and self.scheduled_sampling) if sample and random.random() < self.sample_prob: ix = torch.max(out[-1], dim=2)[1] ix = self.embedding(ix) else: ix = inputs[:, t:t+1, :] if sx is not None: ix = ix + sx hx = self.dec_rnn(ix.squeeze(dim=1), hx) ox = hx.unsqueeze(dim=1) sx, ax = self.attend(x, ox, ax) aligns.append(ax) out.append(self.fc(ox + sx)) out = torch.cat(out, dim=1) aligns = torch.stack(aligns, dim=1) return out, aligns def decode_step(self, x, y, state=None, softmax=False): """ x should be shape (batch, time, hidden dimension) y should be shape (batch, label sequence length) """ if state is None: hx = torch.zeros((x.shape[0], x.shape[2]), requires_grad=False) if self.is_cuda: hx.cuda() ax = None; sx = None; else: hx, ax, sx = state ix = self.embedding(y) if sx is not None: ix = ix + sx hx = self.dec_rnn(ix.squeeze(dim=1), hx=hx) ox = hx.unsqueeze(dim=1) sx, ax = self.attend(x, ox, ax=ax) out = ox + sx out = self.fc(out.squeeze(dim=1)) if softmax: out = nn.functional.log_softmax(out, dim=1) return out, (hx, ax, sx) def predict(self, batch): probs = self(batch) argmaxs = torch.max(probs, dim=2)[1] argmaxs = argmaxs.cpu().data.numpy() return [seq.tolist() for seq in argmaxs] def infer_decode(self, x, y, end_tok, max_len): probs = [] argmaxs = [y] state = None for e in range(max_len): out, state = self.decode_step(x, y, state=state) probs.append(out) y = torch.max(out, dim=1)[1] y = y.unsqueeze(dim=1) argmaxs.append(y) if torch.sum(y.data == end_tok) == y.numel(): break probs = torch.cat(probs) argmaxs = torch.cat(argmaxs, dim=1) return probs, argmaxs def infer(self, batch, max_len=200): """ Infer a likely output. No beam search yet. """ x, y = self.collate(*batch) end_tok = y.data[0, -1] # TODO t = y if self.is_cuda: x = x.cuda() t = y.cuda() x = self.encode(x) # needs to be the start token, TODO y = t[:, 0:1] _, argmaxs = self.infer_decode(x, y, end_tok, max_len) argmaxs = argmaxs.cpu().data.numpy() return [seq.tolist() for seq in argmaxs] def beam_search(self, batch, beam_size=10, max_len=200): x, y = self.collate(*batch) start_tok = y.data[0, 0] end_tok = y.data[0, -1] # TODO if self.is_cuda: x = x.cuda() y = y.cuda() x = self.encode(x) y = y[:, 0:1].clone() beam = [((start_tok,), 0, None)]; complete = [] for _ in range(max_len): new_beam = [] for hyp, score, state in beam: y[0] = hyp[-1] out, state = self.decode_step(x, y, state=state, softmax=True) out = out.cpu().data.numpy().squeeze(axis=0).tolist() for i, p in enumerate(out): new_score = score + p new_hyp = hyp + (i,) new_beam.append((new_hyp, new_score, state)) new_beam = sorted(new_beam, key=lambda x: x[1], reverse=True) # Remove complete hypotheses for cand in new_beam[:beam_size]: if cand[0][-1] == end_tok: complete.append(cand) beam = filter(lambda x : x[0][-1] != end_tok, new_beam) beam = beam[:beam_size] if len(beam) == 0: break # Stopping criteria: # complete contains beam_size more probable # candidates than anything left in the beam if sum(c[1] > beam[0][1] for c in complete) >= beam_size: break complete = sorted(complete, key=lambda x: x[1], reverse=True) if len(complete) == 0: complete = beam hyp, score, _ = complete[0] return [hyp] def collate(self, inputs, labels): inputs = model.zero_pad_concat(inputs) labels = end_pad_concat(labels) inputs = torch.from_numpy(inputs) labels = torch.from_numpy(labels) if self.volatile: inputs.volatile = True labels.volatile = True return inputs, labels def end_pad_concat(labels): # Assumes last item in each example is the end token. batch_size = len(labels) end_tok = labels[0][-1] max_len = max(len(l) for l in labels) cat_labels = np.full((batch_size, max_len), fill_value=end_tok, dtype=np.int64) for e, l in enumerate(labels): cat_labels[e, :len(l)] = l return cat_labels class Attention(nn.Module): def __init__(self, kernel_size=11, log_t=False): """ Module which Performs a single attention step along the second axis of a given encoded input. The module uses both 'content' and 'location' based attention. The 'content' based attention is an inner product of the decoder hidden state with each time-step of the encoder state. The 'location' based attention performs a 1D convollution on the previous attention vector and adds this into the next attention vector prior to normalization. *NB* Should compute attention differently if using cuda or cpu based on performance. See https://gist.github.com/awni/9989dd31642d42405903dec8ab91d1f0 """ super(Attention, self).__init__() assert kernel_size % 2 == 1, \ "Kernel size should be odd for 'same' conv." padding = (kernel_size - 1) // 2 self.conv = nn.Conv1d(1, 1, kernel_size, padding=padding) self.log_t = log_t def forward(self, eh, dhx, ax=None): """ Arguments: eh (FloatTensor): the encoder hidden state with shape (batch size, time, hidden dimension). dhx (FloatTensor): one time step of the decoder hidden state with shape (batch size, hidden dimension). The hidden dimension must match that of the encoder state. ax (FloatTensor): one time step of the attention vector. Returns the summary of the encoded hidden state and the corresponding alignment. """ # Compute inner product of decoder slice with every # encoder slice. # location attention pax = eh * dhx pax = torch.sum(pax, dim=2) if ax is not None: ax = ax.unsqueeze(dim=1) ax = self.conv(ax).squeeze(dim=1) pax = pax + ax if self.log_t: log_t = math.log(pax.size()[1]) pax = log_t * pax ax = nn.functional.softmax(pax, dim=1) # At this point sx should have size (batch size, time). # Reduce the encoder state accross time weighting each # slice by its corresponding value in sx. sx = ax.unsqueeze(2) sx = torch.sum(eh * sx, dim=1, keepdim=True) return sx, ax class ProdAttention(nn.Module): def __init__(self): super(ProdAttention, self).__init__() def forward(self, eh, dhx, ax=None): pax = eh * dhx pax = torch.sum(pax, dim=2) ax = nn.functional.softmax(pax, dim=1) sx = ax.unsqueeze(2) sx = torch.sum(eh * sx, dim=1, keepdim=True) return sx, ax class NNAttention(nn.Module): def __init__(self, n_channels, kernel_size=15, log_t=False): super(NNAttention, self).__init__() assert kernel_size % 2 == 1, \ "Kernel size should be odd for 'same' conv." padding = (kernel_size - 1) // 2 self.conv = nn.Conv1d(1, n_channels, kernel_size, padding=padding) self.nn = nn.Sequential( nn.ReLU(), model.LinearND(n_channels, 1)) self.log_t = log_t def forward(self, eh, dhx, ax=None): pax = eh + dhx if ax is not None: ax = ax.unsqueeze(dim=1) ax = self.conv(ax).transpose(1, 2) pax = pax + ax pax = self.nn(pax) pax = pax.squeeze(dim=2) if self.log_t: log_t = math.log(pax.size()[1]) pax = log_t * pax ax = nn.functional.softmax(pax, dim=1) sx = ax.unsqueeze(2) sx = torch.sum(eh * sx, dim=1, keepdim=True) return sx, ax