"""
Utility functions for affinity/similarity matrices.
"""

# Author: Oualid Benkarim <oualid.benkarim@mcgill.ca>
# License: BSD 3 clause


import numpy as np
from scipy import sparse as ssp


def is_symmetric(x, tol=1E-10):
    """Check if input is symmetric.

    Parameters
    ----------
    x : 2D ndarray or sparse matrix
        Input data.
    tol : float, optional
        Maximum allowed tolerance for equivalence. Default is 1e-10.

    Returns
    -------
    is_symm : bool
        True if `x` is symmetric. False, otherwise.

    Raises
    ------
    ValueError
        If `x` is not square.

    """

    if x.ndim != 2 or x.shape[0] != x.shape[1]:
        raise ValueError('Array is not square.')

    if ssp.issparse(x):
        if x.format not in ['csr', 'csc', 'coo']:
            x = x.tocoo(copy=False)
        dif1 = x - x.T
        return np.all(np.abs(dif1.data) < tol)

    return np.allclose(x, x.T, atol=tol)


def make_symmetric(x, check=True, tol=1E-10, copy=True, sparse_format=None):
    """Make array symmetric.

    Parameters
    ----------
    x : 2D ndarray or sparse matrix
        Input data.
    check : bool, optional
        If True, check if already symmetry first. Default is True.
    tol : float, optional
        Maximum allowed tolerance for equivalence. Default is 1e-10.
    copy : bool, optional
        If True, return a copy. Otherwise, work on `x`.
        If already symmetric, returns original array.
    sparse_format : {'coo', 'csr', 'csc', ...}, optional
        Format of output symmetric matrix. Only used if `x` is sparse.
        Default is None, uses original format.

    Returns
    -------
    sym : 2D ndarray or sparse matrix.
        Symmetrized version of `x`. Return `x` it is already
        symmetric.

    Raises
    ------
    ValueError
        If `x` is not square.

    """

    if not check or not is_symmetric(x, tol=tol):
        if copy:
            xs = .5 * (x + x.T)
            if ssp.issparse(x):
                if sparse_format is None:
                    sparse_format = x.format
                conversion = 'to' + sparse_format
                return getattr(xs, conversion)(copy=False)
            return xs
        else:
            x += x.T
            if ssp.issparse(x):
                x.data *= .5
            else:
                x *= .5
    return x


def _dominant_set_sparse(s, k, is_thresh=False, norm=False):
    """Compute dominant set for a sparse matrix."""
    if is_thresh:
        mask = s > k
        idx, data = np.where(mask), s[mask]
        s = ssp.coo_matrix((data, idx), shape=s.shape)

    else:  # keep top k
        nr, nc = s.shape
        idx = np.argpartition(s, nc - k, axis=1)
        col = idx[:, -k:].ravel()  # idx largest
        row = np.broadcast_to(np.arange(nr)[:, None], (nr, k)).ravel()
        data = s[row, col].ravel()
        s = ssp.coo_matrix((data, (row, col)), shape=s.shape)

    if norm:
        s.data /= s.sum(axis=1).A1[s.row]

    return s.tocsr(copy=False)


def _dominant_set_dense(s, k, is_thresh=False, norm=False, copy=True):
    """Compute dominant set for a dense matrix."""

    if is_thresh:
        s = s.copy() if copy else s
        s[s <= k] = 0

    else:  # keep top k
        nr, nc = s.shape
        idx = np.argpartition(s, nc - k, axis=1)
        row = np.arange(nr)[:, None]
        if copy:
            col = idx[:, -k:]  # idx largest
            data = s[row, col]
            s = np.zeros_like(s)
            s[row, col] = data
        else:
            col = idx[:, :-k]  # idx smallest
            s[row, col] = 0

    if norm:
        s /= np.nansum(s, axis=1, keepdims=True)

    return s


def dominant_set(s, k, is_thresh=False, norm=False, copy=True, as_sparse=True):
    """Keep largest elements for each row. Zero-out the rest.

    Parameters
    ----------
    s : 2D ndarray
        Similarity/affinity matrix.
    k :  int or float
        If int, keep top `k` elements for each row. If float, keep top `100*k`
        percent of elements. When float, must be in range (0, 1).
    is_thresh : bool, optional
        If True, `k` is used as threshold. Keep elements greater than `k`.
        Default is False.
    norm : bool, optional
        If True, normalize rows. Default is False.
    copy : bool, optional
        If True, make a copy of the input array. Otherwise, work on original
        array. Default is True.
    as_sparse : bool, optional
        If True, return a sparse matrix. Otherwise, return the same type of the
        input array. Default is True.

    Returns
    -------
    output : 2D ndarray or sparse matrix
        Dominant set.

    """

    if not is_thresh:
        nr, nc = s.shape
        if isinstance(k, float):
            if not 0 < k < 1:
                raise ValueError('When \'k\' is float, it must be 0<k<1.')
            k = int(nc * k)

        if k <= 0:
            raise ValueError('Cannot select 0 elements.')

    if as_sparse:
        return _dominant_set_sparse(s, k, is_thresh=is_thresh, norm=norm)

    return _dominant_set_dense(s, k, is_thresh=is_thresh, norm=norm, copy=copy)