from __future__ import absolute_import import numpy as np from keras import backend as K from keras.layers import Layer from keras import initializers, regularizers, constraints def _softmax(x, dim): """Computes softmax along a specified dim. Keras currently lacks this feature. """ if K.backend() == 'tensorflow': import tensorflow as tf return tf.nn.softmax(x, dim) elif K.backend() is 'cntk': import cntk return cntk.softmax(x, dim) elif K.backend() == 'theano': # Theano cannot softmax along an arbitrary dim. # So, we will shuffle `dim` to -1 and un-shuffle after softmax. perm = np.arange(K.ndim(x)) perm[dim], perm[-1] = perm[-1], perm[dim] x_perm = K.permute_dimensions(x, perm) output = K.softmax(x_perm) # Permute back perm[dim], perm[-1] = perm[-1], perm[dim] output = K.permute_dimensions(x, output) return output else: raise ValueError("Backend '{}' not supported".format(K.backend())) class AttentionLayer(Layer): """Attention layer that computes a learned attention over input sequence. For details, see papers: - https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf - http://colinraffel.com/publications/iclr2016feed.pdf (fig 1) Input: x: Input tensor of shape `(..., time_steps, features)` where `features` must be static (known). Output: 2D tensor of shape `(..., features)`. i.e., `time_steps` axis is attended over and reduced. """ def __init__(self, kernel_initializer='he_normal', kernel_regularizer=None, kernel_constraint=None, use_bias=True, bias_initializer='zeros', bias_regularizer=None, bias_constraint=None, use_context=True, context_initializer='he_normal', context_regularizer=None, context_constraint=None, attention_dims=None, **kwargs): """ Args: attention_dims: The dimensionality of the inner attention calculating neural network. For input `(32, 10, 300)`, with `attention_dims` of 100, the output is `(32, 10, 100)`. i.e., the attended words are 100 dimensional. This is then collapsed via summation to `(32, 10, 1)` to indicate the attention weights for 10 words. If set to None, `features` dims are used as `attention_dims`. (Default value: None) """ if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) super(AttentionLayer, self).__init__(**kwargs) self.kernel_initializer = initializers.get(kernel_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.use_bias = use_bias self.bias_initializer = initializers.get(bias_initializer) self.bias_regularizer = regularizers.get(bias_regularizer) self.bias_constraint = constraints.get(bias_constraint) self.use_context = use_context self.context_initializer = initializers.get(context_initializer) self.context_regularizer = regularizers.get(context_regularizer) self.context_constraint = constraints.get(context_constraint) self.attention_dims = attention_dims self.supports_masking = True def build(self, input_shape): if len(input_shape) < 3: raise ValueError("Expected input shape of `(..., time_steps, features)`, found `{}`".format(input_shape)) attention_dims = input_shape[-1] if self.attention_dims is None else self.attention_dims self.kernel = self.add_weight(shape=(input_shape[-1], attention_dims), initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight(shape=(attention_dims, ), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: self.bias = None if self.use_context: self.context_kernel = self.add_weight(shape=(attention_dims, ), initializer=self.context_initializer, name='context_kernel', regularizer=self.context_regularizer, constraint=self.context_constraint) else: self.context_kernel = None super(AttentionLayer, self).build(input_shape) def call(self, x, mask=None): # x: [..., time_steps, features] # ut = [..., time_steps, attention_dims] ut = K.dot(x, self.kernel) if self.use_bias: ut = K.bias_add(ut, self.bias) ut = K.tanh(ut) if self.use_context: ut = ut * self.context_kernel # Collapse `attention_dims` to 1. This indicates the weight for each time_step. ut = K.sum(ut, axis=-1, keepdims=True) # Convert those weights into a distribution but along time axis. # i.e., sum of alphas along `time_steps` axis should be 1. self.at = _softmax(ut, dim=1) if mask is not None: self.at *= K.cast(K.expand_dims(mask, -1), K.floatx()) # Weighted sum along `time_steps` axis. return K.sum(x * self.at, axis=-2) def compute_mask(self, input, input_mask=None): # do not pass the mask to the next layers return None def compute_output_shape(self, input_shape): return input_shape[0], input_shape[-1] def get_attention_tensor(self): if not hasattr(self, 'at'): raise ValueError('Attention tensor is available after calling this layer with an input') return self.at def get_config(self): config = { 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'bias_initializer': initializers.serialize(self.bias_initializer), 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 'bias_constraint': constraints.serialize(self.bias_constraint), 'context_initializer': initializers.serialize(self.context_initializer), 'context_regularizer': regularizers.serialize(self.context_regularizer), 'context_constraint': constraints.serialize(self.context_constraint) } base_config = super(AttentionLayer, self).get_config() return dict(list(base_config.items()) + list(config.items())) class ConsumeMask(Layer): """Layer that prevents mask propagation. """ def compute_mask(self, input, input_mask=None): # do not pass the mask to the next layers return None def call(self, x, mask=None): return x