# -*- coding: utf-8 -*- # Created by li huayong on 2019/10/8 # import configargparse as argparse from pathlib import Path import argparse import yaml from types import SimpleNamespace from typing import Dict, List, Tuple def load_configs_from_yaml(yaml_file: str) -> Dict: """ 从yaml配置文件中加载参数,这里会将嵌套的二级映射调整为一级映射 Args: yaml_file: yaml】文件路径 Returns: yaml文件中的配置字典 """ yaml_config = yaml.load(open(yaml_file, encoding='utf-8'), Loader=yaml.FullLoader) configs_dict = {} for sub_k, sub_v in yaml_config.items(): # 读取嵌套的参数 if isinstance(sub_v, dict): for k, v in sub_v.items(): if k in configs_dict.keys(): raise ValueError(f'Duplicate parameter : {k}') configs_dict[k] = v else: configs_dict[sub_k] = sub_v return configs_dict def parse_args() -> SimpleNamespace: parser = argparse.ArgumentParser() # How to set `local_rank` argument? # Ref: https://github.com/huggingface/transformers/issues/1651 # The easiest way is to use the torch launch script. # It will automatically set the local rank correctly. # It would look something like this: # `python -m torch.distributed.launch --nproc_per_node 8 run_squad.py <your arguments>` parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank, torch.distributed.launch 启动器会自动赋值") parser.add_argument('--cpu', action='store_true', default=False, help='不使用cuda加速,仅使用cpu') subparsers = parser.add_subparsers(help='子命令:{train:训练|dev:验证|infer:推理}', dest='command') # ------------------------训练模式参数--------------------------------------------------------------- parser_train = subparsers.add_parser('train', help='训练模式') parser_train.add_argument('-c', '--config_file', required=True, help='训练时的yaml配置文件路径') parser_train.add_argument('--no_output', action='store_true', default=False, help='不输出文件,用于调试') parser_train.add_argument('--test_after_train', action='store_true', help='是否在训练完成之后立刻做一次测试') # ------------------------验证/推理模式参数------------------------------------------------------------- # 验证模式和推理模式所需要的参数完全相同(--saved_model_path,-i,-o),因此这里使用一个parent_parser来处理重复的参数 # 更详细的说明参考:https://stackoverflow.com/a/56595689 [Python argparse - Add argument to multiple subparsers] dev_infer_parent_parser = argparse.ArgumentParser(add_help=False) # add_help必须为False dev_infer_parent_parser.add_argument('-m', '--saved_model_path', required=True, help='预先训练好的模型路径,其下需包含一个config.yaml(配置信息)文件和一个model(模型参数)文件夹') dev_infer_parent_parser.add_argument('-i', '--input_conllu_path', required=True, help='输入CONLL-U文件,dev模式下是一个gold file,infer模式下是一个空conllu file') dev_infer_parent_parser.add_argument('-o', '--output_conllu_path', required=True, help='dev或者infer的输出文件路径') dev_infer_parent_parser.add_argument('-b', '--batch_size', default=5, type=int, help='dev或者infer时刻的batch大小') # -----------------------再处理dev和infer各自的参数(如果有)-------------------------------------------- parser_dev = subparsers.add_parser('dev', help='验证模式', parents=[dev_infer_parent_parser]) parser_infer = subparsers.add_parser('infer', help='推理模式', parents=[dev_infer_parent_parser]) # -------------------------------------------------------------------------------------------------- configs = vars(parser.parse_args()) configs['cuda'] = not configs['cpu'] # 加载yaml配置参数 if configs['command'] == 'train': yaml_config_file = Path(configs['config_file']) else: # dev 或者 inference 模式下配置文件config.yaml放置在saved_model_path文件夹下 # 而模型参数则放置在 saved_model_path/model 下 yaml_config_file = Path(configs['saved_model_path']) / 'config.yaml' # 获取yaml配置文件之后将 saved_model_path 调整为 saved_model_path/model,指向真正放置模型参数的文件夹 configs['saved_model_path'] = str(Path(configs['saved_model_path']) / 'model') if configs['batch_size']: configs['per_gpu_eval_batch_size'] = configs['batch_size'] del configs['batch_size'] if not yaml_config_file.exists(): raise RuntimeError(f'yaml config file {str(yaml_config_file)} not exist') yaml_configs = load_configs_from_yaml(str(yaml_config_file)) # 合并两个参数字典 for c, v in configs.items(): # 将命令行读取的参数字典覆盖写入到yaml配置参数字典中 # 在推理或者验证阶段时,yaml文件中的部分配置信息可能是训练时遗留下来的,已经过时 # 这里覆盖yaml配置的方式可以确保覆盖这些过时的配置信息 yaml_configs[c] = v # 转化为object格式 configs = SimpleNamespace(**yaml_configs) # if configs.skip_too_long_input: # print(f'skip_too_long_input is True, max_seq_len is {configs.max_seq_len}') return configs if __name__ == '__main__': from pprint import pprint pprint(vars(parse_args()))