#!/usr/bin/env python # -*- coding: utf-8 -*- from __future__ import absolute_import from __future__ import division import numpy as np import tensorflow as tf from zhusuan.distributions.base import Distribution from zhusuan.distributions.utils import ( maybe_explicit_broadcast, assert_same_float_dtype, assert_dtype_is_int_or_float, assert_rank_at_least, assert_rank_at_least_one, assert_scalar, assert_positive_int32_scalar, get_shape_at, open_interval_standard_uniform, log_combination ) __all__ = [ 'MultivariateNormalCholesky', 'Multinomial', 'UnnormalizedMultinomial', 'BagofCategoricals', 'OnehotCategorical', 'OnehotDiscrete', 'Dirichlet', 'ExpConcrete', 'ExpGumbelSoftmax', 'Concrete', 'GumbelSoftmax', 'MatrixVariateNormalCholesky', ] class MultivariateNormalCholesky(Distribution): """ The class of multivariate normal distribution, where covariance is parameterized with the lower triangular matrix :math:`L` in Cholesky decomposition :math:`LL^T = \Sigma`. See :class:`~zhusuan.distributions.base.Distribution` for details. :param mean: An N-D `float` Tensor of shape `[..., n_dim]`. Each slice `[i, j, ..., k, :]` represents the mean of a single multivariate normal distribution. :param cov_tril: An (N+1)-D `float` Tensor of shape `[..., n_dim, n_dim]`. Each slice `[i, ..., k, :, :]` represents the lower triangular matrix in the Cholesky decomposition of the covariance of a single distribution. :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. :param is_reparameterized: A Bool. If True, gradients on samples from this distribution are allowed to propagate into inputs, using the reparametrization trick from (Kingma, 2013). :param use_path_derivative: A bool. Whether when taking the gradients of the log-probability to propagate them through the parameters of the distribution (False meaning you do propagate them). This is based on the paper "Sticking the Landing: Simple, Lower-Variance Gradient Estimators for Variational Inference" :param check_numerics: Bool. Whether to check numeric issues. """ def __init__(self, mean, cov_tril, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._check_numerics = check_numerics self._mean = tf.convert_to_tensor(mean) self._mean = assert_rank_at_least_one( self._mean, 'MultivariateNormalCholesky.mean') self._n_dim = get_shape_at(self._mean, -1) self._cov_tril = tf.convert_to_tensor(cov_tril) self._cov_tril = assert_rank_at_least( self._cov_tril, 2, 'MultivariateNormalCholesky.cov_tril') # Static shape check expected_shape = self._mean.get_shape().concatenate( [self._n_dim if isinstance(self._n_dim, int) else None]) self._cov_tril.get_shape().assert_is_compatible_with(expected_shape) # Dynamic expected_shape = tf.concat( [tf.shape(self._mean), [self._n_dim]], axis=0) actual_shape = tf.shape(self._cov_tril) msg = ['MultivariateNormalCholesky.cov_tril should have compatible ' 'shape with mean. Expected', expected_shape, ' got ', actual_shape] assert_ops = [tf.assert_equal(expected_shape, actual_shape, msg)] with tf.control_dependencies(assert_ops): self._cov_tril = tf.identity(self._cov_tril) dtype = assert_same_float_dtype( [(self._mean, 'MultivariateNormalCholesky.mean'), (self._cov_tril, 'MultivariateNormalCholesky.cov_tril')]) super(MultivariateNormalCholesky, self).__init__( dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs) @property def mean(self): """The mean of the normal distribution.""" return self._mean @property def cov_tril(self): """ The lower triangular matrix in the cholosky decomposition of the covariance. """ return self._cov_tril def _value_shape(self): return tf.convert_to_tensor([self._n_dim], tf.int32) def _get_value_shape(self): if isinstance(self._n_dim, int): return tf.TensorShape([self._n_dim]) return tf.TensorShape([None]) def _batch_shape(self): return tf.shape(self.mean)[:-1] def _get_batch_shape(self): if self.mean.get_shape(): return self.mean.get_shape()[:-1] return tf.TensorShape(None) def _sample(self, n_samples): mean, cov_tril = self.mean, self.cov_tril if not self.is_reparameterized: mean = tf.stop_gradient(mean) cov_tril = tf.stop_gradient(cov_tril) def tile(t): new_shape = tf.concat([[n_samples], tf.ones_like(tf.shape(t))], 0) return tf.tile(tf.expand_dims(t, 0), new_shape) batch_mean = tile(mean) batch_cov = tile(cov_tril) # n_dim -> n_dim x 1 for matmul batch_mean = tf.expand_dims(batch_mean, -1) noise = tf.random_normal(tf.shape(batch_mean), dtype=self.dtype) samples = tf.matmul(batch_cov, noise) + batch_mean samples = tf.squeeze(samples, -1) # Update static shape static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape(tf.TensorShape([static_n_samples]) .concatenate(self.get_batch_shape()) .concatenate(self.get_value_shape())) return samples def _log_prob(self, given): mean, cov_tril = (self.path_param(self.mean), self.path_param(self.cov_tril)) log_det = 2 * tf.reduce_sum( tf.log(tf.matrix_diag_part(cov_tril)), axis=-1) n_dim = tf.cast(self._n_dim, self.dtype) log_z = - n_dim / 2 * tf.log( 2 * tf.constant(np.pi, dtype=self.dtype)) - log_det / 2 # log_z.shape == batch_shape if self._check_numerics: log_z = tf.check_numerics(log_z, "log[det(Cov)]") # (given-mean)' Sigma^{-1} (given-mean) = # (g-m)' L^{-T} L^{-1} (g-m) = |x|^2, where Lx = g-m =: y. y = tf.expand_dims(given - mean, -1) L, _ = maybe_explicit_broadcast( cov_tril, y, 'MultivariateNormalCholesky.cov_tril', 'expand_dims(given, -1)') x = tf.matrix_triangular_solve(L, y, lower=True) x = tf.squeeze(x, -1) stoc_dist = -0.5 * tf.reduce_sum(tf.square(x), axis=-1) return log_z + stoc_dist def _prob(self, given): return tf.exp(self._log_prob(given)) class Multinomial(Distribution): """ The class of Multinomial distribution. See :class:`~zhusuan.distributions.base.Distribution` for details. :param logits: A N-D (N >= 1) `float` Tensor of shape `[..., n_categories]`. Each slice `[i, j, ..., k, :]` represents the log probabilities for all categories. By default (when `normalize_logits=True`), the probabilities could be un-normalized. .. math:: \\mathrm{logits} \\propto \\log p :param n_experiments: A 0-D `int32` Tensor or `None`. When it is a 0-D `int32` integer, it represents the number of experiments for each sample, which should be invariant among samples. In this situation `_sample` function is supported. When it is `None`, `_sample` function is not supported, and when calculating probabilities the number of experiments will be inferred from `given`, so it could vary among samples. :param normalize_logits: A bool indicating whether `logits` should be normalized when computing probability. If you believe `logits` is already normalized, set it to `False` to speed up. Default is True. :param dtype: The value type of samples from the distribution. Can be int (`tf.int16`, `tf.int32`, `tf.int64`) or float (`tf.float16`, `tf.float32`, `tf.float64`). Default is `int32`. :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. A single sample is a N-D Tensor with the same shape as logits. Each slice `[i, j, ..., k, :]` is a vector of counts for all categories. """ def __init__(self, logits, n_experiments, normalize_logits=True, dtype=tf.int32, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype( [(self._logits, 'Multinomial.logits')]) assert_dtype_is_int_or_float(dtype) self._logits = assert_rank_at_least_one( self._logits, 'Multinomial.logits') self._n_categories = get_shape_at(self._logits, -1) if n_experiments is None: self._n_experiments = None else: self._n_experiments = assert_positive_int32_scalar( n_experiments, 'Multinomial.n_experiments') self.normalize_logits = normalize_logits super(Multinomial, self).__init__( dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs) @property def logits(self): """The un-normalized log probabilities.""" return self._logits @property def n_categories(self): """The number of categories in the distribution.""" return self._n_categories @property def n_experiments(self): """The number of experiments for each sample.""" return self._n_experiments def _value_shape(self): return tf.convert_to_tensor([self.n_categories], tf.int32) def _get_value_shape(self): if isinstance(self.n_categories, int): return tf.TensorShape([self.n_categories]) return tf.TensorShape([None]) def _batch_shape(self): return tf.shape(self.logits)[:-1] def _get_batch_shape(self): if self.logits.get_shape(): return self.logits.get_shape()[:-1] return tf.TensorShape(None) def _sample(self, n_samples): if self.n_experiments is None: raise ValueError('Cannot sample when `n_experiments` is None') if self.logits.get_shape().ndims == 2: logits_flat = self.logits else: logits_flat = tf.reshape(self.logits, [-1, self.n_categories]) samples_flat = tf.transpose( tf.random.categorical(logits_flat, n_samples * self.n_experiments)) shape = tf.concat([[n_samples, self.n_experiments], self.batch_shape], 0) samples = tf.reshape(samples_flat, shape) static_n_samples = n_samples if isinstance(n_samples, int) else None static_n_exps = self.n_experiments \ if isinstance(self.n_experiments, int) else None samples.set_shape( tf.TensorShape([static_n_samples, static_n_exps]). concatenate(self.get_batch_shape())) samples = tf.reduce_sum( tf.one_hot(samples, self.n_categories, dtype=self.dtype), axis=1) return samples def _log_prob(self, given): given = tf.cast(given, self.param_dtype) given, logits = maybe_explicit_broadcast( given, self.logits, 'given', 'logits') if self.normalize_logits: logits = logits - tf.reduce_logsumexp( logits, axis=-1, keepdims=True) if self.n_experiments is None: n = tf.reduce_sum(given, -1) else: n = tf.cast(self.n_experiments, self.param_dtype) log_p = log_combination(n, given) + \ tf.reduce_sum(given * logits, -1) return log_p def _prob(self, given): return tf.exp(self._log_prob(given)) class UnnormalizedMultinomial(Distribution): """ The class of UnnormalizedMultinomial distribution. UnnormalizedMultinomial distribution calculates probabilities differently from :class:`Multinomial`: It considers the bag-of-words `given` as a statistics of an ordered result sequence, and calculates the probability of the (imagined) ordered sequence. Hence it does not multiply the term .. math:: \\binom{n}{k_1, k_2, \\dots} = \\frac{n!}{\\prod_{i} k_i!} See :class:`~zhusuan.distributions.base.Distribution` for details. :param logits: A N-D (N >= 1) `float` Tensor of shape `[..., n_categories]`. Each slice `[i, j, ..., k, :]` represents the log probabilities for all categories. By default (when `normalize_logits=True`), the probabilities could be un-normalized. .. math:: \\mathrm{logits} \\propto \\log p :param normalize_logits: A bool indicating whether `logits` should be normalized when computing probability. If you believe `logits` is already normalized, set it to `False` to speed up. Default is True. :param dtype: The value type of samples from the distribution. Can be int (`tf.int16`, `tf.int32`, `tf.int64`) or float (`tf.float16`, `tf.float32`, `tf.float64`). Default is `int32`. :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. A single sample is a N-D Tensor with the same shape as logits. Each slice `[i, j, ..., k, :]` is a vector of counts for all categories. """ def __init__(self, logits, normalize_logits=True, dtype=tf.int32, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype( [(self._logits, 'UnnormalizedMultinomial.logits')]) assert_dtype_is_int_or_float(dtype) self._logits = assert_rank_at_least_one( self._logits, 'UnnormalizedMultinomial.logits') self._n_categories = get_shape_at(self._logits, -1) self.normalize_logits = normalize_logits super(UnnormalizedMultinomial, self).__init__( dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs) @property def logits(self): """The un-normalized log probabilities.""" return self._logits @property def n_categories(self): """The number of categories in the distribution.""" return self._n_categories def _value_shape(self): return tf.convert_to_tensor([self.n_categories], tf.int32) def _get_value_shape(self): if isinstance(self.n_categories, int): return tf.TensorShape([self.n_categories]) return tf.TensorShape([None]) def _batch_shape(self): return tf.shape(self.logits)[:-1] def _get_batch_shape(self): if self.logits.get_shape(): return self.logits.get_shape()[:-1] return tf.TensorShape(None) def _sample(self, n_samples): raise NotImplementedError("Unnormalized multinomial distribution" " does not support sampling because" " n_experiments is not given. Please use" " class Multinomial to sample") def _log_prob(self, given): given = tf.cast(given, self.param_dtype) given, logits = maybe_explicit_broadcast( given, self.logits, 'given', 'logits') if self.normalize_logits: logits = logits - tf.reduce_logsumexp( logits, axis=-1, keepdims=True) log_p = tf.reduce_sum(given * logits, -1) return log_p def _prob(self, given): return tf.exp(self._log_prob(given)) BagofCategoricals = UnnormalizedMultinomial class OnehotCategorical(Distribution): """ The class of one-hot Categorical distribution. See :class:`~zhusuan.distributions.base.Distribution` for details. :param logits: A N-D (N >= 1) `float` Tensor of shape (..., n_categories). Each slice `[i, j, ..., k, :]` represents the un-normalized log probabilities for all categories. .. math:: \\mathrm{logits} \\propto \\log p :param dtype: The value type of samples from the distribution. Can be int (`tf.int16`, `tf.int32`, `tf.int64`) or float (`tf.float16`, `tf.float32`, `tf.float64`). Default is `int32`. :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. A single sample is a N-D Tensor with the same shape as logits. Each slice `[i, j, ..., k, :]` is a one-hot vector of the selected category. """ def __init__(self, logits, dtype=tf.int32, group_ndims=0, **kwargs): self._logits = tf.convert_to_tensor(logits) param_dtype = assert_same_float_dtype( [(self._logits, 'OnehotCategorical.logits')]) assert_dtype_is_int_or_float(dtype) self._logits = assert_rank_at_least_one( self._logits, 'OnehotCategorical.logits') self._n_categories = get_shape_at(self._logits, -1) super(OnehotCategorical, self).__init__( dtype=dtype, param_dtype=param_dtype, is_continuous=False, is_reparameterized=False, group_ndims=group_ndims, **kwargs) @property def logits(self): """The un-normalized log probabilities.""" return self._logits @property def n_categories(self): """The number of categories in the distribution.""" return self._n_categories def _value_shape(self): return tf.convert_to_tensor([self.n_categories], tf.int32) def _get_value_shape(self): if isinstance(self.n_categories, int): return tf.TensorShape([self.n_categories]) return tf.TensorShape([None]) def _batch_shape(self): return tf.shape(self.logits)[:-1] def _get_batch_shape(self): if self.logits.get_shape(): return self.logits.get_shape()[:-1] return tf.TensorShape(None) def _sample(self, n_samples): if self.logits.get_shape().ndims == 2: logits_flat = self.logits else: logits_flat = tf.reshape(self.logits, [-1, self.n_categories]) samples_flat = tf.transpose( tf.random.categorical(logits_flat, n_samples)) if self.logits.get_shape().ndims == 2: samples = samples_flat else: shape = tf.concat([[n_samples], self.batch_shape], 0) samples = tf.reshape(samples_flat, shape) static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]). concatenate(self.get_batch_shape())) samples = tf.one_hot(samples, self.n_categories, dtype=self.dtype) return samples def _log_prob(self, given): given = tf.cast(given, self.param_dtype) given, logits = maybe_explicit_broadcast( given, self.logits, 'given', 'logits') if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2): given_flat = given logits_flat = logits else: given_flat = tf.reshape(given, [-1, self.n_categories]) logits_flat = tf.reshape(logits, [-1, self.n_categories]) log_p_flat = -tf.nn.softmax_cross_entropy_with_logits( labels=given_flat, logits=logits_flat) if (given.get_shape().ndims == 2) or (logits.get_shape().ndims == 2): log_p = log_p_flat else: log_p = tf.reshape(log_p_flat, tf.shape(logits)[:-1]) if given.get_shape() and logits.get_shape(): log_p.set_shape(tf.broadcast_static_shape( given.get_shape(), logits.get_shape())[:-1]) return log_p def _prob(self, given): return tf.exp(self._log_prob(given)) OnehotDiscrete = OnehotCategorical class Dirichlet(Distribution): """ The class of Dirichlet distribution. See :class:`~zhusuan.distributions.base.Distribution` for details. :param alpha: A N-D (N >= 1) `float` Tensor of shape (..., n_categories). Each slice `[i, j, ..., k, :]` represents the concentration parameter of a Dirichlet distribution. Should be positive. :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. A single sample is a N-D Tensor with the same shape as alpha. Each slice `[i, j, ..., k, :]` of the sample is a vector of probabilities of a Categorical distribution `[x_1, x_2, ... ]`, which lies on the simplex .. math:: \\sum_{i} x_i = 1, 0 < x_i < 1 """ def __init__(self, alpha, group_ndims=0, check_numerics=False, **kwargs): self._alpha = tf.convert_to_tensor(alpha) dtype = assert_same_float_dtype( [(self._alpha, 'Dirichlet.alpha')]) static_alpha_shape = self._alpha.get_shape() shape_err_msg = "alpha should have rank >= 1." cat_err_msg = "n_categories (length of the last axis " \ "of alpha) should be at least 2." if static_alpha_shape and (static_alpha_shape.ndims < 1): raise ValueError(shape_err_msg) elif static_alpha_shape and ( static_alpha_shape[-1].value is not None): self._n_categories = static_alpha_shape[-1].value if self._n_categories < 2: raise ValueError(cat_err_msg) else: _assert_shape_op = tf.assert_rank_at_least( self._alpha, 1, message=shape_err_msg) with tf.control_dependencies([_assert_shape_op]): self._alpha = tf.identity(self._alpha) self._n_categories = tf.shape(self._alpha)[-1] _assert_cat_op = tf.assert_greater_equal( self._n_categories, 2, message=cat_err_msg) with tf.control_dependencies([_assert_cat_op]): self._alpha = tf.identity(self._alpha) self._check_numerics = check_numerics super(Dirichlet, self).__init__( dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=False, group_ndims=group_ndims, **kwargs) @property def alpha(self): """The concentration parameter of the Dirichlet distribution.""" return self._alpha @property def n_categories(self): """The number of categories in the distribution.""" return self._n_categories def _value_shape(self): return tf.convert_to_tensor([self.n_categories], tf.int32) def _get_value_shape(self): if isinstance(self.n_categories, int): return tf.TensorShape([self.n_categories]) return tf.TensorShape([None]) def _batch_shape(self): return tf.shape(self.alpha)[:-1] def _get_batch_shape(self): if self.alpha.get_shape(): return self.alpha.get_shape()[:-1] return tf.TensorShape(None) def _sample(self, n_samples): samples = tf.random_gamma([n_samples], self.alpha, beta=1, dtype=self.dtype) return samples / tf.reduce_sum(samples, -1, keepdims=True) def _log_prob(self, given): given, alpha = maybe_explicit_broadcast( given, self.alpha, 'given', 'alpha') lbeta_alpha = tf.lbeta(alpha) # fix of no static shape inference for tf.lbeta if alpha.get_shape(): lbeta_alpha.set_shape(alpha.get_shape()[:-1]) log_given = tf.log(given) if self._check_numerics: lbeta_alpha = tf.check_numerics(lbeta_alpha, "lbeta(alpha)") log_given = tf.check_numerics(log_given, "log(given)") log_p = -lbeta_alpha + tf.reduce_sum((alpha - 1) * log_given, -1) return log_p def _prob(self, given): return tf.exp(self._log_prob(given)) class ExpConcrete(Distribution): """ The class of ExpConcrete distribution from (Maddison, 2016), transformed from :class:`~Concrete` by taking logarithm. See :class:`~zhusuan.distributions.base.Distribution` for details. .. seealso:: :class:`~zhusuan.distributions.univariate.BinConcrete` and :class:`~Concrete` :param temperature: A 0-D `float` Tensor. The temperature of the relaxed distribution. The temperature should be positive. :param logits: A N-D (N >= 1) `float` Tensor of shape (..., n_categories). Each slice `[i, j, ..., k, :]` represents the un-normalized log probabilities for all categories. .. math:: \\mathrm{logits} \\propto \\log p :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. :param is_reparameterized: A Bool. If True, gradients on samples from this distribution are allowed to propagate into inputs, using the reparametrization trick from (Kingma, 2013). :param use_path_derivative: A bool. Whether when taking the gradients of the log-probability to propagate them through the parameters of the distribution (False meaning you do propagate them). This is based on the paper "Sticking the Landing: Simple, Lower-Variance Gradient Estimators for Variational Inference" :param check_numerics: Bool. Whether to check numeric issues. """ def __init__(self, temperature, logits, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._logits = tf.convert_to_tensor(logits) self._temperature = tf.convert_to_tensor(temperature) param_dtype = assert_same_float_dtype( [(self._logits, 'ExpConcrete.logits'), (self._temperature, 'ExpConcrete.temperature')]) self._logits = assert_rank_at_least_one( self._logits, 'ExpConcrete.logits') self._n_categories = get_shape_at(self._logits, -1) self._temperature = assert_scalar( self._temperature, 'ExpConcrete.temperature') self._check_numerics = check_numerics super(ExpConcrete, self).__init__( dtype=param_dtype, param_dtype=param_dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs) @property def temperature(self): """The temperature of ExpConcrete.""" return self._temperature @property def logits(self): """The un-normalized log probabilities.""" return self._logits @property def n_categories(self): """The number of categories in the distribution.""" return self._n_categories def _value_shape(self): return tf.convert_to_tensor([self.n_categories], tf.int32) def _get_value_shape(self): if isinstance(self.n_categories, int): return tf.TensorShape([self.n_categories]) return tf.TensorShape([None]) def _batch_shape(self): return tf.shape(self.logits)[:-1] def _get_batch_shape(self): if self.logits.get_shape(): return self.logits.get_shape()[:-1] return tf.TensorShape(None) def _sample(self, n_samples): logits, temperature = self.logits, self.temperature if not self.is_reparameterized: logits = tf.stop_gradient(logits) temperature = tf.stop_gradient(temperature) shape = tf.concat([[n_samples], tf.shape(self.logits)], 0) uniform = open_interval_standard_uniform(shape, self.dtype) gumbel = -tf.log(-tf.log(uniform)) samples = tf.nn.log_softmax((logits + gumbel) / temperature) static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]).concatenate(logits.get_shape())) return samples def _log_prob(self, given): logits, temperature = self.path_param(self.logits),\ self.path_param(self.temperature) n = tf.cast(self.n_categories, self.dtype) log_temperature = tf.log(temperature) if self._check_numerics: log_temperature = tf.check_numerics( log_temperature, "log(temperature)") temp = logits - temperature * given return tf.lgamma(n) + (n - 1) * log_temperature + \ tf.reduce_sum(temp, axis=-1) - \ n * tf.reduce_logsumexp(temp, axis=-1) def _prob(self, given): return tf.exp(self._log_prob(given)) ExpGumbelSoftmax = ExpConcrete class Concrete(Distribution): """ The class of Concrete (or Gumbel-Softmax) distribution from (Maddison, 2016; Jang, 2016), served as the continuous relaxation of the :class:`~OnehotCategorical`. See :class:`~zhusuan.distributions.base.Distribution` for details. .. seealso:: :class:`~zhusuan.distributions.univariate.BinConcrete` and :class:`~ExpConcrete` :param temperature: A 0-D `float` Tensor. The temperature of the relaxed distribution. The temperature should be positive. :param logits: A N-D (N >= 1) `float` Tensor of shape (..., n_categories). Each slice `[i, j, ..., k, :]` represents the un-normalized log probabilities for all categories. .. math:: \\mathrm{logits} \\propto \\log p :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. :param is_reparameterized: A Bool. If True, gradients on samples from this distribution are allowed to propagate into inputs, using the reparametrization trick from (Kingma, 2013). :param use_path_derivative: A bool. Whether when taking the gradients of the log-probability to propagate them through the parameters of the distribution (False meaning you do propagate them). This is based on the paper "Sticking the Landing: Simple, Lower-Variance Gradient Estimators for Variational Inference" :param check_numerics: Bool. Whether to check numeric issues. """ def __init__(self, temperature, logits, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._logits = tf.convert_to_tensor(logits) self._temperature = tf.convert_to_tensor(temperature) param_dtype = assert_same_float_dtype( [(self._logits, 'Concrete.logits'), (self._temperature, 'Concrete.temperature')]) self._logits = assert_rank_at_least_one( self._logits, 'Concrete.logits') self._n_categories = get_shape_at(self._logits, -1) self._temperature = assert_scalar( self._temperature, 'Concrete.temperature') self._check_numerics = check_numerics super(Concrete, self).__init__( dtype=param_dtype, param_dtype=param_dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs) @property def temperature(self): """The temperature of Concrete.""" return self._temperature @property def logits(self): """The un-normalized log probabilities.""" return self._logits @property def n_categories(self): """The number of categories in the distribution.""" return self._n_categories def _value_shape(self): return tf.convert_to_tensor([self.n_categories], tf.int32) def _get_value_shape(self): if isinstance(self.n_categories, int): return tf.TensorShape([self.n_categories]) return tf.TensorShape([None]) def _batch_shape(self): return tf.shape(self.logits)[:-1] def _get_batch_shape(self): if self.logits.get_shape(): return self.logits.get_shape()[:-1] return tf.TensorShape(None) def _sample(self, n_samples): logits, temperature = self.logits, self.temperature if not self.is_reparameterized: logits = tf.stop_gradient(logits) temperature = tf.stop_gradient(temperature) shape = tf.concat([[n_samples], tf.shape(self.logits)], 0) uniform = open_interval_standard_uniform(shape, self.dtype) # TODO: Add Gumbel distribution gumbel = -tf.log(-tf.log(uniform)) samples = tf.nn.softmax((logits + gumbel) / temperature) static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape( tf.TensorShape([static_n_samples]).concatenate(logits.get_shape())) return samples def _log_prob(self, given): logits, temperature = self.path_param(self.logits), \ self.path_param(self.temperature) log_given = tf.log(given) log_temperature = tf.log(temperature) n = tf.cast(self.n_categories, self.dtype) if self._check_numerics: log_given = tf.check_numerics(log_given, "log(given)") log_temperature = tf.check_numerics( log_temperature, "log(temperature)") temp = logits - temperature * log_given return tf.lgamma(n) + (n - 1) * log_temperature + \ tf.reduce_sum(temp - log_given, axis=-1) - \ n * tf.reduce_logsumexp(temp, axis=-1) def _prob(self, given): return tf.exp(self._log_prob(given)) GumbelSoftmax = Concrete class MatrixVariateNormalCholesky(Distribution): """ The class of matrix variate normal distribution, where covariances :math:`U` and :math:`V` are parameterized with the lower triangular matrix in Cholesky decomposition, .. math :: L_u \\text{s.t.} L_u L_u^T = U,\\; L_v \\text{s.t.} L_v L_v^T = V See :class:`~zhusuan.distributions.base.Distribution` for details. :param mean: An N-D `float` Tensor of shape [..., n_row, n_col]. Each slice [i, j, ..., k, :, :] represents the mean of a single matrix variate normal distribution. :param u_tril: An N-D `float` Tensor of shape [..., n_row, n_row]. Each slice [i, j, ..., k, :, :] represents the lower triangular matrix in the Cholesky decomposition of the among-row covariance of a single matrix variate normal distribution. :param v_tril: An N-D `float` Tensor of shape [..., n_col, n_col]. Each slice [i, j, ..., k, :, :] represents the lower triangular matrix in the Cholesky decomposition of the among-column covariance of a single matrix variate normal distribution. :param group_ndims: A 0-D `int32` Tensor representing the number of dimensions in `batch_shape` (counted from the end) that are grouped into a single event, so that their probabilities are calculated together. Default is 0, which means a single value is an event. See :class:`~zhusuan.distributions.base.Distribution` for more detailed explanation. :param is_reparameterized: A Bool. If True, gradients on samples from this distribution are allowed to propagate into inputs, using the reparametrization trick from (Kingma, 2013). :param use_path_derivative: A bool. Whether when taking the gradients of the log-probability to propagate them through the parameters of the distribution (False meaning you do propagate them). This is based on the paper "Sticking the Landing: Simple, Lower-Variance Gradient Estimators for Variational Inference" :param check_numerics: Bool. Whether to check numeric issues. """ def __init__(self, mean, u_tril, v_tril, group_ndims=0, is_reparameterized=True, use_path_derivative=False, check_numerics=False, **kwargs): self._check_numerics = check_numerics self._mean = tf.convert_to_tensor(mean) self._mean = assert_rank_at_least( self._mean, 2, 'MatrixVariateNormalCholesky.mean') self._n_row = get_shape_at(self._mean, -2) self._n_col = get_shape_at(self._mean, -1) self._u_tril = tf.convert_to_tensor(u_tril) self._u_tril = assert_rank_at_least( self._u_tril, 2, 'MatrixVariateNormalCholesky.u_tril') self._v_tril = tf.convert_to_tensor(v_tril) self._v_tril = assert_rank_at_least( self._v_tril, 2, 'MatrixVariateNormalCholesky.v_tril') # Static shape check expected_u_shape = self._mean.get_shape()[:-1].concatenate( [self._n_row if isinstance(self._n_row, int) else None]) self._u_tril.get_shape().assert_is_compatible_with(expected_u_shape) expected_v_shape = self._mean.get_shape()[:-2].concatenate( [self._n_col if isinstance(self._n_col, int) else None] * 2) self._v_tril.get_shape().assert_is_compatible_with(expected_v_shape) # Dynamic expected_u_shape = tf.concat( [tf.shape(self._mean)[:-1], [self._n_row]], axis=0) actual_u_shape = tf.shape(self._u_tril) msg = ['MatrixVariateNormalCholesky.u_tril should have compatible ' 'shape with mean. Expected', expected_u_shape, ' got ', actual_u_shape] assert_u_ops = tf.assert_equal(expected_u_shape, actual_u_shape, msg) expected_v_shape = tf.concat( [tf.shape(self._mean)[:-2], [self._n_col, self._n_col]], axis=0) actual_v_shape = tf.shape(self._v_tril) msg = ['MatrixVariateNormalCholesky.v_tril should have compatible ' 'shape with mean. Expected', expected_v_shape, ' got ', actual_v_shape] assert_v_ops = tf.assert_equal(expected_v_shape, actual_v_shape, msg) with tf.control_dependencies([assert_u_ops, assert_v_ops]): self._u_tril = tf.identity(self._u_tril) self._v_tril = tf.identity(self._v_tril) dtype = assert_same_float_dtype( [(self._mean, 'MatrixVariateNormalCholesky.mean'), (self._u_tril, 'MatrixVariateNormalCholesky.u_tril'), (self._v_tril, 'MatrixVariateNormalCholesky.v_tril')]) super(MatrixVariateNormalCholesky, self).__init__( dtype=dtype, param_dtype=dtype, is_continuous=True, is_reparameterized=is_reparameterized, use_path_derivative=use_path_derivative, group_ndims=group_ndims, **kwargs) @property def mean(self): """The mean of the matrix variate normal distribution.""" return self._mean @property def u_tril(self): """ The lower triangular matrix in the Cholesky decomposition of the among-row covariance. """ return self._u_tril @property def v_tril(self): """ The lower triangular matrix in the Cholesky decomposition of the among-column covariance. """ return self._v_tril def _value_shape(self): return tf.convert_to_tensor([self._n_row, self._n_col], tf.int32) def _get_value_shape(self): shape_ = tf.TensorShape([ self._n_row if isinstance(self._n_row, int) else None, self._n_col if isinstance(self._n_col, int) else None]) return shape_ def _batch_shape(self): return tf.shape(self.mean)[:-2] def _get_batch_shape(self): if self.mean.get_shape(): return self.mean.get_shape()[:-2] return tf.TensorShape(None) def _sample(self, n_samples): mean, u_tril, v_tril = self.mean, self.u_tril, self.v_tril if not self.is_reparameterized: mean = tf.stop_gradient(mean) u_tril = tf.stop_gradient(u_tril) v_tril = tf.stop_gradient(v_tril) def tile(t): new_shape = tf.concat([[n_samples], tf.ones_like(tf.shape(t))], 0) return tf.tile(tf.expand_dims(t, 0), new_shape) batch_u_tril = tile(u_tril) batch_v_tril = tile(v_tril) noise = tf.random_normal( tf.concat([[n_samples], tf.shape(mean)], axis=0), dtype=self.dtype) samples = mean + \ tf.matmul(tf.matmul(batch_u_tril, noise), tf.matrix_transpose(batch_v_tril)) # Update static shape static_n_samples = n_samples if isinstance(n_samples, int) else None samples.set_shape(tf.TensorShape([static_n_samples]) .concatenate(self.get_batch_shape()) .concatenate(self.get_value_shape())) return samples def _log_prob(self, given): mean, u_tril, v_tril = (self.path_param(self.mean), self.path_param(self.u_tril), self.path_param(self.v_tril)) log_det_u = 2 * tf.reduce_sum( tf.log(tf.matrix_diag_part(u_tril)), axis=-1) log_det_v = 2 * tf.reduce_sum( tf.log(tf.matrix_diag_part(v_tril)), axis=-1) n_row = tf.cast(self._n_row, self.dtype) n_col = tf.cast(self._n_col, self.dtype) logZ = - (n_row * n_col) / 2. * \ tf.log(2. * tf.constant(np.pi, dtype=self.dtype)) - \ n_row / 2. * log_det_v - n_col / 2. * log_det_u # logZ.shape == batch_shape if self._check_numerics: logZ = tf.check_numerics(logZ, "log[det(Cov)]") y = given - mean y_with_last_dim_changed = tf.expand_dims(tf.ones(tf.shape(y)[:-1]), -1) Lu, _ = maybe_explicit_broadcast( u_tril, y_with_last_dim_changed, 'MatrixVariateNormalCholesky.u_tril', 'expand_dims(given, -1)') y_with_sec_last_dim_changed = tf.expand_dims(tf.ones( tf.concat([tf.shape(y)[:-2], tf.shape(y)[-1:]], axis=0)), -1) Lv, _ = maybe_explicit_broadcast( v_tril, y_with_sec_last_dim_changed, 'MatrixVariateNormalCholesky.v_tril', 'expand_dims(given, -1)') x_Lb_inv_t = tf.matrix_triangular_solve(Lu, y, lower=True) x_t = tf.matrix_triangular_solve(Lv, tf.matrix_transpose(x_Lb_inv_t), lower=True) stoc_dist = -0.5 * tf.reduce_sum(tf.square(x_t), [-1, -2]) return logZ + stoc_dist def _prob(self, given): return tf.exp(self._log_prob(given))