Python tensorflow.python.util.nest.pack_sequence_as() Examples

The following are 30 code examples of tensorflow.python.util.nest.pack_sequence_as(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.util.nest , or try the search function .
Example #1
Source File: nested_utils.py    From multilabel-image-classification-tensorflow with MIT License 6 votes vote down vote up
def map_nested(map_fn, nested):
  """Executes map_fn on every element in a (potentially) nested structure.

  Args:
    map_fn: A callable to execute on each element in 'nested'.
    nested: A potentially nested combination of sequence objects. Sequence
      objects include tuples, lists, namedtuples, and all subclasses of
      collections.Sequence except strings. See nest.is_sequence for details.
      For example [1, ('hello', 4.3)] is a nested structure containing elements
      1, 'hello', and 4.3.
  Returns:
    out_structure: A potentially nested combination of sequence objects with the
      same structure as the 'nested' input argument. out_structure
      contains the result of applying map_fn to each element in 'nested'. For
      example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
  """
  out = map(map_fn, nest.flatten(nested))
  return nest.pack_sequence_as(nested, out) 
Example #2
Source File: tf_utils.py    From video_prediction with MIT License 6 votes vote down vote up
def static_rnn(cell, inputs, scope=None):
    """Simple version of static_rnn."""
    with tf.variable_scope(scope or "rnn") as varscope:
        batch_size = dimension(inputs, axis=1)
        state = cell.zero_state(batch_size, tf.float32)
        flat_inputs = nest.flatten(inputs)
        flat_inputs = list(zip(*[tf.unstack(flat_input, axis=0) for flat_input in flat_inputs]))
        flat_outputs = []
        for time, flat_input in enumerate(flat_inputs):
            if time > 0:
                varscope.reuse_variables()
            input_ = nest.pack_sequence_as(inputs, flat_input)
            output, state = cell(input_, state)
            flat_output = nest.flatten(output)
            flat_outputs.append(flat_output)
        flat_outputs = [tf.stack(flat_output, axis=0) for flat_output in zip(*flat_outputs)]
        outputs = nest.pack_sequence_as(output, flat_outputs)
        return outputs, state 
Example #3
Source File: tpu_estimator.py    From Chinese-XLNet with Apache License 2.0 6 votes vote down vote up
def unflatten_features_and_labels(self, flattened_inputs):
      """Restores the flattened inputs to original features and labels form.

      Args:
        flattened_inputs: Flattened inputs for each shard.

      Returns:
        A tuple of (`features`, `labels`), where `labels` could be None.
        Each one, if present, should have identical structure (single tensor vs
        dict) as the one returned by input_fn.

      Raises:
        ValueError: If the number of expected tensors from `flattened_inputs`
          mismatches the recorded structure.
      """

      unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
                                                      flattened_inputs)
      return _Inputs(
          unflattened_inputs['features'],
          unflattened_inputs.get('labels'),
          signals=unflattened_inputs.get('signals')) 
Example #4
Source File: beam_search.py    From NJUNMT-tf with Apache License 2.0 6 votes vote down vote up
def gather_states(states, beam_ids):
    """ Gathers states according to beam ids.

    Args:
        states: A Tensor of a list/tuple/dict of Tensors. For each Tensor, the first
          dimension must be batch_size, otherwise, unknow errors may occur.
        beam_ids: A tensor with shape [batch_size, ] that used to gather states.

    Returns: A Tensor or a list/tuple of Tensors with the same structure
      as `states`.
    """

    def _gather(x):
        assert isinstance(x, tf.Tensor)
        return tf.gather(x, beam_ids)

    return nest.pack_sequence_as(
        states,
        nest.map_structure(
            _gather, nest.flatten(states))) 
Example #5
Source File: attention_ops.py    From hart with GNU General Public License v3.0 6 votes vote down vote up
def _zero_state(self, img, att, presence, state, transform_features, transform_state=False):

        with tf.variable_scope(self.__class__.__name__) as vs:
            features = self.extract_features(img, att)[1]

            if transform_features:
                features_flat = tf.reshape(features, (-1, self.n_units))
                features_flat = AffineLayer(features_flat, self.n_units, name='init_feature_transform').output
                features = tf.reshape(features_flat, tf.shape(features))

            rnn_outputs, hidden_state = self._propagate(features, state)

            hidden_state = nest.flatten(hidden_state)

            if transform_state:
                for i, hs in enumerate(hidden_state):
                    name = 'init_state_transform_{}'.format(i)
                    hidden_state[i] = AffineLayer(hs, self.n_units, name=name).output

            state = nest.pack_sequence_as(structure=state, flat_sequence=hidden_state)
        self.rnn_vs = vs
        return state, rnn_outputs 
Example #6
Source File: layers.py    From neural-combinatorial-rl-tensorflow with MIT License 6 votes vote down vote up
def trainable_initial_state(batch_size, state_size,
                            initializer=None, name="initial_state"):
  flat_state_size = nest.flatten(state_size)

  if not initializer:
    flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size)
  else:
    flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size)

  names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))]
  tiled_states = []

  for name, size, init in zip(names, flat_state_size, flat_initializer):
    shape_with_batch_dim = [1, size]
    initial_state_variable = tf.get_variable(
        name, shape=shape_with_batch_dim, initializer=init())

    tiled_state = tf.tile(initial_state_variable,
                          [batch_size, 1], name=(name + "_tiled"))
    tiled_states.append(tiled_state)

  return nest.pack_sequence_as(structure=state_size,
                               flat_sequence=tiled_states) 
Example #7
Source File: layers.py    From pointer-network-tensorflow with MIT License 6 votes vote down vote up
def trainable_initial_state(batch_size, state_size,
                            initializer=None, name="initial_state"):
  flat_state_size = nest.flatten(state_size)

  if not initializer:
    flat_initializer = tuple(tf.zeros_initializer for _ in flat_state_size)
  else:
    flat_initializer = tuple(tf.zeros_initializer for initializer in flat_state_size)

  names = ["{}_{}".format(name, i) for i in xrange(len(flat_state_size))]
  tiled_states = []

  for name, size, init in zip(names, flat_state_size, flat_initializer):
    shape_with_batch_dim = [1, size]
    initial_state_variable = tf.get_variable(
        name, shape=shape_with_batch_dim, initializer=init())

    tiled_state = tf.tile(initial_state_variable,
                          [batch_size, 1], name=(name + "_tiled"))
    tiled_states.append(tiled_state)

  return nest.pack_sequence_as(structure=state_size,
                               flat_sequence=tiled_states) 
Example #8
Source File: nested_utils.py    From yolo_v2 with Apache License 2.0 6 votes vote down vote up
def map_nested(map_fn, nested):
  """Executes map_fn on every element in a (potentially) nested structure.

  Args:
    map_fn: A callable to execute on each element in 'nested'.
    nested: A potentially nested combination of sequence objects. Sequence
      objects include tuples, lists, namedtuples, and all subclasses of
      collections.Sequence except strings. See nest.is_sequence for details.
      For example [1, ('hello', 4.3)] is a nested structure containing elements
      1, 'hello', and 4.3.
  Returns:
    out_structure: A potentially nested combination of sequence objects with the
      same structure as the 'nested' input argument. out_structure
      contains the result of applying map_fn to each element in 'nested'. For
      example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
  """
  out = map(map_fn, nest.flatten(nested))
  return nest.pack_sequence_as(nested, out) 
Example #9
Source File: shapes.py    From texar with Apache License 2.0 6 votes vote down vote up
def transpose_batch_time(inputs):
    """Transposes inputs between time-major and batch-major.

    Args:
        inputs: A Tensor of shape `[batch_size, max_time, ...]` (batch-major)
            or `[max_time, batch_size, ...]` (time-major), or a (possibly
            nested) tuple of such elements.

    Returns:
        A (possibly nested tuple of) Tensor with transposed batch and
        time dimensions of inputs.
    """
    flat_input = nest.flatten(inputs)
    flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
    # pylint: disable=protected-access
    flat_input = [rnn._transpose_batch_time(input_) for input_ in flat_input]
    return nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 
Example #10
Source File: tpu_estimator.py    From transformer-xl with Apache License 2.0 6 votes vote down vote up
def unflatten_features_and_labels(self, flattened_inputs):
      """Restores the flattened inputs to original features and labels form.

      Args:
        flattened_inputs: Flattened inputs for each shard.

      Returns:
        A tuple of (`features`, `labels`), where `labels` could be None.
        Each one, if present, should have identical structure (single tensor vs
        dict) as the one returned by input_fn.

      Raises:
        ValueError: If the number of expected tensors from `flattened_inputs`
          mismatches the recorded structure.
      """

      unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
                                                      flattened_inputs)
      return _Inputs(
          unflattened_inputs['features'],
          unflattened_inputs.get('labels'),
          signals=unflattened_inputs.get('signals')) 
Example #11
Source File: tpu_estimator.py    From embedding-as-service with MIT License 6 votes vote down vote up
def unflatten_features_and_labels(self, flattened_inputs):
      """Restores the flattened inputs to original features and labels form.

      Args:
        flattened_inputs: Flattened inputs for each shard.

      Returns:
        A tuple of (`features`, `labels`), where `labels` could be None.
        Each one, if present, should have identical structure (single tensor vs
        dict) as the one returned by input_fn.

      Raises:
        ValueError: If the number of expected tensors from `flattened_inputs`
          mismatches the recorded structure.
      """

      unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
                                                      flattened_inputs)
      return _Inputs(
          unflattened_inputs['features'],
          unflattened_inputs.get('labels'),
          signals=unflattened_inputs.get('signals')) 
Example #12
Source File: nested_utils.py    From object_detection_with_tensorflow with MIT License 6 votes vote down vote up
def map_nested(map_fn, nested):
  """Executes map_fn on every element in a (potentially) nested structure.

  Args:
    map_fn: A callable to execute on each element in 'nested'.
    nested: A potentially nested combination of sequence objects. Sequence
      objects include tuples, lists, namedtuples, and all subclasses of
      collections.Sequence except strings. See nest.is_sequence for details.
      For example [1, ('hello', 4.3)] is a nested structure containing elements
      1, 'hello', and 4.3.
  Returns:
    out_structure: A potentially nested combination of sequence objects with the
      same structure as the 'nested' input argument. out_structure
      contains the result of applying map_fn to each element in 'nested'. For
      example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
  """
  out = map(map_fn, nest.flatten(nested))
  return nest.pack_sequence_as(nested, out) 
Example #13
Source File: nested_utils.py    From MultitaskAIS with MIT License 6 votes vote down vote up
def map_nested(map_fn, nested):
  """Executes map_fn on every element in a (potentially) nested structure.

  Args:
    map_fn: A callable to execute on each element in 'nested'.
    nested: A potentially nested combination of sequence objects. Sequence
      objects include tuples, lists, namedtuples, and all subclasses of
      collections.Sequence except strings. See nest.is_sequence for details.
      For example [1, ('hello', 4.3)] is a nested structure containing elements
      1, 'hello', and 4.3.
  Returns:
    out_structure: A potentially nested combination of sequence objects with the
      same structure as the 'nested' input argument. out_structure
      contains the result of applying map_fn to each element in 'nested'. For
      example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
  """
  out = list(map(map_fn, nest.flatten(nested)))
  return nest.pack_sequence_as(nested, out) 
Example #14
Source File: tf_utils.py    From pixelsnail-public with MIT License 6 votes vote down vote up
def batch(self, batch_size=None):
    """Get a batch of tensors."""
    if self.produces_batches:
      assert batch_size is None, 'Cannot enforce a batch size if `func()` returns batches!'
      flat_batch = self._queue.dequeue()
      for name, pl in self.flat_placeholders.items():
        flat_batch[name].set_shape(pl.shape)

    else:
      flat_batch = self._queue.dequeue_many(batch_size)

    batch = Struct()
    for name, pl in self.placeholders.items():
      flat_vals = sorted((k, v)
                         for k, v in flat_batch.items() if k.startswith(name))
      vals = [v for k, v in flat_vals]
      batch[name] = vals[0] if len(
          vals) == 0 else nest.pack_sequence_as(pl, vals)

    return batch 
Example #15
Source File: tpu_estimator.py    From xlnet with Apache License 2.0 6 votes vote down vote up
def unflatten_features_and_labels(self, flattened_inputs):
      """Restores the flattened inputs to original features and labels form.

      Args:
        flattened_inputs: Flattened inputs for each shard.

      Returns:
        A tuple of (`features`, `labels`), where `labels` could be None.
        Each one, if present, should have identical structure (single tensor vs
        dict) as the one returned by input_fn.

      Raises:
        ValueError: If the number of expected tensors from `flattened_inputs`
          mismatches the recorded structure.
      """

      unflattened_inputs = data_nest.pack_sequence_as(self._feature_structure,
                                                      flattened_inputs)
      return _Inputs(
          unflattened_inputs['features'],
          unflattened_inputs.get('labels'),
          signals=unflattened_inputs.get('signals')) 
Example #16
Source File: nested_utils.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def map_nested(map_fn, nested):
  """Executes map_fn on every element in a (potentially) nested structure.

  Args:
    map_fn: A callable to execute on each element in 'nested'.
    nested: A potentially nested combination of sequence objects. Sequence
      objects include tuples, lists, namedtuples, and all subclasses of
      collections.Sequence except strings. See nest.is_sequence for details.
      For example [1, ('hello', 4.3)] is a nested structure containing elements
      1, 'hello', and 4.3.
  Returns:
    out_structure: A potentially nested combination of sequence objects with the
      same structure as the 'nested' input argument. out_structure
      contains the result of applying map_fn to each element in 'nested'. For
      example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
  """
  out = map(map_fn, nest.flatten(nested))
  return nest.pack_sequence_as(nested, out) 
Example #17
Source File: nested_utils.py    From g-tensorflow-models with Apache License 2.0 6 votes vote down vote up
def where_tensors(condition, x_tensors, y_tensors):
  """Performs a tf.where operation on a two sets of Tensors.

  Args:
    condition: The condition tensor to use for the where operation.
    x_tensors: A potentially nested tuple or list of Tensors.
    y_tensors: A potentially nested tuple or list of Tensors. Must have the
    same structure as x_tensors.
  Returns:
    whered_tensors: A potentially nested tuple or list of Tensors with the
      same structure as the 'tensors' input argument. Contains the result of
      applying tf.where(condition, x, y) on each pair of elements in x_tensors
      and y_tensors.
  """
  flat_x = nest.flatten(x_tensors)
  flat_y = nest.flatten(y_tensors)
  result = [tf.where(condition, x, y) for x, y in
            itertools.izip(flat_x, flat_y)]

  return nest.pack_sequence_as(x_tensors, result) 
Example #18
Source File: nested_utils.py    From models with Apache License 2.0 6 votes vote down vote up
def map_nested(map_fn, nested):
  """Executes map_fn on every element in a (potentially) nested structure.

  Args:
    map_fn: A callable to execute on each element in 'nested'.
    nested: A potentially nested combination of sequence objects. Sequence
      objects include tuples, lists, namedtuples, and all subclasses of
      collections.Sequence except strings. See nest.is_sequence for details.
      For example [1, ('hello', 4.3)] is a nested structure containing elements
      1, 'hello', and 4.3.
  Returns:
    out_structure: A potentially nested combination of sequence objects with the
      same structure as the 'nested' input argument. out_structure
      contains the result of applying map_fn to each element in 'nested'. For
      example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
  """
  out = map(map_fn, nest.flatten(nested))
  return nest.pack_sequence_as(nested, out) 
Example #19
Source File: nested_utils.py    From models with Apache License 2.0 6 votes vote down vote up
def where_tensors(condition, x_tensors, y_tensors):
  """Performs a tf.where operation on a two sets of Tensors.

  Args:
    condition: The condition tensor to use for the where operation.
    x_tensors: A potentially nested tuple or list of Tensors.
    y_tensors: A potentially nested tuple or list of Tensors. Must have the
    same structure as x_tensors.
  Returns:
    whered_tensors: A potentially nested tuple or list of Tensors with the
      same structure as the 'tensors' input argument. Contains the result of
      applying tf.where(condition, x, y) on each pair of elements in x_tensors
      and y_tensors.
  """
  flat_x = nest.flatten(x_tensors)
  flat_y = nest.flatten(y_tensors)
  result = [tf.where(condition, x, y) for x, y in
            itertools.izip(flat_x, flat_y)]

  return nest.pack_sequence_as(x_tensors, result) 
Example #20
Source File: nested_utils.py    From Gun-Detector with Apache License 2.0 6 votes vote down vote up
def map_nested(map_fn, nested):
  """Executes map_fn on every element in a (potentially) nested structure.

  Args:
    map_fn: A callable to execute on each element in 'nested'.
    nested: A potentially nested combination of sequence objects. Sequence
      objects include tuples, lists, namedtuples, and all subclasses of
      collections.Sequence except strings. See nest.is_sequence for details.
      For example [1, ('hello', 4.3)] is a nested structure containing elements
      1, 'hello', and 4.3.
  Returns:
    out_structure: A potentially nested combination of sequence objects with the
      same structure as the 'nested' input argument. out_structure
      contains the result of applying map_fn to each element in 'nested'. For
      example map_nested(lambda x: x+1, [1, (3, 4.3)]) returns [2, (4, 5.3)].
  """
  out = map(map_fn, nest.flatten(nested))
  return nest.pack_sequence_as(nested, out) 
Example #21
Source File: dataset_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    return nest.pack_sequence_as(
        self._output_types,
        gen_dataset_ops.iterator_get_next(
            self._iterator_resource,
            output_types=nest.flatten(self._output_types),
            output_shapes=nest.flatten(self._output_shapes),
            name=name)) 
Example #22
Source File: shapes.py    From Counterfactual-StoryRW with MIT License 6 votes vote down vote up
def transpose_batch_time(inputs):
    """Transposes inputs between time-major and batch-major.

    Args:
        inputs: A Tensor of shape `[batch_size, max_time, ...]` (batch-major)
            or `[max_time, batch_size, ...]` (time-major), or a (possibly
            nested) tuple of such elements.

    Returns:
        A (possibly nested tuple of) Tensor with transposed batch and
        time dimensions of inputs.
    """
    flat_input = nest.flatten(inputs)
    flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
    # pylint: disable=protected-access
    flat_input = [rnn._transpose_batch_time(input_) for input_ in flat_input]
    return nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 
Example #23
Source File: bridge.py    From tensorflow_end2end_speech_recognition with MIT License 6 votes vote down vote up
def _create(self):
        # Concat bridge inputs on the depth dimensions
        bridge_input = nest.map_structure(
            lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]),
            self._bridge_input)
        bridge_input_flat = nest.flatten([bridge_input])
        bridge_input_concat = tf.concat(bridge_input_flat, axis=1)

        state_size_splits = nest.flatten(self.decoder_state_size)
        total_decoder_state_size = sum(state_size_splits)

        # Pass bridge inputs through a fully connected layer layer
        initial_state_flat = tf.contrib.layers.fully_connected(
            bridge_input_concat,
            num_outputs=total_decoder_state_size,
            activation_fn=self._activation_fn,
            weights_initializer=tf.truncated_normal_initializer(
                stddev=self.parameter_init),
            biases_initializer=tf.zeros_initializer(),
            scope=None)

        # Shape back into required state size
        initial_state = tf.split(initial_state_flat, state_size_splits, axis=1)
        return nest.pack_sequence_as(self.decoder_state_size, initial_state) 
Example #24
Source File: rnn_cell.py    From ecm with Apache License 2.0 5 votes vote down vote up
def zero_state(self, batch_size, dtype):
        """Return zero-filled state tensor(s).

        Args:
            batch_size: int, float, or unit Tensor representing the batch size.
            dtype: the data type to use for the state.

        Returns:
            If `state_size` is an int or TensorShape, then the return value is a
            `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.

            If `state_size` is a nested list or tuple, then the return value is
            a nested list or tuple (of the same structure) of `2-D` tensors with
        the shapes `[batch_size x s]` for each s in `state_size`.
        """
        state_size = self.state_size
        if nest.is_sequence(state_size):
            state_size_flat = nest.flatten(state_size)
            zeros_flat = [
                    array_ops.zeros(
                            array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
                            dtype=dtype)
                    for s in state_size_flat]
            for s, z in zip(state_size_flat, zeros_flat):
                z.set_shape(_state_size_with_prefix(s, prefix=[None]))
            zeros = nest.pack_sequence_as(structure=state_size,
                                                                        flat_sequence=zeros_flat)
        else:
            zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
            zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
            zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))

        return zeros 
Example #25
Source File: nest_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def testFlattenAndPack(self):
    structure = ((3, 4), 5, (6, 7, (9, 10), 8))
    flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
    self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
    self.assertEqual(nest.pack_sequence_as(structure, flat),
                     (("a", "b"), "c", ("d", "e", ("f", "g"), "h")))
    point = collections.namedtuple("Point", ["x", "y"])
    structure = (point(x=4, y=2), ((point(x=1, y=0),),))
    flat = [4, 2, 1, 0]
    self.assertEqual(nest.flatten(structure), flat)
    restructured_from_flat = nest.pack_sequence_as(structure, flat)
    self.assertEqual(restructured_from_flat, structure)
    self.assertEqual(restructured_from_flat[0].x, 4)
    self.assertEqual(restructured_from_flat[0].y, 2)
    self.assertEqual(restructured_from_flat[1][0][0].x, 1)
    self.assertEqual(restructured_from_flat[1][0][0].y, 0)

    self.assertEqual([5], nest.flatten(5))
    self.assertEqual([np.array([5])], nest.flatten(np.array([5])))

    self.assertEqual("a", nest.pack_sequence_as(5, ["a"]))
    self.assertEqual(
        np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])]))

    with self.assertRaisesRegexp(ValueError, "Structure is a scalar"):
      nest.pack_sequence_as("scalar", [4, 5])

    with self.assertRaisesRegexp(TypeError, "flat_sequence"):
      nest.pack_sequence_as([4, 5], "bad_sequence")

    with self.assertRaises(ValueError):
      nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) 
Example #26
Source File: rnn_cell.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def zero_state(self, batch_size, dtype):
    """Return zero-filled state tensor(s).

    Args:
      batch_size: int, float, or unit Tensor representing the batch size.
      dtype: the data type to use for the state.

    Returns:
      If `state_size` is an int or TensorShape, then the return value is a
      `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.

      If `state_size` is a nested list or tuple, then the return value is
      a nested list or tuple (of the same structure) of `2-D` tensors with
    the shapes `[batch_size x s]` for each s in `state_size`.
    """
    state_size = self.state_size
    if nest.is_sequence(state_size):
      state_size_flat = nest.flatten(state_size)
      zeros_flat = [
          array_ops.zeros(
              array_ops.pack(_state_size_with_prefix(s, prefix=[batch_size])),
              dtype=dtype)
          for s in state_size_flat]
      for s, z in zip(state_size_flat, zeros_flat):
        z.set_shape(_state_size_with_prefix(s, prefix=[None]))
      zeros = nest.pack_sequence_as(structure=state_size,
                                    flat_sequence=zeros_flat)
    else:
      zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
      zeros = array_ops.zeros(array_ops.pack(zeros_size), dtype=dtype)
      zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))

    return zeros 
Example #27
Source File: utils.py    From Gun-Detector with Apache License 2.0 5 votes vote down vote up
def structure_map_split(func, value):
  vv = nest.flatten(value)
  rets = []
  for v in vv:
    rets.append(func(v))
  return [nest.pack_sequence_as(value, r) for r in zip(*rets)] 
Example #28
Source File: rnn.py    From MIMN with MIT License 5 votes vote down vote up
def _reverse_seq(input_seq, lengths):
  """Reverse a list of Tensors up to specified lengths.

  Args:
    input_seq: Sequence of seq_len tensors of dimension (batch_size, n_features)
               or nested tuples of tensors.
    lengths:   A `Tensor` of dimension batch_size, containing lengths for each
               sequence in the batch. If "None" is specified, simply reverses
               the list.

  Returns:
    time-reversed sequence
  """
  if lengths is None:
    return list(reversed(input_seq))

  flat_input_seq = tuple(nest.flatten(input_) for input_ in input_seq)

  flat_results = [[] for _ in range(len(input_seq))]
  for sequence in zip(*flat_input_seq):
    input_shape = tensor_shape.unknown_shape(
        ndims=sequence[0].get_shape().ndims)
    for input_ in sequence:
      input_shape.merge_with(input_.get_shape())
      input_.set_shape(input_shape)

    # Join into (time, batch_size, depth)
    s_joined = array_ops.stack(sequence)

    # Reverse along dimension 0
    s_reversed = array_ops.reverse_sequence(s_joined, lengths, 0, 1)
    # Split again into list
    result = array_ops.unstack(s_reversed)
    for r, flat_result in zip(result, flat_results):
      r.set_shape(input_shape)
      flat_result.append(r)

  results = [nest.pack_sequence_as(structure=input_, flat_sequence=flat_result)
             for input_, flat_result in zip(input_seq, flat_results)]
  return results 
Example #29
Source File: control_flow_ops.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def BuildLoop(self, pred, body, loop_vars, shape_invariants):
    """Add the loop termination condition and body to the graph."""

    # Keep original_loop_vars to identify which are TensorArrays
    original_loop_vars = loop_vars
    flat_loop_vars = nest.flatten(loop_vars)
    # Convert TensorArrays to their flow variables
    loop_vars = _convert_tensorarrays_to_flows(flat_loop_vars)
    loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
    try:
      self.Enter()
      original_body_result, exit_vars = self._BuildLoop(
          pred, body, original_loop_vars, loop_vars, shape_invariants)
    finally:
      self.Exit()

    flat_result = nest.flatten(original_body_result)
    # Convert TensorArray flow variables outside the context back into
    # their associated TensorArrays for returning to caller.
    exit_vars_with_tensor_arrays = (
        _convert_flows_to_tensorarrays(flat_result, exit_vars))
    packed_exit_vars = nest.pack_sequence_as(
        structure=original_body_result,
        flat_sequence=exit_vars_with_tensor_arrays)
    return (packed_exit_vars[0] if len(exit_vars) == 1
            else packed_exit_vars) 
Example #30
Source File: utils.py    From g-tensorflow-models with Apache License 2.0 5 votes vote down vote up
def structure_map_split(func, value):
  vv = nest.flatten(value)
  rets = []
  for v in vv:
    rets.append(func(v))
  return [nest.pack_sequence_as(value, r) for r in zip(*rets)]