''' * Building units for neural networks: conv23D units, residual units, unet units, upsampling unit and so on. * all kinds of loss functions: softmax, 2d softmax, 3d softmax, dice, multi-organ dice, focal loss, attention based loss... * kinds of test units * First implemented in Dec. 2016, and the latest updation is Dec. 2017. * Dong Nie ''' import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import torch.nn.init as init import torch.autograd as autograd from torch.autograd import Variable from torch.autograd import Function from itertools import repeat ''' To have an easy switch between nn.Conv2d and nn.Conv3d ''' class conv23DUnit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=True, dilation=1, nd=2): super(conv23DUnit, self).__init__() assert nd==1 or nd==2 or nd==3, 'nd is not correctly specified!!!!, it should be {1,2,3}' if nd==2: self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) elif nd==3: self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) else: self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) init.constant(self.conv.bias, 0) def forward(self, x): return self.conv(x) ''' To have an easy switch between nn.Conv2d and nn.Conv3d, together with BN ''' class conv23D_bn_Unit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=True, dilation=1, nd=2): super(conv23D_bn_Unit, self).__init__() assert nd==1 or nd==2 or nd==3, 'nd is not correctly specified!!!!, it should be {1,2,3}' if nd==2: self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm2d(out_channels) elif nd==3: self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm3d(out_channels) else: self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm1d(out_channels) init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) init.constant(self.conv.bias, 0) # self.relu = nn.ReLU() def forward(self, x): return self.bn(self.conv(x)) ''' To have an easy switch between nn.Conv2d and nn.Conv3d, together with BN and relu ''' class conv23D_bn_relu_Unit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, groups=1, bias=True, dilation=1, nd=2): super(conv23D_bn_relu_Unit, self).__init__() assert nd==1 or nd==2 or nd==3, 'nd is not correctly specified!!!!, it should be {1,2,3}' if nd==2: self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm2d(out_channels) elif nd==3: self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm3d(out_channels) else: self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm1d(out_channels) init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) init.constant(self.conv.bias, 0) self.relu = nn.ReLU() def forward(self, x): # print 'x.shape: ',x.shape # xx = self.conv(x) # print 'xx.shape: ', xx.shape return self.relu(self.bn(self.conv(x))) ''' To have an easy switch between nn.ConvTranspose2d and nn.ConvTranspose3d ''' class convTranspose23DUnit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, nd=2): super(convTranspose23DUnit, self).__init__() assert nd==1 or nd==2 or nd==3, 'nd is not correctly specified!!!!, it should be {1,2,3}' if nd==2: self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) elif nd==3: self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) else: self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) init.constant(self.conv.bias, 0) def forward(self, x): return self.conv(x) ''' To have an easy switch between nn.ConvTranspose2d and nn.ConvTranspose3d, together with BN ''' class convTranspose23D_bn_Unit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, nd=2): super(convTranspose23D_bn_Unit, self).__init__() assert nd==1 or nd==2 or nd==3, 'nd is not correctly specified!!!!, it should be {1,2,3}' if nd==2: self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm2d(out_channels) elif nd==3: self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm3d(out_channels) else: self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm1d(out_channels) init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) init.constant(self.conv.bias, 0) # self.relu = nn.ReLU() def forward(self, x): return self.bn(self.conv(x)) ''' To have an easy switch between nn.ConvTranspose2d and nn.ConvTranspose3d, together with BN and relu ''' class convTranspose23D_bn_relu_Unit(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=True, dilation=1, nd=2): super(convTranspose23D_bn_relu_Unit, self).__init__() assert nd==1 or nd==2 or nd==3, 'nd is not correctly specified!!!!, it should be {1,2,3}' if nd==2: self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm2d(out_channels) elif nd==3: self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm3d(out_channels) else: self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation) self.bn = nn.BatchNorm1d(out_channels) init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) init.constant(self.conv.bias, 0) self.relu = nn.ReLU() def forward(self, x): return self.relu(self.bn(self.conv(x))) ''' To have an easy switch between nn.Dropout2d and nn.Dropout3d ''' class dropout23DUnit(nn.Module): def __init__(self, prob=0, nd=2): super(dropout23DUnit, self).__init__() assert nd==1 or nd==2 or nd==3, 'nd is not correctly specified!!!!, it should be {1,2,3}' if nd==2: self.dp = nn.Dropout2d(p=prob) elif nd==3: self.dp = nn.Dropout3d(p=prob) else: self.dp = nn.Dropout(p=prob) def forward(self, x): return self.dp(x) ''' To have an easy switch between nn.maxPool2D and nn.maxPool3D ''' class maxPool23DUinit(nn.Module): def __init__(self, kernel_size, stride, padding=1, dilation=1, nd=2): super(maxPool23DUinit, self).__init__() assert nd==1 or nd==2 or nd==3, 'nd is not correctly specified!!!!, it should be {1,2,3}' if nd==2: self.pool1 = nn.MaxPool2d(kernel_size=kernel_size,stride=stride,padding=padding, dilation=dilation) elif nd==3: self.pool1 = nn.MaxPool3d(kernel_size=kernel_size,stride=stride,padding=padding, dilation=dilation) else: self.pool1 = nn.MaxPool1d(kernel_size=kernel_size,stride=stride,padding=padding, dilation=dilation) def forward(self, x): return self.pool1(x) ''' ordinary conv block ''' class convUnit(nn.Module): def __init__(self, in_size, out_size, kernel_size=3,stride=1, padding=1, activation=F.relu): super(convUnit, self).__init__() self.conv = nn.Conv2d(in_size, out_size, kernel_size, stride, padding) init.xavier_uniform(self.conv.weight, gain = np.sqrt(2.0)) init.constant(self.conv.bias, 0) self.bn = nn.BatchNorm2d(out_size) self.relu = nn.ReLU() def forward(self, x): return self.relu(self.bn(self.conv(x))) ''' two-layer residual unit: two conv without BN and identity mapping ''' class residualUnit(nn.Module): def __init__(self, in_size, out_size, kernel_size=3,stride=1, padding=1, activation=F.relu, nd=2): super(residualUnit, self).__init__() self.conv1 = conv23DUnit(in_size, out_size, kernel_size, stride, padding, nd=nd) # init.xavier_uniform(self.conv1.weight, gain = np.sqrt(2.0)) #or gain=1 # init.constant(self.conv1.bias, 0) self.conv2 = conv23DUnit(out_size, out_size, kernel_size, stride, padding, nd=nd) # init.xavier_uniform(self.conv2.weight, gain = np.sqrt(2.0)) #or gain=1 # init.constant(self.conv2.bias, 0) def forward(self, x): return F.relu(self.conv2(F.elu(self.conv1(x))) + x) ''' two-layer residual unit: two conv with relu and identity mapping ''' class residualUnit1(nn.Module): def __init__(self, in_size, out_size, kernel_size=3,stride=1, padding=1, activation=F.relu, nd=2): super(residualUnit1, self).__init__() self.conv1_bn_relu = conv23D_bn_relu_Unit(in_size, out_size, kernel_size, stride, padding, nd=nd) # self.conv1 = nn.Conv2d(in_size, out_size, kernel_size, stride, padding, bias=False) # init.xavier_uniform(self.conv1.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.conv1.bias, 0) # self.bn1 = nn.BatchNorm2d(out_size) self.relu = nn.ReLU() self.conv2_bn_relu = nn.conv23D_bn_relu_Unit(out_size, out_size, kernel_size, stride, padding, nd=nd) # self.conv2 = nn.Conv2d(out_size, out_size, kernel_size, stride, padding, bias=False) # init.xavier_uniform(self.conv2.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.conv2.bias, 0) # self.bn2 = nn.BatchNorm2d(out_size) def forward(self, x): identity_data = x output = self.conv1_bn_relu(x) output = self.conv2_bn_relu(output) # output = self.relu(self.bn1(self.conv1(x))) # output = self.bn2(self.conv2(output)) output = torch.add(output,identity_data) output = self.relu(output) return output ''' three-layer residual unit: three conv with BN and identity mapping this one doesn't change the size of channels, which means the in_size is same with out_size input: x output: bottleneck residual block By Dong Nie ''' class residualUnit3(nn.Module): def __init__(self, in_size, out_size, isDilation=None, isEmptyBranch1=None, activation=F.relu, nd=2): super(residualUnit3, self).__init__() # mid_size = in_size/2 mid_size = out_size/2 ###I think it should better be half the out size instead of the input size # print 'line 74, in and out size are, ',in_size,' ',mid_size if isDilation: self.conv1_bn_relu = conv23D_bn_relu_Unit(in_channels=in_size, out_channels=mid_size, kernel_size=1, stride=1, padding=0, dilation=2, nd=nd) else: self.conv1_bn_relu = conv23D_bn_relu_Unit(in_channels=in_size, out_channels=mid_size, kernel_size=1, stride=1, padding=0, nd=nd) # init.xavier_uniform(self.conv1.weight, gain=np.sqrt(2.0)) #or gain=1 # # init.constant(self.conv1.bias, 0) # self.bn1 = nn.BatchNorm2d(mid_size) self.relu = nn.ReLU() if isDilation: self.conv2_bn_relu = conv23D_bn_relu_Unit(in_channels=mid_size, out_channels=mid_size, kernel_size=3, stride=1, padding=2, dilation=2, nd=nd) else: self.conv2_bn_relu = conv23D_bn_relu_Unit(in_channels=mid_size, out_channels=mid_size, kernel_size=3, stride=1, padding=1, nd=nd) # init.xavier_uniform(self.conv2.weight, gain=np.sqrt(2.0)) #or gain=1 # # init.constant(self.conv2.bias, 0) # self.bn2 = nn.BatchNorm2d(mid_size) if isDilation: self.conv3_bn = conv23D_bn_Unit(in_channels=mid_size, out_channels=out_size, kernel_size=1, stride=1, padding=0, dilation=2, nd=nd) else: self.conv3_bn = conv23D_bn_Unit(in_channels=mid_size, out_channels=out_size, kernel_size=1, stride=1, padding=0, nd=nd) # init.xavier_uniform(self.conv3.weight, gain=np.sqrt(2.0)) #or gain=1 # # init.constant(self.conv3.bias, 0) # self.bn3 = nn.BatchNorm2d(out_size) self.isEmptyBranch1 = isEmptyBranch1 if in_size!=out_size or isEmptyBranch1==False: if isDilation: self.convX_bn = conv23D_bn_Unit(in_channels=in_size, out_channels=out_size, kernel_size=1, stride=1, padding=0, dilation=2, nd=nd) else: self.convX_bn = conv23D_bn_Unit(in_channels=in_size, out_channels=out_size, kernel_size=1, stride=1, padding=0, nd=nd) # self.bnX = nn.BatchNorm2d(out_size) def forward(self, x): identity_data = x # print 'line 94, size of x is ', x.size() # output = self.relu(self.bn1(self.conv1(x))) # output = self.relu(self.bn2(self.conv2(output))) # output = self.bn3(self.conv3(output)) output = self.conv1_bn_relu(x) output = self.conv2_bn_relu(output) output = self.conv3_bn(output) outSZ = output.size() idSZ = identity_data.size() if outSZ[1]!=idSZ[1] or self.isEmptyBranch1==False: identity_data = self.convX_bn(identity_data) # identity_data = self.bnX(self.convX(identity_data)) # print output.size(), identity_data.size() output = torch.add(output,identity_data) output = self.relu(output) return output ''' long-term residual unit, there is a long way (a lot of convs) before the residual addition By Dong Nie ''' class longResidualUnit(nn.Module): def __init__(self,in_size, out_size, kernel_size=3,stride=1, padding=1, activation=F.relu, nd=2): super(residualUnit1, self).__init__() self.conv1_bn = conv23D_bn_Unit(in_channels=in_size, out_channels=out_size, kernel_size=kernel_size, stride=stride, padding=padding, nd=nd) # init.xavier_uniform(self.conv1.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.conv1.bias, 0) # self.bn1 = nn.BatchNorm2d(out_size) self.relu = nn.ReLU() def forward(self, x): identity_data = x output = self.conv1_bn(x) output = torch.add(output,identity_data) output = self.relu(output) return output ''' Residual upsampling block with long-range residual connection input: x: the current layer you want to consider bridge: the one you want to combine (using summary instead of concatenation) from lower layers output: residual output space_dropout_rate: the rate to make several of the feature maps to be zero (to avoid the correlation within a feature map, we use spatial dropout instead of traditional dropout By Dong Nie ''' class ResUpUnit(nn.Module): def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, spatial_dropout_rate=0, isConvDilation=None, nd=2): super(ResUpUnit, self).__init__() # self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1, bias=True) # init.xavier_uniform(self.up.weight, gain = np.sqrt(2.0)) #or gain=1 # init.constant(self.up.bias, 0) self.nd = nd self.up = convTranspose23D_bn_relu_Unit(in_size, out_size, kernel_size=4, stride=2, padding=1, nd=nd) self.conv = residualUnit3(out_size, out_size, isDilation=isConvDilation, nd=nd) # self.SpatialDroput = nn.SpatialDropout(space_dropout_rate) self.dp = dropout23DUnit(prob=spatial_dropout_rate,nd=nd) # self.dropout2d = nn.Dropout2d(spatial_dropout_rate) self.spatial_dropout_rate = spatial_dropout_rate self.conv2 = residualUnit3(out_size, out_size, isDilation=isConvDilation, isEmptyBranch1=False, nd=nd) # print 'line 147, in_size is ',out_size,' out_size is ',out_size self.relu = nn.ReLU() def center_crop(self, layer, target_size): #we should make it adust to 2d/3d if self.nd ==2: batch_size, n_channels, layer_width, layer_height = layer.size() elif self.nd==3: batch_size, n_channels, layer_width, layer_height, layer_depth = layer.size() xy1 = (layer_width - target_size) // 2 if self.nd==3: return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size), xy1:(xy1 + target_size)] return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)] def forward(self, x, bridge):#bridge is the corresponding lower layer # print 'x.shape: ', x.size() up = self.up(x) # print 'up.shape: ',up.size() # print 'line 158 ',up.size() #crop1 = self.center_crop(bridge, up.size()[2]) # print 'bridge.size: ', bridge.size() crop1 = bridge # crop1_dp = self.SpatialDroput(crop1) if self.spatial_dropout_rate>0: crop1 = self.dp(crop1) out = self.relu(torch.add(up, crop1)) # print 'line 161' out = self.conv(out) # print 'line 161 is ', a.size() # out = self.relu(a) # out = self.relu(self.conv2(out)) out = self.conv2(out) # print 'line 163 is ', out.size() return out ''' Dilated residual module with two residual block (with diltion k) Note, here we keep the same resolution size among successive layers input: x: input feature map k: dilation k output: y ''' class DilatedResUnit(nn.Module): def __init__(self, in_size, out_size, kernel_size=3, stride=1, dilation=2, nd=2): super(DilatedResUnit,self).__init__() self.nd = nd mid_size = out_size/1 padding = dilation*(kernel_size-1)/2 self.conv1_bn_relu = conv23D_bn_relu_Unit(in_channels=in_size, out_channels=mid_size, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, nd=nd) self.conv2_bn_relu = conv23D_bn_relu_Unit(in_channels=mid_size, out_channels=mid_size, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, nd=nd) self.relu = nn.ReLU() def forward(self, x): #dilated module 1 conv1_1 = self.conv1_bn_relu(x) conv1_2 = self.conv2_bn_relu(conv1_1) out1 = torch.add(x,conv1_2) # we should make sure x is same size with conv2 #dilated module 2 conv2_1 = self.conv1_bn_relu(out1) conv2_2 = self.conv2_bn_relu(conv2_1) out = torch.add(conv2_1,conv2_2) return out ''' Basic Residual upsampling block with long-range residual connection. Note we didn't have two short residual blocks after the long-range residual operation. input: x: the current layer you want to consider bridge: the one you want to combine (using summary instead of concatenation) from lower layers output: residual output By Dong Nie ''' class BaseResUpUnit(nn.Module): def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False, nd=2): super(BaseResUpUnit, self).__init__() # self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1, bias=True) # init.xavier_uniform(self.up.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.up.bias, 0) self.nd = nd self.up = convTranspose23D_bn_relu_Unit(in_size, out_size, kernel_size=4, stride=2, padding=1, nd=nd) self.relu = nn.ReLU() def center_crop(self, layer, target_size): #we should make it adust to 2d/3d if self.nd ==2: batch_size, n_channels, layer_width, layer_height = layer.size() elif self.nd==3: batch_size, n_channels, layer_width, layer_height, layer_depth = layer.size() xy1 = (layer_width - target_size) // 2 if self.nd==3: return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size), xy1:(xy1 + target_size)] return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)] def forward(self, x, bridge):#bridge is the corresponding lower layer up = self.up(x) # print 'line 158 ',up.size() # crop1 = self.center_crop(bridge, up.size()[2]) crop1 = bridge out = self.relu(torch.add(up, crop1)) # print 'line 161' # a = self.conv(out) # # print 'line 161 is ', a.size() # out = self.relu(a) # out = self.relu(self.conv2(out)) # print 'line 163 is ', out.size() return out ''' upsample unit: first upsample (directly interpolation, and then conv_bn_relu ''' class upsampleUnit(nn.Module): # Implements resize-convolution def __init__(self, in_channels, out_channels, nd=2): super(upsampleUnit, self).__init__() self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest') self.conv1_bn_relu = conv23D_bn_relu_Unit(in_channels, out_channels, 3, stride=1, padding=1, nd=nd) def forward(self, x): return self.conv1_bn_relu(x) ''' unetConvUnit: actually, a basic unit composed of two-layer convolutional layers ''' class unetConvUnit(nn.Module): def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, nd=2): super(unetConvUnit, self).__init__() # self.conv = nn.Conv2d(in_size, out_size, kernel_size=3, stride=1, padding=1, bias=True) # init.xavier_uniform(self.conv.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.conv.bias, 0) self.conv = conv23DUnit(in_size, out_size, kernel_size=3, stride=1, padding=1, nd=nd) self.conv2 = conv23DUnit(out_size, out_size, kernel_size=3, stride=1, padding=1, nd=nd) # self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, stride=1, padding=1, bias=True) # init.xavier_uniform(self.conv2.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.conv2.bias, 0) self.activation = activation def forward(self, x): out = self.activation(self.conv(x)) out = self.activation(self.conv2(out)) return out ''' unet upsampling block ''' class unetUpUnit(nn.Module): def __init__(self, in_size, out_size, kernel_size=3, activation=F.relu, space_dropout=False, nd=2): super(unetUpUnit, self).__init__() # self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=4, stride=2, padding=1, bias=True) # init.xavier_uniform(self.up.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.up.bias, 0) self.up = convTranspose23DUnit(in_size, out_size, kernel_size=4, stride=2, padding=1, nd=nd) # self.conv = nn.Conv2d(in_size, out_size, kernel_size=3, stride=1, padding=1, bias=True) # init.xavier_uniform(self.conv.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.conv.bias, 0) self.conv = conv23DUnit(in_size, out_size, kernel_size=3, stride=1, padding=1, nd=nd) #has some problem with the in_size # self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, stride=1, padding=1, bias=True) # init.xavier_uniform(self.conv2.weight, gain=np.sqrt(2.0)) #or gain=1 # init.constant(self.conv2.bias, 0) self.conv2 = conv23DUnit(out_size, out_size, kernel_size=3, stride=1, padding=1, nd=nd) self.activation = activation self.nd = nd def center_crop(self, layer, target_size): #we should make it adust to 2d/3d if self.nd ==2: batch_size, n_channels, layer_width, layer_height = layer.size() elif self.nd==3: batch_size, n_channels, layer_width, layer_height, layer_depth = layer.size() xy1 = (layer_width - target_size) // 2 if self.nd==3: return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size), xy1:(xy1 + target_size)] return layer[:, :, xy1:(xy1 + target_size), xy1:(xy1 + target_size)] def forward(self, x, bridge):#bridge is the corresponding lower layer up = self.up(x) # crop1 = self.center_crop(bridge, up.size()[2]) crop1 = bridge out = torch.cat([up, crop1], 1) out = self.activation(self.conv(out)) out = self.activation(self.conv2(out)) return out ''' The weighted cross entropy loss for 3D data, usually used in voxel wise segmentation. input: predict: 5D tensor, even Variable: NxCxHXWXD target: 4D tensor, even Variable: NxHxWxD weight_map: 4D tensor, NxHxWxD output: loss Feb, 2018 By Dong Nie ''' class WeightedCrossEntropy3d(nn.Module): def __init__(self, weight = None, size_average=True, reduce = True, ignore_label=255): '''weight (Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size "nclasses"''' super(WeightedCrossEntropy3d, self).__init__() self.weight = weight self.size_average = size_average self.ignore_label = ignore_label self.nll_loss = nn.NLLLoss(weight, size_average=False, reduce=False) self.logsoftmax = nn.LogSoftmax(dim=1) def forward(self, predict, target, weight_map=None): """ Args: predict:(n, c, h, w, d) target:(n, h, w, d): 0,1,...,C-1 """ assert not target.requires_grad assert predict.dim() == 5 assert target.dim() == 4 assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(2)) assert predict.size(4) == target.size(3), "{0} vs {1} ".format(predict.size(4), target.size(3)) n, c, h, w, d = predict.size() logits = self.logsoftmax(predict) #NxCxWxHxD voxel_loss = self.nll_loss(logits, target) #NxWxHxD weighted_voxel_loss = weight_map*voxel_loss loss = torch.sum(weighted_voxel_loss)/(n*h*w*d) # print 'cross-entropy-loss: ',type(loss) return loss ''' The cross entropy loss for 3D data, usually used in voxel wise segmentation. input: predict: 5D tensor, even Variable: NxCxHXWXD target: 4D tensor, even Variable: NxHxWxD output: loss By Dong Nie ''' class CrossEntropy3d(nn.Module): def __init__(self, weight = None, size_average=True, ignore_label=255): '''weight (Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size "nclasses"''' super(CrossEntropy3d, self).__init__() self.weight = weight self.size_average = size_average self.ignore_label = ignore_label def forward(self, predict, target): """ Args: predict:(n, c, h, w, d) target:(n, h, w, d): 0,1,...,C-1 """ assert not target.requires_grad assert predict.dim() == 5 assert target.dim() == 4 assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(2)) assert predict.size(4) == target.size(3), "{0} vs {1} ".format(predict.size(4), target.size(3)) n, c, h, w, d = predict.size() target_mask = (target >= 0) * (target != self.ignore_label) #actually, it doesn't convert to one-hot format target = target[target_mask] #N*1 predict = predict.transpose(1, 2).transpose(2, 3).transpose(3, 4).contiguous() # n, h, w, d, c predict = predict[target_mask.view(n, h, w, d, 1).repeat(1, 1, 1, 1, c)].view(-1, c) #N*C loss = F.cross_entropy(predict, target, weight = self.weight, size_average = self.size_average) # print 'cross-entropy-loss: ',type(loss) return loss ''' The cross entropy loss for 2D data, usually used in pixel wise segmentation. input: predict: 4D tensor, even Variable target: 3D tensor, even Variable output: loss By Dong Nie ''' class CrossEntropy2d(nn.Module): def __init__(self, weight = None, size_average=True, ignore_label=255): '''weight (Tensor, optional): a manual rescaling weight given to each class. If given, has to be a Tensor of size "nclasses"''' super(CrossEntropy2d, self).__init__() self.weight = weight self.size_average = size_average self.ignore_label = ignore_label def forward(self, predict, target): """ Args: predict:(n, c, h, w) target:(n, h, w): 0,1,...,C-1 """ assert not target.requires_grad assert predict.dim() == 4 assert target.dim() == 3 assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3)) n, c, h, w = predict.size() target_mask = (target >= 0) * (target != self.ignore_label) target = target[target_mask] predict = predict.transpose(1, 2).transpose(2, 3).contiguous() predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) loss = F.cross_entropy(predict, target, weight = self.weight, size_average = self.size_average) # print 'cross-entropy-loss: ',type(loss) return loss ''' The cross entropy loss for 2D dataset (usually used in FCN based segmentation) implemented by using nn.LLLoss2d, we use F.log_softmax to implement log_softmax in 2D. This one is believed to be stable and also faster ''' class CrossEntropyLoss2d(nn.Module): def __init__(self, weight=None, size_average=True): super(CrossEntropyLoss2d, self).__init__() self.nll_loss = nn.NLLLoss2d(weight, size_average) def forward(self, inputs, targets): return self.nll_loss(F.log_softmax(inputs), targets) ''' This criterion is a implementation of Focal Loss, which is proposed in Focal Loss for Dense Object Detection. Now I implement it in the environment of segmentation Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class]) The losses are averaged across observations for each minibatch. Input: class_num: the number of categories alpha(1D Tensor, Variable) : the scalar factor for this criterion gamma: aim at reducing the relative loss for the well classified examples, focusing more on hard, misclassified example size_average: true by default, the losses are averaged over observations for each minibatch, if set by false, it will summaried for each minibatch inputs: a 4D tensor for the predicted segmentation maps (before softmax), NXCXWXH targets: a 3D tensor for the ground truth segmentation maps, NXWXH Output: The averaged losses 10/31/2017 By Dong Nie ''' class myFocalLoss(nn.Module): def __init__(self, class_num, alpha=None, gamma=2, size_average=True): super(myFocalLoss, self).__init__() if alpha is None: self.alpha = Variable(torch.ones(class_num, 1)) else: if isinstance(alpha, Variable): self.alpha = alpha else: self.alpha = Variable(alpha) self.gamma = gamma self.class_num = class_num self.size_average = size_average def forward(self, inputs, targets): assert inputs.dim()==4,'inputs size should be 4: NXCXWXH' N = inputs.size(0) C = inputs.size(1) W = inputs.size(2) H = inputs.size(3) P = F.softmax(inputs,dim=1) ## one hot embeding for the targets ## class_mask = inputs.data.new(N, C, W, H).fill_(0) class_mask = Variable(class_mask) targets = torch.unsqueeze(targets,1) #Nx1xHxW class_mask.scatter_(1, targets, 1) #scatter along the 'numOfDims' dimension # ids = targets.view(-1, 1) # class_mask.scatter_(1, ids.data, 1.) #print(class_mask) if inputs.is_cuda and not self.alpha.is_cuda: self.alpha = self.alpha.cuda() # alpha = self.alpha[ids.data.view(-1)] alpha = 0.25 probs = (P*class_mask).sum(1).view(-1,1) log_p = probs.log() #print('probs size= {}'.format(probs.size())) #print(probs) batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p #print('-----bacth_loss------') #print(batch_loss) if self.size_average: loss = batch_loss.mean() else: loss = batch_loss.sum() return loss ''' Dice cost function for a single organ input: input: a torch variable of size BatchxnclassesxHxW representing probabilities for each class target: a a also tensor, with batchx1xHxW output: Variable scalar loss succeed for two type segmentation ''' def myDiceLoss4Organ(input,target): # assert input.size() == target.size(), "Input sizes must be equal." assert input.dim() == 4 or input.dim() == 5, "Input must be a 4D Tensor or a 5D tensor." eps = Variable(torch.cuda.FloatTensor(1).fill_(0.000001)) one = Variable(torch.cuda.FloatTensor(1).fill_(1.0)) two = Variable(torch.cuda.FloatTensor(1).fill_(2.0)) target1 = Variable(torch.unsqueeze(target.data,1)) #Nx1xHxW or Nx1xHxWxD target_one_hot = Variable(torch.cuda.FloatTensor(input.size()).zero_()) #NxCxHxW or NxCxHxWxD # target_one_hot = target_one_hot.permute(0,2,3,1) #NxHxWxC target_one_hot.scatter_(1, target1, 1) #scatter along the 'numOfDims' dimension uniques=np.unique(target_one_hot.data.cpu().numpy()) assert set(list(uniques))<=set([0,1]), "target must only contain zeros and ones" # print 'line 330: size: ',target_one_hot.size() probs = F.softmax(input,dim=1) #maybe it is not necessary # print 'line 331: size: ',probs.size() target = target_one_hot.contiguous().view(-1,1).squeeze(1) result = probs.contiguous().view(-1,1).squeeze(1) # print 'unique(target): ',unique(target),' unique(result): ',unique(result) # intersect = torch.dot(result, target) #it doesn't support autograd intersect_vec = result * target intersect = torch.sum(intersect_vec) target_sum = torch.sum(target) result_sum = torch.sum(result) union = result_sum + target_sum + (two*eps) # print 'type of union: ',type(union) # the target volume can be empty - so we still want to # end up with a score of 1 if the result is 0/0 IoU = intersect / union # out = torch.add(out, IoU.data*2) dice_total = one - two*IoU # dice_total = -1*torch.sum(dice_eso)/dice_eso.size(0)#divide by batch_sz # print 'type of dice_total: ', type(dice_total) return dice_total ''' This is dice loss for more than one organs, which means you can compute dice loss for more than one organ at a time, input: inputs: predicted segmentation map, tensor type, even Variable targets: real segmentation map, tensor type, even Variable output: loss: Variable scalar succeed for multiple class segmentation problem ''' def myDiceLoss4Organs(inputs, targets): eps = Variable(torch.cuda.FloatTensor(1).fill_(0.000001)) one = Variable(torch.cuda.FloatTensor(1).fill_(1.0)) two = Variable(torch.cuda.FloatTensor(1).fill_(2.0)) inputSZ = inputs.size() #it should be sth like NxCxHxW inputs = F.softmax(inputs,dim=1) _, results_ = inputs.max(1) results = torch.squeeze(results_) #NxHxW numOfCategories = inputSZ[1] ####### Convert categorical to one-hot format targetSZ = results.size() #NxHxW ## We consider NxHxW 3D tensor # result1 = torch.unsqueeze(results, 1) #Nx1xHxW # results_one_hot = Variable(torch.cuda.FloatTensor(inputSZ).zero_()) #NxCxHxW # results_one_hot.scatter_(1,result1,1) #scatter along the 'numOfDims' dimension results_one_hot = inputs target1 = Variable(torch.unsqueeze(targets.data,1)) #Nx1xHxW targets_one_hot = Variable(torch.cuda.FloatTensor(inputSZ).zero_()) #NxCxHxW # targets_one_hot = targets_one_hot.permute(0,2,3,1) #NxHxWxC targets_one_hot.scatter_(1, target1, 1) #scatter along the 'numOfDims' dimension # print 'line 367: one_hot size: ',targets_one_hot.size() ###### Now the prediction and target has become one-hot format ###### Compute the dice for each organ # intersects = Variable(torch.FloatTensor(numOfCategories).zero_()) # unions = Variable(torch.FloatTensor(numOfCategories).zero_()) out = Variable(torch.cuda.FloatTensor(1).zero_(), requires_grad = True) # intersect = Variable(torch.cuda.FloatTensor([1]).zero_(), requires_grad = True) # union = Variable(torch.cuda.FloatTensor([1]).zero_(), requires_grad = True) for organID in range(0, numOfCategories): # target = targets_one_hot[:,organID,:,:].contiguous().view(-1,1).squeeze(1) # result = results_one_hot[:,organID,:,:].contiguous().view(-1,1).squeeze(1) target = targets_one_hot[:,organID,...].contiguous().view(-1,1).squeeze(1) #can be used as 2D/3D result = results_one_hot[:,organID,...].contiguous().view(-1,1).squeeze(1) #can be used as 2D/3D # print 'unique(target): ',unique(target),' unique(result): ',unique(result) # intersect = torch.dot(result, target) intersect_vec = result * target intersect = torch.sum(intersect_vec) # print type(intersect) # binary values so sum the same as sum of squares result_sum = torch.sum(result) # print type(result_sum) target_sum = torch.sum(target) union = result_sum + target_sum + (two*eps) # the target volume can be empty - so we still want to # end up with a score of 1 if the result is 0/0 IoU = intersect / union # out = torch.add(out, IoU.data*2) out = out + one - two*IoU # intersects[organID], unions[organID] = intersect, union # print('organID: {} union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format( # organID, union.data[0], intersect.data[0], target_sum.data[0], result_sum.data[0], IoU.data[0]) denominator = Variable(torch.cuda.FloatTensor(1).fill_(numOfCategories)) out = out / denominator # print type(out) return out ''' Function to calculate the Generalised Dice Loss defined in Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning loss function for highly unbalanced segmentations. DLMIA 2017 The only issue I have is that it may only suits to two types Input: prediction: the logits (before softmax) ground_truth: the segmentation ground truth weight_map: type_weight: type of weighting allowed between labels (choice between Square (square of inverse of volume), Simple (inverse of volume) and Uniform (no weighting)) Output: the loss ''' def generalised_dice_loss(prediction, ground_truth, weight_map=None, type_weight='Square'): ground_truth = tf.to_int64(ground_truth) n_voxels = ground_truth.get_shape()[0].value n_classes = prediction.get_shape()[1].value prediction = tf.nn.softmax(prediction) ids = tf.constant(np.arange(n_voxels), dtype=tf.int64) ids = tf.stack([ids, ground_truth], axis=1) one_hot = tf.SparseTensor(indices=ids, values=tf.ones([n_voxels], dtype=tf.float32), dense_shape=[n_voxels, n_classes]) if weight_map is not None: weight_map_nclasses = tf.reshape( tf.tile(weight_map, [n_classes]), prediction.get_shape()) ref_vol = tf.sparse_reduce_sum( weight_map_nclasses * one_hot, reduction_axes=[0]) intersect = tf.sparse_reduce_sum( weight_map_nclasses * one_hot * prediction, reduction_axes=[0]) seg_vol = tf.reduce_sum( tf.multiply(weight_map_nclasses, prediction), 0) else: ref_vol = tf.sparse_reduce_sum(one_hot, reduction_axes=[0]) intersect = tf.sparse_reduce_sum(one_hot * prediction, reduction_axes=[0]) seg_vol = tf.reduce_sum(prediction, 0) if type_weight == 'Square': weights = tf.reciprocal(tf.square(ref_vol)) elif type_weight == 'Simple': weights = tf.reciprocal(ref_vol) elif type_weight == 'Uniform': weights = tf.ones_like(ref_vol) else: raise ValueError("The variable type_weight \"{}\"" \ "is not defined.".format(type_weight)) new_weights = tf.where(tf.is_inf(weights), tf.zeros_like(weights), weights) weights = tf.where(tf.is_inf(weights), tf.ones_like(weights) * tf.reduce_max(new_weights), weights) generalised_dice_numerator = \ 2 * tf.reduce_sum(tf.multiply(weights, intersect)) generalised_dice_denominator = \ tf.reduce_sum(tf.multiply(weights, seg_vol + ref_vol)) generalised_dice_score = \ generalised_dice_numerator / generalised_dice_denominator return 1 - generalised_dice_score ''' This is dice loss for one organ, which means you can compute dice for more than one organs at a time, but it adapts to 0/1/2/3... (foreground,organ1,organ2,... problem) input: inputs: predicted segmentation map, tensor type, Variable (NxCxHxW) targets: real segmentation map, tensor type, Variable (NxHxW) output: loss: Variable scalar succeed for multiple organs ''' class myWeightedDiceLoss4Organs(nn.Module): def __init__(self, organIDs = [1], organWeights=[1]): super(myWeightedDiceLoss4Organs, self).__init__() self.organIDs = organIDs self.organWeights = organWeights # pass def forward(self, inputs, targets, save=True): """ Args: inputs:(n, c, h, w, d) targets:(n, h, w, d): 0,1,...,C-1 """ assert not targets.requires_grad assert inputs.dim() == 5, inputs.shape assert targets.dim() == 4, targets.shape assert inputs.size(0) == targets.size(0), "{0} vs {1} ".format(inputs.size(0), targets.size(0)) assert inputs.size(2) == targets.size(1), "{0} vs {1} ".format(inputs.size(2), targets.size(1)) assert inputs.size(3) == targets.size(2), "{0} vs {1} ".format(inputs.size(3), targets.size(2)) assert inputs.size(4) == targets.size(3), "{0} vs {1} ".format(inputs.size(4), targets.size(3)) eps = Variable(torch.cuda.FloatTensor(1).fill_(0.000001)) one = Variable(torch.cuda.FloatTensor(1).fill_(1.0)) two = Variable(torch.cuda.FloatTensor(1).fill_(2.0)) inputSZ = inputs.size() #it should be sth like NxCxHxW inputs = F.softmax(inputs, dim=1) # _, results_ = inputs.max(1) # results = torch.squeeze(results_) #NxHxW numOfCategories = inputSZ[1] assert numOfCategories==len(self.organWeights), 'organ weights is not matched with organs (bg should be included)' ####### Convert categorical to one-hot format results_one_hot = inputs target1 = Variable(torch.unsqueeze(targets.data,1)) #Nx1xHxW targets_one_hot = Variable(torch.cuda.FloatTensor(inputSZ).zero_()) #NxCxHxW targets_one_hot.scatter_(1, target1, 1) #scatter along the 'numOfDims' dimension ###### Now the prediction and target has become one-hot format ###### Compute the dice for each organ out = Variable(torch.cuda.FloatTensor(1).zero_(), requires_grad = True) # intersect = Variable(torch.cuda.FloatTensor([1]).zero_(), requires_grad = True) # union = Variable(torch.cuda.FloatTensor([1]).zero_(), requires_grad = True) for organID in range(0, numOfCategories): # target = targets_one_hot[:,organID,:,:].contiguous().view(-1,1).squeeze(1) # result = results_one_hot[:,organID,:,:].contiguous().view(-1,1).squeeze(1) target = targets_one_hot[:,organID,...].contiguous().view(-1,1).squeeze(1) #for 2D or 3D result = results_one_hot[:,organID,...].contiguous().view(-1,1).squeeze(1) #for 2D or 3D # print 'unique(target): ',unique(target),' unique(result): ',unique(result) # intersect = torch.dot(result, target) intersect_vec = result * target intersect = torch.sum(intersect_vec) # print type(intersect) # binary values so sum the same as sum of squares result_sum = torch.sum(result) # print type(result_sum) target_sum = torch.sum(target) union = result_sum + target_sum + (two*eps) # the target volume can be empty - so we still want to # end up with a score of 1 if the result is 0/0 IoU = intersect / union # out = torch.add(out, IoU.data*2) out = out + self.organWeights[organID] * (one - two*IoU) # print('organID: {} union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format( # organID, union.data[0], intersect.data[0], target_sum.data[0], result_sum.data[0], IoU.data[0]) denominator = Variable(torch.cuda.FloatTensor(1).fill_(sum(self.organWeights))) # denominator = Variable(torch.cuda.FloatTensor(1).fill_(numOfCategories)) out = out / denominator # print type(out) return out ''' This is generalized dice loss for organs, which means you can compute dice for more than one organ at a time input: inputs: predicted segmentation map, tensor type, even Variable targets: real segmentation map, tensor type, even Variable output: loss: scalar Sep, 2017, let's test, succeed ''' class GeneralizedDiceLoss4Organs(nn.Module): def __init__(self, organIDs = [1], size_average=True): super(GeneralizedDiceLoss4Organs,self).__init__() self.organIDs = organIDs self.size_average = size_average def forward(self, inputs, targets, save=True): """ Args: inputs:(n, c, h, w, d) targets:(n, h, w, d): 0,1,...,C-1 """ assert not targets.requires_grad assert inputs.dim() == 5, inputs.shape assert targets.dim() == 4, targets.shape assert inputs.size(0) == targets.size(0), "{0} vs {1} ".format(inputs.size(0), targets.size(0)) assert inputs.size(2) == targets.size(1), "{0} vs {1} ".format(inputs.size(2), targets.size(1)) assert inputs.size(3) == targets.size(2), "{0} vs {1} ".format(inputs.size(3), targets.size(2)) assert inputs.size(4) == targets.size(3), "{0} vs {1} ".format(inputs.size(4), targets.size(3)) eps = Variable(torch.cuda.FloatTensor(1).fill_(0.000001)) one = Variable(torch.cuda.FloatTensor(1).fill_(1.0)) two = Variable(torch.cuda.FloatTensor(1).fill_(2.0)) inputSZ = inputs.size() #it should be sth like NxCxHxW inputs = F.softmax(inputs, dim=1) numOfCategories = inputSZ[1] assert numOfCategories==len(self.organIDs), 'organ weights is not matched with organs (bg should be included)' ####### Convert categorical to one-hot format results_one_hot = inputs target1 = Variable(torch.unsqueeze(targets.data,1)) #Nx1xHxW targets_one_hot = Variable(torch.cuda.FloatTensor(inputSZ).zero_()) #NxCxHxW targets_one_hot.scatter_(1, target1, 1) #scatter along the 'numOfDims' dimension ###### Now the prediction and target has become one-hot format ###### Compute the dice for each organ out = Variable(torch.cuda.FloatTensor(1).zero_(), requires_grad = True) # intersect = Variable(torch.cuda.FloatTensor([1]).zero_(), requires_grad = True) # union = Variable(torch.cuda.FloatTensor([1]).zero_(), requires_grad = True) intersect = Variable(torch.cuda.FloatTensor(1).fill_(0.0)) union = Variable(torch.cuda.FloatTensor(1).fill_(0.0)) for organID in range(0, numOfCategories): target = targets_one_hot[:,organID,...].contiguous().view(-1,1).squeeze(1) #for 2D or 3D result = results_one_hot[:,organID,...].contiguous().view(-1,1).squeeze(1) #for 2D or 3D # print 'unique(target): ',unique(target),' unique(result): ',unique(result) if torch.sum(target).cpu().data[0] == 0: organWeight = Variable(torch.cuda.FloatTensor(1).fill_(0.0)) # this is necessary, otherwise, union can be too big due to too big organ weight if some organ doesnot appear else: organWeight = 1/((torch.sum(target))**2+eps) # print 'sum: %d'%torch.sum(target),' organWeight: %f'%organWeight # intersect = torch.dot(result, target) intersect_vec = result * target intersect = intersect + organWeight*torch.sum(intersect_vec) # print type(intersect) # binary values so sum the same as sum of squares result_sum = torch.sum(result) # print type(result_sum) target_sum = torch.sum(target) union = union + organWeight*(result_sum + target_sum) + (two*eps) # the target volume can be empty - so we still want to # end up with a score of 1 if the result is 0/0 IoU = intersect / union # out = torch.add(out, IoU.data*2) out = one - two*IoU # print type(out) return out ''' This is dice loss for one organ, which means you can only compute dice for one organ at a time, but it only adapts to 0/1 (foreground,background problem) Note: we need to implement backward manually input: inputs: predicted segmentation map, tensor type, even Variable targets: real segmentation map, tensor type, even Variable output: loss: scalar so far, not succeed ''' class WeightedDiceLoss4Organs(Function): def __init__(self, *args, **kwargs): self.numOfCategories = 2 # pass def forward(self, inputs, targets, save=True): if save: self.save_for_backward(inputs, targets) eps = 0.000001 inputSZ = inputs.size() #it should be sth like NxCxHxW # print 'line 336: inputs size: ',inputSZ # print 'line 337: input: ', inputs[1,1,1,1] _, results_ = inputs.max(1) # print 'line 339: results_ size: ',results_.size() results_ = torch.squeeze(results_) #NxHxW # print 'line 338: unique(results): ',unique(results_) self.inputs = torch.cuda.FloatTensor(inputs.size()) self.inputs.copy_(inputs) # if inputs.is_cuda: # results = torch.cuda.FloatTensor(results_.size()) # self.targets_ = torch.cuda.FloatTensor(targets.size()) # else: # results = torch.FloatTensor(results_.size()) # self.targets_ = torch.FloatTensor(targets.size()) results = torch.cuda.LongTensor(results_.size()) self.targets_ = torch.cuda.LongTensor(targets.size()) results.copy_(results_) #NxHxW self.targets_.copy_(targets) targets = self.targets_ #NxHxW self.numOfCategories = inputSZ[1] ####### Convert categorical to one-hot format targetSZ = results.size() #NxHxW numOfDims = len(targetSZ) ## We consider NxHxW 3D tensor result1 = torch.unsqueeze(results, numOfDims) #NxHxWx1 self.results_one_hot = torch.cuda.FloatTensor(inputSZ).zero_() #NxCxHxW self.results_one_hot = self.results_one_hot.permute(0,2,3,1) #NxHxWxC self.results_one_hot.scatter_(numOfDims,result1,1) #scatter along the 'numOfDims' dimension # print 'line 361: one_hot size: ',self.results_one_hot.size() target1 = torch.unsqueeze(targets,numOfDims) #NxHxWx1 self.targets_one_hot = torch.cuda.FloatTensor(inputSZ).zero_() #NxCxHxW self.targets_one_hot = self.targets_one_hot.permute(0,2,3,1) #NxHxWxC self.targets_one_hot.scatter_(numOfDims,target1,1) #scatter along the 'numOfDims' dimension # print 'line 367: one_hot size: ',self.targets_one_hot.size() ###### Now the prediction and target has become one-hot format ###### Compute the dice for each organ # out = torch.FloatTensor(self.numOfCategories).zero_() sumOut = 0.0 self.intersect = torch.FloatTensor(self.numOfCategories).zero_() self.union = torch.FloatTensor(self.numOfCategories).zero_() for organID in range(0, self.numOfCategories): target = self.targets_one_hot[...,organID].contiguous().view(-1,1).squeeze(1) result = self.results_one_hot[...,organID].contiguous().view(-1,1).squeeze(1) # print 'unique(target): ',unique(target),' unique(result): ',unique(result) intersect = torch.dot(result, target) # binary values so sum the same as sum of squares result_sum = torch.sum(result) target_sum = torch.sum(target) union = result_sum + target_sum + (2*eps) # the target volume can be empty - so we still want to # end up with a score of 1 if the result is 0/0 IoU = intersect / union # print('organID: {} union: {:.3f}\t intersect: {:.6f}\t target_sum: {:.0f} IoU: result_sum: {:.0f} IoU {:.7f}'.format( # organID, union, intersect, target_sum, result_sum, 2*IoU)) # out = torch.FloatTensor(1).fill_(2*IoU) # out[organID].fill_(2*IoU) sumOut = sumOut + IoU * 2 self.intersect[organID], self.union[organID] = intersect, union out = torch.cuda.FloatTensor(1).fill_(sumOut) return out def backward(self, grad_output): inputs, _ = self.saved_tensors # we need probabilities for input targets = self.targets_one_hot # we need binary for targets intersects, unions = self.intersect, self.union # print 'targets size: ',targets.size(),'unions size: ',unions.size(),'intersects size: ',intersects.size() for i in range(0,self.numOfCategories): input = inputs[:,i,...] target = targets[...,i] union = unions[i] intersect = intersects[i] gt = torch.div(target, union) IoU2 = intersect/(union*union) print 'line 419: IoU2: ',IoU2 # pred = torch.mul(input[:, 1], IoU2) #input[:,1] is equal to input[:,1,...] pred = torch.mul(input, IoU2) #input[:,1] is equal to input[:,1,...] print 'line 423: input: ',input.cpu()[1,1,1,...] print 'line 423: pred: ',pred.cpu()[1,1,1,...] # print 'gt size: ',gt.size(),' pred size: ',pred.size() print 'line 423: gt: ', gt.cpu()[1,1,1] dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4)) if i==0: prev = torch.mul(dDice, grad_output[0]) else: curr = torch.mul(dDice, -grad_output[0]) grad_input = torch.cat((prev,curr), 0) prev = curr print 'line 429: grad_output: ',grad_output.cpu() print 'line 430: grad_input: ', grad_input.cpu()[1,1,1,...] return grad_input , None ''' topK_Loss for regression problems. We only calculate the top K largest losses for all the elements ''' class topK_RegLoss(nn.Module): def __init__(self, topK, size_average=True): super(topK_RegLoss, self).__init__() self.size_average = size_average self.topK = topK def forward(self, preds, targets): """ Args: inputs:(n, h, w, d) targets:(n, h, w, d) """ assert not targets.requires_grad assert preds.shape == targets.shape,'dim of preds and targets are different' K = torch.abs(preds - targets).view(-1) # # base = targets.view(-1) # # percent = K/base # base[percent>thresold] # top 30% of 384x384 of 4 channels, be careful! K also has batches (used 8)! So actually it is (44236*4.0)/(384*384*2.0*8) = 0.075 # 448*448*2.0*self.opt.batchSize*0.075 if len(preds.shape)==4: V, I = torch.topk(K, int(preds.size(0) * preds.size(1) * preds.size(2) * preds.size(3) * self.topK), largest=True, sorted=True) else: V, I = torch.topk(K, int(preds.size(0) * preds.size(1) * preds.size(2) * self.topK), largest=True, sorted=True) loss = torch.mean(V) return loss ''' topK_Loss for regression problems. We only calculate those losses larger than the specified threshold for all the elements ''' class RelativeThreshold_RegLoss(nn.Module): def __init__(self, threshold, size_average=True): super(RelativeThreshold_RegLoss, self).__init__() self.size_average = size_average self.eps = 1e-7 self.threshold = threshold def forward(self, preds, targets): """ Args: inputs:(n, h, w, d) targets:(n, h, w, d) """ assert not targets.requires_grad assert preds.shape == targets.shape,'dim of preds and targets are different' dist = torch.abs(preds - targets).view(-1) # baseV = targets.view(-1) baseV = torch.abs(baseV + self.eps) relativeDist = torch.div(dist, baseV) mask = relativeDist.ge(self.threshold) largerLossVec = torch.masked_select(dist, mask) loss = torch.mean(largerLossVec) return loss def unique(tensor1d): t, idx = np.unique(tensor1d.cpu().numpy(), return_inverse=True) return t ''' This is dice loss for organs, which means you can compute dice for more than one organ at a time Warning: The backward cannot done automatically, and this kind of dice expansion to multi-class is not correct Thus, I rewrite it. input: inputs: predicted segmentation map, tensor type, even Variable targets: real segmentation map, tensor type, even Variable output: loss: scalar so far, not succeed ''' class DiceLoss4Organs(Function): def __init__(self, organIDs = [1], organWeights=[1], size_average=True): super(DiceLoss4Organs,self).__init__() self.organIDs = organIDs self.organWeights = organWeights self.size_average = size_average def forward(self, inputs, targets): # loss = 0.0 # loss = torch.autograd.Variable(torch.Tensor(1), requires_grad=True) eps = 0.000001 totalWeight = np.sum(self.organWeights) _,results = inputs.max(1) #maximize the channels results = torch.squeeze(results) sz =targets.size() singleSZ = targets[0,...].size() # print type(sz),singleSZ diceLoss = 0.0 loss = np.zeros(sz[0]) # print '..... target size: ',targets.size() # print '..... results size: ',results.size() # print type(inputs) for k in range(0,sz[0]): result_ = torch.FloatTensor(singleSZ) target_ = torch.FloatTensor(singleSZ) result_.copy_(results.data[k]) target_.copy_(targets.data[k]) for ind in range(0,len(self.organIDs)): id = self.organIDs[ind] #get the single organ representation # indI = np.where(inputs==id) # indI = inputs:eq(id) result = torch.zeros(result_.size()) result[result_==id] = 1 # indT = np.where(targets==id) # indT = targets:eq(id) target = torch.zeros(target_.size()) target[target_==id] = 1 # iflat = input.view(-1) # tflat = target.view(-1) intersect = torch.dot(result,target) #no need to flatern input or target intersect = np.max([eps, intersect]) #scalar value # print '......intersect is: ', intersect input_sum = torch.sum(result) target_sum = torch.sum(target) union = input_sum + target_sum + 2*eps #not the real union # print '.......union is: ', union # print '..........dice is: ', 1.0*self.organWeights[ind]/totalWeight * (2.0 * intersect / union) loss[k] += 1.0 * self.organWeights[ind]/totalWeight*(1.0 - (2.0 * intersect / union)) # loss = loss/len(self.organIDs) # print type(loss) # print 'loss shape: ', loss.shape,' loss is ',loss if self.size_average: diceLoss = loss.mean() else: diceLoss = loss.sum() diceLoss = torch.cuda.FloatTensor(1).fill_(1*diceLoss) diceLoss = Variable(diceLoss,requires_grad=True) # print 'haha2',type(loss) return diceLoss def backward(self, grad_ouput): input, _ = self.saved_tensors intersect, union = self.intersect, self.union target = self.target_ gt = torch.div(target, union) IoU2 = intersect/(union*union) pred = torch.mul(input[:, 1], IoU2) dDice = torch.add(torch.mul(gt, 2), torch.mul(pred, -4)) grad_input = torch.cat((torch.mul(dDice, -grad_output[0]), torch.mul(dDice, grad_output[0])), 0) return grad_input , None ''' GDL loss for the image reconstruction Note, this copy of gdl loss is successful. It may not be able to pass the gradcheck, that's becuase of the float/double issue, but it is fine. If we change the filter to torch.DoubleTensor, and also cast the pred and gt as double tensor, it will pass the grad check. However, float is fine with the real application. By Dong Nie ''' class gdl_loss(nn.Module): def __init__(self, pNorm=2): super(gdl_loss, self).__init__() self.convX = nn.Conv2d(1, 1, kernel_size=(1, 2), stride=1, padding=(0, 1), bias=False) self.convY = nn.Conv2d(1, 1, kernel_size=(2, 1), stride=1, padding=(1, 0), bias=False) filterX = torch.FloatTensor([[[[-1, 1]]]]) # 1x2 filterY = torch.FloatTensor([[[[1], [-1]]]]) # 2x1 self.convX.weight = torch.nn.Parameter(filterX,requires_grad=False) self.convY.weight = torch.nn.Parameter(filterY,requires_grad=False) self.pNorm = pNorm def forward(self, pred, gt): assert not gt.requires_grad assert pred.dim() == 4 assert gt.dim() == 4 assert pred.size() == gt.size(), "{0} vs {1} ".format(pred.size(), gt.size()) pred_dx = torch.abs(self.convX(pred)) pred_dy = torch.abs(self.convY(pred)) gt_dx = torch.abs(self.convX(gt)) gt_dy = torch.abs(self.convY(gt)) grad_diff_x = torch.abs(gt_dx - pred_dx) grad_diff_y = torch.abs(gt_dy - pred_dy) mat_loss_x = grad_diff_x ** self.pNorm mat_loss_y = grad_diff_y ** self.pNorm # Batch x Channel x width x height shape = gt.shape mean_loss = (torch.sum(mat_loss_x) + torch.sum(mat_loss_y)) / (shape[0] * shape[1] * shape[2] * shape[3]) return mean_loss ''' Wasserstein loss ''' def Wasserstein_Distance(D_real, D_fake): Wasserstein_Dist = D_real - D_fake return Wasserstein_Dist ''' calculate gradient penalty for wgan ''' def calc_gradient_penalty(netD, real_data, fake_data): #print real_data.size() batch_size = real_data.shape[0] alpha = torch.randn(batch_size, 1,1,1) alpha = alpha.expand(real_data.size()) #alpha = alpha.cuda(gpu) if use_cuda else alpha alpha = alpha.cuda() interpolates = alpha * real_data + ((1 - alpha) * fake_data) # if use_cuda: # interpolates = interpolates.cuda(gpu) interpolates = interpolates.cuda() interpolates = autograd.Variable(interpolates, requires_grad=True) disc_interpolates = netD(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs=torch.ones(disc_interpolates.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty ''' compute the attention weights: for prob belongs to [0.3,0.7] we use positive weights; and we use negative weights belonging to [0,0.3) and (0.7,1] Inputs: prob: probability Outputs: res: attention we compute ''' def computeAttentionWeight(prob): if prob>0.5: prob = 1 - prob res = 10.0/3*prob*prob + 7.0/3*prob -1 if prob<0.3: res = 0 return res ''' compute the attention weights: for prob belongs to [0.3,0.7] we use negative weights (easy samples for the segmenter, hard samples for the discriminator); and we use positive weights belonging to [0,0.3) and (0.7,1] (hard samples for the segmenter, easy samples for the discriminator). Inputs: prob: probability Outputs: res: attention we compute ''' def computeSampleAttentionWeight(prob): res = 12.5*(prob-0.5)*(prob-0.5)-0.5 return res ''' compute the attention weights: difficulty-aware or confidence-aware voxel-wise loss here is sth like focal loss, for the high confidence region, we should retain the model, if it is low confidence region, then we should pay more attention on these regions to train the model Inputs: prob: probability = 1 - confidence probability Outputs: res: attention we compute ''' def computeVoxelAttentionWeight(prob): res = prob*prob return res class FeatureExtractor(nn.Module): def __init__(self, cnn, feature_layer=8): super(FeatureExtractor, self).__init__() self.features = nn.Sequential(*list(cnn.features.children())[:(feature_layer+1)]) def forward(self, x): return self.features(x) ''' Sets the learning rate to the initial LR decayed by 10 every 10 epochs ''' def adjust_learning_rate(optimizer, lr): # lr = opt.lr * (0.1 ** (epoch // opt.step)) for param_group in optimizer.param_groups: print "current lr is ", param_group["lr"] if param_group["lr"] > lr: param_group["lr"] = lr return lr