# -*- coding: utf-8 -*-
"""
Functions for computing various metrics to aid interpretation of similarity
network fusion outputs.
"""

import numpy as np
from sklearn.cluster import spectral_clustering
from sklearn.metrics import v_measure_score
from sklearn.utils.validation import check_random_state
from . import compute


def nmi(labels):
    """
    Calculates normalized mutual information for all combinations of `labels`

    Uses :py:func:`sklearn.metrics.v_measure_score` for calculation; refer to
    that codebase for information on algorithm.

    Parameters
    ----------
    labels : m-length list of (N,) array_like
        List of label arrays

    Returns
    -------
    nmi : (m x m) np.ndarray
        NMI score for all combinations of `labels`

    Examples
    --------
    >>> import numpy as np
    >>> label1 = np.array([1, 1, 1, 2, 2, 2])
    >>> label2 = np.array([1, 1, 2, 2, 2, 2])

    >>> from snf import metrics
    >>> metrics.nmi([label1, label2])
    array([[1.        , 0.47870397],
           [0.47870397, 1.        ]])
    """

    # create empty array for output
    nmi = np.empty(shape=(len(labels), len(labels)))
    # get indices for all combinations of labels and calculate NMI
    for x, y in np.column_stack(np.triu_indices_from(nmi)):
        nmi[x, y] = v_measure_score(labels[x], labels[y])
    # make output symmetric
    nmi = np.triu(nmi) + np.triu(nmi, k=1).T

    return nmi


def rank_feature_by_nmi(inputs, W, *, K=20, mu=0.5, n_clusters=None):
    """
    Calculates NMI of each feature in `inputs` with `W`

    Parameters
    ----------
    inputs : list-of-tuple
        Each tuple should contain (1) an (N, M) data array, where N is samples
        M is features, and (2) a string indicating the metric to use to compute
        a distance matrix for the given data. This MUST be one of the options
        available in :py:func:`scipy.spatial.distance.cdist`
    W : (N, N) array_like
        Similarity array generated by :py:func:`snf.compute.snf`
    K : (0, N) int, optional
        Hyperparameter normalization factor for scaling. Default: 20
    mu : (0, 1) float, optional
        Hyperparameter normalization factor for scaling. Default: 0.5
    n_clusters : int, optional
        Number of desired clusters. Default: determined by eigengap (see
        `snf.get_n_clusters()`)

    Returns
    -------
    nmi : list of (M,) np.ndarray
        Normalized mutual information scores for each feature of input arrays
    """

    if n_clusters is None:
        n_clusters = compute.get_n_clusters(W)[0]
    snf_labels = spectral_clustering(W, n_clusters)
    nmi = [np.empty(shape=(d.shape[-1])) for d, m in inputs]
    for ndtype, (dtype, metric) in enumerate(inputs):
        for nfeature, feature in enumerate(np.asarray(dtype).T):
            aff = compute.make_affinity(np.vstack(feature), K=K, mu=mu,
                                        metric=metric)
            aff_labels = spectral_clustering(aff, n_clusters)
            nmi[ndtype][nfeature] = v_measure_score(snf_labels, aff_labels)

    return nmi


def _silhouette_samples(arr, labels):
    """
    Calculates modified silhouette score from affinity matrix

    The Silhouette Coefficient is calculated using the mean intra-cluster
    affinity (`a`) and the mean nearest-cluster affinity (`b`) for each
    sample. The Silhouette Coefficient for a sample is `(b - a) / max(a,b)`.
    To clarify, `b` is the distance between a sample and the nearest cluster
    that the sample is not a part of. This corresponds to the cluster with the
    next *highest* affinity (opposite how this metric would be computed for a
    distance matrix).

    Parameters
    ----------
    arr : (N, N) array_like
        Array of pairwise affinities between samples
    labels : (N,) array_like
        Predicted labels for each sample

    Returns
    -------
    sil_samples : (N,) np.ndarray
        Modified (affinity) silhouette scores for each sample

    Notes
    -----
    Code is *lightly* modified from the `sklearn` implementation. See:
    `sklearn.metrics.silhouette_samples`

    References
    ----------
    .. [1] `Peter J. Rousseeuw (1987). Silhouettes: a Graphical Aid to the
       Interpretation and Validation of Cluster Analysis. Computational
       and Applied Mathematics, 20, 53-65.
       <http://www.sciencedirect.com/science/article/pii/0377042787901257>`_
    .. [2] `Wikipedia entry on the Silhouette Coefficient
       <https://en.wikipedia.org/wiki/Silhouette_(clustering)>`_
    .. [3] `Pedregosa, F., Varoquaux, G., Gramfort, A., Michel, V., Thirion,
       B., Grisel, O., ... & Vanderplas, J. (2011). Scikit-learn: Machine
       learning in Python. Journal of Machine Learning Research, 12, 2825-2830.
       <https://github.com/scikit-learn/>`_
    """

    from sklearn.preprocessing import LabelEncoder
    from sklearn.utils import check_X_y

    def check_number_of_labels(n_labels, n_samples):
        if not 1 < n_labels < n_samples:
            raise ValueError("Number of labels is %d. Valid values are 2 "
                             "to n_samples - 1 (inclusive)" % n_labels)

    arr, labels = check_X_y(arr, labels, accept_sparse=['csc', 'csr'],
                            copy=True)
    arr[np.diag_indices_from(arr)] = 0
    le = LabelEncoder()
    labels = le.fit_transform(labels)
    check_number_of_labels(len(le.classes_), arr.shape[0])

    unique_labels = le.classes_
    n_samples_per_label = np.bincount(labels, minlength=len(unique_labels))

    # For sample i, store the mean distance of the cluster to which
    # it belongs in intra_clust_dists[i]
    intra_clust_aff = np.zeros(arr.shape[0], dtype=arr.dtype)

    # For sample i, store the mean distance of the second closest
    # cluster in inter_clust_dists[i]
    inter_clust_aff = intra_clust_aff.copy()

    for curr_label in range(len(unique_labels)):

        # Find inter_clust_dist for all samples belonging to the same
        # label.
        mask = labels == curr_label
        current_distances = arr[mask]

        # Leave out current sample.
        n_samples_curr_lab = n_samples_per_label[curr_label] - 1
        if n_samples_curr_lab != 0:
            intra_clust_aff[mask] = np.sum(
                current_distances[:, mask], axis=1) / n_samples_curr_lab

        # Now iterate over all other labels, finding the mean
        # cluster distance that is closest to every sample.
        for other_label in range(len(unique_labels)):
            if other_label != curr_label:
                other_mask = labels == other_label
                other_distances = np.mean(
                    current_distances[:, other_mask], axis=1)
                inter_clust_aff[mask] = np.maximum(
                    inter_clust_aff[mask], other_distances)

    sil_samples = intra_clust_aff - inter_clust_aff
    sil_samples /= np.maximum(intra_clust_aff, inter_clust_aff)

    # score 0 for clusters of size 1, according to the paper
    sil_samples[n_samples_per_label.take(labels) == 1] = 0

    return sil_samples


def silhouette_score(arr, labels):
    """
    Calculates modified silhouette score from affinity matrix

    The Silhouette Coefficient is calculated using the mean intra-cluster
    affinity (`a`) and the mean nearest-cluster affinity (`b`) for each
    sample. The Silhouette Coefficient for a sample is `(b - a) / max(a,b)`.
    To clarify, `b` is the distance between a sample and the nearest cluster
    that the sample is not a part of. This corresponds to the cluster with the
    next *highest* affinity (opposite how this metric would be computed for a
    distance matrix).

    Parameters
    ----------
    arr : (N, N) array_like
        Array of pairwise affinities between samples
    labels : (N,) array_like
        Predicted labels for each sample

    Returns
    -------
    silhouette_score : float
        Modified (affinity) silhouette score

    Notes
    -----
    Code is *lightly* modified from the ``sklearn`` implementation. See:
    `sklearn.metrics.silhouette_score`
    """

    return np.mean(_silhouette_samples(arr, labels))


def affinity_zscore(arr, labels, n_perms=1000, seed=None):
    """
    Calculates z-score of silhouette (affinity) score by permutation

    Parameters
    ----------
    arr : (N, N) array_like
        Array of pairwise affinities between samples
    labels : (N,) array_like
        Predicted labels for each sample
    n_perms : int, optional
        Number of permutations. Default: 1000
    seed : int, optional
        Random seed. Default: None

    Returns
    -------
    z_aff : float
        Z-score of silhouette (affinity) score
    """

    rs = check_random_state(seed)

    dist = np.empty(shape=(n_perms,))
    for perm in range(n_perms):
        new_labels = rs.permutation(labels)
        dist[perm] = silhouette_score(arr, new_labels)

    true_aff_score = silhouette_score(arr, labels)
    z_aff = (true_aff_score - dist.mean()) / dist.std()

    return z_aff