"""Basic definitions for the distributions module.""" import torch from torch import nn import utils class NoMeanException(Exception): """Exception to be thrown when a mean function doesn't exist.""" pass class Distribution(nn.Module): """Base class for all distribution objects.""" def forward(self, *args): raise RuntimeError('Forward method cannot be called for a Distribution object.') def log_prob(self, inputs, context=None): """Calculate log probability under the distribution. Args: inputs: Tensor, input variables. context: Tensor or None, conditioning variables. If a Tensor, it must have the same number or rows as the inputs. If None, the context is ignored. Returns: A Tensor of shape [input_size], the log probability of the inputs given the context. """ inputs = torch.as_tensor(inputs) if context is not None: context = torch.as_tensor(context) if inputs.shape[0] != context.shape[0]: raise ValueError('Number of input items must be equal to number of context items.') return self._log_prob(inputs, context) def _log_prob(self, inputs, context): raise NotImplementedError() def sample(self, num_samples, context=None, batch_size=None): """Generates samples from the distribution. Samples can be generated in batches. Args: num_samples: int, number of samples to generate. context: Tensor or None, conditioning variables. If None, the context is ignored. batch_size: int or None, number of samples per batch. If None, all samples are generated in one batch. Returns: A Tensor containing the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. """ if not utils.is_positive_int(num_samples): raise TypeError('Number of samples must be a positive integer.') if context is not None: context = torch.as_tensor(context) if batch_size is None: return self._sample(num_samples, context) else: if not utils.is_positive_int(batch_size): raise TypeError('Batch size must be a positive integer.') num_batches = num_samples // batch_size num_leftover = num_samples % batch_size samples = [self._sample(batch_size, context) for _ in range(num_batches)] if num_leftover > 0: samples.append(self._sample(num_leftover, context)) return torch.cat(samples, dim=0) def _sample(self, num_samples, context): raise NotImplementedError() def sample_and_log_prob(self, num_samples, context=None): """Generates samples from the distribution together with their log probability. Args: num_samples: int, number of samples to generate. context: Tensor or None, conditioning variables. If None, the context is ignored. Returns: A tuple of: * A Tensor containing the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. * A Tensor containing the log probabilities of the samples, with shape [num_samples, ...] if context is None, or [context_size, num_samples, ...] if context is given. """ samples = self.sample(num_samples, context=context) if context is not None: # Merge the context dimension with sample dimension in order to call log_prob. samples = utils.merge_leading_dims(samples, num_dims=2) context = utils.repeat_rows(context, num_reps=num_samples) assert samples.shape[0] == context.shape[0] log_prob = self.log_prob(samples, context=context) if context is not None: # Split the context dimension from sample dimension. samples = utils.split_leading_dim(samples, shape=[-1, num_samples]) log_prob = utils.split_leading_dim(log_prob, shape=[-1, num_samples]) return samples, log_prob def mean(self, context=None): if context is not None: context = torch.as_tensor(context) return self._mean(context) def _mean(self, context): raise NoMeanException()