""" The projects module defines the Project class, which stores information about a computation-based project and contains a number of methods for managing and running computational experiments, whether simulations, analyses or whatever. This is the main class that is used directly when using Sumatra within your own scripts. Classes ------- Project - stores information about a computational project, and enables launching, annotating, deleting and retrieving information about simulation/analysis runs. Functions --------- load_project() - read project information from the working directory and return a Project object. :copyright: Copyright 2006-2015 by the Sumatra team, see doc/authors.txt :license: BSD 2-clause, see LICENSE for details. """ from __future__ import print_function from __future__ import unicode_literals from future import standard_library standard_library.install_aliases() from builtins import str from builtins import object import os import re import importlib import pickle from copy import deepcopy import uuid import sumatra import django import sqlite3 import time import shutil import textwrap from datetime import datetime from importlib import import_module from sumatra.records import Record from sumatra import programs, datastore from sumatra.formatting import get_formatter, get_diff_formatter from sumatra.recordstore import DefaultRecordStore from sumatra.versioncontrol import UncommittedModificationsError, get_working_copy, VersionControlError from sumatra.core import TIMESTAMP_FORMAT import mimetypes import json import logging logger = logging.getLogger("Sumatra") DEFAULT_PROJECT_FILE = "project" LABEL_GENERATORS = { 'timestamp': lambda: None, # this is the default, implemented in the Record class 'uuid': lambda: str(uuid.uuid4()).split('-')[-1] } def _remove_left_margin(s): # replace this by textwrap.dedent? lines = s.strip().split('\n') return "\n".join(line.strip() for line in lines) def _get_project_file(path): return os.path.join(path, ".smt", DEFAULT_PROJECT_FILE) class Project(object): valid_name_pattern = r'(?P<project>\w+[\w\- ]*)' def __init__(self, name, default_executable=None, default_repository=None, default_main_file=None, default_launch_mode=None, data_store='default', record_store='default', on_changed='error', description='', data_label=None, input_datastore=None, label_generator='timestamp', timestamp_format=TIMESTAMP_FORMAT, allow_command_line_parameters=True, plugins=[]): self.path = os.getcwd() if not os.path.exists(".smt"): os.mkdir(".smt") if os.path.exists(_get_project_file(self.path)): raise Exception("Sumatra project already exists in this directory.") if re.match(Project.valid_name_pattern, name): self.name = name else: raise ValueError("Invalid project name. Names may only contain letters, numbers, spaces and hyphens") self.default_executable = default_executable self.default_repository = default_repository # maybe we should be storing the working copy instead, as this has a ref to the repository anyway self.default_main_file = default_main_file self.default_launch_mode = default_launch_mode if data_store == 'default': data_store = datastore.FileSystemDataStore(None) self.data_store = data_store # a data store object self.input_datastore = input_datastore or self.data_store if record_store == 'default': record_store = DefaultRecordStore(os.path.abspath(".smt/records")) self.record_store = record_store self.on_changed = on_changed self.description = description self.data_label = data_label self.label_generator = label_generator self.timestamp_format = timestamp_format self.sumatra_version = sumatra.__version__ self.allow_command_line_parameters = allow_command_line_parameters self._most_recent = None self.plugins = [] self.load_plugins(*plugins) self.save() print("Sumatra project successfully set up") def __set_data_label(self, value): assert value in (None, 'parameters', 'cmdline') self._data_label = value def __get_data_label(self): return self._data_label data_label = property(fset=__set_data_label, fget=__get_data_label) def save(self): """Save state to some form of persistent storage. (file, database).""" state = {} for name in ('name', 'default_executable', 'default_repository', 'default_launch_mode', 'data_store', 'record_store', 'default_main_file', 'on_changed', 'description', 'data_label', '_most_recent', 'input_datastore', 'label_generator', 'timestamp_format', 'sumatra_version', 'allow_command_line_parameters', 'plugins'): try: attr = getattr(self, name) except: # Some parameters which need special treatment to avoid # unexpected behaviour. if name == 'allow_command_line_parameters': print(textwrap.dedent("""\ Upgrading from a Sumatra version which did not have the --plain configuration option. After this upgrade, arguments to 'smt run' of the form 'name=value' will continue to overwrite default parameter values, but this is now configurable. If it is desired that they should be passed straight through to the program, run the command 'smt configure --plain' after this upgrade. """)) attr = True else: # Default value for unrecognised parameters attr = None if hasattr(attr, "__getstate__"): state[name] = {'type': attr.__class__.__module__ + "." + attr.__class__.__name__} for key, value in attr.__getstate__().items(): state[name][key] = value else: state[name] = attr f = open(_get_project_file(self.path), 'w') # should check if file exists? json.dump(state, f, indent=2) f.close() def info(self): """Show some basic information about the project.""" template = """ Project name : %(name)s Default executable : %(default_executable)s Default repository : %(default_repository)s Default main file : %(default_main_file)s Default launch mode : %(default_launch_mode)s Data store (output) : %(data_store)s . (input) : %(input_datastore)s Record store : %(record_store)s Code change policy : %(on_changed)s Append label to : %(_data_label)s Label generator : %(label_generator)s Timestamp format : %(timestamp_format)s Plug-ins : %(plugins)s Sumatra version : %(sumatra_version)s """ return _remove_left_margin(template % self.__dict__) def new_record(self, parameters={}, input_data=[], script_args="", executable='default', repository='default', main_file='default', version='current', launch_mode='default', diff='', label=None, reason=None, timestamp_format='default'): logger.debug("Creating new record") if executable == 'default': executable = deepcopy(self.default_executable) if repository == 'default': repository = deepcopy(self.default_repository) if main_file == 'default': main_file = self.default_main_file if launch_mode == 'default': launch_mode = deepcopy(self.default_launch_mode) if timestamp_format == 'default': timestamp_format = self.timestamp_format working_copy = repository.get_working_copy() version, diff = self.update_code(working_copy, version, diff) if label is None: label = LABEL_GENERATORS[self.label_generator]() record = Record(executable, repository, main_file, version, launch_mode, self.data_store, parameters, input_data, script_args, label=label, reason=reason, diff=diff, on_changed=self.on_changed, input_datastore=self.input_datastore, timestamp_format=timestamp_format) self.add_record(record) if not isinstance(executable, programs.MatlabExecutable): record.register(working_copy) return record def launch(self, parameters={}, input_data=[], script_args="", executable='default', repository='default', main_file='default', version='current', launch_mode='default', diff='', label=None, reason=None, timestamp_format='default', repeats=None): """Launch a new simulation or analysis.""" record = self.new_record(parameters, input_data, script_args, executable, repository, main_file, version, launch_mode, diff, label, reason, timestamp_format) record.run(with_label=self.data_label, project=self) if 'matlab' in record.executable.name.lower(): record.register(record.repository.get_working_copy()) if repeats: record.repeats = repeats self.save_record(record) logger.debug("Record saved @ completion.") self.save() return record.label def update_code(self, working_copy, version='current', diff=''): """Check if the working copy has modifications and prompt to commit or revert them.""" # we really need to extend this to the dependencies, but we need to take extra special care that the # code ends up in the same condition as before the run logger.debug("Updating working copy to use version: %s" % version) changed = working_copy.has_changed() if (version == 'current' or version == working_copy.current_version) and not diff: if changed: if self.on_changed == "error": raise UncommittedModificationsError("Code has changed, please commit your changes") elif self.on_changed == "store-diff": diff = working_copy.diff() else: raise ValueError("store-diff must be either 'error' or 'store-diff'") elif diff: if changed: raise UncommittedModificationsError( "Code has changed. These changes will be lost when switching " "to a different version, so please commit or stash your " "changes and then retry.") else: working_copy.use_version(version) working_copy.patch(diff) elif version == 'latest': working_copy.use_latest_version() else: working_copy.use_version(version) version = working_copy.current_version() return version, diff def add_record(self, record): """Add a simulation or analysis record.""" success = False cnt = 0 max_tries = 200 sleep_seconds = 5 while not success and cnt < max_tries: try: self.save_record(record) success = True self._most_recent = record.label logger.debug("Created record: %s" % self.most_recent()) except (django.db.utils.DatabaseError, sqlite3.OperationalError): print("Failed to save record due to database error. Trying again in {0} seconds. (Attempt {1}/{2})".format(sleep_seconds, cnt, max_tries)) time.sleep(sleep_seconds) cnt += 1 if cnt == max_tries: print("Reached maximum number of attempts to save record. Aborting.") def save_record(self, record): self.record_store.save(self.name, record) def get_record(self, label): """Search for a record with the supplied label and return it if found. Otherwise return None.""" return self.record_store.get(self.name, label) def delete_record(self, label, delete_data=False): """Delete a record. Return 1 if the record is found. Otherwise return 0.""" if delete_data: self.get_record(label).delete_data() self.record_store.delete(self.name, label) self._most_recent = self.record_store.most_recent(self.name) def delete_by_tag(self, tag, delete_data=False): """Delete all records with a given tag. Return the number of records deleted.""" if delete_data: for record in self.record_store.list(self.name, tag): record.delete_data() n = self.record_store.delete_by_tag(self.name, tag) self._most_recent = self.record_store.most_recent(self.name) return n def get_labels(self, tags=None, reverse=False, *args, **kwargs): labels = self.record_store.labels(self.name, tags=tags, *args, **kwargs) if reverse: labels.reverse() return labels def find_records(self, tags=None, reverse=False, parameters=None, *args, **kwargs): records = self.record_store.list(self.name, tags=tags, *args, **kwargs) if reverse: records.reverse() if parameters is not None: records = [rec for rec in records if len(rec.parameters.diff(parameters)[-1]) == 0] return records def find_input_data(self, *args, **kwargs): records = self.find_records(*args, **kwargs) if len(records) == 0: return [] input_data = [] for record in records: for input_file in record.input_data: input_data.append(input_file) return input_data def find_output_data(self, *args, **kwargs): records = self.find_records(*args, **kwargs) if (records) == 0: return [] output_data = [] for record in records: for output_file in record.output_data: output_data.append(output_file) return output_data def find_data(self, *args, **kwargs): input_data = self.find_input_data(*args, **kwargs) output_data = self.find_output_data(*args, **kwargs) return {'input_data': input_data, 'output_data': output_data} def format_records(self, format='text', mode='short', tags=None, reverse=False, *args, **kwargs): if format=='text' and mode=='short' and ('parameters' not in kwargs.keys()): return '\n'.join(self.get_labels(tags=tags, reverse=reverse, *args, **kwargs)) else: records = self.find_records(tags=tags, reverse=reverse, *args, **kwargs) formatter = get_formatter(format)(records, project=self, tags=tags) return formatter.format(mode) def most_recent(self): try: return self.get_record(self._most_recent) except KeyError: # the record pointed to by self._most_recent has been deleted self._most_recent = self.record_store.most_recent(self.name) return self.most_recent() def add_comment(self, label, comment, replace=False): try: record = self.record_store.get(self.name, label) except Exception as e: raise Exception("%s. label=<%s>" % (e, label)) if replace or record.outcome is "": record.outcome = comment else: record.outcome = record.outcome + "\n" + comment self.save_record(record) def add_tag(self, label, tag): record = self.record_store.get(self.name, label) record.add_tag(tag) self.save_record(record) def remove_tag(self, label, tag): record = self.record_store.get(self.name, label) record.tags.remove(tag) self.save_record(record) def compare(self, label1, label2, ignore_mimetypes=[], ignore_filenames=[]): record1 = self.record_store.get(self.name, label1) record2 = self.record_store.get(self.name, label2) return record1.difference(record2, ignore_mimetypes, ignore_filenames) def show_diff(self, label1, label2, mode='short', ignore_mimetypes=[], ignore_filenames=[]): diff = self.compare(label1, label2, ignore_mimetypes, ignore_filenames) formatter = get_diff_formatter()(diff) return formatter.format(mode) def export(self): # copy the project data shutil.copy(".smt/project", ".smt/project_export.json") # export the record data f = open(".smt/records_export.json", 'w') f.write(self.record_store.export(self.name)) f.close() def repeat(self, original_label, new_label=None): if original_label == 'last': tmp = self.most_recent() else: tmp = self.get_record(original_label) original = deepcopy(tmp) if hasattr(tmp.parameters, '_url'): # for some reason, _url is not copied. original.parameters._url = tmp.parameters._url # this is a hackish solution - needs fixed properly try: working_copy = get_working_copy() except VersionControlError: original.repository.checkout() working_copy = original.repository.get_working_copy() if working_copy.repository != original.repository: raise NotImplementedError("Ability to switch repositories not yet implemented.") current_version = working_copy.current_version() new_label = self.launch(parameters=original.parameters, input_data=original.input_data, script_args=original.script_arguments, executable=original.executable, main_file=original.main_file, repository=original.repository, version=original.version, launch_mode=original.launch_mode, diff=original.diff, label=new_label, reason="Repeat experiment %s" % original.label, repeats=original.label) working_copy.reset() working_copy.use_version(current_version) # ensure we switch back to the original working copy state return new_label, original.label def backup(self, remove_original=False): """ Create a new backup directory in the same location as the project directory and copy the contents of the project directory into the backup directory. Uses `_get_project_file` to extract the path to the project directory. :return: - `backup_dir`: the directory used for the backup """ if remove_original: self.record_store.remove() # creates backup then removes original file else: self.record_store.backup() smt_dir = os.path.split(_get_project_file(self.path))[0] backup_dir = smt_dir + "_backup_%s" % datetime.now().strftime(TIMESTAMP_FORMAT) shutil.copytree(smt_dir, backup_dir) if remove_original: shutil.rmtree(smt_dir) return backup_dir def change_record_store(self, new_store): """ Change the record store that is used by this project. """ self.backup() old_store = self.record_store new_store.sync(old_store, self.name) self.record_store = new_store def load_plugins(self, *plugins): for plugin in plugins: import_module(plugin) if plugin not in self.plugins: self.plugins.append(plugin) def remove_plugins(self, *plugins): # note that we do not unimport plugins, we just remove them from the list # in future, we should unregister the component from the registry for plugin in plugins: self.plugins.remove(plugin) def _load_project_from_json(path): f = open(_get_project_file(path), 'r') data = json.load(f) f.close() prj = Project.__new__(Project) prj.path = path for key, value in data.items(): if isinstance(value, dict) and "type" in value: parts = str(value["type"]).split(".") # make sure not unicode, see http://stackoverflow.com/questions/1971356/haystack-whoosh-index-generation-error/2683624#2683624 module_name = ".".join(parts[:-1]) class_name = parts[-1] module = importlib.import_module(module_name) cls = getattr(module, class_name) args = {} for k, v in value.items(): if k != 'type': args[str(k)] = v # need to use str() as json module uses all unicode setattr(prj, key, cls(**args)) else: setattr(prj, key, value) if hasattr(prj, "plugins"): prj.load_plugins(*prj.plugins) else: prj.plugins = [] return prj def _load_project_from_pickle(path): # earlier versions of Sumatra saved Projects using pickle f = open(_get_project_file(path), 'r') prj = pickle.load(f) f.close() return prj def load_project(path=None): """ Read project from directory passed as the argument and return Project object. If no argument is given, the project is read from the current directory. """ if not path: p = os.getcwd() else: p = os.path.abspath(path) while not os.path.isdir(os.path.join(p, ".smt")): oldp, p = p, os.path.dirname(p) if p == oldp: raise IOError("No Sumatra project exists in the current directory or above it.") mimetypes.init(mimetypes.knownfiles + [os.path.join(p, ".smt", "mime.types")]) # try: prj = _load_project_from_json(p) # except Exception: # prj = _load_project_from_pickle(p) return prj