import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np import seaborn as sns sns.set_style("whitegrid", {'axes.grid': False}) import pandas as pd import matplotlib.patches as mpatches from matplotlib.collections import PatchCollection from matplotlib import cm from kvae.utils.movie import movie_to_frame import mpl_toolkits.mplot3d.art3d as art3d from scipy.spatial.distance import hamming matplotlib.rcParams['xtick.labelsize'] = 14 matplotlib.rcParams['ytick.labelsize'] = 14 def plot_auxiliary(all_vars, filename, table_size=4): # All variables need to be (batch_size, sequence_length, dimension) for i, a in enumerate(all_vars): if a.ndim == 2: all_vars[i] = np.expand_dims(a, 0) dim = all_vars[0].shape[-1] if dim == 2: f, ax = plt.subplots(table_size, table_size, sharex='col', sharey='row', figsize=[12, 12]) idx = 0 for x in range(table_size): for y in range(table_size): for a in all_vars: # Loop over the batch dimension ax[x, y].plot(a[idx, :, 0], a[idx, :, 1], linestyle='-', marker='o', markersize=3) # Plot starting point of the trajectory ax[x, y].plot(a[idx, 0, 0], a[idx, 0, 1], 'r.', ms=12) idx += 1 # plt.show() plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close() else: df_list = [] for i, a in enumerate(all_vars): df = pd.DataFrame(all_vars[i].reshape(-1, dim)) df['class'] = i df_list.append(df) df_all = pd.concat(df_list) sns_plot = sns.pairplot(df_all, hue="class", vars=range(dim)) sns_plot.savefig(filename) plt.close() def plot_alpha(alpha, filename, idx=0): fig = plt.figure(figsize=[6, 6]) ax = fig.gca() for line in np.swapaxes(alpha[idx], 1, 0): ax.plot(line, linestyle='-') ax.set_xlabel('Steps', fontsize=30) ax.set_ylabel('Mixture weight', fontsize=30) plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close() def plot_alpha_grid(alpha, filename, table_size=4, idx=0): # All variables need to be (batch_size, sequence_length, dimension) if alpha.ndim == 2: alpha = np.expand_dims(alpha, 0) f, ax = plt.subplots(table_size, table_size, sharex='col', sharey='row', figsize=[12, 12]) for x in range(table_size): for y in range(table_size): for i in range(alpha.shape[-1]): ax[x, y].plot(alpha[idx, :, i], linestyle='-', marker='o', markersize=3) ax[x, y].set_ylim([-0.01, 1.01]) idx += 1 plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close() def construct_ball_trajectory(var, r=1., cmap='Blues', start_color=0.4, shape='c'): # https://matplotlib.org/examples/color/colormaps_reference.html patches = [] for pos in var: if shape == 'c': patches.append(mpatches.Circle(pos, r)) elif shape == 'r': patches.append(mpatches.RegularPolygon(pos, 4, r)) elif shape == 's': patches.append(mpatches.RegularPolygon(pos, 6, r)) colors = np.linspace(start_color, .9, len(patches)) collection = PatchCollection(patches, cmap=cm.get_cmap(cmap), alpha=1.) collection.set_array(np.array(colors)) collection.set_clim(0, 1) return collection def plot_ball_trajectory(var, filename, idx=0, scale=30, cmap='Blues'): # Calc optimal radius of ball x_min, y_min = np.min(var[:, :, :2], axis=(0, 1)) x_max, y_max = np.max(var[:, :, :2], axis=(0, 1)) r = max((x_max - x_min), (y_max - y_min)) / scale fig = plt.figure(figsize=[4, 4]) ax = fig.gca() collection = construct_ball_trajectory(var[idx], r=1, cmap=cmap) ax.add_collection(collection) ax.set_xticks([]) ax.set_yticks([]) ax.axis("equal") ax.set_xlabel('$a_{t,1}$', fontsize=24) ax.set_ylabel('$a_{t,2}$', fontsize=24) plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close() def plot_ball_trajectories(all_vars, filename, table_size=4, scale=30): """ Plot trajectory with balls :param all_vars: batch size x sequence length x dimensions :param filename: path and filename to save to :param table_size: grid size :return: None """ # Calc optimal radius of ball x_min, y_min = np.min(all_vars[:, :, :2], axis=(0, 1)) x_max, y_max = np.max(all_vars[:, :, :2], axis=(0, 1)) r = (x_max - x_min) / scale fig, axes = plt.subplots(table_size, table_size, sharex=True, sharey=True, figsize=[12, 12]) for idx, ax in enumerate(axes.flat): collection = construct_ball_trajectory(all_vars[idx], r=r) ax.axis("equal") ax.add_collection(collection) plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close() def plot_ball_trajectories_comparison(enc, gen, impute, filename, idx=0, scale=60, nrows=1, ncols=3, mask=None): if isinstance(idx, int): idx = np.arange(nrows*ncols) # Calc optimal radius of ball x_min, y_min = np.min(enc, axis=(0, 1)) x_max, y_max = np.max(enc, axis=(0, 1)) r = (x_max - x_min) / scale fig, axes = plt.subplots(nrows, ncols, figsize=[ncols*6, nrows*6]) for i, ax in enumerate(axes.flat): for var, cmap, c in zip([enc, gen, impute], ['Reds', 'Blues', 'Greens'], ['red', 'blue', 'green']): ax.plot(var[idx[i], :, 0], var[idx[i], :, 1], color=c, alpha=1, linewidth=2) collection = construct_ball_trajectory(var[idx[i]], r=r, cmap=cmap) ax.add_collection(collection) # if cmap == 'Reds': # if mask is not None: # collection = construct_ball_trajectory(enc[idx[i], mask[idx[i]] == 1], r * 1.7, cmap='Reds', # start_color=.4, shape='r') # ax.add_collection(collection) # Add the observed samples in the end collection = construct_ball_trajectory(enc[idx[i], mask[idx[i]] == 1], r * 1.5, cmap='Greys', start_color=.9, shape='r') ax.add_collection(collection) # Add the starting point collection = construct_ball_trajectory(enc[idx[i], [0]], r * 2, cmap='Greys', start_color=.9, shape='s') ax.add_collection(collection) ax.set_xticks([]) ax.set_yticks([]) ax.axis("equal") if i >= ncols*(nrows - 1): ax.set_xlabel('$a_{t,1}$', fontsize=30) if i % ncols == 0: ax.set_ylabel('$a_{t,2}$', fontsize=30) # ax.set_xlim(x_min - 1, x_max + 1) # ax.set_ylim(y_min - 1, y_max + 1) axes[0, 0].legend(['Encoded', 'Generated', 'Smoothed'], fontsize=30, loc=0) plt.tight_layout() plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close() def plot_3d_ball_trajectory(var, filename, r=0.05): var = np.asarray(var) # Normalize directions var -= var.min(axis=0) var /= var.max(axis=0) fig = plt.figure() ax = fig.add_subplot(111, projection='3d') for x, y, z in var: p = mpatches.Circle((x, y), r, ec="none") ax.add_patch(p) art3d.pathpatch_2d_to_3d(p, z=0, zdir="z") p = mpatches.Circle((x, z), r, ec="none") ax.add_patch(p) art3d.pathpatch_2d_to_3d(p, z=0, zdir="y") p = mpatches.Circle((y, z), r, ec="none") ax.add_patch(p) art3d.pathpatch_2d_to_3d(p, z=0, zdir="x") # ax.scatter(x, y, z, s=100) # ax.plot(var[:, 0], var[:, 1], zs=var[:, 2]) ax.view_init(azim=45, elev=30) ax.set_xlim3d(-0.1, 1.1) ax.set_ylim3d(-0.1, 1.1) ax.set_zlim3d(-0.1, 1.1) plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close(fig) # plt.show() def plot_trajectory_and_video(trajectory, images, filename, idx=5, cmap='Blues', sidebyside=True): # Create 2D trajectory collection = construct_ball_trajectory(trajectory[idx, :20], 1, cmap=cmap) # Create constructed images # images[images > 0] = 1. image = movie_to_frame(images[idx, :20]) # Reverse y-axis in image # image = np.flipud(image) image = np.fliplr(image) if sidebyside: fig, ax = plt.subplots(ncols=2, figsize=[20, 10]) x_min, y_min = np.min(trajectory, axis=(0, 1)) x_max, y_max = np.max(trajectory, axis=(0, 1)) ax[0].axis("equal") ax[0].add_collection(collection) ax[0].set_xlim([x_min, x_max]) ax[0].set_ylim([y_min, y_max]) ax[0].set_xticks([]) ax[0].set_yticks([]) ax[1].imshow(image, cmap=plt.cm.get_cmap(cmap), interpolation='none', vmin=0, vmax=1) ax[1].set_xlim([1, 31]) ax[1].set_ylim([1, 31]) ax[1].set_xticks([]) ax[1].set_yticks([]) else: fig, ax = plt.subplots(ncols=1, figsize=[12, 12]) ax.add_collection(collection) ax.imshow(image, cmap=plt.cm.get_cmap('Reds'), interpolation='none', vmin=0, vmax=1) ax.set_xlim([1, 31]) ax.set_ylim([1, 31]) ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() plt.tick_params(bottom=False, left=False) plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close(fig) def plot_ball_and_alpha(alpha, trajectory, filename, cmap='Blues'): f, ax = plt.subplots(nrows=1, ncols=2, figsize=[12, 6]) collection = construct_ball_trajectory(trajectory, r=1., cmap=cmap) x_min, y_min = np.min(trajectory, axis=0) x_max, y_max = np.max(trajectory, axis=0) ax[0].add_collection(collection) ax[0].set_xlim([x_min, x_max]) ax[0].set_ylim([y_min, y_max]) # ax[0].set_xticks([]) # ax[0].set_yticks([]) ax[0].axis("equal") for line in np.swapaxes(alpha, 1, 0): ax[1].plot(line, linestyle='-') plt.savefig(filename, format='png', bbox_inches='tight', dpi=80) plt.close() def plot_trajectory_uncertainty(true, gen, filter, smooth, filename): sequences, timesteps, h, w = true.shape errors = dict(Generated=list(), Filtered=list(), Smoothed=list()) for label, var in zip(('Generated', 'Filtered', 'Smoothed'), (gen, filter, smooth)): for step in range(timesteps): errors[label].append(hamming(true[:, step].ravel() > 0.5, var[:, step].ravel() > 0.5)) plt.plot(np.linspace(1, timesteps, num=timesteps).astype(int), errors[label], linewidth=3, ms=20, label=label) plt.xlabel('Steps', fontsize=20) plt.ylabel('Hamming distance', fontsize=20) plt.legend(fontsize=20) plt.savefig(filename) plt.close() def hinton(matrix, max_weight=None, ax=None): """Draw Hinton diagram for visualizing a weight matrix.""" ax = ax if ax is not None else plt.gca() if not max_weight: max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2)) ax.patch.set_facecolor('gray') ax.set_aspect('equal', 'box') ax.xaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator()) for (x, y), w in np.ndenumerate(matrix): color = 'white' if w > 0 else 'black' size = np.sqrt(np.abs(w) / max_weight) rect = plt.Rectangle([x - size / 2, y - size / 2], size, size, facecolor=color, edgecolor=color) ax.add_patch(rect) ax.autoscale_view() ax.invert_yaxis() def plot_kalman_transfers(matrices, filename): fig, axarr = plt.subplots(1, len(matrices)) for idx, mat in enumerate(matrices): hinton(mat, ax=axarr[idx]) fig.savefig(filename, format='png', bbox_inches='tight', dpi=80) if __name__ == '__main__': # hinton(np.random.rand(20, 20) - 0.5) # plt.show() # filename = 'box_rnd' # npzfile = np.load("../../data/%s.npz" %filename) # states = npzfile['state'][:, :, :2] # plot_auxiliary([states], 'plot_true_%s.png' %filename) # save_frames_to_png(images, 'training_sequence_img') filename = 'box_rnd' npzfile = np.load("../../data/%s.npz" %filename) state = npzfile['state'] images = npzfile['images'] # plot_ball_trajectories(state, 'trajectory_grid') # plot_trajectory_and_video(state, images, 'training_combined', sidebyside=True) # plot_3d_ball_trajectory(np.concatenate((state[0, :, :2], np.random.rand(state.shape[1], 1)*32), 1), '3d_plot') # mask = np.random.choice([0, 1], size=(16, 20), p=[0.5, 0.5]) # plot_ball_trajectories_comparison(state[:16, :, :2], state[16:32, :, :2], state[32:48, :, :2], 'training_ball_comp', # mask=mask) # plot_trajectory_uncertainty(images[:16, :], images[16:32, :], images[32:48, :], images[48:64, :], # 'training_uncertainty') plot_ball_and_alpha(None, state[0, :, :2], '')