Python tensorflow.python.ops.control_flow_ops.cond() Examples
The following are 30
code examples of tensorflow.python.ops.control_flow_ops.cond().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
tensorflow.python.ops.control_flow_ops
, or try the search function
.
Example #1
Source File: utils.py From lambda-packs with MIT License | 6 votes |
def smart_cond(pred, fn1, fn2, name=None): """Return either fn1() or fn2() based on the boolean predicate/value `pred`. If `pred` is bool or has a constant value it would use `static_cond`, otherwise it would use `tf.cond`. Args: pred: A scalar determining whether to return the result of `fn1` or `fn2`. fn1: The callable to be performed if pred is true. fn2: The callable to be performed if pred is false. name: Optional name prefix when using tf.cond Returns: Tensors returned by the call to either `fn1` or `fn2`. """ pred_value = constant_value(pred) if pred_value is not None: # Use static_cond if pred has a constant value. return static_cond(pred_value, fn1, fn2) else: # Use dynamic cond otherwise. return control_flow_ops.cond(pred, fn1, fn2, name)
Example #2
Source File: tensor_forest.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def average_impurity(self): """Constructs a TF graph for evaluating the average leaf impurity of a tree. If in regression mode, this is the leaf variance. If in classification mode, this is the gini impurity. Returns: The last op in the graph. """ children = array_ops.squeeze(array_ops.slice( self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), squeeze_dims=[1])) counts = array_ops.gather(self.variables.node_sums, leaves) gini = self._weighted_gini(counts) # Guard against step 1, when there often are no leaves yet. def impurity(): return gini # Since average impurity can be used for loss, when there's no data just # return a big number so that loss always decreases. def big(): return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000. return control_flow_ops.cond(math_ops.greater( array_ops.shape(leaves)[0], 0), impurity, big)
Example #3
Source File: topn.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def insert(self, ids, scores): """Insert the ids and scores into the TopN.""" with ops.control_dependencies(self.last_ops): scatter_op = state_ops.scatter_update(self.id_to_score, ids, scores) larger_scores = math_ops.greater(scores, self.sl_scores[0]) def shortlist_insert(): larger_ids = array_ops.boolean_mask( math_ops.to_int64(ids), larger_scores) larger_score_values = array_ops.boolean_mask(scores, larger_scores) shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert( self.sl_ids, self.sl_scores, larger_ids, larger_score_values) u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids) u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores) return control_flow_ops.group(u1, u2) # We only need to insert into the shortlist if there are any # scores larger than the threshold. cond_op = control_flow_ops.cond( math_ops.reduce_any(larger_scores), shortlist_insert, control_flow_ops.no_op) with ops.control_dependencies([cond_op]): self.last_ops = [scatter_op, cond_op]
Example #4
Source File: metric_ops.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _safe_scalar_div(numerator, denominator, name): """Divides two values, returning 0 if the denominator is 0. Args: numerator: A scalar `float64` `Tensor`. denominator: A scalar `float64` `Tensor`. name: Name for the returned op. Returns: 0 if `denominator` == 0, else `numerator` / `denominator` """ numerator.get_shape().with_rank_at_most(1) denominator.get_shape().with_rank_at_most(1) return control_flow_ops.cond( math_ops.equal( array_ops.constant(0.0, dtype=dtypes.float64), denominator), lambda: array_ops.constant(0.0, dtype=dtypes.float64), lambda: math_ops.div(numerator, denominator), name=name)
Example #5
Source File: utils.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def static_cond(pred, fn1, fn2): """Return either fn1() or fn2() based on the boolean value of `pred`. Same signature as `control_flow_ops.cond()` but requires pred to be a bool. Args: pred: A value determining whether to return the result of `fn1` or `fn2`. fn1: The callable to be performed if pred is true. fn2: The callable to be performed if pred is false. Returns: Tensors returned by the call to either `fn1` or `fn2`. Raises: TypeError: if `fn1` or `fn2` is not callable. """ if not callable(fn1): raise TypeError('fn1 must be callable.') if not callable(fn2): raise TypeError('fn2 must be callable.') if pred: return fn1() else: return fn2()
Example #6
Source File: utils.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def smart_cond(pred, fn1, fn2, name=None): """Return either fn1() or fn2() based on the boolean predicate/value `pred`. If `pred` is bool or has a constant value it would use `static_cond`, otherwise it would use `tf.cond`. Args: pred: A scalar determining whether to return the result of `fn1` or `fn2`. fn1: The callable to be performed if pred is true. fn2: The callable to be performed if pred is false. name: Optional name prefix when using tf.cond Returns: Tensors returned by the call to either `fn1` or `fn2`. """ pred_value = constant_value(pred) if pred_value is not None: # Use static_cond if pred has a constant value. return static_cond(pred_value, fn1, fn2) else: # Use dynamic cond otherwise. return control_flow_ops.cond(pred, fn1, fn2, name)
Example #7
Source File: bernoulli.py From lambda-packs with MIT License | 6 votes |
def _log_prob(self, event): event = self._maybe_assert_valid_sample(event) # TODO(jaana): The current sigmoid_cross_entropy_with_logits has # inconsistent behavior for logits = inf/-inf. event = math_ops.cast(event, self.logits.dtype) logits = self.logits # sigmoid_cross_entropy_with_logits doesn't broadcast shape, # so we do this here. def _broadcast(logits, event): return (array_ops.ones_like(event) * logits, array_ops.ones_like(logits) * event) # First check static shape. if (event.get_shape().is_fully_defined() and logits.get_shape().is_fully_defined()): if event.get_shape() != logits.get_shape(): logits, event = _broadcast(logits, event) else: logits, event = control_flow_ops.cond( distribution_util.same_dynamic_shape(logits, event), lambda: (logits, event), lambda: _broadcast(logits, event)) return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
Example #8
Source File: tensor_forest.py From lambda-packs with MIT License | 6 votes |
def average_impurity(self): """Constructs a TF graph for evaluating the average leaf impurity of a tree. If in regression mode, this is the leaf variance. If in classification mode, this is the gini impurity. Returns: The last op in the graph. """ children = array_ops.squeeze(array_ops.slice( self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1]) is_leaf = math_ops.equal(constants.LEAF_NODE, children) leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf), squeeze_dims=[1])) counts = array_ops.gather(self.variables.node_sums, leaves) gini = self._weighted_gini(counts) # Guard against step 1, when there often are no leaves yet. def impurity(): return gini # Since average impurity can be used for loss, when there's no data just # return a big number so that loss always decreases. def big(): return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000. return control_flow_ops.cond(math_ops.greater( array_ops.shape(leaves)[0], 0), impurity, big)
Example #9
Source File: bernoulli.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _log_prob(self, event): # TODO(jaana): The current sigmoid_cross_entropy_with_logits has # inconsistent behavior for logits = inf/-inf. event = ops.convert_to_tensor(event, name="event") event = math_ops.cast(event, self.logits.dtype) logits = self.logits # sigmoid_cross_entropy_with_logits doesn't broadcast shape, # so we do this here. broadcast = lambda logits, event: ( array_ops.ones_like(event) * logits, array_ops.ones_like(logits) * event) # First check static shape. if (event.get_shape().is_fully_defined() and logits.get_shape().is_fully_defined()): if event.get_shape() != logits.get_shape(): logits, event = broadcast(logits, event) else: logits, event = control_flow_ops.cond( distribution_util.same_dynamic_shape(logits, event), lambda: (logits, event), lambda: broadcast(logits, event)) return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits)
Example #10
Source File: image_ops_impl.py From lambda-packs with MIT License | 6 votes |
def _assert(cond, ex_type, msg): """A polymorphic assert, works with tensors and boolean expressions. If `cond` is not a tensor, behave like an ordinary assert statement, except that a empty list is returned. If `cond` is a tensor, return a list containing a single TensorFlow assert op. Args: cond: Something evaluates to a boolean value. May be a tensor. ex_type: The exception class to use. msg: The error message. Returns: A list, containing at most one assert op. """ if _is_tensor(cond): return [control_flow_ops.Assert(cond, [msg])] else: if not cond: raise ex_type(msg) else: return []
Example #11
Source File: variables.py From lambda-packs with MIT License | 6 votes |
def initialized_value(self): """Returns the value of the initialized variable. You should use this instead of the variable itself to initialize another variable with a value that depends on the value of this variable. ```python # Initialize 'v' with a random tensor. v = tf.Variable(tf.truncated_normal([10, 40])) # Use `initialized_value` to guarantee that `v` has been # initialized before its value is used to initialize `w`. # The random values are picked only once. w = tf.Variable(v.initialized_value() * 2.0) ``` Returns: A `Tensor` holding the value of this variable after its initializer has run. """ with ops.control_dependencies(None): return control_flow_ops.cond(is_variable_initialized(self), self.read_value, lambda: self.initial_value)
Example #12
Source File: metrics_impl.py From lambda-packs with MIT License | 6 votes |
def _safe_scalar_div(numerator, denominator, name): """Divides two values, returning 0 if the denominator is 0. Args: numerator: A scalar `float64` `Tensor`. denominator: A scalar `float64` `Tensor`. name: Name for the returned op. Returns: 0 if `denominator` == 0, else `numerator` / `denominator` """ numerator.get_shape().with_rank_at_most(1) denominator.get_shape().with_rank_at_most(1) return control_flow_ops.cond( math_ops.equal( array_ops.constant(0.0, dtype=dtypes.float64), denominator), lambda: array_ops.constant(0.0, dtype=dtypes.float64), lambda: math_ops.div(numerator, denominator), name=name)
Example #13
Source File: session_debug_testlib.py From lambda-packs with MIT License | 6 votes |
def testDebugCondWatchingWholeGraphWorks(self): with session.Session() as sess: x = variables.Variable(10.0, name="x") y = variables.Variable(20.0, name="y") cond = control_flow_ops.cond( x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1)) sess.run(variables.global_variables_initializer()) run_options = config_pb2.RunOptions(output_partition_graphs=True) debug_utils.watch_graph(run_options, sess.graph, debug_urls=self._debug_urls()) run_metadata = config_pb2.RunMetadata() self.assertEqual( 21, sess.run(cond, options=run_options, run_metadata=run_metadata)) dump = debug_data.DebugDumpDir( self._dump_root, partition_graphs=run_metadata.partition_graphs) self.assertAllClose( [21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity"))
Example #14
Source File: utils.py From lambda-packs with MIT License | 6 votes |
def static_cond(pred, fn1, fn2): """Return either fn1() or fn2() based on the boolean value of `pred`. Same signature as `control_flow_ops.cond()` but requires pred to be a bool. Args: pred: A value determining whether to return the result of `fn1` or `fn2`. fn1: The callable to be performed if pred is true. fn2: The callable to be performed if pred is false. Returns: Tensors returned by the call to either `fn1` or `fn2`. Raises: TypeError: if `fn1` or `fn2` is not callable. """ if not callable(fn1): raise TypeError('fn1 must be callable.') if not callable(fn2): raise TypeError('fn2 must be callable.') if pred: return fn1() else: return fn2()
Example #15
Source File: tpu_estimator.py From Chinese-XLNet with Apache License 2.0 | 6 votes |
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn): """Wraps the ops generated by `op_fn` in tf.while_loop.""" def cond(scalar_stopping_signal): return math_ops.logical_not( _StopSignals.should_stop(scalar_stopping_signal)) def computation(unused_scalar_stopping_signal): return_value = op_fn() execute_ops = return_value['ops'] signals = return_value['signals'] with ops.control_dependencies(execute_ops): return _StopSignals.as_scalar_stopping_signal(signals) # By setting parallel_iterations=1, the parallel execution in while_loop is # basically turned off. with ops.device(device): return control_flow_ops.while_loop( cond, computation, [_StopSignals.NON_STOPPING_SIGNAL], parallel_iterations=1)
Example #16
Source File: metrics_impl.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _safe_scalar_div(numerator, denominator, name): """Divides two values, returning 0 if the denominator is 0. Args: numerator: A scalar `float64` `Tensor`. denominator: A scalar `float64` `Tensor`. name: Name for the returned op. Returns: 0 if `denominator` == 0, else `numerator` / `denominator` """ numerator.get_shape().with_rank_at_most(1) denominator.get_shape().with_rank_at_most(1) return control_flow_ops.cond( math_ops.equal( array_ops.constant(0.0, dtype=dtypes.float64), denominator), lambda: array_ops.constant(0.0, dtype=dtypes.float64), lambda: math_ops.div(numerator, denominator), name=name)
Example #17
Source File: utils.py From tensornets with MIT License | 6 votes |
def static_cond(pred, fn1, fn2): """Return either fn1() or fn2() based on the boolean value of `pred`. Same signature as `control_flow_ops.cond()` but requires pred to be a bool. Args: pred: A value determining whether to return the result of `fn1` or `fn2`. fn1: The callable to be performed if pred is true. fn2: The callable to be performed if pred is false. Returns: Tensors returned by the call to either `fn1` or `fn2`. Raises: TypeError: if `fn1` or `fn2` is not callable. """ if not callable(fn1): raise TypeError('fn1 must be callable.') if not callable(fn2): raise TypeError('fn2 must be callable.') if pred: return fn1() else: return fn2()
Example #18
Source File: utils.py From tensornets with MIT License | 6 votes |
def smart_cond(pred, fn1, fn2, name=None): """Return either fn1() or fn2() based on the boolean predicate/value `pred`. If `pred` is bool or has a constant value it would use `static_cond`, otherwise it would use `tf.cond`. Args: pred: A scalar determining whether to return the result of `fn1` or `fn2`. fn1: The callable to be performed if pred is true. fn2: The callable to be performed if pred is false. name: Optional name prefix when using tf.cond Returns: Tensors returned by the call to either `fn1` or `fn2`. """ pred_value = constant_value(pred) if pred_value is not None: # Use static_cond if pred has a constant value. return static_cond(pred_value, fn1, fn2) else: # Use dynamic cond otherwise. return control_flow_ops.cond(pred, fn1, fn2, name)
Example #19
Source File: tf_image.py From seglink with GNU General Public License v3.0 | 6 votes |
def _assert(cond, ex_type, msg): """A polymorphic assert, works with tensors and boolean expressions. If `cond` is not a tensor, behave like an ordinary assert statement, except that a empty list is returned. If `cond` is a tensor, return a list containing a single TensorFlow assert op. Args: cond: Something evaluates to a boolean value. May be a tensor. ex_type: The exception class to use. msg: The error message. Returns: A list, containing at most one assert op. """ if _is_tensor(cond): return [control_flow_ops.Assert(cond, [msg])] else: if not cond: raise ex_type(msg) else: return []
Example #20
Source File: tf_image.py From seglink with GNU General Public License v3.0 | 6 votes |
def random_flip_left_right(image, bboxes, seed=None): """Random flip left-right of an image and its bounding boxes. """ def flip_bboxes(bboxes): """Flip bounding boxes coordinates. """ bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3], bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1) return bboxes # Random flip. Tensorflow implementation. with tf.name_scope('random_flip_left_right'): image = ops.convert_to_tensor(image, name='image') _Check3DImage(image, require_static=False) uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed) mirror_cond = math_ops.less(uniform_random, .5) # Flip image. result = control_flow_ops.cond(mirror_cond, lambda: array_ops.reverse_v2(image, [1]), lambda: image) # Flip bboxes. bboxes = control_flow_ops.cond(mirror_cond, lambda: flip_bboxes(bboxes), lambda: bboxes) return fix_image_flip_shape(image, result), bboxes
Example #21
Source File: image_ops_impl.py From auto-alt-text-lambda-api with MIT License | 6 votes |
def _assert(cond, ex_type, msg): """A polymorphic assert, works with tensors and boolean expressions. If `cond` is not a tensor, behave like an ordinary assert statement, except that a empty list is returned. If `cond` is a tensor, return a list containing a single TensorFlow assert op. Args: cond: Something evaluates to a boolean value. May be a tensor. ex_type: The exception class to use. msg: The error message. Returns: A list, containing at most one assert op. """ if _is_tensor(cond): return [control_flow_ops.Assert(cond, [msg])] else: if not cond: raise ex_type(msg) else: return []
Example #22
Source File: input.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input): """Enqueue `tensor_list_list` in `queue`.""" if enqueue_many: enqueue_fn = queue.enqueue_many else: enqueue_fn = queue.enqueue if keep_input is None: enqueue_ops = [enqueue_fn(tl) for tl in tensor_list_list] else: enqueue_ops = [control_flow_ops.cond( keep_input, lambda: enqueue_fn(tl), control_flow_ops.no_op) for tl in tensor_list_list] queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
Example #23
Source File: utils.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def smart_cond(pred, fn1, fn2, name=None): """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`. If `pred` is a bool or has a constant value, we return either `fn1()` or `fn2()`, otherwise we use `tf.cond` to dynamically route to both. Arguments: pred: A scalar determining whether to return the result of `fn1` or `fn2`. fn1: The callable to be performed if pred is true. fn2: The callable to be performed if pred is false. name: Optional name prefix when using `tf.cond`. Returns: Tensors returned by the call to either `fn1` or `fn2`. Raises: TypeError is fn1 or fn2 is not callable. """ if not callable(fn1): raise TypeError('`fn1` must be callable.') if not callable(fn2): raise TypeError('`fn2` must be callable.') pred_value = constant_value(pred) if pred_value is not None: if pred_value: return fn1() else: return fn2() else: return control_flow_ops.cond(pred, fn1, fn2, name)
Example #24
Source File: drop_stale_gradient_optimizer.py From lambda-packs with MIT License | 5 votes |
def apply_gradients(self, grads_and_vars, global_step=None, name=None): gradients = [] # Number of stale gradients. stale_counter = variable_scope.get_variable( "stale_counter", [], initializer=init_ops.zeros_initializer(), trainable=False) def _AcceptGradientOp(): with ops.control_dependencies( [self._opt.apply_gradients( grads_and_vars, global_step=global_step, name=name)]): return gen_array_ops.identity(0.0) def _DropGradientOp(): return gen_array_ops.identity(1.0) for grad_and_var in grads_and_vars: grad = grad_and_var[0] if isinstance(grad, ops.Tensor): gradients.append(grad) elif grad is not None: gradients.append(grad.op) with ops.control_dependencies(gradients), ops.colocate_with(global_step): staleness = gen_array_ops.reshape( global_step - self._local_step, shape=()) conditional_update = stale_counter.assign_add(control_flow_ops.cond( gen_math_ops.less_equal(staleness, self._staleness), _AcceptGradientOp, _DropGradientOp)) summary.scalar( "Gradient staleness percentage", stale_counter / (math_ops.cast(global_step + 1, dtypes.float32))) return conditional_update
Example #25
Source File: check_ops.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _get_diff_for_monotonic_comparison(x): """Gets the difference x[1:] - x[:-1].""" x = array_ops.reshape(x, [-1]) if not is_numeric_tensor(x): raise TypeError('Expected x to be numeric, instead found: %s' % x) # If x has less than 2 elements, there is nothing to compare. So return []. is_shorter_than_two = math_ops.less(array_ops.size(x), 2) short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) # With 2 or more elements, return x[1:] - x[:-1] s_len = array_ops.shape(x) - 1 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
Example #26
Source File: operator_pd.py From lambda-packs with MIT License | 5 votes |
def _dispatch_based_on_batch(self, batch_method, singleton_method, **args): """Helper to automatically call batch or singleton operation.""" if self.get_shape().ndims is not None: is_batch = self.get_shape().ndims > 2 if is_batch: return batch_method(**args) else: return singleton_method(**args) else: is_batch = self.rank() > 2 return control_flow_ops.cond( is_batch, lambda: batch_method(**args), lambda: singleton_method(**args) )
Example #27
Source File: tensor_forest.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _get_loss(self, features, labels): """Constructs, caches, and returns the inference-based loss.""" if self._loss is not None: return self._loss def _average_loss(): probs = self.inference_graph(features) return math_ops.reduce_sum(self.loss_fn( probs, labels)) / math_ops.to_float(array_ops.shape(labels)[0]) self._loss = control_flow_ops.cond( self.average_size() > 0, _average_loss, lambda: constant_op.constant(sys.maxsize, dtype=dtypes.float32)) return self._loss
Example #28
Source File: operator_pd.py From auto-alt-text-lambda-api with MIT License | 5 votes |
def _dispatch_based_on_batch(self, batch_method, singleton_method, **args): """Helper to automatically call batch or singleton operation.""" if self.get_shape().ndims is not None: is_batch = self.get_shape().ndims > 2 if is_batch: return batch_method(**args) else: return singleton_method(**args) else: is_batch = self.rank() > 2 return control_flow_ops.cond( is_batch, lambda: batch_method(**args), lambda: singleton_method(**args) )
Example #29
Source File: tpu_estimator.py From Chinese-XLNet with Apache License 2.0 | 5 votes |
def slice_tensor_or_dict(tensor_or_dict, signals): """Slice the real Tensors according to padding mask in signals.""" padding_mask = signals['padding_mask'] batch_size = array_ops.shape(padding_mask)[0] def verify_batch_size(tensor): check_batch_size = math_ops.equal(batch_size, tensor.shape[0]) with ops.control_dependencies([check_batch_size]): return array_ops.identity(tensor) def slice_single_tensor(tensor): rank = len(tensor.shape) assert rank > 0 real_batch_size = batch_size - math_ops.reduce_sum(padding_mask) return verify_batch_size(tensor)[0:real_batch_size] # As we split the Tensors to all TPU cores and concat them back, it is # important to ensure the real data is placed before padded ones, i.e., # order is preserved. By that, the sliced padding mask should have all 0's. # If this assertion failed, # the slice logic here would not hold. sliced_padding_mask = slice_single_tensor(padding_mask) assert_padding_mask = math_ops.equal( math_ops.reduce_sum(sliced_padding_mask), 0) with ops.control_dependencies([assert_padding_mask]): should_stop = _StopSignals.should_stop( _StopSignals.as_scalar_stopping_signal(signals)) is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0) def slice_fn(tensor): # If the current batch is full batch or part of stopping signals, we do # not need to slice to save performance. return control_flow_ops.cond( math_ops.logical_or(should_stop, is_full_batch), (lambda: verify_batch_size(tensor)), (lambda: slice_single_tensor(tensor))) return nest.map_structure(slice_fn, tensor_or_dict)
Example #30
Source File: tf_example_decoder.py From Person-Detection-and-Tracking with MIT License | 5 votes |
def _decode_png_instance_masks(self, keys_to_tensors): """Decode PNG instance segmentation masks and stack into dense tensor. The instance segmentation masks are reshaped to [num_instances, height, width]. Args: keys_to_tensors: a dictionary from keys to tensors. Returns: A 3-D float tensor of shape [num_instances, height, width] with values in {0, 1}. """ def decode_png_mask(image_buffer): image = tf.squeeze( tf.image.decode_image(image_buffer, channels=1), axis=2) image.set_shape([None, None]) image = tf.to_float(tf.greater(image, 0)) return image png_masks = keys_to_tensors['image/object/mask'] height = keys_to_tensors['image/height'] width = keys_to_tensors['image/width'] if isinstance(png_masks, tf.SparseTensor): png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='') return tf.cond( tf.greater(tf.size(png_masks), 0), lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32), lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))