import numpy as np import pymc3 as pm import theano from arviz import from_pymc3 from bambi.priors import Prior from .base import BackEnd class PyMC3BackEnd(BackEnd): """PyMC3 model-fitting back-end.""" # Available link functions links = { "identity": lambda x: x, "logit": theano.tensor.nnet.sigmoid, "inverse": theano.tensor.inv, "inverse_squared": lambda x: theano.tensor.inv(theano.tensor.sqrt(x)), "log": theano.tensor.exp, } dists = {"HalfFlat": pm.Bound(pm.Flat, lower=0)} def __init__(self): self.reset() # Attributes defined elsewhere self.mu = None # build() self.spec = None # build() self.trace = None # build() self.advi_params = None # build() def reset(self): """Reset PyMC3 model and all tracked distributions and parameters.""" self.model = pm.Model() self.mu = None self.par_groups = {} def _build_dist(self, spec, label, dist, **kwargs): """Build and return a PyMC3 Distribution.""" if isinstance(dist, str): if hasattr(pm, dist): dist = getattr(pm, dist) elif dist in self.dists: dist = self.dists[dist] else: raise ValueError( f"The Distribution {dist} was not found in PyMC3 or the PyMC3BackEnd." ) # Inspect all args in case we have hyperparameters def _expand_args(key, value, label): if isinstance(value, Prior): label = f"{label}_{key}" return self._build_dist(spec, label, value.name, **value.args) return value kwargs = {k: _expand_args(k, v, label) for (k, v) in kwargs.items()} # Non-centered parameterization for hyperpriors if ( spec.noncentered and "sigma" in kwargs and "observed" not in kwargs and isinstance(kwargs["sigma"], pm.model.TransformedRV) ): old_sigma = kwargs["sigma"] _offset = pm.Normal(label + "_offset", mu=0, sigma=1, shape=kwargs["shape"]) return pm.Deterministic(label, _offset * old_sigma) return dist(label, **kwargs) def build(self, spec, reset=True): # pylint: disable=arguments-differ """Compile the PyMC3 model from an abstract model specification. Parameters ---------- spec : Bambi model A bambi Model instance containing the abstract specification of the model to compile. reset : Bool If True (default), resets the PyMC3BackEnd instance before compiling. """ if reset: self.reset() with self.model: self.mu = 0.0 for t in spec.terms.values(): data = t.data label = t.name dist_name = t.prior.name dist_args = t.prior.args n_cols = t.data.shape[1] coef = self._build_dist(spec, label, dist_name, shape=n_cols, **dist_args) if t.random: self.mu += coef[t.group_index][:, None] * t.predictor else: self.mu += pm.math.dot(data, coef)[:, None] y = spec.y.data y_prior = spec.family.prior link_f = spec.family.link if isinstance(link_f, str): link_f = self.links[link_f] y_prior.args[spec.family.parent] = link_f(self.mu) y_prior.args["observed"] = y self._build_dist(spec, spec.y.name, y_prior.name, **y_prior.args) self.spec = spec # pylint: disable=arguments-differ, inconsistent-return-statements def run(self, start=None, method="mcmc", init="auto", n_init=50000, **kwargs): """Run the PyMC3 MCMC sampler. Parameters ---------- start: dict, or array of dict Starting parameter values to pass to sampler; see ``'pm.sample()'`` for details. method: str The method to use for fitting the model. By default, 'mcmc', in which case the PyMC3 sampler will be used. Alternatively, 'advi', in which case the model will be fitted using automatic differentiation variational inference as implemented in PyMC3. Finally, 'laplace', in wich case a laplace approximation is used, 'laplace' is not recommended other than for pedagogical use. init: str Initialization method (see PyMC3 sampler documentation). Currently, this is ``'jitter+adapt_diag'``, but this can change in the future. n_init: int Number of initialization iterations if init = 'advi' or 'nuts'. Default is kind of in PyMC3 for the kinds of models we expect to see run with bambi, so we lower it considerably. Returns ------- An ArviZ InferenceData instance. """ model = self.model if method.lower() == "mcmc": samples = kwargs.pop("samples", 1000) with model: self.trace = pm.sample(samples, start=start, init=init, n_init=n_init, **kwargs) return from_pymc3(self.trace, model=model) elif method.lower() == "advi": with model: self.advi_params = pm.variational.ADVI(start, **kwargs) return ( self.advi_params ) # this should return an InferenceData object (once arviz adds support for VI) elif method.lower() == "laplace": return _laplace(model) def _laplace(model): """Fit a model using a laplace approximation. Mainly for pedagogical use. ``mcmc`` and ``advi`` are better approximations. Parameters ---------- model: PyMC3 model Returns ------- Dictionary, the keys are the names of the variables and the values tuples of modes and standard deviations. """ with model: varis = [v for v in model.unobserved_RVs if not pm.util.is_transformed_name(v.name)] maps = pm.find_MAP(start=model.test_point, vars=varis) hessian = pm.find_hessian(maps, vars=varis) if np.linalg.det(hessian) == 0: raise np.linalg.LinAlgError("Singular matrix. Use mcmc or advi method") stds = np.diag(np.linalg.inv(hessian) ** 0.5) maps = [v for (k, v) in maps.items() if not pm.util.is_transformed_name(k)] modes = [v.item() if v.size == 1 else v for v in maps] names = [v.name for v in varis] shapes = [np.atleast_1d(mode).shape for mode in modes] stds_reshaped = [] idx0 = 0 for shape in shapes: idx1 = idx0 + sum(shape) stds_reshaped.append(np.reshape(stds[idx0:idx1], shape)) idx0 = idx1 return dict(zip(names, zip(modes, stds_reshaped)))