import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from torch.autograd import Variable class conv2DBatchNorm(nn.Module): def __init__( self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, is_batchnorm=True, ): super(conv2DBatchNorm, self).__init__() conv_mod = nn.Conv2d( int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=dilation, ) if is_batchnorm: self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters))) else: self.cb_unit = nn.Sequential(conv_mod) def forward(self, inputs): outputs = self.cb_unit(inputs) return outputs class conv2DGroupNorm(nn.Module): def __init__( self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 ): super(conv2DGroupNorm, self).__init__() conv_mod = nn.Conv2d( int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=dilation, ) self.cg_unit = nn.Sequential(conv_mod, nn.GroupNorm(n_groups, int(n_filters))) def forward(self, inputs): outputs = self.cg_unit(inputs) return outputs class deconv2DBatchNorm(nn.Module): def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): super(deconv2DBatchNorm, self).__init__() self.dcb_unit = nn.Sequential( nn.ConvTranspose2d( int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, ), nn.BatchNorm2d(int(n_filters)), ) def forward(self, inputs): outputs = self.dcb_unit(inputs) return outputs class conv2DBatchNormRelu(nn.Module): def __init__( self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, is_batchnorm=True, ): super(conv2DBatchNormRelu, self).__init__() conv_mod = nn.Conv2d( int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=dilation, ) if is_batchnorm: self.cbr_unit = nn.Sequential( conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True) ) else: self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True)) def forward(self, inputs): outputs = self.cbr_unit(inputs) return outputs class conv2DGroupNormRelu(nn.Module): def __init__( self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1, n_groups=16 ): super(conv2DGroupNormRelu, self).__init__() conv_mod = nn.Conv2d( int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=dilation, ) self.cgr_unit = nn.Sequential( conv_mod, nn.GroupNorm(n_groups, int(n_filters)), nn.ReLU(inplace=True) ) def forward(self, inputs): outputs = self.cgr_unit(inputs) return outputs class deconv2DBatchNormRelu(nn.Module): def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): super(deconv2DBatchNormRelu, self).__init__() self.dcbr_unit = nn.Sequential( nn.ConvTranspose2d( int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, ), nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True), ) def forward(self, inputs): outputs = self.dcbr_unit(inputs) return outputs class unetConv2(nn.Module): def __init__(self, in_size, out_size, is_batchnorm): super(unetConv2, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential( nn.Conv2d(in_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU() ) self.conv2 = nn.Sequential( nn.Conv2d(out_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU() ) else: self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), nn.ReLU()) self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), nn.ReLU()) def forward(self, inputs): outputs = self.conv1(inputs) outputs = self.conv2(outputs) return outputs class unetUp(nn.Module): def __init__(self, in_size, out_size, is_deconv): super(unetUp, self).__init__() self.conv = unetConv2(in_size, out_size, False) if is_deconv: self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) else: self.up = nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, inputs1, inputs2): outputs2 = self.up(inputs2) offset = outputs2.size()[2] - inputs1.size()[2] padding = 2 * [offset // 2, offset // 2] outputs1 = F.pad(inputs1, padding) return self.conv(torch.cat([outputs1, outputs2], 1)) class segnetDown2(nn.Module): def __init__(self, in_size, out_size): super(segnetDown2, self).__init__() self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) def forward(self, inputs): outputs = self.conv1(inputs) outputs = self.conv2(outputs) unpooled_shape = outputs.size() outputs, indices = self.maxpool_with_argmax(outputs) return outputs, indices, unpooled_shape class segnetDown3(nn.Module): def __init__(self, in_size, out_size): super(segnetDown3, self).__init__() self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) def forward(self, inputs): outputs = self.conv1(inputs) outputs = self.conv2(outputs) outputs = self.conv3(outputs) unpooled_shape = outputs.size() outputs, indices = self.maxpool_with_argmax(outputs) return outputs, indices, unpooled_shape class segnetUp2(nn.Module): def __init__(self, in_size, out_size): super(segnetUp2, self).__init__() self.unpool = nn.MaxUnpool2d(2, 2) self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) def forward(self, inputs, indices, output_shape): outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) outputs = self.conv1(outputs) outputs = self.conv2(outputs) return outputs class segnetUp3(nn.Module): def __init__(self, in_size, out_size): super(segnetUp3, self).__init__() self.unpool = nn.MaxUnpool2d(2, 2) self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) def forward(self, inputs, indices, output_shape): outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) outputs = self.conv1(outputs) outputs = self.conv2(outputs) outputs = self.conv3(outputs) return outputs class residualBlock(nn.Module): expansion = 1 def __init__(self, in_channels, n_filters, stride=1, downsample=None): super(residualBlock, self).__init__() self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) self.downsample = downsample self.stride = stride self.relu = nn.ReLU(inplace=True) def forward(self, x): residual = x out = self.convbnrelu1(x) out = self.convbn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class residualBottleneck(nn.Module): expansion = 4 def __init__(self, in_channels, n_filters, stride=1, downsample=None): super(residualBottleneck, self).__init__() self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) self.convbn2 = nn.Conv2DBatchNorm( n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False ) self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.convbn1(x) out = self.convbn2(out) out = self.convbn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class linknetUp(nn.Module): def __init__(self, in_channels, n_filters): super(linknetUp, self).__init__() # B, 2C, H, W -> B, C/2, H, W self.convbnrelu1 = conv2DBatchNormRelu( in_channels, n_filters / 2, k_size=1, stride=1, padding=1 ) # B, C/2, H, W -> B, C/2, H, W self.deconvbnrelu2 = nn.deconv2DBatchNormRelu( n_filters / 2, n_filters / 2, k_size=3, stride=2, padding=0 ) # B, C/2, H, W -> B, C, H, W self.convbnrelu3 = conv2DBatchNormRelu( n_filters / 2, n_filters, k_size=1, stride=1, padding=1 ) def forward(self, x): x = self.convbnrelu1(x) x = self.deconvbnrelu2(x) x = self.convbnrelu3(x) return x class FRRU(nn.Module): """ Full Resolution Residual Unit for FRRN """ def __init__(self, prev_channels, out_channels, scale, group_norm=False, n_groups=None): super(FRRU, self).__init__() self.scale = scale self.prev_channels = prev_channels self.out_channels = out_channels self.group_norm = group_norm self.n_groups = n_groups if self.group_norm: conv_unit = conv2DGroupNormRelu self.conv1 = conv_unit( prev_channels + 32, out_channels, k_size=3, stride=1, padding=1, bias=False, n_groups=self.n_groups, ) self.conv2 = conv_unit( out_channels, out_channels, k_size=3, stride=1, padding=1, bias=False, n_groups=self.n_groups, ) else: conv_unit = conv2DBatchNormRelu self.conv1 = conv_unit( prev_channels + 32, out_channels, k_size=3, stride=1, padding=1, bias=False ) self.conv2 = conv_unit( out_channels, out_channels, k_size=3, stride=1, padding=1, bias=False ) self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0) def forward(self, y, z): x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1) y_prime = self.conv1(x) y_prime = self.conv2(y_prime) x = self.conv_res(y_prime) upsample_size = torch.Size([_s * self.scale for _s in y_prime.shape[-2:]]) x = F.upsample(x, size=upsample_size, mode="nearest") z_prime = z + x return y_prime, z_prime class RU(nn.Module): """ Residual Unit for FRRN """ def __init__(self, channels, kernel_size=3, strides=1, group_norm=False, n_groups=None): super(RU, self).__init__() self.group_norm = group_norm self.n_groups = n_groups if self.group_norm: self.conv1 = conv2DGroupNormRelu( channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False, n_groups=self.n_groups, ) self.conv2 = conv2DGroupNorm( channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False, n_groups=self.n_groups, ) else: self.conv1 = conv2DBatchNormRelu( channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False ) self.conv2 = conv2DBatchNorm( channels, channels, k_size=kernel_size, stride=strides, padding=1, bias=False ) def forward(self, x): incoming = x x = self.conv1(x) x = self.conv2(x) return x + incoming class residualConvUnit(nn.Module): def __init__(self, channels, kernel_size=3): super(residualConvUnit, self).__init__() self.residual_conv_unit = nn.Sequential( nn.ReLU(inplace=True), nn.Conv2d(channels, channels, kernel_size=kernel_size), nn.ReLU(inplace=True), nn.Conv2d(channels, channels, kernel_size=kernel_size), ) def forward(self, x): input = x x = self.residual_conv_unit(x) return x + input class multiResolutionFusion(nn.Module): def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape): super(multiResolutionFusion, self).__init__() self.up_scale_high = up_scale_high self.up_scale_low = up_scale_low self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3) if low_shape is not None: self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3) def forward(self, x_high, x_low): high_upsampled = F.upsample( self.conv_high(x_high), scale_factor=self.up_scale_high, mode="bilinear" ) if x_low is None: return high_upsampled low_upsampled = F.upsample( self.conv_low(x_low), scale_factor=self.up_scale_low, mode="bilinear" ) return low_upsampled + high_upsampled class chainedResidualPooling(nn.Module): def __init__(self, channels, input_shape): super(chainedResidualPooling, self).__init__() self.chained_residual_pooling = nn.Sequential( nn.ReLU(inplace=True), nn.MaxPool2d(5, 1, 2), nn.Conv2d(input_shape[1], channels, kernel_size=3), ) def forward(self, x): input = x x = self.chained_residual_pooling(x) return x + input class pyramidPooling(nn.Module): def __init__( self, in_channels, pool_sizes, model_name="pspnet", fusion_mode="cat", is_batchnorm=True ): super(pyramidPooling, self).__init__() bias = not is_batchnorm self.paths = [] for i in range(len(pool_sizes)): self.paths.append( conv2DBatchNormRelu( in_channels, int(in_channels / len(pool_sizes)), 1, 1, 0, bias=bias, is_batchnorm=is_batchnorm, ) ) self.path_module_list = nn.ModuleList(self.paths) self.pool_sizes = pool_sizes self.model_name = model_name self.fusion_mode = fusion_mode def forward(self, x): h, w = x.shape[2:] if self.training or self.model_name != "icnet": # general settings or pspnet k_sizes = [] strides = [] for pool_size in self.pool_sizes: k_sizes.append((int(h / pool_size), int(w / pool_size))) strides.append((int(h / pool_size), int(w / pool_size))) else: # eval mode and icnet: pre-trained for 1025 x 2049 k_sizes = [(8, 15), (13, 25), (17, 33), (33, 65)] strides = [(5, 10), (10, 20), (16, 32), (33, 65)] if self.fusion_mode == "cat": # pspnet: concat (including x) output_slices = [x] for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) if self.model_name != "icnet": out = module(out) out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) output_slices.append(out) return torch.cat(output_slices, dim=1) else: # icnet: element-wise sum (including x) pp_sum = x for i, (module, pool_size) in enumerate(zip(self.path_module_list, self.pool_sizes)): out = F.avg_pool2d(x, k_sizes[i], stride=strides[i], padding=0) # out = F.adaptive_avg_pool2d(x, output_size=(pool_size, pool_size)) if self.model_name != "icnet": out = module(out) out = F.interpolate(out, size=(h, w), mode="bilinear", align_corners=True) pp_sum = pp_sum + out return pp_sum class bottleNeckPSP(nn.Module): def __init__( self, in_channels, mid_channels, out_channels, stride, dilation=1, is_batchnorm=True ): super(bottleNeckPSP, self).__init__() bias = not is_batchnorm self.cbr1 = conv2DBatchNormRelu( in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm ) if dilation > 1: self.cbr2 = conv2DBatchNormRelu( mid_channels, mid_channels, 3, stride=stride, padding=dilation, bias=bias, dilation=dilation, is_batchnorm=is_batchnorm, ) else: self.cbr2 = conv2DBatchNormRelu( mid_channels, mid_channels, 3, stride=stride, padding=1, bias=bias, dilation=1, is_batchnorm=is_batchnorm, ) self.cb3 = conv2DBatchNorm( mid_channels, out_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm ) self.cb4 = conv2DBatchNorm( in_channels, out_channels, 1, stride=stride, padding=0, bias=bias, is_batchnorm=is_batchnorm, ) def forward(self, x): conv = self.cb3(self.cbr2(self.cbr1(x))) residual = self.cb4(x) return F.relu(conv + residual, inplace=True) class bottleNeckIdentifyPSP(nn.Module): def __init__(self, in_channels, mid_channels, stride, dilation=1, is_batchnorm=True): super(bottleNeckIdentifyPSP, self).__init__() bias = not is_batchnorm self.cbr1 = conv2DBatchNormRelu( in_channels, mid_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm ) if dilation > 1: self.cbr2 = conv2DBatchNormRelu( mid_channels, mid_channels, 3, stride=1, padding=dilation, bias=bias, dilation=dilation, is_batchnorm=is_batchnorm, ) else: self.cbr2 = conv2DBatchNormRelu( mid_channels, mid_channels, 3, stride=1, padding=1, bias=bias, dilation=1, is_batchnorm=is_batchnorm, ) self.cb3 = conv2DBatchNorm( mid_channels, in_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm ) def forward(self, x): residual = x x = self.cb3(self.cbr2(self.cbr1(x))) return F.relu(x + residual, inplace=True) class residualBlockPSP(nn.Module): def __init__( self, n_blocks, in_channels, mid_channels, out_channels, stride, dilation=1, include_range="all", is_batchnorm=True, ): super(residualBlockPSP, self).__init__() if dilation > 1: stride = 1 # residualBlockPSP = convBlockPSP + identityBlockPSPs layers = [] if include_range in ["all", "conv"]: layers.append( bottleNeckPSP( in_channels, mid_channels, out_channels, stride, dilation, is_batchnorm=is_batchnorm, ) ) if include_range in ["all", "identity"]: for i in range(n_blocks - 1): layers.append( bottleNeckIdentifyPSP( out_channels, mid_channels, stride, dilation, is_batchnorm=is_batchnorm ) ) self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x) class cascadeFeatureFusion(nn.Module): def __init__( self, n_classes, low_in_channels, high_in_channels, out_channels, is_batchnorm=True ): super(cascadeFeatureFusion, self).__init__() bias = not is_batchnorm self.low_dilated_conv_bn = conv2DBatchNorm( low_in_channels, out_channels, 3, stride=1, padding=2, bias=bias, dilation=2, is_batchnorm=is_batchnorm, ) self.low_classifier_conv = nn.Conv2d( int(low_in_channels), int(n_classes), kernel_size=1, padding=0, stride=1, bias=True, dilation=1, ) # Train only self.high_proj_conv_bn = conv2DBatchNorm( high_in_channels, out_channels, 1, stride=1, padding=0, bias=bias, is_batchnorm=is_batchnorm, ) def forward(self, x_low, x_high): x_low_upsampled = F.interpolate( x_low, size=get_interp_size(x_low, z_factor=2), mode="bilinear", align_corners=True ) low_cls = self.low_classifier_conv(x_low_upsampled) low_fm = self.low_dilated_conv_bn(x_low_upsampled) high_fm = self.high_proj_conv_bn(x_high) high_fused_fm = F.relu(low_fm + high_fm, inplace=True) return high_fused_fm, low_cls def get_interp_size(input, s_factor=1, z_factor=1): # for caffe ori_h, ori_w = input.shape[2:] # shrink (s_factor >= 1) ori_h = (ori_h - 1) / s_factor + 1 ori_w = (ori_w - 1) / s_factor + 1 # zoom (z_factor >= 1) ori_h = ori_h + (ori_h - 1) * (z_factor - 1) ori_w = ori_w + (ori_w - 1) * (z_factor - 1) resize_shape = (int(ori_h), int(ori_w)) return resize_shape def interp(input, output_size, mode="bilinear"): n, c, ih, iw = input.shape oh, ow = output_size # normalize to [-1, 1] h = torch.arange(0, oh, dtype=torch.float, device=input.device) / (oh - 1) * 2 - 1 w = torch.arange(0, ow, dtype=torch.float, device=input.device) / (ow - 1) * 2 - 1 grid = torch.zeros(oh, ow, 2, dtype=torch.float, device=input.device) grid[:, :, 0] = w.unsqueeze(0).repeat(oh, 1) grid[:, :, 1] = h.unsqueeze(0).repeat(ow, 1).transpose(0, 1) grid = grid.unsqueeze(0).repeat(n, 1, 1, 1) # grid.shape: [n, oh, ow, 2] grid = Variable(grid) if input.is_cuda: grid = grid.cuda() return F.grid_sample(input, grid, mode=mode) def get_upsampling_weight(in_channels, out_channels, kernel_size): """Make a 2D bilinear kernel suitable for upsampling""" factor = (kernel_size + 1) // 2 if kernel_size % 2 == 1: center = factor - 1 else: center = factor - 0.5 og = np.ogrid[:kernel_size, :kernel_size] filt = (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size), dtype=np.float64) weight[range(in_channels), range(out_channels), :, :] = filt return torch.from_numpy(weight).float()