import numpy as np
from matplotlib import cm
from mindpark.stats.figure import Figure
from mindpark.plot import Scalar, Histogram
from mpl_toolkits.axes_grid1 import make_axes_locatable


class Metrics(Figure):

    """
    Generate figures showing multiple metrics as plots. Choose the plot type
    automatically, based on the data.
    """

    def __init__(self):
        self._plot_scalar = Scalar()
        self._plot_counts = Histogram()
        self._plot_distribution = Histogram(normalize=True)

    def __call__(self, metrics, title, filepath):
        self._validate_input(metrics)
        fig, ax = self._create_subplots(2, len(metrics))
        fig.suptitle(title, fontsize=16)
        names, metrics = zip(*metrics)
        self._label_columns(ax, names)
        self._label_rows(ax, ['Training', 'Evaluation'])
        for index, metric in enumerate(metrics):
            train = metric[metric.training == 1]
            test = metric[metric.training == 0]
            test.epoch -= 1
            self._process_metric(ax[0, index], train)
            self._process_metric(ax[1, index], test)
        self._save(fig, filepath)

    def _validate_input(self, metrics):
        assert all(len(x) == 2 for x in metrics)
        for name, metric in metrics:
            assert isinstance(name, str)
            assert all(isinstance(x, np.ndarray) for x in metric.values())

    def _process_metric(self, ax, metric):
        if not metric.data.size:
            ax.tick_params(colors=(0, 0, 0, 0))
            ax.set_axis_bgcolor(cm.get_cmap('viridis')(0))
            divider = make_axes_locatable(ax)
            divider.append_axes('right', size='7%', pad=0.1).axis('off')
            return
        domain = self._domain(metric)
        categorical = self._is_categorical(metric.data)
        if metric.data.shape[1] == 1 and not categorical:
            self._plot_scalar(ax, domain, metric.data[:, 0])
        elif metric.data.shape[1] == 1:
            indices = metric.data[:, 0].astype(int)
            min_, max_ = indices.min(), indices.max()
            count = np.eye(max_ - min_ + 1)[indices - min_]
            self._plot_distribution(ax, domain, count)
        elif metric.data.shape[1] > 1:
            self._plot_counts(ax, domain, metric.data)

    def _is_categorical(self, data):
        if not np.allclose(data, data.astype(int)):
            return False
        if np.unique(data.astype(int)).size > 8:
            return False
        return True