# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """The Hyperspherical Uniform distribution class.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import math from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import distribution from tensorflow.python.ops import nn_impl from tensorflow.python.framework.tensor_shape import TensorShape from tensorflow.python.ops import gen_math_ops class HypersphericalUniform(distribution.Distribution): """Hyperspherical Uniform distribution with `dim` parameter. #### Mathematical Details """ def __init__(self, dim, dtype=dtypes.float32, validate_args=False, allow_nan_stats=True, name="HypersphericalUniform"): """Initialize a batch of Hyperspherical Uniform distributions. Args: dim: Integer tensor, dimensionality of the distribution(s). Must be `dim > 0`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect outputs. allow_nan_stats: Python `bool`, default `True`. When `True`, statistics (e.g., mean, mode, variance) use the value "`NaN`" to indicate the result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. Raises: InvalidArgumentError: if `dim > 0` and `validate_args=False`. """ parameters = locals() with ops.name_scope(name, values=[dim]): with ops.control_dependencies([check_ops.assert_positive(dim), check_ops.assert_integer(dim), check_ops.assert_scalar(dim)] if validate_args else []): self._dim = dim super(HypersphericalUniform, self).__init__( dtype=dtype, reparameterization_type=distribution.FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=[], name=name) @staticmethod def _param_shapes(sample_shape): return {} @property def dim(self): """Dimensionality of the distribution(s).""" return self._dim def _batch_shape_tensor(self): return constant_op.constant([self._dim + 1], dtype=dtypes.int32) def _batch_shape(self): return TensorShape(self._dim + 1) def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) def _event_shape(self): return tensor_shape.scalar() def _sample_n(self, n, seed=None): return nn_impl.l2_normalize(random_ops.random_normal(shape=array_ops.concat(([n], [self._dim + 1]), 0), dtype=self.dtype, seed=seed), axis=-1) def _log_prob(self, x): return - array_ops.ones(shape=array_ops.shape(x)[:-1], dtype=self.dtype) * self.__log_surface_area() def _prob(self, x): return math_ops.exp(self._log_prob(x)) def _entropy(self): return self.__log_surface_area() def __log_surface_area(self): return math.log(2) + ((self._dim + 1) / 2) * math.log(math.pi) - gen_math_ops.lgamma( math_ops.cast((self._dim + 1) / 2, dtype=self.dtype))