""" Copyright 2019, ETH Zurich This file is part of L3C-PyTorch. L3C-PyTorch is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or any later version. L3C-PyTorch is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with L3C-PyTorch. If not, see <https://www.gnu.org/licenses/>. """ import glob from collections import namedtuple from datetime import datetime, timedelta import fasteners import re import os from os import path _LOG_DATE_FORMAT = "%m%d_%H%M" _RESTORE_PREFIX = 'r@' def create_unique_log_dir(config_rel_paths, log_dir_root, line_breaking_chars_pat=r'[-]', postfix=None, restore_dir=None, strip_ext=None): """ 0117_1704 repr@soa3_med_8e*5_deePer_b50_noHM_C16 repr@v2_res_shallow r@0115_1340 :param config_rel_paths: paths to the configs, relative to the config root dir :param log_dir_root: In this directory, all log dirs are stored. Created if needed. :param line_breaking_chars_pat: :param postfix: appended to the returned log dir :param restore_dir: if given, expected to be a log dir. the JOB_ID of that will be appended :param strip_ext: if given, do not store extension `strip_ext` of config_rel_paths :return: path to a newly created directory """ if any('@' in config_rel_path for config_rel_path in config_rel_paths): raise ValueError('"@" not allowed in paths, got {}'.format(config_rel_paths)) if strip_ext: assert all(strip_ext in c for c in config_rel_paths) config_rel_paths = [c.replace(strip_ext, '') for c in config_rel_paths] def prep_path(p): p = p.replace(path.sep, '@') return re.sub(line_breaking_chars_pat, '*', p) postfix_dir_name = ' '.join(map(prep_path, config_rel_paths)) if restore_dir: _, restore_job_component = _split_log_dir(restore_dir) restore_job_id = log_date_from_log_dir(restore_job_component) postfix_dir_name += ' {restore_prefix}{job_id}'.format( restore_prefix=_RESTORE_PREFIX, job_id=restore_job_id) if postfix: if isinstance(postfix, list): postfix = ' '.join(postfix) postfix_dir_name += ' ' + postfix return _mkdir_threadsafe_unique(log_dir_root, datetime.now(), postfix_dir_name) LogDirComps = namedtuple('LogDirComps', ['config_paths', 'postfix']) def parse_log_dir(log_dir, configs_dir, base_dirs, append_ext=''): """ Given a log_dir produced by `create_unique_log_dir`, return the full paths of all configs used. The log dir has thus the following format {now} {netconfig} {probconfig} [r@XXXX_YYYY] [{postfix} {postfix}] :param log_dir: the log dir to parse :param configs_dir: the root config dir, where all the configs live :param base_dirs: Prefixed to the paths of the configs, e.g., ['ae', 'pc'] :return: all config paths, as well as the postfix if one was given """ base_dirs = [path.join(configs_dir, base_dir) for base_dir in base_dirs] log_dir = path.basename(log_dir.strip(path.sep)) comps = log_dir.split(' ') assert is_log_date(comps[0]), 'Invalid log_dir: {}'.format(log_dir) assert len(comps) > len(base_dirs), 'Expected a base dir for every component, got {} and {}'.format( comps, base_dirs) config_components = comps[1:(1+len(base_dirs))] has_restore = any(_RESTORE_PREFIX in c for c in comps) postfix = comps[1+len(base_dirs)+has_restore:] def get_real_path(base, prepped_p): p_glob = prepped_p.replace('@', path.sep) p_glob = path.join(base, p_glob) + append_ext # e.g., ae_configs/p_glob.cf glob_matches = glob.glob(p_glob) # We always only replace one character with *, so filter for those. # I.e. lr1e-5 will become lr1e*5, which will match lr1e-5 but also lr1e-4.5 glob_matches_of_same_len = [g for g in glob_matches if len(g) == len(p_glob)] if len(glob_matches_of_same_len) != 1: raise ValueError('Cannot find config on disk: {} (matches: {})'.format(p_glob, glob_matches_of_same_len)) return glob_matches_of_same_len[0] return LogDirComps( config_paths=tuple(get_real_path(base_dir, comp) for base_dir, comp in zip(base_dirs, config_components)), postfix=tuple(postfix) if postfix else None) # ------------------------------------------------------------------------------ def _split_log_dir(log_dir): """ given some/path/to/job/dir/0101_1818 ae_config pc_config/ckpts or some/path/to/job/dir/0101_1818 ae_config pc_config returns tuple some/path/to/job/dir, 0101_1818 ae_config pc_config """ log_dir_root = [] job_component = None for comp in log_dir.split(path.sep): try: log_date_from_log_dir(comp) job_component = comp break # this component is an actual log dir. stop and return components except ValueError: log_dir_root.append(comp) assert job_component is not None, 'Invalid log_dir: {}'.format(log_dir) return path.sep.join(log_dir_root), job_component def _mkdir_threadsafe_unique(log_dir_root, log_date, postfix_dir_name): os.makedirs(log_dir_root, exist_ok=True) # Make sure only one process at a time writes into log_dir_root with fasteners.InterProcessLock(os.path.join(log_dir_root, 'lock')): return _mkdir_unique(log_dir_root, log_date, postfix_dir_name) def _mkdir_unique(log_dir_root, log_date, postfix_dir_name): log_date_str = log_date.strftime(_LOG_DATE_FORMAT) if _log_dir_with_log_date_exists(log_dir_root, log_date): print('Log dir starting with {} exists...'.format(log_date_str)) return _mkdir_unique(log_dir_root, log_date + timedelta(minutes=1), postfix_dir_name) log_dir = path.join(log_dir_root, '{log_date_str} {postfix_dir_name}'.format( log_date_str=log_date_str, postfix_dir_name=postfix_dir_name).strip()) os.makedirs(log_dir) return log_dir def _log_dir_with_log_date_exists(log_dir_root, log_date): log_date_str = log_date.strftime(_LOG_DATE_FORMAT) all_log_dates = set() for log_dir in os.listdir(log_dir_root): try: all_log_dates.add(log_date_from_log_dir(log_dir)) except ValueError: continue return log_date_str in all_log_dates def log_date_from_log_dir(log_dir): # extract {log_date} from LOG_DIR/{log_date} {netconfig} {probconfig} possible_log_date = os.path.basename(log_dir).split(' ')[0] if not is_log_date(possible_log_date): raise ValueError('Invalid log dir: {}'.format(log_dir)) return possible_log_date def is_log_dir(log_dir): try: log_date_from_log_dir(log_dir) return True except ValueError: return False def is_log_date(possible_log_date): try: datetime.strptime(possible_log_date, _LOG_DATE_FORMAT) return True except ValueError: return False