# Copyright 2018 Daniel Povey, Hossein Hadian # Apache 2.0 """ This module contains codes and algorithms for post-processing the output of the nnet to find the objects in the image. Specifically it is a greedy algorithm in which the only operation is merging objects, starting from individual pixels, and the only choice is in which order to merge objects. At all stages of optimization, objects will maintain their optimal class assignment. """ import os import sys from heapq import heappush, heappop import numpy as np import warnings import resource import scipy.misc from collections import namedtuple SegmenterOptions = namedtuple('SegmenterOptions', ['same_different_bias', 'object_merge_factor', 'merge_logprob_bias']) class Object: """ This class represents an "object" in the output image. Attributes: object_class: A record of the current assigned class (an integer) pixels: A set of pixels (2-tuples) that are part of the object class_logprobs: An array indexed by class, of the total (over all pixels in this object) of the log-prob of assigning this pixel to this class; the object_class corresponds to the index of the max element of this. adjacency_list: A list of adjacency records, recording other objects to which this object is adjacent. ('Adjacent' means "linked by an offset", not adjacency in the normal sense). It's actually a map (from obj pairs to adjacency record) for faster search and access. """ def __init__(self, pixels, id, segmenter): self.pixels = pixels self.compute_class_logprobs(segmenter) self.object_class = np.argmax(self.class_logprobs) self.adjacency_list = {} self.id = id self.sameness_logprob = 0 def compute_class_logprobs(self, segmenter): self.class_logprobs = np.zeros(segmenter.num_classes) for c in range(len(self.class_logprobs)): for p in self.pixels: self.class_logprobs[c] += segmenter.get_class_logprob(p, c) def class_logprob(self): return self.class_logprobs[self.object_class] def print(self): print("Object {}. Adj list:".format(self)) for obj_pair in self.adjacency_list: print("\t{} --> {}".format(obj_pair, self.adjacency_list[obj_pair])) print("") print("Pixel list: {}".format(self.pixels)) def compute_sameness_logprob(self, segmenter): """ This is only used for debugging purposes. """ self.sameness_logprob = 0 for i, o in enumerate(segmenter.offsets): for p1 in self.pixels: p2 = (p1[0] + o[0], p1[1] + o[1]) if p2 in self.pixels: same_prob = segmenter.get_sameness_prob(p1, i) self.sameness_logprob += np.log(same_prob) def __hash__(self): return hash(self.id) def __eq__(self, other): return self.id == other.id def __ne__(self, other): return not(self == other) def __str__(self): return "<OBJ:{} class:{} npix:{} nadj:{}>".format(self.id, self.object_class, len(self.pixels), len(self.adjacency_list)) class AdjacencyRecord: """ This class implements an adjacency record with functions for computing object merge log-probs and class merge log-probs. Attributes: obj1, obj2: The two objects to which it refers object_merge_logprob: This is the change in log-probability from merging these two objects, without considering the effect arising from changes of class assignments. This is the sum of the following: For each p,o such that o is in "offsets", p is in one of the two objects and p+o is in the other, the value log(p(same) / p(different)), i.e.g log(b_{p,o} / (1-b{p,o})). Note: if the sum above had no terms in it, this adjacency record should not exist because the objects would not be "adjacent" in the sense which we refer to. merge_priority: This merge priority is a heuristic which will determine what kinds of objects will get merged first, and is a key choice that we'll have to experiment with. (Note: you can change the sign if it turns out to be easier for python heap reasons). The general idea will be: merge_priority = merge_log_prob / den where merge_log_prob is the log-prob change from doing this merge, and for example, "den" might be the maximum of the num-pixels in object1 and object2. We can experiment with different heuristics for "den" though. class_delta_logprob: It is a term representing a change in the total log-prob that we'll get from merging the two objects, that arises from forcing the class assignments to be the same. If the two objects already have the same assigned class, this will be zero. If different, then this is a value <= 0 which can be obtained by summing the objects' 'class_logprobs' arrays, finding the largest log-prob in the total, and subtracting the total from the current class-assignments of the two objects. merged_class: The class that the merged object would have, obtained when figuring out class_delta_log_prob. """ def __init__(self, obj1, obj2, segmenter, pixel=None, offset=None): self.obj1 = obj1 self.obj2 = obj2 self.sort_and_update_hash() if pixel is not None and offset is not None: same_prob = segmenter.get_sameness_prob(pixel, offset) log_same_prob = np.log(same_prob) log_different_prob = np.log(1.0 - same_prob) self.differentness_logprob = log_different_prob self.sameness_logprob = log_same_prob self.obj_merge_logprob = log_same_prob - log_different_prob else: self.compute_obj_merge_logprob(segmenter) if self.obj_merge_logprob is None: raise Exception( "Bad adjacency record. The given objects are not adjacent.") self.class_delta_logprob = None self.merged_class = None self.merge_priority = None self.update_merge_priority(segmenter) def compute_obj_merge_logprob(self, segmenter): logprob = 0 adjacent = False self.differentness_logprob = 0 self.sameness_logprob = 0 for o, i in zip(segmenter.offsets, range(len(segmenter.offsets))): for p1 in self.obj1.pixels: p2 = (p1[0] + o[0], p1[1] + o[1]) if p2 in self.obj2.pixels: adjacent = True same_prob = segmenter.get_sameness_prob(p1, i) log_same_prob = np.log(same_prob) log_different_prob = np.log(1.0 - same_prob) self.differentness_logprob += log_different_prob self.sameness_logprob += log_same_prob logprob += log_same_prob - log_different_prob for p1 in self.obj2.pixels: p2 = (p1[0] + o[0], p1[1] + o[1]) if p2 in self.obj1.pixels: adjacent = True same_prob = segmenter.get_sameness_prob(p1, i) log_same_prob = np.log(same_prob) log_different_prob = np.log(1.0 - same_prob) self.differentness_logprob += log_different_prob self.sameness_logprob += log_same_prob logprob += log_same_prob - log_different_prob self.obj_merge_logprob = logprob if adjacent else None def compute_class_delta_logprob(self): if self.obj1.object_class == self.obj2.object_class: self.class_delta_logprob, self.merged_class = 0.0, self.obj1.object_class else: joint_class_logprobs = self.obj1.class_logprobs + self.obj2.class_logprobs self.merged_class = np.argmax(joint_class_logprobs) merged_class_joint_logprob = joint_class_logprobs[self.merged_class] self.class_delta_logprob = merged_class_joint_logprob - \ self.obj1.class_logprob() - self.obj2.class_logprob() def update_merge_priority(self, segmenter): self.compute_class_delta_logprob() den = len(self.obj1.pixels) * len(self.obj2.pixels) self.merge_priority = (self.obj_merge_logprob * segmenter.opts.object_merge_factor + self.class_delta_logprob + segmenter.opts.merge_logprob_bias) / den def obj_pair(self): return ObjPair(self.obj1, self.obj2) def sort_and_update_hash(self): if self.obj1.id > self.obj2.id: # swap them self.obj1, self.obj2 = self.obj2, self.obj1 self.cached_hash = hash((self.obj1.id, self.obj2.id)) def __hash__(self): return self.cached_hash def __eq__(self, other): return (self.obj1.id, self.obj2.id) == (other.obj1.id, other.obj2.id) def __ne__(self, other): return not(self == other) def print(self): print("Objects in arec {}:".format(self)) self.obj1.print() self.obj2.print() def __lt__(self, other): return self.merge_priority < other.merge_priority def __str__(self): return "<AREC-{}: [{}, {}] oml:{:0.2f} cdl:{:0.2f} mp:{:0.2f}>".format( id(self), self.obj1, self.obj2, self.obj_merge_logprob, self.class_delta_logprob, self.merge_priority) class ObjectSegmenter: def __init__(self, nnet_class_probs, nnet_sameness_probs, num_classes, offsets, opts=None): self.opts = opts if self.opts is None: self.opts = self.default_options() print(self.opts) epsilon = np.finfo(np.float32).eps self.class_probs = nnet_class_probs.clip(epsilon, 1.0 - epsilon) self.sameness_probs = nnet_sameness_probs.clip(epsilon, 1.0 - epsilon) if self.opts.same_different_bias != 0.0: sameness_probs_biased_logit = (np.log(self.sameness_probs) - np.log(1.0 - self.sameness_probs) + self.opts.same_different_bias) self.sameness_probs = 1.0 / \ (1.0 + np.exp(-sameness_probs_biased_logit)) self.num_classes = num_classes self.offsets = offsets # should be a list of tuples # the pixels here are python tuples (x,y) not numpy arrays self.pixel2obj = {} class_dim, self.img_height, self.img_width = self.class_probs.shape offset_dim, img_height, img_width = self.sameness_probs.shape assert class_dim == self.num_classes assert offset_dim == len(self.offsets) assert self.img_height == img_height assert self.img_width == img_width self.objects = {} self.adjacency_records = {} self.queue = [] # Python's heapq self.init_objects_and_adjacency_records() def default_options(self): return SegmenterOptions(same_different_bias=0.0, object_merge_factor=1.0, merge_logprob_bias=0.0) def init_objects_and_adjacency_records(self): print("Initializing the segmenter...") print("Max mem: {} GB".format(resource.getrusage( resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024)) obj_id = 0 for row in range(self.img_height): for col in range(self.img_width): pixels = set([(row, col)]) obj = Object(pixels, obj_id, self) self.objects[obj_id] = obj self.pixel2obj[(row, col)] = obj obj_id += 1 for row in range(self.img_height): for col in range(self.img_width): obj1 = self.pixel2obj[(row, col)] for o, idx in zip(self.offsets, range(len(self.offsets))): (i, j) = o if (0 <= row + i < self.img_height and 0 <= col + j < self.img_width): obj2 = self.pixel2obj[(row + i, col + j)] arec = AdjacencyRecord(obj1, obj2, self, (row, col), idx) self.adjacency_records[arec] = arec obj1.adjacency_list[arec] = arec obj2.adjacency_list[arec] = arec if arec.merge_priority >= 0: heappush(self.queue, (-arec.merge_priority, arec)) def get_class_logprob(self, pixel, class_index): return np.log(self.class_probs[class_index, pixel[0], pixel[1]]) def get_sameness_prob(self, pixel, offset_index): return self.sameness_probs[offset_index, pixel[0], pixel[1]] def show_stats(self): print("Total logprob: " "{:.3f}".format(self.compute_total_logprob())) print("Total number of objects: {}".format(len(self.objects))) print("Total number of adjacency records: " "{}".format(len(self.adjacency_records))) print("Total number of records in the queue: {}".format(len(self.queue))) pixperobj = sorted([len(obj.pixels) for obj in self.objects.values()], reverse=True) print("Top 10 biggest objs (#pixels): {}".format(pixperobj[:10])) adjlistsize = sorted([len(obj.adjacency_list) for obj in self.objects.values()], reverse=True) print("Top 10 biggest objs (adj_list size): {}".format( adjlistsize[:10])) def compute_total_logprob_from_scratch(self): """ This is for debugging only. """ tot_class_logprob = 0 tot_differentness_logprob = 0 tot_sameness_logprob = 0 for obj in self.objects.values(): for p in obj.pixels: self.pixel2obj[p] = obj tot_class_logprob += self.get_class_logprob(p, obj.object_class) for row in range(self.img_height): for col in range(self.img_width): p1 = (row, col) obj1 = self.pixel2obj[p1] for i, o in enumerate(self.offsets): if (0 <= row + o[0] < self.img_height and 0 <= col + o[1] < self.img_width): obj2 = self.pixel2obj[(row + o[0], col + o[1])] if obj1 is obj2 or obj1 == obj2: tot_sameness_logprob += np.log( self.get_sameness_prob(p1, i)) else: tot_differentness_logprob += np.log( 1.0 - self.get_sameness_prob(p1, i)) return tot_class_logprob + (tot_differentness_logprob + tot_sameness_logprob) * self.opts.object_merge_factor def compute_total_logprob(self): tot_class_logprob = 0 tot_differentness_logprob = 0 tot_sameness_logprob = 0 for obj in self.objects.values(): tot_class_logprob += obj.class_logprob() tot_sameness_logprob += obj.sameness_logprob for arec in self.adjacency_records.values(): tot_differentness_logprob += arec.differentness_logprob return tot_class_logprob + (tot_differentness_logprob + tot_sameness_logprob) * self.opts.object_merge_factor def visualize(self, iter): img = np.zeros((self.img_height, self.img_width)) k = 1 for obj in self.objects.values(): for p in obj.pixels: img[p] = k center = tuple(np.array(list(obj.pixels)).mean(axis=0)) img[int(center[0]), int(center[1])] = 0.0 k += 1 scipy.misc.imsave('{}.png'.format(iter), img) def prune(self, threshold=200.0): # Find the biggest background object: num_pixels = 0 for obj in self.objects.values(): if obj.object_class == 0 and len(obj.pixels) > num_pixels: background_obj = obj num_pixels = len(obj.pixels) objects_to_be_merged = [] for obj in self.objects.values(): nonbackground_score = obj.class_logprob() - obj.class_logprobs[0] if self.verbose >= 2: print("obj: {} --> {:0.2f}".format(len(obj.pixels), nonbackground_score)) if nonbackground_score < threshold and obj is not background_obj: objects_to_be_merged.append(obj) for obj in objects_to_be_merged: if self.verbose >= 1: print("Merging obj with {} pixels to " "background...".format(len(obj.pixels))) background_obj.pixels = background_obj.pixels.union(obj.pixels) del self.objects[obj.id] print("Pruned {} objects (merged into background). Final objects:" " {}".format(len(objects_to_be_merged), len(self.objects))) def output_mask(self): mask = np.zeros((self.img_height, self.img_width), dtype=int) k = 1 object_class = [] for obj in self.objects.values(): # skip background object if obj.object_class == 0: continue object_class.append(obj.object_class) for p in obj.pixels: mask[p] = k k += 1 return mask, object_class def debug(self): """ Do some sanity checks and make sure certain quantities have values that they should have. This function is quite time-consuming and should not be called too many times.""" # check if the current set of objects excatly cover the whole image pix2count = np.zeros((self.img_height, self.img_width)) for obj in self.objects.values(): for p in obj.pixels: pix2count[p] += 1 if not (pix2count == 1).all(): print("Error: pixels are not all covered or they are double counted") np.set_printoptions(threshold=20000) print(pix2count) sys.exit(1) # check the adjacency lists of the objects tot_obj_adj_records = 0 for obj in self.objects.values(): tot_obj_adj_records += len(obj.adjacency_list) for arec in obj.adjacency_list.values(): assert arec in self.adjacency_records assert (arec.obj1 is obj) ^ (arec.obj2 is obj) # make sure that re-computing obj-mere-logprob does not change it # this is too costly to run, so only do it randomly with a small chance if np.random.random() > 0.95: obj_merge_logprob = arec.obj_merge_logprob arec.compute_obj_merge_logprob(self) if np.abs(arec.obj_merge_logprob - obj_merge_logprob) > 0.001: print("Error re-computing obj-merge logprob changed it for " "arec {}".format(arec)) print("Old logprob: {} " "new logprob: {}".format(obj_merge_logprob, arec.obj_merge_logprob)) arec.print() sys.exit(1) assert tot_obj_adj_records == 2 * len(self.adjacency_records) def run_segmentation(self): """ This is the top-level function that performs the optimization. This is the overview: - While the queue is non-empty: - Pop (merge_priority, arec) from the queue. - If merge_priority != arec.merge_priority continue # don't worry, the queue will have the right # merge_priority for this arec somewhere else in it. - Recompute arec.merge_priority, which involves recomputing class_delta_log_prob. This is needed because as we merge objects, the value of class_delta_log_prob and/or the number of pixels may have changed and the adjacency record may not have been updated. - If the newly computed arec.merge_priority is >= the old value (i.e. this merge is at least as good a merge as we thought it was when we got it from the queue), go ahead and merge the objects. - Otherwise if arec.merge_priority >=0 then re-insert "arec" into the queue with its newly computed merge priority. """ print("Starting segmentation...") n = 0 self.verbose = 0 self.do_debugging = False while self.queue: if n % 500000 == 0: print("At iteration {}: max mem: {:0.2f} GB".format( n, resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024)) self.show_stats() if self.do_debugging: print("Logprob from scratch: {}".format( self.compute_total_logprob_from_scratch())) print("") n += 1 merge_cost, arec = heappop(self.queue) merge_priority = -merge_cost if merge_priority != arec.merge_priority: continue arec.update_merge_priority(self) if arec.merge_priority >= merge_priority: self.merge(arec) elif arec.merge_priority >= 0: heappush(self.queue, (-arec.merge_priority, arec)) if len(self.queue) == 0: print("Finished. Queue is empty.") self.show_stats() self.visualize('final') if self.verbose >= 1: print("Final logprob from scratch: {}".format( self.compute_total_logprob_from_scratch())) return self.output_mask() def merge(self, arec): """ This is the most nontrivial aspect of the algorithm: how to merge objects. The basic steps in this function are as follows: - Swap object1 and object2 as necessary to ensure that the num-pixels in object1 is >= the num-pixels in object2. We will assimilate object2 into object1 and we can let object2 be deleted. - Set object1's object_class to merged_class - Append object2's pixels to object1's pixels - Add object2's class_log_probs to object1's class_log_probs - Merge object2's adjancency records into object1's adjacency records: more specifically, - For each element "this_arec" in object2.adjacency_list, change whichever of this_arec.object1 or this_arec.object2 equals "object2" to "object1". That is, make it point to the merged object "object1", instead of to the doomed "object2". - If object1.adjacency_list already contains an adjacency record with same pair of objects that are now in this_arec (viewing them as an unordered pair), then add this_arec.object_merge_log_prob to that adjacency record's object_merge_log_prob. Otherwise, add this_arec to object1.adjacency_list. - For each adjacency record that is directly touched during the process above: - Recompute its class_delta_log_prob, merged_class and merge_priority; if its merge_priority has changed and is >= 0, re-insert it into the queue. """ obj1, obj2 = arec.obj1, arec.obj2 if obj1.id not in self.objects or obj2.id not in self.objects: return if obj1 is obj2: return if len(obj2.pixels) > len(obj1.pixels): obj1, obj2 = obj2, obj1 assert np.abs(arec.obj_merge_logprob - (arec.sameness_logprob - arec.differentness_logprob)) < 0.001 if self.do_debugging: old_logprob = arec.obj_merge_logprob arec.compute_obj_merge_logprob(self) if np.abs(arec.obj_merge_logprob - old_logprob) > 0.001: print("Error: object merge logprob changed unexpectedly. " "{} != {}".format(arec.obj_merge_logprob, old_logprob)) arec.print() sys.exit(1) # now we are sure that obj1 has equal/more pixels obj1.object_class = arec.merged_class obj1.pixels = obj1.pixels.union(obj2.pixels) obj1.class_logprobs += obj2.class_logprobs obj1.sameness_logprob += arec.sameness_logprob + obj2.sameness_logprob del self.adjacency_records[arec] del obj1.adjacency_list[arec] del obj2.adjacency_list[arec] for this_arec in obj2.adjacency_list.values(): # obj3 is any object adjacent to obj2 (never is obj1): obj3 = this_arec.obj2 if this_arec.obj1 is obj2 else this_arec.obj1 assert obj3 is not obj1 del obj3.adjacency_list[this_arec] del self.adjacency_records[this_arec] if this_arec.obj1 is obj2: this_arec.obj1 = obj1 if this_arec.obj2 is obj2: this_arec.obj2 = obj1 this_arec.sort_and_update_hash() assert this_arec.obj1 is not this_arec.obj2 if this_arec in obj1.adjacency_list: that_arec = obj1.adjacency_list[this_arec] that_arec.obj_merge_logprob += this_arec.obj_merge_logprob that_arec.differentness_logprob += this_arec.differentness_logprob that_arec.sameness_logprob += this_arec.sameness_logprob # make sure it is practically deleted from the queue this_arec.merge_priority = -100000.0 self.adjacency_records[that_arec] = that_arec obj3.adjacency_list[that_arec] = that_arec that_arec.update_merge_priority(self) if that_arec.merge_priority >= 0: heappush(self.queue, (-that_arec.merge_priority, that_arec)) else: obj1.adjacency_list[this_arec] = this_arec obj3.adjacency_list[this_arec] = this_arec self.adjacency_records[this_arec] = this_arec this_arec.update_merge_priority(self) if this_arec.merge_priority >= 0: heappush(self.queue, (-this_arec.merge_priority, this_arec)) if self.verbose >= 2: print("Deleting {} being merged to {} according " "to {}".format(obj2, obj1, arec)) del self.objects[obj2.id]