#import matplotlib.mlab as mlab import matplotlib.pyplot as plt import matplotlib.patheffects as PathEffects import seaborn as sns import numpy as np from os.path import join as pjoin from config import FLAGS from sklearn.metrics import confusion_matrix, roc_curve, auc from scipy import interp ''' interpolation options: [None, 'none', 'nearest', 'bilinear', 'bicubic', 'spline16', 'spline36', 'hanning', 'hamming', 'hermite', 'kaiser', 'quadric', 'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos'] ''' def scatter(x, y, plot_name): """ Used to plot t-SNE projections """ num_colors = len(np.unique(y)) # We choose a color palette with seaborn. palette = np.array(sns.color_palette("hls", num_colors)) # We create a scatter plot. f = plt.figure(figsize=(8, 8)) ax = plt.subplot(aspect='equal') sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40, c=palette[y.astype(np.int)]) plt.xlim(-25, 25) plt.ylim(-25, 25) ax.axis('off') ax.axis('tight') # We add the labels for each digit. txts = [] for i in range(num_colors): # Position of each label. xtext, ytext = np.median(x[y == i, :], axis=0) # if np.isnan(xtext) or np.isnan(ytext): # break txt = ax.text(xtext, ytext, str(i), fontsize=24) txt.set_path_effects([ PathEffects.Stroke(linewidth=5, foreground="w"), PathEffects.Normal()]) txts.append(txt) plt.savefig(plot_name, dpi=120) plt.close() def plot_confusion_matrix(cm, target_names, title='Confusion matrix', cmap=plt.cm.BuGn): imgplot = plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.grid(False) plt.colorbar(imgplot) plt.title(title) tick_marks = np.arange(len(target_names)) plt.xticks(tick_marks, target_names, rotation=90) plt.yticks(tick_marks, target_names) plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label') plt.savefig(pjoin(FLAGS.output_dir, title.replace(' ', '_') + '_CM.png')) plt.close() def plot_roc_curve(y_pred, y_true, n_classes, title='ROC_Curve'): # Compute ROC curve and ROC area for each class fpr = dict() tpr = dict() tresholds = dict() roc_auc = dict() for i in range(n_classes): fpr[i], tpr[i], tresholds[i] = roc_curve(y_true, y_pred, pos_label=i, drop_intermediate=False) roc_auc[i] = auc(fpr[i], tpr[i]) # Compute micro-average ROC curve and ROC area # fpr["micro"], tpr["micro"], _ = roc_curve(np.asarray(y_true).ravel(), np.asarray(y_pred).ravel(), pos_label=0, drop_intermediate=True) # roc_auc["micro"] = auc(fpr["micro"], tpr["micro"]) # Aggregate all false positive rates all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)])) # print("Thresholds:") # Interpolate all ROC curves at this points mean_tpr = np.zeros_like(all_fpr) for i in range(n_classes): mean_tpr += interp(all_fpr, fpr[i], tpr[i]) # print("Class_{0}: {1}".format(i, tresholds[i])) # Average it and compute AUC mean_tpr /= n_classes fpr["macro"] = all_fpr tpr["macro"] = mean_tpr roc_auc["macro"] = auc(fpr["macro"], tpr["macro"]) # Plot all ROC curves fig = plt.figure() ax = fig.add_subplot(111) # plt.plot(fpr["micro"], tpr["micro"], # label='micro-average ROC curve (area = {0:0.2f})' # ''.format(roc_auc["micro"]), # linewidth=3, ls='--', color='red') plt.plot(fpr["macro"], tpr["macro"], label='macro-average ROC curve (area = {0:0.2f})' ''.format(roc_auc["macro"]), linewidth=3, ls='--', color='green') for i in range(n_classes): plt.plot(fpr[i], tpr[i], label='ROC curve of class {0} (area = {1:0.2f})' ''.format(i, roc_auc[i])) plt.plot([0, 1], [0, 1], 'k--', linewidth=2) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title('Multi-class Receiver Operating Characteristic') lgd = ax.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) plt.savefig(pjoin(FLAGS.output_dir, title.replace(' ', '_') + '_ROC.png'), bbox_extra_artists=(lgd,), bbox_inches='tight') plt.close() def hist_comparison(data1, data2): f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) f.suptitle('Histogram Before and After Normalization') ax1.hist(data1, 10, facecolor='green', alpha=0.75) ax1.set_xlabel("Values") ax1.set_ylabel("# of Examples") ax1.grid(True) ax2.hist(data2, 10, facecolor='green', alpha=0.75) ax2.set_xlabel("Values") ax2.grid(True) f.savefig(pjoin(FLAGS.output_dir, 'hist_comparison.png')) # plt.show() plt.close() def make_heatmap(data, name): f = plt.figure() ax1 = f.add_axes([0.1,0.1,0.8,0.8]) ax1.grid(False) imgplot = ax1.imshow(data, interpolation="none") imgplot.set_cmap('seismic') f.colorbar(imgplot) f.savefig(pjoin(FLAGS.output_dir, name + '.png')) plt.close() def make_2d_hist(data, name): f = plt.figure() X,Y = np.meshgrid(range(data.shape[0]), range(data.shape[1])) im = plt.pcolormesh(X,Y,data.transpose(), cmap='seismic') plt.colorbar(im, orientation='vertical') # plt.hexbin(data,data) # plt.show() f.savefig(pjoin(FLAGS.output_dir, name + '.png')) plt.close() # def make_2d_hexbin(data, name): # f = plt.figure() # X,Y = np.meshgrid(range(data.shape[0]), range(data.shape[1])) # plt.hexbin(X, data) # # plt.show() # f.savefig(pjoin(FLAGS.output_dir, name + '.png')) def heatmap_comparison(data1, label1, data2, label2, data3, label3): interpolation = 'none' fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True) fig.suptitle('Heatmap Comparison of Normal and Noisy Data') ax1.imshow(data3, interpolation=interpolation) ax1.set_title(label1) ax1.set_ylabel("Examples") ax1.set_xlabel("Features") ax1.set_aspect('equal') ax2.imshow(data2, interpolation=interpolation) ax2.set_title(label2) ax2.set_xlabel("Features") ax2.set_aspect('equal') ax3.imshow(data1, interpolation=interpolation) ax3.set_title(label3) ax3.set_xlabel("Features") ax3.set_aspect('equal') cax = fig.add_axes([0, 0, .1, .1]) cax.get_xaxis().set_visible(False) cax.get_yaxis().set_visible(False) cax.patch.set_alpha(0.5) cax.set_frame_on(True) # plt.colorbar(ax1, ax2, orientation='vertical') plt.show() plt.close() # # fig = plt.figure(figsize=(6, 3.2)) # # ax = fig.add_subplot(111) # ax.set_title('colorMap') # plt.imshow(data1) # ax.set_aspect('equal') # # cax = fig.add_axes([0.12, 0.1, 0.78, 0.8]) # cax.get_xaxis().set_visible(False) # cax.get_yaxis().set_visible(False) # cax.patch.set_alpha(0) # cax.set_frame_on(False) # plt.colorbar(orientation='vertical') # plt.show() #