#!/usr/bin/env python """ # Author: Xiong Lei # Created Time : Mon 09 Apr 2018 07:36:48 PM CST # File Name: plotting.py # Description: """ import numpy as np import pandas as pd import matplotlib matplotlib.use('agg') from matplotlib import pyplot as plt import seaborn as sns # import os # plt.rcParams['savefig.dpi'] = 300 # plt.rcParams['figure.dpi'] = 300 def sort_by_classes(X, y, classes): if classes is None: classes = np.unique(y) index = [] for c in classes: ind = np.where(y==c)[0] index.append(ind) index = np.concatenate(index) X = X.iloc[:, index] y = y[index] return X, y, classes, index def plot_confusion_matrix(cm, x_classes=None, y_classes=None, normalize=False, title='', cmap=plt.cm.Blues, figsize=(4,4), mark=True, save=None, rotation=45, show_cbar=True, show_xticks=True, show_yticks=True, ): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. Params: cm: confusion matrix, MxN x_classes: N y_classes: M """ import itertools from mpl_toolkits.axes_grid1.inset_locator import inset_axes if normalize: cm = cm.astype('float') / cm.sum(axis=0)[np.newaxis, :] fig = plt.figure(figsize=figsize) plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) x_tick_marks = np.arange(len(x_classes)) y_tick_marks = np.arange(len(y_classes)) plt.xticks(x_tick_marks, x_classes, rotation=rotation, ha='right') plt.yticks(y_tick_marks, y_classes) ax=plt.gca() if not show_xticks: ax.axes.get_xaxis().set_ticks([]) ax.axes.get_xaxis().set_ticklabels([]) if not show_yticks: ax.axes.get_yaxis().set_ticks([]) ax.axes.get_yaxis().set_ticklabels([]) else: plt.ylabel('Predicted Cluster') fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. if mark: for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): if cm[i, j] > 0.1: plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() if show_cbar: plt.colorbar(shrink=0.8) if save: plt.savefig(save, format='pdf', bbox_inches='tight') plt.show() def plot_heatmap(X, y, classes=None, y_pred=None, row_labels=None, colormap=None, row_cluster=False, cax_title='', xlabel='', ylabel='', yticklabels='', legend_font=10, show_legend=True, show_cax=True, tick_color='black', ncol=3, bbox_to_anchor=(0.5, 1.3), position=(0.8, 0.78, .1, .04), return_grid=False, save=None, **kw): """ plot hidden code heatmap with labels Params: X: fxn array, n is sample number, f is feature y: a array of labels for n elements or a list of array """ import matplotlib.patches as mpatches # add legend # if classes is not None: X, y, classes, index = sort_by_classes(X, y, classes) # else: # classes = np.unique(y) if y_pred is not None: y_pred = y_pred[index] classes = list(classes) + list(np.unique(y_pred)) if colormap is None: colormap = plt.cm.tab20 colors = {c:colormap(i) for i, c in enumerate(classes)} else: colors = {c:colormap[i] for i, c in enumerate(classes)} col_colors = [] col_colors.append([colors[c] for c in y]) col_colors.append([colors[c] for c in y_pred]) else: if colormap is None: colormap = plt.cm.tab20 colors = {c:colormap(i) for i, c in enumerate(classes)} else: colors = {c:colormap[i] for i, c in enumerate(classes)} col_colors = [ colors[c] for c in y ] legend_TN = [mpatches.Patch(color=color, label=c) for c, color in colors.items()] if row_labels is not None: row_colors = [ colors[c] for c in row_labels ] kw.update({'row_colors':row_colors}) kw.update({'col_colors':col_colors}) cbar_kws={"orientation": "horizontal"} grid = sns.clustermap(X, yticklabels=True, col_cluster=False, row_cluster=row_cluster, cbar_kws=cbar_kws, **kw) if show_cax: grid.cax.set_position(position) grid.cax.tick_params(length=1, labelsize=4, rotation=0) grid.cax.set_title(cax_title, fontsize=6, y=0.35) if show_legend: grid.ax_heatmap.legend(loc='upper center', bbox_to_anchor=bbox_to_anchor, handles=legend_TN, fontsize=legend_font, frameon=False, ncol=ncol) grid.ax_col_colors.tick_params(labelsize=6, length=0, labelcolor='orange') if (row_cluster==True) and (yticklabels is not ''): yticklabels = yticklabels[grid.dendrogram_row.reordered_ind] grid.ax_heatmap.set_xlabel(xlabel) grid.ax_heatmap.set_ylabel(ylabel, fontsize=8) grid.ax_heatmap.set_xticklabels('') grid.ax_heatmap.set_yticklabels(yticklabels, color=tick_color) grid.ax_heatmap.yaxis.set_label_position('left') grid.ax_heatmap.tick_params(axis='x', length=0) grid.ax_heatmap.tick_params(axis='y', labelsize=6, length=0, rotation=0, labelleft=True, labelright=False) grid.ax_row_dendrogram.set_visible(False) grid.cax.set_visible(show_cax) grid.row_color_labels = classes if save: plt.savefig(save, format='pdf', bbox_inches='tight') else: plt.show() if return_grid: return grid def plot_embedding(X, labels, classes=None, method='tSNE', cmap='tab20', figsize=(4, 4), markersize=4, marker=None, return_emb=False, save=False, save_emb=False, show_legend=True, show_axis_label=True, **legend_params): if marker is not None: X = np.concatenate([X, marker], axis=0) N = len(labels) if X.shape[1] != 2: if method == 'tSNE': from sklearn.manifold import TSNE X = TSNE(n_components=2, random_state=124).fit_transform(X) if method == 'UMAP': from umap import UMAP X = UMAP(n_neighbors=30, min_dist=0.1, metric='correlation').fit_transform(X) if method == 'PCA': from sklearn.decomposition import PCA X = PCA(n_components=2, random_state=124).fit_transform(X) plt.figure(figsize=figsize) if classes is None: classes = np.unique(labels) if cmap is not None: cmap = cmap elif len(classes) <= 10: cmap = 'tab10' elif len(classes) <= 20: cmap = 'tab20' else: cmap = 'husl' colors = sns.color_palette(cmap, n_colors=len(classes)) for i, c in enumerate(classes): plt.scatter(X[:N][labels==c, 0], X[:N][labels==c, 1], s=markersize, color=colors[i], label=c) if marker is not None: plt.scatter(X[N:, 0], X[N:, 1], s=10*markersize, color='black', marker='*') # plt.axis("off") legend_params_ = {'loc': 'center left', 'bbox_to_anchor':(1.0, 0.45), 'fontsize': 10, 'ncol': 1, 'frameon': False, 'markerscale': 1.5 } legend_params_.update(**legend_params) if show_legend: plt.legend(**legend_params_) sns.despine(offset=10, trim=True) if show_axis_label: plt.xlabel(method+' dim 1', fontsize=12) plt.ylabel(method+' dim 2', fontsize=12) if save: plt.savefig(save, format='pdf', bbox_inches='tight') else: plt.show() if save_emb: np.savetxt(save_emb, X) if return_emb: return X def corr_heatmap(X, y=None, classes=None, cmap='RdBu_r', show_legend=True, show_cbar=True, figsize=(5,5), ncol=3, distance='pearson', ticks=None, save=None, **kw): """ Plot cell-to-cell correlation matrix heatmap """ import matplotlib.patches as mpatches # add legend colormap = plt.cm.tab20 if y is not None: if classes is None: classes = np.unique(y) X, y, classes, index = sort_by_classes(X, y, classes) colors = {c:colormap(i) for i,c in enumerate(classes)} col_colors = [ colors[c] for c in y ] bbox_to_anchor = (0.4, 1.2) legend_TN = [mpatches.Patch(color=color, label=c) for c,color in colors.items()] else: col_colors = None # else: # index = np.argsort(ref) # X = X.iloc[:,index] # ref = ref[index] corr = X.corr(method=distance) cbar_kws={"orientation": "horizontal", "ticks":ticks} grid = sns.clustermap(corr, cmap=cmap, col_colors=col_colors, figsize=figsize, row_cluster=False, col_cluster=False, cbar_kws=cbar_kws, **kw ) grid.ax_heatmap.set_xticklabels('') grid.ax_heatmap.set_yticklabels('') grid.ax_heatmap.tick_params(axis='x', length=0) grid.ax_heatmap.tick_params(axis='y', length=0) if show_legend and (y is not None): grid.ax_heatmap.legend(loc='upper center', bbox_to_anchor=bbox_to_anchor, handles=legend_TN, fontsize=6, frameon=False, ncol=ncol) if show_cbar: grid.cax.set_position((0.8, 0.76, .1, .02)) grid.cax.tick_params(length=1, labelsize=4, rotation=0) grid.cax.set_title(distance, fontsize=6, y=0.8) else: grid.cax.set_visible(False) if save: plt.savefig(save, format='pdf', bbox_inches='tight') else: plt.show() def feature_specifity(feature, ref, classes, figsize=(6,6), save=None): """ Calculate the feature specifity: Input: feature: latent feature ref: cluster assignments classes: cluster classes """ from scipy.stats import f_oneway # n_cluster = max(ref) + 1 n_cluster = len(classes) dim = feature.shape[1] # feature dimension pvalue_mat = np.zeros((dim, n_cluster)) for i,cluster in enumerate(classes): for feat in range(dim): a = feature.iloc[:, feat][ref == cluster] b = feature.iloc[:, feat][ref != cluster] pvalue = f_oneway(a,b)[1] pvalue_mat[feat, i] = pvalue plt.figure(figsize=figsize) grid = sns.heatmap(-np.log10(pvalue_mat), cmap='RdBu_r', vmax=20, yticklabels=np.arange(10)+1, xticklabels=classes[:n_cluster], ) grid.set_ylabel('Feature', fontsize=18) grid.set_xticklabels(labels=classes[:n_cluster], rotation=45, fontsize=18) grid.set_yticklabels(labels=np.arange(dim)+1, fontsize=16) cbar = grid.collections[0].colorbar cbar.set_label('-log10 (Pvalue)', fontsize=18) #, rotation=0, x=-0.9, y=0) if save: plt.savefig(save, format='pdf', bbox_inches='tight') else: plt.show() import os from .utils import read_labels, reassign_cluster_with_ref from sklearn.metrics import f1_score, normalized_mutual_info_score, adjusted_rand_score def lineplot(data, name, title='', cbar=False): sns.lineplot(x='fraction', y=name, hue='method', data=data, markers=True, style='method', sort=False) plt.title(title) if cbar: plt.legend(loc='right', bbox_to_anchor=(1.25, 0.2), frameon=False) else: plt.legend().set_visible(False) plt.show() def plot_metrics(path, dataset, ref, fraction): ARI = [] NMI = [] F1 = [] methods = ['scABC', 'SC3', 'scVI', 'SCALE'] for frac in fraction: outdir = os.path.join(path, dataset, frac) #;print(outdir) scABC_pred, _ = read_labels(os.path.join(outdir, 'scABC_predict.txt')) if os.path.isfile(os.path.join(outdir, 'SC3_predict.txt')): SC3_pred, _ = read_labels(os.path.join(outdir, 'SC3_predict.txt')) else: SC3_pred = None scVI_pred, _ = read_labels(os.path.join(outdir, 'scVI_predict.txt')) scale_pred, pred_classes = read_labels(os.path.join(outdir, 'cluster_assignments.txt')) ari = [] nmi = [] f1 = [] for pred, method in zip([scABC_pred, SC3_pred, scVI_pred, scale_pred], methods): if pred is None: ari.append(0) nmi.append(0) f1.append(0) else: pred = reassign_cluster_with_ref(pred, ref) ari.append(adjusted_rand_score(ref, pred)) nmi.append(normalized_mutual_info_score(ref, pred)) f1.append(f1_score(ref, pred, average='micro')) ARI.append(ari) NMI.append(nmi) F1.append(f1) fraction = [ frac.replace('corrupt_', '') for frac in fraction] ARI = pd.Series(np.concatenate(ARI, axis=0)) NMI = pd.Series(np.concatenate(NMI, axis=0)) F1 = pd.Series(np.concatenate(F1, axis=0)) M = pd.Series(methods * len(fraction)) F = pd.Series(np.concatenate([[i]*len(methods) for i in fraction])) metrics = pd.concat([ARI, NMI, F1, M, F], axis=1) metrics.columns = ['ARI', 'NMI', 'F1', 'method', 'fraction'] lineplot(metrics, 'ARI', dataset, False) lineplot(metrics, 'NMI', dataset, False) lineplot(metrics, 'F1', dataset, True)