#!/usr/local/bin/python3 import os import argparse import numpy as np import torch from torch import nn from torch import optim from torch.optim.lr_scheduler import MultiStepLR from torch.utils.data import DataLoader from torchvision import transforms from dataset.lip import LIPWithClass from net.pspnet import PSPNet models = { 'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet'), 'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet'), 'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'), 'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'), 'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'), 'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'), 'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152') } parser = argparse.ArgumentParser(description="Human Parsing") parser.add_argument('--data-path', type=str, help='Path to dataset folder') parser.add_argument('--backend', type=str, default='densenet', help='Feature extractor') parser.add_argument('--snapshot', type=str, default=None, help='Path to pre-trained weights') parser.add_argument('--batch-size', type=int, default=16, help="Number of images sent to the network in one step.") parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs to run') parser.add_argument('--crop_x', type=int, default=256, help='Horizontal random crop size') parser.add_argument('--crop_y', type=int, default=256, help='Vertical random crop size') parser.add_argument('--alpha', type=float, default=1.0, help='Coefficient for classification loss term') parser.add_argument('--start-lr', type=float, default=0.001, help='Learning rate') parser.add_argument('--milestones', type=str, default='10,20,30', help='Milestones for LR decreasing') args = parser.parse_args() def build_network(snapshot, backend): epoch = 0 backend = backend.lower() net = models[backend]() net = nn.DataParallel(net) if snapshot is not None: _, epoch = os.path.basename(snapshot).split('_') epoch = int(epoch) net.load_state_dict(torch.load(snapshot)) print("Snapshot for epoch {} loaded from {}".format(epoch, snapshot)) net = net.cuda() return net, epoch def get_transform(): transform_image_list = [ transforms.Resize((256, 256), 3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] transform_gt_list = [ transforms.Resize((256, 256), 0), transforms.Lambda(lambda img: np.asarray(img, dtype=np.uint8)), ] data_transforms = { 'img': transforms.Compose(transform_image_list), 'gt': transforms.Compose(transform_gt_list), } return data_transforms def get_dataloader(): ''' To follow this training routine you need a DataLoader that yields the tuples of the following format: (Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y_cls) where x - batch of input images, y - batch of groung truth seg maps, y_cls - batch of 1D tensors of dimensionality N: N total number of classes, y_cls[i, T] = 1 if class T is present in image i, 0 otherwise ''' data_transform = get_transform() train_loader = DataLoader(LIPWithClass(root=args.data_path, transform=data_transform['img'], gt_transform=data_transform['gt']), batch_size=args.batch_size, shuffle=True, ) return train_loader if __name__ == '__main__': models_path = os.path.join('./checkpoints', args.backend) os.makedirs(models_path, exist_ok=True) train_loader = get_dataloader() net, starting_epoch = build_network(args.snapshot, args.backend) optimizer = optim.Adam(net.parameters(), lr=args.start_lr) scheduler = MultiStepLR(optimizer, milestones=[int(x) for x in args.milestones.split(',')]) for epoch in range(1+starting_epoch, 1+starting_epoch+args.epochs): seg_criterion = nn.NLLLoss(weight=None) cls_criterion = nn.BCEWithLogitsLoss(weight=None) epoch_losses = [] net.train() for count, (x, y, y_cls) in enumerate(train_loader): # input data x, y, y_cls = x.cuda(), y.cuda().long(), y_cls.cuda().float() # forward out, out_cls = net(x) seg_loss, cls_loss = seg_criterion(out, y), cls_criterion(out_cls, y_cls) loss = seg_loss + args.alpha * cls_loss # backward optimizer.zero_grad() loss.backward() optimizer.step() # print epoch_losses.append(loss.item()) status = '[{0}] step = {1}/{2}, loss = {3:0.4f} avg = {4:0.4f}, LR = {5:0.7f}'.format( epoch, count, len(train_loader), loss.item(), np.mean(epoch_losses), scheduler.get_lr()[0]) print(status) scheduler.step() if epoch % 10 == 0: torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", str(epoch)]))) torch.save(net.state_dict(), os.path.join(models_path, '_'.join(["PSPNet", 'last'])))