import matplotlib

matplotlib.use("Agg")  # forces matplotlib to not launch X11 window
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import seaborn as sns
import pandas as pd
import cv2

from scipy import stats
from matplotlib.patches import Patch
from matplotlib.lines import Line2D

__all__ = ["focus_plot"]


def make_scatter(twas_df):
    """
    Make a scatterplot of zscore values with gene names as xtick labels.

    :param twas_df: pandas.DataFrame containing at least zscores and gene-names

    :return: numpy.ndarray (RGB) formatted scatterplot of zscores
    """
    mpl.rcParams["figure.figsize"] = [6.4, 4.8]
    fig, ax = plt.subplots()

    size_arr = []
    color_arr = []
    custom_palette = ["#e4f1fe", "#bdd7e7", "#6baed6", "#3182bd", "#08519c"]
    size_palette = [2, 4, 8, 10, 12]

    for i, row in twas_df.iterrows():
        pip = row["pip"]
        if pip < 0.2:
            color_arr.append(custom_palette[0])
            size_arr.append(size_palette[0])
        elif 0.2 <= pip < 0.40:
            color_arr.append(custom_palette[1])
            size_arr.append(size_palette[1])
        elif 0.40 <= pip < 0.60:
            color_arr.append(custom_palette[2])
            size_arr.append(size_palette[2])
        elif 0.60 <= pip < 0.80:
            color_arr.append(custom_palette[3])
            size_arr.append(size_palette[3])
        else:
            color_arr.append(custom_palette[4])
            size_arr.append(size_palette[4])

    n_rows = len(twas_df.index)
    x_values = np.arange(1, n_rows + 1)
    ax.scatter(x=x_values, y=twas_df["logp"].values, s=size_arr, c=color_arr, edgecolor="black")

    # create legend
    legend_elements = [
        Line2D([0], [0], marker="o", color="w", label="[0.0, 0.2)",
               markerfacecolor=custom_palette[0], markersize=size_palette[0], markeredgecolor="k"),
        Line2D([1], [1], marker="o", color="w", label="[0.2, 0.4)",
               markerfacecolor=custom_palette[1], markersize=size_palette[1], markeredgecolor="k"),
        Line2D([2], [2], marker="o", color="w", label="[0.4, 0.6)",
               markerfacecolor=custom_palette[2], markersize=size_palette[2], markeredgecolor="k"),
        Line2D([3], [3], marker="o", color="w", label="[0.6, 0.8)",
               markerfacecolor=custom_palette[3], markersize=size_palette[3], markeredgecolor="k"),
        Line2D([4], [4], marker="o", color="w", label="[0.8, 1.0]",
               markerfacecolor=custom_palette[4], markersize=size_palette[4], markeredgecolor="k")]
    plt.legend(handles=legend_elements, loc="best", title="PIP")

    n_rows = len(twas_df.index)
    gene_names = twas_df["mol_name"].values

    plt.xticks(np.arange(1, n_rows + 1, 1.0))
    plt.xticks(x_values, labels=gene_names, rotation="vertical")
    plt.subplots_adjust(bottom=0.25)  # make room for xlabel

    plt.ylabel("$-\log_{10}(p)$", fontsize=18)
    plt.xlabel("")

    # drop right/top axis bars
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)

    fig.tight_layout()
    fig.canvas.draw()

    # save as numpy.ndarray format
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
    scatter_plot = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    return scatter_plot


def heatmap(wcor):
    """
    Make a scatterplot of zscore values with gene names as xtick labels.

    :param wcor: numpy.ndarray matrix of sample correlation structure for predicted expression

    :return: numpy.ndarray (RGB) formatted heatmap of correlation structure
    """
    mpl.rcParams["figure.figsize"] = [6.4, 6.4]
    fig = plt.figure()
    fig.subplots_adjust(bottom=0.20, left=0.28)
    mask = np.zeros_like(wcor, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True
    ax = sns.heatmap(wcor, mask=mask, cmap="RdBu_r", square=True,
                     linewidths=0, cbar=False, xticklabels=False, yticklabels=False, ax=None,
                     vmin=-1, vmax=1)
    ax.margins(2)
    ax.set_aspect("equal", "box")
    fig.canvas.draw()

    # save image as numpy array
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
    img = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    # rotate heatmap to make upside-down triangle shape
    rows, cols, ch = img.shape
    M = cv2.getRotationMatrix2D((cols / 2, rows / 2), 45, 1)
    dst = cv2.warpAffine(img, M, (cols, rows), borderMode=cv2.BORDER_CONSTANT,
                         borderValue=(255, 255, 255))

    # trim extra whitespace
    crop_img = dst[int(dst.shape[0] / 2.5):int(dst.shape[0] / 1.1)]

    return crop_img


def heatmap_colorbar():
    """
    Make a colorbar legend for correlation heatmap for range [-1,1].

    :param: None

    :return: numpy.ndarray (RGB) formatted colorbar for range [-1,1]
    """
    fig = plt.figure(figsize=(3.0, 1.0))
    ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
    norm = mpl.colors.Normalize(vmin=-1, vmax=1)
    cmap = mpl.cm.get_cmap("RdBu_r")
    mpl.colorbar.ColorbarBase(ax1, cmap=cmap, norm=norm, orientation="horizontal", ticks=[-1, -.50, 0, .50, 1])

    # convert to numpy array format
    fig.canvas.draw()
    data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
    colorbar = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

    # reshape and fill to match width of heatmap
    new_size = (colorbar.shape[0], colorbar.shape[1])
    desired_size_w = 640
    desired_size_h = 100
    delta_w = desired_size_w - new_size[1]
    delta_h = desired_size_h - new_size[0]
    top, bottom = delta_h // 2, delta_h - (delta_h // 2)
    left, right = delta_w // 2, delta_w - (delta_w // 2)
    left += 18  # make colorbar line up with heatmap
    right -= 18

    # fill in extra space with whitespace
    color = [255, 255, 255]
    colorbar = cv2.copyMakeBorder(colorbar, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)

    return colorbar


def focus_plot(wcor, twas_df):
    """
    Plot zscores and local correlation structure for TWAS fine-mapping.

    :param twas_df: pandas.DataFrame containing at least zscores and gene-names
    :param wcor: numpy.ndarray matrix of sample correlation structure for predicted expression

    :return: matplotlib.figure.Figure object containing plot of zscores and local correlation heatmap
    """

    # filter out the null model
    twas_df = twas_df[twas_df["ens_gene_id"] != "NULL.MODEL"]

    # add p-value to compute -log10 p
    twas_df = twas_df.assign(logp=-stats.chi2.logsf(twas_df["twas_z"].values ** 2, 1))

    scatter_plot = make_scatter(twas_df)
    crop_img = heatmap(wcor)
    colorbar = heatmap_colorbar()

    # combine plots
    numpy_vertical_concat = np.concatenate((scatter_plot, crop_img), axis=0)
    numpy_vertical_concat = np.concatenate((numpy_vertical_concat, colorbar), axis=0)
    numpy_vertical_concat = cv2.resize(numpy_vertical_concat, (0, 0), fx=2.5, fy=2.5)

    fig = plt.figure()
    plt.imshow(numpy_vertical_concat)
    plt.title("")
    plt.axis("off")
    plot_arr = [fig]
    plt.close("all")

    return plot_arr