import os import warnings import numpy as np import pandas as pd import matplotlib from matplotlib.colors import hex2color from matplotlib import font_manager from scipy.stats import gaussian_kde from cycler import cycler from mpl_toolkits.axes_grid1 import make_axes_locatable try: os.environ['DISPLAY'] except KeyError: matplotlib.use('Agg') import matplotlib.pyplot as plt with warnings.catch_warnings(): warnings.simplefilter('ignore') # catch warnings that system can't find fonts fm = font_manager.fontManager fm.findfont('Raleway') fm.findfont('Lato') warnings.filterwarnings(action="ignore", module="matplotlib", message="^tight_layout") dark_gray = '.15' _colors = ['#4C72B0', '#55A868', '#C44E52', '#8172B2', '#CCB974', '#64B5CD'] style_dictionary = { 'figure.figsize': (3, 3), 'figure.facecolor': 'white', 'figure.dpi': 200, 'savefig.dpi': 200, 'text.color': 'k', "legend.frameon": False, "legend.numpoints": 1, "legend.scatterpoints": 1, 'font.family': ['sans-serif'], 'font.serif': ['Computer Modern Roman', 'serif'], 'font.monospace': ['Inconsolata', 'Computer Modern Typewriter', 'Monaco'], 'font.sans-serif': ['Helvetica', 'Lato', 'sans-serif'], 'patch.facecolor': _colors[0], 'patch.edgecolor': 'none', 'grid.linestyle': "-", 'axes.labelcolor': dark_gray, 'axes.facecolor': 'white', 'axes.linewidth': 1., 'axes.grid': False, 'axes.axisbelow': False, 'axes.edgecolor': dark_gray, 'axes.prop_cycle': cycler('color', _colors), 'lines.solid_capstyle': 'round', 'lines.color': _colors[0], 'lines.markersize': 4, 'image.cmap': 'viridis', 'image.interpolation': 'none', 'xtick.direction': 'in', 'xtick.major.size': 4, 'xtick.minor.size': 2, 'xtick.color': dark_gray, 'ytick.direction': 'in', 'ytick.major.size': 4, 'ytick.minor.size': 2, "ytick.color": dark_gray, } matplotlib.rcParams.update(style_dictionary) def refresh_rc(): matplotlib.rcParams.update(style_dictionary) print('rcParams updated') class FigureGrid: """ Generates a grid of axes for plotting axes can be iterated over or selected by number. e.g.: >>> # iterate over axes and plot some nonsense >>> fig = FigureGrid(4, max_cols=2) >>> for i, ax in enumerate(fig): >>> plt.plot(np.arange(10) * i) >>> # select axis using indexing >>> ax3 = fig[3] >>> ax3.set_title("I'm axis 3") """ def __init__(self, n: int, max_cols=3, scale=3): """ :param n: number of axes to generate :param max_cols: maximum number of axes in a given row """ self.n = n self.nrows = int(np.ceil(n / max_cols)) self.ncols = int(min((max_cols, n))) figsize = self.ncols * scale, self.nrows * scale # create figure self.gs = plt.GridSpec(nrows=self.nrows, ncols=self.ncols) self.figure = plt.figure(figsize=figsize) # create axes self.axes = {} for i in range(n): row = int(i // self.ncols) col = int(i % self.ncols) self.axes[i] = plt.subplot(self.gs[row, col]) def __getitem__(self, item): return self.axes[item] def __iter__(self): for i in range(self.n): yield self[i] def tight_layout(self, **kwargs): """wrapper for plt.tight_layout""" self.gs.tight_layout(self.figure, **kwargs) def despine(self, top=True, right=True, bottom=False, left=False): """removes axis spines (default=remove top and right)""" despine(ax=self, top=top, right=right, bottom=bottom, left=left) def detick(self, x=True, y=True): """ removes tick labels :param x: bool, if True, remove tick labels from x-axis :param y: bool, if True, remove tick labels from y-axis """ for ax in self: detick(ax, x=x, y=y) def savefig(self, filename, pad_inches=0.1, bbox_inches='tight', *args, **kwargs): """ wrapper for savefig, including necessary paramters to avoid cut-off :param filename: str, name of output file :param pad_inches: float, number of inches to pad :param bbox_inches: str, method to use when considering bbox inches :param args: additional args for plt.savefig() :param kwargs: additional kwargs for plt.savefig() :return: """ self.figure.savefig( filename, pad_inches=pad_inches, bbox_inches=bbox_inches, *args, **kwargs) def detick(ax=None, x=True, y=True): """helper function for removing tick labels from an axis""" if not ax: ax = plt.gca() if x: ax.xaxis.set_major_locator(plt.NullLocator()) if y: ax.yaxis.set_major_locator(plt.NullLocator()) def despine(ax=None, top=True, right=True, bottom=False, left=False) -> None: """helper function for removing axis spines""" if not ax: ax = plt.gca() # set spines if top: ax.spines['top'].set_visible(False) if right: ax.spines['right'].set_visible(False) if bottom: ax.spines['bottom'].set_visible(False) if left: ax.spines['left'].set_visible(False) # set ticks if top and bottom: ax.xaxis.set_ticks_position('none') elif top: ax.xaxis.set_ticks_position('bottom') elif bottom: ax.xaxis.set_ticks_position('top') if left and right: ax.yaxis.set_ticks_position('none') elif left: ax.yaxis.set_ticks_position('right') elif right: ax.yaxis.set_ticks_position('left') def xtick_vertical(ax=None): """set xticklabels on ax to vertical instead of the horizontal default orientation""" if ax is None: ax = plt.gca() xt = ax.get_xticks() if np.all(xt.astype(int) == xt): # ax.get_xticks() returns floats xt = xt.astype(int) ax.set_xticklabels(xt, rotation='vertical') def equalize_numerical_tick_number(ax=None): if ax is None: ax = plt.gca() xticks = ax.get_xticks() yticks = ax.get_yticks() nticks = min(len(xticks), len(yticks)) ax.set_xticks(np.round(np.linspace(min(xticks), max(xticks), nticks), 1)) ax.set_yticks(np.round(np.linspace(min(yticks), max(yticks), nticks), 1)) def equalize_axis_size(ax=None): if ax is None: ax = plt.gca() xlim = ax.get_xlim() ylim = ax.get_ylim() ax_min = min(xlim[0], ylim[0]) ax_max = max(xlim[1], ylim[1]) ax.set_xlim((ax_min, ax_max)) ax.set_ylim((ax_min, ax_max)) def map_categorical_to_cmap(data: np.ndarray, cmap=plt.get_cmap()): """ create a discrete colormap from cmap appropriate for data :param data: categorical vector to map to colormap :param cmap: cmap to discretize, or 'random' :return np.ndarray, dict: vector of colors matching input data, dictionary of labels to their respective colors """ categories = np.unique(data) n = len(categories) if isinstance(cmap, str) and 'random' in cmap: colors = np.random.rand(n, 3) else: colors = cmap(np.linspace(0, 1, n)) category_to_color = dict(zip(categories, colors)) return np.array([category_to_color[i] for i in data]), category_to_color def add_legend_to_categorical_vector( colors: np.ndarray, labels, ax, loc='best', # bbox_to_anchor=(0.98, 0.5), markerscale=0.75, **kwargs): """ Add a legend to a plot where the color scale was set by discretizing a colormap. :param colors: np.ndarray, output of map_categorical_vector_to_cmap() :param labels: np.ndarray, category labels :param ax: axis on which the legend should be plotted :param kwargs: additional kwargs for legend :return: None """ artists = [] for c in colors: artists.append(plt.Line2D((0, 1), (0, 0), color=c, marker='o', linestyle='')) ax.legend( artists, labels, loc=loc, markerscale=markerscale, # bbox_to_anchor=bbox_to_anchor, **kwargs) class scatter: @staticmethod def categorical( x, y, c, ax=None, cmap=plt.get_cmap(), legend=True, legend_kwargs=None, randomize=True, remove_ticks=False, *args, **kwargs): """ wrapper for scatter wherein the output should be colored by a categorical vector c :param x, y: np.ndarray, coordinate data to be scattered :param c: categories for data :param ax: axis on which to scatter data :param cmap: color map :param legend: bool, if True, plot legend :param legend_kwargs: additional kwargs for legend :param randomize: if True, randomize order of plotting :param remove_ticks: if True, removes axes ticks and labels :param args: additional args for scatter :param kwargs: additional kwargs for scatter :return: ax """ if not ax: # todo replace with plt.gridspec() method ax = plt.gca() if legend_kwargs is None: legend_kwargs = dict() color_vector, category_to_color = map_categorical_to_cmap(c, cmap) if randomize: ind = np.random.permutation(len(x)) else: ind = np.argsort(np.ravel(c)) ax.scatter(np.ravel(x)[ind], np.ravel(y)[ind], c=color_vector[ind], *args, **kwargs) if remove_ticks: ax.xaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator()) labels, colors = zip(*sorted(category_to_color.items())) if legend: add_legend_to_categorical_vector(colors, labels, ax, markerscale=2, **legend_kwargs) return ax @staticmethod def continuous(x, y, c=None, ax=None, colorbar=True, randomize=True, remove_ticks=False, **kwargs): """ wrapper for scatter wherein the coordinates x and y are colored according to a continuous vector c :param x, y: np.ndarray, coordinate data :param c: np.ndarray, continuous vector by which to color data points :param remove_ticks: remove axis ticks and labels :param args: additional args for scatter :param kwargs: additional kwargs for scatter :return: ax """ if ax is None: ax = plt.gca() if c is None: # plot density if no color vector is provided x, y, c = scatter.density_2d(x, y) if randomize: ind = np.random.permutation(len(x)) else: ind = np.argsort(c) sm = ax.scatter(x[ind], y[ind], c=c[ind], **kwargs) if remove_ticks: ax.xaxis.set_major_locator(plt.NullLocator()) ax.yaxis.set_major_locator(plt.NullLocator()) if colorbar: cb = plt.colorbar(sm) cb.ax.xaxis.set_major_locator(plt.NullLocator()) cb.ax.yaxis.set_major_locator(plt.NullLocator()) return ax @staticmethod def density_2d(x, y): """return x and y and their density z, sorted by their density (smallest to largest) :param x, y: np.ndarray: coordinate data :return: sorted x, y, and density """ xy = np.vstack([np.ravel(x), np.ravel(y)]) z = gaussian_kde(xy)(xy) return np.ravel(x), np.ravel(y), np.arcsinh(z) def tatarize(n): """ Return n-by-3 RGB color matrix using the "tatarize" color alphabet (n <= 269) :param n: :return: """ with open(os.path.expanduser('~/.seqc/tools/tatarize_269.txt')) as f: s = f.read().split('","') s[0] = s[0].replace('{"', '') s[-1] = s[-1].replace('"}', '') s = [hex2color(s) for s in s] return s[:n] class Diagnostics: @staticmethod def mitochondrial_fraction(data: pd.DataFrame, ax=None): """plot the fraction of mRNA that are of mitochondrial origin for each cell. :param data: DataFrame of cells x genes containing gene expression information :param ax: matplotlib axis :return: ax """ mt_genes = data.molecules.columns[data.molecules.columns.str.contains('MT-')] mt_counts = data.molecules[mt_genes].sum(axis=1) library_size = data.molecules.sum(axis=1) if ax is None: ax = plt.gca() scatter.continuous(library_size, mt_counts / library_size) ax.set_title('Mitochondrial Fraction') ax.set_xlabel('Total Gene Expression') ax.set_ylabel('Mitochondrial Gene Expression') _, xmax = ax.get_xlim() ax.set_xlim((None, xmax)) _, ymax = ax.get_ylim() ax.set_ylim((None, ymax)) despine(ax) return ax @staticmethod def pca_components(fig_name, variance_ratio, pca_comps): ''' :param fig_name: name for the figure :param variance_ratio: variance ratios of at least 20 pca components :param pca_comps: pca components of cells ''' fig = FigureGrid(4, max_cols=2) ax_pca, ax_pca12, ax_pca13, ax_pca23 = iter(fig) ax_pca.plot(variance_ratio[0:20]*100.0, c = '#1f77b4') ax_pca.set_xlabel('pca components') ax_pca.set_ylabel('explained variance') ax_pca.set_xlim([0,20.5]) ax_pca12.scatter(pca_comps[:, 0], pca_comps[:, 1], s=3, c = '#1f77b4') ax_pca12.set_xlabel("pca 1") ax_pca12.set_ylabel("pca 2") xtick_vertical(ax=ax_pca12) ax_pca13.scatter(pca_comps[:, 0], pca_comps[:, 2], s=3, c = '#1f77b4') ax_pca13.set_xlabel("pca 1") ax_pca13.set_ylabel("pca 3") xtick_vertical(ax=ax_pca13) ax_pca23.scatter(pca_comps[:, 1], pca_comps[:, 2], s=3, c = '#1f77b4') ax_pca23.set_xlabel("pca 2") ax_pca23.set_ylabel("pca 3") xtick_vertical(ax=ax_pca23) fig.tight_layout() fig.savefig(fig_name, dpi=300, transparent=True) @staticmethod def phenograph_clustering(fig_name, cell_sizes, clust_info, tsne_comps): # sketching tSNE and Phenograph figure fig = FigureGrid(2, max_cols=2) ax_tsne, ax_phenograph = iter(fig) cl = np.log10(cell_sizes) splot = ax_tsne.scatter(tsne_comps[:, 0], tsne_comps[:, 1], c=cl, s=3, cmap=plt.cm.coolwarm, vmin = np.min(cl), vmax=np.percentile(cl, 98)) ax_tsne.set_title("UMI Counts (log10)") ax_tsne.set_xticks([]) ax_tsne.set_yticks([]) divider = make_axes_locatable(ax_tsne) cax = divider.append_axes('right', size='3%', pad=0.04) fig.figure.colorbar(splot, cax=cax, orientation='vertical') # this is a list of contrast colors for clutering cmap=["#010067","#D5FF00","#FF0056","#9E008E","#0E4CA1","#FFE502","#005F39","#00FF00","#95003A", "#FF937E","#A42400","#001544","#91D0CB","#620E00","#6B6882","#0000FF","#007DB5","#6A826C", "#00AE7E","#C28C9F","#BE9970","#008F9C","#5FAD4E","#FF0000","#FF00F6","#FF029D","#683D3B", "#FF74A3","#968AE8","#98FF52","#A75740","#01FFFE","#FFEEE8","#FE8900","#BDC6FF","#01D0FF", "#BB8800","#7544B1","#A5FFD2","#FFA6FE","#774D00","#7A4782","#263400","#004754","#43002C", "#B500FF","#FFB167","#FFDB66","#90FB92","#7E2DD2","#BDD393","#E56FFE","#DEFF74","#00FF78", "#009BFF","#006401","#0076FF","#85A900","#00B917","#788231","#00FFC6","#FF6E41","#E85EBE"] colors = [] for i in range(len(clust_info)): colors.append(cmap[clust_info[i]]) for ci in range(np.min(clust_info),np.max(clust_info)+1): x1 = [] y1 = [] for i in range(len(clust_info)): if clust_info[i] == ci: x1.append(tsne_comps[i, 0]) y1.append(tsne_comps[i, 1]) cl = colors[i] ax_phenograph.scatter(x1, y1, c=cl, s=3, label="C"+str(ci+1)) ax_phenograph.set_title('Phenograph Clustering') ax_phenograph.set_xticks([]) ax_phenograph.set_yticks([]) ax_phenograph.legend(bbox_to_anchor=(1, 1), loc=2, borderaxespad=0., markerscale=2) fig.tight_layout() fig.savefig(fig_name, dpi=300, transparent=True) @staticmethod def cell_size_histogram(data, f=None, ax=None, save=None): if ax is None: f, ax = plt.subplots(figsize=(3.5, 3.5)) if f is None: f = plt.gcf() cell_size = data.sum(axis=1) plt.hist(np.log10(cell_size), bins=25, log=True) ax.set_xlabel('log10(cell size)') ax.set_ylabel('frequency') despine(ax) xtick_vertical(ax) if save is not None: if not isinstance(save, str): raise TypeError('save must be the string filename of the ' 'figure-to-be-saved') plt.tight_layout() f.savefig(save, dpi=300)