import numpy as np import torch from torch import nn class NormedLinear(nn.Module): """ Linear layer with normalization """ def __init__(self, in_features, out_features, dim=-1, weightnorm=True): super(NormedLinear, self).__init__() """ args: in_features (in): number of input features out_features (int): number of output features dim (int): dimension to aply transformation to weightnorm (bool): use weight normalization """ self.register_buffer("initialized", torch.tensor(False)) self._in_features = in_features self._out_features = out_features self.dim = dim self.linear = nn.Linear(self._in_features, out_features) # add batch norm if not weightnorm: self.weightnorm = False else: self.weightnorm = True self.linear = nn.utils.weight_norm(self.linear, dim=0, name="weight") def forward(self, x): # reshape in shp = list(x.size()) dim = self.dim if self.dim >= 0 else x.dim() + self.dim if dim < x.dim() - 1: x = x.view(*x.size()[:dim + 1], -1) x = x.transpose(-1, -2).contiguous() shp_2 = list(x.shape) x = x.view(-1, x.size(-1)) permute = True else: x = x.view(-1, x.size(-1)) permute = False # init and transform if not self.initialized: self.init_parameters(x) x = self.linear(x) # reshape out shp[dim] = self._out_features if permute: shp_2[-1] = self._out_features x = x.view(shp_2).transpose(-1, -2) x = x.view(shp) else: x = x.view(shp) return x @property def input_shape(self): return (-1, self._in_features) @property def output_shape(self): return (-1, self._out_features) def init_parameters(self, x, init_scale=0.05, eps=1e-8): if self.weightnorm: # initial values self.linear._parameters['weight_v'].data.normal_(mean=0, std=init_scale) self.linear._parameters['weight_g'].data.fill_(1.) self.linear._parameters['bias'].data.fill_(0.) init_scale = .01 # data dependent init x = self.linear(x) m_init, v_init = torch.mean(x, 0), torch.var(x, 0) scale_init = init_scale / torch.sqrt(v_init + eps) self.linear._parameters['weight_g'].data = self.linear._parameters['weight_g'].data * scale_init.view( self.linear._parameters['weight_g'].data.size()) self.linear._parameters['bias'].data = self.linear._parameters['bias'].data - m_init * scale_init self.initialized = True + self.initialized return scale_init[None, :] * (x - m_init[None, :]) class NormedDense(nn.Module): """ Dense layer with normalization """ def __init__(self, tensor_shape, out_features, weightnorm=True): super(NormedDense, self).__init__() """ args: tensor_shape (tuple): input tensor shape (B x C x D) out_features (int): number of output features weight (bool): use weight normalization """ self.register_buffer("initialized", torch.tensor(False)) self._input_shp = tensor_shape self.input_features = int(np.prod(tensor_shape[1:])) self._output_shp = (-1, out_features) self.linear = nn.Linear(self.input_features, out_features) # add batch norm if not weightnorm: self.weightnorm = False else: self.weightnorm = True self.linear = nn.utils.weight_norm(self.linear, dim=0, name="weight") def forward(self, x): x = x.view(x.size()[0], -1) if not self.initialized: self.init_parameters(x) x = self.linear(x) return x @property def input_shape(self): return self._input_shp @property def output_shape(self): return self._output_shp def init_parameters(self, x, init_scale=0.05, eps=1e-8): if self.weightnorm: # initial values self.linear._parameters['weight_v'].data.normal_(mean=0, std=init_scale) self.linear._parameters['weight_g'].data.fill_(1.) self.linear._parameters['bias'].data.fill_(0.) init_scale = .01 # data dependent init x = self.linear(x) m_init, v_init = torch.mean(x, 0), torch.var(x, 0) scale_init = init_scale / torch.sqrt(v_init + eps) self.linear._parameters['weight_g'].data = self.linear._parameters['weight_g'].data * scale_init.view( self.linear._parameters['weight_g'].data.size()) self.linear._parameters['bias'].data = self.linear._parameters['bias'].data - m_init * scale_init self.initialized = True + self.initialized return scale_init[None, :] * (x - m_init[None, :]) class AsFeatureMap(nn.Module): def __init__(self, input_shape, target_shape, weightnorm=True, **kwargs): super().__init__() self._input_shp = input_shape if len(input_shape) < len(target_shape): out_features = np.prod(target_shape[1:]) self.linear = NormedDense(input_shape, out_features, weightnorm=weightnorm) self._output_shp = target_shape else: self.linear = None self._output_shp = input_shape def forward(self, x): if self.linear is None: return x x = self.linear(x) return x.view(self.output_shape) @property def input_shape(self): return self._input_shp @property def output_shape(self): return self._output_shp