# Copyright 2019 Babylon Partners. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Static graph batching example. """ import numpy as np import tensorflow as tf from collections import OrderedDict from scipy import sparse from rgat.layers import RGAT from rgat.utils import graph_utils tf.logging.set_verbosity(tf.logging.DEBUG) FLAGS = tf.flags.FLAGS tf.flags.DEFINE_integer("seed", 42, "The random seed.") tf.flags.DEFINE_integer("relations", 3, "The number of relations.") tf.flags.DEFINE_integer("nodes_min", 2, "The smallest number of nodes any graph can have. This " "is used for random graph generation.") tf.flags.DEFINE_integer("nodes_max", 9, "The largest number of nodes any graph can have. This " "is used for random graph generation.") tf.flags.DEFINE_integer("batch_size", 37, "The batch size.") tf.flags.DEFINE_integer("attention_heads", 7, "The number of attention heads.") tf.flags.DEFINE_integer("features_dim", 3, "The input dimensionality.") tf.flags.DEFINE_integer("units", 5, "The number of units in the layer.") def _build_support(size): sup = np.random.uniform(size=(size, size)) sup = (sup > 0.75).astype(int) return sparse.coo_matrix(sup) def _built_relational_support(size, names): return OrderedDict([(r, _build_support(size)) for r in names]) def get_architecture(): inputs_ph = tf.placeholder( dtype=tf.float32, shape=[None, FLAGS.features_dim], name="features_") support_ph = tf.sparse_placeholder( dtype=tf.float32, shape=[None, None], name="support_") tf.logging.info("Reordering indices of support - this is extremely " "important as sparse operations assume sparse indices have " "been ordered.") support_reorder = tf.sparse_reorder(support_ph) rgat_layer = RGAT(units=FLAGS.units, relations=FLAGS.relations) outputs = rgat_layer(inputs=inputs_ph, support=support_reorder) return inputs_ph, support_ph, outputs def get_batch_of_features_supports_values(): tf.logging.info("Generating support names.") rel_names = ["rel_{}".format(i) for i in range(FLAGS.relations)] tf.logging.info("Generating number of nodes in each element of the batch.") graph_sizes = [ np.random.random_integers(low=FLAGS.nodes_min, high=FLAGS.nodes_max) for _ in range(FLAGS.batch_size)] tf.logging.info("Generating fake input features for each node in each " "graph.") features_val = [np.random.uniform(size=(graph_size, FLAGS.features_dim)) for graph_size in graph_sizes] supports_val = [ _built_relational_support(size=graph_size, names=rel_names) for graph_size in graph_sizes] return features_val, supports_val def main(unused_argv): tf.logging.info("{} Flags {}".format('*'*15, '*'*15)) for k, v in FLAGS.flag_values_dict().items(): tf.logging.info("FLAG `{}`: {}".format(k, v)) tf.logging.info('*' * (2 * 15 + len(' Flags '))) np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) features_ph, support_ph, outputs = get_architecture() features_val, supports_val = get_batch_of_features_supports_values() sess = tf.Session() sess.run(tf.global_variables_initializer()) # Route 1: Run RGAT on each element in the batch separately and combine the # results individual_supports = [ graph_utils.relational_supports_to_support(d) for d in supports_val] individual_supports = [ graph_utils.triple_from_coo(s) for s in individual_supports] individual_results = [ sess.run(outputs, feed_dict={features_ph: fv, support_ph: sv}) for fv, sv in zip(features_val, individual_supports)] individual_results = np.concatenate(individual_results, axis=0) # Route 2: First combine the batch into a single graph and pass everything # through in one go combined_features_val = np.concatenate(features_val, axis=0) combined_supports = graph_utils.batch_of_relational_supports_to_support( supports_val) combined_supports = graph_utils.triple_from_coo(combined_supports) combined_results = sess.run( outputs, feed_dict={features_ph: combined_features_val, support_ph: combined_supports}) if np.allclose(combined_results, individual_results): tf.logging.info("The approaches match!") else: raise ValueError( "Doing each element in a batch independently does not produce the " "same results as doing all the batch in one go. Something has " "clearly broken. Please contact the author ASAP :).") if __name__ == '__main__': tf.app.run(main=main)