Python tensorflow.compat.v1.identity() Examples

The following are 30 code examples of tensorflow.compat.v1.identity(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.compat.v1 , or try the search function .
Example #1
Source File: signal_conv_test.py    From compression with Apache License 2.0 6 votes vote down vote up
def test_2d_bias_activation(self):
    """Test 2D convolutions with bias and activation."""
    batch = 1
    channels = 1
    filters = 1
    input_support = (4, 6)
    kernel_support = (2, 2)
    corr = True
    strides_up = (1, 1)
    strides_down = (1, 1)
    extra_pad_end = True
    channel_separable = False
    activation = tf.identity
    use_bias = True
    padding = "valid"
    for data_format in self.data_formats:
      self.run_or_fail(
          self.run_valid,
          batch, input_support, channels, filters,
          kernel_support, corr, strides_down, strides_up,
          padding, extra_pad_end, channel_separable,
          data_format, activation, use_bias) 
Example #2
Source File: aligned.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def infer(self,
            features=None,
            decode_length=1,
            beam_size=1,
            top_beams=1,
            alpha=0.0,
            use_tpu=False):
    """Predict."""
    features["targets"] = tf.identity(features["inputs"])
    logits, _ = self(features)
    log_probs = common_layers.log_prob_from_logits(logits)
    predictions, scores = common_layers.argmax_with_score(log_probs)
    return {
        "outputs": predictions,
        "scores": scores,
    } 
Example #3
Source File: tf_atari_wrappers.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def simulate(self, action):
    reward, done = self._batch_env.simulate(action)
    with tf.control_dependencies([reward, done]):
      new_observ = tf.expand_dims(self._batch_env.observ, axis=1)

      # If we shouldn't stack, i.e. self.history == 1, then just assign
      # new_observ to self._observ and return from here.
      if self.history == 1:
        with tf.control_dependencies([self._observ.assign(new_observ)]):
          return tf.identity(reward), tf.identity(done)

      # If we should stack, then do the required work.
      old_observ = tf.gather(
          self._observ.read_value(),
          list(range(1, self.history)),
          axis=1)
      with tf.control_dependencies([new_observ, old_observ]):
        with tf.control_dependencies([self._observ.assign(
            tf.concat([old_observ, new_observ], axis=1))]):
          return tf.identity(reward), tf.identity(done) 
Example #4
Source File: optimize.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None):
  """Apply weight decay and weight noise."""
  if var_list is None:
    var_list = tf.trainable_variables()

  decay_vars = [v for v in var_list]
  noise_vars = [v for v in var_list if "/body/" in v.name]

  weight_decay_loss = weight_decay(hparams.weight_decay, decay_vars)
  if hparams.weight_decay and common_layers.should_generate_summaries():
    tf.summary.scalar("losses/weight_decay", weight_decay_loss)
  weight_noise_ops = weight_noise(hparams.weight_noise, learning_rate,
                                  noise_vars)

  with tf.control_dependencies(weight_noise_ops):
    loss = tf.identity(loss)

  loss += weight_decay_loss
  return loss 
Example #5
Source File: ppo_learner.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def simulate(self, action):

    # There is subtlety here. We need to collect data
    # obs, action = policy(obs), done, reward = env(abs, action)
    # Thus we need to enqueue data before assigning new observation

    reward, done = self._batch_env.simulate(action)

    with tf.control_dependencies([reward, done]):
      enqueue_op = self.speculum.enqueue(
          [self._observ.read_value(), reward, done, action])

    with tf.control_dependencies([enqueue_op]):
      assign = self._observ.assign(self._batch_env.observ)

    with tf.control_dependencies([assign]):
      return tf.identity(reward), tf.identity(done) 
Example #6
Source File: transformer_memory.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def post_attention(self, token, x):
    """Called after self-attention. The memory can be updated here.

    Args:
      token: Data returned by pre_attention, which can be used to carry over
        state related to the current memory operation.
      x: a Tensor of data after self-attention and feed-forward
    Returns:
      a (possibly modified) version of the input x
    """
    with tf.control_dependencies([
        self.previous_segment.assign(token[0]),
        self.previous_vals.assign(token[1]),
        self.previous_bias.assign(token[2]),
        ]):
      return tf.identity(x) 
Example #7
Source File: transformer_memory.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def post_attention(self, token, x):
    """Called after self-attention. The memory can be updated here.

    Args:
      token: Data returned by pre_attention, which can be used to carry over
        state related to the current memory operation.
      x: a Tensor of data after self-attention and feed-forward
    Returns:
      a (possibly modified) version of the input x
    """
    with tf.variable_scope(self.name + "/post_attention", reuse=tf.AUTO_REUSE):
      depth = common_layers.shape_list(x)[-1]
      actual_batch_size = common_layers.shape_list(x)[0]
      memory_output = tf.gather(token["retrieved_mem"],
                                tf.range(actual_batch_size))
      output = tf.add(tf.layers.dense(x, depth, use_bias=False),
                      tf.layers.dense(memory_output, depth))
      with tf.control_dependencies([output]):
        with tf.control_dependencies([
            self.write(token["x"], token["access_logits"])]):
          return tf.identity(output) 
Example #8
Source File: generator_utils.py    From tensor2tensor with Apache License 2.0 6 votes vote down vote up
def _finalize(self, _, contents):
    """Structure output and compute segment and position metadata."""

    # The output shape information is lost during the filter; however we can
    # guarantee the shape. (That's the point of this exercise, after all!)
    contents.set_shape((self._packed_length, self._num_sequences * 2))

    # Both the dummy branch of the scan step function and the eviction dataset
    # use vectors of minus one. The cost of this check is negligible and the
    # leakage of such dummy sequences would be difficult to debug downstream.
    check_leaks = tf.assert_none_equal(contents, -tf.ones_like(contents))
    with tf.control_dependencies([check_leaks]):
      contents = tf.identity(contents)

    segment, position = self._compute_auxiliary_structure(contents)
    return {"contents": contents[:, :self._num_sequences],
            "segment": segment, "position": position} 
Example #9
Source File: tpu_util.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def write_to_variable(tensor, fail_if_exists=True):
  """Saves a tensor for later retrieval on CPU."""
  # Only relevant for debugging.
  debug_name = 'tpu_util__' + tensor.name.split(':')[0]

  reuse = False if fail_if_exists else tf.compat.v1.AUTO_REUSE
  with tf.variable_scope(top_level_scope, reuse=reuse):
    variable = tf.get_variable(
        name=debug_name,
        shape=tensor.shape,
        dtype=tensor.dtype,
        trainable=False,
        use_resource=True)

  var_store[tensor] = variable
  with tf.control_dependencies([variable.assign(tensor)]):
    tensor_copy = tf.identity(tensor)
  var_store[tensor_copy] = variable
  return tensor_copy 
Example #10
Source File: tpu_util_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def test_resource_variable(self):
    with tf.variable_scope('', use_resource=True):
      relu = self.build_model()
    gamma_tensor, gamma_source_op = self.get_gamma(relu)
    variable = tpu_util.maybe_convert_to_variable(gamma_tensor)

    # First assert that we didn't return the original tensor
    self.assertNotEqual(variable, gamma_tensor)

    # Now check that the variable created by maybe_convert_to_variable is
    # driven by the same op as the tensor passed as input.
    self.assertEqual(variable.op, gamma_source_op)

    # If input tensor is separated from a variable by an extra hop of Identity,
    # maybe_read_variable pretends the Identity op isn't there.
    identity_tensor = tf.identity(gamma_tensor)
    self.assertEqual(
        tpu_util.maybe_convert_to_variable(identity_tensor), variable) 
Example #11
Source File: depth_to_space_op_handler_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def test_assign_grouping_no_neighbor_groups(self):
    # No ops have groups.
    self.op_group_dict = {}

    # Call handler to assign grouping.
    handler = depth_to_space_op_handler.DepthToSpaceOpHandler()
    handler.assign_grouping(self.dts_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        [mock.call(self.id1_op),
         mock.call(self.id2_op)])

    # Verify manager does not group.
    self.mock_op_reg_manager.group_op_slices.assert_not_called()

    # Verify manager processes grouping for identity ops.
    self.mock_op_reg_manager.process_ops.assert_called_once_with(
        [self.id1_op]) 
Example #12
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testGetRegularizerForConcatWithNone(self, test_concat, depth):
    image = tf.constant(0.0, shape=[1, 17, 19, 3])
    conv2 = layers.conv2d(image, 5, [1, 1], padding='SAME', scope='conv2')
    other_input = tf.add(
        tf.identity(tf.constant(3.0, shape=[1, 17, 19, depth])), 3.0)
    # other_input has None as regularizer.
    concat = tf.concat([other_input, conv2], 3)
    output = tf.add(concat, concat, name='output_out')
    op = concat.op if test_concat else output.op

    # Instantiate OpRegularizerManager.
    op_handler_dict = self._default_op_handler_dict
    op_handler_dict['Conv2D'] = StubConvSourceOpHandler(add_concat_model_stub)
    op_reg_manager = orm.OpRegularizerManager([output.op], op_handler_dict)

    expected_alive = add_concat_model_stub.expected_alive()
    alive = op_reg_manager.get_regularizer(op).alive_vector
    self.assertAllEqual([True] * depth, alive[:depth])
    self.assertAllEqual(expected_alive['conv2'], alive[depth:]) 
Example #13
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testAddN_Duplicates(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)
    add_n = tf.add_n([identity, identity, identity, identity])
    batch_norm = layers.batch_norm(add_n)

    manager = orm.OpRegularizerManager(
        [batch_norm.op], op_handler_dict=self._default_op_handler_dict)

    op_slices = manager.get_op_slices(identity.op)
    self.assertLen(op_slices, 1)
    op_group = manager.get_op_group(op_slices[0]).op_slices

    # Verify all ops are in the same group.
    for test_op in (identity.op, add_n.op, batch_norm.op):
      test_op_slices = manager.get_op_slices(test_op)
      self.assertLen(test_op_slices, 1)
      self.assertIn(test_op_slices[0], op_group) 
Example #14
Source File: nasnet_model.py    From benchmarks with Apache License 2.0 6 votes vote down vote up
def _build_aux_head(net, end_points, num_classes, hparams, scope):
  """Auxiliary head used for all models across all datasets."""
  with tf.variable_scope(scope):
    aux_logits = tf.identity(net)
    with tf.variable_scope('aux_logits'):
      aux_logits = slim.avg_pool2d(
          aux_logits, [5, 5], stride=3, padding='VALID')
      aux_logits = slim.conv2d(aux_logits, 128, [1, 1], scope='proj')
      aux_logits = slim.batch_norm(aux_logits, scope='aux_bn0')
      aux_logits = tf.nn.relu(aux_logits)
      # Shape of feature map before the final layer.
      shape = aux_logits.shape
      if hparams.data_format == 'NHWC':
        shape = shape[1:3]
      else:
        shape = shape[2:4]
      aux_logits = slim.conv2d(aux_logits, 768, shape, padding='VALID')
      aux_logits = slim.batch_norm(aux_logits, scope='aux_bn1')
      aux_logits = tf.nn.relu(aux_logits)
      aux_logits = contrib_layers.flatten(aux_logits)
      aux_logits = slim.fully_connected(aux_logits, num_classes)
      end_points['AuxLogits'] = aux_logits 
Example #15
Source File: signal_conv_test.py    From compression with Apache License 2.0 6 votes vote down vote up
def test_1d_bias_activation(self):
    """Test 1D convolutions with bias and activation."""
    batch = 1
    channels = 1
    filters = 1
    input_support = (6,)
    kernel_support = (3,)
    corr = True
    strides_up = (1,)
    strides_down = (1,)
    extra_pad_end = True
    channel_separable = False
    activation = tf.identity
    use_bias = True
    padding = "valid"
    for data_format in self.data_formats:
      self.run_or_fail(
          self.run_valid,
          batch, input_support, channels, filters,
          kernel_support, corr, strides_down, strides_up,
          padding, extra_pad_end, channel_separable,
          data_format, activation, use_bias) 
Example #16
Source File: layers_test.py    From tf-slim with Apache License 2.0 6 votes vote down vote up
def testCreateDropoutWithPlaceholder(self):
    height, width = 3, 3
    tf.reset_default_graph()
    with self.cached_session():
      is_training = array_ops.placeholder(dtype=dtypes.bool, shape=[])
      images = random_ops.random_uniform((5, height, width, 3), seed=1)
      # this verifies that that we've inserted cond properly.
      output = _layers.dropout(images, is_training=is_training)
      # In control_flow_v2 the op is called "If" and it is behind
      # identity op. In legacy mode cond we just go by name.
      # Might need to do something more robust here eventually.
      is_cond_op = (output.op.inputs[0].op.type == 'If' or
                    output.op.name == 'Dropout/cond/Merge')
      self.assertTrue(is_cond_op,
                      'Expected cond_op got ' + repr(output))
      output.get_shape().assert_is_compatible_with(images.get_shape()) 
Example #17
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testGetSourceSlices(self):
    inputs = tf.zeros([2, 4, 4, 10])
    identity = tf.identity(inputs)

    manager = orm.OpRegularizerManager([])

    # Create OpSlices with size [3, 7].
    identity_slice1 = orm.OpSlice(identity.op, orm.Slice(0, 3))
    identity_slice2 = orm.OpSlice(identity.op, orm.Slice(3, 7))

    # Create OpGroup where only first group has source OpSlice.
    manager.create_op_group_for_op_slice(identity_slice1)
    manager.create_op_group_for_op_slice(identity_slice2,
                                         is_source=False)

    # First slice of size 3 is sliced into [1, 2], so these are sources.  Second
    # slice of size 7 is sliced into [3, 4], which are not sources.
    sizes = [1, 2, 3, 4]
    expected_sources = [True, True, False, False]
    self.assertListEqual(
        expected_sources,
        manager._get_source_slices(sizes, [identity_slice1, identity_slice2])) 
Example #18
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testGetOpSlices_CreateNew(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)

    # Create OpRegularizerManager with empty OpSlice dictionary.
    manager = orm.OpRegularizerManager([])
    manager._op_slice_dict = {}

    op_slices = manager.get_op_slices(identity.op)

    # Verify OpSlice is created correctly.
    self.assertLen(op_slices, 1)
    op_slice = op_slices[0]
    self.assertEqual(identity.op, op_slice.op)
    self.assertEqual(0, op_slice.slice.start_index)
    self.assertEqual(3, op_slice.slice.size) 
Example #19
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testProcessOpsLast_DuplicatesRemoved(self):
    inputs = tf.zeros([2, 4, 4, 3])
    batch_norm = layers.batch_norm(inputs)
    identity1 = tf.identity(batch_norm)
    identity2 = tf.identity(batch_norm)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    manager.process_ops([identity1.op])
    manager.process_ops_last([identity2.op, batch_norm.op])
    # Try to process the same ops again.
    manager.process_ops_last([identity2.op, batch_norm.op])

    self.assertLen(manager._op_deque, 3)
    self.assertEqual(identity1.op, manager._op_deque.pop())
    self.assertEqual(identity2.op, manager._op_deque.pop())
    self.assertEqual(batch_norm.op, manager._op_deque.pop()) 
Example #20
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testProcessOpsLast(self):
    inputs = tf.zeros([2, 4, 4, 3])
    batch_norm = layers.batch_norm(inputs)
    identity1 = tf.identity(batch_norm)
    identity2 = tf.identity(batch_norm)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    manager.process_ops([identity1.op])
    manager.process_ops_last([identity2.op, batch_norm.op])

    self.assertLen(manager._op_deque, 3)
    self.assertEqual(identity1.op, manager._op_deque.pop())
    self.assertEqual(identity2.op, manager._op_deque.pop())
    self.assertEqual(batch_norm.op, manager._op_deque.pop()) 
Example #21
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testProcessOps_DuplicatesRemoved(self):
    inputs = tf.zeros([2, 4, 4, 3])
    batch_norm = layers.batch_norm(inputs)
    identity1 = tf.identity(batch_norm)
    identity2 = tf.identity(batch_norm)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    manager.process_ops([identity1.op, identity2.op, batch_norm.op])
    # Try to process the same ops again.
    manager.process_ops([identity1.op, identity2.op, batch_norm.op])

    self.assertLen(manager._op_deque, 3)
    self.assertEqual(batch_norm.op, manager._op_deque.pop())
    self.assertEqual(identity2.op, manager._op_deque.pop())
    self.assertEqual(identity1.op, manager._op_deque.pop()) 
Example #22
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testInit_AddConcat_AllOps(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      add = c1 + c2
      c3 = layers.conv2d(add, num_outputs=10, kernel_size=3, scope='conv3')
      out = tf.identity(c3)
      concat = tf.concat([c1, c2], axis=3)
      c4 = layers.conv2d(concat, num_outputs=10, kernel_size=3, scope='conv4')

    manager = orm.OpRegularizerManager(
        [out.op], self._default_op_handler_dict, SumGroupingRegularizer)

    # Op c4 is not in the DFS path of out.  Verify that OpRegularizerManager
    # does not process c4.
    self.assertNotIn(c4.op, manager.ops)
    self.assertNotIn(concat.op, manager.ops) 
Example #23
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 6 votes vote down vote up
def testAddN(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    identity3 = tf.identity(inputs)
    identity4 = tf.identity(inputs)
    add_n = tf.add_n([identity1, identity2, identity3, identity4])
    batch_norm = layers.batch_norm(add_n)

    manager = orm.OpRegularizerManager(
        [batch_norm.op], op_handler_dict=self._default_op_handler_dict)

    op_slices = manager.get_op_slices(identity1.op)
    self.assertLen(op_slices, 1)
    op_group = manager.get_op_group(op_slices[0]).op_slices

    # Verify all ops are in the same group.
    for test_op in (identity1.op, identity2.op, identity3.op, identity4.op,
                    add_n.op, batch_norm.op):
      test_op_slices = manager.get_op_slices(test_op)
      self.assertLen(test_op_slices, 1)
      self.assertIn(test_op_slices[0], op_group) 
Example #24
Source File: batch_allreduce.py    From benchmarks with Apache License 2.0 5 votes vote down vote up
def _add_put_op_control_deps(all_device_tensors, num_splits, put_ops):
  """Add control dependencies from `put_ops` to `all_device_tensors`.

  This should only be called when deferred tensors are being used.

  The control dependencies are added so that the put ops are run whenever
  `all_device_tensors` is run. That way, the caller does not have to explicitly
  run the put ops.

  Args:
    all_device_tensors: A list of list of tensors. `all_device_tensors[i][j]` is
      a tensor where `i` is the device index and `j` is the tensor index.
    num_splits: The number of splits that were used for the all-reduce.
    put_ops: A list of put ops from deferring the tensors.
  Returns:
    A list in the same form as `all_device_tensors`, except each tensor has a
    control dependency on an op in `put_ops`.

  """
  def apply_func(tensor, device_index, tensor_index):
    if num_splits == 0:
      deps = [put_ops[device_index][tensor_index]]
    else:
      deps = put_ops[device_index]
    assert len(deps) == 1
    with tf.control_dependencies(deps):
      return tf.identity(tensor, name='control_dependency')
  return _apply_to_all_device_tensors(all_device_tensors, apply_func) 
Example #25
Source File: svg_decoder_loss.py    From magenta with Apache License 2.0 5 votes vote down vote up
def real_svg_bottom(features, unused_model_hparams, unused_vocab_size):
  with tf.variable_scope('real_bottom', reuse=tf.AUTO_REUSE):
    return tf.identity(features) 
Example #26
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 5 votes vote down vote up
def testPrintOpSlices(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    op_slices1 = manager.get_op_slices(identity1.op)
    op_slices2 = manager.get_op_slices(identity2.op)
    all_slices = op_slices1 + op_slices2

    self.assertEqual('[Identity (0, 3), Identity_1 (0, 3)]',
                     str(all_slices)) 
Example #27
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 5 votes vote down vote up
def testOpGroup_NewSourceGroup_DuplicateOpSlice(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    op_slice1 = orm.OpSlice(identity1.op, None)
    op_slice2 = orm.OpSlice(identity2.op, None)
    op_group1 = orm.OpGroup(op_slice1)
    op_group2 = orm.OpGroup(
        op_slice2, [op_group1], omit_source_op_slices=[op_slice2])
    op_group3 = orm.OpGroup(op_groups=[op_group1, op_group2])

    self.assertListEqual([op_slice1, op_slice2], op_group3.op_slices)
    self.assertListEqual([op_slice1], op_group3.source_op_slices) 
Example #28
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 5 votes vote down vote up
def testOpGroup_NewGroupNoSource(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)
    op_slice = orm.OpSlice(identity.op, None)
    op_group = orm.OpGroup(op_slice, omit_source_op_slices=[op_slice])

    self.assertListEqual([op_slice], op_group.op_slices)
    self.assertListEqual([], op_group.source_op_slices) 
Example #29
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 5 votes vote down vote up
def testSliceOpWithSizes(self):
    inputs = tf.zeros([2, 4, 4, 10])
    identity = tf.identity(inputs)

    manager = orm.OpRegularizerManager([])

    sizes = [1, 2, 3, 4]
    is_source = [True, False, True, False]
    is_resliced = [True, True, True, True]
    op_slices = manager._slice_op_with_sizes(identity.op, sizes, is_source,
                                             is_resliced)

    # Verify OpSlice count and whether they are sources.
    self.assertLen(op_slices, 4)

    slice1 = op_slices[0]
    op_group1 = manager.get_op_group(slice1)
    self.assertIn(slice1, op_group1.source_op_slices)

    slice2 = op_slices[1]
    op_group2 = manager.get_op_group(slice2)
    self.assertIsNone(op_group2)

    slice3 = op_slices[2]
    op_group3 = manager.get_op_group(slice3)
    self.assertIn(slice3, op_group3.source_op_slices)

    slice4 = op_slices[3]
    op_group4 = manager.get_op_group(slice4)
    self.assertIsNone(op_group4) 
Example #30
Source File: op_regularizer_manager_test.py    From morph-net with Apache License 2.0 5 votes vote down vote up
def testGetOpSlices(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)

    # Create OpRegularizerManager with OpSlice mapping.
    manager = orm.OpRegularizerManager([])
    op_slice = orm.OpSlice(identity.op, orm.Slice(0, 3))
    manager._op_slice_dict[identity.op] = [op_slice]

    op_slices = manager.get_op_slices(identity.op)

    self.assertLen(op_slices, 1)
    self.assertEqual(op_slice, op_slices[0])