# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as cp from fairseq.modules import ( MaskedConvolution ) # FIXME non-zero padding is not effective class DenseNetCascade(nn.Module): """ Single block DenseNet with checkpointing""" def __init__(self, num_init_features, args): super().__init__() divide_channels = args.divide_channels num_layers = args.num_layers num_cascade_layers = args.num_cascade_layers growth_rate = args.growth_rate num_features = num_init_features self.reduce_channels = Linear( num_features, num_features // args.divide_channels ) if args.divide_channels > 1 else None num_features = num_features // args.divide_channels self.dense_layers = nn.ModuleList([]) for _ in range(num_layers): self.dense_layers.append(_DenseLayer(num_features, args)) num_features += growth_rate self.cascade_layers = nn.Sequential() for i in range(num_cascade_layers): assert not num_features % 2, "Num feeatures should be pair" self.cascade_layers.add_module( 'cascade%d' % i, _CascadeLayer(num_features) ) print('Cascade', i, num_features, '>>', num_features//2) num_features = num_features // 2 self.output_channels = num_features def forward(self, x, encoder_mask=None, decoder_mask=None, incremental_state=None): """ Input : B, Tt, Ts, C Output : B, Tt, Ts, C """ if self.reduce_channels is not None: x = self.reduce_channels(x) # B,Tt,Ts,C >> B,C,Tt,Ts x = x.permute(0, 3, 1, 2) features = [x] for i, layer in enumerate(self.dense_layers): x = layer(features, decoder_mask=decoder_mask, encoder_mask=encoder_mask, incremental_state=incremental_state) features.append(x) x = torch.cat(features, 1) # Back to the original shape B, Tt,Ts,C x = x.permute(0, 2, 3, 1) x = self.cascade_layers(x) return x class _DenseLayer(nn.Module): def __init__(self, num_input_features, args): super().__init__() self.memory_efficient = args.memory_efficient self.drop_rate = args.convolution_dropout bn_size = args.bn_size growth_rate = args.growth_rate inter_features = bn_size * growth_rate kernel_size = args.kernel_size self.conv1 = nn.Conv2d(num_input_features, inter_features, kernel_size=1, stride=1, bias=args.conv_bias, ) dilsrc = args.source_dilation diltrg = args.target_dilation padding_trg = diltrg * (kernel_size - 1) // 2 padding_src = dilsrc * (kernel_size - 1) // 2 padding = (padding_trg, padding_src) self.mconv2 = MaskedConvolution( inter_features, growth_rate, kernel_size, args, padding=padding, ) def bottleneck_function(self, *inputs): x = torch.cat(inputs, 1) x = F.relu(x) x = self.conv1(x) return x def forward(self, prev_features, encoder_mask=None, decoder_mask=None, incremental_state=None): """ Memory efficient forward pass with checkpointing Each DenseLayer splits its forward into: - bottleneck_function - therest_function Prev_features as list of features in (B, C, Tt, Ts) Returns the new features alone (B, g, Tt, Ts) """ if self.memory_efficient and any( prev_feature.requires_grad for prev_feature in prev_features ): # Does not keep intermediate values, # but recompute them in the backward pass: # tradeoff btw memory & compute x = cp.checkpoint( self.bottleneck_function, *prev_features ) else: x = self.bottleneck_function(*prev_features) x = self.mconv2(x, incremental_state) if encoder_mask is not None: x = x.masked_fill(encoder_mask.unsqueeze(1).unsqueeze(1), 0) if decoder_mask is not None: x = x.masked_fill(decoder_mask.unsqueeze(1).unsqueeze(-1), 0) if self.drop_rate: x = F.dropout(x, p=self.drop_rate, training=self.training) return x class _CascadeLayer(nn.Module): def __init__(self, num_features): super().__init__() self.linear = Linear(num_features, num_features) def forward(self, x): return F.glu(self.linear(x), dim=-1) def Linear(in_features, out_features, bias=True): m = nn.Linear(in_features, out_features, bias) nn.init.xavier_uniform_(m.weight) if bias: nn.init.constant_(m.bias, 0.) return m