"""Tests for op_regularizer_manager."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
from absl.testing import parameterized
from morph_net.framework import batch_norm_source_op_handler
from morph_net.framework import concat_op_handler
from morph_net.framework import conv_source_op_handler
from morph_net.framework import depthwise_convolution_op_handler
from morph_net.framework import generic_regularizers
from morph_net.framework import grouping_op_handler
from morph_net.framework import op_regularizer_manager as orm
from morph_net.framework import output_non_passthrough_op_handler
from morph_net.testing import add_concat_model_stub
from morph_net.testing import grouping_concat_model_stub
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.contrib import framework
from tensorflow.contrib import layers

arg_scope = framework.arg_scope

DEBUG_PRINTS = False


def _get_op(name):
  return tf.get_default_graph().get_operation_by_name(name)


class OpRegularizerManagerTest(parameterized.TestCase, tf.test.TestCase):

  def setUp(self):
    super(OpRegularizerManagerTest, self).setUp()
    tf.set_random_seed(12)
    np.random.seed(665544)
    IndexOpRegularizer.reset_index()

    # Create default OpHandler dict for testing.
    self._default_op_handler_dict = collections.defaultdict(
        grouping_op_handler.GroupingOpHandler)
    self._default_op_handler_dict.update({
        'FusedBatchNormV3':
            IndexBatchNormSourceOpHandler(),
        'Conv2D':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
        'DepthwiseConv2dNative':
            depthwise_convolution_op_handler.DepthwiseConvolutionOpHandler(),
    })

  def _batch_norm_scope(self):
    params = {
        'trainable': True,
        'normalizer_fn': layers.batch_norm,
        'normalizer_params': {
            'scale': True
        }
    }

    with arg_scope([layers.conv2d], **params) as sc:
      return sc

  @parameterized.named_parameters(('Batch_no_par1', True, False, 'conv1'),
                                  ('Batch_par1', True, True, 'conv1'),
                                  ('NoBatch_no_par1', False, False, 'conv1'),
                                  ('NoBatch_par2', False, True, 'conv2'),
                                  ('Batch_no_par2', True, False, 'conv2'),
                                  ('Batch_par2', True, True, 'conv2'),
                                  ('Batch_par3', True, True, 'conv3'),
                                  ('NoBatch_par3', False, True, 'conv3'),
                                  ('NoBatch_no_par3', False, False, 'conv3'))
  def testSimpleOpGetRegularizer(self, use_batch_norm, use_partitioner, scope):
    # Tests the alive pattern of the conv and relu ops.
    # use_batch_norm: A Boolean. Indicates if batch norm should be used.
    # use_partitioner: A Boolean. Indicates if a fixed_size_partitioner should
    #   be used.
    # scope: A String with the scope to test.
    sc = self._batch_norm_scope() if use_batch_norm else []
    partitioner = tf.fixed_size_partitioner(2) if use_partitioner else None
    model_stub = add_concat_model_stub
    with arg_scope(sc):
      with tf.variable_scope(tf.get_variable_scope(), partitioner=partitioner):
        final_op = add_concat_model_stub.build_model()

    # Instantiate OpRegularizerManager.
    op_handler_dict = self._default_op_handler_dict
    op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler(
        model_stub)
    if not use_batch_norm:
      op_handler_dict['Conv2D'] = StubConvSourceOpHandler(model_stub)
    op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict)

    expected_alive = model_stub.expected_alive()
    conv_reg = op_reg_manager.get_regularizer(_get_op(scope + '/Conv2D'))
    self.assertAllEqual(expected_alive[scope], conv_reg.alive_vector)

    relu_reg = op_reg_manager.get_regularizer(_get_op(scope + '/Relu'))
    self.assertAllEqual(expected_alive[scope], relu_reg.alive_vector)

  @parameterized.named_parameters(('Batch_no_par', True, False),
                                  ('Batch_par', True, True),
                                  ('NoBatch_no_par', False, False),
                                  ('NoBatch_par', False, True))
  def testConcatOpGetRegularizer(self, use_batch_norm, use_partitioner):
    sc = self._batch_norm_scope() if use_batch_norm else []
    partitioner = tf.fixed_size_partitioner(2) if use_partitioner else None
    model_stub = add_concat_model_stub
    with arg_scope(sc):
      with tf.variable_scope(tf.get_variable_scope(), partitioner=partitioner):
        final_op = add_concat_model_stub.build_model()

    # Instantiate OpRegularizerManager.
    op_handler_dict = self._default_op_handler_dict
    op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler(
        model_stub)
    if not use_batch_norm:
      op_handler_dict['Conv2D'] = StubConvSourceOpHandler(model_stub)
    op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict)

    expected_alive = model_stub.expected_alive()
    expected = np.logical_or(expected_alive['conv4'], expected_alive['concat'])
    conv_reg = op_reg_manager.get_regularizer(_get_op('conv4/Conv2D'))
    self.assertAllEqual(expected, conv_reg.alive_vector)

    relu_reg = op_reg_manager.get_regularizer(_get_op('conv4/Relu'))
    self.assertAllEqual(expected, relu_reg.alive_vector)

  @parameterized.named_parameters(
      ('_conv1', 'conv1/Conv2D', 'conv1'),
      ('_conv2', 'conv2/Conv2D', 'conv2'),
      ('_conv3', 'conv3/Conv2D', 'conv3'),
      ('_conv4', 'conv4/Conv2D', 'conv4'),
  )
  def testGroupConcatOpGetRegularizerValues(self, op_name, short_name):
    model_stub = grouping_concat_model_stub
    with arg_scope(self._batch_norm_scope()):
      with tf.variable_scope(tf.get_variable_scope()):
        final_op = model_stub.build_model()

    # Instantiate OpRegularizerManager.
    op_handler_dict = self._default_op_handler_dict
    op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler(
        model_stub)

    op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict)

    expected_alive = model_stub.expected_alive()
    expected_reg = model_stub.expected_regularization()

    reg = op_reg_manager.get_regularizer(_get_op(op_name))
    self.assertAllEqual(expected_alive[short_name], reg.alive_vector)
    self.assertAllClose(expected_reg[short_name], reg.regularization_vector)

  def testGroupConcatOpGetRegularizerObjects(self):
    model_stub = grouping_concat_model_stub
    with arg_scope(self._batch_norm_scope()):
      with tf.variable_scope(tf.get_variable_scope()):
        final_op = model_stub.build_model()

    # Instantiate OpRegularizerManager.
    op_handler_dict = self._default_op_handler_dict
    op_handler_dict['FusedBatchNormV3'] = StubBatchNormSourceOpHandler(
        model_stub)

    op_reg_manager = orm.OpRegularizerManager([final_op], op_handler_dict)
    self.assertEqual(
        op_reg_manager.get_regularizer(_get_op('conv1/Conv2D')),
        op_reg_manager.get_regularizer(_get_op('conv2/Conv2D')))
    self.assertEqual(
        op_reg_manager.get_regularizer(_get_op('conv3/Conv2D')),
        op_reg_manager.get_regularizer(_get_op('conv4/Conv2D')))

  @parameterized.named_parameters(('Concat_5', True, 5),
                                  ('Concat_7', True, 7),
                                  ('Add_6', False, 6))
  def testGetRegularizerForConcatWithNone(self, test_concat, depth):
    image = tf.constant(0.0, shape=[1, 17, 19, 3])
    conv2 = layers.conv2d(image, 5, [1, 1], padding='SAME', scope='conv2')
    other_input = tf.add(
        tf.identity(tf.constant(3.0, shape=[1, 17, 19, depth])), 3.0)
    # other_input has None as regularizer.
    concat = tf.concat([other_input, conv2], 3)
    output = tf.add(concat, concat, name='output_out')
    op = concat.op if test_concat else output.op

    # Instantiate OpRegularizerManager.
    op_handler_dict = self._default_op_handler_dict
    op_handler_dict['Conv2D'] = StubConvSourceOpHandler(add_concat_model_stub)
    op_reg_manager = orm.OpRegularizerManager([output.op], op_handler_dict)

    expected_alive = add_concat_model_stub.expected_alive()
    alive = op_reg_manager.get_regularizer(op).alive_vector
    self.assertAllEqual([True] * depth, alive[:depth])
    self.assertAllEqual(expected_alive['conv2'], alive[depth:])

  @parameterized.named_parameters(('add', tf.add),
                                  ('div', tf.divide),
                                  ('mul', tf.multiply),
                                  ('max', tf.maximum),
                                  ('min', tf.minimum),
                                  ('l2', tf.squared_difference))
  def testGroupingOps(self, tested_op):
    th = 0.5
    image = tf.constant(0.5, shape=[1, 17, 19, 3])

    conv1 = layers.conv2d(image, 5, [1, 1], padding='SAME', scope='conv1')
    conv2 = layers.conv2d(image, 5, [1, 1], padding='SAME', scope='conv2')
    res = tested_op(conv1, conv2)

    # Instantiate OpRegularizerManager.
    op_handler_dict = self._default_op_handler_dict
    op_handler_dict['Conv2D'] = RandomConvSourceOpHandler(th)
    op_reg_manager = orm.OpRegularizerManager([res.op], op_handler_dict)

    alive = op_reg_manager.get_regularizer(res.op).alive_vector
    conv1_reg = op_reg_manager.get_regularizer(conv1.op).regularization_vector
    conv2_reg = op_reg_manager.get_regularizer(conv2.op).regularization_vector
    with self.session():
      self.assertAllEqual(alive, np.logical_or(conv1_reg.eval() > th,
                                               conv2_reg.eval() > th))

  def testCascadedGrouping(self):
    inputs = tf.zeros([6, 8, 8, 10], name='prev')
    with arg_scope(
        [layers.conv2d, layers.max_pool2d],
        kernel_size=1,
        stride=1,
        padding='SAME'):
      net = layers.conv2d(inputs, 17, scope='conv/input')

      first = layers.conv2d(net, num_outputs=17, scope='conv/first')
      add_0 = tf.add(first, net, 'Add/first')  # So conv/first must be 17.
      second = layers.conv2d(add_0, num_outputs=17, scope='conv/second')
      out = tf.add(net, second, 'Add/second')  # So conv/second must be 17.

    # Instantiate OpRegularizerManager.
    op_handler_dict = self._default_op_handler_dict
    op_handler_dict['Conv2D'] = IndexConvSourceOpHandler()
    op_reg_manager = orm.OpRegularizerManager([out.op], op_handler_dict)

    grouped_names = [
        [op_slice.op.name for op_slice in group.op_slices]
        for group in op_reg_manager._op_group_dict.values()]
    expected = set([
        'conv/second/Conv2D', 'Add/second', 'conv/first/Conv2D',
        'conv/input/Conv2D', 'Add/first'
    ])
    groups = []
    for group in grouped_names:
      filtered = []
      for op_name in group:
        if '/Conv2D' in op_name or 'Add/' in op_name:
          filtered.append(op_name)
      if filtered:
        groups.append(set(filtered))
        if DEBUG_PRINTS:
          print('Group Found = ', filtered)
    self.assertIn(expected, groups)

  def testBroadcast(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      tmp = c1 + c2
      final_op = layers.conv2d(
          tmp, num_outputs=13, kernel_size=3, scope='conv3')

    manager = orm.OpRegularizerManager(
        [final_op.op], self._default_op_handler_dict)

    c1_reg = manager.get_regularizer(_get_op('conv1/Conv2D'))
    c2_reg = manager.get_regularizer(_get_op('conv2/Conv2D'))
    add_reg = manager.get_regularizer(_get_op('add'))
    c3_reg = manager.get_regularizer(_get_op('conv3/Conv2D'))

    expected_c1_reg_size = 1
    self.assertEqual(expected_c1_reg_size, c1_reg.regularization_vector.shape)
    self.assertEqual(10, c2_reg.regularization_vector.shape)
    self.assertEqual(10, add_reg.regularization_vector.shape)
    self.assertEqual(13, c3_reg.regularization_vector.shape)

    c1_slice = manager.get_op_slices(c1.op)[0]
    c1_group = manager.get_op_group(c1_slice)
    c2_slice = manager.get_op_slices(c2.op)[0]
    c2_group = manager.get_op_group(c2_slice)
    add_slice = manager.get_op_slices(tmp.op)[0]
    add_group = manager.get_op_group(add_slice)
    c3_slice = manager.get_op_slices(final_op.op)[0]
    c3_group = manager.get_op_group(c3_slice)

    # Verify all OpSlice grouped with c1 have size 1.
    for op_slice in c1_group.op_slices:
      self.assertEqual(1, op_slice.slice.size)

    # Verify all OpSlice grouped with c2 have size 10.
    for op_slice in c2_group.op_slices:
      self.assertEqual(10, op_slice.slice.size)

    # Verify c1 is not grouped with c2, add, or c3.
    self.assertNotEqual(c1_group, c2_group)
    self.assertNotEqual(c1_group, add_group)
    self.assertNotEqual(c1_group, c3_group)

    # Verify c2 and add are grouped to each other, but not c3.
    self.assertEqual(c2_group, add_group)
    self.assertNotEqual(c2_group, c3_group)

  def testReuse(self):
    inputs = tf.zeros([2, 4, 4, 3])
    num_outputs = 3
    kernel_size = 1
    with arg_scope([layers.conv2d], normalizer_fn=layers.batch_norm):
      with tf.variable_scope('parallel', reuse=tf.AUTO_REUSE):
        mul0 = layers.conv2d(inputs, num_outputs, kernel_size, scope='conv1')
        mul1 = layers.conv2d(inputs, num_outputs, kernel_size,
                             activation_fn=tf.nn.sigmoid, scope='conv2')
        prev1 = np.prod([mul0, mul1])
      with tf.variable_scope('parallel', reuse=tf.AUTO_REUSE):
        mul0_1 = layers.conv2d(prev1, num_outputs, kernel_size, scope='conv1')
        mul1_1 = layers.conv2d(prev1, num_outputs, kernel_size,
                               activation_fn=tf.nn.sigmoid, scope='conv2')
      prev2 = np.prod([mul0_1, mul1_1])
      prev3 = prev2 + 0.0
      # This hack produces the desired grouping due to variable reuse.
      # prev3 = prev2 + 0.0 * (mul0 + mul1 + mul0_1 + mul1_1)

    manager = orm.OpRegularizerManager(
        [prev3.op], self._default_op_handler_dict)

    mul0_reg = manager.get_regularizer(_get_op('parallel/conv1/Conv2D'))
    mul1_reg = manager.get_regularizer(_get_op('parallel/conv2/Conv2D'))
    mul0_1_reg = manager.get_regularizer(_get_op('parallel_1/conv1/Conv2D'))
    mul1_1_reg = manager.get_regularizer(_get_op('parallel_1/conv2/Conv2D'))

    # Check that regularizers were grouped properly.
    self.assertEqual(mul0_reg, mul1_reg)
    self.assertEqual(mul0_1_reg, mul1_1_reg)
    # These regularizers should be grouped due to reused variables.
    # self.assertEqual(mul0_reg, mul0_1_reg)
    # self.assertEqual(mul1_reg, mul1_1_reg)

  def testGather(self):
    gather_index = [5, 6, 7, 8, 9, 0, 1, 2, 3, 4]
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      gather = tf.gather(c1, gather_index, axis=3)

    manager = orm.OpRegularizerManager(
        [gather.op], self._default_op_handler_dict)

    c1_reg = manager.get_regularizer(_get_op('conv1/Conv2D'))
    gather_reg = manager.get_regularizer(_get_op('GatherV2'))

    # Check regularizer indices.
    self.assertAllEqual(list(range(10)), c1_reg.regularization_vector)
    # This fails due to gather not being supported.  Once gather is supported,
    # this test can be enabled to verify that the regularization vector is
    # gathered in the same ordering as the tensor.
    # self.assertAllEqual(
    #     gather_index, gather_reg.regularization_vector)

    # This test shows that gather is not supported.  The regularization vector
    # has the same initial ordering after the gather op scrambled the
    # channels.  Remove this once gather is supported.
    self.assertAllEqual(list(range(10)), gather_reg.regularization_vector)

  def testConcat(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      concat = tf.concat([c1, c2], axis=3)
      tmp = c1 + c2

    manager = orm.OpRegularizerManager(
        [concat.op, tmp.op], self._default_op_handler_dict)

    # Fetch OpSlice to verify grouping.
    inputs_op_slice = manager.get_op_slices(inputs.op)[0]
    c1_op_slice = manager.get_op_slices(c1.op)[0]
    c2_op_slice = manager.get_op_slices(c2.op)[0]
    tmp_op_slice = manager.get_op_slices(tmp.op)[0]
    concat_op_slice0 = manager.get_op_slices(concat.op)[0]
    concat_op_slice1 = manager.get_op_slices(concat.op)[1]

    # Verify inputs and c1 have different group.
    self.assertNotEqual(manager.get_op_group(inputs_op_slice),
                        manager.get_op_group(c1_op_slice))

    # Verify inputs and c2 have different group.
    self.assertNotEqual(manager.get_op_group(inputs_op_slice),
                        manager.get_op_group(c2_op_slice))

    # Verify c1, c2, and add have the same group.
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(c2_op_slice))
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(tmp_op_slice))

    # Verify concat slices are grouped with c1, c2, and add.
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(concat_op_slice0))
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(concat_op_slice1))

  def testGroupingConcat(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv2')
      concat = tf.concat([c1, c2], axis=2)

    manager = orm.OpRegularizerManager([concat.op],
                                       self._default_op_handler_dict)

    # Fetch OpSlice to verify grouping.
    inputs_op_slice = manager.get_op_slices(inputs.op)[0]
    c1_op_slice = manager.get_op_slices(c1.op)[0]
    c2_op_slice = manager.get_op_slices(c2.op)[0]
    concat_op_slice = manager.get_op_slices(concat.op)[0]

    # Verify inputs and c1 have different group.
    self.assertNotEqual(
        manager.get_op_group(inputs_op_slice),
        manager.get_op_group(c1_op_slice))

    # Verify inputs and c2 have different group.
    self.assertNotEqual(
        manager.get_op_group(inputs_op_slice),
        manager.get_op_group(c2_op_slice))

    # Verify c1, c2, and concat have the same group.
    self.assertEqual(
        manager.get_op_group(c1_op_slice), manager.get_op_group(c2_op_slice))
    self.assertEqual(
        manager.get_op_group(c1_op_slice),
        manager.get_op_group(concat_op_slice))

  def testBatchNormAfterConcat(self):
    inputs = tf.zeros([2, 4, 4, 3])
    # BN before concat - one per conv.
    with arg_scope(
        [layers.conv2d],
        normalizer_fn=layers.batch_norm,
        normalizer_params={'fused': True, 'scale': True}):
      left = layers.conv2d(inputs, 2, kernel_size=3, scope='left')
      right = layers.conv2d(inputs, 3, kernel_size=3, scope='right')
      concat = tf.concat([left, right], -1)

    manager = orm.OpRegularizerManager(
        [concat.op], self._default_op_handler_dict)

    # Fetch OpSlice to verify grouping.
    left_op_slice = manager.get_op_slices(left.op)[0]
    right_op_slice = manager.get_op_slices(right.op)[0]
    concat_op_slice0 = manager.get_op_slices(concat.op)[0]
    concat_op_slice1 = manager.get_op_slices(concat.op)[1]

    # Verify that left op is grouped with left part of concat.
    self.assertEqual(manager.get_op_group(left_op_slice),
                     manager.get_op_group(concat_op_slice0))

    # Verify that right op is grouped with right part of concat.
    self.assertEqual(manager.get_op_group(right_op_slice),
                     manager.get_op_group(concat_op_slice1))

    # BN after concat
    tf.reset_default_graph()
    inputs = tf.zeros([2, 4, 4, 3])
    left = layers.conv2d(inputs, 3, kernel_size=3, scope='left_after')
    right = layers.conv2d(inputs, 4, kernel_size=3, scope='right_after')
    concat = tf.concat([left, right], -1)
    batch_norm = layers.batch_norm(concat, fused=True, scale=True)

    manager = orm.OpRegularizerManager(
        [batch_norm.op], self._default_op_handler_dict)

    # Fetch OpSlice to verify grouping.
    left_op_slice = manager.get_op_slices(left.op)[0]
    right_op_slice = manager.get_op_slices(right.op)[0]
    concat_op_slice0 = manager.get_op_slices(concat.op)[0]
    concat_op_slice1 = manager.get_op_slices(concat.op)[1]
    batch_norm_op_slice0 = manager.get_op_slices(batch_norm.op)[0]
    batch_norm_op_slice1 = manager.get_op_slices(batch_norm.op)[1]

    # Verify that left op is grouped with left part of concat and batch norm.
    self.assertEqual(manager.get_op_group(left_op_slice),
                     manager.get_op_group(concat_op_slice0))
    self.assertEqual(manager.get_op_group(left_op_slice),
                     manager.get_op_group(batch_norm_op_slice0))

    # Verify that right op is grouped with right part of concat and batch norm.
    self.assertEqual(manager.get_op_group(right_op_slice),
                     manager.get_op_group(concat_op_slice1))
    self.assertEqual(manager.get_op_group(right_op_slice),
                     manager.get_op_group(batch_norm_op_slice1))

    # Verify that original concat OpSlice is removed.
    old_concat_op_slice = orm.OpSlice(concat.op, orm.Slice(0, 7))
    self.assertIsNone(manager.get_op_group(old_concat_op_slice))

  def testNestedConcat(self):
    inputs = tf.zeros([2, 4, 4, 3])
    conv1 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv1')
    conv2 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv2')
    conv3 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv3')
    conv4 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv4')
    conv5 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv5')
    conv6 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv6')
    conv7 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv7')
    conv8 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv8')
    conv9 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv9')
    conv10 = layers.conv2d(inputs, num_outputs=1, kernel_size=3, scope='conv10')

    concat1 = tf.concat([conv1, conv2, conv3], axis=3)
    concat2 = tf.concat([conv4, conv5, conv6, conv7], axis=3)
    concat3 = tf.concat([conv8, conv9], axis=3)
    concat4 = tf.concat([conv10], axis=3)

    concat5 = tf.concat([concat1, concat2], axis=3)
    concat6 = tf.concat([concat3, concat4], axis=3)

    # This looks like [[[1, 2, 3], [4, 5, 6, 7]], [[8, 9], [10]]].
    concat7 = tf.concat([concat5, concat6], axis=3)
    batch_norm = layers.batch_norm(concat7)

    manager = orm.OpRegularizerManager(
        [batch_norm.op], self._default_op_handler_dict)

    # Verify that batch norm gets sliced into individual channels due to
    # concatenation of all the convolutions.
    batch_norm_op_slices = manager.get_op_slices(batch_norm.op)
    self.assertLen(batch_norm_op_slices, 10)
    for i in range(10):
      op_slice = batch_norm_op_slices[i]
      self.assertEqual(i, op_slice.slice.start_index)
      self.assertEqual(1, op_slice.slice.size)

      # Verify other OpSlice are not grouped with this one.
      group_op_slices = manager.get_op_group(op_slice).op_slices
      for j in range(10):
        if i != j:
          self.assertNotIn(batch_norm_op_slices[j], group_op_slices)

  def testSplit(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      split = tf.split(c1, [5, 5], axis=3)
      c2 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv2')
      c3 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv3')
      out1 = split[0] + c2
      out2 = split[1] + c3

    with self.assertRaises(RuntimeError):
      # Regularizer assignment fails because c2/c3 have size 5 while split has
      # size 10, so regularizer grouping fails.
      orm.OpRegularizerManager(
          [out1.op, out2.op], self._default_op_handler_dict,
          iteration_limit=100)

  @parameterized.named_parameters(('DepthMultiplier_1', 8, 1),
                                  ('DepthMultiplier_2', 8, 2),
                                  ('DepthMultiplier_7', 8, 7),
                                  ('DepthMultiplier_1_no_pointwise', None, 1),
                                  ('DepthMultiplier_2_no_pointwise', None, 2),
                                  ('DepthMultiplier_7_no_pointwise', None, 7))
  def testSeparableConv2D_DepthMultiplier(
      self, pointwise_outputs, depth_multiplier):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      num_outputs = 5
      c1 = layers.conv2d(
          inputs, num_outputs=num_outputs, kernel_size=3, scope='conv1')
      c2 = layers.separable_conv2d(
          c1, num_outputs=pointwise_outputs, kernel_size=3,
          depth_multiplier=depth_multiplier, scope='conv2')
      identity = tf.identity(c2)

    manager = orm.OpRegularizerManager(
        [identity.op], self._default_op_handler_dict)

    # If separable_conv2d is passed num_outputs=None, the name of the depthwise
    # convolution changes.
    depthwise_conv_name = 'conv2/separable_conv2d/depthwise'
    if pointwise_outputs is None:
      depthwise_conv_name = 'conv2/depthwise'

    dwise_op = _get_op(depthwise_conv_name)
    dwise_reg = manager.get_regularizer(dwise_op)

    # Verify that depthwise convolution has output tensor and regularization
    # vector of size 5 * depth_multiplier channels where 5 is the number of
    # input channels from c1.
    self.assertEqual(num_outputs * depth_multiplier,
                     dwise_op.outputs[0].shape[-1])
    self.assertEqual(num_outputs * depth_multiplier,
                     dwise_reg.regularization_vector.shape[-1])

    # Verify OpSlice in the depthwise convolution has the correct grouping.
    relu1_op_slices = manager.get_op_slices(c1.op)
    dwise_op_slices = manager.get_op_slices(dwise_op)
    relu2_op_slices = manager.get_op_slices(c2.op)

    # Expected input grouping has a pattern like [0, 0, 1, 1, 2, 2, ...].
    # pylint: disable=g-complex-comprehension
    expected_input_grouping = [j
                               for j in range(num_outputs)
                               for i in range(depth_multiplier)]
    # pylint: enable=g-complex-comprehension
    # Expected output grouping is just linear, but with
    # num_outputs * depth_multiplier channels (e.g. [0, 1, 2, 3, ...]).
    expected_output_grouping = range(num_outputs * depth_multiplier)

    for i, op_slice in enumerate(dwise_op_slices):
      group = manager.get_op_group(op_slice)
      group_op_slices = group.op_slices
      self.assertIn(relu1_op_slices[expected_input_grouping[i]],
                    group_op_slices)
      self.assertIn(dwise_op_slices[expected_output_grouping[i]],
                    group_op_slices)
      if pointwise_outputs is None:
        # When pointwise_outputs is None, the pointwise convolution is omitted
        # and the depthwise convolution is immediately followed by
        # BiasAdd -> Relu ops.  In that case, verify that input channels of the
        # depthwise convolution are correctly grouped with the output (relu2)
        # channels.  Otherwise, the depthwise convolution is immediately
        # followed by a pointwise convolution which is non-passthrough, so there
        # is no output grouping to verify.
        self.assertIn(relu2_op_slices[expected_output_grouping[i]],
                      group_op_slices)

  def testAddN(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    identity3 = tf.identity(inputs)
    identity4 = tf.identity(inputs)
    add_n = tf.add_n([identity1, identity2, identity3, identity4])
    batch_norm = layers.batch_norm(add_n)

    manager = orm.OpRegularizerManager(
        [batch_norm.op], op_handler_dict=self._default_op_handler_dict)

    op_slices = manager.get_op_slices(identity1.op)
    self.assertLen(op_slices, 1)
    op_group = manager.get_op_group(op_slices[0]).op_slices

    # Verify all ops are in the same group.
    for test_op in (identity1.op, identity2.op, identity3.op, identity4.op,
                    add_n.op, batch_norm.op):
      test_op_slices = manager.get_op_slices(test_op)
      self.assertLen(test_op_slices, 1)
      self.assertIn(test_op_slices[0], op_group)

  def testAddN_Duplicates(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)
    add_n = tf.add_n([identity, identity, identity, identity])
    batch_norm = layers.batch_norm(add_n)

    manager = orm.OpRegularizerManager(
        [batch_norm.op], op_handler_dict=self._default_op_handler_dict)

    op_slices = manager.get_op_slices(identity.op)
    self.assertLen(op_slices, 1)
    op_group = manager.get_op_group(op_slices[0]).op_slices

    # Verify all ops are in the same group.
    for test_op in (identity.op, add_n.op, batch_norm.op):
      test_op_slices = manager.get_op_slices(test_op)
      self.assertLen(test_op_slices, 1)
      self.assertIn(test_op_slices[0], op_group)

  def testInit_Add(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      add = c1 + c2
      c3 = layers.conv2d(add, num_outputs=10, kernel_size=3, scope='conv3')
      out = tf.identity(c3)

    manager = orm.OpRegularizerManager(
        [out.op], self._default_op_handler_dict, SumGroupingRegularizer)

    # Fetch OpSlice to verify grouping.
    inputs_op_slice = manager.get_op_slices(inputs.op)[0]
    c1_op_slice = manager.get_op_slices(c1.op)[0]
    c2_op_slice = manager.get_op_slices(c2.op)[0]
    add_op_slice = manager.get_op_slices(add.op)[0]
    c3_op_slice = manager.get_op_slices(c3.op)[0]
    out_op_slice = manager.get_op_slices(out.op)[0]

    # Verify inputs and c1 have different group.
    self.assertNotEqual(manager.get_op_group(inputs_op_slice),
                        manager.get_op_group(c1_op_slice))
    self.assertNotEqual(manager.get_regularizer(inputs.op),
                        manager.get_regularizer(c1.op))

    # Verify inputs and c2 have different group.
    self.assertNotEqual(manager.get_op_group(inputs_op_slice),
                        manager.get_op_group(c2_op_slice))
    self.assertNotEqual(manager.get_regularizer(inputs.op),
                        manager.get_regularizer(c2.op))

    # Verify c1, c2, and add have the same group.
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(c2_op_slice))
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(add_op_slice))
    self.assertEqual(manager.get_regularizer(c1.op),
                     manager.get_regularizer(c2.op))
    self.assertEqual(manager.get_regularizer(c1.op),
                     manager.get_regularizer(add.op))

    # Verify c3 and out have the same group, which differs from c1 and c2.
    self.assertEqual(manager.get_op_group(c3_op_slice),
                     manager.get_op_group(out_op_slice))
    self.assertNotEqual(manager.get_op_group(c3_op_slice),
                        manager.get_op_group(c1_op_slice))
    self.assertEqual(manager.get_regularizer(c3.op),
                     manager.get_regularizer(out.op))
    self.assertNotEqual(manager.get_regularizer(c3.op),
                        manager.get_regularizer(c1.op))

  def testInit_Concat(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      concat = tf.concat([c1, c2], axis=3)
      out = tf.identity(concat)

    manager = orm.OpRegularizerManager(
        [out.op], self._default_op_handler_dict, SumGroupingRegularizer)

    # Fetch OpSlice to verify grouping.
    inputs_op_slice = manager.get_op_slices(inputs.op)[0]
    c1_op_slice = manager.get_op_slices(c1.op)[0]
    c2_op_slice = manager.get_op_slices(c2.op)[0]
    out_op_slice0 = manager.get_op_slices(out.op)[0]
    out_op_slice1 = manager.get_op_slices(out.op)[1]

    # Verify inputs and c1 have different group and OpRegularizer.
    self.assertNotEqual(manager.get_op_group(inputs_op_slice),
                        manager.get_op_group(c1_op_slice))
    self.assertNotEqual(manager.get_regularizer(inputs.op),
                        manager.get_regularizer(c1.op))

    # Verify inputs and c2 have different group and OpRegularizer.
    self.assertNotEqual(manager.get_op_group(inputs_op_slice),
                        manager.get_op_group(c2_op_slice))
    self.assertNotEqual(manager.get_regularizer(inputs.op),
                        manager.get_regularizer(c2.op))

    # Verify c1 and c2 have different group and OpRegularizer.
    self.assertNotEqual(manager.get_op_group(c1_op_slice),
                        manager.get_op_group(c2_op_slice))
    self.assertNotEqual(manager.get_regularizer(c1.op),
                        manager.get_regularizer(c2.op))

    # Verify c1 is grouped with first slice of out.
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(out_op_slice0))

    # Verify c2 is grouped with second slice of out.
    self.assertEqual(manager.get_op_group(c2_op_slice),
                     manager.get_op_group(out_op_slice1))

    # Verify out regularization_vector is the concat of c1 and c2
    # regularizertion_vector.
    self.assertAllEqual(
        manager.get_regularizer(c1.op).regularization_vector,
        manager.get_regularizer(out.op).regularization_vector[0:10])
    self.assertAllEqual(
        manager.get_regularizer(c2.op).regularization_vector,
        manager.get_regularizer(out.op).regularization_vector[10:20])

  def testInit_AddConcat(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      add = c1 + c2
      c3 = layers.conv2d(add, num_outputs=10, kernel_size=3, scope='conv3')
      out = tf.identity(c3)
      concat = tf.concat([c1, c2], axis=3)
      c4 = layers.conv2d(concat, num_outputs=10, kernel_size=3, scope='conv4')

    manager = orm.OpRegularizerManager(
        [out.op, c4.op], self._default_op_handler_dict, SumGroupingRegularizer)

    # Fetch OpSlice to verify grouping.
    inputs_op_slice = manager.get_op_slices(inputs.op)[0]
    c1_op_slice = manager.get_op_slices(c1.op)[0]
    c2_op_slice = manager.get_op_slices(c2.op)[0]
    add_op_slice = manager.get_op_slices(add.op)[0]
    c3_op_slice = manager.get_op_slices(c3.op)[0]
    out_op_slice = manager.get_op_slices(out.op)[0]
    concat_op_slice0 = manager.get_op_slices(concat.op)[0]
    concat_op_slice1 = manager.get_op_slices(concat.op)[1]
    c4_op_slice = manager.get_op_slices(c4.op)[0]

    # Verify inputs and c1 have different group.
    self.assertNotEqual(manager.get_op_group(inputs_op_slice),
                        manager.get_op_group(c1_op_slice))
    self.assertNotEqual(manager.get_regularizer(inputs.op),
                        manager.get_regularizer(c1.op))

    # Verify inputs and c2 have different group.
    self.assertNotEqual(manager.get_op_group(inputs_op_slice),
                        manager.get_op_group(c2_op_slice))
    self.assertNotEqual(manager.get_regularizer(inputs.op),
                        manager.get_regularizer(c2.op))

    # Verify c1, c2, and add have the same group.
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(c2_op_slice))
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(add_op_slice))
    self.assertEqual(manager.get_regularizer(c1.op),
                     manager.get_regularizer(c2.op))
    self.assertEqual(manager.get_regularizer(c1.op),
                     manager.get_regularizer(add.op))

    # Verify c3 and out have the same group, which differs from c1 and c2.
    self.assertEqual(manager.get_op_group(c3_op_slice),
                     manager.get_op_group(out_op_slice))
    self.assertNotEqual(manager.get_op_group(c3_op_slice),
                        manager.get_op_group(c1_op_slice))
    self.assertEqual(manager.get_regularizer(c3.op),
                     manager.get_regularizer(out.op))
    self.assertNotEqual(manager.get_regularizer(c3.op),
                        manager.get_regularizer(c1.op))

    # Verify concat slices are grouped with c1, c2, and add.
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(concat_op_slice0))
    self.assertEqual(manager.get_op_group(c1_op_slice),
                     manager.get_op_group(concat_op_slice1))

    # Verify concat regularization_vector is the concat of c1 and c2
    # regularizertion_vector.
    self.assertAllEqual(
        manager.get_regularizer(c1.op).regularization_vector,
        manager.get_regularizer(concat.op).regularization_vector[0:10])
    self.assertAllEqual(
        manager.get_regularizer(c2.op).regularization_vector,
        manager.get_regularizer(concat.op).regularization_vector[10:20])

    # Verify c4 has a different group than c1, c2, and add.
    self.assertNotEqual(manager.get_op_group(c1_op_slice),
                        manager.get_op_group(c4_op_slice))
    self.assertNotEqual(manager.get_regularizer(c1.op),
                        manager.get_regularizer(c4.op))

  def testInit_AddConcat_AllOps(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      add = c1 + c2
      c3 = layers.conv2d(add, num_outputs=10, kernel_size=3, scope='conv3')
      out = tf.identity(c3)
      concat = tf.concat([c1, c2], axis=3)
      c4 = layers.conv2d(concat, num_outputs=10, kernel_size=3, scope='conv4')

    manager = orm.OpRegularizerManager(
        [out.op], self._default_op_handler_dict, SumGroupingRegularizer)

    # Op c4 is not in the DFS path of out.  Verify that OpRegularizerManager
    # does not process c4.
    self.assertNotIn(c4.op, manager.ops)
    self.assertNotIn(concat.op, manager.ops)

  def testInit_ForceGroup(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(c1, num_outputs=10, kernel_size=3, scope='conv2')
      c3 = layers.conv2d(c2, num_outputs=10, kernel_size=3, scope='conv3')

    # Initialize OpRegularizerManager with no force-grouping.
    manager = orm.OpRegularizerManager(
        [c3.op], self._default_op_handler_dict, SumGroupingRegularizer)

    # Verify that c2 is not grouped with c3.
    c2_op_slices = manager.get_op_slices(c2.op)
    self.assertLen(c2_op_slices, 1)
    c2_op_slice = c2_op_slices[0]
    c2_group = manager.get_op_group(c2_op_slice)
    c3_op_slices = manager.get_op_slices(c3.op)
    self.assertLen(c3_op_slices, 1)
    c3_op_slice = c3_op_slices[0]
    self.assertNotIn(c3_op_slice, c2_group.op_slices)

    # Force-group c2 and c3.
    manager = orm.OpRegularizerManager(
        [c3.op], self._default_op_handler_dict, SumGroupingRegularizer,
        force_group=['conv2|conv3'])

    # Verify that c2 is grouped with c3.
    c2_op_slices = manager.get_op_slices(c2.op)
    self.assertLen(c2_op_slices, 1)
    c2_op_slice = c2_op_slices[0]
    c2_group = manager.get_op_group(c2_op_slice)
    c3_op_slices = manager.get_op_slices(c3.op)
    self.assertLen(c3_op_slices, 1)
    c3_op_slice = c3_op_slices[0]
    self.assertIn(c3_op_slice, c2_group.op_slices)

  def testInit_ForceGroup_MultipleOpSlice(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=5, kernel_size=3, scope='conv2')
      concat = tf.concat([c1, c2], axis=3)
      c3 = layers.conv2d(concat, num_outputs=10, kernel_size=3, scope='conv3')

    # Verify force-group with multiple OpSlice raises ValueError.
    self.assertRaisesRegexp(
        ValueError,
        r'Cannot force-group ops with more than 1 OpSlice: \[u?\'concat\'\]',
        orm.OpRegularizerManager, [c3.op], self._default_op_handler_dict,
        SumGroupingRegularizer, force_group=['conv3|concat'])

  def testInit_ForceGroup_SizeMismatch(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(c1, num_outputs=10, kernel_size=3, scope='conv2')
      # c3 has size 9 instead of 10.
      c3 = layers.conv2d(c2, num_outputs=9, kernel_size=3, scope='conv3')

    # Verify size mismatch raises ValueError.
    self.assertRaisesRegexp(
        ValueError,
        r'Cannot force-group ops with different sizes: \[.*\]',
        orm.OpRegularizerManager, [c3.op], self._default_op_handler_dict,
        SumGroupingRegularizer, force_group=['conv2|conv3'])

  def testInit_ForceGroup_NotList(self):
    inputs = tf.zeros([2, 4, 4, 3])

    # Verify that force_group string instead of a list raises exception.
    self.assertRaisesRegexp(
        TypeError,
        r'force_group must be a list of regex.',
        orm.OpRegularizerManager, [inputs.op], self._default_op_handler_dict,
        SumGroupingRegularizer, force_group='conv')

  def testInit_Blacklist(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=3, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(c1, num_outputs=4, kernel_size=3, scope='conv2')
      c3 = layers.conv2d(c2, num_outputs=5, kernel_size=3, scope='conv3')

    # Verify c2 has a regularizer.
    manager = orm.OpRegularizerManager(
        [c3.op], self._default_op_handler_dict, SumGroupingRegularizer)
    self.assertIsNotNone(manager.get_regularizer(c2.op))

    # Verify c2 has None regularizer after blacklisting.
    manager = orm.OpRegularizerManager(
        [c3.op], self._default_op_handler_dict, SumGroupingRegularizer,
        regularizer_blacklist=['conv2'])
    self.assertIsNone(manager.get_regularizer(c2.op))

  def testInit_BlacklistGroup(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      add = c1 + c2
      c3 = layers.conv2d(add, num_outputs=10, kernel_size=3, scope='conv3')

    # Verify c2 has a regularizer.
    manager = orm.OpRegularizerManager(
        [c3.op], self._default_op_handler_dict, SumGroupingRegularizer)
    self.assertIsNotNone(manager.get_regularizer(c2.op))

    # Verify c2 has None regularizer after blacklisting c1 which is grouped.
    manager = orm.OpRegularizerManager(
        [c3.op], self._default_op_handler_dict, SumGroupingRegularizer,
        regularizer_blacklist=['conv1'])
    self.assertIsNone(manager.get_regularizer(c2.op))

  def testInit_BlacklistGroup_NoMatch(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      add = c1 + c2
      c3 = layers.conv2d(add, num_outputs=10, kernel_size=3, scope='conv3')

    # Verify blacklist regex without match raises ValueError
    self.assertRaisesWithLiteralMatch(
        ValueError,
        'Blacklist regex never used: \'oops\'.',
        orm.OpRegularizerManager, [c3.op], self._default_op_handler_dict,
        SumGroupingRegularizer, regularizer_blacklist=['oops'])

  def testInit_BlacklistGroup_NotList(self):
    inputs = tf.zeros([2, 4, 4, 3])

    # Verify that regularizer_blacklist string instead of a list raises
    # exception.
    self.assertRaisesRegexp(
        TypeError,
        r'regularizer_blacklist must be a list of regex.',
        orm.OpRegularizerManager, [inputs.op], self._default_op_handler_dict,
        SumGroupingRegularizer, regularizer_blacklist='conv')

  def testInit_IterationLimit(self):
    inputs = tf.zeros([2, 4, 4, 3])

    # Verify that reaching iteration limit raises exception.
    self.assertRaisesRegexp(
        RuntimeError,
        r'OpRegularizerManager could not handle ops:',
        orm.OpRegularizerManager, [inputs.op], self._default_op_handler_dict,
        SumGroupingRegularizer, iteration_limit=0)

  def testGetRegularizer(self):
    op1 = tf.zeros([2, 4, 4, 3])
    op2 = tf.zeros([2, 4, 4, 3])
    op3 = tf.zeros([2, 4, 4, 10])

    manager = orm.OpRegularizerManager([], self._default_op_handler_dict)
    manager.slice_op(op3.op, [1, 2, 3, 4])

    # op2 has 1 OpSlice and op3 has 4 OpSlice of size [1, 2, 3, 4].
    op2_slices = manager.get_op_slices(op2.op)
    op3_slices = manager.get_op_slices(op3.op)

    op2_reg = IndexOpRegularizer(op2_slices[0], manager)
    op3_reg0 = IndexOpRegularizer(op3_slices[0], manager)
    op3_reg1 = IndexOpRegularizer(op3_slices[1], manager)
    op3_reg2 = IndexOpRegularizer(op3_slices[2], manager)
    op3_reg3 = IndexOpRegularizer(op3_slices[3], manager)

    # Map OpSlice to OpRegularizer.
    manager._op_regularizer_dict = {
        op2_slices[0]: op2_reg,
        op3_slices[0]: op3_reg0,
        op3_slices[1]: op3_reg1,
        op3_slices[2]: op3_reg2,
        op3_slices[3]: op3_reg3,
    }

    # Verify None is returned if OpSlice does not have OpRegularizer.
    self.assertIsNone(manager.get_regularizer(op1.op))

    # Verify OpRegularizer for op with single OpSlice.
    self.assertAllEqual([0, 1, 2],
                        manager.get_regularizer(op2.op).regularization_vector)

    # Verify OpRegularizer for op with multiple OpSlice.
    self.assertAllEqual(list(range(3, 13)),
                        manager.get_regularizer(op3.op).regularization_vector)

    # Verify OpRegularzier for op with multiple OpSlice but not all slices have
    # a regularizer.
    del manager._op_regularizer_dict[op3_slices[2]]
    expected_regularization_vector = [3, 4, 5, 0, 0, 0, 9, 10, 11, 12]
    self.assertAllEqual(expected_regularization_vector,
                        manager.get_regularizer(op3.op).regularization_vector)

  def testCreateOpGroupForOpSlice_Source(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)

    manager = orm.OpRegularizerManager([])

    # Create OpSlice for each identity op.
    op_slice = manager.get_op_slices(identity.op)[0]

    # Create OpGroup for each OpSlice.
    op_group = manager.create_op_group_for_op_slice(op_slice)

    self.assertListEqual([op_slice], op_group.op_slices)
    self.assertListEqual([op_slice], op_group.source_op_slices)
    self.assertEqual(op_group, manager.get_op_group(op_slice))

  def testCreateOpGroupForOpSlice_NotSource(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)

    manager = orm.OpRegularizerManager([])

    # Create OpSlice for each identity op.
    op_slice = manager.get_op_slices(identity.op)[0]

    # Create OpGroup for each OpSlice.
    op_group = manager.create_op_group_for_op_slice(op_slice, is_source=False)

    self.assertListEqual([op_slice], op_group.op_slices)
    self.assertListEqual([], op_group.source_op_slices)
    self.assertEqual(op_group, manager.get_op_group(op_slice))

  def testGroupOpSlices(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    identity3 = tf.identity(inputs)
    identity4 = tf.identity(inputs)
    identity5 = tf.identity(inputs)
    identity6 = tf.identity(inputs)
    identity7 = tf.identity(inputs)
    identity8 = tf.identity(inputs)

    manager = orm.OpRegularizerManager([])

    # Create OpSlice for each identity op.
    op_slice1 = manager.get_op_slices(identity1.op)[0]
    op_slice2 = manager.get_op_slices(identity2.op)[0]
    op_slice3 = manager.get_op_slices(identity3.op)[0]
    op_slice4 = manager.get_op_slices(identity4.op)[0]
    op_slice5 = manager.get_op_slices(identity5.op)[0]
    op_slice6 = manager.get_op_slices(identity6.op)[0]
    op_slice7 = manager.get_op_slices(identity7.op)[0]
    op_slice8 = manager.get_op_slices(identity8.op)[0]

    # Create OpGroup for each OpSlice.
    op_group1 = manager.create_op_group_for_op_slice(op_slice1)
    op_group2 = manager.create_op_group_for_op_slice(op_slice2)
    op_group3 = manager.create_op_group_for_op_slice(op_slice3)
    op_group4 = manager.create_op_group_for_op_slice(op_slice4)
    op_group5 = manager.create_op_group_for_op_slice(op_slice5)
    op_group6 = manager.create_op_group_for_op_slice(op_slice6)
    op_group7 = manager.create_op_group_for_op_slice(op_slice7)
    op_group8 = manager.create_op_group_for_op_slice(op_slice8)

    # Group all OpGroup together by grouping their OpSlice.
    manager.group_op_slices([op_slice1, op_slice2, op_slice3, op_slice4,
                             op_slice5, op_slice6, op_slice7, op_slice8])

    expected_group = orm.OpGroup(
        op_groups=[op_group1, op_group2, op_group3, op_group4, op_group5,
                   op_group6, op_group7, op_group8])

    # Check all OpSlice are in one big group.
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice1).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice2).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice3).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice4).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice5).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice6).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice7).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice8).op_slices)

  def testGroupOpSlices_TransitiveGrouping(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    identity3 = tf.identity(inputs)
    identity4 = tf.identity(inputs)
    identity5 = tf.identity(inputs)
    identity6 = tf.identity(inputs)
    identity7 = tf.identity(inputs)
    identity8 = tf.identity(inputs)

    manager = orm.OpRegularizerManager([])

    # Create OpSlice for each identity op.
    op_slice1 = manager.get_op_slices(identity1.op)[0]
    op_slice2 = manager.get_op_slices(identity2.op)[0]
    op_slice3 = manager.get_op_slices(identity3.op)[0]
    op_slice4 = manager.get_op_slices(identity4.op)[0]
    op_slice5 = manager.get_op_slices(identity5.op)[0]
    op_slice6 = manager.get_op_slices(identity6.op)[0]
    op_slice7 = manager.get_op_slices(identity7.op)[0]
    op_slice8 = manager.get_op_slices(identity8.op)[0]

    # Create OpGroup for each OpSlice.
    op_group1 = manager.create_op_group_for_op_slice(op_slice1)
    op_group2 = manager.create_op_group_for_op_slice(op_slice2)
    op_group3 = manager.create_op_group_for_op_slice(op_slice3)
    op_group4 = manager.create_op_group_for_op_slice(op_slice4)
    op_group5 = manager.create_op_group_for_op_slice(op_slice5)
    op_group6 = manager.create_op_group_for_op_slice(op_slice6)
    op_group7 = manager.create_op_group_for_op_slice(op_slice7)
    op_group8 = manager.create_op_group_for_op_slice(op_slice8)

    # Group all OpGroup together by grouping their OpSlice.
    manager.group_op_slices([op_slice1, op_slice2, op_slice3, op_slice4])
    manager.group_op_slices([op_slice5, op_slice6, op_slice7, op_slice8])
    # Transitively create one large group by grouping one OpSlice from each
    # group.
    manager.group_op_slices([op_slice3, op_slice6])

    expected_group = orm.OpGroup(
        op_groups=[op_group1, op_group2, op_group3, op_group4, op_group5,
                   op_group6, op_group7, op_group8])

    # Check all OpSlice are in one big group.
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice1).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice2).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice3).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice4).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice5).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice6).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice7).op_slices)
    self.assertListEqual(
        expected_group.op_slices,
        manager.get_op_group(op_slice8).op_slices)

  def testSliceOp_SingleSlice(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    identity3 = tf.identity(inputs)
    identity4 = tf.identity(inputs)
    identity5 = tf.identity(inputs)
    identity6 = tf.identity(inputs)
    identity7 = tf.identity(inputs)
    identity8 = tf.identity(inputs)

    manager = orm.OpRegularizerManager([], self._default_op_handler_dict)

    # Create OpSlice for each identity op.
    op_slice1 = manager.get_op_slices(identity1.op)[0]
    op_slice2 = manager.get_op_slices(identity2.op)[0]
    op_slice3 = manager.get_op_slices(identity3.op)[0]
    op_slice4 = manager.get_op_slices(identity4.op)[0]
    op_slice5 = manager.get_op_slices(identity5.op)[0]
    op_slice6 = manager.get_op_slices(identity6.op)[0]
    op_slice7 = manager.get_op_slices(identity7.op)[0]
    op_slice8 = manager.get_op_slices(identity8.op)[0]

    # Create OpGroup for each OpSlice.
    manager.create_op_group_for_op_slice(op_slice1)
    manager.create_op_group_for_op_slice(op_slice2)
    manager.create_op_group_for_op_slice(op_slice3)
    manager.create_op_group_for_op_slice(op_slice4)
    manager.create_op_group_for_op_slice(op_slice5)
    manager.create_op_group_for_op_slice(op_slice6)
    manager.create_op_group_for_op_slice(op_slice7)
    manager.create_op_group_for_op_slice(op_slice8)

    # Group all OpGroup together by grouping their OpSlice.
    manager.group_op_slices([op_slice1, op_slice2, op_slice3, op_slice4])
    manager.group_op_slices([op_slice5, op_slice6, op_slice7, op_slice8])

    # Only slice identity1 op.  This will also slice identity2, identity3, and
    # identity4 because the slices are grouped.  The ops identity5, identity6,
    # identity7, and identity8 are unaffected.
    manager.slice_op(identity1.op, [1, 2])

    # Verify ops grouped with identity1 are sliced, while other ops are not.
    self.assertLen(manager.get_op_slices(identity1.op), 2)
    self.assertLen(manager.get_op_slices(identity2.op), 2)
    self.assertLen(manager.get_op_slices(identity3.op), 2)
    self.assertLen(manager.get_op_slices(identity4.op), 2)
    self.assertLen(manager.get_op_slices(identity5.op), 1)
    self.assertLen(manager.get_op_slices(identity6.op), 1)
    self.assertLen(manager.get_op_slices(identity7.op), 1)
    self.assertLen(manager.get_op_slices(identity8.op), 1)

    # Verify sliced ops have sizes [1, 2].
    for op in (identity1.op, identity2.op, identity3.op, identity4.op):
      op_slices = manager.get_op_slices(op)
      self.assertEqual(0, op_slices[0].slice.start_index)
      self.assertEqual(1, op_slices[0].slice.size)
      self.assertEqual(1, op_slices[1].slice.start_index)
      self.assertEqual(2, op_slices[1].slice.size)

  def testSliceOp_SingleSlice_Ungrouped(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)

    manager = orm.OpRegularizerManager([], self._default_op_handler_dict)

    # Only slice identity1 op which is ungrouped.
    manager.slice_op(identity1.op, [1, 2])

    # Verify identity1 op is sliced.
    self.assertLen(manager.get_op_slices(identity1.op), 2)

    # Verify sliced op has size [1, 2].
    op_slices = manager.get_op_slices(identity1.op)
    self.assertEqual(0, op_slices[0].slice.start_index)
    self.assertEqual(1, op_slices[0].slice.size)
    self.assertEqual(1, op_slices[1].slice.start_index)
    self.assertEqual(2, op_slices[1].slice.size)

  def testSliceOp_MultipleSlices(self):
    inputs = tf.zeros([2, 4, 4, 20])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    identity3 = tf.identity(inputs)

    manager = orm.OpRegularizerManager([], self._default_op_handler_dict)

    # First op has sizes [4, 3, 7, 6].
    op_slice1_0_4 = orm.OpSlice(identity1.op, orm.Slice(0, 4))
    op_slice1_4_7 = orm.OpSlice(identity1.op, orm.Slice(4, 3))
    op_slice1_7_14 = orm.OpSlice(identity1.op, orm.Slice(7, 7))
    op_slice1_14_20 = orm.OpSlice(identity1.op, orm.Slice(14, 6))

    # Second op has sizes [3, 7, 10].
    op_slice2_0_3 = orm.OpSlice(identity2.op, orm.Slice(0, 3))
    op_slice2_3_10 = orm.OpSlice(identity2.op, orm.Slice(3, 7))
    op_slice2_10_20 = orm.OpSlice(identity2.op, orm.Slice(10, 10))

    # Third op has sizes [2, 2, 2, 2, 3, 7, 2].
    op_slice3_0_2 = orm.OpSlice(identity3.op, orm.Slice(0, 2))
    op_slice3_2_4 = orm.OpSlice(identity3.op, orm.Slice(2, 2))
    op_slice3_4_6 = orm.OpSlice(identity3.op, orm.Slice(4, 2))
    op_slice3_6_8 = orm.OpSlice(identity3.op, orm.Slice(6, 2))
    op_slice3_8_11 = orm.OpSlice(identity3.op, orm.Slice(8, 3))
    op_slice3_11_18 = orm.OpSlice(identity3.op, orm.Slice(11, 7))
    op_slice3_18_20 = orm.OpSlice(identity3.op, orm.Slice(18, 2))

    manager._op_slice_dict = {
        identity1.op: [op_slice1_0_4, op_slice1_4_7, op_slice1_7_14,
                       op_slice1_14_20],
        identity2.op: [op_slice2_0_3, op_slice2_3_10, op_slice2_10_20],
        identity3.op: [op_slice3_0_2, op_slice3_2_4, op_slice3_4_6,
                       op_slice3_6_8, op_slice3_8_11, op_slice3_11_18,
                       op_slice3_18_20],
    }

    # Only the [3, 7] slices of the ops are grouped.  Only the first op is a
    # source.
    manager.group_op_slices(
        [op_slice1_4_7, op_slice2_0_3, op_slice3_8_11],
        omit_source_op_slices=[op_slice2_0_3, op_slice3_8_11])
    manager.group_op_slices(
        [op_slice1_7_14, op_slice2_3_10, op_slice3_11_18],
        omit_source_op_slices=[op_slice2_3_10, op_slice3_11_18])

    # Slice the [3, 7] grouped slices into [1, 2, 3, 4].
    manager.slice_op(identity1.op, [4, 1, 2, 3, 4, 6])

    # Verify grouped ops are sliced into the correct sizes.
    op_slices1 = manager.get_op_slices(identity1.op)
    op_slices2 = manager.get_op_slices(identity2.op)
    op_slices3 = manager.get_op_slices(identity3.op)

    expected_sizes1 = [4, 1, 2, 3, 4, 6]
    expected_sizes2 = [1, 2, 3, 4, 10]
    expected_sizes3 = [2, 2, 2, 2, 1, 2, 3, 4, 2]

    self.assertListEqual(
        expected_sizes1, [s.slice.size for s in op_slices1])
    self.assertListEqual(
        expected_sizes2, [s.slice.size for s in op_slices2])
    self.assertListEqual(
        expected_sizes3, [s.slice.size for s in op_slices3])

    # Verify new slices are grouped.
    op_slice1_4_5 = orm.OpSlice(identity1.op, orm.Slice(4, 1))
    op_slice1_5_7 = orm.OpSlice(identity1.op, orm.Slice(5, 2))
    op_slice1_7_10 = orm.OpSlice(identity1.op, orm.Slice(7, 3))
    op_slice1_10_14 = orm.OpSlice(identity1.op, orm.Slice(10, 4))

    op_slice2_0_1 = orm.OpSlice(identity2.op, orm.Slice(0, 1))
    op_slice2_1_3 = orm.OpSlice(identity2.op, orm.Slice(1, 2))
    op_slice2_3_6 = orm.OpSlice(identity2.op, orm.Slice(3, 3))
    op_slice2_6_10 = orm.OpSlice(identity2.op, orm.Slice(6, 4))

    op_slice3_8_9 = orm.OpSlice(identity3.op, orm.Slice(8, 1))
    op_slice3_9_11 = orm.OpSlice(identity3.op, orm.Slice(9, 2))
    op_slice3_11_14 = orm.OpSlice(identity3.op, orm.Slice(11, 3))
    op_slice3_14_18 = orm.OpSlice(identity3.op, orm.Slice(14, 4))

    expected_group1 = [op_slice1_4_5, op_slice2_0_1, op_slice3_8_9]
    expected_group2 = [op_slice1_5_7, op_slice2_1_3, op_slice3_9_11]
    expected_group3 = [op_slice1_7_10, op_slice2_3_6, op_slice3_11_14]
    expected_group4 = [op_slice1_10_14, op_slice2_6_10, op_slice3_14_18]

    self.assertListEqual(
        expected_group1, manager.get_op_group(op_slice1_4_5).op_slices)
    self.assertListEqual(
        expected_group1, manager.get_op_group(op_slice2_0_1).op_slices)
    self.assertListEqual(
        expected_group1, manager.get_op_group(op_slice3_8_9).op_slices)

    self.assertListEqual(
        expected_group2, manager.get_op_group(op_slice1_5_7).op_slices)
    self.assertListEqual(
        expected_group2, manager.get_op_group(op_slice2_1_3).op_slices)
    self.assertListEqual(
        expected_group2, manager.get_op_group(op_slice3_9_11).op_slices)

    self.assertListEqual(
        expected_group3, manager.get_op_group(op_slice1_7_10).op_slices)
    self.assertListEqual(
        expected_group3, manager.get_op_group(op_slice2_3_6).op_slices)
    self.assertListEqual(
        expected_group3, manager.get_op_group(op_slice3_11_14).op_slices)

    self.assertListEqual(
        expected_group4, manager.get_op_group(op_slice1_10_14).op_slices)
    self.assertListEqual(
        expected_group4, manager.get_op_group(op_slice2_6_10).op_slices)
    self.assertListEqual(
        expected_group4, manager.get_op_group(op_slice3_14_18).op_slices)

  def testProcessOps(self):
    inputs = tf.zeros([2, 4, 4, 3])
    batch_norm = layers.batch_norm(inputs)
    identity1 = tf.identity(batch_norm)
    identity2 = tf.identity(batch_norm)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    manager.process_ops([identity1.op, identity2.op, batch_norm.op])

    self.assertLen(manager._op_deque, 3)
    self.assertEqual(batch_norm.op, manager._op_deque.pop())
    self.assertEqual(identity2.op, manager._op_deque.pop())
    self.assertEqual(identity1.op, manager._op_deque.pop())

  def testProcessOps_DuplicatesRemoved(self):
    inputs = tf.zeros([2, 4, 4, 3])
    batch_norm = layers.batch_norm(inputs)
    identity1 = tf.identity(batch_norm)
    identity2 = tf.identity(batch_norm)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    manager.process_ops([identity1.op, identity2.op, batch_norm.op])
    # Try to process the same ops again.
    manager.process_ops([identity1.op, identity2.op, batch_norm.op])

    self.assertLen(manager._op_deque, 3)
    self.assertEqual(batch_norm.op, manager._op_deque.pop())
    self.assertEqual(identity2.op, manager._op_deque.pop())
    self.assertEqual(identity1.op, manager._op_deque.pop())

  def testProcessOpsLast(self):
    inputs = tf.zeros([2, 4, 4, 3])
    batch_norm = layers.batch_norm(inputs)
    identity1 = tf.identity(batch_norm)
    identity2 = tf.identity(batch_norm)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    manager.process_ops([identity1.op])
    manager.process_ops_last([identity2.op, batch_norm.op])

    self.assertLen(manager._op_deque, 3)
    self.assertEqual(identity1.op, manager._op_deque.pop())
    self.assertEqual(identity2.op, manager._op_deque.pop())
    self.assertEqual(batch_norm.op, manager._op_deque.pop())

  def testProcessOpsLast_DuplicatesRemoved(self):
    inputs = tf.zeros([2, 4, 4, 3])
    batch_norm = layers.batch_norm(inputs)
    identity1 = tf.identity(batch_norm)
    identity2 = tf.identity(batch_norm)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    manager.process_ops([identity1.op])
    manager.process_ops_last([identity2.op, batch_norm.op])
    # Try to process the same ops again.
    manager.process_ops_last([identity2.op, batch_norm.op])

    self.assertLen(manager._op_deque, 3)
    self.assertEqual(identity1.op, manager._op_deque.pop())
    self.assertEqual(identity2.op, manager._op_deque.pop())
    self.assertEqual(batch_norm.op, manager._op_deque.pop())

  def testIsSourceOp(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)
    batch_norm = layers.batch_norm(identity)

    manager = orm.OpRegularizerManager([], self._default_op_handler_dict)

    self.assertFalse(manager.is_source_op(identity.op))
    self.assertTrue(manager.is_source_op(batch_norm.op))

  def testIsPassthrough(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)
    layers.conv2d(identity, 5, 3, scope='conv1')

    manager = orm.OpRegularizerManager([], self._default_op_handler_dict)

    self.assertTrue(manager.is_passthrough(identity.op))
    # TODO(a1): Verify OutputNonPassthrough OpHandler returns False.

  def testGetOpSlices(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)

    # Create OpRegularizerManager with OpSlice mapping.
    manager = orm.OpRegularizerManager([])
    op_slice = orm.OpSlice(identity.op, orm.Slice(0, 3))
    manager._op_slice_dict[identity.op] = [op_slice]

    op_slices = manager.get_op_slices(identity.op)

    self.assertLen(op_slices, 1)
    self.assertEqual(op_slice, op_slices[0])

  def testGetOpSlices_CreateNew(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)

    # Create OpRegularizerManager with empty OpSlice dictionary.
    manager = orm.OpRegularizerManager([])
    manager._op_slice_dict = {}

    op_slices = manager.get_op_slices(identity.op)

    # Verify OpSlice is created correctly.
    self.assertLen(op_slices, 1)
    op_slice = op_slices[0]
    self.assertEqual(identity.op, op_slice.op)
    self.assertEqual(0, op_slice.slice.start_index)
    self.assertEqual(3, op_slice.slice.size)

  def testGetOpSlices_CreateNew_MultipleOutputs(self):
    inputs = tf.zeros([2, 4, 4, 10])
    split = tf.split(inputs, [3, 7], axis=3)
    split_op = split[0].op

    # Create OpRegularizerManager with empty OpSlice dictionary.
    manager = orm.OpRegularizerManager([])
    manager._op_slice_dict = {}

    op_slices = manager.get_op_slices(split_op)

    # Verify OpSlice is created correctly.
    self.assertLen(op_slices, 1)
    op_slice = op_slices[0]
    self.assertEqual(split_op, op_slice.op)
    self.assertEqual(0, op_slice.slice.start_index)
    self.assertEqual(10, op_slice.slice.size)

  def testGetOpSlices_ZeroSize(self):
    constant = tf.constant(123)

    # Create OpRegularizerManager with empty OpSlice dictionary.
    manager = orm.OpRegularizerManager([])
    manager._op_slice_dict = {}

    op_slices = manager.get_op_slices(constant.op)

    # Verify zero-size op has no slices.
    self.assertListEqual([], op_slices)

  def testSliceOpSlice(self):
    inputs = tf.zeros([2, 4, 4, 10])
    identity = tf.identity(inputs)

    op_slice1 = orm.OpSlice(identity.op, orm.Slice(0, 2))
    op_slice2 = orm.OpSlice(identity.op, orm.Slice(2, 6))
    op_slice3 = orm.OpSlice(identity.op, orm.Slice(8, 2))

    manager = orm.OpRegularizerManager([])
    manager._op_slice_dict[identity.op] = [op_slice1, op_slice2, op_slice3]

    # Original op has slice sizes [2, 6, 2].  The middle op is being sliced into
    # [1, 3, 2], so the new slice sizes are [2, 1, 3, 2, 2].
    sizes = [2, 1, 3, 2, 2]
    size_index = 1
    size_count = 3
    new_op_slice_group = [list() for _ in range(size_count)]
    manager._slice_op_slice(op_slice2, sizes, size_index, size_count,
                            new_op_slice_group)

    # Verify new slices are created.
    self.assertLen(new_op_slice_group, size_count)
    for i in range(size_count):
      self.assertLen(new_op_slice_group[i], 1)

    # Verify new slices are correct.
    new_slice1 = new_op_slice_group[0][0]
    self.assertEqual(2, new_slice1.slice.start_index)
    self.assertEqual(1, new_slice1.slice.size)

    new_slice2 = new_op_slice_group[1][0]
    self.assertEqual(3, new_slice2.slice.start_index)
    self.assertEqual(3, new_slice2.slice.size)

    new_slice3 = new_op_slice_group[2][0]
    self.assertEqual(6, new_slice3.slice.start_index)
    self.assertEqual(2, new_slice3.slice.size)

  def testSliceOpWithSizes(self):
    inputs = tf.zeros([2, 4, 4, 10])
    identity = tf.identity(inputs)

    manager = orm.OpRegularizerManager([])

    sizes = [1, 2, 3, 4]
    is_source = [True, False, True, False]
    is_resliced = [True, True, True, True]
    op_slices = manager._slice_op_with_sizes(identity.op, sizes, is_source,
                                             is_resliced)

    # Verify OpSlice count and whether they are sources.
    self.assertLen(op_slices, 4)

    slice1 = op_slices[0]
    op_group1 = manager.get_op_group(slice1)
    self.assertIn(slice1, op_group1.source_op_slices)

    slice2 = op_slices[1]
    op_group2 = manager.get_op_group(slice2)
    self.assertIsNone(op_group2)

    slice3 = op_slices[2]
    op_group3 = manager.get_op_group(slice3)
    self.assertIn(slice3, op_group3.source_op_slices)

    slice4 = op_slices[3]
    op_group4 = manager.get_op_group(slice4)
    self.assertIsNone(op_group4)

  def testGetSourceSlices(self):
    inputs = tf.zeros([2, 4, 4, 10])
    identity = tf.identity(inputs)

    manager = orm.OpRegularizerManager([])

    # Create OpSlices with size [3, 7].
    identity_slice1 = orm.OpSlice(identity.op, orm.Slice(0, 3))
    identity_slice2 = orm.OpSlice(identity.op, orm.Slice(3, 7))

    # Create OpGroup where only first group has source OpSlice.
    manager.create_op_group_for_op_slice(identity_slice1)
    manager.create_op_group_for_op_slice(identity_slice2,
                                         is_source=False)

    # First slice of size 3 is sliced into [1, 2], so these are sources.  Second
    # slice of size 7 is sliced into [3, 4], which are not sources.
    sizes = [1, 2, 3, 4]
    expected_sources = [True, True, False, False]
    self.assertListEqual(
        expected_sources,
        manager._get_source_slices(sizes, [identity_slice1, identity_slice2]))

  def testDfsForSourceOps(self):
    with arg_scope(self._batch_norm_scope()):
      inputs = tf.zeros([2, 4, 4, 3])
      c1 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv1')
      c2 = layers.conv2d(inputs, num_outputs=10, kernel_size=3, scope='conv2')
      tmp = c1 + c2
      c3 = layers.conv2d(tmp, num_outputs=10, kernel_size=3, scope='conv3')
      out = tf.identity(c3)
      # Extra branch that is not a dependency of out.
      concat = tf.concat([c1, c2], axis=3)
      layers.conv2d(concat, num_outputs=10, kernel_size=3, scope='conv4')

    manager = orm.OpRegularizerManager([], self._default_op_handler_dict)
    manager._dfs_for_source_ops([out.op])

    # Verify source ops were found.
    expected_queue = collections.deque([
        _get_op('conv3/BatchNorm/FusedBatchNormV3'),
        _get_op('conv2/BatchNorm/FusedBatchNormV3'),
        _get_op('conv1/BatchNorm/FusedBatchNormV3')
    ])
    self.assertEqual(expected_queue, manager._op_deque)

    # Verify extra branch was not included.
    self.assertNotIn(
        _get_op('conv4/BatchNorm/FusedBatchNormV3'), manager._op_deque)

  def testOpGroup_NewSourceGroup(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)
    op_slice = orm.OpSlice(identity.op, None)
    op_group = orm.OpGroup(op_slice)

    self.assertListEqual([op_slice], op_group.op_slices)
    self.assertListEqual([op_slice], op_group.source_op_slices)

  def testOpGroup_NewGroupNoSource(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity = tf.identity(inputs)
    op_slice = orm.OpSlice(identity.op, None)
    op_group = orm.OpGroup(op_slice, omit_source_op_slices=[op_slice])

    self.assertListEqual([op_slice], op_group.op_slices)
    self.assertListEqual([], op_group.source_op_slices)

  def testOpGroup_NewSourceGroup_DuplicateOpSlice(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    op_slice1 = orm.OpSlice(identity1.op, None)
    op_slice2 = orm.OpSlice(identity2.op, None)
    op_group1 = orm.OpGroup(op_slice1)
    op_group2 = orm.OpGroup(
        op_slice2, [op_group1], omit_source_op_slices=[op_slice2])
    op_group3 = orm.OpGroup(op_groups=[op_group1, op_group2])

    self.assertListEqual([op_slice1, op_slice2], op_group3.op_slices)
    self.assertListEqual([op_slice1], op_group3.source_op_slices)

  def testOpGroup_MergeGroups(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)
    identity3 = tf.identity(inputs)
    identity4 = tf.identity(inputs)
    identity5 = tf.identity(inputs)
    identity6 = tf.identity(inputs)
    identity7 = tf.identity(inputs)
    identity8 = tf.identity(inputs)

    # Reset OpGroup counter.
    orm.OpGroup._static_index = 0

    # Create OpGroup where only identity3, identity6, and identity7 are sources.
    op_slice1 = orm.OpSlice(identity1.op, None)
    op_group1 = orm.OpGroup(op_slice1, omit_source_op_slices=[op_slice1])
    op_slice2 = orm.OpSlice(identity2.op, None)
    op_group2 = orm.OpGroup(op_slice2, omit_source_op_slices=[op_slice2])
    op_slice3 = orm.OpSlice(identity3.op, None)
    op_group3 = orm.OpGroup(op_slice3)
    op_slice4 = orm.OpSlice(identity4.op, None)
    op_group4 = orm.OpGroup(op_slice4, omit_source_op_slices=[op_slice4])
    op_slice5 = orm.OpSlice(identity5.op, None)
    op_group5 = orm.OpGroup(op_slice5, omit_source_op_slices=[op_slice5])
    op_slice6 = orm.OpSlice(identity6.op, None)
    op_group6 = orm.OpGroup(op_slice6)
    op_slice7 = orm.OpSlice(identity7.op, None)
    op_group7 = orm.OpGroup(op_slice7)
    op_slice8 = orm.OpSlice(identity8.op, None)
    op_group8 = orm.OpGroup(op_slice8, omit_source_op_slices=[op_slice8])

    # Merge group1 and group2 into group9.
    op_group9 = orm.OpGroup(op_groups=[op_group1, op_group2])
    self.assertListEqual([op_slice1, op_slice2], op_group9.op_slices)
    self.assertListEqual([], op_group9.source_op_slices)
    self.assertEqual(8, op_group9._index)  # OpGroup is zero-indexed.

    # Merge group3 and group4 into group10.
    op_group10 = orm.OpGroup(op_groups=[op_group3, op_group4])
    self.assertListEqual([op_slice3, op_slice4], op_group10.op_slices)
    self.assertListEqual([op_slice3], op_group10.source_op_slices)
    self.assertEqual(9, op_group10._index)  # OpGroup is zero-indexed.

    # Merge group5, group6, group7, and group8 into group 11.
    op_group11 = orm.OpGroup(
        op_groups=[op_group5, op_group6, op_group7, op_group8])
    self.assertListEqual(
        [op_slice5, op_slice6, op_slice7, op_slice8], op_group11.op_slices)
    self.assertListEqual([op_slice6, op_slice7], op_group11.source_op_slices)
    self.assertEqual(10, op_group11._index)  # OpGroup is zero-indexed.

    # Merge group9 and group10 into group12.
    op_group12 = orm.OpGroup(op_groups=[op_group9, op_group10])
    self.assertListEqual(
        [op_slice1, op_slice2, op_slice3, op_slice4], op_group12.op_slices)
    self.assertListEqual([op_slice3], op_group12.source_op_slices)
    self.assertEqual(11, op_group12._index)  # OpGroup is zero-indexed.

    # Merge group11 and group12 into group13.
    op_group13 = orm.OpGroup(op_groups=[op_group11, op_group12])
    self.assertListEqual(
        [op_slice5, op_slice6, op_slice7, op_slice8, op_slice1, op_slice2,
         op_slice3, op_slice4],
        op_group13.op_slices)
    self.assertListEqual([op_slice6, op_slice7, op_slice3],
                         op_group13.source_op_slices)
    self.assertEqual(12, op_group13._index)  # OpGroup is zero-indexed.

  def testCorrectSourceOpsWithSkipConnection(self):
    inputs = tf.zeros([2, 4, 4, 3])
    x0 = layers.conv2d(
        inputs, num_outputs=8, kernel_size=3, activation_fn=None, scope='conv0')
    x1 = tf.nn.relu(layers.batch_norm(x0, scale=True, scope='bn0'))
    x1 = layers.conv2d(
        x1, num_outputs=8, kernel_size=3, activation_fn=None, scope='conv1')
    x2 = tf.add_n([x0, x1], name='add')
    final_op = tf.nn.relu(layers.batch_norm(x2, scale=True, scope='bn1'))

    op_handler_dict = self._default_op_handler_dict
    op_reg_manager = orm.OpRegularizerManager([final_op.op], op_handler_dict)

    # All ops are in the same group
    group = list(op_reg_manager._op_group_dict.values())[0]
    source_op_names = [s.op.name for s in group.source_op_slices]
    self.assertSetEqual(set(['bn0/FusedBatchNormV3', 'bn1/FusedBatchNormV3']),
                        set(source_op_names))

  def testPrintOpSlices(self):
    inputs = tf.zeros([2, 4, 4, 3])
    identity1 = tf.identity(inputs)
    identity2 = tf.identity(inputs)

    manager = orm.OpRegularizerManager(
        [identity1.op, identity2.op],
        op_handler_dict=self._default_op_handler_dict)
    op_slices1 = manager.get_op_slices(identity1.op)
    op_slices2 = manager.get_op_slices(identity2.op)
    all_slices = op_slices1 + op_slices2

    self.assertEqual('[Identity (0, 3), Identity_1 (0, 3)]',
                     str(all_slices))


class IndexOpRegularizer(generic_regularizers.OpRegularizer):
  """A test OpRegularizer with a self-incrementing index.

  This class creates a regularizer where the regularization vector contains
  self-incrementing values (e.g. [0, 1, 2, ...]).  The index continues to
  increment as regularizers are created.  This is convenient for testing in
  order to track individual elements of the regularization vector (e.g. gather).

  For example, creating 2 regularizers of size 3 results in
  r1 = [0, 1, 2] and r2 = [3, 4, 5].
  """

  index = 0

  def __init__(self, op_slice, op_reg_manager):
    size = op_slice.slice.size
    self._alive_vector = tf.cast(tf.ones(size), tf.bool)
    self._regularization_vector = tf.constant(
        list(range(IndexOpRegularizer.index, IndexOpRegularizer.index + size)),
        tf.float32)
    IndexOpRegularizer.index += size

  @classmethod
  def reset_index(cls):
    IndexOpRegularizer.index = 0

  @property
  def regularization_vector(self):
    return self._regularization_vector

  @property
  def alive_vector(self):
    return self._alive_vector


class SumGroupingRegularizer(generic_regularizers.OpRegularizer):
  """A regularizer that groups others by summing their regularization values."""

  def __init__(self, regularizers_to_group):
    """Creates an instance.

    Args:
      regularizers_to_group: A list of generic_regularizers.OpRegularizer
        objects.Their regularization_vector (alive_vector) are expected to be of
        the same length.

    Raises:
      ValueError: regularizers_to_group is not of length at least 2.
    """
    if len(regularizers_to_group) < 2:
      raise ValueError('Groups must be of at least size 2.')
    self._regularization_vector = tf.add_n(
        [r.regularization_vector for r in regularizers_to_group])
    self._alive_vector = tf.cast(
        tf.ones(self._regularization_vector.get_shape()[-1]), tf.bool)

  @property
  def regularization_vector(self):
    return self._regularization_vector

  @property
  def alive_vector(self):
    return self._alive_vector


class IndexBatchNormSourceOpHandler(
    batch_norm_source_op_handler.BatchNormSourceOpHandler):
  """An OpHandler that creates OpRegularizer using IndexOpRegularizer.

  A wrapper around BatchNormSourceOpHandler that overrides the
  create_regularizer method to use IndexOpRegularizer for testing.
  """

  def __init__(self):
    super(IndexBatchNormSourceOpHandler, self).__init__(0.0)

  def create_regularizer(self, op_slice):
    return IndexOpRegularizer(op_slice, None)


class StubBatchNormSourceOpHandler(
    batch_norm_source_op_handler.BatchNormSourceOpHandler):
  """An OpHandler that creates OpRegularizer using stub values.

  A wrapper around BatchNormSourceOpHandler that overrides the
  create_regularizer method to use stub values for testing.
  """

  def __init__(self, model_stub):
    super(StubBatchNormSourceOpHandler, self).__init__(0.0)
    self._model_stub = model_stub

  def create_regularizer(self, op_slice):
    return _stub_create_regularizer(op_slice, self._model_stub)


class IndexConvSourceOpHandler(
    conv_source_op_handler.ConvSourceOpHandler):
  """An OpHandler that creates OpRegularizer using IndexOpRegularizer.

  A wrapper around ConvSourceOpHandler that overrides the create_regularizer
  method to use IndexOpRegularizer for testing.
  """

  def __init__(self):
    pass

  def create_regularizer(self, op_slice):
    return IndexOpRegularizer(op_slice, None)


class StubConvSourceOpHandler(conv_source_op_handler.ConvSourceOpHandler):
  """An OpHandler that creates OpRegularizer using stub values.

  A wrapper around ConvSourceOpHandler that overrides the create_regularizer
  method to use stub values for testing.
  """

  def __init__(self, model_stub):
    super(StubConvSourceOpHandler, self).__init__(0.1)
    self._model_stub = model_stub

  def create_regularizer(self, op_slice):
    return _stub_create_regularizer(op_slice, self._model_stub)


class RandomConvSourceOpHandler(
    conv_source_op_handler.ConvSourceOpHandler):
  """An OpHandler that creates OpRegularizer using random values.

  A wrapper around ConvSourceOpHandler that overrides the create_regularizer
  method to use random values for testing.
  """

  def create_regularizer(self, op_slice):
    regularization_vector = np.random.random(op_slice.slice.size)
    return StubOpRegularizer(regularization_vector,
                             regularization_vector > self._threshold)


def _stub_create_regularizer(op_slice, model_stub):
  """Create a StubOpRegularizer for a given OpSlice.

  Args:
    op_slice: A op_regularizer_manager.OpSlice.
    model_stub: Module name where REG_STUB and ALIVE_STUB will be found.

  Returns:
    StubOpRegularizer with stubbed regularization and alive vectors.
  """
  op = op_slice.op
  start_index = op_slice.slice.start_index
  size = op_slice.slice.size
  for key in model_stub.REG_STUB:
    if op.name.startswith(key):
      return StubOpRegularizer(
          model_stub.REG_STUB[key][start_index:start_index + size],
          model_stub.ALIVE_STUB[key][start_index:start_index + size])
  raise ValueError('No regularizer for %s' % op.name)


class StubOpRegularizer(generic_regularizers.OpRegularizer):
  """A test OpRegularizer with configured regularization vectors.

  Regularization values are stored in a dict and keyed on op name prefix.
  """

  def __init__(self, regularization_vector, alive_vector):
    self._regularization_vector = tf.constant(regularization_vector)
    self._alive_vector = tf.constant(alive_vector, dtype=tf.bool)

  @property
  def regularization_vector(self):
    return self._regularization_vector

  @property
  def alive_vector(self):
    return self._alive_vector


if __name__ == '__main__':
  tf.test.main()