""" /*****************************************************************************/ BatchNorm2dSync with multi-gpu /*****************************************************************************/ """ from __future__ import absolute_import from __future__ import division from __future__ import print_function try: # python 3 from queue import Queue except ImportError: # python 2 from Queue import Queue import torch import torch.nn as nn from torch.nn import functional as F from torch.nn.parameter import Parameter from modules.functional import batchnorm2d_sync class _BatchNorm(nn.Module): """ Customized BatchNorm from nn.BatchNorm >> added freeze attribute to enable bn freeze. """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(_BatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats self.freezed = False if self.affine: self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) else: self.register_parameter('running_mean', None) self.register_parameter('running_var', None) self.reset_parameters() def reset_parameters(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) if self.affine: self.weight.data.uniform_() self.bias.data.zero_() def _check_input_dim(self, input): return NotImplemented def forward(self, input): self._check_input_dim(input) compute_stats = not self.freezed and \ self.training and self.track_running_stats ret = F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, compute_stats, self.momentum, self.eps) return ret def extra_repr(self): return '{num_features}, eps={eps}, momentum={momentum}, '\ 'affine={affine}, ' \ 'track_running_stats={track_running_stats}'.format( **self.__dict__) class BatchNorm2dNoSync(_BatchNorm): """ Equivalent to nn.BatchNorm2d """ def _check_input_dim(self, input): if input.dim() != 4: raise ValueError('expected 4D input (got {}D input)' .format(input.dim())) class BatchNorm2dSync(BatchNorm2dNoSync): """ BatchNorm2d with automatic multi-GPU Sync """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True): super(BatchNorm2dSync, self).__init__( num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) self.sync_enabled = True self.devices = list(range(torch.cuda.device_count())) if len(self.devices) > 1: # Initialize queues self.worker_ids = self.devices[1:] self.master_queue = Queue(len(self.worker_ids)) self.worker_queues = [Queue(1) for _ in self.worker_ids] def forward(self, x): compute_stats = not self.freezed and \ self.training and self.track_running_stats if self.sync_enabled and compute_stats and len(self.devices) > 1: if x.get_device() == self.devices[0]: # Master mode extra = { "is_master": True, "master_queue": self.master_queue, "worker_queues": self.worker_queues, "worker_ids": self.worker_ids } else: # Worker mode extra = { "is_master": False, "master_queue": self.master_queue, "worker_queue": self.worker_queues[ self.worker_ids.index(x.get_device())] } return batchnorm2d_sync(x, self.weight, self.bias, self.running_mean, self.running_var, extra, compute_stats, self.momentum, self.eps) return super(BatchNorm2dSync, self).forward(x) def __repr__(self): """repr""" rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 'affine={affine}, ' \ 'track_running_stats={track_running_stats},' \ 'devices={devices})' return rep.format(name=self.__class__.__name__, **self.__dict__) #BatchNorm2d = BatchNorm2dNoSync BatchNorm2d = BatchNorm2dSync