import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from torch.autograd import Variable from torch.nn.utils import weight_norm from torch.nn.utils.rnn import pack_padded_sequence from torch.nn.parameter import Parameter import numpy as np import config import word_embedding from reuse_modules import Fusion, FCNet class Net(nn.Module): def __init__(self, words_list): super(Net, self).__init__() mid_features = 1024 question_features = mid_features vision_features = config.output_features self.top_k_sparse = 16 num_kernels = 8 sparse_graph = True self.text = word_embedding.TextProcessor( classes=words_list, embedding_features=300, lstm_features=question_features, drop=0.0, ) self.pseudo_coord = PseudoCoord() self.graph_learner = GraphLearner( v_features=vision_features+4, q_features=question_features, mid_features=512, dropout=0.5, sparse_graph=sparse_graph, ) self.graph_conv1 = GraphConv( v_features=vision_features+4, mid_features=mid_features * 2, num_kernels=num_kernels, bias=False ) self.graph_conv2 = GraphConv( v_features=mid_features*2, mid_features=mid_features, num_kernels=num_kernels, bias=False ) self.classifier = Classifier( in_features=mid_features, mid_features=mid_features*2, out_features=config.max_answers, drop=0.5,) self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) def forward(self, v, b, q, v_mask, q_mask, q_len): ''' v: visual feature [batch, num_obj, 2048] b: bounding box [batch, num_obj, 4] q: question [batch, max_q_len] v_mask: number of obj [batch, max_obj] 1 is obj, 0 is none q_mask: question length [batch, max_len] 1 is word, 0 is none answer: predict logits [batch, config.max_answers] ''' q = self.text(q, list(q_len.data)) # [batch, 1024] v = self.dropout(v) v = torch.cat((v, b), dim=2) # [batch, 2048+4] new_coord = self.pseudo_coord(b) #[batch, num_obj, num_obj, 2] adj_matrix, top_ind = self.graph_learner(v, q, v_mask, top_K=self.top_k_sparse) #[batch, num_obj, K] hid_v1 = self.graph_conv1(v, v_mask, new_coord, adj_matrix, top_ind, weight_adj=True) hid_v1 = self.dropout(self.relu(hid_v1)) hid_v2 = self.graph_conv2(hid_v1, v_mask, new_coord, adj_matrix, top_ind, weight_adj=False) hid_v2 = self.relu(hid_v2) # [batch, num_obj, dim] #hid_v2 = hid_v2 * v_mask.unsqueeze(-1) max_pooled_v = torch.max(hid_v2, dim=1)[0] # [batch, dim] answer = self.classifier(max_pooled_v, q) return answer class Classifier(nn.Module): def __init__(self, in_features, mid_features, out_features, drop=0.0): super(Classifier, self).__init__() self.lin1 = FCNet(in_features, mid_features, activate='relu') self.lin2 = FCNet(mid_features, out_features, drop=drop) self.relu = nn.ReLU() def forward(self, v, q): x = v * self.relu(q) x = self.lin1(x) x = self.lin2(x) return x class PseudoCoord(nn.Module): def __init__(self): super(PseudoCoord, self).__init__() def forward(self, b): ''' Input: b: bounding box [batch, num_obj, 4] (x1,y1,x2,y2) Output: pseudo_coord [batch, num_obj, num_obj, 2] (rho, theta) ''' batch_size, num_obj, _ = b.shape centers = (b[:,:,2:] + b[:,:,:2]) * 0.5 relative_coord = centers.view(batch_size, num_obj, 1, 2) - \ centers.view(batch_size, 1, num_obj, 2) # broadcast: [batch, num_obj, num_obj, 2] rho = torch.sqrt(relative_coord[:,:,:,0]**2 + relative_coord[:,:,:,1]**2) theta = torch.atan2(relative_coord[:,:,:,0], relative_coord[:,:,:,1]) new_coord = torch.cat((rho.unsqueeze(-1), theta.unsqueeze(-1)), dim=-1) return new_coord class GraphLearner(nn.Module): def __init__(self, v_features, q_features, mid_features, dropout=0.0, sparse_graph=True): super(GraphLearner, self).__init__() self.sparse_graph = sparse_graph self.lin1 = FCNet(v_features + q_features, mid_features, activate='relu') self.lin2 = FCNet(mid_features, mid_features, activate='relu') def forward(self, v, q, v_mask, top_K): ''' Input: v: visual feature [batch, num_obj, 2048] q: bounding box [batch, 1024] v_mask: number of obj [batch, max_obj] 1 is obj, 0 is none Return: adjacent_logits [batch, num_obj, K(sum=1)] adjacent_matrix [batch, num_obj, K(sum=1)] ''' batch_size, num_obj, _ = v.shape q_repeated = q.unsqueeze(1).repeat(1, num_obj, 1) v_cat_q = torch.cat((v, q_repeated), dim=2) h = self.lin1(v_cat_q) h = self.lin2(h) h = h.view(batch_size, num_obj, -1) # batch_size, num_obj, feat_size adjacent_logits = torch.matmul(h, h.transpose(1, 2)) # batch_size, num_obj, num_obj # object mask #mask = torch.matmul(v_mask.unsqueeze(2), v_mask.unsqueeze(1)) #adjacent_logits = adjacent_logits * mask # sparse adjacent matrix if self.sparse_graph: top_value, top_ind = torch.topk(adjacent_logits, k=top_K, dim=-1, sorted=False) # batch_size, num_obj, K # softmax attention adjacent_matrix = F.softmax(top_value, dim=-1) # batch_size, num_obj, K return adjacent_matrix, top_ind class GraphConv(nn.Module): def __init__(self, v_features, mid_features, num_kernels, bias=False): super(GraphConv, self).__init__() self.num_kernels = num_kernels # for graph conv self.conv_weights = nn.ModuleList([nn.Linear( v_features, mid_features//(num_kernels), bias=bias) for i in range(num_kernels)]) # for gaussian kernels self.mean_rho = Parameter(torch.FloatTensor(num_kernels, 1)) self.mean_theta = Parameter(torch.FloatTensor(num_kernels, 1)) self.precision_rho = Parameter(torch.FloatTensor(num_kernels, 1)) self.precision_theta = Parameter(torch.FloatTensor(num_kernels, 1)) self.init_param() def init_param(self): self.mean_rho.data.uniform_(0, 1.0) self.mean_theta.data.uniform_(-np.pi, np.pi) self.precision_rho.data.uniform_(0, 1.0) self.precision_theta.data.uniform_(0, 1.0) def forward(self, v, v_mask, coord, adj_matrix, top_ind, weight_adj=True): """ Input: v: visual feature [batch, num_obj, 2048] v_mask: number of obj [batch, max_obj] 1 is obj, 0 is none coord: relative coord [batch, num_obj, num_obj, 2] obj to obj relative coord adj_matrix: sparse [batch, num_obj, K(sum=1)] top_ind: [batch, num_obj, K] Output: v: visual feature [batch, num_obj, dim] """ batch_size, num_obj, feat_dim = v.shape K = adj_matrix.shape[-1] conv_v = v.unsqueeze(1).expand(batch_size, num_obj, num_obj, feat_dim) # batch_size, num_obj(same), num_obj(diff), feat_dim coord_weight = self.get_gaussian_weights(coord) # batch, num_obj, num_obj(diff), n_kernels slice_idx1 = top_ind.unsqueeze(-1).expand(batch_size, num_obj, K, feat_dim) # batch_size, num_obj, K, feat_dim slice_idx2 = top_ind.unsqueeze(-1).expand(batch_size, num_obj, K, self.num_kernels) # batch_size, num_obj, K, num_kernels sparse_v = torch.gather(conv_v, dim=2, index=slice_idx1) sparse_weight = torch.gather(coord_weight, dim=2, index=slice_idx2) if weight_adj: adj_mat = adj_matrix.unsqueeze(-1) # batch, num_obj, K(sum=1), 1 attentive_v = sparse_v * adj_mat # update feature : batch_size, num_obj, K(diff), feat_dim else: attentive_v = sparse_v # update feature : batch_size, num_obj(same), K(diff), feat_dim weighted_neighbourhood = torch.matmul(sparse_weight.transpose(2, 3), attentive_v) # batch, num_obj, n_kernels, feat_dim weighted_neighbourhood = [self.conv_weights[i](weighted_neighbourhood[:, :, i, :]) for i in range(self.num_kernels)] # each: batch, num_obj, feat_dim output = torch.cat(weighted_neighbourhood, dim=2) # batch, num_obj(same), feat_dim return output def get_gaussian_weights(self, coord): """ Input: coord: relative coord [batch, num_obj, num_obj, 2] obj to obj relative coord Output: weights [batch, num_obj, num_obj, n_kernels) """ batch_size, num_obj, _, _ = coord.shape # compute rho weights diff = (coord[:, :, :, 0].contiguous().view(-1, 1) - self.mean_rho.view(1, -1))**2 # batch*num_obj*num_obj, n_kernels weights_rho = torch.exp(-0.5 * diff / (1e-14 + self.precision_rho.view(1, -1)**2)) # batch*num_obj*num_obj, n_kernels # compute theta weights first_angle = torch.abs(coord[:, :, :, 1].contiguous().view(-1, 1) - self.mean_theta.view(1, -1)) second_angle = torch.abs(2 * np.pi - first_angle) weights_theta = torch.exp(-0.5 * (torch.min(first_angle, second_angle)**2) / (1e-14 + self.precision_theta.view(1, -1)**2)) weights = weights_rho * weights_theta weights[(weights != weights).detach()] = 0 # normalise weights weights = weights / (torch.sum(weights, dim=1, keepdim=True) + 1e-14) # batch*num_obj*num_obj, n_kernels (sum=-1) return weights.view(batch_size, num_obj, num_obj, self.num_kernels) # 1. weights Normalized on object dim # 2. second time still use weight """ class GraphConv(nn.Module): def __init__(self, v_features, q_features, mid_features, output_features, num_head, sparse_graph, drop=0.0): super(GraphConv, self).__init__() self.num_head = num_head self.norm_term = (mid_features / num_head) ** 0.5 self.sparse_graph = sparse_graph assert(v_features == output_features) self.lin_v = FCNet(v_features, mid_features, drop=drop) self.lin_q = FCNet(q_features, mid_features, drop=drop) self.lin_pass_v = FCNet(v_features, output_features, drop=drop) self.lin_pass_q = FCNet(q_features, output_features, drop=drop) self.sigmoid = nn.Sigmoid() self.tanh = nn.Tanh() def forward(self, v, q, v_mask, sparse_num): #v = batch, num_obj, dim #q = batch, dim #sparse_num: if graph is sparse, how many neighbor will be included batch_size, num_obj, _ = v.shape # for discrete mask att_v = self.lin_v(v) att_q = self.lin_q(q) v_on_q = att_v * att_q.unsqueeze(1) #batch, num_obj, dim # multi-head v_on_q_splits = v_on_q.view(batch_size, num_obj, self.num_head, -1).transpose(1,2) # batch, num_head, num_obj, (dim // num_head) attention_logits = torch.matmul(v_on_q_splits, v_on_q_splits.transpose(2,3)) / self.norm_term # batch, num_head, num_obj, num_obj attention_logits.masked_fill_(v_mask.unsqueeze(2).unsqueeze(1) @ v_mask.unsqueeze(1).unsqueeze(1) == 0, -float('inf')) if self.sparse_graph: # select topk neighbors for each object, the attention to rest neighbours will be 0 attention_logits = attention_logits.view(-1, num_obj) _, topk_indices = torch.topk(attention_logits, sparse_num) mask = torch.zeros(attention_logits.shape).cuda().scatter_(1, topk_indices, 1) attention_logits.masked_fill_(mask==0, -float('inf')) attention_logits = attention_logits.view(batch_size, self.num_head, num_obj, num_obj) attentions = F.softmax(attention_logits, dim=3) print('Sparse Attention Example: ', attentions[1,1,1,:]) # propagation pass_v = self.tanh(self.lin_pass_v(v)).view(batch_size, num_obj, self.num_head, -1) #batch, num_obj, num_head, (dim // num_head) gate_q = self.sigmoid(self.lin_pass_q(q)).view(batch_size, self.num_head, -1) #batch, num_head, (dim // num_head) pass_v = pass_v.transpose(1,2).unsqueeze(4) #batch, num_head, num_obj, (dim // num_head), 1 attentions = attentions.unsqueeze(3) # #batch, num_head, num_obj, 1, num_obj update = (pass_v * attentions).sum(4) #batch, num_head, num_obj, (dim // num_head) gated_update = (update * gate_q.unsqueeze(2)).transpose(1,2).contiguous().view(batch_size, num_obj, -1) #batch, num_obj, dim new_v = v + gated_update return new_v class GraphPool(nn.Module): def __init__(self, v_features, q_features, mid_features, output_features, num_nodes, drop=0.0): super(GraphPool, self).__init__() self.lin_assign_v = FCNet(v_features, mid_features, drop=drop) self.lin_assign_q = FCNet(q_features, mid_features, drop=drop) self.lin_assign = FCNet(mid_features, num_nodes, drop=drop) self.lin_v = FCNet(v_features, output_features, activate='relu', drop=drop) def forward(self, v, q, v_mask): batch_size, _ = v.shape # for virtual nodes assignment assign_v = self.lin_assign_v(v) assign_q = self.lin_assign_q(q) v_on_q = assign_v * assign_q.unsqueeze(1) #batch, num_obj, dim assign = self.lin_assign(v_on_q) #batch, num_obj, assignment assign = F.softmax(assign, dim=2).transpose(1,2) #batch, assignment, num_obj # value value_v = self.lin_v(v) * v_mask.unsqueeze(2) #batch, num_obj, dim output = assign @ value_v # batch, assignment, dim return output.view(batch_size, -1) """