Python tensorflow.python.util.nest.flatten() Examples

The following are 30 code examples of tensorflow.python.util.nest.flatten(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.util.nest , or try the search function .
Example #1
Source File: dataset_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def make_one_shot_iterator(self):
    """Creates an `Iterator` for enumerating the elements of this dataset.

    **N.B.** The returned iterator will be initialized automatically.
    A "one-shot" iterator does not currently support re-initialization.

    Returns:
      An `Iterator` over the elements of this dataset.
    """
    # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
    # a 0-argument function.
    @function.Defun(capture_by_value=True)
    def _make_dataset():
      return self.make_dataset_resource()

    _make_dataset.add_to_graph(ops.get_default_graph())

    return Iterator(
        gen_dataset_ops.one_shot_iterator(
            dataset_factory=_make_dataset,
            output_types=nest.flatten(self.output_types),
            output_shapes=nest.flatten(self.output_shapes)), None,
        self.output_types, self.output_shapes) 
Example #2
Source File: attention_ops.py    From hart with GNU General Public License v3.0 6 votes vote down vote up
def _zero_state(self, img, att, presence, state, transform_features, transform_state=False):

        with tf.variable_scope(self.__class__.__name__) as vs:
            features = self.extract_features(img, att)[1]

            if transform_features:
                features_flat = tf.reshape(features, (-1, self.n_units))
                features_flat = AffineLayer(features_flat, self.n_units, name='init_feature_transform').output
                features = tf.reshape(features_flat, tf.shape(features))

            rnn_outputs, hidden_state = self._propagate(features, state)

            hidden_state = nest.flatten(hidden_state)

            if transform_state:
                for i, hs in enumerate(hidden_state):
                    name = 'init_state_transform_{}'.format(i)
                    hidden_state[i] = AffineLayer(hs, self.n_units, name=name).output

            state = nest.pack_sequence_as(structure=state, flat_sequence=hidden_state)
        self.rnn_vs = vs
        return state, rnn_outputs 
Example #3
Source File: nn.py    From hart with GNU General Public License v3.0 6 votes vote down vote up
def __init__(self, inpt, n_hidden, n_output, transfer_hidden=tf.nn.elu, transfer=None,
                 hidden_weight_init=None, hidden_bias_init=None,weight_init=None, bias_init=None,
                 name=None):
        """
        :param inpt: inpt tensor
        :param n_hidden: scalar ot list, number of hidden units
        :param n_output: scalar, number of output units
        :param transfer_hidden: scalar or list, transfers for hidden units. If list, len must be == len(n_hidden).
        :param transfer: tf.Op or None
        """

        self.n_hidden = nest.flatten(n_hidden)
        self.n_output = n_output
        self.hidden_weight_init = hidden_weight_init
        self.hidden_bias_init = hidden_bias_init

        transfer_hidden = nest.flatten(transfer_hidden)
        if len(transfer_hidden) == 1:
            transfer_hidden *= len(self.n_hidden)
        self.transfer_hidden = transfer_hidden

        self.transfer = transfer
        super(MLP, self).__init__(inpt, name, weight_init, bias_init) 
Example #4
Source File: dynamic_rnn_estimator.py    From lambda-packs with MIT License 6 votes vote down vote up
def state_tuple_to_dict(state):
  """Returns a dict containing flattened `state`.

  Args:
    state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must
    have the same rank and agree on all dimensions except the last.

  Returns:
    A dict containing the `Tensor`s that make up `state`. The keys of the dict
    are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor`
    in a depth-first traversal of `state`.
  """
  with ops.name_scope('state_tuple_to_dict'):
    flat_state = nest.flatten(state)
    state_dict = {}
    for i, state_component in enumerate(flat_state):
      state_name = _get_state_name(i)
      state_value = (None if state_component is None
                     else array_ops.identity(state_component, name=state_name))
      state_dict[state_name] = state_value
  return state_dict 
Example #5
Source File: state_saving_rnn_estimator.py    From lambda-packs with MIT License 6 votes vote down vote up
def state_tuple_to_dict(state):
  """Returns a dict containing flattened `state`.

  Args:
    state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must
    have the same rank and agree on all dimensions except the last.

  Returns:
    A dict containing the `Tensor`s that make up `state`. The keys of the dict
    are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor`
    in a depth-first traversal of `state`.
  """
  with ops.name_scope('state_tuple_to_dict'):
    flat_state = nest.flatten(state)
    state_dict = {}
    for i, state_component in enumerate(flat_state):
      state_name = _get_state_name(i)
      state_value = (None if state_component is None else array_ops.identity(
          state_component, name=state_name))
      state_dict[state_name] = state_value
  return state_dict 
Example #6
Source File: dynamic_rnn_estimator.py    From auto-alt-text-lambda-api with MIT License 6 votes vote down vote up
def state_tuple_to_dict(state):
  """Returns a dict containing flattened `state`.

  Args:
    state: A `Tensor` or a nested tuple of `Tensors`. All of the `Tensor`s must
    have the same rank and agree on all dimensions except the last.

  Returns:
    A dict containing the `Tensor`s that make up `state`. The keys of the dict
    are of the form "STATE_PREFIX_i" where `i` is the place of this `Tensor`
    in a depth-first traversal of `state`.
  """
  with ops.name_scope('state_tuple_to_dict'):
    flat_state = nest.flatten(state)
    state_dict = {}
    for i, state_component in enumerate(flat_state):
      state_name = _get_state_name(i)
      state_value = (None if state_component is None
                     else array_ops.identity(state_component, name=state_name))
      state_dict[state_name] = state_value
  return state_dict 
Example #7
Source File: beam_search_decoder_from_tensorflow.py    From tensorflow_end2end_speech_recognition with MIT License 6 votes vote down vote up
def tile_batch(t, multiplier, name=None):
    """Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
    For each tensor t in a (possibly nested structure) of tensors,
    this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
    minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
    `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
    `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
    `multiplier` times.
    Args:
      t: `Tensor` shaped `[batch_size, ...]`.
      multiplier: Python int.
      name: Name scope for any created operations.
    Returns:
      A (possibly nested structure of) `Tensor` shaped
      `[batch_size * multiplier, ...]`.
    Raises:
      ValueError: if tensor(s) `t` do not have a statically known rank or
      the rank is < 1.
    """
    flat_t = nest.flatten(t)
    with tf.name_scope(name, "tile_batch", flat_t + [multiplier]):
        return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) 
Example #8
Source File: dataset_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def get_next(self, name=None):
    """Returns a nested structure of `tf.Tensor`s containing the next element.

    Args:
      name: (Optional.) A name for the created operation.

    Returns:
      A nested structure of `tf.Tensor` objects.
    """
    return nest.pack_sequence_as(
        self._output_types,
        gen_dataset_ops.iterator_get_next(
            self._iterator_resource,
            output_types=nest.flatten(self._output_types),
            output_shapes=nest.flatten(self._output_shapes),
            name=name)) 
Example #9
Source File: beam_search_decoder.py    From lambda-packs with MIT License 6 votes vote down vote up
def initialize(self, name=None):
    """Initialize the decoder.

    Args:
      name: Name scope for any created operations.

    Returns:
      `(finished, start_inputs, initial_state)`.
    """
    finished, start_inputs = self._finished, self._start_inputs

    initial_state = BeamSearchDecoderState(
        cell_state=self._initial_cell_state,
        log_probs=array_ops.zeros(
            [self._batch_size, self._beam_width],
            dtype=nest.flatten(self._initial_cell_state)[0].dtype),
        finished=finished,
        lengths=array_ops.zeros(
            [self._batch_size, self._beam_width], dtype=dtypes.int32))

    return (finished, start_inputs, initial_state) 
Example #10
Source File: bridge.py    From tensorflow_end2end_speech_recognition with MIT License 6 votes vote down vote up
def _create(self):
        # Concat bridge inputs on the depth dimensions
        bridge_input = nest.map_structure(
            lambda x: tf.reshape(x, [self.batch_size, _total_tensor_depth(x)]),
            self._bridge_input)
        bridge_input_flat = nest.flatten([bridge_input])
        bridge_input_concat = tf.concat(bridge_input_flat, axis=1)

        state_size_splits = nest.flatten(self.decoder_state_size)
        total_decoder_state_size = sum(state_size_splits)

        # Pass bridge inputs through a fully connected layer layer
        initial_state_flat = tf.contrib.layers.fully_connected(
            bridge_input_concat,
            num_outputs=total_decoder_state_size,
            activation_fn=self._activation_fn,
            weights_initializer=tf.truncated_normal_initializer(
                stddev=self.parameter_init),
            biases_initializer=tf.zeros_initializer(),
            scope=None)

        # Shape back into required state size
        initial_state = tf.split(initial_state_flat, state_size_splits, axis=1)
        return nest.pack_sequence_as(self.decoder_state_size, initial_state) 
Example #11
Source File: dataset_ops.py    From lambda-packs with MIT License 6 votes vote down vote up
def make_dataset_resource(self):
    input_resource = self._input_dataset.make_dataset_resource()
    if self._num_threads is None:
      return gen_dataset_ops.map_dataset(
          input_resource,
          self._map_func.captured_inputs,
          f=self._map_func,
          output_types=nest.flatten(self.output_types),
          output_shapes=nest.flatten(self.output_shapes))
    else:
      return gen_dataset_ops.parallel_map_dataset(
          input_resource,
          self._map_func.captured_inputs,
          f=self._map_func,
          num_threads=self._num_threads,
          output_buffer_size=self._output_buffer_size,
          output_types=nest.flatten(self.output_types),
          output_shapes=nest.flatten(self.output_shapes)) 
Example #12
Source File: shapes.py    From Counterfactual-StoryRW with MIT License 6 votes vote down vote up
def transpose_batch_time(inputs):
    """Transposes inputs between time-major and batch-major.

    Args:
        inputs: A Tensor of shape `[batch_size, max_time, ...]` (batch-major)
            or `[max_time, batch_size, ...]` (time-major), or a (possibly
            nested) tuple of such elements.

    Returns:
        A (possibly nested tuple of) Tensor with transposed batch and
        time dimensions of inputs.
    """
    flat_input = nest.flatten(inputs)
    flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
    # pylint: disable=protected-access
    flat_input = [rnn._transpose_batch_time(input_) for input_ in flat_input]
    return nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input) 
Example #13
Source File: rnn_decoders.py    From Counterfactual-StoryRW with MIT License 6 votes vote down vote up
def output_dtype(self):
        """Types of output of one step.
        """
        # Assume the dtype of the cell is the output_size structure
        # containing the input_state's first component's dtype.
        # Return that structure and the sample_ids_dtype from the helper.
        dtype = nest.flatten(self._initial_state)[0].dtype
        return AttentionRNNDecoderOutput(
            logits=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
            sample_id=self._helper.sample_ids_dtype,
            cell_output=nest.map_structure(
                lambda _: dtype, self._cell.output_size),
            attention_scores=nest.map_structure(
                lambda _: dtype, self._alignments_size()),
            attention_context=nest.map_structure(
                lambda _: dtype, self._cell.state_size.attention)) 
Example #14
Source File: fisher_factors.py    From kfac with Apache License 2.0 5 votes vote down vote up
def _var_scope(self):
    return "ff_fc_multi_" + scope_string_from_params(
        tuple(nest.flatten(self._tensors))
        + (self._num_timesteps, self._has_bias,)) 
Example #15
Source File: fisher_factors.py    From kfac with Apache License 2.0 5 votes vote down vote up
def _var_scope(self):
    return "ff_convdiag_" + scope_string_from_params(
        tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) 
Example #16
Source File: curvature_matrix_vector_products.py    From kfac with Apache License 2.0 5 votes vote down vote up
def _multiply_jacobian_transpose(self, loss_vecs):
    """Multiply vecs by the transpose Jacobian of losses."""
    loss_vecs_flat = nest.flatten(loss_vecs)
    # We stop gradients at wrt_tensors to produce partial derivatives (which is
    # what we want for Jacobians).
    return tf.gradients(
        self._inputs_to_losses_flat,
        self._wrt_tensors,
        grad_ys=loss_vecs_flat,
        stop_gradients=self._wrt_tensors,
        colocate_gradients_with_ops=self._colocate_gradients_with_ops)

  # Loss Fisher/GGN multiplication functions: 
Example #17
Source File: fisher_factors.py    From kfac with Apache License 2.0 5 votes vote down vote up
def _var_scope(self):
    return "ff_diag_kron_" + scope_string_from_params(
        nest.flatten(self._tensors)) 
Example #18
Source File: fisher_factors.py    From kfac with Apache License 2.0 5 votes vote down vote up
def _var_scope(self):
    return "ff_diagfc_" + scope_string_from_params(
        tuple(self._inputs) + tuple(nest.flatten(self._outputs_grads))) 
Example #19
Source File: dataset_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def __init__(self, input_dataset, key_func, reduce_func, window_size):
    """See `Dataset.group_by_window()` for details."""
    super(GroupByWindowDataset, self).__init__()
    self._input_dataset = input_dataset
    self._window_size = window_size

    @function.Defun(*nest.flatten(input_dataset.output_types))
    def tf_key_func(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
        arg.set_shape(shape)
      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
      if nest.is_sequence(nested_args):
        ret = key_func(*nested_args)
      else:
        ret = key_func(nested_args)
      ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
      if ret.dtype != dtypes.int64:
        raise ValueError("`key_func` must return a single tf.int64 tensor.")
      return ret

    self._key_func = tf_key_func
    self._key_func.add_to_graph(ops.get_default_graph())

    @function.Defun(dtypes.int64, dtypes.resource)
    def tf_reduce_func(key, window_dataset_resource):
      """A wrapper for Defun that facilitates shape inference."""
      key.set_shape([])
      window_dataset = _ResourceDataset(window_dataset_resource,
                                        input_dataset.output_types,
                                        input_dataset.output_shapes)
      output_dataset = reduce_func(key, window_dataset)
      if not isinstance(output_dataset, Dataset):
        raise TypeError("`reduce_func` must return a `Dataset` object.")
      self._output_types = output_dataset.output_types
      self._output_shapes = output_dataset.output_shapes
      return output_dataset.make_dataset_resource()

    self._reduce_func = tf_reduce_func
    self._reduce_func.add_to_graph(ops.get_default_graph()) 
Example #20
Source File: control_flow_ops.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def BuildLoop(self, pred, body, loop_vars, shape_invariants):
    """Add the loop termination condition and body to the graph."""

    # Keep original_loop_vars to identify which are TensorArrays
    original_loop_vars = loop_vars
    flat_loop_vars = nest.flatten(loop_vars)
    # Convert TensorArrays to their flow variables
    loop_vars = _convert_tensorarrays_to_flows(flat_loop_vars)
    loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
    try:
      self.Enter()
      original_body_result, exit_vars = self._BuildLoop(
          pred, body, original_loop_vars, loop_vars, shape_invariants)
    finally:
      self.Exit()

    flat_result = nest.flatten(original_body_result)
    # Convert TensorArray flow variables outside the context back into
    # their associated TensorArrays for returning to caller.
    exit_vars_with_tensor_arrays = (
        _convert_flows_to_tensorarrays(flat_result, exit_vars))
    packed_exit_vars = nest.pack_sequence_as(
        structure=original_body_result,
        flat_sequence=exit_vars_with_tensor_arrays)
    return (packed_exit_vars[0] if len(exit_vars) == 1
            else packed_exit_vars) 
Example #21
Source File: rnn_cell_impl.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def zero_state(self, batch_size, dtype):
    """Return zero-filled state tensor(s).

    Args:
      batch_size: int, float, or unit Tensor representing the batch size.
      dtype: the data type to use for the state.

    Returns:
      If `state_size` is an int or TensorShape, then the return value is a
      `N-D` tensor of shape `[batch_size x state_size]` filled with zeros.

      If `state_size` is a nested list or tuple, then the return value is
      a nested list or tuple (of the same structure) of `2-D` tensors with
    the shapes `[batch_size x s]` for each s in `state_size`.
    """
    state_size = self.state_size
    if nest.is_sequence(state_size):
      state_size_flat = nest.flatten(state_size)
      zeros_flat = [
          array_ops.zeros(
              array_ops.stack(_state_size_with_prefix(
                  s, prefix=[batch_size])),
              dtype=dtype) for s in state_size_flat
      ]
      for s, z in zip(state_size_flat, zeros_flat):
        z.set_shape(_state_size_with_prefix(s, prefix=[None]))
      zeros = nest.pack_sequence_as(structure=state_size,
                                    flat_sequence=zeros_flat)
    else:
      zeros_size = _state_size_with_prefix(state_size, prefix=[batch_size])
      zeros = array_ops.zeros(array_ops.stack(zeros_size), dtype=dtype)
      zeros.set_shape(_state_size_with_prefix(state_size, prefix=[None]))

    return zeros 
Example #22
Source File: rnn.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def _infer_state_dtype(explicit_dtype, state):
  """Infer the dtype of an RNN state.

  Args:
    explicit_dtype: explicitly declared dtype or None.
    state: RNN's hidden state. Must be a Tensor or a nested iterable containing
      Tensors.

  Returns:
    dtype: inferred dtype of hidden state.

  Raises:
    ValueError: if `state` has heterogeneous dtypes or is empty.
  """
  if explicit_dtype is not None:
    return explicit_dtype
  elif nest.is_sequence(state):
    inferred_dtypes = [element.dtype for element in nest.flatten(state)]
    if not inferred_dtypes:
      raise ValueError("Unable to infer dtype from empty state.")
    all_same = all([x == inferred_dtypes[0] for x in inferred_dtypes])
    if not all_same:
      raise ValueError(
          "State has tensors of different inferred_dtypes. Unable to infer a "
          "single representative dtype.")
    return inferred_dtypes[0]
  else:
    return state.dtype 
Example #23
Source File: state_saving_rnn_estimator.py    From lambda-packs with MIT License 5 votes vote down vote up
def _get_initial_states(cell):
  """Gets the initial state of the `RNNCell` used in the RNN.

  Args:
    cell: A `RNNCell` to be used in the RNN.

  Returns:
    A Python dict mapping state names to the `RNNCell`'s initial state for
    consumption by the SQSS.
  """
  names = nest.flatten(_get_state_names(cell))
  values = nest.flatten(cell.zero_state(1, dtype=dtypes.float32))
  return {n: array_ops.squeeze(v, axis=0) for [n, v] in zip(names, values)} 
Example #24
Source File: basic_decoder.py    From lambda-packs with MIT License 5 votes vote down vote up
def output_dtype(self):
    # Assume the dtype of the cell is the output_size structure
    # containing the input_state's first component's dtype.
    # Return that structure and int32 (the id)
    dtype = nest.flatten(self._initial_state)[0].dtype
    return BasicDecoderOutput(
        nest.map_structure(lambda _: dtype, self._rnn_output_size()),
        dtypes.int32) 
Example #25
Source File: beam_search_decoder.py    From lambda-packs with MIT License 5 votes vote down vote up
def tile_batch(t, multiplier, name=None):
  """Tile the batch dimension of a (possibly nested structure of) tensor(s) t.

  For each tensor t in a (possibly nested structure) of tensors,
  this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
  minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
  `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
  `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
  `multiplier` times.

  Args:
    t: `Tensor` shaped `[batch_size, ...]`.
    multiplier: Python int.
    name: Name scope for any created operations.

  Returns:
    A (possibly nested structure of) `Tensor` shaped
    `[batch_size * multiplier, ...]`.

  Raises:
    ValueError: if tensor(s) `t` do not have a statically known rank or
    the rank is < 1.
  """
  flat_t = nest.flatten(t)
  with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
    return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) 
Example #26
Source File: dataset_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def make_dataset_resource(self):
    return gen_dataset_ops.filter_dataset(
        self._input_dataset.make_dataset_resource(),
        other_arguments=self._predicate.captured_inputs,
        predicate=self._predicate,
        output_types=nest.flatten(self.output_types),
        output_shapes=nest.flatten(self.output_shapes)) 
Example #27
Source File: dataset_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def __init__(self, input_dataset, predicate):
    """See `Dataset.filter()` for details."""
    super(FilterDataset, self).__init__()
    self._input_dataset = input_dataset

    @function.Defun(*nest.flatten(input_dataset.output_types))
    def tf_predicate(*args):
      """A wrapper for Defun that facilitates shape inference."""
      # Pass in shape information from the input_dataset.
      for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)):
        arg.set_shape(shape)

      nested_args = nest.pack_sequence_as(input_dataset.output_types, args)

      if nest.is_sequence(nested_args):
        ret = predicate(*nested_args)
      else:
        ret = predicate(nested_args)

      ret = ops.convert_to_tensor(ret, dtype=dtypes.bool)
      if not (ret.dtype == dtypes.bool and
              ret.shape.is_compatible_with(tensor_shape.scalar())):
        raise ValueError("`predicate` must return a scalar boolean tensor.")

      return ret

    self._predicate = tf_predicate
    self._predicate.add_to_graph(ops.get_default_graph()) 
Example #28
Source File: dataset_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def make_dataset_resource(self):
    return gen_dataset_ops.flat_map_dataset(
        self._input_dataset.make_dataset_resource(),
        self._map_func.captured_inputs,
        f=self._map_func,
        output_types=nest.flatten(self.output_types),
        output_shapes=nest.flatten(self.output_shapes)) 
Example #29
Source File: dataset_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def make_dataset_resource(self):
    return gen_dataset_ops.group_by_window_dataset(
        self._input_dataset.make_dataset_resource(),
        self._key_func.captured_inputs,
        self._reduce_func.captured_inputs,
        self._window_size,
        key_func=self._key_func,
        reduce_func=self._reduce_func,
        output_types=nest.flatten(self.output_types),
        output_shapes=nest.flatten(self.output_shapes)) 
Example #30
Source File: dataset_ops.py    From lambda-packs with MIT License 5 votes vote down vote up
def make_dataset_resource(self):
    return gen_dataset_ops.padded_batch_dataset(
        self._input_dataset.make_dataset_resource(),
        batch_size=self._batch_size,
        padded_shapes=[
            ops.convert_to_tensor(s, dtype=dtypes.int64)
            for s in nest.flatten(self._padded_shapes)
        ],
        padding_values=nest.flatten(self._padding_values),
        output_shapes=nest.flatten(self.output_shapes))