#============================================ __author__ = "Sachin Mehta" __license__ = "MIT" __maintainer__ = "Sachin Mehta" # File Description: This file contains the CNN models and is adapted from ESPNet and Y-Net # ESPNET: https://arxiv.org/pdf/1803.06815.pdf # Y-Net: https://arxiv.org/abs/1806.01313 # ============================================================================== import torch import torch.nn as nn import torch.nn.functional as F import math from torch.autograd import Variable class CBR(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1): super().__init__() padding = int((kSize - 1) / 2) self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03) self.act = nn.ReLU(inplace=True) def forward(self, input): output = self.conv(input) output = self.bn(output) output = self.act(output) return output class CB(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1): super().__init__() padding = int((kSize - 1) / 2) self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False) self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03) def forward(self, input): output = self.conv(input) output = self.bn(output) return output class C(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1, groups=1): super().__init__() padding = int((kSize - 1) / 2) self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, groups=groups) def forward(self, input): output = self.conv(input) return output class DownSamplerA(nn.Module): def __init__(self, nIn, nOut): super().__init__() self.conv = CBR(nIn, nOut, 3, 2) def forward(self, input): output = self.conv(input) return output class DownSamplerB(nn.Module): def __init__(self, nIn, nOut): super().__init__() k = 4 n = int(nOut/k) n1 = nOut - (k-1)*n self.c1 = nn.Sequential(CBR(nIn, n, 1, 1), C(n, n, 3, 2)) self.d1 = CDilated(n, n1, 3, 1, 1) self.d2 = CDilated(n, n, 3, 1, 2) self.d4 = CDilated(n, n, 3, 1, 3) self.d8 = CDilated(n, n, 3, 1, 4) self.bn = BR(nOut) def forward(self, input): output1 = self.c1(input) d1 = self.d1(output1) d2 = self.d2(output1) d4 = self.d4(output1) d8 = self.d8(output1) add1 = d2 add2 = add1 + d4 add3 = add2 + d8 combine = torch.cat([d1, add1, add2, add3],1) if input.size() == combine.size(): combine = input + combine output = self.bn(combine) return output class BR(nn.Module): def __init__(self, nOut): super().__init__() self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03) self.act = nn.ReLU(inplace=True) # nn.PReLU(nOut) def forward(self, input): output = self.bn(input) output = self.act(output) return output class CDilated(nn.Module): def __init__(self, nIn, nOut, kSize, stride=1, d=1, groups=1): super().__init__() padding = int((kSize - 1) / 2) * d self.conv = nn.Conv3d(nIn, nOut, kSize, stride=stride, padding=padding, bias=False, dilation=d, groups=groups) #self.bn = nn.BatchNorm3d(nOut, momentum=0.95, eps=1e-03) def forward(self, input): return self.conv(input) #return self.bn(output) class InputProjectionA(nn.Module): ''' This class projects the input image to the same spatial dimensions as the feature map. For example, if the input image is 512 x512 x3 and spatial dimensions of feature map size are 56x56xF, then this class will generate an output of 56x56x3 ''' def __init__(self, samplingTimes): ''' :param samplingTimes: The rate at which you want to down-sample the image ''' super().__init__() self.pool = nn.ModuleList() for i in range(0, samplingTimes): # pyramid-based approach for down-sampling self.pool.append(nn.AvgPool3d(3, stride=2, padding=1)) def forward(self, input): ''' :param input: Input RGB Image :return: down-sampled image (pyramid-based approach) ''' for pool in self.pool: input = pool(input) return input class DilatedParllelResidualBlockB1(nn.Module): # with k=4 def __init__(self, nIn, nOut, stride=1): super().__init__() k = 4 n = int(nOut / k) n1 = nOut - (k - 1) * n self.c1 = CBR(nIn, n, 1, 1) self.d1 = CDilated(n, n1, 3, stride, 1) self.d2 = CDilated(n, n, 3, stride, 1) self.d4 = CDilated(n, n, 3, stride, 2) self.d8 = CDilated(n, n, 3, stride, 2) self.bn = nn.BatchNorm3d(nOut) def forward(self, input): output1 = self.c1(input) d1 = self.d1(output1) d2 = self.d2(output1) d4 = self.d4(output1) d8 = self.d8(output1) add1 = d2 add2 = add1 + d4 add3 = add2 + d8 combine = self.bn(torch.cat([d1, add1, add2, add3], 1)) if input.size() == combine.size(): combine = input + combine output = F.relu(combine, inplace=True) return output class ASPBlock(nn.Module): # with k=4 def __init__(self, nIn, nOut, stride=1): super().__init__() self.d1 = CB(nIn, nOut, 3, 1) self.d2 = CB(nIn, nOut, 5, 1) self.d4 = CB(nIn, nOut, 7, 1) self.d8 = CB(nIn, nOut, 9, 1) self.act = nn.ReLU(inplace=True) def forward(self, input): d1 = self.d1(input) d2 = self.d2(input) d3 = self.d4(input) d4 = self.d8(input) combine = d1 + d2 + d3 + d4 if input.size() == combine.size(): combine = input + combine output = self.act(combine) return output class UpSampler(nn.Module): ''' Up-sample the feature maps by 2 ''' def __init__(self, nIn, nOut): super().__init__() self.up = CBR(nIn, nOut, 3, 1) def forward(self, inp): return F.upsample(self.up(inp), mode='trilinear', scale_factor=2) class PSPDec(nn.Module): ''' Inspired or Adapted from Pyramid Scene Network paper ''' def __init__(self, nIn, nOut, downSize): super().__init__() self.scale = downSize self.features = CBR(nIn, nOut, 3, 1) def forward(self, x): assert x.dim() == 5 inp_size = x.size() out_dim1, out_dim2, out_dim3 = int(inp_size[2] * self.scale), int(inp_size[3] * self.scale), int(inp_size[4] * self.scale) x_down = F.adaptive_avg_pool3d(x, output_size=(out_dim1, out_dim2, out_dim3)) return F.upsample(self.features(x_down), size=(inp_size[2], inp_size[3], inp_size[4]), mode='trilinear') class ESPNet(nn.Module): def __init__(self, classes=4, channels=1): super().__init__() self.input1 = InputProjectionA(1) self.input2 = InputProjectionA(1) initial = 16 # feature maps at level 1 config = [32, 128, 256, 256] # feature maps at level 2 and onwards reps = [2, 2, 3] ### ENCODER # all dimensions are listed with respect to an input of size 4 x 128 x 128 x 128 self.level0 = CBR(channels, initial, 7, 2) # initial x 64 x 64 x64 self.level1 = nn.ModuleList() for i in range(reps[0]): if i==0: self.level1.append(DilatedParllelResidualBlockB1(initial, config[0])) # config[0] x 64 x 64 x64 else: self.level1.append(DilatedParllelResidualBlockB1(config[0], config[0])) # config[0] x 64 x 64 x64 # downsample the feature maps self.level2 = DilatedParllelResidualBlockB1(config[0], config[1], stride=2) # config[1] x 32 x 32 x 32 self.level_2 = nn.ModuleList() for i in range(0, reps[1]): self.level_2.append(DilatedParllelResidualBlockB1(config[1], config[1])) # config[1] x 32 x 32 x 32 # downsample the feature maps self.level3_0 = DilatedParllelResidualBlockB1(config[1], config[2], stride=2) # config[2] x 16 x 16 x 16 self.level_3 = nn.ModuleList() for i in range(0, reps[2]): self.level_3.append(DilatedParllelResidualBlockB1(config[2], config[2])) # config[2] x 16 x 16 x 16 ### DECODER # upsample the feature maps self.up_l3_l2 = UpSampler(config[2], config[1]) # config[1] x 32 x 32 x 32 # Note the 2 in below line. You need this because you are concatenating feature maps from encoder # with upsampled feature maps self.merge_l2 = DilatedParllelResidualBlockB1(2 * config[1], config[1]) # config[1] x 32 x 32 x 32 self.dec_l2 = nn.ModuleList() for i in range(0, reps[0]): self.dec_l2.append(DilatedParllelResidualBlockB1(config[1], config[1])) # config[1] x 32 x 32 x 32 self.up_l2_l1 = UpSampler(config[1], config[0]) # config[0] x 64 x 64 x 64 # Note the 2 in below line. You need this because you are concatenating feature maps from encoder # with upsampled feature maps self.merge_l1 = DilatedParllelResidualBlockB1(2*config[0], config[0]) # config[0] x 64 x 64 x 64 self.dec_l1 = nn.ModuleList() for i in range(0, reps[0]): self.dec_l1.append(DilatedParllelResidualBlockB1(config[0], config[0])) # config[0] x 64 x 64 x 64 self.dec_l1.append(CBR(config[0], classes, 3, 1)) # classes x 64 x 64 x 64 # We use ESP block without reduction step because the number of input feature maps are very small (i.e. 4 in # our case) self.dec_l1.append(ASPBlock(classes, classes)) # Using PSP module to learn the representations at different scales self.pspModules = nn.ModuleList() scales = [0.2, 0.4, 0.6, 0.8] for sc in scales: self.pspModules.append(PSPDec(classes, classes, sc)) # Classifier self.classifier = self.classifier = nn.Sequential( CBR((len(scales) + 1) * classes, classes, 3, 1), ASPBlock(classes, classes), # classes x 64 x 64 x 64 nn.Upsample(scale_factor=2), # classes x 128 x 128 x 128 CBR(classes, classes, 7, 1), # classes x 128 x 128 x 128 C(classes, classes, 1, 1) # classes x 128 x 128 x 128 ) # for m in self.modules(): if isinstance(m, nn.Conv3d): n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if isinstance(m, nn.ConvTranspose3d): n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) elif isinstance(m, nn.BatchNorm3d): m.weight.data.fill_(1) m.bias.data.zero_() def forward(self, input1, inp_res=(128, 128, 128), inpSt2=False): dim0 = input1.size(2) dim1 = input1.size(3) dim2 = input1.size(4) if self.training or inp_res is None: # input resolution should be divisible by 8 inp_res = (math.ceil(dim0 / 8) * 8, math.ceil(dim1 / 8) * 8, math.ceil(dim2 / 8) * 8) if inp_res: input1 = F.adaptive_avg_pool3d(input1, output_size=inp_res) out_l0 = self.level0(input1) for i, layer in enumerate(self.level1): #64 if i == 0: out_l1 = layer(out_l0) else: out_l1 = layer(out_l1) out_l2_down = self.level2(out_l1) #32 for i, layer in enumerate(self.level_2): if i == 0: out_l2 = layer(out_l2_down) else: out_l2 = layer(out_l2) del out_l2_down out_l3_down = self.level3_0(out_l2) #16 for i, layer in enumerate(self.level_3): if i == 0: out_l3 = layer(out_l3_down) else: out_l3 = layer(out_l3) del out_l3_down dec_l3_l2 = self.up_l3_l2(out_l3) merge_l2 = self.merge_l2(torch.cat([dec_l3_l2, out_l2], 1)) for i, layer in enumerate(self.dec_l2): if i == 0: dec_l2 = layer(merge_l2) else: dec_l2 = layer(dec_l2) dec_l2_l1 = self.up_l2_l1(dec_l2) merge_l1 = self.merge_l1(torch.cat([dec_l2_l1, out_l1], 1)) for i, layer in enumerate(self.dec_l1): if i == 0: dec_l1 = layer(merge_l1) else: dec_l1 = layer(dec_l1) psp_outs = dec_l1.clone() for layer in self.pspModules: out_psp = layer(dec_l1) psp_outs = torch.cat([psp_outs, out_psp], 1) decoded = self.classifier(psp_outs) return F.upsample(decoded, size=(dim0, dim1, dim2), mode='trilinear') if __name__ == '__main__': channels = 4 bSz = 1 classes = 4 input = torch.FloatTensor(bSz, channels, 80, 80, 80) input_var = Variable(input).cuda() model = ESPNet(classes=classes, channels=channels).eval().cuda() out = model(input_var) print(out.size())