#!/usr/bin/env python
"""Provides drawing utilities."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import pylab as plt
import math
import networkx as nx
import json
from networkx.readwrite import json_graph
from matplotlib.font_manager import FontProperties

from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score

from eden.util import _serialize_list
from eden.display.graph_layout import KKEmbedder
from sklearn.preprocessing import minmax_scale
from matplotlib.offsetbox import OffsetImage, AnnotationBbox


class SetEncoder(json.JSONEncoder):
    """SetEncoder."""

    def default(self, obj):
        """default."""
        if isinstance(obj, set):
            return list(obj)
        return json.JSONEncoder.default(self, obj)


def serialize_graph(graph):
    """Make string."""
    json_data = json_graph.node_link_data(graph)
    serial_data = json.dumps(json_data,
                             separators=(',', ':'),
                             indent=4,
                             cls=SetEncoder)
    return serial_data


def map_labels_to_colors(graphs):
    """Map all node labels into a real in [0,1]."""
    label_set = set()
    for g in graphs:
        for u in g.nodes():
            label_set.add(g.nodes[u]['label'])
    dim = len(label_set)
    label_colors = dict()
    for i, label in enumerate(sorted(label_set)):
        label_colors[label] = float(i) / dim
    return label_colors


def draw_graph(graph_orig,
               vertex_label='label',
               vertex_color=None,
               vertex_color_dict=None,
               vertex_fixed_color=None,
               vertex_alpha=0.6,
               vertex_border=1,
               vertex_position=None,
               vertex_size=600,
               vertex_images=None,
               vertex_image_scale=0.1,
               vertex_image_alpha=.5,
               vertex_shape='o',
               compact=False,
               colormap='YlOrRd',
               vmin=None,
               vmax=None,
               invert_colormap=False,
               secondary_vertex_label=None,
               secondary_vertex_color=None,
               secondary_vertex_fixed_color=None,
               secondary_vertex_alpha=0.6,
               secondary_vertex_border=1,
               secondary_vertex_size=600,
               secondary_vertex_colormap='YlOrRd',
               secondary_vertex_vmin=None,
               secondary_vertex_vmax=None,

               edge_label='label',
               secondary_edge_label=None,
               edge_colormap='YlOrRd',
               edge_vmin=None,
               edge_vmax=None,
               edge_color=None,
               edge_fixed_color=None,
               edge_width=None,
               edge_alpha=0.5,

               dark_edge_colormap='YlOrRd',
               dark_edge_vmin=0,
               dark_edge_vmax=1,
               dark_edge_color=None,
               dark_edge_fixed_color=None,
               dark_edge_dotted=True,
               dark_edge_alpha=0.3,

               size=10,
               size_x_to_y_ratio=1,
               font_size=9,
               layout='graphviz',
               prog='neato',
               pos=None,

               verbose=True,
               file_name=None,
               title_key='id',
               ignore_for_layout="edge_attribute",

               logscale=False):
    """Plot graph layout."""
    graph = nx.convert_node_labels_to_integers(graph_orig)
    if size is not None:
        size_x = size
        size_y = int(float(size) / size_x_to_y_ratio)
        plt.figure(figsize=(size_x, size_y))
        axes = plt.gca()
    plt.grid(False)
    plt.axis('off')
    plt.axis('equal')

    if vertex_label is not None:
        if secondary_vertex_label:
            vertex_labels = dict()
            for u, d in graph.nodes(data=True):
                label1 = _serialize_list(d.get(vertex_label, 'N/A'))
                label2 = _serialize_list(d.get(secondary_vertex_label, 'N/A'))
                vertex_labels[u] = '%s\n%s' % (label1, label2)
        else:
            vertex_labels = dict()
            for u, d in graph.nodes(data=True):
                label = d.get(vertex_label, 'N/A')
                vertex_labels[u] = _serialize_list(label)

    edges_normal = [(u, v) for (u, v, d) in graph.edges(data=True)
                    if d.get('nesting', False) is False]
    edges_nesting = [(u, v) for (u, v, d) in graph.edges(data=True)
                     if d.get('nesting', False) is True]

    if edge_label is not None:
        if secondary_edge_label:
            edge_labels = dict([((u, v,), '%s\n%s' %
                                 (d.get(edge_label, ''),
                                  d.get(secondary_edge_label, '')))
                                for u, v, d in graph.edges(data=True)])
        else:
            edge_labels = dict([((u, v,), d.get(edge_label, ''))
                                for u, v, d in graph.edges(data=True)])

    if vertex_color is None:
        node_color = 'white'
    elif vertex_color in ['-label-', '_labels_', '_label_', '__labels__', '__label__']:
        node_color = []
        for u, d in graph.nodes(data=True):
            label = d.get('label', '.')
            if vertex_color_dict is not None:
                node_color.append(vertex_color_dict.get(label, 0))
            else:
                node_color.append(hash(_serialize_list(label)))
    else:
        if invert_colormap:
            node_color = [- d.get(vertex_color, 0)
                          for u, d in graph.nodes(data=True)]
        else:
            node_color = [d.get(vertex_color, 0)
                          for u, d in graph.nodes(data=True)]
        if logscale is True:
            log_threshold = 0.01
            node_color = [math.log(c) if c > log_threshold
                          else math.log(log_threshold)
                          for c in node_color]
    if isinstance(node_color, list):
        if vmax is None:
            vmax = max(node_color)
        if vmin is None:
            vmin = min(node_color)
    if edge_width is None:
        widths = 1
    elif isinstance(edge_width, int):
        widths = edge_width
    else:
        widths = [d.get(edge_width, 1)
                  for u, v, d in graph.edges(data=True)
                  if 'nesting' not in d]
    if edge_color is None:
        edge_colors = 'black'
    elif edge_color in ['-label-', '_labels_', '_label_', '__labels__', '__label__']:
        edge_colors = [hash(str(d.get('label', '.')))
                       for u, v, d in graph.edges(data=True)
                       if 'nesting' not in d]
    else:
        if invert_colormap:
            edge_colors = [- d.get(edge_color, 0)
                           for u, v, d in graph.edges(data=True)
                           if 'nesting' not in d]
        else:
            edge_colors = [d.get(edge_color, 0)
                           for u, v, d in graph.edges(data=True)
                           if 'nesting' not in d]
    if isinstance(edge_colors, list):
        if edge_vmax is None:
            edge_vmax = max(edge_colors)
        if edge_vmin is None:
            edge_vmin = min(edge_colors)
    if dark_edge_color is None:
        dark_edge_colors = 'black'
    else:
        dark_edge_colors = [d.get(dark_edge_color, 0)
                            for u, v, d in graph.edges(data=True)
                            if 'nesting' in d]
    tmp_edge_set = [(u, v)
                    for u, v in graph.edges()
                    if graph.edges[u, v].get(ignore_for_layout, False)]
    if len(tmp_edge_set) > 0:
        graph.remove_edges_from(tmp_edge_set)

    if pos is None:
        if layout == 'graphviz':
            graph_copy = graph.copy()
            for u in graph_copy.nodes():
                graph_copy.nodes[u].pop('label', None)
                graph_copy.nodes[u].pop('vec', None)
                graph_copy.nodes[u].pop('svec', None)
            pos = nx.nx_pydot.graphviz_layout(graph_copy,
                                              prog=prog)
        elif layout == "RNA":
            import RNA  # this is part of the vienna RNA package
            rna_object = RNA.get_xy_coordinates(graph.graph['structure'])
            pos = {i: (rna_object.get(i).X, rna_object.get(i).Y)
                   for i in range(len(graph.graph['structure']))}
        elif layout == 'circular':
            pos = nx.circular_layout(graph)
        elif layout == 'random':
            pos = nx.random_layout(graph)
        elif layout == 'spring':
            pos = nx.spring_layout(graph)
        elif layout == 'shell':
            pos = nx.shell_layout(graph)
        elif layout == 'spectral':
            pos = nx.spectral_layout(graph)
        elif layout == 'kk':
            pos = nx.kamada_kawai_layout(graph)
        elif layout == 'KK':
            pos = KKEmbedder().transform(graph)
        else:
            raise Exception('Unknown layout format: %s' % layout)
        _pos = minmax_scale(np.array([pos[i] for i in pos]))
        pos = {i: (p[0], p[1]) for i, p in zip(pos, _pos)}
    if vertex_position is not None:
        pos = {u: (graph.nodes[u][vertex_position][0], graph.nodes[u][vertex_position][1]) for u in graph.nodes()}

    if vertex_border is False:
        linewidths = 0.001
    else:
        linewidths = vertex_border

    if len(tmp_edge_set) > 0:
        graph.add_edges_from(tmp_edge_set)

    if secondary_vertex_border is False:
        secondary_linewidths = 0.001
    else:
        secondary_linewidths = secondary_vertex_border
    if secondary_vertex_fixed_color is not None:
        secondary_node_color = secondary_vertex_fixed_color
    if secondary_vertex_color is not None:
        secondary_node_color = [d.get(secondary_vertex_color, 0)
                                for u, d in graph.nodes(data=True)]
    if secondary_vertex_fixed_color is not None or \
            secondary_vertex_color is not None:
        secondary_nodes = nx.draw_networkx_nodes(
            graph, pos,
            node_color=secondary_node_color,
            alpha=secondary_vertex_alpha,
            node_size=secondary_vertex_size,
            linewidths=secondary_linewidths,
            cmap=plt.get_cmap(
                secondary_vertex_colormap),
            vmin=secondary_vertex_vmin, vmax=secondary_vertex_vmax)
        secondary_nodes.set_edgecolor('k')
    if isinstance(secondary_vertex_color, list):
        if secondary_vertex_vmax is None:
            secondary_vertex_vmax = max(secondary_vertex_color)
        if secondary_vertex_vmin is None:
            secondary_vertex_vmin = min(secondary_vertex_color)
    if vertex_fixed_color is not None:
        node_color = vertex_fixed_color
    if compact:
        nodes = nx.draw_networkx_nodes(graph, pos,
                                       node_shape=vertex_shape,
                                       node_color='w',
                                       alpha=1,
                                       node_size=vertex_size,
                                       linewidths=linewidths)
        nodes.set_edgecolor('k')
        nx.draw_networkx_nodes(graph, pos,
                               node_shape=vertex_shape,
                               node_color=node_color,
                               alpha=vertex_alpha,
                               node_size=vertex_size,
                               linewidths=None,
                               cmap=plt.get_cmap(colormap),
                               vmin=vmin, vmax=vmax)

    else:
        nodes = nx.draw_networkx_nodes(graph, pos,
                                       node_shape=vertex_shape,
                                       node_color=node_color,
                                       alpha=vertex_alpha,
                                       node_size=vertex_size,
                                       linewidths=linewidths,
                                       cmap=plt.get_cmap(colormap),
                                       vmin=vmin, vmax=vmax)
        nodes.set_edgecolor('k')

    if edge_fixed_color is not None:
        edge_colors = edge_fixed_color
    nx.draw_networkx_edges(graph, pos,
                           edgelist=edges_normal,
                           width=widths,
                           edge_color=edge_colors,
                           edge_cmap=plt.get_cmap(edge_colormap),
                           edge_vmin=edge_vmin, edge_vmax=edge_vmax,
                           alpha=edge_alpha)
    if dark_edge_dotted:
        style = 'dotted'
    else:
        style = 'solid'
    if dark_edge_fixed_color is not None:
        dark_edge_colors = dark_edge_fixed_color
    nx.draw_networkx_edges(graph, pos,
                           edgelist=edges_nesting,
                           width=1,
                           edge_cmap=plt.get_cmap(dark_edge_colormap),
                           edge_vmin=dark_edge_vmin, edge_vmax=dark_edge_vmax,
                           edge_color=dark_edge_colors,
                           style=style,
                           alpha=dark_edge_alpha)
    if edge_label is not None:
        nx.draw_networkx_edge_labels(graph,
                                     pos,
                                     edge_labels=edge_labels,
                                     font_size=font_size)
    if vertex_label is not None:
        nx.draw_networkx_labels(graph,
                                pos,
                                vertex_labels,
                                font_size=font_size,
                                font_weight='normal',
                                font_color='black')
    if vertex_images is not None:
        for im, xy_pos in zip(vertex_images, pos):
            x, y = pos[xy_pos]
            oi = OffsetImage(im,
                             zoom=vertex_image_scale,
                             alpha=vertex_image_alpha)
            box = AnnotationBbox(oi, (x, y), frameon=False)
            axes.add_artist(box)
    if title_key:
        title = str(graph.graph.get(title_key, ''))
        font = FontProperties()
        font.set_family('monospace')
        plt.title(title, fontproperties=font)
    if size is not None:
        # here we decide if we output the image.
        # note: if size is not set, the canvas has been created outside
        # of this function.
        # we wont write on a canvas that we didn't create ourselves.
        if file_name is None:
            plt.show()
        else:
            plt.savefig(file_name,
                        bbox_inches='tight',
                        transparent=True,
                        pad_inches=0)
            plt.close()


def draw_adjacency_graph(adjacency_matrix,
                         node_color=None,
                         size=10,
                         layout='graphviz',
                         prog='neato',
                         node_size=80,
                         colormap='autumn'):
    """draw_adjacency_graph."""
    graph = nx.from_scipy_sparse_matrix(adjacency_matrix)

    plt.figure(figsize=(size, size))
    plt.grid(False)
    plt.axis('off')

    if layout == 'graphviz':
        pos = nx.graphviz_layout(graph, prog=prog)
    else:
        pos = nx.spring_layout(graph)

    if len(node_color) == 0:
        node_color = 'gray'
    nx.draw_networkx_nodes(graph, pos,
                           node_color=node_color,
                           alpha=0.6,
                           node_size=node_size,
                           cmap=plt.get_cmap(colormap))
    nx.draw_networkx_edges(graph, pos, alpha=0.5)
    plt.show()


# draw a whole set of graphs::
def draw_graph_set(graphs,
                   n_graphs_per_line=5,
                   size=4,
                   edge_label=None,
                   pos=None,
                   **args):
    """draw_graph_set."""
    graphs = list(graphs)
    if pos:
        for graph, pos_dict in zip(graphs, pos):
            graph.graph['pos_dict'] = pos_dict

    counter = 0
    while graphs:
        counter += 1
        draw_graph_row(graphs[:n_graphs_per_line],
                       index=counter,
                       n_graphs_per_line=n_graphs_per_line,
                       edge_label=edge_label,
                       size=size, **args)
        graphs = graphs[n_graphs_per_line:]

# draw a row of graphs


def draw_graph_row(graphs,
                   index=0,
                   contract=True,
                   n_graphs_per_line=5,
                   size=4,
                   xlim=None,
                   ylim=None,
                   **args):
    """draw_graph_row."""
    dim = len(graphs)
    size_y = size
    size_x = size * n_graphs_per_line * args.get('size_x_to_y_ratio', 1)
    plt.figure(figsize=(size_x, size_y))

    if xlim is not None:
        plt.xlim(xlim)
        plt.ylim(ylim)
    else:
        plt.xlim(xmax=3)

    for i in range(dim):
        plt.subplot(1, n_graphs_per_line, i + 1)
        graph = graphs[i]
        draw_graph(graph,
                   size=None,
                   pos=graph.graph.get('pos_dict', None),
                   **args)
    if args.get('file_name', None) is None:
        plt.show()
    else:
        row_file_name = '%d_' % (index) + args['file_name']
        plt.savefig(row_file_name,
                    bbox_inches='tight',
                    transparent=True,
                    pad_inches=0)
        plt.close()


def dendrogram(data,
               vectorizer,
               method="ward",
               color_threshold=1,
               size=10,
               filename=None):
    """dendrogram.

    "median","centroid","weighted","single","ward","complete","average"
    """
    data = list(data)
    # get labels
    labels = []
    for graph in data:
        label = graph.graph.get('id', None)
        if label:
            labels.append(label)
    # transform input into sparse vectors
    data_matrix = vectorizer.transform(data)

    # labels
    if not labels:
        labels = [str(i) for i in range(data_matrix.shape[0])]

    # embed high dimensional sparse vectors in 2D
    from sklearn import metrics
    from scipy.cluster.hierarchy import linkage, dendrogram
    distance_matrix = metrics.pairwise.pairwise_distances(data_matrix)
    linkage_matrix = linkage(distance_matrix, method=method)
    plt.figure(figsize=(size, size))
    dendrogram(linkage_matrix,
               color_threshold=color_threshold,
               labels=labels,
               orientation='right')
    if filename is not None:
        plt.savefig(filename)
    else:
        plt.show()


def plot_embedding(data_matrix, y,
                   labels=None,
                   image_file_name=None,
                   title=None,
                   cmap='rainbow',
                   density=False):
    """plot_embedding."""
    import matplotlib.pyplot as plt
    from matplotlib import offsetbox
    from PIL import Image
    from eden.embedding import embed_dat_matrix_two_dimensions

    if title is not None:
        plt.title(title)
    if density:
        embed_dat_matrix_two_dimensions(data_matrix,
                                        y=y,
                                        instance_colormap=cmap)
    else:
        plt.scatter(data_matrix[:, 0], data_matrix[:, 1],
                    c=y,
                    cmap=cmap,
                    alpha=.7,
                    s=30,
                    edgecolors='black')
        plt.xticks([])
        plt.yticks([])
        plt.axis('off')
    if image_file_name is not None:
        num_instances = data_matrix.shape[0]
        ax = plt.subplot(111)
        for i in range(num_instances):
            img = Image.open(image_file_name + str(i) + '.png')
            imagebox = offsetbox.AnnotationBbox(
                offsetbox.OffsetImage(img, zoom=1),
                data_matrix[i],
                pad=0,
                frameon=False)
            ax.add_artist(imagebox)
    if labels is not None:
        for id in range(data_matrix.shape[0]):
            label = str(labels[id])
            x = data_matrix[id, 0]
            y = data_matrix[id, 1]
            plt.annotate(label,
                         xy=(x, y),
                         xytext=(0, 0),
                         textcoords='offset points')


def plot_embeddings(data_matrix, y,
                    labels=None,
                    save_image_file_name=None,
                    image_file_name=None,
                    size=20,
                    cmap='rainbow',
                    density=False,
                    knn=16,
                    knn_density=16,
                    k_threshold=0.9,
                    metric='rbf',
                    **args):
    """plot_embeddings."""
    import matplotlib.pyplot as plt
    import time

    plt.figure(figsize=(size, size))

    start = time.time()
    if data_matrix.shape[1] > 2:
        from sklearn.decomposition import TruncatedSVD
        data_matrix_ = TruncatedSVD(n_components=2).fit_transform(data_matrix)
    else:
        data_matrix_ = data_matrix
    duration = time.time() - start
    plt.subplot(221)
    plot_embedding(data_matrix_, y,
                   labels=labels,
                   title="SVD (%.1f sec)" % duration,
                   cmap=cmap,
                   density=density,
                   image_file_name=image_file_name)

    start = time.time()
    from sklearn import manifold
    from sklearn.metrics.pairwise import pairwise_distances
    distance_matrix = pairwise_distances(data_matrix)
    data_matrix_ = manifold.MDS(n_components=2,
                                n_init=1,
                                max_iter=100,
                                dissimilarity='precomputed').fit_transform(
        distance_matrix)
    duration = time.time() - start
    plt.subplot(222)
    plot_embedding(data_matrix_, y,
                   labels=labels,
                   title="MDS (%.1f sec)" % duration,
                   cmap=cmap,
                   density=density,
                   image_file_name=image_file_name)

    start = time.time()
    from sklearn import manifold
    data_matrix_ = manifold.TSNE(n_components=2,
                                 init='random',
                                 random_state=0).fit_transform(data_matrix)
    duration = time.time() - start
    plt.subplot(223)
    plot_embedding(data_matrix_, y,
                   labels=labels,
                   title="t-SNE (%.1f sec)" % duration,
                   cmap=cmap,
                   density=density,
                   image_file_name=image_file_name)

    start = time.time()
    from eden.embedding import quick_shift_tree_embedding
    tree_embedding_knn = knn
    data_matrix_ = quick_shift_tree_embedding(data_matrix,
                                              knn=tree_embedding_knn,
                                              knn_density=knn_density,
                                              k_threshold=k_threshold,
                                              **args)
    duration = time.time() - start
    plt.subplot(224)
    plot_embedding(data_matrix_,
                   y,
                   labels=labels,
                   title="KQST knn=%d (%.1f sec)" %
                         (knn, duration), cmap=cmap, density=density,
                   image_file_name=image_file_name)

    if save_image_file_name:
        plt.savefig(save_image_file_name)
    else:
        plt.show()


def heatmap(values, xlabel, ylabel, xticklabels, yticklabels, cmap=None,
            vmin=None, vmax=None, ax=None, fmt="%0.2f"):
    """heatmap."""
    if ax is None:
        ax = plt.gca()
    # plot the mean cross-validation scores
    img = ax.pcolor(values, cmap=cmap, vmin=vmin, vmax=vmax)
    img.update_scalarmappable()
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_xticks(np.arange(len(xticklabels)) + .5)
    ax.set_yticks(np.arange(len(yticklabels)) + .5)
    ax.set_xticklabels(xticklabels)
    ax.set_yticklabels(yticklabels)
    ax.set_aspect(1)

    for p, color, value in zip(img.get_paths(),
                               img.get_facecolors(),
                               img.get_array()):
        x, y = p.vertices[:-2, :].mean(0)
        if np.mean(color[:3]) > 0.5:
            c = 'k'
        else:
            c = 'w'
        ax.text(x, y, fmt % value, color=c, ha="center", va="center")
    return img


def plot_confusion_matrix(y_true, y_pred, size=None, normalize=False):
    """plot_confusion_matrix."""
    cm = confusion_matrix(y_true, y_pred)
    fmt = "%d"
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        fmt = "%.2f"
    xticklabels = list(sorted(set(y_pred)))
    yticklabels = list(sorted(set(y_true)))
    if size is not None:
        plt.figure(figsize=(size, size))
    heatmap(cm, xlabel='Predicted label', ylabel='True label',
            xticklabels=xticklabels, yticklabels=yticklabels,
            cmap=plt.cm.Blues, fmt=fmt)
    if normalize:
        plt.title("Confusion matrix (norm.)")
    else:
        plt.title("Confusion matrix")
    plt.gca().invert_yaxis()


def plot_confusion_matrices(y_true, y_pred, size=12):
    """plot_confusion_matrices."""
    plt.figure(figsize=(size, size))
    plt.subplot(121)
    plot_confusion_matrix(y_true, y_pred, normalize=False)
    plt.subplot(122)
    plot_confusion_matrix(y_true, y_pred, normalize=True)
    plt.tight_layout(w_pad=5)
    plt.show()


def plot_precision_recall_curve(y_true, y_score, size=None):
    """plot_precision_recall_curve."""
    precision, recall, thresholds = precision_recall_curve(y_true, y_score)
    if size is not None:
        plt.figure(figsize=(size, size))
        plt.axis('equal')
    plt.plot(recall, precision, lw=2, color='navy')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim([-0.05, 1.05])
    plt.xlim([-0.05, 1.05])
    plt.grid()
    plt.title('Precision-Recall AUC={0:0.2f}'.format(average_precision_score(
        y_true, y_score)))


def plot_roc_curve(y_true, y_score, size=None):
    """plot_roc_curve."""
    false_positive_rate, true_positive_rate, thresholds = roc_curve(
        y_true, y_score)
    if size is not None:
        plt.figure(figsize=(size, size))
        plt.axis('equal')
    plt.plot(false_positive_rate, true_positive_rate, lw=2, color='navy')
    plt.plot([0, 1], [0, 1], color='gray', lw=1, linestyle='--')
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.ylim([-0.05, 1.05])
    plt.xlim([-0.05, 1.05])
    plt.grid()
    plt.title('Receiver operating characteristic AUC={0:0.2f}'.format(
        roc_auc_score(y_true, y_score)))


def plot_aucs(y_true, y_score, size=12):
    """plot_confusion_matrices."""
    plt.figure(figsize=(size, size / 2.0))
    plt.subplot(121, aspect='equal')
    plot_roc_curve(y_true, y_score)
    plt.subplot(122, aspect='equal')
    plot_precision_recall_curve(y_true, y_score)
    plt.tight_layout(w_pad=5)
    plt.show()