# coding=utf-8 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for slim.data.parallel_reader.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import tempfile import tensorflow.compat.v1 as tf from tf_slim import queues from tf_slim.data import parallel_reader from tf_slim.data import test_utils # pylint:disable=g-direct-tensorflow-import from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import input as input_lib from tensorflow.python.training import supervisor # pylint:enable=g-direct-tensorflow-import def setUpModule(): tf.disable_eager_execution() class ParallelReaderTest(test.TestCase): def setUp(self): super(ParallelReaderTest, self).setUp() ops.reset_default_graph() def _verify_all_data_sources_read(self, shared_queue): with self.cached_session(): tfrecord_paths = test_utils.create_tfrecord_files( tempfile.mkdtemp(), num_files=3) num_readers = len(tfrecord_paths) p_reader = parallel_reader.ParallelReader( io_ops.TFRecordReader, shared_queue, num_readers=num_readers) data_files = parallel_reader.get_data_files(tfrecord_paths) filename_queue = input_lib.string_input_producer(data_files) key, value = p_reader.read(filename_queue) count0 = 0 count1 = 0 count2 = 0 num_reads = 50 sv = supervisor.Supervisor(logdir=tempfile.mkdtemp()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) for _ in range(num_reads): current_key, _ = sess.run([key, value]) if '0-of-3' in str(current_key): count0 += 1 if '1-of-3' in str(current_key): count1 += 1 if '2-of-3' in str(current_key): count2 += 1 self.assertGreater(count0, 0) self.assertGreater(count1, 0) self.assertGreater(count2, 0) self.assertEqual(count0 + count1 + count2, num_reads) def _verify_read_up_to_out(self, shared_queue): with self.cached_session(): num_files = 3 num_records_per_file = 7 tfrecord_paths = test_utils.create_tfrecord_files( tempfile.mkdtemp(), num_files=num_files, num_records_per_file=num_records_per_file) p_reader = parallel_reader.ParallelReader( io_ops.TFRecordReader, shared_queue, num_readers=5) data_files = parallel_reader.get_data_files(tfrecord_paths) filename_queue = input_lib.string_input_producer(data_files, num_epochs=1) key, value = p_reader.read_up_to(filename_queue, 4) count0 = 0 count1 = 0 count2 = 0 all_keys_count = 0 all_values_count = 0 sv = supervisor.Supervisor(logdir=tempfile.mkdtemp()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) while True: try: current_keys, current_values = sess.run([key, value]) self.assertEqual(len(current_keys), len(current_values)) all_keys_count += len(current_keys) all_values_count += len(current_values) for current_key in current_keys: if '0-of-3' in str(current_key): count0 += 1 if '1-of-3' in str(current_key): count1 += 1 if '2-of-3' in str(current_key): count2 += 1 except errors_impl.OutOfRangeError: break self.assertEqual(count0, num_records_per_file) self.assertEqual(count1, num_records_per_file) self.assertEqual(count2, num_records_per_file) self.assertEqual( all_keys_count, num_files * num_records_per_file) self.assertEqual(all_values_count, all_keys_count) self.assertEqual( count0 + count1 + count2, all_keys_count) def testRandomShuffleQueue(self): shared_queue = data_flow_ops.RandomShuffleQueue( capacity=256, min_after_dequeue=128, dtypes=[dtypes_lib.string, dtypes_lib.string]) self._verify_all_data_sources_read(shared_queue) def testFIFOSharedQueue(self): shared_queue = data_flow_ops.FIFOQueue( capacity=256, dtypes=[dtypes_lib.string, dtypes_lib.string]) self._verify_all_data_sources_read(shared_queue) def testReadUpToFromRandomShuffleQueue(self): shared_queue = data_flow_ops.RandomShuffleQueue( capacity=55, min_after_dequeue=28, dtypes=[dtypes_lib.string, dtypes_lib.string], shapes=[[], []]) self._verify_read_up_to_out(shared_queue) def testReadUpToFromFIFOQueue(self): shared_queue = data_flow_ops.FIFOQueue( capacity=99, dtypes=[dtypes_lib.string, dtypes_lib.string], shapes=[[], []]) self._verify_read_up_to_out(shared_queue) class ParallelReadTest(test.TestCase): def setUp(self): super(ParallelReadTest, self).setUp() ops.reset_default_graph() def testTFRecordReader(self): with self.cached_session(): self._tfrecord_paths = test_utils.create_tfrecord_files( tempfile.mkdtemp(), num_files=3) key, value = parallel_reader.parallel_read( self._tfrecord_paths, reader_class=io_ops.TFRecordReader, num_readers=3) sv = supervisor.Supervisor(logdir=tempfile.mkdtemp()) with sv.prepare_or_wait_for_session() as sess: sv.start_queue_runners(sess) flowers = 0 num_reads = 100 for _ in range(num_reads): current_key, _ = sess.run([key, value]) if 'flowers' in str(current_key): flowers += 1 self.assertGreater(flowers, 0) self.assertEqual(flowers, num_reads) class SinglePassReadTest(test.TestCase): def setUp(self): super(SinglePassReadTest, self).setUp() ops.reset_default_graph() def testOutOfRangeError(self): with self.cached_session(): [tfrecord_path] = test_utils.create_tfrecord_files( tempfile.mkdtemp(), num_files=1) key, value = parallel_reader.single_pass_read( tfrecord_path, reader_class=io_ops.TFRecordReader) init_op = variables.local_variables_initializer() with self.cached_session() as sess: sess.run(init_op) with queues.QueueRunners(sess): num_reads = 11 with self.assertRaises(errors_impl.OutOfRangeError): for _ in range(num_reads): sess.run([key, value]) def testTFRecordReader(self): with self.cached_session(): [tfrecord_path] = test_utils.create_tfrecord_files( tempfile.mkdtemp(), num_files=1) key, value = parallel_reader.single_pass_read( tfrecord_path, reader_class=io_ops.TFRecordReader) init_op = variables.local_variables_initializer() with self.cached_session() as sess: sess.run(init_op) with queues.QueueRunners(sess): flowers = 0 num_reads = 9 for _ in range(num_reads): current_key, _ = sess.run([key, value]) if 'flowers' in str(current_key): flowers += 1 self.assertGreater(flowers, 0) self.assertEqual(flowers, num_reads) if __name__ == '__main__': test.main()