# -*- coding: utf-8 -*- # @Author : DevinYang(pistonyang@gmail.com) import math from torch import nn from torch.nn.init import xavier_normal_, xavier_uniform_, \ kaiming_normal_, kaiming_uniform_, zeros_ class XavierInitializer(object): """Initialize a model params by Xavier. Fills the input `Tensor` with values according to the method described in `Understanding the difficulty of training deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010) Args: model (nn.Module): model you need to initialize. random_type (string): random_type gain (float): an optional scaling factor, default is sqrt(2.0) """ def __init__(self, random_type='uniform', gain=math.sqrt(2.0)): assert random_type in ('uniform', 'normal') self.initializer = xavier_uniform_ if random_type == 'uniform' else xavier_normal_ self.gain = gain def __call__(self, module): if isinstance(module, (nn.Conv2d, nn.Conv3d)): self.initializer(module.weight.data, gain=self.gain) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): if module.weight is not None: module.weight.data.fill_(1) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Linear): self.initializer(module.weight.data, gain=self.gain) if module.bias is not None: module.bias.data.zero_() class KaimingInitializer(object): def __init__( self, slope=0, mode='fan_out', nonlinearity='relu', random_type='normal'): assert random_type in ('uniform', 'normal') self.slope = slope self.mode = mode self.nonlinearity = nonlinearity self.initializer = kaiming_uniform_ if random_type == 'uniform' else kaiming_normal_ def __call__(self, module): if isinstance(module, (nn.Conv2d, nn.Conv3d)): self.initializer( module.weight.data, self.slope, self.mode, self.nonlinearity) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): if module.weight is not None: module.weight.data.fill_(1) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Linear): self.initializer( module.weight.data, self.slope, self.mode, self.nonlinearity) if module.bias is not None: module.bias.data.zero_() class ZeroLastGamma(object): """Notice that this need to put after other initializer. """ def __init__(self, block_name='Bottleneck', bn_name='bn3'): self.block_name = block_name self.bn_name = bn_name def __call__(self, module): if module.__class__.__name__ == self.block_name: target_bn = module.__getattr__(self.bn_name) zeros_(target_bn.weight)