"""gelcd.py: coordinate descent for group elastic net. This implementation uses block-wise coordinate descent to solve the optimization problem: min_b (1/2m)||y - b_0 - sum_j A_j@b_j||^2 + sum_j {sqrt{n_j}(l_1*||b_j|| + l_2*||b_j||^2)}. The algorithm repeatedly minimizes with respect to b_0 and the b_js. For b_0, the minimization has a closed form solution. For each b_j, the minimization objective is convex, twice differentiable; and can be solved using gradient descent, Newton's method etc. This implementation allows the internal minimizer to be chosen. These are the block_solve_* functions. Each solves the above optimization problem with respect to one of the b_js, while keeping the others fixed. The gradient with respect to b_j is (-1/m)A_j'@(y - b_0 - sum_j A_j@b_j) + sqrt{n_j}(l_1*b_j/||b_j|| + 2*l_2*b_j). and the Hessian is (1/m)A_j'@A_j + sqrt{n_j}(l_1*(I/||b_j|| - b_j@b_j'/||b_j||^3) + 2*l_2*I). The coefficients are represented using a scalar bias b_0 and a matrix B where each row corresponds to a single group (with appropriate 0 padding). The root of the group sizes are stored in a vector sns. The features are stored in a list of tensors each of size m x n_j where m is the number of samples, and n_j is the number of features in group j. For any j, r_j = y - b_0 - sum_{k =/= j} A_k@b_k, q_j = r_j - A_j@b_j, C_j = A_j'@A_j / m, and I_j is an identity matrix of size n_j. Finally, a_1 = l_1*sns and a_2 = 2*l_2*sns. """ import torch import tqdm def _f_j(q_j, b_j_norm, a_1_j, a_2_j, m): """Compute the objective with respect to one of the coefficients i.e. (1/2m)||q_j||^2 + a_1||b_j|| + (a_2/2)||b_j||^2. """ return ( ((q_j @ q_j) / (2.0 * m)) + (a_1_j * b_j_norm) + ((a_2_j / 2.0) * (b_j_norm ** 2)) ) def _grad_j(q_j, A_j, b_j, b_j_norm, a_1_j, a_2_j, m): """Compute the gradient with respect to one of the coefficients.""" return (A_j.t() @ q_j / (-m)) + (b_j * (a_1_j / b_j_norm + a_2_j)) def _hess_j(C_j, I_j, b_j, b_j_norm, a_1_j, a_2_j): """Compute the Hessian with respect to one of the coefficients.""" D_j = torch.ger(b_j, b_j) return C_j + (a_1_j / b_j_norm) * (I_j - D_j / (b_j_norm ** 2)) + a_2_j * I_j def block_solve_agd( r_j, A_j, a_1_j, a_2_j, m, b_j_init, t_init=None, ls_beta=None, max_iters=None, rel_tol=1e-6, verbose=False, zero_thresh=1e-6, zero_fill=1e-3, ): """Solve the optimization problem for a single block with accelerated gradient descent.""" b_j = b_j_init b_j_prev = b_j k = 1 # iteration number t = 1 # initial step length (used if t_init is None) pbar_stats = {} # stats for the progress bar pbar = tqdm.tqdm(desc="Solving block with AGD", disable=not verbose, leave=False) while True: # Compute the v terms. mom = (k - 2) / (k + 1.0) v_j = b_j + mom * (b_j - b_j_prev) q_v_j = r_j - A_j @ v_j v_j_norm = v_j.norm(p=2) f_v_j = _f_j(q_v_j, v_j_norm, a_1_j, a_2_j, m) grad_v_j = _grad_j(q_v_j, A_j, v_j, v_j_norm, a_1_j, a_2_j, m) b_j_prev = b_j # Adjust the step size with backtracking line search. if t_init is not None: t = t_init while True: b_j = v_j - t * grad_v_j # gradient descent update if ls_beta is None: # Don't perform line search. break # Line search: exit when f_j(b_j) <= f_j(v_j) + grad_v_j'@(b_j - # v_j) + (1/2t)||b_j - v_j||^2. q_b_j = r_j - A_j @ b_j b_j_norm = b_j.norm(p=2) f_b_j = _f_j(q_b_j, b_j_norm, a_1_j, a_2_j, m) b_v_diff = b_j - v_j c2 = grad_v_j @ b_v_diff c3 = b_v_diff @ b_v_diff / 2.0 if t * f_b_j <= t * (f_v_j + c2) + c3: break else: t *= ls_beta # Make b_j non-zero if it is 0. if len((b_j.abs() < zero_thresh).nonzero()) == len(b_j): b_j.fill_(zero_fill) b_j_norm = b_j.norm(p=2) b_diff_norm = (b_j - b_j_prev).norm(p=2) pbar_stats["t"] = "{:.2g}".format(t) pbar_stats["rel change"] = "{:.2g}".format(b_diff_norm / b_j_norm) pbar.set_postfix(pbar_stats) pbar.update() # Check max iterations exit criterion. if max_iters is not None and k == max_iters: break k += 1 # Check tolerance exit criterion. Exit when the relative change is less # than the tolerance. # k > 2 ensures that at least 2 iterations are completed. if b_diff_norm <= rel_tol * b_j_norm and k > 2: break pbar.close() return b_j def block_solve_newton( r_j, A_j, a_1_j, a_2_j, m, b_j_init, C_j, I_j, ls_alpha=0.5, ls_beta=0.9, max_iters=20, tol=1e-8, verbose=False, ): """Solve the optimization problem for a single block with Newton's method.""" b_j = b_j_init k = 1 pbar_stats = {} # stats for the progress bar pbar = tqdm.tqdm( desc="Solving block with Newton's method", disable=not verbose, leave=False ) while True: # First, compute the Newton step and decrement. q_b_j = r_j - A_j @ b_j b_j_norm = b_j.norm(p=2) grad_b_j = _grad_j(q_b_j, A_j, b_j, b_j_norm, a_1_j, a_2_j, m) hess_b_j = _hess_j(C_j, I_j, b_j, b_j_norm, a_1_j, a_2_j) hessinv_b_j = torch.inverse(hess_b_j) v_j = hessinv_b_j @ grad_b_j dec_j = grad_b_j @ (hessinv_b_j @ grad_b_j) # Check tolerance stopping criterion. Exit if dec_j / 2 is less than the # tolerance. if dec_j / 2 <= tol: break # Perform backtracking line search. t = 1 f_b_j = _f_j(q_b_j, b_j_norm, a_1_j, a_2_j, m) k_j = grad_b_j @ v_j while True: # Compute the update and evaluate function at that point. bp_j = b_j - t * v_j q_bp_j = r_j - A_j @ bp_j bp_j_norm = bp_j.norm(p=2) f_bp_j = _f_j(q_bp_j, bp_j_norm, a_1_j, a_2_j, m) if f_bp_j <= f_b_j - ls_alpha * t * k_j: b_j = bp_j break else: t *= ls_beta # Make b_j non-zero if it is 0. if all(b_j.abs() < tol): b_j.fill_(1e-3) pbar_stats["t"] = "{:.2g}".format(t) pbar_stats["1/2 newton decrement"] = "{:.2g}".format(dec_j / 2) pbar.set_postfix(pbar_stats) pbar.update() # Check max iterations stopping criterion. if max_iters is not None and k == max_iters and k > 2: break k += 1 pbar.close() return b_j def make_A( As, ns, device=torch.device("cpu"), dtype=torch.float32 ): # pylint: disable=unused-argument """Move the As to the provided device, convert to the provided dtype, and return as A.""" As = [A_j.to(device, dtype) for A_j in As] return As def gel_solve( A, y, l_1, l_2, ns, b_init=None, block_solve_fun=block_solve_agd, block_solve_kwargs=None, max_cd_iters=None, rel_tol=1e-6, Cs=None, Is=None, verbose=False, ): """Solve a group elastic net problem. Arguments: A: list of feature tensors, one per group (size m x n_j). y: tensor vector of predictions. l_1: the 2-norm coefficient. l_2: the squared 2-norm coefficient. ns: iterable of group sizes. b_init: tuple (b_0, B) to initialize b. block_solve_fun: the function to be used for minimizing individual blocks (should be one of the block_solve_* functions). block_solve_kwargs: dictionary of arguments to be passed to block_solve_fun. max_cd_iters: maximum number of outer CD iterations. rel_tol: tolerance for exit criterion; after every CD outer iteration, b is compared to the previous value, and if the relative difference is less than this value, the loop is terminated. Cs, Is: lists of C_j and I_j matrices respectively; used if the internal solver is Newton's method. verbose: boolean to enable/disable verbosity. """ p = len(A) m = len(y) device = A[0].device dtype = A[0].dtype y = y.to(device, dtype) if block_solve_kwargs is None: block_solve_kwargs = dict() # Create initial values if not specified. if b_init is None: b_init = 0.0, torch.zeros(p, max(ns), device=device, dtype=dtype) if not isinstance(ns, torch.Tensor): ns = torch.tensor(ns) sns = ns.to(device, dtype).sqrt() a_1 = l_1 * sns ma_1 = m * a_1 a_2 = 2 * l_2 * sns b_0, B = b_init b_0_prev, B_prev = b_0, B k = 1 # iteration number pbar_stats = {} # stats for the outer progress bar pbar = tqdm.tqdm( desc="Solving gel with CD (l_1 {:.2g}, l_2 {:.2g})".format(l_1, l_2), disable=not verbose, ) while True: # First minimize with respect to b_0. This has a closed form solution # given by b_0 = 1'@(y - sum_j A_j@b_j) / m. b_0 = (y - sum(A[j] @ B[j, : ns[j]] for j in range(p))).sum() / m # Now, minimize with respect to each b_j. for j in tqdm.trange( p, desc="Solving individual blocks", disable=not verbose, leave=False ): r_j = y - b_0 - sum(A[k] @ B[k, : ns[k]] for k in range(p) if k != j) # Check if b_j must be set to 0. The condition is ||A_j'@r_j|| <= # m*a_1. if (A[j].t() @ r_j).norm(p=2) <= ma_1[j]: B[j] = 0 else: # Otherwise, minimize. First make sure initial value is not 0. if len((B[j, : ns[j]].abs() < 1e-6).nonzero()) == ns[j]: B[j, : ns[j]] = 1e-3 # Add C_j and I_j to the arguments if using Newton's method. if block_solve_fun is block_solve_newton: block_solve_kwargs["C_j"] = Cs[j] block_solve_kwargs["I_j"] = Is[j] B[j, : ns[j]] = block_solve_fun( r_j, A[j], a_1[j].item(), a_2[j].item(), m, B[j, : ns[j]], verbose=verbose, **block_solve_kwargs, ) # Compute relative change in b. b_0_diff = b_0 - b_0_prev B_diff = B - B_prev delta_norm = (b_0_diff ** 2 + (B_diff ** 2).sum()).sqrt() b_norm = (b_0 ** 2 + (B ** 2).sum()).sqrt() pbar_stats["rel change"] = "{:.2g}".format(delta_norm.item() / b_norm.item()) pbar.set_postfix(pbar_stats) pbar.update() # Check max iterations exit criterion. if max_cd_iters is not None and k == max_cd_iters: break k += 1 # Check tolerance exit criterion. if delta_norm.item() <= rel_tol * b_norm.item() and k > 2: break b_0_prev, B_prev = b_0, B pbar.close() return b_0.item(), B