"""Tests for depth_to_space_op_handler."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import mock

from morph_net.framework import depth_to_space_op_handler
from morph_net.framework import op_regularizer_manager as orm
import tensorflow.compat.v1 as tf


class DepthToSpaceOpHandlerTest(tf.test.TestCase):

  def setUp(self):
    super(DepthToSpaceOpHandlerTest, self).setUp()
    # Test a Identity -> DepthToSpace -> Identity chain of ops.
    inputs = tf.zeros([2, 4, 4, 4])
    id1 = tf.identity(inputs)
    dts = tf.depth_to_space(id1, 2)
    tf.identity(dts)

    g = tf.get_default_graph()

    # Declare OpSlice and OpGroup for ops of interest.
    self.id1_op = g.get_operation_by_name('Identity')
    self.id1_op_slice = orm.OpSlice(self.id1_op, orm.Slice(0, 4))
    self.id1_op_group = orm.OpGroup(self.id1_op_slice,
                                    omit_source_op_slices=[self.id1_op_slice])
    self.id1_op_slice0 = orm.OpSlice(self.id1_op, orm.Slice(0, 1))
    self.id1_op_slice1 = orm.OpSlice(self.id1_op, orm.Slice(1, 1))
    self.id1_op_slice2 = orm.OpSlice(self.id1_op, orm.Slice(2, 1))
    self.id1_op_slice3 = orm.OpSlice(self.id1_op, orm.Slice(3, 1))

    self.dts_op = g.get_operation_by_name('DepthToSpace')
    self.dts_op_slice = orm.OpSlice(self.dts_op, orm.Slice(0, 1))
    self.dts_op_group = orm.OpGroup(self.dts_op_slice,
                                    omit_source_op_slices=[self.dts_op_slice])

    self.id2_op = g.get_operation_by_name('Identity_1')
    self.id2_op_slice = orm.OpSlice(self.id2_op, orm.Slice(0, 1))
    self.id2_op_group = orm.OpGroup(self.id2_op_slice,
                                    omit_source_op_slices=[self.id2_op_slice])

    # Create mock OpRegularizerManager with custom mapping of OpSlice and
    # OpGroup.
    self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)

    self.op_slice_dict = {
        self.id1_op: [self.id1_op_slice],
        self.dts_op: [self.dts_op_slice],
        self.id2_op: [self.id2_op_slice],
    }
    def get_op_slices(op):
      return self.op_slice_dict.get(op)

    def get_op_group(op_slice):
      return self.op_group_dict.get(op_slice)

    self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
    self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
    self.mock_op_reg_manager.is_source_op.return_value = False
    self.mock_op_reg_manager.ops = [self.id1_op, self.dts_op, self.id2_op]

  def test_assign_grouping_no_neighbor_groups(self):
    # No ops have groups.
    self.op_group_dict = {}

    # Call handler to assign grouping.
    handler = depth_to_space_op_handler.DepthToSpaceOpHandler()
    handler.assign_grouping(self.dts_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        [mock.call(self.id1_op),
         mock.call(self.id2_op)])

    # Verify manager does not group.
    self.mock_op_reg_manager.group_op_slices.assert_not_called()

    # Verify manager processes grouping for identity ops.
    self.mock_op_reg_manager.process_ops.assert_called_once_with(
        [self.id1_op])

  def test_assign_grouping_all_inputs_grouped(self):
    # Map ops to slices.
    self.op_slice_dict[self.id1_op] = [
        self.id1_op_slice0,
        self.id1_op_slice1,
        self.id1_op_slice2,
        self.id1_op_slice3]

    # All inputs have groups.
    self.op_group_dict = {
        self.id1_op_slice0: self.id1_op_group,
        self.id1_op_slice1: self.id1_op_group,
        self.id1_op_slice2: self.id1_op_group,
        self.id1_op_slice3: self.id1_op_group,
    }

    # Call handler to assign grouping.
    handler = depth_to_space_op_handler.DepthToSpaceOpHandler()
    handler.assign_grouping(self.dts_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.id1_op),
         mock.call(self.id2_op),
         # Reslicing.
         mock.call(self.id1_op),
         mock.call(self.dts_op),
         mock.call(self.id2_op),
         # Refreshing slice data.
         mock.call(self.dts_op),
         mock.call(self.id1_op)])

    # Verify manager groups DepthToSpace channel with individual input channels.
    self.mock_op_reg_manager.group_op_slices.assert_called_once_with(
        [self.id1_op_slice0, self.id1_op_slice1, self.id1_op_slice2,
         self.id1_op_slice3, self.dts_op_slice])

    # Verify manager processes grouping for identity ops.
    self.mock_op_reg_manager.process_ops.assert_called_once_with([self.id2_op])

  def test_assign_grouping_all_outputs_grouped(self):
    # All outputs have groups.
    self.op_group_dict = {
        self.id2_op_slice: self.id2_op_group,
    }

    # Call handler to assign grouping.
    handler = depth_to_space_op_handler.DepthToSpaceOpHandler()
    handler.assign_grouping(self.dts_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.id1_op),
         mock.call(self.id2_op)])

    # Verify manager does not group.
    self.mock_op_reg_manager.group_op_slices.assert_not_called()

    # Verify manager processes grouping for identity ops.
    self.mock_op_reg_manager.process_ops.assert_called_once_with(
        [self.id1_op])

  def test_assign_grouping_all_neighbors_grouped(self):
    # Map ops to slices.
    self.op_slice_dict[self.id1_op] = [
        self.id1_op_slice0,
        self.id1_op_slice1,
        self.id1_op_slice2,
        self.id1_op_slice3]

    # All neighbors have groups.
    self.op_group_dict = {
        self.id1_op_slice0: self.id1_op_group,
        self.id1_op_slice1: self.id1_op_group,
        self.id1_op_slice2: self.id1_op_group,
        self.id1_op_slice3: self.id1_op_group,
        self.id2_op_slice: self.id2_op_group,
    }

    # Call handler to assign grouping.
    handler = depth_to_space_op_handler.DepthToSpaceOpHandler()
    handler.assign_grouping(self.dts_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.id1_op),
         mock.call(self.id2_op),
         # Reslicing.
         mock.call(self.id1_op),
         mock.call(self.dts_op),
         mock.call(self.id2_op),
         # Refreshing slice data.
         mock.call(self.dts_op),
         mock.call(self.id1_op)])

    # Verify manager groups DepthToSpace channel with individual input channels.
    self.mock_op_reg_manager.group_op_slices.assert_called_once_with(
        [self.id1_op_slice0, self.id1_op_slice1, self.id1_op_slice2,
         self.id1_op_slice3, self.dts_op_slice])

    # Verify manager processes grouping for identity ops.
    self.mock_op_reg_manager.process_ops.assert_not_called()

  def test_assign_grouping_all_neighbors_grouped_same_group(self):
    # Map ops to slices.
    self.op_slice_dict[self.id1_op] = [
        self.id1_op_slice0,
        self.id1_op_slice1,
        self.id1_op_slice2,
        self.id1_op_slice3]

    # All neighbors have the same group.
    self.op_group_dict = {
        self.id1_op_slice0: self.id1_op_group,
        self.id1_op_slice1: self.id1_op_group,
        self.id1_op_slice2: self.id1_op_group,
        self.id1_op_slice3: self.id1_op_group,
        self.id2_op_slice: self.id1_op_group,
    }

    # Call handler to assign grouping.
    handler = depth_to_space_op_handler.DepthToSpaceOpHandler()
    handler.assign_grouping(self.dts_op, self.mock_op_reg_manager)

    # Verify manager looks up OpSlice for ops of interest.
    self.mock_op_reg_manager.get_op_slices.assert_has_calls(
        # Checking for ops to process.
        [mock.call(self.id1_op),
         mock.call(self.id2_op),
         # Reslicing.
         mock.call(self.id1_op),
         mock.call(self.dts_op),
         mock.call(self.id2_op),
         # Refreshing slice data.
         mock.call(self.dts_op),
         mock.call(self.id1_op)])

    # Verify manager groups DepthToSpace channel with individual input channels.
    self.mock_op_reg_manager.group_op_slices.assert_called_once_with(
        [self.id1_op_slice0, self.id1_op_slice1, self.id1_op_slice2,
         self.id1_op_slice3, self.dts_op_slice])

    # Verify manager processes grouping for identity ops.
    self.mock_op_reg_manager.process_ops.assert_not_called()


if __name__ == '__main__':
  tf.test.main()