"Implements plot_sweep1d function"
import matplotlib.pyplot as plt
from ..exceptions import InvalidGPConstraint


def assign_axes(var, posys, axes):
    "Assigns axes to posys, creating and formatting if necessary"
    if not hasattr(posys, "__iter__"):
        posys = [posys]
    N = len(posys)
    if axes is None:
        _, axes = plt.subplots(N, 1, sharex="col", figsize=(4.5, 3+1.5*N))
        if N == 1:
            axes = [axes]
        format_and_label_axes(var, posys, axes)
    elif N == 1 and not hasattr(axes, "__len__"):
        axes = [axes]
    return posys, axes


def format_and_label_axes(var, posys, axes, ylabel=True):
    "Formats and labels axes"
    for posy, ax in zip(posys, axes):
        if ylabel:
            if hasattr(posy, "key"):
                ylabel = (posy.key.descr.get("label", posy.key.name)
                          + " [%s]" % posy.key.unitstr(dimless="-"))
            else:
                ylabel = str(posy)
            ax.set_ylabel(ylabel)
        ax.grid(color="0.6")
        # ax.set_frame_on(False)
        for item in [ax.xaxis.label, ax.yaxis.label]:
            item.set_fontsize(12)
        for item in ax.get_xticklabels() + ax.get_yticklabels():
            item.set_fontsize(9)
        ax.tick_params(length=0)
        ax.spines['left'].set_visible(False)
        ax.spines['top'].set_visible(False)
        for i in ax.spines.values():
            i.set_linewidth(0.6)
            i.set_color("0.6")
            i.set_linestyle("dotted")
    xlabel = (var.key.descr.get("label", var.key.name)
              + " [%s]" % var.key.unitstr(dimless="-"))
    ax.set_xlabel(xlabel)  # pylint: disable=undefined-loop-variable
    plt.locator_params(nbins=4)
    plt.subplots_adjust(wspace=0.15)


# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def plot_1dsweepgrid(model, sweeps, posys, origsol=None, tol=0.01, **solveargs):
    """Creates and plots a sweep from an existing model

    Example usage:
    f, _ = plot_sweep_1d(m, {'x': np.linspace(1, 2, 5)}, 'y')
    f.savefig('mysweep.png')
    """
    origsubs = {swept: model.substitutions[swept] for swept in sweeps
                if swept in model.substitutions}
    if origsubs and not origsol:
        try:
            origsol = model.solve(**solveargs)
        except InvalidGPConstraint:
            origsol = model.localsolve(**solveargs)
    if not hasattr(posys, "__iter__"):
        posys = [posys]

    N, S = len(posys), len(sweeps)
    f, axes = plt.subplots(N, S, sharex='col', sharey='row',
                           figsize=(4+2*S, 4+2*N))
    plt.subplots_adjust(hspace=0.15)

    for i, (swept, swept_over) in enumerate(sweeps.items()):
        if isinstance(swept_over, tuple) and len(swept_over) == 2:
            sol = model.autosweep({swept: swept_over}, tol=tol, **solveargs)
        else:
            sol = model.sweep({swept: swept_over}, **solveargs)

        if len(sweeps) == 1:
            if len(posys) == 1:
                subaxes = [axes]
            else:
                subaxes = axes
        elif len(posys) == 1:
            subaxes = [axes[i]]
        else:
            subaxes = axes[:, i]

        sol.plot(posys, subaxes)
        if origsubs:
            for posy, ax in zip(posys, subaxes):
                ax.plot(origsubs[swept], origsol(posy), "ko", markersize=4)
        format_and_label_axes(swept, posys, subaxes, ylabel=(i == 0))
        model.substitutions.update(origsubs)

    return f, axes