# # Copyright (C) 2018-2020 University of Oxford # # This file is part of tsinfer. # # tsinfer is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # tsinfer is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with tsinfer. If not, see <http://www.gnu.org/licenses/>. # """ Central module for high-level inference. The actual implementation of of the core tasks like ancestor generation and matching are delegated to other modules. """ import collections import queue import time import logging import threading import json import heapq import numpy as np import humanize import tskit import _tsinfer import tsinfer.formats as formats import tsinfer.algorithm as algorithm import tsinfer.threads as threads import tsinfer.provenance as provenance import tsinfer.constants as constants logger = logging.getLogger(__name__) def is_pc_ancestor(flags): """ Returns True if the path compression ancestor flag is set on the specified flags value. """ return (flags & constants.NODE_IS_PC_ANCESTOR) != 0 def is_srb_ancestor(flags): """ Returns True if the shared recombination breakpoint flag is set on the specified flags value. """ return (flags & constants.NODE_IS_SRB_ANCESTOR) != 0 def count_pc_ancestors(flags): """ Returns the number of values in the specified array which have the NODE_IS_PC_ANCESTOR set. """ flags = np.array(flags, dtype=np.uint32, copy=False) return np.sum(np.bitwise_and(flags, constants.NODE_IS_PC_ANCESTOR) != 0) def count_srb_ancestors(flags): """ Returns the number of values in the specified array which have the NODE_IS_SRB_ANCESTOR set. """ flags = np.array(flags, dtype=np.uint32, copy=False) return np.sum(np.bitwise_and(flags, constants.NODE_IS_SRB_ANCESTOR) != 0) class DummyProgress(object): """ Class that mimics the subset of the tqdm API that we use in this module. """ def update(self): pass def close(self): pass class DummyProgressMonitor(object): """ Simple class to mimic the interface of the real progress monitor. """ def get(self, key, total): return DummyProgress() def set_detail(self, info): pass def _get_progress_monitor(progress_monitor): if progress_monitor is None: progress_monitor = DummyProgressMonitor() return progress_monitor def verify(samples, tree_sequence, progress_monitor=None): """ verify(samples, tree_sequence) Verifies that the specified sample data and tree sequence files encode the same data. :param SampleData samples: The input :class:`SampleData` instance representing the observed data that we wish to compare to. :param TreeSequence tree_sequence: The input :class:`tskit.TreeSequence` instance an encoding of the specified samples that we wish to verify. """ progress_monitor = _get_progress_monitor(progress_monitor) if samples.num_sites != tree_sequence.num_sites: raise ValueError("numbers of sites not equal") if samples.num_samples != tree_sequence.num_samples: raise ValueError("numbers of samples not equal") if samples.sequence_length != tree_sequence.sequence_length: raise ValueError("Sequence lengths not equal") progress = progress_monitor.get("verify", tree_sequence.num_sites) for var1, var2 in zip(samples.variants(), tree_sequence.variants()): if var1.site.position != var2.site.position: raise ValueError( "site positions not equal: {} != {}".format( var1.site.position, var2.site.position ) ) if var1.alleles != var2.alleles: raise ValueError( "alleles not equal: {} != {}".format(var1.alleles, var2.alleles) ) if not np.array_equal(var1.genotypes, var2.genotypes): raise ValueError("Genotypes not equal at site {}".format(var1.site.id)) progress.update() progress.close() def infer( sample_data, num_threads=0, path_compression=True, simplify=True, recombination_rate=None, mismatch_rate=None, precision=None, engine=constants.C_ENGINE, progress_monitor=None, ): """ infer(sample_data, num_threads=0, path_compression=True, simplify=True) Runs the full :ref:`inference pipeline <sec_inference>` on the specified :class:`SampleData` instance and returns the inferred :class:`tskit.TreeSequence`. :param SampleData sample_data: The input :class:`SampleData` instance representing the observed data that we wish to make inferences from. :param int num_threads: The number of worker threads to use in parallelised sections of the algorithm. If <= 0, do not spawn any threads and use simpler sequential algorithms (default). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :param bool simplify: Whether to remove extra tree nodes and edges that are not on a path between the root and any of the samples. To do so, the final tree sequence is simplified by appling the :meth:`tskit.TreeSequence.simplify` method with ``keep_unary`` set to True (default = ``True``). :returns: The :class:`tskit.TreeSequence` object inferred from the input sample data. :rtype: tskit.TreeSequence """ ancestor_data = generate_ancestors( sample_data, num_threads=num_threads, engine=engine, progress_monitor=progress_monitor, ) ancestors_ts = match_ancestors( sample_data, ancestor_data, engine=engine, num_threads=num_threads, recombination_rate=recombination_rate, mismatch_rate=mismatch_rate, precision=precision, path_compression=path_compression, progress_monitor=progress_monitor, ) inferred_ts = match_samples( sample_data, ancestors_ts, engine=engine, num_threads=num_threads, recombination_rate=recombination_rate, mismatch_rate=mismatch_rate, precision=precision, path_compression=path_compression, simplify=simplify, progress_monitor=progress_monitor, ) return inferred_ts def generate_ancestors( sample_data, num_threads=0, path=None, engine=constants.C_ENGINE, progress_monitor=None, **kwargs ): """ generate_ancestors(sample_data, num_threads=0, path=None, **kwargs) Runs the ancestor generation :ref:`algorithm <sec_inference_generate_ancestors>` on the specified :class:`SampleData` instance and returns the resulting :class:`AncestorData` instance. If you wish to store generated ancestors persistently on file you must pass the ``path`` keyword argument to this function. For example, .. code-block:: python ancestor_data = tsinfer.generate_ancestors(sample_data, path="mydata.ancestors") Other keyword arguments are passed to the :class:`AncestorData` constructor, which may be used to control the storage properties of the generated file. :param SampleData sample_data: The :class:`SampleData` instance that we are genering putative ancestors from. :param int num_threads: The number of worker threads to use. If < 1, use a simpler synchronous algorithm. :param str path: The path of the file to store the sample data. If None, the information is stored in memory and not persistent. :rtype: AncestorData :returns: The inferred ancestors stored in an :class:`AncestorData` instance. """ sample_data._check_finalised() progress_monitor = _get_progress_monitor(progress_monitor) with formats.AncestorData(sample_data, path=path, **kwargs) as ancestor_data: generator = AncestorsGenerator( sample_data, ancestor_data, num_threads=num_threads, engine=engine, progress_monitor=progress_monitor, ) generator.add_sites() generator.run() ancestor_data.record_provenance("generate-ancestors") return ancestor_data def match_ancestors( sample_data, ancestor_data, num_threads=0, path_compression=True, recombination_rate=None, mismatch_rate=None, precision=None, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, ): """ match_ancestors(sample_data, ancestor_data, num_threads=0, path_compression=True) Runs the ancestor matching :ref:`algorithm <sec_inference_match_ancestors>` on the specified :class:`SampleData` and :class:`AncestorData` instances, returning the resulting :class:`tskit.TreeSequence` representing the complete ancestry of the putative ancestors. :param SampleData sample_data: The :class:`SampleData` instance representing the input data. :param AncestorData ancestor_data: The :class:`AncestorData` instance representing the set of ancestral haplotypes that we are finding a history for. :param int num_threads: The number of match worker threads to use. If this is <= 0 then a simpler sequential algorithm is used (default). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :return: The ancestors tree sequence representing the inferred history of the set of ancestors. :rtype: tskit.TreeSequence """ sample_data._check_finalised() ancestor_data._check_finalised() matcher = AncestorMatcher( sample_data, ancestor_data, num_threads=num_threads, recombination_rate=recombination_rate, mismatch_rate=mismatch_rate, precision=precision, path_compression=path_compression, extended_checks=extended_checks, engine=engine, progress_monitor=progress_monitor, ) return matcher.match_ancestors() def augment_ancestors( sample_data, ancestors_ts, indexes, num_threads=0, path_compression=True, recombination_rate=None, mismatch_rate=None, precision=None, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, ): """ augment_ancestors(sample_data, ancestors_ts, indexes, num_threads=0,\ path_compression=True) Runs the sample matching :ref:`algorithm <sec_inference_match_samples>` on the specified :class:`SampleData` instance and ancestors tree sequence, for the specified subset of sample indexes, returning the :class:`tskit.TreeSequence` instance including these samples. This tree sequence can then be used as an ancestors tree sequence for subsequent matching against all samples. :param SampleData sample_data: The :class:`SampleData` instance representing the input data. :param tskit.TreeSequence ancestors_ts: The :class:`tskit.TreeSequence` instance representing the inferred history among ancestral ancestral haplotypes. :param array indexes: The sample indexes to insert into the ancestors tree sequence. :param int num_threads: The number of match worker threads to use. If this is <= 0 then a simpler sequential algorithm is used (default). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :return: The specified ancestors tree sequence augmented with copying paths for the specified sample. :rtype: tskit.TreeSequence """ sample_data._check_finalised() manager = SampleMatcher( sample_data, ancestors_ts, num_threads=num_threads, recombination_rate=recombination_rate, mismatch_rate=mismatch_rate, precision=precision, path_compression=path_compression, extended_checks=extended_checks, engine=engine, progress_monitor=progress_monitor, ) indexes = np.array(indexes) if len(indexes) == 0: raise ValueError("Must supply at least one sample to augment") if np.any(indexes < 0) or np.any(indexes >= sample_data.num_samples): raise ValueError("Sample index out of bounds") manager.match_samples(indexes) ts = manager.get_augmented_ancestors_tree_sequence(indexes) return ts def match_samples( sample_data, ancestors_ts, num_threads=0, path_compression=True, simplify=True, recombination_rate=None, mismatch_rate=None, precision=None, extended_checks=False, stabilise_node_ordering=False, engine=constants.C_ENGINE, progress_monitor=None, ): """ match_samples(sample_data, ancestors_ts, num_threads=0, path_compression=True,\ simplify=True) Runs the sample matching :ref:`algorithm <sec_inference_match_samples>` on the specified :class:`SampleData` instance and ancestors tree sequence, returning the final :class:`tskit.TreeSequence` instance containing the full inferred history for all samples and sites. :param SampleData sample_data: The :class:`SampleData` instance representing the input data. :param tskit.TreeSequence ancestors_ts: The :class:`tskit.TreeSequence` instance representing the inferred history among ancestral ancestral haplotypes. :param int num_threads: The number of match worker threads to use. If this is <= 0 then a simpler sequential algorithm is used (default). :param bool path_compression: Whether to merge edges that share identical paths (essentially taking advantage of shared recombination breakpoints). :param bool simplify: Whether to remove extra tree nodes and edges that are not on a path between the root and any of the samples. To do so, the final tree sequence is simplified by appling the :meth:`tskit.TreeSequence.simplify` method with ``keep_unary`` set to True (default = ``True``). :return: The tree sequence representing the inferred history of the sample. :rtype: tskit.TreeSequence """ sample_data._check_finalised() manager = SampleMatcher( sample_data, ancestors_ts, num_threads=num_threads, recombination_rate=recombination_rate, mismatch_rate=mismatch_rate, precision=precision, path_compression=path_compression, extended_checks=extended_checks, engine=engine, progress_monitor=progress_monitor, ) manager.match_samples() ts = manager.finalise( simplify=simplify, stabilise_node_ordering=stabilise_node_ordering ) return ts class AncestorsGenerator(object): """ Manages the process of building ancestors. """ def __init__( self, sample_data, ancestor_data, num_threads=0, engine=constants.C_ENGINE, progress_monitor=None, ): self.sample_data = sample_data self.ancestor_data = ancestor_data self.progress_monitor = progress_monitor self.num_sites = sample_data.num_inference_sites self.num_samples = sample_data.num_samples self.num_threads = num_threads if engine == constants.C_ENGINE: logger.debug("Using C AncestorBuilder implementation") self.ancestor_builder = _tsinfer.AncestorBuilder( self.num_samples, self.num_sites ) elif engine == constants.PY_ENGINE: logger.debug("Using Python AncestorBuilder implementation") self.ancestor_builder = algorithm.AncestorBuilder( self.num_samples, self.num_sites ) else: raise ValueError("Unknown engine:{}".format(engine)) def add_sites(self): """ Add all sites marked for inference in the sample_data object into the ancestor builder. """ logger.info("Starting addition of {} sites".format(self.num_sites)) progress = self.progress_monitor.get("ga_add_sites", self.num_sites) for j, variant in enumerate(self.sample_data.variants(inference_sites=True)): time = variant.site.time if time == constants.TIME_UNSPECIFIED: counts = formats.allele_counts(variant.genotypes) # Non-variable sites have no obvious freq-as-time values assert counts.known != counts.derived assert counts.known != counts.ancestral assert counts.known > 0 # Time = freq of *all* derived alleles. Note that if n_alleles > 2 this # may not be sensible: https://github.com/tskit-dev/tsinfer/issues/228 time = counts.derived / counts.known self.ancestor_builder.add_site(j, time, variant.genotypes) progress.update() progress.close() logger.info("Finished adding sites") def _run_synchronous(self, progress): a = np.zeros(self.num_sites, dtype=np.int8) for t, focal_sites in self.descriptors: before = time.perf_counter() s, e = self.ancestor_builder.make_ancestor(focal_sites, a) duration = time.perf_counter() - before logger.debug( "Made ancestor in {:.2f}s at timepoint {} (epoch {}) " "from {} to {} (len={}) with {} focal sites ({})".format( duration, t, self.timepoint_to_epoch[t], s, e, e - s, focal_sites.shape[0], focal_sites, ) ) self.ancestor_data.add_ancestor( start=s, end=e, time=t, focal_sites=focal_sites, haplotype=a[s:e] ) progress.update() def _run_threaded(self, progress): # This works by pushing the ancestor descriptors onto the build_queue, # which the worker threads pop off and process. We need to add ancestors # in the the ancestor_data object in the correct order, so we maintain # a priority queue (add_queue) which allows us to track the next smallest # index of the generated ancestor. We add build ancestors to this queue # as they are built, and drain it when we can. queue_depth = 8 * self.num_threads # Seems like a reasonable limit build_queue = queue.Queue(queue_depth) add_lock = threading.Lock() next_add_index = 0 add_queue = [] def drain_add_queue(): nonlocal next_add_index num_drained = 0 while len(add_queue) > 0 and add_queue[0][0] == next_add_index: _, t, focal_sites, s, e, haplotype = heapq.heappop(add_queue) self.ancestor_data.add_ancestor( start=s, end=e, time=t, focal_sites=focal_sites, haplotype=haplotype ) progress.update() next_add_index += 1 num_drained += 1 logger.debug("Drained {} ancestors from add queue".format(num_drained)) def build_worker(thread_index): a = np.zeros(self.num_sites, dtype=np.int8) while True: work = build_queue.get() if work is None: break index, t, focal_sites = work start, end = self.ancestor_builder.make_ancestor(focal_sites, a) with add_lock: haplotype = a[start:end].copy() heapq.heappush( add_queue, (index, t, focal_sites, start, end, haplotype) ) drain_add_queue() build_queue.task_done() build_queue.task_done() build_threads = [ threads.queue_consumer_thread( build_worker, build_queue, name="build-worker-{}".format(j), index=j ) for j in range(self.num_threads) ] logger.debug("Started {} build worker threads".format(self.num_threads)) for index, (t, focal_sites) in enumerate(self.descriptors): build_queue.put((index, t, focal_sites)) # Stop the the worker threads. for j in range(self.num_threads): build_queue.put(None) for j in range(self.num_threads): build_threads[j].join() drain_add_queue() def run(self): self.descriptors = self.ancestor_builder.ancestor_descriptors() self.num_ancestors = len(self.descriptors) # Maps epoch numbers to their corresponding ancestor times. self.timepoint_to_epoch = {} for t, _ in reversed(self.descriptors): if t not in self.timepoint_to_epoch: self.timepoint_to_epoch[t] = len(self.timepoint_to_epoch) + 1 if self.num_ancestors > 0: logger.info("Starting build for {} ancestors".format(self.num_ancestors)) progress = self.progress_monitor.get("ga_generate", self.num_ancestors) a = np.zeros(self.num_sites, dtype=np.int8) root_time = max(self.timepoint_to_epoch.keys()) + 1 ultimate_ancestor_time = root_time + 1 # Add the ultimate ancestor. This is an awkward hack really; we don't # ever insert this ancestor. The only reason to add it here is that # it makes sure that the ancestor IDs we have in the ancestor file are # the same as in the ancestor tree sequence. This seems worthwhile. self.ancestor_data.add_ancestor( start=0, end=self.num_sites, time=ultimate_ancestor_time, focal_sites=[], haplotype=a, ) # Hack to ensure we always have a root with zeros at every position. self.ancestor_data.add_ancestor( start=0, end=self.num_sites, time=root_time, focal_sites=np.array([], dtype=np.int32), haplotype=a, ) if self.num_threads <= 0: self._run_synchronous(progress) else: self._run_threaded(progress) progress.close() logger.info("Finished building ancestors") class Matcher(object): def __init__( self, sample_data, inference_site_position, num_threads=1, path_compression=True, recombination_rate=None, mismatch_rate=None, precision=None, extended_checks=False, engine=constants.C_ENGINE, progress_monitor=None, ): self.sample_data = sample_data self.num_threads = num_threads self.path_compression = path_compression self.num_samples = self.sample_data.num_samples self.num_sites = len(inference_site_position) self.progress_monitor = _get_progress_monitor(progress_monitor) self.match_progress = None # Allocated by subclass self.extended_checks = extended_checks # Map of site index to tree sequence position. Bracketing # values of 0 and L are used for simplicity. self.position_map = np.hstack( [inference_site_position, [sample_data.sequence_length]] ) self.position_map[0] = 0 if precision is None: # TODO Is this a good default? Need to investigate the effects. precision = 2 if recombination_rate is None: # TODO is this a good value? Will need to tune recombination_rate = 1e-8 self.recombination_rate = np.zeros(self.num_sites) # FIXME not quite right: we should check the rho[0] = 0 self.recombination_rate[:] = recombination_rate if mismatch_rate is None: # Setting a very small value for now. mismatch_rate = 1e-20 self.mismatch_rate = np.zeros(self.num_sites) self.mismatch_rate[:] = mismatch_rate self.precision = precision if engine == constants.C_ENGINE: logger.debug("Using C matcher implementation") self.tree_sequence_builder_class = _tsinfer.TreeSequenceBuilder self.ancestor_matcher_class = _tsinfer.AncestorMatcher elif engine == constants.PY_ENGINE: logger.debug("Using Python matcher implementation") self.tree_sequence_builder_class = algorithm.TreeSequenceBuilder self.ancestor_matcher_class = algorithm.AncestorMatcher else: raise ValueError("Unknown engine:{}".format(engine)) self.tree_sequence_builder = None all_sites = self.sample_data.sites_position[:] index = np.searchsorted(all_sites, inference_site_position) if not np.all(all_sites[index] == inference_site_position): raise ValueError( "Site positions for inference must be a subset of those in " "the sample data file." ) num_alleles = sample_data.num_alleles()[index] # Allocate 64K nodes and edges initially. This will double as needed and will # quickly be big enough even for very large instances. max_edges = 64 * 1024 max_nodes = 64 * 1024 self.tree_sequence_builder = self.tree_sequence_builder_class( num_alleles=num_alleles, max_nodes=max_nodes, max_edges=max_edges ) logger.debug( "Allocated tree sequence builder with max_nodes={}".format(max_nodes) ) # Allocate the matchers and statistics arrays. num_threads = max(1, self.num_threads) self.match = [ np.full(self.num_sites, tskit.MISSING_DATA, np.int8) for _ in range(num_threads) ] self.results = ResultBuffer() self.mean_traceback_size = np.zeros(num_threads) self.num_matches = np.zeros(num_threads) self.matcher = [ self.ancestor_matcher_class( self.tree_sequence_builder, recombination_rate=self.recombination_rate, mismatch_rate=self.mismatch_rate, precision=precision, extended_checks=self.extended_checks, ) for _ in range(num_threads) ] def encode_metadata(self, value): return json.dumps(value).encode() def _find_path(self, child_id, haplotype, start, end, thread_index=0): """ Finds the path of the specified haplotype and upates the results for the specified thread_index. """ matcher = self.matcher[thread_index] match = self.match[thread_index] missing = haplotype == tskit.MISSING_DATA left, right, parent = matcher.find_path(haplotype, start, end, match) self.results.set_path(child_id, left, right, parent) match[missing] = tskit.MISSING_DATA diffs = start + np.where(haplotype[start:end] != match[start:end])[0] derived_state = haplotype[diffs] self.results.set_mutations(child_id, diffs.astype(np.int32), derived_state) self.match_progress.update() self.mean_traceback_size[thread_index] += matcher.mean_traceback_size self.num_matches[thread_index] += 1 logger.debug( "matched node {}; num_edges={} tb_size={:.2f} match_mem={}".format( child_id, left.shape[0], matcher.mean_traceback_size, humanize.naturalsize(matcher.total_memory, binary=True), ) ) def convert_inference_mutations(self, tables): """ Convert the mutations stored in the tree sequence builder into the output format. """ mut_site, node, derived_state, _ = self.tree_sequence_builder.dump_mutations() site_id = 0 mutation_id = 0 num_mutations = len(mut_site) for site in self.sample_data.sites(): if site.inference: tables.sites.add_row( site.position, ancestral_state=site.alleles[0], metadata=self.encode_metadata(site.metadata), ) while mutation_id < num_mutations and mut_site[mutation_id] == site_id: tables.mutations.add_row( site_id, node=node[mutation_id], derived_state=site.alleles[derived_state[mutation_id]], ) mutation_id += 1 site_id += 1 class AncestorMatcher(Matcher): def __init__(self, sample_data, ancestor_data, **kwargs): super().__init__(sample_data, ancestor_data.sites_position[:], **kwargs) self.ancestor_data = ancestor_data self.num_ancestors = self.ancestor_data.num_ancestors self.epoch = self.ancestor_data.ancestors_time[:] # Add nodes for all the ancestors so that the ancestor IDs are equal # to the node IDs. for ancestor_id in range(self.num_ancestors): self.tree_sequence_builder.add_node(self.epoch[ancestor_id]) self.ancestors = self.ancestor_data.ancestors() # Consume the first ancestor. a = next(self.ancestors, None) self.num_epochs = 0 if a is not None: # assert np.array_equal(a.haplotype, np.zeros(self.num_sites, dtype=np.int8)) # Create a list of all ID ranges in each epoch. breaks = np.where(self.epoch[1:] != self.epoch[:-1])[0] start = np.hstack([[0], breaks + 1]) end = np.hstack([breaks + 1, [self.num_ancestors]]) self.epoch_slices = np.vstack([start, end]).T self.num_epochs = self.epoch_slices.shape[0] self.start_epoch = 1 def __epoch_info_dict(self, epoch_index): start, end = self.epoch_slices[epoch_index] return collections.OrderedDict( [("epoch", str(self.epoch[start])), ("nanc", str(end - start))] ) def __ancestor_find_path(self, ancestor, thread_index=0): # NOTE we're no longer using the ancestor's focal sites as a way # of knowing where mutations happen but instead having a non-zero # mutation rate and letting the mismatches do the work. We might # want to have a version with a zero mutation rate. haplotype = np.full(self.num_sites, tskit.MISSING_DATA, dtype=np.int8) start = ancestor.start end = ancestor.end assert ancestor.haplotype.shape[0] == (end - start) haplotype[start:end] = ancestor.haplotype self._find_path(ancestor.id, haplotype, start, end, thread_index) def __start_epoch(self, epoch_index): start, end = self.epoch_slices[epoch_index] info = collections.OrderedDict( [("epoch", str(self.epoch[start])), ("nanc", str(end - start))] ) self.progress_monitor.set_detail(info) self.tree_sequence_builder.freeze_indexes() def __complete_epoch(self, epoch_index): start, end = map(int, self.epoch_slices[epoch_index]) num_ancestors_in_epoch = end - start current_time = self.epoch[start] nodes_before = self.tree_sequence_builder.num_nodes for child_id in range(start, end): left, right, parent = self.results.get_path(child_id) self.tree_sequence_builder.add_path( child_id, left, right, parent, compress=self.path_compression, extended_checks=self.extended_checks, ) site, derived_state = self.results.get_mutations(child_id) self.tree_sequence_builder.add_mutations(child_id, site, derived_state) extra_nodes = self.tree_sequence_builder.num_nodes - nodes_before mean_memory = np.mean([matcher.total_memory for matcher in self.matcher]) logger.debug( "Finished epoch {} with {} ancestors; {} extra nodes inserted; " "mean_tb_size={:.2f} edges={}; mean_matcher_mem={}".format( current_time, num_ancestors_in_epoch, extra_nodes, np.sum(self.mean_traceback_size) / np.sum(self.num_matches), self.tree_sequence_builder.num_edges, humanize.naturalsize(mean_memory, binary=True), ) ) self.mean_traceback_size[:] = 0 self.num_matches[:] = 0 self.results.clear() def __match_ancestors_single_threaded(self): for j in range(self.start_epoch, self.num_epochs): self.__start_epoch(j) start, end = map(int, self.epoch_slices[j]) for ancestor_id in range(start, end): a = next(self.ancestors) assert ancestor_id == a.id self.__ancestor_find_path(a) self.__complete_epoch(j) def __match_ancestors_multi_threaded(self, start_epoch=1): # See note on match samples multithreaded below. Should combine these # into a single function. Possibly when trying to make the thread # error handling more robust. queue_depth = 8 * self.num_threads # Seems like a reasonable limit match_queue = queue.Queue(queue_depth) def match_worker(thread_index): while True: work = match_queue.get() if work is None: break self.__ancestor_find_path(work, thread_index) match_queue.task_done() match_queue.task_done() match_threads = [ threads.queue_consumer_thread( match_worker, match_queue, name="match-worker-{}".format(j), index=j ) for j in range(self.num_threads) ] logger.debug("Started {} match worker threads".format(self.num_threads)) for j in range(self.start_epoch, self.num_epochs): self.__start_epoch(j) start, end = map(int, self.epoch_slices[j]) for ancestor_id in range(start, end): a = next(self.ancestors) assert a.id == ancestor_id match_queue.put(a) # Block until all matches have completed. match_queue.join() self.__complete_epoch(j) # Stop the the worker threads. for j in range(self.num_threads): match_queue.put(None) for j in range(self.num_threads): match_threads[j].join() def match_ancestors(self): logger.info("Starting ancestor matching for {} epochs".format(self.num_epochs)) self.match_progress = self.progress_monitor.get("ma_match", self.num_ancestors) if self.num_threads <= 0: self.__match_ancestors_single_threaded() else: self.__match_ancestors_multi_threaded() ts = self.store_output() self.match_progress.close() logger.info("Finished ancestor matching") return ts def get_ancestors_tree_sequence(self): """ Return the ancestors tree sequence. Only inference sites are included in this tree sequence. All nodes have the sample flag bit set. """ logger.debug("Building ancestors tree sequence") tsb = self.tree_sequence_builder tables = tskit.TableCollection( sequence_length=self.ancestor_data.sequence_length ) flags, times = tsb.dump_nodes() num_pc_ancestors = count_pc_ancestors(flags) # TODO Write out the metadata here etc also tables.nodes.set_columns(flags=flags, time=times) left, right, parent, child = tsb.dump_edges() tables.edges.set_columns( left=self.position_map[left], right=self.position_map[right], parent=parent, child=child, ) self.convert_inference_mutations(tables) logger.debug("Sorting ancestors tree sequence") tables.sort() # Note: it's probably possible to compute the mutation parents from the # tsb data structures but we're not doing it for now. tables.build_index() tables.compute_mutation_parents() logger.debug("Sorting ancestors tree sequence done") for timestamp, record in self.ancestor_data.provenances(): tables.provenances.add_row(timestamp=timestamp, record=json.dumps(record)) record = provenance.get_provenance_dict( command="match_ancestors", source={"uuid": self.ancestor_data.uuid} ) tables.provenances.add_row(record=json.dumps(record)) logger.info( "Built ancestors tree sequence: {} nodes ({} pc ancestors); {} edges; " "{} sites; {} mutations".format( len(tables.nodes), num_pc_ancestors, len(tables.edges), len(tables.mutations), len(tables.sites), ) ) return tables.tree_sequence() def store_output(self): if self.num_ancestors > 0: ts = self.get_ancestors_tree_sequence() else: # Allocate an empty tree sequence. tables = tskit.TableCollection( sequence_length=self.ancestor_data.sequence_length ) ts = tables.tree_sequence() return ts class SampleMatcher(Matcher): def __init__(self, sample_data, ancestors_ts, **kwargs): self.ancestors_ts_tables = ancestors_ts.dump_tables() super().__init__(sample_data, self.ancestors_ts_tables.sites.position, **kwargs) self.restore_tree_sequence_builder() self.sample_ids = np.zeros(self.num_samples, dtype=np.int32) def restore_tree_sequence_builder(self): tables = self.ancestors_ts_tables if self.sample_data.sequence_length != tables.sequence_length: raise ValueError( "Ancestors tree sequence not compatible: sequence length is different to " "sample data file." ) if np.any(tables.nodes.time <= 0): raise ValueError("All nodes must have time > 0") edges = tables.edges # Get the indexes into the position array. left = np.searchsorted(self.position_map, edges.left) if np.any(self.position_map[left] != edges.left): raise ValueError("Invalid left coordinates") right = np.searchsorted(self.position_map, edges.right) if np.any(self.position_map[right] != edges.right): raise ValueError("Invalid right coordinates") # Need to sort by child ID here and left so that we can efficiently # insert the child paths. index = np.lexsort((left, edges.child)) nodes = tables.nodes self.tree_sequence_builder.restore_nodes(nodes.time, nodes.flags) self.tree_sequence_builder.restore_edges( left[index].astype(np.int32), right[index].astype(np.int32), edges.parent[index], edges.child[index], ) mutations = tables.mutations derived_state = np.zeros(len(mutations), dtype=np.int8) mutation_site = mutations.site site_id = 0 mutation_id = 0 for site in self.sample_data.sites(): if site.inference: while ( mutation_id < len(mutations) and mutation_site[mutation_id] == site_id ): allele = mutations[mutation_id].derived_state derived_state[mutation_id] = site.alleles.index(allele) mutation_id += 1 site_id += 1 self.tree_sequence_builder.restore_mutations( mutation_site, mutations.node, derived_state, mutations.parent ) logger.info( "Loaded {} samples {} nodes; {} edges; {} sites; {} mutations".format( self.num_samples, len(nodes), len(edges), self.num_sites, len(mutations), ) ) def __process_sample(self, sample_id, haplotype, thread_index=0): self._find_path(sample_id, haplotype, 0, self.num_sites, thread_index) def __match_samples_single_threaded(self, indexes): sample_haplotypes = self.sample_data.haplotypes(indexes, inference_sites=True) for j, a in sample_haplotypes: self.__process_sample(self.sample_ids[j], a) def __match_samples_multi_threaded(self, indexes): # Note that this function is not almost identical to the match_ancestors # multithreaded function above. All we need to do is provide a function # to do the matching and some producer for the actual items and we # can bring this into a single function. queue_depth = 8 * self.num_threads # Seems like a reasonable limit match_queue = queue.Queue(queue_depth) def match_worker(thread_index): while True: work = match_queue.get() if work is None: break sample_id, a = work self.__process_sample(sample_id, a, thread_index) match_queue.task_done() match_queue.task_done() match_threads = [ threads.queue_consumer_thread( match_worker, match_queue, name="match-worker-{}".format(j), index=j ) for j in range(self.num_threads) ] logger.debug("Started {} match worker threads".format(self.num_threads)) sample_haplotypes = self.sample_data.haplotypes(indexes, inference_sites=True) for j, a in sample_haplotypes: match_queue.put((self.sample_ids[j], a)) # Stop the the worker threads. for j in range(self.num_threads): match_queue.put(None) for j in range(self.num_threads): match_threads[j].join() def match_samples(self, indexes=None): if indexes is None: indexes = np.arange(self.num_samples) # Add in sample nodes. for j in indexes: self.sample_ids[j] = self.tree_sequence_builder.add_node(0) logger.info("Started matching for {} samples".format(len(indexes))) if self.num_sites > 0: self.match_progress = self.progress_monitor.get("ms_match", len(indexes)) if self.num_threads <= 0: self.__match_samples_single_threaded(indexes) else: self.__match_samples_multi_threaded(indexes) self.match_progress.close() logger.info( "Inserting sample paths: {} edges in total".format( self.results.total_edges ) ) progress_monitor = self.progress_monitor.get("ms_paths", len(indexes)) for j in indexes: sample_id = int(self.sample_ids[j]) left, right, parent = self.results.get_path(sample_id) self.tree_sequence_builder.add_path( sample_id, left, right, parent, compress=self.path_compression ) site, derived_state = self.results.get_mutations(sample_id) self.tree_sequence_builder.add_mutations(sample_id, site, derived_state) progress_monitor.update() progress_monitor.close() def finalise(self, simplify, stabilise_node_ordering): logger.info("Finalising tree sequence") ts = self.get_samples_tree_sequence() if simplify: logger.info( "Running simplify(keep_unary=True) on {} nodes and {} edges".format( ts.num_nodes, ts.num_edges ) ) if stabilise_node_ordering: # Ensure all the node times are distinct so that they will have # stable IDs after simplifying. This could possibly also be done # by reversing the IDs within a time slice. This is used for comparing # tree sequences produced by perfect inference. tables = ts.dump_tables() times = tables.nodes.time for t in range(1, int(times[0])): index = np.where(times == t)[0] k = index.shape[0] times[index] += np.arange(k)[::-1] / k tables.nodes.set_columns(flags=tables.nodes.flags, time=times) tables.sort() ts = tables.tree_sequence() ts = ts.simplify( samples=self.sample_ids, filter_sites=False, keep_unary=True ) logger.info( "Finished simplify; now have {} nodes and {} edges".format( ts.num_nodes, ts.num_edges ) ) return ts def insert_sites(self, tables): """ Insert the sites in the sample data that were not marked for inference, updating the specified site and mutation tables. This is done by iterating over the trees """ # NOTE: This is all quite confusing and can hopefully be cleaned up. num_sites = self.sample_data.num_sites num_non_inference_sites = num_sites - self.num_sites progress_monitor = self.progress_monitor.get("ms_sites", num_sites) site_id, node, derived_state, _ = self.tree_sequence_builder.dump_mutations() ts = tables.tree_sequence() if num_non_inference_sites > 0: assert ts.num_edges > 0 logger.info( "Starting mutation positioning for {} non inference sites".format( num_non_inference_sites ) ) inferred_mutation = 0 inferred_site = 0 trees = ts.trees() tree = next(trees) for variant in self.sample_data.variants(): site = variant.site predefined_anc_state = site.ancestral_state while tree.interval[1] <= site.position: tree = next(trees) assert tree.interval[0] <= site.position < tree.interval[1] tables.sites.add_row( position=site.position, ancestral_state=predefined_anc_state, metadata=self.encode_metadata(site.metadata), ) if site.inference == 1: while ( inferred_mutation < len(site_id) and site_id[inferred_mutation] == inferred_site ): tables.mutations.add_row( site=site.id, node=node[inferred_mutation], derived_state=variant.alleles[ derived_state[inferred_mutation] ], ) inferred_mutation += 1 inferred_site += 1 else: if np.all(variant.genotypes == tskit.MISSING_DATA): # Map_mutations has to have at least 1 non-missing value to work inferred_anc_state = predefined_anc_state mapped_mutations = [] else: inferred_anc_state, mapped_mutations = tree.map_mutations( variant.genotypes, variant.alleles ) if inferred_anc_state != predefined_anc_state: # The user specified a specific ancestral state. However, the # map_mutations method has reconstructed a different state at the # root, so insert an extra mutation over each root to allow the # ancestral state to be as the user specified for root_node in tree.roots: tables.mutations.add_row( site=site.id, node=root_node, derived_state=inferred_anc_state, ) for mutation in mapped_mutations: tables.mutations.add_row( site=site.id, node=mutation.node, derived_state=mutation.derived_state, ) progress_monitor.update() else: # Simple case where all sites are inference sites. We save a lot of time here # by not decoding the genotypes. logger.info("Inserting detailed site information") position = self.sample_data.sites_position[:] alleles = self.sample_data.sites_alleles[:] metadata = self.sample_data.sites_metadata[:] k = 0 for j in range(self.num_sites): tables.sites.add_row( position=position[j], ancestral_state=alleles[j][0], metadata=self.encode_metadata(metadata[j]), ) while k < len(site_id) and site_id[k] == j: tables.mutations.add_row( site=j, node=node[k], derived_state=alleles[j][derived_state[k]] ) k += 1 progress_monitor.update() progress_monitor.close() def get_samples_tree_sequence(self): """ Returns the current state of the build tree sequence. All samples and ancestors will have the sample node flag set. For correct sample reconstruction, the non-inference sites also need to be placed into the resulting tree sequence. """ tsb = self.tree_sequence_builder tables = self.ancestors_ts_tables.copy() num_ancestral_individuals = len(tables.individuals) # Currently there's no information about populations etc stored in the # ancestors ts. for metadata in self.sample_data.populations_metadata[:]: tables.populations.add_row(self.encode_metadata(metadata)) for ind in self.sample_data.individuals(): tables.individuals.add_row( location=ind.location, metadata=self.encode_metadata(ind.metadata) ) logger.debug("Adding tree sequence nodes") flags, times = tsb.dump_nodes() num_pc_ancestors = count_pc_ancestors(flags) # All true ancestors are samples in the ancestors tree sequence. We unset # the SAMPLE flag but keep other flags intact. new_flags = np.bitwise_and(tables.nodes.flags, ~tskit.NODE_IS_SAMPLE) tables.nodes.set_columns( flags=new_flags.astype(np.uint32), time=tables.nodes.time, population=tables.nodes.population, individual=tables.nodes.individual, metadata=tables.nodes.metadata, metadata_offset=tables.nodes.metadata_offset, ) assert len(tables.nodes) == self.sample_ids[0] # Now add in the sample nodes with metadata, etc. for sample_id, metadata, population, individual in zip( self.sample_ids, self.sample_data.samples_metadata[:], self.sample_data.samples_population[:], self.sample_data.samples_individual[:], ): tables.nodes.add_row( flags=flags[sample_id], time=times[sample_id], population=population, individual=num_ancestral_individuals + individual, metadata=self.encode_metadata(metadata), ) # Add in the remaining non-sample nodes. for u in range(self.sample_ids[-1] + 1, tsb.num_nodes): tables.nodes.add_row(flags=flags[u], time=times[u]) logger.debug("Adding tree sequence edges") tables.edges.clear() left, right, parent, child = tsb.dump_edges() if self.num_sites == 0: # We have no inference sites, so no edges have been estimated. To ensure # we have a rooted tree, we add in edges for each sample to an artificial # root. assert left.shape[0] == 0 root = tables.nodes.add_row(flags=0, time=tables.nodes.time.max() + 1) for sample_id in self.sample_ids: tables.edges.add_row(0, tables.sequence_length, root, sample_id) else: tables.edges.set_columns( left=self.position_map[left], right=self.position_map[right], parent=parent, child=child, ) logger.debug("Sorting and building intermediate tree sequence.") tables.sites.clear() tables.mutations.clear() tables.sort() self.insert_sites(tables) # FIXME this is a shortcut. We should be computing the mutation parent above # during insertion (probably) tables.build_index() tables.compute_mutation_parents() # We don't have a source here because tree sequence files don't have a # UUID yet. record = provenance.get_provenance_dict(command="match-samples") tables.provenances.add_row(record=json.dumps(record)) logger.info( "Built samples tree sequence: {} nodes ({} pc); {} edges; " "{} sites; {} mutations".format( len(tables.nodes), num_pc_ancestors, len(tables.edges), len(tables.sites), len(tables.mutations), ) ) return tables.tree_sequence() def get_augmented_ancestors_tree_sequence(self, sample_indexes): """ Return the ancestors tree sequence augmented with samples as extra ancestors. """ logger.debug("Building augmented ancestors tree sequence") tsb = self.tree_sequence_builder tables = self.ancestors_ts_tables.copy() num_pc_ancestors = count_pc_ancestors(tables.nodes.flags) flags, times = tsb.dump_nodes() s = 0 for j in range(len(tables.nodes), len(flags)): if times[j] == 0.0: # This is an augmented ancestor node. tables.nodes.add_row( flags=constants.NODE_IS_SAMPLE_ANCESTOR, time=times[j], metadata=self.encode_metadata({"sample": int(sample_indexes[s])}), ) s += 1 else: tables.nodes.add_row(flags=flags[j], time=times[j]) assert s == len(sample_indexes) assert len(tables.nodes) == len(flags) # Increment the time for all nodes so the augmented samples are no longer # at timepoint 0. tables.nodes.set_columns( flags=tables.nodes.flags, time=tables.nodes.time + 1, population=tables.nodes.population, individual=tables.nodes.individual, metadata=tables.nodes.metadata, metadata_offset=tables.nodes.metadata_offset, ) num_pc_ancestors = count_pc_ancestors(tables.nodes.flags) - num_pc_ancestors # TODO - check this works for augmented ancestors with missing data left, right, parent, child = tsb.dump_edges() tables.edges.set_columns( left=self.position_map[left], right=self.position_map[right], parent=parent, child=child, ) tables.sites.clear() tables.mutations.clear() self.convert_inference_mutations(tables) record = provenance.get_provenance_dict(command="augment_ancestors") tables.provenances.add_row(record=json.dumps(record)) logger.debug("Sorting ancestors tree sequence") tables.sort() logger.debug("Sorting ancestors tree sequence done") logger.info( "Augmented ancestors tree sequence: {} nodes ({} extra pc ancestors); " "{} edges; {} sites; {} mutations".format( len(tables.nodes), num_pc_ancestors, len(tables.edges), len(tables.mutations), len(tables.sites), ) ) return tables.tree_sequence() class ResultBuffer(object): """ A wrapper for numpy arrays representing the results of a copying operations. """ def __init__(self): self.paths = {} self.mutations = {} self.lock = threading.Lock() self.total_edges = 0 def clear(self): """ Clears this result buffer. """ self.paths.clear() self.mutations.clear() self.total_edges = 0 def set_path(self, node_id, left, right, parent): with self.lock: assert node_id not in self.paths self.paths[node_id] = left, right, parent self.total_edges += len(left) def set_mutations(self, node_id, site, derived_state=None): if derived_state is None: derived_state = np.ones(site.shape[0], dtype=np.int8) with self.lock: self.mutations[node_id] = site, derived_state def get_path(self, node_id): return self.paths[node_id] def get_mutations(self, node_id): return self.mutations[node_id] def minimise(ts): """ Returns a tree sequence with the minimal information required to represent the tree topologies at its sites. This is a convenience function used when we wish to use a subset of the sites in a tree sequence for ancestor matching. It is a thin-wrapper over the simplify method. """ return ts.simplify(reduce_to_site_topology=True, filter_sites=False)