# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """This is the benchmark code used in the ICLR 2017 paper. The paper is entitled Deep Learning with Dynamic Computation graphs." """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import logging import random import time # import google3 import numpy as np import six import tensorflow as tf from tensorflow_fold.public import loom tf.flags.DEFINE_integer("vector_size", 1024, "Size of tree RNN output vector.") tf.flags.DEFINE_integer("tree_size", 128, "Size of trees to test.") tf.flags.DEFINE_integer("num_repeats", 2, "Numer of times to repeat test.") tf.flags.DEFINE_integer("num_epochs", 1, "Number of epochs.") tf.flags.DEFINE_boolean("tree_lstm", True, "Use a tree lstm.") tf.flags.DEFINE_string("tree_type", "random", "Type of tree to construct." "Valid values are: random, sequence, or balanced") tf.flags.DEFINE_boolean("log_device_placement", False, "Log device placement.") tf.flags.DEFINE_boolean("direct_feed_dict", True, "Use direct feed_dict.") tf.flags.DEFINE_boolean("serialize_and_merge", False, "Serialize each tree separately and merge them.") tf.flags.DEFINE_boolean("train_with_loss", True, "Run SGD on a loss.") tf.flags.DEFINE_boolean("quick_run", False, "Use a limited set of batch sizes.") FLAGS = tf.flags.FLAGS logging.basicConfig(format="%(asctime)s %(message)s") _logger = logging.getLogger("benchmark") _logger.setLevel(logging.INFO) def make_random_tree(size): """Make a random binary tree with size nodes.""" if size <= 1: return 0 r = random.randint(1, size-1) return (make_random_tree(r), make_random_tree(size-r)) def make_sequence_tree(size): """Make a maximally unbalanced tree (a sequence) with size nodes.""" if size <= 1: return 0 return (make_sequence_tree(size-1), 0) def make_balanced_tree(size): """Make a balanced binary tree with size nodes, where size is a power of 2.""" if size <= 1: return 0 return (make_balanced_tree(size/2), make_balanced_tree(size/2)) def make_input_tree(size): """Make a tree based on the value of the tree_type flag.""" if FLAGS.tree_type == "sequence": return make_sequence_tree(size) elif FLAGS.tree_type == "balanced": return make_balanced_tree(size) elif FLAGS.tree_type == "random": return make_random_tree(size) raise ValueError("Invalid tree type: %s." % FLAGS.tree_type) def index_type(): return loom.TypeShape("int32", ()) def vector_type(): return loom.TypeShape("float32", (FLAGS.vector_size,)) class LeafOp(loom.LoomOp): """Create a LoomOp for the leaf nodes. We use a simple embedding table.""" def __init__(self, embedding_size): super(LeafOp, self).__init__([index_type()], [vector_type()]) self._embedding_size = embedding_size self._embedding = None self._vscope = "Leaf" def instantiate_batch(self, inputs): return [self(*inputs)] def __call__(self, indices): if self._embedding is None: with tf.variable_scope(self._vscope): self._embedding = ( tf.get_variable("embedding_table", [self._embedding_size, FLAGS.vector_size], initializer=tf.random_uniform_initializer())) return tf.gather(self._embedding, indices) class NonTerminalOp(loom.LoomOp): """Create a LoomOp for the non-terminals -- either a tree RNN or LSTM.""" def __init__(self): super(NonTerminalOp, self).__init__([vector_type(), vector_type()], [vector_type()]) self._weights = None self._bias = None self._vscope = "NonTerminal" def instantiate_batch(self, inputs): return [self(*inputs)] def tree_fc(self, left, right): # A simple tree RNN with a single fully connected layer. if self._weights is None: with tf.variable_scope(self._vscope): self._weights = tf.get_variable( "weights", [FLAGS.vector_size*2, FLAGS.vector_size], initializer=tf.uniform_unit_scaling_initializer(1.43)) self._bias = tf.get_variable("bias", [FLAGS.vector_size], initializer=tf.zeros_initializer()) x = tf.concat([left, right], 1) result = tf.add(tf.matmul(x, self._weights), self._bias) return tf.nn.relu(result) def tree_lstm(self, left, right): # A variation on the tree LSTM -- we add an extra hidden layer. if self._weights is None: with tf.variable_scope(self._vscope): self._weights_0 = tf.get_variable( "weights_0", [FLAGS.vector_size*2, FLAGS.vector_size], initializer=tf.uniform_unit_scaling_initializer(1.43)) self._bias_0 = tf.get_variable("bias_0", [FLAGS.vector_size], initializer=tf.zeros_initializer()) self._weights = tf.get_variable( "weights", [FLAGS.vector_size, FLAGS.vector_size*4], initializer=tf.uniform_unit_scaling_initializer(1.0)) self._bias = tf.get_variable("bias", [FLAGS.vector_size*4], initializer=tf.zeros_initializer()) # One hidden layer x = tf.concat([left, right], 1) h0 = tf.nn.relu(tf.add(tf.matmul(x, self._weights_0), self._bias_0)) # Do a single matrix multiply to compute all gates h1 = tf.add(tf.matmul(h0, self._weights), self._bias) (hfl, hfr, hi, hg) = tf.split(h1, 4, axis=1) fl = tf.nn.sigmoid(hfl) # forget left fr = tf.nn.sigmoid(hfr) # forget right i = tf.nn.sigmoid(hi) # input gate g = tf.nn.tanh(hg) # computation ylr = tf.add(tf.multiply(fl, left), tf.multiply(fr, right)) ygi = tf.multiply(i, g) y = tf.add(ylr, ygi) return y def __call__(self, left, right): if FLAGS.tree_lstm: return self.tree_lstm(left, right) else: return self.tree_fc(left, right) class ModelBase(object): """Base class for the benchmark model.""" def __init__(self, batch_size): # Use the tree size as the number of entries in the embedding table. self._embedding_size = FLAGS.tree_size self.batch_size = batch_size self._leaf_op = LeafOp(self._embedding_size) self._non_terminal_op = NonTerminalOp() self.elapsed_times = [] self.elapsed_fd_times = [] def random_index(self): """Get a random index into the embedding table.""" return random.randint(0, self._embedding_size-1) def name(self): """Return the name of this model -- to be overridden by base classes.""" return "Undefined." def build_model(self): """Build self._output -- to be overridden by base classes.""" pass def build_model_loss(self): """Build model and add a loss function.""" self.build_model() _logger.info("Differentiating.") if FLAGS.train_with_loss: # We don't actually care what the loss function is; we're just timing. self._loss = tf.nn.l2_loss(self._output) optr = tf.train.GradientDescentOptimizer(0.00001) self._train = optr.minimize(self._loss) else: self._loss = tf.reduce_sum(self._output) self._train = tf.constant(0.0) def build_feed_dict(self): """Build a feed dict for the model -- to be overridden by base classes.""" return {} def evaluate(self, sess): """Run the model on random input data.""" _logger.info("Testing for batch size %d.", self.batch_size) # We run the graph twice without timing to force any hidden initialization # and/or caching behavior to occur. for i in six.moves.xrange(0, 1): _logger.info("Burn-in %d.", i) fd = self.build_feed_dict() sess.run([self._train, self._loss], feed_dict=fd) # Small batch sizes have greater timing variation, so we compensate # by doing more batches. batch_size = self.batch_size if batch_size < 32: num_batches = int(32/self.batch_size) * FLAGS.num_repeats else: num_batches = FLAGS.num_repeats for batch in six.moves.xrange(0, num_batches): _logger.info("Batch: %d", batch) _logger.info("Build feed_dict.") start_time_fd = time.time() fd = self.build_feed_dict() end_time_fd = time.time() elapsed_fd = end_time_fd - start_time_fd self.elapsed_fd_times.append(elapsed_fd) _logger.info("Run.") start_time = time.time() [_, loss_v] = sess.run([self._train, self._loss], feed_dict=fd) end_time = time.time() elapsed = end_time - start_time self.elapsed_times.append(elapsed) _logger.info("Done. Elapsed: %f [%f]. Loss: %f", elapsed, elapsed_fd, loss_v) def run(self): """Build a graph and run the model on random input data.""" _logger.info("Creating graph.") with tf.Graph().as_default(): _logger.info("Building model.") self.build_model_loss() _logger.info("Starting session.") config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement) with tf.Session(config=config) as sess: _logger.info("Initializing variables.") sess.run(tf.global_variables_initializer()) _logger.info("Starting timing test.") self.evaluate(sess) _logger.info("Ending session.") class TfModel(ModelBase): """Native tensorflow tree model. All trees have the same shape so that they can be batched together without using Loom. """ def __init__(self, batch_size): """Create a tree model that uses native TensorFlow.""" # Use the tree size as the number of entries in the embedding table. super(TfModel, self).__init__(batch_size) self._placeholders = [] def name(self): return "TensorFlow" def build_model(self): """The tensorflow model uses a tree of fixed shape.""" tree = make_input_tree(FLAGS.tree_size) _logger.info("Tree: %s", str(tree)) self._output = self.build_graph(tree) def build_graph(self, root): """Build a tensorflow computation graph with the same shape as root.""" if isinstance(root, tuple): left = self.build_graph(root[0]) right = self.build_graph(root[1]) # We don't use the LoomOp with Loom. # Just call it to get the computation graph. return self._non_terminal_op(left, right) else: # Insert a placeholder for every leaf node. # We substitute random indices for the placeholders in build_feed_dict. indices = tf.placeholder(dtype="int32", shape=[self.batch_size]) self._placeholders.append(indices) return self._leaf_op(indices) def build_feed_dict(self): """Create random indices for leaf nodes.""" # Pass the indices via feed_dict, to ensure a fair comparison with Loom. # Each leaf has one index (times the number of trees in the batch). def rand_indices(): return np.array( [self.random_index() for _ in six.moves.xrange(0, self.batch_size)], dtype="int32") return {p: rand_indices() for p in self._placeholders} class LoomModel(ModelBase): """Model which uses dynamic batching with the Loom API.""" def __init__(self, batch_size, proper_batching): """Create a tree model that uses the Loom API. Args: batch_size: The number of trees in the batch. proper_batching: If True, each tree has a different shape, otherwise each tree in the batch has the same shape for comparison with native TensorFlow. """ # Use the tree size as the number of entries in the embedding table. super(LoomModel, self).__init__(batch_size) self._proper_batching = proper_batching def name(self): return "Loom" def build_model(self): """Build a model using Loom.""" # Create a dictionary of the LoomOps that the model uses. named_tensors = {} named_ops = { "leaf": self._leaf_op, "non_terminal": self._non_terminal_op } # Make a random tree. self._tree = make_input_tree(FLAGS.tree_size) if not self._proper_batching: _logger.info("Tree: %s", str(self._tree)) # Register all the of LoomOps with Loom. self._loom = loom.Loom(named_tensors=named_tensors, named_ops=named_ops, direct_feed_dict=FLAGS.direct_feed_dict) # Grab the TensorFlow tensor that holds the Loom output. self._output = self._loom.output_tensor(vector_type()) def get_input_tree(self): """Get an input tree, either random or of fixed shape.""" if self._proper_batching: # Make a different tree shape for each input in the batch. return make_input_tree(FLAGS.tree_size) else: # Use the same shape for each input in the batch. return self._tree def build_feed_dict(self): if FLAGS.serialize_and_merge: return self.build_feed_dict_with_serialize_and_merge() _logger.info("Traversing trees.") # The weaver is an object that can invoke LoomOps, and create a feed_dict. weaver = self._loom.make_weaver() # Recurse over each tree in the batch for _ in six.moves.xrange(0, self.batch_size): root = self.traverse_tree(self.get_input_tree(), weaver) weaver.add_output(root) # Now build the feed_dict, which contains both indices into the embedding # tables, and indices for the gather nodes inserted by Loom. if FLAGS.direct_feed_dict: _logger.info("Calling build_feed_dict in direct mode.") else: _logger.info("Calling build_feed_dict with serialization.") return weaver.build_feed_dict() def build_feed_dict_with_serialize_and_merge(self): """Serialize each tree with a separate weaver, and merge them together.""" if FLAGS.direct_feed_dict: raise RuntimeError("Cannot serialize separately with direct_feed_dict.") _logger.info("Traversing and serializing trees.") # Recurse over each tree in the batch. serialized_trees = [] for _ in six.moves.xrange(0, self.batch_size): weaver = self._loom.make_weaver() root = self.traverse_tree(self.get_input_tree(), weaver) weaver.add_output(root) serialized_trees.append(weaver.serialize()) # Pass the serialized trees as the input tensors. return {self._loom.input_tensor: serialized_trees} def traverse_tree(self, node, weaver): # Recursive function to invoke a LoomOp on each node in the tree. if isinstance(node, tuple): left = self.traverse_tree(node[0], weaver) right = self.traverse_tree(node[1], weaver) # Invoke the Loom non_terminal op. return weaver.non_terminal(left, right) else: # Invoke the Loom leaf op, on a random index. idx = weaver(np.array(self.random_index(), dtype="int32")) return weaver.leaf(idx) def test_model(model_class, *args): """Do a timing test for a model on a range of batch sizes.""" test_results = {} if FLAGS.quick_run: batch_size_list = [1, 1024] else: batch_size_list = [1, 32, 64, 128, 256, 1024] for batch_size in batch_size_list: test_results[batch_size] = ([], []) for _ in six.moves.xrange(0, FLAGS.num_epochs): model = model_class(batch_size, *args) model.run() test_results[batch_size][0].extend(model.elapsed_times) test_results[batch_size][1].extend(model.elapsed_fd_times) return test_results def print_results(test_results, model_name): """Print the results of a timing test.""" def avg(lst): return sum(lst)/len(lst) _logger.info("Results for model %s:", model_name) result_list = list(six.iteritems(test_results)) for (b, r) in sorted(result_list, reverse=True): (times, times_fd) = r tree_times = [t/b for t in times] tree_times_fd = [t/b for t in times_fd] _logger.info("Batch size: %d | per batch: %f [%f, %f] | " "per tree: %f [%f, %f] | feed_dict: %f [%f, %f]", b, avg(times), min(times), max(times), avg(tree_times), min(tree_times), max(tree_times), avg(tree_times_fd), min(tree_times_fd), max(tree_times_fd)) def compare_results(results1, results2, model_name1, model_name2): """Compare the results between two models.""" def avg(lst): return sum(lst)/len(lst) _logger.info("Comparing %s to %s", model_name1, model_name2) rs1 = sorted(list(six.iteritems(results1)), reverse=True) rs2 = sorted(list(six.iteritems(results2)), reverse=True) for r in zip(rs1, rs2): ((b, (times1, _)), (_, (times2, _))) = r _logger.info("Batch size: %d | ratio: %f", b, avg(times2)/avg(times1)) def compare_total_speedup(test_results, baseline): """Get the total speedup over base line time for a single tree.""" def avg(lst): return sum(lst)/len(lst) baseline_tree_time = avg(baseline[0]) _logger.info("Speedup over baseline time %f.", baseline_tree_time) result_list = list(six.iteritems(test_results)) for (b, r) in sorted(result_list, reverse=True): (times, _) = r tree_times = [t/b for t in times] avg_time = avg(tree_times) _logger.info("Batch size: %d | tree time: %f, speedup: %f", b, avg_time, baseline_tree_time/avg_time) def main(unused_argv): tf.logging.set_verbosity(tf.logging.INFO) _logger.info("Tensorflow Version: %s", str(tf.__version__)) tf_results = test_model(TfModel) loom_results = test_model(LoomModel, False) loom_results_proper = test_model(LoomModel, True) if FLAGS.tree_lstm: model_type = "GRU" else: model_type = "FC" _logger.info("====================================================") _logger.info("Num epochs: %d; repeats per epoch %d", FLAGS.num_epochs, FLAGS.num_repeats) _logger.info("Model type: %s, %s", model_type, FLAGS.tree_type) _logger.info("Vector size: %d", FLAGS.vector_size) _logger.info("Tree size: %d", FLAGS.tree_size) print_results(tf_results, "TensorFlow") print_results(loom_results, "Loom") print_results(loom_results_proper, "Loom with random trees") compare_results(tf_results, loom_results, "TensorFlow", "Loom") compare_total_speedup(loom_results, tf_results[1]) _logger.info("Finished benchmarks.") if __name__ == "__main__": tf.app.run()