import fire import itertools import os import signal import sh import sys import traceback from concurrent.futures import ThreadPoolExecutor from random import uniform from threading import Lock, Semaphore from time import sleep from tempfile import gettempdir, NamedTemporaryFile from .config import (active_config, write_to_file, Embed300FineRandomConfigBuilder) from .common_utils import parse_timedelta from .datasets import get_dataset_instance from .io_utils import mkdir_p, logging class HyperparamSearch(object): """Spawns and schedules training scripts.""" def __init__(self, training_label_prefix, dataset_name=None, epochs=None, time_limit=None, num_gpus=None): if not ((epochs is None) ^ (time_limit is None)): raise ValueError('epochs or time_limit must present, ' 'but not both!') self._training_label_prefix = training_label_prefix self._dataset_name = dataset_name or active_config().dataset_name self._validate_training_label_prefix() self._epochs = epochs self._time_limit = time_limit fixed_config_keys = dict(dataset_name=self._dataset_name, epochs=self._epochs, time_limit=self._time_limit) self._config_builder = Embed300FineRandomConfigBuilder( fixed_config_keys) try: self._num_gpus = len(sh.nvidia_smi('-L').split('\n')) - 1 except sh.CommandNotFound: self._num_gpus = 1 self._num_gpus = num_gpus or self._num_gpus # TODO ! Replace set with a thread-safe set self._available_gpus = set(range(self.num_gpus)) self._semaphore = Semaphore(self.num_gpus) self._running_commands = [] # a list of (index, sh.RunningCommand) self._stop_search = False self._lock = Lock() @property def training_label_prefix(self): return self._training_label_prefix @property def num_gpus(self): return self._num_gpus @property def running_commands(self): return self._running_commands @property def lock(self): return self._lock def run(self): """Start the hyperparameter search.""" for search_index in itertools.count(): sleep(uniform(0.1, 1)) self._semaphore.acquire() with self.lock: if self._stop_search: break training_label = self.training_label(search_index) config = self._config_builder.build_config() gpu_index = self._available_gpus.pop() done_callback = self._create_done_callback(gpu_index) command = TrainingCommand(training_label=training_label, config=config, gpu_index=gpu_index, background=True, done_callback=done_callback) self.running_commands.append((search_index, command.execute())) logging('Running training {}..'.format(training_label)) self._remove_finished_commands() self._wait_running_commands() def stop(self): """Stop the hyperparameter search.""" self._stop_search = True def training_label(self, search_index): return '{}/{:04d}'.format(self.training_label_prefix, search_index) def _validate_training_label_prefix(self): dataset = get_dataset_instance(self._dataset_name) prefix_dir = os.path.join(dataset.training_results_dir, self._training_label_prefix) if os.path.exists(prefix_dir): raise ValueError('Training label prefix {} exists!'.format( self._training_label_prefix)) def _create_done_callback(self, gpu_index): def done_callback(cmd, success, exit_code): # NEVER write anything to stdout in done_callback # OR a deadlock will happen self._available_gpus.add(gpu_index) self._semaphore.release() return done_callback def _remove_finished_commands(self): running_commands = [] for search_index, running_command in self.running_commands: if running_command.process.is_alive()[0]: running_commands.append((search_index, running_command)) else: training_label = self.training_label(search_index) logging('Training {} has finished.'.format(training_label)) self._running_commands = running_commands def _wait_running_commands(self): for search_index, running_command in self.running_commands: training_label = self.training_label(search_index) logging('Waiting {} to finish..'.format(training_label)) try: running_command.wait() except sh.ErrorReturnCode as e: logging('{} returned a non-zero code!'.format(training_label)) except: traceback.print_exc(file=sys.stderr) class TrainingCommand(object): """Executes and manages a training script.""" COMMAND = sh.python.bake('-m', 'keras_image_captioning.training') def __init__(self, training_label, config, gpu_index, background=False, done_callback=None): self._training_label = training_label self._config = config self._gpu_index = gpu_index self._background = background if done_callback is not None: self._done_callback = done_callback else: self._done_callback = lambda cmd, success, exit_code: None self._init_config_filepath() self._init_log_filepath() @property def training_label(self): return self._training_label @property def config(self): return self._config @property def gpu_index(self): return self._gpu_index @property def config_filepath(self): return self._config_filepath def execute(self): """Execute the training.""" env = os.environ.copy() env['CUDA_VISIBLE_DEVICES'] = str(self.gpu_index) return self.COMMAND(training_label=self.training_label, config_file=self.config_filepath, _env=env, _out=self._log_filepath, _err_to_out=True, _bg=self._background, _done=self._done_callback) def _init_config_filepath(self): tmp_dir = os.path.join(gettempdir(), 'keras_image_captioning') mkdir_p(tmp_dir) config_file = NamedTemporaryFile(suffix='.yaml', dir=tmp_dir, delete=False) config_file.close() self._config_filepath = config_file.name write_to_file(self._config, self._config_filepath) def _init_log_filepath(self): LOG_FILENAME = 'training-log.txt' dataset = get_dataset_instance(self._config.dataset_name, self._config.lemmatize_caption) result_dir = os.path.join(dataset.training_results_dir, self._training_label) mkdir_p(result_dir) self._log_filepath = os.path.join(result_dir, LOG_FILENAME) def main(training_label_prefix, dataset_name=None, epochs=None, time_limit=None, num_gpus=None): epochs = int(epochs) if epochs else None time_limit = parse_timedelta(time_limit) if time_limit else None num_gpus = int(num_gpus) if num_gpus else None search = HyperparamSearch(training_label_prefix=training_label_prefix, dataset_name=dataset_name, epochs=epochs, time_limit=time_limit, num_gpus=num_gpus) def handler(signum, frame): logging('Stopping hyperparam search..') with search.lock: search.stop() for index, running_command in search.running_commands: try: label = search.training_label(index) logging('Sending SIGINT to {}..'.format(label)) running_command.signal(signal.SIGINT) except OSError: # The process might have exited before logging('{} might have terminated before.'.format(label)) except: traceback.print_exc(file=sys.stderr) logging('All training processes have been sent SIGINT.') signal.signal(signal.SIGINT, handler) # We need to execute search.run() in another thread in order for Semaphore # inside it doesn't block the signal handler. Otherwise, the signal handler # will be executed after any training process finishes the whole epoch. executor = ThreadPoolExecutor(max_workers=1) executor.submit(search.run) # wait must be True in order for the mock works, # see the unit test for more details executor.shutdown(wait=True) if __name__ == '__main__': fire.Fire(main)