import torch import torchvision import torch.nn as nn from models.ResBlock import SemiResnetBlock_bn, ResnetBlock_bn, SemiResnetBlock, ResnetBlock # import mmdet.ops.dcn.deform_conv as dc# DeformConv, ModulatedDeformConv from mmdet.ops.dcn.deform_conv import DeformConv, ModulatedDeformConv, _pair, modulated_deform_conv import torch.nn.functional as F class DCN_sep(ModulatedDeformConv): def __init__(self, *args, **kwargs): super(DCN_sep, self).__init__(*args, **kwargs) self.conv_offset_mask = nn.Conv2d( self.in_channels, self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], kernel_size=self.kernel_size, stride=_pair(self.stride), padding=_pair(self.padding), bias=True) self.init_offset() def init_offset(self): self.conv_offset_mask.weight.data.zero_() self.conv_offset_mask.bias.data.zero_() def forward(self, x, fea): out = self.conv_offset_mask(fea) o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) offset_mean = torch.mean(torch.abs(offset)) if offset_mean > 100: print('Offset mean is {}, larger than 100.'.format(offset_mean)) mask = torch.sigmoid(mask) return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups) class ReVggBlock(nn.Module): def __init__(self, inchannel, outchannel, upsampling=False, end=False): """ Reverse Vgg19_bn block :param inchannel: input channel :param outchannel: output channel :param upsampling: judge for adding upsampling module :param padding: padding mode: 'zero', 'reflect', by default:'reflect' """ super(ReVggBlock, self).__init__() model = [] model += [nn.ReplicationPad2d(1)] model += [nn.Conv2d(inchannel, outchannel, 3)] if upsampling: model += [nn.UpsamplingBilinear2d(scale_factor=2)] if not end: model += [nn.LeakyReLU(True), nn.BatchNorm2d(outchannel)] self.model = nn.Sequential(*model) def forward(self, x): return self.model(x) class ReVgg19(nn.Module): def __init__(self): """ A reverse operation network for vgg19_bn """ super(ReVgg19, self).__init__() self.dcn0 = DeformableConv(3) self.dcn1 = DeformableConv(64, 2) self.dcn2 = DeformableConv(128, 2) self.dcn3 = DeformableConv(256, 2) self.dcn4 = DeformableConv(512, 2) self.model_4 = nn.Sequential(ReVggBlock(512, 512), ReVggBlock(512, 512), ReVggBlock(512, 256, True, True)) self.model_3 = nn.Sequential(ReVggBlock(256, 256), ReVggBlock(256, 256), ReVggBlock(256, 128, True, True)) self.model_2 = nn.Sequential(ReVggBlock(128, 128), ReVggBlock(128, 64, True, True)) self.model_1 = [ReVggBlock(64, 64)] self.model_1 += [nn.ReplicationPad2d(1)] self.model_1 += [nn.Conv2d(64, 64, 3)] self.model_1 = nn.Sequential(*self.model_1) self.model_0 = nn.Sequential(SemiResnetBlock_bn(64, 64), ResnetBlock_bn(64), ResnetBlock_bn(64), SemiResnetBlock_bn(64, 3, end=True), nn.Tanh()) self.offset_tran = nn.Sequential(ReVggBlock(128, 128), ReVggBlock(128, 12, end=True)) def forward(self, ft_img0, ft_img1): ft4, offset, _, _ = self.dcn4(ft_img0[4], ft_img1[4]) out4 = self.model_4(ft4) ft3, offset, _, _ = self.dcn3(ft_img0[3], ft_img1[3], last_offset=offset, last_up_out=out4) out3 = self.model_3(ft3) ft2, offset, _, _ = self.dcn2(ft_img0[2], ft_img1[2], last_offset=offset, last_up_out=out3) out2 = self.model_2(ft2) ft1, _, last, test = self.dcn1(ft_img0[1], ft_img1[1], last_offset=offset, last_up_out=out2) out1 = self.model_1(ft1) # offset = self.offset_tran(offset) # ft0, offset = self.dcn0(torch.cat([ft_img0[0], ft_img1[0]], dim=1), last_offset=offset, last_up_out=out1, up=False) imgt = self.model_0(out1) return imgt, [0, ft1, ft2, ft3, ft4], last, test class ExtractFeatures(nn.Module): def __init__(self): super(ExtractFeatures, self).__init__() self.net = torchvision.models.resnet50(pretrained=True) self.conv1_4in = nn.Conv2d(4, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) nn.init.xavier_uniform(self.conv1_4in.weight) def forward(self, x): output = self.conv1_4in(x) output = self.net.bn1(output) output = self.net.relu(output) output = self.net.maxpool(output) ft1 = self.net.layer1(output) ft2 = self.net.layer2(ft1) ft3 = self.net.layer3(ft2) return 0, 0, ft1, ft2, ft3 class ValidationFeatures(nn.Module): def __init__(self): super(ValidationFeatures, self).__init__() vgg = torchvision.models.vgg16_bn(pretrained=True) self.extract_feature1 = nn.Sequential(*list(vgg.features.children())[:4]) for param in self.extract_feature1.parameters(True): param.requires_grad = False self.extract_feature2 = nn.Sequential(*list(vgg.features.children())[4:11]) for param in self.extract_feature2.parameters(True): param.requires_grad = False self.extract_feature3 = nn.Sequential(*list(vgg.features.children())[11:21]) for param in self.extract_feature3.parameters(True): param.requires_grad = False self.extract_feature4 = nn.Sequential(*list(vgg.features.children())[21:31]) for param in self.extract_feature4.parameters(True): param.requires_grad = False def forward(self, x): ft_1 = self.extract_feature1(x) ft_2 = self.extract_feature2(ft_1) ft_3 = self.extract_feature3(ft_2) ft_4 = self.extract_feature4(ft_3) return ft_1, ft_2, ft_3, ft_4 class StructureExtractor(nn.Module): def __init__(self): super(StructureExtractor, self).__init__() vgg = torchvision.models.vgg16(pretrained=True) self.extract_feature1 = nn.Sequential(*list(vgg.features.children())[:4]) for param in self.extract_feature1.parameters(True): param.requires_grad = False self.ap = nn.AvgPool2d(kernel_size=2, stride=2) self.mp = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): edge = self.extract_feature1(x) ap_edge1 = self.ap(edge) ap_edge2 = self.ap(ap_edge1) mp_edge1 = self.mp(edge) mp_edge2 = self.mp(mp_edge1) return ap_edge1, ap_edge2, mp_edge1, mp_edge2 class DeformableConv(nn.Module): def __init__(self, inchannel, dg=2): super(DeformableConv, self).__init__() self.dg = dg self.offset_cnn1 = nn.Sequential(nn.Conv2d(2 * inchannel, inchannel, 3, padding=1), nn.BatchNorm2d(inchannel), nn.LeakyReLU(True)) self.offset_cnn2 = nn.Sequential(nn.Conv2d(2 * 2 * 2 * dg * 9, 2 * 2 * dg * 9, 3, padding=1), nn.BatchNorm2d(2 * 2 * dg * 9), nn.LeakyReLU(True)) self.offset_cnn3 = nn.Sequential(*([ResnetBlock_bn(inchannel)] * 5 + [ResnetBlock(inchannel)] * 3 + [nn.Conv2d(inchannel, 2 * 2 * dg * 9, 3, padding=1)])) self.emb = nn.Conv2d(inchannel, inchannel, 3, padding=1) self.mix_map = nn.Sequential(nn.Conv2d(2 * inchannel, inchannel, 3, padding=1), nn.LeakyReLU(True), *([ResnetBlock(inchannel)] * 3), nn.Conv2d(inchannel, 2 * dg, 3, padding=1)) self.dcn = DeformConv(inchannel, inchannel, 3, padding=1, deformable_groups=dg) self.up = nn.UpsamplingBilinear2d(scale_factor=2) def forward(self, x, y, last_offset=None, up=True): offset = None if last_offset is not None: if up: last_offset = self.up(last_offset) offset = self.offset_cnn1(torch.cat([x, y], dim=1)) offset = self.offset_cnn2(torch.cat([offset, last_offset * 2], dim=1)) offset_com = self.offset_cnn3(offset) else: offset = self.offset_cnn1(torch.cat([x, y], dim=1)) offset = self.offset_cnn3(offset) offset_x, offset_y = torch.chunk(offset, 2, dim=1) out_x = self.dcn(x, offset_x) out_y = self.dcn(y, offset_y) vmap_x, vmap_y = torch.chunk(torch.sigmoid(self.mix_map(torch.cat([self.emb(out_x), self.emb(out_y)], dim=1))), 2, dim=1) vmap_x = torch.chunk(vmap_x, self.dg, dim=1) vmap_y = torch.chunk(vmap_y, self.dg, dim=1) out_x_d = torch.chunk(out_x, self.dg, dim=1) out_y_d = torch.chunk(out_y, self.dg, dim=1) out = [vmap_x[i] * out_x_d[i] + vmap_y[i] * out_y_d[i] for i in range(self.dg)] out = torch.cat(out, dim=1) return out, out_x, out_y class ExtractAlignedFeatures(nn.Module): """ Extract features """ def __init__(self, nf=64, n_res=5): super(ExtractAlignedFeatures, self).__init__() self.deblur = Predeblur_ResNet_Pyramid(nf=nf) self.fea_L1_conv = nn.Sequential(*([ResnetBlock(nf)] * n_res)) self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, x): ft_L1 = self.fea_L1_conv(self.deblur(x)) ft_L2 = self.lrelu(self.fea_L2_conv2(self.fea_L2_conv1(ft_L1))) ft_L3 = self.lrelu(self.fea_L3_conv2(self.fea_L3_conv1(ft_L2))) return [ft_L1, ft_L2, ft_L3] class PCD_Align(nn.Module): ''' Alignment module using Pyramid, Cascading and Deformable convolution with 3 pyramid levels. ''' def __init__(self, nf=64, groups=8): super(PCD_Align, self).__init__() # L3: level 3, 1/4 spatial size self.L3_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.L3_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L3_dcnpack = DCN_sep(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) # L2: level 2, 1/2 spatial size self.L2_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.L2_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L2_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L2_dcnpack = DCN_sep(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.L2_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # L1: level 1, original spatial size self.L1_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.L1_offset_conv2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for offset self.L1_offset_conv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.L1_dcnpack = DCN_sep(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.L1_fea_conv = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea # Cascading DCN self.cas_offset_conv1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for diff self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.cas_dcnpack = DCN_sep(nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, nbr_fea_l, ref_fea_l): '''align other neighboring frames to the reference frame in the feature level nbr_fea_l, ref_fea_l: [L1, L2, L3], each with [B,C,H,W] features ''' # L3 L3_offset = torch.cat([nbr_fea_l[2], ref_fea_l[2]], dim=1) L3_offset = self.lrelu(self.L3_offset_conv1(L3_offset)) L3_offset = self.lrelu(self.L3_offset_conv2(L3_offset)) L3_fea = self.lrelu(self.L3_dcnpack(nbr_fea_l[2], L3_offset)) # L2 L2_offset = torch.cat([nbr_fea_l[1], ref_fea_l[1]], dim=1) L2_offset = self.lrelu(self.L2_offset_conv1(L2_offset)) L3_offset = F.interpolate(L3_offset, scale_factor=2, mode='bilinear', align_corners=False) L2_offset = self.lrelu(self.L2_offset_conv2(torch.cat([L2_offset, L3_offset * 2], dim=1))) L2_offset = self.lrelu(self.L2_offset_conv3(L2_offset)) L2_fea = self.L2_dcnpack(nbr_fea_l[1], L2_offset) L3_fea = F.interpolate(L3_fea, scale_factor=2, mode='bilinear', align_corners=False) L2_fea = self.lrelu(self.L2_fea_conv(torch.cat([L2_fea, L3_fea], dim=1))) # L1 L1_offset = torch.cat([nbr_fea_l[0], ref_fea_l[0]], dim=1) L1_offset = self.lrelu(self.L1_offset_conv1(L1_offset)) L2_offset = F.interpolate(L2_offset, scale_factor=2, mode='bilinear', align_corners=False) L1_offset = self.lrelu(self.L1_offset_conv2(torch.cat([L1_offset, L2_offset * 2], dim=1))) L1_offset = self.lrelu(self.L1_offset_conv3(L1_offset)) L1_fea = self.L1_dcnpack(nbr_fea_l[0], L1_offset) L2_fea = F.interpolate(L2_fea, scale_factor=2, mode='bilinear', align_corners=False) L1_fea = self.L1_fea_conv(torch.cat([L1_fea, L2_fea], dim=1)) # Cascading offset = torch.cat([L1_fea, ref_fea_l[0]], dim=1) offset = self.lrelu(self.cas_offset_conv1(offset)) offset = self.lrelu(self.cas_offset_conv2(offset)) L1_fea = self.lrelu(self.cas_dcnpack(L1_fea, offset)) # Denoise # L1_fea = self.lrelu(self.cas_offset_conv2(L1_fea)) # L1_fea = self.cas_offset_conv2(L1_fea) return L1_fea class TSA_Fusion(nn.Module): ''' Temporal Spatial Attention fusion module Temporal: correlation; Spatial: 3 pyramid levels. ''' def __init__(self, nf=64, nframes=5, center=2): super(TSA_Fusion, self).__init__() self.center = center # temporal attention (before fusion conv) self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) # fusion conv: using 1x1 to save parameters and computation self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) # spatial attention (after fusion conv) self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=True) self.maxpool = nn.MaxPool2d(3, stride=2, padding=1) self.avgpool = nn.AvgPool2d(3, stride=2, padding=1) self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=True) self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=True) self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=True) self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=True) self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, aligned_fea): B, N, C, H, W = aligned_fea.size() # N video frames #### temporal attention emb_ref = self.tAtt_2(aligned_fea[:, self.center, :, :, :].clone()) emb = self.tAtt_1(aligned_fea.view(-1, C, H, W)).view(B, N, -1, H, W) # [B, N, C(nf), H, W] cor_l = [] for i in range(N): emb_nbr = emb[:, i, :, :, :] cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1) # B, 1, H, W cor_l.append(cor_tmp) cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1)) # B, N, H, W cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1).view(B, -1, H, W) aligned_fea = aligned_fea.view(B, -1, H, W) * cor_prob #### fusion fea = self.lrelu(self.fea_fusion(aligned_fea)) #### spatial attention att = self.lrelu(self.sAtt_1(aligned_fea)) att_max = self.maxpool(att) att_avg = self.avgpool(att) att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1))) # pyramid levels att_L = self.lrelu(self.sAtt_L1(att)) att_max = self.maxpool(att_L) att_avg = self.avgpool(att_L) att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1))) att_L = self.lrelu(self.sAtt_L3(att_L)) att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False) att = self.lrelu(self.sAtt_3(att)) att = att + att_L att = self.lrelu(self.sAtt_4(att)) att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False) att = self.sAtt_5(att) att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att))) att = torch.sigmoid(att) fea = fea * att * 2 + att_add return fea class Reconstruct(nn.Module): def __init__(self, nf=64, n_res=10): super(Reconstruct, self).__init__() #### reconstruction self.recon_trunk = nn.Sequential(*([ResnetBlock(nf)] * n_res)) #### upsampling self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) self.upconv2 = nn.Conv2d(nf, 64 * 4, 3, 1, 1, bias=True) self.pixel_shuffle = nn.PixelShuffle(2) self.down = nn.AvgPool2d(2, 2) self.HRconv = nn.Conv2d(64, 64, 3, 1, 1, bias=True) self.conv_last = nn.Conv2d(64, 3, 3, 1, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) self.tanh = nn.Tanh() def forward(self, fea, x_center): out = self.recon_trunk(fea) out = self.lrelu(self.HRconv(out)) out = self.conv_last(out) out += x_center return out class Predeblur_ResNet_Pyramid(nn.Module): def __init__(self, nf=128, HR_in=False): ''' HR_in: True if the inputs are high spatial size ''' super(Predeblur_ResNet_Pyramid, self).__init__() self.HR_in = True if HR_in else False if self.HR_in: self.conv_first_1 = nn.Conv2d(3, nf, 3, 1, 1, bias=True) self.conv_first_2 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.conv_first_3 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) else: self.conv_first = nn.Conv2d(3, nf, 3, 1, 1, bias=True) self.RB_L1_1 = ResnetBlock(nf) self.RB_L1_2 = ResnetBlock(nf) self.RB_L1_3 = ResnetBlock(nf) self.RB_L1_4 = ResnetBlock(nf) self.RB_L1_5 = ResnetBlock(nf) self.RB_L2_1 = ResnetBlock(nf) self.RB_L2_2 = ResnetBlock(nf) self.RB_L3_1 = ResnetBlock(nf) self.deblur_L2_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.deblur_L3_conv = nn.Conv2d(nf, nf, 3, 2, 1, bias=True) self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) def forward(self, x): if self.HR_in: L1_fea = self.lrelu(self.conv_first_1(x)) L1_fea = self.lrelu(self.conv_first_2(L1_fea)) L1_fea = self.lrelu(self.conv_first_3(L1_fea)) else: L1_fea = self.lrelu(self.conv_first(x)) L2_fea = self.lrelu(self.deblur_L2_conv(L1_fea)) L3_fea = self.lrelu(self.deblur_L3_conv(L2_fea)) L3_fea = F.interpolate(self.RB_L3_1(L3_fea), scale_factor=2, mode='bilinear', align_corners=False) L2_fea = self.RB_L2_1(L2_fea) + L3_fea L2_fea = F.interpolate(self.RB_L2_2(L2_fea), scale_factor=2, mode='bilinear', align_corners=False) L1_fea = self.RB_L1_2(self.RB_L1_1(L1_fea)) + L2_fea out = self.RB_L1_5(self.RB_L1_4(self.RB_L1_3(L1_fea))) return out