import matplotlib.pyplot as plt import ete3 import numpy as np np.seterr(divide='ignore', invalid='ignore') import wquantiles import itertools,argparse,gzip,sys def getMidPos_method2(pos1, pos2, depth1, depth2, depth3): return((1.*pos2*(depth2-depth1) + pos1*(depth3-depth1))/(depth3-2*depth1+depth2)) def getNodePos(node, method = 1): children = node.get_children() if method == 1: c1,c2 = children assert len(children)==2, "Position method one only works for bifurcating nodes." return((1.*c2.pos*(c1.depth-node.depth)+c1.pos*(c2.depth-node.depth))/(c2.depth-2*node.depth+c1.depth)) elif method == 2: return(np.mean([c.pos for c in node.get_children()])) else: raise "Position method can only be 1 or 2." def drawTree(tree, leafPos = None, depthDict = None, depthRangeDict=None, extendTips=False, rootIsZero=False, show=True, posMethod=1, col="black",linewidth=2,alpha=1,direction="down",taxColDict=None): tree = tree.copy("newick") #get node depths for node in tree.traverse(): if depthDict is None: node.add_feature("depth", node.get_distance(tree)) else: node.add_feature("depth", depthDict[node.name]) #extend tips to align them if needed if extendTips: maxDP = max([l.depth for l in tree.iter_leaves()]) for l in tree.iter_leaves(): l.depth = maxDP #adjust depths so that they are aligned at zero if not rootIsZero: maxDP = max([l.depth for l in tree.iter_leaves()]) for node in tree.traverse(): node.depth -= maxDP #set leaf positions if leafPos is None: leafNames = [l.name for l in tree.get_leaves()] leafPos = dict(zip(leafNames,range(len(leafNames)))) for leaf in tree.iter_leaves(): leaf.add_feature("pos", leafPos[leaf.name]) #set positions for all other nodes relative to their children for node in tree.traverse(strategy="postorder"): if not node.is_leaf(): node.add_feature("pos", getNodePos(node,method=posMethod)) #draw for node in tree.traverse(): if direction is "down": plt.setp(plt.gca(),xticks=[]) for child in node.get_children(): plt.plot([node.pos,child.pos],[node.depth,child.depth],color=col,linewidth=linewidth,alpha=alpha, solid_capstyle="round") if depthRangeDict: plt.plot([node.pos]*2,depthRangeDict[node.name],color=col,linewidth=1,alpha=alpha, solid_capstyle="round") plt.plot([node.pos-.1,node.pos+.1],[node.depth]*2,color=col,linewidth=1,alpha=alpha, solid_capstyle="round") if node.is_leaf(): plt.text(node.pos, node.depth - 0.1, node.name, horizontalalignment='center', verticalalignment='center', color=taxColDict[node.name] if taxColDict else "black") else: plt.setp(plt.gca(),yticks=[]) for child in node.get_children(): plt.plot([node.depth,child.depth],[node.pos,child.pos],color=col,linewidth=linewidth,alpha=alpha, solid_capstyle="round") if depthRangeDict: plt.plot(depthRangeDict[node.name],[node.pos]*2,color=col,linewidth=1,alpha=alpha, solid_capstyle="round") plt.plot([node.depth]*2,[node.pos-.1,node.pos+.1],color=col,linewidth=1,alpha=alpha, solid_capstyle="round") if node.is_leaf(): plt.text(node.depth - 0.1, node.pos, node.name, horizontalalignment='left', verticalalignment='center', color=taxColDict[node.name] if taxColDict else "black") if show: plt.show() def subset(things,subLen): starts = range(0,len(things),subLen) ends = [start+subLen for start in starts] return [things[starts[i]:ends[i]] for i in range(len(starts))] def asColumn(a): return a[:,np.newaxis] def addNodeNames(tree): for n in tree.traverse(): if n.name is not None: n.name = "_".join(sorted(n.get_leaf_names())) allIsNaN = lambda x: np.all(np.isnan(x)) def normLength(tree, outgroup): l = 1.*tree.get_distance(outgroup) for node in tree.traverse(): node.dist /= l def addNodeNames(tree): for n in tree.traverse(): if n.name is "": n.name = "_".join(sorted(n.get_leaf_names())) def treeToParentChildTable(tree): return [(n.up.name,n.name,n.dist) for n in tree.traverse() if n.up is not None] def getLeafPairs(node): assert len(node.children) == 2, "Node {} does not have two children.".format(node.name) return itertools.product(node.children[0].get_leaf_names(), node.children[1].get_leaf_names()) def make2DarrayFrom1DupperTriangle(upperTriangle1D,N,includesDiagnol=False): a = np.zeros([N,N]) n=N if includesDiagnol else N-1 indices = list(np.triu_indices(n)) if not includesDiagnol: indices[1]+=1 a[indices] = a[indices[::-1]] = upperTriangle1D return a ########################## plot topos with average branch lengths parser = argparse.ArgumentParser() parser.add_argument("-w", "--weightsFiles", help="Input weights file(s) from Twisst", action = "store", nargs= "+", required = True) parser.add_argument("-d", "--distsFiles", help="Input dists file(s) from Twisst", action = "store", nargs = "+", required = True) parser.add_argument("-f", "--figFile", help="File for output figure", action = "store", required = True) parser.add_argument("--figFormat", help="Format of figFile", action = "store", default="pdf") parser.add_argument("--figSize", help="Size of figFile", action = "store", nargs=2, type=float, default=(10,10,)) parser.add_argument("--posMethod", help="Node positioning method", choices=(1,2,), type=int, action = "store", default = 2) parser.add_argument("--quantiles", help="Add quantiles for each node in tree", type=float, nargs = 2, action = "store") parser.add_argument("--plotTaxa", help="Prune tree to include on the specifed taxa", nargs = "+", action = "store") parser.add_argument("--taxOrder", choices = ("levelorder", "preorder", "postorder", "predefined"), action = "store", default="levelorder", help="How to determine order of taxa in plots") parser.add_argument("--lineWidth", help="Width for tree lines", type=float, action = "store", default= 4) parser.add_argument("--scaleLinesByWeights", help="Scale tree lines in figure by weights", action = "store_true") parser.add_argument("--orderByWeights", help="Order tree plots by weights", action = "store_true") parser.add_argument("--cols", help="Topology colours", nargs = "+", action = "store") parser.add_argument("--taxCols", help="Taxon name colours", nargs = "+", action = "store") parser.add_argument("--alpha", help="Topology alpha", type=float, action = "store", default=1.) parser.add_argument("--layout", help="Rows and columns to plot", nargs=2, type=int, action = "store") parser.add_argument("--tight", help="Pading for tight edges", nargs=2, type=float, action = "store") args = parser.parse_args() sys.stderr.write("\nReading distances file...") dists = np.vstack([np.loadtxt(f, skiprows=1) for f in args.distsFiles]) #get topologies sys.stderr.write("\nGetting topologies...") topos = [] with gzip.open(args.weightsFiles[0], "r") as wf: while True: try: topos.append(ete3.Tree(wf.readline().split()[-1])) except: break nTopos = len(topos) for t in topos: addNodeNames(t) nTaxa = len(topos[0].get_leaves()) sys.stderr.write("\nThere are {} topologies and {} taxa".format(nTopos,nTaxa)) #make a separate set of topologies for plotting plotTopos = [t.copy("newick") for t in topos] if args.plotTaxa: for t in plotTopos: t.prune(args.plotTaxa) if args.taxOrder == "predefined": assert args.plotTaxa is not None, "Predefined taxa order must be given using --plotTaxa." taxOrder = [args.plotTaxa]*nTopos else: taxOrder = [[node.name for node in topo.traverse(strategy=args.taxOrder) if node.is_leaf()] for topo in plotTopos] if args.taxCols: try: taxColDict = dict(zip(args.plotTaxa,args.taxCols)) except: raise ValueError("To plot coloured taxon labels, you must specify names of taxa using --plotTaxa") else: taxColDict = None if args.layout: nRow,nCol = args.layout elif nTopos == 3: nRow,nCol = (1,3,) elif nTopos == 15: nRow,nCol = (3,5,) elif nTopos == 105: nRow,nCol = (3,5,) else: raise ValueError("Please specify number of rows and columns in plot using --layout") #get pair names in the dists file. The order is essential here. #We use the order of the first N headers, but assume that the rest follow the same pattern. #for example, if the taxaare called A, B C and D, the headers should be: #Topo1_A_B Topo1_A_C Topo1_A_D Topo1_B_C Topo1_B_D Topo1_C_D Topo2_A_B Topo2_A_C ... ect with gzip.open(args.distsFiles[0], "r") as df: pairs = df.readline().split() pairs = [pair.split("_")[1:] for pair in pairs][:nTopos] taxonNames = pairs[0] + [pair[1] for pair in pairs[1:nTaxa]] #get columns for dists separated by topology topo_column_indices = subset(range(dists.shape[1]), dists.shape[1]/nTopos) #split topologies into a third dimension dists = np.dstack([dists[:,i] for i in topo_column_indices]) dists = np.swapaxes(dists,1,2) #set all missing dists to zero. Necessary for averaging, and doesnt impact results, because weighting for these is zero dists = np.nan_to_num(dists) ############# weights #read weights and convert to proportions sys.stderr.write("\nReading weights file...") weights = np.vstack([np.loadtxt(f, skiprows=nTopos+1) for f in args.weightsFiles]) assert weights.shape[0] == dists.shape[0] assert weights.shape[1] == dists.shape[1] #convert to proportions rowSums = np.apply_along_axis(np.sum, 1, weights) weights = weights / np.reshape(rowSums,[len(rowSums),1]) #convert any nan to zero weights = np.nan_to_num(weights) #get means meanWeights = np.apply_along_axis(np.mean, 0, weights) # now we need to get the average distance between leaves for each node in each topo # the first step here is to get the two sets of leaves that descend from each node nodes_all = [list(tree.traverse()) for tree in topos] nodeNames = [[n.name for n in nodes] for nodes in nodes_all] nodeLeafPairs = [[zip(*[(taxonNames.index(x),taxonNames.index(y),) for x,y in getLeafPairs(node)]) if not node.is_leaf() else None for node in nodes] for nodes in nodes_all] #make a detpths array that gives the depth of each node for each topo at each window depths = np.zeros([dists.shape[0], dists.shape[1], len(nodeNames[0])]) #now we go line by line, topology by topology and retrieve the depth # as the avergae pairwise distance between all leaf pairs for each node # unless the node is a leaf, in which case depth is zero. sys.stderr.write("\nComputing depth for each node for each topology for each line in input...") for x in range(depths.shape[0]): for y in range(nTopos): distMat = make2DarrayFrom1DupperTriangle(dists[x,y,:], nTaxa) for z in range(len(nodeNames[y])): depths[x,y,z] = distMat[nodeLeafPairs[y][z]].mean() if not nodes_all[y][z].is_leaf() else 0.0 #scale depths by dividing by the depth of the root #the first node in each topo is the root, as the traversal goes to the root first depths = depths / np.repeat(depths[:,:,0,np.newaxis], depths.shape[2], axis=2) #anyehere we have nan is where the root depth was zero. This happens where we had missing data. So we can set all these tree depths to zero. depths = np.nan_to_num(depths) depths_average = np.average(depths, axis = 0, weights=np.repeat(weights[:,:,np.newaxis], depths.shape[2], axis=2)) depths_median = [[wquantiles.median(depths[:,j,k], weights=weights[:,j]) for k in range(depths.shape[2])] for j in range(depths.shape[1])] if args.quantiles: depths_qL = [[wquantiles.quantile(depths[:,j,k], weights[:,j], args.quantiles[0]) for k in range(depths.shape[2])] for j in range(depths.shape[1])] depths_qU = [[wquantiles.quantile(depths[:,j,k], weights[:,j], args.quantiles[1]) for k in range(depths.shape[2])] for j in range(depths.shape[1])] #cols = np.array([ #"#2BCE48", #Green #"#005C31", #Forest #"#94FFB5", #Jade #"#9DCC00", #Lime #"#426600", #Quagmire #"#00998F", #Turquoise #"#5EF1F2", #Sky #"#0075DC", #Blue #"#003380", #Navy #"#740AFF", #Violet #"#FF5005", #Zinnia #"#F0A3FF", #Amethyst #"#FFA405", #Orpiment #"#FF0010", #Red #"#C20088"]) #Mallow cols = args.cols if args.cols else ["#000000"]*nTopos lineWidths = np.array([args.lineWidth]*nTopos, dtype=float) if args.scaleLinesByWeights: lineWidths *= (meanWeights/meanWeights.max()) plotOrder = np.argsort(meanWeights)[::-1] if args.orderByWeights else range(nTopos) sys.stderr.write("\nMaking plot.") plt.figure(figsize=args.figSize, frameon=False) for i in range(len(plotOrder)): plt.subplot(nRow,nCol,i+1) x = plotOrder[i] for y in np.arange(0,1.1,0.1): plt.plot([0,len(taxOrder[x])+1],[y,y],color="#CCCCCC") drawTree(plotTopos[x], leafPos = dict(zip(taxOrder[x],np.arange(1,len(taxOrder[x])+1))), depthDict = dict(zip(nodeNames[x],depths_median[x])), depthRangeDict = dict(zip(nodeNames[x],zip(depths_qL[x],depths_qU[x]))) if args.quantiles else None, show=False, alpha = args.alpha, posMethod = args.posMethod, linewidth = lineWidths[x], col=cols[x], taxColDict=taxColDict) axes = plt.gca() axes.set_ylim([-0.1,1.1]) axes.set_xlim([0.5,len(taxOrder[x])+.5]) axes.spines["top"].set_visible(False) axes.spines["right"].set_visible(False) axes.spines["bottom"].set_visible(False) axes.spines["left"].set_visible(False) plt.text(1,.95,"T"+str(x+1),color=cols[x], horizontalalignment='left', verticalalignment='center', size=12) #plt.show() if args.tight: plt.tight_layout(h_pad=args.tight[0], w_pad = args.tight[1]) plt.savefig(args.figFile, format=args.figFormat, figsize=args.figSize, frameon=False) plt.close() sys.stderr.write("\nDone.\n")