"""roc.py Provides functions for computing ROC AUC and plotting then saving ROC curves in a multiclass setting. Requires: NumPy, SciPy, matplotlib, scikit-learn (and their dependencies) Author: Ji-Sung Kim, Rzhetsky Lab Copyright: 2018, all rights reserved """ from __future__ import print_function import numpy as np import matplotlib matplotlib.use('Agg') # do not run X server import matplotlib.pyplot as plt plt.style.use('ggplot') from scipy import interp from sklearn.metrics import roc_curve, auc from sklearn.preprocessing import label_binarize def compute_and_plot_roc_scores(y_test, y_test_probas, num_class, path=None): """Compute ROC statistics and plot ROC curves. Arguments: y_test: [int] list of test class labels as integer indices y_test_probas: np.ndarray, float array of predicted probabilities with shape (num_sample, num_class) num_class: int number of classes path: string filepath where to save the ROC curve plot; if None will not perform plotting Returns: roc_auc_dict: {int: float} dictionary mapping classes to ROC AUC scores fpr_dict: {string: np.ndarray} dictionary mapping names of classes or an averaging method to arrays of increasing false positive rates tpr_dict: {string: float} dictionary mapping names of classes or an averaging method to arrays of increasing true positive rates """ roc_auc_dict, fpr_dict, tpr_dict = _compute_roc_stats(y_test, y_test_probas, num_class) if path is not None: _create_roc_plot(roc_auc_dict, fpr_dict, tpr_dict, num_class, path) return roc_auc_dict, fpr_dict, tpr_dict def _create_roc_plot(roc_auc_dict, fpr_dict, tpr_dict, num_class, path): """Create and save a combined ROC plot to file. Arguments: roc_auc_dict: {int: float} dictionary mapping classes to ROC AUC scores fpr_dict: {string: np.ndarray} dictionary mapping names of classes or an averaging method to arrays of increasing false positive rates tpr_dict: {string: float} dictionary mapping names of classes or an averaging method to arrays of increasing true positive rates num_class: int number of classes path: string filepath where to save the plot """ # aggregate all false positive rates all_fpr = np.unique(np.concatenate( [fpr_dict[i] for i in range(num_class)])) # interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(num_class): mean_tpr += interp(all_fpr, fpr_dict[i], tpr_dict[i]) # average and compute AUC mean_tpr /= num_class fpr_dict["macro"] = all_fpr tpr_dict["macro"] = mean_tpr roc_auc_dict["macro"] = auc(fpr_dict["macro"], tpr_dict["macro"]) # plot plt.figure() plt.plot(fpr_dict["micro"], tpr_dict["micro"], label='micro-average ROC curve (area = {0:0.2f})'.format( roc_auc_dict["micro"]), linewidth=2) plt.plot(fpr_dict["macro"], tpr_dict["macro"], label='macro-average ROC curve (area = {0:0.2f})'.format( roc_auc_dict["macro"]), linewidth=2) for i in range(num_class): plt.plot(fpr_dict[i], tpr_dict[i], label='ROC curve of class {0} (area = {1:0.2f})'.format( i, roc_auc_dict[i])) plt.plot([0, 1], [0, 1], 'k--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Some extension of Receiver operating characteristic to multi-class') plt.legend(loc="lower right") plt.savefig(path) # save plot plt.close() def _compute_roc_stats(y_test, y_test_probas, num_class): """Compute ROC AUC statistics and visualize ROC curves. Arguments: y_test: [int] list of test class labels as integer indices y_test_probas: np.ndarray, float array of predicted probabilities with shape (num_sample, num_class) num_class: int number of classes Returns: roc_auc_dict: {int: float} dictionary mapping classes to ROC AUC scores fpr_dict: {string: np.ndarray} dictionary mapping names of classes or an averaging method to arrays of increasing false positive rates tpr_dict: {string: float} dictionary mapping names of classes or an averaging method to arrays of increasing true positive rates """ y_test = label_binarize(y_test, classes=range(0, num_class)) fpr_dict, tpr_dict, roc_auc_dict = {}, {}, {} for i in range(num_class): fpr_dict[i], tpr_dict[i], _ = roc_curve( y_test[:, i], y_test_probas[:, i]) roc_auc_dict[i] = auc(fpr_dict[i], tpr_dict[i]) # Compute micro-average ROC curve and ROC area fpr_dict["micro"], tpr_dict["micro"], _ = roc_curve( y_test.ravel(), y_test_probas.ravel()) roc_auc_dict["micro"] = auc(fpr_dict["micro"], tpr_dict["micro"]) return roc_auc_dict, fpr_dict, tpr_dict