# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
"""Helper functions for the figures in the paper."""
import os.path as op
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.font_manager import FontProperties


def plot_qi2(x_grid, ref_pdf, fit_pdf, ref_data, cutoff_idx, out_file=None):
    fig, ax = plt.subplots()

    ax.plot(
        x_grid, ref_pdf, linewidth=2, alpha=0.5, label="background", color="dodgerblue"
    )

    refmax = np.percentile(ref_data, 99.95)
    x_max = x_grid[-1]

    ax.hist(
        ref_data,
        40 * max(int(refmax / x_max), 1),
        fc="dodgerblue",
        histtype="stepfilled",
        alpha=0.2,
        normed=True,
    )
    fit_pdf[fit_pdf > 1.0] = np.nan
    ax.plot(x_grid, fit_pdf, linewidth=2, alpha=0.5, label="chi2", color="darkorange")

    ylims = ax.get_ylim()
    ax.axvline(
        x_grid[-cutoff_idx], ymax=ref_pdf[-cutoff_idx] / ylims[1], color="dodgerblue"
    )
    plt.xlabel('Intensity within "hat" mask')
    plt.ylabel("Frequency")
    ax.set_xlim([0, x_max])
    plt.legend()

    if out_file is None:
        out_file = op.abspath("qi2_plot.svg")

    fig.savefig(out_file, bbox_inches="tight", pad_inches=0, dpi=300)
    return out_file


def plot_batches(fulldata, cols=None, out_file=None, site_labels="left"):
    fulldata = fulldata.sort_values(by=["database", "site"]).copy()
    sites = fulldata.site.values.ravel().tolist()
    if cols is None:
        numdata = fulldata.select_dtypes([np.number])
    else:
        numdata = fulldata[cols]

    numdata = numdata[cols]
    colmin = numdata.min()
    numdata = numdata - colmin
    colmax = numdata.max()
    numdata = numdata / colmax

    fig, ax = plt.subplots(figsize=(20, 10))
    ax.imshow(
        numdata.values, cmap=plt.cm.viridis, interpolation="nearest", aspect="auto"
    )

    locations = []
    spines = []
    fulldata["index"] = range(len(fulldata))
    for site in list(set(sites)):
        indices = fulldata.loc[fulldata.site == site, "index"].values.ravel().tolist()
        locations.append(int(np.average(indices)))
        spines.append(indices[0])

    if site_labels == "right":
        ax.yaxis.tick_right()
        ax.yaxis.set_label_position("right")

    plt.xticks(
        range(numdata.shape[1]), numdata.columns.ravel().tolist(), rotation="vertical"
    )
    plt.yticks(locations, list(set(sites)))
    for line in spines[1:]:
        plt.axhline(y=line, color="w", linestyle="-")
    ax.spines["right"].set_visible(False)
    ax.spines["top"].set_visible(False)
    # ax.spines['left'].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.grid(False)

    ticks_font = FontProperties(
        family="FreeSans", style="normal", size=14, weight="normal", stretch="normal"
    )
    for label in ax.get_yticklabels():
        label.set_fontproperties(ticks_font)

    ticks_font = FontProperties(
        family="FreeSans", style="normal", size=12, weight="normal", stretch="normal"
    )
    for label in ax.get_xticklabels():
        label.set_fontproperties(ticks_font)

    if out_file is not None:
        fig.savefig(out_file, bbox_inches="tight", pad_inches=0, dpi=300)
    return fig


def plot_roc_curve(true_y, prob_y, out_file=None):
    from sklearn.metrics import roc_curve

    fpr, tpr, _ = roc_curve(true_y, prob_y)

    fig = plt.figure()
    plt.plot(fpr, tpr, color="darkorange", lw=2, label="ROC curve")
    plt.plot([0, 1], [0, 1], color="navy", lw=1, linestyle="--")
    plt.xlim([-0.025, 1.025])
    plt.ylim([-0.025, 1.025])
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("RoC Curve")
    if out_file is not None:
        fig.savefig(out_file)
    return fig


def fill_matrix(matrix, width, value="n/a"):
    if matrix.shape[0] < width:
        nraters = matrix.shape[1]
        nas = np.chararray((1, nraters), itemsize=len(value))
        nas[:] = value
        matrix = np.vstack(tuple([matrix] + [nas] * (width - matrix.shape[0])))
    return matrix


def plot_raters(dataframe, ax=None, width=101, size=0.40):
    raters = sorted(dataframe.columns.ravel().tolist())
    dataframe["notnan"] = np.any(np.isnan(dataframe[raters]), axis=1).astype(int)
    dataframe = dataframe.sort_values(by=["notnan"] + raters, ascending=True)
    for rater in raters:
        dataframe[rater] = dataframe[[rater]].astype(str)

    matrix = dataframe.as_matrix()
    nsamples, nraters = dataframe.shape
    matrix = fill_matrix(matrix, width)

    nblocks = 1
    if matrix.shape[0] > width:
        matrices = []
        nblocks = (matrix.shape[0] // width) + 1

        nas = np.chararray((width, 1), itemsize=3)
        nas[:] = "n/a"
        for i in range(nblocks):
            if i > 0:
                matrices.append(nas)
            matrices.append(matrix[i * width:(i + 1) * width, ...])

        matrices[-1] = fill_matrix(matrices[-1], width)
        matrix = np.hstack(tuple(matrices))

    palette = {"1.0": "limegreen", "0.0": "dimgray", "-1.0": "tomato", "n/a": "w"}

    ax = ax if ax is not None else plt.gca()

    # ax.patch.set_facecolor('gray')
    ax.set_aspect("equal", "box")
    ax.xaxis.set_major_locator(plt.NullLocator())
    ax.yaxis.set_major_locator(plt.NullLocator())
    nrows = ((nsamples - 1) // width) + 1
    xlims = (-14.0, width)
    ylims = (-0.07 * nraters, nrows * nraters + nraters * 0.07 + (nrows - 1))

    ax.set_xlim(xlims)
    ax.set_ylim(ylims)

    offset = 0.5 * (size / 0.40)
    for (x, y), w in np.ndenumerate(matrix):
        if w not in list(palette.keys()):
            w = "n/a"

        color = palette[w]
        rect = plt.Circle(
            [x + offset, y + offset], size, facecolor=color, edgecolor=color
        )
        ax.add_patch(rect)

    # text_x = ((nsamples - 1) % width) + 6.5
    text_x = -8.5
    for i, rname in enumerate(raters):
        nsamples = sum(dataframe[rname] != "n/a")
        good = 100 * sum(dataframe[rname] == "1.0") / nsamples
        bad = 100 * sum(dataframe[rname] == "-1.0") / nsamples

        text_y = 1.5 * i + (nrows - 1) * 2.0
        ax.text(
            text_x,
            text_y,
            "%2.0f%%" % good,
            color="limegreen",
            weight=1000,
            size=16,
            horizontalalignment="right",
            verticalalignment="center",
            transform=ax.transData,
        )
        ax.text(
            text_x + 3.50,
            text_y,
            "%2.0f%%" % max((0.0, 100 - good - bad)),
            color="dimgray",
            weight=1000,
            size=16,
            horizontalalignment="right",
            verticalalignment="center",
            transform=ax.transData,
        )
        ax.text(
            text_x + 7.0,
            text_y,
            "%2.0f%%" % bad,
            color="tomato",
            weight=1000,
            size=16,
            horizontalalignment="right",
            verticalalignment="center",
            transform=ax.transData,
        )

    # ax.autoscale_view()
    ax.invert_yaxis()
    plt.grid(False)

    # Remove and redefine spines
    for side in ["top", "right", "bottom"]:
        # Toggle the spine objects
        ax.spines[side].set_color("none")
        ax.spines[side].set_visible(False)

    ax.spines["left"].set_linewidth(1.5)
    ax.spines["left"].set_color("dimgray")
    # ax.spines["left"].set_position(('data', xlims[0]))

    ax.set_yticks([0.5 * (ylims[0] + ylims[1])])
    ax.tick_params(axis="y", which="major", pad=15)

    ticks_font = FontProperties(
        family="FreeSans", style="normal", size=20, weight="normal", stretch="normal"
    )
    for label in ax.get_yticklabels():
        label.set_fontproperties(ticks_font)

    return ax


def raters_variability_plot(
    mdata,
    figsize=(22, 22),
    width=101,
    out_file=None,
    raters=("rater_1", "rater_2", "rater_3"),
    only_overlap=True,
    rater_names=("Rater 1", "Rater 2a", "Rater 2b"),
):
    if only_overlap:
        mdata = mdata[np.all(~np.isnan(mdata[raters]), axis=1)]
    # Swap raters 2 and 3
    # i, j = cols.index('rater_2'), cols.index('rater_3')
    # cols[j], cols[i] = cols[i], cols[j]
    # mdata.columns = cols

    sites_list = sorted(set(mdata.site.values.ravel().tolist()))
    sites_len = []
    for site in sites_list:
        sites_len.append(len(mdata.loc[mdata.site == site]))

    sites_len, sites_list = zip(*sorted(zip(sites_len, sites_list)))

    blocks = [(slen - 1) // width + 1 for slen in sites_len]
    fig = plt.figure(figsize=figsize)
    gs = GridSpec(
        len(sites_list), 1, width_ratios=[1], height_ratios=blocks, hspace=0.05
    )

    for s, gsel in zip(sites_list, gs):
        ax = plt.subplot(gsel)
        plot_raters(
            mdata.loc[mdata.site == s, raters],
            ax=ax,
            width=width,
            size=0.40 if len(raters) == 3 else 0.80,
        )
        ax.set_yticklabels([s])

    # ax.add_line(Line2D([0.0, width], [8.0, 8.0], color='k'))
    # ax.annotate(
    #     '%d images' % width, xy=(0.5 * width, 8), xycoords='data',
    #     xytext=(0.5 * width, 9), fontsize=20, ha='center', va='top',
    #     arrowprops=dict(arrowstyle='-[,widthB=1.0,lengthB=0.2', lw=1.0)
    # )

    # ax.annotate('QC Prevalences', xy=(0.1, -0.15), xytext=(0.5, -0.1), xycoords='axes fraction',
    #         fontsize=20, ha='center', va='top',
    #         arrowprops=dict(arrowstyle='-[, widthB=3.0, lengthB=0.2', lw=1.0))

    newax = plt.axes([0.6, 0.65, 0.25, 0.16])
    newax.grid(False)
    newax.set_xticklabels([])
    newax.set_xticks([])
    newax.set_yticklabels([])
    newax.set_yticks([])

    nsamples = len(mdata)
    for i, rater in enumerate(raters):
        nsamples = len(mdata) - sum(np.isnan(mdata[rater].values))
        good = 100 * sum(mdata[rater] == 1.0) / nsamples
        bad = 100 * sum(mdata[rater] == -1.0) / nsamples

        text_x = 0.92
        text_y = 0.5 - 0.17 * i
        newax.text(
            text_x - 0.36,
            text_y,
            "%2.1f%%" % good,
            color="limegreen",
            weight=1000,
            size=25,
            horizontalalignment="right",
            verticalalignment="center",
            transform=newax.transAxes,
        )
        newax.text(
            text_x - 0.18,
            text_y,
            "%2.1f%%" % max((0.0, 100 - good - bad)),
            color="dimgray",
            weight=1000,
            size=25,
            horizontalalignment="right",
            verticalalignment="center",
            transform=newax.transAxes,
        )
        newax.text(
            text_x,
            text_y,
            "%2.1f%%" % bad,
            color="tomato",
            weight=1000,
            size=25,
            horizontalalignment="right",
            verticalalignment="center",
            transform=newax.transAxes,
        )

        newax.text(
            1 - text_x,
            text_y,
            rater_names[i],
            color="k",
            size=25,
            horizontalalignment="left",
            verticalalignment="center",
            transform=newax.transAxes,
        )

    newax.text(
        0.5,
        0.95,
        "Imbalance of ratings",
        color="k",
        size=25,
        horizontalalignment="center",
        verticalalignment="top",
        transform=newax.transAxes,
    )
    newax.text(
        0.5,
        0.85,
        "(ABIDE, aggregated)",
        color="k",
        size=25,
        horizontalalignment="center",
        verticalalignment="top",
        transform=newax.transAxes,
    )

    if out_file is None:
        out_file = "raters.svg"

    fname, ext = op.splitext(out_file)
    if ext[1:] not in ["pdf", "svg", "png"]:
        ext = ".svg"
        out_file = fname + ".svg"

    fig.savefig(
        op.abspath(out_file), format=ext[1:], bbox_inches="tight", pad_inches=0, dpi=300
    )
    return fig


def plot_abide_stripplots(
    inputs, figsize=(15, 2), out_file=None, rating_label="rater_1", dpi=100
):
    import seaborn as sn
    from ..classifier.helper import FEATURE_NORM
    from ..classifier.data import read_dataset
    from ..classifier.sklearn.preprocessing import BatchRobustScaler

    sn.set(style="whitegrid")

    mdata = []
    pp_cols = []

    for X, Y, sitename in inputs:
        sitedata, cols = read_dataset(
            X, Y, rate_label=rating_label, binarize=False, site_name=sitename
        )
        sitedata["database"] = [sitename] * len(sitedata)

        if sitename == "DS030":
            sitedata["site"] = [sitename] * len(sitedata)

        mdata.append(sitedata)
        pp_cols.append(cols)

    mdata = pd.concat(mdata)
    pp_cols = pp_cols[0]

    for col in mdata.columns.ravel().tolist():
        if col.startswith("rater_") and col != rating_label:
            del mdata[col]

    mdata = mdata.loc[mdata[rating_label].notnull()]

    for col in ["size_x", "size_y", "size_z", "spacing_x", "spacing_y", "spacing_z"]:
        del mdata[col]
        try:
            pp_cols.remove(col)
        except ValueError:
            pass

    zscored = BatchRobustScaler(by="site", columns=FEATURE_NORM).fit_transform(mdata)

    sites = list(set(mdata.site.values.ravel()))
    nsites = len(sites)

    # palette = ['dodgerblue', 'darkorange']
    palette = ["limegreen", "tomato"]
    if len(set(mdata[[rating_label]].values.ravel().tolist())) == 3:
        palette = ["tomato", "gold", "limegreen"]
    # pp_cols = pp_cols[:5]
    nrows = len(pp_cols)

    fig = plt.figure(figsize=(figsize[0], figsize[1] * nrows))
    # ncols = 2 * (nsites - 1) + 2
    gs = GridSpec(nrows, 4, wspace=0.02)
    gs.set_width_ratios([nsites, len(inputs), len(inputs), nsites])

    for i, colname in enumerate(pp_cols):
        ax_nzs = plt.subplot(gs[i, 0])
        axg_nzs = plt.subplot(gs[i, 1])
        axg_zsc = plt.subplot(gs[i, 2])
        ax_zsc = plt.subplot(gs[i, 3])

        # plots
        sn.stripplot(
            x="site",
            y=colname,
            data=mdata,
            hue=rating_label,
            jitter=0.18,
            alpha=0.6,
            split=True,
            palette=palette,
            ax=ax_nzs,
        )
        sn.stripplot(
            x="site",
            y=colname,
            data=zscored,
            hue=rating_label,
            jitter=0.18,
            alpha=0.6,
            split=True,
            palette=palette,
            ax=ax_zsc,
        )

        sn.stripplot(
            x="database",
            y=colname,
            data=mdata,
            hue=rating_label,
            jitter=0.18,
            alpha=0.6,
            split=True,
            palette=palette,
            ax=axg_nzs,
        )
        sn.stripplot(
            x="database",
            y=colname,
            data=zscored,
            hue=rating_label,
            jitter=0.18,
            alpha=0.6,
            split=True,
            palette=palette,
            ax=axg_zsc,
        )

        ax_nzs.legend_.remove()
        ax_zsc.legend_.remove()
        axg_nzs.legend_.remove()
        axg_zsc.legend_.remove()

        if i == nrows - 1:
            ax_nzs.set_xticklabels(ax_nzs.xaxis.get_majorticklabels(), rotation=80)
            ax_zsc.set_xticklabels(ax_zsc.xaxis.get_majorticklabels(), rotation=80)
            axg_nzs.set_xticklabels(axg_nzs.xaxis.get_majorticklabels(), rotation=80)
            axg_zsc.set_xticklabels(axg_zsc.xaxis.get_majorticklabels(), rotation=80)
        else:
            ax_nzs.set_xticklabels([])
            ax_zsc.set_xticklabels([])
            axg_nzs.set_xticklabels([])
            axg_zsc.set_xticklabels([])

        ax_nzs.set_xlabel("", visible=False)
        ax_zsc.set_xlabel("", visible=False)
        ax_zsc.set_ylabel("", visible=False)
        ax_zsc.yaxis.tick_right()

        axg_nzs.set_yticklabels([])
        axg_nzs.set_xlabel("", visible=False)
        axg_nzs.set_ylabel("", visible=False)
        axg_zsc.set_yticklabels([])
        axg_zsc.set_xlabel("", visible=False)
        axg_zsc.set_ylabel("", visible=False)

        for yt in ax_nzs.yaxis.get_major_ticks()[1:-1]:
            yt.label1.set_visible(False)

        for yt in axg_nzs.yaxis.get_major_ticks()[1:-1]:
            yt.label1.set_visible(False)

        for yt in zip(
            ax_zsc.yaxis.get_majorticklabels(), axg_zsc.yaxis.get_majorticklabels()
        ):
            yt[0].set_visible(False)
            yt[1].set_visible(False)

    if out_file is None:
        out_file = "stripplot.svg"

    fname, ext = op.splitext(out_file)
    if ext[1:] not in ["pdf", "svg", "png"]:
        ext = ".svg"
        out_file = fname + ".svg"

    fig.savefig(
        op.abspath(out_file), format=ext[1:], bbox_inches="tight", pad_inches=0, dpi=dpi
    )
    return fig


def plot_corrmat(in_csv, out_file=None):
    import seaborn as sn

    sn.set(style="whitegrid")

    dataframe = pd.read_csv(in_csv, index_col=False, na_values="n/a", na_filter=False)
    colnames = dataframe.columns.ravel().tolist()

    for col in ["subject_id", "site", "modality"]:
        try:
            colnames.remove(col)
        except ValueError:
            pass

    # Correlation matrix
    corr = dataframe[colnames].corr()
    corr = corr.dropna((0, 1), "all")

    # Generate a mask for the upper triangle
    mask = np.zeros_like(corr, dtype=np.bool)
    mask[np.triu_indices_from(mask)] = True

    # Generate a custom diverging colormap
    cmap = sn.diverging_palette(220, 10, as_cmap=True)

    # Draw the heatmap with the mask and correct aspect ratio
    corrplot = sn.clustermap(
        corr, cmap=cmap, center=0.0, method="average", square=True, linewidths=0.5
    )
    plt.setp(corrplot.ax_heatmap.yaxis.get_ticklabels(), rotation="horizontal")
    # , mask=mask, square=True, linewidths=.5, cbar_kws={"shrink": .5})

    if out_file is None:
        out_file = "corr_matrix.svg"

    fname, ext = op.splitext(out_file)
    if ext[1:] not in ["pdf", "svg", "png"]:
        ext = ".svg"
        out_file = fname + ".svg"

    corrplot.savefig(
        out_file, format=ext[1:], bbox_inches="tight", pad_inches=0, dpi=100
    )
    return corrplot


def plot_histograms(X, Y, rating_label="rater_1", out_file=None):
    import re
    import seaborn as sn
    from ..classifier.data import read_dataset

    sn.set(style="whitegrid")

    mdata, pp_cols = read_dataset(X, Y, rate_label=rating_label)
    mdata["rater"] = mdata[[rating_label]].values.ravel()

    for col in mdata.columns.ravel().tolist():
        if col.startswith("rater_"):
            del mdata[col]

    mdata = mdata.loc[mdata.rater.notnull()]
    zscored = mdata.copy()
    # TODO: zscore_dataset was removed
    # zscored = zscore_dataset(
    #     mdata, excl_columns=['rater', 'size_x', 'size_y', 'size_z',
    #                          'spacing_x', 'spacing_y', 'spacing_z'])
    pat = re.compile(r"^(spacing|summary|size)")
    colnames = [col for col in sorted(pp_cols) if pat.match(col)]

    nrows = len(colnames)
    # palette = ['dodgerblue', 'darkorange']

    fig = plt.figure(figsize=(18, 2 * nrows))
    gs = GridSpec(nrows, 2, hspace=0.2)

    for i, col in enumerate(sorted(colnames)):
        ax_nzs = plt.subplot(gs[i, 0])
        ax_zsd = plt.subplot(gs[i, 1])

        sn.distplot(
            mdata.loc[(mdata.rater == 0), col],
            norm_hist=False,
            label="Accept",
            ax=ax_nzs,
            color="dodgerblue",
        )
        sn.distplot(
            mdata.loc[(mdata.rater == 1), col],
            norm_hist=False,
            label="Reject",
            ax=ax_nzs,
            color="darkorange",
        )
        ax_nzs.legend()

        sn.distplot(
            zscored.loc[(zscored.rater == 0), col],
            norm_hist=False,
            label="Accept",
            ax=ax_zsd,
            color="dodgerblue",
        )
        sn.distplot(
            zscored.loc[(zscored.rater == 1), col],
            norm_hist=False,
            label="Reject",
            ax=ax_zsd,
            color="darkorange",
        )

        alldata = mdata[[col]].values.ravel().tolist()
        minv = np.percentile(alldata, 0.2)
        maxv = np.percentile(alldata, 99.8)
        ax_nzs.set_xlim([minv, maxv])

        alldata = zscored[[col]].values.ravel().tolist()
        minv = np.percentile(alldata, 0.2)
        maxv = np.percentile(alldata, 99.8)
        ax_zsd.set_xlim([minv, maxv])

    if out_file is None:
        out_file = "histograms.svg"

    fname, ext = op.splitext(out_file)
    if ext[1:] not in ["pdf", "svg", "png"]:
        ext = ".svg"
        out_file = fname + ".svg"

    fig.savefig(out_file, format=ext[1:], bbox_inches="tight", pad_inches=0, dpi=100)
    return fig


def inter_rater_variability(
    y1, y2, figsize=(4, 4), normed=True, raters=None, labels=None, out_file=None
):
    plt.rcParams["font.family"] = "sans-serif"
    plt.rcParams["font.sans-serif"] = "FreeSans"
    plt.rcParams["font.size"] = 25
    plt.rcParams["axes.labelsize"] = 20
    plt.rcParams["axes.titlesize"] = 25
    plt.rcParams["xtick.labelsize"] = 15
    plt.rcParams["ytick.labelsize"] = 15
    # fig = plt.figure(figsize=(3.5, 3))

    if raters is None:
        raters = ["Rater 1", "Rater 2"]

    if labels is None:
        labels = ["exclude", "doubtful", "accept"]

    fig, ax = plt.subplots(figsize=figsize)
    ax.set_aspect("equal")

    nbins = len(set(y1 + y2))
    if nbins == 2:
        xlabels = [labels[0], labels[-1]]
        ylabels = [labels[0], labels[-1]]

    # Reverse x
    y1 = (np.array(y1) * -1).tolist()
    ylabels = labels
    xlabels = list(reversed(labels))

    hist, xbins, ybins, _ = plt.hist2d(y1, y2, bins=nbins, cmap=plt.cm.viridis)
    xcenters = (xbins[:-1] + xbins[1:]) * 0.5
    ycenters = (ybins[:-1] + ybins[1:]) * 0.5

    total = np.sum(hist.reshape(-1))
    celfmt = "%d%%" if normed else "%d"
    for i, x in enumerate(xcenters):
        for j, y in enumerate(ycenters):
            val = hist[i, j]
            if normed:
                val = 100 * hist[i, j] / total

            ax.text(
                x,
                y,
                celfmt % val,
                ha="center",
                va="center",
                fontweight="bold",
                color="w" if hist[i, j] < 15 else "k",
            )

    # plt.colorbar(pad=0.10)
    plt.grid(False)
    plt.xticks(xcenters, xlabels)
    plt.yticks(ycenters, ylabels, rotation="vertical", va="center")
    plt.xlabel(raters[0])
    plt.ylabel(raters[1])
    ax.yaxis.tick_right()
    ax.xaxis.set_label_position("top")

    if out_file is not None:
        fig.savefig(out_file, bbox_inches="tight", pad_inches=0, dpi=300)

    return fig


def plot_artifact(
    image_path,
    figsize=(20, 20),
    vmax=None,
    cut_coords=None,
    display_mode="ortho",
    size=None,
):
    import nilearn.plotting as nplt

    fig = plt.figure(figsize=figsize)
    nplt_disp = nplt.plot_anat(
        image_path,
        display_mode=display_mode,
        cut_coords=cut_coords,
        vmax=vmax,
        figure=fig,
        annotate=False,
    )

    if size is None:
        size = figsize[0] * 6

    bg_color = "k"
    fg_color = "w"
    ax = fig.gca()
    ax.text(
        0.1,
        0.95,
        "L",
        transform=ax.transAxes,
        horizontalalignment="left",
        verticalalignment="top",
        size=size,
        bbox=dict(boxstyle="square,pad=0", ec=bg_color, fc=bg_color, alpha=1),
        color=fg_color,
    )

    ax.text(
        0.9,
        0.95,
        "R",
        transform=ax.transAxes,
        horizontalalignment="right",
        verticalalignment="top",
        size=size,
        bbox=dict(boxstyle="square,pad=0", ec=bg_color, fc=bg_color),
        color=fg_color,
    )

    return nplt_disp, ax


def figure1_a(
    image_path, display_mode="y", vmax=300, cut_coords=None, figsize=(20, 20)
):
    import matplotlib.patches as patches

    if cut_coords is None:
        cut_coords = [15]

    disp, ax = plot_artifact(
        image_path,
        display_mode=display_mode,
        vmax=vmax,
        cut_coords=cut_coords,
        figsize=figsize,
    )

    ax.add_patch(
        patches.Arrow(
            0.2,  # x
            0.2,  # y
            0.1,  # dx
            0.6,  # dy
            width=0.25,
            color="tomato",
            transform=ax.transAxes,
        )
    )

    ax.add_patch(
        patches.Arrow(
            0.8,  # x
            0.2,  # y
            -0.1,  # dx
            0.6,  # dy
            width=0.25,
            color="tomato",
            transform=ax.transAxes,
        )
    )
    return disp


def figure1_b(
    image_path, display_mode="z", vmax=400, cut_coords=None, figsize=(20, 20)
):
    import matplotlib.patches as patches

    if cut_coords is None:
        cut_coords = [-24]

    disp, ax = plot_artifact(
        image_path,
        display_mode=display_mode,
        vmax=vmax,
        cut_coords=cut_coords,
        figsize=figsize,
    )

    ax.add_patch(
        patches.Arrow(
            0.02,  # x
            0.55,  # y
            0.1,  # dx
            0.0,  # dy
            width=0.10,
            color="tomato",
            transform=ax.transAxes,
        )
    )
    ax.add_patch(
        patches.Arrow(
            0.98,  # x
            0.55,  # y
            -0.1,  # dx
            0.0,  # dy
            width=0.10,
            color="tomato",
            transform=ax.transAxes,
        )
    )

    ax.add_patch(
        patches.Arrow(
            0.02,  # x
            0.80,  # y
            0.1,  # dx
            0.0,  # dy
            width=0.10,
            color="limegreen",
            transform=ax.transAxes,
        )
    )
    ax.add_patch(
        patches.Arrow(
            0.98,  # x
            0.80,  # y
            -0.1,  # dx
            0.0,  # dy
            width=0.10,
            color="limegreen",
            transform=ax.transAxes,
        )
    )
    return disp


def figure1(artifact1, artifact2, out_file):
    from .svg import svg2str, combine_svg

    combine_svg(
        [svg2str(figure1_b(artifact2)), svg2str(figure1_a(artifact1))], axis="vertical"
    ).save(out_file)