"""Support code for YAML parsing of experiment descriptions.""" import yaml from pylearn2.utils import serial from pylearn2.utils.exc import reraise_as from pylearn2.utils.string_utils import preprocess from pylearn2.utils.call_check import checked_call from pylearn2.utils.string_utils import match from collections import namedtuple import logging import warnings import re from theano.compat import six SCIENTIFIC_NOTATION_REGEXP = r'^[\-\+]?(\d+\.?\d*|\d*\.?\d+)?[eE][\-\+]?\d+$' is_initialized = False additional_environ = None logger = logging.getLogger(__name__) # Lightweight container for initial YAML evaluation. # # This is intended as a robust, forward-compatible intermediate representation # for either internal consumption or external consumption by another tool e.g. # hyperopt. # # We've included a slot for positionals just in case, though they are # unsupported by the instantiation mechanism as yet. BaseProxy = namedtuple('BaseProxy', ['callable', 'positionals', 'keywords', 'yaml_src']) class Proxy(BaseProxy): """ An intermediate representation between initial YAML parse and object instantiation. Parameters ---------- callable : callable The function/class to call to instantiate this node. positionals : iterable Placeholder for future support for positional arguments (`*args`). keywords : dict-like A mapping from keywords to arguments (`**kwargs`), which may be `Proxy`s or `Proxy`s nested inside `dict` or `list` instances. Keys must be strings that are valid Python variable names. yaml_src : str The YAML source that created this node, if available. Notes ----- This is intended as a robust, forward-compatible intermediate representation for either internal consumption or external consumption by another tool e.g. hyperopt. This particular class mainly exists to override `BaseProxy`'s `__hash__` (to avoid hashing unhashable namedtuple elements). """ __slots__ = [] def __hash__(self): """ Return a hash based on the object ID (to avoid hashing unhashable namedtuple elements). """ return hash(id(self)) def do_not_recurse(value): """ Function symbol used for wrapping an unpickled object (which should not be recursively expanded). This is recognized and respected by the instantiation parser. Implementationally, no-op (returns the value passed in as an argument). Parameters ---------- value : object The value to be returned. Returns ------- value : object The same object passed in as an argument. """ return value def _instantiate_proxy_tuple(proxy, bindings=None): """ Helper function for `_instantiate` that handles objects of the `Proxy` class. Parameters ---------- proxy : Proxy object A `Proxy` object that. bindings : dict, opitonal A dictionary mapping previously instantiated `Proxy` objects to their instantiated values. Returns ------- obj : object The result object from recursively instantiating the object DAG. """ if proxy in bindings: return bindings[proxy] else: # Respect do_not_recurse by just un-packing it (same as calling). if proxy.callable == do_not_recurse: obj = proxy.keywords['value'] else: # TODO: add (requested) support for positionals (needs to be added # to checked_call also). if len(proxy.positionals) > 0: raise NotImplementedError('positional arguments not yet ' 'supported in proxy instantiation') kwargs = dict((k, _instantiate(v, bindings)) for k, v in six.iteritems(proxy.keywords)) obj = checked_call(proxy.callable, kwargs) try: obj.yaml_src = proxy.yaml_src except AttributeError: # Some classes won't allow this. pass bindings[proxy] = obj return bindings[proxy] def _instantiate(proxy, bindings=None): """ Instantiate a (hierarchy of) Proxy object(s). Parameters ---------- proxy : object A `Proxy` object or list/dict/literal. Strings are run through `preprocess`. bindings : dict, opitonal A dictionary mapping previously instantiated `Proxy` objects to their instantiated values. Returns ------- obj : object The result object from recursively instantiating the object DAG. Notes ----- This should not be considered part of the stable, public API. """ if bindings is None: bindings = {} if isinstance(proxy, Proxy): return _instantiate_proxy_tuple(proxy, bindings) elif isinstance(proxy, dict): # Recurse on the keys too, for backward compatibility. # Is the key instantiation feature ever actually used, by anyone? return dict((_instantiate(k, bindings), _instantiate(v, bindings)) for k, v in six.iteritems(proxy)) elif isinstance(proxy, list): return [_instantiate(v, bindings) for v in proxy] # In the future it might be good to consider a dict argument that provides # a type->callable mapping for arbitrary transformations like this. elif isinstance(proxy, six.string_types): return preprocess(proxy) else: return proxy def load(stream, environ=None, instantiate=True, **kwargs): """ Loads a YAML configuration from a string or file-like object. Parameters ---------- stream : str or object Either a string containing valid YAML or a file-like object supporting the .read() interface. environ : dict, optional A dictionary used for ${FOO} substitutions in addition to environment variables. If a key appears both in `os.environ` and this dictionary, the value in this dictionary is used. instantiate : bool, optional If `False`, do not actually instantiate the objects but instead produce a nested hierarchy of `Proxy` objects. Returns ------- graph : dict or object The dictionary or object (if the top-level element specified a Python object to instantiate), or a nested hierarchy of `Proxy` objects. Notes ----- Other keyword arguments are passed on to `yaml.load`. """ global is_initialized global additional_environ if not is_initialized: initialize() additional_environ = environ if isinstance(stream, six.string_types): string = stream else: string = stream.read() proxy_graph = yaml.load(string, **kwargs) if instantiate: return _instantiate(proxy_graph) else: return proxy_graph def load_path(path, environ=None, instantiate=True, **kwargs): """ Convenience function for loading a YAML configuration from a file. Parameters ---------- path : str The path to the file to load on disk. environ : dict, optional A dictionary used for ${FOO} substitutions in addition to environment variables. If a key appears both in `os.environ` and this dictionary, the value in this dictionary is used. instantiate : bool, optional If `False`, do not actually instantiate the objects but instead produce a nested hierarchy of `Proxy` objects. Returns ------- graph : dict or object The dictionary or object (if the top-level element specified a Python object to instantiate), or a nested hierarchy of `Proxy` objects. Notes ----- Other keyword arguments are passed on to `yaml.load`. """ with open(path, 'r') as f: content = ''.join(f.readlines()) # This is apparently here to avoid the odd instance where a file gets # loaded as Unicode instead (see 03f238c6d). It's rare instance where # basestring is not the right call. if not isinstance(content, str): raise AssertionError("Expected content to be of type str, got " + str(type(content))) return load(content, instantiate=instantiate, environ=environ, **kwargs) def try_to_import(tag_suffix): """ .. todo:: WRITEME """ components = tag_suffix.split('.') modulename = '.'.join(components[:-1]) try: exec('import %s' % modulename) except ImportError as e: # We know it's an ImportError, but is it an ImportError related to # this path, # or did the module we're importing have an unrelated ImportError? # and yes, this test can still have false positives, feel free to # improve it pieces = modulename.split('.') str_e = str(e) found = True in [piece.find(str(e)) != -1 for piece in pieces] if found: # The yaml file is probably to blame. # Report the problem with the full module path from the YAML # file reraise_as(ImportError("Could not import %s; ImportError was %s" % (modulename, str_e))) else: pcomponents = components[:-1] assert len(pcomponents) >= 1 j = 1 while j <= len(pcomponents): modulename = '.'.join(pcomponents[:j]) try: exec('import %s' % modulename) except Exception: base_msg = 'Could not import %s' % modulename if j > 1: modulename = '.'.join(pcomponents[:j - 1]) base_msg += ' but could import %s' % modulename reraise_as(ImportError(base_msg + '. Original exception: ' + str(e))) j += 1 try: obj = eval(tag_suffix) except AttributeError as e: try: # Try to figure out what the wrong field name was # If we fail to do it, just fall back to giving the usual # attribute error pieces = tag_suffix.split('.') module = '.'.join(pieces[:-1]) field = pieces[-1] candidates = dir(eval(module)) msg = ('Could not evaluate %s. ' % tag_suffix + 'Did you mean ' + match(field, candidates) + '? ' + 'Original error was ' + str(e)) except Exception: warnings.warn("Attempt to decipher AttributeError failed") reraise_as(AttributeError('Could not evaluate %s. ' % tag_suffix + 'Original error was ' + str(e))) reraise_as(AttributeError(msg)) return obj def initialize(): """ Initialize the configuration system by installing YAML handlers. Automatically done on first call to load() specified in this file. """ global is_initialized # Add the custom multi-constructor yaml.add_multi_constructor('!obj:', multi_constructor_obj) yaml.add_multi_constructor('!pkl:', multi_constructor_pkl) yaml.add_multi_constructor('!import:', multi_constructor_import) yaml.add_constructor('!import', constructor_import) yaml.add_constructor("!float", constructor_float) pattern = re.compile(SCIENTIFIC_NOTATION_REGEXP) yaml.add_implicit_resolver('!float', pattern) is_initialized = True ############################################################################### # Callbacks used by PyYAML def multi_constructor_obj(loader, tag_suffix, node): """ Callback used by PyYAML when a "!obj:" tag is encountered. See PyYAML documentation for details on the call signature. """ yaml_src = yaml.serialize(node) construct_mapping(node) mapping = loader.construct_mapping(node) assert hasattr(mapping, 'keys') assert hasattr(mapping, 'values') for key in mapping.keys(): if not isinstance(key, six.string_types): message = "Received non string object (%s) as " \ "key in mapping." % str(key) raise TypeError(message) if '.' not in tag_suffix: # TODO: I'm not sure how this was ever working without eval(). callable = eval(tag_suffix) else: callable = try_to_import(tag_suffix) rval = Proxy(callable=callable, yaml_src=yaml_src, positionals=(), keywords=mapping) return rval def multi_constructor_pkl(loader, tag_suffix, node): """ Callback used by PyYAML when a "!pkl:" tag is encountered. """ global additional_environ if tag_suffix != "" and tag_suffix != u"": raise AssertionError('Expected tag_suffix to be "" but it is "' + tag_suffix + '": Put space between !pkl: and the filename.') mapping = loader.construct_yaml_str(node) obj = serial.load(preprocess(mapping, additional_environ)) proxy = Proxy(callable=do_not_recurse, positionals=(), keywords={'value': obj}, yaml_src=yaml.serialize(node)) return proxy def multi_constructor_import(loader, tag_suffix, node): """ Callback used by PyYAML when a "!import:" tag is encountered. """ if '.' not in tag_suffix: raise yaml.YAMLError("!import: tag suffix contains no '.'") return try_to_import(tag_suffix) def constructor_import(loader, node): """ Callback used by PyYAML when a "!import <str>" tag is encountered. This tag exects a (quoted) string as argument. """ value = loader.construct_scalar(node) if '.' not in value: raise yaml.YAMLError("import tag suffix contains no '.'") return try_to_import(value) def constructor_float(loader, node): """ Callback used by PyYAML when a "!float <str>" tag is encountered. This tag exects a (quoted) string as argument. """ value = loader.construct_scalar(node) return float(value) def construct_mapping(node, deep=False): # This is a modified version of yaml.BaseConstructor.construct_mapping # in which a repeated key raises a ConstructorError if not isinstance(node, yaml.nodes.MappingNode): const = yaml.constructor message = "expected a mapping node, but found" raise const.ConstructorError(None, None, "%s %s " % (message, node.id), node.start_mark) mapping = {} constructor = yaml.constructor.BaseConstructor() for key_node, value_node in node.value: key = constructor.construct_object(key_node, deep=False) try: hash(key) except TypeError as exc: const = yaml.constructor reraise_as(const.ConstructorError("while constructing a mapping", node.start_mark, "found unacceptable key (%s)" % (exc, key_node.start_mark))) if key in mapping: const = yaml.constructor raise const.ConstructorError("while constructing a mapping", node.start_mark, "found duplicate key (%s)" % key) value = constructor.construct_object(value_node, deep=False) mapping[key] = value return mapping if __name__ == "__main__": initialize() # Demonstration of how to specify objects, reference them # later in the configuration, etc. yamlfile = """{ "corruptor" : !obj:pylearn2.corruption.GaussianCorruptor &corr { "corruption_level" : 0.9 }, "dae" : !obj:pylearn2.models.autoencoder.DenoisingAutoencoder { "nhid" : 20, "nvis" : 30, "act_enc" : null, "act_dec" : null, "tied_weights" : true, # we could have also just put the corruptor definition here "corruptor" : *corr } }""" # yaml.load can take a string or a file object loaded = yaml.load(yamlfile) logger.info(loaded) # These two things should be the same object assert loaded['corruptor'] is loaded['dae'].corruptor