"""Tests for network_regularizers.cost_calculator."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
from absl.testing import parameterized
from morph_net.framework import batch_norm_source_op_handler
from morph_net.framework import concat_op_handler
from morph_net.framework import grouping_op_handler
from morph_net.framework import op_regularizer_manager as orm
from morph_net.framework import output_non_passthrough_op_handler
from morph_net.network_regularizers import cost_calculator as cc
from morph_net.network_regularizers import resource_function
from morph_net.testing import add_concat_model_stub
import tensorflow.compat.v1 as tf

from tensorflow.contrib import layers
from tensorflow.contrib.framework import arg_scope


class CostCalculatorTest(parameterized.TestCase, tf.test.TestCase):

  def _batch_norm_scope(self):
    params = {
        'trainable': True,
        'normalizer_fn': layers.batch_norm,
        'normalizer_params': {
            'scale': True
        }
    }

    with arg_scope([layers.conv2d], **params) as sc:
      return sc

  def testImageIsNotZerothOutputOfOp(self):
    # Throughout the framework, we assume that the 0th output of each op is the
    # only one of interest. One exception that often happens is when the input
    # image comes from a queue or from a staging op. Then the image is one of
    # the outputs of the dequeue (or staging) op, not necessarily the 0th one.
    # Here we test that the BilinearNetworkRegularizer deals correctly with this
    # case.

    # Create an input op where the image is output number 1, not 0.
    # TODO(g1) Move this mechanism to add_concat_model_stub, possibly using
    # tf.split to produce an op where the image is not the 0th output image
    # (instead of FIFOQueue).
    image = add_concat_model_stub.image_stub()
    non_image_tensor = tf.zeros(shape=(41,))
    queue = tf.FIFOQueue(
        capacity=1,
        dtypes=(tf.float32,) * 2,
        shapes=(non_image_tensor.shape, image.shape))

    # Pass the image (output[1]) to the network.
    with arg_scope(self._batch_norm_scope()):
      output_op = add_concat_model_stub.build_model(queue.dequeue()[1])

    # Create OpHandler dict for test.
    op_handler_dict = collections.defaultdict(
        grouping_op_handler.GroupingOpHandler)
    op_handler_dict.update({
        'FusedBatchNormV3':
            batch_norm_source_op_handler.BatchNormSourceOpHandler(0.1),
        'Conv2D':
            output_non_passthrough_op_handler.OutputNonPassthroughOpHandler(),
        'ConcatV2':
            concat_op_handler.ConcatOpHandler(),
    })

    # Create OpRegularizerManager and NetworkRegularizer for test.
    manager = orm.OpRegularizerManager([output_op], op_handler_dict)
    calculator = cc.CostCalculator(manager, resource_function.flop_function)

    # Calculate expected FLOP cost.
    expected_alive_conv1 = sum(add_concat_model_stub.expected_alive()['conv1'])
    conv1_op = tf.get_default_graph().get_operation_by_name('conv1/Conv2D')
    conv1_coeff = resource_function.flop_coeff(conv1_op)
    num_channels = 3
    expected_cost = conv1_coeff * num_channels * expected_alive_conv1

    with self.session():
      tf.global_variables_initializer().run()
      # Set gamma values to replicate aliveness in add_concat_model_stub.
      name_to_var = {v.op.name: v for v in tf.global_variables()}
      gamma1 = name_to_var['conv1/BatchNorm/gamma']
      gamma1.assign([0, 1, 1, 0, 1, 0, 1]).eval()
      gamma4 = name_to_var['conv4/BatchNorm/gamma']
      gamma4.assign([0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0]).eval()

      queue.enqueue((non_image_tensor, image)).run()
      self.assertEqual(expected_cost,
                       calculator.get_cost([conv1_op]).eval())
      # for 0/1 assigments cost and reg_term are equal:
      self.assertEqual(expected_cost,
                       calculator.get_regularization_term([conv1_op]).eval())

  @parameterized.named_parameters(
      ('_conv2d', 4, lambda x: layers.conv2d(x, 16, 3), 'Conv2D'),
      ('_convt', 4, lambda x: layers.conv2d_transpose(x, 16, 3),
       'conv2d_transpose'),
      ('_conv2s', 4, lambda x: layers.separable_conv2d(x, None, 3),
       'depthwise'),
      ('_conv3d', 5, lambda x: layers.conv3d(x, 16, 3), 'Conv3D'))
  def test_get_input_activation2(self, rank, fn, op_name):
    g = tf.get_default_graph()
    inputs = tf.zeros([6] * rank)
    with arg_scope([
        layers.conv2d, layers.conv2d_transpose, layers.separable_conv2d,
        layers.conv3d
    ],
                   scope='test_layer'):
      _ = fn(inputs)
    for op in g.get_operations():
      print(op.name)
    self.assertEqual(
        inputs,
        cc.get_input_activation(
            g.get_operation_by_name('test_layer/' + op_name)))


if __name__ == '__main__':
  tf.test.main()