"""test_gel.py: framework to test gel implementations.""" import itertools import os import unittest import cvxpy as cvx import numpy as np from scipy.spatial.distance import cosine import torch from gel.gelcd import ( block_solve_agd, block_solve_newton, gel_solve as gel_solve_cd, make_A as make_A_cd, ) from gel.gelfista import gel_solve as gel_solve_fista, make_A as make_A_fista def gel_solve_cvx(As, y, l_1, l_2, ns): """Solve a group elastic net problem with cvx. Arguments: As: list of tensors. y: tensor. l_1, l_2: floats. ns: iterable. """ # Convert everything to numpy dtype = As[0].dtype As = [A_j.cpu().numpy() for A_j in As] y = y.cpu().numpy() ns = np.array([int(n) for n in ns]) # Create the b variables. b_0 = cvx.Variable() bs = [] for _, n_j in zip(As, ns): bs.append(cvx.Variable(n_j)) # Form g(b). Ab = sum(A_j * b_j for A_j, b_j in zip(As, bs)) m = As[0].shape[0] g_b = cvx.square(cvx.norm(y - b_0 - Ab)) / (2 * m) # Form h(b). h_b = sum( np.sqrt(n_j) * (l_1 * cvx.norm(b_j) + l_2 * cvx.square(cvx.norm(b_j))) for n_j, b_j in zip(ns, bs) ) # Build the optimization problem. obj = cvx.Minimize(g_b + h_b) problem = cvx.Problem(obj, constraints=None) problem.solve(solver="CVXOPT") b_0 = b_0.value.item() # Form B as returned by gel_solve. p = len(As) B = torch.zeros(p, int(max(ns)), dtype=dtype) for j in range(p): b_j = np.asarray(bs[j].value) B[j, : ns[j]] = torch.from_numpy(b_j) return b_0, B def block_solve_cvx(r_j, A_j, a_1_j, a_2_j, m, b_j_init, verbose=False): # pylint: disable=unused-argument """Solve the gelcd optimization problem for a single block with cvx. b_j_init and verbose are ignored. b_j_init because cvx doesn't support it. verbose because it doesn't go together with tqdm. """ # Convert everything to numpy. device = A_j.device dtype = A_j.dtype r_j = r_j.cpu().numpy() A_j = A_j.cpu().numpy() # Create the b_j variable. b_j = cvx.Variable(A_j.shape[1]) # Form the objective. q_j = r_j - A_j * b_j obj_fun = cvx.square(cvx.norm(q_j)) / (2.0 * m) obj_fun += a_1_j * cvx.norm(b_j) + (a_2_j / 2.0) * cvx.square(cvx.norm(b_j)) # Build the optimization problem. obj = cvx.Minimize(obj_fun) problem = cvx.Problem(obj, constraints=None) problem.solve(solver="CVXOPT", verbose=False) b_j = np.asarray(b_j.value) return torch.from_numpy(b_j).to(device, dtype) def _b2vec(B, groups): """Convert B as returned by gel_solve functions to a single numpy vector.""" d = sum(len(group_j) for group_j in groups) # the total dimension b = np.zeros(d, dtype=B[0, 0].cpu().numpy().dtype) for j, group_j in enumerate(groups): b[group_j] = B[j, : len(group_j)].cpu().numpy() return b class TestGelBirthwtBase: """Base class to test different gel_solve implementations with the birth weight data.""" l_1_base = 4.0 l_2_base = 0.5 def __init__(self, device, dtype, *args, **kwargs): """Load data and solve with cvx to get ground truth solution.""" super().__init__(*args, **kwargs) self.device = device self.dtype = dtype dtype_np = torch.rand(0, dtype=dtype).numpy().dtype data_dir = os.path.join(os.path.dirname(__file__), "data", "birthwt") self.X = np.loadtxt( os.path.join(data_dir, "X.csv"), skiprows=1, delimiter=",", dtype=dtype_np ) self.y = np.loadtxt(os.path.join(data_dir, "y.csv"), skiprows=1, dtype=dtype_np) self.m = len(self.y) self.l_1 = self.l_1_base / (2 * self.m) self.l_2 = self.l_2_base / (2 * self.m) self.groups = [ [0, 1, 2], [3, 4, 5], [6, 7], [8], [9, 10], [11], [12], [13, 14, 15], ] self.done_setup = False def setUp(self): if self.device.type == "cuda" and not torch.cuda.is_available(): raise unittest.SkipTest("cuda unavailable") if self.done_setup: return self.ns = torch.tensor([len(g) for g in self.groups]) self.p = len(self.groups) # Convert things to gel format. self.As = [] for j in range(self.p): A_j = self.X[:, self.groups[j]] self.As.append(torch.from_numpy(A_j).to(self.device, self.dtype)) self.yt = torch.from_numpy(self.y).to(self.device, self.dtype) # Solve with cvx. self.b_0_cvx, self.B_cvx = gel_solve_cvx( self.As, self.yt, self.l_1, self.l_2, self.ns ) self.b_cvx = _b2vec(self.B_cvx, self.groups) self.obj_cvx = self._obj(self.b_0_cvx, self.b_cvx) self.done_setup = True def _obj(self, b_0, b): """Compute the objective function value for the given b_0, b.""" r = self.y - b_0 - self.X @ b g_b = r @ r / (2.0 * self.m) b_j_norms = [np.linalg.norm(b[self.groups[j]], ord=2) for j in range(self.p)] h_b = self.l_1 * sum( np.sqrt(len(self.groups[j])) * b_j_norms[j] for j in range(self.p) ) h_b += self.l_2 * sum( np.sqrt(len(self.groups[j])) * (b_j_norms[j] ** 2) for j in range(self.p) ) return g_b + h_b def _compare_to_cvx(self, b_0, b, obj): """Compare the given solution to the cvx solution.""" # pylint: disable=no-member self.assertAlmostEqual(obj, self.obj_cvx, places=2) self.assertAlmostEqual(b_0, self.b_0_cvx, places=2) if np.allclose(b, 0) or np.allclose(self.b_cvx, 0): for b_i, b_cvx_i in zip(b, self.b_cvx): self.assertAlmostEqual(b_i, b_cvx_i, places=2) else: self.assertAlmostEqual(cosine(b, self.b_cvx), 0, places=2) def _test_implementation(self, make_A, gel_solve, **gel_solve_kwargs): """Test the given implementation.""" A = make_A(self.As, self.ns, self.device, self.dtype) b_0, B = gel_solve(A, self.yt, self.l_1, self.l_2, self.ns, **gel_solve_kwargs) b = _b2vec(B, self.groups) obj = self._obj(b_0, b) self._compare_to_cvx(b_0, b, obj) def test_fista(self): """Test the FISTA implementation of gel_solve.""" self._test_implementation( make_A_fista, gel_solve_fista, t_init=0.1, ls_beta=0.9, max_iters=1000, rel_tol=1e-6, ) def test_cd_cvx(self): """Test the CD implementation with cvx internal solver.""" self._test_implementation( make_A_cd, gel_solve_cd, block_solve_fun=block_solve_cvx, block_solve_kwargs={}, max_cd_iters=100, rel_tol=1e-6, ) def test_cd_agd(self): """Test the CD implementation with AGD internal solver.""" self._test_implementation( make_A_cd, gel_solve_cd, block_solve_fun=block_solve_agd, block_solve_kwargs={ "t_init": 1, "ls_beta": 0.5, "max_iters": 100, "rel_tol": 1e-5, }, max_cd_iters=100, rel_tol=1e-6, ) def test_cd_newton(self): """Test the CD implementation with Newton internal solver.""" # Compute the C_js and I_js. Cs = [(A_j.t() @ A_j) / self.m for A_j in self.As] Is = [torch.eye(n_j, device=self.device, dtype=self.dtype) for n_j in self.ns] self._test_implementation( make_A_cd, gel_solve_cd, block_solve_fun=block_solve_newton, block_solve_kwargs={ "ls_alpha": 0.1, "ls_beta": 0.5, "max_iters": 10, "tol": 1e-10, }, max_cd_iters=100, rel_tol=1e-6, Cs=Cs, Is=Is, ) def create_gel_birthwt_test(device_name, dtype, *mods): # I'm so sorry. device = torch.device(device_name) def __init__(self, *args, **kwargs): TestGelBirthwtBase.__init__(self, device, dtype, *args, **kwargs) for mod in mods: if mod == "l10": self.l_1 = 0 elif mod == "l20": self.l_2 = 0 elif mod == "nj1": self.groups = [[i] for i in range(self.X.shape[1])] else: raise RuntimeError("unrecognized mod: " + mod) _doc = "Test gel implementations on {} with {}".format(device_name, dtype) if mods: _doc += " (mods: " + ", ".join(mods) + ")" test_name = "TestGelBirthwt" + device_name.upper() + str(dtype)[-2:] if mods: test_name += "_" + "".join(str(m) for m in mods) globals()[test_name] = type( test_name, (TestGelBirthwtBase, unittest.TestCase), {"__init__": __init__, "__doc__": _doc}, ) _mods = ["l10", "l20", "nj1"] _mod_subsets = set( frozenset(s) for s in itertools.combinations_with_replacement(_mods, len(_mods)) ) _mod_subsets.add(frozenset()) for _device_name, _dtype, _mod_subset in itertools.product( ["cpu", "cuda"], [torch.float32, torch.float64], _mod_subsets ): create_gel_birthwt_test(_device_name, _dtype, *list(sorted(_mod_subset)))