# Copyright 2018 MLBenchmark Group. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Defines the Transformer model, and its encoder and decoder stacks. Model paper: https://arxiv.org/pdf/1706.03762.pdf Transformer model code source: https://github.com/tensorflow/tensor2tensor """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import tensorflow as tf from mlperf_compliance import mlperf_log from model import attention_layer from model import beam_search from model import embedding_layer from model import ffn_layer from model import model_utils from utils.tokenizer import EOS_ID _NEG_INF = -1e9 # Define defaults for parameters policy = tf.keras.mixed_precision.experimental.Policy('mixed_bfloat16') class Transformer(object): """Transformer model for sequence to sequence data. Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf The Transformer model consists of an encoder and decoder. The input is an int sequence (or a batch of sequences). The encoder produces a continous representation, and the decoder uses the encoder output to generate probabilities for the output sequence. """ def __init__(self, params, train): """Initialize layers to build Transformer model. Args: params: hyperparameter object defining layer sizes, dropout values, etc. train: boolean indicating whether the model is in training mode. Used to determine if dropout layers should be added. """ self.train = train self.params = params self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights( params.vocab_size, params.hidden_size) self.encoder_stack = EncoderStack(params, train) self.decoder_stack = DecoderStack(params, train) def __call__(self, inputs, targets=None): """Calculate target logits or inferred target sequences. Args: inputs: int tensor with shape [batch_size, input_length]. targets: None or int tensor with shape [batch_size, target_length]. Returns: If targets is defined, then return logits for each word in the target sequence. float tensor with shape [batch_size, target_length, vocab_size] If target is none, then generate output sequence one token at a time. returns a dictionary { output: [batch_size, decoded length] score: [batch_size, float]} """ # Variance scaling is used here because it seems to work in many problems. # Other reasonable initializers may also work just as well. mlperf_log.transformer_print(key=mlperf_log.MODEL_HP_INITIALIZER_GAIN, value=self.params.initializer_gain) initializer = tf.compat.v1.variance_scaling_initializer( self.params.initializer_gain, mode="fan_avg", distribution="uniform") with tf.compat.v1.variable_scope("Transformer", initializer=initializer): # Calculate attention bias for encoder self-attention and decoder # multi-headed attention layers. attention_bias = model_utils.get_padding_bias(inputs) attention_bias = tf.cast(attention_bias, tf.bfloat16) # Run the inputs through the encoder layer to map the symbol # representations to continuous representations. encoder_outputs = self.encode(inputs, attention_bias) # Generate output sequence if targets is None, or return logits if target # sequence is known. if targets is None: out_seq = self.predict(encoder_outputs, attention_bias) return out_seq else: logits = self.decode(targets, encoder_outputs, attention_bias) return logits def encode(self, inputs, attention_bias): """Generate continuous representation for inputs. Args: inputs: int tensor with shape [batch_size, input_length]. attention_bias: float tensor with shape [batch_size, 1, 1, input_length] Returns: float tensor with shape [batch_size, input_length, hidden_size] """ with tf.compat.v1.name_scope("encode"): # Prepare inputs to the layer stack by adding positional encodings and # applying dropout. embedded_inputs = self.embedding_softmax_layer(inputs) inputs_padding = model_utils.get_padding(inputs) with tf.compat.v1.name_scope("add_pos_encoding"): length = tf.shape(input=embedded_inputs)[1] pos_encoding = model_utils.get_position_encoding( length, self.params.hidden_size) encoder_inputs = embedded_inputs + pos_encoding with tf.compat.v1.tpu.bfloat16_scope(): encoder_inputs = tf.cast(encoder_inputs, tf.bfloat16) #attention_bias = tf.cast(attention_bias, tf.bfloat16) inputs_padding = tf.cast(inputs_padding, tf.bfloat16) if self.train: mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=self.params.layer_postprocess_dropout) encoder_inputs = tf.nn.dropout( encoder_inputs, 1 - (1 - self.params.layer_postprocess_dropout)) #encoder_outputs = self.encoder_stack(encoder_inputs, attention_bias, inputs_padding) #return encoder_outputs # self.encoder_stack(encoder_inputs, attention_bias, inputs_padding) return self.encoder_stack(encoder_inputs, attention_bias, inputs_padding) def decode(self, targets, encoder_outputs, attention_bias): """Generate logits for each value in the target sequence. Args: targets: target values for the output sequence. int tensor with shape [batch_size, target_length] encoder_outputs: continuous representation of input sequence. float tensor with shape [batch_size, input_length, hidden_size] attention_bias: float tensor with shape [batch_size, 1, 1, input_length] Returns: float32 tensor with shape [batch_size, target_length, vocab_size] """ with tf.compat.v1.name_scope("decode"): # Prepare inputs to decoder layers by shifting targets, adding positional # encoding and applying dropout. decoder_inputs = self.embedding_softmax_layer(targets) with tf.compat.v1.name_scope("shift_targets"): # Shift targets to the right, and remove the last element decoder_inputs = tf.pad( tensor=decoder_inputs, paddings=[[0, 0], [1, 0], [0, 0]])[:, :-1, :] with tf.compat.v1.name_scope("add_pos_encoding"): length = tf.shape(input=decoder_inputs)[1] decoder_inputs += model_utils.get_position_encoding( length, self.params.hidden_size) if self.train: mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=self.params.layer_postprocess_dropout) decoder_inputs = tf.nn.dropout( decoder_inputs, 1 - (1 - self.params.layer_postprocess_dropout)) with tf.compat.v1.tpu.bfloat16_scope(): decoder_inputs = tf.cast(decoder_inputs, tf.bfloat16) #encoder_outputs = tf.cast(encoder_outputs, tf.bfloat16) #attention_bias = tf.cast(attention_bias, tf.bfloat16) # Run values decoder_self_attention_bias = tf.cast(model_utils.get_decoder_self_attention_bias( length), tf.bfloat16) outputs = self.decoder_stack( decoder_inputs, encoder_outputs, decoder_self_attention_bias, attention_bias) logits = self.embedding_softmax_layer.linear(outputs) logits = tf.cast(logits, tf.float32) return logits def _get_symbols_to_logits_fn(self, max_decode_length): """Returns a decoding function that calculates logits of the next tokens.""" timing_signal = model_utils.get_position_encoding( max_decode_length + 1, self.params.hidden_size) decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( max_decode_length) def symbols_to_logits_fn(ids, i, cache): """Generate logits for next potential IDs. Args: ids: Current decoded sequences. int tensor with shape [batch_size * beam_size, i + 1] i: Loop index cache: dictionary of values storing the encoder output, encoder-decoder attention bias, and previous decoder attention values. Returns: Tuple of (logits with shape [batch_size * beam_size, vocab_size], updated cache values) """ # Set decoder input to the last generated IDs decoder_input = ids[:, -1:] # Preprocess decoder input by getting embeddings and adding timing signal. decoder_input = self.embedding_softmax_layer(decoder_input) decoder_input += timing_signal[i:i + 1] self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] with tf.compat.v1.tpu.bfloat16_scope(): decoder_input = tf.cast(decoder_input, tf.bfloat16) self_attention_bias = tf.cast(self_attention_bias, tf.bfloat16) decoder_outputs = self.decoder_stack( decoder_input, cache.get("encoder_outputs"), self_attention_bias, cache.get("encoder_decoder_attention_bias"), cache) logits = self.embedding_softmax_layer.linear(decoder_outputs) logits = tf.squeeze(logits, axis=[1]) logits = tf.cast(logits, tf.float32) return logits, cache return symbols_to_logits_fn def predict(self, encoder_outputs, encoder_decoder_attention_bias): """Return predicted sequence.""" batch_size = tf.shape(input=encoder_outputs)[0] input_length = tf.shape(input=encoder_outputs)[1] max_decode_length = input_length + self.params.extra_decode_length symbols_to_logits_fn = self._get_symbols_to_logits_fn(max_decode_length) # Create initial set of IDs that will be passed into symbols_to_logits_fn. initial_ids = tf.zeros([batch_size], dtype=tf.int32) # Create cache storing decoder attention values for each layer. cache = { "layer_%d" % layer: { "k": tf.zeros([batch_size, 0, self.params.hidden_size]), "v": tf.zeros([batch_size, 0, self.params.hidden_size]), } for layer in range(self.params.num_hidden_layers)} # Add encoder output and attention bias to the cache. #with tf.compat.v1.tpu.bfloat16_scope(): encoder_outputs = tf.cast(encoder_outputs, tf.bfloat16) encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, tf.bfloat16) cache["encoder_outputs"] = encoder_outputs cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias # Use beam search to find the top beam_size sequences and scores. mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_SEQ_BEAM_SEARCH, value={ "vocab_size": self.params.vocab_size, "beam_size": self.params.beam_size, "alpha": self.params.alpha, "extra_decode_length": self.params.extra_decode_length}) with tf.compat.v1.tpu.bfloat16_scope(): decoded_ids, scores = beam_search.sequence_beam_search( symbols_to_logits_fn=symbols_to_logits_fn, initial_ids=initial_ids, initial_cache=cache, vocab_size=self.params.vocab_size, beam_size=self.params.beam_size, alpha=self.params.alpha, max_decode_length=max_decode_length, eos_id=EOS_ID) # Get the top sequence for each batch element top_decoded_ids = decoded_ids[:, 0, 1:] top_scores = scores[:, 0] return {"outputs": top_decoded_ids, "scores": top_scores} class LayerNormalization(tf.compat.v1.layers.Layer): """Applies layer normalization.""" def __init__(self, hidden_size): super(LayerNormalization, self).__init__() self.hidden_size = hidden_size def build(self, _): mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NORM, value={"hidden_size": self.hidden_size}) self.scale = tf.compat.v1.get_variable("layer_norm_scale", [self.hidden_size], initializer=tf.compat.v1.ones_initializer()) self.bias = tf.compat.v1.get_variable("layer_norm_bias", [self.hidden_size], initializer=tf.compat.v1.zeros_initializer()) self.built = True def call(self, x, epsilon=1e-6): mean = tf.reduce_mean(input_tensor=x, axis=[-1], keepdims=True) variance = tf.reduce_mean(input_tensor=tf.square(x - mean), axis=[-1], keepdims=True) norm_x = (x - mean) * tf.math.rsqrt(variance + epsilon) return norm_x * self.scale + self.bias class PrePostProcessingWrapper(object): """Wrapper class that applies layer pre-processing and post-processing.""" def __init__(self, layer, params, train): self.layer = layer self.postprocess_dropout = params.layer_postprocess_dropout self.train = train # Create normalization layer self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=0.000001, dtype=policy) #self.layer_norm = LayerNormalization(params.hidden_size) def __call__(self, x, *args, **kwargs): # Preprocessing: apply layer normalization #casting back to float32 x = tf.cast(x, tf.bfloat16) y = self.layer_norm(x) #y = tf.cast(y, tf.float32) # Get layer output y = self.layer(y, *args, **kwargs) # Postprocessing: apply dropout and residual connection if self.train: mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_LAYER_POSTPROCESS_DROPOUT, value=self.postprocess_dropout) y = tf.nn.dropout(y, 1 - (1 - self.postprocess_dropout)) return x + y class EncoderStack(tf.compat.v1.layers.Layer): """Transformer encoder stack. The encoder stack is made up of N identical layers. Each layer is composed of the sublayers: 1. Self-attention layer 2. Feedforward network (which is 2 fully-connected layers) """ def __init__(self, params, train): super(EncoderStack, self).__init__() self.layers = [] mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS, value=params.num_hidden_layers) for _ in range(params.num_hidden_layers): # Create sublayers for each layer. self_attention_layer = attention_layer.SelfAttention( params.hidden_size, params.num_heads, params.attention_dropout, train) feed_forward_network = ffn_layer.FeedFowardNetwork( params.hidden_size, params.filter_size, params.relu_dropout, train) self.layers.append([ PrePostProcessingWrapper(self_attention_layer, params, train), PrePostProcessingWrapper(feed_forward_network, params, train)]) # Create final layer normalization layer. #self.output_normalization = LayerNormalization(params.hidden_size) self.output_normalization = tf.keras.layers.LayerNormalization(epsilon=0.000001, dtype=policy) def call(self, encoder_inputs, attention_bias, inputs_padding): for n, layer in enumerate(self.layers): # Run inputs through the sublayers. self_attention_layer = layer[0] feed_forward_network = layer[1] with tf.compat.v1.variable_scope("layer_%d" % n): with tf.compat.v1.variable_scope("self_attention"): encoder_inputs = self_attention_layer(encoder_inputs, attention_bias) with tf.compat.v1.variable_scope("ffn"): encoder_inputs = feed_forward_network(encoder_inputs, inputs_padding) #encoder_inputs = tf.cast(encoder_inputs, tf.float32) return self.output_normalization(encoder_inputs) class DecoderStack(tf.compat.v1.layers.Layer): """Transformer decoder stack. Like the encoder stack, the decoder stack is made up of N identical layers. Each layer is composed of the sublayers: 1. Self-attention layer 2. Multi-headed attention layer combining encoder outputs with results from the previous self-attention layer. 3. Feedforward network (2 fully-connected layers) """ def __init__(self, params, train): super(DecoderStack, self).__init__() self.layers = [] mlperf_log.transformer_print( key=mlperf_log.MODEL_HP_NUM_HIDDEN_LAYERS, value=params.num_hidden_layers) for _ in range(params.num_hidden_layers): self_attention_layer = attention_layer.SelfAttention( params.hidden_size, params.num_heads, params.attention_dropout, train) enc_dec_attention_layer = attention_layer.Attention( params.hidden_size, params.num_heads, params.attention_dropout, train) feed_forward_network = ffn_layer.FeedFowardNetwork( params.hidden_size, params.filter_size, params.relu_dropout, train) self.layers.append([ PrePostProcessingWrapper(self_attention_layer, params, train), PrePostProcessingWrapper(enc_dec_attention_layer, params, train), PrePostProcessingWrapper(feed_forward_network, params, train)]) #self.output_normalization = LayerNormalization(params.hidden_size) self.output_normalization = tf.keras.layers.LayerNormalization(epsilon=0.000001) def call(self, decoder_inputs, encoder_outputs, decoder_self_attention_bias, attention_bias, cache=None): #with tf.compat.v1.tpu.bfloat16_scope(): #decoder_inputs = tf.cast(decoder_inputs, tf.bfloat16) #encoder_outputs = tf.cast(encoder_outputs, tf.bfloat16) #decoder_self_attention_bias = tf.cast(decoder_self_attention_bias, tf.bfloat16) #attention_bias = tf.cast(attention_bias, tf.bfloat16) for n, layer in enumerate(self.layers): self_attention_layer = layer[0] enc_dec_attention_layer = layer[1] feed_forward_network = layer[2] # Run inputs through the sublayers. layer_name = "layer_%d" % n layer_cache = cache[layer_name] if cache is not None else None with tf.compat.v1.variable_scope(layer_name): with tf.compat.v1.variable_scope("self_attention"): decoder_inputs = self_attention_layer( decoder_inputs, decoder_self_attention_bias, cache=layer_cache) with tf.compat.v1.variable_scope("encdec_attention"): decoder_inputs = enc_dec_attention_layer( decoder_inputs, encoder_outputs, attention_bias) with tf.compat.v1.variable_scope("ffn"): decoder_inputs = feed_forward_network(decoder_inputs) #decoder_inputs = tf.cast(decoder_inputs, tf.float32) return self.output_normalization(decoder_inputs)