"""
Plotter 3D

The Plotter 3D plots data in 3D. It has options for setting a title and legend,
plotting 3D points or 3D Gaussians, and clipping data based off axis limits.

This is used to plot the 3D trajectories, including the trajectory samples,
policy samples, and the linear Gaussian controllers.
"""
import numpy as np
import matplotlib.pylab as plt
import matplotlib.gridspec as gridspec
from mpl_toolkits.mplot3d import Axes3D

class Plotter3D:
    def __init__(self, fig, gs, num_plots, rows=None, cols=None):
        if cols is None:
            cols = int(np.floor(np.sqrt(num_plots)))
        if rows is None:
            rows = int(np.ceil(float(num_plots)/cols))
        assert num_plots <= rows*cols, 'Too many plots to put into gridspec.'

        self._fig = fig
        self._gs = gridspec.GridSpecFromSubplotSpec(8, 1, subplot_spec=gs)
        self._gs_legend = self._gs[0:1, 0]
        self._gs_plot   = self._gs[1:8, 0]

        self._ax_legend = plt.subplot(self._gs_legend)
        self._ax_legend.get_xaxis().set_visible(False)
        self._ax_legend.get_yaxis().set_visible(False)

        self._gs_plots = gridspec.GridSpecFromSubplotSpec(rows, cols, subplot_spec=self._gs_plot)
        self._axarr = [plt.subplot(self._gs_plots[i], projection='3d') for i in range(num_plots)]
        self._lims = [None for i in range(num_plots)]
        self._plots = [[] for i in range(num_plots)]

        for ax in self._axarr:
            ax.tick_params(pad=0)
            ax.locator_params(nbins=5)
            for item in (ax.get_xticklabels() + ax.get_yticklabels() + ax.get_zticklabels()):
                item.set_fontsize(10)

        self._fig.canvas.draw()
        self._fig.canvas.flush_events()   # Fixes bug with Qt4Agg backend

    def set_title(self, i, title):
        self._axarr[i].set_title(title)
        self._axarr[i].title.set_fontsize(10)

    def add_legend(self, linestyle, marker, color, label):
        self._ax_legend.plot([], [], linestyle=linestyle, marker=marker,
                color=color, label=label)
        self._ax_legend.legend(ncol=2, mode='expand', fontsize=10)

    def plot(self, i, xs, ys, zs, linestyle='-', linewidth=1.0, marker=None,
            markersize=5.0, markeredgewidth=1.0, color='black', alpha=1.0, label=''):
        # Manually clip at xlim, ylim, zlim (MPL doesn't support axis limits for 3D plots)
        if self._lims[i]:
            xlim, ylim, zlim = self._lims[i]
            xs[np.any(np.c_[xs < xlim[0], xs > xlim[1]], axis=1)] = np.nan
            ys[np.any(np.c_[ys < ylim[0], ys > ylim[1]], axis=1)] = np.nan
            zs[np.any(np.c_[zs < zlim[0], zs > zlim[1]], axis=1)] = np.nan

        # Create and add plot
        plot = self._axarr[i].plot(xs, ys, zs=zs, linestyle=linestyle,
                linewidth=linewidth, marker=marker, markersize=markersize,
                markeredgewidth=markeredgewidth, color=color, alpha=alpha,
                label=label)[0]
        self._plots[i].append(plot)

    def plot_3d_points(self, i, points, linestyle='-', linewidth=1.0,
            marker=None, markersize=5.0, markeredgewidth=1.0, color='black',
            alpha=1.0, label=''):
        self.plot(i, points[:, 0], points[:, 1], points[:, 2],
                linestyle=linestyle, linewidth=linewidth, marker=marker,
                markersize=markersize, markeredgewidth=markeredgewidth,
                color=color, alpha=alpha, label=label)

    def plot_3d_gaussian(self, i, mu, sigma, edges=100, linestyle='-.',
            linewidth=1.0, color='black', alpha=0.1, label=''):
        """
        Plots ellipses in the xy plane representing the Gaussian distributions 
        specified by mu and sigma.
        Args:
            mu    - Tx3 mean vector for (x, y, z)
            sigma - Tx3x3 covariance matrix for (x, y, z)
            edges - the number of edges to use to construct each ellipse
        """
        p = np.linspace(0, 2*np.pi, edges)
        xy_ellipse = np.c_[np.cos(p), np.sin(p)]
        T = mu.shape[0]

        sigma_xy = sigma[:, 0:2, 0:2]
        u, s, v = np.linalg.svd(sigma_xy)

        for t in range(T):
            xyz = np.repeat(mu[t, :].reshape((1, 3)), edges, axis=0)
            xyz[:, 0:2] += np.dot(xy_ellipse, np.dot(np.diag(
                    np.sqrt(s[t, :])), u[t, :, :].T))
            self.plot_3d_points(i, xyz, linestyle=linestyle,
                    linewidth=linewidth, color=color, alpha=alpha, label=label)

    def set_lim(self, i, xlim, ylim, zlim):
        """
        Sets the xlim, ylim, and zlim for plot i
        WARNING: limits must be set before adding data to plots
        Args:
            xlim - a tuple of (x_start, x_end)
            ylim - a tuple of (y_start, y_end)
            zlim - a tuple of (z_start, z_end)
        """
        self._lims[i] = [xlim, ylim, zlim]

    def clear(self, i):
        for plot in self._plots[i]:
            plot.remove()
        self._plots[i] = []

    def clear_all(self):
        for i in range(len(self._plots)):
            self.clear(i)

    def draw(self):
        for ax in self._axarr:
            ax.draw_artist(ax.patch)
        for i in range(len(self._plots)):
            for plot in self._plots[i]:
                self._axarr[i].draw_artist(plot)
        self._fig.canvas.draw()
        self._fig.canvas.flush_events()   # Fixes bug with Qt4Agg backend