#!/usr/bin/env python3 # Copyright 2018 Johns Hopkins University (author: Yiwen Shao) # Apache 2.0 import csv import argparse import numpy as np parser = argparse.ArgumentParser(description='Scroing script for dsb2018') parser.add_argument('--ground-truth', type=str, default='data/download/stage1_solution.csv', help='Ground truth segmentation result csv file') parser.add_argument('--predict', type=str, required=True, help='predicted segmentation result csv file') parser.add_argument('--result', type=str, required=True, help='the file to store final statistical results') def get_iou_from_csvs(gt_csv, pred_csv): """ This function accepts two csv filenames that have been written in a specific format. The first argument represents the groud truth csv file and the second represents the segmentation results generated by segment.py. It returns a iou dictionary, where the key is the image_id, and the value is a matrix called iou_matrix. iou_matrix is a i by j matrix, where i and j is the number of objects in the predicted mask and ground truth mask respectiveluy. So iou_matrix[i, j] is the iou of object i in predicted mask and object j in ground truth mask. """ iou_dict = {} gt_dict = read_csv_as_dict(gt_csv) pred_dict = read_csv_as_dict(pred_csv) for image_id in gt_dict: if image_id not in pred_dict: raise ValueError( 'All images need to be segmented but now miss: {}'.format(image_id)) else: gt_rles = gt_dict[image_id] pred_rles = pred_dict[image_id] iou_matrix = np.zeros((len(pred_rles), len(gt_rles))) for i, pred_rle in enumerate(pred_rles): for j, gt_rle in enumerate(gt_rles): iou_matrix[i, j] = compute_iou(pred_rle, gt_rle) iou_dict[image_id] = iou_matrix return iou_dict def read_csv_as_dict(csv_file): """ This function accepts a csv file and returns a run-length encoding (rle) dictionary, where the key is the image_id and the value is a matrix. Each row in this matrix is the rle of an object. """ rle_dict = {} with open(csv_file, 'r') as csv_fh: csv_reader = csv.reader(csv_fh) for row in csv_reader: # each row represents an object image_id = row[0] if image_id == 'ImageId': # skip header row continue encoded_pixels = row[1].split() encoded_pixels = list(map(int, encoded_pixels)) if image_id not in rle_dict: rle_dict[image_id] = [encoded_pixels] else: rle_dict[image_id].append(encoded_pixels) return rle_dict def compute_iou(gt_rle, pred_rle): """ This function accepts two rle list and returns their intersection over union (iou). """ gt_pairs = [] pred_pairs = [] gt_total_length = 0 pred_total_length = 0 for i in range(0, len(pred_rle), 2): start_position = pred_rle[i] length = pred_rle[i + 1] pred_total_length += length end_position = start_position + length pred_pairs.append((start_position, end_position)) for i in range(0, len(gt_rle), 2): start_position = gt_rle[i] length = gt_rle[i + 1] gt_total_length += length end_position = start_position + length gt_pairs.append((start_position, end_position)) intersection = 0 for (p_s, p_e) in pred_pairs: for (g_s, g_e) in gt_pairs: if p_s > g_e: continue if p_e <= g_s: break # all following g_s are larger than the current g_s if p_e > g_s and p_e <= g_e: intersection += p_e - g_s break if p_e > g_e: intersection += g_e - p_s union = gt_total_length + pred_total_length - intersection iou = float(intersection) / union return iou def statistical_hypothesis_testing(iou_matrix, threshold): """ This function accepts a iou_matrix and a threshold, and returns the score of it. The score is calculated as: TP / (TP + FP + FN) """ tp = 0 # true positive fp = 0 # false positive fn = 0 # false negative # find the max iou for each object in the predicted mask max_iou_of_pred = iou_matrix.max(axis=1) tp = sum(max_iou_of_pred >= threshold) fp = len(max_iou_of_pred) - tp # find the max iou for each object in the ground truth mask max_iou_of_gt = iou_matrix.max(axis=0) fn = sum(max_iou_of_gt < threshold) score = float(tp / (tp + fp + fn)) return score if __name__ == '__main__': args = parser.parse_args() iou_dict = get_iou_from_csvs(args.ground_truth, args.predict) stat_dict = {} thresholds = np.arange(0.5, 1, 0.05) mean_ap = 0 for image_id in iou_dict: iou_matrix = iou_dict[image_id] average_prec = 0 for thred in thresholds: average_prec += statistical_hypothesis_testing(iou_matrix, thred) average_prec /= len(thresholds) stat_dict[image_id] = average_prec mean_ap += average_prec mean_ap /= len(stat_dict) with open(args.result, 'w') as fh: fh.write('Altogether Mean Average Precision: {}\n'.format(mean_ap)) fh.write('ImageID\tAveragePrecision\n') for key in stat_dict: fh.write('{}\t{}\n'.format(key, stat_dict[key]))