#!/usr/local/bin/python3 # -*- coding: utf-8 -*- import os import getpass import argparse import logging import numpy as np import matplotlib import matplotlib.pyplot as plt import torch import torch.nn as nn from torch.utils.data import DataLoader from torchvision import transforms from dataset.lip import LIP from net.pspnet import PSPNet models = { 'squeezenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='squeezenet'), 'densenet': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=1024, deep_features_size=512, backend='densenet'), 'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'), 'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'), 'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'), 'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'), 'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152') } parser = argparse.ArgumentParser(description="Pyramid Scene Parsing Network") parser.add_argument('--data-path', type=str, help='Path to dataset folder') parser.add_argument('--models-path', type=str, default='./checkpoints', help='Path for storing model snapshots') parser.add_argument('--backend', type=str, default='densenet', help='Feature extractor') parser.add_argument('--gpu', type=str, default='0', help='List of GPUs for parallel training, e.g. 0,1,2,3') parser.add_argument('--batch-size', type=int, default=1, help="Number of images sent to the network in one step.") parser.add_argument('--num-classes', type=int, default=20, help="Number of classes.") parser.add_argument('-v', '--visualize', action='store_true', help="Display output and ground truth.") args = parser.parse_args() def build_network(snapshot, backend): epoch = 0 backend = backend.lower() net = models[backend]() net = nn.DataParallel(net) if snapshot is not None: _, epoch = os.path.basename(snapshot).split('_') if not epoch == 'last': epoch = int(epoch) net.load_state_dict(torch.load(snapshot)) logging.info("Snapshot for epoch {} loaded from {}".format(epoch, snapshot)) net = net.cuda() return net, epoch def get_transform(): transform_image_list = [ transforms.Resize((256, 256), 3), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] transform_gt_list = [ transforms.Resize((256, 256), 0), transforms.Lambda(lambda img: np.asarray(img, dtype=np.uint8)), ] data_transforms = { 'img': transforms.Compose(transform_image_list), 'gt': transforms.Compose(transform_gt_list), } return data_transforms def show_image(img, pred, gt): fig, axes = plt.subplots(1, 3) ax0, ax1, ax2 = axes ax0.get_xaxis().set_ticks([]) ax0.get_yaxis().set_ticks([]) ax1.get_xaxis().set_ticks([]) ax1.get_yaxis().set_ticks([]) ax2.get_xaxis().set_ticks([]) ax2.get_yaxis().set_ticks([]) classes = np.array(('Background', # always index 0 'Hat', 'Hair', 'Glove', 'Sunglasses', 'UpperClothes', 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe', )) colormap = [(0,0,0), (1,0.25,0), (0,0.25,0), (0.5,0,0.25), (1,1,1), (1,0.75,0), (0,0,0.5), (0.5,0.25,0), (0.75,0,0.25), (1,0,0.25), (0,0.5,0), (0.5,0.5,0), (0.25,0,0.5), (1,0,0.75), (0,0.5,0.5), (0.25,0.5,0.5), (1,0,0), (1,0.25,0), (0,0.75,0), (0.5,0.75,0), ] cmap = matplotlib.colors.ListedColormap(colormap) bounds=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20] norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N) h, w, _ = pred.shape def denormalize(img, mean, std): c, _, _ = img.shape for idx in range(c): img[idx, :, :] = img[idx, :, :] * std[idx] + mean[idx] return img img = denormalize(img[0].numpy(), [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) img = img.transpose(1,2,0).reshape((h,w,3)) pred = pred.reshape((h,w)) gt = gt.reshape((h, w)) # show image ax0.set_title('img') ax0.imshow(img) ax1.set_title('pred') mappable = ax1.imshow(pred, cmap=cmap, norm=norm) ax2.set_title('gt') mappable = ax2.imshow(gt, cmap=cmap, norm=norm) # colorbar legend cbar = plt.colorbar(mappable, ax=axes, shrink=0.7,) # orientation='horizontal' cbar.ax.get_yaxis().set_ticks([]) for j, lab in enumerate(classes): cbar.ax.text(1.3, (j+0.45) / 20.0, lab, ha='left', va='center',) # fontsize=7 # cbar.ax.get_yaxis().labelpad = 5 plt.savefig(fname="result.jpg") plt.show() def get_pixel_acc(pred, gt): valid = (gt >= 0) acc_sum = (valid * (pred == gt)).sum() valid_sum = valid.sum() acc = float(acc_sum) / (valid_sum + 1e-10) return acc def get_mean_acc(pred, gt, numClass): imPred = pred.copy() imLabel = gt.copy() imPred += 1 imLabel += 1 imPred = imPred * (imLabel > 0) # Compute area intersection: intersection = imPred * (imPred == imLabel) (area_intersection, _) = np.histogram( intersection, bins=numClass, range=(1, numClass)) # Compute area label: (area_label, _) = np.histogram(imLabel, bins=numClass, range=(1, numClass)) # Remove classes from unlabeled pixels in gt image. # We should not penalize detections in unlabeled portions of the image. valid = area_label > 0 # Compute intersection over union: classes_acc = area_intersection / (area_label + 1e-10) mean_acc = np.average(classes_acc, weights=valid) return mean_acc def get_mean_IoU(pred, gt, numClass): imPred = pred.copy() imLabel = gt.copy() imPred += 1 imLabel += 1 imPred = imPred * (imLabel > 0) # Compute area intersection: intersection = imPred * (imPred == imLabel) (area_intersection, _) = np.histogram( intersection, bins=numClass, range=(1, numClass)) # Compute area union: (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) (area_label, _) = np.histogram(imLabel, bins=numClass, range=(1, numClass)) area_union = area_pred + area_label - area_intersection # Remove classes from unlabeled pixels in gt image. # We should not penalize detections in unlabeled portions of the image. valid = area_label > 0 # Compute intersection over union: IoU = area_intersection / (area_union + 1e-10) mean_IoU = np.average(IoU, weights=valid) return mean_IoU def get_mean_acc_and_IoU(pred, gt, numClass): imPred = pred.copy() imLabel = gt.copy() imPred += 1 imLabel += 1 imPred = imPred * (imLabel > 0) # Compute area intersection: intersection = imPred * (imPred == imLabel) (area_intersection, _) = np.histogram( intersection, bins=numClass, range=(1, numClass)) # Compute area union: (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) (area_label, _) = np.histogram(imLabel, bins=numClass, range=(1, numClass)) area_union = area_pred + area_label - area_intersection # Remove classes from unlabeled pixels in gt image. # We should not penalize detections in unlabeled portions of the image. valid = area_label > 0 # Compute mean acc. classes_acc = area_intersection / (area_label + 1e-10) mean_acc = np.average(classes_acc, weights=valid) # Compute intersection over union: IoU = area_intersection / (area_union + 1e-10) mean_IoU = np.average(IoU, weights=valid) return mean_acc, mean_IoU def main(): # --------------- model --------------- # snapshot = os.path.join(args.models_path, args.backend, 'PSPNet_last') net, starting_epoch = build_network(snapshot, args.backend) net.eval() # ------------ data loader ------------ # data_transform = get_transform() val_loader = DataLoader(LIP(args.data_path, train=False, transform=data_transform['img'], gt_transform=data_transform['gt']), batch_size=args.batch_size, shuffle=False, ) # --------------- eval --------------- # overall_acc_list = [] mean_acc_list = [] mean_IoU_list = [] with torch.no_grad(): for index, (img, gt) in enumerate(val_loader): pred_seg, pred_cls = net(img.cuda()) pred_seg = pred_seg[0] pred = pred_seg.cpu().numpy().transpose(1, 2, 0) pred = np.asarray(np.argmax(pred, axis=2), dtype=np.uint8).reshape((256, 256, 1)) gt = np.asarray(gt.numpy(), dtype=np.uint8).transpose(1, 2, 0) if args.visualize: show_image(img, pred, gt) overall_acc_list.append(get_pixel_acc(pred, gt)) _mean_acc, _mean_IoU = get_mean_acc_and_IoU(pred, gt, args.num_classes) mean_acc_list.append(_mean_acc) mean_IoU_list.append(_mean_IoU) print(' %d / %d ' % (index, len(val_loader))) print(' overall acc. : %f ' % (np.mean(overall_acc_list))) print(' mean acc. : %f ' % (np.mean(mean_acc_list))) print(' mean IoU : %f ' % (np.mean(mean_IoU_list))) if __name__ == '__main__': main()