# -*- coding: utf-8 -*- # Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for SessionRunHooks. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import os import tempfile import shutil import time import tensorflow as tf from tensorflow.python.training import monitored_session # pylint: disable=E0611 from tensorflow import gfile from seq2seq import graph_utils from seq2seq.training import hooks class TestPrintModelAnalysisHook(tf.test.TestCase): """Tests the `PrintModelAnalysisHook` hook""" def test_begin(self): model_dir = tempfile.mkdtemp() outfile = tempfile.NamedTemporaryFile() tf.get_variable("weigths", [128, 128]) hook = hooks.PrintModelAnalysisHook( params={}, model_dir=model_dir, run_config=tf.contrib.learn.RunConfig()) hook.begin() with gfile.GFile(os.path.join(model_dir, "model_analysis.txt")) as file: file_contents = file.read().strip() self.assertEqual(file_contents.decode(), "_TFProfRoot (--/16.38k params)\n" " weigths (128x128, 16.38k/16.38k params)") outfile.close() class TestTrainSampleHook(tf.test.TestCase): """Tests `TrainSampleHook` class. """ def setUp(self): super(TestTrainSampleHook, self).setUp() self.model_dir = tempfile.mkdtemp() self.sample_dir = os.path.join(self.model_dir, "samples") # The hook expects these collections to be in the graph pred_dict = {} pred_dict["predicted_tokens"] = tf.constant([["Hello", "World", "笑w"]]) pred_dict["labels.target_tokens"] = tf.constant([["Hello", "World", "笑w"]]) pred_dict["labels.target_len"] = tf.constant(2), graph_utils.add_dict_to_collection(pred_dict, "predictions") def tearDown(self): super(TestTrainSampleHook, self).tearDown() shutil.rmtree(self.model_dir) def test_sampling(self): hook = hooks.TrainSampleHook( params={"every_n_steps": 10}, model_dir=self.model_dir, run_config=tf.contrib.learn.RunConfig()) global_step = tf.contrib.framework.get_or_create_global_step() no_op = tf.no_op() hook.begin() with self.test_session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) #pylint: disable=W0212 mon_sess = monitored_session._HookedSession(sess, [hook]) # Should trigger for step 0 sess.run(tf.assign(global_step, 0)) mon_sess.run(no_op) outfile = os.path.join(self.sample_dir, "samples_000000.txt") with open(outfile, "rb") as readfile: self.assertIn("Prediction followed by Target @ Step 0", readfile.read().decode("utf-8")) # Should not trigger for step 9 sess.run(tf.assign(global_step, 9)) mon_sess.run(no_op) outfile = os.path.join(self.sample_dir, "samples_000009.txt") self.assertFalse(os.path.exists(outfile)) # Should trigger for step 10 sess.run(tf.assign(global_step, 10)) mon_sess.run(no_op) outfile = os.path.join(self.sample_dir, "samples_000010.txt") with open(outfile, "rb") as readfile: self.assertIn("Prediction followed by Target @ Step 10", readfile.read().decode("utf-8")) class TestMetadataCaptureHook(tf.test.TestCase): """Test for the MetadataCaptureHook""" def setUp(self): super(TestMetadataCaptureHook, self).setUp() self.model_dir = tempfile.mkdtemp() def tearDown(self): super(TestMetadataCaptureHook, self).tearDown() shutil.rmtree(self.model_dir) def test_capture(self): global_step = tf.contrib.framework.get_or_create_global_step() # Some test computation some_weights = tf.get_variable("weigths", [2, 128]) computation = tf.nn.softmax(some_weights) hook = hooks.MetadataCaptureHook( params={"step": 5}, model_dir=self.model_dir, run_config=tf.contrib.learn.RunConfig()) hook.begin() with self.test_session() as sess: sess.run(tf.global_variables_initializer()) #pylint: disable=W0212 mon_sess = monitored_session._HookedSession(sess, [hook]) # Should not trigger for step 0 sess.run(tf.assign(global_step, 0)) mon_sess.run(computation) self.assertEqual(gfile.ListDirectory(self.model_dir), []) # Should trigger *after* step 5 sess.run(tf.assign(global_step, 5)) mon_sess.run(computation) self.assertEqual(gfile.ListDirectory(self.model_dir), []) mon_sess.run(computation) self.assertEqual( set(gfile.ListDirectory(self.model_dir)), set(["run_meta", "tfprof_log", "timeline.json"])) if __name__ == "__main__": tf.test.main()