# -*- coding: utf-8 -*- # @Time : 2018/8/23 22:20 # @Author : zhoujun import os def main(config): from mxnet import nd from mxnet.gluon.loss import CTCLoss from models import get_model from data_loader import get_dataloader from trainer import Trainer from utils import get_ctx, load if os.path.isfile(config['dataset']['alphabet']): config['dataset']['alphabet'] = ''.join(load(config['dataset']['alphabet'])) prediction_type = config['arch']['args']['prediction']['type'] num_class = len(config['dataset']['alphabet']) # loss 设置 if prediction_type == 'CTC': criterion = CTCLoss() else: raise NotImplementedError ctx = get_ctx(config['trainer']['gpus']) model = get_model(num_class, ctx, config['arch']['args']) model.hybridize() model.initialize(ctx=ctx) img_h, img_w = 32, 100 for process in config['dataset']['train']['dataset']['args']['pre_processes']: if process['type'] == "Resize": img_h = process['args']['img_h'] img_w = process['args']['img_w'] break img_channel = 3 if config['dataset']['train']['dataset']['args']['img_mode'] != 'GRAY' else 1 sample_input = nd.zeros((2, img_channel, img_h, img_w), ctx[0]) num_label = model.get_batch_max_length(sample_input) train_loader = get_dataloader(config['dataset']['train'], num_label, config['dataset']['alphabet']) assert train_loader is not None if 'validate' in config['dataset']: validate_loader = get_dataloader(config['dataset']['validate'], num_label, config['dataset']['alphabet']) else: validate_loader = None config['lr_scheduler']['args']['step'] *= len(train_loader) trainer = Trainer(config=config, model=model, criterion=criterion, train_loader=train_loader, validate_loader=validate_loader, sample_input=sample_input, ctx=ctx) trainer.train() def init_args(): import argparse parser = argparse.ArgumentParser(description='DBNet.pytorch') parser.add_argument('--config_file', default='config/icdar2015.yaml', type=str) args = parser.parse_args() return args if __name__ == '__main__': import sys import anyconfig project = 'crnn.gluon' # 工作项目根目录 sys.path.append(os.getcwd().split(project)[0] + project) from utils import parse_config args = init_args() assert os.path.exists(args.config_file) config = anyconfig.load(open(args.config_file, 'rb')) if 'base' in config: config = parse_config(config) main(config)