from keras.engine import Layer from keras import backend as K from keras import initializers class Multi_Dim_Attention(Layer): """ 2D attention from "A Structured Self-Attentive Sentence Embedding" (2017) """ def __init__(self, ws1, ws2, punish, init='glorot_normal', **kwargs): self.kernel_initializer = initializers.get(init) self.weight_ws1 = ws1 self.weight_ws2 = ws2 self.punish = punish super(Multi_Dim_Attention, self).__init__(** kwargs) def build(self, input_shape): self.Ws1 = self.add_weight(shape=(input_shape[-1], self.weight_ws1), initializer=self.kernel_initializer, trainable=True, name='{}_Ws1'.format(self.name)) self.Ws2 = self.add_weight(shape=(self.weight_ws1, self.weight_ws2), initializer=self.kernel_initializer, trainable=True, name='{}_Ws2'.format(self.name)) self.batch_size = input_shape[0] super(Multi_Dim_Attention, self).build(input_shape) def compute_mask(self, input, input_mask=None): return None def call(self, x, mask=None): uit = K.tanh(K.dot(x, self.Ws1)) ait = K.dot(uit, self.Ws2) ait = K.permute_dimensions(ait, (0, 2, 1)) A = K.softmax(ait, axis=1) M = K.batch_dot(A, x) if self.punish: A_T = K.permute_dimensions(A, (0, 2, 1)) tile_eye = K.tile(K.eye(self.weight_ws2), [self.batch_size, 1]) tile_eye = K.reshape( tile_eye, shape=[-1, self.weight_ws2, self.weight_ws2]) AA_T = K.batch_dot(A, A_T) - tile_eye P = K.l2_normalize(AA_T, axis=(1, 2)) return M, P else: return M def compute_output_shape(self, input_shape): if self.punish: out1 = (input_shape[0], self.weight_ws2, input_shape[-1]) out2 = (input_shape[0], self.weight_ws2, self.weight_ws2) return [out1, out2] else: return (input_shape[0], self.weight_ws2, input_shape[-1])