from keras.engine import Layer, InputSpec from keras import backend as K from keras import initializers from keras import regularizers from keras import constraints def dot_product(x, kernel): """ Wrapper for dot product operation, in order to be compatible with both Theano and Tensorflow Args: x (): input kernel (): weights Returns: """ if K.backend() == 'tensorflow': return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1) else: return K.dot(x, kernel) class AttentionWithContext(Layer): """ Attention operation, with a context/query vector, for temporal data. Supports Masking. Follows the work of Yang et al. [https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf] "Hierarchical Attention Networks for Document Classification" by using a context vector to assist the attention # Input shape 3D tensor with shape: `(samples, steps, features)`. # Output shape 2D tensor with shape: `(samples, features)`. How to use: Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True. The dimensions are inferred based on the output shape of the RNN. Note: The layer has been tested with Keras 2.0.6 Example: model.add(LSTM(64, return_sequences=True)) model.add(AttentionWithContext()) # next add a Dense layer (for classification/regression) or whatever... """ def __init__(self, W_regularizer=None, u_regularizer=None, b_regularizer=None, W_constraint=None, u_constraint=None, b_constraint=None, bias=True, **kwargs): self.init = initializers.get('glorot_uniform') self.W_regularizer = regularizers.get(W_regularizer) self.u_regularizer = regularizers.get(u_regularizer) self.b_regularizer = regularizers.get(b_regularizer) self.W_constraint = constraints.get(W_constraint) self.u_constraint = constraints.get(u_constraint) self.b_constraint = constraints.get(b_constraint) self.bias = bias super(AttentionWithContext, self).__init__(**kwargs) def build(self, input_shape): assert len(input_shape) == 3 self.W = self.add_weight((input_shape[-1], input_shape[-1],), initializer=self.init, name='{}_W'.format(self.name), regularizer=self.W_regularizer, constraint=self.W_constraint) if self.bias: self.b = self.add_weight((input_shape[-1],), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, constraint=self.b_constraint) self.u = self.add_weight((input_shape[-1],), initializer=self.init, name='{}_u'.format(self.name), regularizer=self.u_regularizer, constraint=self.u_constraint) super(AttentionWithContext, self).build(input_shape) def call(self, x): uit = dot_product(x, self.W) if self.bias: uit += self.b uit = K.tanh(uit) ait = dot_product(uit, self.u) a = K.softmax(ait) a = K.expand_dims(a) weighted_input = x * a return K.sum(weighted_input, axis=1) def compute_output_shape(self, input_shape): return input_shape[0], input_shape[-1]