import math import numpy as np import torch import copy # eps for numerical stability DEBUG = True eps = 1e-6 if DEBUG: import logging logging.basicConfig(filename="./num.log",level=logging.DEBUG) class YFOptimizer(object): def __init__(self, var_list, lr=0.1, mu=0.0, clip_thresh=None, weight_decay=0.0, beta=0.999, curv_win_width=20, zero_debias=True, sparsity_debias=True, delta_mu=0.0, auto_clip_fac=None, force_non_inc_step=False, lr_grad_norm_thresh=1.0, h_max_log_smooth=False, h_min_log_smooth=False, checkpoint_interval=500): ''' clip thresh is the threshold value on ||lr * gradient|| delta_mu can be place holder/variable/python scalar. They are used for additional momentum in situations such as asynchronous-parallel training. The default is 0.0 for basic usage of the optimizer. Args: lr: python scalar. The initial value of learning rate, we use 1.0 in our paper. mu: python scalar. The initial value of momentum, we use 0.0 in our paper. clip_thresh: python scalar. The manaully-set clipping threshold for tf.clip_by_global_norm. if None, the automatic clipping can be carried out. The automatic clipping feature is parameterized by argument auto_clip_fac. The auto clip feature can be switched off with auto_clip_fac = None beta: python scalar. The smoothing parameter for estimations. sparsity_debias: gradient norm and curvature are biased to larger values when calculated with sparse gradient. This is useful when the model is very sparse, e.g. LSTM with word embedding. For non-sparse CNN, turning it off could slightly accelerate the speed. delta_mu: for extensions. Not necessary in the basic use. force_non_inc_step: in some very rare cases, it is necessary to force ||lr * gradient|| to be not increasing dramatically for stableness after some iterations. In practice, if turned on, we enforce lr * sqrt(smoothed ||grad||^2) to be less than 2x of the minimal value of historical value on smoothed || lr * grad ||. This feature is turned off by default. Other features: If you want to manually control the learning rates, self.lr_factor is an interface to the outside, it is an multiplier for the internal learning rate in YellowFin. It is helpful when you want to do additional hand tuning or some decaying scheme to the tuned learning rate in YellowFin. Example on using lr_factor can be found here: https://github.com/JianGoForIt/YellowFin_Pytorch/blob/master/pytorch-cifar/main.py#L109 ''' self._lr = lr self._mu = mu # we convert var_list from generator to list so that # it can be used for multiple times self._var_list = list(var_list) self._clip_thresh = clip_thresh self._auto_clip_fac = auto_clip_fac self._beta = beta self._curv_win_width = curv_win_width self._zero_debias = zero_debias self._sparsity_debias = sparsity_debias self._force_non_inc_step = force_non_inc_step self._optimizer = torch.optim.SGD(self._var_list, lr=self._lr, momentum=self._mu, weight_decay=weight_decay) self._iter = 0 # global states are the statistics self._global_state = {} # for decaying learning rate and etc. self._lr_factor = 1.0 # lr threshold self._lr_grad_norm_thresh = lr_grad_norm_thresh # smoothing options self._h_max_log_smooth = h_max_log_smooth self._h_min_log_smooth = h_min_log_smooth # checkpoint interval self._checkpoint_interval = checkpoint_interval if DEBUG: logging.debug('This message should go to the log file') def state_dict(self): # for checkpoint saving sgd_state_dict = self._optimizer.state_dict() # for recover model internally in case of numerical issue model_state_list = [p.data \ for group in self._optimizer.param_groups for p in group['params'] ] global_state = self._global_state lr_factor = self._lr_factor iter = self._iter lr = self._lr mu = self._mu clip_thresh = self._clip_thresh beta = self._beta curv_win_width = self._curv_win_width zero_debias = self._zero_debias h_min = self._h_min h_max = self._h_max return { "sgd_state_dict": sgd_state_dict, "model_state_list": model_state_list, "global_state": global_state, "lr_factor": lr_factor, "iter": iter, "lr": lr, "mu": mu, "clip_thresh": clip_thresh, "beta": beta, "curv_win_width": curv_win_width, "zero_debias": zero_debias, "h_min": h_min, "h_max": h_max } def load_state_dict(self, state_dict): # for checkpoint saving self._optimizer.load_state_dict(state_dict['sgd_state_dict']) # for recover model internally if any numerical issue happens param_id = 0 for group in self._optimizer.param_groups: for p in group["params"]: p.data = state_dict["model_state_list"][param_id] param_id += 1 self._global_state = state_dict['global_state'] self._lr_factor = state_dict['lr_factor'] self._iter = state_dict['iter'] self._lr = state_dict['lr'] self._mu = state_dict['mu'] self._clip_thresh = state_dict['clip_thresh'] self._beta = state_dict['beta'] self._curv_win_width = state_dict['curv_win_width'] self._zero_debias = state_dict['zero_debias'] self._h_min = state_dict["h_min"] self._h_max = state_dict["h_max"] return def set_lr_factor(self, factor): self._lr_factor = factor return def get_lr_factor(self): return self._lr_factor def zero_grad(self): self._optimizer.zero_grad() return def zero_debias_factor(self): return 1.0 - self._beta ** (self._iter + 1) def zero_debias_factor_delay(self, delay): # for exponentially averaged stat which starts at non-zero iter return 1.0 - self._beta ** (self._iter - delay + 1) def curvature_range(self): global_state = self._global_state if self._iter == 0: global_state["curv_win"] = torch.FloatTensor(self._curv_win_width, 1).zero_() curv_win = global_state["curv_win"] grad_norm_squared = self._global_state["grad_norm_squared"] # curv_win[self._iter % self._curv_win_width] = np.log(grad_norm_squared + eps) curv_win[self._iter % self._curv_win_width] = grad_norm_squared valid_end = min(self._curv_win_width, self._iter + 1) # we use running average over log scale, accelerating # h_max / min in the begining to follow the varying trend of curvature. beta = self._beta if self._iter == 0: global_state["h_min_avg"] = 0.0 global_state["h_max_avg"] = 0.0 self._h_min = 0.0 self._h_max = 0.0 if self._h_min_log_smooth: global_state["h_min_avg"] = \ global_state["h_min_avg"] * beta + (1 - beta) * torch.min(np.log(curv_win[:valid_end] + eps) ) else: global_state["h_min_avg"] = \ global_state["h_min_avg"] * beta + (1 - beta) * torch.min(curv_win[:valid_end] ) if self._h_max_log_smooth: global_state["h_max_avg"] = \ global_state["h_max_avg"] * beta + (1 - beta) * torch.max(np.log(curv_win[:valid_end] + eps) ) else: global_state["h_max_avg"] = \ global_state["h_max_avg"] * beta + (1 - beta) * torch.max(curv_win[:valid_end] ) if self._zero_debias: debias_factor = self.zero_debias_factor() if self._h_min_log_smooth: self._h_min = np.exp(global_state["h_min_avg"] / debias_factor) else: self._h_min = global_state["h_min_avg"] / debias_factor if self._h_max_log_smooth: self._h_max = np.exp(global_state["h_max_avg"] / debias_factor) else: self._h_max = global_state["h_max_avg"] / debias_factor else: if self._h_min_log_smooth: self._h_min = np.exp(global_state["h_min_avg"] ) else: self._h_min = global_state["h_min_avg"] if self._h_max_log_smooth: self._h_max = np.exp(global_state["h_max_avg"] ) else: self._h_max = global_state["h_max_avg"] if self._sparsity_debias: self._h_min *= self._sparsity_avg self._h_max *= self._sparsity_avg return def grad_variance(self): global_state = self._global_state beta = self._beta self._grad_var = np.array(0.0, dtype=np.float32) for group_id, group in enumerate(self._optimizer.param_groups): for p_id, p in enumerate(group['params'] ): if p.grad is None: continue grad = p.grad.data state = self._optimizer.state[p] if self._iter == 0: state["grad_avg"] = grad.new().resize_as_(grad).zero_() state["grad_avg_squared"] = 0.0 state["grad_avg"].mul_(beta).add_(1 - beta, grad) self._grad_var += torch.sum(state["grad_avg"] * state["grad_avg"] ) if self._zero_debias: debias_factor = self.zero_debias_factor() else: debias_factor = 1.0 self._grad_var /= -(debias_factor**2) self._grad_var += global_state['grad_norm_squared_avg'] / debias_factor # in case of negative variance: the two term are using different debias factors self._grad_var = max(self._grad_var, eps) if self._sparsity_debias: self._grad_var *= self._sparsity_avg return def dist_to_opt(self): global_state = self._global_state beta = self._beta if self._iter == 0: global_state["grad_norm_avg"] = 0.0 global_state["dist_to_opt_avg"] = 0.0 global_state["grad_norm_avg"] = \ global_state["grad_norm_avg"] * beta + (1 - beta) * math.sqrt(global_state["grad_norm_squared"] ) global_state["dist_to_opt_avg"] = \ global_state["dist_to_opt_avg"] * beta \ + (1 - beta) * global_state["grad_norm_avg"] / (global_state['grad_norm_squared_avg'] + eps) if self._zero_debias: debias_factor = self.zero_debias_factor() self._dist_to_opt = global_state["dist_to_opt_avg"] / debias_factor else: self._dist_to_opt = global_state["dist_to_opt_avg"] if self._sparsity_debias: self._dist_to_opt /= (np.sqrt(self._sparsity_avg) + eps) return def grad_sparsity(self): global_state = self._global_state if self._iter == 0: global_state["sparsity_avg"] = 0.0 non_zero_cnt = 0.0 all_entry_cnt = 0.0 for group in self._optimizer.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data grad_non_zero = grad.nonzero() if grad_non_zero.dim() > 0: non_zero_cnt += grad_non_zero.size()[0] all_entry_cnt += torch.numel(grad) beta = self._beta global_state["sparsity_avg"] = beta * global_state["sparsity_avg"] \ + (1 - beta) * non_zero_cnt / float(all_entry_cnt) self._sparsity_avg = \ global_state["sparsity_avg"] / self.zero_debias_factor() if DEBUG: logging.debug("sparsity %f, sparsity avg %f", non_zero_cnt / float(all_entry_cnt), self._sparsity_avg) return def lr_grad_norm_avg(self): # this is for enforcing lr * grad_norm not # increasing dramatically in case of instability. # Not necessary for basic use. global_state = self._global_state beta = self._beta if "lr_grad_norm_avg" not in global_state: global_state['grad_norm_squared_avg_log'] = 0.0 global_state['grad_norm_squared_avg_log'] = \ global_state['grad_norm_squared_avg_log'] * beta \ + (1 - beta) * np.log(global_state['grad_norm_squared'] + eps) if "lr_grad_norm_avg" not in global_state: global_state["lr_grad_norm_avg"] = \ 0.0 * beta + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps) # we monitor the minimal smoothed ||lr * grad|| global_state["lr_grad_norm_avg_min"] = \ np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) else: global_state["lr_grad_norm_avg"] = global_state["lr_grad_norm_avg"] * beta \ + (1 - beta) * np.log(self._lr * np.sqrt(global_state['grad_norm_squared'] ) + eps) global_state["lr_grad_norm_avg_min"] = \ min(global_state["lr_grad_norm_avg_min"], np.exp(global_state["lr_grad_norm_avg"] / self.zero_debias_factor() ) ) def before_apply(self): # compute running average of gradient and norm of gradient beta = self._beta global_state = self._global_state if self._iter == 0: global_state["grad_norm_squared_avg"] = 0.0 global_state["grad_norm_squared"] = 0.0 for group_id, group in enumerate(self._optimizer.param_groups): for p_id, p in enumerate(group['params'] ): if p.grad is None: continue grad = p.grad.data param_grad_norm_squared = torch.sum(grad * grad) global_state['grad_norm_squared'] += param_grad_norm_squared if DEBUG: logging.debug("Iteration %f", self._iter) logging.debug("param grad squared gid %d, pid %d, %f, %f", group_id, p_id, param_grad_norm_squared, np.log(param_grad_norm_squared) / np.log(10) ) global_state['grad_norm_squared_avg'] = \ global_state['grad_norm_squared_avg'] * beta + (1 - beta) * global_state['grad_norm_squared'] if DEBUG: logging.debug("overall grad norm squared %f, %f", global_state['grad_norm_squared'], np.log(global_state['grad_norm_squared'] ) / np.log(10)) if self._sparsity_debias: self.grad_sparsity() self.curvature_range() self.grad_variance() self.dist_to_opt() if DEBUG: logging.debug("h_min %f, %f", self._h_min, np.log(self._h_min) ) logging.debug("h_max %f, %f", self._h_max, np.log(self._h_max) ) logging.debug("dist %f, %f", self._dist_to_opt, np.log(self._dist_to_opt) ) logging.debug("var %f, %f", self._grad_var, np.log(self._grad_var) ) if self._iter > 0: self.get_mu() self.get_lr() self._lr = beta * self._lr + (1 - beta) * self._lr_t self._mu = beta * self._mu + (1 - beta) * self._mu_t if DEBUG: logging.debug("lr_t %f", self._lr_t) logging.debug("mu_t %f", self._mu_t) logging.debug("lr %f", self._lr) logging.debug("mu %f", self._mu) return def get_lr(self): self._lr_t = (1.0 - math.sqrt(self._mu_t) )**2 / (self._h_min + eps) return def get_cubic_root(self): # We have the equation x^2 D^2 + (1-x)^4 * C / h_min^2 # where x = sqrt(mu). # We substitute x, which is sqrt(mu), with x = y + 1. # It gives y^3 + py = q # where p = (D^2 h_min^2)/(2*C) and q = -p. # We use the Vieta's substution to compute the root. # There is only one real solution y (which is in [0, 1] ). # http://mathworld.wolfram.com/VietasSubstitution.html # eps in the numerator is to prevent momentum = 1 in case of zero gradient p = (self._dist_to_opt + eps)**2 * (self._h_min + eps)**2 / 2 / (self._grad_var + eps) w3 = (-math.sqrt(p**2 + 4.0 / 27.0 * p**3) - p) / 2.0 w = math.copysign(1.0, w3) * math.pow(math.fabs(w3), 1.0/3.0) y = w - p / 3.0 / (w + eps) x = y + 1 if DEBUG: logging.debug("p %f, den %f", p, self._grad_var + eps) logging.debug("w3 %f ", w3) logging.debug("y %f, den %f", y, w + eps) return x def get_mu(self): root = self.get_cubic_root() dr = (self._h_max + eps) / (self._h_min + eps) self._mu_t = max(root**2, ( (np.sqrt(dr) - 1) / (np.sqrt(dr) + 1) )**2 ) return def update_hyper_param(self): for group in self._optimizer.param_groups: group['momentum'] = self._mu if self._force_non_inc_step == False: group['lr'] = min(self._lr * self._lr_factor, self._lr_grad_norm_thresh / (math.sqrt(self._global_state["grad_norm_squared"] ) + eps) ) elif self._iter > self._curv_win_width: # force to guarantee lr * grad_norm not increasing dramatically. # Not necessary for basic use. Please refer to the comments # in YFOptimizer.__init__ for more details self.lr_grad_norm_avg() debias_factor = self.zero_debias_factor() group['lr'] = min(self._lr * self._lr_factor, 2.0 * self._global_state["lr_grad_norm_avg_min"] \ / (np.sqrt(np.exp(self._global_state['grad_norm_squared_avg_log'] / debias_factor) ) + eps) ) return def auto_clip_thresh(self): # Heuristic to automatically prevent sudden exploding gradient # Not necessary for basic use. return math.sqrt(self._h_max) * self._auto_clip_fac def step(self): # add weight decay for group in self._optimizer.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if group['weight_decay'] != 0: grad = grad.add(group['weight_decay'], p.data) if self._clip_thresh != None: torch.nn.utils.clip_grad_norm(self._var_list, self._clip_thresh) elif (self._iter != 0 and self._auto_clip_fac != None): # do not clip the first iteration torch.nn.utils.clip_grad_norm(self._var_list, self.auto_clip_thresh() ) if True:#try: # before appply self.before_apply() # update learning rate and momentum self.update_hyper_param() # apply update self._optimizer.step() # periodically save model and states if self._iter % self._checkpoint_interval == 1: self._state_checkpoint = copy.deepcopy(self.state_dict() ) self._iter += 1 #except: # # load the last checkpoint # logging.warning("Numerical issue triggered restore with backup states. Resumed from last internal checkpoint.") # self.load_state_dict(self._state_checkpoint) return