# correctness test # MNIST model taken from github/tensorflow/models from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import os import sys import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # import memory_saving_gradients from .. module_path=os.path.dirname(os.path.abspath(__file__)) sys.path.append(module_path+'/..') import memory_saving_gradients import mem_util TEST_DEVICE='/cpu:0' USE_REAL_DATA = False FLAGS_data_dir='/tmp/mnist_data' FLAGS_model_dir='/tmp/mnist_model' FLAGS_batch_size=1 FLAGS_data_format=None # from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets def train_dataset(data_dir): """Returns a tf.data.Dataset yielding (image, label) pairs for training.""" data = input_data.read_data_sets(data_dir, one_hot=True).train return tf.data.Dataset.from_tensor_slices((data.images, data.labels)) def mnist_model(inputs, mode, data_format): """Takes the MNIST inputs and mode and outputs a tensor of logits.""" # Input Layer # Reshape X to 4-D tensor: [batch_size, width, height, channels] # MNIST images are 28x28 pixels, and have one color channel inputs = tf.reshape(inputs, [-1, 28, 28, 1]) if data_format is None: # When running on GPU, transpose the data from channels_last (NHWC) to # channels_first (NCHW) to improve performance. # See https://www.tensorflow.org/performance/performance_guide#data_formats data_format = ('channels_first' if tf.test.is_gpu_available() else 'channels_last') if data_format == 'channels_first': inputs = tf.transpose(inputs, [0, 3, 1, 2]) # Convolutional Layer #1 # Computes 32 features using a 5x5 filter with ReLU activation. # Padding is added to preserve width and height. # Input Tensor Shape: [batch_size, 28, 28, 1] # Output Tensor Shape: [batch_size, 28, 28, 32] conv1 = tf.layers.conv2d( inputs=inputs, filters=32, kernel_size=[5, 5], padding='same', activation=tf.nn.relu, data_format=data_format) # Pooling Layer #1 # First max pooling layer with a 2x2 filter and stride of 2 # Input Tensor Shape: [batch_size, 28, 28, 32] # Output Tensor Shape: [batch_size, 14, 14, 32] pool1 = tf.layers.max_pooling2d( inputs=conv1, pool_size=[2, 2], strides=2, data_format=data_format) # Convolutional Layer #2 # Computes 64 features using a 5x5 filter. # Padding is added to preserve width and height. # Input Tensor Shape: [batch_size, 14, 14, 32] # Output Tensor Shape: [batch_size, 14, 14, 64] conv2 = tf.layers.conv2d( inputs=pool1, filters=64, kernel_size=[5, 5], padding='same', activation=tf.nn.relu, data_format=data_format) # Pooling Layer #2 # Second max pooling layer with a 2x2 filter and stride of 2 # Input Tensor Shape: [batch_size, 14, 14, 64] # Output Tensor Shape: [batch_size, 7, 7, 64] pool2 = tf.layers.max_pooling2d( inputs=conv2, pool_size=[2, 2], strides=2, data_format=data_format) # Flatten tensor into a batch of vectors # Input Tensor Shape: [batch_size, 7, 7, 64] # Output Tensor Shape: [batch_size, 7 * 7 * 64] pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64]) # Dense Layer # Densely connected layer with 1024 neurons # Input Tensor Shape: [batch_size, 7 * 7 * 64] # Output Tensor Shape: [batch_size, 1024] dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu) # Add dropout operation; 0.6 probability that element will be kept dropout = tf.layers.dropout( inputs=dense, rate=0.4, training=(mode == tf.estimator.ModeKeys.TRAIN)) # Logits layer # Input Tensor Shape: [batch_size, 1024] # Output Tensor Shape: [batch_size, 10] logits = tf.layers.dense(inputs=dropout, units=10) return logits GLOBAL_PROFILE = True DUMP_TIMELINES = False run_metadata = True def sessrun(*args, **kwargs): global sess, run_metadata if not GLOBAL_PROFILE: return sess.run(*args, **kwargs) run_metadata = tf.RunMetadata() kwargs['options'] = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) kwargs['run_metadata'] = run_metadata result = sess.run(*args, **kwargs) first_entry = args[0] if isinstance(first_entry, list): if len(first_entry) == 0 and len(args) == 1: return None first_entry = first_entry[0] if DUMP_TIMELINES: name = first_entry.name name = name.replace('/', '-') tl = timeline.Timeline(run_metadata.step_stats) ctf = tl.generate_chrome_trace_format() with open('timelines/%s.json'%(name,), 'w') as f: f.write(ctf) with open('timelines/%s.pbtxt'%(name,), 'w') as f: f.write(str(run_metadata)) return result def create_session(): from tensorflow.core.protobuf import rewriter_config_pb2 optimizer_options = tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0) config = tf.ConfigProto(operation_timeout_in_ms=150000, graph_options=tf.GraphOptions(optimizer_options=optimizer_options)) config.graph_options.rewrite_options.constant_folding = rewriter_config_pb2.RewriterConfig.OFF config.graph_options.place_pruned_graph = True return tf.Session(config=config) sess = None def train_mnist(): global sess # restrict to cpu:0 tf.reset_default_graph() tf.set_random_seed(1) np.random.seed(1) tf_dev = tf.device(TEST_DEVICE) tf_dev.__enter__() # FLAGS = parse_flags() # Train the model # replace Dataset ops with constant images because gradient rewriting # tries to differentiate graphs containing IteratorGetNext # TODO: make it work with Dataset ops images = tf.Variable(tf.random_uniform((FLAGS_batch_size, 28**2))) labels = tf.Variable(tf.concat([tf.ones((FLAGS_batch_size, 1)), tf.zeros((FLAGS_batch_size, 9))], axis=1)) def train_input_fn(): dataset = train_dataset(FLAGS_data_dir) dataset = dataset.batch(FLAGS_batch_size) (images, labels) = dataset.make_one_shot_iterator().get_next() num_images = FLAGS_batch_size return (images[:num_images], labels[:num_images]) if USE_REAL_DATA: images, labels = train_input_fn() # images = tf.stop_gradient(images) # labels = tf.stop_gradient(labels) logits = mnist_model(images, tf.estimator.ModeKeys.TRAIN, 'channels_last') cross_entropy = tf.losses.softmax_cross_entropy(logits=logits, onehot_labels=labels) loss = cross_entropy optimizer = tf.train.GradientDescentOptimizer(learning_rate=1e-2) vars = tf.trainable_variables() grads = tf.gradients(loss, vars) grads_and_vars = zip(grads, vars) train_op = optimizer.apply_gradients(grads_and_vars) sess = create_session() sess.run(tf.global_variables_initializer()) print("Loss %.5f" %(sess.run(loss))) for i in range(10): sessrun(train_op) mem_use = mem_util.peak_memory(run_metadata)[TEST_DEVICE]/1e6 print("Loss %.5f, memory %.2f MB" %(sess.run(loss), mem_use)) # should print something like this for actual dataset # 2.12764 # 1.87759 # 1.54445 # 1.29149 # 1.18474 # 0.884424 # 0.69454 # 0.770236 # 0.629259 # 0.654465 assert sess.run(loss) < 100 def test_correctness(capsys): # enable printing during successful test run under pytest, uncomment these # if capsys: # pytest_decorator = capsys.disabled() # pytest_decorator.__enter__() # Loss 0.01803, memory 399.10 MB # Loss 0.00002, memory 399.10 MB # Loss 0.00000, memory 399.10 MB # Running with memory saving # Extracting /tmp/mnist_data/train-images-idx3-ubyte.gz # Extracting /tmp/mnist_data/train-labels-idx1-ubyte.gz # Extracting /tmp/mnist_data/t10k-images-idx3-ubyte.gz # Extracting /tmp/mnist_data/t10k-labels-idx1-ubyte.gz # Loss 0.07283, memory 380.72 MB # Loss 0.00398, memory 351.23 MB # Loss 0.00035, memory 351.23 MB def grads(ys, xs, grad_ys=None, **kwargs): return memory_saving_gradients.gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs) old_grads = tf.gradients tf.__dict__["gradients"] = grads print("Running with memory saving") train_mnist() print("\nRunning with regular gradient") tf.__dict__["gradients"] = old_grads train_mnist() def main(unused_argv): test_correctness(None) if __name__ == '__main__': main([])