#!/usr/bin/env python import sys import os import matplotlib.pyplot as plt import seaborn as sns from collections import defaultdict from sklearn.cluster import KMeans from scipy.stats import pearsonr,percentileofscore from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score import numpy as np import warnings import pandas as pd base = 'A' modbase = 'm6A' base_colours = {base:'#55B196', modbase:'#B4656F'} def plot_w_labels(klabels,labels,currents,strategy,kmer,pos,outdir,base_colours,train=False,alpha=1): warnings.filterwarnings("ignore", module="matplotlib") bin_labels = [1 if x == 'A' else 0 for x in labels] lstyles = {0:'-',1:'--',-1:':',2:':'} sns.set_style('white') fig = plt.figure() ax = fig.add_subplot(1,1,1) if train: ars = adjusted_rand_score(bin_labels[:-1], klabels[:-1]) if len(set(klabels)) < 4: for current,label,kl in zip(currents,labels,klabels): plt.plot(range(1,7),current,label='{}, {}'.format(label,kl),color=base_colours[label],linestyle=lstyles[kl],alpha=alpha) plt.ylabel('observed-expected current (pA)') plt.xlabel('position in kmer') handles, labels = ax.get_legend_handles_labels() hs, ls = [],[] for h,l in zip(handles, labels): if l not in set(ls): ls.append(l) hs.append(h) ax.legend(hs,ls,loc='center left', bbox_to_anchor=(1, 0.5)) title = kmer if train: title = title + ', clustered by '+strategy+"\nAdjusted Rand Index: "+str(np.round(ars,3)) plt.title(title) plt.show() plt.savefig(outdir+'/signals_cluster_'+str(pos)+'.pdf',dpi=500,bbox_inches='tight',transparent=True) if train: return ars def plot_correlation_matrix(curmat,elevenmer,labels,outdir): plt.figure(figsize=(7,6)) cg = sns.clustermap(curmat,metric='euclidean',xticklabels=labels,yticklabels=labels) plt.setp(cg.ax_heatmap.yaxis.get_majorticklabels(), rotation=0) #sns.heatmap(curmat,xticklabels=labels,yticklabels=labels) plt.title(elevenmer) plt.show() plt.savefig(outdir+'correlation_matrix_'+elevenmer+'.pdf',dpi=500,transparent=True) def plot_change_by_pos(diffs_by_context,plottype='box'): fig = plt.figure(figsize=(6,4)) changes_by_position = {'position':[],'base':[],'diff':[]} for lab in diffs_by_context: for context in diffs_by_context[lab]: for entry in diffs_by_context[lab][context]: for pos,diff in enumerate(entry[:-1]): changes_by_position['position'].append(pos+1) changes_by_position['base'].append(lab) changes_by_position['diff'].append(diff) dPos = pd.DataFrame(changes_by_position) if plottype == 'box': sns.boxplot(x="position", y="diff", hue="base", data=dPos, palette=[cols[base],cols[methbase]]) elif plottype == 'violin': sns.violinplot(x="position",y="diff", hue="base", data=dPos, palette=[cols[base],cols[methbase]]) sns.despine(trim=False) plt.xlabel('Adenine Position in 6-mer') plt.ylabel('Measured - Expected Current (pA)') plt.ylim([-20,20]) plt.legend(title='',loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3, fancybox=True) plt.savefig('change_by_position_box.pdf',transparent=True,dpi=500, bbox_inches='tight') def plot_training_probabilities(prob_scores,tb): #prob_scores = {'m6A':[0.9,0.4,...],'A':[0.1,0.5,0.2,...]} sns.set_style('darkgrid') sns.set_palette(['#55B196','#B4656F']) fig = plt.figure(figsize=(3,4)) prob_dict = {'probability':prob_scores[base]+prob_scores[modbase],'base':[base]*len(prob_scores[base])+[modbase]*len(prob_scores[modbase])} prob_db = pd.DataFrame(prob_dict) sns.boxplot(x="base", y="probability", data=prob_db) sns.despine() plt.show() plt.savefig('training_probability_'+tb+'_model_boxplot.pdf',transparent=True,dpi=500,bbox_inches='tight')