from .base import BaseApproximation
import torch
from torch.distributions import Independent, Normal
from ..utils import stacker
from ...timeseries.base import StochasticProcess
from ...timeseries import Parameter
from typing import Tuple


class StateMeanField(BaseApproximation):
    def __init__(self, model: StochasticProcess):
        """
        Implements a mean field approximation of the state space.
        :param model: The model
        """

        super().__init__()
        self._mean = None
        self._log_std = None
        self._model = model

    def initialize(self, data, *args):
        self._mean = torch.zeros((data.shape[0] + 1, *self._model.increment_dist.event_shape), requires_grad=True)
        self._log_std = torch.ones_like(self._mean, requires_grad=True)

        return self

    def dist(self):
        return Independent(Normal(self._mean, self._log_std.exp()), self._model.ndim + 1)

    def get_parameters(self):
        return self._mean, self._log_std


# TODO: Only supports 1D parameters currently
class ParameterMeanField(BaseApproximation):
    def __init__(self):
        """
        Implements the mean field for parameters.
        """

        super().__init__()
        self._mean = None
        self._log_std = None

    def get_parameters(self):
        return self._mean, self._log_std

    def initialize(self, parameters: Tuple[Parameter, ...], *args):
        stacked = stacker(parameters, lambda u: u.t_values)

        self._mean = torch.zeros(stacked.concated.shape[1:], device=stacked.concated.device)
        self._log_std = torch.ones_like(self._mean)

        for p, msk in zip(parameters, stacked.mask):
            try:
                self._mean[msk] = p.bijection.inv(p.distr.mean)
            except NotImplementedError:
                pass

        self._mean.requires_grad_(True)
        self._log_std.requires_grad_(True)

        return self

    def dist(self):
        return Independent(Normal(self._mean, self._log_std.exp()), 1)