# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_transform.graph_tools."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import collections

# GOOGLE-INITIALIZATION

import six

import tensorflow as tf
from tensorflow_transform import graph_tools
from tensorflow_transform import test_case

from tensorflow.python.ops import control_flow_ops  # pylint: disable=g-direct-tensorflow-import

mock = tf.compat.v1.test.mock


def _create_lookup_table_from_file(filename):
  initializer = tf.lookup.TextFileInitializer(
      filename,
      key_dtype=tf.string,
      key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
      value_dtype=tf.int64,
      value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
  return tf.lookup.StaticHashTable(initializer, default_value=-1)


def _create_graph_with_y_function_of_x():
  x = tf.compat.v1.placeholder(tf.int64)
  y = x + 1
  return {'x': x, 'y': y}


def _create_graph_with_tf_function():
  x = tf.compat.v1.placeholder(tf.int64)
  y = tf.compat.v1.placeholder(tf.int64)

  @tf.function
  def foo(x, y):
    return x * 2, x + y, y * 2

  a, b, c = foo(x, y)
  return {'x': x, 'y': y, 'z': a + b + c, 'r': a, 'q': foo(x, x)[0]}


def _create_graph_with_placeholder_in_tf_function():
  x = tf.compat.v1.placeholder(tf.int64)

  @tf.function
  def foo(x):
    a = tf.compat.v1.placeholder(tf.int64)
    return x * a, a

  y, a = foo(x + 1)
  return {'x': x, 'y': y + 1, 'z': a}


def _create_graph_with_mixed_dependencies():
  x = tf.compat.v1.placeholder(tf.int64)
  y = tf.compat.v1.placeholder(tf.int64)

  @tf.function
  def foo(x):
    return x * 2

  return {'x': x, 'y': y, 'z': foo(x) + y}


def _create_graph_with_chained_tf_function():
  x = tf.compat.v1.placeholder(tf.int64)

  @tf.function
  def goo(x):
    return x + 1

  @tf.function
  def foo(x):
    return goo(x) * 2

  return {'x': x, 'y': foo(x) / 2}


def _create_graph_with_y_function_of_x_with_unused_inputs():
  x = tf.compat.v1.placeholder(tf.int64)
  x2 = tf.compat.v1.placeholder(tf.int64)
  x_unused = tf.compat.v1.placeholder(tf.int64)
  y = x + 1
  z = x2 + 2
  return {'x': x, 'x2': x2, 'x_unused': x_unused, 'y': y, 'z': z}


def _create_graph_with_y_function_of_x_sparse():
  x = tf.compat.v1.sparse_placeholder(tf.int64)
  y = tf.sparse.reduce_sum(x) + 1
  return {'x': x, 'y': y}


def _create_graph_with_z_function_of_x_ragged():
  x = tf.compat.v1.ragged.placeholder(tf.int64, 2)
  y = x.to_sparse()
  z = tf.sparse.reduce_sum(y) + 1
  return {'x': x, 'y': y, 'z': z}


def _create_graph_with_ragged_tensor():
  x1 = tf.compat.v1.placeholder(tf.int64, (1, 3, 3))
  x2 = tf.compat.v1.sparse.placeholder(tf.int64, (4, 3))
  y1 = tf.RaggedTensor.from_tensor(x1, ragged_rank=2)
  y2 = tf.RaggedTensor.from_sparse(x2) + 1
  return {'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2}


def _create_graph_with_y_sparse_function_of_x_sparse():
  x = tf.compat.v1.sparse_placeholder(tf.int64)
  y = tf.SparseTensor(
      indices=x.indices,
      values=x.values + 1,
      dense_shape=x.dense_shape)
  return {
      'x': x,
      'y': y,
      'z': tf.compat.v1.sparse.add(y, tf.ones(y.dense_shape, tf.int64))
  }


def _create_graph_with_y_function_of_x_and_table():
  filename = tf.raw_ops.Placeholder(dtype=tf.string, shape=())
  table = _create_lookup_table_from_file(filename)
  x = tf.raw_ops.Placeholder(dtype=tf.string, shape=(None,))
  y = table.lookup(x)
  return {'filename': filename, 'x': x, 'y': y}


def _create_graph_with_y_function_of_x_and_table_in_first_phase():
  table = _create_lookup_table_from_file(tf.constant('not_a_file_name_but_ok'))
  x = tf.raw_ops.Placeholder(dtype=tf.string, shape=(None,))
  y = table.lookup(x)
  return {'x': x, 'y': y}


def _create_graph_with_y_function_of_x_and_untracked_table():
  filename = tf.compat.v1.placeholder(tf.string, ())
  table = _create_lookup_table_from_file(filename)

  x = tf.compat.v1.placeholder(tf.string, (None,))
  y = table.lookup(x)
  del tf.compat.v1.get_collection_ref(
      tf.compat.v1.GraphKeys.TABLE_INITIALIZERS)[:]
  return {'filename': filename, 'x': x, 'y': y}


def _create_graph_with_table_initialized_by_table_output():
  filename = tf.compat.v1.placeholder(tf.string, ())
  table1 = _create_lookup_table_from_file(filename)

  # Use output from the first table to initialize the second table.
  keys = ['a', 'b', 'c']
  tensor_keys = tf.as_string(
      table1.lookup(tf.constant(keys, tf.string)))
  initializer2 = tf.lookup.KeyValueTensorInitializer(
      keys=tensor_keys,
      values=tf.range(len(keys), dtype=tf.int64),
      key_dtype=tf.string,
      value_dtype=tf.int64)
  table2 = tf.lookup.StaticHashTable(initializer2, default_value=-1)
  x = tf.compat.v1.placeholder(tf.string, (None,))
  y = table2.lookup(x)
  return {'filename': filename, 'x': x, 'y': y}


def _create_graph_with_assert_equal():
  x = tf.raw_ops.Placeholder(dtype=tf.int64)
  y = tf.raw_ops.Placeholder(dtype=tf.int64)
  z = control_flow_ops.with_dependencies(
      [tf.raw_ops.Assert(condition=tf.raw_ops.Equal(x=x, y=y), data=[x, y])], x)
  return {'x': x, 'y': y, 'z': z}


def _create_graph_with_y_function_of_x_with_tf_while():
  x = tf.raw_ops.Placeholder(dtype=tf.int64, shape=())

  # Subtract 10 from x using a tf.while_loop.
  @tf.function(input_signature=[
      tf.TensorSpec([], tf.int32),
      tf.TensorSpec([], tf.int64)
  ])
  def stop_condition(counter, x_minus_counter):
    del x_minus_counter  # unused
    return tf.less(counter, 10)

  @tf.function(input_signature=[
      tf.TensorSpec([], tf.int32),
      tf.TensorSpec([], tf.int64)
  ])
  def iteration(counter, x_minus_counter):
    return tf.add(counter, 1), tf.add(x_minus_counter, -1)
  initial_values = [tf.constant(0), x]
  final_values = tf.raw_ops.While(
      cond=stop_condition.get_concrete_function(),
      body=iteration.get_concrete_function(),
      input=initial_values)

  y = final_values[1]
  return {'x': x, 'y': y}


def _create_graph_with_tf_function_while():
  x = tf.raw_ops.Placeholder(dtype=tf.float32, shape=())

  @tf.function
  def larger_than_100(x):
    while x < 100:
      x *= 2
    return x

  return {'x': x, 'y': larger_than_100(x)}


class _Matcher(object):

  __metaclass__ = abc.ABCMeta

  def _future_proof(self, value):
    if isinstance(value, (six.text_type, str, bytes)):
      new_to_old = {}
      for new, old in new_to_old.items():
        value = value.replace(new, old)
    return value

  @abc.abstractmethod
  def expected_fields(self, other):
    raise NotImplementedError

  @abc.abstractproperty
  def expected_fields_values(self):
    raise NotImplementedError

  @abc.abstractproperty
  def expected_class(self):
    raise NotImplementedError

  def __eq__(self, other):
    if not isinstance(other, self.expected_class):
      tf.compat.v1.logging.error('Types do not match, got: %s, expected: %s',
                                 type(other), self.expected_class)
      return False

    future_expected_fields = tuple(
        self._future_proof(f) for f in self.expected_fields_values)
    if (self.expected_fields_values != self.expected_fields(other) and
        future_expected_fields != self.expected_fields(other)):
      tf.compat.v1.logging.error('Fields do not match: %s != %s',
                                 self.expected_fields_values,
                                 self.expected_fields(other))
      return False

    return True


class _TensorMatcher(_Matcher, collections.namedtuple('_TensorMatcher',
                                                      ['name'])):

  def expected_fields(self, other):
    return (str(other.name),)

  @property
  def expected_fields_values(self):
    return tuple(self)

  @property
  def expected_class(self):
    return tf.Tensor


class _OpMatcher(_Matcher, collections.namedtuple('_OpMatcher', ['name'])):

  def expected_fields(self, other):
    return (str(other.name),)

  @property
  def expected_fields_values(self):
    return tuple(self)

  @property
  def expected_class(self):
    return tf.Operation


class GraphToolsTest(test_case.TransformTestCase):

  @test_case.named_parameters(
      dict(
          testcase_name='_y_function_of_x_nothing_ready',
          create_graph_fn=_create_graph_with_y_function_of_x,
          feeds=[],
          replaced_tensors_ready={'x': False},
          should_be_ready={'y': False},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_unused_input_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_with_unused_inputs,
          feeds=[],
          replaced_tensors_ready={
              'x': False,
              'x2': True,
              'x_unused': True
          },
          should_be_ready={
              'y': False,
              'z': True
          },
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_no_feeds_y_is_ready',
          create_graph_fn=_create_graph_with_y_function_of_x,
          feeds=[],
          replaced_tensors_ready={'x': True},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_feeds_x_y_is_ready',
          create_graph_fn=_create_graph_with_y_function_of_x,
          feeds=['x'],
          replaced_tensors_ready={},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_sparse_nothing_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_sparse,
          feeds=[],
          replaced_tensors_ready={'x': False},
          should_be_ready={'y': False},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_sparse_no_feeds_y_is_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_sparse,
          feeds=[],
          replaced_tensors_ready={'x': True},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_sparse_feeds_x_y_is_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_sparse,
          feeds=['x'],
          replaced_tensors_ready={},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_sparse_function_of_x_sparse_nothing_ready',
          create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse,
          feeds=[],
          replaced_tensors_ready={'x': False},
          should_be_ready={'y': False},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_sparse_function_of_x_sparse_no_feeds_y_is_ready',
          create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse,
          feeds=[],
          replaced_tensors_ready={'x': True},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_sparse_function_of_x_sparse_feeds_x_y_is_ready',
          create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse,
          feeds=['x'],
          replaced_tensors_ready={},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_with_tf_while_nothing_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while,
          feeds=[],
          replaced_tensors_ready={'x': False},
          should_be_ready={'y': False},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_with_tf_while_no_feeds_y_is_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while,
          feeds=[],
          replaced_tensors_ready={'x': True},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_with_tf_while_feeds_x_y_is_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while,
          feeds=['x'],
          replaced_tensors_ready={},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_and_table_nothing_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_and_table,
          feeds=[],
          replaced_tensors_ready={
              'x': False,
              'filename': False
          },
          should_be_ready={'y': False},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_and_table_filename_ready_y_is_not',
          create_graph_fn=_create_graph_with_y_function_of_x_and_table,
          feeds=[],
          replaced_tensors_ready={
              'x': False,
              'filename': True
          },
          should_be_ready={'y': False},
          num_ready_table_initializers=1),
      dict(
          testcase_name='_y_function_of_x_and_table_x_ready_filename_is_not',
          create_graph_fn=_create_graph_with_y_function_of_x_and_table,
          feeds=[],
          replaced_tensors_ready={
              'x': True,
              'filename': False
          },
          should_be_ready={'y': False},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_and_table_everything_is_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_and_table,
          feeds=[],
          replaced_tensors_ready={
              'x': True,
              'filename': True,
          },
          should_be_ready={'y': True},
          num_ready_table_initializers=1),
      dict(
          testcase_name='_y_function_of_x_and_table_feeds_x_nothing_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_and_table,
          feeds=['x'],
          replaced_tensors_ready={'filename': False},
          should_be_ready={'y': False},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_y_function_of_x_and_table_feeds_x_everything_ready',
          create_graph_fn=_create_graph_with_y_function_of_x_and_table,
          feeds=['x'],
          replaced_tensors_ready={'filename': True},
          should_be_ready={'y': True},
          num_ready_table_initializers=1),
      dict(
          testcase_name='_assert_equal',
          create_graph_fn=_create_graph_with_assert_equal,
          feeds=['x', 'y'],
          replaced_tensors_ready={
              'x': True,
              'y': True,
          },
          should_be_ready={'z': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_tf_function',
          create_graph_fn=_create_graph_with_tf_function,
          feeds=['x', 'y'],
          replaced_tensors_ready={},
          should_be_ready={'z': True},
          num_ready_table_initializers=0),
      dict(
          testcase_name='_tf_function_not_ready',
          create_graph_fn=_create_graph_with_tf_function,
          feeds=[],
          replaced_tensors_ready={
              'x': True,
              'y': False,
          },
          should_be_ready={
              'z': False,
              'q': True,
          },
          num_ready_table_initializers=0),
      dict(
          testcase_name='_chained_tf_function',
          create_graph_fn=_create_graph_with_chained_tf_function,
          feeds=['x'],
          replaced_tensors_ready={},
          should_be_ready={'y': True},
          num_ready_table_initializers=0),
  )
  def testDetermineReadyTensorsAndTableInitializers(
      self, create_graph_fn, feeds, replaced_tensors_ready, should_be_ready,
      num_ready_table_initializers):
    """Test determine_ready_tensors_and_table_initializers.

    Args:
      create_graph_fn: A function that adds ops to a graph and returns a dict
          mapping tensor names to `Tensor` or `SparseTensor`s.
      feeds: A list of keys in the dict returned by create_graph_fn that are fed
          in the main run (but not table initialization run).
      replaced_tensors_ready: A dict whose keys are keys in the dict returned by
          create_graph_fn and values are a bools indicating whether that tensor
          is ready to be replaced in this phase.
      should_be_ready: A dict whose keys are keys in the dict returned by
          create_graph_fn and value are bools indicating whether a tensor can be
          calculated in this phase.
      num_ready_table_initializers: The number of table initializers that are
          ready to run in the table initialization run of this phase.
    """
    with tf.compat.v1.Graph().as_default() as graph:
      tensors = create_graph_fn()
    replaced_tensors_ready = [(tensors[name], ready)
                              for name, ready in replaced_tensors_ready.items()]

    graph_analyzer = graph_tools.InitializableGraphAnalyzer(
        graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready)
    self.assertEqual(
        len(graph_analyzer.ready_table_initializers),
        num_ready_table_initializers)

    for name, ready in should_be_ready.items():
      tensor = tensors[name]
      self.assertEqual(
          graph_analyzer.ready_to_run(tensor),
          ready,
          msg='Expected tensor {} to be ready={}'.format(name, ready))

  @test_case.parameters(
      (_create_graph_with_y_function_of_x_and_table,
       [], {'x': False},
       'placeholders will not be fed during table initialization'),
      (_create_graph_with_y_function_of_x_and_table,
       [], {'x': True},
       'placeholders will not be fed during table initialization'),
      (_create_graph_with_y_function_of_x_and_table,
       ['filename'], {'x': False},
       'placeholders will not be fed during table initialization'),
      (_create_graph_with_y_function_of_x_and_table,
       ['filename'], {'x': True},
       'placeholders will not be fed during table initialization'),
      (_create_graph_with_y_function_of_x_and_table,
       ['filename', 'x'], {},
       'placeholders will not be fed during table initialization'),
      (_create_graph_with_table_initialized_by_table_output,
       ['x'], {'filename': True},
       'tables are initialized in one pass')
  )
  def testInitializableGraphAnalyzerConstructorRaises(
      self, create_graph_fn, feeds, replaced_tensors_ready,
      error_msg_regex):
    """Test determine_ready_tensors_and_table_initializers.

    Args:
      create_graph_fn: A function that adds ops to a graph and returns a dict
          mapping tensor names to `Tensor` or `SparseTensor`s.
      feeds: A list of keys in the dict returned by create_graph_fn that are fed
          in the main run (but not table initialization run).
      replaced_tensors_ready: A dict whose keys are keys in the dict returned by
          create_graph_fn and values are a bools indicating whether that tensor
          is ready to be replaced in this phase.
      error_msg_regex: The expected error message.
    """
    with tf.compat.v1.Graph().as_default() as graph:
      tensors = create_graph_fn()
    replaced_tensors_ready = [(tensors[name], ready)
                              for name, ready in replaced_tensors_ready.items()]
    with self.assertRaisesRegexp(ValueError, error_msg_regex):
      graph_tools.InitializableGraphAnalyzer(graph,
                                             {x: tensors[x] for x in feeds},
                                             replaced_tensors_ready)

  @test_case.parameters(
      (_create_graph_with_y_function_of_x, [], {}, 'y',
       'may have be caused by manually adding a placeholder to the graph'),
      (_create_graph_with_placeholder_in_tf_function, ['x'], {}, 'z',
       r'that is part of a tf.function graph \(foo\), this is not supported. '
       'This may be a result of calling a tf.Transform analyzer in a '
       'tf.function'),
      (_create_graph_with_y_function_of_x_and_untracked_table, ['x'], {
          'filename': True
      }, 'y', 'may be caused by adding an initializable table without'),
  )
  def testInitializableGraphAnalyzerReadyToRunRaises(
      self, create_graph_fn, feeds, replaced_tensors_ready, fetch,
      error_msg_regex):
    """Test determine_ready_tensors_and_table_initializers.

    Args:
      create_graph_fn: A function that adds ops to a graph and returns a dict
          mapping tensor names to `Tensor` or `SparseTensor`s.
      feeds: A list of keys in the dict returned by create_graph_fn that are fed
          in the main run (but not table initialization run).
      replaced_tensors_ready: A dict whose keys are keys in the dict returned by
          create_graph_fn and values are a bools indicating whether that tensor
          is ready to be replaced in this phase.
      fetch: The tensor to fetch.  Should be a key in the dict returned by
          create_graph_fn.
      error_msg_regex: The expected error message.
    """
    with tf.compat.v1.Graph().as_default() as graph:
      tensors = create_graph_fn()
    replaced_tensors_ready = [(
        tensors[name], ready) for name, ready in replaced_tensors_ready.items()]
    graph_analyzer = graph_tools.InitializableGraphAnalyzer(
        graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready)
    with self.assertRaisesRegexp(ValueError, error_msg_regex):
      tensor = tensors[fetch]
      graph_analyzer.ready_to_run(tensor)

  @test_case.named_parameters(
      dict(
          testcase_name='_y_function_of_x',
          create_graph_fn=_create_graph_with_y_function_of_x,
          feeds=['x'],
          fetches=['y'],
          expected_dependent_inputs=['x']),
      dict(
          testcase_name='_tf_function',
          create_graph_fn=_create_graph_with_tf_function,
          feeds=['x', 'y'],
          fetches=['z'],
          expected_dependent_inputs=['x', 'y']),
      dict(
          testcase_name='_tf_function_signature_forces_dependencies',
          create_graph_fn=_create_graph_with_tf_function,
          feeds=['x', 'y'],
          fetches=['r'],
          expected_dependent_inputs=['x', 'y']),
      dict(
          testcase_name='_tf_function_mixed_dependencies',
          create_graph_fn=_create_graph_with_mixed_dependencies,
          feeds=['x', 'y'],
          fetches=['z'],
          expected_dependent_inputs=['x', 'y']),
      dict(
          testcase_name='_chained_tf_function',
          create_graph_fn=_create_graph_with_chained_tf_function,
          feeds=['x'],
          fetches=['y'],
          expected_dependent_inputs=['x']),
      dict(
          testcase_name='_y_function_of_x_with_unused_inputs',
          create_graph_fn=_create_graph_with_y_function_of_x_with_unused_inputs,
          feeds=['x', 'x2', 'x_unused'],
          fetches=['y', 'z'],
          expected_dependent_inputs=['x', 'x2']),
      dict(
          testcase_name='_y_function_of_sparse_x',
          create_graph_fn=_create_graph_with_y_function_of_x_sparse,
          feeds=['x'],
          fetches=['y'],
          expected_dependent_inputs=['x']),
      dict(
          testcase_name='_y_sparse_function_of_sparse_x',
          create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse,
          feeds=['x'],
          fetches=['y'],
          expected_dependent_inputs=['x']),
      dict(
          testcase_name='_y_function_of_ragged_x',
          create_graph_fn=_create_graph_with_ragged_tensor,
          feeds=['x1', 'x2'],
          fetches=['y1', 'y2'],
          expected_dependent_inputs=['x1', 'x2']),
      dict(
          testcase_name='_z_function_of_x_ragged',
          create_graph_fn=_create_graph_with_z_function_of_x_ragged,
          feeds=['x'],
          fetches=['y', 'z'],
          expected_dependent_inputs=['x']),
      dict(
          testcase_name='z_function_of_x_y_with_control_dependencies',
          create_graph_fn=_create_graph_with_assert_equal,
          feeds=['x', 'y'],
          fetches=['z'],
          expected_dependent_inputs=['x', 'y']),
      dict(
          testcase_name='_y_function_of_x_with_tf_while',
          create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while,
          feeds=['x'],
          fetches=['y'],
          expected_dependent_inputs=['x']),
  )
  def testGetDependentInputs(self, create_graph_fn, feeds, fetches,
                             expected_dependent_inputs):
    with tf.compat.v1.Graph().as_default() as graph:
      tensors = create_graph_fn()
    got = graph_tools.get_dependent_inputs(graph,
                                           {x: tensors[x] for x in feeds},
                                           {y: tensors[y] for y in fetches})
    self.assertCountEqual(expected_dependent_inputs, got.keys())
    for input_name in expected_dependent_inputs:
      self.assertEqual(tensors[input_name], got[input_name])


class GraphToolsTestUniquePath(test_case.TransformTestCase):

  @test_case.named_parameters(
      dict(
          testcase_name='_y_function_of_x',
          create_graph_fn=_create_graph_with_y_function_of_x,
          feeds=['x'],
          replaced_tensors_ready={'x': False},
          expected_calls_dict={
              'x': [mock.call('x$tensor'),],
              'y': [
                  mock.call(_OpMatcher('add/y'), parents=[]),
                  mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']),
                  mock.call('x$tensor'),
                  mock.call(
                      _OpMatcher('add'), parents=['x$tensor', u'add/y:0']),
                  mock.call(_TensorMatcher('add:0'), parents=[u'add']),
              ]
          }),
      dict(
          testcase_name='_y_function_of_x_and_tf_function',
          create_graph_fn=_create_graph_with_tf_function,
          feeds=['x', 'y'],
          replaced_tensors_ready={
              'x': False,
              'y': False
          },
          expected_calls_dict={
              'x': [mock.call('x$tensor'),],
              'y': [mock.call('y$tensor'),],
              'z': [
                  mock.call('y$tensor'),
                  mock.call('x$tensor'),
                  mock.call(_OpMatcher('mul/y'), parents=[]),
                  mock.call(_TensorMatcher('mul/y:0'), parents=[u'mul/y']),
                  mock.call('FuncGraphInput[0]'),
                  mock.call(
                      _OpMatcher('mul'),
                      parents=['FuncGraphInput[0]', u'mul/y:0']),
                  mock.call(_TensorMatcher('mul:0'), parents=[u'mul']),
                  mock.call(_OpMatcher('Identity'), parents=[u'mul:0']),
                  mock.call(
                      _TensorMatcher('Identity:0'), parents=[u'Identity']),
                  mock.call('FuncGraphInput[1]'),
                  mock.call(
                      _OpMatcher('add'),
                      parents=['FuncGraphInput[0]', 'FuncGraphInput[1]']),
                  mock.call(_TensorMatcher('add:0'), parents=[u'add']),
                  mock.call(_OpMatcher('Identity_1'), parents=[u'add:0']),
                  mock.call(
                      _TensorMatcher('Identity_1:0'), parents=[u'Identity_1']),
                  mock.call(_OpMatcher('mul_1/y'), parents=[]),
                  mock.call(_TensorMatcher('mul_1/y:0'), parents=[u'mul_1/y']),
                  mock.call(
                      _OpMatcher('mul_1'),
                      parents=['FuncGraphInput[1]', u'mul_1/y:0']),
                  mock.call(_TensorMatcher('mul_1:0'), parents=[u'mul_1']),
                  mock.call(_OpMatcher('Identity_2'), parents=[u'mul_1:0']),
                  mock.call(
                      _TensorMatcher('Identity_2:0'), parents=[u'Identity_2']),
                  mock.call(
                      _OpMatcher('PartitionedCall'),
                      parents=[
                          'x$tensor', 'y$tensor', u'Identity:0',
                          u'Identity_1:0', u'Identity_2:0'
                      ]),
                  mock.call(
                      _TensorMatcher('PartitionedCall:2'),
                      parents=[u'PartitionedCall']),
                  mock.call(
                      _TensorMatcher('PartitionedCall:1'),
                      parents=[u'PartitionedCall']),
                  mock.call(
                      _TensorMatcher('PartitionedCall:0'),
                      parents=[u'PartitionedCall']),
                  mock.call(
                      _OpMatcher('add'),
                      parents=[u'PartitionedCall:0', u'PartitionedCall:1']),
                  mock.call(_TensorMatcher('add:0'), parents=[u'add']),
                  mock.call(
                      _OpMatcher('add_1'),
                      parents=[u'add:0', u'PartitionedCall:2']),
                  mock.call(_TensorMatcher('add_1:0'), parents=[u'add_1']),
              ]
          }),
      dict(
          testcase_name='_y_function_of_x_and_chained_tf_function',
          create_graph_fn=_create_graph_with_chained_tf_function,
          feeds=['x'],
          replaced_tensors_ready={'x': False},
          expected_calls_dict={
              'x': [mock.call('x$tensor'),],
              'y': [
                  mock.call(_OpMatcher('truediv/y'), parents=[]),
                  mock.call(
                      _TensorMatcher('truediv/y:0'), parents=[u'truediv/y']),
                  mock.call(
                      _OpMatcher('truediv/Cast_1'), parents=[u'truediv/y:0']),
                  mock.call(
                      _TensorMatcher('truediv/Cast_1:0'),
                      parents=[u'truediv/Cast_1']),
                  mock.call('x$tensor'),
                  mock.call(_OpMatcher('mul/y'), parents=[]),
                  mock.call(_TensorMatcher('mul/y:0'), parents=[u'mul/y']),
                  mock.call('FuncGraphInput[0]'),
                  mock.call(_OpMatcher('add/y'), parents=[]),
                  mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']),
                  mock.call('FuncGraphInput[0]'),
                  mock.call(
                      _OpMatcher('add'),
                      parents=['FuncGraphInput[0]', u'add/y:0']),
                  mock.call(_TensorMatcher('add:0'), parents=[u'add']),
                  mock.call(_OpMatcher('Identity'), parents=[u'add:0']),
                  mock.call(
                      _TensorMatcher('Identity:0'), parents=[u'Identity']),
                  mock.call(
                      _OpMatcher('PartitionedCall'),
                      parents=['FuncGraphInput[0]', u'Identity:0']),
                  mock.call(
                      _TensorMatcher('PartitionedCall:0'),
                      parents=[u'PartitionedCall']),
                  mock.call(
                      _OpMatcher('mul'),
                      parents=[u'PartitionedCall:0', u'mul/y:0']),
                  mock.call(_TensorMatcher('mul:0'), parents=[u'mul']),
                  mock.call(_OpMatcher('Identity'), parents=[u'mul:0']),
                  mock.call(
                      _TensorMatcher('Identity:0'), parents=[u'Identity']),
                  mock.call(
                      _OpMatcher('PartitionedCall'),
                      parents=['x$tensor', u'Identity:0']),
                  mock.call(
                      _TensorMatcher('PartitionedCall:0'),
                      parents=[u'PartitionedCall']),
                  mock.call(
                      _OpMatcher('truediv/Cast'),
                      parents=[u'PartitionedCall:0']),
                  mock.call(
                      _TensorMatcher('truediv/Cast:0'),
                      parents=[u'truediv/Cast']),
                  mock.call(
                      _OpMatcher('truediv'),
                      parents=[u'truediv/Cast:0', u'truediv/Cast_1:0']),
                  mock.call(_TensorMatcher('truediv:0'), parents=[u'truediv']),
              ],
          }),
      dict(
          testcase_name='_y_function_of_x_sparse',
          create_graph_fn=_create_graph_with_y_function_of_x_sparse,
          feeds=['x'],
          replaced_tensors_ready={'x': False},
          expected_calls_dict={
              'y': [
                  mock.call(_OpMatcher('add/y'), parents=[]),
                  mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']),
                  mock.call(_OpMatcher('range/delta'), parents=[]),
                  mock.call(
                      _TensorMatcher('range/delta:0'),
                      parents=[u'range/delta']),
                  mock.call('x$composite_tensor_2'),
                  mock.call(
                      _OpMatcher('Rank'), parents=['x$composite_tensor_2']),
                  mock.call(_TensorMatcher('Rank:0'), parents=[u'Rank']),
                  mock.call(_OpMatcher('range/start'), parents=[]),
                  mock.call(
                      _TensorMatcher('range/start:0'),
                      parents=[u'range/start']),
                  mock.call(
                      _OpMatcher('range'),
                      parents=[u'range/start:0', u'Rank:0', u'range/delta:0']),
                  mock.call(_TensorMatcher('range:0'), parents=[u'range']),
                  mock.call('x$composite_tensor_1'),
                  mock.call('x$composite_tensor_0'),
                  mock.call(
                      _OpMatcher('SparseReduceSum'),
                      parents=[
                          'x$composite_tensor_0', 'x$composite_tensor_1',
                          'x$composite_tensor_2', u'range:0'
                      ]),
                  mock.call(
                      _TensorMatcher('SparseReduceSum:0'),
                      parents=[u'SparseReduceSum']),
                  mock.call(
                      _OpMatcher('add'),
                      parents=[u'SparseReduceSum:0', u'add/y:0']),
                  mock.call(_TensorMatcher('add:0'), parents=[u'add']),
              ]
          }),
      dict(
          testcase_name='_y_sparse_function_of_x_sparse',
          create_graph_fn=_create_graph_with_y_sparse_function_of_x_sparse,
          feeds=['x'],
          replaced_tensors_ready={'x': False},
          expected_calls_dict={
              'z': [
                  mock.call(_OpMatcher('ones/Const'), parents=[]),
                  mock.call(
                      _TensorMatcher('ones/Const:0'), parents=[u'ones/Const']),
                  mock.call('x$composite_tensor_2'),
                  mock.call(
                      _OpMatcher('ones'),
                      parents=['x$composite_tensor_2', u'ones/Const:0']),
                  mock.call(_TensorMatcher('ones:0'), parents=[u'ones']),
                  mock.call(_OpMatcher('add/y'), parents=[]),
                  mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']),
                  mock.call('x$composite_tensor_1'),
                  mock.call(
                      _OpMatcher('add'),
                      parents=['x$composite_tensor_1', u'add/y:0']),
                  mock.call(_TensorMatcher('add:0'), parents=[u'add']),
                  mock.call('x$composite_tensor_0'),
                  mock.call(
                      _OpMatcher('SparseTensorDenseAdd'),
                      parents=[
                          'x$composite_tensor_0', u'add:0',
                          'x$composite_tensor_2', u'ones:0'
                      ]),
                  mock.call(
                      _TensorMatcher('SparseTensorDenseAdd:0'),
                      parents=[u'SparseTensorDenseAdd']),
              ],
          }),
      dict(
          testcase_name='_z_function_of_x_ragged',
          create_graph_fn=_create_graph_with_z_function_of_x_ragged,
          feeds=['x'],
          replaced_tensors_ready={'x': False},
          expected_calls_dict={
              'z': [
                  mock.call(_OpMatcher('add/y'), parents=[]),
                  mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']),
                  mock.call(_OpMatcher('range/delta'), parents=[]),
                  mock.call(
                      _TensorMatcher('range/delta:0'),
                      parents=[u'range/delta']),
                  mock.call('x$composite_tensor_0'),
                  mock.call('x$composite_tensor_2'),
                  mock.call('x$composite_tensor_1'),
                  mock.call(
                      _OpMatcher('RaggedToSparse/RaggedTensorToSparse'),
                      parents=[
                          'x$composite_tensor_1', 'x$composite_tensor_2',
                          'x$composite_tensor_0'
                      ]),
                  mock.call(
                      _TensorMatcher('RaggedToSparse/RaggedTensorToSparse:2'),
                      parents=['RaggedToSparse/RaggedTensorToSparse']),
                  mock.call(
                      _OpMatcher('Rank'),
                      parents=['RaggedToSparse/RaggedTensorToSparse:2']),
                  mock.call(_TensorMatcher('Rank:0'), parents=[u'Rank']),
                  mock.call(_OpMatcher('range/start'), parents=[]),
                  mock.call(
                      _TensorMatcher('range/start:0'),
                      parents=[u'range/start']),
                  mock.call(
                      _OpMatcher('range'),
                      parents=[u'range/start:0', u'Rank:0', u'range/delta:0']),
                  mock.call(_TensorMatcher('range:0'), parents=[u'range']),
                  mock.call(
                      _TensorMatcher('RaggedToSparse/RaggedTensorToSparse:1'),
                      parents=['RaggedToSparse/RaggedTensorToSparse']),
                  mock.call(
                      _TensorMatcher('RaggedToSparse/RaggedTensorToSparse:0'),
                      parents=['RaggedToSparse/RaggedTensorToSparse']),
                  mock.call(
                      _OpMatcher('SparseReduceSum'),
                      parents=[
                          'RaggedToSparse/RaggedTensorToSparse:0',
                          'RaggedToSparse/RaggedTensorToSparse:1',
                          'RaggedToSparse/RaggedTensorToSparse:2', u'range:0'
                      ]),
                  mock.call(
                      _TensorMatcher('SparseReduceSum:0'),
                      parents=[u'SparseReduceSum']),
                  mock.call(
                      _OpMatcher('add'),
                      parents=[u'SparseReduceSum:0', u'add/y:0']),
                  mock.call(_TensorMatcher('add:0'), parents=[u'add']),
              ],
          }),
      dict(
          testcase_name='_y_function_of_x_with_raw_ops_while',
          skip_test_check_fn=test_case.skip_if_external_environment,
          create_graph_fn=_create_graph_with_y_function_of_x_with_tf_while,
          feeds=['x'],
          replaced_tensors_ready={'x': False},
          expected_calls_dict={
              'y': [
                  mock.call('x$tensor'),
                  mock.call(_OpMatcher('Const'), parents=[]),
                  mock.call(_TensorMatcher('Const:0'), parents=[u'Const']),
                  mock.call(_OpMatcher('Less/y'), parents=[]),
                  mock.call(_TensorMatcher('Less/y:0'), parents=[u'Less/y']),
                  mock.call('FuncGraphInput[0]'),
                  mock.call(
                      _OpMatcher('Less'),
                      parents=[u'FuncGraphInput[0]', 'Less/y:0']),
                  mock.call(_TensorMatcher('Less:0'), parents=[u'Less']),
                  mock.call(_OpMatcher('Identity'), parents=[u'Less:0']),
                  mock.call(
                      _TensorMatcher('Identity:0'), parents=[u'Identity']),
                  mock.call(_OpMatcher('Add/y'), parents=[]),
                  mock.call(_TensorMatcher('Add/y:0'), parents=[u'Add/y']),
                  mock.call('FuncGraphInput[0]'),
                  mock.call(
                      _OpMatcher('Add'),
                      parents=[u'FuncGraphInput[0]', 'Add/y:0']),
                  mock.call(_TensorMatcher('Add:0'), parents=[u'Add']),
                  mock.call(_OpMatcher('Identity'), parents=[u'Add:0']),
                  mock.call(
                      _TensorMatcher('Identity:0'), parents=[u'Identity']),
                  mock.call(_OpMatcher('Add_1/y'), parents=[]),
                  mock.call(_TensorMatcher('Add_1/y:0'), parents=[u'Add_1/y']),
                  mock.call('FuncGraphInput[1]'),
                  mock.call(
                      _OpMatcher('Add_1'),
                      parents=[u'FuncGraphInput[1]', 'Add_1/y:0']),
                  mock.call(_TensorMatcher('Add_1:0'), parents=[u'Add_1']),
                  mock.call(_OpMatcher('Identity_1'), parents=[u'Add_1:0']),
                  mock.call(
                      _TensorMatcher('Identity_1:0'), parents=[u'Identity_1']),
                  mock.call(
                      _OpMatcher('While'),
                      parents=[
                          u'Const:0', 'x$tensor', 'Identity:0', 'Identity:0',
                          'Identity_1:0'
                      ]),
                  mock.call(_TensorMatcher('While:1'), parents=[u'While']),
              ],
          }),
      dict(
          testcase_name='_y_function_of_x_with_tf_while',
          skip_test_check_fn=test_case.skip_if_external_environment,
          create_graph_fn=_create_graph_with_tf_function_while,
          feeds=['x'],
          replaced_tensors_ready={'x': False},
          expected_calls_dict={
              'y': [
                  mock.call('x$tensor'),
                  mock.call('FuncGraphInput[0]'),
                  mock.call(_OpMatcher('while/maximum_iterations'), parents=[]),
                  mock.call(
                      _TensorMatcher('while/maximum_iterations:0'),
                      parents=[u'while/maximum_iterations']),
                  mock.call(_OpMatcher('while/loop_counter'), parents=[]),
                  mock.call(
                      _TensorMatcher('while/loop_counter:0'),
                      parents=[u'while/loop_counter']),
                  mock.call(_OpMatcher('Less/y'), parents=[]),
                  mock.call(_TensorMatcher('Less/y:0'), parents=[u'Less/y']),
                  mock.call('FuncGraphInput[2]'),
                  mock.call(
                      _OpMatcher('Less'),
                      parents=['FuncGraphInput[2]', u'Less/y:0']),
                  mock.call(_TensorMatcher('Less:0'), parents=[u'Less']),
                  mock.call(_OpMatcher('Identity'), parents=[u'Less:0']),
                  mock.call(
                      _TensorMatcher('Identity:0'), parents=[u'Identity']),
                  mock.call(_OpMatcher('add/y'), parents=[]),
                  mock.call(_TensorMatcher('add/y:0'), parents=[u'add/y']),
                  mock.call('FuncGraphInput[0]'),
                  mock.call(
                      _OpMatcher('add'),
                      parents=['FuncGraphInput[0]', u'add/y:0']),
                  mock.call(_TensorMatcher('add:0'), parents=[u'add']),
                  mock.call(_OpMatcher('Identity'), parents=[u'add:0']),
                  mock.call(
                      _TensorMatcher('Identity:0'), parents=[u'Identity']),
                  mock.call('FuncGraphInput[1]'),
                  mock.call(
                      _OpMatcher('Identity_1'), parents=['FuncGraphInput[1]']),
                  mock.call(
                      _TensorMatcher('Identity_1:0'), parents=[u'Identity_1']),
                  mock.call(_OpMatcher('mul/y'), parents=[]),
                  mock.call(_TensorMatcher('mul/y:0'), parents=[u'mul/y']),
                  mock.call('FuncGraphInput[2]'),
                  mock.call(
                      _OpMatcher('mul'),
                      parents=['FuncGraphInput[2]', u'mul/y:0']),
                  mock.call(_TensorMatcher('mul:0'), parents=[u'mul']),
                  mock.call(_OpMatcher('Identity_2'), parents=[u'mul:0']),
                  mock.call(
                      _TensorMatcher('Identity_2:0'), parents=[u'Identity_2']),
                  mock.call(
                      _OpMatcher('while'),
                      parents=[
                          u'while/loop_counter:0',
                          u'while/maximum_iterations:0', 'FuncGraphInput[0]',
                          u'Identity:0', u'Identity:0', u'Identity_1:0',
                          u'Identity_2:0'
                      ]),
                  mock.call(_TensorMatcher('while:2'), parents=[u'while']),
                  mock.call(_OpMatcher('Identity'), parents=[u'while:2']),
                  mock.call(
                      _TensorMatcher('Identity:0'), parents=[u'Identity']),
                  mock.call(
                      _OpMatcher('PartitionedCall'),
                      parents=['x$tensor', u'Identity:0']),
                  mock.call(
                      _TensorMatcher('PartitionedCall:0'),
                      parents=[u'PartitionedCall']),
              ],
          }),
      dict(
          testcase_name='_y_function_of_x_and_table',
          create_graph_fn=_create_graph_with_y_function_of_x_and_table_in_first_phase,
          feeds=['x'],
          replaced_tensors_ready={'x': False},
          expected_calls_dict={
              'x': [
                  mock.call(_OpMatcher('Const'), parents=[]),
                  mock.call(_TensorMatcher('Const:0'), parents=['Const']),
                  mock.call(_OpMatcher('hash_table'), parents=['Const:0']),
                  mock.call('x$tensor'),
              ],
              'y': [
                  mock.call(_OpMatcher('Const'), parents=[]),
                  mock.call(_TensorMatcher('Const:0'), parents=['Const']),
                  mock.call(_OpMatcher('hash_table'), parents=['Const:0']),
                  mock.call(_OpMatcher('Const_1'), parents=[]),
                  mock.call(_TensorMatcher('Const_1:0'), parents=['Const_1']),
                  mock.call('x$tensor'),
                  mock.call('hash_table'),
                  mock.call(
                      _TensorMatcher('hash_table:0'), parents=['hash_table']),
                  mock.call(
                      _OpMatcher('hash_table_Lookup/LookupTableFindV2'),
                      parents=['hash_table:0', 'x$tensor', 'Const_1:0']),
                  mock.call(
                      _TensorMatcher('hash_table_Lookup/LookupTableFindV2:0'),
                      parents=['hash_table_Lookup/LookupTableFindV2']),
              ],
          }),
      dict(
          testcase_name='_with_assert_equal',
          create_graph_fn=_create_graph_with_assert_equal,
          feeds=['x', 'y'],
          replaced_tensors_ready={
              'x': False,
              'y': False
          },
          expected_calls_dict={
              'x': [mock.call('x$tensor'),],
              'y': [mock.call('y$tensor'),],
              'z': [
                  mock.call('y$tensor'),
                  mock.call('x$tensor'),
                  mock.call(
                      _OpMatcher('Equal'), parents=['x$tensor', 'y$tensor']),
                  mock.call(_TensorMatcher('Equal:0'), parents=[u'Equal']),
                  mock.call(
                      _OpMatcher('Assert'),
                      parents=[u'Equal:0', 'x$tensor', 'y$tensor']),
                  mock.call(
                      _OpMatcher('control_dependency'),
                      parents=['x$tensor', u'Assert']),
                  mock.call(
                      _TensorMatcher('control_dependency:0'),
                      parents=[u'control_dependency']),
              ]
          }),
  )
  def testGetUniquePath(self,
                        create_graph_fn,
                        feeds,
                        replaced_tensors_ready,
                        expected_calls_dict,
                        skip_test_check_fn=None):

    # TODO(b/138934800): Remove this once TF 1.15 has the same results in all
    # environments.
    if skip_test_check_fn:
      skip_test_check_fn('This test is not currently supported.')

    with tf.compat.v1.Graph().as_default() as graph:
      tensors = create_graph_fn()
    replaced_tensors_ready = [(tensors[name], ready)
                              for name, ready in replaced_tensors_ready.items()]
    for name in expected_calls_dict:

      # This is used to construct the debugging string below.
      actual_needed_matchers_to_pass = []

      def describe_path_fn(x, parents=None):
        if parents is None:
          parents_str = ''
        else:
          parents_str = ', parents={}'.format(
              list(map(_value_to_matcher, parents)))
        actual_needed_matchers_to_pass.append('({}{}),'.format(  # pylint: disable=cell-var-from-loop
            _value_to_matcher(x, True), parents_str))

        if isinstance(x, tf.Operation):
          return x.node_def.name
        if isinstance(x, tf.Tensor):
          self.assertLessEqual(len(parents), 1)
          return x.name
        if isinstance(x, (six.text_type, str, bytes)):
          return x
        raise ValueError('Unexpected type: {}'.format(x))

      path_cb_mock = mock.MagicMock(side_effect=describe_path_fn)

      graph_analyzer = graph_tools.InitializableGraphAnalyzer(
          graph, {x: tensors[x] for x in feeds}, replaced_tensors_ready,
          path_cb_mock)

      graph_analyzer.get_unique_path(tensors[name])

      try:
        path_cb_mock.assert_has_calls(expected_calls_dict[name])
        self.assertEqual(
            path_cb_mock.call_count, len(expected_calls_dict[name]),
            'Number of expected calls != number of actual calls for {}: {}'
            .format(name, path_cb_mock.call_args_list))
      except AssertionError:
        tf.compat.v1.logging.error(
            'The following is a list of matchers for {}:\n{}'.format(
                name, '\n'.join(actual_needed_matchers_to_pass)))
        raise


def _value_to_matcher(value, add_quotes=False):
  """Returns a matcher for the value - used for debugging failures."""
  if isinstance(value, tf.Operation):
    return _OpMatcher(str(value.node_def.name))
  if isinstance(value, tf.Tensor):
    return _TensorMatcher(str(value.name))
  if isinstance(value, (six.text_type, str, bytes)):
    if add_quotes:
      return '\'{}\''.format(value)
    else:
      return value
  raise ValueError('Cannot get a matcher for: {}, {}'.format(
      type(value), value))


if __name__ == '__main__':
  test_case.main()