import os import yaml import random import logging logging.basicConfig(level=logging.INFO) import mxnet as mx from mxnet import nd, gluon, autograd, init import numpy as np np.set_printoptions(formatter={'float': '{: 0.2f}'.format}) import data.dataloader import config from config import PARAM_PATH from helper.callback import Speedometer, Logger from helper.metric import MAE, RMSE, IndexMAE, IndexRMSE import model class ModelTrainer: def __init__(self, net, trainer, clip_gradient, logger, ctx): self.net = net self.trainer = trainer self.clip_gradient = clip_gradient self.logger = logger self.ctx = ctx def step(self, batch_size): # train step with gradient clipping self.trainer.allreduce_grads() grads = [] for param in self.trainer._params: if param.grad_req == 'write': grads += param.list_grad() import math gluon.utils.clip_global_norm(grads, self.clip_gradient * math.sqrt(len(self.ctx))) self.trainer.update(batch_size, ignore_stale_grad=True) def process_data(self, epoch, dataloader, metrics=None, is_training=True, title='[TRAIN]'): speedometer = Speedometer(title, epoch, frequent=50) speedometer.reset() if metrics is not None: for metric in metrics: metric.reset() for nbatch, batch_data in enumerate(dataloader): inputs = [gluon.utils.split_and_load(x, self.ctx) for x in batch_data] if is_training: self.net.decoder.global_steps += 1.0 with autograd.record(): outputs = [self.net(*x, is_training) for x in zip(*inputs)] for out in outputs: out[0].backward() self.step(batch_data[0].shape[0]) else: outputs = [self.net(*x, False) for x in zip(*inputs)] if metrics is not None: for metric in metrics: for out in outputs: metric.update(*out[1]) speedometer.log_metrics(nbatch + 1, metrics) speedometer.finish(metrics) def fit(self, begin_epoch, num_epochs, train, eval, test, metrics=None): for epoch in range(begin_epoch, begin_epoch + num_epochs): if train is not None: self.process_data(epoch, train, metrics) if eval is not None: self.process_data(epoch, eval, metrics, is_training=False, title='[EVAL]') if (train is not None) and (metrics is not None): self.logger.log(epoch, metrics) if test is not None: self.process_data(epoch, test, metrics, is_training=False, title='[TEST]') print('') def main(args): with open(args.file, 'r') as f: settings = yaml.load(f) assert args.file[:-5].endswith(settings['model']['name']), \ 'The model name is not consistent! %s != %s' % (args.file[:-5], settings['model']['name']) mx.random.seed(settings['seed']) np.random.seed(settings['seed']) random.seed(settings['seed']) dataset_setting = settings['dataset'] model_setting = settings['model'] train_setting = settings['training'] ### set meta hiddens if 'meta_hiddens' in model_setting.keys(): config.MODEL['meta_hiddens'] = model_setting['meta_hiddens'] name = os.path.join(PARAM_PATH, model_setting['name']) model_type = getattr(model, model_setting['type']) net = model_type.net(settings) try: logger = Logger.load('%s.yaml' % name) net.load_parameters('%s-%04d.params' % (name, logger.best_epoch()), ctx=args.gpus) logger.set_net(net) print('Successfully loading the model %s [epoch: %d]' % (model_setting['name'], logger.best_epoch())) num_params = 0 for v in net.collect_params().values(): num_params += np.prod(v.shape) print(net.collect_params()) print('NUMBER OF PARAMS:', num_params) except: logger = Logger(name, net, train_setting['early_stop_metric'], train_setting['early_stop_epoch']) net.initialize(init.Orthogonal(), ctx=args.gpus) print('Initialize the model') # net.hybridize() model_trainer = ModelTrainer( net = net, trainer = gluon.Trainer( net.collect_params(), mx.optimizer.Adam( learning_rate = train_setting['lr'], multi_precision = True, lr_scheduler = mx.lr_scheduler.FactorScheduler( step = train_setting['lr_decay_step'] * len(args.gpus), factor = train_setting['lr_decay_factor'], stop_factor_lr = 1e-6 ) ), update_on_kvstore = False ), clip_gradient = train_setting['clip_gradient'], logger = logger, ctx = args.gpus ) train, eval, test, scaler = getattr(data.dataloader, dataset_setting['dataloader'])(settings) model_trainer.fit( begin_epoch = logger.best_epoch(), num_epochs = args.epochs, train = train, eval = eval, test = test, metrics = [MAE(scaler), RMSE(scaler), IndexMAE(scaler, [0,1,2]), IndexRMSE(scaler, [0,1,2])], ) net.load_parameters('%s-%04d.params' % (name, logger.best_epoch()), ctx=args.gpus) model_trainer.fit( begin_epoch = 0, num_epochs = 1, train = None, eval = eval, test = test, metrics = [MAE(scaler), RMSE(scaler), IndexMAE(scaler, [0,1,2]), IndexRMSE(scaler, [0,1,2])] ) if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--file', type=str) parser.add_argument('--epochs', type=int) parser.add_argument('--gpus', type=str) args = parser.parse_args() args.gpus = [mx.gpu(int(i)) for i in args.gpus.split(',')] main(args)