# Two types: # 1. "Hard" gated : use 1/0 to update mask. Image Inpainting for Irregular Holes Using Partial Convolutions # 2. "Soft" gated : use sigmoid to update both feature & mask Free-Form Image Inpainting with Gated Convolution import torch from torch import nn from torch.nn import functional as F from torch.nn.functional import avg_pool2d from .BaseModels import BaseModule try: from .inplace_abn import InPlaceABN # only works in GPU inplace_batch_norm = True except ImportError: inplace_batch_norm = False class PartialConv(BaseModule): # reference: # Image Inpainting for Irregular Holes Using Partial Convolutions # http://masc.cs.gmu.edu/wiki/partialconv/show?time=2018-05-24+21%3A41%3A10 # https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/net.py # https://github.com/SeitaroShinagawa/chainer-partial_convolution_image_inpainting/blob/master/common/net.py # mask is binary, 0 is holes; 1 is not def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, same_holes=False): # same holes: holes are in the same position in all layers. used in the encoder part super(PartialConv, self).__init__() self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) nn.init.kaiming_normal_(self.feature_conv.weight) self.same_holes = same_holes mask_in_channel = 1 if same_holes else in_channels mask_out_channel = 1 if same_holes else out_channels mask_groups = 1 if same_holes else groups self.mask_conv = nn.Conv2d(mask_in_channel, mask_out_channel, kernel_size, stride, padding, dilation, mask_groups, bias=False) torch.nn.init.constant_(self.mask_conv.weight, 1.0) # torch.nn.init.constant_(self.mask_conv.bias, 0.0) for param in self.mask_conv.parameters(): param.requires_grad = False def forward(self, args): x, mask = args output = self.feature_conv(x * mask) if self.feature_conv.bias is not None: output_bias = self.feature_conv.bias.view(1, -1, 1, 1).expand_as(output) else: output_bias = torch.zeros_like(output) with torch.no_grad(): if self.same_holes: output_mask = self.mask_conv(mask[:, :1]) # mask sums no_update_holes = output_mask == 0 output_mask *= self.feature_conv.in_channels else: output_mask = self.mask_conv(mask) # mask sums no_update_holes = output_mask == 0 mask_sum = output_mask.masked_fill_(no_update_holes, 1.0) # See 2nd reference, but takes more time to run # scale = torch.div(ones, mask_sum) output_pre = (output - output_bias) / mask_sum + output_bias output = output_pre.masked_fill_(no_update_holes, 0.0) new_mask = torch.ones_like(output_mask) new_mask = new_mask.masked_fill_(no_update_holes, 0.0) if self.same_holes: new_mask = new_mask.expand_as(output) # output = output_pre * new_mask return output, new_mask class PartialConv1x1(BaseModule): """ Optimization for encoder : if the input mask have holes in the same positions across channels, then 1x1 partial convolution is equivalent to a standard 1x1 convolution because holes are not updated. By assert checking, encoder and feature pooling are eligible, but decoder needs to concatenate encoder's mask, so it fails. """ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bias=True): super(PartialConv1x1, self).__init__() assert kernel_size == 1 and stride == 1 and padding == 0 self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) nn.init.kaiming_normal_(self.feature_conv.weight) def forward(self, args): x, mask = args out_x = self.feature_conv(x) out_m = mask[:, :1, :, :].expand_as(out_x) return out_x, out_m class PartialConvNoHoles(PartialConv): """ Optimization for encoder : Used for the decoder part. After successive partial convolution, the decoder should have no holes in the masks. The u-net structure links the encoder mask with decoder mask, so 1x1 convolution will fill all holes. """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(PartialConvNoHoles, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) assert self.feature_conv.groups == 1 def forward(self, args): x, mask = args output = self.feature_conv(x * mask) if self.feature_conv.bias is not None: output_bias = self.feature_conv.bias.view(1, -1, 1, 1).expand_as(output) else: output_bias = torch.zeros_like(output) with torch.no_grad(): output_mask = self.mask_conv(mask) mask_sum = output_mask output = (output - output_bias) / mask_sum + output_bias new_mask = torch.ones_like(output) return output, new_mask class SoftPartialConv(BaseModule): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, ): super(SoftPartialConv, self).__init__() self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) def forward(self, args): x, mask = args output = self.feature_conv(x) mask_output = self.mask_conv(1 - mask) # holes are 1; else 0 mask_attention = F.tanh(mask_output) # non-holes positions are 0 output = output + mask_attention * output valid_idx = mask_attention == 0 new_mask = torch.where(valid_idx, torch.ones_like(output), F.sigmoid(mask_output)) return output, new_mask def partial_convolution_block(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, BN=True, activation=True, use_1_conv=False, no_holes_1_conv=False, same_holes=False): if use_1_conv: m = [PartialConv1x1(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)] elif no_holes_1_conv: m = [PartialConvNoHoles(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)] else: m = [PartialConv(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, same_holes)] if BN: m += [PartialActivatedBN(out_channels, activation)] if not BN and activation: m += [PartialActivation(activation)] return nn.Sequential(*m) class PartialActivatedBN(BaseModule): def __init__(self, channel, act_fn): super(PartialActivatedBN, self).__init__() if inplace_batch_norm: if act_fn: self.bn_act = InPlaceABN(channel, activation="leaky_relu", slope=0.3) else: self.bn_act = InPlaceABN(channel, activation='none') else: if act_fn: self.bn_act = nn.Sequential(nn.BatchNorm2d(channel), act_fn) else: self.bn_act = nn.Sequential(nn.BatchNorm2d(channel)) def forward(self, args): x, mask = args return self.bn_act(x), mask class PartialActivation(BaseModule): def __init__(self, activation): super(PartialActivation, self).__init__() self.act_fn = activation def forward(self, args): x, mask = args return self.act_fn(x), mask class DoubleAvdPool(nn.AvgPool2d): def __init__(self, kernel_size): super(DoubleAvdPool, self).__init__(kernel_size=kernel_size) self.kernel_size = kernel_size def forward(self, args): type(args) return tuple(map(lambda x: avg_pool2d(x, kernel_size=self.kernel_size), args)) class DoubleUpSample(nn.Module): def __init__(self, scale_factor, mode='nearest'): super(DoubleUpSample, self).__init__() self.upsample = nn.Upsample(scale_factor=scale_factor, mode=mode) def forward(self, args): x, mask = args return self.upsample(x), self.upsample(mask) class PartialGatedConv(BaseModule): # mask is binary, 0 is masked point, 1 is not # https://github.com/JiahuiYu/generative_inpainting/issues/62 def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, BN=False, activation=nn.SELU()): super(PartialGatedConv, self).__init__() self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) if BN: self.bn_act = nn.Sequential(nn.BatchNorm2d(out_channels), activation) else: self.bn_acf = activation def forward(self, x): output = self.feature_conv(x) mask = self.mask_conv(x) return self.bn_act(output * F.sigmoid(mask)) class PartialGatedActivatedBN(BaseModule): def __init__(self, channel, activation): super(PartialGatedActivatedBN, self).__init__() self.bn_act = nn.Sequential(nn.BatchNorm2d(channel), activation) def forward(self, x): return self.bn_act(x) def partial_gated_conv_block(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, BN=True, activation=None): m = [PartialGatedConv(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, BN, activation)] return nn.Sequential(*m)