""" HiddenLayer Implementation of the Canvas class to render visualizations. Written by Waleed Abdulla Licensed under the MIT License """ import itertools import math import numpy as np import matplotlib import matplotlib.pyplot as plt import IPython.display from mpl_toolkits.mplot3d import Axes3D from matplotlib.collections import PolyCollection DEFAULT_THEME = { "fig_width": 12, # inches "hist_outline_color": [0, 0, 0.9], "hist_color": [0.5, 0, 0.9], } def norm(image): """Normalize an image to [0, 1] range.""" min_value = image.min() max_value = image.max() if min_value == max_value: return image - min_value return (image - min_value) / (max_value - min_value) # TODO: Move inside Canvas and merge with draw_images def show_images(images, titles=None, cols=5, **kwargs): """ images: A list of images. I can be either: - A list of Numpy arrays. Each array represents an image. - A list of lists of Numpy arrays. In this case, the images in the inner lists are concatentated to make one image. """ # The images param can be a list or an array titles = titles or [""] * len(images) rows = math.ceil(len(images) / cols) height_ratio = 1.2 * (rows/cols) * (0.5 if type(images[0]) is not np.ndarray else 1) plt.figure(figsize=(11, 11 * height_ratio)) i = 1 for image, title in zip(images, titles): plt.subplot(rows, cols, i) plt.axis("off") # Is image a list? If so, merge them into one image. if type(image) is not np.ndarray: image = [norm(g) for g in image] image = np.concatenate(image, axis=1) else: image = norm(image) plt.title(title, fontsize=9) plt.imshow(image, cmap="Greys_r", **kwargs) i += 1 plt.tight_layout(h_pad=0, w_pad=0) ############################################################################### # Canvas Class ############################################################################### class Canvas(): def __init__(self): self._context = None self.theme = DEFAULT_THEME self.figure = None self.backend = matplotlib.get_backend() self.drawing_calls = [] self.theme = DEFAULT_THEME def __enter__(self): self._context = "build" self.drawing_calls = [] return self def __exit__(self, exc_type, exc_val, exc_tb): self.render() def render(self): self._context = "run" # Clear output if 'inline' in self.backend: IPython.display.clear_output(wait=True) self.figure = None # Separate the draw_*() calls that generate a grid cell grid_calls = [] silent_calls = [] for c in self.drawing_calls: if c[0] == "draw_summary": silent_calls.append(c) else: grid_calls.append(c) # Header area # TODO: ideally, compute how much header area we need based on the # length of text to show there. Right now, we're just using # a fixed number multiplied by the number of calls. Since there # is only one silent call, draw_summary(), then the header padding # is either 0 or 0.1 head_pad = 0.1 * len(silent_calls) width = self.theme['fig_width'] if not self.figure: self.figure = plt.figure(figsize=(width, width/3 * (head_pad + len(grid_calls)))) self.figure.clear() # Divide figure area by number of grid calls gs = matplotlib.gridspec.GridSpec(len(grid_calls), 1) # Call silent calls for c in silent_calls: getattr(self, c[0])(*c[1], **c[2]) # Call grid methods for i, c in enumerate(grid_calls): method = c[0] # Create an axis for each call # Save in in self.ax so the drawing function has access to it self.ax = self.figure.add_subplot(gs[i]) # Save the GridSpec as well self.gs = gs[i] # Call the method getattr(self, method)(*c[1], **c[2]) # Cleanup after drawing self.ax = None self.gs = None gs.tight_layout(self.figure, rect=(0, 0, 1, 1-head_pad)) # TODO: pause() allows the GUI to render but it's sluggish because it # only has 0.1 seconds of CPU time at each step. A better solution would be to # launch a separate process to render the GUI and pipe data to it. plt.pause(0.1) plt.show(block=False) self.drawing_calls = [] self._context = None def __getattribute__(self, name): if name.startswith("draw_") and self._context != "run": def wrapper(*args, **kwargs): self.drawing_calls.append((name, args, kwargs)) if not self._context: self.render() return wrapper else: return object.__getattribute__(self, name) def save(self, file_name): self.figure.savefig(file_name) def draw_summary(self, history, title=""): """Inserts a text summary at the top that lists the number of steps and total training time.""" # Generate summary string time_str = str(history.get_total_time()).split(".")[0] # remove microseconds summary = "Step: {} Time: {}".format(history.step, time_str) if title: summary = title + "\n\n" + summary self.figure.suptitle(summary) def draw_plot(self, metrics, labels=None, ylabel="", title=None): """ metrics: One or more metrics parameters. Each represents the history of one metric. """ metrics = metrics if isinstance(metrics, list) else [metrics] # Loop through metrics default_title = "" for i, m in enumerate(metrics): label = labels[i] if labels else m.name # TODO: use a standard formating function for values default_title += (" " if default_title else "") + "{}: {}".format(label, m.data[-1]) self.ax.plot(m.formatted_steps, m.data, label=label) title = default_title if title is None else title self.ax.set_title(title) self.ax.set_ylabel(ylabel) self.ax.legend() self.ax.set_xlabel("Steps") self.ax.xaxis.set_major_locator(plt.AutoLocator()) def draw_image(self, metric, limit=5): """Display a series of images at different time steps.""" rows = 1 cols = limit self.ax.axis("off") # Take the Axes gridspec and divide it into a grid gs = matplotlib.gridspec.GridSpecFromSubplotSpec( rows, cols, subplot_spec=self.gs) # Loop through images in last few steps for i, image in enumerate(metric.data[-cols:]): ax = self.figure.add_subplot(gs[0, i]) ax.axis('off') ax.set_title(metric.formatted_steps[-cols:][i]) ax.imshow(norm(image)) def draw_hist(self, metric, title=""): """Draw a series of histograms of the selected keys over different training steps. """ # TODO: assert isinstance(list(values.values())[0], np.ndarray) rows = 1 cols = 1 limit = 10 # max steps to show # We need a 3D projection Subplot, so ignore the one provided to # as an create a new one. ax = self.figure.add_subplot(self.gs, projection="3d") ax.view_init(30, -80) # Compute histograms verts = [] area_colors = [] edge_colors = [] for i, s in enumerate(metric.steps[-limit:]): hist, edges = np.histogram(metric.data[-i-1:]) # X is bin centers x = np.diff(edges)/2 + edges[:-1] # Y is hist values y = hist x = np.concatenate([x[0:1], x, x[-1:]]) y = np.concatenate([[0], y, [0]]) # Ranges if i == 0: x_min = x.min() x_max = x.max() y_min = y.min() y_max = y.max() x_min = np.minimum(x_min, x.min()) x_max = np.maximum(x_max, x.max()) y_min = np.minimum(y_min, y.min()) y_max = np.maximum(y_max, y.max()) alpha = 0.8 * (i+1) / min(limit, len(metric.steps)) verts.append(list(zip(x, y))) area_colors.append(np.array(self.theme["hist_color"] + [alpha])) edge_colors.append(np.array(self.theme["hist_outline_color"] + [alpha])) poly = PolyCollection(verts, facecolors=area_colors, edgecolors=edge_colors) ax.add_collection3d(poly, zs=list(range(min(limit, len(metric.steps)))), zdir='y') ax.set_xlim(x_min, x_max) ax.set_ylim(0, limit) ax.set_yticklabels(metric.formatted_steps[-limit:]) ax.set_zlim(y_min, y_max) ax.set_title(metric.name)