__all__ = ['GradientColour', 'CmapColour', 'KMeansColour', 'GreyscaleMeanColour']

from typing import Union, Iterable, Tuple, List
from abc import ABC, abstractmethod

import cv2
import numpy as np
import matplotlib as mpl


class Colour(ABC):
    def __repr__(self):
        """
        Auto __repr__ based on instance __dict__
        Parameters with a leading _ are omitted from the repr and thus from the hash
        """

        d = [f'{k}={v}' for k, v in self.__dict__.items() if not k.startswith('_')]
        return f'{self.__class__.__name__}({", ".join(d)})'

    def __hash__(self):
        h = []
        for k, v in self.__dict__.items():
            if not k.startswith('_'):
                h.append(tuple(v) if isinstance(v, list) else v)

        return hash(tuple(h))

    def __eq__(self, other):
        return hash(self) == hash(other)

    @abstractmethod
    def __call__(self, m: int) -> Iterable[Tuple[float, ...]]:
        pass


class GradientColour(Colour):
    def __init__(self, colour_list: Union[List, Tuple]):
        """
        Create GradientColour object to be passed to SpeckPlot.draw method.
        Colours each line according to a generated colour between the provided checkpoint colours.
        :param colour_list: colours between which colour gradients are generated to colour each line
        """

        if len(colour_list) == 1:
            colour_list = [colour_list[0], colour_list[0]]
        self.colour_list = colour_list
        self._cmap = mpl.colors.LinearSegmentedColormap.from_list("", colour_list)

    def __call__(self, m: int) -> Iterable[Tuple[float, ...]]:
        return [self._cmap(x) for x in np.linspace(0, 1, m, endpoint=False)]


class CmapColour(Colour):
    def __init__(self, cmap: Union[str, mpl.colors.Colormap]):
        """
        Create CmapColour object to be passed to SpeckPlot.draw method.
        Colours each line according to pre-defined matplotlib cmap.
        :param cmap: matplotlib cmap object or name to generate line colours according to
        """

        self.cmap = mpl.cm.get_cmap(cmap) if isinstance(cmap, str) else cmap

    def __call__(self, m: int) -> Iterable[Tuple[float, ...]]:
        return [self.cmap(x) for x in np.linspace(0, 1, m, endpoint=False)]


class KMeansColour(Colour):
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 200, 0.1)
    flags = cv2.KMEANS_RANDOM_CENTERS

    def __init__(self, speck_plot, k: int = 5):
        """
        Create KMeansColour object to be passed to SpeckPlot.draw method.
        Clusters each horizontal line of pixel colour values into k groups using k-means to determine
        the dominant colour of that row, and then sets the line colour to that.
        :param speck_plot: SpeckPlot object to base colours on
        :param k: number of groups for k-means
        """

        if speck_plot.image.mode not in ('RGB', 'RGBA'):
            raise AssertionError('KMeansColour requires RGB image mode')
        else:
            self.im = np.array(speck_plot.image.convert('RGB'))

        self.k = k

    def _kmeans_colour(self, row: np.ndarray) -> Tuple:
        _, labels, palette = cv2.kmeans(
            row, self.k, None, self.criteria, 10, self.flags
        )
        _, counts = np.unique(labels, return_counts=True)

        return palette[np.argmax(counts)] / 255

    def __call__(self, m: int) -> Iterable[Tuple[float, ...]]:
        return list(
            map(
                tuple,
                np.squeeze(
                    np.apply_along_axis(self._kmeans_colour, 1, np.float32(self.im))
                ),
            )
        )


class GreyscaleMeanColour(Colour):
    def __init__(self, speck_plot):
        """
        Create GreyScacleMeanColour objrect to be passed to SpeckPlot.draw method.
        Takes the mean greyscale colour of each row of pixels and makes the line that colour.
        :param speck_plot: SpeckPlot object to base colours on
        """

        self.im = speck_plot.im

    def __call__(self, m: int) -> Iterable[Tuple[float, ...]]:
        return [(c, c, c) for c in np.array(self.im).mean(1) / 255.0]