"""
Visualization functions for displaying spikes, filters, and cells.
"""

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation, cm, gridspec
from matplotlib.patches import Ellipse

from . import filtertools as ft
from .utils import plotwrapper

__all__ = ['raster', 'psth', 'raster_and_psth', 'spatial', 'temporal',
           'plot_sta', 'play_sta', 'ellipse', 'plot_cells', 'play_rates']


@plotwrapper
def raster(spikes, labels, title='Spike raster', marker_string='ko', **kwargs):
    """
    Plot a raster of spike times.

    Parameters
    ----------
    spikes : array_like
        An array of spike times.

    labels : array_like
        An array of labels corresponding to each spike in spikes. For example,
        this can indicate which cell or trial each spike came from. Spike times
        are plotted on the x-axis, and labels on the y-axis.

    title : string, optional
        An optional title for the plot (Default: 'Spike raster').

    marker_string : string, optional
        The marker string passed to matplotlib's plot function (Default: 'ko').

    ax : matplotlib.axes.Axes instance, optional
        An optional axes onto which the data is plotted.

    fig : matplotlib.figure.Figure instance, optional
        An optional figure onto which the data is plotted.

    kwargs : dict
        Optional keyword arguments are passed to matplotlib's plot function.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Matplotlib Figure object into which raster is plotted.

    ax : matplotlib.axes.Axes
        Matplotlib Axes object into which raster is plotted.
    """
    assert len(spikes) == len(labels), "Spikes and labels must have the same length"

    kwargs.pop('fig')
    ax = kwargs.pop('ax')

    # Plot the spikes
    ax.plot(spikes, labels, marker_string, **kwargs)

    # Labels, etc.
    ax.set_title(title, fontdict={'fontsize': 24})
    ax.set_xlabel('Time (s)', fontdict={'fontsize': 20})


@plotwrapper
def psth(spikes, trial_length=None, binsize=0.01, **kwargs):
    """
    Plot a PSTH from the given spike times.

    Parameters
    ----------
    spikes : array_like
        An array of spike times.

    trial_length : float
        The length of each trial to stack, in seconds. If None (the
        default), a single PSTH is plotted. If a float is passed, PSTHs
        from each trial of the given length are averaged together before
        plotting.

    binsize : float
        The size of bins used in computing the PSTH.

    ax : matplotlib.axes.Axes instance, optional
        An optional axes onto which the data is plotted.

    fig : matplotlib.figure.Figure instance, optional
        An optional figure onto which the data is plotted.

    kwargs : dict
        Keyword arguments passed to matplotlib's ``plot`` function.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Matplotlib Figure object into which PSTH is plotted.

    ax : matplotlib.axes.Axes
        Matplotlib Axes object into which PSTH is plotted.
    """
    _ = kwargs.pop('fig')
    ax = kwargs.pop('ax')

    # Input-checking
    if not trial_length:
        trial_length = spikes.max()

    # Compute the histogram bins to use
    ntrials = int(np.ceil(spikes.max() / trial_length))
    basebins = np.arange(0, trial_length + binsize, binsize)
    tbins = np.tile(basebins, (ntrials, 1)) + \
            (np.tile(np.arange(0, ntrials),
            (basebins.size, 1)).T * trial_length)

    # Bin the spikes in each time bin
    bspk = np.empty((tbins.shape[0], tbins.shape[1] - 1))
    for trial in range(ntrials):
        bspk[trial, :], _ = np.histogram(spikes, bins=tbins[trial, :])

    # Compute the mean over each trial, and multiply by the binsize
    firing_rate = np.mean(bspk, axis=0) / binsize

    # Plot the PSTH
    ax.plot(tbins[0, :-1], firing_rate, color='k', marker=None,
            linestyle='-', linewidth=2)

    # Labels etc
    ax.set_title('PSTH', fontsize=24)
    ax.set_xlabel('Time (s)', fontsize=20)
    ax.set_ylabel('Firing rate (Hz)', fontsize=20)


@plotwrapper
def raster_and_psth(spikes, trial_length=None, binsize=0.01, **kwargs):
    """
    Plot a spike raster and a PSTH on the same set of axes.

    Parameters
    ----------
    spikes : array_like
        An array of spike times.

    trial_length : float
        The length of each trial to stack, in seconds. If None (the default),
        all spikes are plotted as part of the same trial.

    binsize : float
        The size of bins used in computing the PSTH.

    ax : matplotlib.axes.Axes instance, optional
        An optional axes onto which the data is plotted.

    fig : matplotlib.figure.Figure instance, optional
        An optional figure onto which the data is plotted.

    kwargs : dict
        Keyword arguments to matplotlib's ``plot`` function.

    Returns
    -------
    fig : matplotlib.figure.Figure
        Matplotlib Figure instance onto which the data is plotted.

    ax : matplotlib.axes.Axes
        Matplotlib Axes instance onto which the data is plotted.
    """
    _ = kwargs.pop('fig')
    ax = kwargs.pop('ax')

    # Input-checking
    if not trial_length:
        trial_length = spikes.max()

    # Compute the histogram bins to use
    ntrials = int(np.ceil(spikes.max() / trial_length))
    basebins = np.arange(0, trial_length + binsize, binsize)
    tbins = np.tile(basebins, (ntrials, 1)) + \
            (np.tile(np.arange(0, ntrials),
            (basebins.size, 1)).T * trial_length)

    # Bin the spikes in each time bin
    bspk = np.empty((tbins.shape[0], tbins.shape[1] - 1))
    for trial in range(ntrials):
        bspk[trial, :], _ = np.histogram(spikes, bins=tbins[trial, :])

    # Compute the mean over each trial, and multiply by the binsize
    firing_rate = np.mean(bspk, axis=0) / binsize

    # Plot the PSTH
    ax.plot(tbins[0, :-1], firing_rate, color='r', marker=None,
            linestyle='-', linewidth=2)
    ax.set_xlabel('Time (s)', fontdict={'fontsize': 20})
    ax.set_ylabel('Firing rate (Hz)', color='r', fontdict={'fontsize': 20})
    for tick in ax.get_yticklabels():
        tick.set_color('r')

    # Plot the raster
    rastax = ax.twinx()
    for trial in range(ntrials):
        idx = np.bitwise_and(spikes > tbins[trial, 0],
                             spikes <= tbins[trial, -1])
        rastax.plot(spikes[idx] - tbins[trial, 0],
                    trial * np.ones(spikes[idx].shape),
                    color='k', marker='.', linestyle='none')
    rastax.set_ylabel('Trial #', color='k', fontdict={'fontsize': 20})
    for tick in ax.get_yticklabels():
        tick.set_color('k')


def play_sta(sta, repeat=True, frametime=100, cmap='seismic_r',
        clim=None, dx=1.0):
    """
    Plays a spatiotemporal spike-triggered average as a movie.

    Parameters
    ----------
    sta : array_like
        Spike-triggered average array, shaped as ``(nt, nx, ny)``.

    repeat : boolean, optional
        Whether or not to repeat the animation (default is True).

    frametime : float, optional
        Length of time each frame is displayed for in milliseconds
        (default is 100).

    cmap : string, optional
        Name of the colormap to use (Default: ``'seismic_r'``).

    clim : array_like, optional
        2-element color limit for animation; e.g. [0, 255].

    dx : float, optional
        The spatial sampling rate of the STA, setting the scale of the
        x- and y-axes.

    Returns
    -------
    anim : matplotlib animation object
    """
    # mean subtract
    X = sta.copy()
    X -= X.mean()

    # Initial frame
    initial_frame = X[0]

    # Set up the figure
    fig = plt.figure()
    plt.axis('equal')
    spatial_range = (0.0, X.shape[1] * dx, 0.0, X.shape[2] * dx)
    ax = plt.axes(xlim=spatial_range[:2],
                  ylim=spatial_range[2:])
    img = plt.imshow(initial_frame, extent=spatial_range)

    # Set up the colors
    img.set_cmap(cmap)
    img.set_interpolation('nearest')
    if clim is not None:
        img.set_clim(clim)
    else:
        maxval = np.max(np.abs(X))
        img.set_clim([-maxval, maxval])

    # Animation function (called sequentially)
    def animate(i):
        ax.set_title('Frame {0:#d}'.format(i + 1))
        img.set_data(X[i])

    # Call the animator
    anim = animation.FuncAnimation(fig, animate, np.arange(X.shape[0]),
                                   interval=frametime, repeat=repeat)
    plt.show()
    plt.draw()

    return anim


@plotwrapper
def spatial(filt, dx=1.0, maxval=None, **kwargs):
    """
    Plot the spatial component of a full linear filter.

    If the given filter is 2D, it is assumed to be a 1D spatial filter,
    and is plotted directly. If the filter is 3D, it is decomposed into
    its spatial and temporal components, and the spatial component is plotted.

    Parameters
    ----------
    filt : array_like
        The filter whose spatial component is to be plotted. It may have
        temporal components.

    dx : float, optional
        The spatial sampling rate of the STA, setting the scale of the
        x- and y-axes.

    maxval : float, optional
        The value to use as minimal and maximal values when normalizing the
        colormap for this plot. See ``plt.imshow()`` documentation for more
        details.

    ax : matplotlib Axes object, optional
        The axes on which to plot the data; defaults to creating a new figure.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure onto which the spatial STA is plotted.

    ax : matplotlib Axes object
        Axes into which the spatial STA is plotted.
    """
    _ = kwargs.pop('fig')
    ax = kwargs.pop('ax')

    if filt.ndim > 2:
        spatial_filter, _ = ft.decompose(filt)
    else:
        spatial_filter = filt.copy()

    # adjust color limits if necessary
    if not maxval:
        spatial_filter -= np.mean(spatial_filter)
        maxval = np.max(np.abs(spatial_filter))

    # plot the spatial component
    spatial_range = (0.0, spatial_filter.shape[0] * dx, 
                     0.0, spatial_filter.shape[1] * dx)
    ax.imshow(spatial_filter,
              cmap='seismic_r',
              interpolation='nearest',
              aspect='equal',
              vmin=-maxval,
              vmax=maxval,
              extent=spatial_range,
              **kwargs)


@plotwrapper
def temporal(time, filt, **kwargs):
    """
    Plot the temporal component of a full linear filter.

    If the given linear filter is 1D, it is assumed to be a temporal filter,
    and is plotted directly. If the filter is 2 or 3D, it is decomposed into
    its spatial and temporal components, and the temporal component is plotted.

    Parameters
    ----------
    time : array_like
        A time vector to plot against.

    filt : array_like
        The full filter to plot. May be than 1D, but must match in size along
        the first dimension with the ``time`` input.

    ax : matplotlib Axes object, optional
        the axes on which to plot the data; defaults to creating a new figure

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure onto which the temoral STA is plotted.

    ax : matplotlib Axes object
        Axes into which the temporal STA is plotted
    """
    if filt.ndim > 1:
        _, temporal_filter = ft.decompose(filt)
    else:
        temporal_filter = filt.copy()
    kwargs['ax'].plot(time, temporal_filter,
            linestyle='-', linewidth=2, color='LightCoral')
    kwargs['ax'].plot([time[0], time[-1]], [0, 0],
            linestyle=':', linewidth=2, color='k')


def plot_sta(time, sta, dx=1.0):
    """
    Plot a linear filter.

    If the given filter is 1D, it is direclty plotted. If it is 2D, it is
    shown as an image, with space and time as its axes. If the filter is 3D,
    it is decomposed into its spatial and temporal components, each of which
    is plotted on its own axis.

    Parameters
    ----------
    time : array_like
        A time vector to plot against.

    dx : float, optional
        The spatial sampling rate of the STA, setting the scale of the
        x- and y-axes.

    sta : array_like
        The filter to plot.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure onto which the STA is plotted.

    ax : matplotlib Axes object
        Axes into which the STA is plotted
    """

    # plot 1D temporal filter
    if sta.ndim == 1:
        fig = plt.figure()
        fig, ax = temporal(time, sta, ax=fig.add_subplot(111))

    # plot 2D spatiotemporal filter
    elif sta.ndim == 2:

        # normalize
        stan = (sta - np.mean(sta)) / np.var(sta)

        # create new axes
        fig = plt.figure()
        fig, ax = spatial(stan, dx=dx, ax=fig.add_subplot(111))

    # plot 3D spatiotemporal filter
    elif sta.ndim == 3:

        # build the figure
        fig = plt.figure()
        gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1])

        # decompose
        spatial_profile, temporal_filter = ft.decompose(sta)

        # plot spatial profile
        _, axspatial = spatial(spatial_profile, dx=dx,
                ax=fig.add_subplot(gs[0]))

        # plot temporal profile
        fig, axtemporal = temporal(time, temporal_filter,
                ax=fig.add_subplot(gs[1]))
        axtemporal.set_xlim(time[0], time[-1])
        axtemporal.spines['right'].set_color('none')
        axtemporal.spines['top'].set_color('none')
        axtemporal.yaxis.set_ticks_position('left')
        axtemporal.xaxis.set_ticks_position('bottom')

        # return handles
        ax = (axspatial, axtemporal)

    else:
        raise ValueError('The sta parameter has an invalid '
                'number of dimensions (must be 1-3)')

    return fig, ax


@plotwrapper
def ellipse(filt, sigma=2.0, alpha=0.8, fc='none', ec='black', 
        lw=3, dx=1.0, **kwargs):
    """
    Plot an ellipse fitted to the given receptive field.

    Parameters
    ----------
    filt : array_like
        A linear filter whose spatial extent is to be plotted. If this
        is 2D, it is assumed to be the spatial component of the receptive
        field. If it is 3D, it is assumed to be a full spatiotemporal
        receptive field; the spatial component is extracted and plotted.

    sigma : float, optional
        Determines the threshold of the ellipse contours. This is
        the standard deviation of a Gaussian fitted to the filter 
        at which the contours are plotted. Default is 2.0.

    alpha : float, optional
        The alpha blending value, between 0 (transparent) and
        1 (opaque) (Default: 0.8).

    fc : string, optional
        Ellipse face color. (Default: none)

    ec : string, optional
        Ellipse edge color. (Default: black)

    lw : int, optional
        Line width. (Default: 3)

    dx : float, optional
        The spatial sampling rate of the STA, setting the scale of the
        x- and y-axes.

    ax : matplotlib Axes object, optional
        The axes onto which the ellipse should be plotted.
        Defaults to a new figure.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure onto which the ellipse is plotted.

    ax : matplotlib.axes.Axes
        The axes onto which the ellipse is plotted.
    """
    _ = kwargs.pop('fig')
    ax = kwargs.pop('ax')

    if filt.ndim == 2:
        spatial_filter = filt.copy()
    elif filt.ndim == 3:
        spatial_filter = ft.decompose(filt)[0]
    else:
        raise ValueError('Linear filter must be 2- or 3-D')

    # get the ellipse parameters
    center, widths, theta = ft.get_ellipse(spatial_filter, sigma=sigma)

    # compute parameters given spatial scale
    center, widths = map(lambda x: np.asarray(x) * dx, (center, widths))

    # create the ellipse
    ell = Ellipse(xy=center, width=widths[0], height=widths[1], angle=theta,
                  alpha=alpha, ec=ec, fc=fc, lw=lw, **kwargs)

    ax.add_artist(ell)
    ax.set_xlim(0, spatial_filter.shape[0] * dx)
    ax.set_ylim(0, spatial_filter.shape[1] * dx)


@plotwrapper
def plot_cells(cells, dx=1.0, **kwargs):
    """
    Plot the spatial receptive fields for multiple cells.

    Parameters
    ----------
    cells : list of array_like
        A list of spatiotemporal receptive fields, each of which is
        a spatiotemporal array.

    dx : float, optional
        The spatial sampling rate of the STA, setting the scale of the
        x- and y-axes.

    ax : matplotlib Axes object, optional
        The axes onto which the ellipse should be plotted.
        Defaults to a new figure.

    Returns
    ------
    fig : matplotlib.figure.Figure
        The figure onto which the ellipses are plotted.

    ax : matplotlib.axes.Axes
        The axes onto which the ellipses are plotted.
    """
    _ = kwargs.pop('fig')
    ax = kwargs.pop('ax')
    colors = cm.Set1(np.random.rand(len(cells),))

    # for each cell
    for color, sta in zip(colors, cells):

        # get the spatial profile
        try:
            spatial_profile = ft.decompose(sta)[0]
        except np.linalg.LinAlgError:
            continue

        # plot ellipse
        try:
            ellipse(spatial_profile, fc=color, ec=color,
                    lw=2, dx=dx, alpha=0.3, ax=ax)
        except RuntimeError:
            pass


def play_rates(rates, patches, num_levels=255, time=None,
        repeat=True, frametime=100):
    """
    Plays a movie representation of the firing rate of a list of cells, by
    coloring a list of patches with a color proportional to the firing rate. This
    is useful, for example, in conjunction with ``plot_cells``, to color the
    ellipses fitted to a set of receptive fields proportional to the firing rate.

    Parameters
    ----------
    rates : array_like
        An ``(N, T)`` matrix of firing rates. ``N`` is the number of cells, and
        ``T`` gives the firing rate at a each time point.

    patches : list
        A list of ``N`` matplotlib patch elements. The facecolor of these patches is
        altered according to the rates values.

    Returns
    -------
    anim : matplotlib.animation.Animation
        The object representing the full animation.
    """
    # Validate input
    if rates.ndim == 1:
        rates = rates.reshape(1, -1)
    if isinstance(patches, Ellipse):
        patches = [patches]
    N, T = rates.shape

    # Approximate necessary colormap
    colors = cm.gray(np.arange(num_levels))
    rscale = np.round((num_levels - 1) * (rates - rates.min()) /
                      (rates.max() - rates.min())).astype('int').reshape(N, T)

    # set up
    fig = plt.gcf()
    ax = plt.gca()
    if time is None:
        time = np.arange(T)

    # Animation function (called sequentially)
    def animate(t):
        for i in range(N):
            patches[i].set_facecolor(colors[rscale[i, t]])
        ax.set_title('Time: %0.2f seconds' % (time[t]), fontsize=20)

    # Call the animator
    anim = animation.FuncAnimation(fig, animate,
                                   np.arange(T), interval=frametime, repeat=repeat)
    return anim


def anim_to_html(anim):
    """
    Convert an animation into an embedable HTML element.

    This converts the animation objects returned by ``play_sta()`` and
    ``play_rates()`` into an HTML tag that can be embedded, for example
    in a Jupyter notebook.

    Paramters
    ---------
    anim : matplotlib.animation.Animation
        The animation object to embed.

    Returns
    -------
    html : IPython.display.HTML
        An HTML object with the encoded video. This can be directly embedded
        into an IPython notebook.

    Raises
    ------
    An ImportError is raised if the IPython modules required to convert the
    animation are not installed.
    """
    from IPython.display import HTML
    return HTML(anim.to_html5_video())