# Copyright 2017 Calico LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ========================================================================= from __future__ import print_function import sys import matplotlib matplotlib.use('agg') import matplotlib.pyplot as plt import seaborn as sns import numpy as np from scipy.stats import spearmanr, pearsonr ################################################################################ # scatter plots def jointplot(vals1, vals2, out_pdf, alpha=0.5, point_size=10, square=False, cor='pearsonr', x_label=None, y_label=None, figsize=(6, 6), sample=None, table=False, kind='scatter', text_means=False, tight=False, outlier_low=None, outlier_high=None): if table: out_txt = '%s.txt' % out_pdf[:-4] out_open = open(out_txt, 'w') for i in range(len(vals1)): print(vals1[i], vals2[i], file=out_open) out_open.close() if sample is not None and sample < len(vals1): indexes = np.random.choice(np.arange(0, len(vals1)), sample, replace=False) vals1 = vals1[indexes] vals2 = vals2[indexes] if type(figsize) == tuple: if figsize[0] != figsize[1]: print('Figure size must be square', file=sys.stderr) figsize = figsize[0] plt.figure() if cor is None: cor_func = None elif cor.lower() in ['spearman', 'spearmanr']: cor_func = spearmanr elif cor.lower() in ['pearson', 'pearsonr']: cor_func = pearsonr else: cor_func = None if kind == 'hex': joint_kws = {} elif kind == 'scatter': joint_kws = {'alpha':alpha, 's':point_size} else: gold = sns.color_palette('husl',8)[1] joint_kws = {} joint_kws['scatter_kws'] = {'color':'black', 's':point_size, 'alpha':alpha} joint_kws['line_kws'] = {'color':gold} # compute summary stat pre-filter u1 = np.mean(vals1) u2 = np.mean(vals2) # filter outliers for aesthetic purposes if outlier_low is not None: vals1 = vals1[vals1 > outlier_low] vals2 = vals2[vals2 > outlier_low] if outlier_high is not None: vals1 = vals1[vals1 < outlier_high] vals2 = vals2[vals2 < outlier_high] assert(len(vals1) > 0) assert(len(vals1) == len(vals2)) g = sns.jointplot(vals1, vals2, color='black', height=figsize, space=0, stat_func=cor_func, kind=kind, joint_kws=joint_kws) ax = g.ax_joint if square: vmin, vmax = scatter_lims(vals1, vals2) xmin = vmin ymin = vmin xmax = vmax ymax = vmax ax.plot([vmin, vmax], [vmin, vmax], linestyle='--', color='black') else: xmin, xmax = scatter_lims(vals1) ymin, ymax = scatter_lims(vals2) ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) if y_label is not None: ax.set_ylabel(y_label) if x_label is not None: ax.set_xlabel(x_label) if text_means: eps = .05 text_xeps = eps*(xmax-xmin) test_yeps = eps*(ymax-ymin) # ax.text(xmax+text_xeps, ymin-test_yeps, 'mean %.3f'%u1, horizontalalignment='right', fontsize=14) # ax.text(xmin-text_xeps, ymax+test_yeps, 'mean %.3f'%u2, horizontalalignment='left', fontsize=14) ax.text(1-eps, eps, 'Mean %.3f'%u1, horizontalalignment='right', transform=ax.transAxes) ax.text(eps, 1-eps, 'Mean %.3f'%u2, verticalalignment='top', transform=ax.transAxes) # ax.grid(True, linestyle=':') if tight: plt.tight_layout(w_pad=0, h_pad=0) plt.savefig(out_pdf) plt.close() def regplot(vals1, vals2, out_pdf, poly_order=1, alpha=0.5, point_size=10, colors=None, cor='pearsonr', print_sig=False, square=False, x_label=None, y_label=None, title=None, figsize=(6, 6), sample=None, table=False, tight=False): if table: out_txt = '%s.txt' % out_pdf[:-4] out_open = open(out_txt, 'w') for i in range(len(vals1)): print(vals1[i], vals2[i], file=out_open) out_open.close() if sample is not None and sample < len(vals1): indexes = np.random.choice(np.arange(0, len(vals1)), sample, replace=False) vals1 = vals1[indexes] vals2 = vals2[indexes] plt.figure(figsize=figsize) gold = sns.color_palette('husl', 8)[1] if colors is None: ax = sns.regplot(vals1, vals2, color='black', order=poly_order, scatter_kws={'color': 'black', 's': point_size, 'alpha': alpha}, line_kws={'color': gold}) else: plt.scatter(vals1, vals2, c=colors, s=point_size, alpha=alpha, cmap='RdBu') plt.colorbar() ax = sns.regplot(vals1, vals2, scatter=False, order=poly_order, line_kws={'color':gold}) if square: xmin, xmax = scatter_lims(vals1, vals2) ymin, ymax = xmin, xmax else: xmin, xmax = scatter_lims(vals1) ymin, ymax = scatter_lims(vals2) ax.set_xlim(xmin, xmax) ax.set_ylim(ymin, ymax) if x_label is not None: ax.set_xlabel(x_label) if y_label is not None: ax.set_ylabel(y_label) if title is not None: plt.title(title) if cor is None: corr = None elif cor.lower() in ['spearman', 'spearmanr']: corr, csig = spearmanr(vals1, vals2) corr_str = 'SpearmanR: %.3f' % corr elif cor.lower() in ['pearson', 'pearsonr']: corr, csig = pearsonr(vals1, vals2) corr_str = 'PearsonR: %.3f' % corr else: corr = None if print_sig: if csig > .001: corr_str += '\n p %.3f' % csig else: corr_str += '\n p %.1e' % csig if corr is not None: xlim_eps = (xmax - xmin) * .03 ylim_eps = (ymax - ymin) * .05 ax.text( xmin + xlim_eps, ymax - 2 * ylim_eps, corr_str, horizontalalignment='left', fontsize=12) # ax.grid(True, linestyle=':') sns.despine() if tight: plt.tight_layout() plt.savefig(out_pdf) plt.close() def scatter_lims(vals1, vals2=None, buffer=.05): if vals2 is not None: vals = np.concatenate((vals1, vals2)) else: vals = vals1 vmin = np.nanmin(vals) vmax = np.nanmax(vals) buf = .05 * (vmax - vmin) if vmin == 0: vmin -= buf / 2 else: vmin -= buf vmax += buf return vmin, vmax ################################################################################ # nucleotides # Thanks to Anshul Kundaje, Avanti Shrikumar # https://github.com/kundajelab/deeplift/tree/master/deeplift/visualization def plot_a(ax, base, left_edge, height, color): a_polygon_coords = [ np.array([[0.0, 0.0], [0.5, 1.0], [0.5, 0.8], [0.2, 0.0]]), np.array([[1.0, 0.0], [0.5, 1.0], [0.5, 0.8], [0.8, 0.0]]), np.array([[0.225, 0.45], [0.775, 0.45], [0.85, 0.3], [0.15, 0.3]]) ] for polygon_coords in a_polygon_coords: ax.add_patch( matplotlib.patches.Polygon( (np.array([1, height])[None, :] * polygon_coords + np.array( [left_edge, base])[None, :]), facecolor=color, edgecolor=color)) def plot_c(ax, base, left_edge, height, color): ax.add_patch( matplotlib.patches.Ellipse( xy=[left_edge + 0.65, base + 0.5 * height], width=1.3, height=height, facecolor=color, edgecolor=color)) ax.add_patch( matplotlib.patches.Ellipse( xy=[left_edge + 0.65, base + 0.5 * height], width=0.7 * 1.3, height=0.7 * height, facecolor='white', edgecolor='white')) ax.add_patch( matplotlib.patches.Rectangle( xy=[left_edge + 1, base], width=1.0, height=height, facecolor='white', edgecolor='white', fill=True)) def plot_g(ax, base, left_edge, height, color): ax.add_patch( matplotlib.patches.Ellipse( xy=[left_edge + 0.65, base + 0.5 * height], width=1.3, height=height, facecolor=color, edgecolor=color)) ax.add_patch( matplotlib.patches.Ellipse( xy=[left_edge + 0.65, base + 0.5 * height], width=0.7 * 1.3, height=0.7 * height, facecolor='white', edgecolor='white')) ax.add_patch( matplotlib.patches.Rectangle( xy=[left_edge + 1, base], width=1.0, height=height, facecolor='white', edgecolor='white', fill=True)) ax.add_patch( matplotlib.patches.Rectangle( xy=[left_edge + 0.825, base + 0.085 * height], width=0.174, height=0.415 * height, facecolor=color, edgecolor=color, fill=True)) ax.add_patch( matplotlib.patches.Rectangle( xy=[left_edge + 0.625, base + 0.35 * height], width=0.374, height=0.15 * height, facecolor=color, edgecolor=color, fill=True)) def plot_t(ax, base, left_edge, height, color): ax.add_patch( matplotlib.patches.Rectangle( xy=[left_edge + 0.4, base], width=0.2, height=height, facecolor=color, edgecolor=color, fill=True)) ax.add_patch( matplotlib.patches.Rectangle( xy=[left_edge, base + 0.8 * height], width=1.0, height=0.2 * height, facecolor=color, edgecolor=color, fill=True)) ################################################################################ # sequences default_colors = {0: 'red', 1: 'blue', 2: 'orange', 3: 'green'} default_plot_funcs = {0: plot_a, 1: plot_c, 2: plot_g, 3: plot_t} def seqlogo(seq_scores, ax=None): if ax is None: ax = plt.gca() colors = ['red', 'blue', 'orange', 'green'] plot_funcs = [plot_a, plot_c, plot_g, plot_t] seq_len = seq_scores.shape[0] seq_depth = seq_scores.shape[1] max_height = 0 for li in range(seq_len): # sort nucleotides by score pos_scores = sorted([(seq_scores[li, ni], ni) for ni in range(seq_depth)]) # maintain current height current_height = 0 # for each nucleotide for di in range(seq_depth): score, ni = pos_scores[di] if score > 0: # plot nucleotide plot_funcs[ni]( ax=ax, base=current_height, left_edge=li, height=score, color=colors[ni]) # update height current_height += score # update max height max_height = max(max_height, current_height) # adjust limits xbuf = .005 * seq_len ax.set_xlim(0, seq_len + xbuf) ybuf = .05 * max_height ax.set_ylim(-ybuf, max_height + ybuf) # adjust line widths for axis in ['top', 'bottom', 'left', 'right']: ax.spines[axis].set_linewidth(0.5)