"""
Filename: display_methods.py
Author: Gene Callahan
A collection of convenience functions
for using matplotlib.
"""
from functools import wraps
from math import ceil
import numpy as np
import networkx as nx
import logging
import io
from indra.prop_args2 import user_type
plt_present = True

try:
    import matplotlib as mpl
    if user_type == "Web browser":  # you can change this to right value!
        mpl.use('Agg')
    import matplotlib.pyplot as plt
    import matplotlib.animation as animation
    from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
    from matplotlib.figure import Figure
    plt.ion()
except ImportError:
    plt_present = False

global imageIO

anim_func = None

BLUE = 'b'
RED = 'r'
GREEN = 'g'
YELLOW = 'y'
MAGENTA = 'm'
CYAN = 'c'
BLACK = 'k'
WHITE = 'w'
colors = [BLUE, RED, GREEN, YELLOW, MAGENTA, CYAN, BLACK, WHITE]
NUM_COLORS = len(colors)
X = 0
Y = 1

def expects_plt(fn):
    """
    Should be used to decorate any function that uses matplotlib's pyplot.
    """
    @wraps(fn)
    def wrapper(*args, **kwargs):
        if not plt_present:
            print(f"cannot plot with {fn.__qualname__}: matplotlib's pyplot is not installed")
            return
        return fn(*args, **kwargs)
    return wrapper


def hierarchy_pos(graph, root, width=1., vert_gap=0.2, vert_loc=0,
                  xcenter=0.5, pos=None, parent=None):
        """
        This is an attempt to get a tree graph from networkx.
        If there is a cycle that is reachable from root, then this will
        infinitely recurse.
        graph: the graph
        root: the root node of current branch
        width: horizontal space allocated for this branch
                - avoids overlap with other branches
        vert_gap: gap between levels of hierarchy
        vert_loc: vertical location of root
        xcenter: horizontal location of root
        pos: a dict saying where all nodes go if they have been assigned
        parent: parent of this branch.
        """
        if pos is None:
            pos = {root: (xcenter, vert_loc)}
        else:
            pos[root] = (xcenter, vert_loc)
        neighbors = graph.neighbors(root)
        if parent is not None:
            neighbors.remove(parent)
        n_len = len(neighbors)
        dx = 0
        if n_len != 0:
            dx = width / n_len
            nextx = xcenter - width / 2 + dx / 2
            for neighbor in neighbors:
                pos = hierarchy_pos(graph, neighbor, width=dx,
                                    vert_gap=vert_gap,
                                    vert_loc=vert_loc-vert_gap,
                                    xcenter=nextx, pos=pos, parent=root)
                nextx += dx
        return pos


@expects_plt
def draw_graph(graph, title, hierarchy=False, root=None):
    """
    Drawing networkx graphs.
    graph is the graph to draw.
    hierarchy is whether we should draw it as a tree.
    """
    pos = None
    plt.title(title)
    if hierarchy:
        pos = hierarchy_pos(graph, root)
    nx.draw(graph, pos=pos, with_labels=True)
    plt.show()


def get_color(var, i):
    color = None
    if "color" in var:
        color = var["color"]
    if color is None:
        color = colors[i % NUM_COLORS]
    return color


def assemble_lgraph_data(key, values, color, data=None):
    # put our data in right form for line graph
    if data is None:
        data = {}

    data[key] = {}
    data[key]["data"] = values
    data[key]["color"] = color
    return data


class LineGraph():
    """
    We create a class here to save state for animation.
    Display a simple matplotlib line graph.
    The data is a dictionary with the label
    as the key and a list of numbers as the
    thing to graph.
    data_points is the length of the x-axis.
    """

    def __init__(self, title, varieties, data_points,
                 anim=False, data_func=None, is_headless=False, legend_pos=4):
        global anim_func

        self.title = title
        self.anim = anim
        self.data_func = data_func
        for i in varieties:
            data_points = len(varieties[i]["data"])
            break
        self.draw_graph(data_points, varieties)
        self.headless = is_headless

        if anim and not self.headless:
            anim_func = animation.FuncAnimation(self.fig,
                                    self.update_plot,
                                    frames=1000,
                                    interval=500,
                                    blit=False)

    @expects_plt
    def draw_graph(self, data_points, varieties):
        """
        Draw all elements of the graph.
        """
        self.fig, self.ax = plt.subplots()
        x = np.arange(0, data_points)
        self.create_lines(x, self.ax, varieties)
        self.ax.legend()
        self.ax.set_title(self.title)

    def create_lines(self, x, ax, varieties):
        """
        Draw just the data portion.
        """
        self.lines = []
        for i, var in enumerate(varieties):
            data = varieties[var]["data"]
            color = get_color(varieties[var], i)
            y = np.array(data)
            ax.plot(x, y, linewidth=2, label=var, alpha=1.0, c=color)

    @expects_plt
    def show(self):
        """
        Display the plot.
        """
        if not self.headless:
            plt.show()
        else:
            file = io.BytesIO()
            plt.savefig(file, format="png")
            return file

    @expects_plt
    def update_plot(self, i):
        """
        This is our animation function.
        For line graphs, redraw the whole thing.
        """
        plt.clf()
        (data_points, varieties) = self.data_func()
        self.draw_graph(data_points, varieties)
        self.show()


class ScatterPlot():
    """
    We are going to use a class here to save state for our animation
    """

    def update_plot(self, i):
        """
        This is our animation function.
        """
        if self.scats is not None:
            for scat in self.scats:
                if scat is not None:
                    scat.remove()
        varieties = self.data_func()
        self.create_scats(varieties)
        return self.scats

    @expects_plt
    def __init__(self, title, varieties, width, height,
                 anim=True, data_func=None, is_headless=False, legend_pos=4):
        """
        Setup a scatter plot.
        varieties contains the different types of
        entities to show in the plot, which
        will get assigned different colors
        """
        global anim_func

        self.scats = None
        self.anim = anim
        self.data_func = data_func
        self.s = ceil(4096 / width)
        self.headless = is_headless

        fig, ax = plt.subplots()
        ax.set_xlim(0, width)
        ax.set_ylim(0, height)
        self.create_scats(varieties)
        ax.legend(loc = legend_pos)
        ax.set_title(title)
        plt.grid(True)

        if anim and not self.headless:
            anim_func = animation.FuncAnimation(fig,
                                    self.update_plot,
                                    frames=1000,
                                    interval=500,
                                    blit=False)

    @expects_plt
    def show(self):
        """
        Display the plot.
        """
        if not self.headless:
            plt.show()
        else:
            file = io.BytesIO()
            plt.savefig(file, format="png")
            return file

    def get_arrays(self, varieties, var):
        x_array = np.array(varieties[var][X])
        y_array = np.array(varieties[var][Y])
        return (x_array, y_array)

    @expects_plt
    def create_scats(self, varieties):
        self.scats = []
        for i, var in enumerate(varieties):
            (x_array, y_array) = self.get_arrays(varieties, var)
            if len(x_array) <= 0:  # no data to graph!
                next
            elif len(x_array) != len(y_array):
                logging.debug("Array length mismatch in scatter plot")
                next
            color = get_color(varieties[var], i)
            scat = plt.scatter(x_array, y_array,
                               c=color, label=var,
                               alpha=1.0, marker="8",
                               edgecolors='none', s=self.s)
            self.scats.append(scat)