''' recurrent tree networks author: bcm ''' from __future__ import absolute_import, print_function from keras.layers import Recurrent, time_distributed_dense, LSTM import keras.backend as K from keras import activations, initializations, regularizers from keras.engine import Layer, InputSpec import ikelos.backend.theano_backend as IKE import numpy as np class DualCurrent(Recurrent): ''' modified from keras's lstm; the recurrent tree network ''' def __init__(self, output_dim, init='glorot_uniform', inner_init='orthogonal', forget_bias_init='one', activation='tanh', inner_activation='hard_sigmoid', W_regularizer=None, U_regularizer=None, b_regularizer=None, dropout_W=0., dropout_U=0., **kwargs): self.output_dim = output_dim self.init = initializations.get(init) self.inner_init = initializations.get(inner_init) self.forget_bias_init = initializations.get(forget_bias_init) self.activation = activations.get(activation) self.inner_activation = activations.get(inner_activation) self.W_regularizer = regularizers.get(W_regularizer) self.U_regularizer = regularizers.get(U_regularizer) self.b_regularizer = regularizers.get(b_regularizer) self.dropout_W, self.dropout_U = dropout_W, dropout_U if self.dropout_W or self.dropout_U: self.uses_learning_phase = True super(DualCurrent, self).__init__(**kwargs) def get_initial_states(self, x): # build an all-zero tensor of shape (samples, output_dim) initial_state = K.zeros_like(x) # (samples, timesteps, input_dim) initial_state = K.permute_dimensions(x, [1,0,2]) # (timesteps, samples, input_dim) reducer = K.zeros((self.input_dim, self.output_dim)) initial_state = K.dot(initial_state, reducer) # (timesteps, samples, output_dim) initial_states = [initial_state for _ in range(len(self.states))] return initial_states def build(self, input_shapes): assert isinstance(input_shapes, list) rnn_shape, indices_shape = input_shapes self.input_spec = [InputSpec(shape=rnn_shape), InputSpec(shape=indices_shape)] input_dim = rnn_shape[2] self.input_dim = input_dim if self.stateful: self.reset_states() else: # initial states: 2 all-zero tensors of shape (output_dim) self.states = [None, None] ''' add a second incoming recurrent connection ''' self.W_i = self.init((input_dim, self.output_dim), name='{}_W_i'.format(self.name)) self.U_i_me = self.inner_init((self.output_dim, self.output_dim), name='{}_U_i_me'.format(self.name)) self.U_i_other = self.inner_init((self.output_dim, self.output_dim), name='{}_U_i_other'.format(self.name)) self.b_i = K.zeros((self.output_dim,), name='{}_b_i'.format(self.name)) self.W_f = self.init((input_dim, self.output_dim), name='{}_W_f'.format(self.name)) self.U_f_me = self.inner_init((self.output_dim, self.output_dim), name='{}_U_f_me'.format(self.name)) self.U_f_other = self.inner_init((self.output_dim, self.output_dim), name='{}_U_f_other'.format(self.name)) self.b_f = self.forget_bias_init((self.output_dim,), name='{}_b_f'.format(self.name)) self.W_c = self.init((input_dim, self.output_dim), name='{}_W_c'.format(self.name)) self.U_c_me = self.inner_init((self.output_dim, self.output_dim), name='{}_U_c_me'.format(self.name)) self.U_c_other = self.inner_init((self.output_dim, self.output_dim), name='{}_U_c_other'.format(self.name)) self.b_c = K.zeros((self.output_dim,), name='{}_b_c'.format(self.name)) self.W_o = self.init((input_dim, self.output_dim), name='{}_W_o'.format(self.name)) self.U_o_me = self.inner_init((self.output_dim, self.output_dim), name='{}_U_o_me'.format(self.name)) self.U_o_other = self.inner_init((self.output_dim, self.output_dim), name='{}_U_o_other'.format(self.name)) self.b_o = K.zeros((self.output_dim,), name='{}_b_o'.format(self.name)) self.regularizers = [] if self.W_regularizer: self.W_regularizer.set_param(K.concatenate([self.W_i, self.W_f, self.W_c, self.W_o])) self.regularizers.append(self.W_regularizer) if self.U_regularizer: self.U_regularizer.set_param(K.concatenate([self.U_i_me,self.U_i_other, self.U_f_me,self.U_f_other, self.U_c_me,self.U_c_other, self.U_o_me,self.U_o_other])) self.regularizers.append(self.U_regularizer) if self.b_regularizer: self.b_regularizer.set_param(K.concatenate([self.b_i, self.b_f, self.b_c, self.b_o])) self.regularizers.append(self.b_regularizer) self.trainable_weights = [self.W_i, self.U_i_me, self.U_i_other, self.b_i, self.W_c, self.U_c_me, self.U_c_other, self.b_c, self.W_f, self.U_f_me, self.U_f_other, self.b_f, self.W_o, self.U_o_me, self.U_o_other, self.b_o] if self.initial_weights is not None: self.set_weights(self.initial_weights) del self.initial_weights def reset_states(self): assert self.stateful, 'Layer must be stateful.' input_shape = self.input_spec[0].shape if not input_shape[0]: raise Exception('If a RNN is stateful, a complete ' + 'input_shape must be provided (including batch size).') if hasattr(self, 'states'): K.set_value(self.states[0], np.zeros((input_shape[1], input_shape[0], self.output_dim))) K.set_value(self.states[1], np.zeros((input_shape[1], input_shape[0], self.output_dim))) else: self.states = [K.zeros((input_shape[1], input_shape[0], self.output_dim)), K.zeros((input_shape[1], input_shape[0], self.output_dim))] def compute_mask(self, input, mask): if self.return_sequences: if isinstance(mask, list): return [mask[0], mask[0]] return [mask, mask] else: return [None, None] def get_output_shape_for(self, input_shapes): rnn_shape, indices_shape = input_shapes out_shape = super(DualCurrent, self).get_output_shape_for(rnn_shape) return [out_shape, out_shape] def preprocess_input(self, x): if self.consume_less == 'cpu': if 0 < self.dropout_W < 1: dropout = self.dropout_W else: dropout = 0 input_shape = self.input_spec[0].shape input_dim = input_shape[2] timesteps = input_shape[1] x_i = time_distributed_dense(x, self.W_i, self.b_i, dropout, input_dim, self.output_dim, timesteps) x_f = time_distributed_dense(x, self.W_f, self.b_f, dropout, input_dim, self.output_dim, timesteps) x_c = time_distributed_dense(x, self.W_c, self.b_c, dropout, input_dim, self.output_dim, timesteps) x_o = time_distributed_dense(x, self.W_o, self.b_o, dropout, input_dim, self.output_dim, timesteps) return K.concatenate([x_i, x_f, x_c, x_o], axis=2) else: return x def step(self, x, states): (h_tm1_me, h_tm1_other) = states[0] (c_tm1_me, c_tm1_other) = states[1] B_U = states[2] B_W = states[3] if self.consume_less == 'cpu': x_i = x[:, :self.output_dim] x_f = x[:, self.output_dim: 2 * self.output_dim] x_c = x[:, 2 * self.output_dim: 3 * self.output_dim] x_o = x[:, 3 * self.output_dim:] else: x_i = K.dot(x * B_W[0], self.W_i) + self.b_i x_f = K.dot(x * B_W[1], self.W_f) + self.b_f x_c = K.dot(x * B_W[2], self.W_c) + self.b_c x_o = K.dot(x * B_W[3], self.W_o) + self.b_o i = self.inner_activation(x_i + K.dot(h_tm1_me * B_U[0], self.U_i_me) + K.dot(h_tm1_other * B_U[0], self.U_i_other)) f_me = self.inner_activation(x_f + K.dot(h_tm1_me * B_U[1], self.U_f_me) + K.dot(h_tm1_other * B_U[1], self.U_f_me)) f_other = self.inner_activation(x_f + K.dot(h_tm1_me * B_U[1], self.U_f_other) + K.dot(h_tm1_other * B_U[1], self.U_f_other)) in_c = i * self.activation(x_c + K.dot(h_tm1_me * B_U[2], self.U_c_me) + K.dot(h_tm1_other * B_U[2], self.U_c_other)) re_c = f_me * c_tm1_me + f_other * c_tm1_other c = in_c + re_c o = self.inner_activation(x_o + K.dot(h_tm1_me * B_U[3], self.U_o_me) + K.dot(h_tm1_other * B_U[3], self.U_o_other)) h = o * self.activation(c) return h, [h, c] def call(self, xpind, mask=None): # input shape: (nb_samples, time (padded with zeros), input_dim) # note that the .build() method of subclasses MUST define # self.input_spec with a complete input shape. x, indices = xpind if isinstance(mask, list): mask, _ = mask input_shape = self.input_spec[0].shape if K._BACKEND == 'tensorflow': if not input_shape[1]: raise Exception('When using TensorFlow, you should define ' 'explicitly the number of timesteps of ' 'your sequences.\n' 'If your first layer is an Embedding, ' 'make sure to pass it an "input_length" ' 'argument. Otherwise, make sure ' 'the first layer has ' 'an "input_shape" or "batch_input_shape" ' 'argument, including the time axis. ' 'Found input shape at layer ' + self.name + ': ' + str(input_shape)) if self.stateful: initial_states = self.states else: initial_states = self.get_initial_states(x) constants = self.get_constants(x) preprocessed_input = self.preprocess_input(x) last_output, outputs, states = IKE.dualsignal_rnn(self.step, preprocessed_input, initial_states, indices, go_backwards=self.go_backwards, mask=mask, constants=constants, unroll=self.unroll, input_length=input_shape[1]) last_tree, last_summary = last_output tree_outputs, summary_outputs = outputs if self.stateful: self.updates = [] for i in range(len(states)): self.updates.append((self.states[i], states[i])) self.cached_states = states return [tree_outputs, summary_outputs] def get_constants(self, x): constants = [] if 0 < self.dropout_U < 1: ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1))) ones = K.concatenate([ones] * self.output_dim, 1) B_U = [K.in_train_phase(K.dropout(ones, self.dropout_U), ones) for _ in range(4)] constants.append(B_U) else: constants.append([K.cast_to_floatx(1.) for _ in range(4)]) if 0 < self.dropout_W < 1: input_shape = self.input_spec[0].shape input_dim = input_shape[-1] ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1))) ones = K.concatenate([ones] * input_dim, 1) B_W = [K.in_train_phase(K.dropout(ones, self.dropout_W), ones) for _ in range(4)] constants.append(B_W) else: constants.append([K.cast_to_floatx(1.) for _ in range(4)]) return constants def get_config(self): config = {"output_dim": self.output_dim, "init": self.init.__name__, "inner_init": self.inner_init.__name__, "forget_bias_init": self.forget_bias_init.__name__, "activation": self.activation.__name__, "inner_activation": self.inner_activation.__name__, "W_regularizer": self.W_regularizer.get_config() if self.W_regularizer else None, "U_regularizer": self.U_regularizer.get_config() if self.U_regularizer else None, "b_regularizer": self.b_regularizer.get_config() if self.b_regularizer else None, "dropout_W": self.dropout_W, "dropout_U": self.dropout_U} base_config = super(DualCurrent, self).get_config() return dict(list(base_config.items()) + list(config.items())) class BranchLSTM(LSTM): def build(self, input_shapes): assert isinstance(input_shapes, list) rnn_shape, indices_shape = input_shapes super(BranchLSTM, self).build(rnn_shape) self.input_spec += [InputSpec(shape=indices_shape)] def get_initial_states(self, x): # build an all-zero tensor of shape (samples, output_dim) initial_state = K.zeros_like(x) # (samples, timesteps, input_dim) initial_state = K.permute_dimensions(x, [1,0,2]) # (timesteps, samples, input_dim) reducer = K.zeros((self.input_dim, self.output_dim)) initial_state = K.dot(initial_state, reducer) # (timesteps, samples, output_dim) initial_states = [initial_state for _ in range(len(self.states))] return initial_states def reset_states(self): assert self.stateful, 'Layer must be stateful.' input_shape = self.input_spec[0].shape if not input_shape[0]: raise Exception('If a RNN is stateful, a complete ' + 'input_shape must be provided (including batch size).') if hasattr(self, 'states'): K.set_value(self.states[0], np.zeros((input_shape[1], input_shape[0], self.output_dim))) K.set_value(self.states[1], np.zeros((input_shape[1], input_shape[0], self.output_dim))) else: self.states = [K.zeros((input_shape[1], input_shape[0], self.output_dim)), K.zeros((input_shape[1], input_shape[0], self.output_dim))] def get_output_shape_for(self, input_shapes): rnn_shape, indices_shape = input_shapes return super(BranchLSTM, self).get_output_shape_for(rnn_shape) def compute_mask(self, input, mask): if self.return_sequences: if isinstance(mask, list): return mask[0] return mask else: return None def call(self, xpind, mask=None): # input shape: (nb_samples, time (padded with zeros), input_dim) # note that the .build() method of subclasses MUST define # self.input_spec with a complete input shape. x, indices = xpind if isinstance(mask, list): mask, _ = mask input_shape = self.input_spec[0].shape if K._BACKEND == 'tensorflow': if not input_shape[1]: raise Exception('When using TensorFlow, you should define ' 'explicitly the number of timesteps of ' 'your sequences.\n' 'If your first layer is an Embedding, ' 'make sure to pass it an "input_length" ' 'argument. Otherwise, make sure ' 'the first layer has ' 'an "input_shape" or "batch_input_shape" ' 'argument, including the time axis. ' 'Found input shape at layer ' + self.name + ': ' + str(input_shape)) if self.stateful: initial_states = self.states else: initial_states = self.get_initial_states(x) constants = self.get_constants(x) preprocessed_input = self.preprocess_input(x) last_output, outputs, states = IKE.stack_rnn(self.step, preprocessed_input, initial_states, indices, go_backwards=self.go_backwards, mask=mask, constants=constants, unroll=self.unroll, input_length=input_shape[1]) if self.stateful: self.updates = [] for i in range(len(states)): self.updates.append((self.states[i], states[i])) self.cached_states = states if self.return_sequences: return outputs else: return last_output class RTTN(Recurrent): '''Recurrent Tree Traversal Network # Arguments See GRU # Notes - ''' def __init__(self, output_dim, init='glorot_uniform', inner_init='orthogonal', activation='tanh', inner_activation='hard_sigmoid', W_regularizer=None, U_regularizer=None, b_regularizer=None, shape_key=None, dropout_W=0., dropout_U=0., **kwargs): self.output_dim = output_dim self.init = initializations.get(init) self.inner_init = initializations.get(inner_init) self.activation = activations.get(activation) self.inner_activation = activations.get(inner_activation) self.W_regularizer = regularizers.get(W_regularizer) self.U_regularizer = regularizers.get(U_regularizer) self.b_regularizer = regularizers.get(b_regularizer) self.dropout_W, self.dropout_U = dropout_W, dropout_U self.shape_key = shape_key or {} if self.dropout_W or self.dropout_U: self.uses_learning_phase = True kwargs['consume_less'] = 'gpu' super(RTTN, self).__init__(**kwargs) self.num_actions = 4 def compute_mask(self, input, mask): if self.return_sequences: if isinstance(mask, list): return [mask[0] for _ in range(4)] return [mask for _ in range(4)] else: return [None, None, None, None] def get_output_shape_for(self, input_shapes): '''given all inputs, compute output shape for all outputs crazy shape computations. super verbose and ugly now to make the code readable ''' ## normal in shapes are (batch, sequence, in_size) ## normal out shapes are (batch, sequence, out_size) ## horizon is (batch, sequence, sequence/horizon, features) ## horizon features is going to be concatenated branch and word feature vectors ## p_horizon is (batch, sequence, sequence/horizon) in_shape = input_shapes[0] out_shape = super(RTTN, self).get_output_shape_for(in_shape) b, s, fin = in_shape b, s, fout = out_shape w = self.shape_key['word'] h = self.shape_key['horizon'] horizon_shape = (b, s, h, w+fout) p_horizon_shape = (b, s, h) #horizon_shape = out_shape[:-1] (self.shape_key['horizon'], # in_shape[-1] + out_shape[-1]) #p_horizon_shape = out_shape[:-1] + (self.shape_key['horizon'],) return [out_shape, out_shape, horizon_shape, p_horizon_shape] def build(self, input_shapes): assert isinstance(input_shapes, list) rnn_shape, indices_shape = input_shapes[0], input_shapes[1] self.input_spec = [InputSpec(shape=rnn_shape), InputSpec(shape=indices_shape)] self.input_spec += [InputSpec(shape=None) for _ in range(len(input_shapes)-2)] self.input_dim = rnn_shape[2] # initial states: all-zero tensor of shape (output_dim) self.states = [None, None] assert self.consume_less == "gpu" ### NOTES. the 4 here is for 4 action types: sub/ins, left/right. self.W_x = self.init((self.num_actions, self.input_dim, 4 * self.output_dim), name='{}_W_x'.format(self.name)) self.b_x = K.variable(np.zeros(4 * self.output_dim), name='{}_b_x'.format(self.name)) ### used for parent node and traversal node recurrence computations self.U_p = self.inner_init((self.output_dim, 3 * self.output_dim), name='{}_U_p'.format(self.name)) self.U_v = self.inner_init((self.output_dim, 3 * self.output_dim), name='{}_U_v'.format(self.name)) ### used for the child node computation self.U_c = self.init((self.output_dim, 3 * self.output_dim), name='{}_U_c'.format(self.name)) self.b_c = K.variable(np.zeros(3 * self.output_dim), name='{}_b_c'.format(self.name)) self.W_ctx = self.init( (self.output_dim, self.shape_key['word'] + self.output_dim), name='{}_W_context'.format(self.name)) self.trainable_weights = [self.W_x, self.U_c, self.U_p, self.U_v, self.b_x, self.b_c, self.W_ctx] def reset_states(self): assert self.stateful, 'Layer must be stateful.' input_shape = self.input_spec[0].shape if not input_shape[0]: raise Exception('If a RNN is stateful, a complete ' + 'input_shape must be provided (including batch size).') if hasattr(self, 'states'): K.set_value(self.states[0], np.zeros((input_shape[0], self.output_dim))) K.set_value(self.states[1], np.zeros((input_shapes[1], input_shape[0], self.output_dim))) else: self.states = [K.zeros((input_shape[0], self.output_dim)), K.zeros((input_shapes[1], input_shape[0], self.output_dim))] def get_initial_states(self, x): # build an all-zero tensor of shape (samples, output_dim) initial_state = K.zeros_like(x) # (samples, timesteps, input_dim) initial_state = K.permute_dimensions(x, [1,0,2]) # (timesteps, samples, input_dim) reducer = K.zeros((self.input_dim, self.output_dim)) initial_state = K.dot(initial_state, reducer) # (timesteps, samples, output_dim) initial_traversal = K.sum(initial_state, axis=0) # traversal is (samples, output_dim) initial_states = [initial_traversal, initial_state] # this order matches assumptions in rttn scan function return initial_states def step(self, x, states): (h_p, h_v, # 0:parent, 1:traversal x_type, # 2:treetype(ins/sub,left/right); ints of size (B,). \in {0,1,2,3} B_U, B_W) = states # 3:Udropoutmask, 4:Wdropoutmask #### matrix x has all 4 x computations in it ## per move this_Wx = self.W_x[x_type] ## B, I, 4*O matrix_x = K.batch_dot(x * B_W[0], this_Wx) + self.b_x x_zp = matrix_x[:, :self.output_dim] x_rp = matrix_x[:, self.output_dim: 2 * self.output_dim] x_rv = matrix_x[:, 2 * self.output_dim: 3 * self.output_dim] x_ih = matrix_x[:, 3 * self.output_dim:] #### matrix p has zp, rp; matrix v has zv, rv matrix_p = K.dot(h_p * B_U[0], self.U_p[:, :2 * self.output_dim]) # zp is for the parent unit update (resulting in child unit) inner_zp = matrix_p[:, :self.output_dim] z_p = self.inner_activation(x_zp + inner_zp) # rp is for gating to the intermediate unit of parent inner_rp = matrix_p[:, self.output_dim: 2 * self.output_dim] r_p = self.inner_activation(x_rp + inner_rp) matrix_v = K.dot(h_v * B_U[0], self.U_v[:, :2 * self.output_dim]) # rv is for the intermediate gate on the traversal unit # this gets reused for both the parent's and its own intermediate inner_rv = matrix_v[:, self.output_dim: 2 * self.output_dim] r_v = self.inner_activation(x_rv + inner_rv) # the actual recurrence calculations # h_p * U and h_v * U ; as gated by their r gates inner_hp = K.dot(r_p * h_p * B_U[0], self.U_p[:, 2 * self.output_dim:]) inner_hv = K.dot(r_v * h_v * B_U[0], self.U_v[:, 2 * self.output_dim:]) # h_c_tilde is the intermediate state h_c_tilde = self.activation(x_ih + inner_hp + inner_hv) # h_c is the new child state h_c = z_p * h_c_tilde + (1 - z_p) * h_p matrix_c = K.dot(h_c * B_U[0], self.U_c) + self.b_c hc_zv = matrix_c[:, :self.output_dim] hc_rv = matrix_c[:, self.output_dim: 2 * self.output_dim] hc_ih = matrix_c[:, 2 * self.output_dim:] ### zv -> gate h_v and h_v_tilde ### rv -> gate h_v's contribution to h_v_tilde ### ih -> h_c's contribution to h_v_tilde # zv is for the traversal unit update. inner_zv = matrix_v[:, :self.output_dim] z_v = self.inner_activation(hc_zv + inner_zv) ## r_v is calculated with h_c rather than x r_v = self.inner_activation(hc_rv + inner_rv) inner_hvplus = K.dot(r_v * h_v * B_U[0], self.U_v[:, 2 * self.output_dim:]) h_vplus_tilde = self.activation(hc_ih + inner_hvplus) h_vplus = z_v * h_v + (1 - z_v) * h_vplus_tilde return h_c, h_vplus def call(self, all_inputs, mask=None): x_in, topology, x_types, horizon_w, horizon_i = all_inputs horizon = [horizon_w, horizon_i] if isinstance(mask, list): mask = mask[0] assert not self.stateful initial_states = self.get_initial_states(x_in) constants = self.get_constants(x_in) states = IKE.rttn( self.step, x_in, initial_states, topology, x_types, horizon, self.shape_key, self.W_ctx, mask=mask, constants=constants ) branch_tensor, traversal_tensor, horizon_states, p_horizons = states return [branch_tensor, traversal_tensor, horizon_states, p_horizons] def get_constants(self, x): constants = [] if 0 < self.dropout_U < 1: ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1))) ones = K.concatenate([ones] * self.output_dim, 1) B_U = [K.in_train_phase(K.dropout(ones, self.dropout_U), ones) for _ in range(3)] constants.append(B_U) else: constants.append([K.cast_to_floatx(1.) for _ in range(3)]) if 0 < self.dropout_W < 1: input_shape = self.input_spec[0].shape input_dim = input_shape[-1] ones = K.ones_like(K.reshape(x[:, 0, 0], (-1, 1))) ones = K.concatenate([ones] * input_dim, 1) B_W = [K.in_train_phase(K.dropout(ones, self.dropout_W), ones) for _ in range(3)] constants.append(B_W) else: constants.append([K.cast_to_floatx(1.) for _ in range(3)]) return constants def get_config(self): config = {'output_dim': self.output_dim, 'init': self.init.__name__, 'inner_init': self.inner_init.__name__, 'activation': self.activation.__name__, 'inner_activation': self.inner_activation.__name__, 'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None, 'U_regularizer': self.U_regularizer.get_config() if self.U_regularizer else None, 'b_regularizer': self.b_regularizer.get_config() if self.b_regularizer else None, 'dropout_W': self.dropout_W, 'dropout_U': self.dropout_U} base_config = super(RTTN, self).get_config() return dict(list(base_config.items()) + list(config.items()))