import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from misc import ops class ActNorm(nn.Module): def __init__(self, num_channels, scale=1., logscale_factor=3., batch_variance=False): """ Activation normalization layer :param num_channels: number of channels :type num_channels: int :param scale: scale :type scale: float :param logscale_factor: factor for logscale :type logscale_factor: float :param batch_variance: use batch variance :type batch_variance: bool """ super().__init__() self.num_channels = num_channels self.scale = scale self.logscale_factor = logscale_factor self.batch_variance = batch_variance self.bias_inited = False self.logs_inited = False self.register_parameter('bias', nn.Parameter(torch.zeros(1, self.num_channels, 1, 1))) self.register_parameter('logs', nn.Parameter(torch.zeros(1, self.num_channels, 1, 1))) def actnorm_center(self, x, reverse=False): """ center operation of activation normalization :param x: input :type x: torch.Tensor :param reverse: whether to reverse bias :type reverse: bool :return: centered input :rtype: torch.Tensor """ if not self.bias_inited: self.initialize_bias(x) if not reverse: return x + self.bias else: return x - self.bias def actnorm_scale(self, x, logdet, reverse=False): """ scale operation of activation normalization :param x: input :type x: torch.Tensor :param logdet: log determinant :type logdet: :param reverse: whether to reverse bias :type reverse: bool :return: centered input and logdet :rtype: tuple(torch.Tensor, torch.Tensor) """ if not self.logs_inited: self.initialize_logs(x) # TODO condition for non 4-dims input logs = self.logs * self.logscale_factor if not reverse: x *= torch.exp(logs) else: x *= torch.exp(-logs) if logdet is not None: logdet_factor = ops.count_pixels(x) # H * W dlogdet = torch.sum(logs) * logdet_factor if reverse: dlogdet *= -1 logdet += dlogdet return x, logdet def initialize_bias(self, x): """ Initialize bias :param x: input :type x: torch.Tensor """ if not self.training: return with torch.no_grad(): # Compute initial value x_mean = -1. * ops.reduce_mean(x, dim=[0, 2, 3], keepdim=True) # Copy to parameters self.bias.data.copy_(x_mean.data) self.bias_inited = True def initialize_logs(self, x): """ Initialize logs :param x: input :type x: torch.Tensor """ if not self.training: return with torch.no_grad(): if self.batch_variance: x_var = ops.reduce_mean(x ** 2, keepdim=True) else: x_var = ops.reduce_mean(x ** 2, dim=[0, 2, 3], keepdim=True) logs = torch.log(self.scale / (torch.sqrt(x_var) + 1e-6)) / self.logscale_factor # Copy to parameters self.logs.data.copy_(logs.data) self.logs_inited = True def forward(self, x, logdet=None, reverse=False): """ Forward activation normalization layer :param x: input :type x: torch.Tensor :param logdet: log determinant :type logdet: :param reverse: whether to reverse bias :type reverse: bool :return: normalized input and logdet :rtype: tuple(torch.Tensor, torch.Tensor) """ assert len(x.shape) == 4 assert x.shape[1] == self.num_channels, \ 'Input shape should be NxCxHxW, however channels are {} instead of {}'.format(x.shape[1], self.num_channels) assert x.device == self.bias.device and x.device == self.logs.device, \ 'Expect input device {} instead of {}'.format(self.bias.device, x.device) if not reverse: # center and scale x = self.actnorm_center(x, reverse=False) x, logdet = self.actnorm_scale(x, logdet, reverse=False) else: # scale and center x, logdet = self.actnorm_scale(x, logdet, reverse=True) x = self.actnorm_center(x, reverse=True) return x, logdet class LinearZeros(nn.Linear): def __init__(self, in_features, out_features, bias=True, logscale_factor=3.): """ Linear layer with zero initialization :param in_features: size of each input sample :type in_features: int :param out_features: size of each output sample :type out_features: int :param bias: whether to learn an additive bias. :type bias: bool :param logscale_factor: factor of logscale :type logscale_factor: float """ super().__init__(in_features, out_features, bias) self.logscale_factor = logscale_factor # zero initialization self.weight.data.zero_() self.bias.data.zero_() # register parameter self.register_parameter('logs', nn.Parameter(torch.zeros(out_features))) def forward(self, x): """ Forward linear zero layer :param x: input :type x: torch.Tensor :return: output :rtype: torch.Tensor """ output = super().forward(x) output *= torch.exp(self.logs * self.logscale_factor) return output class Conv2d(nn.Conv2d): @staticmethod def get_padding(padding_type, kernel_size, stride): """ Get padding size. mentioned in https://github.com/pytorch/pytorch/issues/3867#issuecomment-361775080 behaves as 'SAME' padding in TensorFlow independent on input size when stride is 1 :param padding_type: type of padding in ['SAME', 'VALID'] :type padding_type: str :param kernel_size: kernel size :type kernel_size: tuple(int) or int :param stride: stride :type stride: int :return: padding size :rtype: tuple(int) """ assert padding_type in ['SAME', 'VALID'], "Unsupported padding type: {}".format(padding_type) if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] if padding_type == 'SAME': assert stride == 1, "'SAME' padding only supports stride=1" return tuple((k - 1) // 2 for k in kernel_size) return tuple(0 for _ in kernel_size) def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=1, padding_type='SAME', do_weightnorm=False, do_actnorm=True, dilation=1, groups=1): """ Wrapper of nn.Conv2d with weight normalization and activation normalization :param padding_type: type of padding in ['SAME', 'VALID'] :type padding_type: str :param do_weightnorm: whether to do weight normalization after convolution :type do_weightnorm: bool :param do_actnorm: whether to do activation normalization after convolution :type do_actnorm: bool """ padding = self.get_padding(padding_type, kernel_size, stride) super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=(not do_actnorm)) self.do_weight_norm = do_weightnorm self.do_actnorm = do_actnorm self.weight.data.normal_(mean=0.0, std=0.05) if self.do_actnorm: self.actnorm = ActNorm(out_channels) else: self.bias.data.zero_() def forward(self, x): """ Forward wrapped Conv2d layer :param x: input :type x: torch.Tensor :return: output :rtype: torch.Tensor """ x = super().forward(x) # if self.do_weight_norm: # # normalize N, H and W dims # F.normalize(x, p=2, dim=0) # F.normalize(x, p=2, dim=2) # F.normalize(x, p=2, dim=3) if self.do_actnorm: x, _ = self.actnorm(x) return x class Conv2dZeros(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size=(3, 3), stride=1, padding_type='SAME', logscale_factor=3, dilation=1, groups=1, bias=True): """ Wrapper of nn.Conv2d with zero initialization and logs :param padding_type: type of padding in ['SAME', 'VALID'] :type padding_type: str :param logscale_factor: factor for logscale :type logscale_factor: float """ padding = Conv2d.get_padding(padding_type, kernel_size, stride) super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) self.logscale_factor = logscale_factor # initialize variables with zero self.bias.data.zero_() self.weight.data.zero_() self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) def forward(self, x): """ Forward wrapped Conv2d layer :param x: input :type x: torch.Tensor :return: output :rtype: torch.Tensor """ x = super().forward(x) x *= torch.exp(self.logs * self.logscale_factor) return x def f(in_channels, hidden_channels, out_channels): """ Convolution block :param in_channels: number of input channels :type in_channels: int :param hidden_channels: number of hidden channels :type hidden_channels: int :param out_channels: number of output channels :type out_channels: int :return: desired convolution block :rtype: nn.Module """ return nn.Sequential( Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=True), Conv2d(hidden_channels, hidden_channels, kernel_size=1), nn.ReLU(inplace=True), Conv2dZeros(hidden_channels, out_channels) ) class Invertible1x1Conv(nn.Module): def __init__(self, num_channels, lu_decomposition=False): """ Invertible 1x1 convolution layer :param num_channels: number of channels :type num_channels: int :param lu_decomposition: whether to use LU decomposition :type lu_decomposition: bool """ super().__init__() self.num_channels = num_channels self.lu_decomposition = lu_decomposition if self.lu_decomposition: raise NotImplementedError() else: w_shape = [num_channels, num_channels] # Sample a random orthogonal matrix w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype('float32') self.register_parameter('weight', nn.Parameter(torch.Tensor(w_init))) def forward(self, x, logdet=None, reverse=False): """ :param x: input :type x: torch.Tensor :param logdet: log determinant :type logdet: :param reverse: whether to reverse bias :type reverse: bool :return: output and logdet :rtype: tuple(torch.Tensor, torch.Tensor) """ logdet_factor = ops.count_pixels(x) # H * W dlogdet = torch.log(torch.abs(torch.det(self.weight))) * logdet_factor if not reverse: weight = self.weight.view(*self.weight.shape, 1, 1) z = F.conv2d(x, weight) if logdet is not None: logdet = logdet + dlogdet return z, logdet else: weight = self.weight.inverse().view(*self.weight.shape, 1, 1) z = F.conv2d(x, weight) if logdet is not None: logdet = logdet - dlogdet return z, logdet class Permutation2d(nn.Module): def __init__(self, num_channels, shuffle=False): """ Perform permutation on channel dimension :param num_channels: :type num_channels: :param shuffle: :type shuffle: """ super().__init__() self.num_channels = num_channels self.indices = np.arange(self.num_channels - 1, -1, -1, dtype=np.long) if shuffle: np.random.shuffle(self.indices) self.indices_inverse = np.zeros(self.num_channels, dtype=np.long) for i in range(self.num_channels): self.indices_inverse[self.indices[i]] = i def forward(self, x, reverse=False): assert len(x.shape) == 4 if not reverse: return x[:, self.indices, :, :] else: return x[:, self.indices_inverse, :, :] class GaussianDiag: """ Generator of gaussian diagonal matrix """ log_2pi = float(np.log(2 * np.pi)) @staticmethod def eps(shape_tensor, eps_std=None): """ Returns a tensor filled with random numbers from a standard normal distribution :param shape_tensor: input tensor :type shape_tensor: torch.Tensor :param eps_std: standard deviation of eps :type eps_std: float :return: a tensor filled with random numbers from a standard normal distribution :rtype: torch.Tensor """ eps_std = eps_std or 1. return torch.normal(mean=torch.zeros_like(shape_tensor), std=torch.ones_like(shape_tensor) * eps_std) @staticmethod def flatten_sum(tensor): """ Summarize tensor except first dimension :param tensor: input tensor :type tensor: torch.Tensor :return: summarized tensor :rtype: torch.Tensor """ assert len(tensor.shape) == 4 return ops.reduce_sum(tensor, dim=[1, 2, 3]) @staticmethod def logps(mean, logs, x): """ Likehood :param mean: :type mean: torch.Tensor :param logs: :type logs: torch.Tensor :param x: input tensor :type x: torch.Tensor :return: likehood :rtype: torch.Tensor """ return -0.5 * (GaussianDiag.log_2pi + 2. * logs + ((x - mean) ** 2) / torch.exp(2. * logs)) @staticmethod def logp(mean, logs, x): """ Summarized likehood :param mean: :type mean: torch.Tensor :param logs: :type logs: torch.Tensor :param x: input tensor :type x: torch.Tensor :return: :rtype: torch.Tensor """ s = GaussianDiag.logps(mean, logs, x) return GaussianDiag.flatten_sum(s) @staticmethod def sample(mean, logs, eps_std=None): """ Generate smaple :type mean: torch.Tensor :param logs: :type logs: torch.Tensor :param eps_std: standard deviation of eps :type eps_std: float :return: sample :rtype: torch.Tensor """ eps = GaussianDiag.eps(mean, eps_std) return mean + torch.exp(logs) * eps class Split2d(nn.Module): def __init__(self, num_channels): """ Split2d layer :param num_channels: number of channels :type num_channels: int """ super().__init__() self.num_channels = num_channels self.conv2d_zeros = Conv2dZeros(num_channels // 2, num_channels) def prior(self, z): """ Pre-process :param z: input tensor :type z: torch.Tensor :return: output tensor :rtype: torch.Tensor """ h = self.conv2d_zeros(z) mean, logs = ops.split_channel(h, 'cross') return mean, logs def forward(self, x, logdet=0., reverse=False, eps_std=None): """ Forward Split2d layer :param x: input tensor :type x: torch.Tensor :param logdet: log determinant :type logdet: float :param reverse: whether to reverse flow :type reverse: bool :param eps_std: standard deviation of eps :type eps_std: float :return: output and logdet :rtype: tuple(torch.Tensor, torch.Tensor) """ if not reverse: z1, z2 = ops.split_channel(x, 'simple') mean, logs = self.prior(z1) logdet = GaussianDiag.logp(mean, logs, z2) + logdet return z1, logdet else: z1 = x mean, logs = self.prior(z1) z2 = GaussianDiag.sample(mean, logs, eps_std) z = ops.cat_channel(z1, z2) return z, logdet class Squeeze2d(nn.Module): def __init__(self, factor=2): """ Squeeze2d layer :param factor: squeeze factor :type factor: int """ super().__init__() self.factor = factor @staticmethod def unsqueeze(x, factor=2): """ Unsqueeze tensor :param x: input tensor :type x: torch.Tensor :param factor: unsqueeze factor :type factor: int :return: unsqueezed tensor :rtype: torch.Tensor """ assert factor >= 1 if factor == 1: return x _, nc, nh, nw = x.shape assert nc >= factor ** 2 and nc % factor ** 2 == 0 x = x.view(-1, nc // factor ** 2, factor, factor, nh, nw) x = x.permute(0, 1, 4, 2, 5, 3).contiguous() x = x.view(-1, nc // factor ** 2, nh * factor, nw * factor) return x @staticmethod def squeeze(x, factor=2): """ Squeeze tensor :param x: input tensor :type x: torch.Tensor :param factor: squeeze factor :type factor: int :return: squeezed tensor :rtype: torch.Tensor """ assert factor >= 1 if factor == 1: return x _, nc, nh, nw = x.shape assert nh % factor == 0 and nw % factor == 0 x = x.view(-1, nc, nh // factor, factor, nw // factor, factor) x = x.permute(0, 1, 3, 5, 2, 4).contiguous() x = x.view(-1, nc * factor * factor, nh // factor, nw // factor) return x def forward(self, x, logdet=None, reverse=False): """ Forward Squeeze2d layer :param x: input tensor :type x: torch.Tensor :param logdet: log determinant :type logdet: :param reverse: whether to reverse flow :type reverse: bool :return: output and logdet :rtype: tuple(torch.Tensor, torch.Tensor) """ if not reverse: output = self.squeeze(x, self.factor) else: output = self.unsqueeze(x, self.factor) return output, logdet