from __future__ import division import os import sys import time import glob import logging from tqdm import tqdm import torch import torch.nn as nn import torch.utils from tensorboardX import SummaryWriter import numpy as np import matplotlib # Force matplotlib to not use any Xwindows backend. matplotlib.use('Agg') from matplotlib import pyplot as plt from thop import profile from config_search import config from dataloader import get_train_loader from datasets import Cityscapes from utils.init_func import init_weight from seg_opr.loss_opr import ProbOhemCrossEntropy2d from eval import SegEvaluator from architect import Architect from utils.darts_utils import create_exp_dir, save, plot_op, plot_path_width, objective_acc_lat from model_search import Network_Multi_Path as Network from model_seg import Network_Multi_Path_Infer def main(pretrain=True): config.save = 'search-{}-{}'.format(config.save, time.strftime("%Y%m%d-%H%M%S")) create_exp_dir(config.save, scripts_to_save=glob.glob('*.py')+glob.glob('*.sh')) logger = SummaryWriter(config.save) log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(config.save, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) assert type(pretrain) == bool or type(pretrain) == str update_arch = True if pretrain == True: update_arch = False logging.info("args = %s", str(config)) # preparation ################ torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True seed = config.seed np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) # config network and criterion ################ min_kept = int(config.batch_size * config.image_height * config.image_width // (16 * config.gt_down_sampling ** 2)) ohem_criterion = ProbOhemCrossEntropy2d(ignore_label=255, thresh=0.7, min_kept=min_kept, use_weight=False) # Model ####################################### model = Network(config.num_classes, config.layers, ohem_criterion, Fch=config.Fch, width_mult_list=config.width_mult_list, prun_modes=config.prun_modes, stem_head_width=config.stem_head_width) flops, params = profile(model, inputs=(torch.randn(1, 3, 1024, 2048),), verbose=False) logging.info("params = %fMB, FLOPs = %fGB", params / 1e6, flops / 1e9) model = model.cuda() if type(pretrain) == str: partial = torch.load(pretrain + "/weights.pt", map_location='cuda:0') state = model.state_dict() pretrained_dict = {k: v for k, v in partial.items() if k in state and state[k].size() == partial[k].size()} state.update(pretrained_dict) model.load_state_dict(state) else: init_weight(model, nn.init.kaiming_normal_, nn.BatchNorm2d, config.bn_eps, config.bn_momentum, mode='fan_in', nonlinearity='relu') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) architect = Architect(model, config) # Optimizer ################################### base_lr = config.lr parameters = [] parameters += list(model.stem.parameters()) parameters += list(model.cells.parameters()) parameters += list(model.refine32.parameters()) parameters += list(model.refine16.parameters()) parameters += list(model.head0.parameters()) parameters += list(model.head1.parameters()) parameters += list(model.head2.parameters()) parameters += list(model.head02.parameters()) parameters += list(model.head12.parameters()) optimizer = torch.optim.SGD( parameters, lr=base_lr, momentum=config.momentum, weight_decay=config.weight_decay) # lr policy ############################## lr_policy = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.978) # data loader ########################### data_setting = {'img_root': config.img_root_folder, 'gt_root': config.gt_root_folder, 'train_source': config.train_source, 'eval_source': config.eval_source, 'down_sampling': config.down_sampling} train_loader_model = get_train_loader(config, Cityscapes, portion=config.train_portion) train_loader_arch = get_train_loader(config, Cityscapes, portion=config.train_portion-1) evaluator = SegEvaluator(Cityscapes(data_setting, 'val', None), config.num_classes, config.image_mean, config.image_std, model, config.eval_scale_array, config.eval_flip, 0, config=config, verbose=False, save_path=None, show_image=False) if update_arch: for idx in range(len(config.latency_weight)): logger.add_scalar("arch/latency_weight%d"%idx, config.latency_weight[idx], 0) logging.info("arch_latency_weight%d = "%idx + str(config.latency_weight[idx])) tbar = tqdm(range(config.nepochs), ncols=80) valid_mIoU_history = []; FPSs_history = []; latency_supernet_history = []; latency_weight_history = []; valid_names = ["8s", "16s", "32s", "8s_32s", "16s_32s"] arch_names = {0: "teacher", 1: "student"} for epoch in tbar: logging.info(pretrain) logging.info(config.save) logging.info("lr: " + str(optimizer.param_groups[0]['lr'])) logging.info("update arch: " + str(update_arch)) # training tbar.set_description("[Epoch %d/%d][train...]" % (epoch + 1, config.nepochs)) train(pretrain, train_loader_model, train_loader_arch, model, architect, ohem_criterion, optimizer, lr_policy, logger, epoch, update_arch=update_arch) torch.cuda.empty_cache() lr_policy.step() # validation tbar.set_description("[Epoch %d/%d][validation...]" % (epoch + 1, config.nepochs)) with torch.no_grad(): if pretrain == True: model.prun_mode = "min" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_min_%s'%valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_min_%s %.3f"%(epoch, valid_names[i], valid_mIoUs[i])) if len(model._width_mult_list) > 1: model.prun_mode = "max" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_max_%s'%valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_max_%s %.3f"%(epoch, valid_names[i], valid_mIoUs[i])) model.prun_mode = "random" valid_mIoUs = infer(epoch, model, evaluator, logger, FPS=False) for i in range(5): logger.add_scalar('mIoU/val_random_%s'%valid_names[i], valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_random_%s %.3f"%(epoch, valid_names[i], valid_mIoUs[i])) else: valid_mIoUss = []; FPSs = [] model.prun_mode = None for idx in range(len(model._arch_names)): # arch_idx model.arch_idx = idx valid_mIoUs, fps0, fps1 = infer(epoch, model, evaluator, logger) valid_mIoUss.append(valid_mIoUs) FPSs.append([fps0, fps1]) for i in range(5): # preds logger.add_scalar('mIoU/val_%s_%s'%(arch_names[idx], valid_names[i]), valid_mIoUs[i], epoch) logging.info("Epoch %d: valid_mIoU_%s_%s %.3f"%(epoch, arch_names[idx], valid_names[i], valid_mIoUs[i])) if config.latency_weight[idx] > 0: logger.add_scalar('Objective/val_%s_8s_32s'%arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000./fps0), epoch) logging.info("Epoch %d: Objective_%s_8s_32s %.3f"%(epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[3], 1000./fps0))) logger.add_scalar('Objective/val_%s_16s_32s'%arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000./fps1), epoch) logging.info("Epoch %d: Objective_%s_16s_32s %.3f"%(epoch, arch_names[idx], objective_acc_lat(valid_mIoUs[4], 1000./fps1))) valid_mIoU_history.append(valid_mIoUss) FPSs_history.append(FPSs) if update_arch: latency_supernet_history.append(architect.latency_supernet) latency_weight_history.append(architect.latency_weight) save(model, os.path.join(config.save, 'weights.pt')) if type(pretrain) == str: # contains arch_param names: {"alphas": alphas, "betas": betas, "gammas": gammas, "ratios": ratios} for idx, arch_name in enumerate(model._arch_names): state = {} for name in arch_name['alphas']: state[name] = getattr(model, name) for name in arch_name['betas']: state[name] = getattr(model, name) for name in arch_name['ratios']: state[name] = getattr(model, name) state["mIoU02"] = valid_mIoUs[3] state["mIoU12"] = valid_mIoUs[4] if pretrain is not True: state["latency02"] = 1000. / fps0 state["latency12"] = 1000. / fps1 torch.save(state, os.path.join(config.save, "arch_%d_%d.pt"%(idx, epoch))) torch.save(state, os.path.join(config.save, "arch_%d.pt"%(idx))) if update_arch: for idx in range(len(config.latency_weight)): if config.latency_weight[idx] > 0: if (int(FPSs[idx][0] >= config.FPS_max[idx]) + int(FPSs[idx][1] >= config.FPS_max[idx])) >= 1: architect.latency_weight[idx] /= 2 elif (int(FPSs[idx][0] <= config.FPS_min[idx]) + int(FPSs[idx][1] <= config.FPS_min[idx])) > 0: architect.latency_weight[idx] *= 2 logger.add_scalar("arch/latency_weight_%s"%arch_names[idx], architect.latency_weight[idx], epoch+1) logging.info("arch_latency_weight_%s = "%arch_names[idx] + str(architect.latency_weight[idx])) def train(pretrain, train_loader_model, train_loader_arch, model, architect, criterion, optimizer, lr_policy, logger, epoch, update_arch=True): model.train() bar_format = '{desc}[{elapsed}<{remaining},{rate_fmt}]' pbar = tqdm(range(config.niters_per_epoch), file=sys.stdout, bar_format=bar_format, ncols=80) dataloader_model = iter(train_loader_model) dataloader_arch = iter(train_loader_arch) for step in pbar: optimizer.zero_grad() minibatch = dataloader_model.next() imgs = minibatch['data'] target = minibatch['label'] imgs = imgs.cuda(non_blocking=True) target = target.cuda(non_blocking=True) if update_arch: # get a random minibatch from the search queue with replacement pbar.set_description("[Arch Step %d/%d]" % (step + 1, len(train_loader_model))) minibatch = dataloader_arch.next() imgs_search = minibatch['data'] target_search = minibatch['label'] imgs_search = imgs_search.cuda(non_blocking=True) target_search = target_search.cuda(non_blocking=True) loss_arch = architect.step(imgs, target, imgs_search, target_search) if (step+1) % 10 == 0: logger.add_scalar('loss_arch/train', loss_arch, epoch*len(pbar)+step) logger.add_scalar('arch/latency_supernet', architect.latency_supernet, epoch*len(pbar)+step) loss = model._loss(imgs, target, pretrain) logger.add_scalar('loss/train', loss, epoch*len(pbar)+step) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip) optimizer.step() optimizer.zero_grad() pbar.set_description("[Step %d/%d]" % (step + 1, len(train_loader_model))) torch.cuda.empty_cache() # del loss # if update_arch: del loss_arch def infer(epoch, model, evaluator, logger, FPS=True): model.eval() mIoUs = [] for idx in range(5): evaluator.out_idx = idx # _, mIoU = evaluator.run_online() _, mIoU = evaluator.run_online_multiprocess() mIoUs.append(mIoU) if FPS: fps0, fps1 = arch_logging(model, config, logger, epoch) return mIoUs, fps0, fps1 else: return mIoUs def arch_logging(model, args, logger, epoch): input_size = (1, 3, 1024, 2048) net = Network_Multi_Path_Infer( [getattr(model, model._arch_names[model.arch_idx]["alphas"][0]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["alphas"][1]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["alphas"][2]).clone().detach()], [None, getattr(model, model._arch_names[model.arch_idx]["betas"][0]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["betas"][1]).clone().detach()], [getattr(model, model._arch_names[model.arch_idx]["ratios"][0]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["ratios"][1]).clone().detach(), getattr(model, model._arch_names[model.arch_idx]["ratios"][2]).clone().detach()], num_classes=model._num_classes, layers=model._layers, Fch=model._Fch, width_mult_list=model._width_mult_list, stem_head_width=model._stem_head_width[model.arch_idx]) plot_op(net.ops0, net.path0, F_base=args.Fch).savefig("table.png", bbox_inches="tight") logger.add_image("arch/ops0_arch%d"%model.arch_idx, np.swapaxes(np.swapaxes(plt.imread("table.png"), 0, 2), 1, 2), epoch) plot_op(net.ops1, net.path1, F_base=args.Fch).savefig("table.png", bbox_inches="tight") logger.add_image("arch/ops1_arch%d"%model.arch_idx, np.swapaxes(np.swapaxes(plt.imread("table.png"), 0, 2), 1, 2), epoch) plot_op(net.ops2, net.path2, F_base=args.Fch).savefig("table.png", bbox_inches="tight") logger.add_image("arch/ops2_arch%d"%model.arch_idx, np.swapaxes(np.swapaxes(plt.imread("table.png"), 0, 2), 1, 2), epoch) net.build_structure([2, 0]) net = net.cuda() net.eval() latency0, _ = net.forward_latency(input_size[1:]) logger.add_scalar("arch/fps0_arch%d"%model.arch_idx, 1000./latency0, epoch) logger.add_figure("arch/path_width_arch%d_02"%model.arch_idx, plot_path_width([2, 0], [net.path2, net.path0], [net.widths2, net.widths0]), epoch) net.build_structure([2, 1]) net = net.cuda() net.eval() latency1, _ = net.forward_latency(input_size[1:]) logger.add_scalar("arch/fps1_arch%d"%model.arch_idx, 1000./latency1, epoch) logger.add_figure("arch/path_width_arch%d_12"%model.arch_idx, plot_path_width([2, 1], [net.path2, net.path1], [net.widths2, net.widths1]), epoch) return 1000./latency0, 1000./latency1 if __name__ == '__main__': main(pretrain=config.pretrain)