# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for tf_cnn_benchmarks.cnn_util."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import threading
import time

import tensorflow.compat.v1 as tf

import cnn_util


class CnnUtilBarrierTest(tf.test.TestCase):

  def testBarrier(self):
    num_tasks = 20
    num_waits = 4
    barrier = cnn_util.Barrier(num_tasks)
    threads = []
    sync_matrix = []
    for i in range(num_tasks):
      sync_times = [0] * num_waits
      thread = threading.Thread(
          target=self._run_task, args=(barrier, sync_times))
      thread.start()
      threads.append(thread)
      sync_matrix.append(sync_times)
    for thread in threads:
      thread.join()
    for wait_index in range(num_waits - 1):
      # Max of times at iteration i < min of times at iteration i + 1
      self.assertLessEqual(
          max([sync_matrix[i][wait_index] for i in range(num_tasks)]),
          min([sync_matrix[i][wait_index + 1] for i in range(num_tasks)]))

  def _run_task(self, barrier, sync_times):
    for wait_index in range(len(sync_times)):
      sync_times[wait_index] = time.time()
      barrier.wait()

  def testBarrierAbort(self):
    num_tasks = 2
    num_waits = 1
    sync_times = [0] * num_waits
    barrier = cnn_util.Barrier(num_tasks)
    thread = threading.Thread(
        target=self._run_task, args=(barrier, sync_times))
    thread.start()
    barrier.abort()
    # thread won't be blocked by done barrier.
    thread.join()


class ImageProducerTest(tf.test.TestCase):

  def _slow_tensorflow_op(self):
    """Returns a TensorFlow op that takes approximately 0.1s to complete."""
    def slow_func(v):
      time.sleep(0.1)
      return v
    return tf.py_func(slow_func, [tf.constant(0.)], tf.float32).op

  def _test_image_producer(self, batch_group_size, put_slower_than_get):
    # We use the variable x to simulate a staging area of images. x represents
    # the number of batches in the staging area.
    x = tf.Variable(0, dtype=tf.int32)
    if put_slower_than_get:
      put_dep = self._slow_tensorflow_op()
      get_dep = tf.no_op()
    else:
      put_dep = tf.no_op()
      get_dep = self._slow_tensorflow_op()
    with tf.control_dependencies([put_dep]):
      put_op = x.assign_add(batch_group_size, use_locking=True)
    with tf.control_dependencies([get_dep]):
      get_op = x.assign_sub(1, use_locking=True)
    with self.test_session() as sess:
      sess.run(tf.variables_initializer([x]))
      image_producer = cnn_util.ImageProducer(sess, put_op, batch_group_size,
                                              use_python32_barrier=False)
      image_producer.start()
      for _ in range(5 * batch_group_size):
        sess.run(get_op)
        # We assert x is nonnegative, to ensure image_producer never causes
        # an unstage op to block. We assert x is at most 2 * batch_group_size,
        # to ensure it doesn't use too much memory by storing too many batches
        # in the staging area.
        self.assertGreaterEqual(sess.run(x), 0)
        self.assertLessEqual(sess.run(x), 2 * batch_group_size)
        image_producer.notify_image_consumption()
        self.assertGreaterEqual(sess.run(x), 0)
        self.assertLessEqual(sess.run(x), 2 * batch_group_size)

      image_producer.done()
      time.sleep(0.1)
      self.assertGreaterEqual(sess.run(x), 0)
      self.assertLessEqual(sess.run(x), 2 * batch_group_size)

  def test_image_producer(self):
    self._test_image_producer(1, False)
    self._test_image_producer(1, True)
    self._test_image_producer(2, False)
    self._test_image_producer(2, True)
    self._test_image_producer(3, False)
    self._test_image_producer(3, True)
    self._test_image_producer(8, False)
    self._test_image_producer(8, True)


if __name__ == '__main__':
  tf.disable_v2_behavior()
  tf.test.main()