# Lint as: python3
"""Tests for morph_net.framework.tpu_util."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized
from morph_net.framework import tpu_util
import tensorflow.compat.v1 as tf
from tensorflow.contrib import slim


class TpuUtilTest(parameterized.TestCase, tf.test.TestCase):

  def build_model(self):
    return slim.conv2d(
        tf.zeros([64, 10, 10, 3]),
        32, [5, 5],
        scope='conv1',
        trainable=True,
        normalizer_fn=slim.batch_norm,
        normalizer_params={
            'scale': True,
            'fused': True
        })

  def get_gamma(self, activation_tensor):
    # The input to the activation tensor is a FusedBatchNorm tensor; its name
    # should be conv1/BatchNorm/FusedBatchNormV3:0, but the version may change.
    batch_norm_tensor, = activation_tensor.op.inputs
    assert 'FusedBatchNorm' in batch_norm_tensor.name

    # gamma tensor is used by MorphNet regularizer. It is:
    # conv1/BatchNorm/ReadVariableOp:0 for ResourceVariable,
    # conv1/BatchNorm/gamma/read:0 for VariableV2.
    (unused_input_tensor, gamma_tensor, unused_beta_tensor,
     unused_population_mean_tensor,
     unused_population_variance_tensor) = batch_norm_tensor.op.inputs

    # gamma_source is the op that drives the value of the gamma:
    # 'conv1/BatchNorm/gamma' of type VarHandleOp for ResourceVariable,
    # 'conv1/BatchNorm/gamma' of type VariableV2 for VariableV2.
    gamma_source_op = gamma_tensor.op.inputs[0].op

    return gamma_tensor, gamma_source_op

  def test_variable_v2(self):
    with tf.variable_scope('', use_resource=False):
      relu = self.build_model()
    gamma_tensor, _ = self.get_gamma(relu)
    # Check that maybe_convert_to_variable ignores VariableV2 (i.e., is no op).
    self.assertEqual(
        tpu_util.maybe_convert_to_variable(gamma_tensor), gamma_tensor)

  def test_resource_variable(self):
    with tf.variable_scope('', use_resource=True):
      relu = self.build_model()
    gamma_tensor, gamma_source_op = self.get_gamma(relu)
    variable = tpu_util.maybe_convert_to_variable(gamma_tensor)

    # First assert that we didn't return the original tensor
    self.assertNotEqual(variable, gamma_tensor)

    # Now check that the variable created by maybe_convert_to_variable is
    # driven by the same op as the tensor passed as input.
    self.assertEqual(variable.op, gamma_source_op)

    # If input tensor is separated from a variable by an extra hop of Identity,
    # maybe_read_variable pretends the Identity op isn't there.
    identity_tensor = tf.identity(gamma_tensor)
    self.assertEqual(
        tpu_util.maybe_convert_to_variable(identity_tensor), variable)

  def test_noop(self):
    with tf.variable_scope('', use_resource=True):
      relu = self.build_model()
    # Check tensors that are not variable reads are ignored.
    self.assertEqual(tpu_util.maybe_convert_to_variable(relu), relu)

  def test_write_to_variable(self):
    foo = tf.constant(0., name='foo')
    tpu_util.write_to_variable(foo)

    with self.assertRaises(ValueError):
      tpu_util.write_to_variable(foo, fail_if_exists=True)

    # Variable sharing behavior should be dictated by `fail_if_exists` which
    # overrides the effect of outer scopes.
    with tf.variable_scope('', reuse=True):
      # should fail to return existing variable even though reuse=True
      with self.assertRaises(ValueError):
        tpu_util.write_to_variable(foo, fail_if_exists=True)

    with tf.variable_scope('', reuse=False):
      # should return existing variable even though reuse=False
      foo_copy = tpu_util.write_to_variable(foo, fail_if_exists=False)
      self.assertEqual(tpu_util.var_store[foo], tpu_util.var_store[foo_copy])
      self.assertLen(set(tpu_util.var_store.values()), 1)

    with tf.variable_scope('', reuse=True):
      # should create new variable even though reuse=True
      bar = tf.constant(0., name='bar')
      tpu_util.write_to_variable(bar)
      self.assertLen(set(tpu_util.var_store.values()), 2)

if __name__ == '__main__':
  tf.test.main()