# coding=utf-8 # Copyright 2020 The Mesh TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """MTF implementation of Transformer sequence/seq2seq model. This implmentation is meant to be extensible, allowing users to define their own custom layers. It is meant to eventually replace the existing mesh-tensorflow Transformer implementation in the Tensor2Tensor library. The interface is for the user to create a Unitransformer or Bitransformer object and then call its methods (call_simple, sample_autoregressive, etc.) The Unitransformer or Bitransformer is configured by creating a LayerStack object containing instances of TransformerLayer. Users can subclass TransformerLayer to create new types of layers. Supported so far: - autoregressive single-stack Transformer (e.g. a simple language model) - encoder-decoder models (requires two Transformers) - non-autoregressive single-stack Transformer (e.g. BERT) - fast autoregressive sampling with temperature - beam search - mixture of experts layer - local attention layer - shared embedding / shared embedding and softmax weights Not yet supported: TODO(noam) - compressed attention layer """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import json import math import gin import mesh_tensorflow as mtf import tensorflow.compat.v1 as tf class TransformerLayer(object): """Abstract base class for transformer layers. The point of this object hierarchy is to make Transformer extensible. You can configure a Transformer with your own custom layers without changing the base library. Transformer layers should subclass TransformerLayer. In the constructor, the subclasses simply record their hyperparameters. Subclasses must implement a call() method, representing a call to that layer. The call method is passed an input tensor and a Context object. Variables should be created inside of the call(). Examples of subclasses can be found in transformer_layers.py. In addition to other global hyperparameters, the Context has a "mode", which might be tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL, or another special value. Autoregressive decoding uses the special modes "first_part" and "incremental". In "first_part" mode, the known first part of the sequence is passed through all layers so that they can create any necessary initial states. In "incremental" mode (which is called from the body of a while loop), the input consists of only one position. Layers with recurrent states must generate both output and the new states. """ def call(self, context, x): """Call the layer. Args: context: a Context x: an input Tensor Returns: y: a Tensor """ raise NotImplementedError("Not implemented") def to_json(self): return json.dumps(self, cls=json.JSONEncoder) def set_name(self, name): self._name = name @property def name(self): return getattr(self, "_name", None) class Context(object): """Extra information that layers need at call time. This structure is created by Unitransformer and is passed to the layers. It contains information that may be necessary to some layers in some modes. In "first_part" and "incremental" modes, some layers modify the context by producing and consuming "states" and "constant_states". The "states" are loop variables that change at each decoding step. The "constant_states" are produced once in "first_part" mode and read in the iterative decoding step. """ def __init__(self, model, mesh, batch_dims, length_dim, variable_dtype, beam_dim=None, mode=tf.estimator.ModeKeys.TRAIN, position=None, position_is_default=False, sequence_id=None, subsequence_id=None, states=None, new_states=None, losses=None, initial_position=None, layer_outputs=None, encoder_output=None, encoder_sequence_id=None, constant_states=None, shared_params=None, encoder_layer_outputs=None, write_priority=None, read_priority=None, inputs=None, encoder_inputs=None, num_microbatches=1): """Create a context. Args: model: a pointer back at the unitransformer object mesh: a mtf.Mesh batch_dims: a list of mtf.Dimension length_dim: a mtf.Dimension variable_dtype: a mtf.VariableDType beam_dim: an optional mtf.Dimension (present in beam search) mode: either a tf.estimator.ModeKeys or one of the following: "first_part" "incremental" position: an optional Tensor - represents position in the sequence. Passing None means that the position should be considered to be the index in the Tensor (along length_dim). position_is_default: a boolean - is the position equal to mtf.range(mesh, length_dim, tf.int32). This allows a shortcut in embedding lookup, as we can just slice the embedding variable. sequence_id: an optional int32 Tensor aligned with position - used to separate out different sequences which have been concatenated to form a single training example. Also used to mark padding. Id 0 is used for padding, and different positive values are used for the different sequences. subsequence_id: an optional int32 Tensor - used to represent multiple targets corresponding to the same input. Should only be provided when being called as a decoder. If provided, then position should line up with this rather than sequence_id. The sequence_id will represent the groups of sub-targets corresponding to each input. states: an optional list of Tensors representing loop variables (consumed in "incremental" mode) new_states: an optional list of Tensors onto which to append the new values of loop variables. (produced in "first_part" and "incremental" modes) losses: an optional list of Tensors onto which to append losses initial_position: an optional Tensor ("first_part" mode) layer_outputs: an optional list onto which to append layer outputs encoder_output: an optional Tensor (output of the encoder stack) encoder_sequence_id: an optional int32 Tensor (similar to sequence_id) but aligned with the encoder output. constant_states: an optional list of structures produced during "first_part" mode and consumed during "incremental" mode. shared_params: an optional dictionary which can be populated by parameters that are shared between Transformers - e.g. between the encoder and decoder Unitransformers in a Bitransformer. encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer write_priority: an optional Tensor in self-attention, position a can see position b iff read_priority[a] >= write_priority[b] read_priority: an optional Tensor inputs: an optional int32 Tensor with the input token ids encoder_inputs: an optional int32 Tensor with the input token ids to the encoder half of the Bitransformer of which this Unitransformer is the decoder. num_microbatches: integer - greater than one if the step has been serialized into multiple microbatches to save memory. """ self.model = model self.mesh = mesh self.batch_dims = batch_dims self.length_dim = length_dim self.variable_dtype = variable_dtype self.beam_dim = beam_dim self.mode = mode self.position = position self.position_is_default = position_is_default self.sequence_id = sequence_id self.subsequence_id = subsequence_id self.states = states self.new_states = new_states self.losses = losses self.initial_position = initial_position self.layer_outputs = layer_outputs if self.layer_outputs is None: self.layer_outputs = [] self.encoder_output = encoder_output self.encoder_sequence_id = encoder_sequence_id self.constant_states = constant_states self.next_constant_state = 0 self.shared_params = shared_params or {} self.layer_index = 0 self.encoder_layer_outputs = encoder_layer_outputs # put values here to share them between layers self.cache = {} self.write_priority = write_priority self.read_priority = read_priority self.inputs = inputs self.encoder_inputs = encoder_inputs self.num_microbatches = num_microbatches @property def train(self): return self.mode == tf.estimator.ModeKeys.TRAIN @property def activation_dtype(self): return self.variable_dtype.activation_dtype @property def encoder_length_dim(self): ret, = [d for d in self.encoder_output.shape.dims if d.name == "memory_length"] return ret def get_states(self, n): """Get the next n recurrent states. Called by layers in "incremental" mode. Args: n: an integer Returns: a list of n Tensors """ return self.states[len(self.new_states):len(self.new_states) + n] def record_new_states(self, new_states): """Record the new values of recurrent states. Called by layers in "first_part" or "incremental" mode. Args: new_states: a list of Tensors """ self.new_states.extend(new_states) def record_constant_state(self, s): """Record state in "first_part" mode to be read in "incremental" mode. This is to record state that is computed once and does not change at every decoding step. Args: s: a structure """ self.constant_states.append(s) def get_constant_state(self): """Read state that was written in "first_part" mode. Returns: a structure """ ret = self.constant_states[self.next_constant_state] self.next_constant_state += 1 return ret @property def nonpadding(self): """Tensor with zeros in padding positions and ones elsewhere.""" if self.sequence_id is None: return None if self.sequence_id == 1: return 1 else: return mtf.cast( mtf.not_equal(self.sequence_id, 0), self.activation_dtype) def get_position(self): if self.position_is_default: return mtf.range(self.mesh, self.length_dim, tf.int32) else: return self.position @gin.configurable class LayerStack(TransformerLayer): """A stack of layers with residual connections and layer norms.""" def __init__(self, layers, sublayers_initial=None, sublayers_per_layer=None, sublayers_final=None, dropout_rate=None, norm_epsilon=None, recompute_grads=False): """Create a LayerStack. `layers` is a list of TransformerLayer objects representing the building blocks of the transformer model, e.g. transformer_layers.SelfAttention. In addition, there are a bunch of other transformations which occur around the layer body, and at the beginning and the end of the layer stack. We call these "sublayers". They are configurable with the `sublayers_initial`, `sublayers_per_layer`, and `sublayers_final` arguments, each of which takes a list of sublayer functions. Each of the sublayer functions has signature: x, layer_stack, context -> y where x is the input tensor and y is the output tensor. The default sublayers specified in defaults.gin are: transformer.LayerStack.sublayers_initial = [ @transformer.sublayer_dropout, ] transformer.LayerStack.sublayers_per_layer = [ @transformer.sublayer_rms_norm, @transformer.sublayer_call_layer, @transformer.sublayer_dropout, @transformer.sublayer_residual, ] transformer.LayerStack.sublayers_final = [ @transformer.sublayer_rms_norm, @transformer.sublayer_dropout, ] Refer to these as examples of how to write and call your own sublayer functions. `dropout_rate` and `norm_epsilon` should only be specified in a legacy mode, for compatiblity with older checkpoints. Args: layers: a list of TransformerLayer sublayers_initial: an optional list of sublayer functions sublayers_per_layer: an optional list of sublayer functions sublayers_final: an optional list of sublayer functions dropout_rate: DEPRECATED - a floating-point number norm_epsilon: DEPRECATED - a floating-point number recompute_grads: a boolean """ self._layers = layers self._recompute_grads = recompute_grads self._sublayers_initial = sublayers_initial self._sublayers_per_layer = sublayers_per_layer self._sublayers_final = sublayers_final if (dropout_rate is not None) != (norm_epsilon is not None): raise ValueError( "LayerStack.dropout_rate and LayerStack.norm_epsilon should either " "be both not None (legacy mode) or both None (normal mode)") if dropout_rate is not None: self._legacy_init(dropout_rate, norm_epsilon) def _legacy_init(self, dropout_rate, norm_epsilon): """Legacy initialization for use with old checkpoints. dropout_rate and norm_epsilon are specified in LayerStack. Custom sublayers are not specified. Args: dropout_rate: a float norm_epsilon: a float """ self.dropout_rate = dropout_rate self.norm_epsilon = norm_epsilon if (self._sublayers_initial is not None or self._sublayers_per_layer is not None or self._sublayers_final is not None): tf.logging.warning("legacy mode - ignoring custom sublayers") self._sublayers_initial = [sublayer_legacy_dropout] self._sublayers_per_layer = [sublayer_legacy_rms_norm, sublayer_call_layer, sublayer_legacy_dropout, sublayer_residual] self._sublayers_final = [sublayer_legacy_final_rms_norm, sublayer_legacy_dropout] def call(self, context, x): """Call the layer stack.""" x = self._call_sublayers(self._sublayers_initial, x, context) context.layer_outputs.append(x) for lnum, layer in enumerate(self._layers): with tf.variable_scope(layer.name or ""): if self._recompute_grads: def fn(x, l=layer, c=context): return self._layer_fn(x, l, c) x = mtf.recompute_grad(fn, [x]) else: x = self._layer_fn(x, layer, context) if lnum != len(self._layers) - 1: context.layer_outputs.append(x) context.layer_index += 1 x = self._call_sublayers(self._sublayers_final, x, context) x = sublayer_mask_padding(x, self, context) context.layer_outputs.append(x) return x def _call_sublayers(self, sublayers, x, context): for s in sublayers: x = s(x, self, context) return x def _layer_fn(self, x, layer, context): """Call the layer and its associated sublayers. Args: x: a Tensor layer: a Layer context: a Context Returns: a Tensor """ context.current_layer = layer context.current_layer_input = x y = self._call_sublayers(self._sublayers_per_layer, x, context) if y.shape != x.shape: raise ValueError( "Layer %s returned misshaped output x=%s y=%s" % (layer.__class__.__name__, x, y)) return y @property def num_layers(self): return len(self.layers) @property def layers(self): return self._layers @gin.configurable def sublayer_call_layer(x, layer_stack, context): x = sublayer_mask_padding(x, layer_stack, context) layer = context.current_layer with tf.variable_scope(layer.__class__.__name__): return layer.call(context, x) @gin.configurable def sublayer_mask_padding(x, layer_stack, context): """Zero out padding regions. This "fixes" a bug where extreme values leak from the padding into the non-padding regions. TODO(noam): undertand this better and make a more principled fix. Args: x: a Tensor layer_stack: ignored context: a Tensor Returns: a Tensor """ del layer_stack if isinstance(context.sequence_id, mtf.Tensor): return x * mtf.cast( mtf.not_equal(context.sequence_id, 0), context.activation_dtype) else: return x @gin.configurable def sublayer_rms_norm(x, layer_stack, context, epsilon=1e-6, name="rms_norm"): """RMS normalization. Args: x: an input mtf.Tensor layer_stack: a LayerStack context: a Context epsilon: a float name: a string Returns: a mtf.Tensor """ del layer_stack model_dim = context.model.model_dim with tf.variable_scope(name): scale = mtf.get_variable( context.mesh, "scale", mtf.Shape(context.model.ensemble_dims + [model_dim]), initializer=tf.ones_initializer(), dtype=context.variable_dtype) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=model_dim) return x * mtf.rsqrt(variance + epsilon) * scale @gin.configurable def sublayer_legacy_rms_norm(x, layer_stack, context): """Deprecated - keep for checkpoint/operative_config.gin compatibility.""" return sublayer_rms_norm(x, layer_stack, context, name="layer_norm") @gin.configurable def sublayer_legacy_final_rms_norm(x, layer_stack, context): """Deprecated - keep for checkpoint/operative_config.gin compatibility.""" return sublayer_rms_norm(x, layer_stack, context, name="final_layer_norm") @gin.configurable def sublayer_rms_norm_subsampled(x, layer_stack, context, percentage=100., epsilon=1e-6): """RMS normalization.""" del layer_stack model_dim = context.model.model_dim with tf.variable_scope("layer_norm_subsampled"): scale = mtf.get_variable( context.mesh, "scale", mtf.Shape(context.model.ensemble_dims + [model_dim]), initializer=tf.ones_initializer(), dtype=context.variable_dtype) var_dim = mtf.Dimension( model_dim.name, int(math.ceil(model_dim.size * percentage/100))) var_activations = mtf.slice(x, 0, var_dim.size, var_dim.name) variance = mtf.reduce_mean( mtf.square(var_activations), reduced_dim=var_dim) return x * mtf.rsqrt(variance + epsilon) * scale @gin.configurable def sublayer_residual(x, layer_stack, context): del layer_stack return x + context.current_layer_input @gin.configurable def sublayer_dropout(x, layer_stack, context, dropout_rate=0.0): del layer_stack if context.train and dropout_rate > 0: return mtf.dropout( x, rate=dropout_rate, noise_shape=mtf.Shape(context.batch_dims + [context.model.model_dim])) else: return x @gin.configurable def sublayer_legacy_dropout(x, layer_stack, context): return sublayer_dropout(x, layer_stack, context, dropout_rate=layer_stack.dropout_rate) @gin.configurable def sublayer_rezero(x, layer_stack, context, initial_value=0.0): """Multiply by zero-initialized scalar (residual not included).""" del layer_stack rezero_weight = mtf.get_variable( x.mesh, "rezero_weight", shape=context.model.ensemble_dims, dtype=context.variable_dtype, initializer=tf.constant_initializer(initial_value)) return x * rezero_weight @gin.configurable class ReversibleLayerStack(LayerStack): """A version of LayerStack that uses a revnet. This should be very memory-efficient if LayerStack.recompute_grads is set to True. Also, sublayers_per_layer should be overridden in gin, so as to remove the residual. "Reformer" https://arxiv.org/abs/2001.04451 uses something like this. """ def call(self, context, x): """Call the layer stack.""" x = self._call_sublayers(self._sublayers_initial, x, context) context.layer_outputs.append(x) x1, x1_backwards, x2, x2_backwards = x, None, x, None for lnum, layer in enumerate(self._layers): with tf.variable_scope(layer.name or ""): def fn(x, l=layer, c=context): return self._layer_fn(x, l, c) x1, x1_backwards, x2, x2_backwards = ( mtf.layers.reversible_half_residual_and_swap( x1, x1_backwards, x2, x2_backwards, fn, recompute_grads=self._recompute_grads)) if lnum != len(self._layers) - 1: context.layer_outputs.append(x) context.layer_index += 1 x = x1 + x2 x = self._call_sublayers(self._sublayers_final, x, context) context.layer_outputs.append(x) return x @gin.configurable def sublayer_true_layer_norm(x, layer_stack, context, epsilon=1e-6): """True (aka normal) Normalization.""" del layer_stack model_dim = context.model.model_dim with tf.variable_scope("true_layer_norm"): return mtf.layers.layer_norm(x, model_dim, epsilon) @gin.configurable class Unitransformer(object): """A Transformer model with only one layer stack, e.g. a language model. This class is also used as part of Bitransformer, which contains two Unitransformers. """ def __init__(self, layer_stack, d_model=1024, input_vocab_size=gin.REQUIRED, output_vocab_size=gin.REQUIRED, autoregressive=gin.REQUIRED, max_length=gin.REQUIRED, shared_embedding_and_softmax_weights=False, label_smoothing=0.0, z_loss=1e-4, name="transformer", layout=None, mesh_shape=None, vocab_divisor=128, ensemble=None, loss_fn=None, positional_embedding=True, sinusoid_positional_embedding=False, input_full_attention=False, loss_on_targets_only=False, loss_denominator=None, token_dropout_rate=0.0): """Create a Unitransformer. Args: layer_stack: a LayerStack d_model: an integer input_vocab_size: an integer output_vocab_size: an integer autoregressive: a boolean max_length: an integer shared_embedding_and_softmax_weights: a boolean label_smoothing: a float z_loss: a float name: a string layout: optional - an input to mtf.convert_to_layout_rules Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape mesh_shape: optional - an input to mtf.convert_to_shape Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape vocab_divisor: an integer ensemble: an optional integer (for creating an ensemble of models) loss_fn: an optional function to override self._compute_loss positional_embedding: a boolean sinusoid_positional_embedding: a boolean, whether to use the sinusoid positional embedding from the "Attention Is All You Need" paper. If True, this will override the positional_embedding setting. input_full_attention: a boolean This is an option for seq-to-seq as a language model. Each example consists of [<inputs>, EOS=1, <targets>, EOS=1]. In the self-attention layers, positions in the inputs portion of the sequence can see the entire inputs portion, while positions in the targets portion of the sequence cannot see future positions. loss_on_targets_only: a boolean This is an option for seq-to-seq as a language model. Each example consists of [<inputs>, EOS=1, <targets>, EOS=1]. We zero-out the loss for the inputs portion of the example. loss_denominator: an optional float. The default behavior is to compute the mean loss across all tokens in the batch, making the denomiator the size of the targets tensor (omitting ensemble dimensions). Passing a float here provides an alternative denomiator. One use case is that when fine-tuning a model using a much smaller batch size than the original training batch, one might want to use the same denominator as was used for the pretraining. This complication might be avoided by always using loss_denominator = 1.0. token_dropout_rate: an optional floating point value """ self.layer_stack = layer_stack self.model_dim = mtf.Dimension("d_model", d_model) self.input_vocab_size_unpadded = input_vocab_size self.input_vocab_dim = mtf.Dimension( "vocab", _round_up_to_multiple(input_vocab_size, vocab_divisor)) self.output_vocab_size_unpadded = output_vocab_size if output_vocab_size: self.output_vocab_dim = mtf.Dimension( "vocab", _round_up_to_multiple(output_vocab_size, vocab_divisor)) else: self.output_vocab_dim = None if autoregressive: raise ValueError("autoregressive Transformer needs output vocabulary") self.autoregressive = autoregressive if sinusoid_positional_embedding: positional_embedding = True if positional_embedding: self.max_length_dim = mtf.Dimension("max_length", max_length) else: self.max_length_dim = None self.shared_embedding_and_softmax_weights = ( shared_embedding_and_softmax_weights) self.label_smoothing = label_smoothing self.z_loss = z_loss self.name = name self.layout = layout self.mesh_shape = mesh_shape self.ensemble_dim = ( mtf.Dimension("ensemble", ensemble) if ensemble else None) self._loss_fn = loss_fn self.positional_embedding = positional_embedding self.sinusoid_positional_embedding = sinusoid_positional_embedding self.input_full_attention = input_full_attention self.loss_on_targets_only = loss_on_targets_only self._loss_denominator = loss_denominator if self.input_full_attention and not self.autoregressive: raise ValueError( "input_full_attention only makes sense with autoregressive") self.token_dropout_rate = token_dropout_rate @property def fully_autoregressive(self): return self.autoregressive and not self.input_full_attention @property def ensemble_dims(self): return [self.ensemble_dim] if self.ensemble_dim else [] def _compute_loss(self, context, logits, targets, output_vocab_dim): """Regular cross entropy loss. Args: context: a Context logits: a Tensor, the logits from the decoder targets: an Tensor output_vocab_dim: a Dimension Returns: A 0-dimensional tensor of the loss. """ # Use a custom loss function if one is injected. if self._loss_fn: return self._loss_fn(self, context, logits, targets, output_vocab_dim) off_value = self.label_smoothing / output_vocab_dim.size on_value = 1.0 - self.label_smoothing + off_value soft_targets = mtf.one_hot( targets, output_vocab_dim, dtype=context.activation_dtype, on_value=on_value, off_value=off_value) loss = mtf.layers.softmax_cross_entropy_with_logits( logits, soft_targets, output_vocab_dim, z_loss=self.z_loss if context.train else 0.0) weights = mtf.layers.weights_nonzero( targets, dtype=context.activation_dtype) if self.loss_on_targets_only: weights *= mtf.cast(mtf.logical_not(delimited_lm_inputs_mask(targets)), dtype=context.activation_dtype) return (mtf.reduce_sum(loss * weights) / self.loss_denominator(targets, context.num_microbatches)) def _call_internal(self, context, inputs, targets=None): """Compute logits based on inputs (all positions in parallel). Also updates context if applicable. Args: context: a Context inputs: a Tensor targets: an optional Tensor Returns: logits: a Tensor with shape [<batch_dims>, length_dim, output_vocab_dim] """ mesh = inputs.mesh if self.ensemble_dim and self.ensemble_dim not in inputs.shape.dims: # Training an ensemble where all models are trained on the same examples. inputs = mtf.broadcast(inputs, [self.ensemble_dim] + inputs.shape.dims) if targets: targets = mtf.broadcast( targets, [self.ensemble_dim] + targets.shape.dims) if "embedding" in context.shared_params: vocab_embedding = context.shared_params["embedding"] else: vocab_embedding = get_vocab_embedding_cls()( mesh, self.input_vocab_dim, self.model_dim, context.variable_dtype, name="embedding", ensemble_dim=self.ensemble_dim) if context.train: inputs = mtf.dropout(inputs, rate=self.token_dropout_rate) x = vocab_embedding.ids_to_embedding(inputs) if self.positional_embedding or self.sinusoid_positional_embedding: if self.sinusoid_positional_embedding: pos_emb_var = sinusoid_positional_embedding_weights( mesh, self.max_length_dim, self.model_dim, context.variable_dtype.activation_dtype) elif "positional_embedding" in context.shared_params: pos_emb_var = context.shared_params["positional_embedding"] else: pos_emb_var = mtf.layers.embedding_weights( mesh, self.max_length_dim, self.model_dim, context.variable_dtype, "positional_embedding", ensemble_dim=self.ensemble_dim) if (context.length_dim is not None and context.length_dim.size > self.max_length_dim.size): message = ( "Length dimenison exceeds size of positional embedding table. " "length_dim.size > max_length_dim.size %s vs %s." % (context.length_dim, self.max_length_dim)) if context.position_is_default: # Definitely getting overflow in this case. raise ValueError(message) else: tf.logging.warning( message + " This may be OK if there are several shorter sequences packed " "together. Otherwise, the later positions will get zeros.") if context.position_is_default: pos_emb = mtf.rename_dimension( mtf.slice(pos_emb_var, 0, context.length_dim.size, self.max_length_dim.name), self.max_length_dim.name, context.length_dim.name) else: pos_emb = mtf.gather( pos_emb_var, context.position, self.max_length_dim, output_shape=x.shape) x += pos_emb x = self.layer_stack.call(context, x) if self.output_vocab_dim is None: return x if self.shared_embedding_and_softmax_weights: logits = vocab_embedding.hidden_to_logits(x, context) else: logits = mtf.layers.dense( x, self.output_vocab_dim, use_bias=False, variable_dtype=context.variable_dtype, reduced_dims=x.shape.dims[-1:], name="logits") if targets is not None and context.losses is not None: context.losses.append( self._compute_loss(context, logits, targets, self.output_vocab_dim)) if self.ensemble_dim: logits = reduce_ensemble_logits( logits, self.ensemble_dim, self.output_vocab_dim) return logits def loss_denominator(self, targets, num_microbatches): """Denominator applied to losses. This is usually the size of the targets tensor (omitting ensemble dimensions). Alternitively, it is an override value passed to the class constructor. Args: targets: a mtf.Tensor num_microbatches: an integer - greater than one if the step has been serialized into multiple microbatches to save memory. Returns: a float """ if self._loss_denominator is not None: return float(self._loss_denominator) else: ret = float(targets.shape.size) * num_microbatches if self.ensemble_dim: # The ensembling should not decrease the gradient to each model ret /= self.ensemble_dim.size return float(ret) def call_simple(self, inputs, targets, compute_loss, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), sequence_id=None, subsequence_id=None, position=None, encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, layer_outputs=None, encoder_layer_outputs=None, num_microbatches=1): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training autoregressive models this should be equal to mtf.shift(targets, offset=1, dim=length_dim, wrap=False) targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType sequence_id: an optional Tensor subsequence_id: an optional Tensor position: an optional Tensor encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary layer_outputs: an optional list to append Tensor layer activations to encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer num_microbatches: integer - greater than one if the step has been serialized into multiple microbatches to save memory. Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, dtype=tf.int32) if not self.positional_embedding: # To make relative attention faster, we drop the information about the # position in the subsequence. The relative attention code then # assumes that the positions are given by index in the tensor, # which still leads to the correct computation of relative position. position = None if position is None: position_is_default = True position = length_range else: position_is_default = False if self.input_full_attention: # The inputs part of each sequence can fully attend within itself. full_attention_region = delimited_lm_inputs_mask(targets) # We can include one additional position to the right - the position # where the final EOS of the inputs is read and the first target token # is predicted. full_attention_region = mtf.logical_or( full_attention_region, mtf.shift(full_attention_region, offset=1, dim=length_dim, wrap=False) ) # We set read_priority and write_priority to 0 in the full-attention # region and equal to the position elsewhere. read_priority = write_priority = length_range * mtf.cast( mtf.logical_not(full_attention_region), tf.int32) elif self.autoregressive: # Vanilla autoregressive model - each position can see previous positions. read_priority = write_priority = length_range else: read_priority = write_priority = None context = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode=mode, losses=[] if compute_loss else None, sequence_id=sequence_id, subsequence_id=subsequence_id, position=position, position_is_default=position_is_default, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=layer_outputs, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs, num_microbatches=num_microbatches) with tf.variable_scope(self.name): logits = self._call_internal(context, inputs, targets) if compute_loss: loss = mtf.add_n(context.losses) else: loss = None return logits, loss @gin.configurable(module="Unitransformer") def sample_autoregressive(self, partial_sequences, stop_at_token=1, max_steps=None, temperature=0.0, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, shared_params=None, has_partial_sequences=True, encoder_layer_outputs=None, never_end=False, remove_partial_sequences=False, sampling_keep_top_k=-1, bos_id=0): """Sample randomly one token at a time. The partial_sequences represent partial sequences to be continued. The first tokens of each sequence are nonzero representing the given partial sequences and the last tokens of each sequence are zeros, representing what needs to be filled in. If there are no partial sequences (you want to sample from the beginning), then pass partial_sequences=mtf.zeros(mesh, shape, dtype=tf.int32) and has_partial_sequences=False (so we can skip computation). Args: partial_sequences: an int32 Tensor with shape [<batch_dims>, length_dim] stop_at_token: an optional integer eos id. Stop when we produce it. max_steps: an optional integer, the max number of steps to decode. temperature: an optional floating point value between 0.0 and 1.0 0.0 means argmax, 1.0 means sample according to predicted distribution. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor shared_params: an optional dictionary has_partial_sequences: a boolean encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer never_end: a boolean - if set, then avoid generating stop_at_token remove_partial_sequences: a boolean - whether to remove the partial sequences from the output sampling_keep_top_k: an integer - if not -1, only sample from the top k logits. bos_id: beginning of sequence id Returns: a Tensor with shape [<batch_dims>, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") inputs = partial_sequences batch_dims = inputs.shape.dims[:-1] length_dim = inputs.shape.dims[-1] initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None length_range = mtf.range(inputs.mesh, length_dim, tf.int32) if self.input_full_attention: read_priority = write_priority = length_range * mtf.to_int32( mtf.greater(length_range, initial_position)) else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits constant_states = context_first_part.constant_states if not has_partial_sequences: initial_states = [ mtf.zeros_like(t) for t in context_first_part.new_states] partial_sequences_eos_count = 0 else: initial_states = context_first_part.new_states partial_sequences_eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(partial_sequences, stop_at_token)), reduced_dim=length_dim) def cond_fn(position, ids, *unused_states): """Should we run another loop iteration.""" past_end = mtf.greater_equal(position, length_dim.size) if max_steps: past_end = mtf.logical_or( past_end, mtf.greater_equal(position - initial_position, max_steps)) is_done = past_end if stop_at_token is not None: eos_count = mtf.reduce_sum( mtf.to_int32(mtf.equal(ids, stop_at_token)), reduced_dim=length_dim) has_additional_eos = mtf.greater(eos_count, partial_sequences_eos_count) is_done = mtf.logical_or(is_done, has_additional_eos) all_done = mtf.reduce_all(is_done) return mtf.logical_not(all_done) def body_fn(position, ids, *states): """One step in the decode loop.""" inputs_this_step = mtf.gather(ids, position - 1, length_dim) # Setting proper bos_id for position == 0. No-op otherwise. if bos_id: inputs_this_step += bos_id * mtf.ones_like(inputs_this_step) * mtf.cast( mtf.equal(position, 0), tf.int32) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims, length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=position, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=position, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) if never_end: logits += mtf.one_hot( mtf.constant(logits.mesh, stop_at_token, dtype=tf.int32), self.output_vocab_dim, on_value=-1e9, off_value=0.0, dtype=logits.dtype) # TBD whether this should be before or after never_end: # Note for adding top_p sampling in the future, in other code bases, the # option to apply temperature is done before the top-k truncation. This # implementation does this in the opposite order. For top-k this doesn't # matter, but for top_p it will. if sampling_keep_top_k != -1: if sampling_keep_top_k <= 0: raise ValueError("sampling_keep_top_k must either be -1 or positive.") k_largest = mtf.nth_largest_element( logits, n=sampling_keep_top_k, reduced_dim=self.output_vocab_dim) logits = mtf.where(mtf.less_equal(logits, k_largest), mtf.ones_like(logits)*-1e6, logits) ids_this_step = mtf.sample_with_temperature( logits, self.output_vocab_dim, temperature) new_position = position + 1 new_ids = ids + ids_this_step * mtf.one_hot( position, length_dim, dtype=tf.int32) return [new_position, new_ids] + context_incremental.new_states while_loop_inputs = [initial_position, inputs] + initial_states final_position, outputs = mtf.while_loop( cond_fn, body_fn, while_loop_inputs)[:2] del final_position if has_partial_sequences and remove_partial_sequences: # remove partial sequences from outputs partial_length = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(partial_sequences, 0)), reduced_dim=length_dim) outputs = mtf.dynamic_shift( outputs, -partial_length, length_dim, wrap=False) return outputs def beam_search(self, inputs, decode_length, variable_dtype=mtf.VariableDType(tf.float32), encoder_output=None, encoder_sequence_id=None, encoder_inputs=None, alpha=0.6, shared_params=None, encoder_layer_outputs=None, bos_id=0): """Beam search. Args: inputs: an int32 zero-Tensor with shape [<batch_dims>, beam_dim, length_dim]. decode_length: an int32 mtf scalar. Maximum decode length. variable_dtype: a mtf.VariableDType encoder_output: an optional Tensor encoder_sequence_id: an optional Tensor encoder_inputs: an optional Tensor alpha: a floating point value (length bonus) shared_params: an optional dictionary encoder_layer_outputs: optional - readonly list of tensor activations when decoding, one per each input layer + the embedding layer bos_id: beginning of sequence id Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ if not self.autoregressive: raise ValueError("must be autoregressive") batch_dims = inputs.shape.dims[:-2] if len(batch_dims) != 1: raise NotImplementedError( "beam search supports exactly one batch dimension.") beam_dim = inputs.shape.dims[-2] length_dim = inputs.shape.dims[-1] length_range = mtf.range(inputs.mesh, length_dim, tf.int32) initial_position = mtf.reduce_sum( mtf.to_int32(mtf.not_equal(inputs, 0)), reduced_dim=length_dim) sequence_id = 1 if encoder_sequence_id is not None else None if self.input_full_attention: # This only makes sense in the case of beam search with given partial # sequences, which is not yet implemented. # TODO(noam): implement raise NotImplementedError( "Beam search for language models not yet implemented") else: read_priority = write_priority = length_range context_first_part = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="first_part", position=length_range, position_is_default=True, new_states=[], initial_position=initial_position, sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=[], shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=read_priority, inputs=inputs, encoder_inputs=encoder_inputs) shifted_inputs = mtf.shift(inputs, offset=1, dim=length_dim, wrap=False) with tf.variable_scope(self.name): logits = self._call_internal(context_first_part, shifted_inputs) del logits # There are no partial targets. # Replace initial states by zeros to avoid computing them. initial_states = [mtf.zeros_like(t) for t in context_first_part.new_states] constant_states = context_first_part.constant_states def logits_fn(step_num, ids, states): """logits_fn for mtf.beam_search.beam_search().""" inputs_this_step = mtf.gather(ids, step_num - 1, length_dim) # Setting proper bos_id for step_num == 0. No-op otherwise. if bos_id: inputs_this_step += bos_id * mtf.ones_like(inputs_this_step) * mtf.cast( mtf.equal(step_num, 0), tf.int32) context_incremental = Context( model=self, mesh=inputs.mesh, batch_dims=batch_dims + [beam_dim], length_dim=length_dim, variable_dtype=variable_dtype, mode="incremental", position=step_num, states=states, new_states=[], sequence_id=sequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, constant_states=constant_states, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, write_priority=write_priority, read_priority=step_num, inputs=inputs_this_step, encoder_inputs=encoder_inputs) with tf.variable_scope(self.name, reuse=True): logits = self._call_internal(context_incremental, inputs_this_step) return mtf.to_float(logits), context_incremental.new_states beams, unused_scores = mtf.beam_search.beam_search( logits_fn, inputs, alpha, states=initial_states, decode_length=decode_length, use_tpu=True, dtype=tf.float32, mesh_shape=self.mesh_shape, layout=self.layout) return mtf.gather( beams, mtf.constant(inputs.mesh, 0, dtype=tf.int32), beam_dim) @gin.configurable def shift_targets(targets, bos_id=0, eos_id=1): """Transforms decoder labels to decoder inputs. Args: targets: decoder labels bos_id: begin of sequence id, defaults to 0 eos_id: end of sequence id, defaults to 1 Returns: Decoder inputs. """ length_dim = targets.shape.dims[-1] shifted_targets = mtf.shift(targets, offset=1, dim=length_dim, wrap=False) # We should have a 0 at the beginning of each sequence rather than the # shifted EOS (e.g. 1) from the previous sequence. shifted_targets *= mtf.to_int32(mtf.not_equal(shifted_targets, eos_id)) if bos_id: shifted_targets += mtf.to_int32( mtf.logical_and( mtf.equal(shifted_targets, 0), mtf.not_equal(targets, 0))) * bos_id return shifted_targets @gin.configurable class Bitransformer(object): """A Transformer sequence-to-sequence model with two layer stacks.""" def __init__(self, encoder, decoder, shared_embedding=True): """Create a Bitransformer. Args: encoder: a mtf.unitransformer decoder: a mtf.unitransformer shared_embedding: a boolean """ self.encoder = encoder self.decoder = decoder self.shared_embedding = shared_embedding @property def output_vocab_dim(self): return self.decoder.output_vocab_dim def loss_denominator(self, targets, num_microbatches): return self.decoder.loss_denominator(targets, num_microbatches) @property def z_loss(self): return self.decoder.z_loss def _shared_params(self, mesh, variable_dtype): """Create parameters that are shared between encoder and decoder. Args: mesh: a Mesh variable_dtype: a VariableDType Returns: a dictionary """ shared_params = {} if self.shared_embedding: with tf.variable_scope("shared"): if not (self.encoder.model_dim == self.decoder.model_dim and self.encoder.input_vocab_dim == self.decoder.input_vocab_dim): raise ValueError( "shared_embedding requires encoder and decoder to have identical" " d_model and vocabulary sizes") shared_params["embedding"] = get_vocab_embedding_cls()( mesh, self.encoder.input_vocab_dim, self.encoder.model_dim, variable_dtype, name="embedding", ensemble_dim=self.encoder.ensemble_dim) if (self.encoder.positional_embedding and self.decoder.positional_embedding and self.encoder.max_length_dim == self.decoder.max_length_dim): if (self.encoder.sinusoid_positional_embedding and self.decoder.sinusoid_positional_embedding): pos_emb_var = sinusoid_positional_embedding_weights( mesh, self.encoder.max_length_dim, self.encoder.model_dim, variable_dtype.activation_dtype) else: pos_emb_var = mtf.layers.embedding_weights( mesh, self.encoder.max_length_dim, self.encoder.model_dim, variable_dtype, "positional_embedding", ensemble_dim=self.encoder.ensemble_dim) shared_params["positional_embedding"] = pos_emb_var return shared_params def call_simple(self, inputs, targets, compute_loss, mode=tf.estimator.ModeKeys.TRAIN, variable_dtype=mtf.VariableDType(tf.float32), encoder_sequence_id=None, decoder_sequence_id=None, decoder_subsequence_id=None, encoder_position=None, decoder_position=None, num_microbatches=1): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean mode: a tf.estimator.ModeKeys variable_dtype: a mtf.VariableDType encoder_sequence_id: an optional Tensor decoder_sequence_id: an optional Tensor decoder_subsequence_id: an optional Tensor encoder_position: an optional Tensor decoder_position: an optional Tensor num_microbatches: integer - greater than one if the step has been serialized into multiple microbatches to save memory. Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ # encoder_sequene_id and decoder_sequence_id are used to delineate packed # examples but are also necessary to indicate padding where sequence_id==0. # If they are absent, then we assume that padding is indicated by zeros in # the inputs/targets, and we make up sequence_id tensors to indicate this. if encoder_sequence_id is None: encoder_sequence_id = mtf.minimum(inputs, 1) if decoder_sequence_id is None: decoder_sequence_id = mtf.minimum(targets, 1) encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_output, encoder_loss = self.encoder.call_simple( inputs, None, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, position=encoder_position, shared_params=shared_params, layer_outputs=encoder_layer_outputs, num_microbatches=num_microbatches) encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output) if encoder_sequence_id is not None: encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) logits, loss = self.decoder.call_simple( shift_targets(targets), targets, compute_loss, mode=mode, variable_dtype=variable_dtype, sequence_id=decoder_sequence_id, subsequence_id=decoder_subsequence_id, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs), position=decoder_position, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs, num_microbatches=num_microbatches) if loss is not None and encoder_loss is not None: loss += encoder_loss return logits, loss @gin.configurable(module="Bitransformer") def decode(self, inputs, variable_dtype=mtf.VariableDType(tf.float32), beam_size=1, alpha=0.6, temperature=0.0, decode_length_multiplier=1.5, decode_length_constant=10, max_decode_length=None): """Sampling or beam search. TODO(noam): should we make the output length dimension different from the input length dimension? Args: inputs: a Tensor with shape [<batch_dims>, beam_dim, length_dim] variable_dtype: a mtf.VariableDType beam_size: an integer >= 1 alpha: a floating point value (length bonus for beam search) temperature: a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. decode_length_multiplier: a float decode_length_constant: a float max_decode_length: an optional integer Returns: a Tensor with shape [<batch_dims>, beam_dim, length_dim] """ encoder_layer_outputs = [] shared_params = self._shared_params(inputs.mesh, variable_dtype) encoder_sequence_id = mtf.minimum(inputs, 1) encoder_output, encoder_loss = self.encoder.call_simple( inputs=inputs, targets=None, compute_loss=False, mode=tf.estimator.ModeKeys.PREDICT, variable_dtype=variable_dtype, sequence_id=encoder_sequence_id, shared_params=shared_params, layer_outputs=encoder_layer_outputs) del encoder_loss encoder_output = mtf.layers.rename_length_to_memory_length(encoder_output) encoder_sequence_id = mtf.layers.rename_length_to_memory_length( encoder_sequence_id) batch_dims = inputs.shape[:-1] length_dim = inputs.shape[-1] if max_decode_length is None: decode_length_dim = length_dim else: decode_length_dim = mtf.Dimension("length", max_decode_length) if beam_size == 1: ids_shape = mtf.Shape(batch_dims + [decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) return self.decoder.sample_autoregressive( partial_sequences, temperature=temperature, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=mtf.layers.rename_length_to_memory_length(inputs), shared_params=shared_params, has_partial_sequences=False, encoder_layer_outputs=encoder_layer_outputs) else: if temperature != 0: raise ValueError( "don't know how to beam search with nonzero temperature") # beam search beam_dim = mtf.Dimension("beam", beam_size) ids_shape = mtf.Shape(batch_dims + [beam_dim, decode_length_dim]) partial_sequences = mtf.zeros(inputs.mesh, ids_shape, dtype=tf.int32) input_length = mtf.reduce_sum( mtf.to_float(mtf.cast(inputs, tf.bool)), reduced_dim=length_dim) max_input_length = mtf.reduce_max(input_length) decode_length = mtf.cast( max_input_length * decode_length_multiplier + decode_length_constant, tf.int32) return self.decoder.beam_search( partial_sequences, decode_length, variable_dtype=variable_dtype, encoder_output=encoder_output, encoder_sequence_id=encoder_sequence_id, encoder_inputs=inputs, alpha=alpha, shared_params=shared_params, encoder_layer_outputs=encoder_layer_outputs) @gin.configurable class StudentTeacher(object): """A teacher and a student to be taught via distillation.""" def __init__(self, student, teacher, temperature=None, fraction_soft=None, teacher_checkpoint=None): """Create a StudentTeacher. Args: student: a Unitransformer or Bitransformer teacher: a Unitransformer or Bitransformer temperature: a float, the temperature of the softmax for distilling from the teacher. Required only when training. fraction_soft: a float between 0 and 1, the contribution of the soft target cross entropy to the training loss. The rest of the loss will be the cross entropy with the one-hot actual label. Required only when training. teacher_checkpoint: a string, the path to the teacher checkpoint that we wish to use. Required only when training. """ self.student = student self.teacher = teacher self.temperature = temperature self.fraction_soft = fraction_soft self.teacher_checkpoint = teacher_checkpoint def call_simple(self, inputs, targets, compute_loss, variable_dtype=mtf.VariableDType(tf.float32), num_microbatches=1, **kargs): """Compute logits based on inputs (all positions in parallel). This is called during training and evaluation. Args: inputs: an int32 Tensor with shape [<batch_dims>, length_dim] For training autoregressive models this should be equal to mtf.shift(targets, offset=1, dim=length_dim, wrap=False) targets: an optional int32 Tensor with shape [<batch_dims>, length_dim] compute_loss: a boolean variable_dtype: a mtf.VariableDType num_microbatches: integer - greater than one if the step has been serialized into multiple microbatches to save memory. **kargs: additional arguments to pass to the student.call_simple and teacher.call_simple Returns: logits: a Tensor with shape [<batch_dims>, output_vocab_dim] loss: an optional Scalar (if compute_loss=True) """ with tf.variable_scope("student"): student_logits, hard_loss = self.student.call_simple( inputs, targets, compute_loss=True, variable_dtype=variable_dtype, num_microbatches=num_microbatches, **kargs) if not compute_loss: return student_logits elif self.fraction_soft == 0.0: # Do not create the teacher if we do not need it. return student_logits, hard_loss assert self.student.output_vocab_dim == self.teacher.output_vocab_dim assert self.student.z_loss == self.teacher.z_loss output_vocab_dim = self.student.output_vocab_dim z_loss = self.student.z_loss graph = inputs.mesh.graph with tf.variable_scope("teacher"): teacher_logits, _ = self.teacher.call_simple( inputs, targets, compute_loss=True, variable_dtype=variable_dtype, num_microbatches=num_microbatches, **kargs) graph.make_variables_untrainable( [v for v in graph.trainable_variables if v.name.startswith("teacher/")]) soft_targets = mtf.softmax(teacher_logits / self.temperature, output_vocab_dim) soft_loss = mtf.layers.softmax_cross_entropy_with_logits( student_logits / self.temperature, mtf.stop_gradient(soft_targets), output_vocab_dim, z_loss=z_loss) # Ignore losses from padding regions. weights = mtf.layers.weights_nonzero( targets, dtype=variable_dtype.activation_dtype) soft_loss = (mtf.reduce_sum(soft_loss * weights) / self.student.loss_denominator(targets, num_microbatches)) loss = (1.0 - self.fraction_soft) * hard_loss \ + self.temperature**2 * self.fraction_soft * soft_loss return student_logits, loss def decode(self, *args, **kargs): """Sample from the student. Args: *args: arguments to Unitransformer.sample_autoregressive or Bitransformer.decode **kargs: arguments to Unitransformer.sample_autoregressive or Bitransformer.decode Returns: a Tensor with the same shape as the output of Unitransformer.sample_autoregressive or Bitransformer.decode """ with tf.variable_scope("student"): if isinstance(self.student, Unitransformer): return self.student.sample_autoregressive(*args, **kargs) elif isinstance(self.student, Bitransformer): return self.student.decode(*args, **kargs) else: raise ValueError("unrecognized class") def initialize(self): """Initialize the teacher model from the checkpoint. This function will be called after the graph has been constructed. """ if self.fraction_soft == 0.0: # Do nothing if we do not need the teacher. return vars_to_restore = tf.get_collection( tf.GraphKeys.GLOBAL_VARIABLES, scope="teacher") tf.train.init_from_checkpoint( self.teacher_checkpoint, {v.name[len("teacher/"):].split(":")[0]: v for v in vars_to_restore}) # gin-configurable constructors @gin.configurable def make_layer_stack(layers=gin.REQUIRED, layer_stack_cls=LayerStack, num_layers=6, block_scope=True): """Configurable layer stack. The "layers" argument specifies the layers in each block. It is a list of specifications. Each specification is either a subclass of TransformerLayer or a list/tuple containing such a subclass as well as other optional items. Each optional item is either a string (the class name), or a dictionary of kwargs to be passed to the class constructor. Example: layers=[ transformer_layers.SelfAttention, [transformer_layers.DenseReluDense, "feedforward", {"hidden_size": 2048, "dropout_rate":0.2}], ] The "num_layers" argument specifies the number of blocks. Args: layers: a list (see above) layer_stack_cls: a class, e.g. LayerStack or ReversibleLayerStack num_layers: an integer block_scope: a bool, if True then use scopes of the format ``` block_000/layer_000/... block_000/layer_001/... ... block_001/layer_000/... block_001/layer_001/... ``` If False then use scopes of the format ``` layer_000/... layer_001/... layer_002/... ... ``` Returns: a LayerStack """ layer_stack = [] for block in range(num_layers): for n, cls in enumerate(layers): # Set name to None if it wasn't provided which simplifies the logic below name = None kwargs = {} if isinstance(cls, (list, tuple)): for x in cls: if isinstance(x, str): name = x elif isinstance(x, dict): kwargs = x else: cls = x if block_scope: name = "block_{:03d}/{}".format(block, name or "layer_{:03d}".format(n)) else: name = name or "layer_{:03d}".format(len(layer_stack)) layer = cls(**kwargs) layer.set_name(name) layer_stack.append(layer) return layer_stack_cls(layer_stack) @gin.configurable def make_bitransformer( input_vocab_size=gin.REQUIRED, output_vocab_size=gin.REQUIRED, layout=None, mesh_shape=None, encoder_name="encoder", decoder_name="decoder"): """Gin-configurable bitransformer constructor. In your config file you need to set the encoder and decoder layers like this: encoder/make_layer_stack.layers = [ @transformer_layers.SelfAttention, @transformer_layers.DenseReluDense, ] decoder/make_layer_stack.layers = [ @transformer_layers.SelfAttention, @transformer_layers.EncDecAttention, @transformer_layers.DenseReluDense, ] Args: input_vocab_size: a integer output_vocab_size: an integer layout: optional - an input to mtf.convert_to_layout_rules Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape mesh_shape: optional - an input to mtf.convert_to_shape Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape encoder_name: optional - a string giving the Unitransformer encoder name. decoder_name: optional - a string giving the Unitransformer decoder name. Returns: a Bitransformer """ with gin.config_scope("encoder"): encoder = Unitransformer( layer_stack=make_layer_stack(), input_vocab_size=input_vocab_size, output_vocab_size=None, autoregressive=False, name=encoder_name, layout=layout, mesh_shape=mesh_shape) with gin.config_scope("decoder"): decoder = Unitransformer( layer_stack=make_layer_stack(), input_vocab_size=output_vocab_size, output_vocab_size=output_vocab_size, autoregressive=True, name=decoder_name, layout=layout, mesh_shape=mesh_shape) return Bitransformer(encoder, decoder) @gin.configurable def make_bi_student_teacher(input_vocab_size=gin.REQUIRED, output_vocab_size=gin.REQUIRED, layout=None, mesh_shape=None): """Gin-configurable bitransformer student teacher constructor. In your config file you need to set the encoder and decoder layers like this: encoder_layers = [ @mesh_tensorflow.transformer.transformer_layers.SelfAttention, @mesh_tensorflow.transformer.transformer_layers.DenseReluDense, ] decoder_layers = [ @mesh_tensorflow.transformer.transformer_layers.SelfAttention, @mesh_tensorflow.transformer.transformer_layers.EncDecAttention, @mesh_tensorflow.transformer.transformer_layers.DenseReluDense, ] teacher/encoder/transformer.make_layer_stack.layers = %encoder_layers teacher/decoder/transformer.make_layer_stack.layers = %decoder_layers student/encoder/transformer.make_layer_stack.layers = %encoder_layers student/decoder/transformer.make_layer_stack.layers = %decoder_layers Args: input_vocab_size: a integer output_vocab_size: an integer layout: optional - an input to mtf.convert_to_layout_rules Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape mesh_shape: optional - an input to mtf.convert_to_shape Some layers (e.g. MoE layers) cheat by looking at layout and mesh_shape Returns: a StudentTeacher """ with gin.config_scope("student"): student = make_bitransformer( input_vocab_size=input_vocab_size, output_vocab_size=output_vocab_size, layout=layout, mesh_shape=mesh_shape) with gin.config_scope("teacher"): teacher = make_bitransformer( input_vocab_size=input_vocab_size, output_vocab_size=output_vocab_size, layout=layout, mesh_shape=mesh_shape) return StudentTeacher(student=student, teacher=teacher) def _round_up_to_multiple(n, divisor): return n + -n % divisor def delimited_lm_inputs_mask(ids, eos_id=1): """Binary mask indicating which parts of the ids represent the inputs. Assumes that the ids consist of packed sequences where each example is represented by two eos-terminated sequences, i.e. [<inputs0>, EOS, <targets0>, EOS, <inputs1>, EOS, <targets1>, EOS ...] As such, the inputs are the parts where the number of previous EOS tokens is even. Args: ids: an int32 mtf.Tensor with shape [..., length_dim] eos_id: an integer Returns: a boolean mtf.Tensor with the same shape as ids """ length_dim = ids.shape.dims[-1] return mtf.equal(mtf.mod(mtf.cumsum(mtf.to_int32(mtf.equal(ids, eos_id)), length_dim, exclusive=True), 2), 0) @gin.configurable def reduce_ensemble_logits_select(logits, ensemble_dim, vocab_dim, model_id=0): """Select logits from the model_id-th element of the ensemble.""" del vocab_dim return mtf.gather(logits, model_id % ensemble_dim.size, ensemble_dim) @gin.configurable def reduce_ensemble_logits_mean_prob(logits, ensemble_dim, vocab_dim): """Probabilities equal to arithmetic mean probability across models.""" probs = mtf.softmax(logits, reduced_dim=vocab_dim) probs = mtf.reduce_mean(probs, reduced_dim=ensemble_dim) return mtf.log(mtf.maximum(probs, 1e-20)) @gin.configurable def reduce_ensemble_logits_mean_logit(logits, ensemble_dim, vocab_dim): """Probabilities proportional to geometric mean probability across models.""" del vocab_dim return mtf.reduce_mean(logits, reduced_dim=ensemble_dim) @gin.configurable def reduce_ensemble_logits(logits, ensemble_dim, vocab_dim, reduce_fn=reduce_ensemble_logits_mean_prob): """Configurable reduction function for decoding from an ensemble. reduce_fn is a function which takes: a logits tensor containing ensemble_dim (logits from all models) ensemble_dim vocab_dim and returns a logits tensor without ensemble_dim. Args: logits: a mtf.Tensor containing ensemble_dim ensemble_dim: a mtf.Dimension vocab_dim: a mtf.Dimension reduce_fn: a function Returns: a mtf.Tensor with shape logits.shape - ensemble_dim """ return reduce_fn(logits, ensemble_dim, vocab_dim) @gin.configurable class VocabEmbedding(object): """A class to go from vocab ids to model states and model states to logits.""" def __init__(self, mesh, vocab_dim, output_dim, variable_dtype, name, ensemble_dim, scale_variable_like_classifier_weights=False): """Embedding for the vocabulary. Most of the arguments get passed to `mtf.layers.embedding_weights`. Args: mesh: a mtf.Mesh vocab_dim: a mtf.Dimension output_dim: a mtf.Dimension variable_dtype: a mtf.VariableDType name: a string ensemble_dim: a mtf.Dimension scale_variable_like_classifier_weights: a boolean """ self._vocab_dim = vocab_dim self._output_dim = output_dim self._scale_variable_like_classifier_weights = ( scale_variable_like_classifier_weights) if self._scale_variable_like_classifier_weights: initializer = tf.random_normal_initializer( stddev=self._output_dim.size ** -0.5) else: initializer = None self._embedding_weights = mtf.layers.embedding_weights( mesh=mesh, vocab_dim=vocab_dim, output_dim=output_dim, variable_dtype=variable_dtype, name=name, ensemble_dim=ensemble_dim, initializer=initializer) def ids_to_embedding(self, ids): ret = mtf.gather(self._embedding_weights, ids, self._vocab_dim) if self._scale_variable_like_classifier_weights: ret *= self._output_dim.size ** 0.5 return ret def hidden_to_logits(self, hidden, context): del context if not self._scale_variable_like_classifier_weights: hidden *= self._output_dim.size**-0.5 return mtf.einsum([hidden, self._embedding_weights], reduced_dims=[self._output_dim]) @gin.configurable def get_vocab_embedding_cls(cls=VocabEmbedding): """Configurable function to get the class to use for vocab embeddings. Args: cls: a class implementing the interface of mtf.transformer VocabEmbedding Returns: the class """ return cls def sinusoid_positional_embedding_weights(mesh, max_length_dim, model_dim, dtype, min_timescale=1.0, max_timescale=1.0e4): """Gets a bunch of sinusoids of different frequencies. Mostly copied from tensor2tensor's get_timing_signal_1d. Args: mesh: a mtf.Mesh max_length_dim: a mtf.Dimension model_dim: a mtf.Dimension dtype: a tf.DType min_timescale: a float max_timescale: a float Returns: an mtf.Tensor of timing signals with shape [max_length_dim, model_dim] Raises: ValueError: If the model_dim is not divisible by 2. """ if model_dim.size % 2: raise ValueError("model_dim must be divisible by 2") num_timescales = model_dim.size // 2 timescale_dim = mtf.Dimension(model_dim.name, num_timescales) log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / max(float(num_timescales) - 1, 1)) inv_timescales = min_timescale * mtf.exp( mtf.mtf_range(mesh, timescale_dim, dtype=tf.float32) * -log_timescale_increment) position = mtf.mtf_range(mesh, max_length_dim, dtype=tf.float32) scaled_time = mtf.einsum([position, inv_timescales]) # Please note that this slightly differs from the published paper. # See a discussion here: https://github.com/tensorflow/tensor2tensor/pull/177 embeddings = mtf.concat( [mtf.sin(scaled_time), mtf.cos(scaled_time)], model_dim.name) return mtf.cast(embeddings, dtype)