#!/usr/bin/env python # Copyright 2017 Calico LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ========================================================================= from __future__ import print_function from optparse import OptionParser import os import pdb import random import sys import h5py import numpy as np import pandas as pd import matplotlib matplotlib.use('pdf') import matplotlib.pyplot as plt from PIL import Image import seaborn as sns from basenji import plots ''' basenji_sat_plot.py Generate plots from scores HDF5 file output by saturation mutagenesis analysis via basenji_sat_bed.py ''' ################################################################################ # main ################################################################################ def main(): usage = 'usage: %prog [options] <scores_file>' parser = OptionParser(usage) parser.add_option('-a', dest='activity_enrich', default=1, type='float', help='Enrich for the most active top % of sequences [Default: %default]') parser.add_option('-f', dest='figure_width', default=20, type='float', help='Figure width [Default: %default]') parser.add_option('-g', dest='gain', default=False, action='store_true', help='Draw a sequence logo for the gain score, too [Default: %default]') parser.add_option('-l', dest='plot_len', default=300, type='int', help='Length of centered sequence to mutate [Default: %default]') parser.add_option('-m', dest='min_limit', default=0.05, type='float', help='Minimum heatmap limit [Default: %default]') parser.add_option('-o', dest='out_dir', default='sat_plot', help='Output directory [Default: %default]') parser.add_option('--png', dest='save_png', default=False, action='store_true', help='Write PNG as opposed to PDF [Default: %default]') parser.add_option('-r', dest='rng_seed', default=1, type='float', help='Random number generator seed [Default: %default]') parser.add_option('-s', dest='sample', default=None, type='int', help='Sample N sequences from the set [Default:%default]') parser.add_option('--stat', dest='sad_stat', default='sum', help='SAD stat to display [Default: %default]') parser.add_option('-t', dest='targets_file', default=None, type='str', help='File specifying target indexes and labels in table format') (options, args) = parser.parse_args() if len(args) != 1: parser.error('Must provide scores HDF5 file') else: scores_h5_file = args[0] if not os.path.isdir(options.out_dir): os.mkdir(options.out_dir) save_ext = 'pdf' if options.save_png: save_ext = 'png' np.random.seed(options.rng_seed) # open scores scores_h5 = h5py.File(scores_h5_file) # check for stat if options.sad_stat not in scores_h5: print('%s does not have key %s' % (scores_h5_file, options.sad_stat), file=sys.stderr) exit(1) # extract shapes num_seqs = scores_h5['seqs'].shape[0] mut_len = scores_h5[options.sad_stat].shape[1] if options.plot_len > mut_len: print('Decreasing plot_len=%d to maximum %d' % (options.plot_len, mut_len), file=sys.stderr) options.plot_len = mut_len # determine targets if options.targets_file is not None: targets_df = pd.read_table(options.targets_file, index_col=0) num_targets = targets_df.shape[0] else: num_targets = scores_h5[options.sad_stat].shape[-1] # determine plot region mut_mid = mut_len // 2 plot_start = mut_mid - (options.plot_len//2) plot_end = plot_start + options.plot_len # plot attributes sns.set(style='white', font_scale=1) spp = subplot_params(options.plot_len) # determine sequences seq_indexes = np.arange(num_seqs) if options.sample and options.sample < num_seqs: seq_indexes = np.random.choice(seq_indexes, size=options.sample, replace=False) for si in seq_indexes: # read sequence seq_1hot = scores_h5['seqs'][si,plot_start:plot_end] # read scores scores = scores_h5[options.sad_stat][si,plot_start:plot_end,:,:] # reference scores ref_scores = scores[seq_1hot] for tii in range(num_targets): if options.targets_file is not None: ti = targets_df.index[tii] else: ti = tii scores_ti = scores[:,:,ti] # compute scores relative to reference delta_ti = scores_ti - ref_scores[:,[ti]] # compute loss and gain delta_loss = delta_ti.min(axis=1) delta_gain = delta_ti.max(axis=1) # setup plot plt.figure(figsize=(options.figure_width, 6)) if options.gain: grid_rows = 4 else: grid_rows = 3 row_i = 0 ax_logo_loss = plt.subplot2grid( (grid_rows, spp['heat_cols']), (row_i, spp['logo_start']), colspan=spp['logo_span']) row_i += 1 if options.gain: ax_logo_gain = plt.subplot2grid( (grid_rows, spp['heat_cols']), (row_i, spp['logo_start']), colspan=spp['logo_span']) row_i += 1 ax_sad = plt.subplot2grid( (grid_rows, spp['heat_cols']), (row_i, spp['sad_start']), colspan=spp['sad_span']) row_i += 1 ax_heat = plt.subplot2grid( (grid_rows, spp['heat_cols']), (row_i, 0), colspan=spp['heat_cols']) # plot sequence logo plot_seqlogo(ax_logo_loss, seq_1hot, -delta_loss) if options.gain: plot_seqlogo(ax_logo_gain, seq_1hot, delta_gain) # plot SAD plot_sad(ax_sad, delta_loss, delta_gain) # plot heat map plot_heat(ax_heat, delta_ti.T, options.min_limit) plt.tight_layout() plt.savefig('%s/seq%d_t%d.%s' % (options.out_dir, si, ti, save_ext), dpi=600) plt.close() def enrich_activity(seqs, seqs_1hot, targets, activity_enrich, target_indexes): """ Filter data for the most active sequences in the set. """ # compute the max across sequence lengths and mean across targets seq_scores = targets[:, :, target_indexes].max(axis=1).mean( axis=1, dtype='float64') # sort the sequences scores_indexes = [(seq_scores[si], si) for si in range(seq_scores.shape[0])] scores_indexes.sort(reverse=True) # filter for the top enrich_indexes = sorted( [scores_indexes[si][1] for si in range(seq_scores.shape[0])]) enrich_indexes = enrich_indexes[:int(activity_enrich * len(enrich_indexes))] seqs = [seqs[ei] for ei in enrich_indexes] seqs_1hot = seqs_1hot[enrich_indexes] targets = targets[enrich_indexes] return seqs, seqs_1hot, targets def expand_4l(sat_lg_ti, seq_1hot): """ Expand In: sat_lg_ti (l array): Sat mut loss/gain scores for a single sequence and target. seq_1hot (Lx4 array): One-hot coding for a single sequence. Out: sat_loss_4l (lx4 array): Score-hot coding? """ # determine satmut length satmut_len = sat_lg_ti.shape[0] # jump to satmut region in one hot coded sequence ssi = int((seq_1hot.shape[0] - satmut_len) // 2) # filter sequence for satmut region seq_1hot_sm = seq_1hot[ssi:ssi + satmut_len, :] # tile loss scores to align sat_lg_tile = np.tile(sat_lg_ti, (4, 1)).T # element-wise multiple sat_lg_4l = np.multiply(seq_1hot_sm, sat_lg_tile) return sat_lg_4l def plot_heat(ax, sat_delta_ti, min_limit): """ Plot satmut deltas. Args: ax (Axis): matplotlib axis to plot to. sat_delta_ti (4 x L_sm array): Single target delta matrix for saturated mutagenesis region, min_limit (float): Minimum heatmap limit. """ vlim = max(min_limit, np.nanmax(np.abs(sat_delta_ti))) sns.heatmap( sat_delta_ti, linewidths=0, cmap='RdBu_r', vmin=-vlim, vmax=vlim, xticklabels=False, ax=ax) ax.yaxis.set_ticklabels('ACGT', rotation='horizontal') # , size=10) def plot_predictions(ax, preds, satmut_len, seq_len, buffer): """ Plot the raw predictions for a sequence and target across the specificed saturated mutagenesis region. Args: ax (Axis): matplotlib axis to plot to. preds (L array): Target predictions for one sequence. satmut_len (int): Satmut length from which to determine the values to plot. seq_len (int): Full sequence length. buffer (int): Ignored buffer sequence on each side """ # repeat preds across pool width target_pool = (seq_len - 2 * buffer) // preds.shape[0] epreds = preds.repeat(target_pool) satmut_start = (epreds.shape[0] - satmut_len) // 2 satmut_end = satmut_start + satmut_len ax.plot(epreds[satmut_start:satmut_end], linewidth=1) ax.set_xlim(0, satmut_len) ax.axhline(0, c='black', linewidth=1, linestyle='--') for axis in ['top', 'bottom', 'left', 'right']: ax.spines[axis].set_linewidth(0.5) def plot_sad(ax, sat_loss_ti, sat_gain_ti): """ Plot loss and gain SAD scores. Args: ax (Axis): matplotlib axis to plot to. sat_loss_ti (L_sm array): Minimum mutation delta across satmut length. sat_gain_ti (L_sm array): Maximum mutation delta across satmut length. """ rdbu = sns.color_palette('RdBu_r', 10) ax.plot(-sat_loss_ti, c=rdbu[0], label='loss', linewidth=1) ax.plot(sat_gain_ti, c=rdbu[-1], label='gain', linewidth=1) ax.set_xlim(0, len(sat_loss_ti)) ax.legend() # ax_sad.grid(True, linestyle=':') ax.xaxis.set_ticks([]) for axis in ['top', 'bottom', 'left', 'right']: ax.spines[axis].set_linewidth(0.5) def plot_seqlogo(ax, seq_1hot, sat_score_ti, pseudo_pct=0.05): """ Plot a sequence logo for the loss/gain scores. Args: ax (Axis): matplotlib axis to plot to. seq_1hot (Lx4 array): One-hot coding of a sequence. sat_score_ti (L_sm array): Minimum mutation delta across satmut length. pseudo_pct (float): % of the max to add as a pseudocount. """ sat_score_cp = sat_score_ti.copy() satmut_len = len(sat_score_ti) # add pseudocounts sat_score_cp += pseudo_pct * sat_score_cp.max() # expand sat_score_4l = expand_4l(sat_score_cp, seq_1hot) plots.seqlogo(sat_score_4l, ax) def subplot_params(seq_len): """ Specify subplot layout parameters for various sequence lengths. """ if seq_len < 500: spp = { 'heat_cols': 400, 'sad_start': 1, 'sad_span': 321, 'logo_start': 0, 'logo_span': 323 } else: spp = { 'heat_cols': 400, 'sad_start': 1, 'sad_span': 320, 'logo_start': 0, 'logo_span': 322 } return spp ################################################################################ # __main__ ################################################################################ if __name__ == '__main__': main()