# -*- coding: UTF8 -*-
"""
some plotting functionality for different tasks
author: Michael Grupp

This file is part of evo (github.com/MichaelGrupp/evo).

evo is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

evo is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with evo.  If not, see <http://www.gnu.org/licenses/>.
"""

from __future__ import print_function  # Python 2.7 backwards compatibility

import os
import logging
import pickle
import collections
from enum import Enum

import matplotlib as mpl
from evo.tools.settings import SETTINGS

mpl.use(SETTINGS.plot_backend)
import matplotlib.cm as cm
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.art3d as art3d
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.collections import LineCollection

import numpy as np
import seaborn as sns

from evo import EvoException
from evo.tools import user
from evo.core import trajectory

# configure matplotlib and seaborn according to package settings
# TODO: 'color_codes=False' to work around this bug:
# https://github.com/mwaskom/seaborn/issues/1546
sns.set(style=SETTINGS.plot_seaborn_style, font=SETTINGS.plot_fontfamily,
        font_scale=SETTINGS.plot_fontscale, color_codes=False,
        palette=SETTINGS.plot_seaborn_palette)
rc = {
    "lines.linewidth": SETTINGS.plot_linewidth,
    "text.usetex": SETTINGS.plot_usetex,
    "font.family": SETTINGS.plot_fontfamily,
    "pgf.texsystem": SETTINGS.plot_texsystem
}
mpl.rcParams.update(rc)

logger = logging.getLogger(__name__)


class PlotException(EvoException):
    pass


class PlotMode(Enum):
    xy = "xy"
    xz = "xz"
    yx = "yx"
    yz = "yz"
    zx = "zx"
    zy = "zy"
    xyz = "xyz"


class PlotCollection:
    def __init__(self, title="", deserialize=None):
        self.title = " ".join(title.splitlines())  # one line title
        self.figures = collections.OrderedDict()  # remember placement order
        # hack to avoid premature garbage collection with Qt
        # (stackoverflow.com/questions/600289)
        self.root_window = None  # for now: init later in tabbed_qt_window
        if deserialize is not None:
            logger.debug("Deserializing PlotCollection from " + deserialize +
                         "...")
            self.figures = pickle.load(open(deserialize, 'rb'))

    def __str__(self):
        return self.title + " (" + str(len(self.figures)) + " figure(s))"

    def add_figure(self, name, fig):
        fig.tight_layout()
        self.figures[name] = fig

    def tabbed_qt4_window(self):
        from PyQt4 import QtGui
        from matplotlib.backends.backend_qt4agg import FigureCanvasQTAgg, NavigationToolbar2QT
        # mpl backend can already create instance
        # https://stackoverflow.com/a/40031190
        app = QtGui.QApplication.instance()
        if app is None:
            app = QtGui.QApplication([self.title])
        self.root_window = QtGui.QTabWidget()
        self.root_window.setWindowTitle(self.title)
        for name, fig in self.figures.items():
            tab = QtGui.QWidget(self.root_window)
            tab.canvas = FigureCanvasQTAgg(fig)
            vbox = QtGui.QVBoxLayout(tab)
            vbox.addWidget(tab.canvas)
            toolbar = NavigationToolbar2QT(tab.canvas, tab)
            vbox.addWidget(toolbar)
            tab.setLayout(vbox)
            for axes in fig.get_axes():
                if isinstance(axes, Axes3D):
                    # must explicitly allow mouse dragging for 3D plots
                    axes.mouse_init()
            self.root_window.addTab(tab, name)
        self.root_window.show()
        app.exec_()

    def tabbed_qt5_window(self):
        from PyQt5 import QtGui, QtWidgets
        from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT
        # mpl backend can already create instance
        # https://stackoverflow.com/a/40031190
        app = QtGui.QGuiApplication.instance()
        if app is None:
            app = QtWidgets.QApplication([self.title])
        self.root_window = QtWidgets.QTabWidget()
        self.root_window.setWindowTitle(self.title)
        for name, fig in self.figures.items():
            tab = QtWidgets.QWidget(self.root_window)
            tab.canvas = FigureCanvasQTAgg(fig)
            vbox = QtWidgets.QVBoxLayout(tab)
            vbox.addWidget(tab.canvas)
            toolbar = NavigationToolbar2QT(tab.canvas, tab)
            vbox.addWidget(toolbar)
            tab.setLayout(vbox)
            for axes in fig.get_axes():
                if isinstance(axes, Axes3D):
                    # must explicitly allow mouse dragging for 3D plots
                    axes.mouse_init()
            self.root_window.addTab(tab, name)
        self.root_window.show()
        app.exec_()

    def tabbed_tk_window(self):
        from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
        import sys
        if sys.version_info[0] < 3:
            import Tkinter as tkinter
            import ttk
        else:
            import tkinter
            from tkinter import ttk
        self.root_window = tkinter.Tk()
        self.root_window.title(self.title)
        # quit if the window is deleted
        self.root_window.protocol("WM_DELETE_WINDOW", self.root_window.quit)
        nb = ttk.Notebook(self.root_window)
        nb.grid(row=1, column=0, sticky='NESW')
        for name, fig in self.figures.items():
            fig.tight_layout()
            tab = ttk.Frame(nb)
            canvas = FigureCanvasTkAgg(self.figures[name], master=tab)
            canvas.draw()
            canvas.get_tk_widget().pack(side=tkinter.TOP, fill=tkinter.BOTH,
                                        expand=True)
            toolbar = NavigationToolbar2Tk(canvas, tab)
            toolbar.update()
            canvas._tkcanvas.pack(side=tkinter.TOP, fill=tkinter.BOTH,
                                  expand=True)
            for axes in fig.get_axes():
                if isinstance(axes, Axes3D):
                    # must explicitly allow mouse dragging for 3D plots
                    axes.mouse_init()
            nb.add(tab, text=name)
        nb.pack(side=tkinter.TOP, fill=tkinter.BOTH, expand=True)
        self.root_window.mainloop()
        self.root_window.destroy()

    def show(self):
        if len(self.figures.keys()) == 0:
            return
        if not SETTINGS.plot_split:
            if SETTINGS.plot_backend.lower() == "qt4agg":
                self.tabbed_qt4_window()
            elif SETTINGS.plot_backend.lower() == "qt5agg":
                self.tabbed_qt5_window()
            elif SETTINGS.plot_backend.lower() == "tkagg":
                self.tabbed_tk_window()
            else:
                plt.show()
        else:
            plt.show()

    def serialize(self, dest, confirm_overwrite=True):
        logger.debug("Serializing PlotCollection to " + dest + "...")
        if confirm_overwrite and not user.check_and_confirm_overwrite(dest):
            return
        else:
            pickle.dump(self.figures, open(dest, 'wb'))

    def export(self, file_path, confirm_overwrite=True):
        fmt = SETTINGS.plot_export_format.lower()
        if fmt == "pdf" and not SETTINGS.plot_split:
            if confirm_overwrite and not user.check_and_confirm_overwrite(
                    file_path):
                return
            import matplotlib.backends.backend_pdf
            pdf = matplotlib.backends.backend_pdf.PdfPages(file_path)
            for name, fig in self.figures.items():
                # fig.tight_layout()  # TODO
                pdf.savefig(fig)
            pdf.close()
            logger.info("Plots saved to " + file_path)
        else:
            for name, fig in self.figures.items():
                base, ext = os.path.splitext(file_path)
                dest = base + '_' + name + ext
                if confirm_overwrite and not user.check_and_confirm_overwrite(
                        dest):
                    return
                fig.tight_layout()
                fig.savefig(dest, fmt=fmt)
                logger.info("Plot saved to " + dest)


def set_aspect_equal_3d(ax):
    """
    kudos to https://stackoverflow.com/a/35126679
    :param ax: matplotlib 3D axes object
    """
    xlim = ax.get_xlim3d()
    ylim = ax.get_ylim3d()
    zlim = ax.get_zlim3d()

    from numpy import mean
    xmean = mean(xlim)
    ymean = mean(ylim)
    zmean = mean(zlim)

    plot_radius = max([
        abs(lim - mean_)
        for lims, mean_ in ((xlim, xmean), (ylim, ymean), (zlim, zmean))
        for lim in lims
    ])

    ax.set_xlim3d([xmean - plot_radius, xmean + plot_radius])
    ax.set_ylim3d([ymean - plot_radius, ymean + plot_radius])
    ax.set_zlim3d([zmean - plot_radius, zmean + plot_radius])


def prepare_axis(fig, plot_mode=PlotMode.xy, subplot_arg="111"):
    """
    prepares an axis according to the plot mode (for trajectory plotting)
    :param fig: matplotlib figure object
    :param plot_mode: PlotMode
    :param subplot_arg: optional if using subplots - the subplot id (e.g. '122')
    :return: the matplotlib axis
    """
    if plot_mode == PlotMode.xyz:
        ax = fig.add_subplot(subplot_arg, projection="3d", aspect="equal")
    else:
        ax = fig.add_subplot(subplot_arg, aspect="equal")
    plt.axis("equal")
    if plot_mode in {PlotMode.xy, PlotMode.xz, PlotMode.xyz}:
        xlabel = "$x$ (m)"
    elif plot_mode in {PlotMode.yz, PlotMode.yx}:
        xlabel = "$y$ (m)"
    else:
        xlabel = "$z$ (m)"
    if plot_mode in {PlotMode.xy, PlotMode.zy, PlotMode.xyz}:
        ylabel = "$y$ (m)"
    elif plot_mode in {PlotMode.zx, PlotMode.yx}:
        ylabel = "$x$ (m)"
    else:
        ylabel = "$z$ (m)"
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    if plot_mode == PlotMode.xyz:
        ax.set_zlabel('$z$ (m)')
    if SETTINGS.plot_invert_xaxis:
        plt.gca().invert_xaxis()
    if SETTINGS.plot_invert_yaxis:
        plt.gca().invert_yaxis()
    return ax


def plot_mode_to_idx(plot_mode):
    if plot_mode == PlotMode.xy or plot_mode == PlotMode.xyz:
        x_idx = 0
        y_idx = 1
    elif plot_mode == PlotMode.xz:
        x_idx = 0
        y_idx = 2
    elif plot_mode == PlotMode.yx:
        x_idx = 1
        y_idx = 0
    elif plot_mode == PlotMode.yz:
        x_idx = 1
        y_idx = 2
    elif plot_mode == PlotMode.zx:
        x_idx = 2
        y_idx = 0
    elif plot_mode == PlotMode.zy:
        x_idx = 2
        y_idx = 1
    z_idx = 2 if plot_mode == PlotMode.xyz else None
    return x_idx, y_idx, z_idx


def traj(ax, plot_mode, traj, style='-', color='black', label="", alpha=1.0):
    """
    plot a path/trajectory based on xyz coordinates into an axis
    :param ax: the matplotlib axis
    :param plot_mode: PlotMode
    :param traj: trajectory.PosePath3D or trajectory.PoseTrajectory3D object
    :param style: matplotlib line style
    :param color: matplotlib color
    :param label: label (for legend)
    :param alpha: alpha value for transparency
    """
    x_idx, y_idx, z_idx = plot_mode_to_idx(plot_mode)
    x = traj.positions_xyz[:, x_idx]
    y = traj.positions_xyz[:, y_idx]
    if plot_mode == PlotMode.xyz:
        z = traj.positions_xyz[:, z_idx]
        ax.plot(x, y, z, style, color=color, label=label, alpha=alpha)
        if SETTINGS.plot_xyz_realistic:
            set_aspect_equal_3d(ax)
    else:
        ax.plot(x, y, style, color=color, label=label, alpha=alpha)
    if label:
        ax.legend(frameon=True)


def colored_line_collection(xyz, colors, plot_mode=PlotMode.xy,
                            linestyles="solid"):
    if len(xyz) != len(colors):
        raise PlotException(
            "color values must have same length as xyz data: %d vs. %d" %
            (len(xyz), len(colors)))
    x_idx, y_idx, z_idx = plot_mode_to_idx(plot_mode)
    xs = [[x_1, x_2] for x_1, x_2 in zip(xyz[:-1, x_idx], xyz[1:, x_idx])]
    ys = [[x_1, x_2] for x_1, x_2 in zip(xyz[:-1, y_idx], xyz[1:, y_idx])]
    if plot_mode == PlotMode.xyz:
        zs = [[x_1, x_2] for x_1, x_2 in zip(xyz[:-1, z_idx], xyz[1:, z_idx])]
        segs = [list(zip(x, y, z)) for x, y, z in zip(xs, ys, zs)]
        line_collection = art3d.Line3DCollection(segs, colors=colors,
                                                 linestyles=linestyles)
    else:
        segs = [list(zip(x, y)) for x, y in zip(xs, ys)]
        line_collection = LineCollection(segs, colors=colors,
                                         linestyle=linestyles)
    return line_collection


def traj_colormap(ax, traj, array, plot_mode, min_map, max_map, title=""):
    """
    color map a path/trajectory in xyz coordinates according to
    an array of values
    :param ax: plot axis
    :param traj: trajectory.PosePath3D or trajectory.PoseTrajectory3D object
    :param array: Nx1 array of values used for color mapping
    :param plot_mode: PlotMode
    :param min_map: lower bound value for color mapping
    :param max_map: upper bound value for color mapping
    :param title: plot title
    """
    pos = traj.positions_xyz
    norm = mpl.colors.Normalize(vmin=min_map, vmax=max_map, clip=True)
    mapper = cm.ScalarMappable(
        norm=norm,
        cmap=SETTINGS.plot_trajectory_cmap)  # cm.*_r is reversed cmap
    mapper.set_array(array)
    colors = [mapper.to_rgba(a) for a in array]
    line_collection = colored_line_collection(pos, colors, plot_mode)
    ax.add_collection(line_collection)
    if plot_mode == PlotMode.xyz:
        ax.set_zlim(
            np.amin(traj.positions_xyz[:, 2]),
            np.amax(traj.positions_xyz[:, 2]))
        if SETTINGS.plot_xyz_realistic:
            set_aspect_equal_3d(ax)
    fig = plt.gcf()
    cbar = fig.colorbar(
        mapper, ticks=[min_map, (max_map - (max_map - min_map) / 2), max_map])
    cbar.ax.set_yticklabels([
        "{0:0.3f}".format(min_map),
        "{0:0.3f}".format(max_map - (max_map - min_map) / 2),
        "{0:0.3f}".format(max_map)
    ])
    if title:
        ax.legend(frameon=True)
        plt.title(title)


def traj_xyz(axarr, traj, style='-', color='black', label="", alpha=1.0,
             start_timestamp=None):
    """
    plot a path/trajectory based on xyz coordinates into an axis
    :param axarr: an axis array (for x, y & z)
                  e.g. from 'fig, axarr = plt.subplots(3)'
    :param traj: trajectory.PosePath3D or trajectory.PoseTrajectory3D object
    :param style: matplotlib line style
    :param color: matplotlib color
    :param label: label (for legend)
    :param alpha: alpha value for transparency
    :param start_timestamp: optional start time of the reference
                            (for x-axis alignment)
    """
    if len(axarr) != 3:
        raise PlotException("expected an axis array with 3 subplots - got " +
                            str(len(axarr)))
    if isinstance(traj, trajectory.PoseTrajectory3D):
        x = traj.timestamps - (traj.timestamps[0]
                               if start_timestamp is None else start_timestamp)
        xlabel = "$t$ (s)"
    else:
        x = range(0, len(traj.positions_xyz))
        xlabel = "index"
    ylabels = ["$x$ (m)", "$y$ (m)", "$z$ (m)"]
    for i in range(0, 3):
        axarr[i].plot(x, traj.positions_xyz[:, i], style, color=color,
                      label=label, alpha=alpha)
        axarr[i].set_ylabel(ylabels[i])
    axarr[2].set_xlabel(xlabel)
    if label:
        axarr[0].legend(frameon=True)


def traj_rpy(axarr, traj, style='-', color='black', label="", alpha=1.0,
             start_timestamp=None):
    """
    plot a path/trajectory's Euler RPY angles into an axis
    :param axarr: an axis array (for R, P & Y)
                  e.g. from 'fig, axarr = plt.subplots(3)'
    :param traj: trajectory.PosePath3D or trajectory.PoseTrajectory3D object
    :param style: matplotlib line style
    :param color: matplotlib color
    :param label: label (for legend)
    :param alpha: alpha value for transparency
    :param start_timestamp: optional start time of the reference
                            (for x-axis alignment)
    """
    if len(axarr) != 3:
        raise PlotException("expected an axis array with 3 subplots - got " +
                            str(len(axarr)))
    if isinstance(traj, trajectory.PoseTrajectory3D):
        x = traj.timestamps - (traj.timestamps[0]
                               if start_timestamp is None else start_timestamp)
        xlabel = "$t$ (s)"
    else:
        x = range(0, len(traj.orientations_euler))
        xlabel = "index"
    ylabels = ["$roll$ (deg)", "$pitch$ (deg)", "$yaw$ (deg)"]
    for i in range(0, 3):
        axarr[i].plot(x, np.rad2deg(traj.orientations_euler[:, i]), style,
                      color=color, label=label, alpha=alpha)
        axarr[i].set_ylabel(ylabels[i])
    axarr[2].set_xlabel(xlabel)
    if label:
        axarr[0].legend(frameon=True)


def trajectories(fig, trajectories, plot_mode=PlotMode.xy, title="",
                 subplot_arg="111"):
    """
    high-level function for plotting multiple trajectories
    :param fig: matplotlib figure
    :param trajectories: instances or iterables of PosePath3D or derived
    - if it's a dictionary, the keys (names) will be used as labels
    :param plot_mode: e.g. plot.PlotMode.xy
    :param title: optional plot title
    :param subplot_arg: optional matplotlib subplot ID if used as subplot
    """
    ax = prepare_axis(fig, plot_mode)
    cmap_colors = None
    if SETTINGS.plot_multi_cmap.lower() != "none":
        cmap = getattr(cm, SETTINGS.plot_multi_cmap)
        cmap_colors = iter(cmap(np.linspace(0, 1, len(trajectories))))

    # helper function
    def draw(t, name=""):
        if cmap_colors is None:
            color = next(ax._get_lines.prop_cycler)['color']
        else:
            color = next(cmap_colors)
        if SETTINGS.plot_usetex:
            name = name.replace("_", "\\_")
        traj(ax, plot_mode, t, '-', color, name)

    if isinstance(trajectories, trajectory.PosePath3D):
        draw(trajectories)
    elif isinstance(trajectories, dict):
        for name, t in trajectories.items():
            draw(t, name)
    else:
        for t in trajectories:
            draw(t)


def error_array(fig, err_array, x_array=None, statistics=None, threshold=None,
                cumulative=False, color='grey', name="error", title="",
                xlabel="index", ylabel=None, subplot_arg='111', linestyle="-",
                marker=None):
    """
    high-level function for plotting raw error values of a metric
    :param fig: matplotlib figure
    :param err_array: an nx1 array of values
    :param x_array: an nx1 array of x-axis values
    :param statistics: optional dictionary of {metrics.StatisticsType: value}
    :param threshold: optional value for horizontal threshold line
    :param cumulative: set to True for cumulative plot
    :param name: optional name of the value array
    :param title: optional plot title
    :param xlabel: optional x-axis label
    :param ylabel: optional y-axis label
    :param subplot_arg: optional matplotlib subplot ID if used as subplot
    :param linestyle: matplotlib linestyle
    :param marker: optional matplotlib marker style for points
    :return: the matplotlib figure with the plot
    """
    ax = fig.add_subplot(subplot_arg)
    if cumulative:
        if x_array:
            ax.plot(x_array, np.cumsum(err_array), linestyle=linestyle,
                    marker=marker, color=color, label=name)
        else:
            ax.plot(
                np.cumsum(err_array), linestyle=linestyle, marker=marker,
                color=color, label=name)
    else:
        if x_array:
            ax.plot(x_array, err_array, linestyle=linestyle, marker=marker,
                    color=color, label=name)
        else:
            ax.plot(err_array, linestyle=linestyle, marker=marker, color=color,
                    label=name)
    if statistics is not None:
        for stat_name, value in statistics.items():
            color = next(ax._get_lines.prop_cycler)['color']
            if stat_name in {"mean", "median", "rmse"}:
                ax.axhline(y=value, color=color, linewidth=2.0,
                           label=stat_name)
            if stat_name == "std" and "mean" in statistics:
                mean, std = statistics["mean"], statistics["std"]
                ax.axhspan(mean - std / 2, mean + std / 2, color=color,
                           alpha=0.5, label=stat_name)
    if threshold is not None:
        ax.axhline(y=threshold, color='red', linestyle='dashed', linewidth=2.0,
                   label="threshold")
    plt.ylabel(ylabel if ylabel else name)
    plt.xlabel(xlabel)
    plt.title(title)
    plt.legend(frameon=True)
    return fig