import matplotlib.pyplot as plt import numpy as np from tabulate import tabulate ############################################ ##### Utility functions for evaluation ##### ############################################ def init_iou(im_batch, thresh): iou = dict() for ix in range(im_batch): iou[ix + 1] = dict() for k in thresh: iou[ix + 1][k] = [] return iou def update_iou(batch_iou, iou): for ix in iou.keys(): for th in iou[ix].keys(): iou[ix][th].extend(batch_iou[ix][th]) return iou def eval_seq_iou(pred, gt, im_batch, thresh=[0.1]): bs = gt.shape[0] gt = gt.astype(np.bool) iu = dict() for ix in range(im_batch): iu[ix + 1] = dict() for k in thresh: iu[ix + 1][k] = [] for bx in range(im_batch): for th in thresh: for ix in range(bs): pred_t = (pred[ix][bx] > th).astype(np.bool) i = np.sum(np.logical_and(pred_t, gt[ix])) u = np.sum(np.logical_or(pred_t, gt[ix])) thiou = float(i) / u iu[bx + 1][th].append(thiou) return iu def print_iou_stats(mids, iou, thresh, statistic='mean'): ''' mids: [(shape_id, model_id), ...] iou: {'#images': {'threshold': iou}} output: IoU Thresh: Shape_ids - mean iou''' def pline(s): return '\n' + '*' * 5 + ' ' + s + ' ' + '*' * 5 shape_ids = np.unique([m[0] for m in mids]) siou = dict() for th in thresh: siou[th] = dict() for sid in shape_ids: siou[th][sid] = dict() for ix in iou.keys(): siou[th][sid][ix] = [] for th in sorted(thresh): for mx, m in enumerate(mids): for ix in iou.keys(): siou[th][m[0]][ix].append(iou[ix][th][mx]) full_table = [] for th in sorted(thresh): full_table.append(pline('IoU Thresh: {:.1f}'.format(th))) print_table = [] for sid in shape_ids: print_table.append([sid]) for ix in sorted(iou.keys()): if statistic == 'mean': print_table[-1].append( np.array(siou[th][sid][ix]).mean() * 100) elif statistic == 'median': print_table[-1].append( np.median(np.array(siou[th][sid][ix])) * 100) full_table.append( tabulate(print_table, headers=sorted(iou.keys()), floatfmt=".2f")) return siou, '\n'.join(full_table) def vis_ims(ims, mask=None): if mask is not None: ims[np.logical_not(mask)] = None im_disp = np.reshape(ims, [-1] + list(ims.shape[2:])) im_d = np.concatenate([i for i in im_disp], axis=1) plt.imshow(np.uint8(im_d[..., 0] * 255)) plt.axis('off') def eval_l1_err(pred, gt, mask=None, vis=False): pred = pred[:, 0, ...] bs, im_batch = pred.shape[0], pred.shape[1] if mask is None: nanmask = (gt < np.max(gt)) range_mask = np.logical_and(pred > 2.0 - np.sqrt(3) * 0.5, pred < 2.0 + np.sqrt(3) * 0.5) mask = np.logical_and(nanmask, range_mask) if vis: plt.subplot(5, 1, 1) vis_ims(mask) plt.title("Eval Mask") plt.subplot(5, 1, 2) vis_ims(pred / 10.0, mask=mask) plt.title("Pred") plt.subplot(5, 1, 3) vis_ims(gt / 10.0, mask=nanmask) plt.title("Gt") plt.subplot(5, 1, 4) vis_ims(np.logical_xor(mask, nanmask)) plt.title("Gt Mask - Mask") plt.subplot(5, 1, 5) vis_ims(np.abs(pred - gt) / 10.0, mask=mask) plt.title("Masked L1 error") plt.show() l1_err = np.abs(pred - gt) l1_err_masked = np.ma.array(l1_err, mask=np.logical_not(mask)) batch_err = [] for b in range(bs): tmp = np.zeros((im_batch, )) for imb in range(im_batch): tmp[imb] = np.ma.median(l1_err_masked[b, imb]) batch_err.append(np.nanmean(tmp)) return batch_err def print_depth_stats(mids, err): shape_ids = np.unique([m[0] for m in mids]) serr = dict() for sid in shape_ids: serr[sid] = [] for ex, e in enumerate(err): serr[mids[ex][0]].append(e) table = [] smean = [] for s in serr: sm = np.nanmean(serr[s]) table.append([s, sm]) smean.append(sm) table.append(['Mean', np.nanmean(smean)]) ptable = tabulate(table, headers=['SID', 'L1 error'], floatfmt=".4f") return smean, ptable