# coding=utf-8 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for slim.data.prefetch_queue.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow.compat.v1 as tf from tf_slim.data import prefetch_queue # pylint:disable=g-direct-tensorflow-import from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import input as input_lib from tensorflow.python.training import queue_runner_impl # pylint:enable=g-direct-tensorflow-import def setUpModule(): tf.disable_eager_execution() class PrefetchQueueTest(test.TestCase): def testOneThread(self): with self.cached_session() as sess: batch_size = 10 image_size = 32 num_batches = 5 zero64 = constant_op.constant(0, dtype=dtypes.int64) examples = variables.Variable(zero64) counter = examples.count_up_to(num_batches * batch_size) image = random_ops.random_normal( [image_size, image_size, 3], dtype=dtypes.float32, name='images') label = random_ops.random_uniform( [1], 0, 10, dtype=dtypes.int32, name='labels') batches = input_lib.batch( [counter, image, label], batch_size=batch_size, num_threads=1) batches = prefetch_queue.prefetch_queue(batches).dequeue() variables.global_variables_initializer().run() threads = queue_runner_impl.start_queue_runners() for i in range(num_batches): results = sess.run(batches) self.assertAllEqual(results[0], np.arange(i * batch_size, (i + 1) * batch_size)) self.assertEqual(results[1].shape, (batch_size, image_size, image_size, 3)) self.assertEqual(results[2].shape, (batch_size, 1)) # Reached the limit. with self.assertRaises(errors_impl.OutOfRangeError): sess.run(batches) for thread in threads: thread.join() def testMultiThread(self): with self.cached_session() as sess: batch_size = 10 image_size = 32 num_batches = 5 zero64 = constant_op.constant(0, dtype=dtypes.int64) examples = variables.Variable(zero64) counter = examples.count_up_to(num_batches * batch_size) image = random_ops.random_normal( [image_size, image_size, 3], dtype=dtypes.float32, name='images') label = random_ops.random_uniform( [1], 0, 10, dtype=dtypes.int32, name='labels') batches = input_lib.batch( [counter, image, label], batch_size=batch_size, num_threads=4) batches = prefetch_queue.prefetch_queue(batches).dequeue() variables.global_variables_initializer().run() threads = queue_runner_impl.start_queue_runners() value_counter = [] for _ in range(num_batches): results = sess.run(batches) value_counter.append(results[0]) self.assertEqual(results[1].shape, (batch_size, image_size, image_size, 3)) self.assertEqual(results[2].shape, (batch_size, 1)) self.assertAllEqual( np.sort(np.concatenate(value_counter)), np.arange(0, num_batches * batch_size)) # Reached the limit. with self.assertRaises(errors_impl.OutOfRangeError): sess.run(batches) for thread in threads: thread.join() def testMultipleDequeue(self): with self.cached_session() as sess: batch_size = 10 image_size = 32 num_batches = 4 zero64 = constant_op.constant(0, dtype=dtypes.int64) examples = variables.Variable(zero64) counter = examples.count_up_to(num_batches * batch_size) image = random_ops.random_normal( [image_size, image_size, 3], dtype=dtypes.float32, name='images') label = random_ops.random_uniform( [1], 0, 10, dtype=dtypes.int32, name='labels') batches = input_lib.batch( [counter, image, label], batch_size=batch_size, num_threads=4) batcher = prefetch_queue.prefetch_queue(batches) batches_list = [batcher.dequeue() for _ in range(2)] variables.global_variables_initializer().run() threads = queue_runner_impl.start_queue_runners() value_counter = [] for _ in range(int(num_batches / 2)): for batches in batches_list: results = sess.run(batches) value_counter.append(results[0]) self.assertEqual(results[1].shape, (batch_size, image_size, image_size, 3)) self.assertEqual(results[2].shape, (batch_size, 1)) self.assertAllEqual( np.sort(np.concatenate(value_counter)), np.arange(0, num_batches * batch_size)) # Reached the limit. with self.assertRaises(errors_impl.OutOfRangeError): sess.run(batches) for thread in threads: thread.join() def testDynamicPad_failure(self): with ops.Graph().as_default(): variable_tensor = array_ops.placeholder(dtypes.int32, shape=[None, 3]) with self.assertRaisesRegexp(ValueError, 'shapes must be fully defined'): prefetch_queue.prefetch_queue([variable_tensor]) def testDynamicPad(self): with self.cached_session() as sess: # Create 3 tensors of variable but compatible shapes. var_shape = [None, 2] p1 = constant_op.constant([[1, 2], [3, 4]]) p1.set_shape(var_shape) p2 = constant_op.constant([[5, 6], [7, 8], [9, 10]]) p2.set_shape(var_shape) p3 = constant_op.constant([[11, 12]]) p3.set_shape(var_shape) batch = [p1, p2, p3] batch_size = len(batch) zero64 = constant_op.constant(0, dtype=dtypes.int64) examples = variables.Variable(zero64) counter = examples.count_up_to(batch_size) # Create a PaddingFIFOQueue to enqueue these tensors. q = data_flow_ops.PaddingFIFOQueue( capacity=10, dtypes=[dtypes.int32], shapes=[var_shape]) for tensor in [p1, p2, p3]: q.enqueue([tensor]).run() # Dequeue from the queue and batch them using batch(). batches = input_lib.batch([q.dequeue(), counter], batch_size=batch_size, num_threads=1, dynamic_pad=True) self.assertEqual([batch_size, None, 2], batches[0].shape.as_list()) # Finally, assemble them into prefetch_queue with dynamic_pad. batcher = prefetch_queue.prefetch_queue(batches, dynamic_pad=True) batches = batcher.dequeue() self.assertEqual([batch_size, None, 2], batches[0].shape.as_list()) variables.global_variables_initializer().run() threads = queue_runner_impl.start_queue_runners() values, _ = sess.run(batches) # We enqueued 3 tensors of [None, 2] shapes, so using dynamic_pad # they should be padded to the fixed size [3, 3, 2], where 3 # is the maximum length of the batch. self.assertTrue(np.array_equal( np.array([[[1, 2], [3, 4], [0, 0]], [[5, 6], [7, 8], [9, 10]], [[11, 12], [0, 0], [0, 0]]]), values)) with self.assertRaises(errors_impl.OutOfRangeError): sess.run(batches) for thread in threads: thread.join() def testDictConstruction(self): with ops.Graph().as_default(): batches = { 'first': constant_op.constant([1]), 'second': constant_op.constant([2.0, 2.1]) } prefetcher = prefetch_queue.prefetch_queue(batches) dequeued = prefetcher.dequeue() self.assertTrue(isinstance(dequeued, dict)) self.assertEqual(2, len(dequeued)) self.assertEqual(dtypes.int32, dequeued['first'].dtype) self.assertEqual(dtypes.float32, dequeued['second'].dtype) if __name__ == '__main__': test.main()