import torch import torch.nn.functional as F import numpy as np from path import Path import argparse from tqdm import tqdm import imageio from models import DepthNet, PoseNet from inverse_warp import pose_vec2mat, compensate_pose, invert_mat, inverse_rotate from utils import tensor2array parser = argparse.ArgumentParser(description='Script for DispNet testing with corresponding groundTruth', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("--pretrained-depthnet", required=True, type=str, help="pretrained DispNet path") parser.add_argument("--pretrained-posenet", default=None, type=str, help="pretrained PoseNet path (for scale factor)") parser.add_argument("--img-height", default=128, type=int, help="Image height") parser.add_argument("--img-width", default=416, type=int, help="Image width") parser.add_argument("--no-resize", action='store_true', help="no resizing is done") parser.add_argument("--min-depth", default=1e-3) parser.add_argument("--max-depth", default=80, type=float) parser.add_argument("--stabilize-from-GT", action='store_true') parser.add_argument("--nominal-displacement", type=float, default=0.3) parser.add_argument("--output-dir", default='.', type=str, help="Output directory for saving") parser.add_argument("--log-best-worst", action='store_true', help="if selected, will log depthNet outputs") parser.add_argument("--save-output", action='store_true', help="if selected, will save all predictions in a big 3D numpy file") parser.add_argument("--dataset-dir", default='.', type=str, help="Dataset directory") parser.add_argument("--dataset-list", default=None, type=str, help="Dataset list file") parser.add_argument("--gt-type", default='KITTI', type=str, help="GroundTruth data type", choices=['npy', 'png', 'KITTI', 'stillbox']) parser.add_argument("--img-exts", default=['png', 'jpg', 'bmp'], nargs='*', type=str, help="images extensions to glob") parser.add_argument("--rotation-mode", default='euler', choices=['euler', 'quat'], type=str) target_mean_depthnet_output = 50 best_error = np.inf worst_error = 0 def select_best_map(maps, target_mean): unraveled_maps = maps.view(maps.size(0), -1) means = unraveled_maps.mean(1) # this should be a 1D tensor best_index = torch.min((means-target_mean).abs(), 0)[1].item() best_map = maps[best_index,0] return best_map, best_index def log_result(pred_depth, GT, input_batch, selected_index, folder, prefix): def save(path, to_save): to_save = (255*to_save.transpose(1,2,0)).astype(np.uint8) imageio.imsave(path, to_save) pred_to_save = tensor2array(pred_depth, max_value=100) gt_to_save = tensor2array(torch.from_numpy(GT), max_value=100) prefix = folder/prefix save('{}_depth_pred.jpg'.format(prefix), pred_to_save) save('{}_depth_gt.jpg'.format(prefix), gt_to_save) disp_to_save = tensor2array(1/pred_depth, max_value=None, colormap='magma') gt_disp = np.zeros_like(GT) valid_depth = GT > 0 gt_disp[valid_depth] = 1/GT[valid_depth] gt_disp_to_save = tensor2array(torch.from_numpy(gt_disp), max_value=None, colormap='magma') save('{}_disp_pred.jpg'.format(prefix), disp_to_save) save('{}_disp_gt.jpg'.format(prefix), gt_disp_to_save) to_save = tensor2array(input_batch.cpu().data[selected_index,:3]) save('{}_input0.jpg'.format(prefix), to_save) to_save = tensor2array(input_batch.cpu()[selected_index,3:]) save('{}_input1.jpg'.format(prefix), to_save) for i, batch_elem in enumerate(input_batch.cpu().data): to_save = tensor2array(batch_elem[:3]) save('{}_batch_{}_0.jpg'.format(prefix, i), to_save) to_save = tensor2array(batch_elem[3:]) save('{}_batch_{}_1.jpg'.format(prefix, i), to_save) @torch.no_grad() def main(): global best_error, worst_error device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") args = parser.parse_args() if args.gt_type == 'KITTI': from kitti_eval.depth_evaluation_utils import test_framework_KITTI as test_framework elif args.gt_type == 'stillbox': from stillbox_eval.depth_evaluation_utils import test_framework_stillbox as test_framework weights = torch.load(args.pretrained_depthnet) depthnet_params = {"depth_activation":"elu", "batch_norm":"bn" in weights.keys() and weights['bn']} if not args.no_resize: depthnet_params['input_size'] = (args.img_height, args.img_width) depthnet_params['upscale'] = True depth_net = DepthNet(**depthnet_params).to(device) depth_net.load_state_dict(weights['state_dict']) depth_net.eval() if args.pretrained_posenet is None: args.stabilize_from_GT = True print('no PoseNet specified, stab will be done from ground truth') seq_length = 5 else: weights = torch.load(args.pretrained_posenet) seq_length = int(weights['state_dict']['conv1.0.weight'].size(1)/3) posenet_params = {'seq_length':seq_length} if not args.no_resize: posenet_params['input_size'] = (args.img_eight, args.img_width) pose_net = PoseNet(**posenet_params).to(device) pose_net.load_state_dict(weights['state_dict'], strict=False) dataset_dir = Path(args.dataset_dir) if args.dataset_list is not None: with open(args.dataset_list, 'r') as f: test_files = list(f.read().splitlines()) else: test_files = [file.relpathto(dataset_dir) for file in sum([dataset_dir.files('*.{}'.format(ext)) for ext in args.img_exts], [])] framework = test_framework(dataset_dir, test_files, seq_length, args.min_depth, args.max_depth) print('{} files to test'.format(len(test_files))) errors = np.zeros((9, len(test_files)), np.float32) args.output_dir = Path(args.output_dir) args.output_dir.makedirs_p() for j, sample in enumerate(tqdm(framework)): intrinsics = torch.from_numpy(sample['intrinsics']).unsqueeze(0).to(device) imgs = sample['imgs'] imgs = [torch.from_numpy(np.transpose(img, (2,0,1))) for img in imgs] imgs = torch.stack(imgs).unsqueeze(0).to(device) imgs = 2*(imgs/255 - 0.5) tgt_img = imgs[:,sample['tgt_index']] # Construct a batch of all possible stabilized pairs, with PoseNet or with GT orientation, will take the output closest to target mean depth if args.stabilize_from_GT: poses_GT = torch.from_numpy(sample['poses']).unsqueeze(0).to(device) inv_poses_GT = invert_mat(poses_GT) tgt_pose = inv_poses_GT[:,sample['tgt_index']] inv_transform_matrices_tgt = compensate_pose(inv_poses_GT, tgt_pose) else: poses = pose_net(imgs) inv_transform_matrices = pose_vec2mat(poses, rotation_mode=args.rotation_mode) tgt_pose = inv_transform_matrices[:,sample['tgt_index']] inv_transform_matrices_tgt = compensate_pose(inv_transform_matrices, tgt_pose) stabilized_pairs = [] corresponding_displ = [] for i in range(seq_length): if i == sample['tgt_index']: continue img = imgs[:,i] img_pose = inv_transform_matrices_tgt[:,i] stab_img = inverse_rotate(img, img_pose[:,:,:3], intrinsics) pair = torch.cat([stab_img, tgt_img], dim=1) # [1, 6, H, W] stabilized_pairs.append(pair) GT_translations = sample['poses'][:,:,-1] real_displacement = np.linalg.norm(GT_translations[sample['tgt_index']] - GT_translations[i]) corresponding_displ.append(real_displacement) stab_batch = torch.cat(stabilized_pairs) # [seq, 6, H, W] depth_maps = depth_net(stab_batch) # [seq, 1 , H/4, W/4] selected_depth, selected_index = select_best_map(depth_maps, target_mean_depthnet_output) pred_depth = selected_depth * corresponding_displ[selected_index] / args.nominal_displacement if args.save_output: if j == 0: predictions = np.zeros((len(test_files), *pred_depth.shape)) predictions[j] = 1/pred_depth gt_depth = sample['gt_depth'] pred_depth_zoomed = F.interpolate(pred_depth.view(1,1,*pred_depth.shape), gt_depth.shape[:2], mode='bilinear', align_corners=False).clamp(args.min_depth, args.max_depth)[0,0] if sample['mask'] is not None: pred_depth_zoomed_masked = pred_depth_zoomed.cpu().numpy()[sample['mask']] gt_depth = gt_depth[sample['mask']] errors[:,j] = compute_errors(gt_depth, pred_depth_zoomed_masked) if args.log_best_worst: if best_error > errors[0,j]: best_error = errors[0,j] log_result(pred_depth_zoomed, sample['gt_depth'], stab_batch, selected_index, args.output_dir, 'best') if worst_error < errors[0,j]: worst_error = errors[0,j] log_result(pred_depth_zoomed, sample['gt_depth'], stab_batch, selected_index, args.output_dir, 'worst') mean_errors = errors.mean(1) error_names = ['mean_abs', 'abs_rel','abs_log','sq_rel','rms','log_rms','a1','a2','a3'] print("Results : ") print("{:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}, {:>10}".format(*error_names)) print("{:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}, {:10.4f}".format(*mean_errors)) if args.save_output: np.save(args.output_dir/'predictions.npy', predictions) def compute_errors(gt, pred): thresh = np.maximum((gt / pred), (pred / gt)) a1 = (thresh < 1.25 ).mean() a2 = (thresh < 1.25 ** 2).mean() a3 = (thresh < 1.25 ** 3).mean() mabs = np.mean(np.abs(gt - pred)) rmse = (gt - pred) ** 2 rmse = np.sqrt(rmse.mean()) rmse_log = (np.log(gt) - np.log(pred)) ** 2 rmse_log = np.sqrt(rmse_log.mean()) abs_log = np.mean(np.abs(np.log(gt) - np.log(pred))) abs_rel = np.mean(np.abs(gt - pred) / gt) sq_rel = np.mean(((gt - pred)**2) / gt) return mabs, abs_rel, abs_log, sq_rel, rmse, rmse_log, a1, a2, a3 if __name__ == '__main__': main()