""" script for generating the loss plots from the Loss logs """ import argparse import matplotlib.pyplot as plt def read_loss_log(file_name, delimiter='\t'): """ read and load the loss values from a loss.log file :param file_name: path of the loss.log file :param delimiter: delimiter used to delimit the two columns :return: loss_val => numpy array [Iterations x 2] """ from numpy import genfromtxt losses = genfromtxt(file_name, delimiter=delimiter) return losses def plot_loss(*loss_vals, plot_name="Loss plot", fig_size=(17, 7), save_path=None, legends=("discriminator", "generator")): """ plot the discriminator loss values and save the plot if required :param loss_vals: (Variable Arg) numpy array or Sequence like for plotting values :param plot_name: Name of the plot :param fig_size: size of the generated figure (column_width, row_width) :param save_path: path to save the figure :param legends: list containing labels for loss plots' legends len(legends) == len(loss_vals) :return: """ assert len(loss_vals) == len(legends), "Not enough labels for legends" plt.figure(figsize=fig_size).suptitle(plot_name) plt.grid(True, which="both") plt.ylabel("loss value") plt.xlabel("spaced iterations") plt.axhline(y=0, color='k') plt.axvline(x=0, color='k') # plot all the provided loss values in a single plot plts = [] for loss_val in loss_vals: plts.append(plt.plot(loss_val)[0]) plt.legend(plts, legends, loc="upper right", fontsize=16) if save_path is not None: plt.savefig(save_path) def parse_arguments(): """ command line arguments parser :return: args => parsed command line arguments """ parser = argparse.ArgumentParser() parser.add_argument("--loss_file", action="store", type=str, default=None, help="path to loss log file") parser.add_argument("--plot_file", action="store", type=str, default=".", help="path to the file where plots are to be saved") args = parser.parse_args() return args def main(args): """ Main function for the script :param args: parsed command line arguments :return: None """ # Make sure input logs directory is provided assert args.loss_file is not None, "Loss-Log file not specified" # read the loss file loss_vals = read_loss_log(args.loss_file) # plot the loss: plot_loss(loss_vals[:, 1], loss_vals[:, 2], save_path=args.plot_file) print("Loss plots have been successfully generated ...") print("Please check: ", args.plot_file) if __name__ == '__main__': main(parse_arguments())