import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from .decoder import DecoderBase def he_init(m): s = np.sqrt(2./ m.in_features) m.weight.data.normal_(0, s) class GatedMaskedConv2d(nn.Module): def __init__(self, in_dim, out_dim=None, kernel_size = 3, mask = 'B'): super(GatedMaskedConv2d, self).__init__() if out_dim is None: out_dim = in_dim self.dim = out_dim self.size = kernel_size self.mask = mask pad = self.size // 2 #vertical stack self.v_conv = nn.Conv2d(in_dim, 2*self.dim, kernel_size=(pad+1, self.size)) self.v_pad1 = nn.ConstantPad2d((pad, pad, pad, 0), 0) self.v_pad2 = nn.ConstantPad2d((0, 0, 1, 0), 0) self.vh_conv = nn.Conv2d(2*self.dim, 2*self.dim, kernel_size = 1) #horizontal stack self.h_conv = nn.Conv2d(in_dim, 2*self.dim, kernel_size=(1, pad+1)) self.h_pad1 = nn.ConstantPad2d((self.size // 2, 0, 0, 0), 0) self.h_pad2 = nn.ConstantPad2d((1, 0, 0, 0), 0) self.h_conv_res = nn.Conv2d(self.dim, self.dim, 1) def forward(self, v_map, h_map): v_out = self.v_pad2(self.v_conv(self.v_pad1(v_map)))[:, :, :-1, :] v_map_out = F.tanh(v_out[:, :self.dim])*F.sigmoid(v_out[:, self.dim:]) vh = self.vh_conv(v_out) h_out = self.h_conv(self.h_pad1(h_map)) if self.mask == 'A': h_out = self.h_pad2(h_out)[:, :, :, :-1] h_out = h_out + vh h_out = F.tanh(h_out[:, :self.dim])*F.sigmoid(h_out[:, self.dim:]) h_map_out = self.h_conv_res(h_out) if self.mask == 'B': h_map_out = h_map_out + h_map return v_map_out, h_map_out class StackedGatedMaskedConv2d(nn.Module): def __init__(self, img_size = [1, 28, 28], layers = [64,64,64], kernel_size = [7,7,7], latent_dim=64, latent_feature_map = 1): super(StackedGatedMaskedConv2d, self).__init__() input_dim = img_size[0] self.conv_layers = [] if latent_feature_map > 0: self.latent_feature_map = latent_feature_map self.z_linear = nn.Linear(latent_dim, latent_feature_map*28*28) for i in range(len(kernel_size)): if i == 0: self.conv_layers.append(GatedMaskedConv2d(input_dim+latent_feature_map, layers[i], kernel_size[i], 'A')) else: self.conv_layers.append(GatedMaskedConv2d(layers[i-1], layers[i], kernel_size[i])) self.modules = nn.ModuleList(self.conv_layers) def forward(self, img, q_z=None): """ Args: img: (batch, nc, H, W) q_z: (batch, nsamples, nz) """ batch_size, nsamples, _ = q_z.size() if q_z is not None: z_img = self.z_linear(q_z) z_img = z_img.view(img.size(0), nsamples, self.latent_feature_map, img.size(2), img.size(3)) # (batch, nsamples, nc, H, W) img = img.unsqueeze(1).expand(batch_size, nsamples, *img.size()[1:]) for i in range(len(self.conv_layers)): if i == 0: if q_z is not None: # (batch, nsamples, nc + fm, H, W) --> (batch * nsamples, nc + fm, H, W) v_map = torch.cat([img, z_img], 2) v_map = v_map.view(-1, *v_map.size()[2:]) else: v_map = img h_map = v_map v_map, h_map = self.conv_layers[i](v_map, h_map) return h_map class PixelCNNDecoder(DecoderBase): """docstring for PixelCNNDecoder""" def __init__(self, args): super(PixelCNNDecoder, self).__init__() self.dec_cnn = StackedGatedMaskedConv2d(img_size=args.img_size, layers = args.dec_layers, latent_dim= args.nz, kernel_size = args.dec_kernel_size, latent_feature_map = args.latent_feature_map) self.dec_linear = nn.Conv2d(args.dec_layers[-1], args.img_size[0], kernel_size = 1) self.reset_parameters() def reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Linear): he_init(m) def decode(self, img, q_z): dec_cnn_output = self.dec_cnn(img, q_z) pred = F.sigmoid(self.dec_linear(dec_cnn_output)) return pred def reconstruct_error(self, x, z): """Cross Entropy in the language case Args: x: (batch_size, nc, H, W) z: (batch_size, n_sample, nz) Returns: loss: (batch_size, n_sample). Loss across different sentence and z """ batch_size, nsamples, _ = z.size() # (batch * nsamples, nc, H, W) pred = self.decode(x, z) prob = torch.clamp(pred.view(pred.size(0), -1), min=1e-5, max=1.-1e-5) # (batch, nsamples, nc, H, W) --> (batch * nsamples, nc, H, W) x = x.unsqueeze(1).expand(batch_size, nsamples, *x.size()[1:]).contiguous() tgt_vec = x.view(-1, *x.size()[2:]) # (batch * nsamples, *) tgt_vec = tgt_vec.view(tgt_vec.size(0), -1) log_bernoulli = tgt_vec * torch.log(prob) + (1. - tgt_vec)*torch.log(1. - prob) log_bernoulli = log_bernoulli.view(batch_size, nsamples, -1) return -torch.sum(log_bernoulli, 2) def log_probability(self, x, z): """Cross Entropy in the language case Args: x: (batch_size, nc, H, W) z: (batch_size, n_sample, nz) Returns: log_p: (batch_size, n_sample). log_p(x|z) across different x and z """ return -self.reconstruct_error(x, z)