# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#      Unless required by applicable law or agreed to in writing, software
#      distributed under the License is distributed on an "AS IS" BASIS,
#      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#      See the License for the specific language governing permissions and
#      limitations under the License.
"""TFE-specific test utils."""
import numpy as np

import pytest
from tangent.grad_util import autodiff, jvp
import utils

try:
  import tensorflow as tf
  from tensorflow.contrib.eager.python import tfe
except ImportError:
  tf = None
  tfe = None
else:
  tfe.enable_eager_execution()


def register_parametrizations(metafunc, short):
  """Create additional parametrizations required for TF tests."""

  for arg in ['t', 't1', 't2']:
    # Note: care must be exercised when sharing tensor objects. Although
    # immutable, references to the same value will be interpreted as the same
    # variable, with unexpected side effects.
    if tf:
      vectors = [
          np.random.randn(i)
          for i in (
              (3,) if short else (3, 5, 10))]
      tensors = [tf.constant(v, dtype=tf.float32) for v in vectors]
    else:
      tensors = [pytest.mark.skip(None, reason='tensorflow not present')(None)]
    if arg in metafunc.fixturenames:
      metafunc.parametrize(arg, tensors)

  for arg in ['mat1', 'mat2']:
    if tf:
      matrices = [
        np.random.randn(*i)
        for i in (
            ((3, 3),) if short else (
                (1, 1),
                (3, 3),
                (5, 5)))]
      tensors = [tf.constant(m, dtype=tf.float32) for m in matrices]
    else:
      tensors = [pytest.mark.skip(None, reason='tensorflow not present')(None)]
    if arg in metafunc.fixturenames:
      metafunc.parametrize(arg, tensors)

  if 's' in metafunc.fixturenames:
    if tf:
      if short:
        scalars = [tf.constant(1.0)]
      else:
        scalars = [tf.constant(c) for c in (0.0, 1.0, 2.0)]
    else:
      scalars = [pytest.mark.skip(reason='tensorflow not present')(None)]
    metafunc.parametrize('s', scalars)

  for arg in ['timage', 'timage1', 'timage2']:
    if arg in metafunc.fixturenames:
      if tf:
        images = [
            np.random.randn(*i)
            for i in (
              ((2, 3, 3, 3),) if short else (
                    (2, 1, 1, 3),
                    (2, 3, 3, 3),
                    (2, 5, 5, 3),
                ))
        ]
        timages = [tf.constant(v, dtype=tf.float32) for v in images]
      else:
        timages = [pytest.mark.skip(reason='tensorflow not present')(None)]
      metafunc.parametrize(arg, timages)

  if 'tkernel' in metafunc.fixturenames:
    if tf:
      kernels = [
          np.random.randn(*i)
          for i in (
              ((3, 3, 3, 1),) if short else (
                  (3, 3, 3, 1),
                  (3, 3, 3, 2),
                  (5, 5, 3, 3),
              ))
      ]
      tkernels = [tf.constant(v, dtype=tf.float32) for v in kernels]
    else:
      tkernels = [pytest.mark.skip(reason='tensorflow not present')(None)]
    metafunc.parametrize('tkernel', tkernels)

  if 'conv2dstrides' in metafunc.fixturenames:
    strides = [(1, 2, 2, 1)] if short else [
        (1, 1, 1, 1),
        (1, 2, 2, 1),
        (1, 2, 2, 2),
    ]
    metafunc.parametrize('conv2dstrides', strides)

  if 'pool2dsizes' in metafunc.fixturenames:
    sizes = [(1, 2, 2, 1)] if short else [
        (1, 1, 1, 1),
        (1, 2, 2, 1),
        (1, 3, 3, 1),
    ]
    metafunc.parametrize('pool2dsizes', sizes)


def tensors_to_numpy(tensors):
  if isinstance(tensors, (tuple, list)):
    return tuple(tensors_to_numpy(t) for t in tensors)
  if isinstance(tensors, tf.Tensor):
    return tensors.numpy()
  raise ValueError('Don\'t know how to handle %s' % type(tensors))


def as_numpy_sig(func):
  """Wrap a TF Eager function into a signature that uses NumPy arrays."""
  def wrapped(*args):
    np_args = [tf.constant(a) if isinstance(a, np.ndarray) else a for a in args]
    return tensors_to_numpy(func(*np_args))
  return wrapped


def test_forward_tensor(func, wrt, *args):
  """Test gradients of functions with TFE signatures."""

  def tangent_func():
    df = jvp(func, wrt=wrt, optimized=True, verbose=True)
    args_ = args + tuple(tf.ones_like(args[i]) for i in wrt)  # seed gradient
    return tensors_to_numpy(df(*args_))

  def reference_func():
    return tensors_to_numpy(tfe.gradients_function(func, params=wrt)(*args))

  def backup_reference_func():
    func_ = as_numpy_sig(func)
    args_ = tensors_to_numpy(args)
    return utils.numeric_grad(utils.numeric_grad(func_))(*args_)

  # TODO: Should results really be that far off?
  utils.assert_result_matches_reference(
      tangent_func, reference_func, backup_reference_func,
      tolerance=1e-4)


def test_gradgrad_tensor(func, optimized, *args):
  """Test gradients of functions with TFE signatures."""

  def tangent_func():
    df = tangent.autodiff(func, motion='joint', optimized=optimized, verbose=True)
    ddf = tangent.autodiff(df, motion='joint', optimized=optimized, verbose=True)
    dxx = ddf(*args)
    return tuple(t.numpy() for t in dxx)

  def reference_func():
    dxx = tfe.gradients_function(tfe.gradients_function(func))(*args)
    return tensors_to_numpy(tuple(t.numpy() for t in dxx))

  def backup_reference_func():
    func_ = as_numpy_sig(func)
    args_ = tensors_to_numpy(args)
    return utils.numeric_grad(utils.numeric_grad(func_))(*args_)

  utils.assert_result_matches_reference(
      tangent_func, reference_func, backup_reference_func,
      tolerance=1e-2)  # extra loose bounds for 2nd order grad


def test_rev_tensor(func, motion, optimized, preserve_result, wrt, *args):
  """Test gradients of functions with TFE signatures."""

  def tangent_func():
    y = func(*args)
    if isinstance(y, (tuple, list)):
      init_grad = tuple(tf.ones_like(t) for t in y)
    else:
      init_grad = tf.ones_like(y)
    df = autodiff(
        func,
        motion=motion,
        optimized=optimized,
        preserve_result=preserve_result,
        wrt=wrt,
        verbose=True)
    if motion == 'joint':
      # TODO: This won't work if func has default args unspecified.
      dx = df(*args + (init_grad,))
    else:
      dx = df(*args, init_grad=init_grad)
    return tensors_to_numpy(dx)

  def reference_func():
    gradval = tensors_to_numpy(tfe.gradients_function(func, params=wrt)(*args))
    if preserve_result:
      val = tensors_to_numpy(func(*args))
      if isinstance(gradval, (tuple)):
        return gradval + (val,)
      return gradval, val
    else:
      return gradval

  def backup_reference_func():
    func_ = as_numpy_sig(func)
    args_ = tensors_to_numpy(args)
    gradval = utils.numeric_grad(utils.numeric_grad(func_))(*args_)
    if preserve_result:
      val = tensors_to_numpy(func(*args))
      return gradval, val
    else:
      return gradval

  utils.assert_result_matches_reference(
      tangent_func, reference_func, backup_reference_func,
      # Some ops like tf.divide diverge significantly due to what looks like
      # numerical instability.
      tolerance=1e-5)