''' Layers used after RNN with return_sequence to summarize the sentence encoding. ''' from keras.engine import Layer from keras import initializations from keras import backend as K from keras_extensions import switch class AveragePooling(Layer): ''' This layer takes sequential output from an RNN and simply computes the average of it. ''' def __init__(self, **kwargs): self.supports_masking = True super(AveragePooling, self).__init__(**kwargs) def compute_mask(self, input_, mask=None): # pylint: disable=unused-argument return None def get_output_shape_for(self, input_shape): return (input_shape[0], input_shape[2]) def call(self, x, mask=None): # x: (batch_size, input_length, input_dim) if mask is None: return K.mean(x, axis=1) # (batch_size, input_dim) else: # This is to remove padding from the computational graph. if K.ndim(mask) > K.ndim(x): # This is due to the bug in Bidirectional that is passing the input mask # instead of computing output mask. # TODO: Fix the implementation of Bidirectional. mask = K.any(mask, axis=(-2, -1)) if K.ndim(mask) < K.ndim(x): mask = K.expand_dims(mask) masked_input = switch(mask, x, K.zeros_like(x)) weights = K.cast(mask / (K.sum(mask) + K.epsilon()), 'float32') return K.sum(masked_input * weights, axis=1) # (batch_size, input_dim) class IntraAttention(AveragePooling): ''' This layer returns a average of the input, but the average is weighted by how close the vector from each timestep is to the mean. ''' def __init__(self, init='uniform', projection_dim=50, weights=None, **kwargs): self.intra_attention_weights = weights self.init = initializations.get(init) self.projection_dim = projection_dim super(IntraAttention, self).__init__(**kwargs) def build(self, input_shape): # pylint: disable=attribute-defined-outside-init input_dim = input_shape[-1] self.vector_projector = self.init((input_dim, self.projection_dim)) self.mean_projector = self.init((input_dim, self.projection_dim)) self.scorer = self.init((self.projection_dim,)) super(IntraAttention, self).build(input_shape) self.trainable_weights = [self.vector_projector, self.mean_projector, self.scorer] if self.intra_attention_weights is not None: self.set_weights(self.intra_attention_weights) del self.intra_attention_weights def call(self, x, mask=None): mean = super(IntraAttention, self).call(x, mask) # x: (batch_size, input_length, input_dim) # mean: (batch_size, input_dim) ones = K.expand_dims(K.mean(K.ones_like(x), axis=(0, 2)), dim=0) # (1, input_length) # (batch_size, input_length, input_dim) tiled_mean = K.permute_dimensions(K.dot(K.expand_dims(mean), ones), (0, 2, 1)) if mask is not None: if K.ndim(mask) > K.ndim(x): # Assuming this is because of the bug in Bidirectional. Temporary fix follows. # TODO: Fix Bidirectional. mask = K.any(mask, axis=(-2, -1)) if K.ndim(mask) < K.ndim(x): mask = K.expand_dims(mask) x = switch(mask, x, K.zeros_like(x)) # (batch_size, input_length, proj_dim) projected_combination = K.tanh(K.dot(x, self.vector_projector) + K.dot(tiled_mean, self.mean_projector)) scores = K.dot(projected_combination, self.scorer) # (batch_size, input_length) weights = K.softmax(scores) # (batch_size, input_length) attended_x = K.sum(K.expand_dims(weights) * x, axis=1) # (batch_size, input_dim) return attended_x def get_config(self): config = {"init": self.init.__name__, "projection_dim": self.projection_dim} base_config = super(IntraAttention, self).get_config() config.update(base_config) return config