from keras.engine import Layer from keras import backend as K class Attention(Layer): """ Basic attention layer. Attention layers are normally used to find important tokens based on different labels. uses 'max trick' for numerical stability # Arguments: 1. use_bias: whether to use bias 2. use_context: whether to use context vector 3. return_attention: whether to return attention weights as part of output 4. attention_dim: dimensionality of the inner attention 5. activation: whether to use activation func in first MLP # Inputs: Tensor with shape (batch_size, time_steps, hidden_size) # Returns: Tensor with shape (batch_size, hidden_size) If return attention weight, an additional tensor with shape (batch_size, time_steps) will be returned. """ def __init__(self, use_bias=True, use_context=True, return_attention=False, attention_dim=None, activation=True, **kwargs): self.use_bias = use_bias self.use_context = use_context self.return_attention = return_attention self.attention_dim = attention_dim self.activation = activation super(Attention, self).__init__(**kwargs) def build(self, input_shape): if len(input_shape) < 3: raise ValueError( "Expected input shape of `(batch_size, time_steps, features)`, found `{}`".format(input_shape)) if self.attention_dim is None: attention_dim = input_shape[-1] else: attention_dim = self.attention_dim self.kernel = self.add_weight(name='kernel', shape=(input_shape[-1], attention_dim), initializer="glorot_normal", trainable=True) if self.use_bias: self.bias = self.add_weight(name='bias', shape=(attention_dim,), initializer="zeros", trainable=True) else: self.bias = None if self.use_context: self.context_kernel = self.add_weight(name='context_kernel', shape=(attention_dim, 1), initializer="glorot_normal", trainable=True) else: self.context_kernel = None super(Attention, self).build(input_shape) def call(self, x, mask=None): # MLP ut = K.dot(x, self.kernel) if self.use_bias: ut = K.bias_add(ut, self.bias) if self.activation: ut = K.tanh(ut) if self.context_kernel: ut = K.dot(ut, self.context_kernel) ut = K.squeeze(ut, axis=-1) # softmax at = K.exp(ut - K.max(ut, axis=-1, keepdims=True)) if mask is not None: at *= K.cast(mask, K.floatx()) att_weights = at / (K.sum(at, axis=1, keepdims=True) + K.epsilon()) # output atx = x * K.expand_dims(att_weights, axis=-1) output = K.sum(atx, axis=1) if self.return_attention: return [output, att_weights] return output def compute_mask(self, input, input_mask=None): if isinstance(input_mask, list): return [None] * len(input_mask) else: return None def compute_output_shape(self, input_shape): output_len = input_shape[2] if self.return_attention: return [(input_shape[0], output_len), (input_shape[0], input_shape[1])] return (input_shape[0], output_len)