from keras import backend as K, initializers, regularizers, constraints from keras.engine.topology import Layer def dot_product(x, kernel): if K.backend() == 'tensorflow': return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1) else: return K.dot(x, kernel) class Attention(Layer): def __init__(self, W_regularizer=None, b_regularizer=None, W_constraint=None, b_constraint=None, bias=True, return_attention=False, **kwargs): self.supports_masking = True self.return_attention = return_attention self.init = initializers.get('glorot_uniform') self.W_regularizer = regularizers.get(W_regularizer) self.b_regularizer = regularizers.get(b_regularizer) self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) self.bias = bias super(Attention, self).__init__(**kwargs) def build(self, input_shape): assert len(input_shape) == 3 self.W = self.add_weight((input_shape[-1],), initializer=self.init, name='{}_W'.format(self.name), regularizer=self.W_regularizer, constraint=self.W_constraint) if self.bias: self.b = self.add_weight((input_shape[1],), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, constraint=self.b_constraint) else: self.b = None self.built = True def compute_mask(self, input, input_mask=None): return None def call(self, x, mask=None): eij = dot_product(x, self.W) if self.bias: eij += self.b eij = K.tanh(eij) a = K.exp(eij) if mask is not None: a *= K.cast(mask, K.floatx()) a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx()) weighted_input = x * K.expand_dims(a) result = K.sum(weighted_input, axis=1) if self.return_attention: return [result, a] return result def compute_output_shape(self, input_shape): if self.return_attention: return [(input_shape[0], input_shape[-1]), (input_shape[0], input_shape[1])] else: return input_shape[0], input_shape[-1]