import numpy as np import math import mxnet as mx @mx.optimizer.Optimizer.register class YFOptimizer(mx.optimizer.Optimizer): """The YF optimizer built upon SGD optimizer with momentum and weight decay. The optimizer updates the weight by:: state = momentum * state + lr * rescale_grad * clip(grad, clip_gradient) + wd * weight weight = weight - state For details of the update algorithm see :class:`~mxnet.ndarray.sgd_update` and :class:`~mxnet.ndarray.sgd_mom_update`. This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`. Parameters ---------- momentum : float, optional The initial momentum value. beta : float, optional The smoothing parameter for estimations. curv_win_width: int, optional zero_bias: bool, optional """ def __init__(self, momentum=0.0, beta=0.999, curv_win_width=20, zero_debias=True, **kwargs): super(YFOptimizer, self).__init__(**kwargs) self.momentum = momentum self.beta = beta self.curv_win_width = 20 self.zero_debias = zero_debias # The following are global states for YF tuner # 1. Calculate grad norm for all indices self._grad_norm = None # 2. Calculate grad norm squared for all indices self._grad_norm_squared = None # 3. Update state parameters for YF after each iteration # a. Used in curvature estimation self._h_min = 0.0 self._h_max = 0.0 self._h_window = np.zeros(curv_win_width) # b. Used in grad_variance self._grad_var = None # c. Used in distance to opt. estimation self._grad_norm_avg = 0.0 self._grad_norm_squared_avg = 0.0 self._h_avg = 0.0 self._dist_to_opt_avg = 0.0 # For testing purpose only self._test_res = [] def create_state(self, index, weight): momentum = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) grad_avg = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) grad_avg_squared = mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype) return momentum, grad_avg, grad_avg_squared def zero_debias_factor(self): if not self.zero_debias: return 1.0 return 1.0 - self.beta ** (self.num_update) def clear_grad_norm_info(self): # self._grad_norm = None self._grad_norm_squared = None self._grad_var = None def update_grad_norm_and_var(self, index, grad, state): _, grad_avg, grad_avg_squared = state # _, grad_avg = state grad_avg[:] = self.beta * grad_avg + (1. - self.beta) * grad grad_avg_squared[:] = self.beta * grad_avg_squared + (1. - self.beta) * mx.nd.square(grad) # grad_norm_squared = sum(grad * grad) grad_norm_squared = mx.ndarray.sum(grad * grad) # print(grad_norm_squared.shape) if self._grad_norm_squared is None: self._grad_norm_squared = grad_norm_squared else: self._grad_norm_squared += grad_norm_squared if self._grad_var is None: self._grad_var = mx.ndarray.sum(grad_avg * grad_avg) else: self._grad_var += mx.ndarray.sum(grad_avg * grad_avg) def curvature_range(self): curv_win = self._h_window beta = self.beta curv_win[(self.num_update-1) % self.curv_win_width] = self._grad_norm_squared valid_end = min(self.curv_win_width, self.num_update) self._h_min = beta * self._h_min + (1 - beta) * curv_win[:valid_end].min() self._h_max = beta * self._h_max + (1 - beta) * curv_win[:valid_end].max() debias_factor = self.zero_debias_factor() return self._h_min / debias_factor, self._h_max / debias_factor def grad_variance(self): debias_factor = self.zero_debias_factor() self._grad_var /= -(debias_factor ** 2) self._grad_var += self._grad_norm_squared_avg/debias_factor return self._grad_var def dist_to_opt(self): beta = self.beta self._grad_norm_avg = beta * self._grad_norm_avg + (1 - beta) * math.sqrt(self._grad_norm_squared) self._dist_to_opt_avg = beta * self._dist_to_opt_avg + (1 - beta) * self._grad_norm_avg / self._grad_norm_squared_avg debias_factor = self.zero_debias_factor() return self._dist_to_opt_avg / debias_factor def single_step_mu_lr(self, C, D, h_min, h_max): coef = np.array([-1.0, 3.0, 0.0, 1.0]) coef[2] = -(3 + D ** 2 * h_min ** 2 / 2 / C) roots = np.roots(coef) root = roots[np.logical_and(np.logical_and(np.real(roots) > 0.0, np.real(roots) < 1.0), np.imag(roots) < 1e-5)] assert root.size == 1 dr = h_max / h_min mu_t = max(np.real(root)[0] ** 2, ((np.sqrt(dr) - 1) / (np.sqrt(dr) + 1)) ** 2) lr_t = (1.0 - math.sqrt(mu_t)) ** 2 / h_min return mu_t, lr_t def after_apply(self): beta = self.beta self._grad_norm_squared = self._grad_norm_squared.asscalar() self._grad_norm_squared_avg = self.beta * self._grad_norm_squared_avg + (1 - self.beta) * self._grad_norm_squared h_min, h_max = self.curvature_range() C = self.grad_variance().asscalar() D = self.dist_to_opt() if self.num_update > 1: mu_t, lr_t = self.single_step_mu_lr(C, D, h_min, h_max) self.momentum = beta * self.momentum + (1 - beta) * mu_t self.lr = beta * self.lr + (1 - beta) * lr_t self._test_res = [h_max, h_min, C, D, self.lr, self.momentum] self.clear_grad_norm_info() def is_end_iter(self): if (self.num_update == 1) and (len(self._index_update_count) == len(self.idx2name)): return True elif (self.num_update > 1) and (np.min(self._index_update_count.values()) == self.num_update): return True else: return False def update(self, index, weight, grad, state): assert (isinstance(weight, mx.nd.NDArray)) assert (isinstance(grad, mx.nd.NDArray)) lr = self._get_lr(index) wd = self._get_wd(index) momentum = self.momentum self._update_count(index) kwargs = {'rescale_grad': self.rescale_grad} if self.momentum > 0: kwargs['momentum'] = momentum if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient if state is not None: mx.optimizer.sgd_mom_update(weight, grad, state[0], out=weight, lr=lr, wd=wd, **kwargs) self.update_grad_norm_and_var(index, grad*self.rescale_grad, state) if self.is_end_iter(): self.after_apply() else: mx.optimizer.sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)