Python torch.nn.Softmax2d() Examples

The following are 18 code examples of torch.nn.Softmax2d(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn , or try the search function .
Example #1
Source File: attention.py    From Conditional-Batch-Norm with MIT License 6 votes vote down vote up
def __init__(self, config):
		super(Attention, self).__init__()

		self.mlp_units = config['model']['image']['attention']['no_attention_mlp'] # hidden layer units of the MLP

		# MLP for concatenated feature map and question embedding
		self.fc = nn.Sequential(
            nn.Linear(3072, self.mlp_units),
            nn.ReLU(inplace=True),
            nn.Linear(self.mlp_units, 1),
            nn.ReLU(inplace=True),
            ).cuda()

		self.softmax = nn.Softmax2d() # to get the probablity values across the height and width of feature map

		for m in self.modules():
			if isinstance(m, nn.Linear):
				nn.init.xavier_uniform(m.weight)
				nn.init.constant(m.bias, 0.1)

		self.batch_size = None
		self.channels = None
		self.height = None
		self.width = None
		self.len_emb = None 
Example #2
Source File: net10a.py    From IIC with MIT License 6 votes vote down vote up
def __init__(self, config, output_k, cfg):
    super(SegmentationNet10aHead, self).__init__()

    self.batchnorm_track = config.batchnorm_track

    self.cfg = cfg
    num_features = self.cfg[-1][0]

    self.num_sub_heads = config.num_sub_heads

    self.heads = nn.ModuleList([nn.Sequential(
      nn.Conv2d(num_features, output_k, kernel_size=1,
                stride=1, dilation=1, padding=1, bias=False),
      nn.Softmax2d()) for _ in xrange(self.num_sub_heads)])

    self.input_sz = config.input_sz 
Example #3
Source File: loss.py    From CLAN with MIT License 6 votes vote down vote up
def forward(self, predict, target):
        N, C, H, W = predict.size()
        sm = nn.Softmax2d()
        
        P = sm(predict)
        P = torch.clamp(P, min = 1e-9, max = 1-(1e-9))
        
        target_mask = (target >= 0) * (target != self.ignore_label)
        target = target[target_mask].view(1, -1)
        predict = P[target_mask.view(N, 1, H, W).repeat(1, C, 1, 1)].view(C, -1)
        probs = torch.gather(predict, dim = 0, index = target)
        log_p = probs.log()
        batch_loss = -(torch.pow((1-probs), self.gamma))*log_p 

        if self.size_average:
            loss = batch_loss.mean()
        else:
            
            loss = batch_loss.sum()
        return loss 
Example #4
Source File: models.py    From UNet-Zoo with MIT License 5 votes vote down vote up
def __init__(self, num_channels=1, num_classes=2):
        super(UNet, self).__init__()
        num_feat = [64, 128, 256, 512, 1024]

        self.down1 = nn.Sequential(Conv3x3(num_channels, num_feat[0]))

        self.down2 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Conv3x3(num_feat[0], num_feat[1]))

        self.down3 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Conv3x3(num_feat[1], num_feat[2]))

        self.down4 = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                   Conv3x3(num_feat[2], num_feat[3]))

        self.bottom = nn.Sequential(nn.MaxPool2d(kernel_size=2),
                                    Conv3x3(num_feat[3], num_feat[4]))

        self.up1 = UpConcat(num_feat[4], num_feat[3])
        self.upconv1 = Conv3x3(num_feat[4], num_feat[3])

        self.up2 = UpConcat(num_feat[3], num_feat[2])
        self.upconv2 = Conv3x3(num_feat[3], num_feat[2])

        self.up3 = UpConcat(num_feat[2], num_feat[1])
        self.upconv3 = Conv3x3(num_feat[2], num_feat[1])

        self.up4 = UpConcat(num_feat[1], num_feat[0])
        self.upconv4 = Conv3x3(num_feat[1], num_feat[0])

        self.final = nn.Sequential(nn.Conv2d(num_feat[0],
                                             num_classes,
                                             kernel_size=1),
                                   nn.Softmax2d()) 
Example #5
Source File: CLSTM.py    From UNet-Zoo with MIT License 5 votes vote down vote up
def __init__(self, input_channels=64, hidden_channels=[64],
                 kernel_size=5, bias=True, num_classes=2):

        super(BDCLSTM, self).__init__()
        self.forward_net = CLSTM(
            input_channels, hidden_channels, kernel_size, bias)
        self.reverse_net = CLSTM(
            input_channels, hidden_channels, kernel_size, bias)
        self.conv = nn.Conv2d(
            2 * hidden_channels[-1], num_classes, kernel_size=1)
        self.soft = nn.Softmax2d()

    # Forward propogation
    # x --> BatchSize x NumChannels x Height x Width
    #       BatchSize x 64 x 240 x 240 
Example #6
Source File: sinet.py    From imgclsmob with MIT License 5 votes vote down vote up
def __init__(self,
                 channels,
                 bn_eps):
        super(SBDecodeBlock, self).__init__()
        self.up = InterpolationBlock(
            scale_factor=2,
            align_corners=False)
        self.bn = nn.BatchNorm2d(
            num_features=channels,
            eps=bn_eps)
        self.conf = nn.Softmax2d() 
Example #7
Source File: deeplab_v2.py    From SceneChangeDet with MIT License 5 votes vote down vote up
def __init__(self,norm_flag = 'l2'):
        super(SiameseNet, self).__init__()
        self.CNN = deeplab_V2()
        if norm_flag == 'l2':
           self.norm = fun.l2normalization(scale=1)
        if norm_flag == 'exp':
            self.norm = nn.Softmax2d() 
Example #8
Source File: gcnet.py    From DSMnet with Apache License 2.0 5 votes vote down vote up
def __init__(self, num_F=32):
        super(feature3d, self).__init__()
        self.F = num_F

        self.l19 = conv3d_bn(self.F*2, self.F, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l20 = conv3d_bn(self.F,   self.F, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)

        self.l21 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=2, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l22 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l23 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)

        self.l24 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=2, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l25 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l26 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)

        self.l27 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=2, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l28 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l29 = conv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)

        self.l30 = conv3d_bn(self.F*2, self.F*4, kernel_size=3, stride=2, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l31 = conv3d_bn(self.F*4, self.F*4, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l32 = conv3d_bn(self.F*4, self.F*4, kernel_size=3, stride=1, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)

        self.l33 = deconv3d_bn(self.F*4, self.F*2, kernel_size=3, stride=2, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l34 = deconv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=2, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l35 = deconv3d_bn(self.F*2, self.F*2, kernel_size=3, stride=2, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l36 = deconv3d_bn(self.F*2, self.F, kernel_size=3, stride=2, flag_bias=flag_bias_t, bn=flag_bn, activefun=activefun_t)
        self.l37 = deconv3d_bn(self.F, 1, kernel_size=3, stride=2, bn=False, activefun=None)
        self.softmax = nn.Softmax2d()
#        self.m = nn.Upsample(scale_factor=2, mode='bilinear') 
Example #9
Source File: srresnet.py    From 3D_Appearance_SR with MIT License 5 votes vote down vote up
def __init__(self, conv=common.default_conv, n_feats=64, kernel_size=3, reg_act=nn.Softplus(), rescale=1, norm_f=None):
        super(JointAttention, self).__init__()
        mask_conv = [nn.Conv2d(n_feats, 16, kernel_size=kernel_size, stride=4, padding=kernel_size//2), nn.PReLU()]
        mask_deconv = nn.ConvTranspose2d(16, n_feats, kernel_size=kernel_size, stride=4, padding=1)
        mask_deconv_act = nn.Softmax2d()
        conv_body = [conv(n_feats, n_feats, kernel_size), nn.PReLU()]
        self.mask_conv = nn.Sequential(*mask_conv)
        self.mask_deconv = mask_deconv
        self.mask_deconv_act = mask_deconv_act
        # self.ca = CALayer(n_feats)
        self.conv_body = nn.Sequential(*conv_body) 
Example #10
Source File: output.py    From pytorch-mono-depth with MIT License 5 votes vote down vote up
def __init__(self, num_gaussians):
        super().__init__()
        self.num_gaussians = num_gaussians
        self.num_channels = 2 * num_gaussians + 1
        self.softmax = nn.Softmax2d() 
Example #11
Source File: utils.py    From deepsaber with GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, weights=None, num_class=3):
        super(MultiLabelSoftDiceLoss, self).__init__()
        if num_class>1:
            self.sm = nn.Softmax2d()
        else:
            self.sm = nn.Sigmoid()
        self.weights = nn.Parameter(torch.from_numpy(np.array(weights) or np.array([1 for i in range(num_class)])).type(torch.FloatTensor),
                        requires_grad=False) 
Example #12
Source File: SINet.py    From ext_portrait_segmentation with MIT License 5 votes vote down vote up
def forward(self, input):
        '''
        :param input: RGB image
        :return: transformed feature map
        '''
        output1 = self.encoder.level1(input)  # 8h 8w
        output2_0 = self.encoder.level2_0(output1)  # 4h 4w

        # print(str(output1_0.size()))
        for i, layer in enumerate(self.encoder.level2):
            if i == 0:
                output2 = layer(output2_0)
            else:
                output2 = layer(output2)  # 2h 2w

        output3_0 = self.encoder.level3_0(self.encoder.BR2(torch.cat([output2_0, output2], 1)))  # h w
        # print(str(output2_0.size()))

        for i, layer in enumerate(self.encoder.level3):
            if i == 0:
                output3 = layer(output3_0)
            else:
                output3 = layer(output3)

        output3_cat = self.encoder.BR3(torch.cat([output3_0, output3], 1))
        Enc_final = self.encoder.classifier(output3_cat) #1/8

        Dnc_stage1 = self.bn_3(self.up(Enc_final))  # 1/4
        stage1_confidence = torch.max(nn.Softmax2d()(Dnc_stage1), dim=1)[0]
        b, c, h, w = Dnc_stage1.size()
        # TH = torch.mean(torch.median(stage1_confidence.view(b,-1),dim=1)[0])

        stage1_gate = (1-stage1_confidence).unsqueeze(1).expand(b, c, h, w)

        Dnc_stage2_0 = self.level2_C(output2)  # 2h 2w
        Dnc_stage2 = self.bn_2(self.up(Dnc_stage2_0 * stage1_gate + (Dnc_stage1)))  # 4h 4w

        classifier = self.classifier(Dnc_stage2)


        return classifier 
Example #13
Source File: oth_sinet.py    From imgclsmob with MIT License 4 votes vote down vote up
def forward(self, input, train=False):
        '''
        :param input: RGB image
        :return: transformed feature map
        '''
        output1 = self.encoder.level1(input)  # 8h 8w

        output2_0 = self.encoder.level2_0(output1)  # 4h 4w
        output3_0 = self.encoder.level3_0(output2_0)  # 2h 2w

        # print(str(output1_0.size()))
        for i, layer in enumerate(self.encoder.level3):
            if i == 0:
                output3 = layer(output3_0)
            else:
                output3 = layer(output3)  # 2h 2w

        output4_0 = self.encoder.level4_0(self.encoder.BR3(torch.cat([output3_0, output3], 1)))  # h w
        # print(str(output2_0.size()))

        for i, layer in enumerate(self.encoder.level4):
            if i == 0:
                output4 = layer(output4_0)
            else:
                output4 = layer(output4)

        output4_cat = self.encoder.BR4(torch.cat([output4_0, output4], 1))
        Enc_final = self.encoder.classifier(output4_cat)

        Dnc_stage1 = self.bn_4(self.up(Enc_final, scale_factor=2, mode="bilinear"))  # 2h 2w
        stage1_confidence = nn.Softmax2d()(Dnc_stage1)
        b, c, h, w = Dnc_stage1.size()
        # Coarse_att = ((torch.max(Coarse_confidence,dim=1)[0]).unsqueeze(1)).expand(b,c,h,w)
        stage1_blocking = (torch.max(stage1_confidence, dim=1)[0]).unsqueeze(1).expand(b, c, h, w)

        Dnc_stage2_0 = self.level3_C(output3)  # 2h 2w
        Dnc_stage2 = self.bn_3(
            self.up(Dnc_stage2_0 * (1 - stage1_blocking) + (Dnc_stage1), scale_factor=2, mode="bilinear"))  # 4h 4w

        stage2_confidence = nn.Softmax2d()(Dnc_stage2)  # 4h 4w
        b, c, h, w = Dnc_stage2.size()

        stage2_blocking = (torch.max(stage2_confidence, dim=1)[0]).unsqueeze(1).expand(b, c, h, w)
        Dnc_stage3 = output2_0 * (1 - stage2_blocking) + (Dnc_stage2)

        classifier = self.classifier(Dnc_stage3)

        import torch.nn.functional as F
        classifier = F.interpolate(
            classifier,
            scale_factor=2,
            mode="bilinear",
            align_corners=True)


        if train:
            return Enc_final, classifier
        else :
            return classifier 
Example #14
Source File: back2future.py    From cc with MIT License 4 votes vote down vote up
def __init__(self, nlevels):
        super(Model, self).__init__()

        self.nlevels = nlevels
        idx = [list(range(n, -1, -9)) for n in range(80,71,-1)]
        idx = list(np.array(idx).flatten())
        self.idx_fwd = Variable(torch.LongTensor(np.array(idx)).cuda(), requires_grad=False)
        self.idx_bwd = Variable(torch.LongTensor(np.array(list(reversed(idx)))).cuda(), requires_grad=False)
        self.upsample =  nn.Upsample(scale_factor=2, mode='bilinear')
        self.softmax2d = nn.Softmax2d()

        self.conv1a = conv_feat_block(3,16)
        self.conv1b = conv_feat_block(3,16)
        self.conv1c = conv_feat_block(3,16)

        self.conv2a = conv_feat_block(16,32)
        self.conv2b = conv_feat_block(16,32)
        self.conv2c = conv_feat_block(16,32)

        self.conv3a = conv_feat_block(32,64)
        self.conv3b = conv_feat_block(32,64)
        self.conv3c = conv_feat_block(32,64)

        self.conv4a = conv_feat_block(64,96)
        self.conv4b = conv_feat_block(64,96)
        self.conv4c = conv_feat_block(64,96)

        self.conv5a = conv_feat_block(96,128)
        self.conv5b = conv_feat_block(96,128)
        self.conv5c = conv_feat_block(96,128)

        self.conv6a = conv_feat_block(128,192)
        self.conv6b = conv_feat_block(128,192)
        self.conv6c = conv_feat_block(128,192)

        self.corr = correlate # Correlation(pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1)

        self.decoder_fwd6 = conv_dec_block(162)
        self.decoder_bwd6 = conv_dec_block(162)
        self.decoder_fwd5 = conv_dec_block(292)
        self.decoder_bwd5 = conv_dec_block(292)
        self.decoder_fwd4 = conv_dec_block(260)
        self.decoder_bwd4 = conv_dec_block(260)
        self.decoder_fwd3 = conv_dec_block(228)
        self.decoder_bwd3 = conv_dec_block(228)
        self.decoder_fwd2 = conv_dec_block(196)
        self.decoder_bwd2 = conv_dec_block(196)

        self.decoder_occ6 = conv_dec_block(354)
        self.decoder_occ5 = conv_dec_block(292)
        self.decoder_occ4 = conv_dec_block(260)
        self.decoder_occ3 = conv_dec_block(228)
        self.decoder_occ2 = conv_dec_block(196) 
Example #15
Source File: unet.py    From grouped-ssd-pytorch with MIT License 4 votes vote down vote up
def __init__(self, feature_scale=4, n_classes=21, is_deconv=True, in_channels=3, is_batchnorm=True):
        """
        :param feature_scale: scale factor of feature (1: original, 4: 4x smaller # of conv feature map)
        :param n_classes: number of classes
        :param is_deconv: switch for deconv in upsampling
        :param in_channels: number of input channels
        :param is_batchnorm: switch for batchnorm in downsampling
        """
        super(unet, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale

        # filter size definition
        filters = [64, 128, 256, 512, 1024]
        filters = [int(x / self.feature_scale) for x in filters]

        self.conv1 = unetConv2(self.in_channels, filters[0], self.is_batchnorm)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)

        self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)

        self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2)

        self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
        self.maxpool4 = nn.MaxPool2d(kernel_size=2)

        self.center = unetConv2(filters[3], filters[4], self.is_batchnorm)

        # upsampling
        self.up_concat4 = unetUp(filters[4], filters[3], self.is_deconv)
        self.up_concat3 = unetUp(filters[3], filters[2], self.is_deconv)
        self.up_concat2 = unetUp(filters[2], filters[1], self.is_deconv)
        self.up_concat1 = unetUp(filters[1], filters[0], self.is_deconv)

        # final conv (without any concat)
        self.final = nn.Conv2d(filters[0], n_classes, 1)

        # softmax
        self.softmax = nn.Softmax2d() 
Example #16
Source File: attention_module.py    From Attentive-Filtering-Network with MIT License 4 votes vote down vote up
def __init__(self, in_channels, out_channels, size1=(128,545), size2=(120,529), size3=(104,497), size4=(72,186), l1weight=0.2):
            
        super(AttentionModule_stg0, self).__init__()
        self.l1weight = l1weight 
        self.pre = ResidualBlock(in_channels, 1)

        ## trunk branch 
        self.trunk = nn.Sequential(
            ResidualBlock(in_channels, 1),
            ResidualBlock(in_channels, 1)
        )
        ## softmax branch: bottom-up 
        self.mp1   = nn.MaxPool2d(kernel_size=3, stride=(1,1))
        self.sm1   = ResidualBlock(in_channels, (4,8))
        self.skip1 = ResidualBlock(in_channels, 1)
        
        self.mp2   = nn.MaxPool2d(kernel_size=3, stride=(1,1))
        self.sm2   = ResidualBlock(in_channels, (8,16))
        self.skip2 = ResidualBlock(in_channels, 1)
        
        self.mp3   = nn.MaxPool2d(kernel_size=3, stride=(1,2))
        self.sm3   = ResidualBlock(in_channels, (16,32))
        self.skip3 = ResidualBlock(in_channels, 1)
        
        self.mp4   = nn.MaxPool2d(kernel_size=3, stride=(2,2))
        self.sm4   = nn.Sequential(
            ResidualBlock(in_channels, (16,32)),
            ResidualBlock(in_channels, 1)
        )
        ## softmax branch: top-down 
        self.up4   = nn.UpsamplingBilinear2d(size=size4)
        self.sm5   = ResidualBlock(in_channels, 1)
        self.up3   = nn.UpsamplingBilinear2d(size=size3)
        self.sm6   = ResidualBlock(in_channels, 1)
        self.up2   = nn.UpsamplingBilinear2d(size=size2)
        self.sm7   = ResidualBlock(in_channels, 1)
        self.up1   = nn.UpsamplingBilinear2d(size=size1)
        # 1*1 convolution blocks 
        self.conv1 = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels , kernel_size=1, stride=1, bias=False),
            #nn.Sigmoid()
            nn.Softmax2d()
        )
        
        self.post = ResidualBlock(in_channels, 1) 
Example #17
Source File: fcn32s_tiny.py    From SceneChangeDet with MIT License 4 votes vote down vote up
def __init__(self,distance_flag):

        super(fcn32s, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1,ceil_mode=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=64,out_channels=128,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1,ceil_mode=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=256,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1,ceil_mode=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(in_channels=256,out_channels=512,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1, ceil_mode=True)
        )
        self.conv5 = nn.Sequential(
            nn.Conv2d(in_channels=512,out_channels=512,kernel_size=3,dilation=2,padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, padding=2),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3,stride=1,padding=1,ceil_mode=True)
        )
        self.embedding_layer = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, dilation=2, padding=2)
        if distance_flag == 'softmax':
           self.fc8 = nn.Softmax2d()
        if distance_flag == 'l2':
           self.fc8 = fun.l2normalization(scale=1) 
Example #18
Source File: model.py    From Semi-supervised-segmentation-cycleGAN with MIT License 4 votes vote down vote up
def __init__(self, args):

        if args.dataset == 'voc2012':
            self.n_channels = 21
        elif args.dataset == 'cityscapes':
            self.n_channels = 20
        elif args.dataset == 'acdc':
            self.n_channels = 4

        # Define the network 
        self.Gsi = define_Gen(input_nc=3, output_nc=self.n_channels, ngf=args.ngf, netG='deeplab', norm=args.norm,
                              use_dropout=not args.no_dropout, gpu_ids=args.gpu_ids)  # for image to segmentation

        ### Now we put in the pretrained weights in Gsi
        ### These will only be used in the case of VOC and cityscapes
        if args.dataset != 'acdc':
            saved_state_dict = torch.load(pretrained_loc)
            new_params = self.Gsi.state_dict().copy()
            for name, param in new_params.items():
                # print(name)
                if name in saved_state_dict and param.size() == saved_state_dict[name].size():
                    new_params[name].copy_(saved_state_dict[name])
                    # print('copy {}'.format(name))
            # self.Gsi.load_state_dict(new_params)

        utils.print_networks([self.Gsi], ['Gsi'])

        ###Defining an interpolation function so as to match the output of network to feature map size
        self.interp = nn.Upsample(size = (args.crop_height, args.crop_width), mode='bilinear', align_corners=True)
        self.interp_val = nn.Upsample(size = (512, 512), mode='bilinear', align_corners=True)

        self.CE = nn.CrossEntropyLoss()
        self.activation_softmax = nn.Softmax2d()
        self.gsi_optimizer = torch.optim.Adam(self.Gsi.parameters(), lr=args.lr, betas=(0.9, 0.999))

        ### writer for tensorboard
        self.writer_supervised = SummaryWriter(tensorboard_loc + '_supervised')
        self.running_metrics_val = utils.runningScore(self.n_channels, args.dataset)

        self.args = args

        if not os.path.isdir(args.checkpoint_dir):
            os.makedirs(args.checkpoint_dir)

        try:
            ckpt = utils.load_checkpoint('%s/latest_supervised_model.ckpt' % (args.checkpoint_dir))
            self.start_epoch = ckpt['epoch']
            self.Gsi.load_state_dict(ckpt['Gsi'])
            self.gsi_optimizer.load_state_dict(ckpt['gsi_optimizer'])
            self.best_iou = ckpt['best_iou']
        except:
            print(' [*] No checkpoint!')
            self.start_epoch = 0
            self.best_iou = -100