"""
tint.visualization
==================

Visualization tools for tracks objects.

"""
import warnings
warnings.filterwarnings('ignore')

import gc
import os
import pandas as pd
import numpy as np
import shutil
import tempfile
import matplotlib as mpl
from IPython.display import display, Image
from matplotlib import pyplot as plt
import cartopy.crs as ccrs

import pyart

from .grid_utils import get_grid_alt


class Tracer(object):
    colors = ['m', 'r', 'lime', 'darkorange', 'k', 'b', 'darkgreen', 'yellow']
    colors.reverse()

    def __init__(self, tobj, persist):
        self.tobj = tobj
        self.persist = persist
        self.color_stack = self.colors * 10
        self.cell_color = pd.Series()
        self.history = None
        self.current = None

    def update(self, nframe):
        self.history = self.tobj.tracks.loc[:nframe]
        self.current = self.tobj.tracks.loc[nframe]
        if not self.persist:
            dead_cells = [key for key in self.cell_color.keys()
                          if key
                          not in self.current.index.get_level_values('uid')]
            self.color_stack.extend(self.cell_color[dead_cells])
            self.cell_color.drop(dead_cells, inplace=True)

    def _check_uid(self, uid):
        if uid not in self.cell_color.keys():
            try:
                self.cell_color[uid] = self.color_stack.pop()
            except IndexError:
                self.color_stack += self.colors * 5
                self.cell_color[uid] = self.color_stack.pop()

    def plot(self, ax):
        for uid, group in self.history.groupby(level='uid'):
            self._check_uid(uid)
            tracer = group[['lon', 'lat']]
            if self.persist or (uid in self.current.index):
                ax.plot(tracer.lon, tracer.lat, self.cell_color[uid])

def full_domain(tobj, grids, tmp_dir, vmin=-8, vmax=64,
                cmap=None, alt=None, isolated_only=False,
                tracers=False, persist=False,
                projection=None, **kwargs):
    grid_size = tobj.grid_size
    if cmap is None:
        cmap = pyart.graph.cm_colorblind.HomeyerRainbow
    if alt is None:
        alt = tobj.params['GS_ALT']
    if projection is None:
        projection=ccrs.PlateCarree()
    if tracers:
        tracer = Tracer(tobj, persist)

    radar_lon = tobj.radar_info['radar_lon']
    radar_lat = tobj.radar_info['radar_lat']
    lon = np.arange(round(radar_lon-5,2),round(radar_lon+5,2), 1)
    lat = np.arange(round(radar_lat-5,2),round(radar_lat+5,2), 1)

    nframes = tobj.tracks.index.levels[0].max() + 1
    print('Animating', nframes, 'frames')

    for nframe, grid in enumerate(grids):
        fig_grid = plt.figure(figsize=(10, 8))
        print('Frame:', nframe)
        display = pyart.graph.GridMapDisplay(grid)
        ax = fig_grid.add_subplot(111, projection=projection)
        transform = projection._as_mpl_transform(ax)
        display.plot_crosshairs(lon=radar_lon, lat=radar_lat)
        display.plot_grid(tobj.field, level=get_grid_alt(grid_size, alt),
                          vmin=vmin, vmax=vmax, mask_outside=False,
                          cmap=cmap, transform=projection, ax=ax, **kwargs)

        if nframe in tobj.tracks.index.levels[0]:
            frame_tracks = tobj.tracks.loc[nframe]

            if tracers:
                tracer.update(nframe)
                tracer.plot(ax)

            for ind, uid in enumerate(frame_tracks.index):
                if isolated_only and not frame_tracks['isolated'].iloc[ind]:
                    continue
                x = frame_tracks['lon'].iloc[ind]
                y = frame_tracks['lat'].iloc[ind]
                ax.text(x, y, uid, transform=projection, fontsize=20)


        plt.savefig(tmp_dir + '/frame_' + str(nframe).zfill(3) + '.png',
                    bbox_inches = 'tight', dpi=300)
        plt.close()
        del grid, display, ax
        gc.collect()


def lagrangian_view(tobj, grids, tmp_dir, uid=None, vmin=-8, vmax=64,
                    cmap=None, alt=None, box_rad=.1, projection=None):

    if uid is None:
        print("Please specify 'uid' keyword argument.")
        return
    stepsize = 0.05
    title_font = 18
    axes_font = 16
    mpl.rcParams['xtick.labelsize'] = 16
    mpl.rcParams['ytick.labelsize'] = 16

    field = tobj.field
    grid_size = tobj.grid_size

    if cmap is None:
        cmap = pyart.graph.cm_colorblind.HomeyerRainbow
    if alt is None:
        alt = tobj.params['GS_ALT']
    if projection is None:
        projection = ccrs.PlateCarree()
        
    cell = tobj.tracks.xs(uid, level='uid')

    nframes = len(cell)
    print('Animating', nframes, 'frames')
    cell_frame = 0

    for nframe, grid in enumerate(grids):
        if nframe not in cell.index:
            continue

        print('Frame:', cell_frame)
        cell_frame += 1

        row = cell.loc[nframe]
        display = pyart.graph.GridMapDisplay(grid)

        # Box Size
        tx = np.int(np.round(row['grid_x']))
        ty = np.int(np.round(row['grid_y']))
        tx_met = grid.x['data'][tx]
        ty_met = grid.y['data'][ty]
        lat = row['lat']
        lon = row['lon']
        box_rad_met = box_rad 
        box = np.array([-1*box_rad_met, box_rad_met])
        

        lvxlim = (lon) + box
        lvylim = (lat) + box
        xlim = (tx_met + np.array([-25000, 25000]))/1000
        ylim = (ty_met + np.array([-25000, 25000]))/1000

        fig = plt.figure(figsize=(20, 15))

        fig.suptitle('Cell ' + uid + ' Scan ' + str(nframe), fontsize=22)
        plt.axis('off')

        # Lagrangian View
        ax = fig.add_subplot(3, 2, (1, 3), projection=projection)

        display.plot_grid(field, level=get_grid_alt(grid_size, alt),
                          vmin=vmin, vmax=vmax, mask_outside=False,
                          cmap=cmap, colorbar_flag=False,
                          ax=ax, projection=projection)

        display.plot_crosshairs(lon=lon, lat=lat, linestyle='--', 
                                color='k', linewidth=3, ax=ax)

        ax.set_xlim(lvxlim[0], lvxlim[1])
        ax.set_ylim(lvylim[0], lvylim[1])

        ax.set_xticks(np.arange(lvxlim[0], lvxlim[1], stepsize))
        ax.set_yticks(np.arange(lvylim[0], lvylim[1], stepsize))

        ax.set_title('Top-Down View', fontsize=title_font)
        ax.set_xlabel('Longitude of grid cell center\n [degree_E]',
                       fontsize=axes_font)
        ax.set_ylabel('Latitude of grid cell center\n [degree_N]',
                       fontsize=axes_font)

        # Latitude Cross Section
        ax = fig.add_subplot(3, 2, 2)
        display.plot_latitude_slice(field, lon=lon, lat=lat,
                                    title_flag=False,
                                    colorbar_flag=False, edges=False,
                                    vmin=vmin, vmax=vmax, mask_outside=False,
                                    cmap=cmap,
                                    ax=ax)

        ax.set_xlim(xlim[0], xlim[1])
        ax.set_xticks(np.arange(xlim[0], xlim[1], 6))
        ax.set_xticklabels(np.round((np.arange(xlim[0], xlim[1], 6)),
                                     2))

        ax.set_title('Latitude Cross Section', fontsize=title_font)
        ax.set_xlabel('East West Distance From Origin (km)' + '\n',
                       fontsize=axes_font)
        ax.set_ylabel('Distance Above Origin (km)', fontsize=axes_font)
        ax.set_aspect(aspect=1.3)

        # Longitude Cross Section
        ax = fig.add_subplot(3, 2, 4)
        display.plot_longitude_slice(field, lon=lon, lat=lat,
                                     title_flag=False,
                                     colorbar_flag=False, edges=False,
                                     vmin=vmin, vmax=vmax, mask_outside=False,
                                     cmap=cmap,
                                     ax=ax)
        ax.set_xlim(ylim[0], ylim[1])
        ax.set_xticks(np.arange(ylim[0], ylim[1], 6))
        ax.set_xticklabels(np.round(np.arange(ylim[0], ylim[1], 6), 2))

        ax.set_title('Longitudinal Cross Section', fontsize=title_font)
        ax.set_xlabel('North South Distance From Origin (km)',
                       fontsize=axes_font)
        ax.set_ylabel('Distance Above Origin (km)', fontsize=axes_font)
        ax.set_aspect(aspect=1.3)

        # Time Series Statistic
        max_field = cell['max']
        plttime = cell['time']

        # Plot
        ax = fig.add_subplot(3, 2, (5, 6))
        ax.plot(plttime, max_field, color='b', linewidth=3)
        ax.axvline(x=plttime[nframe], linewidth=4, color='r')
        ax.set_title('Time Series', fontsize=title_font)
        ax.set_xlabel('Time (UTC) \n Lagrangian Viewer Time',
                       fontsize=axes_font)
        ax.set_ylabel('Maximum ' + field, fontsize=axes_font)

        # plot and save figure
        fig.savefig(tmp_dir + '/frame_' + str(nframe).zfill(3) + '.png')
        plt.close()
        del grid, display
        gc.collect()


def make_mp4_from_frames(tmp_dir, dest_dir, basename, fps):
    os.chdir(tmp_dir)
    os.system(" ffmpeg -framerate " + str(fps)
              + " -pattern_type glob -i '*.png'"
              + " -movflags faststart -pix_fmt yuv420p -vf"
              + " 'scale=trunc(iw/2)*2:trunc(ih/2)*2' -y "
              + basename + '.mp4')
    try:
        shutil.move(basename + '.mp4', dest_dir)
    except FileNotFoundError:
        print('Make sure ffmpeg is installed properly.')


def animate(tobj, grids, outfile_name, style='full', fps=1, keep_frames=False,
            overwrite=False, **kwargs):
    """
    Creates gif animation of tracked cells.

    Parameters
    ----------
    tobj : Cell_tracks
        The Cell_tracks object to be visualized.
    grids : iterable
        An iterable containing all of the grids used to generate tobj.
    outfile_name : str
        The name of the output file to be produced.
    alt : float
        The altitude to be plotted in meters.
    vmin, vmax : float
        Limit values for the colormap.
    arrows : bool
        If True, draws arrow showing corrected shift for each object. Only used
        in 'full' style.
    isolation : bool
        If True, only annotates uids for isolated objects. Only used in 'full'
        style.
    uid : str
        The uid of the object to be viewed from a lagrangian persepective. Only
        used when style is 'lagrangian'.
    fps : int
        Frames per second for output gif.
    overwrite : bool
        If true, will overwrite existing mp4 if one already exists.
        False, won't overwrite if file already exists.

    """

    styles = {'full': full_domain,
              'lagrangian': lagrangian_view}
    anim_func = styles[style]

    dest_dir = os.path.dirname(outfile_name)
    basename = os.path.basename(outfile_name)
    if len(dest_dir) == 0:
        dest_dir = os.getcwd()

    if os.path.exists(basename + '.mp4') and overwrite is False:
        print('Filename already exists.')
        return

    tmp_dir = tempfile.mkdtemp()

    try:
        anim_func(tobj, grids, tmp_dir, **kwargs)
        if len(os.listdir(tmp_dir)) == 0:
            print('Grid generator is empty.')
            return
        make_mp4_from_frames(tmp_dir, dest_dir, basename, fps)
        if keep_frames:
            frame_dir = os.path.join(dest_dir, basename + '_frames')
            shutil.copytree(tmp_dir, frame_dir)
            os.chdir(dest_dir)
    finally:
        shutil.rmtree(tmp_dir)


def embed_mp4_as_gif(filename):
    """ Makes a temporary gif version of an mp4 using ffmpeg for embedding in
    IPython. Intended for use in Jupyter notebooks. """
    if not os.path.exists(filename):
        print('file does not exist.')
        return

    dirname = os.path.dirname(filename)
    basename = os.path.basename(filename)
    newfile = tempfile.NamedTemporaryFile()
    newname = newfile.name + '.gif'
    if len(dirname) != 0:
        os.chdir(dirname)

    os.system('ffmpeg -i ' + basename + ' ' + newname)

    try:
        with open(newname, 'rb') as f:
            display(Image(f.read(), format='png'))
    finally:
        os.remove(newname)