from __future__ import division from contextlib import contextmanager import pytest import warnings import autograd.numpy as np from autograd import grad, deriv from autograd.extend import primitive from autograd.test_util import check_grads from autograd.core import primitive_vjps def test_assert(): # from https://github.com/HIPS/autograd/issues/43 def fun(x): assert np.allclose(x, (x*3.0)/3.0) return np.sum(x) check_grads(fun)(np.array([1.0, 2.0, 3.0])) def test_nograd(): # we want this to raise non-differentiability error fun = lambda x: np.allclose(x, (x*3.0)/3.0) with pytest.raises(TypeError): with warnings.catch_warnings(record=True) as w: grad(fun)(np.array([1., 2., 3.])) def test_no_vjp_def(): fun = primitive(lambda x: 2. * x) with pytest.raises(NotImplementedError): grad(fun)(1.) def test_no_jvp_def(): fun = primitive(lambda x: 2. * x) with pytest.raises(NotImplementedError): deriv(fun)(1.) def test_falseyness(): fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x)) check_grads(fun)(2.) check_grads(fun)(2. + 1j) def test_unimplemented_falseyness(): @contextmanager def remove_grad_definitions(fun): vjpmaker = primitive_vjps.pop(fun, None) yield if vjpmaker: primitive_vjps[fun] = vjpmaker with remove_grad_definitions(np.iscomplex): fun = lambda x: np.real(x**2 if np.iscomplex(x) else np.sum(x)) check_grads(fun)(5.) check_grads(fun)(2. + 1j)