# Lint as: python3 # Copyright 2020 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for batch_major_attention.""" from absl.testing import parameterized from lingvo import compat as tf from lingvo.core import attention as tm_attention from lingvo.core import base_layer from lingvo.core import batch_major_attention as attention from lingvo.core import hyperparams from lingvo.core import py_utils from lingvo.core import test_utils import numpy as np from six.moves import range from six.moves import zip class MultiHeadSelfAttentionTest(test_utils.TestCase, parameterized.TestCase): """Test attention models.""" def _AttentionInputs(self, input_dim=4, dtype=tf.float32): np.random.seed(6348575) batch_size = 6 seq_len = 6 input_vecs_p = [ np.random.rand(seq_len, input_dim) for _ in range(batch_size) ] input_vecs = tf.stack([tf.constant(x, dtype=dtype) for x in input_vecs_p]) # pyformat: disable input_padding_p = [[0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 1, 0], [0, 0, 1, 1, 0, 0], [1, 0, 0, 0, 1, 0], [0, 0, 1, 0, 1, 0]] # pyformat: enable input_padding = tf.constant(input_padding_p, dtype=dtype) return input_vecs, input_padding, input_vecs_p, input_padding_p def testDotProductAttention(self): (input_vecs, input_padding, input_vecs_p, input_padding_p) = self._AttentionInputs() p = attention.MultiHeadedAttention.Params().Set( name='self_atten', input_dim=4, hidden_dim=4) l = p.Instantiate() probs = l.AttenProbs( l.theta, tf.expand_dims(input_vecs, 2), tf.expand_dims(input_vecs, 2), input_padding, segment_mask=None) with self.session(use_gpu=False) as sess: tf.global_variables_initializer().run() prob_out = sess.run(tf.squeeze(probs)) # Use numpy to perform the same computation to generate expected results. input_vecs_p = np.array(input_vecs_p) target_vecs_p = np.transpose(input_vecs_p, (0, 2, 1)) expected_logit = np.matmul(input_vecs_p, target_vecs_p) expected_logit = np.transpose(expected_logit, (0, 2, 1)) elexp = np.exp(expected_logit) input_padding_p = np.array(input_padding_p) input_padding_p = np.expand_dims(input_padding_p, axis=1) input_padding_p = np.tile(input_padding_p, (1, 6, 1)) elexp *= (1 - input_padding_p) expected_prob_out = elexp / np.expand_dims(np.sum(elexp, axis=-1), axis=-1) expected_prob_out = np.reshape(expected_prob_out, (6, 6, 6)) self.assertAllClose(expected_prob_out, prob_out) def testMultiHeadedAttentionDotProduct(self): # input_batch:6, seq_len:6. Test n = 2 case. with self.session(use_gpu=True) as sess: input_vecs, input_padding, _, _ = self._AttentionInputs() p = attention.MultiHeadedAttention.Params().Set( name='self_atten', num_heads=2, input_dim=4, hidden_dim=4) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() tf.global_variables_initializer().run() ctx_vec, _ = l.FProp( l.theta, input_vecs, input_vecs, input_vecs, input_padding, segment_mask=None) context_vec_out = sess.run(ctx_vec) context_vec_out = np.reshape(context_vec_out, (6, 24)) self.assertAllClose( [27.417763, 31.783672, 19.99568, 23.907103, 21.078259, 28.429199], np.sum(context_vec_out, axis=1)) def testMultiHeadedAttentionDotProductSegmentMask(self): # input_batch:6, seq_len:6. Test n = 2 case. with self.session(use_gpu=True) as sess: input_vecs, input_padding, _, _ = self._AttentionInputs() p = attention.MultiHeadedAttention.Params().Set( name='self_atten', num_heads=2, input_dim=4, hidden_dim=4, packed_input=True) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) segment_id = tf.zeros([6, 6]) segment_mask = attention.SegmentMask(segment_id, segment_id) padding = tf.tile(tf.reshape(input_padding, [6, 1, 1, 6]), [1, 1, 6, 1]) padding_mask = padding * segment_mask.dtype.max * tf.constant( -0.7, dtype=segment_mask.dtype) segment_mask += padding_mask l = p.Instantiate() tf.global_variables_initializer().run() ctx_vec, _ = l.FProp( l.theta, input_vecs, input_vecs, input_vecs, input_padding, segment_mask=segment_mask) context_vec_out = sess.run(ctx_vec) context_vec_out = np.reshape(context_vec_out, (6, 24)) self.assertAllClose( [27.417763, 31.783672, 19.99568, 23.907103, 21.078259, 28.429199], np.sum(context_vec_out, axis=1)) class MultiHeadedAttentionXLOracle(object): """Oracle layer used for computing ground truths for MultiHeadedAttention. Written in a non-vectorized way. """ def __init__(self, u, v, pos_proj, sinusoid_emb): """Constructor. Args: u: A numpy ndarray of shape [N, H] v: A numpy ndarray of shape [N, H] pos_proj: A numpy ndarray of shape [embed_dim, N, H] sinusoid_emb: A numpy ndarray of shape [seqlen, emb_dim]. """ assert u.shape == v.shape assert u.shape == pos_proj.shape[1:] assert sinusoid_emb.shape[-1] == pos_proj.shape[0] # [N, H] self._u = u # [N, H] self._v = v # [?, N, H] self._pos_proj = pos_proj self._num_heads = u.shape[0] self._atten_dim = u.shape[-1] self._hidden_dim = u.shape[0] * u.shape[-1] self._sinusoid_emb = sinusoid_emb def _GetPositionEnc(self, tgt_t, src_t, head, seqlen): """Gets positional encoding. Args: tgt_t: A Python int, time step of target seq. src_t: A Python int, time step of source seq. head: A Python int, num of heads of the attention. seqlen: A Python int, sequence length of target/source seq. Returns: A numpy array of shape [head, emb_dim // head]. """ # [emb_dim] sinusoid_enc = self._sinusoid_emb[tgt_t - src_t + seqlen - 1] return np.einsum('DNH,D->NH', self._pos_proj, sinusoid_enc)[head] def AttenProbs(self, key, query, paddings, per_step_padding): """Computes attention probs in a non vectorized way. Args: key: A numpy ndarray of shape [batch, seqlen, heads, dim]. query: A numpy ndarray of the same shape as `key`. paddings: A numpy ndarray of shape [batch, seqlen]. per_step_padding: A numpy ndarray of shape [batch, seqlen, seqlen]. Returns: A numpy ndarray of shape [batch, query_seqlen, key_seqlen] """ assert query.ndim == 4 assert paddings.ndim == 2 assert key.shape == query.shape batch, seqlen = query.shape[:2] tgtlen, srclen = seqlen, seqlen assert query.shape[2] == self._num_heads assert query.shape[3] == self._atten_dim assert paddings.shape == query.shape[:2] logits = np.zeros((batch, self._num_heads, tgtlen, srclen)) probs = np.zeros((batch, self._num_heads, tgtlen, srclen)) def Normalize(vec): expx = np.exp(vec) expxsum = np.sum(expx, axis=-1) return expx / expxsum # [b, tgtlen, srclen] paddings = np.broadcast_to( np.reshape(paddings, (batch, 1, seqlen)), (batch, seqlen, seqlen)) for b in range(batch): for h in range(self._num_heads): for i in range(tgtlen): for j in range(srclen): pos_enc = self._GetPositionEnc(i, j, h, seqlen) logits[b][h][i][j] = ( np.dot(query[b][i][h], key[b][j][h]) + np.dot(query[b][i][h], pos_enc) + np.dot(self._u[h], key[b][j][h]) + np.dot(self._v[h], pos_enc)) total_padding = paddings[b][i] + per_step_padding[b][i] logits[b][h][i] = np.where(total_padding > 0, np.finfo(np.float32).max * (-0.7), logits[b][h][i]) probs[b][h][i] = Normalize(logits[b][h][i]) return probs def _AttentionInputs(input_dim=4, dtype=tf.float32, is_causal=True): np.random.seed(6348575) batch_size = 6 seq_len = 6 query_vec_p = [np.random.rand(seq_len, input_dim) for _ in range(batch_size)] query_vec_p = np.array(query_vec_p).astype(dtype.as_numpy_dtype) query_vec = tf.convert_to_tensor(query_vec_p) memory_vec_p = [np.random.rand(seq_len, input_dim) for _ in range(batch_size)] memory_vec_p = np.array(memory_vec_p).astype(dtype.as_numpy_dtype) memory_vec = tf.convert_to_tensor(memory_vec_p) # pyformat: disable paddings_p = np.array( [[0, 0, 1, 1, 1, 1], [0, 1, 1, 1, 1, 1], [0, 0, 0, 0, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 0, 1]]).astype(dtype.as_numpy_dtype) paddings = tf.convert_to_tensor(paddings_p) # causal padding. if is_causal: per_step_padding_p = [ [0, 1, 1, 1, 1, 1], [0, 0, 1, 1, 1, 1], [0, 0, 0, 1, 1, 1], [0, 0, 0, 0, 1, 1], [0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0]] else: per_step_padding_p = np.zeros((seq_len, seq_len)) per_step_padding_p = [per_step_padding_p for _ in range(batch_size)] per_step_padding_p = np.array(per_step_padding_p).astype(dtype.as_numpy_dtype) per_step_padding = tf.convert_to_tensor(per_step_padding_p) # pyformat: enable return (query_vec, memory_vec, paddings, per_step_padding, query_vec_p, memory_vec_p, paddings_p, per_step_padding_p) class MultiHeadedAttentionTest(test_utils.TestCase, parameterized.TestCase): """Test dot-product multiheaded attention.""" def _AttentionExtendStepInputs(self, input_dim=4, num_heads=2, dtype=tf.float32): np.random.seed(6348575) batch_size = 6 seq_len = 6 query_vec_p = [np.random.rand(1, input_dim) for _ in range(batch_size)] query_vec = tf.stack([tf.constant(x, dtype=dtype) for x in query_vec_p]) # pyformat: disable per_step_padding_p = [[0, 1, 1, 1, 1, 1]] per_step_padding_p = [per_step_padding_p for _ in range(batch_size)] # pyformat: enable per_step_padding = tf.stack( [tf.constant(x, dtype=dtype) for x in per_step_padding_p]) source_vecs = tf.zeros( [seq_len, batch_size, num_heads, input_dim // num_heads], dtype=dtype) source_ctxs = tf.zeros( [seq_len, batch_size, num_heads, input_dim // num_heads], dtype=dtype) return source_vecs, source_ctxs, query_vec, per_step_padding def testAttenProbs(self): (query_vec, key_vec, paddings, per_step_padding, query_vec_p, key_vec_p, paddings_p, per_step_padding_p) = _AttentionInputs() p = attention.MultiHeadedAttention.Params().Set( name='atten', input_dim=4, hidden_dim=4) l = p.Instantiate() probs = l.AttenProbs( l.theta, tf.expand_dims(query_vec, 2), tf.expand_dims(key_vec, 2), paddings, segment_mask=None, per_step_padding=per_step_padding) with self.session(use_gpu=False) as sess: tf.global_variables_initializer().run() prob_out = sess.run(tf.squeeze(probs)) # Use numpy to perform the same computation to generate expected results. query_vec_p = np.array(query_vec_p) key_vec_p = np.array(key_vec_p) key_vec_p = np.transpose(key_vec_p, (0, 2, 1)) expected_logit = np.matmul(query_vec_p, key_vec_p) paddings_p = np.array(paddings_p) paddings_p = np.expand_dims(paddings_p, axis=1) paddings_p = np.tile(paddings_p, (1, 6, 1)) per_step_padding_p = np.array(per_step_padding_p) paddings_p = 1.0 * np.logical_or(paddings_p, per_step_padding_p) elexp = np.exp(expected_logit) elexp *= (1.0 - paddings_p) elexp += 1e-9 expected_prob_out = elexp / np.expand_dims(np.sum(elexp, axis=-1), axis=-1) expected_prob_out = np.reshape(expected_prob_out, (6, 6, 6)) self.assertAllClose(expected_prob_out, prob_out) def testFPropCrossAttention(self): # input_batch:6, seq_len:6. Test n = 2 case. with self.session(use_gpu=True) as sess: query_vec, memory_vec, paddings, per_step_padding, _, _, _, _ = ( _AttentionInputs()) p = attention.MultiHeadedAttention.Params().Set( name='cross_atten', num_heads=2, input_dim=4, hidden_dim=4) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() tf.global_variables_initializer().run() ctx_vec, _ = l.FProp( l.theta, query_vec, memory_vec, memory_vec, paddings, segment_mask=None, per_step_padding=per_step_padding) context_vec_out = sess.run(ctx_vec) context_vec_out = np.reshape(context_vec_out, (6, 24)) self.assertAllClose( [24.624561, 27.805634, 23.358835, 11.085404, 27.165989, 23.750813], np.sum(context_vec_out, axis=1)) @parameterized.named_parameters( { 'testcase_name': '_short_seq', 'use_short_seq_opt': True, }, { 'testcase_name': '_long_seq', 'use_short_seq_opt': False, }) def testExtendStepSelfAttention(self, use_short_seq_opt): # input_batch:6, seq_len:6, query_len: 1. Test n = 2 case. with self.session(use_gpu=True) as sess: source_vecs, source_ctxs, query_vec, per_step_padding = ( self._AttentionExtendStepInputs()) p = attention.MultiHeadedAttention.Params().Set( name='atten', num_heads=2, input_dim=4, hidden_dim=4) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() tf.global_variables_initializer().run() ctx_vec, new_src_vecs, _ = l.ExtendStep(l.theta, query_vec, source_vecs, source_ctxs, None, None, per_step_padding, 0, use_short_seq_opt) context_vec_out = sess.run(ctx_vec) new_source_vecs = sess.run(new_src_vecs) context_vec_out = np.reshape(context_vec_out, (6, 4)) self.assertAllClose( [5.381485, 5.384035, 4.493689, 3.544395, 3.424472, 3.311054], np.sum(context_vec_out, axis=1)) new_source_vecs = np.reshape(new_source_vecs, (6, 24)) self.assertAllClose([4.116683, 0.0, 0.0, 0.0, 0.0, 0.0], np.sum(new_source_vecs, axis=1)) class MultiSourceMultiHeadedAttentionTest(MultiHeadedAttentionTest): def testAttenProbs(self): (query_vec, key_vec, paddings, per_step_padding, query_vec_p, key_vec_p, paddings_p, per_step_padding_p) = _AttentionInputs() # Two-source attention. mha_params = attention.MultiHeadedAttention.Params().Set( name='atten', input_dim=4, hidden_dim=4) atten_merger_p = tm_attention.MergerLayer.Params().Set( params_init=py_utils.WeightInit.Uniform(0.04), merger_op='concat', # concatenate attention pre_proj_input_dims=[4, 4], pre_proj_output_dims=[4, 4]) params = attention.MultiSourceAttention.Params().Set( name='two_source_atten', input_dim=4, hidden_dim=4, source_atten_tpls=[('src_1', mha_params), ('src_2', mha_params.Copy().Set(name='atten2'))], primary_source_key='src_1', atten_merger_tpl=atten_merger_p) l = params.Instantiate() probs = l.AttenProbs( l.theta, tf.expand_dims(query_vec, 2), py_utils.NestedMap({ 'src_1': tf.expand_dims(key_vec, 2), 'src_2': tf.expand_dims(key_vec, 2) }), py_utils.NestedMap({ 'src_1': paddings, 'src_2': paddings }), segment_mask=None, per_step_padding=per_step_padding) with self.session(use_gpu=False) as sess: tf.global_variables_initializer().run() prob_out = sess.run(tf.squeeze(probs)) # Use numpy to perform the same computation to generate expected results. query_vec_p = np.array(query_vec_p) key_vec_p = np.array(key_vec_p) key_vec_p = np.transpose(key_vec_p, (0, 2, 1)) expected_logit = np.matmul(query_vec_p, key_vec_p) paddings_p = np.array(paddings_p) paddings_p = np.expand_dims(paddings_p, axis=1) paddings_p = np.tile(paddings_p, (1, 6, 1)) per_step_padding_p = np.array(per_step_padding_p) paddings_p = 1.0 * np.logical_or(paddings_p, per_step_padding_p) elexp = np.exp(expected_logit) elexp *= (1.0 - paddings_p) elexp += 1e-9 expected_prob_out = elexp / np.expand_dims(np.sum(elexp, axis=-1), axis=-1) expected_prob_out = np.reshape(expected_prob_out, (6, 6, 6)) self.assertAllClose(expected_prob_out, prob_out) def testFPropCrossAttention(self): # input_batch:6, seq_len:6. Test n = 2 case. with self.session(use_gpu=True) as sess: query_vec, memory_vec, paddings, per_step_padding, _, _, _, _ = ( _AttentionInputs()) mha_params = attention.MultiHeadedAttention.Params().Set( name='cross_atten', num_heads=2, input_dim=4, hidden_dim=4) mha_params.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) atten_merger_p = tm_attention.MergerLayer.Params().Set( params_init=py_utils.WeightInit.Uniform(0.04), merger_op='concat', # concatenate attention pre_proj_input_dims=[4, 4], pre_proj_output_dims=[4, 4]) # Two-source attention. p = attention.MultiSourceAttention.Params().Set( name='two_source_atten', input_dim=4, hidden_dim=4, source_atten_tpls=[('src_1', mha_params), ('src_2', mha_params.Copy().Set(name='atten2'))], primary_source_key='src_1', atten_merger_tpl=atten_merger_p) l = p.Instantiate() tf.global_variables_initializer().run() ctx_vec, _ = l.FProp( l.theta, query_vec, py_utils.NestedMap({ 'src_1': memory_vec, 'src_2': memory_vec }), py_utils.NestedMap({ 'src_1': memory_vec, 'src_2': memory_vec }), py_utils.NestedMap({ 'src_1': paddings, 'src_2': paddings }), segment_mask=None, per_step_padding=per_step_padding) context_vec_out = sess.run(ctx_vec) context_vec_out = np.reshape(context_vec_out, (12, 24)) self.assertAllClose([ 5.6162043, 5.0109887, 6.0565553, 6.0565553, 4.5718207, 5.253615, 2.0541124, 2.490314, 6.049119, 5.5567484, 4.409875, 5.8939424 ], np.sum(context_vec_out, axis=1)) class MultiHeadedAttentionXLTest(test_utils.TestCase, parameterized.TestCase): """Test dot-product multiheaded attention.""" def _AttentionExtendStepInputs(self, input_dim, batch_size, seq_len, dtype=tf.float32): np.random.seed(6348575) query_vec_p = [ np.random.rand(seq_len, input_dim) for _ in range(batch_size) ] query_vec = tf.stack([tf.constant(x, dtype=dtype) for x in query_vec_p]) paddings_p = [[0] * seq_len] * batch_size paddings = tf.constant(paddings_p, dtype=dtype) return query_vec, paddings @parameterized.named_parameters(('OneHead', 1), ('OneHeadCausal', 1, True), ('MultiHead', 2), ('MultiHeadCausal', 2, True)) def testAttenProbs(self, num_heads, is_causal=False): batch, slen = 6, 6 atten_dim = 4 input_dim = num_heads * atten_dim (input_vecs, _, input_padding, per_step_padding, input_vecs_p, _, input_padding_p, per_step_padding_p) = _AttentionInputs( input_dim=input_dim, is_causal=is_causal) p = attention.MultiHeadedAttentionXL.Params().Set( name='self_atten', input_dim=input_dim, num_heads=num_heads, hidden_dim=input_dim, rel_pos_emb_dim=input_dim) l = p.Instantiate() query = tf.reshape(input_vecs, (batch, slen, num_heads, atten_dim)) probs = l.AttenProbs( l.theta, query, query, input_padding, segment_mask=None, per_step_padding=per_step_padding) # [1, 2 * slen - 1] positions = np.expand_dims(np.arange(-(slen - 1), slen), 0) sinusoid_emb = l.pos_emb.FPropWithPosition(l.theta.pos_emb, tf.convert_to_tensor(positions)) # [ 2 * slen - 1, emb_dim=input_dim] sinusoid_emb = tf.squeeze(sinusoid_emb, 0) with self.session(use_gpu=False) as sess: tf.global_variables_initializer().run() u, v, pos_proj = sess.run([l.vars.u, l.vars.v, l.pos_proj.vars.w]) actual_probs = sess.run(probs) sinusoid_emb_p = sess.run(sinusoid_emb) # Compute ground truth with oracle class. # Use numpy to perform the same computation to generate expected results. # [B, tgt_t, H] input_vecs_p = np.array(input_vecs_p) # [B, tgt_t, N, H] input_vecs_p = np.reshape(input_vecs_p, (batch, slen, num_heads, atten_dim)) input_padding_p = np.array(input_padding_p) oracle = MultiHeadedAttentionXLOracle(u, v, pos_proj, sinusoid_emb_p) expected_probs = oracle.AttenProbs(input_vecs_p, input_vecs_p, input_padding_p, per_step_padding_p) self.assertAllClose(expected_probs, actual_probs) def testFPropSelfAttention(self): # input_batch:6, seq_len:6. Test n = 2 case. with self.session(use_gpu=True) as sess: query_vec, _, paddings, _, _, _, _, _ = _AttentionInputs() num_heads, input_dim, hidden_dim = 2, 4, 4 p = attention.MultiHeadedAttentionXL.Params().Set( name='self_atten', num_heads=num_heads, input_dim=input_dim, hidden_dim=hidden_dim, rel_pos_emb_dim=num_heads * hidden_dim) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() ctx_vec, _ = l.FPropDefaultTheta( query_vec, query_vec, query_vec, paddings, segment_mask=None) tf.global_variables_initializer().run() context_vec_out = sess.run(ctx_vec) context_vec_out = np.reshape(context_vec_out, (6, 24)) self.assertAllClose( [32.33513, 28.584404, 20.54517, 23.407812, 18.616188, 24.212755], np.sum(context_vec_out, axis=1)) def testExtendStepSelfAttention(self): num_heads, input_dim, hidden_dim, batch, seqlen = 2, 4, 4, 6, 6 emb_dim = 4 with self.session(use_gpu=True): tf.random.set_seed(12345) query_vec, paddings = self._AttentionExtendStepInputs( input_dim, batch, seqlen) p = attention.MultiHeadedAttentionXL.Params().Set( name='atten', num_heads=num_heads, input_dim=input_dim, hidden_dim=hidden_dim, rel_pos_emb_dim=emb_dim, random_seed=0) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() tf.global_variables_initializer().run() # Verify ExtendStep() via compare N ExtendStep() with one FProp() call on # a seq with length N. per_step_padding = 1 - tf.linalg.band_part( tf.ones((seqlen, seqlen)), -1, 0) per_step_padding = tf.stack([per_step_padding] * batch) expected_ctx_vec, _ = l.FPropDefaultTheta( query_vec, query_vec, query_vec, paddings, segment_mask=None, per_step_padding=per_step_padding) dims_per_head = hidden_dim // num_heads cached_source_vecs = tf.zeros([seqlen, batch, num_heads, dims_per_head]) cached_source_ctxs = tf.zeros([seqlen, batch, num_heads, dims_per_head]) encoded_all = [] for i in range(seqlen): per_step_paddings = 1. - tf.cast( tf.sequence_mask([i + 1] * batch, seqlen), tf.float32) per_step_paddings = tf.expand_dims(per_step_paddings, 1) encoded, cached_source_vecs, cached_source_ctxs = l.ExtendStep( l.theta, query_vec[:, i:i + 1, :], cached_source_vecs, cached_source_ctxs, paddings, None, per_step_paddings, i) # [batch, 1, dims_per_head] encoded_all.append(encoded) # [batch, T, dims_per_head] actual_ctx_vec = tf.concat(encoded_all, axis=1) self.assertAllClose(expected_ctx_vec.eval(), actual_ctx_vec.eval()) class MultiHeadedAttentionRPEOracle(object): """Computes ground truths for MultiHeadedfAttentionRPE. Written in a non-vectorized way. """ def __init__(self, num_heads, key_embs, value_embs): """Constructor. Args: num_heads: A Python int. key_embs: A numpy array of shape [2 * radius + 1, hidden_dim] value_embs: A numpy array of shape [2 * radius + 1, hidden_dim] """ assert key_embs.shape == value_embs.shape self._num_heads = num_heads self._hidden_dim = key_embs.shape[-1] self._atten_dim = self._hidden_dim // self._num_heads assert self._atten_dim * self._num_heads == self._hidden_dim self._key_embs = np.reshape( key_embs, [key_embs.shape[0], self._num_heads, self._atten_dim]) self._value_embs = np.reshape( value_embs, [value_embs.shape[0], self._num_heads, self._atten_dim]) self._radius = key_embs.shape[0] // 2 def _GetEmb(self, tgt_t, src_t, head, emb_wt): radius = self._radius distance = np.clip(src_t - tgt_t, -radius, radius) return emb_wt[distance][head] def GetKeyEmb(self, tgt_t, src_t, head): return self._GetEmb(tgt_t, src_t, head, self._key_embs) def GetValueEmb(self, tgt_t, src_t, head): return self._GetEmb(tgt_t, src_t, head, self._value_embs) def AttenProbs(self, key, query, paddings): assert query.ndim == 4 assert paddings.ndim == 2 assert key.shape == query.shape batch, seqlen = query.shape[:2] tgtlen, srclen = seqlen, seqlen assert query.shape[2] == self._num_heads assert query.shape[3] == self._atten_dim assert paddings.shape == query.shape[:2] # [B, N, T, T] logits = np.zeros((batch, self._num_heads, tgtlen, srclen)) # [B, N, T, T] probs = np.zeros((batch, self._num_heads, tgtlen, srclen)) paddings = np.broadcast_to( np.reshape(paddings, (batch, 1, 1, seqlen)), (batch, self._num_heads, seqlen, seqlen)) def Normalize(vec): expx = np.exp(vec) expxsum = np.sum(expx, axis=-1) return expx / expxsum for b in range(batch): for h in range(self._num_heads): for i in range(tgtlen): for j in range(srclen): logits[b][h][i][j] = np.dot(query[b][i][h], key[b][j][h] + self.GetKeyEmb(i, j, h)) logits[b][h][i] = np.where(paddings[b][h][i] > 0, np.finfo(np.float32).max * (-0.7), logits[b][h][i]) probs[b][h][i] = Normalize(logits[b][h][i]) return probs def AttenContext(self, probs, values): assert probs.ndim == 4 assert values.ndim == 4 assert probs.shape[0] == values.shape[0] # batch assert probs.shape[1] == values.shape[2] # head assert probs.shape[2] == values.shape[1] # tgtlen assert probs.shape[3] == probs.shape[2] # slen assert values.shape[-1] == self._atten_dim batch, _, tgtlen, srclen = probs.shape # [B, N, T, H] ctx = np.zeros((batch, self._num_heads, tgtlen, self._atten_dim)) for b in range(batch): for h in range(self._num_heads): for i in range(tgtlen): for j in range(srclen): ctx[b][h][i] += probs[b][h][i][j] * ( values[b][j][h] + self.GetValueEmb(i, j, h)) # [B, T, N, H] return np.transpose(ctx, (0, 2, 1, 3)) class MultiHeadedAttentionRPETest(test_utils.TestCase, parameterized.TestCase): @parameterized.named_parameters(('OneHead', 1), ('MultiHead', 2)) def testAttenProbs(self, num_heads): batch, slen = 6, 6 atten_dim = 4 radius = 3 input_dim = num_heads * atten_dim (input_vecs, _, input_padding, _, input_vecs_p, _, input_padding_p, _) = _AttentionInputs(input_dim=input_dim) p = attention.MultiHeadedAttentionRPE.Params().Set( name='self_atten', input_dim=input_dim, num_heads=num_heads, hidden_dim=input_dim, rel_pos_radius=radius) l = p.Instantiate() query = tf.reshape(input_vecs, (batch, slen, num_heads, atten_dim)) probs = l.AttenProbs( l.theta, query, query, input_padding, segment_mask=None) with self.session(use_gpu=False) as sess: tf.global_variables_initializer().run() # [radius * 2 + 1, hidden_dim], [B, tgt_t, src_t] key_emb, value_emb, actual_probs = sess.run( [l.key_emb.vars.w, l.value_emb.vars.w, probs]) oracle = MultiHeadedAttentionRPEOracle(num_heads, key_emb, value_emb) # Use numpy to perform the same computation to generate expected results. # [B, tgt_t, N, H] input_vecs_p = np.reshape(input_vecs_p, (batch, slen, num_heads, atten_dim)) expected_probs = oracle.AttenProbs(input_vecs_p, input_vecs_p, input_padding_p) self.assertAllClose(expected_probs, actual_probs) @parameterized.named_parameters(('OneHead', 1), ('MultiHead', 2)) def testAttenContext(self, num_heads): batch, slen = 6, 6 atten_dim = 4 radius = 3 input_dim = num_heads * atten_dim (input_vecs, _, _, _, input_vecs_p, _, _, _) = _AttentionInputs(input_dim=input_dim) p = attention.MultiHeadedAttentionRPE.Params().Set( name='self_atten', input_dim=input_dim, num_heads=num_heads, hidden_dim=input_dim, rel_pos_radius=radius) l = p.Instantiate() probs = np.random.rand(batch, num_heads, slen, slen).astype(np.float32) probs = np.exp(probs) / np.sum(np.exp(probs), axis=-1, keepdims=True) ctx = l._AttenContext( l.theta, tf.convert_to_tensor(probs), tf.reshape(input_vecs, (batch, slen, num_heads, atten_dim))) with self.session(use_gpu=False) as sess: tf.global_variables_initializer().run() key_emb, value_emb, actual_ctx = sess.run( [l.key_emb.vars.w, l.value_emb.vars.w, ctx]) oracle = MultiHeadedAttentionRPEOracle(num_heads, key_emb, value_emb) # [B, tgt_t, N, H] input_vecs_p = np.reshape(input_vecs_p, (batch, slen, num_heads, atten_dim)) expected_ctx = oracle.AttenContext(probs, input_vecs_p) self.assertAllClose(expected_ctx, actual_ctx) @parameterized.named_parameters(('OneHead', 1), ('MultiHead', 2)) def testAttenLogitsOneStep(self, num_heads): batch, slen = 6, 6 atten_dim = 4 radius = 3 input_dim = num_heads * atten_dim (input_vecs, _, _, per_step_padding, _, _, _, _) = _AttentionInputs( input_dim=input_dim, is_causal=True) p = attention.MultiHeadedAttentionRPE.Params().Set( name='self_atten', input_dim=input_dim, num_heads=num_heads, hidden_dim=input_dim, rel_pos_radius=radius) l = p.Instantiate() # [B, T, N, H] query = tf.reshape(input_vecs, (batch, slen, num_heads, atten_dim)) # Causal self attention. # [B, N, T, S] logits = l._AttenLogits(l.theta, query, query, per_step_padding) one_step_logits = [] # [S=T, B, N, H] key = tf.transpose(query, [1, 0, 2, 3]) for i in range(slen): local_logits = l._AttenLogitsOneStep(l.theta, query[:, i, :, :], key, i) one_step_logits.append(local_logits) # [T, S, B, N] stacked_logits = tf.stack(one_step_logits) stacked_logits = tf.transpose(stacked_logits, [2, 3, 0, 1]) with self.session(use_gpu=False) as sess: tf.global_variables_initializer().run() expected_logits, actual_logits = sess.run([logits, stacked_logits]) self.assertAllClose(expected_logits, actual_logits) @parameterized.named_parameters(('OneHead', 1), ('MultiHead', 2)) def testAttenContextsOneStep(self, num_heads): batch, slen = 6, 6 atten_dim = 4 radius = 3 input_dim = num_heads * atten_dim (input_vecs, _, _, per_step_padding, _, _, _, _) = _AttentionInputs( input_dim=input_dim, is_causal=True) p = attention.MultiHeadedAttentionRPE.Params().Set( name='self_atten', input_dim=input_dim, num_heads=num_heads, hidden_dim=input_dim, rel_pos_radius=radius) l = p.Instantiate() # [B, N, T, S=T] # Make causal attention probs. probs = np.random.rand(batch, num_heads, slen, slen).astype(np.float32) per_step_padding = 1 - np.tril(np.ones((slen, slen))).astype(np.float32) probs *= per_step_padding # Normalize probs = np.exp(probs) / np.sum(np.exp(probs), axis=-1, keepdims=True) # Causal self attention. # [B, N, T, S] ctx = l._AttenContext( l.theta, tf.convert_to_tensor(probs), tf.reshape(input_vecs, (batch, slen, num_heads, atten_dim))) one_step_ctx = [] # [B, T, N, H] -> [S=T, B, N, H] value = tf.reshape(input_vecs, (batch, slen, num_heads, atten_dim)) value = tf.transpose(value, [1, 0, 2, 3]) for i in range(slen): # [B, N, S] local_prob = probs[:, :, i, :] # [S, B, N] local_prob = tf.transpose(local_prob, [2, 0, 1]) # [B, N, H] local_ctx = l._AttenContextOneStep(l.theta, local_prob, value, i) one_step_ctx.append(local_ctx) # [T, B, N, H] stacked_ctx = tf.stack(one_step_ctx) stacked_ctx = tf.transpose(stacked_ctx, [1, 0, 2, 3]) with self.session(use_gpu=False) as sess: tf.global_variables_initializer().run() expected_ctx, actual_ctx = sess.run([ctx, stacked_ctx]) self.assertAllClose(expected_ctx, actual_ctx) class LocalCausalSelfAttentionTest(test_utils.TestCase, parameterized.TestCase): """Test local causual self attention.""" def _LocalCasualPadding(self, b, t, l, r): padding = np.ones((b, t, t)) for i in range(t): padding[:, i, max(0, i - l + 1):i + r + 1] = 0 return tf.constant(padding, dtype=tf.float32) @parameterized.named_parameters( { 'testcase_name': 'block_size_unspecified', 'block_size': None, 'left_context': 4, 'right_context': 1 }, { 'testcase_name': 'left_context_only', 'block_size': 3, 'left_context': 4, 'right_context': 0, }, { 'testcase_name': 'block_longer_than_sequence', 'block_size': 10, 'left_context': 7, 'right_context': 0, }, { 'testcase_name': 'pos_emb_left_context_only', 'block_size': 3, 'left_context': 4, 'right_context': 0, 'pos_emb_dim': 8, }, { 'testcase_name': 'pos_emb_left_and_right_context', 'block_size': 3, 'left_context': 4, 'right_context': 2, 'pos_emb_dim': 8, }, { 'testcase_name': 'lite_pos_emb_left_and_right_context', 'block_size': 3, 'left_context': 4, 'right_context': 2, 'pos_emb_dim': 8, 'skip_term_b': True, }) def testFPropAgainstReference(self, block_size, left_context, right_context, pos_emb_dim=0, num_heads=2, input_dim=4, hidden_dim=4, skip_term_b=False, use_additional_per_step_padding=False): tf.reset_default_graph() with self.session(use_gpu=True) as sess: query_vec, _, paddings, _, _, _, _, _ = _AttentionInputs(input_dim) if use_additional_per_step_padding: # Generate a random binary mask of shape [N, T, S]. additional_per_step_padding_val = np.random.random_integers( low=0, high=1, size=(6, 6, 6)) additional_per_step_padding = tf.constant( additional_per_step_padding_val, tf.float32) else: additional_per_step_padding = None # Use the reference implementation + local casual padding to verify # correctness. if pos_emb_dim == 0: p_cls = attention.LocalCausalSelfAttention expected_p_cls = attention.MultiHeadedAttention else: p_cls = attention.LocalCausalSelfAttentionXL expected_p_cls = attention.MultiHeadedAttentionXL p = p_cls.Params().Set( name='self_atten', num_heads=num_heads, input_dim=input_dim, hidden_dim=hidden_dim, block_size=block_size, left_context=left_context, right_context=right_context) expected_p = expected_p_cls.Params().Set( name='expected_self_atten', num_heads=num_heads, input_dim=input_dim, hidden_dim=hidden_dim) if pos_emb_dim != 0: p.rel_pos_emb_dim = pos_emb_dim expected_p.rel_pos_emb_dim = pos_emb_dim p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) expected_p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() expected_l = expected_p.Instantiate() tf.global_variables_initializer().run() ctx_vec, _ = l.FProp( l.theta, query_vec, query_vec, query_vec, paddings, segment_mask=None, per_step_padding=additional_per_step_padding) context_vec_out = sess.run(ctx_vec) per_step_padding = self._LocalCasualPadding(6, 6, left_context, right_context) if additional_per_step_padding is not None: per_step_padding += additional_per_step_padding expected_ctx_vec, _ = expected_l.FProp(expected_l.theta, query_vec, query_vec, query_vec, paddings, None, per_step_padding) expected_context_vec_out = sess.run(expected_ctx_vec) # Don't compare if the query position is padded, or if all key positions # are padded. paddings_val = sess.run(paddings) per_step_padding_val = sess.run(per_step_padding) per_step_padding_val += paddings_val[:, :, np.newaxis] per_step_padding_val += paddings_val[:, np.newaxis, :] dont_compare = np.sum( per_step_padding_val > 0, axis=-1) == per_step_padding_val.shape[-1] expected_context_vec_out *= (1 - dont_compare)[..., np.newaxis] context_vec_out *= (1 - dont_compare)[..., np.newaxis] self.assertAllClose(context_vec_out, expected_context_vec_out) def testFPropWithDropout(self): with self.session(use_gpu=True) as sess: query_vec, _, paddings, _, _, _, _, _ = _AttentionInputs(input_dim=4) p = attention.LocalCausalSelfAttention.Params().Set( name='self_atten', num_heads=2, input_dim=4, hidden_dim=4, block_size=2, left_context=2, right_context=0, atten_dropout_prob=0.3, ) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() tf.global_variables_initializer().run() ctx_vec, _ = l.FProp( l.theta, query_vec, query_vec, query_vec, paddings, segment_mask=None) ctx_vec_val = sess.run(ctx_vec) print(ctx_vec_val) class TransformerLayerTest(test_utils.TestCase, parameterized.TestCase): """Test Transformer decoder layers.""" def _TransformerAttentionLayerInputs(self, input_dim=4, dtype=tf.float32): np.random.seed(6348575) query_vec = tf.transpose( tf.stack([ tf.constant(np.random.rand(2, input_dim), dtype=dtype) for _ in range(5) ]), [1, 0, 2]) paddings = tf.constant([[0, 0, 1, 1, 0], [1, 0, 0, 0, 1]], dtype=dtype) aux_vec = tf.transpose( tf.stack([ tf.constant(np.random.rand(2, input_dim), dtype=dtype) for _ in range(7) ]), [1, 0, 2]) aux_paddings = tf.constant([[0, 1, 0, 1, 0, 1, 0], [1, 0, 1, 0, 1, 0, 1]], dtype=dtype) return query_vec, paddings, aux_vec, aux_paddings def testTransformerAttentionLayerFPropMaskedSelfAttention(self): with self.session(use_gpu=True) as sess: query_vec, paddings, _, _ = self._TransformerAttentionLayerInputs() p = attention.TransformerAttentionLayer.Params().Set( name='transformer_masked_self_atten', input_dim=4, is_masked=True, num_heads=2) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() ctx_vec, _ = l.FProp(l.theta, query_vec, None, paddings) tf.global_variables_initializer().run() actual_ctx = sess.run(ctx_vec) actual_ctx = np.reshape(actual_ctx, (10, 4)) tf.logging.info(np.array_repr(actual_ctx)) expected_ctx = [7.777687, 5.219166, 6.305151, 4.817311] self.assertAllClose(expected_ctx, np.sum(actual_ctx, axis=0)) def testAttentionLayerFPropMaskedSelfAttentionPaddingOverride(self): with self.session(use_gpu=True) as sess: query_vec, paddings, _, _ = self._TransformerAttentionLayerInputs() p = attention.TransformerAttentionLayer.Params().Set( name='transformer_masked_self_atten', input_dim=4, is_masked=True, num_heads=2) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() triangle_padding = 1.0 - tf.linalg.band_part( tf.ones([5, 5], dtype=query_vec.dtype), -1, 0) per_step_padding_override = tf.tile( tf.expand_dims(triangle_padding, 0), [2, 1, 1]) ctx_vec1, _ = l.FProp(l.theta, query_vec, None, paddings, per_step_padding_override) expected_ctx1, _ = l.FProp(l.theta, query_vec, None, paddings) per_step_padding_override = tf.zeros([2, 5, 5]) ctx_vec2, _ = l.FProp(l.theta, query_vec, None, paddings, per_step_padding_override) tf.global_variables_initializer().run() actual_ctx1, actual_ctx2, actual_expected_ctx1 = sess.run( [ctx_vec1, ctx_vec2, expected_ctx1]) tf.logging.info(np.array_repr(actual_ctx1)) tf.logging.info(np.array_repr(actual_ctx2)) expected_ctx2 = [7.9491496, 5.2976646, 6.5383415, 5.0169916] self.assertAllClose(actual_expected_ctx1, ctx_vec1) self.assertAllClose(expected_ctx2, np.sum(np.reshape(actual_ctx2, (10, 4)), axis=0)) def testTransformerAttentionLayerFPropCrossAttention(self): with self.session(use_gpu=True) as sess: (query_vec, _, aux_vec, aux_paddings) = self._TransformerAttentionLayerInputs() p = attention.TransformerAttentionLayer.Params().Set( name='transformer_cross_atten', input_dim=4, is_masked=False, num_heads=2) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() ctx_vec, _ = l.FProp(l.theta, query_vec, aux_vec, aux_paddings) tf.global_variables_initializer().run() actual_ctx = sess.run(ctx_vec) actual_ctx = np.reshape(actual_ctx, (10, 4)) tf.logging.info(np.array_repr(actual_ctx)) expected_ctx = [19.345360, 15.057412, 13.744134, 13.387347] self.assertAllClose(expected_ctx, np.sum(actual_ctx, axis=0)) def testMultiSourceTransformerAttentionLayerFPropCrossAttention(self): with self.session(use_gpu=True) as sess: (query_vec, _, aux_vec, aux_paddings) = self._TransformerAttentionLayerInputs() p = attention.TransformerMultiSourceAttentionLayer.Params().Set( name='transformer_multi_source_cross_atten', input_dim=4, is_masked=False, num_heads=2, num_source=2) p.multi_source_atten.atten_merger_tpl = ( tm_attention.MergerLayer.Params().Set(merger_op='sum')) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() ctx_vec, _ = l.FProp( l.theta, query_vec, py_utils.NestedMap({ 'source_0': aux_vec, 'source_1': aux_vec }), py_utils.NestedMap({ 'source_0': aux_paddings, 'source_1': aux_paddings })) tf.global_variables_initializer().run() actual_ctx = sess.run(ctx_vec) actual_ctx = np.reshape(actual_ctx, (10, 4)) tf.logging.info(np.array_repr(actual_ctx)) expected_ctx = [32.4878, 25.145725, 21.534966, 22.007454] self.assertAllClose(expected_ctx, np.sum(actual_ctx, axis=0)) @parameterized.named_parameters( { 'testcase_name': '_short_seq', 'use_short_seq_opt': True, }, { 'testcase_name': '_long_seq', 'use_short_seq_opt': False, }) def testTransformerAttentionLayerExtendStep(self, use_short_seq_opt): with self.session(use_gpu=True) as sess: query_vec, _, _, _ = self._TransformerAttentionLayerInputs() paddings = tf.zeros([2, 5]) cached_key = tf.zeros([5, 2, 2, 2]) cached_value = tf.zeros([5, 2, 2, 2]) prefix_states = py_utils.NestedMap(key=cached_key, value=cached_value) p = attention.TransformerAttentionLayer.Params().Set( name='transformer_atten', input_dim=4, is_masked=True, num_heads=2) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() ctx_vec1, _ = l.FProp(l.theta, query_vec, None, paddings) ctx_vec2 = [] for i in range(5): ctx_vec, prefix_states = l.ExtendStep( l.theta, tf.expand_dims(query_vec[:, i, :], 1), prefix_states, i, use_short_seq_opt) ctx_vec2.append(tf.squeeze(ctx_vec, 1)) ctx_vec2 = tf.transpose(tf.stack(ctx_vec2), [1, 0, 2]) tf.global_variables_initializer().run() ctx1, ctx2 = sess.run([ctx_vec1, ctx_vec2]) self.assertAllClose(ctx1, ctx2) def _ConstructTransformerDecoderLayer(self, use_relative_atten=False): p = attention.TransformerDecoderLayer.Params() p.name = 'transformer_decoder_layer' p.input_dim = 4 p.tr_fflayer_tpl.hidden_dim = 7 p.tr_atten_tpl.num_heads = 2 if use_relative_atten: p = attention.UseRelativeAttentionInTransformerLayer(p, 4) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) return attention.TransformerDecoderLayer(p) @parameterized.named_parameters(('SingleBatch', 1), ('DoubleBatch', 2)) def testTransformerLayerFPropWithCrossAttention(self, multiplier): with self.session(use_gpu=True) as sess: (query_vec, _, aux_vec, aux_paddings) = self._TransformerAttentionLayerInputs() query_vec = tf.tile(query_vec, [multiplier, 1, 1]) paddings = tf.zeros([2 * multiplier, 5]) p = attention.TransformerLayer.Params() p.name = 'transformer_layer' p.input_dim = 4 p.tr_fflayer_tpl.hidden_dim = 7 p.tr_atten_tpl.num_heads = 2 p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() ctx_vec, _ = l.FProp(l.theta, query_vec, paddings, aux_vec, aux_paddings) tf.global_variables_initializer().run() actual_ctx = sess.run(ctx_vec) actual_ctx = np.reshape(actual_ctx, (10 * multiplier, 4)) tf.logging.info(np.array_repr(actual_ctx)) expected_ctx = [ 4.7839108, 4.5303655, 5.5551023, 5.065767, 5.0493064, 3.2142467, 2.8200178, 5.659971, 4.3814187, 2.60475 ] * multiplier self.assertAllClose(expected_ctx, np.sum(actual_ctx, axis=1)) @parameterized.named_parameters(('SingleBatch', 1), ('DoubleBatch', 2)) def testMultiSourceTransformerLayerFPropWithCrossAttention(self, multiplier): with self.session(use_gpu=True) as sess: (query_vec, _, aux_vec, aux_paddings) = self._TransformerAttentionLayerInputs() query_vec = tf.tile(query_vec, [multiplier, 1, 1]) paddings = tf.zeros([2 * multiplier, 5]) p = attention.TransformerLayer.Params() p.name = 'transformer_layer' p.input_dim = 4 p.tr_fflayer_tpl.hidden_dim = 7 p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) # multi-source cross attention p.tr_atten_tpl = ( attention.TransformerMultiSourceAttentionLayer.Params().Set( num_source=2, primary_source_index=0, num_heads=2)) p.tr_self_atten_tpl = attention.TransformerAttentionLayer.Params().Set( input_dim=4, num_heads=2) l = p.Instantiate() ctx_vec, _ = l.FProp( l.theta, query_vec, paddings, py_utils.NestedMap({ 'source_0': aux_vec, 'source_1': aux_vec }), py_utils.NestedMap({ 'source_0': aux_paddings, 'source_1': aux_paddings })) tf.global_variables_initializer().run() actual_ctx = sess.run(ctx_vec) actual_ctx = np.reshape(actual_ctx, (10 * multiplier, 4)) tf.logging.info(np.array_repr(actual_ctx)) expected_ctx = [ 4.7839108, 4.5303655, 5.5551023, 5.0657663, 5.0493064, 3.2142467, 2.820018, 5.659971, 4.3814187, 2.60475 ] * multiplier self.assertAllClose(expected_ctx, np.sum(actual_ctx, axis=1)) @parameterized.named_parameters(('Base', False), ('RelativeAtten', True)) def testTransformerDecoderLayerConstruction(self, use_relative_atten): _ = self._ConstructTransformerDecoderLayer( use_relative_atten=use_relative_atten) def testTransformerDecoderLayerFProp(self): with self.session(use_gpu=True) as sess: (query_vec, paddings, aux_vec, aux_paddings) = self._TransformerAttentionLayerInputs() l = self._ConstructTransformerDecoderLayer() layer_output, _ = l.FProp(l.theta, query_vec, paddings, aux_vec, aux_paddings) tf.global_variables_initializer().run() actual_layer_output = sess.run(layer_output) actual_layer_output = np.reshape(actual_layer_output, (10, 4)) tf.logging.info(np.array_repr(actual_layer_output)) expected_layer_output = [16.939590, 24.121685, 19.975197, 15.924350] self.assertAllClose(expected_layer_output, np.sum(actual_layer_output, axis=0)) def _ConstructTransformerEncoderLayerStack(self): p = attention.StackedTransformerLayers.Params() p.name = 'encoder_layers' p.has_aux_atten = False p.mask_self_atten = False p.num_layers = 2 p.mdl_dim = 4 p.hidden_dim = 8 p.num_atten_heads = 2 p.dropout_prob = 0.2 p.params_init = py_utils.WeightInit.Xavier() p.random_seed = 12345 return p.Instantiate() def _ConstructTransformerDecoderLayerStack(self, dropout_prob=0.2): p = attention.StackedTransformerLayers.Params() p.name = 'decoder_layers' p.has_aux_atten = True p.mask_self_atten = True p.num_layers = 2 p.mdl_dim = 4 p.hidden_dim = 8 p.num_atten_heads = 2 p.dropout_prob = dropout_prob p.params_init = py_utils.WeightInit.Xavier() p.random_seed = 12345 return p.Instantiate() def testTransformerEncoderLayerStackFProp(self): with self.session(use_gpu=True) as sess: (query_vec, paddings, _, _) = self._TransformerAttentionLayerInputs() l = self._ConstructTransformerEncoderLayerStack() layer_output, _ = l.FProp(l.theta, query_vec=query_vec, paddings=paddings) tf.global_variables_initializer().run() actual_layer_output = sess.run(layer_output) actual_layer_output = np.reshape(actual_layer_output, (10, 4)) tf.logging.info(np.array_repr(actual_layer_output)) expected_layer_output = [6.178955, -11.376661, 7.032681, -1.532627] self.assertAllClose(expected_layer_output, np.sum(actual_layer_output, axis=0)) def testTransformerDecoderLayerStackFProp(self): with self.session(use_gpu=True) as sess: (query_vec, paddings, aux_vec, aux_paddings) = self._TransformerAttentionLayerInputs() l = self._ConstructTransformerDecoderLayerStack() layer_output, _ = l.FProp( l.theta, query_vec=query_vec, paddings=paddings, aux_vec=aux_vec, aux_paddings=aux_paddings) tf.global_variables_initializer().run() actual_layer_output = sess.run(layer_output) actual_layer_output = np.reshape(actual_layer_output, (10, 4)) tf.logging.info(np.array_repr(actual_layer_output)) expected_layer_output = [9.926413, -4.491376, 27.051598, 2.112684] self.assertAllClose(expected_layer_output, np.sum(actual_layer_output, axis=0)) @parameterized.named_parameters( { 'testcase_name': '_short_seq', 'use_short_seq_opt': True, }, { 'testcase_name': '_long_seq', 'use_short_seq_opt': False, }) def testTransformerDecoderLayerStackExtendStep(self, use_short_seq_opt): def _Rnd(seed): return tf.random.normal([5, 2, 2, 2], seed=seed) graph = tf.Graph() with graph.as_default(): tf.random.set_seed(123456) query_vec, _, aux_vec, aux_paddings = ( self._TransformerAttentionLayerInputs()) paddings = tf.zeros([2, 5]) layer_prefix_states_1 = py_utils.NestedMap(key=_Rnd(1), value=_Rnd(2)) layer_prefix_states_2 = py_utils.NestedMap(key=_Rnd(3), value=_Rnd(4)) prefix_states = py_utils.NestedMap( x_layers=[layer_prefix_states_1, layer_prefix_states_2]) l = self._ConstructTransformerDecoderLayerStack(dropout_prob=0.) layer_output1, _ = l.FProp(l.theta, query_vec, paddings, aux_vec, aux_paddings) layer_output2 = [] for i in range(5): layer_output, prefix_states = l.ExtendStep( l.theta, tf.expand_dims(query_vec[:, i, :], 1), aux_vec, aux_paddings, prefix_states, i, use_short_seq_opt) layer_output2.append(tf.squeeze(layer_output, 1)) layer_output2 = tf.transpose(tf.stack(layer_output2), [1, 0, 2]) with self.session(graph=graph, use_gpu=True) as sess: tf.global_variables_initializer().run() actual_layer_output1, actual_layer_output2 = sess.run( [layer_output1, layer_output2]) self.assertAllClose(actual_layer_output1, actual_layer_output2) @parameterized.named_parameters( { 'testcase_name': '_short_seq', 'use_short_seq_opt': True, }, { 'testcase_name': '_long_seq', 'use_short_seq_opt': False, }) def testTransformerDecoderLayerExtendStep(self, use_short_seq_opt): with self.session(use_gpu=True) as sess: (query_vec, _, aux_vec, aux_paddings) = self._TransformerAttentionLayerInputs() paddings = tf.zeros([2, 5]) cached_key = tf.zeros([5, 2, 2, 2]) cached_value = tf.zeros([5, 2, 2, 2]) prefix_states = py_utils.NestedMap(key=cached_key, value=cached_value) l = self._ConstructTransformerDecoderLayer() layer_output1, _ = l.FProp(l.theta, query_vec, paddings, aux_vec, aux_paddings) layer_output2 = [] for i in range(5): layer_output, prefix_states = l.ExtendStep( l.theta, tf.expand_dims(query_vec[:, i, :], 1), aux_vec, aux_paddings, prefix_states, i, use_short_seq_opt) layer_output2.append(tf.squeeze(layer_output, 1)) layer_output2 = tf.transpose(tf.stack(layer_output2), [1, 0, 2]) tf.global_variables_initializer().run() actual_layer_output1, actual_layer_output2 = sess.run( [layer_output1, layer_output2]) self.assertAllClose(actual_layer_output1, actual_layer_output2) def _ConstructMultiSourceTransformerDecoderLayer(self, use_relative_atten=False): p = attention.MultiSourceTransformerDecoderLayer.Params().Set(num_source=2) p.name = 'multi_source_transformer_decoder_layer' p.input_dim = 4 p.tr_fflayer_tpl.hidden_dim = 7 # multi-source cross attention p.tr_atten_tpl = ( attention.TransformerMultiSourceAttentionLayer.Params().Set( num_source=2, primary_source_index=0, num_heads=2)) p.tr_self_atten_tpl = attention.TransformerAttentionLayer.Params().Set( input_dim=4, num_heads=2) p.tr_atten_tpl.multi_source_atten.atten_merger_tpl = ( tm_attention.MergerLayer.Params().Set(merger_op='sum')) if use_relative_atten: p = attention.UseRelativeAttentionInTransformerLayer(p, 4) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) return attention.MultiSourceTransformerDecoderLayer(p) @parameterized.named_parameters( { 'testcase_name': '_short_seq', 'use_short_seq_opt': True, }, { 'testcase_name': '_long_seq', 'use_short_seq_opt': False, }) def testMultiSourceTransformerDecoderLayerExtendStep(self, use_short_seq_opt): with self.session(use_gpu=True) as sess: (query_vec, _, aux_vec, aux_paddings) = self._TransformerAttentionLayerInputs() paddings = tf.zeros([2, 5]) cached_key = tf.zeros([5, 2, 2, 2]) cached_value = tf.zeros([5, 2, 2, 2]) prefix_states = py_utils.NestedMap(key=cached_key, value=cached_value) l = self._ConstructMultiSourceTransformerDecoderLayer() ms_aux_vec = py_utils.NestedMap({ 'source_0': aux_vec, 'source_1': aux_vec }) ms_aux_paddings = py_utils.NestedMap({ 'source_0': aux_paddings, 'source_1': aux_paddings }) layer_output1, _ = l.FProp(l.theta, query_vec, paddings, ms_aux_vec, ms_aux_paddings) layer_output2 = [] for i in range(5): layer_output, prefix_states = l.ExtendStep( l.theta, tf.expand_dims(query_vec[:, i, :], 1), ms_aux_vec, ms_aux_paddings, prefix_states, i, use_short_seq_opt) layer_output2.append(tf.squeeze(layer_output, 1)) layer_output2 = tf.transpose(tf.stack(layer_output2), [1, 0, 2]) tf.global_variables_initializer().run() actual_layer_output1, actual_layer_output2 = sess.run( [layer_output1, layer_output2]) self.assertAllClose(actual_layer_output1, actual_layer_output2) def testGPipeTransformerLayerConstruction(self): p = attention.GPipeTransformerLayer.Params() p.name = 'gpipe_transformer_layer' p.input_dim = 4 p.tr_fflayer_tpl.hidden_dim = 7 p.tr_atten_tpl.num_heads = 2 p.tr_atten_tpl.residual_dropout_prob = 0.1 p.cls.SetupDeterministicDropout(p) layer = p.Instantiate() self.assertEqual(0.1, layer.params.tr_atten_tpl.residual_dropout_prob) class BuilderTest(test_utils.TestCase, parameterized.TestCase): def _testGraph(self, glu_with_tanh=False, dtype=tf.float32): tf.random.set_seed(398847392) np.random.seed(12345) atten_builder = attention.Builder.Params().Set( model_dim=4, num_heads=2, ff_hidden_dim=16, glu_with_tanh=glu_with_tanh) params = atten_builder.Instantiate().LConvStack( name='lightconv', kernel_sizes=[3, 3]) params.dtype = dtype params.random_seed = 0 params.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = params.Instantiate() l_in = tf.constant(np.random.rand(2, 3, 4), dtype=dtype) l_padding = tf.zeros([2, 3], dtype=dtype) l_out = l.FPropDefaultTheta( py_utils.NestedMap(vec=l_in, paddings=l_padding)) return l_out.vec @parameterized.parameters((False, 38.163662), (True, 35.88797)) def testFprop(self, glu_with_tanh, expected_result): with self.session(use_gpu=False, graph=tf.Graph()) as sess: l_out = self._testGraph(glu_with_tanh) l_out = tf.reduce_sum(l_out) tf.global_variables_initializer().run() l_out_eval = sess.run(l_out) self.assertAllClose(expected_result, l_out_eval) def testBProp(self): with self.session(use_gpu=True) as sess: output = self._testGraph(dtype=tf.float64) loss = tf.reduce_sum(output) all_vars = tf.trainable_variables() grads = tf.gradients(loss, all_vars) tf.global_variables_initializer().run() sym_grads = [sg.eval() for sg in grads] num_grads = [ test_utils.ComputeNumericGradient(sess, loss, v) for v in all_vars ] for ng, sg in zip(num_grads, sym_grads): self.assertAllClose(ng, sg, rtol=5e-02, atol=5e-02) @parameterized.named_parameters( { 'testcase_name': '_baseline', 'strides': [1, 1], }, { 'testcase_name': '_stride_2', 'strides': [2, 1], }, { 'testcase_name': '_first_token', 'strides': [2, 0], }) def testTransformerStackWithStride(self, strides): with self.session(use_gpu=False) as sess: bs = 2 sl = 10 d = 16 tf.random.set_seed(12345) atten_builder = attention.Builder.Params().Set( model_dim=d, num_heads=2, ff_hidden_dim=5).Instantiate() layers = [] accumulate_stride = 1 for layer_i, stride in enumerate(strides): accumulate_stride *= stride layers.append( atten_builder.TransformerEncoderLayer( name='atten_{}'.format(layer_i), stride=stride)) p = atten_builder.Seq('model', *layers) p.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) l = p.Instantiate() input_embs = tf.constant( np.random.random(size=[bs, sl, d]), dtype=np.float) paddings = tf.zeros([bs, sl]) l_out = l.FPropDefaultTheta( py_utils.NestedMap(vec=input_embs, paddings=paddings)) enc_out = l_out.vec tf.global_variables_initializer().run() actual_enc_out = sess.run(enc_out) seq_len = sl // accumulate_stride if accumulate_stride != 0 else 1 self.assertAllEqual([bs, seq_len, d], actual_enc_out.shape) def _CreateDummyParams(field_names): p = hyperparams.Params() for name in field_names: p.Define(name, None, 'Dummy') return p class DummyDecoderRNNT(base_layer.BaseLayer): @classmethod def Params(cls): p = super(DummyDecoderRNNT, cls).Params() p.name = 'dummy_decoder_rnnt' p.Define('emb', _CreateDummyParams(['vocab_size']), 'Dummy emb.') p.Define('target_seq_len', 20, 'Dummy target seq len.') p.Define('num_classes', None, 'Dummy num classes.') return p @classmethod def UpdateTargetVocabSize(cls, p, vocab_size, wpm_model=None): p.emb.vocab_size = vocab_size p.num_classes = vocab_size return p class RelativeAttentionHelperTest(test_utils.TestCase, parameterized.TestCase): @parameterized.named_parameters( ('MultiHeadedAttentionXL', attention.MultiHeadedAttentionXL, attention.MultiHeadedAttention), ('LocalCausalSelfAttentionXL', attention.LocalCausalSelfAttentionXL, attention.LocalCausalSelfAttention)) def testClearRelativeAttentionInTransformerLayer(self, atten_cls, expected_atten_cls): """Tests scenarios in clear relative attention in transformer layer.""" trans_p = attention.TransformerLayer.Params() # set attention params in transformer layer. input_dim = 4 rel_pos_emb_dim = 4 # Set rel_pos_emb_dim in attention params. trans_p.tr_atten_tpl.atten_tpl = ( atten_cls.Params().Set( input_dim=input_dim, rel_pos_emb_dim=rel_pos_emb_dim)) new_trans_p = attention.ClearRelativeAttentionInTransformerLayer(trans_p) tr_atten_tpl = new_trans_p.tr_self_atten_tpl.atten_tpl self.assertEqual(tr_atten_tpl.cls, expected_atten_cls) self.assertEqual(tr_atten_tpl.input_dim, input_dim) def testClearRelativeAttentionTransformerLayerNotSupportedError(self): transformer_params = DummyDecoderRNNT.Params() with self.assertRaises(ValueError): _ = attention.ClearRelativeAttentionInTransformerLayer(transformer_params) def testClearRelativeAttentionAttentionParamsNotSupportedError(self): trans_p = attention.TransformerLayer.Params() # MultiHeadedAttention is not supported in ClearRelativeAttention. attention_params = attention.MultiHeadedAttention.Params() trans_p.tr_atten_tpl.atten_tpl = attention_params with self.assertRaises(ValueError): _ = attention.ClearRelativeAttentionInTransformerLayer(trans_p) @parameterized.named_parameters( ('AttentionParamsNotSupported', _CreateDummyParams( ['name', 'cls']), attention.ATTEN_TRANSFORMER_XL), ('AttentionTypeNotSupported', attention.MultiHeadedAttention.Params(), 'unsupported_atten_type')) def testUseRelativeAttentionInTransformerLayerValueError( self, attention_params, attention_type): """Tests unsupported Use Relative Attention cases.""" transformer_param = attention.TransformerLayer.Params() transformer_param.tr_atten_tpl.atten_tpl = attention_params rel_pos_emb_dim = 4 with self.assertRaises(ValueError): _ = attention.UseRelativeAttentionInTransformerLayer( transformer_param, rel_pos_emb_dim, atten_type=attention_type) def testUseRelativeAttentionInTransformerLayerNotSupportedError(self): """Tests unsupported input transformer params in Use Relative Attention.""" transformer_params = DummyDecoderRNNT.Params() with self.assertRaises(ValueError): _ = attention.UseRelativeAttentionInTransformerLayer( transformer_params, 4, atten_type=attention.ATTEN_TRANSFORMER_XL) @parameterized.named_parameters( ('MultiHeadedAttention', attention.MultiHeadedAttention, attention.MultiHeadedAttentionXL, attention.ATTEN_TRANSFORMER_XL), ('LocalCausalSelfAttention', attention.LocalCausalSelfAttention, attention.LocalCausalSelfAttentionXL, attention.ATTEN_TRANSFORMER_XL), ('MultiHeadedAttentionRPE', attention.MultiHeadedAttention, attention.MultiHeadedAttentionRPE, attention.ATTEN_RPE)) def testUseRelativeAttentionInTransformerLayer(self, atten_cls, expected_atten_cls, atten_type): """Tests different scenarios in Use Relative Attention.""" trans_p = attention.TransformerLayer.Params() # set attenion params in transformer layer. input_dim = 4 trans_p.tr_atten_tpl.atten_tpl = atten_cls.Params().Set(input_dim=input_dim) rel_pos_emb_dim = 4 new_trans_p = attention.UseRelativeAttentionInTransformerLayer( trans_p, rel_pos_emb_dim, atten_type=atten_type) tr_atten_tpl = new_trans_p.tr_self_atten_tpl.atten_tpl self.assertEqual(tr_atten_tpl.cls, expected_atten_cls) self.assertEqual(tr_atten_tpl.rel_pos_emb_dim, rel_pos_emb_dim) self.assertEqual(tr_atten_tpl.input_dim, input_dim) if __name__ == '__main__': tf.test.main()