# coding=utf-8 # Copyright 2020 The Tensor2Tensor Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utilities for area attention.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np from six.moves import range # pylint: disable=redefined-builtin from tensor2tensor.layers import common_layers import tensorflow.compat.v1 as tf def lengths_to_area_mask(feature_length, length, max_area_size): """Generates a non-padding mask for areas based on lengths. Args: feature_length: a tensor of [batch_size] length: the length of the batch max_area_size: the maximum area size considered Returns: mask: a tensor in shape of [batch_size, num_areas] """ paddings = tf.cast(tf.expand_dims( tf.logical_not( tf.sequence_mask(feature_length, maxlen=length)), 2), tf.float32) _, _, area_sum, _, _ = compute_area_features(paddings, max_area_width=max_area_size) mask = tf.squeeze(tf.logical_not(tf.cast(area_sum, tf.bool)), [2]) return mask def _pool_one_shape(features_2d, area_width, area_height, batch_size, width, height, depth, fn=tf.reduce_max, name=None): """Pools for an area in features_2d. Args: features_2d: a Tensor in a shape of [batch_size, height, width, depth]. area_width: the max width allowed for an area. area_height: the max height allowed for an area. batch_size: the batch size. width: the width of the memory. height: the height of the memory. depth: the depth of the features. fn: the TF function for the pooling. name: the op name. Returns: pool_tensor: A Tensor of shape [batch_size, num_areas, depth] """ with tf.name_scope(name, default_name="pool_one_shape"): images = [] for y_shift in range(area_height): image_height = tf.maximum(height - area_height + 1 + y_shift, 0) for x_shift in range(area_width): image_width = tf.maximum(width - area_width + 1 + x_shift, 0) area = features_2d[:, y_shift:image_height, x_shift:image_width, :] flatten_area = tf.reshape(area, [batch_size, -1, depth, 1]) images.append(flatten_area) image_tensor = tf.concat(images, axis=3) max_tensor = fn(image_tensor, axis=3) return max_tensor def basic_pool(features, max_area_width, max_area_height=1, height=1, fn=tf.reduce_max, name=None): """Pools for each area based on a given pooling function (fn). Args: features: a Tensor in a shape of [batch_size, height * width, depth]. max_area_width: the max width allowed for an area. max_area_height: the max height allowed for an area. height: the height of the image. fn: the TF function for the pooling. name: the namescope. Returns: pool_results: A Tensor of shape [batch_size, num_areas, depth] area_heights: A Tensor of shape [batch_size, num_areas, 1] area_widths: A Tensor of shape [batch_size, num_areas, 1] """ with tf.name_scope(name, default_name="basic_pool"): feature_shape = common_layers.shape_list(features) batch_size = feature_shape[0] length = feature_shape[-2] depth = feature_shape[-1] width = length // height features_2d = tf.reshape(features, [batch_size, height, width, depth]) height_list = [] width_list = [] pool_list = [] size_tensor = tf.ones_like(features_2d[:, :, :, 0], dtype=tf.int32) for area_height in range(max_area_height): for area_width in range(max_area_width): pool_tensor = _pool_one_shape(features_2d, area_width=area_width + 1, area_height=area_height + 1, batch_size=batch_size, width=width, height=height, depth=depth, fn=fn) pool_list.append( tf.reshape(pool_tensor, [batch_size, -1, depth])) height_list.append( tf.reshape( size_tensor[:, area_height:, area_width:] *\ (area_height + 1), [batch_size, -1])) width_list.append( tf.reshape( size_tensor[:, area_height:, area_width:] *\ (area_width + 1), [batch_size, -1])) pool_results = tf.concat(pool_list, axis=1) area_heights = tf.expand_dims(tf.concat(height_list, axis=1), 2) area_widths = tf.expand_dims(tf.concat(width_list, axis=1), 2) return pool_results, area_heights, area_widths def _compute_sum_image(features, max_area_width, max_area_height=1, height=1, name=None): """Computes area sums for features. Args: features: a Tensor in a shape of [batch_size, height * width, depth]. max_area_width: the max width allowed for an area. max_area_height: the max height allowed for an area. height: the height of the image. name: the namescope. Returns: sum_image: A Tensor of shape [batch_size, num_areas, depth] area_heights: A Tensor of shape [batch_size, num_areas, 1] area_widths: A Tensor of shape [batch_size, num_areas, 1] """ with tf.name_scope(name, default_name="compute_sum_image"): feature_shape = common_layers.shape_list(features) batch_size = feature_shape[0] length = feature_shape[-2] depth = feature_shape[-1] width = length // height features_2d = tf.reshape(features, [batch_size, height, width, depth]) width_cum = tf.cumsum(features_2d, axis=-2, name="compute_integral_h") integral_image = tf.cumsum(width_cum, axis=-3, name="compute_integral_v") padded_image = tf.pad( integral_image, [[0, 0], [1, 0], [1, 0], [0, 0]], constant_values=0) height_list = [] width_list = [] dst_images = [] src_images_diag = [] src_images_h = [] src_images_v = [] size_tensor = tf.ones_like(padded_image[:, :, :, 0], dtype=tf.int32) for area_height in range(max_area_height): for area_width in range(max_area_width): dst_images.append( tf.reshape( padded_image[:, area_height + 1:, area_width + 1:, :], [batch_size, -1, depth])) src_images_diag.append( tf.reshape( padded_image[:, :-area_height - 1, :-area_width - 1, :], [batch_size, -1, depth])) src_images_h.append( tf.reshape( padded_image[:, area_height + 1:, :-area_width - 1, :], [batch_size, -1, depth])) src_images_v.append( tf.reshape( padded_image[:, :-area_height - 1, area_width + 1:, :], [batch_size, -1, depth])) height_list.append( tf.reshape( size_tensor[:, area_height + 1:, area_width + 1:] *\ (area_height + 1), [batch_size, -1])) width_list.append( tf.reshape( size_tensor[:, area_height + 1:, area_width + 1:] *\ (area_width + 1), [batch_size, -1])) sum_image = tf.subtract( tf.concat(dst_images, axis=1) + tf.concat(src_images_diag, axis=1), tf.concat(src_images_v, axis=1) + tf.concat(src_images_h, axis=1)) area_heights = tf.expand_dims(tf.concat(height_list, axis=1), 2) area_widths = tf.expand_dims(tf.concat(width_list, axis=1), 2) return sum_image, area_heights, area_widths def compute_area_features(features, max_area_width, max_area_height=1, height=1, epsilon=1e-6): """Computes features for each area. Args: features: a Tensor in a shape of [batch_size, height * width, depth]. max_area_width: the max width allowed for an area. max_area_height: the max height allowed for an area. height: the height of the image. epsilon: the epsilon added to the variance for computing standard deviation. Returns: area_mean: A Tensor of shape [batch_size, num_areas, depth] area_std: A Tensor of shape [batch_size, num_areas, depth] area_sum: A Tensor of shape [batch_size, num_areas, depth] area_heights: A Tensor of shape [batch_size, num_areas, 1] area_widths: A Tensor of shape [batch_size, num_areas, 1] """ with tf.name_scope("compute_area_features"): tf.logging.info("area_attention compute_area_features: %d x %d", max_area_height, max_area_width) area_sum, area_heights, area_widths = _compute_sum_image( features, max_area_width=max_area_width, max_area_height=max_area_height, height=height) area_squared_sum, _, _ = _compute_sum_image( tf.pow(features, 2), max_area_width=max_area_width, max_area_height=max_area_height, height=height) sizes = tf.multiply(area_heights, area_widths) float_area_sizes = tf.to_float(sizes) area_mean = tf.div(area_sum, float_area_sizes) s2_n = tf.div(area_squared_sum, float_area_sizes) area_variance = tf.subtract(s2_n, tf.pow(area_mean, 2)) area_std = tf.sqrt(tf.abs(area_variance) + epsilon) return area_mean, area_std, area_sum, area_heights, area_widths def compute_area_key(features, max_area_width, max_area_height=1, height=1, mode="mean", training=True, name=None): """Computes the key for each area. Args: features: a Tensor in a shape of [batch_size, height * width, depth]. max_area_width: the max width allowed for an area. max_area_height: the max height allowed for an area. height: the height of the image. mode: whether to combine different area features or only use the vector mean of each area, which can be "mean", "concat", "sum", "sample_concat", and "sample_sum". training: indicating if it is in the training mode. name: the name for setting the variable scope. Returns: area_key: a Tensor in the shape of [batch_size, num_areas, depth] """ tf.logging.info("area_attention mode=%s", mode) area_mean, area_std, _, area_heights, area_widths =\ compute_area_features(features, max_area_width=max_area_width, max_area_height=max_area_height, height=height) if mode == "mean": return area_mean elif mode == "max": area_max, _, _ = basic_pool(features, max_area_width=max_area_width, max_area_height=max_area_height, height=height) return area_max elif mode == "sample": if training: area_mean += (area_std * tf.random_normal(tf.shape(area_std))) return area_mean with tf.variable_scope( name, default_name="combine_area_features", values=[area_mean, area_std, area_heights, area_widths]): depth = common_layers.shape_list(area_mean)[-1] height_embed = tf.nn.embedding_lookup( params=tf.get_variable("area_height_emb", [max_area_height, depth // 2]), ids=area_heights[:, :, 0] - 1) width_embed = tf.nn.embedding_lookup( params=tf.get_variable("area_width_emb", [max_area_width, depth // 2]), ids=area_widths[:, :, 0] - 1) size_embed = tf.concat([height_embed, width_embed], -1) if mode == "concat": feature_concat = tf.concat([area_mean, area_std, size_embed], -1) elif mode == "max_concat": area_max, _, _ = basic_pool(features, max_area_width=max_area_width, max_area_height=max_area_height, height=height) feature_concat = tf.concat([area_max, size_embed], -1) elif mode == "sum": feature_concat = size_embed + area_mean + area_std elif mode == "sample_concat": if training: area_mean += (area_std * tf.random_normal(tf.shape(area_std))) feature_concat = tf.concat([area_mean, size_embed], -1) elif mode == "sample_sum": if training: area_mean += (area_std * tf.random_normal(tf.shape(area_std))) feature_concat = area_mean + size_embed else: raise ValueError("Unsupported area key mode=%s" % mode) feature_hidden = tf.layers.dense(inputs=feature_concat, units=depth, activation=tf.nn.relu) area_key = tf.layers.dense(feature_hidden, units=depth) return area_key def dot_product_area_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, attention_image_summary=None, save_weights_to=None, dropout_broadcast_dims=None, max_area_width=1, max_area_height=1, memory_height=1, area_key_mode="mean", area_value_mode="sum", top_k_areas=0, area_temperature=1.0, training=True): """Dot-product area attention. Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. bias: bias Tensor (see attention_bias()) dropout_rate: a float. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string attention_image_summary: the callback for making image summary of attention. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than rank of q. Specifies in which dimensions to broadcast the dropout decisions. max_area_width: the max width allowed for an area. max_area_height: the max height allowed for an area. memory_height: the height of the memory. area_key_mode: the mode for computing area keys, which can be "mean", "concat", "sum", "sample_concat", and "sample_sum". area_value_mode: the mode for computing area values, which can be either "mean", or "sum". top_k_areas: Use the top key areas for attention. area_temperature: the temperature for attention softmax. training: indicating if it is in the training mode. Returns: Tensor with shape [..., length_q, depth_v]. """ tf.logging.info("dot_product_area_attention: " "area_h=%d, area_w=%d, mem_h=%d, " "area_key_mode=%s, area_value_mode=%s, " "area_temperature=%f", max_area_height, max_area_width, memory_height, area_key_mode, area_value_mode, area_temperature) with tf.variable_scope( name, default_name="dot_product_area_attention", values=[q, k, v]) as scope: mem_shape = common_layers.shape_list(k) batch_size = mem_shape[0] head_size = mem_shape[1] length = mem_shape[2] depth = mem_shape[3] k_area = compute_area_key( tf.reshape(k, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height, mode=area_key_mode, training=training) if area_value_mode == "mean": v_area, _, _, _, _ = compute_area_features( tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) elif area_value_mode == "max": v_area, _, _ = basic_pool(tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height, fn=tf.reduce_max) elif area_value_mode == "sum": _, _, v_area, _, _ = compute_area_features( tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) else: raise ValueError("Unsupported area value mode=%s" % area_value_mode) k = tf.reshape(k_area, [batch_size, head_size, -1, depth]) v = tf.reshape(v_area, [batch_size, head_size, -1, depth]) logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv] if bias is not None: bias = common_layers.cast_like(bias, logits) with tf.name_scope("compute_area_att_bias", values=[bias]): bias_shape = common_layers.shape_list(bias) mem_length = bias_shape[-1] bias_values = tf.reshape( tf.to_float(tf.less(bias, -1)), [-1, mem_length, 1]) _, _, padding_sum, _, _ = compute_area_features( bias_values, max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) bias = tf.where( tf.cast(tf.to_int32(padding_sum), tf.bool), tf.fill(tf.shape(padding_sum), -np.inf), tf.zeros_like(padding_sum, dtype=tf.float32)) bias = tf.reshape(bias, [bias_shape[0], bias_shape[1], bias_shape[2], -1]) logits += bias logits = logits / area_temperature weights = tf.nn.softmax(logits, name="attention_weights") if top_k_areas > 0: tf.logging.info("area_attention top_k_areas=%d", top_k_areas) top_k = tf.minimum(common_layers.shape_list(weights)[-1], top_k_areas) top_weights, _ = tf.nn.top_k(weights, k=top_k) min_values = tf.reduce_min(top_weights, -1, keepdims=True) weights = tf.where(tf.greater_equal(weights, min_values), weights, tf.zeros_like(weights)) weights = tf.div(weights, tf.reduce_sum(weights, -1, keepdims=True)) if save_weights_to is not None: save_weights_to[scope.name] = weights save_weights_to[scope.name + "/logits"] = logits # Drop out attention links for each head. weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) if common_layers.should_generate_summaries() and attention_image_summary: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v)