# coding=utf-8 # Copyright 2020 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities related to using bfloat16 activations and/or parameters.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow.compat.v1 as tf from tensorflow.python.framework import function def bfloat16_activations_var_getter(getter, *args, **kwargs): """A custom getter function for float32 parameters and bfloat16 activations. Args: getter: custom getter *args: arguments **kwargs: keyword arguments Returns: variables with the correct dtype. Raises: KeyError: if "dtype" is not provided as a kwarg. """ requested_dtype = kwargs["dtype"] if requested_dtype == tf.bfloat16: kwargs["dtype"] = tf.float32 var = getter(*args, **kwargs) # This if statement is needed to guard the cast, because batch norm # assigns directly to the return value of this custom getter. The cast # makes the return value not a variable so it cannot be assigned. Batch # norm variables are always in fp32 so this if statement is never # triggered for them. if var.dtype.base_dtype != requested_dtype: var = tf.cast(var, requested_dtype) return var def float16_activations_var_getter(getter, *args, **kwargs): """A custom getter function for float32 parameters and float16 activations. This function ensures the following: 1. All variables requested with type fp16 are stored as type fp32. 2. All variables requested with type fp32 are returned as type fp16. See https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/ #training_tensorflow for more information on this strategy. Args: getter: custom getter *args: arguments **kwargs: keyword arguments Returns: variables with the correct dtype. Raises: KeyError: if "dtype" is not provided as a kwarg. """ requested_dtype = kwargs["dtype"] if requested_dtype == tf.float16: kwargs["dtype"] = tf.float32 if requested_dtype == tf.float32: requested_dtype = tf.float16 var = getter(*args, **kwargs) # This if statement is needed to guard the cast, because batch norm # assigns directly to the return value of this custom getter. The cast # makes the return value not a variable so it cannot be assigned. Batch # norm variables are always in fp32 so this if statement is never # triggered for them. if var.dtype.base_dtype != requested_dtype: var = tf.cast(var, requested_dtype) return var def simulated_quantize(x, num_bits, noise): """Simulate quantization to num_bits bits, with externally-stored scale. num_bits is the number of bits used to store each value. noise is a float32 Tensor containing values in [0, 1). Each value in noise should take different values across different steps, approximating a uniform distribution over [0, 1). In the case of replicated TPU training, noise should be identical across replicas in order to keep the parameters identical across replicas. The natural choice for noise would be tf.random_uniform(), but this is not possible for TPU, since there is currently no way to seed the different cores to produce identical values across replicas. Instead we use noise_from_step_num() (see below). The quantization scheme is as follows: Compute the maximum absolute value by row (call this max_abs). Store this either in an auxiliary variable or in an extra column. Divide the parameters by (max_abs / (2^(num_bits-1)-1)). This gives a float32 value in the range [-2^(num_bits-1)-1, 2^(num_bits-1)-1] Unbiased randomized roundoff by adding noise and rounding down. This produces a signed integer with num_bits bits which can then be stored. Args: x: a float32 Tensor num_bits: an integer between 1 and 22 noise: a float Tensor broadcastable to the shape of x. Returns: a float32 Tensor """ shape = x.get_shape().as_list() if not (len(shape) >= 2 and shape[-1] > 1): return x max_abs = tf.reduce_max(tf.abs(x), -1, keepdims=True) + 1e-9 max_int = 2 ** (num_bits - 1) - 1 scale = max_abs / max_int x /= scale x = tf.floor(x + noise) # dequantize before storing (since this is a simulation) x *= scale return x def noise_from_step_num(): """Quantization noise equal to (phi * (step_num + 1)) mod 1.0. Not using random_uniform here due to a problem on TPU in that random seeds are not respected, which may cause the parameters on different replicas to go out-of-sync. Returns: a float32 scalar """ step = tf.to_int32(tf.train.get_or_create_global_step()) + 1 phi = ((5 ** 0.5) - 1) / 2 # Naive computation tf.mod(phi * step, 1.0) in float32 would be disastrous # due to loss of precision when the step number gets large. # Computation in doubles does not work on TPU, so we use this complicated # alternative computation which does not suffer from these roundoff errors. ret = 0.0 for i in range(30): ret += (((phi * (2 ** i)) % 1.0) # double-precision computation in python * tf.to_float(tf.mod(step // (2 ** i), 2))) return tf.mod(ret, 1.0) def _randomized_roundoff_to_bfloat16(x, noise, cand1, cand2): """Round-off x to cand1 or to cand2 in an unbiased way. Cand1 and cand2 are the same shape as x. For every element of x, the corresponding elements of cand1 and cand2 should be the two closest bfloat16 values to x. Order does not matter. cand1 and cand2 must differ from each other. Args: x: A float32 Tensor. noise: A Tensor broadcastable to the shape of x containing random uniform values in [0.0, 1.0]. cand1: A bfloat16 Tensor the same shape as x. cand2: A bfloat16 Tensor the same shape as x. Returns: A bfloat16 Tensor. """ cand1_f = tf.to_float(cand1) cand2_f = tf.to_float(cand2) step_size = cand2_f - cand1_f fpart = (x - cand1_f) / step_size ret = tf.where(tf.greater(fpart, noise), cand2, cand1) return ret def _to_bfloat16_unbiased(x, noise): """Convert a float32 to a bfloat16 using randomized roundoff. Args: x: A float32 Tensor. noise: a float32 Tensor with values in [0, 1), broadcastable to tf.shape(x) Returns: A float32 Tensor. """ x_sign = tf.sign(x) # Make sure x is positive. If it is zero, the two candidates are identical. x = x * x_sign + 1e-30 cand1 = tf.to_bfloat16(x) cand1_f = tf.to_float(cand1) # This relies on the fact that for a positive bfloat16 b, # b * 1.005 gives you the next higher bfloat16 and b*0.995 gives you the # next lower one. Both 1.005 and 0.995 are ballpark estimation. cand2 = tf.to_bfloat16( tf.where(tf.greater(x, cand1_f), cand1_f * 1.005, cand1_f * 0.995)) ret = _randomized_roundoff_to_bfloat16(x, noise, cand1, cand2) return ret * tf.to_bfloat16(x_sign) class ParameterEncoding(object): """Helper class for encoding weights as bfloat16. For now, the parameters are always stored (encoded) as bfloat16 and decoded to bfloat32. Confusingly, the custom getter then converts the bfloat32 back to a bfloat16 to use as an activation, assuming that we use bfloat16 for activations. TODO(noam): Add options for activation dtype=float32, and for different storage dtypes. """ def encode(self, x, noise): """Encode float32 to bfloat16. Args: x: a float32 Tensor noise: a float32 Tensor with values in [0, 1), broadcastable to shape(x) Returns: a bfloat16 Tensor """ raise NotImplementedError("encode not implemented") def decode(self, x): """Decode bfloat16 to float32.""" raise NotImplementedError("decode not implemented") def _decode_with_identity_gradient(self, x): # identity backprop through the decoder. # This means that the optimizer must call encode when updating weights. @function.Defun(python_grad_func=lambda op, dy: dy, shape_func=lambda op: [op.inputs[0].get_shape()]) def my_fn(x): return self.decode(x) return my_fn(x) def custom_getter(self, activation_dtype=tf.bfloat16): """A custom getter that uses the encoding for bfloat16 and float32 vars. When a bfloat16 or float32 variable is requsted, an encoded float16 varaible is created, which is then decoded and cast to a bfloat16 activation. Args: activation_dtype: a dtype to which to convert the decoded value. Returns: a function. """ def getter_fn(getter, *args, **kwargs): requested_dtype = kwargs["dtype"] if requested_dtype in (tf.bfloat16, tf.float32): kwargs["dtype"] = tf.bfloat16 kwargs["initializer"] = _EncodingInitializer( kwargs["initializer"], self) ret = self._decode_with_identity_gradient(getter(*args, **kwargs)) return tf.cast(ret, activation_dtype) return getter(*args, **kwargs) return getter_fn class _EncodingInitializer(object): """Helper class for ParameterEncoding. Initializes variables by calling base initializer, then encoding. """ def __init__(self, base_initializer, parameter_encoding): self._base_initializer = base_initializer self._parameter_encoding = parameter_encoding def __call__(self, shape, dtype, partition_info=None): if self._base_initializer is None: # mimic default initialization in tf.get_variable() if dtype.is_floating: ret = tf.glorot_uniform_initializer()(shape, dtype) else: ret = tf.zeros(shape, dtype) else: ret = self._base_initializer(shape, dtype, partition_info=partition_info) noise = 0.0 # no random noise in the initializer. return tf.cast(self._parameter_encoding.encode(ret, noise), dtype) class EighthPowerEncoding(ParameterEncoding): """enc(x) = sign(x) * (abs(x)*128)^8. This provides less range and more resolution. The range of representable positive values is approximately [2^-23, 2^9] Resolution is 8x better than bfloat16. """ def encode(self, x, noise): x = tf.to_float(x) # we can't use tf.pow(..., 8.0) because of a high-error approximation # on TPU. Instead we square three times. x = tf.sign(x) * tf.square(tf.square(tf.square(tf.abs(x) * 128.0))) x = _to_bfloat16_unbiased(x, noise) return x def decode(self, x): x = tf.to_float(x) # we can't use tf.pow(..., 0.125) because of a high-error approximation # on TPU. Instead we sqrt three times. return tf.sign(x) * (tf.sqrt(tf.sqrt(tf.sqrt(tf.abs(x)))) / 128.0)