import matplotlib
matplotlib.use('agg')

import matplotlib.pyplot as plt


################################################################################
# Graphing Functions #
################################################################################

def plot_learn(values, yAxis, xAxis, title=None, saveTo=None):
    """
    Plots the learning curve with train/val for all values. Limited to
    7 learning curves on the same graph as we only have 7 colours.

    Args:
        1. values: Dictionary of tuples of lists, where the tuple is ([train values], [dev values])
            and the key is the name of the model for the graph label.
        2. yAxis: Either 'Loss', 'F1' or 'Exact Match'
        3. xAxis: 'Epochs' or 'Iterations'
        4. title: optional title to the graph
        5. saveTo: save location for graph
    """
    colours = ['b', 'g', 'r', 'c', 'm', 'y', 'k']

    for i, (k, (train_values, dev_values)) in enumerate(values.items()):

        plt.plot(map(float, train_values), linewidth=2, color=colours[i],
                 linestyle='--', label="Train {} {}".format(yAxis, k))
        if dev_values:
            plt.plot(map(float, dev_values), linewidth=2, color=colours[i],
                     linestyle='-', label="Dev {} {}".format(yAxis, k))

    plt.xlabel(xAxis)
    plt.ylabel(yAxis)
    if title:
        plt.title(title)

    if yAxis == "Loss":
        plt.legend(loc='upper right', shadow=True, prop={'size': 6})
    else:
        plt.legend(loc='upper left', shadow=True, prop={'size': 6})

    assert saveTo
    plt.savefig("{}".format(saveTo))
    plt.cla()
    plt.clf()
    plt.close()


def plot_metrics(values, yAxis, xAxis, title=None, saveTo=None):
    colours = ['b', 'g', 'r', 'c', 'm', 'y', 'k']

    for i, (train_values, dev_values, metric) in enumerate(values):
        plt.plot(map(float, train_values), linewidth=2, color=colours[i],
                 linestyle='-', label="Train {}".format(metric))
        if dev_values:
            plt.plot(map(float, dev_values), linewidth=2, color=colours[i],
                     linestyle='--', label="Dev {}".format(metric))

    plt.xlabel(xAxis)
    plt.ylabel(yAxis)
    if title:
        plt.title(title)

    if yAxis == "Loss":
        plt.legend(loc='upper right', shadow=True, prop={'size': 6})
    else:
        plt.legend(loc='upper left', shadow=True, prop={'size': 6})

    assert saveTo
    plt.savefig("{}".format(saveTo))
    plt.cla()
    plt.clf()
    plt.close()