import numpy as np import torch import torch.nn as nn import ablation_vgg16_c def crop(data1, data2, crop_h, crop_w): _, _, h1, w1 = data1.size() _, _, h2, w2 = data2.size() assert(h2 <= h1 and w2 <= w1) data = data1[:, :, crop_h:crop_h+h2, crop_w:crop_w+w2] return data def get_upsampling_weight(in_channels, out_channels, kernel_size): """Make a 2D bilinear kernel suitable for upsampling""" factor = (kernel_size + 1) // 2 if kernel_size % 2 == 1: center = factor - 1 else: center = factor - 0.5 og = np.ogrid[:kernel_size, :kernel_size] filt = (1 - abs(og[0] - center) / factor) * \ (1 - abs(og[1] - center) / factor) weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) weight[range(in_channels), range(out_channels), :, :] = filt return torch.from_numpy(weight).float() class MSBlock(nn.Module): def __init__(self, c_in, k=3, rate=4): super(MSBlock, self).__init__() c_out = c_in self.k = k self.rate = rate self.conv = nn.Conv2d(c_in, 32, 3, stride=1, padding=1) self.relu = nn.ReLU(inplace=True) if k>=1: dilation = self.rate*1 if self.rate >= 1 else 1 self.conv1 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) self.relu1 = nn.ReLU(inplace=True) if k>=2: dilation = self.rate*2 if self.rate >= 1 else 1 self.conv2 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) self.relu2 = nn.ReLU(inplace=True) if k>=3: dilation = self.rate*3 if self.rate >= 1 else 1 self.conv3 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) self.relu3 = nn.ReLU(inplace=True) if k>=4: dilation = self.rate*4 if self.rate >= 1 else 1 self.conv4 = nn.Conv2d(32, 32, 3, stride=1, dilation=dilation, padding=dilation) self.relu4 = nn.ReLU(inplace=True) self._initialize_weights() def forward(self, x): o = self.relu(self.conv(x)) if self.k>=1: o1 = self.relu1(self.conv1(o)) if self.k>=2: o2 = self.relu2(self.conv2(o)) if self.k>=3: o3 = self.relu3(self.conv3(o)) if self.k>=4: o4 = self.relu4(self.conv4(o)) if self.k < 1: return o elif self.k>=4: return o+o1+o2+o3+o4 elif self.k>=3: return o + o1 + o2 + o3 elif self.k>=2: return o + o1 + o2 elif self.k>=1: return o + o1 def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): m.weight.data.normal_(0, 0.01) if m.bias is not None: m.bias.data.zero_() class BDCN(nn.Module): def __init__(self, pretrain=None, logger=None, ms=True, block=5, bdcn=True, direction='both', k=3, rate=4): super(BDCN, self).__init__() if logger: logger.info(ms) logger.info(block) logger.info(bdcn) self.pretrain = pretrain self.ms = ms self.block = block self.bdcn = bdcn self.dir = direction self.k = k t = 1 self.features = ablation_vgg16_c.VGG16_C(pretrain, logger, block=block) if ms: self.msblock1_1 = MSBlock(64, k, rate) self.msblock1_2 = MSBlock(64, k, rate) else: t = 2 self.conv1_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.conv1_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.score_dsn1 = nn.Conv2d(21, 1, (1, 1), stride=1) self.score_dsn1_1 = nn.Conv2d(21, 1, 1, stride=1) if block >= 2: if ms: self.msblock2_1 = MSBlock(128, k, rate) self.msblock2_2 = MSBlock(128, k, rate) else: t = 4 self.conv2_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.conv2_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.score_dsn2 = nn.Conv2d(21, 1, (1, 1), stride=1) self.score_dsn2_1 = nn.Conv2d(21, 1, (1, 1), stride=1) self.upsample_2 = nn.ConvTranspose2d(1, 1, 4, stride=2, bias=False) if block >= 3: if ms: self.msblock3_1 = MSBlock(256, k, rate) self.msblock3_2 = MSBlock(256, k, rate) self.msblock3_3 = MSBlock(256, k, rate) else: t = 8 self.conv3_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.conv3_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.conv3_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.score_dsn3 = nn.Conv2d(21, 1, (1, 1), stride=1) self.score_dsn3_1 = nn.Conv2d(21, 1, (1, 1), stride=1) self.upsample_4 = nn.ConvTranspose2d(1, 1, 8, stride=4, bias=False) if block >= 4: if ms: self.msblock4_1 = MSBlock(512, k, rate) self.msblock4_2 = MSBlock(512, k, rate) self.msblock4_3 = MSBlock(512, k, rate) else: t = 16 self.conv4_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.conv4_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.conv4_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.score_dsn4 = nn.Conv2d(21, 1, (1, 1), stride=1) self.score_dsn4_1 = nn.Conv2d(21, 1, (1, 1), stride=1) self.upsample_8 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False) if block >=5: if ms: self.msblock5_1 = MSBlock(512, k, rate) self.msblock5_2 = MSBlock(512, k, rate) self.msblock5_3 = MSBlock(512, k, rate) else: t = 16 self.conv5_1_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.conv5_2_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.conv5_3_down = nn.Conv2d(32*t, 21, (1, 1), stride=1) self.score_dsn5 = nn.Conv2d(21, 1, (1, 1), stride=1) self.score_dsn5_1 = nn.Conv2d(21, 1, (1, 1), stride=1) self.upsample_8_5 = nn.ConvTranspose2d(1, 1, 16, stride=8, bias=False) if bdcn and self.dir == 'both': c = block * 2 else: c = block self.fuse = nn.Conv2d(c, 1, 1, stride=1) self._initialize_weights(logger) def forward(self, x): features = self.features(x) if self.ms: sum1 = self.conv1_1_down(self.msblock1_1(features[0])) + \ self.conv1_2_down(self.msblock1_2(features[1])) else: sum1 = self.conv1_1_down(features[0]) + \ self.conv1_2_down(features[1]) s1 = self.score_dsn1(sum1) if self.bdcn: s11 = self.score_dsn1_1(sum1) if self.block >= 2: if self.ms: sum2 = self.conv2_1_down(self.msblock2_1(features[2])) + \ self.conv2_2_down(self.msblock2_2(features[3])) else: sum2 = self.conv2_1_down(features[2]) + \ self.conv2_2_down(features[3]) s2 = self.score_dsn2(sum2) s2 = self.upsample_2(s2) s2 = crop(s2, x, 1, 1) if self.bdcn: s21 = self.score_dsn2_1(sum2) s21 = self.upsample_2(s21) s21 = crop(s21, x, 1, 1) if self.block >= 3: if self.ms: sum3 = self.conv3_1_down(self.msblock3_1(features[4])) + \ self.conv3_2_down(self.msblock3_2(features[5])) + \ self.conv3_3_down(self.msblock3_3(features[6])) else: sum3 = self.conv3_1_down(features[4]) + \ self.conv3_2_down(features[5]) + \ self.conv3_3_down(features[6]) s3 = self.score_dsn3(sum3) s3 =self.upsample_4(s3) s3 = crop(s3, x, 2, 2) if self.bdcn: s31 = self.score_dsn3_1(sum3) s31 =self.upsample_4(s31) s31 = crop(s31, x, 2, 2) if self.block >= 4: if self.ms: sum4 = self.conv4_1_down(self.msblock4_1(features[7])) + \ self.conv4_2_down(self.msblock4_2(features[8])) + \ self.conv4_3_down(self.msblock4_3(features[9])) else: sum4 = self.conv4_1_down(features[7]) + \ self.conv4_2_down(features[8]) + \ self.conv4_3_down(features[9]) s4 = self.score_dsn4(sum4) s4 = self.upsample_8(s4) s4 = crop(s4, x, 4, 4) if self.bdcn: s41 = self.score_dsn4_1(sum4) s41 = self.upsample_8(s41) s41 = crop(s41, x, 4, 4) if self.block >= 5: if self.ms: sum5 = self.conv5_1_down(self.msblock5_1(features[10])) + \ self.conv5_2_down(self.msblock5_2(features[11])) + \ self.conv5_3_down(self.msblock5_3(features[12])) else: sum5 = self.conv5_1_down(features[10]) + \ self.conv5_2_down(features[11]) + \ self.conv5_3_down(features[12]) s5 = self.score_dsn5(sum5) s5 = self.upsample_8_5(s5) s5 = crop(s5, x, 0, 0) if self.bdcn: s51 = self.score_dsn5_1(sum5) s51 = self.upsample_8_5(s51) s51 = crop(s51, x, 0, 0) if self.bdcn: if self.block >= 5: o1, o2, o3, o4, o5 = s1.detach(), s2.detach(), s3.detach(), s4.detach(), s5.detach() o11, o21, o31, o41, o51 = s11.detach(), s21.detach(), s31.detach(), s41.detach(), s51.detach() p1_1 = s1 p2_1 = s2 + o1 p3_1 = s3 + o2 + o1 p4_1 = s4 + o3 + o2 + o1 p5_1 = s5 + o4 + o3 + o2 + o1 p1_2 = s11 + o21 + o31 + o41 + o51 p2_2 = s21 + o31 + o41 + o51 p3_2 = s31 + o41 + o51 p4_2 = s41 + o51 p5_2 = s51 if self.dir == 'both': fuse = self.fuse(torch.cat([p1_1, p2_1, p3_1, p4_1, p5_1, p1_2, p2_2, p3_2, p4_2, p5_2], 1)) return [p1_1, p2_1, p3_1, p4_1, p5_1, p1_2, p2_2, p3_2, p4_2, p5_2, fuse] if self.dir == 'd2s': fuse = self.fuse(torch.cat([p1_1, p2_1, p3_1, p4_1, p5_1], 1)) return [p1_1, p2_1, p3_1, p4_1, p5_1, fuse] elif self.dir == 's2d': fuse = self.fuse(torch.cat([p1_2, p2_2, p3_2, p4_2, p5_2], 1)) return [p1_2, p2_2, p3_2, p4_2, p5_2, fuse] elif self.block >= 4: o1, o2, o3, o4 = s1.detach(), s2.detach(), s3.detach(), s4.detach() o11, o21, o31, o41 = s11.detach(), s21.detach(), s31.detach(), s41.detach() p1_1 = s1 p2_1 = s2 + o1 p3_1 = s3 + o2 + o1 p4_1 = s4 + o3 + o2 + o1 p1_2 = s11 + o21 + o31 + o41 p2_2 = s21 + o31 + o41 p3_2 = s31 + o41 p4_2 = s41 fuse = self.fuse(torch.cat([p1_1, p2_1, p3_1, p4_1,p1_2, p2_2, p3_2, p4_2], 1)) return [p1_1, p2_1, p3_1, p4_1, p1_2, p2_2, p3_2, p4_2, fuse] elif self.block >= 3: o1, o2, o3 = s1.detach(), s2.detach(), s3.detach() o11, o21, o31 = s11.detach(), s21.detach(), s31.detach() p1_1 = s1 p2_1 = s2 + o1 p3_1 = s3 + o2 + o1 p1_2 = s11 + o21 + o31 p2_2 = s21 + o31 p3_2 = s31 fuse = self.fuse(torch.cat([p1_1, p2_1, p3_1, p1_2, p2_2, p3_2], 1)) return [p1_1, p2_1, p3_1, p1_2, p2_2, p3_2, fuse] elif self.block >= 2: o1, o2 = s1.detach(), s2.detach() o11, o21 = s11.detach(), s21.detach() p1_1 = s1 p2_1 = s2 + o1 p1_2 = s11 + o21 p2_2 = s21 fuse = self.fuse(torch.cat([p1_1, p2_1, p1_2, p2_2], 1)) return [p1_1, p2_1, p1_2, p2_2, fuse] concat = s1 res = [s1] if self.block >= 2: concat = torch.cat([concat, s2], 1) res = [s1, s2] if self.block >= 3: concat = torch.cat([concat, s3], 1) res = [s1, s2, s3] if self.block >= 4: concat = torch.cat([concat, s4], 1) res = [s1, s2, s3, s4] if self.block >= 5: concat = torch.cat([concat, s5], 1) res = [s1, s2, s3, s4, s5] fuse = self.fuse(concat) res.append(fuse) return res def _initialize_weights(self, logger=None): for name, param in self.state_dict().items(): if self.pretrain and 'features' in name: continue elif 'down' in name: param.zero_() elif 'upsample' in name: if logger: logger.info('init upsamle layer %s ' % name) k = int(name.split('.')[0].split('_')[1]) param.copy_(get_upsampling_weight(1, 1, k*2)) elif 'fuse' in name: if logger: logger.info('init params %s ' % name) if 'bias' in name: param.zero_() else: nn.init.constant(param, 0.080) else: if logger: logger.info('init params %s ' % name) if 'bias' in name: param.zero_() else: param.normal_(0, 0.01) if __name__ == '__main__': model = BDCN('./caffemodel2pytorch/vgg16.pth') for name, param in model.state_dict().items(): print name, param