""" Script for training model on MXNet/Gluon. """ import argparse import time import logging import os import random import numpy as np import mxnet as mx from mxnet import gluon from mxnet import autograd as ag from common.logger_utils import initialize_logging from common.train_log_param_saver import TrainLogParamSaver from gluon.lr_scheduler import LRScheduler from gluon.utils import prepare_mx_context, prepare_model, validate from gluon.utils import report_accuracy, get_composite_metric, get_metric_name, get_initializer, get_loss from gluon.dataset_utils import get_dataset_metainfo from gluon.dataset_utils import get_train_data_source, get_val_data_source from gluon.dataset_utils import get_batch_fn def add_train_cls_parser_arguments(parser): """ Create python script parameters (for training/classification specific subpart). Parameters: ---------- parser : ArgumentParser ArgumentParser instance. """ parser.add_argument( "--model", type=str, required=True, help="type of model to use. see model_provider for options") parser.add_argument( "--use-pretrained", action="store_true", help="enable using pretrained model from github repo") parser.add_argument( "--dtype", type=str, default="float32", help="data type for training") parser.add_argument( '--not-hybridize', action='store_true', help='do not hybridize model') parser.add_argument( "--resume", type=str, default="", help="resume from previously saved parameters if not None") parser.add_argument( "--resume-state", type=str, default="", help="resume from previously saved optimizer state if not None") parser.add_argument( "--initializer", type=str, default="MSRAPrelu", help="initializer name. options are MSRAPrelu, Xavier and Xavier-gaussian-out-2") parser.add_argument( "--num-gpus", type=int, default=0, help="number of gpus to use") parser.add_argument( "-j", "--num-data-workers", dest="num_workers", default=4, type=int, help="number of preprocessing workers") parser.add_argument( "--batch-size", type=int, default=512, help="training batch size per device (CPU/GPU)") parser.add_argument( "--batch-size-scale", type=int, default=1, help="manual batch-size increasing factor") parser.add_argument( "--num-epochs", type=int, default=120, help="number of training epochs") parser.add_argument( "--start-epoch", type=int, default=1, help="starting epoch for resuming, default is 1 for new training") parser.add_argument( "--attempt", type=int, default=1, help="current attempt number for training") parser.add_argument( "--optimizer-name", type=str, default="nag", help="optimizer name") parser.add_argument( "--lr", type=float, default=0.1, help="learning rate") parser.add_argument( "--lr-mode", type=str, default="cosine", help="learning rate scheduler mode. options are step, poly and cosine") parser.add_argument( "--lr-decay", type=float, default=0.1, help="decay rate of learning rate") parser.add_argument( "--lr-decay-period", type=int, default=0, help="interval for periodic learning rate decays. default is 0 to disable") parser.add_argument( "--lr-decay-epoch", type=str, default="40,60", help="epoches at which learning rate decays") parser.add_argument( "--target-lr", type=float, default=1e-8, help="ending learning rate") parser.add_argument( "--poly-power", type=float, default=2, help="power value for poly LR scheduler") parser.add_argument( "--warmup-epochs", type=int, default=0, help="number of warmup epochs") parser.add_argument( "--warmup-lr", type=float, default=1e-8, help="starting warmup learning rate") parser.add_argument( "--warmup-mode", type=str, default="linear", help="learning rate scheduler warmup mode. options are linear, poly and constant") parser.add_argument( "--momentum", type=float, default=0.9, help="momentum value for optimizer") parser.add_argument( "--wd", type=float, default=0.0001, help="weight decay rate") parser.add_argument( "--gamma-wd-mult", type=float, default=1.0, help="weight decay multiplier for batchnorm gamma") parser.add_argument( "--beta-wd-mult", type=float, default=1.0, help="weight decay multiplier for batchnorm beta") parser.add_argument( "--bias-wd-mult", type=float, default=1.0, help="weight decay multiplier for bias") parser.add_argument( "--grad-clip", type=float, default=None, help="max_norm for gradient clipping") parser.add_argument( "--label-smoothing", action="store_true", help="use label smoothing") parser.add_argument( "--mixup", action="store_true", help="use mixup strategy") parser.add_argument( "--mixup-epoch-tail", type=int, default=12, help="number of epochs without mixup at the end of training") parser.add_argument( "--log-interval", type=int, default=50, help="number of batches to wait before logging") parser.add_argument( "--save-interval", type=int, default=4, help="saving parameters epoch interval, best model will always be saved") parser.add_argument( "--save-dir", type=str, default="", help="directory of saved models and log-files") parser.add_argument( "--logging-file-name", type=str, default="train.log", help="filename of training log") parser.add_argument( "--seed", type=int, default=-1, help="random seed to be fixed") parser.add_argument( "--log-packages", type=str, default="mxnet, numpy", help="list of python packages for logging") parser.add_argument( "--log-pip-packages", type=str, default="mxnet-cu101", help="list of pip packages for logging") parser.add_argument( "--tune-layers", type=str, default="", help="regexp for selecting layers for fine tuning") def parse_args(): """ Parse python script parameters (common part). Returns ------- ArgumentParser Resulted args. """ parser = argparse.ArgumentParser( description="Train a model for image classification (Gluon)", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument( "--dataset", type=str, default="ImageNet1K_rec", help="dataset name. options are ImageNet1K, ImageNet1K_rec, CUB200_2011, CIFAR10, CIFAR100, SVHN") parser.add_argument( "--work-dir", type=str, default=os.path.join("..", "imgclsmob_data"), help="path to working directory only for dataset root path preset") args, _ = parser.parse_known_args() dataset_metainfo = get_dataset_metainfo(dataset_name=args.dataset) dataset_metainfo.add_dataset_parser_arguments( parser=parser, work_dir_path=args.work_dir) add_train_cls_parser_arguments(parser) args = parser.parse_args() return args def init_rand(seed): """ Initialize all random generators by seed. Parameters: ---------- seed : int Seed value. Returns ------- int Generated seed value. """ if seed <= 0: seed = np.random.randint(10000) random.seed(seed) np.random.seed(seed) mx.random.seed(seed) return seed def prepare_trainer(net, optimizer_name, wd, momentum, lr_mode, lr, lr_decay_period, lr_decay_epoch, lr_decay, target_lr, poly_power, warmup_epochs, warmup_lr, warmup_mode, batch_size, num_epochs, num_training_samples, dtype, gamma_wd_mult=1.0, beta_wd_mult=1.0, bias_wd_mult=1.0, state_file_path=None): """ Prepare trainer. Parameters: ---------- net : HybridBlock Model. optimizer_name : str Name of optimizer. wd : float Weight decay rate. momentum : float Momentum value. lr_mode : str Learning rate scheduler mode. lr : float Learning rate. lr_decay_period : int Interval for periodic learning rate decays. lr_decay_epoch : str Epoches at which learning rate decays. lr_decay : float Decay rate of learning rate. target_lr : float Final learning rate. poly_power : float Power value for poly LR scheduler. warmup_epochs : int Number of warmup epochs. warmup_lr : float Starting warmup learning rate. warmup_mode : str Learning rate scheduler warmup mode. batch_size : int Training batch size. num_epochs : int Number of training epochs. num_training_samples : int Number of training samples in dataset. dtype : str Base data type for tensors. gamma_wd_mult : float Weight decay multiplier for batchnorm gamma. beta_wd_mult : float Weight decay multiplier for batchnorm beta. bias_wd_mult : float Weight decay multiplier for bias. state_file_path : str, default None Path for file with trainer state. Returns ------- Trainer Trainer. LRScheduler Learning rate scheduler. """ if gamma_wd_mult != 1.0: for k, v in net.collect_params(".*gamma").items(): v.wd_mult = gamma_wd_mult if beta_wd_mult != 1.0: for k, v in net.collect_params(".*beta").items(): v.wd_mult = beta_wd_mult if bias_wd_mult != 1.0: for k, v in net.collect_params(".*bias").items(): v.wd_mult = bias_wd_mult if lr_decay_period > 0: lr_decay_epoch = list(range(lr_decay_period, num_epochs, lr_decay_period)) else: lr_decay_epoch = [int(i) for i in lr_decay_epoch.split(",")] num_batches = num_training_samples // batch_size lr_scheduler = LRScheduler( mode=lr_mode, base_lr=lr, n_iters=num_batches, n_epochs=num_epochs, step=lr_decay_epoch, step_factor=lr_decay, target_lr=target_lr, power=poly_power, warmup_epochs=warmup_epochs, warmup_lr=warmup_lr, warmup_mode=warmup_mode) optimizer_params = {"learning_rate": lr, "wd": wd, "momentum": momentum, "lr_scheduler": lr_scheduler} if dtype != "float32": optimizer_params["multi_precision"] = True trainer = gluon.Trainer( params=net.collect_params(), optimizer=optimizer_name, optimizer_params=optimizer_params) if (state_file_path is not None) and state_file_path and os.path.exists(state_file_path): logging.info("Loading trainer states: {}".format(state_file_path)) trainer.load_states(state_file_path) if trainer._optimizer.wd != wd: trainer._optimizer.wd = wd logging.info("Reset the weight decay: {}".format(wd)) # lr_scheduler = trainer._optimizer.lr_scheduler trainer._optimizer.lr_scheduler = lr_scheduler return trainer, lr_scheduler def save_params(file_stem, net, trainer): """ Save current model/trainer parameters. Parameters: ---------- file_stem : str File stem (with path). net : HybridBlock Model. trainer : Trainer Trainer. """ net.save_parameters(file_stem + ".params") trainer.save_states(file_stem + ".states") def train_epoch(epoch, net, train_metric, train_data, batch_fn, data_source_needs_reset, dtype, ctx, loss_func, trainer, lr_scheduler, batch_size, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, num_epochs, grad_clip_value, batch_size_scale): """ Train model on particular epoch. Parameters: ---------- epoch : int Epoch number. net : HybridBlock Model. train_metric : EvalMetric Metric object instance. train_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator. batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). dtype : str Base data type for tensors. ctx : Context MXNet context. loss_func : Loss Loss function. trainer : Trainer Trainer. lr_scheduler : LRScheduler Learning rate scheduler. batch_size : int Training batch size. log_interval : int Batch count period for logging. mixup : bool Whether to use mixup. mixup_epoch_tail : int Number of epochs without mixup at the end of training. label_smoothing : bool Whether to use label-smoothing. num_classes : int Number of model classes. num_epochs : int Number of training epochs. grad_clip_value : float Threshold for gradient clipping. batch_size_scale : int Manual batch-size increasing factor. Returns ------- float Loss value. """ labels_list_inds = None batch_size_extend_count = 0 tic = time.time() if data_source_needs_reset: train_data.reset() train_metric.reset() train_loss = 0.0 i = 0 btic = time.time() for i, batch in enumerate(train_data): data_list, labels_list = batch_fn(batch, ctx) if label_smoothing: eta = 0.1 on_value = 1 - eta + eta / num_classes off_value = eta / num_classes labels_list_inds = labels_list labels_list = [Y.one_hot(depth=num_classes, on_value=on_value, off_value=off_value) for Y in labels_list] if mixup: if not label_smoothing: labels_list_inds = labels_list labels_list = [Y.one_hot(depth=num_classes) for Y in labels_list] if epoch < num_epochs - mixup_epoch_tail: alpha = 1 lam = np.random.beta(alpha, alpha) data_list = [lam * X + (1 - lam) * X[::-1] for X in data_list] labels_list = [lam * Y + (1 - lam) * Y[::-1] for Y in labels_list] with ag.record(): outputs_list = [net(X.astype(dtype, copy=False)) for X in data_list] loss_list = [loss_func(yhat, y.astype(dtype, copy=False)) for yhat, y in zip(outputs_list, labels_list)] for loss in loss_list: loss.backward() lr_scheduler.update(i, epoch) if grad_clip_value is not None: grads = [v.grad(ctx[0]) for v in net.collect_params().values() if v._grad is not None] gluon.utils.clip_global_norm(grads, max_norm=grad_clip_value) if batch_size_scale == 1: trainer.step(batch_size) else: if (i + 1) % batch_size_scale == 0: batch_size_extend_count = 0 trainer.step(batch_size * batch_size_scale) for p in net.collect_params().values(): p.zero_grad() else: batch_size_extend_count += 1 train_loss += sum([loss.mean().asscalar() for loss in loss_list]) / len(loss_list) train_metric.update( labels=(labels_list if not (mixup or label_smoothing) else labels_list_inds), preds=outputs_list) if log_interval and not (i + 1) % log_interval: speed = batch_size * log_interval / (time.time() - btic) btic = time.time() train_accuracy_msg = report_accuracy(metric=train_metric) logging.info("Epoch[{}] Batch [{}]\tSpeed: {:.2f} samples/sec\t{}\tlr={:.5f}".format( epoch + 1, i, speed, train_accuracy_msg, trainer.learning_rate)) if (batch_size_scale != 1) and (batch_size_extend_count > 0): trainer.step(batch_size * batch_size_extend_count) for p in net.collect_params().values(): p.zero_grad() throughput = int(batch_size * (i + 1) / (time.time() - tic)) logging.info("[Epoch {}] speed: {:.2f} samples/sec\ttime cost: {:.2f} sec".format( epoch + 1, throughput, time.time() - tic)) train_loss /= (i + 1) train_accuracy_msg = report_accuracy(metric=train_metric) logging.info("[Epoch {}] training: {}\tloss={:.4f}".format( epoch + 1, train_accuracy_msg, train_loss)) return train_loss def train_net(batch_size, num_epochs, start_epoch1, train_data, val_data, batch_fn, data_source_needs_reset, dtype, net, trainer, lr_scheduler, lp_saver, log_interval, mixup, mixup_epoch_tail, label_smoothing, num_classes, grad_clip_value, batch_size_scale, val_metric, train_metric, loss_func, ctx): """ Main procedure for training model. Parameters: ---------- batch_size : int Training batch size. num_epochs : int Number of training epochs. start_epoch1 : int Number of starting epoch (1-based). train_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator (training subset). val_data : DataLoader or ImageRecordIter Data loader or ImRec-iterator (validation subset). batch_fn : func Function for splitting data after extraction from data loader. data_source_needs_reset : bool Whether to reset data (if test_data is ImageRecordIter). dtype : str Base data type for tensors. net : HybridBlock Model. trainer : Trainer Trainer. lr_scheduler : LRScheduler Learning rate scheduler. lp_saver : TrainLogParamSaver Model/trainer state saver. log_interval : int Batch count period for logging. mixup : bool Whether to use mixup. mixup_epoch_tail : int Number of epochs without mixup at the end of training. label_smoothing : bool Whether to use label-smoothing. num_classes : int Number of model classes. grad_clip_value : float Threshold for gradient clipping. batch_size_scale : int Manual batch-size increasing factor. val_metric : EvalMetric Metric object instance (validation subset). train_metric : EvalMetric Metric object instance (training subset). loss_func : Loss Loss object instance. ctx : Context MXNet context. """ if batch_size_scale != 1: for p in net.collect_params().values(): p.grad_req = "add" if isinstance(ctx, mx.Context): ctx = [ctx] # loss_func = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=(not (mixup or label_smoothing))) assert (type(start_epoch1) == int) assert (start_epoch1 >= 1) if start_epoch1 > 1: logging.info("Start training from [Epoch {}]".format(start_epoch1)) validate( metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format(start_epoch1 - 1, val_accuracy_msg)) gtic = time.time() for epoch in range(start_epoch1 - 1, num_epochs): train_loss = train_epoch( epoch=epoch, net=net, train_metric=train_metric, train_data=train_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx, loss_func=loss_func, trainer=trainer, lr_scheduler=lr_scheduler, batch_size=batch_size, log_interval=log_interval, mixup=mixup, mixup_epoch_tail=mixup_epoch_tail, label_smoothing=label_smoothing, num_classes=num_classes, num_epochs=num_epochs, grad_clip_value=grad_clip_value, batch_size_scale=batch_size_scale) validate( metric=val_metric, net=net, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=data_source_needs_reset, dtype=dtype, ctx=ctx) val_accuracy_msg = report_accuracy(metric=val_metric) logging.info("[Epoch {}] validation: {}".format(epoch + 1, val_accuracy_msg)) if lp_saver is not None: lp_saver_kwargs = {"net": net, "trainer": trainer} val_acc_values = val_metric.get()[1] train_acc_values = train_metric.get()[1] val_acc_values = val_acc_values if type(val_acc_values) == list else [val_acc_values] train_acc_values = train_acc_values if type(train_acc_values) == list else [train_acc_values] lp_saver.epoch_test_end_callback( epoch1=(epoch + 1), params=(val_acc_values + train_acc_values + [train_loss, trainer.learning_rate]), **lp_saver_kwargs) logging.info("Total time cost: {:.2f} sec".format(time.time() - gtic)) if lp_saver is not None: opt_metric_name = get_metric_name(val_metric, lp_saver.acc_ind) logging.info("Best {}: {:.4f} at {} epoch".format( opt_metric_name, lp_saver.best_eval_metric_value, lp_saver.best_eval_metric_epoch)) def main(): """ Main body of script. """ args = parse_args() args.seed = init_rand(seed=args.seed) _, log_file_exist = initialize_logging( logging_dir_path=args.save_dir, logging_file_name=args.logging_file_name, script_args=args, log_packages=args.log_packages, log_pip_packages=args.log_pip_packages) ctx, batch_size = prepare_mx_context( num_gpus=args.num_gpus, batch_size=args.batch_size) ds_metainfo = get_dataset_metainfo(dataset_name=args.dataset) ds_metainfo.update(args=args) net = prepare_model( model_name=args.model, use_pretrained=args.use_pretrained, pretrained_model_file_path=args.resume.strip(), dtype=args.dtype, net_extra_kwargs=ds_metainfo.train_net_extra_kwargs, tune_layers=args.tune_layers, classes=args.num_classes, in_channels=args.in_channels, do_hybridize=(not args.not_hybridize), initializer=get_initializer(initializer_name=args.initializer), ctx=ctx) assert (hasattr(net, "classes")) num_classes = net.classes train_data = get_train_data_source( ds_metainfo=ds_metainfo, batch_size=batch_size, num_workers=args.num_workers) val_data = get_val_data_source( ds_metainfo=ds_metainfo, batch_size=batch_size, num_workers=args.num_workers) batch_fn = get_batch_fn(ds_metainfo=ds_metainfo) num_training_samples = len(train_data._dataset) if not ds_metainfo.use_imgrec else ds_metainfo.num_training_samples trainer, lr_scheduler = prepare_trainer( net=net, optimizer_name=args.optimizer_name, wd=args.wd, momentum=args.momentum, lr_mode=args.lr_mode, lr=args.lr, lr_decay_period=args.lr_decay_period, lr_decay_epoch=args.lr_decay_epoch, lr_decay=args.lr_decay, target_lr=args.target_lr, poly_power=args.poly_power, warmup_epochs=args.warmup_epochs, warmup_lr=args.warmup_lr, warmup_mode=args.warmup_mode, batch_size=batch_size, num_epochs=args.num_epochs, num_training_samples=num_training_samples, dtype=args.dtype, gamma_wd_mult=args.gamma_wd_mult, beta_wd_mult=args.beta_wd_mult, bias_wd_mult=args.bias_wd_mult, state_file_path=args.resume_state) if args.save_dir and args.save_interval: param_names = ds_metainfo.val_metric_capts + ds_metainfo.train_metric_capts + ["Train.Loss", "LR"] lp_saver = TrainLogParamSaver( checkpoint_file_name_prefix="{}_{}".format(ds_metainfo.short_label, args.model), last_checkpoint_file_name_suffix="last", best_checkpoint_file_name_suffix=None, last_checkpoint_dir_path=args.save_dir, best_checkpoint_dir_path=None, last_checkpoint_file_count=2, best_checkpoint_file_count=2, checkpoint_file_save_callback=save_params, checkpoint_file_exts=(".params", ".states"), save_interval=args.save_interval, num_epochs=args.num_epochs, param_names=param_names, acc_ind=ds_metainfo.saver_acc_ind, # bigger=[True], # mask=None, score_log_file_path=os.path.join(args.save_dir, "score.log"), score_log_attempt_value=args.attempt, best_map_log_file_path=os.path.join(args.save_dir, "best_map.log")) else: lp_saver = None val_metric = get_composite_metric(ds_metainfo.val_metric_names, ds_metainfo.val_metric_extra_kwargs) train_metric = get_composite_metric(ds_metainfo.train_metric_names, ds_metainfo.train_metric_extra_kwargs) loss_kwargs = {"sparse_label": not (args.mixup or args.label_smoothing)} if ds_metainfo.loss_extra_kwargs is not None: loss_kwargs.update(ds_metainfo.loss_extra_kwargs) loss_func = get_loss(ds_metainfo.loss_name, loss_kwargs) train_net( batch_size=batch_size, num_epochs=args.num_epochs, start_epoch1=args.start_epoch, train_data=train_data, val_data=val_data, batch_fn=batch_fn, data_source_needs_reset=ds_metainfo.use_imgrec, dtype=args.dtype, net=net, trainer=trainer, lr_scheduler=lr_scheduler, lp_saver=lp_saver, log_interval=args.log_interval, mixup=args.mixup, mixup_epoch_tail=args.mixup_epoch_tail, label_smoothing=args.label_smoothing, num_classes=num_classes, grad_clip_value=args.grad_clip, batch_size_scale=args.batch_size_scale, val_metric=val_metric, train_metric=train_metric, loss_func=loss_func, ctx=ctx) if __name__ == "__main__": main()