Python torch.nn.SyncBatchNorm() Examples

The following are 23 code examples of torch.nn.SyncBatchNorm(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.nn , or try the search function .
Example #1
Source File: batch_norm.py    From Parsing-R-CNN with MIT License 6 votes vote down vote up
def forward(self, input):
        if get_world_size() == 1 or not self.training:
            return super().forward(input)

        assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
        C = input.shape[1]
        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

        vec = torch.cat([mean, meansqr], dim=0)
        vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())

        mean, meansqr = torch.split(vec, C)
        var = meansqr - mean * mean
        self.running_mean += self.momentum * (mean.detach() - self.running_mean)
        self.running_var += self.momentum * (var.detach() - self.running_var)

        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
        return input * scale + bias 
Example #2
Source File: norm.py    From Det3D with Apache License 2.0 6 votes vote down vote up
def forward(self, input):
        if comm.get_world_size() == 1 or not self.training:
            return super().forward(input)

        assert input.shape[0] > 0, "SyncBatchNorm does not support empty input"
        C = input.shape[1]
        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

        vec = torch.cat([mean, meansqr], dim=0)
        vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())

        mean, meansqr = torch.split(vec, C)
        var = meansqr - mean * mean
        self.running_mean += self.momentum * (mean.detach() - self.running_mean)
        self.running_var += self.momentum * (var.detach() - self.running_var)

        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
        return input * scale + bias 
Example #3
Source File: batch_norm.py    From detectron2 with Apache License 2.0 6 votes vote down vote up
def forward(self, input):
        if comm.get_world_size() == 1 or not self.training:
            return super().forward(input)

        assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
        C = input.shape[1]
        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

        vec = torch.cat([mean, meansqr], dim=0)
        vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())

        mean, meansqr = torch.split(vec, C)
        var = meansqr - mean * mean
        self.running_mean += self.momentum * (mean.detach() - self.running_mean)
        self.running_var += self.momentum * (var.detach() - self.running_var)

        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
        return input * scale + bias 
Example #4
Source File: batch_norm.py    From detectron2 with Apache License 2.0 6 votes vote down vote up
def get_norm(norm, out_channels):
    """
    Args:
        norm (str or callable):

    Returns:
        nn.Module or None: the normalization layer
    """
    if isinstance(norm, str):
        if len(norm) == 0:
            return None
        norm = {
            "BN": BatchNorm2d,
            "SyncBN": NaiveSyncBatchNorm,
            "FrozenBN": FrozenBatchNorm2d,
            "GN": lambda channels: nn.GroupNorm(32, channels),
            "nnSyncBN": nn.SyncBatchNorm,  # keep for debugging
        }[norm]
    return norm(out_channels) 
Example #5
Source File: batch_norm.py    From detectron2 with Apache License 2.0 6 votes vote down vote up
def get_norm(norm, out_channels):
    """
    Args:
        norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
            or a callable that takes a channel number and returns
            the normalization layer as a nn.Module.

    Returns:
        nn.Module or None: the normalization layer
    """
    if isinstance(norm, str):
        if len(norm) == 0:
            return None
        norm = {
            "BN": BatchNorm2d,
            # Fixed in https://github.com/pytorch/pytorch/pull/36382
            "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
            "FrozenBN": FrozenBatchNorm2d,
            "GN": lambda channels: nn.GroupNorm(32, channels),
            # for debugging:
            "nnSyncBN": nn.SyncBatchNorm,
            "naiveSyncBN": NaiveSyncBatchNorm,
        }[norm]
    return norm(out_channels) 
Example #6
Source File: batch_norm.py    From SegmenTron with Apache License 2.0 6 votes vote down vote up
def get_norm(norm):
    """
    Args:
        norm (str or callable):

    Returns:
        nn.Module or None: the normalization layer
    """
    support_norm_type = ['BN', 'SyncBN', 'FrozenBN', 'GN', 'nnSyncBN']
    assert norm in support_norm_type, 'Unknown norm type {}, support norm types are {}'.format(
                                                                        norm, support_norm_type)
    if isinstance(norm, str):
        if len(norm) == 0:
            return None
        norm = {
            "BN": nn.BatchNorm2d,
            "SyncBN": NaiveSyncBatchNorm,
            "FrozenBN": FrozenBatchNorm2d,
            "GN": groupNorm,
            "nnSyncBN": nn.SyncBatchNorm,  # keep for debugging
        }[norm]
    return norm 
Example #7
Source File: batch_norm.py    From SegmenTron with Apache License 2.0 6 votes vote down vote up
def forward(self, input):
        if get_world_size() == 1 or not self.training:
            return super().forward(input)

        assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs"
        C = input.shape[1]
        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

        vec = torch.cat([mean, meansqr], dim=0)
        vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())

        mean, meansqr = torch.split(vec, C)
        var = meansqr - mean * mean
        self.running_mean += self.momentum * (mean.detach() - self.running_mean)
        self.running_var += self.momentum * (var.detach() - self.running_var)

        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)
        return input * scale + bias 
Example #8
Source File: operations.py    From lightDSFD with MIT License 5 votes vote down vote up
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, bn=False):
        super(SepConv, self).__init__()
        if not bn:
            op = nn.Sequential(
                # nn.ReLU(),
                nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=True,),
                nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=True),
            )
        else:
            if cfg['GN']:
                bn_layer = nn.GroupNorm(32, C_out)
            elif cfg["syncBN"]:
                bn_layer = nn.SyncBatchNorm(C_out)
            else:
                bn_layer = nn.BatchNorm2d(C_out)
                
            op = nn.Sequential(
                # nn.ReLU(),
                nn.Conv2d(
                    C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False,
                ),
                nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
                bn_layer,
            )

        if RELU_FIRST:
            self.op = nn.Sequential(nn.ReLU())
            # self.op.add_module('0', nn.ReLU())
            for i in range(1, len(op)+1):
                self.op.add_module(str(i), op[i-1])
        else:
            self.op = op
            self.op.add_module(str(len(op)), nn.ReLU())
        # self.op = op 
Example #9
Source File: fuse_conv_bn.py    From mmdetection with Apache License 2.0 5 votes vote down vote up
def fuse_module(m):
    last_conv = None
    last_conv_name = None

    for name, child in m.named_children():
        if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)):
            if last_conv is None:  # only fuse BN that is after Conv
                continue
            fused_conv = fuse_conv_bn(last_conv, child)
            m._modules[last_conv_name] = fused_conv
            # To reduce changes, set BN as Identity instead of deleting it.
            m._modules[name] = nn.Identity()
            last_conv = None
        elif isinstance(child, nn.Conv2d):
            last_conv = child
            last_conv_name = name
        else:
            fuse_module(child)
    return m 
Example #10
Source File: eval.py    From SegmenTron with Apache License 2.0 5 votes vote down vote up
def set_batch_norm_attr(self, named_modules, attr, value):
        for m in named_modules:
            if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm):
                setattr(m[1], attr, value) 
Example #11
Source File: batch_norm.py    From SegmenTron with Apache License 2.0 5 votes vote down vote up
def convert_frozen_batchnorm(cls, module):
        """
        Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.

        Args:
            module (torch.nn.Module):

        Returns:
            If module is BatchNorm/SyncBatchNorm, returns a new module.
            Otherwise, in-place convert module and return it.

        Similar to convert_sync_batchnorm in
        https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
        """
        bn_module = nn.modules.batchnorm
        bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
        res = module
        if isinstance(module, bn_module):
            res = cls(module.num_features)
            if module.affine:
                res.weight.data = module.weight.data.clone().detach()
                res.bias.data = module.bias.data.clone().detach()
            res.running_mean.data = module.running_mean.data
            res.running_var.data = module.running_var.data + module.eps
        else:
            for name, child in module.named_children():
                new_child = cls.convert_frozen_batchnorm(child)
                if new_child is not child:
                    res.add_module(name, new_child)
        return res 
Example #12
Source File: batch_norm.py    From detectron2 with Apache License 2.0 5 votes vote down vote up
def convert_frozen_batchnorm(cls, module):
        """
        Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.

        Args:
            module (torch.nn.Module):

        Returns:
            If module is BatchNorm/SyncBatchNorm, returns a new module.
            Otherwise, in-place convert module and return it.

        Similar to convert_sync_batchnorm in
        https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
        """
        bn_module = nn.modules.batchnorm
        bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
        res = module
        if isinstance(module, bn_module):
            res = cls(module.num_features)
            if module.affine:
                res.weight.data = module.weight.data.clone().detach()
                res.bias.data = module.bias.data.clone().detach()
            res.running_mean.data = module.running_mean.data
            res.running_var.data = module.running_var.data + module.eps
        else:
            for name, child in module.named_children():
                new_child = cls.convert_frozen_batchnorm(child)
                if new_child is not child:
                    res.add_module(name, new_child)
        return res 
Example #13
Source File: batch_norm.py    From detectron2 with Apache License 2.0 5 votes vote down vote up
def forward(self, input):
        if comm.get_world_size() == 1 or not self.training:
            return super().forward(input)

        B, C = input.shape[0], input.shape[1]

        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

        if self._stats_mode == "":
            assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
            vec = torch.cat([mean, meansqr], dim=0)
            vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
            mean, meansqr = torch.split(vec, C)
            momentum = self.momentum
        else:
            if B == 0:
                vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
                vec = vec + input.sum()  # make sure there is gradient w.r.t input
            else:
                vec = torch.cat(
                    [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0
                )
            vec = AllReduce.apply(vec * B)

            total_batch = vec[-1].detach()
            momentum = total_batch.clamp(max=1) * self.momentum  # no update if total_batch is 0
            total_batch = torch.max(total_batch, torch.ones_like(total_batch))  # avoid div-by-zero
            mean, meansqr, _ = torch.split(vec / total_batch, C)

        var = meansqr - mean * mean
        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)

        self.running_mean += momentum * (mean.detach() - self.running_mean)
        self.running_var += momentum * (var.detach() - self.running_var)
        return input * scale + bias 
Example #14
Source File: batch_norm.py    From detectron2 with Apache License 2.0 5 votes vote down vote up
def convert_frozen_batchnorm(cls, module):
        """
        Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.

        Args:
            module (torch.nn.Module):

        Returns:
            If module is BatchNorm/SyncBatchNorm, returns a new module.
            Otherwise, in-place convert module and return it.

        Similar to convert_sync_batchnorm in
        https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
        """
        bn_module = nn.modules.batchnorm
        bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
        res = module
        if isinstance(module, bn_module):
            res = cls(module.num_features)
            if module.affine:
                res.weight.data = module.weight.data.clone().detach()
                res.bias.data = module.bias.data.clone().detach()
            res.running_mean.data = module.running_mean.data
            res.running_var.data = module.running_var.data
            res.eps = module.eps
        else:
            for name, child in module.named_children():
                new_child = cls.convert_frozen_batchnorm(child)
                if new_child is not child:
                    res.add_module(name, new_child)
        return res 
Example #15
Source File: operations.py    From eval-nas with MIT License 5 votes vote down vote up
def __init__(self, kernel_size, full_input_size, full_output_size, curr_vtx_id=None, args=None):
        super(DynamicReLUConvBN, self).__init__()
        self.args = args
        padding = 1 if kernel_size == 3 else 0
        # assign layers.
        self.relu = nn.ReLU(inplace=False)
        self.conv = DynamicConv2d(
            full_input_size, full_output_size, kernel_size, padding=padding, bias=False,
            dynamic_conv_method=args.dynamic_conv_method, dynamic_conv_dropoutw=args.dynamic_conv_dropoutw
        )
        self.curr_vtx_id = curr_vtx_id
        tracking_stat = args.wsbn_track_stat
        if args.wsbn_sync:
            # logging.debug("Using sync bn.")
            self.bn = SyncBatchNorm(full_output_size, momentum=base_ops.BN_MOMENTUM, eps=base_ops.BN_EPSILON,
                                    track_running_stats=tracking_stat)
        else:
            self.bn = nn.BatchNorm2d(full_output_size, momentum=base_ops.BN_MOMENTUM, eps=base_ops.BN_EPSILON,
                                     track_running_stats=tracking_stat)

        self.bn_train = args.wsbn_train     # store the bn train of not.
        if self.bn_train:
            self.bn.train()
        else:
            self.bn.eval()

        # for dynamic channel
        self.channel_drop = ChannelDropout(args.channel_dropout_method, args.channel_dropout_dropouto)
        self.output_size = full_output_size
        self.current_outsize = full_output_size # may change according to different value.
        self.current_insize = full_input_size 
Example #16
Source File: operations.py    From lightDSFD with MIT License 5 votes vote down vote up
def forward(self, weights, temp_coeff=1.0):
        gumbel = -1e-3 * torch.log(-torch.log(torch.rand_like(weights))).to(weights.device)
        weights = _GumbelSoftMax.apply((weights + gumbel) / temp_coeff)
        return weights


# class D_Conv(nn.Module):
#     """ Deformable Conv V2 """

#     def __init__(self, C_in, C_out, kernel_size, padding, affine=True, bn=False):
#         super(D_Conv, self).__init__()
#         if bn:
#             if cfg["syncBN"]:
#                 bn_layer = nn.SyncBatchNorm(C_out)
#             else:
#                 bn_layer = nn.BatchNorm2d(C_out)
#             self.op = nn.Sequential(
#                 nn.ReLU(inplace=False),
#                 DCN(
#                     C_in, C_in, kernel_size=kernel_size, padding=padding, stride=1, deformable_groups=C_in, groups=C_in
#                 ),
#                 nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
#                 bn_layer,
#             )
#         else:
#             self.op = nn.Sequential(
#                 nn.ReLU(inplace=False),
#                 DCN(
#                     C_in, C_in, kernel_size=kernel_size, padding=padding, stride=1, deformable_groups=C_in, groups=C_in
#                 ),
#                 nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=True),
#             )

#     def forward(self, x):
#         return self.op(x) 
Example #17
Source File: operations.py    From lightDSFD with MIT License 5 votes vote down vote up
def __init__(self, C_in, C_out, affine=True, bn=False, **kwargs):
        super(Normal_Relu_Conv, self).__init__()
        if not bn:
            op = nn.Sequential(
                # nn.ReLU(),
                nn.Conv2d(C_in, C_in, bias=True, **kwargs),
            )
        else:
            if cfg['GN']:
                bn_layer = nn.GroupNorm(32, C_out)
            elif cfg["syncBN"]:
                bn_layer = nn.SyncBatchNorm(C_out)
            else:
                bn_layer = nn.BatchNorm2d(C_out)
                
            op = nn.Sequential(
                # nn.ReLU(),
                nn.Conv2d(C_in, C_in, bias=False, **kwargs),
                bn_layer,
            )
        
        if RELU_FIRST:
            self.op = nn.Sequential()
            self.op.add_module('0', nn.ReLU())
            for i in range(1, len(op)+1):
                self.op.add_module(str(i), op[i-1])
        else:
            self.op = op
            self.op.add_module(str(len(op)), nn.ReLU())
        # self.op = op 
Example #18
Source File: model_search.py    From lightDSFD with MIT License 5 votes vote down vote up
def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
        
        if cfg['GN']:
            self.bn = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-5)
        elif cfg['syncBN']:
            self.bn = nn.SyncBatchNorm(out_channels, eps=1e-5)
        else:
            self.bn = nn.BatchNorm2d(out_channels, eps=1e-5) 
Example #19
Source File: eval.py    From awesome-semantic-segmentation-pytorch with Apache License 2.0 5 votes vote down vote up
def __init__(self, args):
        self.args = args
        self.device = torch.device(args.device)

        # image transform
        input_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
        ])

        # dataset and dataloader
        val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform)
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1)
        self.val_loader = data.DataLoader(dataset=val_dataset,
                                          batch_sampler=val_batch_sampler,
                                          num_workers=args.workers,
                                          pin_memory=True)

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone,
                                            aux=args.aux, pretrained=True, pretrained_base=False,
                                            local_rank=args.local_rank,
                                            norm_layer=BatchNorm2d).to(self.device)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(self.model,
                device_ids=[args.local_rank], output_device=args.local_rank)
        self.model.to(self.device)

        self.metric = SegmentationMetric(val_dataset.num_class) 
Example #20
Source File: optimizer.py    From SegmenTron with Apache License 2.0 5 votes vote down vote up
def _set_batch_norm_attr(named_modules, attr, value):
    for m in named_modules:
        if isinstance(m[1], (nn.BatchNorm2d, nn.SyncBatchNorm)):
            setattr(m[1], attr, value) 
Example #21
Source File: operations.py    From lightDSFD with MIT License 4 votes vote down vote up
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, bn=False):
        super(DilConv, self).__init__()
        if not bn:
            op = nn.Sequential(
                # nn.ReLU(),
                nn.Conv2d(
                    C_in,
                    C_in,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    dilation=dilation,
                    groups=C_in,
                    bias=True,
                ),
                nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=True),
            )
        else:
            if cfg['GN']:
                bn_layer = nn.GroupNorm(32, C_out)
            elif cfg["syncBN"]:
                bn_layer = nn.SyncBatchNorm(C_out)
            else:
                bn_layer = nn.BatchNorm2d(C_out)
            
            op = nn.Sequential(
                # nn.ReLU(),
                nn.Conv2d(
                    C_in,
                    C_in,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    dilation=dilation,
                    groups=C_in,
                    bias=False,
                ),
                nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
                bn_layer,
            )

        if RELU_FIRST:
            self.op = nn.Sequential()
            self.op.add_module('0', nn.ReLU())
            for i in range(1, len(op)+1):
                self.op.add_module(str(i), op[i-1])
        else:
            self.op = op
            self.op.add_module(str(len(op)), nn.ReLU())
        # self.op = op 
Example #22
Source File: module_helper.py    From openseg.pytorch with MIT License 4 votes vote down vote up
def BatchNorm2d(bn_type='torch', ret_cls=False):
        if bn_type == 'torchbn':
            return nn.BatchNorm2d

        elif bn_type == 'torchsyncbn':
            return nn.SyncBatchNorm

        elif bn_type == 'syncbn':
            from lib.extensions.syncbn.module import BatchNorm2d
            return BatchNorm2d

        elif bn_type == 'sn':
            from lib.extensions.switchablenorms.switchable_norm import SwitchNorm2d
            return SwitchNorm2d

        elif bn_type == 'gn':
            return functools.partial(nn.GroupNorm, num_groups=32)    

        elif bn_type == 'inplace_abn':
            torch_ver = torch.__version__[:3]
            if torch_ver == '0.4':
                from lib.extensions.inplace_abn.bn import InPlaceABNSync
                if ret_cls:
                    return InPlaceABNSync

                return functools.partial(InPlaceABNSync, activation='none')

            elif torch_ver in ('1.0', '1.1'):
                from lib.extensions.inplace_abn_1.bn import InPlaceABNSync
                if ret_cls:
                    return InPlaceABNSync

                return functools.partial(InPlaceABNSync, activation='none')  
                          
            elif torch_ver == '1.2':
                from inplace_abn import InPlaceABNSync
                if ret_cls:
                    return InPlaceABNSync

                return functools.partial(InPlaceABNSync, activation='identity')

        else:
            Log.error('Not support BN type: {}.'.format(bn_type))
            exit(1) 
Example #23
Source File: module_helper.py    From openseg.pytorch with MIT License 4 votes vote down vote up
def BNReLU(num_features, bn_type=None, **kwargs):
        if bn_type == 'torchbn':
            return nn.Sequential(
                nn.BatchNorm2d(num_features, **kwargs),
                nn.ReLU()
            )
        elif bn_type == 'torchsyncbn':
            return nn.Sequential(
                nn.SyncBatchNorm(num_features, **kwargs),
                nn.ReLU()
            )
        elif bn_type == 'syncbn':
            from lib.extensions.syncbn.module import BatchNorm2d
            return nn.Sequential(
                BatchNorm2d(num_features, **kwargs),
                nn.ReLU()
            )
        elif bn_type == 'sn':
            from lib.extensions.switchablenorms.switchable_norm import SwitchNorm2d
            return nn.Sequential(
                SwitchNorm2d(num_features, **kwargs),
                nn.ReLU()
            )
        elif bn_type == 'gn':
            return nn.Sequential(
                nn.GroupNorm(num_groups=8, num_channels=num_features, **kwargs),
                nn.ReLU()
            )
        elif bn_type == 'fn':
            Log.error('Not support Filter-Response-Normalization: {}.'.format(bn_type))
            exit(1)
        elif bn_type == 'inplace_abn':
            torch_ver = torch.__version__[:3]
            # Log.info('Pytorch Version: {}'.format(torch_ver))
            if torch_ver == '0.4':
                from lib.extensions.inplace_abn.bn import InPlaceABNSync
                return InPlaceABNSync(num_features, **kwargs)
            elif torch_ver in ('1.0', '1.1'):
                from lib.extensions.inplace_abn_1.bn import InPlaceABNSync
                return InPlaceABNSync(num_features, **kwargs)
            elif torch_ver == '1.2':
                from inplace_abn import InPlaceABNSync
                return InPlaceABNSync(num_features, **kwargs)

        else:
            Log.error('Not support BN type: {}.'.format(bn_type))
            exit(1)