import os import re import sys import fnmatch import inspect import importlib import importlib.abc import importlib.util import importlib.machinery import warnings import subprocess from types import ModuleType from typing import Any, List, Dict, Iterator, Callable, Optional from difflib import SequenceMatcher from datetime import datetime from argparse import Namespace as BaseNamespace from inspect import formatannotation as format_anno from nest import utils as U from nest.logger import exception from nest.settings import settings class Context(BaseNamespace): """Helper class for storing module context. """ def __getitem__(self, key: str) -> Any: return getattr(self, key) def __setitem__(self, key: str, val: str) -> Any: return setattr(self, key, val) def __iter__(self) -> Iterator: return iter(self.__dict__.items()) def items(self) -> Iterator: return self.__dict__.items() def keys(self) -> Iterator: return self.__dict__.keys() def values(self) -> Iterator: return self.__dict__.values() def clear(self): self.__dict__.clear() class NestModule(object): """Base Nest module class. """ __slots__ = ('__name__', 'func', 'sig', 'meta', 'params') def __init__(self, func: Callable, meta: Dict[str, object], params: dict = {}) -> None: # module func self.func = func self.__name__ = func.__name__ # module signature self.sig = inspect.signature(func) # meta information self.meta = U.merge_dict(dict(), meta, union=True) # record module params self.params = U.merge_dict(dict(), params, union=True) # init module context for k, v in self.sig.parameters.items(): if k =='ctx' and issubclass(v.annotation, Context): self.params[k] = v.annotation() break # check module self._check_definition() def _check_definition(self) -> None: """Raise errors if the module definition is invalid. """ for v in self.sig.parameters.values(): # type of parameters must be annotated if v.annotation is inspect.Parameter.empty: raise TypeError('The param "%s" of Nest module "%s" is not explicitly annotated.' % (v, self.__name__)) # type of defaults must match annotations if v.default is not inspect.Parameter.empty and not U.is_annotation_matched(v.default, v.annotation): raise TypeError('The param "%s" of Nest module "%s" has an incompatible default value of type "%s".' % (v, self.__name__, format_anno(type(v.default)))) # type of returns must be annotated if self.sig.return_annotation is inspect.Parameter.empty: raise TypeError('The returns of Nest module "%s" is not explicitly annotated.' % self.__name__) # important meta data must be provided if getattr(self, '__doc__', None) is None: raise KeyError('Documentation of module "%s" is missing.' % self.__name__) def _check_params(self, params: dict) -> None: """Raise errors if invalid params are provided to the Nest module. Parameters: params: The provided params """ unexpected_params = ', '.join(set(params.keys()) - set(self.sig.parameters.keys())) if len(unexpected_params) > 0: raise TypeError('Unexpected param(s) "%s" for Nest module: \n%s' % \ (unexpected_params, self)) for k, v in self.sig.parameters.items(): resolved = params.get(k) if resolved is None: if v.default is inspect.Parameter.empty: raise KeyError('The required param "%s" of Nest module "%s" is missing.' % \ (v, self.__name__)) elif not U.is_annotation_matched(resolved, v.annotation): if issubclass(type(resolved), NestModule): detailed_msg = 'The param "%s" of Nest module "%s" should be type of "%s". Got \n%s\n' + \ 'Please check if some important params of Nest module "%s" have been forgotten in use.' raise TypeError(detailed_msg % \ (k, self.__name__, format_anno(v.annotation), U.indent_text(str(resolved), 4), resolved.__name__)) else: raise TypeError('The param "%s" of Nest module "%s" should be type of "%s". Got "%s".' % \ (k, self.__name__, format_anno(v.annotation), resolved)) def _check_returns(self, returns: Any) -> None: """Raise errors if invalid returns are generated by the Nest module. Parameters: returns: The generated returns """ if not U.is_annotation_matched(returns, self.sig.return_annotation): raise TypeError('The returns of Nest module "%s" should be type of "%s". Got "%s".' % \ (self.__name__, format_anno(self.sig.return_annotation), returns)) def __call__(self, *args, **kwargs): # handle positional params num_args = len(args) if num_args > 0: # positional params should not be optional or resolved expected_param_names = [k for k, v in self.sig.parameters.items() if not k in self.params.keys() and v.default is inspect.Parameter.empty] num_expected_params = len(expected_param_names) if num_args != num_expected_params: raise TypeError('Nest module "%s" expects %d positional param(s) "%s". Got "%s".' % (self.__name__, num_expected_params, ', '.join(expected_param_names), ', '.join([str(v) for v in args]))) for idx, val in enumerate(args): key = expected_param_names[idx] if key in kwargs.keys(): raise TypeError('Nest module "%s" got multiple values for param "%s".' % (self.__name__, key)) else: kwargs[key] = val # resolve params resolved_params = dict() U.merge_dict(resolved_params, self.params, union=True) U.merge_dict(resolved_params, kwargs, union=True) if resolved_params.pop('delay_resolve', None): try: self._check_params(resolved_params) returns = self.func(**resolved_params) except KeyError as exc_info: if 'Nest module' in str(exc_info): # wait for next call return self.clone(resolved_params) else: raise else: # parameters must be fulfilled self._check_params(resolved_params) returns = self.func(**resolved_params) # check returns self._check_returns(returns) return returns def __str__(self) -> str: param_string = ', \n'.join(['[✓] ' + str(v) if k in self.params.keys() else ' ' + str(v) for k, v in self.sig.parameters.items()]) return_string = ' -> ' + format_anno(self.sig.return_annotation) return self.__name__ + '(\n' + param_string + ')' + return_string def __repr__(self) -> str: return "nest.modules['%s']" % self.__name__ def clone(self, params: dict = {}) -> Callable: """Clone the Nest module. Parameters: params: Module parameters """ return type(self)(self.func, self.meta, params) class ModuleManager(object): """Helper class for easy access to Nest modules. """ def __init__(self) -> None: self.namespaces = dict() self.py_modules = dict() self.nest_modules = dict() self.update_timestamp = 0.0 self.namespace_regex = re.compile(r'^[a-z][a-z0-9\_]*\Z') # get available namespaces self._update_namespaces() # import syntax self._add_module_finder() @staticmethod def _format_namespace(src: str) -> str: """Format namespace. Parameters: src: The original namespace Returns: Formatted namespace. """ return src.lower().replace('-', '_').replace('.', '_') @staticmethod def _register(*args, **kwargs) -> Callable: """Decorator for Nest modules registration. Parameters: ignored: Ignore the module module meta information which could be utilized by CLI and UI. For example: author: Module author(s), e.g., 'Zhou, Yanzhao' version: Module version, e.g., '1.2.0' backend: Module backend, e.g., 'pytorch' tags: Searchable tags, e.g., ['loss', 'cuda_only'] etc. """ # ignore the Nest module (could be used for debuging) if kwargs.pop('ignored', False): return lambda x: x # use the rest of kwargs to update metadata frame = inspect.stack()[1] current_py_module = inspect.getmodule(frame[0]) nest_meta = U.merge_dict(getattr(current_py_module, '__nest_meta__', dict()), kwargs, union=True) if current_py_module is not None: setattr(current_py_module, '__nest_meta__', nest_meta) def create_module(func): # append meta to doc doc = (func.__doc__ + '\n' + (U.yaml_format(nest_meta) if len(nest_meta) > 0 else '')) \ if isinstance(func.__doc__, str) else None return type('NestModule', (NestModule,), dict(__slots__=(), __doc__=doc))(func, nest_meta) if len(args) == 1 and inspect.isfunction(args[0]): return create_module(args[0]) else: return create_module @staticmethod def _import_nest_modules_from_py_module( namespace: str, py_module: object, nest_modules: Dict[str, object]) -> bool: """Import registered Nest modules from a given python module. Parameters: namespace: A namespace that is used to avoid name conflicts py_module: The python module nest_modules: The dict for storing Nest modules Returns: The id of imported Nest modules """ imported_ids = [] # search for Nest modules for key, val in py_module.__dict__.items(): module_id = U.encode_id(namespace, key) if not key.startswith('_') and type(val).__name__ == 'NestModule': if module_id in nest_modules.keys(): U.alert_msg('There are duplicate "%s" modules under namespace "%s".' % \ (key, namespace)) else: nest_modules[module_id] = val imported_ids.append(module_id) return imported_ids @staticmethod def _import_nest_modules_from_file( path: str, namespace: str, py_modules: Dict[str, float], nest_modules: Dict[str, object], meta: Dict[str, object] = dict()) -> None: """Import registered Nest modules form a given file. Parameters: path: The path to the file namespace: A namespace that is used to avoid name conflicts py_modules: The dict for storing python modules information nest_modules: The dict for storing Nest modules meta: Global meta information """ py_module_name = os.path.basename(path).split('.')[0] py_module_id = U.encode_id(namespace, py_module_name) timestamp = os.path.getmtime(path) # check whether the python module have already been imported is_reload = False if py_module_id in py_modules.keys(): if timestamp <= py_modules[py_module_id][0]: # skip return else: is_reload = True # import the python module # note that a python module could contain multiple Nest modules. ref_id = 'nest.' + namespace + '.' + py_module_name spec = importlib.util.spec_from_file_location(ref_id, path) if spec is not None: py_module = importlib.util.module_from_spec(spec) py_module.__nest_meta__ = U.merge_dict(dict(), meta, union=True) # no need to bind global requirements to individual Nest modules. requirements = py_module.__nest_meta__.pop('requirements', None) if requirements is not None: requirements = [dict(url=v, tool='pip') if isinstance(v, str) else v for v in requirements] sys.modules[ref_id] = py_module try: with warnings.catch_warnings(): warnings.simplefilter("ignore") spec.loader.exec_module(py_module) except Exception as exc_info: # helper function def find_requirement(name): if isinstance(requirements, list) and len(requirements) > 0: scores = [(SequenceMatcher(None, name, v['url']).ratio(), v) for v in requirements] return max(scores, key=lambda x: x[0]) # install tip tip = '' if (type(exc_info) is ImportError or type(exc_info) is ModuleNotFoundError) and exc_info.name is not None: match = find_requirement(exc_info.name) if match and match[0] > settings['INSTALL_TIP_THRESHOLD']: tip = 'Try to execute "%s install %s" to install the missing dependency.' % \ (match[1]['tool'], match[1]['url']) exc_info = str(exc_info) exc_info = exc_info if exc_info.endswith('.') else exc_info + '.' U.alert_msg('%s The package "%s" under namespace "%s" could not be imported. %s' % (exc_info, py_module_name, namespace, tip)) else: # remove old Nest modules if is_reload: for key in py_modules[py_module_id][1]: if key in nest_modules.keys(): del nest_modules[key] # import all Nest modules within the python module imported_ids = ModuleManager._import_nest_modules_from_py_module(namespace, py_module, nest_modules) if len(imported_ids) > 0: # record modified time, id, and spec of imported Nest modules py_modules[py_module_id] = (timestamp, imported_ids, py_module.__spec__) @staticmethod def _import_nest_modules_from_dir( path: str, namespace: str, py_modules: Dict[str, float], nest_modules: Dict[str, object], meta: Dict[str, object] = dict()) -> None: """Import registered Nest modules form a given directory. Parameters: path: The path to the directory namespace: A namespace that is used to avoid name conflicts py_modules: The dict for storing modified timestamp of python modules nest_modules: The dict for storing Nest modules meta: Global meta information Returns: The Nest modules The set of python modules """ for entry in os.listdir(path): file_path = os.path.join(path, entry) if entry.endswith('.py') and os.path.isfile(file_path): ModuleManager._import_nest_modules_from_file(file_path, namespace, py_modules, nest_modules, meta) @staticmethod def _fetch_nest_modules_from_url(url: str, dst: str) -> None: """Fetch and unzip Nest modules from url. Parameters: url: URL of the zip file or git repo dst: Save dir path """ def _hook(count, block_size, total_size): size = float(count * block_size) / (1024.0 * 1024.0) total_size = float(total_size / (1024.0 * 1024.0)) if total_size > 0: size = min(size, total_size) percent = 100.0 * size / total_size sys.stdout.write("\rFetching...%d%%, %.2f MB / %.2f MB" % (percent, size, total_size)) else: sys.stdout.write("\rFetching...%.2f MB" % size) sys.stdout.flush() # extract if url.endswith('zip'): import random import string import zipfile from urllib import request, error cache_name = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(6)) + '.cache' cache_path = os.path.join(dst, cache_name) try: # download request.urlretrieve(url, cache_path, _hook) sys.stdout.write('\n') # unzip with zipfile.ZipFile(cache_path, 'r') as f: file_list = f.namelist() namespaces = set([v.split('/')[0] for v in file_list]) members = [v for v in file_list if '/' in v] f.extractall(dst, members) return namespaces except error.URLError as exc_info: U.alert_msg('Could not fetch "%s". %s' % (url, exc_info)) return [] except Exception as exc_info: U.alert_msg('Error occurs during extraction. %s' % exc_info) return [] finally: # remove cache if os.path.exists(cache_path): os.remove(cache_path) elif url.endswith('.git'): try: repo_name = url[url.rfind('/')+1: -4] match = re.search(r'(?:\s|^)(?:-b|--branch) (\w+)', url) if match: repo_name += '-' + match.group(1) subprocess.check_call(['git', 'clone'] + url.split() + [repo_name]) return [repo_name] except subprocess.CalledProcessError as exc_info: U.alert_msg('Failed to clone "%s".' % url) return [] else: raise NotImplementedError('Only supports zip file and git repo for now. Got "%s".' % url) @staticmethod def _install_namespaces_from_url(url: str, namespace: Optional[str] = None) -> None: """Install namespaces from url. Parameters: url: URL of the zip file or git repo namespace: Specified namespace """ # pre-process short URL if url.startswith('github@'): m = re.match(r'^github@([\w\-\_]+)/([\w\-\_]+)(:[\w\-\_]+)*$', url) repo = m.group(1) + '/' + m.group(2) branch = m.group(3) or ':master' url = '-b %s https://github.com/%s.git' % (branch[1:], repo) elif url.startswith('gitlab@'): m = re.match(r'^gitlab@([\w\-\_]+)/([\w\-\_]+)(:[\w\-\_]+)*$', url) repo = m.group(1) + '/' + m.group(2) branch = m.group(3) or ':master' url = '-b %s https://gitlab.com/%s.git' % (branch[1:], repo) elif url.startswith('bitbucket@'): m = re.match(r'^bitbucket@([\w\-\_]+)/([\w\-\_]+)(:[\w\-\_]+)*$', url) repo = m.group(1) + '/' + m.group(2) branch = m.group(3) or ':master' url = '-b %s https://bitbucket.org/%s.git' % (branch[1:], repo) elif url.startswith('file@'): path = url[5:] url = 'file:///' + os.path.abspath(path) for dirname in ModuleManager._fetch_nest_modules_from_url(url, './'): module_path = os.path.join('./', dirname) ModuleManager._install_namespaces_from_path(module_path, namespace) # parse config meta_path = os.path.join(module_path, settings['NAMESPACE_CONFIG_FILENAME']) meta = U.load_yaml(meta_path)[0] if os.path.exists(meta_path) else dict() if settings['AUTO_INSTALL_REQUIREMENTS']: # auto install deps for dep in meta.get('requirements', []): # helper function def install_dep(url, tool): # filter deps if re.match(r'^[a-zA-Z0-9<=>.-]+$', dep): try: subprocess.check_call([sys.executable, '-m', tool, 'install', dep]) except subprocess.CalledProcessError: U.alert_msg('Failed to install "%s" for "%s". Please manually install it.' % (dep, dirname)) if isinstance(dep, str): # use pip by default install_dep(dep, 'pip') elif isinstance(dep, dict) and 'url' in dep and 'tool' in dep: install_dep(dep['url'], dep['tool']) else: U.alert_msg('Invalid install requirement "%s".' % dep) @staticmethod def _install_namespaces_from_path(path: str, namespace: Optional[str] = None) -> None: """Install namespaces from path. Parameters: path: Path to the directory namespace: Specified namespace """ path = os.path.abspath(path) namespace = namespace or ModuleManager._format_namespace(os.path.basename(path)) search_paths = settings['SEARCH_PATHS'] for k, v in search_paths.items(): if namespace == k: U.alert_msg('Namespace "%s" is already bound to the path "%s".' % (k, v)) return if path == v: U.alert_msg('"%s" is already installed under the namespace "%s".' % (v, k)) return search_paths[namespace] = path settings['SEARCH_PATHS'] = search_paths settings.save() @staticmethod def _remove_namespaces_from_path(src: str) -> Optional[str]: """Remove namespaces from path. Parameters: src: Namespace or path """ if os.path.isdir(src): path, namespace = os.path.abspath(src), None else: path, namespace = None, src delete_key = None search_paths = settings['SEARCH_PATHS'] for k, v in search_paths.items(): if namespace == k: delete_key = k break if path == v: delete_key = k break if delete_key is None: if namespace: U.alert_msg('The namespace "%s" is not installed.' % namespace) if path: U.alert_msg('The path "%s" is not installed.' % path) else: path = search_paths.pop(delete_key) settings['SEARCH_PATHS'] = search_paths settings.save() return path @staticmethod def _pack_namespaces(srcs: List[str], dst: str) -> List[str]: """Pack namespaces to a zip file. Parameters: srcs: Path to the namespaces dst: Save path for the resulting zip file Returns: Archived files """ import zipfile save_list = dict() for src in srcs: namespace = os.path.basename(os.path.normpath(src)) # helper function def check_extension(filename): splits = filename.split('.') if len(splits) > 1: # Python file, YAML config, Plain text, Markdown file, Image, and IPython Notebook return splits[-1] in ['py', 'yml', 'txt', 'md', 'jpg', 'png', 'gif', 'ipynb'] else: return True # scan files file_list = [] for root, dirs, files in os.walk(src): dirs[:] = [v for v in dirs if not (v[0] == '.' or v.startswith('__'))] file_list += [os.path.join(root, v) for v in files if not v[0] == '.' and check_extension(v)] save_list[namespace] = file_list # save to the zip file with zipfile.ZipFile(dst, 'w', zipfile.ZIP_DEFLATED) as f: for v in file_list: f.write(v, os.path.join(namespace, os.path.relpath(v, src))) return save_list def _add_module_finder(self) -> None: """Add a custom finder to support Nest module import syntax. """ module_manager = self class NamespaceLoader(importlib.abc.Loader): def create_module(self, spec): _, namespace = spec.name.split('.') module = ModuleType(spec.name) module_manager._update_namespaces() meta = module_manager.namespaces.get(namespace) module.__path__ = [meta['module_path']] if meta else [] return module def exec_module(self, module): pass class NestModuleFinder(importlib.abc.MetaPathFinder): def __init__(self): super(NestModuleFinder, self).__init__() self.reserved_namespaces = [ v[:-3] for v in os.listdir(os.path.dirname(os.path.realpath(__file__))) if v.endswith('.py')] def find_spec(self, fullname, path, target=None): if fullname.startswith('nest.'): name = fullname.split('.') if len(name) == 2: if not name[1] in self.reserved_namespaces: return importlib.machinery.ModuleSpec(fullname, NamespaceLoader()) sys.meta_path.insert(0, NestModuleFinder()) def _update_namespaces(self) -> None: """Get the available namespaces. """ # user defined search paths dir_list = set() self.namespaces = dict() for k, v in settings['SEARCH_PATHS'].items(): if os.path.isdir(v): meta_path = os.path.join(v, settings['NAMESPACE_CONFIG_FILENAME']) meta = U.load_yaml(meta_path)[0] if os.path.exists(meta_path) else dict() meta['module_path'] = os.path.abspath(os.path.join(v, meta.get('module_path', './'))) if os.path.isdir(meta['module_path']): self.namespaces[k] = meta dir_list.add(meta['module_path']) else: U.alert_msg('Namespace "%s" has an invalid module path "%s".' % (k, meta['module_path'])) # current path current_path = os.path.abspath(os.curdir) if not current_path in dir_list: self.namespaces['main'] = dict(module_path=current_path) def _update_modules(self) -> None: """Automatically import all available Nest modules. """ timestamp = datetime.now().timestamp() if timestamp - self.update_timestamp > settings['UPDATE_INTERVAL']: for namespace, meta in self.namespaces.items(): importlib.import_module('nest.' + namespace) ModuleManager._import_nest_modules_from_dir(meta['module_path'], namespace, self.py_modules, self.nest_modules, meta) self.update_timestamp = timestamp def __iter__(self) -> Iterator: """Iterator for Nest modules. Returns: The Nest module iterator """ self._update_modules() return iter(self.nest_modules.items()) def __len__(self): """Number of Nest modules Returns: The number of Nest modules """ self._update_modules() return len(self.nest_modules) def _ipython_key_completions_(self) -> List[str]: """Support IPython key completion. Returns: A list of module ids """ self._update_modules() return list(self.nest_modules.keys()) def __dir__(self) -> List[str]: """Support IDE auto-completion Returns: A list of module names """ self._update_modules() return list([U.decode_id(uid)[1] for uid in self.nest_modules.keys()]) @exception def __getattr__(self, key: str) -> object: """Get a Nest module by name. Parameters: key: Name of the Nest module Returns: The Nest module """ self._update_modules() matches = [] for uid in self.nest_modules.keys(): _, module_key = U.decode_id(uid) if key == module_key: matches.append(uid) if len(matches) == 0: raise KeyError('Could not find the Nest module "%s".' % key) elif len(matches) > 1: warnings.warn('Multiple Nest modules with this name have been found. \n' 'The returned module is "%s", but you can use nest.modules[regex] to specify others: \n%s' % (matches[0], '\n'.join(['[%d] %s %s' % (k, v, self.nest_modules[v].sig) for k, v in enumerate(matches)]))) return self.nest_modules[matches[0]].clone() @exception def __getitem__(self, key: str) -> object: """Get a Nest module by a query string. There are three match modes: 1. Exact match if the query string starts with '$': E.g., nest.modules['$nest/optimizer'] 2. Regex match if the query string starts with 'r/': E.g., nest.modules['r/.*optim\w+'] 3. Wildcard match if otherwise: E.g., nest.modules['optim*er']. Note that a wildcard is automatically added to the beginning of the string. Parameters: key: The query string Returns: The Nest module """ self._update_modules() if isinstance(key, str): if key.startswith('$'): # exact match key = key[1:] if key in self.nest_modules.keys(): return self.nest_modules[key].clone() else: raise KeyError('Could not find Nest module "%s".' % key) elif key.startswith('r/'): # regex match key = key[2:] r = re.compile(key) matches = list(filter(r.match, self.nest_modules.keys())) if len(matches) == 0: raise KeyError('Could not find a Nest module matches regex "%s".' % key) elif len(matches) > 1: warnings.warn('Multiple Nest modules match the given regex have been found. \n' 'The returned module is "%s", but you can adjust regex to specify others: \n%s' % (matches[0], '\n'.join(['[%d] %s %s' % (k, v, self.nest_modules[v].sig) for k, v in enumerate(matches)]))) return self.nest_modules[matches[0]].clone() else: # wildcard match if not key[0] == '*': key = '*' + key matches = fnmatch.filter(self.nest_modules.keys(), key) if len(matches) == 0: raise KeyError('Could not find a Nest module matches query "%s".' % key) elif len(matches) > 1: warnings.warn('Multiple Nest modules match the given regex have been found. \n' 'The returned module is "%s", but you can adjust regex to specify others: \n%s' % (matches[0], '\n'.join(['[%d] %s %s' % (k, v, self.nest_modules[v].sig) for k, v in enumerate(matches)]))) return self.nest_modules[matches[0]].clone() else: raise NotImplementedError def __repr__(self) -> str: return 'nest.modules' def __str__(self) -> str: num = self.__len__() if num == 0: return 'No Nest module found.' elif num == 1: return 'Found 1 Nest module.' else: return '%d Nest modules are availble.' % num # global manager module_manager = ModuleManager()