# coding=utf-8 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for regularizers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # pylint: disable=g-direct-tensorflow-import import tensorflow.compat.v1 as tf from tf_slim.layers import summaries as summaries_lib from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test def setUpModule(): tf.disable_eager_execution() class SummariesTest(test.TestCase): def test_summarize_scalar_tensor(self): with self.cached_session(): scalar_var = variables.Variable(1) summary_op = summaries_lib.summarize_tensor(scalar_var) self.assertEqual(summary_op.op.type, 'ScalarSummary') def test_summarize_multidim_tensor(self): with self.cached_session(): tensor_var = variables.Variable([1, 2, 3]) summary_op = summaries_lib.summarize_tensor(tensor_var) self.assertEqual(summary_op.op.type, 'HistogramSummary') def test_summarize_activation(self): with self.cached_session(): var = variables.Variable(1) op = array_ops.identity(var, name='SummaryTest') summary_op = summaries_lib.summarize_activation(op) self.assertEqual(summary_op.op.type, 'HistogramSummary') names = [op.op.name for op in ops.get_collection(ops.GraphKeys.SUMMARIES)] self.assertEqual(len(names), 1) self.assertIn(u'SummaryTest/activation', names) def test_summarize_activation_relu(self): with self.cached_session(): var = variables.Variable(1) op = nn_ops.relu(var, name='SummaryTest') summary_op = summaries_lib.summarize_activation(op) self.assertEqual(summary_op.op.type, 'HistogramSummary') names = [op.op.name for op in ops.get_collection(ops.GraphKeys.SUMMARIES)] self.assertEqual(len(names), 2) self.assertIn(u'SummaryTest/zeros', names) self.assertIn(u'SummaryTest/activation', names) def test_summarize_activation_relu6(self): with self.cached_session(): var = variables.Variable(1) op = nn_ops.relu6(var, name='SummaryTest') summary_op = summaries_lib.summarize_activation(op) self.assertEqual(summary_op.op.type, 'HistogramSummary') names = [op.op.name for op in ops.get_collection(ops.GraphKeys.SUMMARIES)] self.assertEqual(len(names), 3) self.assertIn(u'SummaryTest/zeros', names) self.assertIn(u'SummaryTest/sixes', names) self.assertIn(u'SummaryTest/activation', names) def test_summarize_collection_regex(self): with self.cached_session(): var = variables.Variable(1) array_ops.identity(var, name='Test1') ops.add_to_collection('foo', array_ops.identity(var, name='Test2')) ops.add_to_collection('foo', array_ops.identity(var, name='Foobar')) ops.add_to_collection('foo', array_ops.identity(var, name='Test3')) summaries = summaries_lib.summarize_collection('foo', r'Test[123]') names = [op.op.name for op in summaries] self.assertEqual(len(names), 2) self.assertIn(u'Test2_summary', names) self.assertIn(u'Test3_summary', names) if __name__ == '__main__': test.main()