import os import sys import base64 import yaml import tensorflow as tf from docopt import docopt from mayo.log import log from mayo.config import Config from mayo.session import Test, Evaluate, Train, Search, Profile _root = os.path.dirname(__file__) def _vigenere(key, string, decode=False): if decode: string = base64.b64decode(string.encode('utf-8')).decode('utf-8') encoded_chars = [] for i in range(len(string)): key_c = ord(key[i % len(key)]) % 256 encoded_c = ord(string[i]) encoded_c += -key_c if decode else key_c encoded_chars.append(chr(encoded_c)) encoded_str = "".join(encoded_chars) if decode: return encoded_str return base64.b64encode(encoded_str.encode('utf-8')).decode('utf-8') def meta(): meta_file = os.path.join(_root, 'meta.yaml') meta_dict = yaml.load(open(meta_file, 'r')) meta_dict['__root__'] = _root meta_dict['__executable__'] = os.path.basename(sys.argv[0]) email = '__email__' encrypted_email = meta_dict[email].replace('\n', '').replace(' ', '') meta_dict[email] = _vigenere(email, encrypted_email, decode=True) authors_emails = zip( meta_dict['__author__'].split(', '), meta_dict[email].split(', ')) credits = ', '.join('{} ({})'.format(a, e) for a, e in authors_emails) meta_dict['__credits__'] = credits return meta_dict class CLI(object): _DOC = """ {__mayo__} {__version__} ({__date__}) {__description__} {__credits__} """ _USAGE = """ Usage: {__executable__} <anything>... {__executable__} (-h | --help) Arguments: <anything> can be one of the following given in sequence: * A YAML file with a `.yaml` or `.yml` suffix. If a YAML file is given, it will attempt to load the YAML file to update the config. * An overrider argument to update the config, formatted as "<dot_key_path>=<yaml_value>", e.g., "system.num_gpus=2". * An action to execute, one of: {commands} """ def __init__(self): super().__init__() self.config = Config() self.session = None def doc(self): return self._DOC.format(**meta()) def commands(self): prefix = 'cli_' commands = {} for method in dir(self): if not method.startswith(prefix): continue name = method[len(prefix):].replace('_', '-') commands[name] = getattr(self, method) return commands def usage(self): usage_meta = meta() commands = self.commands() name_len = max(len(name) for name in commands) descriptions = [] for name, func in commands.items(): doc = func.__doc__ or '' doc = '{}{:{l}} {}'.format(' ' * 9, name, doc.strip(), l=name_len) descriptions.append(doc) usage_meta['commands'] = '\n'.join(descriptions) return self.doc() + self._USAGE.format(**usage_meta) def _validate_config(self, keys, action, test=False): for k in keys: if k in self.config: continue if test: return False log.error_exit( 'Please ensure config content {!r} is imported before ' 'executing {!r}.'.format(k, action)) return True _model_keys = [ 'model.name', 'model.layers', 'model.graph', ] _dataset_keys = [ 'dataset.name', 'dataset.task', ] _validate_keys = [ 'dataset.path.validate', 'dataset.num_examples_per_epoch.validate', ] _test_keys = [] _train_keys = [ 'dataset.path.train', 'dataset.num_examples_per_epoch.train', 'train.learning_rate', 'train.optimizer', ] _search_keys = [ 'search', ] _session_map = { 'train': Train, 'search': Search, 'test': Test, 'validate': Evaluate, 'profile': Profile, } _keys_map = { 'train': _train_keys, 'search': _train_keys, 'profile': _train_keys, 'test': _test_keys, 'validate': _validate_keys, } def _get_session(self, action=None): if not action: if self.session: return self.session keys = self._train_keys if self._validate_config(keys, 'train', test=True): self.session = self._get_session('train') else: self.session = self._get_session('validate') return self.session keys = self._model_keys + self._dataset_keys try: cls = self._session_map[action] keys += self._keys_map[action] except KeyError: raise TypeError('Action {!r} not recognized.'.format(action)) self._validate_config(keys, action) if not isinstance(self.session, cls): log.info('Starting a {} session...'.format(action)) self.session = cls(self.config) return self.session def cli_profile_timeline(self): """Performs training profiling to produce timeline.json. """ # TODO integrate this into Profile. from tensorflow.python.client import timeline options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() session = self._get_session('train') # run 100 iterations to warm up max_iterations = 100 for i in range(max_iterations): log.info( 'Running {}/{} iterations to warm up...' .format(i, max_iterations), update=True) session.run(session._train_op) log.info('Running the final iteration to generate timeline...') session.run( session._train_op, options=options, run_metadata=run_metadata) fetched_timeline = timeline.Timeline(run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open('timeline.json', 'w') as f: f.write(chrome_trace) def cli_plot(self): """Plots activation maps as images and parameters as histograms.""" return self._get_session('validate').plot() def cli_train(self): """Performs training. """ return self._get_session('train').train() def cli_search(self): """Performs automated hyperparameter search. """ return self._get_session('search').search() def cli_profile(self): """Performs profiling. """ return self._get_session('profile').profile() def cli_eval(self): """Evaluates the accuracy of a saved model. """ return self._get_session('validate').eval() def cli_eval_all(self): """Evaluates all checkpoints for accuracy. """ result = self._get_session('validate').eval_all() file_name = 'eval_all.csv' with open(file_name, 'w') as f: f.write(result.csv()) log.info( 'Evaluation results saved in {!r}.'.format(file_name)) def cli_test(self): """Perform inference for custom test data. """ return self._get_session('test').test() def cli_overriders_update(self): """Updates variable overriders in the training session. """ self._get_session('train').overriders_update() def cli_overriders_assign(self): """Assign overridden values to original parameters. """ self._get_session('train').overriders_assign() self._get_session('train').save_checkpoint('assigned') def cli_overriders_reset(self): """Reset the internal state of overriders. """ self._get_session('train').overriders_reset() def cli_overriders_dump(self): """Export the internal parameters of overriders. """ self._get_session().overriders_dump() def cli_reset_num_epochs(self): """Resets the number of training epochs. """ self._get_session('train').reset_num_epochs() def cli_export(self): """Exports the current config. """ name = 'export.yaml' with open(name, 'w') as f: f.write(self.config.to_yaml()) log.info('Config successfully exported to {!r}.'.format(name)) def cli_info(self): """Prints parameter and layer info of the model. """ plumbing = self.config.system.info.get('plumbing') info = self._get_session().info(plumbing) if plumbing: with open('info.yaml', 'w') as f: yaml.dump(info, f) else: for key in ('trainables', 'nontrainables', 'layers'): print(info[key].format()) for table in info.get('overriders', {}).values(): print(table.format()) def cli_interact(self): """Interacts with the train/eval session using iPython. """ self._get_session().interact() def cli_save(self): """Saves the latest checkpoint. """ self.session.checkpoint.save('latest') def _purge_session(self): if not self.session: return log.info('Purging current session because config is updated...') del self.session self.session = None def main(self, args=None): if args is None: args = docopt(self.usage(), version=meta()['__version__']) anything = args['<anything>'] commands = self.commands() for each in anything: # some problems with `\r` when running through Linux # subsystem on Windows each = each.strip() if any(each.endswith(suffix) for suffix in ('.yaml', '.yml')): self.config.yaml_update(each) log.key('Using config yaml {!r}...'.format(each)) self._purge_session() elif '=' in each: self.config.override_update(*each.split('=')) log.key('Overriding config with {!r}...'.format(each)) self._purge_session() elif each in commands: log.key('Executing command {!r}...'.format(each)) commands[each]() else: with log.use_pause_level('off'): log.error( 'We don\'t know what you mean by {!r}.'.format(each)) return