# coding=utf-8 # Copyright 2020 The Mesh TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Mesh Tensorflow Optimizers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import re import gin from mesh_tensorflow import layers from mesh_tensorflow import ops_with_redefined_builtins as mtf import tensorflow.compat.v1 as tf def make_optimizer(hparams, lr): if hparams.optimizer == "SGD": return SgdOptimizer(lr) elif hparams.optimizer == "Adafactor": return adafactor_optimizer_from_hparams(hparams, lr) else: raise ValueError("Unknown Optimizer") class Optimizer(object): """Base optimizer class. Constructor of subclasses must take `learning_rate` as an argument. """ def apply_grads(self, grads, variables): """Apply gradients to variables. Call this function externally instead of apply_grad(). This causes the operations to be combined, which is necessary for stacking variables see mtf.rewrite_stack_variables(). Args: grads: a list of Tensor variables: a list of Variables Returns: a list of Operations """ ops = [] for grad, var in zip(grads, variables): ops.extend(self.apply_grad(grad, var)) if not ops: return ops return variables[0].graph.combine_assignments(ops) def apply_grad(self, grad, var): """Update variable and accumulators. Args: grad: a Tensor var: a Variablle Returns: a list of Operations """ raise ValueError("apply_grad not implemented %s %s" % (grad, var)) @gin.configurable class SgdOptimizer(Optimizer): """Optimizer implementing SGD.""" def __init__(self, learning_rate): self._lr = learning_rate @property def lr(self): return self._lr def apply_grad(self, grad, var): if grad is None: tf.logging.warning("Gradient is None for variable %s" % var.name) return [] # It is critical to use assign_sub instead of mtf.assign(var - ...) # for the case of bfloat16 activations, so as to avoid repeatedly rounding # the slice value, which results in poor quality. return [mtf.assign_sub(var, grad * self.lr)] @gin.configurable class MomentumOptimizer(Optimizer): """SGD with momentum.""" def __init__(self, learning_rate, momentum): self._lr = learning_rate self._momentum = momentum @property def lr(self): return self._lr @property def momentum(self): return self._momentum def apply_grad(self, grad, var): if grad is None: tf.logging.warning("Gradient is None for variable %s" % var.name) return [] updates = [] v = mtf.get_variable( var.mesh, var.name + "_momentum_v", var.shape, dtype=var.dtype, initializer=tf.zeros_initializer(), trainable=False) with tf.variable_scope(var.name + "/sgd_momentum"): updates.append(mtf.assign(v, grad * self.lr + v * self.momentum)) updates.append(mtf.assign_sub(var, v)) return updates @gin.configurable class AdamWeightDecayOptimizer(Optimizer): """A basic Adam optimizer that includes "correct" L2 weight decay.""" def __init__(self, learning_rate, weight_decay_rate=0.0, beta_1=0.9, beta_2=0.999, epsilon=1e-6, exclude_from_weight_decay=None): """Constructs a AdamWeightDecayOptimizer.""" self.learning_rate = learning_rate self.weight_decay_rate = weight_decay_rate self.beta_1 = beta_1 self.beta_2 = beta_2 self.epsilon = epsilon self.exclude_from_weight_decay = exclude_from_weight_decay def apply_grad(self, grad, var): """See base class.""" if grad is None: tf.logging.warning("Gradient is None for variable %s" % var.name) return [] grad = mtf.to_float(grad) assignments = [] m = mtf.get_variable( var.mesh, var.name + "/adam_m", var.shape, initializer=tf.zeros_initializer(), trainable=False) v = mtf.get_variable( var.mesh, var.name + "/adam_v", var.shape, initializer=tf.zeros_initializer(), trainable=False) # Standard Adam update. next_m = self.beta_1 * m + (1.0 - self.beta_1) * grad next_v = self.beta_2 * v + (1.0 - self.beta_2) * mtf.square(grad) update = next_m / (mtf.sqrt(next_v) + self.epsilon) # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(var.name): update += self.weight_decay_rate * var.value update_with_lr = self.learning_rate * update var_update = mtf.assign_sub(var, update_with_lr) assignments.extend( [var_update, mtf.assign(m, next_m), mtf.assign(v, next_v)]) return assignments def _do_use_weight_decay(self, param_name): """Whether to use L2 weight decay for `param_name`.""" if not self.weight_decay_rate: return False if self.exclude_from_weight_decay: for r in self.exclude_from_weight_decay: if re.search(r, param_name) is not None: return False return True @gin.configurable class AdafactorOptimizer(Optimizer): """Adafactor.""" def __init__(self, multiply_by_parameter_scale=True, learning_rate=None, decay_rate=None, beta1=0.0, clipping_threshold=1.0, factored=True, epsilon1=1e-30, epsilon2=1e-3, min_dim_size_to_factor=128): """Construct a new Adafactor optimizer. See class comment. Args: multiply_by_parameter_scale: a boolean learning_rate: an optional Scalar. decay_rate: an optional Scalar. beta1: a float value between 0 and 1 clipping_threshold: an optional float >= 1 factored: a boolean - whether to use factored second-moment estimator for 2d variables epsilon1: Regularization constant for squared gradient. epsilon2: Regularization constant for parameter scale. min_dim_size_to_factor: only factor accumulator if two tensor dimensions are at least this size. Raises: ValueError: if absolute_update_scale and relative_update_scale_fn are both present or both absent. """ self._multiply_by_parameter_scale = multiply_by_parameter_scale if learning_rate is None: learning_rate = self._learning_rate_default(multiply_by_parameter_scale) self._learning_rate = learning_rate if decay_rate is None: decay_rate = self._decay_rate_default() self._decay_rate = decay_rate self._beta1 = beta1 self._clipping_threshold = clipping_threshold self._factored = factored self._epsilon1 = epsilon1 self._epsilon2 = epsilon2 self._min_dim_size_to_factor = min_dim_size_to_factor def _factored_dims(self, shape): """Should we use a factored second moment estimator. Based on the shape of the variable. If we factor the accumulator, then this function returns a list of two mtf.Dimensions to reduce over. We always pick the two largest dimensions. If there are not two dimensions of size >= min_dim_size_to_factor, then we do not factor. Args: shape: a Shape Returns: either a list of 2 Dimensions or None """ if not self._factored or shape.ndims < 2: return None sorted_dims = sorted(shape.dims, key=lambda d: -d.size) if sorted_dims[1].size < self._min_dim_size_to_factor: return None return sorted_dims[:2] def _parameter_scale(self, var): """Estimate the scale of the parameters from the current values. We include a minimum value of 0.001 to give it a chance to escape 0 if it was zero-initialized. Instead of using the value, we could impute the scale from the shape, as initializers do. Args: var: a variable or Tensor. Returns: a Scalar """ return mtf.maximum(reduce_rms(var), self._epsilon2) def apply_grad(self, grad, var): if grad is None: tf.logging.warning("Gradient is None for variable %s" % var.name) return [] # create slots grad = mtf.to_float(grad) factored_dims = self._factored_dims(var.shape) if factored_dims: d0, d1 = factored_dims vr_shape = var.shape - d0 vc_shape = var.shape - d1 vr = mtf.get_variable( var.mesh, var.name + "_slot_vr", vr_shape, initializer=tf.zeros_initializer(), trainable=False) vc = mtf.get_variable( var.mesh, var.name + "_slot_vc", vc_shape, initializer=tf.zeros_initializer(), trainable=False) else: v = mtf.get_variable( var.mesh, var.name + "_slot_v", var.shape, initializer=tf.zeros_initializer(), trainable=False) if self._beta1: m = mtf.get_variable( var.mesh, var.name + "_slot_m", var.shape, initializer=tf.zeros_initializer(), trainable=False) with tf.variable_scope(var.name + "/adafactor"): grad_squared = mtf.square(grad) + self._epsilon1 decay_rate = self._decay_rate old_val = mtf.to_float(var.value) if self._multiply_by_parameter_scale: update_scale = self._parameter_scale(old_val) * self._learning_rate else: update_scale = self._learning_rate mixing_rate = 1.0 - decay_rate updates = [] if factored_dims: grad_squared_row_mean = mtf.reduce_mean( grad_squared, output_shape=vr_shape) grad_squared_col_mean = mtf.reduce_mean( grad_squared, output_shape=vc_shape) new_vr = vr * decay_rate + grad_squared_row_mean * mixing_rate new_vc = vc * decay_rate + grad_squared_col_mean * mixing_rate vr_update = mtf.assign(vr, new_vr) vc_update = mtf.assign(vc, new_vc) updates.extend([vr_update, vc_update]) long_term_mean = mtf.reduce_mean(new_vr, reduced_dim=d1) r_factor = mtf.rsqrt(new_vr / long_term_mean) c_factor = mtf.rsqrt(new_vc) x = grad * r_factor * c_factor else: new_v = v * decay_rate + grad_squared * mixing_rate v_update = mtf.assign(v, new_v) updates.append(v_update) x = grad * mtf.rsqrt(new_v) if self._clipping_threshold is not None: clipping_denom = mtf.maximum( 1.0, reduce_rms(x) / self._clipping_threshold) x /= clipping_denom subtrahend = x * update_scale if self._beta1: new_m = (m * tf.constant(self._beta1) + subtrahend * tf.constant(1.0 - self._beta1)) subtrahend = new_m updates.append(mtf.assign(m, new_m)) # It is critical to use assign_sub instead of mtf.assign(var - subtrahend) # for the case of bfloat16 activations, so as to avoid repeatedly # rounding the slice value, which results in poor quality. var_update = mtf.assign_sub(var, subtrahend) updates.append(var_update) return updates def _decay_rate_default(self): return adafactor_decay_rate_pow(0.8) def _learning_rate_default(self, multiply_by_parameter_scale): learning_rate = tf.minimum(tf.math.rsqrt(step_num() + 1.0), 0.01) if (not multiply_by_parameter_scale and not layers.unit_scaling_convention()): learning_rate *= 0.05 return learning_rate def adafactor_decay_rate_adam(beta2): """Second-moment decay rate like Adam, subsuming the correction factor. Args: beta2: a float between 0 and 1 Returns: a scalar """ t = tf.cast(tf.train.get_or_create_global_step(), tf.float32) + 1.0 decay = beta2 * (1.0 - tf.pow(beta2, t - 1.0)) / (1.0 - tf.pow(beta2, t)) return decay def adafactor_decay_rate_pow(exponent): """Second moment decay rate where memory-length grows as step_num^exponent. Args: exponent: a float between 0 and 1 Returns: a scalar """ return 1.0 - tf.pow((step_num() + 1.0), -exponent) def step_num(): return tf.cast(tf.train.get_or_create_global_step(), tf.float32) def adafactor_optimizer_from_hparams(hparams, lr): """Create an Adafactor optimizer based on model hparams. Args: hparams: model hyperparameters lr: learning rate scalar. Returns: an AdafactorOptimizer Raises: ValueError: on illegal values """ if hparams.optimizer_adafactor_decay_type == "Adam": decay_rate = adafactor_decay_rate_adam( hparams.optimizer_adafactor_beta2) elif hparams.optimizer_adafactor_decay_type == "pow": decay_rate = adafactor_decay_rate_pow( hparams.optimizer_adafactor_memory_exponent) else: raise ValueError("unknown optimizer_adafactor_decay_type") return AdafactorOptimizer( multiply_by_parameter_scale=( hparams.optimizer_adafactor_multiply_by_parameter_scale), learning_rate=lr, decay_rate=decay_rate, beta1=hparams.optimizer_adafactor_beta1, clipping_threshold=hparams.optimizer_adafactor_clipping_threshold, factored=hparams.optimizer_adafactor_factored) def reduce_rms(x): return mtf.sqrt(mtf.reduce_mean(mtf.square(x)))