from __future__ import absolute_import from __future__ import division from __future__ import print_function import torch import torch.nn as nn import torch.nn.functional as F from core.controller import Controller class LSTMController(Controller): def __init__(self, args): super(LSTMController, self).__init__(args) # build model self.in_2_hid = nn.LSTMCell(self.input_dim + self.read_vec_dim, self.hidden_dim, 1) self._reset() def _init_weights(self): pass def forward(self, input_vb, read_vec_vb): self.lstm_hidden_vb = self.in_2_hid(torch.cat((input_vb.contiguous().view(-1, self.input_dim), read_vec_vb.contiguous().view(-1, self.read_vec_dim)), 1), self.lstm_hidden_vb) # we clip the controller hidden states here self.lstm_hidden_vb = [self.lstm_hidden_vb[0].clamp(min=-self.clip_value, max=self.clip_value), self.lstm_hidden_vb[1]] return self.lstm_hidden_vb[0]