# -*- coding: utf-8 -*- # # Copyright 2018 Pascual Martinez-Gomez # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import numpy as np from gather import gather3 from gather import gather_output_shape3 seed = 23 np.random.seed(seed=seed) import keras.backend as K from keras import initializers from keras.layers import Input, Dense, TimeDistributed from keras.layers import Reshape from keras.layers import Permute from keras.layers import RepeatVector from keras.layers import Add, Multiply from keras.layers import GlobalMaxPooling1D from keras.layers.core import Lambda from keras.layers.normalization import BatchNormalization from keras.layers import Activation def make_gather_layer(): return Lambda(gather3, output_shape=gather_output_shape3) gather_layer = make_gather_layer() # TODO: make a class and have hyperparameters as class attributes. def tp1_node_update(graph_node_embs, node_rel, node_rel_weight, max_nodes, max_bi_relations, embed_dim, label): """ graph_node_embs has shape (batch_size, max_nodes per graph, embed_dim feats). """ dense_dim = embed_dim x = gather_layer([graph_node_embs, node_rel]) logging.debug('After gather3 shape: {0}'.format(x.shape)) x = Reshape((max_nodes * max_bi_relations, 2 * embed_dim))(x) x = TimeDistributed( Dense( dense_dim, kernel_initializer=initializers.Ones(), bias_initializer=initializers.Zeros(), name=label + '_dense1'))(x) # TODO: re-enable the batch normalization. # x = BatchNormalization(axis=2, name=label + '_bn1')(x) x = Activation('relu')(x) x = TimeDistributed( Dense( dense_dim, kernel_initializer=initializers.Ones(), bias_initializer=initializers.Zeros(), name=label + '_dense2'))(x) # x = BatchNormalization(axis=2, name=label + '_bn2')(x) x = Activation('relu')(x) normalizer = Reshape((max_nodes * max_bi_relations,))(node_rel_weight) normalizer = RepeatVector(dense_dim)(normalizer) normalizer = Permute((2, 1))(normalizer) x = Multiply()([x, normalizer]) x = Reshape((max_nodes, max_bi_relations, dense_dim))(x) x = Lambda( lambda xin: K.sum(xin, axis=2), output_shape=(None, max_nodes * max_bi_relations, dense_dim), name=label + '_integrate')(x) return x # TODO: Dense use_bias=True def make_pair_branch(graph_node_embs, max_nodes, max_bi_relations, label='child'): embed_dim = 2 dense_dim = embed_dim num_updates = 1 node_rel = Input( shape=(max_nodes, max_bi_relations, 2), dtype='int32', name=label + '_rel') # Weight to compute weighted sum (i.e. average). node_rel_weight = Input( shape=(max_nodes, max_bi_relations), dtype='float32', name=label + '_rel_weight') for i in range(num_updates): graph_node_embs = tp1_node_update( graph_node_embs, node_rel, node_rel_weight, max_nodes, max_bi_relations, embed_dim, label + '_it' + str(i)) return [graph_node_embs], [node_rel, node_rel_weight] def make_child_parent_branch(token_emb, max_nodes, max_bi_relations): node_indices = Input( shape=(max_nodes,), dtype='int32', name='node_inds') graph_node_embs = token_emb(node_indices) child_rel_outputs, child_rel_inputs = make_pair_branch( graph_node_embs, max_nodes, max_bi_relations, label='child') parent_rel_outputs, parent_rel_inputs = make_pair_branch( graph_node_embs, max_nodes, max_bi_relations, label='parent') x = Add(name='child_parent_add')( child_rel_outputs + parent_rel_outputs) # Integrate node embeddings into a single graph embedding. x = GlobalMaxPooling1D()(x) outputs = [x] inputs = [node_indices] + child_rel_inputs + parent_rel_inputs return outputs, inputs