from __future__ import print_function, division import os import json import argparse from textwrap import fill import matplotlib.pyplot as plt from matplotlib import cm, colors import seaborn as sns import numpy as np from mpl_toolkits.mplot3d import Axes3D from sklearn.decomposition import PCA from utils import parseDataFolder, getInputBuiltin, loadData # Init seaborn sns.set() INTERACTIVE_PLOT = True TITLE_MAX_LENGTH = 50 def updateDisplayMode(): """ Enable or disable interactive plot see: http://matplotlib.org/faq/usage_faq.html#what-is-interactive-mode """ if INTERACTIVE_PLOT: plt.ion() else: plt.ioff() def pauseOrClose(fig): """ :param fig: (matplotlib figure object) """ if INTERACTIVE_PLOT: plt.draw() plt.pause(0.0001) # Small pause to update the plot else: plt.close(fig) def plotRepresentation(states, rewards, name="Learned State Representation", add_colorbar=True, path=None, fit_pca=False, cmap='coolwarm', true_states=None): """ Plot learned state representation using rewards for coloring :param states: (np.ndarray) :param rewards: (numpy 1D array) :param name: (str) :param add_colorbar: (bool) :param path: (str) :param fit_pca: (bool) :param cmap: (str) :param true_states: project a 1D predicted states onto the ground_truth """ state_dim = states.shape[1] if state_dim != 1 and (fit_pca or state_dim > 3): name += " (PCA)" n_components = min(state_dim, 3) print("Fitting PCA with {} components".format(n_components)) states = PCA(n_components=n_components).fit_transform(states) if state_dim == 1: # Extend states as 2D: states_matrix = np.zeros((states.shape[0], 2)) states_matrix[:, 0] = states[:, 0] plot2dRepresentation(states_matrix, rewards, name, add_colorbar, path, cmap, true_states=true_states) elif state_dim == 2: plot2dRepresentation(states, rewards, name, add_colorbar, path, cmap) else: plot3dRepresentation(states, rewards, name, add_colorbar, path, cmap) def plot2dRepresentation(states, rewards, name="Learned State Representation", add_colorbar=True, path=None, cmap='coolwarm', true_states=None): updateDisplayMode() fig = plt.figure(name) plt.clf() if true_states is not None: plt.scatter(true_states[:len(states), 0], true_states[:len(states), 1], s=7, c=states[:, 0], cmap=cmap, linewidths=0.1) else: plt.scatter(states[:, 0], states[:, 1], s=7, c=rewards, cmap=cmap, linewidths=0.1) plt.xlabel('State dimension 1') plt.ylabel('State dimension 2') plt.title(fill(name, TITLE_MAX_LENGTH)) fig.tight_layout() if add_colorbar: plt.colorbar(label='Reward') if path is not None: plt.savefig(path) pauseOrClose(fig) def plot3dRepresentation(states, rewards, name="Learned State Representation", add_colorbar=True, path=None, cmap='coolwarm'): updateDisplayMode() fig = plt.figure(name) plt.clf() ax = fig.add_subplot(111, projection='3d') im = ax.scatter(states[:, 0], states[:, 1], states[:, 2], s=7, c=rewards, cmap=cmap, linewidths=0.1) ax.set_xlabel('State dimension 1') ax.set_ylabel('State dimension 2') ax.set_zlabel('State dimension 3') ax.set_title(fill(name, TITLE_MAX_LENGTH)) fig.tight_layout() if add_colorbar: fig.colorbar(im, label='Reward') if path is not None: plt.savefig(path) pauseOrClose(fig) def plotImage(image, name='Observation Sample'): """ Display an image :param image: (np.ndarray) (with values in [0, 1]) :param name: (str) """ # Reorder channels if image.shape[0] == 3 and len(image.shape) == 3: # (n_channels, height, width) -> (width, height, n_channels) image = np.transpose(image, (2, 1, 0)) updateDisplayMode() fig = plt.figure(name) plt.imshow(image, interpolation='nearest') # plt.gca().invert_yaxis() plt.xticks([]) plt.yticks([]) pauseOrClose(fig) def colorPerEpisode(episode_starts): """ :param episode_starts: (numpy 1D array) :return: (numpy 1D array) """ colors = np.zeros(len(episode_starts)) color_idx = -1 print(np.sum(episode_starts)) for i in range(len(episode_starts)): # New episode if episode_starts[i] == 1: color_idx += 1 colors[i] = color_idx return colors def prettyPlotAgainst(states, rewards, title="Representation", fit_pca=False, cmap='coolwarm'): """ State dimensions are plotted one against the other (it creates a matrix of 2d representation) using rewards for coloring, the diagonal is a distribution plot, and the scatter plots have a density outline. :param states: (np.ndarray) :param rewards: (np.ndarray) :param title: (str) :param fit_pca: (bool) :param cmap: (str) """ with sns.axes_style('white'): n = states.shape[1] fig, ax_mat = plt.subplots(n, n, figsize=(10, 10), sharex=False, sharey=False) fig.subplots_adjust(hspace=0.2, wspace=0.2) if fit_pca: title += " (PCA)" states = PCA(n_components=n).fit_transform(states) c_idx = cm.get_cmap(cmap) norm = colors.Normalize(vmin=np.min(rewards), vmax=np.max(rewards)) for i in range(n): for j in range(n): x, y = states[:, i], states[:, j] ax = ax_mat[i, j] if i != j: ax.scatter(x, y, c=rewards, cmap=cmap, s=5) sns.kdeplot(x, y, cmap="Greys", ax=ax, shade=True, shade_lowest=False, alpha=0.2) ax.set_xlim([np.min(x), np.max(x)]) ax.set_ylim([np.min(y), np.max(y)]) else: if len(np.unique(rewards)) < 10: for r in np.unique(rewards): sns.distplot(x[rewards == r], color=c_idx(norm(r)), ax=ax) else: sns.distplot(x, ax=ax) if i == 0: ax.set_title("Dim {}".format(j), y=1.2) if i != j: # Hide ticks if i != 0 and i != n - 1: ax.xaxis.set_visible(False) if j != 0 and j != n - 1: ax.yaxis.set_visible(False) # Set up ticks only on one side for the "edge" subplots... if j == 0: ax.yaxis.set_ticks_position('left') if j == n - 1: ax.yaxis.set_ticks_position('right') if i == 0: ax.xaxis.set_ticks_position('top') if i == n - 1: ax.xaxis.set_ticks_position('bottom') plt.suptitle(title, fontsize=16) plt.show() def plotAgainst(states, rewards, title="Representation", fit_pca=False, cmap='coolwarm'): """ State dimensions are plotted one against the other (it creates a matrix of 2d representation) using rewards for coloring :param states: (np.ndarray) :param rewards: (np.ndarray) :param title: (str) :param fit_pca: (bool) :param cmap: (str) """ n = states.shape[1] fig, ax_mat = plt.subplots(n, n, figsize=(10, 10), sharex=False, sharey=False) fig.subplots_adjust(hspace=0.0, wspace=0.0) if fit_pca: title += " (PCA)" states = PCA(n_components=n).fit_transform(states) for i in range(n): for j in range(n): x, y = states[:, i], states[:, j] ax = ax_mat[i, j] ax.scatter(x, y, c=rewards, cmap=cmap, s=5) ax.set_xlim([np.min(x), np.max(x)]) ax.set_ylim([np.min(y), np.max(y)]) # Hide ticks if i != 0 and i != n - 1: ax.xaxis.set_visible(False) if j != 0 and j != n - 1: ax.yaxis.set_visible(False) # Set up ticks only on one side for the "edge" subplots... if j == 0: ax.yaxis.set_ticks_position('left') if j == n - 1: ax.yaxis.set_ticks_position('right') if i == 0: ax.set_title("Dim {}".format(j), y=1.2) ax.xaxis.set_ticks_position('top') if i == n - 1: ax.xaxis.set_ticks_position('bottom') plt.suptitle(title, fontsize=16) plt.show() def plotCorrelation(states_rewards, ground_truth, target_positions, only_print=False): """ Correlation matrix: Target pos/ground truth states vs. States predicted :param states_rewards: (numpy dict) :param ground_truth: (numpy dict) :param target_positions: (np.ndarray) :param only_print: (bool) only print the correlation mesurements (max of correlation for each of Ground Truth's dimension) :return: returns the max correlation for each of Ground Truth's dimension with the predicted states as well as its mean """ np.set_printoptions(precision=2) correlation_max_vector = np.array([]) for index, ground_truth_name in enumerate([" Agent's position ", "Target Position"]): if ground_truth_name == " Agent's position ": key = 'ground_truth_states' if 'ground_truth_states' in ground_truth.keys() else 'arm_states' x = ground_truth[key][:len(rewards)] else: x = target_positions[:len(rewards)] # adding epsilon in case of little variance in samples of X & Ys eps = 1e-12 corr = np.corrcoef(x=x + eps, y=states_rewards['states'] + eps, rowvar=0) fig = plt.figure(figsize=(8, 6)) ax = fig.add_subplot(111) labels = [r'$\tilde{s}_' + str(i_) + '$' for i_ in range(x.shape[1])] labels += [r'$s_' + str(i_) + '$' for i_ in range(states_rewards['states'].shape[1])] cax = ax.matshow(corr, cmap=cmap, vmin=-1, vmax=1) ax.set_xticklabels([''] + labels) ax.set_yticklabels([''] + labels) ax.grid(False) plt.title(r'Correlation Matrix: S = Predicted states | $\tilde{S}$ = ' + ground_truth_name) fig.colorbar(cax, label='correlation coefficient') # Building the vector of max correlation ( a scalar for each of the Ground Truth's dimension) ground_truth_dim = x.shape[1] corr_copy = corr for idx in range(ground_truth_dim): corr_copy[idx, idx] = 0.0 correlation_max_vector = np.append(correlation_max_vector, max(abs(corr_copy[idx]))) # Printing the max correlation for each of Ground Truth's dimension with the predicted states # as well as the mean correlation_scalar = sum(correlation_max_vector) print("\nCorrelation value of the model's prediction with the Ground Truth:\n Max correlation vector (GTC): {}" "\n Mean : {:.2f}".format(correlation_max_vector, correlation_scalar / len(correlation_max_vector))) if not only_print: pauseOrClose(fig) return correlation_max_vector, correlation_scalar / len(correlation_max_vector) if __name__ == '__main__': parser = argparse.ArgumentParser(description='Plotting script for representation') parser.add_argument('-i', '--input-file', type=str, default="", help='Path to a npz file containing states and rewards') parser.add_argument('--data-folder', type=str, default="", help='Path to a dataset folder, it will plot ground truth states') parser.add_argument('--color-episode', action='store_true', default=False, help='Color states per episodes instead of reward') parser.add_argument('--plot-against', action='store_true', default=False, help='Plot against each dimension') parser.add_argument('--pretty-plot-against', action='store_true', default=False, help='Plot against each dimension (diagonals are distributions + cleaner look)') parser.add_argument('--correlation', action='store_true', default=False, help='Plot correlation coeff against each dimension') parser.add_argument('--projection', action='store_true', default=False, help='Plot 1D projection of predicted state on ground truth') parser.add_argument('--print-corr', action='store_true', default=False, help='Only print correlation measurements') args = parser.parse_args() cmap = "tab20" if args.color_episode else "coolwarm" assert not (args.color_episode and args.data_folder == ""), \ "You must specify a datafolder when using per-episode color" assert not (args.correlation and args.data_folder == ""), \ "You must specify a datafolder when using the correlation plot" # Force correlation plotting when `--print-cor` is passed if args.print_corr: args.correlation = True args.data_folder = parseDataFolder(args.data_folder) if args.input_file != "": print("Loading {}...".format(args.input_file)) states_rewards = np.load(args.input_file) rewards = states_rewards['rewards'] if args.color_episode: episode_starts = np.load('data/{}/preprocessed_data.npz'.format(args.data_folder))['episode_starts'] rewards = colorPerEpisode(episode_starts)[:len(rewards)] if args.plot_against: print("Plotting against") plotAgainst(states_rewards['states'], rewards, cmap=cmap) elif args.pretty_plot_against: print("Pretty plotting against") prettyPlotAgainst(states_rewards['states'], rewards, cmap=cmap) elif args.projection: training_data, ground_truth, true_states, _ = loadData(args.data_folder) plotRepresentation(states_rewards['states'], rewards, cmap=cmap, true_states=true_states) elif args.correlation: training_data, ground_truth, true_states, target_positions = loadData(args.data_folder) if args.color_episode: rewards = colorPerEpisode(training_data['episode_starts']) # Compute Ground Truth Correlation gt_corr, gt_corr_mean = plotCorrelation(states_rewards, ground_truth, target_positions, only_print=args.print_corr) result_dict = { 'gt_corr': gt_corr.tolist(), 'gt_corr_mean': gt_corr_mean } # Write the results in a json file log_folder = os.path.dirname(args.input_file) with open("{}/gt_correlation.json".format(log_folder), 'w') as f: json.dump(result_dict, f) else: plotRepresentation(states_rewards['states'], rewards, cmap=cmap) if not args.print_corr: getInputBuiltin()('\nPress any key to exit.') elif args.data_folder != "": print("Plotting ground truth...") training_data, ground_truth, true_states, _ = loadData(args.data_folder) rewards = training_data['rewards'] name = "Ground Truth States - {}".format(args.data_folder) if args.color_episode: rewards = colorPerEpisode(training_data['episode_starts']) if args.plot_against: plotAgainst(true_states, rewards, cmap=cmap) elif args.pretty_plot_against: prettyPlotAgainst(true_states, rewards, cmap=cmap) else: plotRepresentation(true_states, rewards, name, fit_pca=False, cmap=cmap) getInputBuiltin()('\nPress any key to exit.') else: print("You must specify one of --input-file or --data-folder")