# coding=utf-8 # Copyright 2020 The Mesh TensorFlow Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Lint as: python3 """MeshTensorFlow implementation of BERT. The code is ported from https://github.com/google-research/bert. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import copy import json import math import re import mesh_tensorflow as mtf import mesh_tensorflow.transformer.moe as moe import six import tensorflow.compat.v1 as tf class BertConfig(object): """Configuration for `BertModel`.""" def __init__(self, vocab_size, d_model=768, position_signal="embedding", max_position_embeddings=512, num_blocks=12, block_layers="attention,feedforward", layer_output_dropout_prob=0.1, residual_structure="original", use_bias=True, attention_num_heads=12, attention_head_size=None, attention_num_key_heads=None, attention_key_head_size=None, attention_num_value_heads=None, attention_value_head_size=None, attention_probs_dropout_prob=0.1, feedforward_intermediate_size=3072, feedforward_intermediate_act="gelu", feedforward_intermediate_dropout_prob=0.0, moe_num_experts=32, moe_intermediate_size=6144, type_vocab_size=16, initializer_range=0.02): """Constructs BertConfig. residual_structure="original" TODO(noam): describe residual_structure="direct" TODO(noam): describe Args: vocab_size: Vocabulary size of `inputs_ids` in `BertModel`. d_model: Number of channels in input/output of each layer. position_signal: A string specifying the type of position signal. Implemented values are "embedding", "relative_attention_bias". max_position_embeddings: For models using positional embeddings, this is the maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048). num_blocks: Number of (attention+feed-forward) blocks in the Transformer encoder. block_layers: a comma-separated string specifying the sequence of layers in each block. layer_output_dropout_prob: The dropout probability for the output of each layer. residual_structure: a string. Legal values are "original" and "direct". use_bias: a boolean - If true, then we use biases for dense layers and in layer normalization, and subtract off the mean in layer normalization. attention_num_heads: Number of attention heads for each attention layer in the Transformer encoder. attention_head_size: Size of attention keys and values. If set to None, a default value is used equal to (d_model / attention_num_heads) attention_num_key_heads: Number of attention key heads. attention_key_head_size: Size of attention keys. attention_num_value_heads: Number of attention value heads. attention_value_head_size: Size of attention values. attention_probs_dropout_prob: The dropout ratio for the attention probabilities. feedforward_intermediate_size: The size of the "intermediate" layer in the feed-forward layer in the Transformer encoder (a.k.a. d_ff). feedforward_intermediate_act: The non-linear activation function (function or string) applied to the feedforward intermediate layer and the pooler layer. feedforward_intermediate_dropout_prob: The dropout probability for feed-forward intermediate layer. moe_num_experts: an integer - number of experts in moe layer moe_intermediate_size: an integer - size of intermediate layer in each expert type_vocab_size: The vocabulary size of the `token_type_ids` passed into `BertModel`. initializer_range: The stdev of the truncated_normal_initializer for initializing all weight matrices. """ self.vocab_size = vocab_size self.d_model = d_model self.position_signal = position_signal self.max_position_embeddings = max_position_embeddings self.num_blocks = num_blocks self.block_layers = block_layers.split(",") self.layer_output_dropout_prob = layer_output_dropout_prob self.residual_structure = residual_structure self.use_bias = use_bias self.attention_probs_dropout_prob = attention_probs_dropout_prob self.attention_num_heads = attention_num_heads self.attention_head_size = attention_head_size self.attention_num_key_heads = attention_num_key_heads self.attention_key_head_size = attention_key_head_size self.attention_num_value_heads = attention_num_value_heads self.attention_value_head_size = attention_value_head_size self.feedforward_intermediate_size = feedforward_intermediate_size self.feedforward_intermediate_act = feedforward_intermediate_act self.feedforward_intermediate_dropout_prob = ( feedforward_intermediate_dropout_prob) self.moe_num_experts = moe_num_experts self.moe_intermediate_size = moe_intermediate_size self.type_vocab_size = type_vocab_size self.initializer_range = initializer_range if self.position_signal not in ["embedding", "relative_attention_bias"]: raise ValueError("unknown position_signal") if self.residual_structure not in ["original", "direct"]: raise ValueError("unknown residual_structure") @classmethod def from_dict(cls, json_object): """Constructs a `BertConfig` from a Python dictionary of parameters.""" # Dictionary for compatibility for tf BertConfig files. hparam_name_conversion = { "hidden_size": "d_model", "num_hidden_layers": "num_blocks", "num_attention_heads": "attention_num_heads", "intermediate_size": "feedforward_intermediate_size", "hidden_act": "feedforward_intermediate_act", "hidden_dropout_prob": "layer_output_dropout_prob", } config = BertConfig(vocab_size=None) for (key, value) in six.iteritems(json_object): config.__dict__[hparam_name_conversion.get(key, key)] = value return config @classmethod def from_json_file(cls, json_file): """Constructs a `BertConfig` from a json file of parameters.""" with tf.gfile.GFile(json_file, "r") as reader: text = reader.read() return cls.from_dict(json.loads(text)) def to_dict(self): """Serializes this instance to a Python dictionary.""" output = copy.deepcopy(self.__dict__) return output def to_json_string(self): """Serializes this instance to a JSON string.""" return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" class BertModel(object): """BERT model ("Bidirectional Encoder Representations from Transformers").""" def __init__(self, config, is_training, input_ids, input_mask=None, token_type_ids=None, scope=None, mesh_shape="", layout=""): self.config = copy.deepcopy(config) del config if not is_training: self.config.layer_output_dropout_prob = 0.0 self.config.attention_probs_dropout_prob = 0.0 self.config.feedforward_intermediate_dropout_prob = 0.0 input_shape = input_ids.shape assert input_shape.ndims == 2 self._seq_dim = input_shape.dims[1] self._memory_seq_dim = mtf.Dimension("memory_seq", self.seq_dim.size) self._extra_losses = [] mesh = input_ids.mesh if token_type_ids is None: token_type_ids = mtf.zeros(mesh, input_shape, dtype=tf.int32) with tf.variable_scope(scope, default_name="bert"): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. self.embedding_table = mtf.get_variable( mesh, "word_embeddings", mtf.Shape([self.vocab_dim, self.model_dim]), initializer=self.embedding_initializer) self.word_embedding_output = mtf.gather( self.embedding_table, input_ids, self.vocab_dim) # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. self.embedding_output = self.word_embedding_output token_type_table = mtf.get_variable( mesh, "token_type_embeddings", mtf.Shape([self.token_type_vocab_dim, self.model_dim]), initializer=self.embedding_initializer) if token_type_ids is not None: self.embedding_output += mtf.gather( token_type_table, token_type_ids, self.token_type_vocab_dim) if self.config.position_signal == "embedding": full_position_table = mtf.get_variable( mesh, "position_embeddings", mtf.Shape([self.max_position_embeddings_dim, self.model_dim]), initializer=self.embedding_initializer) short_position_table = mtf.rename_dimension( mtf.slice(full_position_table, 0, self.seq_dim.size, self.max_position_embeddings_dim.name), self.max_position_embeddings_dim.name, self.seq_dim.name) self.embedding_output += short_position_table self.embedding_output = self.normalize(self.embedding_output) self.embedding_output = mtf.dropout( self.embedding_output, keep_prob=1.0 - self.config.layer_output_dropout_prob) with tf.variable_scope("encoder"): attention_biases = [] if input_mask: # [batch_dim, memory_seq_dim] attention_biases.append( (1.0 - mtf.to_float(mtf.replace_dimensions( input_mask, self.seq_dim, self.memory_seq_dim))) * -10000.0) if self.config.position_signal == "relative_attention_bias": buckets_dim = mtf.Dimension("buckets", 32) rp_bucket = _relative_position_bucket( mtf.range(mesh, self.memory_seq_dim, tf.int32) - mtf.range(mesh, self.seq_dim, tf.int32), num_buckets=buckets_dim.size) bias_var = mtf.get_variable( mesh, "relative_attention_bias", [self.num_heads_dim, buckets_dim], initializer=tf.zeros_initializer()) attention_biases.append(mtf.gather(bias_var, rp_bucket, buckets_dim)) attention_bias = mtf.add_n(attention_biases) prev_layer_output = self.embedding_output self.all_encoder_layers = [] for block_num in range(self.config.num_blocks): with tf.variable_scope("block_%d" % block_num): for layer_idx, layer_type in enumerate(self.config.block_layers): layer_name = layer_type count = self.config.block_layers[:layer_idx].count(layer_type) if count: layer_name += "_%d" % count with tf.variable_scope(layer_name): x = prev_layer_output if self.config.residual_structure == "direct": x = self.normalize(x) if layer_type == "attention": x = self.self_attention(x, attention_bias) elif layer_type == "feedforward": x = self.feedforward(x) elif layer_type == "moe": x = self.moe(x, layout, mesh_shape, input_mask, is_training) else: raise ValueError("unknown layer type " + layer_type) x = mtf.dropout( x, keep_prob=1.0 - self.config.layer_output_dropout_prob) layer_output = prev_layer_output + x if self.config.residual_structure == "original": layer_output = self.normalize(layer_output) prev_layer_output = layer_output self.all_encoder_layers.append(layer_output) self.sequence_output = prev_layer_output if self.config.residual_structure == "direct": self.sequence_output = self.normalize(self.sequence_output) # The "pooler" converts the encoded sequence tensor of shape # [batch_dim, seq_dim, hidden_size] to a tensor of shape # [batch_dim, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = mtf.gather(self.sequence_output, 0, self.seq_dim) self.pooled_output = mtf.layers.dense( first_token_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=mtf.tanh, kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias) def self_attention(self, x, attention_bias): """Performs multi-headed self-attention with output projection. Args: x: output of previous layer attention_bias: optional float32 Tensor broadcastable to shape x.shape - self.model_dim + self.memory_seq_dim to be added to attention logits. This may used to mask out padding regions of the memory. Returns: float Tensor with the same shape as x """ queries = mtf.layers.dense( x, reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="query", use_bias=self.config.use_bias) keys = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="key", use_bias=self.config.use_bias) values = mtf.layers.dense( mtf.replace_dimensions(x, self.seq_dim, self.memory_seq_dim), reduced_dims=[self.model_dim], new_dims=[self.num_heads_dim, self.size_per_head_dim], kernel_initializer=self.dense_initializer, name="value", use_bias=self.config.use_bias) # Take the dot product between "query" and "key" to get the raw # attention scores. attention_scores = mtf.einsum( [queries, keys], reduced_dims=[self.size_per_head_dim]) attention_scores *= self.size_per_head_dim.size ** -0.5 if attention_bias is not None: attention_scores += attention_bias # Normalize the attention scores to probabilities. attention_probs = mtf.softmax(attention_scores, self.memory_seq_dim) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = mtf.dropout( attention_probs, keep_prob=1.0 - self.config.attention_probs_dropout_prob) output = mtf.einsum([attention_probs, values], reduced_dims=[self.memory_seq_dim]) # linear transformation back to shape of query_antecedent output = mtf.layers.dense( output, reduced_dims=[self.num_heads_dim, self.size_per_head_dim], new_dims=[self.model_dim], kernel_initializer=self.dense_initializer, name="output", use_bias=self.config.use_bias) output = mtf.transpose(output, x.shape) return output def feedforward(self, x): intermediate = mtf.layers.dense( x, reduced_dims=[self.model_dim], new_dims=[self.feedforward_intermediate_dim], activation=get_activation(self.config.feedforward_intermediate_act), kernel_initializer=self.dense_initializer, name="dense_1", use_bias=self.config.use_bias) return mtf.layers.dense( intermediate, reduced_dims=[self.feedforward_intermediate_dim], new_dims=[self.model_dim], kernel_initializer=self.dense_initializer, name="dense_2", use_bias=self.config.use_bias) def moe(self, x, layout, mesh_shape, input_mask, is_training): """Mixture of experts layer. TODO(noam): clean up the mixture-of-experts code in Transformer. Args: x: layer input layout: a mtf.LayoutRules mesh_shape: a mtf.Shape input_mask: a mtf.Tensor is_training: a boolean Returns: a mtf.Tensor (the layer output) """ hparams = moe.HParams( moe_gating="top_2", moe_num_experts=self.config.moe_num_experts, moe_loss_coef=1e-3, moe_hidden_size=self.config.moe_intermediate_size, moe_group_size=2048, moe_capacity_factor_train=1.25, moe_capacity_factor_eval=8.0, moe_use_second_place_loss=False, moe_second_policy_train="random", moe_second_policy_eval="random", moe_second_threshold_train=0.2, moe_second_threshold_eval=0.2) layer_output, loss = moe.transformer_moe_layer_v1( inputs=x, output_dim=self.model_dim, hparams=hparams, train=is_training, variable_dtype=tf.float32, layout=layout, mesh_shape=mesh_shape, nonpadding=(mtf.cast(input_mask, tf.float32) if input_mask else None), activation=get_activation(self.config.feedforward_intermediate_act)) self._extra_losses.append(loss) return layer_output def get_masked_lm_output(self, positions, label_ids, label_weights): """Get loss and logits for the masked LM.""" input_tensor = self.get_sequence_output() output_weights = self.get_embedding_table() # [batch_size, num_position, hidden] input_tensor = mtf.gather(input_tensor, positions, self.seq_dim) with tf.variable_scope("cls/predictions"): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = mtf.layers.dense( input_tensor, reduced_dims=[self.model_dim], new_dims=[self.model_dim], activation=get_activation(self.config.feedforward_intermediate_act), kernel_initializer=self.dense_initializer, use_bias=self.config.use_bias) input_tensor = self.normalize(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = mtf.get_variable( input_tensor.mesh, name="output_bias", shape=[self.vocab_dim], initializer=tf.zeros_initializer()) logits = mtf.einsum([input_tensor, output_weights], reduced_dims=[self.model_dim]) + output_bias per_example_loss = mtf.layers.softmax_cross_entropy_with_logits( logits, label_ids, self.vocab_dim, z_loss=1e-4) # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. numerator = mtf.reduce_sum(label_weights * per_example_loss) denominator = mtf.reduce_sum(label_weights) + mtf.constant( input_tensor.mesh, 1e-5, dtype=tf.float32) loss = numerator / denominator return (loss, per_example_loss, logits) def get_next_sentence_output(self, labels): """Get loss and logits for the next sentence prediction.""" class_dim = mtf.Dimension("class", 2) input_tensor = self.get_pooled_output() # Simple binary classification. Note that 0 is "next sentence" and 1 is # "random sentence". This weight matrix is not used after pre-training. logits = mtf.layers.dense( input_tensor, reduced_dims=[self.model_dim], new_dims=[class_dim], kernel_initializer=self.dense_initializer, name="cls/seq_relationship", use_bias=self.config.use_bias) per_example_loss = mtf.layers.softmax_cross_entropy_with_logits( logits, labels, class_dim, z_loss=1e-4) loss = mtf.reduce_mean(per_example_loss) return (loss, per_example_loss, logits) def get_pooled_output(self): return self.pooled_output def get_sequence_output(self): """Gets final hidden layer of encoder. Returns: float Tensor of shape [batch_dim, seq_dim, model_dim] corresponding to the final hidden of the transformer encoder. """ return self.sequence_output def get_all_encoder_layers(self): return self.all_encoder_layers def get_word_embedding_output(self): """Get output of the word(piece) embedding lookup. This is BEFORE positional embeddings and token type embeddings have been added. Returns: float Tensor of shape [batch_dim, seq_dim, model_dim] corresponding to the output of the word(piece) embedding layer. """ return self.word_embedding_output def get_embedding_output(self): """Gets output of the embedding lookup (i.e., input to the transformer). Returns: float Tensor of shape [batch_dim, seq_dim, model_dim] corresponding to the output of the embedding layer, after summing the word embeddings with the positional embeddings and the token type embeddings, then performing layer normalization. This is the input to the transformer. """ return self.embedding_output def normalize(self, x): return layer_norm(x, self.model_dim, subtract_mean=self.config.use_bias, use_bias=self.config.use_bias) def get_embedding_table(self): return self.embedding_table def get_extra_loss(self): return mtf.add_n(self._extra_losses) @property def vocab_dim(self): # pad vocab to a multiple of 128 so as to be splittable. # TODO(noam): This creates issues in checkpoint compatibility n = self.config.vocab_size return mtf.Dimension("vocab", n + (-n % 128)) @property def model_dim(self): return mtf.Dimension("hidden", self.config.d_model) @property def token_type_vocab_dim(self): return mtf.Dimension("token_type_vocab", self.config.type_vocab_size) @property def feedforward_intermediate_dim(self): return mtf.Dimension("intermediate", self.config.feedforward_intermediate_size) @property def num_heads_dim(self): return mtf.Dimension("num_heads", self.config.attention_num_heads) @property def softmax_heads_dims(self): return self.num_heads_dim @property def max_position_embeddings_dim(self): return mtf.Dimension("max_position_embeddings", self.config.max_position_embeddings) @property def seq_dim(self): return self._seq_dim @property def memory_seq_dim(self): return self._memory_seq_dim @property def dense_initializer(self): if self.config.initializer_range: return tf.truncated_normal_initializer( stddev=self.config.initializer_range) else: return mtf.layers.VarianceScalingInitializer(scale=0.4) @property def embedding_initializer(self): initializer = self.dense_initializer if isinstance(initializer, mtf.layers.DenseInitializer): # embedding matrix is also used as classifier weight matrix. # scale it appropriately. return initializer( reduced_dims=[self.model_dim], new_dims=[self.vocab_dim]) else: return initializer @property def size_per_head_dim(self): """Dimensionality of attention queries/keys/values.""" if self.config.attention_head_size: attention_head_size = self.config.attention_head_size else: if self.model_dim.size % self.num_heads_dim.size != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (self.model_dim.size, self.num_heads_dim.size)) attention_head_size = int(self.model_dim.size / self.num_heads_dim.size) return mtf.Dimension("attention_head", attention_head_size) @property def key_dim(self): """Dimensionality of attention key.""" if self.config.attention_key_head_size is None: raise ValueError("The key head size is not defined.") return mtf.Dimension("d_k", self.config.attention_key_head_size) @property def key_heads_dims(self): """Dimensionality of number of key heads.""" if self.config.attention_num_key_heads is None: raise ValueError("The number of key heads is not defined.") return mtf.Dimension("key_heads", self.config.attention_num_key_heads) @property def value_dim(self): """Dimensionality of attention value.""" if self.config.attention_value_head_size is None: raise ValueError("The value head size is not defined.") return mtf.Dimension("d_v", self.config.attention_value_head_size) @property def value_heads_dims(self): """Dimensionality of number of value heads.""" if self.config.attention_num_value_heads is None: raise ValueError("The number of value heads is not defined.") return mtf.Dimension("value_heads", self.config.attention_num_value_heads) def get_activation(activation_string): """Maps a string to a Python function, e.g., "relu" => `mtf.relu`. Args: activation_string: String name of the activation function. Returns: A Python function corresponding to the activation function. If `activation_string` is None, empty, or "linear", this will return None. If `activation_string` is not a string, it will return `activation_string`. Raises: ValueError: The `activation_string` does not correspond to a known activation. """ # We assume that anything that"s not a string is already an activation # function, so we just return it. if not isinstance(activation_string, six.string_types): return activation_string if not activation_string: return None act = activation_string.lower() if act == "linear": return None elif act == "gelu": return mtf.gelu elif act == "relu": return mtf.relu elif act == "tanh": return mtf.tanh else: raise ValueError("Unsupported activation: %s" % act) def get_assignment_map_from_checkpoint(tvars, init_checkpoint): """Compute the union of the current variables and checkpoint variables.""" assignment_map = {} initialized_variable_names = {} name_to_variable = collections.OrderedDict() for var in tvars: name = var.name m = re.match("^(.*):\\d+$", name) if m is not None: name = m.group(1) if "global_step" in name or "adam_" in name or "slot_" in name: continue name_to_variable[name] = var tf.logging.info("init_checkpoint:{} ".format(init_checkpoint)) init_vars = tf.train.list_variables(init_checkpoint) assignment_map = collections.OrderedDict() for x in init_vars: (name, var) = (x[0], x[1]) if name not in name_to_variable: continue assignment_map[name] = name initialized_variable_names[name] = 1 initialized_variable_names[name + ":0"] = 1 return (assignment_map, initialized_variable_names) def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): """Translate relative position to a bucket number for relative attention. The relative position is defined as memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for small absolute relative_position and larger buckets for larger absolute relative_positions. All relative positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. This should allow for more graceful generalization to longer sequences than the model has been trained on. Args: relative_position: an int32 Tensor bidirectional: a boolean - whether the attention is bidirectional num_buckets: an integer max_distance: an integer Returns: a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) """ ret = 0 n = -relative_position if bidirectional: num_buckets //= 2 ret += mtf.to_int32(mtf.less(n, 0)) * num_buckets n = mtf.abs(n) else: n = mtf.maximum(n, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 is_small = mtf.less(n, max_exact) val_if_large = max_exact + mtf.to_int32( mtf.log(mtf.to_float(n) / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)) val_if_large = mtf.minimum(val_if_large, num_buckets - 1) ret += mtf.where(is_small, n, val_if_large) return ret def layer_norm(x, dim, epsilon=1e-6, subtract_mean=True, use_scale=True, use_bias=True, name=None): """Layer normalization over dimension dim. TODO(noam): This is cleaner than the version in mtf.layers Move this version into mtf.layers to replace the one there. Args: x: a mtf.Tensor whose shape contains dim. dim: a mtf.Dimension epsilon: a floating point number subtract_mean: a boolean use_scale: a boolean use_bias: a boolean name: a string used for tf.variable_scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name, default_name="layer_norm"): if subtract_mean: x -= mtf.reduce_mean(x, reduced_dim=dim) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=dim) x *= mtf.rsqrt(variance + epsilon) if use_scale: x *= mtf.get_variable( x.mesh, "scale", mtf.Shape([dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) if use_bias: x += mtf.get_variable( x.mesh, "bias", mtf.Shape([dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) return x