# Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Common testing utilities.""" from copy import deepcopy from autograd import grad as ag_grad from autograd import value_and_grad as ag_value_and_grad from autograd.misc.flatten import flatten import autograd.numpy as ag_np import numpy as np import tangent # Autograd's NumPy implementation may be missing the definition for _NoValue. if not hasattr(ag_np, '_NoValue'): ag_np._NoValue = np._NoValue # pylint: disable=protected-access def assert_forward_not_implemented(func, wrt): try: tangent.autodiff(func, mode='forward', preserve_result=False, wrt=wrt) assert False, 'Remove this when implementing.' except NotImplementedError: pass def _assert_allclose(a, b, tol=1e-5): if isinstance(a, (tuple, list)) and isinstance(b, (tuple, list)): for ia, ib in zip(a, b): _assert_allclose(ia, ib, tol) else: try: a = np.nan_to_num(a) b = np.nan_to_num(b) assert np.allclose(a, b, tol), ('Expected: %s\nGot: %s' % (b, a)) except TypeError: raise TypeError('Could not compare values %s and %s' % (a, b)) def assert_result_matches_reference( tangent_func, reference_func, backup_reference_func, tolerance=1e-7): """Test Tangent functionality against reference implementation. Args: tangent_func: Returns the Tangent derivative. reference_func: Returns the derivative calculated by the reference implementation. backup_reference_func: Returns the derivative calculated by a catch-all implementation, should the reference be unavailable. tolerance: Absolute tolerance override for FP comparisons. """ tangent_value = tangent_func() try: reference_value = reference_func() except (ImportError, TypeError) as e: if __debug__: print('WARNING: Reference function call failed. The test will revert to ' 'the backup reference.\nReason for failure: %s' % e) # TODO: Try to narrow the exception handler. reference_value = backup_reference_func() _assert_allclose(tangent_value, reference_value, tolerance) def numeric_grad(func, eps=1e-6): """Generate a finite-differences gradient of function `f`. def f(x, *args): ... return scalar g = numeric_grad(f, eps=1e-4) finite_difference_grad_of_x = g(x, *args) Adapted from github.com/hips/autograd """ def g(x, *args): fd_grad, unflatten_fd = flatten(tangent.init_grad(x)) y = func(deepcopy(x), *args) seed = np.ones_like(y) for d in range(fd_grad.size): x_flat, unflatten_x = flatten(deepcopy(x)) x_flat[d] += eps / 2 a = np.array(func(unflatten_x(x_flat), *args)) x_flat, unflatten_x = flatten(deepcopy(x)) x_flat[d] -= eps / 2 b = np.array(func(unflatten_x(x_flat), *args)) fd_grad[d] = np.dot((a - b) / eps, seed) return unflatten_fd(fd_grad) return g def test_reverse_array(func, motion, optimized, preserve_result, *args): """Test gradients of functions with NumPy-compatible signatures.""" def tangent_func(): y = func(*deepcopy(args)) if np.array(y).size > 1: init_grad = np.ones_like(y) else: init_grad = 1 func.__globals__['np'] = np df = tangent.autodiff( func, mode='reverse', motion=motion, optimized=optimized, preserve_result=preserve_result, verbose=1) if motion == 'joint': return df(*deepcopy(args) + (init_grad,)) return df(*deepcopy(args), init_grad=init_grad) def reference_func(): func.__globals__['np'] = ag_np if preserve_result: val, gradval = ag_value_and_grad(func)(*deepcopy(args)) return gradval, val else: return ag_grad(func)(*deepcopy(args)) def backup_reference_func(): func.__globals__['np'] = np df_num = numeric_grad(func) gradval = df_num(*deepcopy(args)) if preserve_result: val = func(*deepcopy(args)) return gradval, val else: return gradval assert_result_matches_reference(tangent_func, reference_func, backup_reference_func) def test_forward_array(func, wrt, preserve_result, *args): """Test derivatives of functions with NumPy-compatible signatures.""" def tangent_func(): func.__globals__['np'] = np df = tangent.autodiff( func, mode='forward', preserve_result=preserve_result, wrt=wrt, optimized=True, verbose=1) args_ = args + (1.0,) # seed gradient return df(*deepcopy(args_)) def reference_func(): func.__globals__['np'] = ag_np if preserve_result: # Note: ag_value_and_grad returns (val, grad) but we need (grad, val) val, gradval = ag_value_and_grad(func)(*deepcopy(args)) return gradval, val else: return ag_grad(func)(*deepcopy(args)) def backup_reference_func(): func.__globals__['np'] = np df_num = numeric_grad(func) gradval = df_num(*deepcopy(args)) if preserve_result: val = func(*deepcopy(args)) return gradval, val else: return gradval assert_result_matches_reference(tangent_func, reference_func, backup_reference_func)