"""The plot tool module contains a base plot tool for displaying viewgraphs """ from math import pi import os import matplotlib as mpl import imageio as io from PIL import Image import networkx as nx mpl.use('TkAgg') import matplotlib.pyplot as plt # noqa import pylab # noqa BASE = 10000000 IMAGE_LIMIT = 500 FRAMES = "graphs/" THUMBNAILS = "thumbs/" COLOURS = ["LightYellow", "Yellow", "Orange", "OrangeRed", "Red", "DarkRed", "Black"] class PlotTool(object): """A base object with functions for building, displaying, and saving viewgraphs""" def __init__(self, display, save, node_shape): self.display = display self.save = save self.node_shape = node_shape if save: self._create_graph_folder() self.report_number = 0 def _create_graph_folder(self): graph_path = os.path.dirname(os.path.abspath(__file__)) + '/../graphs/' # if there isn't a graph folder, make one! if not os.path.isdir(graph_path): os.makedirs(graph_path) # find the next name for the next plot! graph_num = 0 while True: new_plot = graph_path + 'graph_num_' + str(graph_num).zfill(3) graph_num += 1 if not os.path.isdir(new_plot): os.makedirs(new_plot) break self.graph_path = new_plot + "/" self.thumbnail_path = self.graph_path + "thumbnails/" os.makedirs(self.thumbnail_path) def build_viewgraph(self, view, validator_set, message_colors, message_labels, edges): """Creates and displays view graphs.""" graph = nx.Graph() nodes = view.justified_messages.values() fig_size = plt.rcParams["figure.figsize"] fig_size[0] = 20 fig_size[1] = 20 plt.rcParams["figure.figsize"] = fig_size for message in nodes: graph.add_edges_from([(message, message)]) edge = [] if edges == []: for message in nodes: for msg_in_justification in message.justification.values(): if msg_in_justification is not None: edge.append((msg_in_justification, message)) edges = [{'edges': edge, 'width': 3, 'edge_color': 'black', 'style': 'solid'}] positions = dict() sorted_validators = validator_set.sorted_by_name() for message in nodes: # Index of val in list may have some small performance concerns. if message.estimate is not None: xslot = sorted_validators.index(message.sender) + 1 else: xslot = (len(validator_set) + 1) / 2.0 positions[message] = ( (float)(xslot) / (float)(len(validator_set) + 1), 0.2 + 0.1 * message.display_height ) node_color_map = {} for message in nodes: if message not in message_colors: node_color_map[message] = 'white' elif message_colors[message] == len(validator_set) - 1: node_color_map[message] = "Black" else: node_color_map[message] = COLOURS[int(len(COLOURS) * message_colors[message] / len(validator_set))] color_values = [node_color_map.get(node) for node in nodes] labels = {} node_sizes = [] for message in nodes: node_sizes.append(350 * pow(message.sender.weight / pi, 0.5)) labels[message] = message_labels.get(message, '') nx.draw_networkx_nodes(graph, positions, alpha=0.5, node_color=color_values, nodelist=nodes, node_size=node_sizes, node_shape=self.node_shape, edge_color='black') for edge in edges: if isinstance(edge, dict): nx.draw_networkx_edges( graph, positions, edgelist=(edge['edges']), width=edge['width'], edge_color=edge['edge_color'], style=edge['style'], alpha=0.5 ) else: assert False, edge nx.draw_networkx_labels(graph, positions, labels=labels) ax = plt.gca() ax.collections[0].set_edgecolor("black") ax.text(-0.05, 0.1, "Weights: ", fontsize=20) for validator in validator_set: xpos = (float)(validator.name + 1) / (float)(len(validator_set) + 1) - 0.01 ax.text(xpos, 0.1, (str)((int)(validator.weight)), fontsize=20) def next_viewgraph( self, view, validator_set, message_colors=None, message_labels=None, edges=None ): """Generates the next viewgraph, and saves and/or displays it""" if message_colors is None: message_colors = {} if message_labels is None: message_labels = {} if edges is None: edges = [] self.report_number += 1 # TODO: if we save and plot the graph, we currently build it twice # issues as pyplot clears the graph otherwise, should try to fix this if self.save: self.build_viewgraph( view, validator_set, message_colors=message_colors, message_labels=message_labels, edges=edges ) plt.savefig(self.graph_path + '/' + str(1000 + self.report_number) + ".png") plt.close('all') if self.display: self.build_viewgraph( view, validator_set, message_colors=message_colors, message_labels=message_labels, edges=edges ) plt.show() def make_thumbnails(self, frame_count_limit=IMAGE_LIMIT, xsize=1000, ysize=1000): """Make thumbnail images in PNG format.""" file_names = sorted([fn for fn in os.listdir(self.graph_path) if fn.endswith('.png')]) if len(file_names) >= frame_count_limit: raise Exception("To many frames!") images = [] for file_name in file_names: images.append(Image.open(self.graph_path + file_name)) size = (xsize, ysize) iterator = 0 for image in images: image.thumbnail(size, Image.ANTIALIAS) image.save(self.thumbnail_path + str(1000 + iterator) + "thumbnail.png", "PNG") iterator += 1 def make_gif(self, frame_count_limit=IMAGE_LIMIT, gif_name="mygif.gif", frame_duration=0.4): """Make a GIF visualization of view graph.""" if not self.save: return self.make_thumbnails(frame_count_limit=frame_count_limit) file_names = sorted([file_name for file_name in os.listdir(self.thumbnail_path) if file_name.endswith('thumbnail.png')]) images = [] for file_name in file_names: images.append(Image.open(self.thumbnail_path + file_name)) destination_filename = self.graph_path + gif_name iterator = 0 with io.get_writer(destination_filename, mode='I', duration=frame_duration) as writer: for file_name in file_names: image = io.imread(self.thumbnail_path + file_name) writer.append_data(image) iterator += 1 writer.close()