import tensorflow as tf from keras import backend as K from keras import regularizers, constraints, initializers, activations from keras.layers.recurrent import Recurrent from keras.engine import InputSpec from .tdd import _time_distributed_dense tfPrint = lambda d, T: tf.Print(input_=T, data=[T, tf.shape(T)], message=d) class AttentionDecoder(Recurrent): def __init__(self, units, output_dim, activation='tanh', return_probabilities=False, name='AttentionDecoder', kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, **kwargs): """ Implements an AttentionDecoder that takes in a sequence encoded by an encoder and outputs the decoded states :param units: dimension of the hidden state and the attention matrices :param output_dim: the number of labels in the output space references: Bahdanau, Dzmitry, Kyunghyun Cho, and Yoshua Bengio. "Neural machine translation by jointly learning to align and translate." arXiv preprint arXiv:1409.0473 (2014). """ self.units = units self.output_dim = output_dim self.return_probabilities = return_probabilities self.activation = activations.get(activation) self.kernel_initializer = initializers.get(kernel_initializer) self.recurrent_initializer = initializers.get(recurrent_initializer) self.bias_initializer = initializers.get(bias_initializer) self.kernel_regularizer = regularizers.get(kernel_regularizer) self.recurrent_regularizer = regularizers.get(kernel_regularizer) self.bias_regularizer = regularizers.get(bias_regularizer) self.activity_regularizer = regularizers.get(activity_regularizer) self.kernel_constraint = constraints.get(kernel_constraint) self.recurrent_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) super(AttentionDecoder, self).__init__(**kwargs) self.name = name self.return_sequences = True # must return sequences def build(self, input_shape): """ See Appendix 2 of Bahdanau 2014, arXiv:1409.0473 for model details that correspond to the matrices here. """ self.batch_size, self.timesteps, self.input_dim = input_shape if self.stateful: super(AttentionDecoder, self).reset_states() self.states = [None, None] # y, s """ Matrices for creating the context vector """ self.V_a = self.add_weight(shape=(self.units,), name='V_a', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.W_a = self.add_weight(shape=(self.units, self.units), name='W_a', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.U_a = self.add_weight(shape=(self.input_dim, self.units), name='U_a', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.b_a = self.add_weight(shape=(self.units,), name='b_a', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) """ Matrices for the r (reset) gate """ self.C_r = self.add_weight(shape=(self.input_dim, self.units), name='C_r', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.U_r = self.add_weight(shape=(self.units, self.units), name='U_r', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.W_r = self.add_weight(shape=(self.output_dim, self.units), name='W_r', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.b_r = self.add_weight(shape=(self.units, ), name='b_r', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) """ Matrices for the z (update) gate """ self.C_z = self.add_weight(shape=(self.input_dim, self.units), name='C_z', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.U_z = self.add_weight(shape=(self.units, self.units), name='U_z', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.W_z = self.add_weight(shape=(self.output_dim, self.units), name='W_z', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.b_z = self.add_weight(shape=(self.units, ), name='b_z', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) """ Matrices for the proposal """ self.C_p = self.add_weight(shape=(self.input_dim, self.units), name='C_p', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.U_p = self.add_weight(shape=(self.units, self.units), name='U_p', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.W_p = self.add_weight(shape=(self.output_dim, self.units), name='W_p', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.b_p = self.add_weight(shape=(self.units, ), name='b_p', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) """ Matrices for making the final prediction vector """ self.C_o = self.add_weight(shape=(self.input_dim, self.output_dim), name='C_o', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.U_o = self.add_weight(shape=(self.units, self.output_dim), name='U_o', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.W_o = self.add_weight(shape=(self.output_dim, self.output_dim), name='W_o', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.b_o = self.add_weight(shape=(self.output_dim, ), name='b_o', initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) # For creating the initial state: self.W_s = self.add_weight(shape=(self.input_dim, self.units), name='W_s', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) self.input_spec = [ InputSpec(shape=(self.batch_size, self.timesteps, self.input_dim))] self.built = True def call(self, x): # store the whole sequence so we can "attend" to it at each timestep self.x_seq = x # apply the a dense layer over the time dimension of the sequence # do it here because it doesn't depend on any previous steps # thefore we can save computation time: self._uxpb = _time_distributed_dense(self.x_seq, self.U_a, b=self.b_a, input_dim=self.input_dim, timesteps=self.timesteps, output_dim=self.units) return super(AttentionDecoder, self).call(x) def get_initial_state(self, inputs): print('inputs shape:', inputs.get_shape()) # apply the matrix on the first time step to get the initial s0. s0 = activations.tanh(K.dot(inputs[:, 0], self.W_s)) # from keras.layers.recurrent to initialize a vector of (batchsize, # output_dim) y0 = K.zeros_like(inputs) # (samples, timesteps, input_dims) y0 = K.sum(y0, axis=(1, 2)) # (samples, ) y0 = K.expand_dims(y0) # (samples, 1) y0 = K.tile(y0, [1, self.output_dim]) return [y0, s0] def step(self, x, states): ytm, stm = states # repeat the hidden state to the length of the sequence _stm = K.repeat(stm, self.timesteps) # now multiplty the weight matrix with the repeated hidden state _Wxstm = K.dot(_stm, self.W_a) # calculate the attention probabilities # this relates how much other timesteps contributed to this one. et = K.dot(activations.tanh(_Wxstm + self._uxpb), K.expand_dims(self.V_a)) at = K.exp(et) at_sum = K.sum(at, axis=1) at_sum_repeated = K.repeat(at_sum, self.timesteps) at /= at_sum_repeated # vector of size (batchsize, timesteps, 1) # calculate the context vector context = K.squeeze(K.batch_dot(at, self.x_seq, axes=1), axis=1) # ~~~> calculate new hidden state # first calculate the "r" gate: rt = activations.sigmoid( K.dot(ytm, self.W_r) + K.dot(stm, self.U_r) + K.dot(context, self.C_r) + self.b_r) # now calculate the "z" gate zt = activations.sigmoid( K.dot(ytm, self.W_z) + K.dot(stm, self.U_z) + K.dot(context, self.C_z) + self.b_z) # calculate the proposal hidden state: s_tp = activations.tanh( K.dot(ytm, self.W_p) + K.dot((rt * stm), self.U_p) + K.dot(context, self.C_p) + self.b_p) # new hidden state: st = (1-zt)*stm + zt * s_tp yt = activations.softmax( K.dot(ytm, self.W_o) + K.dot(stm, self.U_o) + K.dot(context, self.C_o) + self.b_o) if self.return_probabilities: return at, [yt, st] else: return yt, [yt, st] def compute_output_shape(self, input_shape): """ For Keras internal compatability checking """ if self.return_probabilities: return (None, self.timesteps, self.timesteps) else: return (None, self.timesteps, self.output_dim) def get_config(self): """ For rebuilding models on load time. """ config = { 'output_dim': self.output_dim, 'units': self.units, 'return_probabilities': self.return_probabilities } base_config = super(AttentionDecoder, self).get_config() return dict(list(base_config.items()) + list(config.items())) # check to see if it compiles if __name__ == '__main__': from keras.layers import Input, LSTM from keras.models import Model from keras.layers.wrappers import Bidirectional i = Input(shape=(100,104), dtype='float32') enc = Bidirectional(LSTM(64, return_sequences=True), merge_mode='concat')(i) dec = AttentionDecoder(32, 4)(enc) model = Model(inputs=i, outputs=dec) model.summary()