# wdecoster
"""
This module provides functions for plotting data extracted from Oxford Nanopore sequencing
reads and alignments, but some of it's functions can also be used for other applications.


FUNCTIONS
* Check if a specified color is a valid matplotlib color
check_valid_color(color)
* Check if a specified output format is valid
checkvalidFormat(format)
* Create a bivariate plot with dots, hexbins and/or kernel density estimates.
Also arguments for specifying axis names, color and xlim/ylim
scatter(x, y, names, path, color, format, plots, stat=None, log=False, minvalx=0, minvaly=0)
* Create cumulative yield plot and evaluate read length and quality over time
timePlots(df, path, color, format)
* Create length distribution histogram and density curve
lengthPlots(array, name, path, n50, color, format, log=False)
* Create flowcell physical layout in numpy array
makeLayout()
* Present the activity (number of reads) per channel on the flowcell as a heatmap
spatialHeatmap(array, title, path, color, format)

"""


import logging
import sys
import pandas as pd
import numpy as np
from collections import namedtuple
from nanoplotter.plot import Plot
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import colors as mcolors
import seaborn as sns
from pauvre.marginplot import margin_plot
from nanoplotter.timeplots import time_plots
from nanoplotter.spatial_heatmap import spatial_heatmap
from matplotlib import cm
import plotly
import plotly.graph_objs as go


def check_valid_color(color):
    """Check if the color provided by the user is valid.

    If color is invalid the default is returned.
    """
    if color in list(mcolors.CSS4_COLORS.keys()) + ["#4CB391"]:
        logging.info("NanoPlot:  Valid color {}.".format(color))
        return color
    else:
        logging.info("NanoPlot:  Invalid color {}, using default.".format(color))
        sys.stderr.write("Invalid color {}, using default.\n".format(color))
        return "#4CB391"


def check_valid_colormap(colormap):
    """Check if the colormap provided by the user is valid.

    If colormap is invalid the default is returned.
    """
    if colormap in list(cm.cmap_d.keys()):
        logging.info("NanoPlot:  Valid colormap {}.".format(colormap))
        return colormap
    else:
        logging.info("NanoPlot:  Invalid colormap {}, using default.".format(colormap))
        sys.stderr.write("Invalid colormap {}, using default.\n".format(colormap))
        return "Greens"


def check_valid_format(figformat):
    """Check if the specified figure format is valid.

    If format is invalid the default is returned.
    Probably installation-dependent
    """
    fig = plt.figure()
    if figformat in list(fig.canvas.get_supported_filetypes().keys()):
        logging.info("NanoPlot:  valid output format {}".format(figformat))
        return figformat
    else:
        logging.info("NanoPlot:  invalid output format {}".format(figformat))
        sys.stderr.write("Invalid format {}, using default.\n".format(figformat))
        return "png"


def plot_settings(plot_settings, dpi):
    sns.set(**plot_settings)
    mpl.rcParams['savefig.dpi'] = dpi


def scatter(x, y, names, path, plots, color="#4CB391", figformat="png",
            stat=None, log=False, minvalx=0, minvaly=0, title=None,
            plot_settings={}, xmax=None, ymax=None):
    """Create bivariate plots.

    Create four types of bivariate plots of x vs y, containing marginal summaries
    -A scatter plot with histograms on axes
    -A hexagonal binned plot with histograms on axes
    -A kernel density plot with density curves on axes
    -A pauvre-style plot using code from https://github.com/conchoecia/pauvre
    """
    logging.info("NanoPlot:  Creating {} vs {} plots using statistics from {} reads.".format(
        names[0], names[1], x.size))
    if not contains_variance([x, y], names):
        return []
    sns.set(style="ticks", **plot_settings)
    maxvalx = xmax or np.amax(x)
    maxvaly = ymax or np.amax(y)

    plots_made = []

    if plots["hex"]:
        if log:
            hex_plot = Plot(
                path=path + "_loglength_hex." + figformat,
                title="{} vs {} plot using hexagonal bins "
                      "after log transformation of read lengths".format(names[0], names[1]))
        else:
            hex_plot = Plot(
                path=path + "_hex." + figformat,
                title="{} vs {} plot using hexagonal bins".format(names[0], names[1]))
        plot = sns.jointplot(
            x=x,
            y=y,
            kind="hex",
            color=color,
            stat_func=stat,
            space=0,
            xlim=(minvalx, maxvalx),
            ylim=(minvaly, maxvaly),
            height=10)
        plot.set_axis_labels(names[0], names[1])
        if log:
            ticks = [10**i for i in range(10) if not 10**i > 10 * (10**maxvalx)]
            plot.ax_joint.set_xticks(np.log10(ticks))
            plot.ax_marg_x.set_xticks(np.log10(ticks))
            plot.ax_joint.set_xticklabels(ticks)
        plt.subplots_adjust(top=0.90)
        plot.fig.suptitle(title or "{} vs {} plot".format(names[0], names[1]), fontsize=25)
        hex_plot.fig = plot
        hex_plot.save(format=figformat)
        plots_made.append(hex_plot)

    sns.set(style="darkgrid", **plot_settings)
    if plots["dot"]:
        if log:
            dot_plot = Plot(
                path=path + "_loglength_dot." + figformat,
                title="{} vs {} plot using dots "
                      "after log transformation of read lengths".format(names[0], names[1]))
        else:
            dot_plot = Plot(
                path=path + "_dot." + figformat,
                title="{} vs {} plot using dots".format(names[0], names[1]))
        plot = sns.jointplot(
            x=x,
            y=y,
            kind="scatter",
            color=color,
            stat_func=stat,
            xlim=(minvalx, maxvalx),
            ylim=(minvaly, maxvaly),
            space=0,
            height=10,
            joint_kws={"s": 1})
        plot.set_axis_labels(names[0], names[1])
        if log:
            ticks = [10**i for i in range(10) if not 10**i > 10 * (10**maxvalx)]
            plot.ax_joint.set_xticks(np.log10(ticks))
            plot.ax_marg_x.set_xticks(np.log10(ticks))
            plot.ax_joint.set_xticklabels(ticks)
        plt.subplots_adjust(top=0.90)
        plot.fig.suptitle(title or "{} vs {} plot".format(names[0], names[1]), fontsize=25)
        dot_plot.fig = plot
        dot_plot.save(format=figformat)
        plots_made.append(dot_plot)

    if plots["kde"]:
        idx = np.random.choice(x.index, min(2000, len(x)), replace=False)
        if log:
            kde_plot = Plot(
                path=path + "_loglength_kde." + figformat,
                title="{} vs {} plot using a kernel density estimation "
                      "after log transformation of read lengths".format(names[0], names[1]))
        else:
            kde_plot = Plot(
                path=path + "_kde." + figformat,
                title="{} vs {} plot using a kernel density estimation".format(names[0], names[1]))
        plot = sns.jointplot(
            x=x[idx],
            y=y[idx],
            kind="kde",
            clip=((0, np.Inf), (0, np.Inf)),
            xlim=(minvalx, maxvalx),
            ylim=(minvaly, maxvaly),
            space=0,
            color=color,
            stat_func=stat,
            shade_lowest=False,
            height=10)
        plot.set_axis_labels(names[0], names[1])
        if log:
            ticks = [10**i for i in range(10) if not 10**i > 10 * (10**maxvalx)]
            plot.ax_joint.set_xticks(np.log10(ticks))
            plot.ax_marg_x.set_xticks(np.log10(ticks))
            plot.ax_joint.set_xticklabels(ticks)
        plt.subplots_adjust(top=0.90)
        plot.fig.suptitle(title or "{} vs {} plot".format(names[0], names[1]), fontsize=25)
        kde_plot.fig = plot
        kde_plot.save(format=figformat)
        plots_made.append(kde_plot)

    if plots["pauvre"] and names == ['Read lengths', 'Average read quality'] and log is False:
        pauvre_plot = Plot(
            path=path + "_pauvre." + figformat,
            title="{} vs {} plot using pauvre-style @conchoecia".format(names[0], names[1]))
        sns.set(style="white", **plot_settings)
        margin_plot(df=pd.DataFrame({"length": x, "meanQual": y}),
                    Y_AXES=False,
                    title=title or "Length vs Quality in Pauvre-style",
                    plot_maxlen=None,
                    plot_minlen=0,
                    plot_maxqual=None,
                    plot_minqual=0,
                    lengthbin=None,
                    qualbin=None,
                    BASENAME="whatever",
                    path=pauvre_plot.path,
                    fileform=[figformat],
                    dpi=600,
                    TRANSPARENT=True,
                    QUIET=True)
        plots_made.append(pauvre_plot)
    plt.close("all")
    return plots_made


def contains_variance(arrays, names):
    """
    Make sure both arrays for bivariate ("scatter") plot have a stddev > 0
    """
    for ar, name in zip(arrays, names):
        if np.std(ar) == 0:
            sys.stderr.write(
                "No variation in '{}', skipping bivariate plots.\n".format(name.lower()))
            logging.info("NanoPlot:  No variation in {}, skipping bivariate plot".format(name))
            return False
    else:
        return True


def length_plots(array, name, path, title=None, n50=None, color="#4CB391", figformat="png"):
    """Create histogram of normal and log transformed read lengths."""
    logging.info("NanoPlot:  Creating length plots for {}.".format(name))
    maxvalx = np.amax(array)
    if n50:
        logging.info("NanoPlot:  Using {} reads with read length N50 of {}bp and maximum of {}bp."
                     .format(array.size, n50, maxvalx))
    else:
        logging.info("NanoPlot:  Using {} reads maximum of {}bp.".format(array.size, maxvalx))

    plots = []
    HistType = namedtuple('HistType', 'weight name ylabel')
    for h_type in [HistType(None, "", "Number of reads"),
                   HistType(array, "Weighted ", "Number of bases")]:
        histogram = Plot(
            path=path + h_type.name.replace(" ", "_") + "Histogram" +
            name.replace(' ', '') + "." + figformat,
            title=h_type.name + "Histogram of read lengths")
        ax = sns.distplot(
            a=array,
            kde=False,
            hist=True,
            bins=max(round(int(maxvalx) / 500), 10),
            color=color,
            hist_kws=dict(weights=h_type.weight,
                          edgecolor=color,
                          linewidth=0.2,
                          alpha=0.8))
        if n50:
            plt.axvline(n50)
            plt.annotate('N50', xy=(n50, np.amax([h.get_height() for h in ax.patches])), size=8)
        ax.set(
            xlabel='Read length',
            ylabel=h_type.ylabel,
            title=title or histogram.title)
        plt.ticklabel_format(style='plain', axis='y')
        histogram.fig = ax.get_figure()
        histogram.save(format=figformat)
        plt.close("all")

        log_histogram = Plot(
            path=path + h_type.name.replace(" ", "_") + "LogTransformed_Histogram" +
            name.replace(' ', '') + "." + figformat,
            title=h_type.name + "Histogram of read lengths after log transformation")
        ax = sns.distplot(
            a=np.log10(array),
            kde=False,
            hist=True,
            color=color,
            hist_kws=dict(weights=h_type.weight,
                          edgecolor=color,
                          linewidth=0.2,
                          alpha=0.8))
        ticks = [10**i for i in range(10) if not 10**i > 10 * maxvalx]
        ax.set(
            xticks=np.log10(ticks),
            xticklabels=ticks,
            xlabel='Read length',
            ylabel=h_type.ylabel,
            title=title or log_histogram.title)
        if n50:
            plt.axvline(np.log10(n50))
            plt.annotate('N50', xy=(np.log10(n50), np.amax(
                [h.get_height() for h in ax.patches])), size=8)
        plt.ticklabel_format(style='plain', axis='y')
        log_histogram.fig = ax.get_figure()
        log_histogram.save(format=figformat)
        plt.close("all")
        plots.extend([histogram, log_histogram])
    plots.append(dynamic_histogram(array=array, name=name, path=path, title=title, color=color))
    plots.append(yield_by_minimal_length_plot(array=array,
                                              name=name,
                                              path=path,
                                              title=title,
                                              color=color,
                                              figformat=figformat))
    return plots


def dynamic_histogram(array, name, path, title=None, color="#4CB391"):
    """
    Use plotly to a histogram
    Return html code, but also save as png
    """
    dynhist = Plot(path=path + "Dynamic_Histogram_{}.html".format(name.replace(' ', '_')),
                   title=title or "Dynamic histogram of {}".format(name))
    ylabel = "Number of reads" if len(array) <= 10000 else "Downsampled number of reads"
    dynhist.html, dynhist.fig = plotly_histogram(array=array.sample(min(len(array), 10000)),
                                                 color=color,
                                                 title=dynhist.title,
                                                 xlabel=name,
                                                 ylabel=ylabel)
    dynhist.save()
    return dynhist


def plotly_histogram(array, color="#4CB391", title=None, xlabel=None, ylabel=None):
    data = [go.Histogram(x=array,
                         opacity=0.4,
                         marker=dict(color=color))]
    html = plotly.offline.plot(
        {"data": data,
         "layout": go.Layout(barmode='overlay',
                             title=title,
                             yaxis_title=ylabel,
                             xaxis_title=xlabel)},
        output_type="div",
        show_link=False)
    fig = go.Figure(
        {"data": data,
         "layout": go.Layout(barmode='overlay',
                             title=title)})
    return html, fig


def yield_by_minimal_length_plot(array, name, path,
                                 title=None, color="#4CB391", figformat="png"):
    df = pd.DataFrame(data={"lengths": np.sort(array)[::-1]})
    df["cumyield_gb"] = df["lengths"].cumsum() / 10**9
    yield_by_length = Plot(
        path=path + "Yield_By_Length." + figformat,
        title="Yield by length")
    ax = sns.regplot(
        x='lengths',
        y="cumyield_gb",
        data=df,
        x_ci=None,
        fit_reg=False,
        color=color,
        scatter_kws={"s": 3})
    ax.set(
        xlabel='Read length',
        ylabel='Cumulative yield for minimal length',
        title=title or yield_by_length.title)
    yield_by_length.fig = ax.get_figure()
    yield_by_length.save(format=figformat)
    plt.close("all")
    return yield_by_length


def run_tests():
    import pickle
    df = pickle.load(open("nanotest/sequencing_summary.pickle", "rb"))
    scatter(
        x=df["lengths"],
        y=df["quals"],
        names=['Read lengths', 'Average read quality'],
        path="LengthvsQualityScatterPlot",
        plots={'dot': 1, 'kde': 1, 'hex': 1, 'pauvre': 1},
        plot_settings=dict(font_scale=1))
    time_plots(
        df=df,
        path=".",
        color="#4CB391",
        plot_settings=dict(font_scale=1))
    length_plots(
        array=df["lengths"],
        name="lengths",
        path=".")
    spatial_heatmap(
        array=df["channelIDs"],
        title="Number of reads generated per channel",
        path="ActivityMap_ReadsPerChannel")


if __name__ == "__main__":
    run_tests()