# -*- coding: utf-8 -*-
from dataclasses import dataclass
from typing import List
from typing import Optional

import matplotlib.colors as clr
import numpy as np

from nata.plots.types import BasePlot


@dataclass
class ScatterPlot(BasePlot):
    """Color plot class.

    Parameters
    ----------
    s: ``float``, optional
        Marker size in in points**2. If not provided, defaults to ``0.1``

    c: ``str``, optional
        Color of the markers. See :mod:`matplotlib.colors` for available
        options.

    marker: ``str``, optional
        Marker style. See :mod:`matplotlib.markers` for available options.

    alpha: ``float``, optional
        Marker alpha value. Must be between ``0`` and ``1``.

    vmin: ``float``, optional
        Minimum of the colorbar axis. If not provided, it is
        inferred from the dataset represented in the plot.

    vmax: ``float``, optional
        Same as ``vmin`` for the maximum of the colorbar axis.

    cb_title: ``str``, optional
        Colorbar title. If not provided, it is inferred from the dataset
        represented in the plot.

    cb_scale: ``{'linear','log', 'symlog'}``, optional
        Scale of the colorbar. If not provided, defaults to ``'linear'``.

    cb_map: ``str``, optional
        Colormap used to represent the data. See
        :func:`matplotlib.pyplot.colormaps` for available options. If not
        provided, defaults to ``rainbow``.

    cb_linthresh: ``float``, optional
        Range within which the colorbar axis is linear. Applicable only when
        ``cb_scale`` is set to ``'symlog'``. If not provided, defaults to
        ``1e-5``.

    Notes
    -----
    All colorbar parameters are only applicable if the dataset represented in
    the plot has a quantity to be represented in color. In this case, ``c`` is
    overriden if set.

    """

    s: Optional[float] = 0.1
    c: Optional[str] = None
    marker: Optional[str] = None
    alpha: Optional[float] = None
    vmin: Optional[float] = None
    vmax: Optional[float] = None
    cb_map: Optional[str] = "rainbow"
    cb_scale: Optional[str] = "linear"
    cb_linthresh: Optional[float] = 1e-5
    cb_title: Optional[str] = None

    def __post_init__(self):
        if self.cb_title is None and len(self.data.axes) > 2:
            self.cb_title = self.data.axes[2].get_label(units=True)

        if self.has_cb:
            self.c = self.data.data[2]

        super().__post_init__()

    @property
    def has_cb(self):
        return len(self.data.axes) > 2

    def _default_xlim(self):
        return (np.min(self.data.data[0]), np.max(self.data.data[0]))

    def _default_ylim(self):
        return (np.min(self.data.data[1]), np.max(self.data.data[1]))

    def _default_xlabel(self, units=True):
        return self.data.axes[0].get_label(units)

    def _default_ylabel(self, units=True):
        return self.data.axes[1].get_label(units)

    def _default_title(self):
        return self.data.get_time_label()

    def _default_label(self):
        return self.data.get_label(units=False)

    def _xunits(self):
        return f"${self.data.axes[0].units}$" if self.data.axes[0].units else ""

    def _yunits(self):
        return f"${self.data.axes[1].units}$" if self.data.axes[1].units else ""

    def build_canvas(self):
        # get plot axes and data
        x = self.data.data[0]
        y = self.data.data[1]

        # build color map norm
        if self.cb_scale == "log":
            self.cb_norm = clr.LogNorm(vmin=self.vmin, vmax=self.vmax)
        elif self.cb_scale == "symlog":
            self.cb_norm = clr.SymLogNorm(
                vmin=self.vmin,
                vmax=self.vmax,
                linthresh=self.cb_linthresh,
                base=10,
            )
        else:
            self.cb_norm = clr.Normalize(vmin=self.vmin, vmax=self.vmax)

        # build plot
        self.h = self.axes.ax.scatter(
            x,
            y,
            s=self.s,
            c=self.c,
            marker=self.marker,
            alpha=self.alpha,
            label=self.label,
            cmap=self.cb_map,
            norm=self.cb_norm
            # antialiased=self.antialiased,
        )

        if self.has_cb:
            self.axes.init_colorbar(plot=self)

    @classmethod
    def style_attrs(cls) -> List[str]:
        return [
            "s",
            "c",
            "marker",
            "alpha",
            "vmin",
            "vmax",
            "cb_map",
            "cb_scale",
            "cb_linthresh",
            "cb_title",
        ] + BasePlot.style_attrs()