import torch from torch import nn from torch.nn import Parameter import torch.nn.functional as F from functools import reduce import operator eps = 1e-8 class LinearARD(nn.Module): """ Dense layer implementation with weights ARD-prior (arxiv:1701.05369) """ def __init__(self, in_features, out_features, bias=True, thresh=3, ard_init=-10): super(LinearARD, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) self.thresh = thresh if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.ard_init = ard_init self.log_sigma2 = Parameter(torch.Tensor(out_features, in_features)) self.reset_parameters() def forward(self, input): """ Forward with all regularized connections and random activations (Beyesian mode). Typically used for train """ if self.training == False: return F.linear(input, self.weights_clipped, self.bias) clip_mask = self.get_clip_mask() W = self.weight zeros = torch.zeros_like(W) mu = input.matmul(W.t()) eps = 1e-8 log_alpha = self.clip(self.log_alpha) si = torch.sqrt((input * input) \ .matmul(((torch.exp(log_alpha) * self.weight * self.weight)+eps).t())) activation = mu + torch.normal(torch.zeros_like(mu), torch.ones_like(mu)) * si return activation + self.bias @property def weights_clipped(self): clip_mask = self.get_clip_mask() return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight) def reset_parameters(self): self.weight.data.normal_(std=0.01) if self.bias is not None: self.bias.data.uniform_(0, 0) # self.log_sigma2.data = 2*torch.log(torch.abs(self.weight)+eps).clone().detach() + self.ard_init*torch.ones_like(self.log_sigma2) self.log_sigma2.data = self.ard_init*torch.ones_like(self.log_sigma2) @staticmethod def clip(tensor, to=8): """ Shrink all tensor's values to range [-to,to] """ return torch.clamp(tensor, -to, to) def get_clip_mask(self): log_alpha = self.clip(self.log_alpha) return torch.ge(log_alpha, self.thresh) def train(self, mode): self.training = mode super(LinearARD, self).train(mode) def get_reg(self, **kwargs): """ Get weights regularization (KL(q(w)||p(w)) approximation) """ k1, k2, k3 = 0.63576, 1.8732, 1.48695; C = -k1 log_alpha = self.clip(self.log_alpha) mdkl = k1 * torch.sigmoid(k2 + k3 * log_alpha) - 0.5 * torch.log1p(torch.exp(-log_alpha)) + C return -torch.sum(mdkl) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None ) def get_dropped_params_cnt(self): """ Get number of dropped weights (with log alpha greater than "thresh" parameter) :returns (number of dropped weights, number of all weight) """ return self.get_clip_mask().sum().cpu().numpy() @property def log_alpha(self): eps = 1e-8 return self.log_sigma2 - 2 * torch.log(torch.abs(self.weight)+eps) class Conv2dARD(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, ard_init=-10, thresh=3): bias = False # Goes to nan if bias = True super(Conv2dARD, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.bias = None self.thresh = thresh self.in_channels = in_channels self.out_channels = out_channels self.ard_init = ard_init self.log_sigma2 = Parameter(ard_init*torch.ones_like(self.weight)) # self.log_sigma2 = Parameter(2 * torch.log(torch.abs(self.weight) + eps).clone().detach()+ard_init*torch.ones_like(self.weight)) @staticmethod def clip(tensor, to=8): """ Shrink all tensor's values to range [-to,to] """ return torch.clamp(tensor, -to, to) def forward(self, input): """ Forward with all regularized connections and random activations (Beyesian mode). Typically used for train """ if self.training == False: return F.conv2d(input, self.weights_clipped, self.bias, self.stride, self.padding, self.dilation, self.groups) eps = 1e-8 W = self.weight zeros = torch.zeros_like(W) clip_mask = self.get_clip_mask() conved_mu = F.conv2d(input, W, self.bias, self.stride, self.padding, self.dilation, self.groups) log_alpha = self.clip(self.log_alpha) conved_si = torch.sqrt(eps + F.conv2d(input*input, torch.exp(log_alpha) * W * W, self.bias, self.stride, self.padding, self.dilation, self.groups)) conved = conved_mu + \ conved_si * torch.normal(torch.zeros_like(conved_mu), torch.ones_like(conved_mu)) return conved @property def weights_clipped(self): clip_mask = self.get_clip_mask() return torch.where(clip_mask, torch.zeros_like(self.weight), self.weight) def get_clip_mask(self): log_alpha = self.clip(self.log_alpha) return torch.ge(log_alpha, self.thresh) def train(self, mode): self.training = mode super(Conv2dARD, self).train(mode) def get_reg(self, **kwargs): """ Get weights regularization (KL(q(w)||p(w)) approximation) """ k1, k2, k3 = 0.63576, 1.8732, 1.48695; C = -k1 log_alpha = self.clip(self.log_alpha) mdkl = k1 * torch.sigmoid(k2 + k3 * log_alpha) - 0.5 * torch.log1p(torch.exp(-log_alpha)) + C return -torch.sum(mdkl) def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_channels, self.out_channels, self.bias is not None ) def get_dropped_params_cnt(self): """ Get number of dropped weights (greater than "thresh" parameter) :returns (number of dropped weights, number of all weight) """ return self.get_clip_mask().sum().cpu().numpy() @property def log_alpha(self): eps = 1e-8 return self.log_sigma2 - 2 * torch.log(torch.abs(self.weight) + eps) def get_ard_reg(module, reg=0): """ :param module: model to evaluate ard regularization for :param reg: auxilary cumulative variable for recursion :return: total regularization for module """ if isinstance(module, LinearARD) or isinstance(module, Conv2dARD): return reg + module.get_reg() if hasattr(module, 'children'): return reg + sum([get_ard_reg(submodule) for submodule in module.children()]) return reg def _get_dropped_params_cnt(module, cnt=0): if hasattr(module, 'get_dropped_params_cnt'): return cnt + module.get_dropped_params_cnt() if hasattr(module, 'children'): return cnt + sum([_get_dropped_params_cnt(submodule) for submodule in module.children()]) return cnt def _get_params_cnt(module, cnt=0): if any([isinstance(module, LinearARD), isinstance(module, Conv2dARD)]): return cnt + reduce(operator.mul, module.weight.shape, 1) if hasattr(module, 'children'): return cnt + sum( [_get_params_cnt(submodule) for submodule in module.children()]) return cnt + sum(p.numel() for p in module.parameters()) def get_dropped_params_ratio(model): return _get_dropped_params_cnt(model)*1.0/_get_params_cnt(model)