# Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for attention.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import tensorflow as tf from third_party.tensor2tensor import common_layers # def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e5): """Adds a bunch of sinusoids of different frequencies to a Tensor. Each channel of the input Tensor is incremented by a sinusoid of a different frequency and phase. This allows attention to learn to use absolute and relative positions. Timing signals should be added to some precursors of both the query and the memory inputs to attention. The use of relative position is possible because sin(x+y) and cos(x+y) can be experessed in terms of y, sin(x) and cos(x). In particular, we use a geometric sequence of timescales starting with min_timescale and ending with max_timescale. The number of different timescales is equal to channels / 2. For each timescale, we generate the two sinusoidal signals sin(timestep/timescale) and cos(timestep/timescale). All of these sinusoids are concatenated in the channels dimension. Args: x: a Tensor with shape [batch, length, channels] min_timescale: a float max_timescale: a float Returns: a Tensor the same shape as x. """ length = tf.shape(x)[1] channels = tf.shape(x)[2] position = tf.to_float(tf.range(length)) num_timescales = channels // 2 log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.to_float(num_timescales) - 1)) inv_timescales = min_timescale * tf.exp( tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales, 0) signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) signal = tf.pad(signal, [[0, 0], [0, tf.mod(channels, 2)]]) signal = tf.reshape(signal, [1, length, channels]) return x + signal def add_timing_signal_nd(x, min_timescale=1.0, max_timescale=1.0e4): """Adds a bunch of sinusoids of different frequencies to a Tensor. Each channel of the input Tensor is incremented by a sinusoid of a different frequency and phase in one of the positional dimensions. This allows attention to learn to use absolute and relative positions. Timing signals should be added to some precursors of both the query and the memory inputs to attention. The use of relative position is possible because sin(a+b) and cos(a+b) can be experessed in terms of b, sin(a) and cos(a). x is a Tensor with n "positional" dimensions, e.g. one dimension for a sequence or two dimensions for an image We use a geometric sequence of timescales starting with min_timescale and ending with max_timescale. The number of different timescales is equal to channels // (n * 2). For each timescale, we generate the two sinusoidal signals sin(timestep/timescale) and cos(timestep/timescale). All of these sinusoids are concatenated in the channels dimension. Args: x: a Tensor with shape [batch, d1 ... dn, channels] min_timescale: a float max_timescale: a float Returns: a Tensor the same shape as x. """ static_shape = x.get_shape().as_list() num_dims = len(static_shape) - 2 channels = tf.shape(x)[-1] num_timescales = channels // (num_dims * 2) log_timescale_increment = ( math.log(float(max_timescale) / float(min_timescale)) / (tf.to_float(num_timescales) - 1)) inv_timescales = min_timescale * tf.exp( tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) for dim in xrange(num_dims): length = tf.shape(x)[dim + 1] position = tf.to_float(tf.range(length)) scaled_time = tf.expand_dims(position, 1) * tf.expand_dims( inv_timescales, 0) signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=1) prepad = dim * 2 * num_timescales postpad = channels - (dim + 1) * 2 * num_timescales signal = tf.pad(signal, [[0, 0], [prepad, postpad]]) for _ in xrange(1 + dim): signal = tf.expand_dims(signal, 0) for _ in xrange(num_dims - 1 - dim): signal = tf.expand_dims(signal, -2) x += signal return x def add_positional_embedding_nd(x, max_length, name): """Add n-dimensional positional embedding. Adds embeddings to represent the positional dimensions of the tensor. The input tensor has n positional dimensions - i.e. 1 for text, 2 for images, 3 for video, etc. Args: x: a Tensor with shape [batch, p1 ... pn, depth] max_length: an integer. static maximum size of any dimension. name: a name for this layer. Returns: a Tensor the same shape as x. """ static_shape = x.get_shape().as_list() dynamic_shape = tf.shape(x) num_dims = len(static_shape) - 2 depth = static_shape[-1] base_shape = [1] * (num_dims + 1) + [depth] base_start = [0] * (num_dims + 2) base_size = [-1] + [1] * num_dims + [depth] for i in xrange(num_dims): shape = base_shape[:] start = base_start[:] size = base_size[:] shape[i + 1] = max_length size[i + 1] = dynamic_shape[i + 1] var = (tf.get_variable( name + "_%d" % i, shape, initializer=tf.random_normal_initializer(0, depth ** -0.5)) * (depth ** 0.5)) x += tf.slice(var, start, size) return x def embedding_to_padding(emb): """Input embeddings -> is_padding. We have hacked symbol_modality to return all-zero embeddings for padding. Args: emb: a Tensor with shape [..., depth]. Returns: a boolean Tensor with shape [...]. """ emb_sum = tf.reduce_sum(tf.abs(emb), axis=-1) return tf.equal(emb_sum, 0.0) def attention_bias_lower_triangle(length): """Create an bias tensor to be added to attention logits. Args: length: a Scalar. Returns: a `Tensor` with shape [1, 1, length, length]. """ lower_triangle = tf.matrix_band_part(tf.ones([length, length]), -1, 0) ret = -1e9 * (1.0 - lower_triangle) return tf.reshape(ret, [1, 1, length, length]) def attention_bias_ignore_padding(memory_padding): """Create an bias tensor to be added to attention logits. Args: memory_padding: a boolean `Tensor` with shape [batch, memory_length]. Returns: a `Tensor` with shape [batch, 1, 1, memory_length]. """ ret = tf.to_float(memory_padding) * -1e9 return tf.expand_dims(tf.expand_dims(ret, 1), 1) def split_last_dimension(x, n): """Reshape x so that the last dimension becomes two dimensions. The first of these two dimensions is n. Args: x: a Tensor with shape [..., m] n: an integer. Returns: a Tensor with shape [..., n, m/n] """ old_shape = x.get_shape().dims last = old_shape[-1] new_shape = old_shape[:-1] + [n] + [last // n if last else None] ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [n, -1]], 0)) ret.set_shape(new_shape) return ret def combine_last_two_dimensions(x): """Reshape x so that the last two dimension become one. Args: x: a Tensor with shape [..., a, b] Returns: a Tensor with shape [..., ab] """ old_shape = x.get_shape().dims a, b = old_shape[-2:] new_shape = old_shape[:-2] + [a * b if a and b else None] ret = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [-1]], 0)) ret.set_shape(new_shape) return ret def split_heads(x, num_heads): """Split channels (dimension 3) into multiple heads (becomes dimension 1). Args: x: a Tensor with shape [batch, length, channels] num_heads: an integer Returns: a Tensor with shape [batch, num_heads, length, channels / num_heads] """ return tf.transpose(split_last_dimension(x, num_heads), [0, 2, 1, 3]) def combine_heads(x): """Inverse of split_heads. Args: x: a Tensor with shape [batch, num_heads, length, channels / num_heads] Returns: a Tensor with shape [batch, length, channels] """ return combine_last_two_dimensions(tf.transpose(x, [0, 2, 1, 3])) def attention_image_summary(attn, image_shapes=None): """Compute color image summary. Args: attn: a Tensor with shape [batch, num_heads, query_length, memory_length] image_shapes: optional quadruple of integer scalars. If the query positions and memory positions represent the pixels of a flattened image, then pass in their dimensions: (query_rows, query_cols, memory_rows, memory_cols). """ num_heads = attn.get_shape().as_list()[1] # [batch, query_length, memory_length, num_heads] image = tf.transpose(attn, [0, 2, 3, 1]) image = tf.pow(image, 0.2) # for high-dynamic-range # Each head will correspond to one of RGB. # pad the heads to be a multiple of 3 image = tf.pad(image, [[0, 0], [0, 0], [0, 0], [0, -num_heads % 3]]) image = split_last_dimension(image, 3) image = tf.reduce_max(image, 4) if image_shapes is not None: q_rows, q_cols, m_rows, m_cols = list(image_shapes) image = tf.reshape(image, [-1, q_rows, q_cols, m_rows, m_cols, 3]) image = tf.transpose(image, [0, 1, 3, 2, 4, 5]) image = tf.reshape(image, [-1, q_rows * m_rows, q_cols * m_cols, 3]) tf.summary.image("attention", image, max_outputs=1) def dot_product_attention(q, k, v, bias, dropout_rate=0.0, summaries=False, image_shapes=None, name=None): """dot-product attention. Args: q: a Tensor with shape [batch, heads, length_q, depth_k] k: a Tensor with shape [batch, heads, length_kv, depth_k] v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number summaries: a boolean image_shapes: optional quadruple of integer scalars for image summary. If the query positions and memory positions represent the pixels of a flattened image, then pass in their dimensions: (query_rows, query_cols, memory_rows, memory_cols). name: an optional string Returns: A Tensor. """ with tf.variable_scope( name, default_name="dot_product_attention", values=[q, k, v]): # [batch, num_heads, query_length, memory_length] logits = tf.matmul(q, k, transpose_b=True) if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") # dropping out the attention links for each of the heads weights = tf.nn.dropout(weights, 1.0 - dropout_rate) if summaries and not tf.get_variable_scope().reuse: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v) def dot_product_attention_sigmoid(q, k, v, bias, dropout_rate=0.0, summaries=False, image_shapes=None, name=None): """dot-product attention. Args: q: a Tensor with shape [batch, heads, length_q, depth_k] k: a Tensor with shape [batch, heads, length_kv, depth_k] v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number summaries: a boolean image_shapes: optional quadruple of integer scalars for image summary. If the query positions and memory positions represent the pixels of a flattened image, then pass in their dimensions: (query_rows, query_cols, memory_rows, memory_cols). name: an optional string Returns: A Tensor. """ with tf.variable_scope( name, default_name="dot_product_attention", values=[q, k, v]): # [batch, num_heads, query_length, memory_length] logits = tf.matmul(q, k, transpose_b=True) if bias is not None: logits += bias weights = tf.nn.sigmoid(logits, name="attention_weights") # dropping out the attention links for each of the heads weights = tf.nn.dropout(weights, 1.0 - dropout_rate) if summaries and not tf.get_variable_scope().reuse: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v) def multihead_attention(query_antecedent, memory_antecedent, bias, total_key_depth, total_value_depth, output_depth, num_heads, dropout_rate, summaries=False, image_shapes=None, name=None): """Multihead scaled-dot-product attention with input/output transformations. Args: query_antecedent: a Tensor with shape [batch, length_q, channels] memory_antecedent: a Tensor with shape [batch, length_m, channels] bias: bias Tensor (see attention_bias()) total_key_depth: an integer total_value_depth: an integer output_depth: an integer num_heads: an integer dividing total_key_depth and total_value_depth dropout_rate: a floating point number summaries: a boolean image_shapes: optional quadruple of integer scalars for image summary. If the query positions and memory positions represent the pixels of a flattened image, then pass in their dimensions: (query_rows, query_cols, memory_rows, memory_cols). name: an optional string Returns: A Tensor. """ with tf.variable_scope( name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): if memory_antecedent is None: # self attention combined = common_layers.conv1d( query_antecedent, total_key_depth * 2 + total_value_depth, 1, name="qkv_transform") q, k, v = tf.split( combined, [total_key_depth, total_key_depth, total_value_depth], axis=2) else: q = common_layers.conv1d( query_antecedent, total_key_depth, 1, name="q_transform") combined = common_layers.conv1d( memory_antecedent, total_key_depth + total_value_depth, 1, name="kv_transform") k, v = tf.split(combined, [total_key_depth, total_value_depth], axis=2) q = split_heads(q, num_heads) k = split_heads(k, num_heads) v = split_heads(v, num_heads) key_depth_per_head = total_key_depth // num_heads import model as model if model.is_attention_smoothing: x = dot_product_attention_sigmoid( q, k, v, bias, dropout_rate, summaries, image_shapes) else: q *= key_depth_per_head ** -0.5 x = dot_product_attention( q, k, v, bias, dropout_rate, summaries, image_shapes) x = combine_heads(x) x = common_layers.conv1d(x, output_depth, 1, name="output_transform") return x