from __future__ import absolute_import from __future__ import division from __future__ import unicode_literals from __future__ import print_function import torch.nn as nn import torch as th import torch.nn.functional as F import re class Net(nn.Module): def __init__( self, embd_dim=1024, video_dim=2048, n_pair=1, we_dim=300, max_words=30, sentence_dim=-1, we=None, ): super(Net, self).__init__() if sentence_dim <= 0: self.text_pooling = Sentence_Maxpool(we_dim, embd_dim) else: self.text_pooling = Sentence_Maxpool(we_dim, sentence_dim) self.GU_text = Gated_Embedding_Unit( self.text_pooling.out_dim, embd_dim, gating=True) self.GU_video = Gated_Embedding_Unit( video_dim, embd_dim, gating=True) self.n_pair = n_pair self.embd_dim = embd_dim self.we = we self.we_dim = we_dim def save_checkpoint(self, path): th.save(self.state_dict(), path) def load_checkpoint(self, path, cpu=False): if cpu: self.load_state_dict(th.load(path, map_location=lambda storage, loc: storage)) else: self.load_state_dict(th.load(path)) def forward(self, video, text): video = self.GU_video(video) text = self.GU_text(self.text_pooling(text)) return th.matmul(text, video.t()) class Gated_Embedding_Unit(nn.Module): def __init__(self, input_dimension, output_dimension, gating=True): super(Gated_Embedding_Unit, self).__init__() self.fc = nn.Linear(input_dimension, output_dimension) self.cg = Context_Gating(output_dimension) self.gating = gating def forward(self, x): x = self.fc(x) if self.gating: x = self.cg(x) x = F.normalize(x) return x class Sentence_Maxpool(nn.Module): def __init__(self, word_dimension, output_dim, relu=True): super(Sentence_Maxpool, self).__init__() self.fc = nn.Linear(word_dimension, output_dim) self.out_dim = output_dim self.relu = relu def forward(self, x): x = self.fc(x) if self.relu: x = F.relu(x) return th.max(x, dim=1)[0] class Context_Gating(nn.Module): def __init__(self, dimension, add_batch_norm=False): super(Context_Gating, self).__init__() self.fc = nn.Linear(dimension, dimension) self.add_batch_norm = add_batch_norm self.batch_norm = nn.BatchNorm1d(dimension) def forward(self, x): x1 = self.fc(x) if self.add_batch_norm: x1 = self.batch_norm(x1) x = th.cat((x, x1), 1) return F.glu(x, 1)