import torch import torch.nn as nn import torch.nn.functional as F class conv2DBatchNorm(nn.Module): def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1): super(conv2DBatchNorm, self).__init__() if dilation > 1: conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=dilation) else: conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=1) self.cb_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters)),) def forward(self, inputs): outputs = self.cb_unit(inputs) return outputs class deconv2DBatchNorm(nn.Module): def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): super(deconv2DBatchNorm, self).__init__() self.dcb_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias), nn.BatchNorm2d(int(n_filters)),) def forward(self, inputs): outputs = self.dcb_unit(inputs) return outputs class conv2DBatchNormRelu(nn.Module): def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True, dilation=1): super(conv2DBatchNormRelu, self).__init__() if dilation > 1: conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=dilation) else: conv_mod = nn.Conv2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias, dilation=1) self.cbr_unit = nn.Sequential(conv_mod, nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True),) def forward(self, inputs): outputs = self.cbr_unit(inputs) return outputs class deconv2DBatchNormRelu(nn.Module): def __init__(self, in_channels, n_filters, k_size, stride, padding, bias=True): super(deconv2DBatchNormRelu, self).__init__() self.dcbr_unit = nn.Sequential(nn.ConvTranspose2d(int(in_channels), int(n_filters), kernel_size=k_size, padding=padding, stride=stride, bias=bias), nn.BatchNorm2d(int(n_filters)), nn.ReLU(inplace=True),) def forward(self, inputs): outputs = self.dcbr_unit(inputs) return outputs class unetConv2(nn.Module): def __init__(self, in_size, out_size, is_batchnorm): super(unetConv2, self).__init__() if is_batchnorm: self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU(),) self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), nn.BatchNorm2d(out_size), nn.ReLU(),) else: self.conv1 = nn.Sequential(nn.Conv2d(in_size, out_size, 3, 1, 0), nn.ReLU(),) self.conv2 = nn.Sequential(nn.Conv2d(out_size, out_size, 3, 1, 0), nn.ReLU(),) def forward(self, inputs): outputs = self.conv1(inputs) outputs = self.conv2(outputs) return outputs class unetUp(nn.Module): def __init__(self, in_size, out_size, is_deconv): super(unetUp, self).__init__() self.conv = unetConv2(in_size, out_size, False) if is_deconv: self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2) else: self.up = nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, inputs1, inputs2): outputs2 = self.up(inputs2) offset = outputs2.size()[2] - inputs1.size()[2] padding = 2 * [offset // 2, offset // 2] outputs1 = F.pad(inputs1, padding) return self.conv(torch.cat([outputs1, outputs2], 1)) class segnetDown2(nn.Module): def __init__(self, in_size, out_size): super(segnetDown2, self).__init__() self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) def forward(self, inputs): outputs = self.conv1(inputs) outputs = self.conv2(outputs) unpooled_shape = outputs.size() outputs, indices = self.maxpool_with_argmax(outputs) return outputs, indices, unpooled_shape class segnetDown3(nn.Module): def __init__(self, in_size, out_size): super(segnetDown3, self).__init__() self.conv1 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) self.conv2 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) self.conv3 = conv2DBatchNormRelu(out_size, out_size, 3, 1, 1) self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) def forward(self, inputs): outputs = self.conv1(inputs) outputs = self.conv2(outputs) outputs = self.conv3(outputs) unpooled_shape = outputs.size() outputs, indices = self.maxpool_with_argmax(outputs) return outputs, indices, unpooled_shape class segnetUp2(nn.Module): def __init__(self, in_size, out_size): super(segnetUp2, self).__init__() self.unpool = nn.MaxUnpool2d(2, 2) self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) self.conv2 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) def forward(self, inputs, indices, output_shape): outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) outputs = self.conv1(outputs) outputs = self.conv2(outputs) return outputs class segnetUp3(nn.Module): def __init__(self, in_size, out_size): super(segnetUp3, self).__init__() self.unpool = nn.MaxUnpool2d(2, 2) self.conv1 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) self.conv2 = conv2DBatchNormRelu(in_size, in_size, 3, 1, 1) self.conv3 = conv2DBatchNormRelu(in_size, out_size, 3, 1, 1) def forward(self, inputs, indices, output_shape): outputs = self.unpool(input=inputs, indices=indices, output_size=output_shape) outputs = self.conv1(outputs) outputs = self.conv2(outputs) outputs = self.conv3(outputs) return outputs class residualBlock(nn.Module): expansion = 1 def __init__(self, in_channels, n_filters, stride=1, downsample=None): super(residualBlock, self).__init__() self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters, 3, stride, 1, bias=False) self.convbn2 = conv2DBatchNorm(n_filters, n_filters, 3, 1, 1, bias=False) self.downsample = downsample self.stride = stride self.relu = nn.ReLU(inplace=True) def forward(self, x): residual = x out = self.convbnrelu1(x) out = self.convbn2(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class residualBottleneck(nn.Module): expansion = 4 def __init__(self, in_channels, n_filters, stride=1, downsample=None): super(residualBottleneck, self).__init__() self.convbn1 = nn.Conv2DBatchNorm(in_channels, n_filters, k_size=1, bias=False) self.convbn2 = nn.Conv2DBatchNorm(n_filters, n_filters, k_size=3, padding=1, stride=stride, bias=False) self.convbn3 = nn.Conv2DBatchNorm(n_filters, n_filters * 4, k_size=1, bias=False) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride def forward(self, x): residual = x out = self.convbn1(x) out = self.convbn2(out) out = self.convbn3(out) if self.downsample is not None: residual = self.downsample(x) out += residual out = self.relu(out) return out class linknetUp(nn.Module): def __init__(self, in_channels, n_filters): super(linknetUp, self).__init__() # B, 2C, H, W -> B, C/2, H, W self.convbnrelu1 = conv2DBatchNormRelu(in_channels, n_filters/2, k_size=1, stride=1, padding=1) # B, C/2, H, W -> B, C/2, H, W self.deconvbnrelu2 = nn.deconv2DBatchNormRelu(n_filters/2, n_filters/2, k_size=3, stride=2, padding=0) # B, C/2, H, W -> B, C, H, W self.convbnrelu3 = conv2DBatchNormRelu(n_filters/2, n_filters, k_size=1, stride=1, padding=1) def forward(self, x): x = self.convbnrelu1(x) x = self.deconvbnrelu2(x) x = self.convbnrelu3(x) return x class FRRU(nn.Module): """ Full Resolution Residual Unit for FRRN """ def __init__(self, prev_channels, out_channels, scale): super(FRRU, self).__init__() self.scale = scale self.prev_channels = prev_channels self.out_channels = out_channels self.conv1 = conv2DBatchNormRelu(prev_channels + 32, out_channels, k_size=3, stride=1, padding=1) self.conv2 = conv2DBatchNormRelu(out_channels, out_channels, k_size=3, stride=1, padding=1) self.conv_res = nn.Conv2d(out_channels, 32, kernel_size=1, stride=1, padding=0) def forward(self, y, z): x = torch.cat([y, nn.MaxPool2d(self.scale, self.scale)(z)], dim=1) y_prime = self.conv1(x) y_prime = self.conv2(y_prime) x = self.conv_res(y_prime) upsample_size = torch.Size([_s*self.scale for _s in y_prime.shape[-2:]]) x = F.upsample(x, size=upsample_size, mode='nearest') z_prime = z + x return y_prime, z_prime class RU(nn.Module): """ Residual Unit for FRRN """ def __init__(self, channels, kernel_size=3, strides=1): super(RU, self).__init__() self.conv1 = conv2DBatchNormRelu(channels, channels, k_size=kernel_size, stride=strides, padding=1) self.conv2 = conv2DBatchNorm(channels, channels, k_size=kernel_size, stride=strides, padding=1) def forward(self, x): incoming = x x = self.conv1(x) x = self.conv2(x) return x + incoming class residualConvUnit(nn.Module): def __init__(self, channels, kernel_size=3): super(residualConvUnit, self).__init__() self.residual_conv_unit = nn.Sequential(nn.ReLU(inplace=True), nn.Conv2d(channels, channels, kernel_size=kernel_size), nn.ReLU(inplace=True), nn.Conv2d(channels, channels, kernel_size=kernel_size),) def forward(self, x): input = x x = self.residual_conv_unit(x) return x + input class multiResolutionFusion(nn.Module): def __init__(self, channels, up_scale_high, up_scale_low, high_shape, low_shape): super(multiResolutionFusion, self).__init__() self.up_scale_high = up_scale_high self.up_scale_low = up_scale_low self.conv_high = nn.Conv2d(high_shape[1], channels, kernel_size=3) if low_shape is not None: self.conv_low = nn.Conv2d(low_shape[1], channels, kernel_size=3) def forward(self, x_high, x_low): high_upsampled = F.upsample(self.conv_high(x_high), scale_factor=self.up_scale_high, mode='bilinear') if x_low is None: return high_upsampled low_upsampled = F.upsample(self.conv_low(x_low), scale_factor=self.up_scale_low, mode='bilinear') return low_upsampled + high_upsampled class chainedResidualPooling(nn.Module): def __init__(self, channels, input_shape): super(chainedResidualPooling, self).__init__() self.chained_residual_pooling = nn.Sequential(nn.ReLU(inplace=True), nn.MaxPool2d(5, 1, 2), nn.Conv2d(input_shape[1], channels, kernel_size=3),) def forward(self, x): input = x x = self.chained_residual_pooling(x) return x + input class pyramidPooling(nn.Module): def __init__(self, in_channels, pool_sizes): super(pyramidPooling, self).__init__() self.paths = [] for i in range(len(pool_sizes)): self.paths.append(conv2DBatchNormRelu(in_channels, int(in_channels / len(pool_sizes)), 1, 1, 0, bias=False)) self.path_module_list = nn.ModuleList(self.paths) self.pool_sizes = pool_sizes def forward(self, x): output_slices = [x] h, w = x.shape[2:] for module, pool_size in zip(self.path_module_list, self.pool_sizes): out = F.avg_pool2d(x, int(h/pool_size), int(h/pool_size), 0) out = module(out) out = F.upsample(out, size=(h,w), mode='bilinear') output_slices.append(out) return torch.cat(output_slices, dim=1) class bottleNeckPSP(nn.Module): def __init__(self, in_channels, mid_channels, out_channels, stride, dilation=1): super(bottleNeckPSP, self).__init__() self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, 1, 0, bias=False) if dilation > 1: self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, 1, padding=dilation, bias=False, dilation=dilation) else: self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, stride=stride, padding=1, bias=False, dilation=1) self.cb3 = conv2DBatchNorm(mid_channels, out_channels, 1, 1, 0, bias=False) self.cb4 = conv2DBatchNorm(in_channels, out_channels, 1, stride, 0, bias=False) def forward(self, x): conv = self.cb3(self.cbr2(self.cbr1(x))) residual = self.cb4(x) return F.relu(conv+residual, inplace=True) class bottleNeckIdentifyPSP(nn.Module): def __init__(self, in_channels, mid_channels, stride, dilation=1): super(bottleNeckIdentifyPSP, self).__init__() self.cbr1 = conv2DBatchNormRelu(in_channels, mid_channels, 1, 1, 0, bias=False) if dilation > 1: self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, 1, padding=dilation, bias=False, dilation=dilation) else: self.cbr2 = conv2DBatchNormRelu(mid_channels, mid_channels, 3, stride=1, padding=1, bias=False, dilation=1) self.cb3 = conv2DBatchNorm(mid_channels, in_channels, 1, 1, 0, bias=False) def forward(self, x): residual = x x = self.cb3(self.cbr2(self.cbr1(x))) return F.relu(x+residual, inplace=True) class residualBlockPSP(nn.Module): def __init__(self, n_blocks, in_channels, mid_channels, out_channels, stride, dilation=1): super(residualBlockPSP, self).__init__() if dilation > 1: stride = 1 layers = [bottleNeckPSP(in_channels, mid_channels, out_channels, stride, dilation)] for i in range(n_blocks): layers.append(bottleNeckIdentifyPSP(out_channels, mid_channels, stride, dilation)) self.layers = nn.Sequential(*layers) def forward(self, x): return self.layers(x)