import unittest import torch import torchdiffeq from problems import construct_problem eps = 1e-12 torch.set_default_dtype(torch.float64) TEST_DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def max_abs(tensor): return torch.max(torch.abs(tensor)) class TestGradient(unittest.TestCase): def test_midpoint(self): f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='midpoint') self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) def test_rk4(self): f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='rk4') self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) def test_dopri5(self): f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5') self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) def test_adams(self): f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='adams') self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) def test_adaptive_heun(self): f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='adaptive_heun') self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) def test_dopri8(self): f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri8') self.assertTrue(torch.autograd.gradcheck(func, (y0, t_points))) def test_adjoint(self): """ Test against dopri5 """ f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint(f, y0, t_points, method='dopri5') ys = func(y0, t_points) torch.manual_seed(0) gradys = torch.rand_like(ys) ys.backward(gradys) # reg_y0_grad = y0.grad reg_t_grad = t_points.grad reg_a_grad = f.a.grad reg_b_grad = f.b.grad f, y0, t_points, _ = construct_problem(TEST_DEVICE) func = lambda y0, t_points: torchdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5') ys = func(y0, t_points) ys.backward(gradys) # adj_y0_grad = y0.grad adj_t_grad = t_points.grad adj_a_grad = f.a.grad adj_b_grad = f.b.grad # self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps) self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps) self.assertLess(max_abs(reg_a_grad - adj_a_grad), eps) self.assertLess(max_abs(reg_b_grad - adj_b_grad), eps) class TestCompareAdjointGradient(unittest.TestCase): def problem(self): class Odefunc(torch.nn.Module): def __init__(self): super(Odefunc, self).__init__() self.A = torch.nn.Parameter(torch.tensor([[-0.1, 2.0], [-2.0, -0.1]])) self.unused_module = torch.nn.Linear(2, 5) def forward(self, t, y): return torch.mm(y**3, self.A) y0 = torch.tensor([[2., 0.]]).to(TEST_DEVICE).requires_grad_(True) t_points = torch.linspace(0., 25., 10).to(TEST_DEVICE).requires_grad_(True) func = Odefunc().to(TEST_DEVICE) return func, y0, t_points def test_dopri5_adjoint_against_dopri5(self): func, y0, t_points = self.problem() ys = torchdiffeq.odeint_adjoint(func, y0, t_points, method='dopri5') gradys = torch.rand_like(ys) * 0.1 ys.backward(gradys) adj_y0_grad = y0.grad adj_t_grad = t_points.grad adj_A_grad = func.A.grad self.assertEqual(max_abs(func.unused_module.weight.grad), 0) self.assertEqual(max_abs(func.unused_module.bias.grad), 0) func, y0, t_points = self.problem() ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5') ys.backward(gradys) self.assertLess(max_abs(y0.grad - adj_y0_grad), 3e-4) self.assertLess(max_abs(t_points.grad - adj_t_grad), 1e-4) self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-3) def test_adams_adjoint_against_dopri5(self): func, y0, t_points = self.problem() ys_ = torchdiffeq.odeint_adjoint(func, y0, t_points, method='adams') gradys = torch.rand_like(ys_) * 0.1 ys_.backward(gradys) adj_y0_grad = y0.grad adj_t_grad = t_points.grad adj_A_grad = func.A.grad self.assertEqual(max_abs(func.unused_module.weight.grad), 0) self.assertEqual(max_abs(func.unused_module.bias.grad), 0) func, y0, t_points = self.problem() ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5') ys.backward(gradys) self.assertLess(max_abs(y0.grad - adj_y0_grad), 5e-2) self.assertLess(max_abs(t_points.grad - adj_t_grad), 5e-4) self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-2) if __name__ == '__main__': unittest.main()