import threading from pyrealtime import FPSTimer from pyrealtime.layer import TransformMixin, ProcessLayer, ThreadLayer import time import copy import matplotlib.animation as animation from matplotlib import pyplot as plt import numpy as np from matplotlib.widgets import Button def _blit_draw(self, artists, bg_cache): # Handles blitted drawing, which renders only the artists given instead # of the entire figure. updated_ax = [] for a in artists: # If we haven't cached the background for this axes object, do # so now. This might not always be reliable, but it's an attempt # to automate the process. if a.axes not in bg_cache: # bg_cache[a.axes] = a.figure.canvas.copy_from_bbox(a.axes.bbox) # change here bg_cache[a.axes] = a.figure.canvas.copy_from_bbox(a.axes.figure.bbox) a.axes.draw_artist(a) updated_ax.append(a.axes) # After rendering all the needed artists, blit each axes individually. for ax in set(updated_ax): # and here # ax.figure.canvas.blit(ax.bbox) ax.figure.canvas.blit(ax.figure.bbox) class FigureManager(ProcessLayer): def __init__(self, create_fig=None, fps=30, keep_plot_open=False, *args, **kwargs): if 'name' not in kwargs: kwargs['name'] = "FigureManager" super().__init__(*args, **kwargs) self.fig = None self.axes_dict = None self.plot_layers = {} self.create_fig = create_fig if create_fig is not None else FigureManager.default_create_fig self.anim = None self.fps = fps self.print_anim_fps = self.print_fps self.anim_fps_timer = FPSTimer(5) if self.print_fps: self.print_fps = False self.keep_plot_open = keep_plot_open @staticmethod def default_create_fig(fig): ax = fig.add_subplot(111) return {None: ax} def initialize(self): self.fig = plt.figure() self.axes_dict = self.create_fig(self.fig) for plot_key in self.plot_layers.keys(): plot_layer = self.plot_layers[plot_key] if plot_key not in self.axes_dict: raise KeyError("No axis created for plot %s" % plot_key) plot_layer.create_fig(self.fig, self.axes_dict[plot_key]) # matplotlib.animation.Animation._blit_draw = _blit_draw self.anim = animation.FuncAnimation(self.fig, self.update_func, init_func=self.init_func, frames=None, interval=1000 / self.fps, blit=True, save_count=0) def init_func(self): artists = [] for plot_key in self.plot_layers.keys(): plot_layer = self.plot_layers[plot_key] artists += plot_layer.init_fig() return artists def update_func(self, frame): # print('update') artists = [] if self.print_anim_fps: self.anim_fps_timer.tick() for plot_key in self.plot_layers.keys(): plot_layer = self.plot_layers[plot_key] artists += plot_layer.anim_update(frame) return artists def get_input(self): time.sleep(1) return None def main_thread_post_init(self): try: plt.show() except KeyboardInterrupt: print("exiting figure") # print("done showing") self.shutdown() def register_plot(self, key, plot_layer): if key in self.plot_layers: raise NameError("plot key already exists: %s" % key) self.plot_layers[key] = plot_layer def shutdown(self): if not self.keep_plot_open: def close_fig(): plt.close(self.fig) # plt.close hangs for some reason, so doing this in a daemon thread t = threading.Thread(target=close_fig, daemon=True) t.start() self.stop_event.set() super().shutdown() class InProcFigureManager(ThreadLayer): def __init__(self, create_fig=None, fps=30, keep_plot_open=False, *args, **kwargs): if 'name' not in kwargs: kwargs['name'] = "FigureManager" super().__init__(*args, **kwargs) self.fig = None self.axes_dict = None self.plot_layers = {} self.create_fig = create_fig if create_fig is not None else FigureManager.default_create_fig self.anim = None self.fps = fps self.keep_plot_open = keep_plot_open @staticmethod def default_create_fig(fig): ax = fig.add_subplot(111) return {None: ax} def initialize(self): self.fig = plt.figure() self.axes_dict = self.create_fig(self.fig) for plot_key in self.plot_layers.keys(): plot_layer = self.plot_layers[plot_key] if plot_key not in self.axes_dict: raise KeyError("No axis created for plot %s" % plot_key) plot_layer.create_fig(self.fig, self.axes_dict[plot_key]) # matplotlib.animation.Animation._blit_draw = _blit_draw self.anim = animation.FuncAnimation(self.fig, self.update_func, init_func=self.init_func, frames=None, interval=1000 / self.fps, blit=True) print("show") plt.ion() plt.show() print("continue") def init_func(self): artists = [] for plot_key in self.plot_layers.keys(): plot_layer = self.plot_layers[plot_key] artists += plot_layer.init_fig() return artists def update_func(self, frame): artists = [] for plot_key in self.plot_layers.keys(): plot_layer = self.plot_layers[plot_key] artists += plot_layer.anim_update(frame) return artists def get_input(self): time.sleep(1) return None def register_plot(self, key, plot_layer): if key in self.plot_layers: raise NameError("plot key already exists: %s" % key) self.plot_layers[key] = plot_layer def shutdown(self): if not self.keep_plot_open: def close_fig(): plt.close(self.fig) # plt.close hangs for some reason, so doing this in a daemon thread t = threading.Thread(target=close_fig, daemon=True) t.start() self.stop_event.set() super().shutdown() class PlotLayer(TransformMixin, ThreadLayer): def __init__(self, port_in, samples=10, fig_manager=None, plot_config=None, plot_key=None, create_fig=None, legend=False, *args, **kwargs): self.data_lock = None self.samples = samples self.buf_data = None self.fig_manager = fig_manager if fig_manager is not None else FigureManager(create_fig=create_fig) self.fig_manager.register_plot(plot_key, self) # TODO self.plot_config = plot_config self.ax = None self.series = None self.to_return = None self.legend = legend self.h_legend = None self.legend_dict = dict() self.trigger_legend_redraw = False super().__init__(port_in, parent_proc=self.fig_manager, *args, **kwargs) def transform(self, data): self.data_lock.acquire() self.buf_data = data self.data_lock.release() self.data_lock.acquire() to_return_copy = copy.copy(self.to_return) self.to_return = None self.data_lock.release() return to_return_copy def raise_event(self, key, value): # TODO: Check if thread safe self.data_lock.acquire() if self.to_return is None: self.to_return = {} self.to_return[key] = value self.data_lock.release() def anim_update(self, _): lines = [] self.data_lock.acquire() if self.buf_data is not None: lines = self.update_fig(self.buf_data) self.data_lock.release() # if self.trigger_legend_redraw: # lines += self.h_legend.get_lines() # self.trigger_legend_redraw = False # # self.fig_manager.anim._blit = False # print(lines) return lines def create_fig(self, fig, ax): self.data_lock = threading.Lock() self.ax = ax self.series = self.draw_empty_plot(ax) if self.plot_config is not None: self.plot_config(ax) def post_init(self, data): super().post_init(data) if self.legend: self.h_legend = self.ax.legend(loc='upper left') for legline, origline in zip(self.h_legend.get_lines(), self.series): legline.set_picker(5) # 5 pts tolerance self.legend_dict[legline] = origline self.fig_manager.fig.canvas.mpl_connect('pick_event', self.on_pick) def on_pick(self, event): # on the pick event, find the orig line corresponding to the # legend proxy line, and toggle the visibility legline = event.artist origline = self.legend_dict[legline] vis = not origline.get_visible() origline.set_visible(vis) # Change the alpha on the line in the legend so we can see what lines # have been toggled if vis: legline.set_alpha(1) else: legline.set_alpha(0.2) self.trigger_legend_redraw = True def draw_empty_plot(self, ax): raise NotImplementedError def update_fig(self, data): raise NotImplementedError def init_fig(self): raise NotImplementedError class SimplePlotLayer(PlotLayer): def __init__(self, port_in, ylim=None, *args, **kwargs): super().__init__(port_in, *args, **kwargs) self.ylim = ylim def draw_empty_plot(self, ax): return [] def post_init(self, data): n_channels = 1 import numpy as np if isinstance(data, np.ndarray): n_channels = data.shape[-1] self.samples = data.shape[0] else: self.samples = len(data) self.series = [] self.ax.set_xlim(0, self.samples) if self.ylim is not None: self.ax.set_ylim(self.ylim) for channel in range(n_channels): handle, = self.ax.plot([], [], '-', lw=1) self.series.append(handle) super().post_init(data) def init_fig(self): for series in self.series: series.set_data([], []) return self.series def update_fig(self, data): import numpy as np x = np.linspace(1, self.samples, self.samples) for (i, series) in enumerate(self.series): if isinstance(data, np.ndarray): series.set_data(x, data[:, i]) else: series.set_data(x, data) return self.series class TimePlotLayer(PlotLayer): def __init__(self, port_in, window_size=100, n_channels=None, ylim=None, lw=1, *args, **kwargs): super().__init__(port_in, *args, **kwargs) self.window_size = window_size self.n_channels = n_channels self.ylim = ylim self.lw = lw self.buffer = None self.pause_button = None self.paused = False self.use_np = False self.num_ticks = 5 self.x_data = np.arange(window_size) # self.x_time_locs = np.linspace(0, window_size, self.num_ticks) # self.x_time = np.linspace(-window_size, 0, self.num_ticks) def pause(self, _): self.paused = not self.paused def draw_empty_plot(self, ax): ax_pause = plt.axes([0.81, 0.005, 0.1, 0.075]) self.pause_button = Button(ax_pause, 'Pause') self.pause_button.on_clicked(self.pause) return [] def post_init(self, data): import numpy as np if isinstance(data, np.ndarray): self.use_np = True if self.n_channels is None: if self.use_np: self.n_channels = data.shape[-1] if len(data.shape) > 1 else 1 else: self.n_channels = 1 self.buffer = np.zeros((self.window_size, self.n_channels)) self.series = [] self.ax.set_xlim(0, self.window_size) # self.ax.set_xticks(self.x_time_locs) # self.ax.set_xticklabels(self.x_time) # self.ax.get_xaxis().set_animated(True) if self.ylim is not None: self.ax.set_ylim(self.ylim) for channel in range(self.n_channels): handle, = self.ax.plot([], [], '-', lw=self.lw, label=channel) self.series.append(handle) super().post_init(data) def init_fig(self): for series in self.series: series.set_data([], []) return self.series def update_fig(self, data): import numpy as np for (i, series) in enumerate(self.series): if isinstance(data, np.ndarray): series.set_data(self.x_data, data[:, i]) else: series.set_data(self.x_data, data) # labels = self.ax.set_xticks(self.x_data) # self.ax.get_xaxis().set_visible(True) # print(self.ax.get_xticklabels()[0]) # self.ax.set_xlim((self.x_data[0], self.x_data[-1])) return self.series#+ [self.ax.get_xaxis()]# + labels def transform(self, data): if self.paused: return None # assert (len(data) < self.window_size) if self.use_np and len(data.shape) == 1: if len(data) == self.n_channels: data = np.atleast_2d(data) else: data = np.atleast_2d(data).T if self.use_np: if not np.isscalar(data): data_size = data.shape[0] else: data_size = 1 self.buffer = np.roll(self.buffer, shift=-data_size, axis=0) if not np.isscalar(data): self.buffer[-data_size:, :] = data else: self.buffer[-1, :] = data else: if isinstance(data, list): if len(data) == self.n_channels: data_size = 1 else: data_size = len(data) else: data_size = 1 self.buffer = np.roll(self.buffer, shift=-data_size, axis=0) if isinstance(data, list): self.buffer[-data_size:, :] = data else: self.buffer[-1, :] = data # self.x_data += data_size # self.x_time += data_size # self.ax.set_xticklabels(self.x_time) super().transform(self.buffer) class BarPlotLayer(PlotLayer): def __init__(self, port_in, ylim=None, *args, **kwargs): super().__init__(port_in, *args, **kwargs) self.ylim = ylim self.series = [] def draw_empty_plot(self, ax): rects = ax.bar([], []) return rects def init_fig(self): for series in self.series: series.set_height([]) return self.series def post_init(self, data): n_channels = 1 if isinstance(data, np.ndarray): n_channels = data.shape[-1] x = list(range(n_channels)) y = np.zeros((n_channels,)) rects = self.ax.bar(x, y) self.series = rects if self.ylim is not None: self.ax.set_ylim(self.ylim) def update_fig(self, data): for (i, bar) in enumerate(self.series.get_children()): if isinstance(data, list) or isinstance(data, np.ndarray): bar.set_height(data[i]) else: bar.set_height(data) return self.series class TextPlotLayer(PlotLayer): def __init__(self, port_in, *args, **kwargs): super().__init__(port_in, *args, **kwargs) self.h_text = None def draw_empty_plot(self, ax): ax.set_axis_off() self.h_text = ax.text(0.5, 0.5, "", horizontalalignment='center', verticalalignment='center', fontsize=15) return self.h_text, def init_fig(self): return self.update_fig("") def update_fig(self, data): self.h_text.set_text(data) return self.h_text,