# -*- coding: utf-8 -*- """ Created on 2019/8/15 File radam @author: ZhengYuwei """ import tensorflow as tf from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.keras import backend as K class RAdam(tf.keras.optimizers.Optimizer): """RAdam optimizer. Default parameters follow those provided in the original paper. Arguments: lr: float >= 0. Learning rate. beta_1: float, 0 < beta < 1. Generally close to 1. beta_2: float, 0 < beta < 1. Generally close to 1. epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`. decay: float >= 0. Learning rate decay over each update. amsgrad: boolean. Whether to apply the AMSGrad variant of this algorithm from the paper "On the Convergence of Adam and Beyond". warmup_coef: in early training stage, RAdam will fallback to SGDM, and for using warmup in SGDM, will set warmup_lr = warmup_coef * lr, default 1. """ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0., amsgrad=False, warmup_coef=1., **kwargs): super(RAdam, self).__init__(**kwargs) with K.name_scope(self.__class__.__name__): self.iterations = K.variable(0, dtype='int64', name='iterations') self.lr = K.variable(lr, name='lr') self.beta_1 = K.variable(beta_1, name='beta_1') self.beta_2 = K.variable(beta_2, name='beta_2') self.decay = K.variable(decay, name='decay') if epsilon is None: epsilon = K.epsilon() self.epsilon = epsilon self.initial_decay = decay self.amsgrad = amsgrad self.warmup_coef = warmup_coef self.rho_inf = 2. / (1. - self.beta_2) - 1 def get_updates(self, loss, params): grads = self.get_gradients(loss, params) self.updates = [] lr = self.lr if self.initial_decay > 0: lr = lr * ( # pylint: disable=g-no-augmented-assignment 1. / (1. + self.decay * math_ops.cast(self.iterations, K.dtype(self.decay)))) with ops.control_dependencies([state_ops.assign_add(self.iterations, 1)]): t = math_ops.cast(self.iterations, K.floatx()) ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] if self.amsgrad: vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] else: vhats = [K.zeros(1) for _ in params] self.weights = [self.iterations] + ms + vs + vhats beta_1_power = math_ops.pow(self.beta_1, t) beta_2_power = math_ops.pow(self.beta_2, t) rho_t = self.rho_inf - 2.0 * t * beta_2_power / (1.0 - beta_2_power) lr_t = tf.where(rho_t >= 5.0, K.sqrt((rho_t - 4.) * (rho_t - 2.) * self.rho_inf / ((self.rho_inf - 4.) * (self.rho_inf - 2.) * rho_t)) * lr * (K.sqrt(1. - beta_2_power) / (1. - beta_1_power)), self.warmup_coef * lr / (1. - beta_1_power)) for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * math_ops.square(g) if self.amsgrad: vhat_t = math_ops.maximum(vhat, v_t) p_t = p - lr_t * tf.where(rho_t >= 5.0, m_t / (K.sqrt(vhat_t) + self.epsilon), m_t) self.updates.append(state_ops.assign(vhat, vhat_t)) else: p_t = p - lr_t * tf.where(rho_t >= 5.0, m_t / (K.sqrt(v_t) + self.epsilon), m_t) self.updates.append(state_ops.assign(m, m_t)) self.updates.append(state_ops.assign(v, v_t)) new_p = p_t # Apply constraints. if getattr(p, 'constraint', None) is not None: new_p = p.constraint(new_p) self.updates.append(state_ops.assign(p, new_p)) return self.updates def get_config(self): config = { 'lr': float(K.get_value(self.lr)), 'beta_1': float(K.get_value(self.beta_1)), 'beta_2': float(K.get_value(self.beta_2)), 'decay': float(K.get_value(self.decay)), 'epsilon': self.epsilon, 'amsgrad': self.amsgrad } base_config = super(RAdam, self).get_config() return dict(list(base_config.items()) + list(config.items()))