from typing import Optional, Sequence, Tuple, Union

import torch
from torch import Tensor

from torch_kalman.design import Design
from torch_kalman.state_belief import StateBelief
from torch_kalman.state_belief.families.censored_gaussian.utils import tobit_adjustment, tobit_probs, std_normal
from torch_kalman.state_belief.families.gaussian import Gaussian, GaussianOverTime
from torch_kalman.state_belief.utils import bmat_idx

Selector = Union[Sequence[int], slice]


class CensoredGaussian(Gaussian):

    def update(self,
               obs: Tensor,
               lower: Optional[Tensor] = None,
               upper: Optional[Tensor] = None,
               **kwargs) -> 'StateBelief':
        if 'time' in kwargs:
            time = kwargs.pop('time')
            if time >= obs.shape[1]:
                return self.copy()
            return self.update(
                obs=obs[:, time],
                lower=lower[:, time] if lower is not None else None,
                upper=upper[:, time] if upper is not None else None,
                **kwargs
            )

        return super().update(obs, lower=lower, upper=upper)

    def _update_group(self,
                      obs: Tensor,
                      group_idx: Union[slice, Sequence[int]],
                      which_valid: Union[slice, Sequence[int]],
                      lower: Optional[Tensor] = None,
                      upper: Optional[Tensor] = None
                      ) -> Tuple[Tensor, Tensor]:
        # indices:
        idx_2d = bmat_idx(group_idx, which_valid)
        idx_3d = bmat_idx(group_idx, which_valid, which_valid)

        # observed values, censoring limits
        obs = obs[idx_2d]
        if lower is None:
            lower = torch.full_like(obs, -float('inf'))
        else:
            lower = lower[idx_2d]
            if torch.isnan(lower).any():
                raise ValueError("NaNs not allowed in `lower`")
        if upper is None:
            upper = torch.full_like(obs, float('inf'))
        else:
            upper = upper[idx_2d]
            if torch.isnan(upper).any():
                raise ValueError("NaNs not allowed in `upper`")

        if (lower == upper).any():
            raise RuntimeError("lower cannot == upper")

        # subset belief / design-mats:
        means = self.means[group_idx]
        covs = self.covs[group_idx]
        R = self.R[idx_3d]
        H = self.H[idx_2d]
        measured_means = H.matmul(means.unsqueeze(-1)).squeeze(-1)

        # calculate censoring fx:
        prob_lo, prob_up = tobit_probs(mean=measured_means,
                                       cov=R,
                                       lower=lower,
                                       upper=upper)
        prob_obs = torch.diag_embed(1 - prob_up - prob_lo)

        mm_adj, R_adj = tobit_adjustment(mean=measured_means,
                                         cov=R,
                                         lower=lower,
                                         upper=upper,
                                         probs=(prob_lo, prob_up))

        # kalman gain:
        K = self.kalman_gain(covariance=covs, H=H, R_adjusted=R_adj, prob_obs=prob_obs)

        # update
        means_new = self.mean_update(mean=means, K=K, residuals=obs - mm_adj)
        covs_new = self.covariance_update(covariance=covs, K=K, H=H, prob_obs=prob_obs)
        return means_new, covs_new

    def _update_last_measured(self, obs: Tensor) -> Tensor:
        if obs.ndimension() == 3:
            obs = obs[..., 0]
        any_measured_group_idx = (torch.sum(~torch.isnan(obs), 1) > 0).nonzero().squeeze(-1)
        last_measured = self.last_measured.clone()
        last_measured[any_measured_group_idx] = 0
        return last_measured

    @staticmethod
    def mean_update(mean: Tensor, K: Tensor, residuals: Tensor) -> Tensor:
        return mean + K.matmul(residuals.unsqueeze(-1)).squeeze(-1)

    @staticmethod
    def covariance_update(covariance: Tensor, H: Tensor, K: Tensor, prob_obs: Tensor) -> Tensor:
        num_groups, num_dim, *_ = covariance.shape
        I = torch.eye(num_dim, num_dim).expand(num_groups, -1, -1)
        k = (I - K.matmul(prob_obs).matmul(H))
        return k.matmul(covariance)

    @staticmethod
    def kalman_gain(covariance: Tensor,
                    H: Tensor,
                    R_adjusted: Tensor,
                    prob_obs: Tensor) -> Tensor:
        Ht = H.permute(0, 2, 1)
        state_uncertainty = covariance.matmul(Ht).matmul(prob_obs)
        system_uncertainty = prob_obs.matmul(H).matmul(covariance).matmul(Ht).matmul(prob_obs) + R_adjusted
        system_uncertainty_inv = torch.inverse(system_uncertainty)
        return state_uncertainty.matmul(system_uncertainty_inv)

    @classmethod
    def concatenate_over_time(cls,
                              state_beliefs: Sequence['CensoredGaussian'],
                              design: Design) -> 'CensoredGaussianOverTime':
        return CensoredGaussianOverTime(state_beliefs=state_beliefs, design=design)

    def sample_transition(self,
                          lower: Optional[Tensor] = None,
                          upper: Optional[Tensor] = None,
                          eps: Optional[Tensor] = None) -> Tensor:
        if lower is None and upper is None:
            return super().sample_transition(eps=eps)
        raise NotImplementedError


class CensoredGaussianOverTime(GaussianOverTime):
    def __init__(self,
                 state_beliefs: Sequence['CensoredGaussian'],
                 design: Design):
        super().__init__(state_beliefs=state_beliefs, design=design)

    def log_prob(self,
                 obs: Tensor,
                 lower: Optional[Tensor] = None,
                 upper: Optional[Tensor] = None):
        return super().log_prob(obs=obs, lower=lower, upper=upper)

    def _log_prob_with_subsetting(self,
                                  obs: Tensor,
                                  group_idx: Selector,
                                  time_idx: Selector,
                                  measure_idx: Selector,
                                  method: str = 'independent',
                                  lower: Optional[Tensor] = None,
                                  upper: Optional[Tensor] = None) -> Tensor:
        self._check_lp_sub_input(group_idx, time_idx)

        idx_3d = bmat_idx(group_idx, time_idx, measure_idx)
        idx_4d = bmat_idx(group_idx, time_idx, measure_idx, measure_idx)

        # subset obs, lower, upper:
        if upper is None:
            upper = torch.full_like(obs, float('inf'))
        if lower is None:
            lower = torch.full_like(obs, -float('inf'))
        obs, lower, upper = obs[idx_3d], lower[idx_3d], upper[idx_3d]

        #
        pred_mean = self.predictions[idx_3d]
        pred_cov = self.prediction_uncertainty[idx_4d]

        #
        cens_up = torch.isclose(obs, upper)
        cens_lo = torch.isclose(obs, lower)

        #
        loglik_uncens = torch.zeros_like(obs)
        loglik_cens_up = torch.zeros_like(obs)
        loglik_cens_lo = torch.zeros_like(obs)
        for m in range(pred_mean.shape[-1]):
            std = pred_cov[..., m, m].sqrt()
            z = (pred_mean[..., m] - obs[..., m]) / std

            # pdf is well behaved at tails:
            loglik_uncens[..., m] = std_normal.log_prob(z) - std.log()

            # but cdf is not, clamp:
            z = torch.clamp(z, -5., 5.)
            loglik_cens_up[..., m] = std_normal.cdf(z).log()
            loglik_cens_lo[..., m] = (1. - std_normal.cdf(z)).log()

        loglik = torch.zeros_like(obs)
        loglik[cens_up] = loglik_cens_up[cens_up]
        loglik[cens_lo] = loglik_cens_lo[cens_lo]
        loglik[~(cens_up | cens_lo)] = loglik_uncens[~(cens_up | cens_lo)]

        # take the product of the dimension probs (i.e., assume independence)
        return torch.sum(loglik, -1)

    def sample_measurements(self,
                            lower: Optional[Tensor] = None,
                            upper: Optional[Tensor] = None,
                            eps: Optional[Tensor] = None):
        if lower is None and upper is None:
            return super().sample_measurements(eps=eps)
        raise NotImplementedError