Python tensorflow.python.ops.math_ops.equal() Examples

The following are 30 code examples of tensorflow.python.ops.math_ops.equal(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.ops.math_ops , or try the search function .
Example #1
Source File: tensor_forest.py    From lambda-packs with MIT License 6 votes vote down vote up
def average_impurity(self):
    """Constructs a TF graph for evaluating the average leaf impurity of a tree.

    If in regression mode, this is the leaf variance. If in classification mode,
    this is the gini impurity.

    Returns:
      The last op in the graph.
    """
    children = array_ops.squeeze(array_ops.slice(
        self.variables.tree, [0, 0], [-1, 1]), squeeze_dims=[1])
    is_leaf = math_ops.equal(constants.LEAF_NODE, children)
    leaves = math_ops.to_int32(array_ops.squeeze(array_ops.where(is_leaf),
                                                 squeeze_dims=[1]))
    counts = array_ops.gather(self.variables.node_sums, leaves)
    gini = self._weighted_gini(counts)
    # Guard against step 1, when there often are no leaves yet.
    def impurity():
      return gini
    # Since average impurity can be used for loss, when there's no data just
    # return a big number so that loss always decreases.
    def big():
      return array_ops.ones_like(gini, dtype=dtypes.float32) * 10000000.
    return control_flow_ops.cond(math_ops.greater(
        array_ops.shape(leaves)[0], 0), impurity, big) 
Example #2
Source File: metrics_impl.py    From lambda-packs with MIT License 6 votes vote down vote up
def _safe_scalar_div(numerator, denominator, name):
  """Divides two values, returning 0 if the denominator is 0.

  Args:
    numerator: A scalar `float64` `Tensor`.
    denominator: A scalar `float64` `Tensor`.
    name: Name for the returned op.

  Returns:
    0 if `denominator` == 0, else `numerator` / `denominator`
  """
  numerator.get_shape().with_rank_at_most(1)
  denominator.get_shape().with_rank_at_most(1)
  return control_flow_ops.cond(
      math_ops.equal(
          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
      lambda: math_ops.div(numerator, denominator),
      name=name) 
Example #3
Source File: metrics_impl.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def _safe_scalar_div(numerator, denominator, name):
  """Divides two values, returning 0 if the denominator is 0.

  Args:
    numerator: A scalar `float64` `Tensor`.
    denominator: A scalar `float64` `Tensor`.
    name: Name for the returned op.

  Returns:
    0 if `denominator` == 0, else `numerator` / `denominator`
  """
  numerator.get_shape().with_rank_at_most(1)
  denominator.get_shape().with_rank_at_most(1)
  return control_flow_ops.cond(
      math_ops.equal(
          array_ops.constant(0.0, dtype=dtypes.float64), denominator),
      lambda: array_ops.constant(0.0, dtype=dtypes.float64),
      lambda: math_ops.div(numerator, denominator),
      name=name) 
Example #4
Source File: check_ops.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def is_non_decreasing(x, name=None):
  """Returns `True` if `x` is non-decreasing.

  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
  is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
  If `x` has less than two elements, it is trivially non-decreasing.

  See also:  `is_strictly_increasing`

  Args:
    x: Numeric `Tensor`.
    name: A name for this operation (optional).  Defaults to "is_non_decreasing"

  Returns:
    Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.

  Raises:
    TypeError: if `x` is not a numeric tensor.
  """
  with ops.name_scope(name, 'is_non_decreasing', [x]):
    diff = _get_diff_for_monotonic_comparison(x)
    # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
    return math_ops.reduce_all(math_ops.less_equal(zero, diff)) 
Example #5
Source File: loss_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def _safe_div(numerator, denominator, name="value"):
  """Computes a safe divide which returns 0 if the denominator is zero.

  Note that the function contains an additional conditional check that is
  necessary for avoiding situations where the loss is zero causing NaNs to
  creep into the gradient computation.

  Args:
    numerator: An arbitrary `Tensor`.
    denominator: A `Tensor` whose shape matches `numerator` and whose values are
      assumed to be non-negative.
    name: An optional name for the returned op.

  Returns:
    The element-wise value of the numerator divided by the denominator.
  """
  return array_ops.where(
      math_ops.greater(denominator, 0),
      math_ops.div(numerator, array_ops.where(
          math_ops.equal(denominator, 0),
          array_ops.ones_like(denominator), denominator)),
      array_ops.zeros_like(numerator),
      name=name) 
Example #6
Source File: binomial.py    From lambda-packs with MIT License 6 votes vote down vote up
def _bdtr(k, n, p):
  """The binomial cumulative distribution function.

  Args:
    k: floating point `Tensor`.
    n: floating point `Tensor`.
    p: floating point `Tensor`.

  Returns:
    `sum_{j=0}^k p^j (1 - p)^(n - j)`.
  """
  # Trick for getting safe backprop/gradients into n, k when
  #   betainc(a = 0, ..) = nan
  # Write:
  #   where(unsafe, safe_output, betainc(where(unsafe, safe_input, input)))
  ones = array_ops.ones_like(n - k)
  k_eq_n = math_ops.equal(k, n)
  safe_dn = array_ops.where(k_eq_n, ones, n - k)
  dk = math_ops.betainc(a=safe_dn, b=k + 1, x=1 - p)
  return array_ops.where(k_eq_n, ones, dk) 
Example #7
Source File: check_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def is_strictly_increasing(x, name=None):
  """Returns `True` if `x` is strictly increasing.

  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
  is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
  If `x` has less than two elements, it is trivially strictly increasing.

  See also:  `is_non_decreasing`

  Args:
    x: Numeric `Tensor`.
    name: A name for this operation (optional).
      Defaults to "is_strictly_increasing"

  Returns:
    Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.

  Raises:
    TypeError: if `x` is not a numeric tensor.
  """
  with ops.name_scope(name, 'is_strictly_increasing', [x]):
    diff = _get_diff_for_monotonic_comparison(x)
    # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
    return math_ops.reduce_all(math_ops.less(zero, diff)) 
Example #8
Source File: util.py    From lambda-packs with MIT License 6 votes vote down vote up
def assert_integer_form(
    x, data=None, summarize=None, message=None, name="assert_integer_form"):
  """Assert that x has integer components (or floats equal to integers).

  Args:
    x: Floating-point `Tensor`
    data: The tensors to print out if the condition is `False`. Defaults to
      error message and first few entries of `x` and `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).

  Returns:
    Op raising `InvalidArgumentError` if round(x) != x.
  """

  message = message or "x has non-integer components"
  x = ops.convert_to_tensor(x, name="x")
  casted_x = math_ops.to_int64(x)
  return check_ops.assert_equal(
      x, math_ops.cast(math_ops.round(casted_x), x.dtype),
      data=data, summarize=summarize, message=message, name=name) 
Example #9
Source File: affine_impl.py    From lambda-packs with MIT License 6 votes vote down vote up
def _process_matrix(self, matrix, min_rank, event_ndims):
    """Helper to __init__ which gets matrix in batch-ready form."""
    # Pad the matrix so that matmul works in the case of a matrix and vector
    # input. Keep track if the matrix was padded, to distinguish between a
    # rank 3 tensor and a padded rank 2 tensor.
    # TODO(srvasude): Remove side-effects from functions. Its currently unbroken
    # but error-prone since the function call order may change in the future.
    self._rank_two_event_ndims_one = math_ops.logical_and(
        math_ops.equal(array_ops.rank(matrix), min_rank),
        math_ops.equal(event_ndims, 1))
    left = array_ops.where(self._rank_two_event_ndims_one, 1, 0)
    pad = array_ops.concat(
        [array_ops.ones(
            [left], dtype=dtypes.int32), array_ops.shape(matrix)],
        0)
    return array_ops.reshape(matrix, pad) 
Example #10
Source File: math_grad.py    From lambda-packs with MIT License 6 votes vote down vote up
def _MinOrMaxGrad(op, grad):
  """Gradient for Min or Max. Amazingly it's precisely the same code."""
  input_shape = array_ops.shape(op.inputs[0])
  output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
  y = op.outputs[0]
  y = array_ops.reshape(y, output_shape_kept_dims)
  grad = array_ops.reshape(grad, output_shape_kept_dims)

  # Compute the number of selected (maximum or minimum) elements in each
  # reduction dimension. If there are multiple minimum or maximum elements
  # then the gradient will be divided between them.
  indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
  num_selected = array_ops.reshape(
      math_ops.reduce_sum(indicators, op.inputs[1]), output_shape_kept_dims)

  return [math_ops.div(indicators, num_selected) * grad, None] 
Example #11
Source File: math_grad.py    From lambda-packs with MIT License 6 votes vote down vote up
def _SegmentMinOrMaxGrad(op, grad, is_sorted):
  """Gradient for SegmentMin and (unsorted) SegmentMax. They share similar code."""
  zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
                          dtype=op.inputs[0].dtype)

  # Get the number of selected (minimum or maximum) elements in each segment.
  gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
  is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
  if is_sorted:
    num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
                                        op.inputs[1])
  else:
    num_selected = math_ops.unsorted_segment_sum(math_ops.cast(is_selected, grad.dtype),
                                                 op.inputs[1], op.inputs[2])

  # Compute the gradient for each segment. The gradient for the ith segment is
  # divided evenly among the selected elements in that segment.
  weighted_grads = math_ops.div(grad, num_selected)
  gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])

  if is_sorted:
    return array_ops.where(is_selected, gathered_grads, zeros), None
  else:
    return array_ops.where(is_selected, gathered_grads, zeros), None, None 
Example #12
Source File: losses_impl.py    From lambda-packs with MIT License 6 votes vote down vote up
def _safe_div(numerator, denominator, name="value"):
  """Computes a safe divide which returns 0 if the denominator is zero.

  Note that the function contains an additional conditional check that is
  necessary for avoiding situations where the loss is zero causing NaNs to
  creep into the gradient computation.

  Args:
    numerator: An arbitrary `Tensor`.
    denominator: `Tensor` whose shape matches `numerator` and whose values are
      assumed to be non-negative.
    name: An optional name for the returned op.

  Returns:
    The element-wise value of the numerator divided by the denominator.
  """
  return array_ops.where(
      math_ops.greater(denominator, 0),
      math_ops.div(numerator, array_ops.where(
          math_ops.equal(denominator, 0),
          array_ops.ones_like(denominator), denominator)),
      array_ops.zeros_like(numerator),
      name=name) 
Example #13
Source File: tensor_util.py    From lambda-packs with MIT License 5 votes vote down vote up
def _is_rank(expected_rank, actual_tensor):
  """Returns whether actual_tensor's rank is expected_rank.

  Args:
    expected_rank: Integer defining the expected rank, or tensor of same.
    actual_tensor: Tensor to test.
  Returns:
    New tensor.
  """
  with ops.name_scope('is_rank', values=[actual_tensor]) as scope:
    expected = ops.convert_to_tensor(expected_rank, name='expected')
    actual = array_ops.rank(actual_tensor, name='actual')
    return math_ops.equal(expected, actual, name=scope) 
Example #14
Source File: metrics_impl.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _select_class_id(ids, selected_id):
  """Filter all but `selected_id` out of `ids`.

  Args:
    ids: `int64` `Tensor` or `SparseTensor` of IDs.
    selected_id: Int id to select.

  Returns:
    `SparseTensor` of same dimensions as `ids`. This contains only the entries
    equal to `selected_id`.
  """
  ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
  if isinstance(ids, sparse_tensor.SparseTensor):
    return sparse_ops.sparse_retain(
        ids, math_ops.equal(ids.values, selected_id))

  # TODO(ptucker): Make this more efficient, maybe add a sparse version of
  # tf.equal and tf.reduce_any?

  # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
  ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
  ids_last_dim = array_ops.size(ids_shape) - 1
  filled_selected_id_shape = math_ops.reduced_shape(
      ids_shape, array_ops.reshape(ids_last_dim, [1]))

  # Intersect `ids` with the selected ID.
  filled_selected_id = array_ops.fill(
      filled_selected_id_shape, math_ops.to_int64(selected_id))
  result = sets.set_intersection(filled_selected_id, ids)
  return sparse_tensor.SparseTensor(
      indices=result.indices, values=result.values, dense_shape=ids_shape) 
Example #15
Source File: binomial.py    From lambda-packs with MIT License 5 votes vote down vote up
def _maybe_assert_valid_sample(self, counts, check_integer=True):
    """Check counts for proper shape, values, then return tensor version."""
    if not self.validate_args:
      return counts

    counts = distribution_util.embed_check_nonnegative_discrete(
        counts, check_integer=check_integer)
    return control_flow_ops.with_dependencies([
        check_ops.assert_less_equal(
            counts, self.total_count,
            message="counts are not less than or equal to n."),
    ], counts) 
Example #16
Source File: binomial.py    From lambda-packs with MIT License 5 votes vote down vote up
def _cdf(self, counts):
    counts = self._maybe_assert_valid_sample(counts)
    probs = self.probs
    if not (counts.shape.is_fully_defined()
            and self.probs.shape.is_fully_defined()
            and counts.shape.is_compatible_with(self.probs.shape)):
      # If both shapes are well defined and equal, we skip broadcasting.
      probs += array_ops.zeros_like(counts)
      counts += array_ops.zeros_like(self.probs)

    return _bdtr(k=counts, n=self.total_count, p=probs) 
Example #17
Source File: tensor_util.py    From lambda-packs with MIT License 5 votes vote down vote up
def _all_equal(tensor0, tensor1):
  with ops.name_scope('all_equal', values=[tensor0, tensor1]) as scope:
    return math_ops.reduce_all(
        math_ops.equal(tensor0, tensor1, name='equal'), name=scope) 
Example #18
Source File: check_ops.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
  """Assert the condition `x == y` holds element-wise.

  Example of adding a dependency to an operation:

  ```python
  with tf.control_dependencies([tf.assert_equal(x, y)]):
    output = tf.reduce_sum(x)
  ```

  This condition holds if for every pair of (possibly broadcast) elements
  `x[i]`, `y[i]`, we have `x[i] == y[i]`.
  If both `x` and `y` are empty, this is trivially satisfied.

  Args:
    x:  Numeric `Tensor`.
    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
    data:  The tensors to print out if the condition is False.  Defaults to
      error message and first few entries of `x`, `y`.
    summarize: Print this many entries of each tensor.
    message: A string to prefix to the default message.
    name: A name for this operation (optional).  Defaults to "assert_equal".

  Returns:
    Op that raises `InvalidArgumentError` if `x == y` is False.
  """
  message = message or ''
  with ops.name_scope(name, 'assert_equal', [x, y, data]):
    x = ops.convert_to_tensor(x, name='x')
    y = ops.convert_to_tensor(y, name='y')
    if data is None:
      data = [
          message,
          'Condition x == y did not hold element-wise: x = ', x.name, x, 'y = ',
          y.name, y
      ]
    condition = math_ops.reduce_all(math_ops.equal(x, y))
    return control_flow_ops.Assert(condition, data, summarize=summarize) 
Example #19
Source File: metrics_impl.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def false_negatives(labels, predictions, weights=None,
                    metrics_collections=None,
                    updates_collections=None,
                    name=None):
  """Computes the total number of false positives.

  If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.

  Args:
    labels: The ground truth values, a `Tensor` whose dimensions must match
      `predictions`. Will be cast to `bool`.
    predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
      be cast to `bool`.
    weights: Optional `Tensor` whose rank is either 0, or the same rank as
      `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
      be either `1`, or the same as the corresponding `labels` dimension).
    metrics_collections: An optional list of collections that the metric
      value variable should be added to.
    updates_collections: An optional list of collections that the metric update
      ops should be added to.
    name: An optional variable_scope name.

  Returns:
    value_tensor: A `Tensor` representing the current value of the metric.
    update_op: An operation that accumulates the error from a batch of data.

  Raises:
    ValueError: If `weights` is not `None` and its shape doesn't match `values`,
      or if either `metrics_collections` or `updates_collections` are not a list
      or tuple.
  """
  with variable_scope.variable_scope(
      name, 'false_negatives', (predictions, labels, weights)):

    labels = math_ops.cast(labels, dtype=dtypes.bool)
    predictions = math_ops.cast(predictions, dtype=dtypes.bool)
    predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    is_false_negative = math_ops.logical_and(math_ops.equal(labels, True),
                                             math_ops.equal(predictions, False))
    return _count_condition(is_false_negative, weights, metrics_collections,
                            updates_collections) 
Example #20
Source File: check_ops.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _dynamic_rank_in(actual_rank, given_ranks):
  if len(given_ranks) < 1:
    return ops.convert_to_tensor(False)
  result = math_ops.equal(given_ranks[0], actual_rank)
  for given_rank in given_ranks[1:]:
    result = math_ops.logical_or(
        result, math_ops.equal(given_rank, actual_rank))
  return result 
Example #21
Source File: sequence_queueing_state_saver.py    From lambda-packs with MIT License 5 votes vote down vote up
def _check_multiple_of(value, multiple_of):
  """Checks that value `value` is a non-zero multiple of `multiple_of`.

  Args:
    value: an int32 scalar Tensor.
    multiple_of: an int or int32 scalar Tensor.

  Returns:
    new_value: an int32 scalar Tensor matching `value`, but which includes an
      assertion that `value` is a multiple of `multiple_of`.
  """
  assert isinstance(value, ops.Tensor)
  with ops.control_dependencies([
      control_flow_ops.Assert(
          math_ops.logical_and(
              math_ops.equal(math_ops.mod(value, multiple_of), 0),
              math_ops.not_equal(value, 0)), [
                  string_ops.string_join([
                      "Tensor %s should be a multiple of: " % value.name,
                      string_ops.as_string(multiple_of), ", but saw value: ",
                      string_ops.as_string(value),
                      ". Consider setting pad=True."
                  ])
              ])
  ]):
    new_value = array_ops.identity(value, name="multiple_of_checked")
    return new_value 
Example #22
Source File: head.py    From lambda-packs with MIT License 5 votes vote down vote up
def _class_labels_streaming_mean(labels, weights, class_id):
  return metrics_lib.streaming_mean(
      array_ops.where(
          math_ops.equal(
              math_ops.to_int32(class_id), math_ops.to_int32(labels)),
          array_ops.ones_like(labels), array_ops.zeros_like(labels)),
      weights=weights) 
Example #23
Source File: head.py    From lambda-packs with MIT License 5 votes vote down vote up
def _class_predictions_streaming_mean(predictions, weights, class_id):
  return metrics_lib.streaming_mean(
      array_ops.where(
          math_ops.equal(
              math_ops.to_int32(class_id), math_ops.to_int32(predictions)),
          array_ops.ones_like(predictions),
          array_ops.zeros_like(predictions)),
      weights=weights) 
Example #24
Source File: boolean_mask.py    From lambda-packs with MIT License 5 votes vote down vote up
def sparse_boolean_mask(sparse_tensor, mask, name="sparse_boolean_mask"):
  """Boolean mask for `SparseTensor`s.

  Args:
    sparse_tensor: a `SparseTensor`.
    mask: a 1D boolean dense`Tensor` whose length is equal to the 0th dimension
      of `sparse_tensor`.
    name: optional name for this operation.
  Returns:
    A `SparseTensor` that contains row `k` of `sparse_tensor` iff `mask[k]` is
    `True`.
  """
  # TODO(jamieas): consider mask dimension > 1 for symmetry with `boolean_mask`.
  with ops.name_scope(name, values=[sparse_tensor, mask]):
    mask = ops.convert_to_tensor(mask)
    mask_rows = array_ops.where(mask)
    first_indices = array_ops.squeeze(array_ops.slice(sparse_tensor.indices,
                                                      [0, 0], [-1, 1]))

    # Identify indices corresponding to the rows identified by mask_rows.
    sparse_entry_matches = functional_ops.map_fn(
        lambda x: math_ops.equal(first_indices, x),
        mask_rows,
        dtype=dtypes.bool)
    # Combine the rows of index_matches to form a mask for the sparse indices
    # and values.
    to_retain = array_ops.reshape(
        functional_ops.foldl(math_ops.logical_or, sparse_entry_matches), [-1])

    return sparse_ops.sparse_retain(sparse_tensor, to_retain) 
Example #25
Source File: helper.py    From lambda-packs with MIT License 5 votes vote down vote up
def next_inputs(self, time, outputs, state, sample_ids, name=None):
    """next_inputs_fn for GreedyEmbeddingHelper."""
    del time, outputs  # unused by next_inputs_fn
    finished = math_ops.equal(sample_ids, self._end_token)
    all_finished = math_ops.reduce_all(finished)
    next_inputs = control_flow_ops.cond(
        all_finished,
        # If we're finished, the next_inputs value doesn't matter
        lambda: self._start_inputs,
        lambda: self._embedding_fn(sample_ids))
    return (finished, next_inputs, state) 
Example #26
Source File: tf_helpers.py    From Counterfactual-StoryRW with MIT License 5 votes vote down vote up
def initialize(self, name=None):
        with ops.name_scope(name, "TrainingHelperInitialize"):
            finished = math_ops.equal(0, self._sequence_length)
            all_finished = math_ops.reduce_all(finished)
            next_inputs = control_flow_ops.cond(
                all_finished, lambda: self._zero_inputs,
                lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
            return (finished, next_inputs) 
Example #27
Source File: helper.py    From lambda-packs with MIT License 5 votes vote down vote up
def initialize(self, name=None):
    with ops.name_scope(name, "TrainingHelperInitialize"):
      finished = math_ops.equal(0, self._sequence_length)
      all_finished = math_ops.reduce_all(finished)
      next_inputs = control_flow_ops.cond(
          all_finished, lambda: self._zero_inputs,
          lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
      return (finished, next_inputs) 
Example #28
Source File: beam_search_decoder.py    From lambda-packs with MIT License 5 votes vote down vote up
def _tensor_gather_helper(gather_indices, gather_from, batch_size,
                          range_size, gather_shape):
  """Helper for gathering the right indices from the tensor.

  This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
  gathering from that according to the gather_indices, which are offset by
  the right amounts in order to preserve the batch order.

  Args:
    gather_indices: The tensor indices that we use to gather.
    gather_from: The tensor that we are gathering from.
    batch_size: The input batch size.
    range_size: The number of values in each range. Likely equal to beam_width.
    gather_shape: What we should reshape gather_from to in order to preserve the
      correct values. An example is when gather_from is the attention from an
      AttentionWrapperState with shape [batch_size, beam_width, attention_size].
      There, we want to preserve the attention_size elements, so gather_shape is
      [batch_size * beam_width, -1]. Then, upon reshape, we still have the
      attention_size as desired.

  Returns:
    output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
  """
  range_ = array_ops.expand_dims(math_ops.range(batch_size) * range_size, 1)
  gather_indices = array_ops.reshape(gather_indices + range_, [-1])
  output = array_ops.gather(
      array_ops.reshape(gather_from, gather_shape), gather_indices)
  final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)]
  static_batch_size = tensor_util.constant_value(batch_size)
  final_static_shape = (tensor_shape.TensorShape([static_batch_size])
                        .concatenate(
                            gather_from.shape[1:1 + len(gather_shape)]))
  output = array_ops.reshape(output, final_shape)
  output.set_shape(final_static_shape)
  return output 
Example #29
Source File: beam_search_decoder.py    From lambda-packs with MIT License 5 votes vote down vote up
def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
                                range_size, gather_shape):
  """Maybe applies _tensor_gather_helper.

  This applies _tensor_gather_helper when the gather_from dims is at least as
  big as the length of gather_shape. This is used in conjunction with nest so
  that we don't apply _tensor_gather_helper to inapplicable values like scalars.

  Args:
    gather_indices: The tensor indices that we use to gather.
    gather_from: The tensor that we are gathering from.
    batch_size: The batch size.
    range_size: The number of values in each range. Likely equal to beam_width.
    gather_shape: What we should reshape gather_from to in order to preserve the
      correct values. An example is when gather_from is the attention from an
      AttentionWrapperState with shape [batch_size, beam_width, attention_size].
      There, we want to preserve the attention_size elements, so gather_shape is
      [batch_size * beam_width, -1]. Then, upon reshape, we still have the
      attention_size as desired.

  Returns:
    output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
      or the original tensor if its dimensions are too small.
  """
  _check_maybe(gather_from)
  if gather_from.shape.ndims >= len(gather_shape):
    return _tensor_gather_helper(
        gather_indices=gather_indices,
        gather_from=gather_from,
        batch_size=batch_size,
        range_size=range_size,
        gather_shape=gather_shape)
  else:
    return gather_from 
Example #30
Source File: sequence_queueing_state_saver.py    From lambda-packs with MIT License 5 votes vote down vote up
def sequence_count(self):
    """An int32 vector, length `batch_size`: the sequence count of each entry.

    When an input is split up, the number of splits is equal to:
    `padded_length / num_unroll`.  This is the sequence_count.

    Returns:
      An int32 vector `Tensor`.
    """
    return self._state_saver._received_sequence_count