#!/usr/bin/env python3 import argparse, json, os import torch from utils import Logger #from data import FolderDataManager, ImageTransforms import data as data_module import net as net_module from train import Trainer from eval import ClassificationEvaluator, AudioInference def _get_transform(config, name): tsf_name = config['transforms']['type'] tsf_args = config['transforms']['args'] return getattr(data_module, tsf_name)(name, tsf_args) def _get_model_att(checkpoint): m_name = checkpoint['config']['model']['type'] sd = checkpoint['state_dict'] classes = checkpoint['classes'] return m_name, sd, classes def eval_main(checkpoint): config = checkpoint['config'] data_config = config['data'] tsf = _get_transform(config, 'val') data_manager = getattr(data_module, config['data']['type'])(config['data']) test_loader = data_manager.get_loader('val', tsf) m_name, sd, classes = _get_model_att(checkpoint) model = getattr(net_module, m_name)(classes, config, state_dict=sd) print(model) model.load_state_dict(checkpoint['state_dict']) num_classes = len(classes) metrics = getattr(net_module, config['metrics'])(num_classes) evaluation = ClassificationEvaluator(test_loader, model) ret = evaluation.evaluate(metrics) print(ret) return ret def infer_main(file_path, config, checkpoint): # Fix bugs if checkpoint is None: model = getattr(net_module, config['model']['type'])() else: m_name, sd, classes = _get_model_att(checkpoint) model = getattr(net_module, m_name)(classes, config, state_dict=sd) model.load_state_dict(checkpoint['state_dict']) tsf = _get_transform(config, 'val') inference = AudioInference(model, transforms=tsf) label, conf = inference.infer(file_path) print(label, conf) inference.draw(file_path, label, conf) def train_main(config, resume): train_logger = Logger() data_config = config['data'] t_transforms = _get_transform(config, 'train') v_transforms = _get_transform(config, 'val') print(t_transforms) data_manager = getattr(data_module, config['data']['type'])(config['data']) classes = data_manager.classes t_loader = data_manager.get_loader('train', t_transforms) v_loader = data_manager.get_loader('val', v_transforms) m_name = config['model']['type'] model = getattr(net_module, m_name)(classes, config=config) num_classes = len(classes) loss = getattr(net_module, config['train']['loss']) metrics = getattr(net_module, config['metrics'])(num_classes) trainable_params = filter(lambda p: p.requires_grad, model.parameters()) opt_name = config['optimizer']['type'] opt_args = config['optimizer']['args'] optimizer = getattr(torch.optim, opt_name)(trainable_params, **opt_args) lr_name = config['lr_scheduler']['type'] lr_args = config['lr_scheduler']['args'] if lr_name == 'None': lr_scheduler = None else: lr_scheduler = getattr(torch.optim.lr_scheduler, lr_name)(optimizer, **lr_args) trainer = Trainer(model, loss, metrics, optimizer, resume=resume, config=config, data_loader=t_loader, valid_data_loader=v_loader, lr_scheduler=lr_scheduler, train_logger=train_logger) trainer.train() return trainer #duration = 1; freq = 440 #os.system('play --no-show-progress --null --channels 1 synth %s sine %f'%(duration, freq)) def _test_loader(config): def disp_batch(batch): ret = [] for b in batch: if len(b.size()) != 1: ret.append(b.shape) else: ret.append(b) return ret tsf = _get_transform(config, 'train') data_manager = getattr(data_module, config['data']['type'])(config['data']) loader = data_manager.get_loader('train', tsf) print(tsf.transfs) for batch in loader: print(disp_batch([batch[0], batch[-1]])) if __name__ == '__main__': argparser = argparse.ArgumentParser(description='PyTorch Template') argparser.add_argument('action', type=str, help='what action to take (train, test, eval)') argparser.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)') argparser.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') argparser.add_argument('--net_mode', default='init', type=str, help='type of transfer learning to use') argparser.add_argument('--cfg', default=None, type=str, help='nn layer config file') args = argparser.parse_args() # Resolve config vs. resume checkpoint = None if args.config: config = json.load(open(args.config)) config['net_mode'] = args.net_mode config['cfg'] = args.cfg elif args.resume: checkpoint = torch.load(args.resume) config = checkpoint['config'] else: raise AssertionError("Configuration file need to be specified. Add '-c config.json', for example.") # Pick mode to run if args.action == 'train': train_main(config, args.resume) elif args.action == 'eval': eval_main(checkpoint) elif args.action == 'testloader': _test_loader(config) elif os.path.isfile(args.action): file_path = args.action infer_main(file_path, config, checkpoint)