# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for object_detection.core.matcher."""
import numpy as np
import tensorflow as tf

from object_detection.core import matcher


class AnchorMatcherTest(tf.test.TestCase):

  def test_get_correct_matched_columnIndices(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    expected_column_indices = [0, 1, 3, 5]
    matched_column_indices = match.matched_column_indices()
    self.assertEquals(matched_column_indices.dtype, tf.int32)
    with self.test_session() as sess:
      matched_column_indices = sess.run(matched_column_indices)
      self.assertAllEqual(matched_column_indices, expected_column_indices)

  def test_get_correct_counts(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    exp_num_matched_columns = 4
    exp_num_unmatched_columns = 2
    exp_num_ignored_columns = 1
    num_matched_columns = match.num_matched_columns()
    num_unmatched_columns = match.num_unmatched_columns()
    num_ignored_columns = match.num_ignored_columns()
    self.assertEquals(num_matched_columns.dtype, tf.int32)
    self.assertEquals(num_unmatched_columns.dtype, tf.int32)
    self.assertEquals(num_ignored_columns.dtype, tf.int32)
    with self.test_session() as sess:
      (num_matched_columns_out, num_unmatched_columns_out,
       num_ignored_columns_out) = sess.run(
           [num_matched_columns, num_unmatched_columns, num_ignored_columns])
      self.assertAllEqual(num_matched_columns_out, exp_num_matched_columns)
      self.assertAllEqual(num_unmatched_columns_out, exp_num_unmatched_columns)
      self.assertAllEqual(num_ignored_columns_out, exp_num_ignored_columns)

  def testGetCorrectUnmatchedColumnIndices(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    expected_column_indices = [2, 4]
    unmatched_column_indices = match.unmatched_column_indices()
    self.assertEquals(unmatched_column_indices.dtype, tf.int32)
    with self.test_session() as sess:
      unmatched_column_indices = sess.run(unmatched_column_indices)
      self.assertAllEqual(unmatched_column_indices, expected_column_indices)

  def testGetCorrectMatchedRowIndices(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    expected_row_indices = [3, 1, 0, 5]
    matched_row_indices = match.matched_row_indices()
    self.assertEquals(matched_row_indices.dtype, tf.int32)
    with self.test_session() as sess:
      matched_row_inds = sess.run(matched_row_indices)
      self.assertAllEqual(matched_row_inds, expected_row_indices)

  def test_get_correct_ignored_column_indices(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    expected_column_indices = [6]
    ignored_column_indices = match.ignored_column_indices()
    self.assertEquals(ignored_column_indices.dtype, tf.int32)
    with self.test_session() as sess:
      ignored_column_indices = sess.run(ignored_column_indices)
      self.assertAllEqual(ignored_column_indices, expected_column_indices)

  def test_get_correct_matched_column_indicator(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    expected_column_indicator = [True, True, False, True, False, True, False]
    matched_column_indicator = match.matched_column_indicator()
    self.assertEquals(matched_column_indicator.dtype, tf.bool)
    with self.test_session() as sess:
      matched_column_indicator = sess.run(matched_column_indicator)
      self.assertAllEqual(matched_column_indicator, expected_column_indicator)

  def test_get_correct_unmatched_column_indicator(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    expected_column_indicator = [False, False, True, False, True, False, False]
    unmatched_column_indicator = match.unmatched_column_indicator()
    self.assertEquals(unmatched_column_indicator.dtype, tf.bool)
    with self.test_session() as sess:
      unmatched_column_indicator = sess.run(unmatched_column_indicator)
      self.assertAllEqual(unmatched_column_indicator, expected_column_indicator)

  def test_get_correct_ignored_column_indicator(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    expected_column_indicator = [False, False, False, False, False, False, True]
    ignored_column_indicator = match.ignored_column_indicator()
    self.assertEquals(ignored_column_indicator.dtype, tf.bool)
    with self.test_session() as sess:
      ignored_column_indicator = sess.run(ignored_column_indicator)
      self.assertAllEqual(ignored_column_indicator, expected_column_indicator)

  def test_get_correct_unmatched_ignored_column_indices(self):
    match_results = tf.constant([3, 1, -1, 0, -1, 5, -2])
    match = matcher.Match(match_results)
    expected_column_indices = [2, 4, 6]
    unmatched_ignored_column_indices = (match.
                                        unmatched_or_ignored_column_indices())
    self.assertEquals(unmatched_ignored_column_indices.dtype, tf.int32)
    with self.test_session() as sess:
      unmatched_ignored_column_indices = sess.run(
          unmatched_ignored_column_indices)
      self.assertAllEqual(unmatched_ignored_column_indices,
                          expected_column_indices)

  def test_all_columns_accounted_for(self):
    # Note: deliberately setting to small number so not always
    # all possibilities appear (matched, unmatched, ignored)
    num_matches = 10
    match_results = tf.random_uniform(
        [num_matches], minval=-2, maxval=5, dtype=tf.int32)
    match = matcher.Match(match_results)
    matched_column_indices = match.matched_column_indices()
    unmatched_column_indices = match.unmatched_column_indices()
    ignored_column_indices = match.ignored_column_indices()
    with self.test_session() as sess:
      matched, unmatched, ignored = sess.run([
          matched_column_indices, unmatched_column_indices,
          ignored_column_indices
      ])
      all_indices = np.hstack((matched, unmatched, ignored))
      all_indices_sorted = np.sort(all_indices)
      self.assertAllEqual(all_indices_sorted,
                          np.arange(num_matches, dtype=np.int32))


if __name__ == '__main__':
  tf.test.main()