# -*- coding: utf-8 -*-

"""Color routines."""

#------------------------------------------------------------------------------
# Imports
#------------------------------------------------------------------------------

import colorcet as cc
import logging

from phylib.utils import Bunch
from phylib.io.array import _index_of

import numpy as np
from numpy.random import uniform
from matplotlib.colors import hsv_to_rgb, rgb_to_hsv

logger = logging.getLogger(__name__)


#------------------------------------------------------------------------------
# Random colors
#------------------------------------------------------------------------------

def _random_color(h_range=(0., 1.), s_range=(.5, 1.), v_range=(.5, 1.)):
    """Generate a random RGB color."""
    h, s, v = uniform(*h_range), uniform(*s_range), uniform(*v_range)
    r, g, b = hsv_to_rgb(np.array([[[h, s, v]]])).flat
    return r, g, b


def _is_bright(rgb):
    """Return whether a RGB color is bright or not.
    see https://stackoverflow.com/a/3943023/1595060
    """
    L = 0
    for c, coeff in zip(rgb, (0.2126, 0.7152, 0.0722)):
        if c <= 0.03928:
            c = c / 12.92
        else:
            c = ((c + 0.055) / 1.055) ** 2.4
        L += c * coeff
    if (L + 0.05) / (0.0 + 0.05) > (1.0 + 0.05) / (L + 0.05):
        return True


def _random_bright_color():
    """Generate a random bright color."""
    rgb = _random_color()
    while not _is_bright(rgb):
        rgb = _random_color()
    return rgb


def _hex_to_triplet(h):
    """Convert an hexadecimal color to a triplet of int8 integers."""
    if h.startswith('#'):
        h = h[1:]
    return tuple(int(h[i:i + 2], 16) for i in (0, 2, 4))


def _override_hsv(rgb, h=None, s=None, v=None):
    h_, s_, v_ = rgb_to_hsv(np.array([[rgb]])).flat
    h = h if h is not None else h_
    s = s if s is not None else s_
    v = v if v is not None else v_
    r, g, b = hsv_to_rgb(np.array([[[h, s, v]]])).flat
    return r, g, b


#------------------------------------------------------------------------------
# Colormap utilities
#------------------------------------------------------------------------------

def _selected_cluster_idx(selected_clusters, cluster_ids):
    selected_clusters = np.asarray(selected_clusters, dtype=np.int32)
    cluster_ids = np.asarray(cluster_ids, dtype=np.int32)
    kept = np.isin(selected_clusters, cluster_ids)
    clu_idx = _index_of(selected_clusters[kept], cluster_ids)
    cmap_idx = np.arange(len(selected_clusters))[kept]
    return clu_idx, cmap_idx


def _continuous_colormap(colormap, values, vmin=None, vmax=None):
    """Convert values into colors given a specified continuous colormap."""
    assert values is not None
    assert colormap.shape[1] == 3
    n = colormap.shape[0]
    vmin = vmin if vmin is not None else values.min()
    vmax = vmax if vmax is not None else values.max()
    assert vmin is not None
    assert vmax is not None
    denom = vmax - vmin
    denom = denom if denom != 0 else 1
    # NOTE: clipping is necessary when a view using color selector (like the raster view)
    # is updated right after a clustering update, but before the vmax had a chance to
    # be updated.
    i = np.clip(np.round((n - 1) * (values - vmin) / denom).astype(np.int32), 0, n - 1)
    return colormap[i, :]


def _categorical_colormap(colormap, values, vmin=None, vmax=None, categorize=None):
    """Convert values into colors given a specified categorical colormap."""
    assert np.issubdtype(values.dtype, np.integer)
    assert colormap.shape[1] == 3
    n = colormap.shape[0]
    if categorize is True or (categorize is None and vmin is None and vmax is None):
        # Find unique values and keep the order.
        _, idx = np.unique(values, return_index=True)
        lookup = values[np.sort(idx)]
        x = _index_of(values, lookup)
    else:
        x = values
    return colormap[x % n, :]


#------------------------------------------------------------------------------
# Colormaps
#------------------------------------------------------------------------------

# Default color map for the selected clusters.
# see https://colorcet.pyviz.org/user_guide/Categorical.html
def _make_default_colormap():
    """Return the default colormap, with custom first colors."""
    colormap = np.array(cc.glasbey_bw_minc_20_minl_30)
    # Reorder first colors.
    colormap[[0, 1, 2, 3, 4, 5]] = colormap[[3, 0, 4, 5, 2, 1]]
    # Replace first two colors.
    colormap[0] = [0.03137, 0.5725, 0.9882]
    colormap[1] = [1.0000, 0.0078, 0.0078]
    return colormap


def _make_cluster_group_colormap():
    """Return cluster group colormap."""
    return np.array([
        [0.4, 0.4, 0.4],  # noise
        [0.5, 0.5, 0.5],  # mua
        [0.5254, 0.8196, 0.42745],  # good
        [0.75, 0.75, 0.75],  # '' (None = '' = unsorted)
    ])


"""Built-in colormaps."""
colormaps = Bunch(
    blank=np.array([[.75, .75, .75]]),
    default=_make_default_colormap(),
    cluster_group=_make_cluster_group_colormap(),
    categorical=np.array(cc.glasbey_bw_minc_20_minl_30),
    rainbow=np.array(cc.rainbow_bgyr_35_85_c73),
    linear=np.array(cc.linear_wyor_100_45_c55),
    diverging=np.array(cc.diverging_linear_bjy_30_90_c45),
)


def selected_cluster_color(i, alpha=1.):
    """Return the color, as a 4-tuple, of the i-th selected cluster."""
    return add_alpha(tuple(colormaps.default[i % len(colormaps.default)]), alpha=alpha)


def spike_colors(spike_clusters, cluster_ids):
    """Return the colors of spikes according to the index of their cluster within `cluster_ids`.

    Parameters
    ----------

    spike_clusters : array-like
        The spike-cluster assignments.
    cluster_ids : array-like
        The set of unique selected cluster ids appearing in spike_clusters, in a given order

    Returns
    -------

    spike_colors : array-like
        For each spike, the RGBA color (in [0,1]) depending on the index of the cluster within
        `cluster_ids`.

    """
    spike_clusters_idx = _index_of(spike_clusters, cluster_ids)
    return add_alpha(colormaps.default[np.mod(spike_clusters_idx, colormaps.default.shape[0])])


def _add_selected_clusters_colors(selected_clusters, cluster_ids, cluster_colors=None):
    """Take an array with colors of clusters as input, and add colors of selected clusters."""
    # clu_idx contains the index of the selected clusters within cluster_ids
    # cmap_idx contains 0, 1, 2... as the colormap index, but without the selected clusters
    # that are missing in cluster_ids.
    clu_idx, cmap_idx = _selected_cluster_idx(selected_clusters, cluster_ids)
    colormap = _categorical_colormap(colormaps.default, cmap_idx, categorize=False)
    # Inject those colors in cluster_colors.
    cluster_colors[clu_idx] = add_alpha(colormap, 1)
    return cluster_colors


#------------------------------------------------------------------------------
# Cluster color selector
#------------------------------------------------------------------------------

def add_alpha(c, alpha=1.):
    """Add an alpha channel to an RGB color.

    Parameters
    ----------

    c : array-like (2D, shape[1] == 3) or 3-tuple
    alpha : float

    """
    if isinstance(c, (tuple,)):
        if len(c) == 4:
            c = c[:3]
        return c + (alpha,)
    elif isinstance(c, np.ndarray):
        if c.shape[-1] == 4:
            c = c[..., :3]
        assert c.shape[-1] == 3
        out = np.concatenate([c, alpha * np.ones((c.shape[:-1] + (1,)))], axis=-1)
        assert out.ndim == c.ndim
        assert out.shape[-1] == c.shape[-1] + 1
        return out
    raise ValueError("Unknown value given in add_alpha().")  # pragma: no cover


def _categorize(values):
    """Categorize a list of values by replacing strings and None values by integers."""
    if any(isinstance(v, str) for v in values):
        # HACK: replace None by empty string to avoid error when sorting the unique values.
        values = [str(v).lower() if v is not None else '' for v in values]
        uv = sorted(set(values))
        values = [uv.index(v) for v in values]
    return values


class ClusterColorSelector(object):
    """Assign a color to clusters depending on cluster labels or metrics."""
    _colormap = colormaps.categorical
    _categorical = True
    _logarithmic = False

    def __init__(
            self, fun=None, colormap=None, categorical=None, logarithmic=None, cluster_ids=None):
        self.cluster_ids = cluster_ids if cluster_ids is not None else ()
        self._fun = fun
        self.set_color_mapping(
            fun=fun, colormap=colormap, categorical=categorical, logarithmic=logarithmic)

    def set_color_mapping(
            self, fun=None, colormap=None, categorical=None, logarithmic=None):
        """Set the field used to choose the cluster colors, and the associated colormap.

        Parameters
        ----------

        fun : function
            Function cluster_id => value
        colormap : array-like
            A `(N, 3)` array with the colormaps colors
        categorical : boolean
            Whether the colormap is categorical (one value = one color) or continuous (values
            are continuously mapped from their initial interval to the colors).
        logarithmic : boolean
            Whether to use a logarithmic transform for the mapping.

        """
        self._fun = self._fun or fun
        if isinstance(colormap, str):
            colormap = colormaps[colormap]
        self._colormap = colormap if colormap is not None else self._colormap
        self._categorical = categorical if categorical is not None else self._categorical
        self._logarithmic = logarithmic if logarithmic is not None else self._logarithmic
        # Recompute the value range.
        self.set_cluster_ids(self.cluster_ids)

    def set_cluster_ids(self, cluster_ids):
        """Precompute the value range for all clusters."""
        self.cluster_ids = cluster_ids
        values = self.get_values(self.cluster_ids)
        if values is not None and len(values):
            self.vmin, self.vmax = values.min(), values.max()
        else:  # pragma: no cover
            self.vmin, self.vmax = 0, 1

    def map(self, values):
        """Convert values to colors using the selected colormap.

        Parameters
        ----------

        values : array-like (1D)

        Returns
        -------

        colors : array-like (2D, shape[1] == 3)

        """
        if self._logarithmic:
            assert np.all(values > 0)
            values = np.log(values)
            vmin, vmax = np.log(self.vmin), np.log(self.vmax)
        else:
            vmin, vmax = self.vmin, self.vmax
        assert values is not None
        # Use categorical or continuous colormap depending on the categorical option.
        f = (_categorical_colormap
             if self._categorical and np.issubdtype(values.dtype, np.integer)
             else _continuous_colormap)
        return f(self._colormap, values, vmin=vmin, vmax=vmax)

    def _get_cluster_value(self, cluster_id):
        """Return the field value for a given cluster."""
        return self._fun(cluster_id) if hasattr(self._fun, '__call__') else self._fun or 0

    def get(self, cluster_id, alpha=None):
        """Return the RGBA color of a single cluster."""
        assert self.cluster_ids is not None
        assert self._colormap is not None
        val = [self._get_cluster_value(cluster_id)]
        if self._categorical:
            val = _categorize(val)
        col = tuple(self.map(np.array(val))[0].tolist())
        return add_alpha(col, alpha=alpha)

    def get_values(self, cluster_ids):
        """Get the values of clusters for the selected color field.."""
        values = [self._get_cluster_value(cluster_id) for cluster_id in cluster_ids]
        if self._categorical:
            values = _categorize(values)
        return np.array(values)

    def get_colors(self, cluster_ids, alpha=1.):
        """Return the RGBA colors of some clusters."""
        values = self.get_values(cluster_ids)
        assert values is not None
        assert len(values) == len(cluster_ids)
        return add_alpha(self.map(values), alpha=alpha)