from typing import Tuple, List, Dict, Optional

import warnings

import matplotlib
import matplotlib.pyplot as plt
from IPython.display import clear_output
from livelossplot.main_logger import MainLogger, LogItem
from livelossplot.outputs.base_output import BaseOutput


class MatplotlibPlot(BaseOutput):
    """NOTE: Removed figsize and dynamix_x_axis."""
    def __init__(
        self,
        cell_size: Tuple[int, int] = (6, 4),
        max_cols: int = 2,
        max_epoch: int = None,
        skip_first: int = 2,
        extra_plots=[],
        figpath: Optional[str] = None
    ):
        self.cell_size = cell_size
        self.max_cols = max_cols
        self.max_epoch = max_epoch
        self.skip_first = skip_first  # think about it
        self.extra_plots = extra_plots
        self.max_epoch = max_epoch
        self.figpath = figpath
        self.file_idx = 0  # now only for saving files

    def send(self, logger: MainLogger):
        """Draw figures with metrics and show"""
        log_groups = logger.grouped_log_history()
        figsize_x = self.max_cols * self.cell_size[0]
        figsize_y = ((len(log_groups) + 1) // self.max_cols + 1) * self.cell_size[1]

        max_rows = (len(log_groups) + len(self.extra_plots) + 1) // self.max_cols + 1

        clear_output(wait=True)
        plt.figure(figsize=(figsize_x, figsize_y))

        for group_id, (group_name, group_logs) in enumerate(log_groups.items()):
            plt.subplot(max_rows, self.max_cols, group_id + 1)
            self._draw_metric_subplot(group_logs, group_name=group_name)

        for i, extra_plot in enumerate(self.extra_plots):
            plt.subplot(max_rows, self.max_cols, i + len(log_groups) + 1)
            extra_plot(logger)

        plt.tight_layout()
        if self.figpath is not None:
            plt.savefig(self.figpath.format(i=self.file_idx))
            self.file_idx += 1
        plt.show()

    def _draw_metric_subplot(self, group_logs: Dict[str, List[LogItem]], group_name: str = ''):
        # there used to be skip first part, but I skip it first
        if self.max_epoch is not None:
            plt.xlim(0, self.max_epoch)

        for name, logs in group_logs.items():
            if len(logs) > 0:
                xs = [log.step for log in logs]
                ys = [log.value for log in logs]
                plt.plot(xs, ys, label=name)

        plt.title(group_name)
        plt.xlabel('epoch')
        plt.legend(loc='center right')

    def _not_inline_warning(self):
        backend = matplotlib.get_backend()
        if "backend_inline" not in backend:
            warnings.warn(
                "livelossplot requires inline plots.\nYour current backend is: {}"
                "\nRun in a Jupyter environment and execute '%matplotlib inline'.".format(backend)
            )