#!/usr/bin/env python3 # # Run these tests with: nosetests -v -d test-adact-np.py # This will run all functions even if one throws an assertion. # # For debugging: ./test-adact-back.py # Easier to print statements. # This will exit after the first assertion. import os import sys import torch import numpy as np import numpy.random as npr import numpy.testing as npt np.set_printoptions(precision=2) import numdifftools as nd import cvxpy as cp import adact import adact_forward_ip as aip from solver import BlockSolver as Solver from nose.tools import with_setup, assert_almost_equal ATOL=1e-2 RTOL=1e-7 npr.seed(1) nz, neq, nineq = 5,0,4 # nz, neq, nineq = 3,3,3 L = np.tril(np.random.randn(nz,nz)) + 2.*np.eye(nz,nz) Q = L.dot(L.T)+1e-8*np.eye(nz) G = 1000.*npr.randn(nineq,nz) A = 10000.*npr.randn(neq,nz) z0 = 1.*npr.randn(nz) s0 = 100.*np.ones(nineq) p = npr.randn(nz) truez = npr.randn(nz) af = adact.AdactFunction() zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) dl_dzhat = zhat-truez # dp, dL, dG, dA, dz0, ds0 = af.backward_single_np(zhat, nu, lam, dl_dzhat, L, G, A, z0, s0) S = Solver(L, A, G, z0, s0, 1e-8) S.reinit(lam, zhat) dp, dL, dG, dA, dz0, ds0 = af.backward_single_np_solver(S, zhat, nu, lam, dl_dzhat, L, G, A, z0, s0) verbose = True def test_ip_forward(): p_t, Q_t, G_t, A_t, z0_t, s0_t = [torch.Tensor(x) for x in [p, Q, G, A, z0, s0]] b = torch.mv(A_t, z0_t) if neq > 0 else None h = torch.mv(G_t,z0_t)+s0_t L_Q, L_S, R = aip.pre_factor_kkt(Q_t, G_t, A_t) zhat_ip, nu_ip, lam_ip = aip.forward_single(p_t, Q_t, G_t, A_t, b, h, L_Q, L_S, R) # Unnecessary clones here because of a pytorch bug when calling numpy # on a tensor with a non-zero offset. npt.assert_allclose(zhat, zhat_ip.clone().numpy(), rtol=RTOL, atol=ATOL) if neq > 0: npt.assert_allclose(nu, nu_ip.clone().numpy(), rtol=RTOL, atol=ATOL) npt.assert_allclose(lam, lam_ip.clone().numpy(), rtol=RTOL, atol=ATOL) def test_dl_dz0(): def f(z0): zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) return 0.5*np.sum(np.square(zhat - truez)) df = nd.Gradient(f) dz0_fd = df(z0) if verbose: print('dz0_fd: ', dz0_fd) print('dz0: ', dz0) npt.assert_allclose(dz0_fd, dz0, rtol=RTOL, atol=ATOL) def test_dl_ds0(): def f(s0): zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) return 0.5*np.sum(np.square(zhat - truez)) df = nd.Gradient(f) ds0_fd = df(s0) if verbose: print('ds0_fd: ', ds0_fd) print('ds0: ', ds0) npt.assert_allclose(ds0_fd, ds0, rtol=RTOL, atol=ATOL) def test_dl_dp(): def f(p): zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) return 0.5*np.sum(np.square(zhat - truez)) df = nd.Gradient(f) dp_fd = df(p) if verbose: print('dp_fd: ', dp_fd) print('dp: ', dp) npt.assert_allclose(dp_fd, dp, rtol=RTOL, atol=ATOL) def test_dl_dp_batch(): def f(p): zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) return 0.5*np.sum(np.square(zhat - truez)) df = nd.Gradient(f) dp_fd = df(p) if verbose: print('dp_fd: ', dp_fd) print('dp: ', dp) npt.assert_allclose(dp_fd, dp, rtol=RTOL, atol=ATOL) def test_dl_dA(): def f(A): A = A.reshape(neq,nz) zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) return 0.5*np.sum(np.square(zhat - truez)) df = nd.Gradient(f) dA_fd = df(A.ravel()).reshape(neq, nz) if verbose: print('dA_fd[1,:]: ', dA_fd[1,:]) print('dA[1,:]: ', dA[1,:]) npt.assert_allclose(dA_fd, dA, rtol=RTOL, atol=ATOL) def test_dl_dG(): def f(G): G = G.reshape(nineq,nz) zhat, nu, lam = af.forward_single_np(p, L, G, A, z0, s0) return 0.5*np.sum(np.square(zhat - truez)) df = nd.Gradient(f) dG_fd = df(G.ravel()).reshape(nineq, nz) if verbose: print('dG_fd[1,:]: ', dG_fd[1,:]) print('dG[1,:]: ', dG[1,:]) npt.assert_allclose(dG_fd, dG, rtol=RTOL, atol=ATOL) def test_dl_dL(): def f(l0): L_ = np.copy(L) L_[:,0] = l0 zhat, nu, lam = af.forward_single_np(p, L_, G, A, z0, s0) return 0.5*np.sum(np.square(zhat - truez)) df = nd.Gradient(f) dL_fd = df(L[:,0]) dl0 = np.array(dL[:,0]).ravel() if verbose: print('dL_fd: ', dL_fd) print('dL: ', dl0) npt.assert_allclose(dL_fd, dl0, rtol=RTOL, atol=ATOL) if __name__=='__main__': # test_ip_forward() test_dl_dp() # test_dl_dp_batch() # test_dl_dz0() # test_dl_ds0() # if neq > 0: # test_dl_dA() # test_dl_dG() # test_dl_dL()