import argparse import Bio import Bio.Phylo import gzip import os, json, sys import pandas as pd import subprocess import shlex from contextlib import contextmanager from treetime.utils import numeric_date from collections import defaultdict from pkg_resources import resource_stream from io import TextIOWrapper from .__version__ import __version__ import packaging.version as packaging_version from .validate import validate, ValidateError, load_json_schema from augur.util_support.date_disambiguator import DateDisambiguator from augur.util_support.metadata_file import MetadataFile from augur.util_support.shell_command_runner import ShellCommandRunner class AugurException(Exception): pass @contextmanager def open_file(fname, mode): """Open a file using either gzip.open() or open() depending on file name. Semantics identical to open()""" if fname.endswith('.gz'): if "t" not in mode: # For interoperability, gzip needs to open files in "text" mode mode = mode + "t" with gzip.open(fname, mode, encoding='utf-8') as fh: yield fh else: with open(fname, mode, encoding='utf-8') as fh: yield fh def is_vcf(fname): """Convenience method to check if a file is a vcf file. >>> is_vcf("./foo") False >>> is_vcf("./foo.vcf") True >>> is_vcf("./foo.vcf.GZ") True """ return fname.lower().endswith(".vcf") or fname.lower().endswith(".vcf.gz") def myopen(fname, mode): if fname.endswith('.gz'): import gzip return gzip.open(fname, mode, encoding='utf-8') else: return open(fname, mode, encoding='utf-8') def get_json_name(args, default=None): if args.output_node_data: return args.output_node_data else: if default: print("WARNING: no name for the output file was specified. Writing results to %s."%default, file=sys.stderr) return default else: raise ValueError("Please specify a name for the JSON file containing the results.") def ambiguous_date_to_date_range(uncertain_date, fmt, min_max_year=None): return DateDisambiguator(uncertain_date, fmt=fmt, min_max_year=min_max_year).range() def read_metadata(fname, query=None): return MetadataFile(fname, query).read() def get_numerical_dates(meta_dict, name_col = None, date_col='date', fmt=None, min_max_year=None): if fmt: from datetime import datetime numerical_dates = {} for k,m in meta_dict.items(): v = m[date_col] if type(v)!=str: print("WARNING: %s has an invalid data string:"%k,v) continue elif 'XX' in v: ambig_date = ambiguous_date_to_date_range(v, fmt, min_max_year) if ambig_date is None or None in ambig_date: numerical_dates[k] = [None, None] #don't send to numeric_date or will be set to today else: numerical_dates[k] = [numeric_date(d) for d in ambig_date] else: try: numerical_dates[k] = numeric_date(datetime.strptime(v, fmt)) except: numerical_dates[k] = None else: numerical_dates = {k:float(v) for k,v in meta_dict.items()} return numerical_dates class InvalidTreeError(Exception): """Represents an error loading a phylogenetic tree from a filename. """ pass def read_tree(fname, min_terminals=3): """Safely load a tree from a given filename or raise an error if the file does not contain a valid tree. Parameters ---------- fname : str name of a file containing a phylogenetic tree min_terminals : int minimum number of terminals required for the parsed tree as a sanity check on the tree Raises ------ InvalidTreeError If the given file exists but does not seem to contain a valid tree format. Returns ------- Bio.Phylo : BioPython tree instance """ T = None supported_tree_formats = ["newick", "nexus"] for fmt in supported_tree_formats: try: T = Bio.Phylo.read(fname, fmt) # Check the sanity of the parsed tree to handle cases when non-tree # data are still successfully parsed by BioPython. Too few terminals # in a tree indicates that the input is not valid. if T.count_terminals() < min_terminals: T = None else: break except ValueError: # We cannot open the tree in the current format, so we will try # another. pass # If the tree cannot be loaded, raise an error to that effect. if T is None: raise InvalidTreeError( "Could not read the given tree %s using the following supported formats: %s" % (fname, ", ".join(supported_tree_formats)) ) return T def read_node_data(fnames, tree=None): """ parses one or more "node-data" JSON files and combines them using custom logic. Will exit with a (hopefully) helpful message if errors are detected. For each JSON, we expect the top-level key "nodes" to be a dict. Generated-by fields will not be included in the returned dict of this function. """ if isinstance(fnames, str): fnames = [fnames] node_data = {"nodes": {}} for fname in fnames: if os.path.isfile(fname): with open(fname, encoding='utf-8') as jfile: tmp_data = json.load(jfile) if tmp_data.get("annotations"): try: validate(tmp_data.get("annotations"), load_json_schema("schema-annotations.json"), fname) except ValidateError as err: print("{} contains an `annotations` block of an invalid JSON format. " "Was it produced by different version of augur the one you are currently using ({})? " "Please check the script / program which produced that JSON file.".format(fname, get_augur_version())) print(err) sys.exit(2) try: for k,v in tmp_data.items(): if k=="nodes": if not isinstance(v, dict): raise AugurException("\"nodes\" key in {} is not a dictionary. Please check the formatting of this JSON!".format(fname)) for n,nv in v.items(): if n in node_data["nodes"]: node_data["nodes"][n].update(nv) else: node_data["nodes"][n] = nv elif k=="generated_by": # Note that this key is _not_ part of the dict returned from this fn. if v.get("program") == "augur" and not is_augur_version_compatable(v.get("version")): # check that the augur version, if provided, is compatible. # ignore version checking of non-augur produced JSONs raise AugurException("Augur version incompatability detected -- the JSON {} was generated by augur version {} which is " "incompatable with the current augur version ({}). We suggest you rerun the pipeline using the current " "version of augur.".format(fname, v.get("version"), get_augur_version())) elif k in node_data: # Behavior as of 2019-11-07 is to do a top-level merge # of dictionaries. If the value is not a dictionary, we # now have a fatal error with a nice message (note that # before 2019-11-07 this was an unhandled error). # This should be revisited in the future. TODO. if isinstance(node_data[k], dict) and isinstance(v, dict): node_data[k].update(v) else: raise AugurException("\"{}\" key found in multiple JSONs. This is not currently handled by augur, " "unless all values are dictionaries. " "Please check the source of these JSONs.".format(k)) else: node_data[k]=v except AugurException as e: print(e) sys.exit(2) else: print("ERROR: node_data JSON file %s not found. Attempting to proceed without it."%fname) if tree and os.path.isfile(tree): try: T = Bio.Phylo.read(tree, 'newick') except: print("Failed to read tree from file "+tree, file=sys.stderr) else: tree_node_names = set([l.name for l in T.find_clades()]) meta_node_names = set(node_data["nodes"].keys()) if tree_node_names!=meta_node_names: print("Names of nodes (including internal nodes) of tree %s don't" " match node names in the node data files."%tree, file=sys.stderr) return node_data def write_json(data, file_name, indent=(None if os.environ.get("AUGUR_MINIFY_JSON") else 2), include_version=True): """ Write ``data`` as JSON to the given ``file_name``, creating parent directories if necessary. The augur version is included as a top-level key "augur_version". Parameters ---------- data : dict data to write out to JSON file_name : str file name to write to indent : int or None, optional JSON indentation level. Default is `None` if the environment variable `AUGUR_MINIFY_JSON` is truthy, else 1 include_version : bool, optional Include the augur version. Default: `True`. Raises ------ OSError """ #in case parent folder does not exist yet parent_directory = os.path.dirname(file_name) if parent_directory and not os.path.exists(parent_directory): try: os.makedirs(parent_directory) except OSError: #Guard against race condition if not os.path.isdir(parent_directory): raise if include_version: data["generated_by"] = {"program": "augur", "version": get_augur_version()} with open(file_name, 'w', encoding='utf-8') as handle: json.dump(data, handle, indent=indent, sort_keys=True) def load_features(reference, feature_names=None): #read in appropriately whether GFF or Genbank #checks explicitly for GFF otherwise assumes Genbank if not os.path.isfile(reference): print("ERROR: reference sequence not found. looking for", reference) return None features = {} if '.gff' in reference.lower(): #looks for 'gene' and 'gene' as best for TB try: from BCBio import GFF #Package name is confusing - tell user exactly what they need! except ImportError: print("ERROR: Package BCBio.GFF not found! Please install using \'pip install bcbio-gff\' before re-running.") return None limit_info = dict( gff_type = ['gene'] ) with open(reference, encoding='utf-8') as in_handle: for rec in GFF.parse(in_handle, limit_info=limit_info): for feat in rec.features: if feature_names is not None: #check both tags; user may have used either if "gene" in feat.qualifiers and feat.qualifiers["gene"][0] in feature_names: fname = feat.qualifiers["gene"][0] elif "locus_tag" in feat.qualifiers and feat.qualifiers["locus_tag"][0] in feature_names: fname = feat.qualifiers["locus_tag"][0] else: fname = None else: if "gene" in feat.qualifiers: fname = feat.qualifiers["gene"][0] else: fname = feat.qualifiers["locus_tag"][0] if fname: features[fname] = feat if feature_names is not None: for fe in feature_names: if fe not in features: print("Couldn't find gene {} in GFF or GenBank file".format(fe)) else: from Bio import SeqIO for feat in SeqIO.read(reference, 'genbank').features: if feat.type=='CDS': if "locus_tag" in feat.qualifiers: fname = feat.qualifiers["locus_tag"][0] if feature_names is None or fname in feature_names: features[fname] = feat elif "gene" in feat.qualifiers: fname = feat.qualifiers["gene"][0] if feature_names is None or fname in feature_names: features[fname] = feat elif feat.type=='source': #read 'nuc' as well for annotations - need start/end of whole! features['nuc'] = feat return features def read_config(fname): if not (fname and os.path.isfile(fname)): print("ERROR: config file %s not found."%fname) return defaultdict(dict) try: with open(fname, 'rb') as ifile: config = json.load(ifile) except json.decoder.JSONDecodeError as err: print("FATAL ERROR:") print("\tCouldn't parse the JSON file {}".format(fname)) print("\tError message: '{}'".format(err.msg)) print("\tLine number: '{}'".format(err.lineno)) print("\tColumn number: '{}'".format(err.colno)) print("\tYou must correct this file in order to proceed.") sys.exit(2) return config def read_lat_longs(overrides=None, use_defaults=True): coordinates = {} # TODO: make parsing of tsv files more robust while allow for whitespace delimiting for backwards compatibility def add_line_to_coordinates(line): if line.startswith('#') or line.strip() == "": return fields = line.strip().split() if not '\t' in line else line.strip().split('\t') if len(fields) == 4: geo_field, loc = fields[0].lower(), fields[1].lower() lat, long = float(fields[2]), float(fields[3]) coordinates[(geo_field, loc)] = { "latitude": lat, "longitude": long } else: print("WARNING: geo-coordinate file contains invalid line. Please make sure not to mix tabs and spaces as delimiters (use only tabs):",line) if use_defaults: with resource_stream(__package__, "data/lat_longs.tsv") as stream: with TextIOWrapper(stream, "utf-8") as defaults: for line in defaults: add_line_to_coordinates(line) if overrides: if os.path.isfile(overrides): with open(overrides, encoding='utf-8') as ifile: for line in ifile: add_line_to_coordinates(line) else: print("WARNING: input lat/long file %s not found." % overrides) return coordinates def read_colors(overrides=None, use_defaults=True): colors = {} # TODO: make parsing of tsv files more robust while allow for whitespace delimiting for backwards compatibility def add_line(line): if line.startswith('#'): return fields = line.strip().split() if not '\t' in line else line.strip().split('\t') if not fields: return # blank lines if len(fields) != 3: print("WARNING: Color map file contains invalid line. Please make sure not to mix tabs and spaces as delimiters (use only tabs):",line) return trait, trait_value, hex_code = fields[0].lower(), fields[1].lower(), fields[2] if not hex_code.startswith("#") or len(hex_code) != 7: print("WARNING: Color map file contained this invalid hex code: ", hex_code) return # If was already added, delete entirely so order can change to order in user-specified file # (even though dicts shouldn't be relied on to have order) if (trait, trait_value) in colors: del colors[(trait, trait_value)] colors[(trait, trait_value)] = hex_code if use_defaults: with resource_stream(__package__, "data/colors.tsv") as stream: with TextIOWrapper(stream, "utf-8") as defaults: for line in defaults: add_line(line) if overrides: if os.path.isfile(overrides): with open(overrides, encoding='utf-8') as fh: for line in fh: add_line(line) else: print("WARNING: Couldn't open color definitions file {}.".format(overrides)) color_map = defaultdict(list) for (trait, trait_value), hex_code in colors.items(): color_map[trait].append((trait_value, hex_code)) return color_map def write_VCF_translation(prot_dict, vcf_file_name, ref_file_name): """ Writes out a VCF-style file (which seems to be minimally handleable by vcftools and pyvcf) of the AA differences between sequences and the reference. This is a similar format created/used by read_in_vcf except that there is one of these dicts (with sequences, reference, positions) for EACH gene. Also writes out a fasta of the reference alignment. EBH 12 Dec 2017 """ import numpy as np #for the header seqNames = list(prot_dict[list(prot_dict.keys())[0]]['sequences'].keys()) #prepare the header of the VCF & write out header=["#CHROM","POS","ID","REF","ALT","QUAL","FILTER","INFO","FORMAT"]+seqNames with open(vcf_file_name, 'w', encoding='utf-8') as the_file: the_file.write( "##fileformat=VCFv4.2\n"+ "##source=NextStrain_Protein_Translation\n"+ "##FORMAT=<ID=GT,Number=1,Type=String,Description=\"Genotype\">\n") the_file.write("\t".join(header)+"\n") refWrite = [] vcfWrite = [] #go through for every gene/protein for fname, prot in prot_dict.items(): sequences = prot['sequences'] ref = prot['reference'] positions = prot['positions'] #write out the reference fasta refWrite.append(">"+fname) refWrite.append(ref) #go through every variable position #There are no deletions here, so it's simpler than for VCF nuc sequenes! for pi in positions: pos = pi+1 #change numbering to match VCF not python refb = ref[pi] #reference base at this position #try/except is (much) faster than list comprehension! pattern = [] for k,v in sequences.items(): try: pattern.append(sequences[k][pi]) except KeyError: pattern.append('.') pattern = np.array(pattern) #get the list of ALTs - minus any '.'! uniques = np.unique(pattern) uniques = uniques[np.where(uniques!='.')] #Convert bases to the number that matches the ALT j=1 for u in uniques: pattern[np.where(pattern==u)[0]] = str(j) j+=1 #Now convert these calls to #/# (VCF format) calls = [ j+"/"+j if j!='.' else '.' for j in pattern ] if len(uniques)==0: print("UNEXPECTED ERROR WHILE CONVERTING TO VCF AT POSITION {}".format(str(pi))) break #put it all together and write it out output = [fname, str(pos), ".", refb, ",".join(uniques), ".", "PASS", ".", "GT"] + calls vcfWrite.append("\t".join(output)) #write it all out with open(ref_file_name, 'w', encoding='utf-8') as the_file: the_file.write("\n".join(refWrite)) with open(vcf_file_name, 'a', encoding='utf-8') as the_file: the_file.write("\n".join(vcfWrite)) if vcf_file_name.lower().endswith('.gz'): import os #must temporarily remove .gz ending, or gzip won't zip it! os.rename(vcf_file_name, vcf_file_name[:-3]) call = ["gzip", vcf_file_name[:-3]] run_shell_command(" ".join(call), raise_errors = True) shquote = shlex.quote def run_shell_command(cmd, raise_errors=False, extra_env=None): """ Run the given command string via Bash with error checking. Returns True if the command exits normally. Returns False if the command exits with failure and "raise_errors" is False (the default). When "raise_errors" is True, exceptions are rethrown. If an *extra_env* mapping is passed, the provided keys and values are overlayed onto the default subprocess environment. """ return ShellCommandRunner(cmd, raise_errors=raise_errors, extra_env=extra_env).run() def first_line(text): """ Returns the first line of the given text, ignoring leading and trailing whitespace. """ return text.strip().splitlines()[0] def available_cpu_cores(fallback: int = 1) -> int: """ Returns the number (an int) of CPU cores available to this **process**, if determinable, otherwise the number of CPU cores available to the **computer**, if determinable, otherwise the *fallback* number (which defaults to 1). """ try: # Note that this is the correct function to use, not os.cpu_count(), as # described in the latter's documentation. # # The reason, which the documentation does not detail, is that # processes may be pinned or restricted to certain CPUs by setting # their "affinity". This is not typical except in high-performance # computing environments, but if it is done, then a computer with say # 24 total cores may only allow our process to use 12. If we tried to # naively use all 24, we'd end up with two threads across the 12 cores. # This would degrade performance rather than improve it! return len(os.sched_getaffinity(0)) except: # cpu_count() returns None if the value is indeterminable. return os.cpu_count() or fallback def nthreads_value(value): """ Argument value validation and casting function for --nthreads. """ if value.lower() == 'auto': return available_cpu_cores() try: return int(value) except ValueError: raise argparse.ArgumentTypeError("'%s' is not an integer or the word 'auto'" % value) from None def get_parent_name_by_child_name_for_tree(tree): ''' Return dictionary mapping child node names to parent node names ''' parents = {} for clade in tree.find_clades(order='level'): for child in clade: parents[child.name] = clade.name return parents def annotate_parents_for_tree(tree): """Annotate each node in the given tree with its parent. >>> import io >>> tree = Bio.Phylo.read(io.StringIO("(A, (B, C))"), "newick") >>> not any([hasattr(node, "parent") for node in tree.find_clades()]) True >>> tree = annotate_parents_for_tree(tree) >>> tree.root.parent is None True >>> all([hasattr(node, "parent") for node in tree.find_clades()]) True """ tree.root.parent = None for node in tree.find_clades(order="level"): for child in node.clades: child.parent = node # Return the tree. return tree def json_to_tree(json_dict, root=True): """Returns a Bio.Phylo tree corresponding to the given JSON dictionary exported by `tree_to_json`. Assigns links back to parent nodes for the root of the tree. Test opening a JSON from augur export v1. >>> import json >>> json_fh = open("tests/data/json_tree_to_nexus/flu_h3n2_ha_3y_tree.json", "r") >>> json_dict = json.load(json_fh) >>> tree = json_to_tree(json_dict) >>> tree.name 'NODE_0002020' >>> len(tree.clades) 2 >>> tree.clades[0].name 'NODE_0001489' >>> hasattr(tree, "attr") True >>> "dTiter" in tree.attr True >>> tree.clades[0].parent.name 'NODE_0002020' >>> tree.clades[0].branch_length > 0 True Test opening a JSON from augur export v2. >>> json_fh = open("tests/data/zika.json", "r") >>> json_dict = json.load(json_fh) >>> tree = json_to_tree(json_dict) >>> hasattr(tree, "name") True >>> len(tree.clades) > 0 True >>> tree.clades[0].branch_length > 0 True """ # Check for v2 JSON which has combined metadata and tree data. if root and "meta" in json_dict and "tree" in json_dict: json_dict = json_dict["tree"] node = Bio.Phylo.Newick.Clade() # v1 and v2 JSONs use different keys for strain names. if "name" in json_dict: node.name = json_dict["name"] else: node.name = json_dict["strain"] if "children" in json_dict: # Recursively add children to the current node. node.clades = [json_to_tree(child, root=False) for child in json_dict["children"]] # Assign all non-children attributes. for attr, value in json_dict.items(): if attr != "children": setattr(node, attr, value) # Only v1 JSONs support a single `attr` attribute. if hasattr(node, "attr"): node.numdate = node.attr.get("num_date") node.branch_length = node.attr.get("div") if "translations" in node.attr: node.translations = node.attr["translations"] elif hasattr(node, "node_attrs"): node.branch_length = node.node_attrs.get("div") if root: node = annotate_parents_for_tree(node) return node def get_augur_version(): """ Returns a string of the current augur version. """ return __version__ def is_augur_version_compatable(version): """ Checks if the provided **version** is the same major version as the currently running version of augur. Parameters ---------- version : str version to check against the current version Returns ------- Bool """ current_version = packaging_version.parse(get_augur_version()) this_version = packaging_version.parse(version) return this_version.release[0] == current_version.release[0] def read_bed_file(bed_file): """Read a BED file and return a list of excluded sites. Note: This function assumes the given file is a BED file. On parsing failures, it will attempt to skip the first line and retry, but no other error checking is attempted. Incorrectly formatted files will raise errors. Parameters ---------- bed_file : str Path to the BED file Returns: -------- list[int]: Sorted list of unique zero-indexed sites """ mask_sites = [] try: bed = pd.read_csv(bed_file, sep='\t', header=None, usecols=[1,2], dtype={1:int,2:int}) except ValueError: # Check if we have a header row. Otherwise, just fail. bed = pd.read_csv(bed_file, sep='\t', header=None, usecols=[1,2], dtype={1:int,2:int}, skiprows=1) print("Skipped row 1 of %s, assuming it is a header." % bed_file) for _, row in bed.iterrows(): mask_sites.extend(range(row[1], row[2])) return sorted(set(mask_sites)) def read_mask_file(mask_file): """Read a masking file and return a list of excluded sites. Masking files have a single masking site per line, either alone or as the second column of a tab-separated file. These sites are assumed to be one-indexed, NOT zero-indexed. Incorrectly formatted lines will be skipped. Parameters ---------- mask_file : str Path to the masking file Returns: -------- list[int]: Sorted list of unique zero-indexed sites """ mask_sites = [] with open(mask_file, encoding='utf-8') as mf: for idx, line in enumerate(l.strip() for l in mf.readlines()): if "\t" in line: line = line.split("\t")[1] try: mask_sites.append(int(line) - 1) except ValueError as err: print("Could not read line %s of %s: '%s' - %s" % (idx, mask_file, line, err), file=sys.stderr) raise return sorted(set(mask_sites)) def load_mask_sites(mask_file): """Load masking sites from either a BED file or a masking file. Parameters ---------- mask_file: str Path to the BED or masking file Returns ------- list[int] Sorted list of unique zero-indexed sites """ if mask_file.lower().endswith(".bed"): mask_sites = read_bed_file(mask_file) else: mask_sites = read_mask_file(mask_file) print("%d masking sites read from %s" % (len(mask_sites), mask_file)) return mask_sites