Python torch.nn.SyncBatchNorm() Examples
The following are 23
code examples of torch.nn.SyncBatchNorm().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch.nn
, or try the search function
.
Example #1
Source File: batch_norm.py From Parsing-R-CNN with MIT License | 6 votes |
def forward(self, input): if get_world_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return input * scale + bias
Example #2
Source File: norm.py From Det3D with Apache License 2.0 | 6 votes |
def forward(self, input): if comm.get_world_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty input" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return input * scale + bias
Example #3
Source File: batch_norm.py From detectron2 with Apache License 2.0 | 6 votes |
def forward(self, input): if comm.get_world_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return input * scale + bias
Example #4
Source File: batch_norm.py From detectron2 with Apache License 2.0 | 6 votes |
def get_norm(norm, out_channels): """ Args: norm (str or callable): Returns: nn.Module or None: the normalization layer """ if isinstance(norm, str): if len(norm) == 0: return None norm = { "BN": BatchNorm2d, "SyncBN": NaiveSyncBatchNorm, "FrozenBN": FrozenBatchNorm2d, "GN": lambda channels: nn.GroupNorm(32, channels), "nnSyncBN": nn.SyncBatchNorm, # keep for debugging }[norm] return norm(out_channels)
Example #5
Source File: batch_norm.py From detectron2 with Apache License 2.0 | 6 votes |
def get_norm(norm, out_channels): """ Args: norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; or a callable that takes a channel number and returns the normalization layer as a nn.Module. Returns: nn.Module or None: the normalization layer """ if isinstance(norm, str): if len(norm) == 0: return None norm = { "BN": BatchNorm2d, # Fixed in https://github.com/pytorch/pytorch/pull/36382 "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm, "FrozenBN": FrozenBatchNorm2d, "GN": lambda channels: nn.GroupNorm(32, channels), # for debugging: "nnSyncBN": nn.SyncBatchNorm, "naiveSyncBN": NaiveSyncBatchNorm, }[norm] return norm(out_channels)
Example #6
Source File: batch_norm.py From SegmenTron with Apache License 2.0 | 6 votes |
def get_norm(norm): """ Args: norm (str or callable): Returns: nn.Module or None: the normalization layer """ support_norm_type = ['BN', 'SyncBN', 'FrozenBN', 'GN', 'nnSyncBN'] assert norm in support_norm_type, 'Unknown norm type {}, support norm types are {}'.format( norm, support_norm_type) if isinstance(norm, str): if len(norm) == 0: return None norm = { "BN": nn.BatchNorm2d, "SyncBN": NaiveSyncBatchNorm, "FrozenBN": FrozenBatchNorm2d, "GN": groupNorm, "nnSyncBN": nn.SyncBatchNorm, # keep for debugging }[norm] return norm
Example #7
Source File: batch_norm.py From SegmenTron with Apache License 2.0 | 6 votes |
def forward(self, input): if get_world_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return input * scale + bias
Example #8
Source File: operations.py From lightDSFD with MIT License | 5 votes |
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True, bn=False): super(SepConv, self).__init__() if not bn: op = nn.Sequential( # nn.ReLU(), nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=True,), nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=True), ) else: if cfg['GN']: bn_layer = nn.GroupNorm(32, C_out) elif cfg["syncBN"]: bn_layer = nn.SyncBatchNorm(C_out) else: bn_layer = nn.BatchNorm2d(C_out) op = nn.Sequential( # nn.ReLU(), nn.Conv2d( C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False, ), nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), bn_layer, ) if RELU_FIRST: self.op = nn.Sequential(nn.ReLU()) # self.op.add_module('0', nn.ReLU()) for i in range(1, len(op)+1): self.op.add_module(str(i), op[i-1]) else: self.op = op self.op.add_module(str(len(op)), nn.ReLU()) # self.op = op
Example #9
Source File: fuse_conv_bn.py From mmdetection with Apache License 2.0 | 5 votes |
def fuse_module(m): last_conv = None last_conv_name = None for name, child in m.named_children(): if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)): if last_conv is None: # only fuse BN that is after Conv continue fused_conv = fuse_conv_bn(last_conv, child) m._modules[last_conv_name] = fused_conv # To reduce changes, set BN as Identity instead of deleting it. m._modules[name] = nn.Identity() last_conv = None elif isinstance(child, nn.Conv2d): last_conv = child last_conv_name = name else: fuse_module(child) return m
Example #10
Source File: eval.py From SegmenTron with Apache License 2.0 | 5 votes |
def set_batch_norm_attr(self, named_modules, attr, value): for m in named_modules: if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm): setattr(m[1], attr, value)
Example #11
Source File: batch_norm.py From SegmenTron with Apache License 2.0 | 5 votes |
def convert_frozen_batchnorm(cls, module): """ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. Args: module (torch.nn.Module): Returns: If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it. Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py """ bn_module = nn.modules.batchnorm bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) res = module if isinstance(module, bn_module): res = cls(module.num_features) if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data + module.eps else: for name, child in module.named_children(): new_child = cls.convert_frozen_batchnorm(child) if new_child is not child: res.add_module(name, new_child) return res
Example #12
Source File: batch_norm.py From detectron2 with Apache License 2.0 | 5 votes |
def convert_frozen_batchnorm(cls, module): """ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. Args: module (torch.nn.Module): Returns: If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it. Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py """ bn_module = nn.modules.batchnorm bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) res = module if isinstance(module, bn_module): res = cls(module.num_features) if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data + module.eps else: for name, child in module.named_children(): new_child = cls.convert_frozen_batchnorm(child) if new_child is not child: res.add_module(name, new_child) return res
Example #13
Source File: batch_norm.py From detectron2 with Apache License 2.0 | 5 votes |
def forward(self, input): if comm.get_world_size() == 1 or not self.training: return super().forward(input) B, C = input.shape[0], input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) if self._stats_mode == "": assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) momentum = self.momentum else: if B == 0: vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) vec = vec + input.sum() # make sure there is gradient w.r.t input else: vec = torch.cat( [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0 ) vec = AllReduce.apply(vec * B) total_batch = vec[-1].detach() momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 total_batch = torch.max(total_batch, torch.ones_like(total_batch)) # avoid div-by-zero mean, meansqr, _ = torch.split(vec / total_batch, C) var = meansqr - mean * mean invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) self.running_mean += momentum * (mean.detach() - self.running_mean) self.running_var += momentum * (var.detach() - self.running_var) return input * scale + bias
Example #14
Source File: batch_norm.py From detectron2 with Apache License 2.0 | 5 votes |
def convert_frozen_batchnorm(cls, module): """ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. Args: module (torch.nn.Module): Returns: If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it. Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py """ bn_module = nn.modules.batchnorm bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) res = module if isinstance(module, bn_module): res = cls(module.num_features) if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data res.eps = module.eps else: for name, child in module.named_children(): new_child = cls.convert_frozen_batchnorm(child) if new_child is not child: res.add_module(name, new_child) return res
Example #15
Source File: operations.py From eval-nas with MIT License | 5 votes |
def __init__(self, kernel_size, full_input_size, full_output_size, curr_vtx_id=None, args=None): super(DynamicReLUConvBN, self).__init__() self.args = args padding = 1 if kernel_size == 3 else 0 # assign layers. self.relu = nn.ReLU(inplace=False) self.conv = DynamicConv2d( full_input_size, full_output_size, kernel_size, padding=padding, bias=False, dynamic_conv_method=args.dynamic_conv_method, dynamic_conv_dropoutw=args.dynamic_conv_dropoutw ) self.curr_vtx_id = curr_vtx_id tracking_stat = args.wsbn_track_stat if args.wsbn_sync: # logging.debug("Using sync bn.") self.bn = SyncBatchNorm(full_output_size, momentum=base_ops.BN_MOMENTUM, eps=base_ops.BN_EPSILON, track_running_stats=tracking_stat) else: self.bn = nn.BatchNorm2d(full_output_size, momentum=base_ops.BN_MOMENTUM, eps=base_ops.BN_EPSILON, track_running_stats=tracking_stat) self.bn_train = args.wsbn_train # store the bn train of not. if self.bn_train: self.bn.train() else: self.bn.eval() # for dynamic channel self.channel_drop = ChannelDropout(args.channel_dropout_method, args.channel_dropout_dropouto) self.output_size = full_output_size self.current_outsize = full_output_size # may change according to different value. self.current_insize = full_input_size
Example #16
Source File: operations.py From lightDSFD with MIT License | 5 votes |
def forward(self, weights, temp_coeff=1.0): gumbel = -1e-3 * torch.log(-torch.log(torch.rand_like(weights))).to(weights.device) weights = _GumbelSoftMax.apply((weights + gumbel) / temp_coeff) return weights # class D_Conv(nn.Module): # """ Deformable Conv V2 """ # def __init__(self, C_in, C_out, kernel_size, padding, affine=True, bn=False): # super(D_Conv, self).__init__() # if bn: # if cfg["syncBN"]: # bn_layer = nn.SyncBatchNorm(C_out) # else: # bn_layer = nn.BatchNorm2d(C_out) # self.op = nn.Sequential( # nn.ReLU(inplace=False), # DCN( # C_in, C_in, kernel_size=kernel_size, padding=padding, stride=1, deformable_groups=C_in, groups=C_in # ), # nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), # bn_layer, # ) # else: # self.op = nn.Sequential( # nn.ReLU(inplace=False), # DCN( # C_in, C_in, kernel_size=kernel_size, padding=padding, stride=1, deformable_groups=C_in, groups=C_in # ), # nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=True), # ) # def forward(self, x): # return self.op(x)
Example #17
Source File: operations.py From lightDSFD with MIT License | 5 votes |
def __init__(self, C_in, C_out, affine=True, bn=False, **kwargs): super(Normal_Relu_Conv, self).__init__() if not bn: op = nn.Sequential( # nn.ReLU(), nn.Conv2d(C_in, C_in, bias=True, **kwargs), ) else: if cfg['GN']: bn_layer = nn.GroupNorm(32, C_out) elif cfg["syncBN"]: bn_layer = nn.SyncBatchNorm(C_out) else: bn_layer = nn.BatchNorm2d(C_out) op = nn.Sequential( # nn.ReLU(), nn.Conv2d(C_in, C_in, bias=False, **kwargs), bn_layer, ) if RELU_FIRST: self.op = nn.Sequential() self.op.add_module('0', nn.ReLU()) for i in range(1, len(op)+1): self.op.add_module(str(i), op[i-1]) else: self.op = op self.op.add_module(str(len(op)), nn.ReLU()) # self.op = op
Example #18
Source File: model_search.py From lightDSFD with MIT License | 5 votes |
def __init__(self, in_channels, out_channels, **kwargs): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d(in_channels, out_channels, **kwargs) if cfg['GN']: self.bn = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-5) elif cfg['syncBN']: self.bn = nn.SyncBatchNorm(out_channels, eps=1e-5) else: self.bn = nn.BatchNorm2d(out_channels, eps=1e-5)
Example #19
Source File: eval.py From awesome-semantic-segmentation-pytorch with Apache License 2.0 | 5 votes |
def __init__(self, args): self.args = args self.device = torch.device(args.device) # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader val_dataset = get_segmentation_dataset(args.dataset, split='val', mode='testval', transform=input_transform) val_sampler = make_data_sampler(val_dataset, False, args.distributed) val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=1) self.val_loader = data.DataLoader(dataset=val_dataset, batch_sampler=val_batch_sampler, num_workers=args.workers, pin_memory=True) # create network BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d self.model = get_segmentation_model(model=args.model, dataset=args.dataset, backbone=args.backbone, aux=args.aux, pretrained=True, pretrained_base=False, local_rank=args.local_rank, norm_layer=BatchNorm2d).to(self.device) if args.distributed: self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank], output_device=args.local_rank) self.model.to(self.device) self.metric = SegmentationMetric(val_dataset.num_class)
Example #20
Source File: optimizer.py From SegmenTron with Apache License 2.0 | 5 votes |
def _set_batch_norm_attr(named_modules, attr, value): for m in named_modules: if isinstance(m[1], (nn.BatchNorm2d, nn.SyncBatchNorm)): setattr(m[1], attr, value)
Example #21
Source File: operations.py From lightDSFD with MIT License | 4 votes |
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True, bn=False): super(DilConv, self).__init__() if not bn: op = nn.Sequential( # nn.ReLU(), nn.Conv2d( C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=True, ), nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=True), ) else: if cfg['GN']: bn_layer = nn.GroupNorm(32, C_out) elif cfg["syncBN"]: bn_layer = nn.SyncBatchNorm(C_out) else: bn_layer = nn.BatchNorm2d(C_out) op = nn.Sequential( # nn.ReLU(), nn.Conv2d( C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False, ), nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), bn_layer, ) if RELU_FIRST: self.op = nn.Sequential() self.op.add_module('0', nn.ReLU()) for i in range(1, len(op)+1): self.op.add_module(str(i), op[i-1]) else: self.op = op self.op.add_module(str(len(op)), nn.ReLU()) # self.op = op
Example #22
Source File: module_helper.py From openseg.pytorch with MIT License | 4 votes |
def BatchNorm2d(bn_type='torch', ret_cls=False): if bn_type == 'torchbn': return nn.BatchNorm2d elif bn_type == 'torchsyncbn': return nn.SyncBatchNorm elif bn_type == 'syncbn': from lib.extensions.syncbn.module import BatchNorm2d return BatchNorm2d elif bn_type == 'sn': from lib.extensions.switchablenorms.switchable_norm import SwitchNorm2d return SwitchNorm2d elif bn_type == 'gn': return functools.partial(nn.GroupNorm, num_groups=32) elif bn_type == 'inplace_abn': torch_ver = torch.__version__[:3] if torch_ver == '0.4': from lib.extensions.inplace_abn.bn import InPlaceABNSync if ret_cls: return InPlaceABNSync return functools.partial(InPlaceABNSync, activation='none') elif torch_ver in ('1.0', '1.1'): from lib.extensions.inplace_abn_1.bn import InPlaceABNSync if ret_cls: return InPlaceABNSync return functools.partial(InPlaceABNSync, activation='none') elif torch_ver == '1.2': from inplace_abn import InPlaceABNSync if ret_cls: return InPlaceABNSync return functools.partial(InPlaceABNSync, activation='identity') else: Log.error('Not support BN type: {}.'.format(bn_type)) exit(1)
Example #23
Source File: module_helper.py From openseg.pytorch with MIT License | 4 votes |
def BNReLU(num_features, bn_type=None, **kwargs): if bn_type == 'torchbn': return nn.Sequential( nn.BatchNorm2d(num_features, **kwargs), nn.ReLU() ) elif bn_type == 'torchsyncbn': return nn.Sequential( nn.SyncBatchNorm(num_features, **kwargs), nn.ReLU() ) elif bn_type == 'syncbn': from lib.extensions.syncbn.module import BatchNorm2d return nn.Sequential( BatchNorm2d(num_features, **kwargs), nn.ReLU() ) elif bn_type == 'sn': from lib.extensions.switchablenorms.switchable_norm import SwitchNorm2d return nn.Sequential( SwitchNorm2d(num_features, **kwargs), nn.ReLU() ) elif bn_type == 'gn': return nn.Sequential( nn.GroupNorm(num_groups=8, num_channels=num_features, **kwargs), nn.ReLU() ) elif bn_type == 'fn': Log.error('Not support Filter-Response-Normalization: {}.'.format(bn_type)) exit(1) elif bn_type == 'inplace_abn': torch_ver = torch.__version__[:3] # Log.info('Pytorch Version: {}'.format(torch_ver)) if torch_ver == '0.4': from lib.extensions.inplace_abn.bn import InPlaceABNSync return InPlaceABNSync(num_features, **kwargs) elif torch_ver in ('1.0', '1.1'): from lib.extensions.inplace_abn_1.bn import InPlaceABNSync return InPlaceABNSync(num_features, **kwargs) elif torch_ver == '1.2': from inplace_abn import InPlaceABNSync return InPlaceABNSync(num_features, **kwargs) else: Log.error('Not support BN type: {}.'.format(bn_type)) exit(1)