import glob from datetime import datetime, timedelta import fasteners import re import os from os import path from saver import Saver _LOG_DATE_FORMAT = "%m%d_%H%M" _RESTORE_PREFIX = 'RESTORE@' def iter_ckpt_dirs(log_dir_root, job_ids_str): assert os.path.exists(log_dir_root), 'Invalid log dir: {}'.format(log_dir_root) job_ids = job_ids_str.strip().replace(';', ',').split(',') assert len(job_ids) > 0, 'No job_ids!' for job_id in job_ids: # ckpt_dir_for_log_dir appends 'ckpts', which ensures that we only get training log dirs as matches, # and not other or previous validation dir. ckpt_dir_glob = Saver.ckpt_dir_for_log_dir(path.join(log_dir_root, job_id + '*')) ckpt_dir_matches = glob.glob(ckpt_dir_glob) if len(ckpt_dir_matches) == 0: print('*** ERR: No matches for {}'.format(ckpt_dir_glob)) continue if len(ckpt_dir_matches) > 1: print('*** ERR: Multiple matches for {}: {}'.format(ckpt_dir_glob, '\n'.join(ckpt_dir_matches))) continue yield ckpt_dir_matches[0] def create_unique_log_dir(config_rel_paths, log_dir_root, line_breaking_chars_pat=r'[-]', restore_dir=None): """ 0117_1704 repr@soa3_med_8e*5_deePer_b50_noHM_C16 repr@v2_res_shallow RESTORE@path@to@restore@0115_1340 :param config_rel_paths: :param log_dir_root: :param line_breaking_chars_pat: :return: """ if any(':' in config_rel_path for config_rel_path in config_rel_paths): raise ValueError('":" not allowed in paths, got {}'.format(config_rel_paths)) def prep_path(p): p = p.replace(path.sep, '@') return re.sub(line_breaking_chars_pat, '*', p) postfix_dir_name = ' '.join(map(prep_path, config_rel_paths)) if restore_dir: restore_dir_root, restore_job_component = _split_log_dir(restore_dir) restore_dir_root = restore_dir_root.replace(path.sep, '@') restore_job_id = log_date_from_log_dir(restore_job_component) postfix_dir_name += ' {restore_prefix}{root}@{job_id}'.format( restore_prefix=_RESTORE_PREFIX, root=restore_dir_root, job_id=restore_job_id) return _mkdir_threadsafe_unique(log_dir_root, datetime.now(), postfix_dir_name) def _split_log_dir(log_dir): """ given some/path/to/job/dir/0101_1818 ae_config pc_config/ckpts or some/path/to/job/dir/0101_1818 ae_config pc_config returns tuple some/path/to/job/dir, 0101_1818 ae_config pc_config """ log_dir_root = [] job_component = None for comp in log_dir.split(path.sep): try: log_date_from_log_dir(comp) job_component = comp break # this component is an actual log dir. stop and return components except ValueError: log_dir_root.append(comp) assert job_component is not None, 'Invalid log_dir: {}'.format(log_dir) return path.sep.join(log_dir_root), job_component def _mkdir_threadsafe_unique(log_dir_root, log_date, postfix_dir_name): os.makedirs(log_dir_root, exist_ok=True) # Make sure only one process at a time writes into log_dir_root with fasteners.InterProcessLock(os.path.join(log_dir_root, 'lock')): return _mkdir_unique(log_dir_root, log_date, postfix_dir_name) def _mkdir_unique(log_dir_root, log_date, postfix_dir_name): log_date_str = log_date.strftime(_LOG_DATE_FORMAT) if _log_dir_with_log_date_exists(log_dir_root, log_date): print('Log dir starting with {} exists...'.format(log_date_str)) return _mkdir_unique(log_dir_root, log_date + timedelta(minutes=1), postfix_dir_name) log_dir = path.join(log_dir_root, '{log_date_str} {postfix_dir_name}'.format( log_date_str=log_date_str, postfix_dir_name=postfix_dir_name)) os.makedirs(log_dir) return log_dir def _log_dir_with_log_date_exists(log_dir_root, log_date): log_date_str = log_date.strftime(_LOG_DATE_FORMAT) all_log_dates = set() for log_dir in os.listdir(log_dir_root): try: all_log_dates.add(log_date_from_log_dir(log_dir)) except ValueError: continue return log_date_str in all_log_dates def log_date_from_log_dir(log_dir): # extract {log_date} from LOG_DIR/{log_date} {netconfig} {probconfig} possible_log_date = os.path.basename(log_dir).split(' ')[0] if not is_log_date(possible_log_date): raise ValueError('Invalid log dir: {}'.format(log_dir)) return possible_log_date def is_log_date(possible_log_date): try: datetime.strptime(possible_log_date, _LOG_DATE_FORMAT) return True except ValueError: return False def config_paths_from_log_dir(log_dir, base_dirs): log_dir = path.basename(log_dir.strip(path.sep)) # log_dir == {now} {netconfig} {probconfig} [RESTORE@some_dir@XXXX_YYYY], get [netconfig, probconfig] comps = log_dir.split(' ') assert is_log_date(comps[0]), 'Invalid log_dir: {}'.format(log_dir) comps = [c for c in comps[1:] if _RESTORE_PREFIX not in c] assert len(comps) <= len(base_dirs), 'Expected as many config components as base dirs: {}, {}'.format( comps, base_dirs) def get_real_path(base, prepped_p): p_glob = prepped_p.replace('@', path.sep) p_glob = path.join(base, p_glob) # e.g., ae_configs/p_glob glob_matches = glob.glob(p_glob) # We always only replace one character with *, so filter for those. # I.e. lr1e-5 will become lr1e*5, which will match lr1e-5 but also lr1e-4.5 glob_matches_of_same_len = [g for g in glob_matches if len(g) == len(p_glob)] if len(glob_matches_of_same_len) != 1: raise ValueError('Cannot find config on disk: {} (matches: {})'.format(p_glob, glob_matches_of_same_len)) return glob_matches_of_same_len[0] return tuple(get_real_path(base_dir, comp) for base_dir, comp in zip(base_dirs, comps))