import torch from torch import nn import torch.nn.functional as F import numpy as np import utils class Bernoulli(): def __init__(self, mu): self.mu = mu def log_probability(self, x): self.mu = torch.clamp(self.mu, min=1e-5, max=1.0 - 1e-5) return (x * torch.log(self.mu) + (1.0 - x) * torch.log(1 - self.mu)).sum(1) def sample(self): return (torch.rand_like(self.mu).to(device=self.mu.device) < self.mu).to(torch.float) class DiagonalGaussian(): def __init__(self, mu, logvar): self.mu = mu self.logvar = logvar def log_probability(self, x): return -0.5 * torch.sum(np.log(2.0*np.pi) + self.logvar + ((x - self.mu)**2) / torch.exp(self.logvar), dim=1) def sample(self): eps = torch.randn_like(self.mu) return self.mu + torch.exp(0.5 * self.logvar) * eps def repeat(self, n): mu = self.mu.unsqueeze(1).repeat(1, n, 1).view(-1, self.mu.shape[-1]) logvar = self.logvar.unsqueeze(1).repeat(1, n, 1).view(-1, self.logvar.shape[-1]) return DiagonalGaussian(mu, logvar) @staticmethod def kl_div(p, q): return 0.5 * torch.sum(q.logvar - p.logvar - 1.0 + (torch.exp(p.logvar) + (p.mu - q.mu)**2)/(torch.exp(q.logvar)), dim=1) class Gaussian(): def __init__(self, mu, precision): # mu: [batch_size, z_dim] self.mu = mu # precision: [batch_size, z_dim, z_dim] self.precision = precision # TODO: get rid of the inverse for efficiency self.L = torch.cholesky(torch.inverse(precision)) self.dim = self.mu.shape[1] def log_probability(self, x): indices = np.arange(self.L.shape[-1]) return -0.5 * (self.dim * np.log(2.0*np.pi) + 2.0 * torch.log(self.L[:, indices, indices]).sum(1) + torch.matmul(torch.matmul((x - self.mu).unsqueeze(1), self.precision), (x - self.mu).unsqueeze(-1)).sum([1, 2])) def sample(self): eps = torch.randn_like(self.mu) return self.mu + torch.matmul(self.L, eps.unsqueeze(-1)).squeeze(-1) def repeat(self, n): mu = self.mu.unsqueeze(1).repeat(1, n, 1).view(-1, self.mu.shape[-1]) precision = self.precision.unsqueeze(1).repeat(1, n, 1, 1).view(-1, *self.precision.shape[1:]) return Gaussian(mu, precision)