import numpy as np
import pymc3 as pm
import theano
from arviz import from_pymc3

from bambi.priors import Prior

from .base import BackEnd

class PyMC3BackEnd(BackEnd):
    """PyMC3 model-fitting back-end."""

    # Available link functions
    links = {
        "identity": lambda x: x,
        "logit": theano.tensor.nnet.sigmoid,
        "inverse": theano.tensor.inv,
        "inverse_squared": lambda x: theano.tensor.inv(theano.tensor.sqrt(x)),
        "log": theano.tensor.exp,

    dists = {"HalfFlat": pm.Bound(pm.Flat, lower=0)}

    def __init__(self):

        # Attributes defined elsewhere = None  # build()
        self.spec = None  # build()
        self.trace = None  # build()
        self.advi_params = None  # build()

    def reset(self):
        """Reset PyMC3 model and all tracked distributions and parameters."""
        self.model = pm.Model() = None
        self.par_groups = {}

    def _build_dist(self, spec, label, dist, **kwargs):
        """Build and return a PyMC3 Distribution."""
        if isinstance(dist, str):
            if hasattr(pm, dist):
                dist = getattr(pm, dist)
            elif dist in self.dists:
                dist = self.dists[dist]
                raise ValueError(
                    f"The Distribution {dist} was not found in PyMC3 or the PyMC3BackEnd."

        # Inspect all args in case we have hyperparameters
        def _expand_args(key, value, label):
            if isinstance(value, Prior):
                label = f"{label}_{key}"
                return self._build_dist(spec, label,, **value.args)
            return value

        kwargs = {k: _expand_args(k, v, label) for (k, v) in kwargs.items()}

        # Non-centered parameterization for hyperpriors
        if (
            and "sigma" in kwargs
            and "observed" not in kwargs
            and isinstance(kwargs["sigma"], pm.model.TransformedRV)
            old_sigma = kwargs["sigma"]
            _offset = pm.Normal(label + "_offset", mu=0, sigma=1, shape=kwargs["shape"])
            return pm.Deterministic(label, _offset * old_sigma)

        return dist(label, **kwargs)

    def build(self, spec, reset=True):  # pylint: disable=arguments-differ
        """Compile the PyMC3 model from an abstract model specification.

        spec : Bambi model
            A bambi Model instance containing the abstract specification of the model to compile.
        reset : Bool
            If True (default), resets the PyMC3BackEnd instance before compiling.
        if reset:

        with self.model:
   = 0.0

            for t in spec.terms.values():
                data =
                label =
                dist_name =
                dist_args = t.prior.args

                n_cols =[1]

                coef = self._build_dist(spec, label, dist_name, shape=n_cols, **dist_args)

                if t.random:
           += coef[t.group_index][:, None] * t.predictor
           +=, coef)[:, None]

            y =
            y_prior =
            link_f =

            if isinstance(link_f, str):
                link_f = self.links[link_f]

            y_prior.args[] = link_f(
            y_prior.args["observed"] = y
            self._build_dist(spec,,, **y_prior.args)
            self.spec = spec

    # pylint: disable=arguments-differ, inconsistent-return-statements
    def run(self, start=None, method="mcmc", init="auto", n_init=50000, **kwargs):
        """Run the PyMC3 MCMC sampler.

        start: dict, or array of dict
            Starting parameter values to pass to sampler; see ``'pm.sample()'`` for details.
        method: str
            The method to use for fitting the model. By default, 'mcmc', in which case the
            PyMC3 sampler will be used. Alternatively, 'advi', in which case the model will be
            fitted using  automatic differentiation variational inference as implemented in PyMC3.
            Finally, 'laplace', in wich case a laplace approximation is used, 'laplace' is not
            recommended other than for pedagogical use.
        init: str
            Initialization method (see PyMC3 sampler documentation). Currently, this is
            ``'jitter+adapt_diag'``, but this can change in the future.
        n_init: int
            Number of initialization iterations if init = 'advi' or 'nuts'. Default is kind of in
            PyMC3 for the kinds of models we expect to see run with bambi, so we lower it

        An ArviZ InferenceData instance.
        model = self.model

        if method.lower() == "mcmc":
            samples = kwargs.pop("samples", 1000)
            with model:
                self.trace = pm.sample(samples, start=start, init=init, n_init=n_init, **kwargs)

            return from_pymc3(self.trace, model=model)

        elif method.lower() == "advi":
            with model:
                self.advi_params = pm.variational.ADVI(start, **kwargs)
            return (
            )  # this should return an InferenceData object (once arviz adds support for VI)

        elif method.lower() == "laplace":
            return _laplace(model)

def _laplace(model):
    """Fit a model using a laplace approximation.

    Mainly for pedagogical use. ``mcmc`` and ``advi`` are better approximations.

    model: PyMC3 model

    Dictionary, the keys are the names of the variables and the values tuples of modes and standard
    with model:
        varis = [v for v in model.unobserved_RVs if not pm.util.is_transformed_name(]
        maps = pm.find_MAP(start=model.test_point, vars=varis)
        hessian = pm.find_hessian(maps, vars=varis)
        if np.linalg.det(hessian) == 0:
            raise np.linalg.LinAlgError("Singular matrix. Use mcmc or advi method")
        stds = np.diag(np.linalg.inv(hessian) ** 0.5)
        maps = [v for (k, v) in maps.items() if not pm.util.is_transformed_name(k)]
        modes = [v.item() if v.size == 1 else v for v in maps]
        names = [ for v in varis]
        shapes = [np.atleast_1d(mode).shape for mode in modes]
        stds_reshaped = []
        idx0 = 0
        for shape in shapes:
            idx1 = idx0 + sum(shape)
            stds_reshaped.append(np.reshape(stds[idx0:idx1], shape))
            idx0 = idx1
    return dict(zip(names, zip(modes, stds_reshaped)))