import os import shutil import argparse from tqdm import tqdm import mxnet as mx from mxnet import gluon, autograd from mxnet.gluon.data.vision import transforms import gluoncv from gluoncv.loss import MixSoftmaxCrossEntropyLoss from gluoncv.utils import LRScheduler from gluoncv.model_zoo.segbase import get_segmentation_model, SegEvalModel from gluoncv.model_zoo import get_model from gluoncv.utils.parallel import DataParallelModel, DataParallelCriterion from gluoncv.data import get_segmentation_dataset def parse_args(): """Training Options for Segmentation Experiments""" parser = argparse.ArgumentParser(description='MXNet Gluon Segmentation') parser.add_argument('--model', type=str, default='fcn', help='model name (default: fcn)') parser.add_argument('--backbone', type=str, default='resnet50', help='backbone name (default: resnet50)') parser.add_argument('--dataset', type=str, default='pascalaug', help='dataset name (default: pascal)') parser.add_argument('--dataset-dir', type=str, default='../imgclsmob_data/voc', help='dataset path') parser.add_argument('--workers', type=int, default=16, metavar='N', help='dataloader threads') parser.add_argument('--base-size', type=int, default=520, help='base image size') parser.add_argument('--crop-size', type=int, default=480, help='crop image size') parser.add_argument('--train-split', type=str, default='train', help='dataset train split (default: train)') parser.add_argument('--aux', action='store_true', default=False, help='Auxiliary loss') parser.add_argument('--aux-weight', type=float, default=0.5, help='auxiliary loss weight') parser.add_argument('--epochs', type=int, default=50, metavar='N', help='number of epochs to train (default: 50)') parser.add_argument('--start_epoch', type=int, default=0, metavar='N', help='start epochs (default:0)') parser.add_argument('--batch-size', type=int, default=16, metavar='N', help='input batch size for training (default: 16)') parser.add_argument('--test-batch-size', type=int, default=16, metavar='N', help='input batch size for testing (default: 32)') parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate (default: 1e-3)') parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default: 0.9)') parser.add_argument('--weight-decay', type=float, default=1e-4, metavar='M', help='w-decay (default: 1e-4)') parser.add_argument('--no-wd', action='store_true', help='whether to remove weight decay on bias, and beta/gamma for batchnorm layers.') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--ngpus', type=int, default=len(mx.test_utils.list_gpus()), help='number of GPUs (default: 4)') parser.add_argument('--kvstore', type=str, default='device', help='kvstore to use for trainer/module.') parser.add_argument('--dtype', type=str, default='float32', help='data type for training. default is float32') # checking point parser.add_argument('--resume', type=str, default=None, help='put the path to resuming file if needed') parser.add_argument('--checkname', type=str, default='default', help='set the checkpoint name') parser.add_argument('--model-zoo', type=str, default=None, help='evaluating on model zoo model') # evaluation only parser.add_argument('--eval', action='store_true', default=False, help='evaluation only') parser.add_argument('--no-val', action='store_true', default=False, help='skip validation during training') # synchronized Batch Normalization parser.add_argument('--syncbn', action='store_true', default=False, help='using Synchronized Cross-GPU BatchNorm') # the parser args = parser.parse_args() # handle contexts if args.no_cuda: print('Using CPU') args.kvstore = 'local' args.ctx = [mx.cpu(0)] else: print('Number of GPUs:', args.ngpus) args.ctx = [mx.gpu(i) for i in range(args.ngpus)] # Synchronized BatchNorm args.norm_layer = mx.gluon.contrib.nn.SyncBatchNorm if args.syncbn else mx.gluon.nn.BatchNorm args.norm_kwargs = {'num_devices': args.ngpus} if args.syncbn else {} print(args) return args class Trainer(object): def __init__(self, args): self.args = args # image transform input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) # dataset and dataloader data_kwargs = { 'transform': input_transform, 'base_size': args.base_size, 'crop_size': args.crop_size, 'root': args.dataset_dir} trainset = get_segmentation_dataset( args.dataset, split=args.train_split, mode='train', **data_kwargs) valset = get_segmentation_dataset( args.dataset, split='val', mode='val', **data_kwargs) self.train_data = gluon.data.DataLoader( trainset, args.batch_size, shuffle=True, last_batch='rollover', num_workers=args.workers) self.eval_data = gluon.data.DataLoader( valset, args.test_batch_size, last_batch='rollover', num_workers=args.workers) # create network if args.model_zoo is not None: model = get_model(args.model_zoo, pretrained=True) else: model = get_segmentation_model( model=args.model, dataset=args.dataset, backbone=args.backbone, norm_layer=args.norm_layer, norm_kwargs=args.norm_kwargs, aux=args.aux, crop_size=args.crop_size) model.cast(args.dtype) print(model) self.net = DataParallelModel(model, args.ctx, args.syncbn) self.evaluator = DataParallelModel(SegEvalModel(model), args.ctx) # resume checkpoint if needed if args.resume is not None: if os.path.isfile(args.resume): model.load_parameters(args.resume, ctx=args.ctx) else: raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) # create criterion criterion = MixSoftmaxCrossEntropyLoss(args.aux, aux_weight=args.aux_weight) self.criterion = DataParallelCriterion(criterion, args.ctx, args.syncbn) # optimizer and lr scheduling self.lr_scheduler = LRScheduler( mode='poly', base_lr=args.lr, nepochs=args.epochs, iters_per_epoch=len(self.train_data), power=0.9) kv = mx.kv.create(args.kvstore) optimizer_params = { 'lr_scheduler': self.lr_scheduler, 'wd': args.weight_decay, 'momentum': args.momentum} if args.dtype == 'float16': optimizer_params['multi_precision'] = True if args.no_wd: for k, v in self.net.module.collect_params('.*beta|.*gamma|.*bias').items(): v.wd_mult = 0.0 self.optimizer = gluon.Trainer( self.net.module.collect_params(), 'sgd', optimizer_params, kvstore=kv) # evaluation metrics self.metric = gluoncv.utils.metrics.SegmentationMetric(trainset.num_class) def training(self, epoch): tbar = tqdm(self.train_data) train_loss = 0.0 for i, (data, target) in enumerate(tbar): with autograd.record(True): outputs = self.net(data.astype(args.dtype, copy=False)) losses = self.criterion(outputs, target) mx.nd.waitall() autograd.backward(losses) self.optimizer.step(self.args.batch_size) for loss in losses: train_loss += loss.asnumpy()[0] / len(losses) tbar.set_description('Epoch {}, training loss {}'.format(epoch, train_loss / (i + 1))) mx.nd.waitall() # save every epoch save_checkpoint(self.net.module, self.args, False) def validation(self, epoch): self.metric.reset() tbar = tqdm(self.eval_data) for i, (data, target) in enumerate(tbar): outputs = self.evaluator(data.astype(args.dtype, copy=False)) outputs = [x[0] for x in outputs] targets = mx.gluon.utils.split_and_load(target, args.ctx, even_split=False) self.metric.update(targets, outputs) pixAcc, mIoU = self.metric.get() tbar.set_description('Epoch {}, validation pixAcc: {}, mIoU: {}'.format(epoch, pixAcc, mIoU)) mx.nd.waitall() def save_checkpoint(net, args, is_best=False): """Save Checkpoint""" directory = "../imgclsmob_data/{}/{}/{}/".format(args.dataset, args.model, args.checkname) if not os.path.exists(directory): os.makedirs(directory) filename = 'checkpoint.params' filename = directory + filename net.save_parameters(filename) if is_best: shutil.copyfile(filename, directory + 'model_best.params') if __name__ == "__main__": args = parse_args() trainer = Trainer(args) if args.eval: print('Evaluating model: ', args.resume) trainer.validation(args.start_epoch) else: print('Starting Epoch:', args.start_epoch) print('Total Epochs:', args.epochs) for epoch in range(args.start_epoch, args.epochs): trainer.training(epoch) if not trainer.args.no_val: trainer.validation(epoch)