# Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Routines for manipulating numpy arrays of segmentation data.""" from collections import Counter import numpy as np import scipy.sparse import skimage.measure # Monkey patch fix for indexing overflow problems with 64 bit IDs. # See also: # http://scipy-user.10969.n7.nabble.com/SciPy-User-strange-error-when-creating-csr-matrix-td20129.html # https://github.com/scipy/scipy/pull/4678 if scipy.__version__ in ('0.14.0', '0.14.1', '0.15.1'): def _get_index_dtype(*unused_args, **unused_kwargs): return np.int64 scipy.sparse.compressed.get_index_dtype = _get_index_dtype scipy.sparse.csr.get_index_dtype = _get_index_dtype scipy.sparse.csc.get_index_dtype = _get_index_dtype scipy.sparse.bsr.get_index_dtype = _get_index_dtype def make_labels_contiguous(labels): """Relabels 'labels' so that its ID space is dense. If N is the number of unique ids in 'labels', the new IDs will cover the range [0..N-1]. Args: labels: ndarray of segment IDs Returns: tuple of: ndarray of dense segment IDs list of (old_id, new_id) pairs """ orig_ids = np.unique(labels) new_ids = np.arange(len(orig_ids)) # A sparse matrix is required so that arbitrarily large IDs can be used as # input. The first dimension of the matrix is dummy and has a size of 1 (the # first coordinate is fixed at 0). row_indices = np.zeros_like(orig_ids) col_indices = orig_ids relabel = scipy.sparse.csr_matrix((new_ids, (row_indices, col_indices))) # Index with a 2D array so that the output is a sparse matrix. labels2d = labels.reshape(1, labels.size) relabeled = relabel[0, labels2d] return relabeled.toarray().reshape(labels.shape), zip(orig_ids, new_ids) def clear_dust(data, min_size=10): """Removes small objects from a segmentation array. Replaces objects smaller than `min_size` with 0 (background). Args: data: numpy array of segment IDs min_size: minimum size in voxels of an object to be retained Returns: the data array (modified in place) """ ids, sizes = np.unique(data, return_counts=True) small = ids[sizes < min_size] small_mask = np.in1d(data.flat, small).reshape(data.shape) data[small_mask] = 0 return data def reduce_id_bits(segmentation): """Reduces the number of bits used for IDs. Assumes that one additional ID beyond the max of 'segmentation' is necessary (used by GALA to mark boundary areas). Args: segmentation: ndarray of int type Returns: segmentation ndarray converted to minimal uint type large enough to keep all the IDs. """ max_id = segmentation.max() if max_id <= np.iinfo(np.uint8).max: return segmentation.astype(np.uint8) elif max_id <= np.iinfo(np.uint16).max: return segmentation.astype(np.uint16) elif max_id <= np.iinfo(np.uint32).max: return segmentation.astype(np.uint32) def split_disconnected_components(labels): """Relabels the connected components of a 3-D integer array. Connected components are determined based on 6-connectivity, where two neighboring positions are considering part of the same component if they have identical labels. The label 0 is treated specially: all positions labeled 0 in the input are labeled 0 in the output, regardless of whether they are contiguous. Connected components of the input array (other than segment id 0) are given consecutive ids in the output, starting from 1. Args: labels: 3-D integer numpy array. Returns: The relabeled numpy array, same dtype as `labels`. """ has_zero = 0 in labels fixed_labels = skimage.measure.label(labels, connectivity=1, background=0) if has_zero or (not has_zero and 0 in fixed_labels): if np.any((fixed_labels == 0) != (labels == 0)): fixed_labels[...] += 1 fixed_labels[labels == 0] = 0 return np.cast[labels.dtype](fixed_labels) def clean_up(seg, split_cc=True, min_size=0, return_id_map=False): # pylint: disable=invalid-name """Runs connected components and removes small objects. Args: seg: segmentation to clean as a uint64 ndarray split_cc: whether to recompute connected components min_size: connected components smaller that this value get removed from the segmentation; if 0, no filtering by size is done return_id_map: whether to compute and return a map from new IDs to original IDs Returns: None if not return_id_map, otherwise a dictionary mapping new IDs to original IDs. `seg` is modified in place. """ if return_id_map: seg_orig = seg.copy() if split_cc: seg[...] = split_disconnected_components(seg) if min_size > 0: clear_dust(seg, min_size) if return_id_map: cc_ids, cc_idx = np.unique(seg.ravel(), return_index=True) orig_ids = seg_orig.ravel()[cc_idx] cc_to_orig = dict(zip(cc_ids, orig_ids)) return cc_to_orig def split_segmentation_by_intersection(a, b, min_size): """Computes the intersection of two segmentations. Intersects two spatially overlapping segmentations and assigns a new ID to every unique (id1, id2) pair of overlapping voxels. If 'id2' is the largest object overlapping 'id1', their intersection retains the 'id1' label. If the fragment created by intersection is smaller than 'min_size', it gets removed from the segmentation (assigned an id of 0 in the output). `a` is modified in place, `b` is not changed. Note that (id1, 0) is considered a valid pair and will be mapped to a non-zero ID as long as the size of the overlapping region is >= min_size, but (0, id2) will always be mapped to 0 in the output. Args: a: First segmentation. b: Second segmentation. min_size: Minimum size intersection segment to keep (not map to 0). Raises: TypeError: if a or b don't have a dtype of uint64 ValueError: if a.shape != b.shape, or if `a` or `b` contain more than 2**32-1 unique labels. """ if a.shape != b.shape: raise ValueError a = a.ravel() output_array = a b = b.ravel() def remap_input(x): """Remaps `x` if needed to fit within a 32-bit ID space. Args: x: uint64 numpy array. Returns: `remapped, max_id, orig_values_map`, where: `remapped` contains the remapped version of `x` containing only values < 2**32. `max_id = x.max()`. `orig_values_map` is None if `remapped == x`, or otherwise an array such that `x = orig_values_map[remapped]`. Raises: TypeError: if `x` does not have uint64 dtype ValueError: if `x.max() > 2**32-1`. """ if x.dtype != np.uint64: raise TypeError max_uint32 = 2**32 - 1 max_id = x.max() orig_values_map = None if max_id > max_uint32: orig_values_map, x = np.unique(x, return_inverse=True) if len(orig_values_map) > max_uint32: raise ValueError('More than 2**32-1 unique labels not supported') x = np.cast[np.uint64](x) if orig_values_map[0] != 0: orig_values_map = np.concatenate( [np.array([0], dtype=np.uint64), orig_values_map]) x[...] += 1 return x, max_id, orig_values_map remapped_a, max_id, a_reverse_map = remap_input(a) remapped_b, _, _ = remap_input(b) intersection_segment_ids = np.bitwise_or(remapped_a, remapped_b << 32) unique_joint_labels, remapped_joint_labels, joint_counts = np.unique( intersection_segment_ids, return_inverse=True, return_counts=True) unique_joint_labels_a = np.bitwise_and(unique_joint_labels, 0xFFFFFFFF) unique_joint_labels_b = unique_joint_labels >> 32 # Maps each segment id `id_a` in `remapped_a` to `(id_b, joint_count)` where # `id_b` is the segment id in `remapped_b` with maximum overlap, and # `joint_count` is the number of voxels of overlap. max_overlap_ids = dict() for label_a, label_b, count in zip(unique_joint_labels_a, unique_joint_labels_b, joint_counts): new_pair = (label_b, count) existing = max_overlap_ids.setdefault(label_a, new_pair) if existing[1] < count: max_overlap_ids[label_a] = new_pair # Relabel map to apply to remapped_joint_labels to obtain the output ids. new_labels = np.zeros(len(unique_joint_labels), np.uint64) for i, (label_a, label_b, count) in enumerate(zip(unique_joint_labels_a, unique_joint_labels_b, joint_counts)): if count < min_size or label_a == 0: new_label = 0 elif label_b == max_overlap_ids[label_a][0]: if a_reverse_map is not None: new_label = a_reverse_map[label_a] else: new_label = label_a else: max_id += 1 new_label = max_id new_labels[i] = new_label output_array[...] = new_labels[remapped_joint_labels]