import torch import torch.nn as nn import torch.sparse as sparse import torch.nn.functional as F from tool import pyutils import network.vgg16d class Net(network.vgg16d.Net): def __init__(self): super(Net, self).__init__(fc6_dilation=4) self.f8_3 = nn.Conv2d(512, 64, 1, bias=False) self.f8_4 = nn.Conv2d(512, 128, 1, bias=False) self.f8_5 = nn.Conv2d(1024, 256, 1, bias=False) self.gn8_3 = nn.modules.normalization.GroupNorm(8, 64) self.gn8_4 = nn.modules.normalization.GroupNorm(16, 128) self.gn8_5 = nn.modules.normalization.GroupNorm(32, 256) self.f9 = torch.nn.Conv2d(448, 448, 1, bias=False) torch.nn.init.kaiming_normal_(self.f8_3.weight) torch.nn.init.kaiming_normal_(self.f8_4.weight) torch.nn.init.kaiming_normal_(self.f8_5.weight) torch.nn.init.xavier_uniform_(self.f9.weight, gain=4) self.not_training = [self.conv1_1, self.conv1_2, self.conv2_1, self.conv2_2] self.from_scratch_layers = [self.f8_3, self.f8_4, self.f8_5, self.f9] self.predefined_featuresize = int(448//8) self.ind_from, self.ind_to = pyutils.get_indices_of_pairs(5, (self.predefined_featuresize, self.predefined_featuresize)) self.ind_from = torch.from_numpy(self.ind_from); self.ind_to = torch.from_numpy(self.ind_to) return def forward(self, x, to_dense=False): d = super().forward_as_dict(x) f8_3 = F.elu(self.gn8_3(self.f8_3(d['conv4']))) f8_4 = F.elu(self.gn8_4(self.f8_4(d['conv5']))) f8_5 = F.elu(self.gn8_5(self.f8_5(d['conv5fc']))) x = torch.cat([f8_3, f8_4, f8_5], dim=1) x = F.elu(self.f9(x)) if x.size(2) == self.predefined_featuresize and x.size(3) == self.predefined_featuresize: ind_from = self.ind_from ind_to = self.ind_to else: ind_from, ind_to = pyutils.get_indices_of_pairs(5, (x.size(2), x.size(3))) ind_from = torch.from_numpy(ind_from); ind_to = torch.from_numpy(ind_to) x = x.view(x.size(0), x.size(1), -1) ff = torch.index_select(x, dim=2, index=ind_from.cuda(non_blocking=True)) ft = torch.index_select(x, dim=2, index=ind_to.cuda(non_blocking=True)) ff = torch.unsqueeze(ff, dim=2) ft = ft.view(ft.size(0), ft.size(1), -1, ff.size(3)) aff = torch.exp(-torch.mean(torch.abs(ft-ff), dim=1)) if to_dense: aff = aff.view(-1).cpu() ind_from_exp = torch.unsqueeze(ind_from, dim=0).expand(ft.size(2), -1).contiguous().view(-1) indices = torch.stack([ind_from_exp, ind_to]) indices_tp = torch.stack([ind_to, ind_from_exp]) area = x.size(2) indices_id = torch.stack([torch.arange(0, area).long(), torch.arange(0, area).long()]) aff_mat = sparse.FloatTensor(torch.cat([indices, indices_id, indices_tp], dim=1), torch.cat([aff, torch.ones([area]), aff])).to_dense().cuda() return aff_mat else: return aff def get_parameter_groups(self): groups = ([], [], [], []) for m in self.modules(): if (isinstance(m, nn.Conv2d) or isinstance(m, nn.modules.normalization.GroupNorm)): if m.weight.requires_grad: if m in self.from_scratch_layers: groups[2].append(m.weight) else: groups[0].append(m.weight) if m.bias is not None and m.bias.requires_grad: if m in self.from_scratch_layers: groups[3].append(m.bias) else: groups[1].append(m.bias) return groups