#!/usr/bin/env python import argparse import itertools import sys import gzip import ete3 import random import numpy as np from collections import defaultdict from collections import deque np.seterr(divide='ignore', invalid="ignore") verbose = False ############################################################################################################################## def sample(things, n = None, replace = False): if n == None: n = len(things) if replace == False: return random.sample(things,n) else: return [random.choice(things) for i in range(n)] def randomComboGen(lists): while True: yield tuple(random.choice(l) for l in lists) def readTree(newick_tree): try: if newick_tree[0] == "[": return ete3.Tree(newick_tree[newick_tree.index("]")+1:]) else: return ete3.Tree(newick_tree) except: return None def asciiTrees(trees, nColumns = 5): treeLines = [tree.get_ascii().split("\n") for tree in trees] maxLines = max(map(len,treeLines)) for tl in treeLines: #add lines if needed tl += [""]*(maxLines - len(tl)) #add spaces to each line to make all even lineLengths = map(len,tl) maxLen = max(lineLengths) for i in range(len(tl)): tl[i] += "".join([" "]*(maxLen-len(tl[i]))) #now join lines that will be on the same row and print treeLinesChunked = [treeLines[x:(x+nColumns)] for x in range(0,len(trees),nColumns)] zippedLinesChunked = [zip(*chunk) for chunk in treeLinesChunked] return "\n\n".join(["\n".join([" ".join(l) for l in chunk]) for chunk in zippedLinesChunked]) def getPrunedCopy(tree, leavesToKeep, preserve_branch_length): pruned = tree.copy("newick") ##prune function was too slow for big trees ## speeding up by first deleting all other leaves for leaf in pruned.iter_leaves(): if leaf.name not in leavesToKeep: leaf.delete(preserve_branch_length=preserve_branch_length) #and then prune to fix the root (not sure why this is necessary, but it is) #but at least it's faster than pruning the full tree pruned.prune(leavesToKeep, preserve_branch_length = preserve_branch_length) return pruned class NodeChain(deque): def __init__(self, nodeList, dists=None): super(NodeChain, self).__init__(nodeList) if dists is None: self.dists = None else: assert len(dists) == len(self)-1, "incorrect number of iternode distances" self.dists = deque(dists) self._set_ = None def addNode(self, name, dist=0): self.append(name) if self.dists is not None: self.dists.append(dist) def addNodeLeft(self, name, dist=0): self.appendleft(name) if self.dists is not None: self.dists.appendleft(dist) def addNodeChain(self, chainToAdd, joinDist=0): self.extend(chainToAdd) if self.dists is not None: assert chainToAdd.dists is not None, "Cannot add a chain without distances to one with distances" self.dists.append(joinDist) self.dists.extend(chainToAdd.dists) def addNodeChainLeft(self, chainToAdd, joinDist=0): self.extendleft(chainToAdd) if self.dists is not None: assert chainToAdd.dists is not None, "Cannot add a chain without distances to one with distances" self.dists.appendleft(joinDist) self.dists.extendleft(chainToAdd.dists) def chopLeft(self): self.popleft() if self.dists is not None: self.dists.popleft() def chop(self): self.pop() if self.dists is not None: self.dists.pop() def fuseLeft(self, chainToFuse): new = NodeChain(self, self.dists) assert new[0] == chainToFuse[0], "No common nodes" i = 1 while new[1] == chainToFuse[i]: new.chopLeft() i += 1 m = len(chainToFuse) while i < m: new.addNodeLeft(chainToFuse[i], chainToFuse.dists[i-1] if self.dists is not None else None) i += 1 return new def simplifyToEnds(self, newDist=None): if self.dists is not None: if not newDist: newDist = sum(self.dists) self.dists.clear() leftNode = self.popleft() rightNode = self.pop() self.clear() self.append(leftNode) self.append(rightNode) if self.dists is not None: self.dists.append(newDist) def setSet(self): self._set_ = set(self) ##simpler version that only collapses monophyletic clades #def getChainsToLeaves(node, collapseDict = None): #children = node.get_children() #if children == []: #node.add_feature("weight", 1) #return [NodeChain(node)] #chains = list(itertools.chain(*[getChainsToLeaves(child, collapseDict) for child in children])) #if (collapseDict and sum([len(chain) for chain in chains]) == len(chains) and #len(set([collapseDict[chain[0].name] for chain in chains])) == 1): ##all chains are a leaf from same group, so we collapse #newWeight = sum([chain[0].weight for chain in chains]) #newDist = node.dist + sum([chain[0].dist * chain[0].weight * 1. for chain in chains]) / newWeight #chains[0][0].dist = newDist #chains[0][0].weight = newWeight #chains = [chains[0]] #else: #for chain in chains: #chain.addNodeLeft(node, dist=chain[0].dist) #return chains def getChainsToLeaves(node, collapseDict = None, preserveDists = False): children = node.get_children() if children == []: #if it has no children is is a child, so just record a weight for the node and return is as a new 1-node chain chain = NodeChain([node], dists = [] if preserveDists else None) setattr(chain, "weight", 1) return [chain] #otherwise get chains for all children childrenChains = [getChainsToLeaves(child, collapseDict, preserveDists) for child in children] #now we have the chains from all children, we need to add the current node for childChains in childrenChains: for chain in childChains: chain.addNodeLeft(node, dist=chain[0].dist if preserveDists else None) #if collapsing, check groups for each node if collapseDict: nodeGroupsAll = np.array([collapseDict[chain[-1].name] for childChains in childrenChains for chain in childChains]) nodeGroups = list(set(nodeGroupsAll)) nGroups = len(nodeGroups) if (nGroups == 1 and len(nodeGroupsAll) > 1): #all leaves are from same group, so collapse to one chain #we can also preserve distances when doing this type of collapsing #first list all chains chains = [chain for childChains in childrenChains for chain in childChains] newWeight = sum([chain.weight for chain in chains]) if preserveDists: newDist = sum([sum(chain.dists) * chain.weight * 1. for chain in chains]) / newWeight chains[0].simplifyToEnds(newDist = newDist) else: chains[0].simplifyToEnds() chains[0].weight = newWeight chains = [chains[0]] elif (nGroups == 2 and len(nodeGroupsAll) > 2 and preserveDists==False): #all chains end in a leaf from one of two groups, so we can simplify. #first list all chains chains = [chain for childChains in childrenChains for chain in childChains] #Start by getting index of each chain for each group indices = [(nodeGroupsAll == group).nonzero()[0] for group in nodeGroups] #the new weight for each chain we keep will be the total node weight of all from each group newWeights = [sum([chains[i].weight for i in idx]) for idx in indices] #now reduce to just a chain for each group chains = [chains[idx[0]] for idx in indices] for j in range(nGroups): chains[j].simplifyToEnds() chains[j].weight = newWeights[j] #if we couldn't simply collapse completely, we might still be able to merge down a side branch #Side branches are child chains ending in a single leaf #If there is a lower level child branch that is itself a side branch, we can merge to it elif (preserveDists == False and len(childrenChains) == 2 and ((len(childrenChains[0]) == 1 and len(childrenChains[1]) > 1) or (len(childrenChains[1]) == 1 and len(childrenChains[0]) > 1))): chains,sideChain = (childrenChains[1],childrenChains[0][0]) if len(childrenChains[0]) == 1 else (childrenChains[0],childrenChains[1][0]) #now check if any main chain is suitable (should be length 3, and the only one that is such. and have correct group targets = (np.array([len(chain) for chain in chains]) == 3).nonzero()[0] if len(targets) == 1 and collapseDict[chains[targets[0]][-1].name] == collapseDict[sideChain[-1].name]: #we have found a suitable chain to merge to targetChain = chains[targets[0]] newWeight = targetChain.weight + sideChain.weight targetChain.simplifyToEnds() targetChain.weight = newWeight else: #if we didn't find a suitable match, just add side chain chains.append(sideChain) else: #if there was no side chain, just list all chains chains = [chain for childChains in childrenChains for chain in childChains] #otherwise we are not collapsing, so just list all chains else: #chains = list(itertools.chain(*[getChainsToLeaves(child, collapseDict) for child in children])) chains = [chain for childChains in childrenChains for chain in childChains] #now we have the chains from all children, we need to add the current node return chains #version for tree sequence tree format from msprime and tsinfer def getChainsToLeaves_ts(tree, node=None, collapseDict = None): if node is None: node = tree.root children = tree.children(node) if children == (): #if it has no children is is a child #if it's in the collapseDict or there is not collapseDict #just record a weight for the node and return is as a new 1-node chain if collapseDict is None or node in collapseDict: chain = NodeChain([node]) setattr(chain, "weight", 1) return [chain] else: return [] #otherwise get chains for all children childrenChains = [getChainsToLeaves_ts(tree, child, collapseDict) for child in children] #now we have the chains from all children, we need to add the current node for childChains in childrenChains: for chain in childChains: chain.addNodeLeft(node) #if collapsing, check groups for each node if collapseDict: nodeGroupsAll = np.array([collapseDict[chain[-1]] for childChains in childrenChains for chain in childChains]) nodeGroups = list(set(nodeGroupsAll)) nGroups = len(nodeGroups) if (nGroups == 1 and len(nodeGroupsAll) > 1) or (nGroups == 2 and len(nodeGroupsAll) > 2): #all chains end in a leaf from one or two groups, so we can simplify. #first list all chains chains = [chain for childChains in childrenChains for chain in childChains] #Start by getting index of each chain for each group indices = [(nodeGroupsAll == group).nonzero()[0] for group in nodeGroups] #the new weight for each chain we keep will be the total node weight of all from each group newWeights = [sum([chains[i].weight for i in idx]) for idx in indices] #now reduce to just a chain for each group chains = [chains[idx[0]] for idx in indices] for j in range(nGroups): chains[j].simplifyToEnds() chains[j].weight = newWeights[j] #if we couldn't simply collapse completely, we might still be able to merge down a side branch #Side branches are child chains ending in a single leaf #If there is a lower level child branch that is itself a side branch, we can merge to it elif (len(childrenChains) == 2 and ((len(childrenChains[0]) == 1 and len(childrenChains[1]) > 1) or (len(childrenChains[1]) == 1 and len(childrenChains[0]) > 1))): chains,sideChain = (childrenChains[1],childrenChains[0][0]) if len(childrenChains[0]) == 1 else (childrenChains[0],childrenChains[1][0]) #now check if any main chain is suitable (should be length 3, and the only one that is such. and have correct group targets = (np.array([len(chain) for chain in chains]) == 3).nonzero()[0] if len(targets) == 1 and collapseDict[chains[targets[0]][-1]] == collapseDict[sideChain[-1]]: #we have found a suitable internal chain to merge to targetChain = chains[targets[0]] newWeight = targetChain.weight + sideChain.weight targetChain.simplifyToEnds() targetChain.weight = newWeight else: #if we didn't find a suitable match, just add side chain chains.append(sideChain) else: #if there was no side chain, just list all chains chains = [chain for childChains in childrenChains for chain in childChains] #otherwise we are not collapsing, so just list all chains else: chains = [chain for childChains in childrenChains for chain in childChains] #now we have the chains from all children, we need to add the current node return chains def makeRootLeafChainDict(tree, collapseDict = None, preserveDists=False, treeFormat = "ete3"): if treeFormat == "ts": chains = getChainsToLeaves_ts(tree, collapseDict=collapseDict) return dict([(chain[-1],chain) for chain in chains]) else: chains = getChainsToLeaves(tree, collapseDict=collapseDict, preserveDists=preserveDists) return dict([(chain[-1].name,chain) for chain in chains]) def makeLeafLeafChainDict(rootLeafChainDict, pairs): leafLeafChainDict = defaultdict(defaultdict) for pair in pairs: #get the leaf leaf chain by removing the unshared ancestors and joining root leaf chains end to end leafLeafChainDict[pair[0]][pair[1]] = rootLeafChainDict[pair[0]].fuseLeft(rootLeafChainDict[pair[1]]) return leafLeafChainDict def checkDisjointChains(leafLeafChains, pairsOfPairs, samples=None): if not samples: return [leafLeafChains[pairs[0][0]][pairs[0][1]]._set_.isdisjoint(leafLeafChains[pairs[1][0]][pairs[1][1]]._set_) for pairs in pairsOfPairs] else: return [leafLeafChains[samples[pairs[0][0]]][samples[pairs[0][1]]]._set_.isdisjoint(leafLeafChains[samples[pairs[1][0]]][samples[pairs[1][1]]]._set_) for pairs in pairsOfPairs] def pairsDisjoint(pair1,pair2): return pair1[0] != pair2[0] and pair1[0] != pair2[1] and pair1[1] != pair2[0] and pair1[1] != pair2[1] def makeTopoDict(taxonNames, topos=None, outgroup = None): output = {} output["topos"] = allTopos(taxonNames, []) if topos is None else topos if outgroup: for topo in output["topos"]: topo.set_outgroup(outgroup) output["n"] = len(output["topos"]) pairs = list(itertools.combinations(taxonNames,2)) pairsNumeric = list(itertools.combinations(range(len(taxonNames)),2)) output["pairsOfPairs"] = [y for y in itertools.combinations(pairs,2) if pairsDisjoint(y[0],y[1])] output["pairsOfPairsNumeric"] = [y for y in itertools.combinations(pairsNumeric,2) if pairsDisjoint(y[0],y[1])] output["chainsDisjoint"] = [] for tree in output["topos"]: rootLeafChains = makeRootLeafChainDict(tree) leafLeafChains = makeLeafLeafChainDict(rootLeafChains, pairs) for pair in pairs: leafLeafChains[pair[0]][pair[1]].setSet() output["chainsDisjoint"].append(checkDisjointChains(leafLeafChains, output["pairsOfPairs"])) return output def makeGroupDict(groups, names=None): groupDict = {} for x in range(len(groups)): for y in groups[x]: groupDict[y] = x if not names else names[x] return groupDict #Main weighting function that uses "chains" to check topologies and simplifies while generating chains def weightTree(tree, taxa, taxonDict=None, pairs=None, topoDict=None, nIts=None, getDists=False, simplify=True, abortCutoff=None, treeFormat="ete3", verbose=True, taxonNames=None, outgroup=None): nTaxa = len(taxa) if not taxonDict: taxonDict = makeGroupDict(taxa) if pairs is None: pairs = [pair for taxPair in itertools.combinations(taxa,2) for pair in itertools.product(*taxPair)] rootLeafChains = makeRootLeafChainDict(tree, collapseDict=taxonDict if simplify else None, preserveDists=getDists, treeFormat=treeFormat) leavesRetained = rootLeafChains.keys() leavesRetainedSet = set(leavesRetained) leafWeights = dict([(ind, rootLeafChains[ind].weight) for ind in leavesRetained]) _pairs = [pair for pair in pairs if pair[0] in leavesRetainedSet and pair[1] in leavesRetainedSet] leafLeafChains = makeLeafLeafChainDict(rootLeafChains, pairs=_pairs) #make a set for each chain so that for pair in _pairs: leafLeafChains[pair[0]][pair[1]].setSet() if topoDict is None: if taxonNames is None: taxonNames = [str(x) for x in range(len(taxa))] topoDict = makeTopoDict(taxonNames, outgroup=outgroup) _taxa = [[ind for ind in taxon if ind in leavesRetainedSet] for taxon in taxa] if getDists: assert taxonNames is not None, "taxonNames required for recording pairwise distances" dists = np.zeros([nTaxa, nTaxa, topoDict["n"]]) #we make a generator object for all combos nCombos = np.prod([len(t) for t in _taxa]) #if not speciified assume all combinations must be considered if nIts is None: nIts = nCombos #if doing all combos, we use an exhaustive combo generator if nIts >= nCombos: if verbose: sys.stderr.write("Complete weighting for {} combinations\n".format(nCombos)) #unless there are too many combos, in which case we abort if abortCutoff and nCombos > abortCutoff: if verbose: sys.stderr.write("Aborting\n") return None comboGenerator = itertools.product(*_taxa) #if we are doing a subset, then use a random combo generator, but make sure simplify was false else: #sys.stderr.write("Approximate weighting with {} combinations\n".format(nIts)) assert not simplify, "Tree simplification should be turned off when considering only a subset of combinations." comboGenerator = randomComboGen(_taxa) #initialise counts array counts = [0]*topoDict["n"] i=0 for combo in comboGenerator: chainsDisjoint = checkDisjointChains(leafLeafChains, topoDict["pairsOfPairsNumeric"], samples=combo) x = topoDict["chainsDisjoint"].index(chainsDisjoint) comboWeight = np.prod([leafWeights[ind] for ind in combo]) counts[x] += comboWeight #get pairwise dists if necessary if getDists: comboPairs = [(combo[pair[0]], combo[pair[1]],) for pairs in topoDict["pairsOfPairsNumeric"] for pair in pairs] currentDists = np.zeros([nTaxa,nTaxa]) for comboPair in comboPairs: taxPair = (taxonNames.index(taxonDict[comboPair[0]]), taxonNames.index(taxonDict[comboPair[1]])) currentDists[taxPair[0],taxPair[1]] = currentDists[taxPair[1],taxPair[0]] = sum(leafLeafChains[comboPair[0]][comboPair[1]].dists) dists[:,:,x] += currentDists*comboWeight i += 1 if i == nIts: break meanDists = dists/counts if getDists else np.NaN return {"topos":topoDict["topos"], "weights":counts, "dists":meanDists} def weightTrees(trees, taxa=None, taxonDict=None, pairs=None, topoDict=None, nIts=None, getDists=False, simplify=True, abortCutoff=None, treeFormat="ete3", verbose=True, taxonNames=None, outgroup=None): if taxa is None: assert(treeFormat=="ts"), "Taxa must be specified as a list of lists." if taxonNames is None: taxonNames = [str(pop.id) for pop in trees.populations()] taxa = [[s for s in trees.samples() if str(trees.get_population(s)) == t] for t in taxonNames] if topoDict is None: if taxonNames is None: taxonNames = [str(x) for x in range(len(taxa))] topoDict = makeTopoDict(taxonNames, outgroup=outgroup) if not taxonDict: taxonDict = makeGroupDict(taxa, names=taxonNames) if pairs is None: pairs = [pair for taxPair in itertools.combinations(taxa,2) for pair in itertools.product(*taxPair)] _trees_ = trees.trees() if treeFormat=="ts" else trees allTreeData = [weightTree(tree, taxa, taxonDict=taxonDict, pairs=pairs, topoDict=topoDict, nIts=nIts, getDists=getDists, simplify=simplify, abortCutoff=abortCutoff, treeFormat=treeFormat, verbose=verbose) for tree in _trees_] output = {"topos":allTreeData[0]["topos"]} output["dists"] = np.array([x["dists"] for x in allTreeData]) output["weights"] = np.array([x["weights"] for x in allTreeData]) output["weights_norm"] = np.apply_along_axis(lambda x: x/x.sum(), 1, output["weights"]) return output def summary(weightsData): if "weights_norm" not in weightsData: weights = np.apply_along_axis(lambda x: x/x.sum(), 1, weightsData["weights"]) else: weights =weightsData["weights_norm"] meanWeights = weights.mean(axis=0) for i in range(len(meanWeights)): print("Topo", i+1) print(weightsData["topos"][i].get_ascii()) print(round(meanWeights[i],3)) print("\n\n") def listToNwk(t): t = str(t) t = t.replace("[","(") t = t.replace("]",")") t = t.replace("'","") t += ";" return(t) def allTopos(branches, _topos=None, _topo_IDs=None): if _topos is None or _topo_IDs is None: _topos = [] _topo_IDs = set([]) assert 4 <= len(branches) <= 8, "Please specify between 4 and 8 unique taxon names." #print("topos contains", len(_topos), "topologies.") #print("current tree is:", branches) for x in range(len(branches)-1): for y in range(x+1,len(branches)): #print("Joining branch", x, branches[x], "with branch", y, branches[y]) new_branches = list(branches) new_branches[x] = [new_branches[x],new_branches.pop(y)] #print("New tree is:", new_branches) if len(new_branches) == 3: #print("Tree has three branches, so appending to topos.") #now check that the topo doesn't match a topology already in trees, and if not add it t = ete3.Tree(listToNwk(new_branches)) ID = t.get_topology_id() if ID not in _topo_IDs: _topos.append(t) _topo_IDs.add(ID) else: #print("Tree still unresolved, so re-calling function.") _topos = allTopos(new_branches, _topos, _topo_IDs) #print(_topo_IDs) return(_topos) def writeWeights(filename, weightsData): nTopos = len(weightsData["topos"]) with gzip.open(filename, "wt") if filename.endswith(".gz") else open(filename, "wt") as weightsFile: #write topologies for x in range(nTopos): weightsFile.write("#topo" + str(x+1) + " " + weightsData["topos"][x].write(format = 9) + "\n") #write headers weightsFile.write("\t".join(["topo" + str(x+1) for x in range(nTopos)]) + "\n") #write weights weightsFile.write("\n".join(["\t".join(row) for row in weightsData["weights"].astype(str)]) + "\n") def writeTsWindowData(filename, ts): with open("filename", "wt") as dataFile: dataFile.write("chrom\tstart\tend\n") dataFile.write("\n".join(["\t".join(["chr1", str(tree.interval[0]), str(tree.interval[1])]) for tree in ts.trees()]) + "\n") ################################################################################################################################# if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-t", "--treeFile", help="File containing tree(s) to analyse", action = "store") parser.add_argument("-w", "--weightsFile", help="Output file of all weights", action = "store") parser.add_argument("-D", "--distsFile", help="Output file of mean pairwise dists", action = "store", required = False) parser.add_argument("--inputTopos", help="Input file for user-defined topologies (optional)", action = "store", required = False) parser.add_argument("--outputTopos", help="Output file for topologies used", action = "store", required = False) parser.add_argument("--outgroup", help="Outgroup for rooting - only affects speed", action = "store") parser.add_argument("--method", help="Tree sampling method", choices=["fixed", "complete"], action = "store", default = "complete") parser.add_argument("--iterations", help="Number of iterations for fixed partial sampling", type=int, action = "store", default = 10000) parser.add_argument("--abortCutoff", help="# tips in simplified tree to abort 'complete' weighting", type=int, action = "store", default = 100000) parser.add_argument("-g", "--group", help="Group name and individual names (separated by commas)", action='append', nargs="+", required = True, metavar=("name","[inds]")) parser.add_argument("--groupsFile", help="Optional file of sample names and groups", action = "store", required = False) parser.add_argument("--verbose", help="Verbose output", action="store_true") parser.add_argument("--skip_simplify", help="", action="store_true") parser.add_argument("--silent", help="No stderr output", action="store_true") args = parser.parse_args() #args = parser.parse_args("-t examples/msms_4of10_l1Mb_r10k_sweep.seq_gen.SNP.w50sites.phyml_bionj.trees.gz -g A 1,2,3,4,5,6,7,8,9,10 -g B 11,12,13,14,15,16,17,18,19,20 -g C 21,22,23,24,25,26,27,28,29,30 -g D 31,32,33,34,35,36,37,38,39,40".split()) getDists = args.distsFile is not None method = args.method ################################################################################################################################# #parse taxa assert len(args.group) >= 4, "Please specify at least four groups." taxonNames = [] taxa = [] for g in args.group: taxonNames.append(g[0]) if len(g) > 1: taxa.append(g[1].split(",")) else: taxa.append([]) if args.groupsFile: with open(args.groupsFile, "rt") as gf: groupDict = dict([ln.split() for ln in gf.readlines()]) for sample in groupDict.keys(): try: taxa[taxonNames.index(groupDict[sample])].append(sample) except: pass nTaxa=len(taxa) assert min([len(t) for t in taxa]) >= 1, "Please specify at least one sample name per group." names = [t for taxon in taxa for t in taxon] namesSet = set(names) assert len(names) == len(namesSet), "Each sample should only be in one group." taxonDict = makeGroupDict(taxa, taxonNames) #get all topologies if args.inputTopos: with open(args.inputTopos, "rt") as tf: topos = [ete3.Tree(ln) for ln in tf.readlines()] else: topos = None topoDict = makeTopoDict(taxonNames, topos, args.outgroup if args.outgroup else None) nTopos = topoDict["n"] if not args.silent: sys.stderr.write(asciiTrees(topoDict["topos"],5) + "\n") if args.outputTopos: with open(args.outputTopos, "wt") as tf: tf.write("\n".join([t.write(format = 9) for t in topoDict["topos"]]) + "\n") pairs = [pair for taxPair in itertools.combinations(taxa,2) for pair in itertools.product(*taxPair)] ################################################################################################################################# ### file for weights if args.weightsFile: weightsFile = gzip.open(args.weightsFile, "wt") if args.weightsFile.endswith(".gz") else open(args.weightsFile, "wt") else: weightsFile = sys.stdout for x in range(nTopos): weightsFile.write("#topo" + str(x+1) + " " + topoDict["topos"][x].write(format = 9) + "\n") weightsFile.write("\t".join(["topo" + str(x+1) for x in range(nTopos)]) + "\n") ### file for lengths if getDists: if args.distsFile[-3:] == ".gz": distsFile = gzip.open(args.distsFile, "wt") else: distsFile = open(args.distsFile, "wt") for x in range(nTopos): distsFile.write("\t".join(["topo" + str(x+1) + "_" + "_".join(pair) for pair in itertools.combinations(taxonNames,2)]) + "\t") distsFile.write("\n") ################################################################################################################################ #open tree file if args.treeFile: treeFile = gzip.open(args.treeFile, "rt") if args.treeFile.endswith(".gz") else open(args.treeFile, "rt") else: treeFile = sys.stdin ################################################################################################################################ nTrees = 0 for line in treeFile: tree = readTree(line) if tree: #remove unneccesary leaves (speeds up downstream steps) leafNamesSet = set([leaf.name for leaf in tree.get_leaves()]) if namesSet != leafNamesSet: assert namesSet.issubset(leafNamesSet), "Named samples not present in tree:" + " ".join(list(namesSet.difference(leafNamesSet))) tree = getPrunedCopy(tree, leavesToKeep=names, preserve_branch_length=True) #root tree (this only helps speed up analysis, but does not change results) if args.outgroup: tree.set_outgroup(taxa[taxonNames.index(args.outgroup)][-1]) weightsData = None if method == "complete": weightsData = weightTree(tree=tree, taxa=taxa, taxonDict=taxonDict, pairs=pairs, topoDict=topoDict, getDists=getDists, simplify=not args.skip_simplify, abortCutoff=args.abortCutoff, verbose=args.verbose, taxonNames=taxonNames) if method == "fixed" or weightsData == None: weightsData = weightTree(tree=tree, taxa=taxa, taxonDict=taxonDict, pairs=pairs, topoDict=topoDict, nIts=args.iterations, getDists=getDists, simplify=False, verbose=args.verbose, taxonNames=taxonNames) weightsLine = "\t".join([str(x) for x in weightsData["weights"]]) if getDists: distsByTopo = [] for x in range(nTopos): distsByTopo.append("\t".join([str(round(weightsData["dists"][pair[0],pair[1],x], 4)) for pair in itertools.combinations(range(nTaxa), 2)])) distsLine = "\t".join(distsByTopo) else: if not args.silent: sys.stderr.write("Warning - failed to read tree.\n") weightsLine = "\t".join(["nan"]*nTopos) if getDists: distsLine = "\t".join(["nan"]*nTopos*len(list(itertools.combinations(range(nTaxa), 2)))) weightsFile.write(weightsLine + "\n") if getDists: distsFile.write(distsLine + "\n") nTrees += 1 if not args.silent: print(".", end="", file=sys.stderr, flush=True) if nTrees % 100 == 0: sys.stderr.write(str(nTrees)+"\n") treeFile.close() weightsFile.close() if getDists: distsFile.close() if not args.silent: sys.stderr.write(str(nTrees)+"\nDone.\n") sys.exit() #############################################################################################################################################