Python torch.distributions.Distribution() Examples

The following are 27 code examples of torch.distributions.Distribution(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch.distributions , or try the search function .
Example #1
Source File: module.py    From gpytorch with MIT License 6 votes vote down vote up
def _validate_module_outputs(outputs):
    if isinstance(outputs, tuple):
        if not all(
            torch.is_tensor(output) or isinstance(output, Distribution) or isinstance(output, LazyTensor)
            for output in outputs
        ):
            raise RuntimeError(
                "All outputs must be a Distribution, torch.Tensor, or LazyTensor. "
                "Got {}".format([output.__class__.__name__ for output in outputs])
            )
        if len(outputs) == 1:
            outputs = outputs[0]
        return outputs
    elif torch.is_tensor(outputs) or isinstance(outputs, Distribution) or isinstance(outputs, LazyTensor):
        return outputs
    else:
        raise RuntimeError(
            "Output must be a Distribution, torch.Tensor, or LazyTensor. Got {}".format(outputs.__class__.__name__)
        ) 
Example #2
Source File: base_likelihood_test_case.py    From gpytorch with MIT License 6 votes vote down vote up
def _test_marginal(self, batch_shape):
        likelihood = self.create_likelihood()
        likelihood.max_plate_nesting += len(batch_shape)
        input = self._create_marginal_input(batch_shape)
        output = likelihood(input)

        self.assertTrue(isinstance(output, Distribution))
        self.assertEqual(output.sample().shape[-len(input.sample().shape) :], input.sample().shape)

        # Compare against default implementation
        with gpytorch.settings.num_likelihood_samples(30000):
            default = Likelihood.marginal(likelihood, input)
        # print(output.mean, default.mean)
        default_mean = default.mean
        actual_mean = output.mean
        if default_mean.dim() > actual_mean.dim():
            default_mean = default_mean.mean(0)
        self.assertAllClose(default_mean, actual_mean, rtol=0.25, atol=0.25) 
Example #3
Source File: linear.py    From pyfilter with MIT License 6 votes vote down vote up
def __init__(self, hidden, a=1., scale=1.):
        """
        Implements a State Space model that's linear in the observation equation but has arbitrary dynamics in the
        state process.
        :param hidden: The hidden dynamics
        :param a: The A-matrix
        :param scale: The variance of the observations
        """

        # ===== Convoluted way to decide number of dimensions ===== #
        dim, is_1d = _get_shape(a)

        # ====== Define distributions ===== #
        n = dists.Normal(0., 1.) if is_1d else dists.Independent(dists.Normal(torch.zeros(dim), torch.ones(dim)), 1)

        if not isinstance(scale, (torch.Tensor, float, dists.Distribution)):
            raise ValueError(f'`scale` parameter must be numeric type!')

        super().__init__(hidden, a, scale, n) 
Example #4
Source File: module.py    From pyfilter with MIT License 6 votes vote down vote up
def tensors(self) -> Tuple[torch.Tensor, ...]:
        """
        Finds and returns all instances of type module.
        """

        res = tuple()

        # ===== Find all tensor types ====== #
        res += tuple(self._find_obj_helper(torch.Tensor).values())

        # ===== Tensor containers ===== #
        for tc in self._find_obj_helper(TensorContainerBase).values():
            res += tc.tensors
            for t in (t_ for t_ in tc.tensors if isinstance(t_, Parameter) and t_.trainable):
                res += _iterate_distribution(t.distr)

        # ===== Pytorch distributions ===== #
        for d in self._find_obj_helper(Distribution).values():
            res += _iterate_distribution(d)

        # ===== Modules ===== #
        for mod in self.modules().values():
            res += mod.tensors()

        return res 
Example #5
Source File: module.py    From pyfilter with MIT License 6 votes vote down vote up
def _iterate_distribution(d: Distribution) -> Tuple[Distribution, ...]:
    """
    Helper method for iterating over distributions.
    :param d: The distribution
    """

    res = tuple()
    if not isinstance(d, TransformedDistribution):
        res += tuple(_find_types(d, torch.Tensor).values())

        for sd in _find_types(d, Distribution).values():
            res += _iterate_distribution(sd)

    else:
        res += _iterate_distribution(d.base_dist)

        for t in d.transforms:
            res += tuple(_find_types(t, torch.Tensor).values())

    return res 
Example #6
Source File: utils.py    From pyfilter with MIT License 6 votes vote down vote up
def _mcmc_move(params: Iterable[Parameter], dist: Distribution, stacked: StackedObject, shape: int):
    """
    Performs an MCMC move to rejuvenate parameters.
    :param params: The parameters to use for defining the distribution
    :param dist: The distribution to use for sampling
    :param stacked: The mask to apply for parameters
    :param shape: The shape to sample
    :return: Samples from a multivariate normal distribution
    """

    rvs = dist.sample((shape,))

    for p, msk, ps in zip(params, stacked.mask, stacked.prev_shape):
        p.t_values = unflattify(rvs[:, msk], ps)

    return True 
Example #7
Source File: affine.py    From pyfilter with MIT License 5 votes vote down vote up
def _define_transdist(loc: torch.Tensor, scale: torch.Tensor, inc_dist: Distribution, ndim: int):
    loc, scale = torch.broadcast_tensors(loc, scale)

    shape = loc.shape[:-ndim] if ndim > 0 else loc.shape

    return TransformedDistribution(
        inc_dist.expand(shape), AffineTransform(loc, scale, event_dim=ndim)
    ) 
Example #8
Source File: continuous.py    From rising with MIT License 5 votes vote down vote up
def __init__(self, distribution: TorchDistribution):
        """
        Args:
            distribution : the distribution to sample from
        """
        super().__init__()
        self.dist = distribution 
Example #9
Source File: base_likelihood_test_case.py    From gpytorch with MIT License 5 votes vote down vote up
def _test_conditional(self, batch_shape):
        likelihood = self.create_likelihood()
        likelihood.max_plate_nesting += len(batch_shape)
        input = self._create_conditional_input(batch_shape)
        output = likelihood(input)

        self.assertTrue(isinstance(output, Distribution))
        self.assertEqual(output.sample().shape, input.shape) 
Example #10
Source File: test_softmax_likelihood.py    From gpytorch with MIT License 5 votes vote down vote up
def _test_marginal(self, batch_shape):
        likelihood = self.create_likelihood()
        input = self._create_marginal_input(batch_shape)
        output = likelihood(input)

        self.assertTrue(isinstance(output, Distribution))
        self.assertEqual(output.sample().shape[-len(batch_shape) - 1 :], torch.Size([*batch_shape, 5])) 
Example #11
Source File: test_softmax_likelihood.py    From gpytorch with MIT License 5 votes vote down vote up
def _test_conditional(self, batch_shape):
        likelihood = self.create_likelihood()
        input = self._create_conditional_input(batch_shape)
        output = likelihood(input)

        self.assertIsInstance(output, Distribution)
        self.assertEqual(output.sample().shape, torch.Size([*batch_shape, 5])) 
Example #12
Source File: action_sampler.py    From guacamol_baselines with MIT License 5 votes vote down vote up
def __init__(self, max_batch_size, max_seq_length, device,
                 distribution_cls: Type[Distribution] = None) -> None:
        """
        Args:
            max_batch_size: maximal batch size for the RNN model
            max_seq_length: max length for a sampled SMILES string
            device: cuda | cpu
            distribution_cls: distribution type to sample from. If None, will be a multinomial distribution. Useful for testing purposes.
        """
        self.max_batch_size = max_batch_size
        self.max_seq_length = max_seq_length
        self.device = device

        self.distribution_cls = Categorical if distribution_cls is None else distribution_cls 
Example #13
Source File: action_replay.py    From guacamol_baselines with MIT License 5 votes vote down vote up
def __init__(self, max_batch_size, device, distribution_cls: Type[Distribution] = None) -> None:
        """
        Args:
            max_batch_size: Max. batch size
            device: cuda | cpu
            distribution_cls: distribution type to sample from. If None, will be a multinomial distribution.
        """
        self.max_batch_size = max_batch_size
        self.device = device

        self.distribution_cls = Categorical if distribution_cls is None else distribution_cls 
Example #14
Source File: diffusion.py    From pyfilter with MIT License 5 votes vote down vote up
def __init__(self, dynamics: Tuple[Callable[[torch.Tensor, Tuple[object, ...]], torch.Tensor], ...], theta,
                 initial_dist, increment_dist: Distribution, dt, **kwargs):
        """
        Euler Maruyama method for SDEs of affine nature. A generalization of OneStepMaruyama that allows multiple
        recursions. The difference between this class and GeneralEulerMaruyama is that you need not specify prop_state
        as it is assumed to follow the structure of OneStepEulerMaruyama.
        :param dynamics: A tuple of callable. Should _not_ include `dt` as the last argument
        """

        super().__init__(theta, initial_dist, increment_dist=increment_dist, dt=dt, prop_state=self._prop, **kwargs)
        self.f, self.g = dynamics 
Example #15
Source File: affine.py    From pyfilter with MIT License 5 votes vote down vote up
def __init__(self, std: Union[torch.Tensor, float, Distribution]):
        """
        Defines a random walk.
        :param std: The vector of standard deviations
        :type std: torch.Tensor|float|Distribution
        """

        if not isinstance(std, torch.Tensor):
            normal = Normal(0., 1.)
        else:
            normal = Normal(0., 1.) if std.shape[-1] < 2 else Independent(Normal(torch.zeros_like(std), std), 1)

        super().__init__((_f, _g), (std,), normal, normal) 
Example #16
Source File: base.py    From pyfilter with MIT License 5 votes vote down vote up
def propagate(self, x: torch.Tensor, as_dist=False) -> Union[Distribution, torch.Tensor]:
        """
        Propagates the model forward conditional on the previous state and current parameters.
        :param x: The previous state
        :param as_dist: Whether to return the new value as a distribution
        :return: Samples from the model
        """

        return self._propagate(x, as_dist) 
Example #17
Source File: parameter.py    From pyfilter with MIT License 5 votes vote down vote up
def trainable(self):
        """
        Boolean of whether parameter is trainable.
        """

        return isinstance(self._prior, Distribution) 
Example #18
Source File: parameter.py    From pyfilter with MIT License 5 votes vote down vote up
def sample_(self, shape: Union[int, Tuple[int, ...], torch.Size] = None):
        """
        Samples the variable from prior distribution in place.
        :param shape: The shape to use
        """
        if not self.trainable:
            raise ValueError('Cannot initialize parameter as it is not of instance `Distribution`!')

        self.data = self._prior.sample(size_getter(shape))

        return self 
Example #19
Source File: parameter.py    From pyfilter with MIT License 5 votes vote down vote up
def bijection(self) -> Transform:
        """
        Returns a bijected function for transforms from unconstrained to constrained space.
        """
        if not self.trainable:
            raise ValueError('Is not of `Distribution` instance!')

        return biject_to(self._prior.support) 
Example #20
Source File: parameter.py    From pyfilter with MIT License 5 votes vote down vote up
def transformed_dist(self):
        """
        Returns the unconstrained distribution.
        """

        if not self.trainable:
            raise ValueError('Is not of `Distribution` instance!')

        return TransformedDistribution(self._prior, [self.bijection.inv]) 
Example #21
Source File: parameter.py    From pyfilter with MIT License 5 votes vote down vote up
def __init__(self, parameter: Union[torch.Tensor, Distribution] = None, requires_grad=False):
        """
        The parameter class.
        """
        self._prior = parameter if isinstance(parameter, Distribution) else None 
Example #22
Source File: parameter.py    From pyfilter with MIT License 5 votes vote down vote up
def __new__(cls, parameter: Union[torch.Tensor, Distribution] = None, requires_grad=False):
        if isinstance(parameter, Parameter):
            raise ValueError('The input cannot be of instance `{}`!'.format(parameter.__class__.__name__))
        elif isinstance(parameter, torch.Tensor):
            _data = parameter
        elif not isinstance(parameter, Distribution):
            _data = torch.tensor(parameter) if not isinstance(parameter, np.ndarray) else torch.from_numpy(parameter)
        else:
            # This is just a place holder
            _data = torch.empty(parameter.event_shape)

        return torch.Tensor._make_subclass(cls, _data, requires_grad) 
Example #23
Source File: linear.py    From pyfilter with MIT License 5 votes vote down vote up
def _get_shape(a):
    is_1d = False
    if isinstance(a, dists.Distribution):
        dim = a.event_shape
        is_1d = len(a.event_shape) == 1
    elif isinstance(a, float) or a.dim() < 2:
        dim = torch.Size([])
        is_1d = (torch.tensor(a) if isinstance(a, float) else a).dim() <= 1
    else:
        dim = a.shape[:1]

    return dim, is_1d 
Example #24
Source File: base.py    From pyfilter with MIT License 5 votes vote down vote up
def dist(self) -> Distribution:
        """
        Returns the distribution.
        """

        raise NotImplementedError() 
Example #25
Source File: mh.py    From pyfilter with MIT License 5 votes vote down vote up
def define_pdf(self, values: torch.Tensor, weights: torch.Tensor) -> Distribution:
        """
        The method to be overridden by the user for defining the kernel to propagate the parameters. Note that the
        parameters are propagated in the transformed space.
        :param values: The parameters as a single Tensor
        :param weights: The normalized weights of the particles
        :return: A distribution
        """

        raise NotImplementedError() 
Example #26
Source File: utils.py    From pyfilter with MIT License 5 votes vote down vote up
def _eval_kernel(params: Iterable[Parameter], dist: Distribution, n_params: Iterable[Parameter]):
    """
    Evaluates the kernel used for performing the MCMC move.
    :param params: The current parameters
    :param dist: The distribution to use for evaluating the prior
    :param n_params: The new parameters to evaluate against
    :return: The log difference in priors
    """

    p_vals = stacker(params, lambda u: u.t_values)
    n_p_vals = stacker(n_params, lambda u: u.t_values)

    return dist.log_prob(p_vals.concated) - dist.log_prob(n_p_vals.concated) 
Example #27
Source File: stochastic.py    From probtorch with Apache License 2.0 4 votes vote down vote up
def _autogen_trace_methods():
    import torch as _torch
    from torch import distributions as _distributions
    import inspect as _inspect
    import re as _re

    # monkey patch relaxed distribtions
    def relaxed_bernoulli_log_pmf(self, value):
        return (value > self.probs).type('torch.FloatTensor')

    def relaxed_categorical_log_pmf(self, value):
        _, max_index = value.max(-1)
        return self.base_dist._categorical.log_prob(max_index)

    _distributions.RelaxedBernoulli.log_pmf = relaxed_bernoulli_log_pmf

    _distributions.RelaxedOneHotCategorical.log_pmf = relaxed_categorical_log_pmf

    def camel_to_snake(name):
        s1 = _re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
        return _re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()

    for name, obj in _inspect.getmembers(_distributions):
        if hasattr(obj, "__bases__") and issubclass(obj, _distributions.Distribution) and (obj.has_rsample == True):
            f_name = camel_to_snake(name).lower()
            doc="""Generates a random variable of type torch.distributions.%s""" % name
            try:
                # try python 3 first
                asp = _inspect.getfullargspec(obj.__init__)
            except Exception as e:
                # python 2
                asp = _inspect.getargspec(obj.__init__)

            arg_split = -len(asp.defaults) if asp.defaults else None
            args = ', '.join(asp.args[:arg_split])

            if arg_split:
                pairs = zip(asp.args[arg_split:], asp.defaults)
                kwargs = ', '.join(['%s=%s' % (n, v) for n, v in pairs])
                args = args + ', ' + kwargs

            env = {'obj': obj, 'torch': _torch}
            s = ("""def f({0}, name=None, value=None):
                    return self.variable(obj, {1}, name=name, value=value)""")
            input_args = ', '.join(asp.args[1:])
            exec(s.format(args, input_args), env)
            f = env['f']
            f.__name__ = f_name
            f.__doc__ = doc
            setattr(Trace, f_name, f)

    # add alias for relaxed_one_hot_categorical
    setattr(Trace, 'concrete', Trace.relaxed_one_hot_categorical)