"""
@Article{Banerjee2005vMF,
  author  = {Banerjee, Arindam and Dhillon, Inderjit S and Ghosh, Joydeep and Sra, Suvrit},
  title   = {Clustering on the unit hypersphere using von {M}ises-{F}isher distributions},
  journal = {Journal of Machine Learning Research},
  year    = {2005},
  volume  = {6},
  number  = {Sep},
  pages   = {1345--1382},
}

@article{Wood1994Simulation,
  title={Simulation of the von Mises Fisher distribution},
  author={Wood, Andrew TA},
  journal={Communications in statistics-simulation and computation},
  volume={23},
  number={1},
  pages={157--164},
  year={1994},
  publisher={Taylor \& Francis}
}
"""
from dataclasses import dataclass
from scipy.special import ive
import numpy as np
from pb_bss.distribution.utils import _ProbabilisticModel
from pb_bss.utils import is_broadcast_compatible


@dataclass
class VonMisesFisher(_ProbabilisticModel):
    mean: np.array  # (..., D)
    concentration: np.array  # (...,)

    def log_norm(self):
        """Is fairly stable, when concentration > 1e-10."""
        D = self.mean.shape[-1]
        return (
            (D / 2) * np.log(2 * np.pi)
            + np.log(ive(D / 2 - 1, self.concentration))
            + (
                np.abs(self.concentration)
                - (D / 2 - 1) * np.log(self.concentration)
            )
        )

    def sample(self, size):
        """
        Sampling according to [Wood1994Simulation].

        Args:
            size:

        Returns:

        """
        raise NotImplementedError(
            'A good implementation can be found in libdirectional: '
            'https://github.com/libDirectional/libDirectional/blob/master/lib/distributions/Hypersphere/VMFDistribution.m#L239'
        )

    def norm(self):
        return np.exp(self.log_norm)

    def log_pdf(self, y):
        """ Logarithm of probability density function.

        Args:
            y: Observations with shape (..., D), i.e. (1, N, D).

        Returns: Log-probability density with properly broadcasted shape.
        """
        y = y / np.maximum(
            np.linalg.norm(y, axis=-1, keepdims=True), np.finfo(y.dtype).tiny
        )
        result = np.einsum("...d,...d", y, self.mean[..., None, :])
        result *= self.concentration[..., None]
        result -= self.log_norm()[..., None]
        return result

    def pdf(self, y):
        """ Probability density function.

        Args:
            y: Observations with shape (..., D), i.e. (1, N, D).

        Returns: Probability density with properly broadcasted shape.
        """
        return np.exp(self.log_pdf(y))


class VonMisesFisherTrainer:
    def fit(
        self, y, saliency=None, min_concentration=1e-10, max_concentration=500
    ) -> VonMisesFisher:
        """ Fits a von Mises Fisher distribution.

        Broadcasting (for sources) has to be done outside this function.

        Args:
            y: Observations with shape (..., N, D)
            saliency: Either None or weights with shape (..., N)
            min_concentration:
            max_concentration:
        """
        assert np.isrealobj(y), y.dtype
        y = y / np.maximum(
            np.linalg.norm(y, axis=-1, keepdims=True), np.finfo(y.dtype).tiny
        )
        if saliency is not None:
            assert is_broadcast_compatible(y.shape[:-1], saliency.shape), (
                y.shape,
                saliency.shape,
            )
        return self._fit(
            y,
            saliency=saliency,
            min_concentration=min_concentration,
            max_concentration=max_concentration,
        )

    def _fit(
        self, y, saliency, min_concentration, max_concentration
    ) -> VonMisesFisher:

        D = y.shape[-1]

        if saliency is None:
            saliency = np.ones(y.shape[:-1])

        # [Banerjee2005vMF] Equation 2.4
        r = np.einsum("...n,...nd->...d", saliency, y)
        norm = np.linalg.norm(r, axis=-1)
        mean = r / np.maximum(norm, np.finfo(y.dtype).tiny)[..., None]

        # [Banerjee2005vMF] Equation 2.5
        r_bar = norm / np.sum(saliency, axis=-1)

        # [Banerjee2005vMF] Equation 4.4
        concentration = (r_bar * D - r_bar ** 3) / (1 - r_bar ** 2)
        concentration = np.clip(
            concentration, min_concentration, max_concentration
        )
        return VonMisesFisher(mean=mean, concentration=concentration)