from collections import OrderedDict, Iterable from itertools import repeat try: # python 3 from queue import Queue except ImportError: # python 2 from Queue import Queue import torch import torch.nn as nn import torch.autograd as autograd from .functions import inplace_abn, inplace_abn_sync def _pair(x): if isinstance(x, Iterable): return x return tuple(repeat(x, 2)) class ABN(nn.Sequential): """Activated Batch Normalization This gathers a `BatchNorm2d` and an activation function in a single module """ def __init__(self, num_features, activation=nn.ReLU(inplace=True), **kwargs): """Creates an Activated Batch Normalization module Parameters ---------- num_features : int Number of feature channels in the input and output. activation : nn.Module Module used as an activation function. kwargs All other arguments are forwarded to the `BatchNorm2d` constructor. """ super(ABN, self).__init__(OrderedDict([ ("bn", nn.BatchNorm2d(num_features, **kwargs)), ("act", activation) ])) class InPlaceABN(nn.Module): """InPlace Activated Batch Normalization""" def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): """Creates an InPlace Activated Batch Normalization module Parameters ---------- num_features : int Number of feature channels in the input and output. eps : float Small constant to prevent numerical issues. momentum : float Momentum factor applied to compute running statistics as. affine : bool If `True` apply learned scale and shift transformation after normalization. activation : str Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. slope : float Negative slope for the `leaky_relu` activation. """ super(InPlaceABN, self).__init__() self.num_features = num_features self.affine = affine self.eps = eps self.momentum = momentum self.activation = activation self.slope = slope if self.affine: self.weight = nn.Parameter(torch.Tensor(num_features)) self.bias = nn.Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): self.running_mean.zero_() self.running_var.fill_(1) if self.affine: self.weight.data.fill_(1) self.bias.data.zero_() def forward(self, x): return inplace_abn(x, self.weight, self.bias, autograd.Variable(self.running_mean), autograd.Variable(self.running_var), self.training, self.momentum, self.eps, self.activation, self.slope) def __repr__(self): rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ ' affine={affine}, activation={activation}' if self.activation == "leaky_relu": rep += ' slope={slope})' else: rep += ')' return rep.format(name=self.__class__.__name__, **self.__dict__) class InPlaceABNSync(nn.Module): """InPlace Activated Batch Normalization with cross-GPU synchronization This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`. """ def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): """Creates a synchronized, InPlace Activated Batch Normalization module Parameters ---------- num_features : int Number of feature channels in the input and output. devices : list of int or None IDs of the GPUs that will run the replicas of this module. eps : float Small constant to prevent numerical issues. momentum : float Momentum factor applied to compute running statistics as. affine : bool If `True` apply learned scale and shift transformation after normalization. activation : str Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. slope : float Negative slope for the `leaky_relu` activation. """ super(InPlaceABNSync, self).__init__() self.num_features = num_features self.devices = devices if devices else list(range(torch.cuda.device_count())) self.affine = affine self.eps = eps self.momentum = momentum self.activation = activation self.slope = slope if self.affine: self.weight = nn.Parameter(torch.Tensor(num_features)) self.bias = nn.Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() # Initialize queues self.worker_ids = self.devices[1:] self.master_queue = Queue(len(self.worker_ids)) self.worker_queues = [Queue(1) for _ in self.worker_ids] def reset_parameters(self): self.running_mean.zero_() self.running_var.fill_(1) if self.affine: self.weight.data.fill_(1) self.bias.data.zero_() def forward(self, x): if x.get_device() == self.devices[0]: # Master mode extra = { "is_master": True, "master_queue": self.master_queue, "worker_queues": self.worker_queues, "worker_ids": self.worker_ids } else: # Worker mode extra = { "is_master": False, "master_queue": self.master_queue, "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] } return inplace_abn_sync(x, self.weight, self.bias, autograd.Variable(self.running_mean), autograd.Variable(self.running_var), extra, self.training, self.momentum, self.eps, self.activation, self.slope) def __repr__(self): rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ ' affine={affine}, devices={devices}, activation={activation}' if self.activation == "leaky_relu": rep += ' slope={slope})' else: rep += ')' return rep.format(name=self.__class__.__name__, **self.__dict__) class InPlaceABNWrapper(nn.Module): """Wrapper module to make `InPlaceABN` compatible with `ABN`""" def __init__(self, *args, **kwargs): super(InPlaceABNWrapper, self).__init__() self.bn = InPlaceABN(*args, **kwargs) def forward(self, input): return self.bn(input) class InPlaceABNSyncWrapper(nn.Module): """Wrapper module to make `InPlaceABNSync` compatible with `ABN`""" def __init__(self, *args, **kwargs): super(InPlaceABNSyncWrapper, self).__init__() self.bn = InPlaceABNSync(*args, **kwargs) def forward(self, input): return self.bn(input)