from operator import xor

import numpy as np
from dataclasses import dataclass
from pb_bss.distribution.mixture_model_utils import (
    estimate_mixture_weight,
    log_pdf_to_affiliation,
)
from pb_bss.utils import labels_to_one_hot
from sklearn.cluster import KMeans

from . import Gaussian, GaussianTrainer
from .utils import _ProbabilisticModel


@dataclass
class GMM(_ProbabilisticModel):
    weight: np.array  # (..., K)
    gaussian: Gaussian

    def predict(self, x):
        return log_pdf_to_affiliation(
            self.weight,
            self.gaussian.log_pdf(x[..., None, :, :]),
        )


class GMMTrainer:
    def __init__(self, eps=1e-10):
        self.eps = eps
        self.log_likelihood_history = []

    def fit(
        self,
        y,
        initialization=None,
        num_classes=None,
        iterations=100,
        *,
        saliency=None,
        weight_constant_axis=(-1,),
        covariance_type="full",
        fixed_covariance=None,
    ):
        """

        Args:
            y: Shape (..., N, D)
            initialization: Affiliations between 0 and 1. Shape (..., K, N)
            num_classes: Scalar >0
            iterations: Scalar >0
            saliency: Importance weighting for each observation, shape (..., N)
            weight_constant_axis: The axis that is used to calculate the mean
                over the affiliations. The affiliations have the
                shape (..., K, N), so the default value means averaging over
                the sample dimension. Note that averaging over an independent
                axis is supported.
            covariance_type: Either 'full', 'diagonal', or 'spherical'
            fixed_covariance: Learned if None. Otherwise, you need to provide
                the correct shape.

        Returns:
        """
        assert xor(initialization is None, num_classes is None), (
            "Incompatible input combination. "
            "Exactly one of the two inputs has to be None: "
            f"{initialization is None} xor {num_classes is None}"
        )
        assert np.isrealobj(y), y.dtype

        if initialization is None and num_classes is not None:
            *independent, num_observations, _ = y.shape
            affiliation_shape = (*independent, num_classes, num_observations)
            initialization = np.random.uniform(size=affiliation_shape)
            initialization \
                /= np.einsum("...kn->...n", initialization)[..., None, :]

        if saliency is None:
            saliency = np.ones_like(initialization[..., 0, :])

        return self._fit(
            y,
            initialization=initialization,
            iterations=iterations,
            saliency=saliency,
            weight_constant_axis=weight_constant_axis,
            covariance_type=covariance_type,
            fixed_covariance=fixed_covariance,
        )

    def fit_predict(
        self,
        y,
        initialization=None,
        num_classes=None,
        iterations=100,
        *,
        saliency=None,
        weight_constant_axis=(-2,),
        covariance_type="full",
        fixed_covariance=None,
    ):
        """Fit a model. Then just return the posterior affiliations."""
        model = self.fit(
            y=y,
            initialization=initialization,
            num_classes=num_classes,
            iterations=iterations,
            saliency=saliency,
            weight_constant_axis=weight_constant_axis,
            covariance_type=covariance_type,
            fixed_covariance=fixed_covariance,
        )
        return model.predict(y)

    def _fit(
            self,
            y,
            initialization,
            iterations,
            saliency,
            covariance_type,
            fixed_covariance,
            weight_constant_axis,
    ):
        affiliation = initialization  # TODO: Do we need np.copy here?
        model = None
        for iteration in range(iterations):
            if model is not None:
                affiliation = model.predict(y)

            model = self._m_step(
                y,
                affiliation=affiliation,
                saliency=saliency,
                weight_constant_axis=weight_constant_axis,
                covariance_type=covariance_type,
                fixed_covariance=fixed_covariance,
            )

        return model

    def _m_step(
            self,
            x,
            affiliation,
            saliency,
            weight_constant_axis,
            covariance_type,
            fixed_covariance,
    ):
        weight = estimate_mixture_weight(
            affiliation=affiliation,
            saliency=saliency,
            weight_constant_axis=weight_constant_axis,
        )

        gaussian = GaussianTrainer()._fit(
            y=x[..., None, :, :],
            saliency=affiliation * saliency[..., None, :],
            covariance_type=covariance_type
        )

        if fixed_covariance is not None:
            assert fixed_covariance.shape == gaussian.covariance.shape, (
                f'{fixed_covariance.shape} != {gaussian.covariance.shape}'
            )
            gaussian = gaussian.__class__(
                mean=gaussian.mean,
                covariance=fixed_covariance
            )

        return GMM(weight=weight, gaussian=gaussian)


@dataclass
class BinaryGMM(_ProbabilisticModel):
    kmeans: KMeans  # from sklearn

    def predict(self, x):
        """

        Args:
            x: Shape (N, D)

        Returns: Affiliation with shape (K, N)

        """
        N, D = x.shape
        assert np.isrealobj(x), x.dtype

        labels = self.kmeans.predict(x)
        affiliations = labels_to_one_hot(
            labels, self.kmeans.n_clusters, axis=-2, keepdims=False,
            dtype=x.dtype
        )
        assert affiliations.shape == (self.kmeans.n_clusters, N)
        return affiliations


class BinaryGMMTrainer:
    """k-means trainer.
    This is a specific wrapper of sklearn's kmeans for Deep Clustering
    embeddings. This explains the variable names and also the fixed shape for
    the embeddings.
    """
    def fit(
        self,
        x,
        num_classes,
        saliency=None
    ):
        """

        Args:
            x: Shape (N, D)
            num_classes: Scalar >0
            saliency: Importance weighting for each observation, shape (N,)
                Saliency has to be boolean.

        """
        N, D = x.shape
        if saliency is not None:
            assert saliency.dtype == np.bool, (
                'Only boolean saliency supported. '
                f'Current dtype: {saliency.dtype}.'
            )
            assert saliency.shape == (N,)
            x = x[saliency, :]
        return BinaryGMM(kmeans=KMeans(n_clusters=num_classes).fit(x))