Python tensorflow.python.ops.control_flow_ops.cond() Examples

The following are 30 code examples of tensorflow.python.ops.control_flow_ops.cond(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.control_flow_ops , or try the search function .
Example #1
Source File: utils.py    From lambda-packs with MIT License 6 votes vote down vote up
def smart_cond(pred, fn1, fn2, name=None):
  """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

  If `pred` is bool or has a constant value it would use `static_cond`,
  otherwise it would use `tf.cond`.

  Args:
    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.
    name: Optional name prefix when using tf.cond
  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.
  """
  pred_value = constant_value(pred)
  if pred_value is not None:
    # Use static_cond if pred has a constant value.
    return static_cond(pred_value, fn1, fn2)
  else:
    # Use dynamic cond otherwise.
    return control_flow_ops.cond(pred, fn1, fn2, name) 
Example #2
Source File: tensor_forest.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def average_impurity(self):
    """Constructs a TF graph for evaluating the average leaf impurity of a tree.

    If in regression mode, this is the leaf variance. If in classification mode,
    this is the gini impurity.

    Returns:
      The last op in the graph.
    """
    children = array_ops.squeeze(array_ops.slice(
        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
                                                 squeeze_dims=[1]))
    counts = array_ops.gather(self.variables.node_sums, leaves)
    gini = self._weighted_gini(counts)
    # Guard against step 1, when there often are no leaves yet.
    def impurity():
      return gini
    # Since average impurity can be used for loss, when there's no data just
    # return a big number so that loss always decreases.
    def big():
      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
    return control_flow_ops.cond(math_ops.greater(
        array_ops.shape(leaves)[0], 0), impurity, big) 
Example #3
Source File: topn.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def insert(self, ids, scores):
    """Insert the ids and scores into the TopN."""
    with ops.control_dependencies(self.last_ops):
      scatter_op = state_ops.scatter_update(self.id_to_score, ids, scores)
      larger_scores = math_ops.greater(scores, self.sl_scores[0])

      def shortlist_insert():
        larger_ids = array_ops.boolean_mask(
            math_ops.to_int64(ids), larger_scores)
        larger_score_values = array_ops.boolean_mask(scores, larger_scores)
        shortlist_ids, new_ids, new_scores = tensor_forest_ops.top_n_insert(
            self.sl_ids, self.sl_scores, larger_ids, larger_score_values)
        u1 = state_ops.scatter_update(self.sl_ids, shortlist_ids, new_ids)
        u2 = state_ops.scatter_update(self.sl_scores, shortlist_ids, new_scores)
        return control_flow_ops.group(u1, u2)

      # We only need to insert into the shortlist if there are any
      # scores larger than the threshold.
      cond_op = control_flow_ops.cond(
          math_ops.reduce_any(larger_scores), shortlist_insert,
          control_flow_ops.no_op)
      with ops.control_dependencies([cond_op]):
        self.last_ops = [scatter_op, cond_op] 
Example #4
Source File: metric_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _safe_scalar_div(numerator, denominator, name):
  """Divides two values, returning 0 if the denominator is 0.

  Args:
    numerator: A scalar `float64` `Tensor`.
    denominator: A scalar `float64` `Tensor`.
    name: Name for the returned op.

  Returns:
    0 if `denominator` == 0, else `numerator` / `denominator`
  """
  numerator.get_shape().with_rank_at_most(1)
  denominator.get_shape().with_rank_at_most(1)
  return control_flow_ops.cond(
      math_ops.equal(
          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
      lambda: math_ops.div(numerator, denominator),
      name=name) 
Example #5
Source File: utils.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def static_cond(pred, fn1, fn2):
  """Return either fn1() or fn2() based on the boolean value of `pred`.

  Same signature as `control_flow_ops.cond()` but requires pred to be a bool.

  Args:
    pred: A value determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.

  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.

  Raises:
    TypeError: if `fn1` or `fn2` is not callable.
  """
  if not callable(fn1):
    raise TypeError('fn1 must be callable.')
  if not callable(fn2):
    raise TypeError('fn2 must be callable.')
  if pred:
    return fn1()
  else:
    return fn2() 
Example #6
Source File: utils.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def smart_cond(pred, fn1, fn2, name=None):
  """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

  If `pred` is bool or has a constant value it would use `static_cond`,
  otherwise it would use `tf.cond`.

  Args:
    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.
    name: Optional name prefix when using tf.cond
  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.
  """
  pred_value = constant_value(pred)
  if pred_value is not None:
    # Use static_cond if pred has a constant value.
    return static_cond(pred_value, fn1, fn2)
  else:
    # Use dynamic cond otherwise.
    return control_flow_ops.cond(pred, fn1, fn2, name) 
Example #7
Source File: bernoulli.py    From lambda-packs with MIT License 6 votes vote down vote up
def _log_prob(self, event):
    event = self._maybe_assert_valid_sample(event)
    # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
    # inconsistent  behavior for logits = inf/-inf.
    event = math_ops.cast(event, self.logits.dtype)
    logits = self.logits
    # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
    # so we do this here.

    def _broadcast(logits, event):
      return (array_ops.ones_like(event) * logits,
              array_ops.ones_like(logits) * event)

    # First check static shape.
    if (event.get_shape().is_fully_defined() and
        logits.get_shape().is_fully_defined()):
      if event.get_shape() != logits.get_shape():
        logits, event = _broadcast(logits, event)
    else:
      logits, event = control_flow_ops.cond(
          distribution_util.same_dynamic_shape(logits, event),
          lambda: (logits, event),
          lambda: _broadcast(logits, event))
    return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits) 
Example #8
Source File: tensor_forest.py    From lambda-packs with MIT License 6 votes vote down vote up
def average_impurity(self):
    """Constructs a TF graph for evaluating the average leaf impurity of a tree.

    If in regression mode, this is the leaf variance. If in classification mode,
    this is the gini impurity.

    Returns:
      The last op in the graph.
    """
    children = array_ops.squeeze(array_ops.slice(
        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
                                                 squeeze_dims=[1]))
    counts = array_ops.gather(self.variables.node_sums, leaves)
    gini = self._weighted_gini(counts)
    # Guard against step 1, when there often are no leaves yet.
    def impurity():
      return gini
    # Since average impurity can be used for loss, when there's no data just
    # return a big number so that loss always decreases.
    def big():
      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
    return control_flow_ops.cond(math_ops.greater(
        array_ops.shape(leaves)[0], 0), impurity, big) 
Example #9
Source File: bernoulli.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _log_prob(self, event):
    # TODO(jaana): The current sigmoid_cross_entropy_with_logits has
    # inconsistent  behavior for logits = inf/-inf.
    event = ops.convert_to_tensor(event, name="event")
    event = math_ops.cast(event, self.logits.dtype)
    logits = self.logits
    # sigmoid_cross_entropy_with_logits doesn't broadcast shape,
    # so we do this here.

    broadcast = lambda logits, event: (
        array_ops.ones_like(event) * logits,
        array_ops.ones_like(logits) * event)

    # First check static shape.
    if (event.get_shape().is_fully_defined() and
        logits.get_shape().is_fully_defined()):
      if event.get_shape() != logits.get_shape():
        logits, event = broadcast(logits, event)
    else:
      logits, event = control_flow_ops.cond(
          distribution_util.same_dynamic_shape(logits, event),
          lambda: (logits, event),
          lambda: broadcast(logits, event))
    return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits) 
Example #10
Source File: image_ops_impl.py    From lambda-packs with MIT License 6 votes vote down vote up
def _assert(cond, ex_type, msg):
  """A polymorphic assert, works with tensors and boolean expressions.

  If `cond` is not a tensor, behave like an ordinary assert statement, except
  that a empty list is returned. If `cond` is a tensor, return a list
  containing a single TensorFlow assert op.

  Args:
    cond: Something evaluates to a boolean value. May be a tensor.
    ex_type: The exception class to use.
    msg: The error message.

  Returns:
    A list, containing at most one assert op.
  """
  if _is_tensor(cond):
    return [control_flow_ops.Assert(cond, [msg])]
  else:
    if not cond:
      raise ex_type(msg)
    else:
      return [] 
Example #11
Source File: variables.py    From lambda-packs with MIT License 6 votes vote down vote up
def initialized_value(self):
    """Returns the value of the initialized variable.

    You should use this instead of the variable itself to initialize another
    variable with a value that depends on the value of this variable.

    ```python
    # Initialize 'v' with a random tensor.
    v = tf.Variable(tf.truncated_normal([10, 40]))
    # Use `initialized_value` to guarantee that `v` has been
    # initialized before its value is used to initialize `w`.
    # The random values are picked only once.
    w = tf.Variable(v.initialized_value() * 2.0)
    ```

    Returns:
      A `Tensor` holding the value of this variable after its initializer
      has run.
    """
    with ops.control_dependencies(None):
      return control_flow_ops.cond(is_variable_initialized(self),
                                   self.read_value,
                                   lambda: self.initial_value) 
Example #12
Source File: metrics_impl.py    From lambda-packs with MIT License 6 votes vote down vote up
def _safe_scalar_div(numerator, denominator, name):
  """Divides two values, returning 0 if the denominator is 0.

  Args:
    numerator: A scalar `float64` `Tensor`.
    denominator: A scalar `float64` `Tensor`.
    name: Name for the returned op.

  Returns:
    0 if `denominator` == 0, else `numerator` / `denominator`
  """
  numerator.get_shape().with_rank_at_most(1)
  denominator.get_shape().with_rank_at_most(1)
  return control_flow_ops.cond(
      math_ops.equal(
          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
      lambda: math_ops.div(numerator, denominator),
      name=name) 
Example #13
Source File: session_debug_testlib.py    From lambda-packs with MIT License 6 votes vote down vote up
def testDebugCondWatchingWholeGraphWorks(self):
    with session.Session() as sess:
      x = variables.Variable(10.0, name="x")
      y = variables.Variable(20.0, name="y")
      cond = control_flow_ops.cond(
          x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))

      sess.run(variables.global_variables_initializer())

      run_options = config_pb2.RunOptions(output_partition_graphs=True)
      debug_utils.watch_graph(run_options,
                              sess.graph,
                              debug_urls=self._debug_urls())
      run_metadata = config_pb2.RunMetadata()
      self.assertEqual(
          21, sess.run(cond, options=run_options, run_metadata=run_metadata))

      dump = debug_data.DebugDumpDir(
          self._dump_root, partition_graphs=run_metadata.partition_graphs)
      self.assertAllClose(
          [21.0], dump.get_tensors("cond/Merge", 0, "DebugIdentity")) 
Example #14
Source File: utils.py    From lambda-packs with MIT License 6 votes vote down vote up
def static_cond(pred, fn1, fn2):
  """Return either fn1() or fn2() based on the boolean value of `pred`.

  Same signature as `control_flow_ops.cond()` but requires pred to be a bool.

  Args:
    pred: A value determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.

  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.

  Raises:
    TypeError: if `fn1` or `fn2` is not callable.
  """
  if not callable(fn1):
    raise TypeError('fn1 must be callable.')
  if not callable(fn2):
    raise TypeError('fn2 must be callable.')
  if pred:
    return fn1()
  else:
    return fn2() 
Example #15
Source File: tpu_estimator.py    From Chinese-XLNet with Apache License 2.0 6 votes vote down vote up
def _wrap_computation_in_while_loop_with_stopping_signals(device, op_fn):
  """Wraps the ops generated by `op_fn` in tf.while_loop."""

  def cond(scalar_stopping_signal):
    return math_ops.logical_not(
        _StopSignals.should_stop(scalar_stopping_signal))

  def computation(unused_scalar_stopping_signal):
    return_value = op_fn()
    execute_ops = return_value['ops']
    signals = return_value['signals']
    with ops.control_dependencies(execute_ops):
      return _StopSignals.as_scalar_stopping_signal(signals)

  # By setting parallel_iterations=1, the parallel execution in while_loop is
  # basically turned off.
  with ops.device(device):
    return control_flow_ops.while_loop(
        cond,
        computation, [_StopSignals.NON_STOPPING_SIGNAL],
        parallel_iterations=1) 
Example #16
Source File: metrics_impl.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _safe_scalar_div(numerator, denominator, name):
  """Divides two values, returning 0 if the denominator is 0.

  Args:
    numerator: A scalar `float64` `Tensor`.
    denominator: A scalar `float64` `Tensor`.
    name: Name for the returned op.

  Returns:
    0 if `denominator` == 0, else `numerator` / `denominator`
  """
  numerator.get_shape().with_rank_at_most(1)
  denominator.get_shape().with_rank_at_most(1)
  return control_flow_ops.cond(
      math_ops.equal(
          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
      lambda: math_ops.div(numerator, denominator),
      name=name) 
Example #17
Source File: utils.py    From tensornets with MIT License 6 votes vote down vote up
def static_cond(pred, fn1, fn2):
  """Return either fn1() or fn2() based on the boolean value of `pred`.

  Same signature as `control_flow_ops.cond()` but requires pred to be a bool.

  Args:
    pred: A value determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.

  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.

  Raises:
    TypeError: if `fn1` or `fn2` is not callable.
  """
  if not callable(fn1):
    raise TypeError('fn1 must be callable.')
  if not callable(fn2):
    raise TypeError('fn2 must be callable.')
  if pred:
    return fn1()
  else:
    return fn2() 
Example #18
Source File: utils.py    From tensornets with MIT License 6 votes vote down vote up
def smart_cond(pred, fn1, fn2, name=None):
  """Return either fn1() or fn2() based on the boolean predicate/value `pred`.

  If `pred` is bool or has a constant value it would use `static_cond`,
  otherwise it would use `tf.cond`.

  Args:
    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.
    name: Optional name prefix when using tf.cond
  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.
  """
  pred_value = constant_value(pred)
  if pred_value is not None:
    # Use static_cond if pred has a constant value.
    return static_cond(pred_value, fn1, fn2)
  else:
    # Use dynamic cond otherwise.
    return control_flow_ops.cond(pred, fn1, fn2, name) 
Example #19
Source File: tf_image.py    From seglink with GNU General Public License v3.0 6 votes vote down vote up
def _assert(cond, ex_type, msg):
    """A polymorphic assert, works with tensors and boolean expressions.
    If `cond` is not a tensor, behave like an ordinary assert statement, except
    that a empty list is returned. If `cond` is a tensor, return a list
    containing a single TensorFlow assert op.
    Args:
      cond: Something evaluates to a boolean value. May be a tensor.
      ex_type: The exception class to use.
      msg: The error message.
    Returns:
      A list, containing at most one assert op.
    """
    if _is_tensor(cond):
        return [control_flow_ops.Assert(cond, [msg])]
    else:
        if not cond:
            raise ex_type(msg)
        else:
            return [] 
Example #20
Source File: tf_image.py    From seglink with GNU General Public License v3.0 6 votes vote down vote up
def random_flip_left_right(image, bboxes, seed=None):
    """Random flip left-right of an image and its bounding boxes.
    """
    def flip_bboxes(bboxes):
        """Flip bounding boxes coordinates.
        """
        bboxes = tf.stack([bboxes[:, 0], 1 - bboxes[:, 3],
                           bboxes[:, 2], 1 - bboxes[:, 1]], axis=-1)
        return bboxes

    # Random flip. Tensorflow implementation.
    with tf.name_scope('random_flip_left_right'):
        image = ops.convert_to_tensor(image, name='image')
        _Check3DImage(image, require_static=False)
        uniform_random = random_ops.random_uniform([], 0, 1.0, seed=seed)
        mirror_cond = math_ops.less(uniform_random, .5)
        # Flip image.
        result = control_flow_ops.cond(mirror_cond,
                                       lambda: array_ops.reverse_v2(image, [1]),
                                       lambda: image)
        # Flip bboxes.
        bboxes = control_flow_ops.cond(mirror_cond,
                                       lambda: flip_bboxes(bboxes),
                                       lambda: bboxes)
        return fix_image_flip_shape(image, result), bboxes 
Example #21
Source File: image_ops_impl.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _assert(cond, ex_type, msg):
  """A polymorphic assert, works with tensors and boolean expressions.

  If `cond` is not a tensor, behave like an ordinary assert statement, except
  that a empty list is returned. If `cond` is a tensor, return a list
  containing a single TensorFlow assert op.

  Args:
    cond: Something evaluates to a boolean value. May be a tensor.
    ex_type: The exception class to use.
    msg: The error message.

  Returns:
    A list, containing at most one assert op.
  """
  if _is_tensor(cond):
    return [control_flow_ops.Assert(cond, [msg])]
  else:
    if not cond:
      raise ex_type(msg)
    else:
      return [] 
Example #22
Source File: input.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input):
  """Enqueue `tensor_list_list` in `queue`."""
  if enqueue_many:
    enqueue_fn = queue.enqueue_many
  else:
    enqueue_fn = queue.enqueue
  if keep_input is None:
    enqueue_ops = [enqueue_fn(tl) for tl in tensor_list_list]
  else:
    enqueue_ops = [control_flow_ops.cond(
        keep_input,
        lambda: enqueue_fn(tl),
        control_flow_ops.no_op) for tl in tensor_list_list]
  queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops)) 
Example #23
Source File: utils.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def smart_cond(pred, fn1, fn2, name=None):
  """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`.

  If `pred` is a bool or has a constant value, we return either `fn1()`
  or `fn2()`, otherwise we use `tf.cond` to dynamically route to both.

  Arguments:
    pred: A scalar determining whether to return the result of `fn1` or `fn2`.
    fn1: The callable to be performed if pred is true.
    fn2: The callable to be performed if pred is false.
    name: Optional name prefix when using `tf.cond`.

  Returns:
    Tensors returned by the call to either `fn1` or `fn2`.

  Raises:
    TypeError is fn1 or fn2 is not callable.
  """
  if not callable(fn1):
    raise TypeError('`fn1` must be callable.')
  if not callable(fn2):
    raise TypeError('`fn2` must be callable.')

  pred_value = constant_value(pred)
  if pred_value is not None:
    if pred_value:
      return fn1()
    else:
      return fn2()
  else:
    return control_flow_ops.cond(pred, fn1, fn2, name) 
Example #24
Source File: drop_stale_gradient_optimizer.py    From lambda-packs with MIT License 5 votes vote down vote up
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    gradients = []
    # Number of stale gradients.
    stale_counter = variable_scope.get_variable(
        "stale_counter", [],
        initializer=init_ops.zeros_initializer(),
        trainable=False)

    def _AcceptGradientOp():
      with ops.control_dependencies(
          [self._opt.apply_gradients(
              grads_and_vars, global_step=global_step, name=name)]):
        return gen_array_ops.identity(0.0)

    def _DropGradientOp():
      return gen_array_ops.identity(1.0)

    for grad_and_var in grads_and_vars:
      grad = grad_and_var[0]
      if isinstance(grad, ops.Tensor):
        gradients.append(grad)
      elif grad is not None:
        gradients.append(grad.op)

    with ops.control_dependencies(gradients), ops.colocate_with(global_step):
      staleness = gen_array_ops.reshape(
          global_step - self._local_step, shape=())

    conditional_update = stale_counter.assign_add(control_flow_ops.cond(
        gen_math_ops.less_equal(staleness, self._staleness),
        _AcceptGradientOp, _DropGradientOp))

    summary.scalar(
        "Gradient staleness percentage",
        stale_counter / (math_ops.cast(global_step + 1, dtypes.float32)))
    return conditional_update 
Example #25
Source File: check_ops.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _get_diff_for_monotonic_comparison(x):
  """Gets the difference x[1:] - x[:-1]."""
  x = array_ops.reshape(x, [-1])
  if not is_numeric_tensor(x):
    raise TypeError('Expected x to be numeric, instead found: %s' % x)

  # If x has less than 2 elements, there is nothing to compare.  So return [].
  is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
  short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)

  # With 2 or more elements, return x[1:] - x[:-1]
  s_len = array_ops.shape(x) - 1
  diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len)
  return control_flow_ops.cond(is_shorter_than_two, short_result, diff) 
Example #26
Source File: operator_pd.py    From lambda-packs with MIT License 5 votes vote down vote up
def _dispatch_based_on_batch(self, batch_method, singleton_method, **args):
    """Helper to automatically call batch or singleton operation."""
    if self.get_shape().ndims is not None:
      is_batch = self.get_shape().ndims > 2
      if is_batch:
        return batch_method(**args)
      else:
        return singleton_method(**args)
    else:
      is_batch = self.rank() > 2
      return control_flow_ops.cond(
          is_batch,
          lambda: batch_method(**args),
          lambda: singleton_method(**args)
      ) 
Example #27
Source File: tensor_forest.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _get_loss(self, features, labels):
    """Constructs, caches, and returns the inference-based loss."""
    if self._loss is not None:
      return self._loss

    def _average_loss():
      probs = self.inference_graph(features)
      return math_ops.reduce_sum(self.loss_fn(
          probs, labels)) / math_ops.to_float(array_ops.shape(labels)[0])

    self._loss = control_flow_ops.cond(
        self.average_size() > 0, _average_loss,
        lambda: constant_op.constant(sys.maxsize, dtype=dtypes.float32))

    return self._loss 
Example #28
Source File: operator_pd.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _dispatch_based_on_batch(self, batch_method, singleton_method, **args):
    """Helper to automatically call batch or singleton operation."""
    if self.get_shape().ndims is not None:
      is_batch = self.get_shape().ndims > 2
      if is_batch:
        return batch_method(**args)
      else:
        return singleton_method(**args)
    else:
      is_batch = self.rank() > 2
      return control_flow_ops.cond(
          is_batch,
          lambda: batch_method(**args),
          lambda: singleton_method(**args)
      ) 
Example #29
Source File: tpu_estimator.py    From Chinese-XLNet with Apache License 2.0 5 votes vote down vote up
def slice_tensor_or_dict(tensor_or_dict, signals):
    """Slice the real Tensors according to padding mask in signals."""

    padding_mask = signals['padding_mask']
    batch_size = array_ops.shape(padding_mask)[0]

    def verify_batch_size(tensor):
      check_batch_size = math_ops.equal(batch_size, tensor.shape[0])
      with ops.control_dependencies([check_batch_size]):
        return array_ops.identity(tensor)

    def slice_single_tensor(tensor):
      rank = len(tensor.shape)
      assert rank > 0
      real_batch_size = batch_size - math_ops.reduce_sum(padding_mask)
      return verify_batch_size(tensor)[0:real_batch_size]

    # As we split the Tensors to all TPU cores and concat them back, it is
    # important to ensure the real data is placed before padded ones, i.e.,
    # order is preserved. By that, the sliced padding mask should have all 0's.
    # If this assertion failed, # the slice logic here would not hold.
    sliced_padding_mask = slice_single_tensor(padding_mask)
    assert_padding_mask = math_ops.equal(
        math_ops.reduce_sum(sliced_padding_mask), 0)

    with ops.control_dependencies([assert_padding_mask]):
      should_stop = _StopSignals.should_stop(
          _StopSignals.as_scalar_stopping_signal(signals))

    is_full_batch = math_ops.equal(math_ops.reduce_sum(padding_mask), 0)

    def slice_fn(tensor):
      # If the current batch is full batch or part of stopping signals, we do
      # not need to slice to save performance.
      return control_flow_ops.cond(
          math_ops.logical_or(should_stop, is_full_batch),
          (lambda: verify_batch_size(tensor)),
          (lambda: slice_single_tensor(tensor)))

    return nest.map_structure(slice_fn, tensor_or_dict) 
Example #30
Source File: tf_example_decoder.py    From Person-Detection-and-Tracking with MIT License 5 votes vote down vote up
def _decode_png_instance_masks(self, keys_to_tensors):
    """Decode PNG instance segmentation masks and stack into dense tensor.

    The instance segmentation masks are reshaped to [num_instances, height,
    width].

    Args:
      keys_to_tensors: a dictionary from keys to tensors.

    Returns:
      A 3-D float tensor of shape [num_instances, height, width] with values
        in {0, 1}.
    """

    def decode_png_mask(image_buffer):
      image = tf.squeeze(
          tf.image.decode_image(image_buffer, channels=1), axis=2)
      image.set_shape([None, None])
      image = tf.to_float(tf.greater(image, 0))
      return image

    png_masks = keys_to_tensors['image/object/mask']
    height = keys_to_tensors['image/height']
    width = keys_to_tensors['image/width']
    if isinstance(png_masks, tf.SparseTensor):
      png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='')
    return tf.cond(
        tf.greater(tf.size(png_masks), 0),
        lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
        lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))