#!/usr/bin/env python # -*- coding: UTF-8 -*- ######################################################################## # GNU General Public License v3.0 # GNU GPLv3 # Copyright (c) 2019, Noureldien Hussein # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. ######################################################################## """ Configurations for project. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import os import platform import argparse import logging import yaml import pprint from ast import literal_eval from core.config import __C from core.utils import AttrDict from core import const, config, utils logger = logging.getLogger(__name__) # region Misc def get_machine_name(): return platform.node() def import_dl_platform(): if const.DL_FRAMEWORK == 'tensorflow': import tensorflow as tf elif const.DL_FRAMEWORK == 'pytorch': import torch elif const.DL_FRAMEWORK == 'caffe': import caffe elif const.DL_FRAMEWORK == 'keras': import keras.backend as K # endregion # region Config GPU def config_gpu(): if const.DL_FRAMEWORK == 'tensorflow': __config_gpu_for_tensorflow() elif const.DL_FRAMEWORK == 'pytorch': __config_gpu_for_pytorch() elif const.DL_FRAMEWORK == 'keras': __config_gpu_for_keras() elif const.DL_FRAMEWORK == 'caffe': __config_gpu_for_caffe() def __config_gpu_for_tensorflow(): import tensorflow as tf gpu_core_id = __parse_gpu_id() # import os # import tensorflow as tf # set the logging level of tensorflow # 1: filter out INFO # 2: filter out WARNING # 3: filter out ERROR # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # or any {'0', '1', '2'} # set which device to be used const.GPU_CORE_ID = gpu_core_id pass def __config_gpu_for_keras(): import tensorflow as tf import keras.backend as K gpu_core_id = __parse_gpu_id() K.clear_session() config = tf.ConfigProto() config.gpu_options.visible_device_list = str(gpu_core_id) config.gpu_options.allow_growth = True session = tf.Session(config=config) K.set_session(session) # set which device to be used const.GPU_CORE_ID = gpu_core_id def __config_gpu_for_pytorch(): import torch gpu_core_id = __parse_gpu_id() torch.cuda.set_device(gpu_core_id) # set which device to be used const.GPU_CORE_ID = gpu_core_id def __config_gpu_for_caffe(): import os gpu_core_id = __parse_gpu_id() os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_core_id) # set which device to be used const.GPU_CORE_ID = gpu_core_id def __parse_gpu_id(): parser = argparse.ArgumentParser() parser.add_argument('-c', '--gpu_core_id', default='-1', type=int) args = parser.parse_args() gpu_core_id = args.gpu_core_id return gpu_core_id # endregion # region Config File Helpers def cfg_print_cfg(): logger.info('Config file is:') logger.info(pprint.pformat(__C)) def cfg_merge_dicts(dict_a, dict_b): from ast import literal_eval for key, value in dict_a.items(): if key not in dict_b: raise KeyError('Invalid key in config file: {}'.format(key)) if type(value) is dict: dict_a[key] = value = AttrDict(value) if isinstance(value, str): try: value = literal_eval(value) except BaseException: pass # the types must match, too old_type = type(dict_b[key]) if old_type is not type(value) and value is not None: raise ValueError('Type mismatch ({} vs. {}) for config key: {}'.format(type(dict_b[key]), type(value), key)) # recursively merge dicts if isinstance(value, AttrDict): try: cfg_merge_dicts(dict_a[key], dict_b[key]) except BaseException: raise Exception('Error under config key: {}'.format(key)) else: dict_b[key] = value def cfg_from_file(file_path, is_check=True): """ Load a config file and merge it into the default options. """ # read from file yaml_config = utils.yaml_load(file_path) # merge to project config cfg_merge_dicts(yaml_config, __C) # make sure everything is okay if is_check: cfg_sanity_check() def cfg_from_attrdict(attr_dict): cfg_merge_dicts(attr_dict, __C) def cfg_from_dict(args_dict): """Set config keys via list (e.g., from command line).""" for key, value in args_dict.iteritems(): key_list = key.split('.') cfg = __C for subkey in key_list[:-1]: assert subkey in cfg, 'Config key {} not found'.format(subkey) cfg = cfg[subkey] subkey = key_list[-1] if subkey not in cfg: raise Exception('Config key {} not found'.format(subkey)) try: # handle the case when v is a string literal val = literal_eval(value) except BaseException: val = value if isinstance(val, type(cfg[subkey])) or cfg[subkey] is None: pass else: type1 = type(val) type2 = type(cfg[subkey]) msg = 'type {} does not match original type {}'.format(type1, type2) raise Exception(msg) cfg[subkey] = val def cfg_from_list(args_list): """ Set config keys via list (e.g., from command line). """ from ast import literal_eval assert len(args_list) % 2 == 0, 'Specify values or keys for args' for key, value in zip(args_list[0::2], args_list[1::2]): key_list = key.split('.') cfg = __C for subkey in key_list[:-1]: assert subkey in cfg, 'Config key {} not found'.format(subkey) cfg = cfg[subkey] subkey = key_list[-1] assert subkey in cfg, 'Config key {} not found'.format(subkey) try: # handle the case when v is a string literal val = literal_eval(value) except BaseException: val = value msg = 'type {} does not match original type {}'.format(type(val), type(cfg[subkey])) assert isinstance(val, type(cfg[subkey])) or cfg[subkey] is None, msg cfg[subkey] = val def cfg_sanity_check(): assert __C.TRAIN.SCHEME in const.TRAIN_SCHEMES assert __C.MODEL.CLASSIFICATION_TYPE in const.MODEL_CLASSIFICATION_TYPES assert __C.MODEL.MULTISCALE_TYPE in const.MODEL_MULTISCALE_TYPES assert __C.SOLVER.NAME in const.SOLVER_NAMES assert __C.DATASET_NAME in const.DATASET_NAMES # endregion