import torch import torch.nn as nn from torch.autograd import Variable from torch.nn import functional as F import s2s.modules from torch.nn.utils.rnn import pad_packed_sequence as unpack from torch.nn.utils.rnn import pack_padded_sequence as pack try: import ipdb except ImportError: pass class Encoder(nn.Module): def __init__(self, opt, dicts): self.layers = opt.layers self.num_directions = 2 if opt.brnn else 1 assert opt.enc_rnn_size % self.num_directions == 0 self.hidden_size = opt.enc_rnn_size // self.num_directions input_size = opt.word_vec_size super(Encoder, self).__init__() self.word_lut = nn.Embedding(dicts.size(), opt.word_vec_size, padding_idx=s2s.Constants.PAD) self.bio_lut = nn.Embedding(8, 16, padding_idx=s2s.Constants.PAD) # TODO: Fix this magic number self.feat_lut = nn.Embedding(64, 16, padding_idx=s2s.Constants.PAD) # TODO: Fix this magic number input_size = input_size + 16 + 16 * 3 self.rnn = nn.GRU(input_size, self.hidden_size, num_layers=opt.layers, dropout=opt.dropout, bidirectional=opt.brnn) def load_pretrained_vectors(self, opt): if opt.pre_word_vecs_enc is not None: pretrained = torch.load(opt.pre_word_vecs_enc) self.word_lut.weight.data.copy_(pretrained) def forward(self, input, bio, feats, hidden=None): """ input: (wrap(srcBatch), wrap(srcBioBatch), lengths) """ lengths = input[-1].data.view(-1).tolist() # lengths data is wrapped inside a Variable wordEmb = self.word_lut(input[0]) bioEmb = self.bio_lut(bio[0]) featsEmb = [self.feat_lut(feat) for feat in feats[0]] featsEmb = torch.cat(featsEmb, dim=-1) input_emb = torch.cat((wordEmb, bioEmb, featsEmb), dim=-1) emb = pack(input_emb, lengths) outputs, hidden_t = self.rnn(emb, hidden) if isinstance(input, tuple): outputs = unpack(outputs)[0] return hidden_t, outputs class StackedGRU(nn.Module): def __init__(self, num_layers, input_size, rnn_size, dropout): super(StackedGRU, self).__init__() self.dropout = nn.Dropout(dropout) self.num_layers = num_layers self.layers = nn.ModuleList() for i in range(num_layers): self.layers.append(nn.GRUCell(input_size, rnn_size)) input_size = rnn_size def forward(self, input, hidden): h_0 = hidden h_1 = [] for i, layer in enumerate(self.layers): h_1_i = layer(input, h_0[i]) input = h_1_i if i + 1 != self.num_layers: input = self.dropout(input) h_1 += [h_1_i] h_1 = torch.stack(h_1) return input, h_1 class Decoder(nn.Module): def __init__(self, opt, dicts): self.opt = opt self.layers = opt.layers self.input_feed = opt.input_feed input_size = opt.word_vec_size if self.input_feed: input_size += opt.enc_rnn_size super(Decoder, self).__init__() self.word_lut = nn.Embedding(dicts.size(), opt.word_vec_size, padding_idx=s2s.Constants.PAD) self.rnn = StackedGRU(opt.layers, input_size, opt.dec_rnn_size, opt.dropout) self.attn = s2s.modules.ConcatAttention(opt.enc_rnn_size, opt.dec_rnn_size, opt.att_vec_size) self.dropout = nn.Dropout(opt.dropout) self.readout = nn.Linear((opt.enc_rnn_size + opt.dec_rnn_size + opt.word_vec_size), opt.dec_rnn_size) self.maxout = s2s.modules.MaxOut(opt.maxout_pool_size) self.maxout_pool_size = opt.maxout_pool_size self.copySwitch = nn.Linear(opt.enc_rnn_size + opt.dec_rnn_size, 1) self.hidden_size = opt.dec_rnn_size def load_pretrained_vectors(self, opt): if opt.pre_word_vecs_dec is not None: pretrained = torch.load(opt.pre_word_vecs_dec) self.word_lut.weight.data.copy_(pretrained) def forward(self, input, hidden, context, src_pad_mask, init_att): emb = self.word_lut(input) g_outputs = [] c_outputs = [] copyGateOutputs = [] cur_context = init_att self.attn.applyMask(src_pad_mask) precompute = None for emb_t in emb.split(1): emb_t = emb_t.squeeze(0) input_emb = emb_t if self.input_feed: input_emb = torch.cat([emb_t, cur_context], 1) output, hidden = self.rnn(input_emb, hidden) cur_context, attn, precompute = self.attn(output, context.transpose(0, 1), precompute) copyProb = self.copySwitch(torch.cat((output, cur_context), dim=1)) copyProb = F.sigmoid(copyProb) readout = self.readout(torch.cat((emb_t, output, cur_context), dim=1)) maxout = self.maxout(readout) output = self.dropout(maxout) g_outputs += [output] c_outputs += [attn] copyGateOutputs += [copyProb] g_outputs = torch.stack(g_outputs) c_outputs = torch.stack(c_outputs) copyGateOutputs = torch.stack(copyGateOutputs) return g_outputs, c_outputs, copyGateOutputs, hidden, attn, cur_context class DecInit(nn.Module): def __init__(self, opt): super(DecInit, self).__init__() self.num_directions = 2 if opt.brnn else 1 assert opt.enc_rnn_size % self.num_directions == 0 self.enc_rnn_size = opt.enc_rnn_size self.dec_rnn_size = opt.dec_rnn_size self.initer = nn.Linear(self.enc_rnn_size // self.num_directions, self.dec_rnn_size) self.tanh = nn.Tanh() def forward(self, last_enc_h): # batchSize = last_enc_h.size(0) # dim = last_enc_h.size(1) return self.tanh(self.initer(last_enc_h)) class NMTModel(nn.Module): def __init__(self, encoder, decoder, decIniter): super(NMTModel, self).__init__() self.encoder = encoder self.decoder = decoder self.decIniter = decIniter def make_init_att(self, context): batch_size = context.size(1) h_size = (batch_size, self.encoder.hidden_size * self.encoder.num_directions) return Variable(context.data.new(*h_size).zero_(), requires_grad=False) def forward(self, input): """ (wrap(srcBatch), lengths), \ (wrap(bioBatch), lengths), ((wrap(x) for x in featBatches), lengths), \ (wrap(tgtBatch), wrap(copySwitchBatch), wrap(copyTgtBatch)), \ indices """ # ipdb.set_trace() src = input[0] tgt = input[3][0][:-1] # exclude last target from inputs src_pad_mask = Variable(src[0].data.eq(s2s.Constants.PAD).transpose(0, 1).float(), requires_grad=False, volatile=False) bio = input[1] feats = input[2] enc_hidden, context = self.encoder(src, bio, feats) init_att = self.make_init_att(context) enc_hidden = self.decIniter(enc_hidden[1]).unsqueeze(0) # [1] is the last backward hiden g_out, c_out, c_gate_out, dec_hidden, _attn, _attention_vector = self.decoder(tgt, enc_hidden, context, src_pad_mask, init_att) return g_out, c_out, c_gate_out