# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow testing subclass to automate numerical testing.

Reference tests determine when behavior deviates from some "gold standard," and
are useful for determining when layer definitions have changed without
performing full regression testing, which is generally prohibitive. This class
handles the symbolic graph comparison as well as loading weights to avoid
relying on random number generation, which can change.

The tests performed by this class are:

1) Compare a generated graph against a reference graph. Differences are not
   necessarily fatal.
2) Attempt to load known weights for the graph. If this step succeeds but
   changes are present in the graph, a warning is issued but does not raise
   an exception.
3) Perform a calculation and compare the result to a reference value.

This class also provides a method to generate reference data.

Note:
  The test class is responsible for fixing the random seed during graph
  definition. A convenience method name_to_seed() is provided to make this
  process easier.

The test class should also define a .regenerate() class method which (usually)
just calls the op definition function with test=False for all relevant tests.

A concise example of this class in action is provided in:
  official/utils/testing/reference_data_test.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import hashlib
import json
import os
import shutil
import sys

import numpy as np
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow


class BaseTest(tf.test.TestCase):
  """TestCase subclass for performing reference data tests."""

  def regenerate(self):
    """Subclasses should override this function to generate a new reference."""
    raise NotImplementedError

  @property
  def test_name(self):
    """Subclass should define its own name."""
    raise NotImplementedError

  @property
  def data_root(self):
    """Use the subclass directory rather than the parent directory.

    Returns:
      The path prefix for reference data.
    """
    return os.path.join(os.path.split(
        os.path.abspath(__file__))[0], "reference_data", self.test_name)

  ckpt_prefix = "model.ckpt"

  @staticmethod
  def name_to_seed(name):
    """Convert a string into a 32 bit integer.

    This function allows test cases to easily generate random fixed seeds by
    hashing the name of the test. The hash string is in hex rather than base 10
    which is why there is a 16 in the int call, and the modulo projects the
    seed from a 128 bit int to 32 bits for readability.

    Args:
      name: A string containing the name of a test.

    Returns:
      A pseudo-random 32 bit integer derived from name.
    """
    seed = hashlib.md5(name.encode("utf-8")).hexdigest()
    return int(seed, 16) % (2**32 - 1)

  @staticmethod
  def common_tensor_properties(input_array):
    """Convenience function for matrix testing.

    In tests we wish to determine whether a result has changed. However storing
    an entire n-dimensional array is impractical. A better approach is to
    calculate several values from that array and test that those derived values
    are unchanged. The properties themselves are arbitrary and should be chosen
    to be good proxies for a full equality test.

    Args:
      input_array: A numpy array from which key values are extracted.

    Returns:
      A list of values derived from the input_array for equality tests.
    """
    output = list(input_array.shape)
    flat_array = input_array.flatten()
    output.extend([float(i) for i in
                   [flat_array[0], flat_array[-1], np.sum(flat_array)]])
    return output

  def default_correctness_function(self, *args):
    """Returns a vector with the concatenation of common properties.

    This function simply calls common_tensor_properties() for every element.
    It is useful as it allows one to easily construct tests of layers without
    having to worry about the details of result checking.

    Args:
      *args: A list of numpy arrays corresponding to tensors which have been
        evaluated.

    Returns:
      A list of values containing properties for every element in args.
    """
    output = []
    for arg in args:
      output.extend(self.common_tensor_properties(arg))
    return output

  def _construct_and_save_reference_files(
      self, name, graph, ops_to_eval, correctness_function):
    """Save reference data files.

    Constructs a serialized graph_def, layer weights, and computation results.
    It then saves them to files which are read at test time.

    Args:
      name: String defining the run. This will be used to define folder names
        and will be used for random seed construction.
      graph: The graph in which the test is conducted.
      ops_to_eval: Ops which the user wishes to be evaluated under a controlled
        session.
      correctness_function: This function accepts the evaluated results of
        ops_to_eval, and returns a list of values. This list must be JSON
        serializable; in particular it is up to the user to convert numpy
        dtypes into builtin dtypes.
    """
    data_dir = os.path.join(self.data_root, name)

    # Make sure there is a clean space for results.
    if os.path.exists(data_dir):
      shutil.rmtree(data_dir)
    os.makedirs(data_dir)

    # Serialize graph for comparison.
    graph_bytes = graph.as_graph_def().SerializeToString()
    expected_file = os.path.join(data_dir, "expected_graph")
    with tf.gfile.Open(expected_file, "wb") as f:
      f.write(graph_bytes)

    with graph.as_default():
      init = tf.global_variables_initializer()
      saver = tf.train.Saver()

    with self.test_session(graph=graph) as sess:
      sess.run(init)
      saver.save(sess=sess, save_path=os.path.join(data_dir, self.ckpt_prefix))

      # These files are not needed for this test.
      os.remove(os.path.join(data_dir, "checkpoint"))
      os.remove(os.path.join(data_dir, self.ckpt_prefix + ".meta"))

      # ops are evaluated even if there is no correctness function to ensure
      # that they can be evaluated.
      eval_results = [op.eval() for op in ops_to_eval]

      if correctness_function is not None:
        results = correctness_function(*eval_results)
        with tf.gfile.Open(os.path.join(data_dir, "results.json"), "w") as f:
          json.dump(results, f)

      with tf.gfile.Open(os.path.join(data_dir, "tf_version.json"), "w") as f:
        json.dump([tf.VERSION, tf.GIT_VERSION], f)

  def _evaluate_test_case(self, name, graph, ops_to_eval, correctness_function):
    """Determine if a graph agrees with the reference data.

    Args:
      name: String defining the run. This will be used to define folder names
        and will be used for random seed construction.
      graph: The graph in which the test is conducted.
      ops_to_eval: Ops which the user wishes to be evaluated under a controlled
        session.
      correctness_function: This function accepts the evaluated results of
        ops_to_eval, and returns a list of values. This list must be JSON
        serializable; in particular it is up to the user to convert numpy
        dtypes into builtin dtypes.
    """
    data_dir = os.path.join(self.data_root, name)

    # Serialize graph for comparison.
    graph_bytes = graph.as_graph_def().SerializeToString()
    expected_file = os.path.join(data_dir, "expected_graph")
    with tf.gfile.Open(expected_file, "rb") as f:
      expected_graph_bytes = f.read()
      # The serialization is non-deterministic byte-for-byte. Instead there is
      # a utility which evaluates the semantics of the two graphs to test for
      # equality. This has the added benefit of providing some information on
      # what changed.
      #   Note: The summary only show the first difference detected. It is not
      #         an exhaustive summary of differences.
    differences = pywrap_tensorflow.EqualGraphDefWrapper(
        graph_bytes, expected_graph_bytes).decode("utf-8")

    with graph.as_default():
      init = tf.global_variables_initializer()
      saver = tf.train.Saver()

    with tf.gfile.Open(os.path.join(data_dir, "tf_version.json"), "r") as f:
      tf_version_reference, tf_git_version_reference = json.load(f)  # pylint: disable=unpacking-non-sequence

    tf_version_comparison = ""
    if tf.GIT_VERSION != tf_git_version_reference:
      tf_version_comparison = (
          "Test was built using:     {} (git = {})\n"
          "Local TensorFlow version: {} (git = {})"
          .format(tf_version_reference, tf_git_version_reference,
                  tf.VERSION, tf.GIT_VERSION)
      )

    with self.test_session(graph=graph) as sess:
      sess.run(init)
      try:
        saver.restore(sess=sess, save_path=os.path.join(
            data_dir, self.ckpt_prefix))
        if differences:
          tf.logging.warn(
              "The provided graph is different than expected:\n  {}\n"
              "However the weights were still able to be loaded.\n{}".format(
                  differences, tf_version_comparison)
          )
      except:  # pylint: disable=bare-except
        raise self.failureException(
            "Weight load failed. Graph comparison:\n  {}{}"
            .format(differences, tf_version_comparison))

      eval_results = [op.eval() for op in ops_to_eval]
      if correctness_function is not None:
        results = correctness_function(*eval_results)
        with tf.gfile.Open(os.path.join(data_dir, "results.json"), "r") as f:
          expected_results = json.load(f)
        self.assertAllClose(results, expected_results)

  def _save_or_test_ops(self, name, graph, ops_to_eval=None, test=True,
                        correctness_function=None):
    """Utility function to automate repeated work of graph checking and saving.

    The philosophy of this function is that the user need only define ops on
    a graph and specify which results should be validated. The actual work of
    managing snapshots and calculating results should be automated away.

    Args:
      name: String defining the run. This will be used to define folder names
        and will be used for random seed construction.
      graph: The graph in which the test is conducted.
      ops_to_eval: Ops which the user wishes to be evaluated under a controlled
        session.
      test: Boolean. If True this function will test graph correctness, load
        weights, and compute numerical values. If False the necessary test data
        will be generated and saved.
      correctness_function: This function accepts the evaluated results of
        ops_to_eval, and returns a list of values. This list must be JSON
        serializable; in particular it is up to the user to convert numpy
        dtypes into builtin dtypes.
    """

    ops_to_eval = ops_to_eval or []

    if test:
      try:
        self._evaluate_test_case(
            name=name, graph=graph, ops_to_eval=ops_to_eval,
            correctness_function=correctness_function
        )
      except:
        tf.logging.error("Failed unittest {}".format(name))
        raise
    else:
      self._construct_and_save_reference_files(
          name=name, graph=graph, ops_to_eval=ops_to_eval,
          correctness_function=correctness_function
      )


class ReferenceDataActionParser(argparse.ArgumentParser):
  """Minimal arg parser so that test regeneration can be called from the CLI."""

  def __init__(self):
    super(ReferenceDataActionParser, self).__init__()
    self.add_argument(
        "--regenerate", "-regen",
        action="store_true",
        help="Enable this flag to regenerate test data. If not set unit tests"
             "will be run."
    )


def main(argv, test_class):
  """Simple switch function to allow test regeneration from the CLI."""
  flags = ReferenceDataActionParser().parse_args(argv[1:])
  if flags.regenerate:
    if sys.version_info[0] == 2:
      raise NameError("\nPython2 unittest does not support being run as a "
                      "standalone class.\nAs a result tests must be "
                      "regenerated using Python3.\n"
                      "Tests can be run under 2 or 3.")
    test_class().regenerate()
  else:
    tf.test.main()