import sys

import argparse
import json
import matplotlib
import os
import pickle

from linguistic_style_transfer_model.config import global_config
from linguistic_style_transfer_model.utils import log_initializer

logger = None
colors = ['b', 'r', 'g', 'c', 'm', 'y', 'k']
plot_markers = ['x', '+']


def plot_coordinates(coordinates, plot_path, markers, label_names, fig_num):
    matplotlib.use('svg')
    import matplotlib.pyplot as plt

    plt.figure(fig_num)
    for i in range(len(markers) - 1):
        plt.scatter(x=coordinates[markers[i]:markers[i + 1], 0],
                    y=coordinates[markers[i]:markers[i + 1], 1],
                    marker=plot_markers[i % len(plot_markers)],
                    c=colors[i % len(colors)],
                    label=label_names[i], alpha=0.75)

    plt.legend(loc='upper right', fontsize='x-large')
    plt.axis('off')
    plt.savefig(fname=plot_path, format="svg", bbox_inches='tight', transparent=True)
    plt.close()


def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument("--saved-model-path", type=str)

    global logger
    logger = log_initializer.setup_custom_logger(global_config.logger_name, "INFO")

    args = vars(parser.parse_args(args=argv))
    logger.info(args)

    with open(os.path.join(args["saved_model_path"],
                           global_config.index_to_label_dict_file), 'r') as file:
        label_names = json.load(file)
    logger.info("label_names: {}".format(label_names))

    with open(os.path.join(args["saved_model_path"],
                           global_config.style_coordinates_file), 'rb') as pickle_file:
        (style_coordinates, markers) = pickle.load(pickle_file)
        plot_coordinates(style_coordinates,
                         os.path.join(args["saved_model_path"],
                                      global_config.style_embedding_plot_file),
                         markers, label_names, 0)

    with open(os.path.join(args["saved_model_path"],
                           global_config.content_coordinates_file), 'rb') as pickle_file:
        (content_coordinates, markers) = pickle.load(pickle_file)
        plot_coordinates(content_coordinates,
                         os.path.join(args["saved_model_path"],
                                      global_config.content_embedding_plot_file),
                         markers, label_names, 1)


if __name__ == "__main__":
    main(sys.argv[1:])