from abc import ABCMeta, abstractmethod from collections import namedtuple from functools import total_ordering import logging import signal import time from six import add_metaclass from six.moves.queue import PriorityQueue from subprocess import Popen, PIPE from threading import Thread try: from bokeh.document import Document from bokeh.plotting import figure from bokeh.session import Session BOKEH_AVAILABLE = True except ImportError: BOKEH_AVAILABLE = False from blocks.config import config from blocks.extensions import SimpleExtension logger = logging.getLogger(__name__) @add_metaclass(ABCMeta) class PlottingExtension(SimpleExtension): """Base class for extensions doing Bokeh plotting. Parameters ---------- document_name : str The name of the Bokeh document. Use a different name for each experiment if you are storing your plots. start_server : bool, optional Whether to try and start the Bokeh plotting server. Defaults to ``False``. The server started is not persistent i.e. after shutting it down you will lose your plots. If you want to store your plots, start the server manually using the ``bokeh-server`` command. Also see the warning above. server_url : str, optional Url of the bokeh-server. Ex: when starting the bokeh-server with ``bokeh-server --ip 0.0.0.0`` at ``alice``, server_url should be ``http://alice:5006``. When not specified the default configured by ``bokeh_server`` in ``.blocksrc`` will be used. Defaults to ``http://localhost:5006/``. clear_document : bool, optional Whether or not to clear the contents of the server-side document upon creation. If `False`, previously existing plots within the document will be kept. Defaults to `True`. """ def __init__(self, document_name, server_url=None, start_server=False, clear_document=True, **kwargs): self.document_name = document_name self.server_url = (config.bokeh_server if server_url is None else server_url) self.start_server = start_server self.sub = self._start_server_process() if self.start_server else None self.session = Session(root_url=self.server_url) self.document = Document() self._setup_document(clear_document) super(PlottingExtension, self).__init__(**kwargs) def _start_server_process(self): def preexec_fn(): """Prevents the server from dying on training interrupt.""" signal.signal(signal.SIGINT, signal.SIG_IGN) # Only memory works with subprocess, need to wait for it to start logger.info('Starting plotting server on localhost:5006') self.sub = Popen('bokeh-server --ip 0.0.0.0 ' '--backend memory'.split(), stdout=PIPE, stderr=PIPE, preexec_fn=preexec_fn) time.sleep(2) logger.info('Plotting server PID: {}'.format(self.sub.pid)) def _setup_document(self, clear_document=False): self.session.use_doc(self.document_name) self.session.load_document(self.document) if clear_document: self.document.clear() self._document_setup_done = True def __getstate__(self): state = self.__dict__.copy() state['sub'] = None state.pop('session', None) state.pop('_push_thread', None) return state def __setstate__(self, state): self.__dict__.update(state) if self.start_server: self._start_server_process() self.session = Session(root_url=self.server_url) self._document_setup_done = False def do(self, which_callback, *args): if not self._document_setup_done: self._setup_document() @property def push_thread(self): if not hasattr(self, '_push_thread'): self._push_thread = PushThread(self.session, self.document) self._push_thread.start() return self._push_thread def store(self, obj): self.push_thread.put(obj, PushThread.PUT) def push(self, which_callback): self.push_thread.put(which_callback, PushThread.PUSH) class Plot(PlottingExtension): r"""Live plotting of monitoring channels. In most cases it is preferable to start the Bokeh plotting server manually, so that your plots are stored permanently. Alternatively, you can set the ``start_server`` argument of this extension to ``True``, to automatically start a server when training starts. However, in that case your plots will be deleted when you shut down the plotting server! .. warning:: When starting the server automatically using the ``start_server`` argument, the extension won't attempt to shut down the server at the end of training (to make sure that you do not lose your plots the moment training completes). You have to shut it down manually (the PID will be shown in the logs). If you don't do this, this extension will crash when you try and train another model with ``start_server`` set to ``True``, because it can't run two servers at the same time. Parameters ---------- document_name : str See :class:`PlottingExtension` for details. channels : list of channel specifications A channel specification is either a list of channel names, or a dict with at least the entry ``channels`` mapping to a list of channel names. The channels in a channel specification will be plotted together in a single figure, so use e.g. ``[['test_cost', 'train_cost'], ['weight_norms']]`` to plot a single figure with the training and test cost, and a second figure for the weight norms. When the channel specification is a list, a bokeh figure will be created with default arguments. When the channel specification is a dict, the field channels is used to specify the contnts of the figure, and all remaining keys are passed as ``\*\*kwargs`` to the ``figure`` function. """ # Tableau 10 colors colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] def __init__(self, document_name, channels, **kwargs): if not BOKEH_AVAILABLE: raise ImportError('Bokeh required for {} extension' .format(self.__class__.__name__)) self.plots = {} # Create figures for each group of channels self.p = [] self.p_indices = {} self.color_indices = {} for i, channel_set in enumerate(channels): channel_set_opts = {} if isinstance(channel_set, dict): channel_set_opts = channel_set channel_set = channel_set_opts.pop('channels') channel_set_opts.setdefault('title', '{} #{}'.format(document_name, i + 1)) channel_set_opts.setdefault('x_axis_label', 'iterations') channel_set_opts.setdefault('y_axis_label', 'value') self.p.append(figure(**channel_set_opts)) for j, channel in enumerate(channel_set): self.p_indices[channel] = i self.color_indices[channel] = j kwargs.setdefault('after_epoch', True) kwargs.setdefault('before_first_epoch', True) kwargs.setdefault('after_training', True) super(Plot, self).__init__(document_name, **kwargs) def do(self, which_callback, *args): super(Plot, self).do(which_callback, *args) log = self.main_loop.log iteration = log.status['iterations_done'] for key, value in log.current_row.items(): if key in self.p_indices: if key not in self.plots: line_color = self.colors[ self.color_indices[key] % len(self.colors)] fig = self.p[self.p_indices[key]] fig.line([iteration], [value], legend=key, name=key, line_color=line_color) self.document.add(fig) renderer = fig.select(dict(name=key)) self.plots[key] = renderer[0].data_source else: self.plots[key].data['x'].append(iteration) self.plots[key].data['y'].append(value) self.store(self.plots[key]) self.push(which_callback) @total_ordering class _WorkItem(namedtuple('BaseWorkItem', ['priority', 'obj'])): __slots__ = () def __lt__(self, other): return self.priority < other.priority class PushThread(Thread): PUSH, PUT = range(2) def __init__(self, session, document): self.session = session self.document = document super(PushThread, self).__init__() self.queue = PriorityQueue() self.setDaemon(True) def put(self, obj, priority): self.queue.put(_WorkItem(priority, obj)) def run(self): while True: priority, obj = self.queue.get() if priority == PushThread.PUT: self.session.store_objects(obj) elif priority == PushThread.PUSH: self.session.store_document(self.document) # delete queued objects when training has finished if obj == 'after_training': with self.queue.mutex: del self.queue.queue[:] break self.queue.task_done()