Python mmcv.Config.fromfile() Examples

The following are 30 code examples of mmcv.Config.fromfile(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module mmcv.Config , or try the search function .
Example #1
Source File: train.py    From PolarMask with Apache License 2.0 7 votes vote down vote up
def test():
    from tqdm import trange
    import cv2
    print('debug mode '*10 )
    args = parse_args()
    cfg = Config.fromfile(args.config)
    cfg.gpus = 1

    dataset = build_dataset(cfg.data.train)
    embed(header='123123')
    # def visual(i):
    #     img = dataset[i]['img'].data
    #     img = img.permute(1,2,0) + 100
    #     img = img.data.cpu().numpy()
    #     cv2.imwrite('./trash/resize_v1.jpg',img)

    # embed(header='check data resizer') 
Example #2
Source File: test_predictor.py    From mmfashion with Apache License 2.0 6 votes vote down vote up
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    img_tensor = get_img_tensor(args.input, args.use_cuda)

    cfg.model.pretrained = None
    model = build_predictor(cfg.model)
    load_checkpoint(model, args.checkpoint, map_location='cpu')
    if args.use_cuda:
        model.cuda()

    model.eval()

    # predict probabilities for each attribute
    attr_prob = model(img_tensor, attr=None, landmark=None, return_loss=False)
    attr_predictor = AttrPredictor(cfg.data.test)

    attr_predictor.show_prediction(attr_prob) 
Example #3
Source File: upgrade_model_version.py    From mmdetection with Apache License 2.0 6 votes vote down vote up
def parse_config(config_strings):
    temp_file = tempfile.NamedTemporaryFile()
    config_path = f'{temp_file.name}.py'
    with open(config_path, 'w') as f:
        f.write(config_strings)

    config = Config.fromfile(config_path)
    is_two_stage = True
    is_ssd = False
    is_retina = False
    reg_cls_agnostic = False
    if 'rpn_head' not in config.model:
        is_two_stage = False
        # check whether it is SSD
        if config.model.bbox_head.type == 'SSDHead':
            is_ssd = True
        elif config.model.bbox_head.type == 'RetinaHead':
            is_retina = True
    elif isinstance(config.model['bbox_head'], list):
        reg_cls_agnostic = True
    elif 'reg_class_agnostic' in config.model.bbox_head:
        reg_cls_agnostic = config.model.bbox_head \
            .reg_class_agnostic
    temp_file.close()
    return is_two_stage, is_ssd, is_retina, reg_cls_agnostic 
Example #4
Source File: test_config.py    From mmcv with Apache License 2.0 6 votes vote down vote up
def test_merge_from_base():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/d.py')
    cfg = Config.fromfile(cfg_file)
    assert isinstance(cfg, Config)
    assert cfg.filename == cfg_file
    base_cfg_file = osp.join(osp.dirname(__file__), 'data/config/base.py')
    merge_text = osp.abspath(osp.expanduser(base_cfg_file)) + '\n' + \
        open(base_cfg_file, 'r').read()
    merge_text += '\n' + osp.abspath(osp.expanduser(cfg_file)) + '\n' + \
                  open(cfg_file, 'r').read()
    assert cfg.text == merge_text
    assert cfg.item1 == [2, 3]
    assert cfg.item2.a == 1
    assert cfg.item3 is False
    assert cfg.item4 == 'test_base'

    with pytest.raises(TypeError):
        Config.fromfile(osp.join(osp.dirname(__file__), 'data/config/e.py')) 
Example #5
Source File: test_config.py    From mmcv with Apache License 2.0 6 votes vote down vote up
def test_merge_from_multiple_bases():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/l.py')
    cfg = Config.fromfile(cfg_file)
    assert isinstance(cfg, Config)
    assert cfg.filename == cfg_file
    # cfg.field
    assert cfg.item1 == [1, 2]
    assert cfg.item2.a == 0
    assert cfg.item3 is False
    assert cfg.item4 == 'test'
    assert cfg.item5 == dict(a=0, b=1)
    assert cfg.item6 == [dict(a=0), dict(b=1)]
    assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))

    with pytest.raises(KeyError):
        Config.fromfile(osp.join(osp.dirname(__file__), 'data/config/m.py')) 
Example #6
Source File: extract_features.py    From mmfashion with Apache License 2.0 6 votes vote down vote up
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    if args.data_type == 'train':
        image_set = build_dataset(cfg.data.train)
    elif args.data_type == 'query':
        image_set = build_dataset(cfg.data.query)
    elif args.data_type == 'gallery':
        image_set = build_dataset(cfg.data.gallery)
    else:
        raise TypeError('So far only support train/query/gallery dataset')

    if args.checkpoint is not None:
        cfg.load_from = args.checkpoint

    extract_features(image_set, cfg, args.save_dir) 
Example #7
Source File: test_config.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_merge_recursive_bases():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/f.py')
    cfg = Config.fromfile(cfg_file)
    assert isinstance(cfg, Config)
    assert cfg.filename == cfg_file
    # cfg.field
    assert cfg.item1 == [2, 3]
    assert cfg.item2.a == 1
    assert cfg.item3 is False
    assert cfg.item4 == 'test_recursive_bases' 
Example #8
Source File: test_config.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_dump_mapping():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/n.py')
    cfg = Config.fromfile(cfg_file)

    with tempfile.TemporaryDirectory() as temp_config_dir:
        text_cfg_filename = osp.join(temp_config_dir, '_text_config.py')
        cfg.dump(text_cfg_filename)
        text_cfg = Config.fromfile(text_cfg_filename)

    assert text_cfg._cfg_dict == cfg._cfg_dict 
Example #9
Source File: test_config.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_pretty_text():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/l.py')
    cfg = Config.fromfile(cfg_file)
    with tempfile.TemporaryDirectory() as temp_config_dir:
        text_cfg_filename = osp.join(temp_config_dir, '_text_config.py')
        with open(text_cfg_filename, 'w') as f:
            f.write(cfg.pretty_text)
        text_cfg = Config.fromfile(text_cfg_filename)
    assert text_cfg._cfg_dict == cfg._cfg_dict 
Example #10
Source File: test_config.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_fromfile_in_config():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/code.py')
    cfg = Config.fromfile(cfg_file)
    # cfg.field
    assert cfg.cfg.item1 == [1, 2]
    assert cfg.cfg.item2 == dict(a=0)
    assert cfg.cfg.item3 is True
    assert cfg.cfg.item4 == 'test'
    assert cfg.item5 == 1 
Example #11
Source File: test_config.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_merge_intermediate_variable():

    cfg_file = osp.join(osp.dirname(__file__), 'data/config/i_child.py')
    cfg = Config.fromfile(cfg_file)
    # cfg.field
    assert cfg.item1 == [1, 2]
    assert cfg.item2 == dict(a=0)
    assert cfg.item3 is True
    assert cfg.item4 == 'test'
    assert cfg.item_cfg == dict(b=2)
    assert cfg.item5 == dict(cfg=dict(b=1))
    assert cfg.item6 == dict(cfg=dict(b=2)) 
Example #12
Source File: test_config.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_merge_delete():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/delete.py')
    cfg = Config.fromfile(cfg_file)
    # cfg.field
    assert cfg.item1 == [1, 2]
    assert cfg.item2 == dict(b=0)
    assert cfg.item3 is True
    assert cfg.item4 == 'test'
    assert '_delete_' not in cfg.item2 
Example #13
Source File: test_config.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_reserved_key():
    cfg_file = osp.join(osp.dirname(__file__), 'data/config/g.py')
    with pytest.raises(KeyError):
        Config.fromfile(cfg_file) 
Example #14
Source File: browse_dataset.py    From mmdetection with Apache License 2.0 5 votes vote down vote up
def retrieve_data_cfg(config_path, skip_type):
    cfg = Config.fromfile(config_path)
    train_data_cfg = cfg.data.train
    train_data_cfg['pipeline'] = [
        x for x in train_data_cfg.pipeline if x['type'] not in skip_type
    ]

    return cfg 
Example #15
Source File: test_fashion_recommender.py    From mmfashion with Apache License 2.0 5 votes vote down vote up
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    if args.checkpoint is not None:
        cfg.load_from = args.checkpoint
    # init distributed env first
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # init logger
    logger = get_root_logger(cfg.log_level)
    logger.info('Distributed test: {}'.format(distributed))

    # data loader
    dataset = build_dataset(cfg.data.test)
    print('dataset loaded')

    # create model
    model = build_fashion_recommender(cfg.model)
    load_checkpoint(model, cfg.load_from, map_location='cpu')
    print('load checkpoint from: {}'.format(cfg.load_from))

    test_fashion_recommender(
        model, dataset, cfg, distributed=False, validate=False, logger=None) 
Example #16
Source File: test_retriever.py    From mmfashion with Apache License 2.0 5 votes vote down vote up
def main():
    seed = 0
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    args = parse_args()
    cfg = Config.fromfile(args.config)

    model = build_retriever(cfg.model)
    load_checkpoint(model, args.checkpoint)
    print('load checkpoint from {}'.format(args.checkpoint))

    if args.use_cuda:
        model.cuda()
    model.eval()

    img_tensor = get_img_tensor(args.input, args.use_cuda)

    query_feat = model(img_tensor, landmark=None, return_loss=False)
    query_feat = query_feat.data.cpu().numpy()

    gallery_set = build_dataset(cfg.data.gallery)
    gallery_embeds = _process_embeds(gallery_set, model, cfg)

    retriever = ClothesRetriever(cfg.data.gallery.img_file, cfg.data_root,
                                 cfg.data.gallery.img_path)
    retriever.show_retrieved_images(query_feat, gallery_embeds) 
Example #17
Source File: test_landmark_detector.py    From mmfashion with Apache License 2.0 5 votes vote down vote up
def main():
    seed = 0
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    args = parse_args()
    cfg = Config.fromfile(args.config)

    img_tensor, w, h = get_img_tensor(args.input, args.use_cuda, get_size=True)

    # build model and load checkpoint
    model = build_landmark_detector(cfg.model)
    print('model built')
    load_checkpoint(model, args.checkpoint)
    print('load checkpoint from: {}'.format(args.checkpoint))

    if args.use_cuda:
        model.cuda()

    # detect landmark
    model.eval()
    pred_vis, pred_lm = model(img_tensor, return_loss=False)
    pred_lm = pred_lm.data.cpu().numpy()
    vis_lms = []

    for i, vis in enumerate(pred_vis):
        if vis >= 0.5:
            print('detected landmark {} {}'.format(
                pred_lm[i][0] * (w / 224.), pred_lm[i][1] * (h / 224.)))
            vis_lms.append(pred_lm[i])

    draw_landmarks(args.input, vis_lms) 
Example #18
Source File: test_forward.py    From mmdetection with Apache License 2.0 5 votes vote down vote up
def _get_config_module(fname):
    """Load a configuration as a python module."""
    from mmcv import Config
    config_dpath = _get_config_directory()
    config_fpath = join(config_dpath, fname)
    config_mod = Config.fromfile(config_fpath)
    return config_mod 
Example #19
Source File: test_config.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def test_syntax_error():
    temp_cfg_file = tempfile.NamedTemporaryFile(suffix='.py')
    temp_cfg_path = temp_cfg_file.name
    # write a file with syntax error
    with open(temp_cfg_path, 'w') as f:
        f.write('a=0b=dict(c=1)')
    with pytest.raises(
            SyntaxError,
            match='There are syntax errors in config '
            f'file {temp_cfg_path}'):
        Config.fromfile(temp_cfg_path)
    temp_cfg_file.close() 
Example #20
Source File: get_flops.py    From kaggle-kuzushiji-recognition with MIT License 5 votes vote down vote up
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
Example #21
Source File: get_flops.py    From RDSNet with Apache License 2.0 5 votes vote down vote up
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
Example #22
Source File: get_flops.py    From IoU-Uniform-R-CNN with Apache License 2.0 5 votes vote down vote up
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
Example #23
Source File: config.py    From mmskeleton with Apache License 2.0 5 votes vote down vote up
def fromfile(filename):
        try:
            return BaseConfig.fromfile(filename)
        except:
            return BaseConfig.fromfile(os.path.join(mmskl_home, filename)) 
Example #24
Source File: get_flops.py    From FoveaBox with Apache License 2.0 5 votes vote down vote up
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
Example #25
Source File: get_flops.py    From Cascade-RPN with Apache License 2.0 5 votes vote down vote up
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
Example #26
Source File: get_flops.py    From Feature-Selective-Anchor-Free-Module-for-Single-Shot-Object-Detection with Apache License 2.0 5 votes vote down vote up
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
Example #27
Source File: eda.py    From kaggle-imaterialist with MIT License 5 votes vote down vote up
def main():
    args = parse_args()
    os.makedirs(args.output, exist_ok=True)
    cfg = Config.fromfile(args.config)
    dataset = get_dataset(cfg.data.train)
    for i in tqdm(np.random.randint(0, len(dataset), 500)):
        data = dataset[i]
        img = data['img'].data.numpy().transpose(1, 2, 0)
        masks = data['gt_masks'].data.transpose(1, 2, 0).astype(bool)
        bboxes = data['gt_bboxes'].data.numpy()
        img = mmcv.imdenormalize(img, mean=cfg.img_norm_cfg.mean, std=cfg.img_norm_cfg.std, to_bgr=False)
        img = draw_masks(img, masks).astype(np.uint8)
        draw_bounding_boxes_on_image_array(img, bboxes, use_normalized_coordinates=False, thickness=5)
        cv2.imwrite(osp.join(args.output, f'{i}_{np.random.randint(0, 10000)}.jpg'), img[..., ::-1]) 
Example #28
Source File: get_flops.py    From CenterNet with Apache License 2.0 5 votes vote down vote up
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
Example #29
Source File: get_flops.py    From ttfnet with Apache License 2.0 5 votes vote down vote up
def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg).cuda()
    model.eval()

    if hasattr(model, 'forward_dummy'):
        model.forward = model.forward_dummy
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    flops, params = get_model_complexity_info(model, input_shape)
    split_line = '=' * 30
    print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
        split_line, input_shape, flops, params)) 
Example #30
Source File: main.py    From learn-to-cluster with MIT License 5 votes vote down vote up
def main():
    args = parse_args()
    cfg = Config.fromfile(args.config)

    # set cuda
    cfg.cuda = not args.no_cuda and torch.cuda.is_available()

    # set cudnn_benchmark & cudnn_deterministic
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True
    if cfg.get('cudnn_deterministic', False):
        torch.backends.cudnn.deterministic = True

    # update configs according to args
    if not hasattr(cfg, 'work_dir'):
        if args.work_dir is not None:
            cfg.work_dir = args.work_dir
        else:
            cfg_name = rm_suffix(os.path.basename(args.config))
            cfg.work_dir = os.path.join('./data/work_dir', cfg_name)
    mkdir_if_no_exists(cfg.work_dir, is_folder=True)

    cfg.load_from = args.load_from
    cfg.resume_from = args.resume_from

    cfg.gpus = args.gpus
    cfg.distributed = args.distributed
    cfg.save_output = args.save_output
    cfg.force = args.force

    logger = create_logger()

    # set random seeds
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

    model = build_model(cfg.model['type'], **cfg.model['kwargs'])
    handler = build_handler(args.phase)

    handler(model, cfg, logger)