#!/usr/bin/env python # coding=utf-8 import datetime import os.path import sys import traceback as tb from sacred import metrics_logger from sacred.metrics_logger import linearize_metrics from sacred.randomness import set_global_seed from sacred.utils import SacredInterrupt, join_paths, IntervalTimer from sacred.stdout_capturing import get_stdcapturer class Run: """Represent and manage a single run of an experiment.""" def __init__( self, config, config_modifications, main_function, observers, root_logger, run_logger, experiment_info, host_info, pre_run_hooks, post_run_hooks, captured_out_filter=None, ): self._id = None """The ID of this run as assigned by the first observer""" self.captured_out = "" """Captured stdout and stderr""" self.config = config """The final configuration used for this run""" self.config_modifications = config_modifications """A ConfigSummary object with information about config changes""" self.experiment_info = experiment_info """A dictionary with information about the experiment""" self.host_info = host_info """A dictionary with information about the host""" self.info = {} """Custom info dict that will be sent to the observers""" self.root_logger = root_logger """The root logger that was used to create all the others""" self.run_logger = run_logger """The logger that is used for this run""" self.main_function = main_function """The main function that is executed with this run""" self.observers = observers """A list of all observers that observe this run""" self.pre_run_hooks = pre_run_hooks """List of pre-run hooks (captured functions called before this run)""" self.post_run_hooks = post_run_hooks """List of post-run hooks (captured functions called after this run)""" self.result = None """The return value of the main function""" self.status = None """The current status of the run, from QUEUED to COMPLETED""" self.start_time = None """The datetime when this run was started""" self.stop_time = None """The datetime when this run stopped""" self.debug = False """Determines whether this run is executed in debug mode""" self.pdb = False """If true the pdb debugger is automatically started after a failure""" self.meta_info = {} """A custom comment for this run""" self.beat_interval = 10.0 # sec """The time between two heartbeat events measured in seconds""" self.unobserved = False """Indicates whether this run should be unobserved""" self.force = False """Disable warnings about suspicious changes""" self.queue_only = False """If true then this run will only fire the queued_event and quit""" self.captured_out_filter = captured_out_filter """Filter function to be applied to captured output""" self.fail_trace = None """A stacktrace, in case the run failed""" self.capture_mode = None """Determines the way the stdout/stderr are captured""" self._heartbeat = None self._failed_observers = [] self._output_file = None self._metrics = metrics_logger.MetricsLogger() def open_resource(self, filename, mode="r"): """Open a file and also save it as a resource. Opens a file, reports it to the observers as a resource, and returns the opened file. In Sacred terminology a resource is a file that the experiment needed to access during a run. In case of a MongoObserver that means making sure the file is stored in the database (but avoiding duplicates) along its path and md5 sum. See also :py:meth:`sacred.Experiment.open_resource`. Parameters ---------- filename : str name of the file that should be opened mode : str mode that file will be open Returns ------- file the opened file-object """ filename = os.path.abspath(filename) self._emit_resource_added(filename) # TODO: maybe non-blocking? return open(filename, mode) def add_resource(self, filename): """Add a file as a resource. In Sacred terminology a resource is a file that the experiment needed to access during a run. In case of a MongoObserver that means making sure the file is stored in the database (but avoiding duplicates) along its path and md5 sum. See also :py:meth:`sacred.Experiment.add_resource`. Parameters ---------- filename : str name of the file to be stored as a resource """ filename = os.path.abspath(filename) self._emit_resource_added(filename) def add_artifact(self, filename, name=None, metadata=None, content_type=None): """Add a file as an artifact. In Sacred terminology an artifact is a file produced by the experiment run. In case of a MongoObserver that means storing the file in the database. See also :py:meth:`sacred.Experiment.add_artifact`. Parameters ---------- filename : str name of the file to be stored as artifact name : str, optional optionally set the name of the artifact. Defaults to the filename. metadata: dict optionally attach metadata to the artifact. This only has an effect when using the MongoObserver. content_type: str, optional optionally attach a content-type to the artifact. This only has an effect when using the MongoObserver. """ filename = os.path.abspath(filename) name = os.path.basename(filename) if name is None else name self._emit_artifact_added(name, filename, metadata, content_type) def __call__(self, *args): r"""Start this run. Parameters ---------- \*args parameters passed to the main function Returns ------- the return value of the main function """ if self.start_time is not None: raise RuntimeError( "A run can only be started once. " "(Last start was {})".format(self.start_time) ) if self.unobserved: self.observers = [] else: self.observers = sorted(self.observers, key=lambda x: -x.priority) self.warn_if_unobserved() set_global_seed(self.config["seed"]) if self.capture_mode is None and not self.observers: capture_mode = "no" else: capture_mode = self.capture_mode capture_mode, capture_stdout = get_stdcapturer(capture_mode) self.run_logger.debug('Using capture mode "%s"', capture_mode) if self.queue_only: self._emit_queued() return try: with capture_stdout() as self._output_file: self._emit_started() self._start_heartbeat() self._execute_pre_run_hooks() self.result = self.main_function(*args) self._execute_post_run_hooks() if self.result is not None: self.run_logger.info("Result: {}".format(self.result)) elapsed_time = self._stop_time() self.run_logger.info("Completed after %s", elapsed_time) self._get_captured_output() self._stop_heartbeat() self._emit_completed(self.result) except (SacredInterrupt, KeyboardInterrupt) as e: self._stop_heartbeat() status = getattr(e, "STATUS", "INTERRUPTED") self._emit_interrupted(status) raise except BaseException: exc_type, exc_value, trace = sys.exc_info() self._stop_heartbeat() self._emit_failed(exc_type, exc_value, trace.tb_next) raise finally: self._warn_about_failed_observers() self._wait_for_observers() return self.result def _get_captured_output(self): if self._output_file.closed: return text = self._output_file.get() if isinstance(text, bytes): text = text.decode("utf-8", "replace") if self.captured_out: text = self.captured_out + text if self.captured_out_filter is not None: text = self.captured_out_filter(text) self.captured_out = text def _start_heartbeat(self): self.run_logger.debug("Starting Heartbeat") if self.beat_interval > 0: self._stop_heartbeat_event, self._heartbeat = IntervalTimer.create( self._emit_heartbeat, self.beat_interval ) self._heartbeat.start() def _stop_heartbeat(self): self.run_logger.debug("Stopping Heartbeat") # only stop if heartbeat was started if self._heartbeat is not None: self._stop_heartbeat_event.set() self._heartbeat.join(timeout=2) def _emit_queued(self): self.status = "QUEUED" queue_time = datetime.datetime.utcnow() self.meta_info["queue_time"] = queue_time command = join_paths( self.main_function.prefix, self.main_function.signature.name ) self.run_logger.info("Queuing-up command '%s'", command) for observer in self.observers: _id = observer.queued_event( ex_info=self.experiment_info, command=command, host_info=self.host_info, queue_time=queue_time, config=self.config, meta_info=self.meta_info, _id=self._id, ) if self._id is None: self._id = _id # do not catch any exceptions on startup: # the experiment SHOULD fail if any of the observers fails if self._id is None: self.run_logger.info("Queued") else: self.run_logger.info('Queued-up run with ID "{}"'.format(self._id)) def _emit_started(self): self.status = "RUNNING" self.start_time = datetime.datetime.utcnow() command = join_paths( self.main_function.prefix, self.main_function.signature.name ) self.run_logger.info("Running command '%s'", command) for observer in self.observers: _id = observer.started_event( ex_info=self.experiment_info, command=command, host_info=self.host_info, start_time=self.start_time, config=self.config, meta_info=self.meta_info, _id=self._id, ) if self._id is None: self._id = _id # do not catch any exceptions on startup: # the experiment SHOULD fail if any of the observers fails if self._id is None: self.run_logger.info("Started") else: self.run_logger.info('Started run with ID "{}"'.format(self._id)) def _emit_heartbeat(self): beat_time = datetime.datetime.utcnow() self._get_captured_output() # Read all measured metrics since last heartbeat logged_metrics = self._metrics.get_last_metrics() metrics_by_name = linearize_metrics(logged_metrics) for observer in self.observers: self._safe_call( observer, "log_metrics", metrics_by_name=metrics_by_name, info=self.info ) self._safe_call( observer, "heartbeat_event", info=self.info, captured_out=self.captured_out, beat_time=beat_time, result=self.result, ) def _stop_time(self): self.stop_time = datetime.datetime.utcnow() elapsed_time = datetime.timedelta( seconds=round((self.stop_time - self.start_time).total_seconds()) ) return elapsed_time def _emit_completed(self, result): self.status = "COMPLETED" for observer in self.observers: self._final_call( observer, "completed_event", stop_time=self.stop_time, result=result ) def _emit_interrupted(self, status): self.status = status elapsed_time = self._stop_time() self.run_logger.warning("Aborted after %s!", elapsed_time) for observer in self.observers: self._final_call( observer, "interrupted_event", interrupt_time=self.stop_time, status=status, ) def _emit_failed(self, exc_type, exc_value, trace): self.status = "FAILED" elapsed_time = self._stop_time() self.run_logger.error("Failed after %s!", elapsed_time) self.fail_trace = tb.format_exception(exc_type, exc_value, trace) for observer in self.observers: self._final_call( observer, "failed_event", fail_time=self.stop_time, fail_trace=self.fail_trace, ) def _emit_resource_added(self, filename): for observer in self.observers: self._safe_call(observer, "resource_event", filename=filename) def _emit_artifact_added(self, name, filename, metadata, content_type): for observer in self.observers: self._safe_call( observer, "artifact_event", name=name, filename=filename, metadata=metadata, content_type=content_type, ) def _safe_call(self, obs, method, **kwargs): if obs not in self._failed_observers: try: getattr(obs, method)(**kwargs) except Exception as e: self._failed_observers.append(obs) self.run_logger.warning( "An error ocurred in the '{}' " "observer: {}".format(obs, e) ) def _final_call(self, observer, method, **kwargs): try: getattr(observer, method)(**kwargs) except Exception: # Feels dirty to catch all exceptions, but it is just for # finishing up, so we don't want one observer to kill the # others self.run_logger.error(tb.format_exc()) def _wait_for_observers(self): """Block until all observers finished processing.""" for observer in self.observers: self._safe_call(observer, "join") def _warn_about_failed_observers(self): for observer in self._failed_observers: self.run_logger.warning( "The observer '{}' failed at some point " "during the run.".format(observer) ) def _execute_pre_run_hooks(self): for pr in self.pre_run_hooks: pr() def _execute_post_run_hooks(self): for pr in self.post_run_hooks: pr() def warn_if_unobserved(self): if not self.observers and not self.debug and not self.unobserved: self.run_logger.warning("No observers have been added to this run") def log_scalar(self, metric_name, value, step=None): """ Add a new measurement. The measurement will be processed by the MongoDB observer during a heartbeat event. Other observers are not yet supported. :param metric_name: The name of the metric, e.g. training.loss :param value: The measured value :param step: The step number (integer), e.g. the iteration number If not specified, an internal counter for each metric is used, incremented by one. """ # Method added in change https://github.com/chovanecm/sacred/issues/4 # The same as Experiment.log_scalar (if something changes, # update the docstring too!) self._metrics.log_scalar_metric(metric_name, value, step)