"""Test the functions in utilities.py""" import pytest try: import torch from safe_exploration.ssm_pytorch import compute_jacobian, update_cholesky, SetTorchDtype except: pass @pytest.fixture(autouse = True) def check_has_ssm_pytorch_module(check_has_ssm_pytorch): pass class TestJacobian(object): def test_error(self): """Test assertion error raised when grad is missing.""" with pytest.raises(AssertionError): compute_jacobian(None, torch.ones(2, 1)) def test_0d(self): x = torch.ones(2, 2, requires_grad=True) A = torch.tensor([[1., 2.], [3., 4.]]) f = A * x f = torch.sum(f) jac = compute_jacobian(f, x) torch.testing.assert_allclose(jac, A) def test_1d(self): """Test jacobian function for 1D inputs.""" x = torch.ones(1, requires_grad=True) f = 2 * x jac = compute_jacobian(f, x) torch.testing.assert_allclose(jac[0, 0], 2) def test_2d(self): """Test jacobian computation.""" x = torch.ones(2, 1, requires_grad=True) A = torch.tensor([[1., 2.], [3., 4.]]) f = A @ x jac = compute_jacobian(f, x) torch.testing.assert_allclose(A, jac[:, 0, :, 0]) # Test both multiple runs jac = compute_jacobian(f.squeeze(-1), x) torch.testing.assert_allclose(A, jac.squeeze(-1)) def test_2d_output(self): """Test jacobian with 2d input and output""" x = torch.ones(2, 2, requires_grad=True) A = torch.tensor([[1., 2.], [3., 4.]]) f = A * x jac = compute_jacobian(f, x) torch.testing.assert_allclose(jac.shape, 2) torch.testing.assert_allclose(jac.sum(dim=0).sum(dim=0), A) @pytest.mark.xfail(reason="There seems to be a dimensionality error.") def test_update_cholesky(): """Test that the update cholesky function returns correct values.""" n = 6 new_A = torch.rand(n, n, dtype=torch.float64) new_A = new_A @ new_A.t() new_A += torch.eye(len(new_A), dtype=torch.float64) A = new_A[:n - 1, :n - 1] old_chol = torch.cholesky(A, upper=False) new_row = new_A[-1] # Test updateing overall new_chol = update_cholesky(old_chol, new_row) error = new_chol - torch.cholesky(new_A, upper=False) assert torch.all(torch.abs(error) <= 1e-15) # Test updating inplace new_chol = torch.zeros(n, n, dtype=torch.float64) new_chol[:n - 1, :n - 1] = old_chol update_cholesky(old_chol, new_row, chol_row_out=new_chol[-1]) error = new_chol - torch.cholesky(new_A, upper=False) assert torch.all(torch.abs(error) <= 1e-15) def test_set_torch_dtype(): """Test dtype context manager.""" dtype = torch.get_default_dtype() torch.set_default_dtype(torch.float32) with SetTorchDtype(torch.float64): a = torch.zeros(1) assert a.dtype is torch.float64 b = torch.zeros(1) assert b.dtype is torch.float32 torch.set_default_dtype(dtype)