import numpy as np #### CAUTION #### def _step_gamma(step, gamma): """Update gamma parameter for use inside of continuous proximal operator. Every proximal operator for a function with a continuous parameter, e.g. gamma ||x||_1, needs to update that parameter to account for the stepsize of the algorithm. Returns: gamma * step """ return gamma * step ################# def prox_id(X, step): """Identity proximal operator """ return X def prox_zero(X, step): """Proximal operator to project onto zero """ X[:] = np.zeros(X.shape, dtype=X.dtype) return X def prox_plus(X, step): """Projection onto non-negative numbers """ below = X < 0 X[below] = 0 return X def prox_unity(X, step, axis=0): """Projection onto sum=1 along an axis """ X[:] = X / np.sum(X, axis=axis, keepdims=True) return X def prox_unity_plus(X, step, axis=0): """Non-negative projection onto sum=1 along an axis """ X[:] = prox_unity(prox_plus(X, step), step, axis=axis) return X def prox_min(X, step, thresh=0, type="relative"): """Projection onto numbers above `thresh` If type == 'relative', the penalty is expressed in units of the function value; if type == 'absolute', it's expressed in units of the variable `X`. """ assert type in ["relative", "absolute"] if type == "relative": thresh_ = _step_gamma(step, thresh) else: thresh_ = thresh below = X - thresh_ < 0 X[below] = thresh_ return X def prox_max(X, step, thresh=0, type="relative"): """Projection onto numbers below `thresh` If type == 'relative', the penalty is expressed in units of the function value; if type == 'absolute', it's expressed in units of the variable `X`. """ assert type in ["relative", "absolute"] if type == "relative": thresh_ = _step_gamma(step, thresh) else: thresh_ = thresh above = X - thresh_ > 0 X[above] = thresh_ return X def prox_components(X, step, prox=None, axis=0): """Split X along axis and apply prox to each chunk. prox can be a list. """ K = X.shape[axis] if not hasattr(prox_list, "__iter__"): prox = [prox] * K assert len(prox_list) == K if axis == 0: Pk = [prox_list[k](X[k], step) for k in range(K)] if axis == 1: Pk = [prox_list[k](X[:, k], step) for k in range(K)] X[:] = np.stack(Pk, axis=axis) return X #### Regularization function below #### def prox_hard(X, step, thresh=0, type="relative"): """Hard thresholding X if |X| >= thresh, otherwise 0 NOTE: modifies X in place If type == 'relative', the penalty is expressed in units of the function value; if type == 'absolute', it's expressed in units of the variable `X`. """ assert type in ["relative", "absolute"] if type == "relative": thresh_ = _step_gamma(step, thresh) else: thresh_ = thresh below = np.abs(X) < thresh_ X[below] = 0 return X def prox_hard_plus(X, step, thresh=0, type="relative"): """Hard thresholding with projection onto non-negative numbers If type == 'relative', the penalty is expressed in units of the function value; if type == 'absolute', it's expressed in units of the variable `X`. """ X[:] = prox_plus(prox_hard(X, step, thresh=thresh, type=type), step) return X def prox_soft(X, step, thresh=0, type="relative"): """Soft thresholding proximal operator If type == 'relative', the penalty is expressed in units of the function value; if type == 'absolute', it's expressed in units of the variable `X`. """ assert type in ["relative", "absolute"] if type == "relative": thresh_ = _step_gamma(step, thresh) else: thresh_ = thresh X[:] = np.sign(X) * prox_plus(np.abs(X) - thresh_, step) return X def prox_soft_plus(X, step, thresh=0, type="relative"): """Soft thresholding with projection onto non-negative numbers If type == 'relative', the penalty is expressed in units of the function value; if type == 'absolute', it's expressed in units of the variable `X`. """ X[:] = prox_plus(prox_soft(X, step, thresh=thresh, type=type), step) return X def prox_max_entropy(X, step, gamma=1, type="relative"): """Proximal operator for maximum entropy regularization. g(x) = gamma sum_i x_i ln(x_i) has the analytical solution of gamma W(1/gamma exp((X-gamma)/gamma)), where W is the Lambert W function. If type == 'relative', the penalty is expressed in units of the function value; if type == 'absolute', it's expressed in units of the variable `X`. """ from scipy.special import lambertw assert type in ["relative", "absolute"] if type == "relative": gamma_ = _step_gamma(step, gamma) else: gamma_ = gamma # minimize entropy: return gamma_ * np.real(lambertw(np.exp((X - gamma_) / gamma_) / gamma_)) above = X > 0 X[above] = gamma_ * np.real(lambertw(np.exp(X[above] / gamma_ - 1) / gamma_)) return X class AlternatingProjections(object): """Combine several proximal operators in the form of Alternating Projections This implements the simple POCS method with several repeated executions of the projection sequence. Note: The operators are executed in the "natural" order, i.e. the first one in the list is applied last. """ def __init__(self, prox_list=None, repeat=1): self.operators = [] self.repeat = repeat if prox_list is not None: self.operators += prox_list def __call__(self, X, step): # simple POCS method, no Dykstra or averaging # TODO: no convergence test # NOTE: inline updates for r in range(self.repeat): # in reverse order (first one last, as expected from a sequence of ops) for prox in self.operators[::-1]: X = prox(X, step) return X def find(self, cls): import functools for i in range(len(self.operators)): prox = self.operators[i] if isinstance(prox, functools.partial): if prox.func is cls: return i else: if prox is cls: return i return -1