# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Tests for object_detection.trainer."""

import tensorflow as tf

from google.protobuf import text_format

from object_detection.core import losses
from object_detection.core import model
from object_detection.core import standard_fields as fields
from object_detection.legacy import trainer
from object_detection.protos import train_pb2


NUMBER_OF_CLASSES = 2


def get_input_function():
  """A function to get test inputs. Returns an image with one box."""
  image = tf.random_uniform([32, 32, 3], dtype=tf.float32)
  key = tf.constant('image_000000')
  class_label = tf.random_uniform(
      [1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32)
  box_label = tf.random_uniform(
      [1, 4], minval=0.4, maxval=0.6, dtype=tf.float32)
  multiclass_scores = tf.random_uniform(
      [1, NUMBER_OF_CLASSES], minval=0.4, maxval=0.6, dtype=tf.float32)

  return {
      fields.InputDataFields.image: image,
      fields.InputDataFields.key: key,
      fields.InputDataFields.groundtruth_classes: class_label,
      fields.InputDataFields.groundtruth_boxes: box_label,
      fields.InputDataFields.multiclass_scores: multiclass_scores
  }


class FakeDetectionModel(model.DetectionModel):
  """A simple (and poor) DetectionModel for use in test."""

  def __init__(self):
    super(FakeDetectionModel, self).__init__(num_classes=NUMBER_OF_CLASSES)
    self._classification_loss = losses.WeightedSigmoidClassificationLoss()
    self._localization_loss = losses.WeightedSmoothL1LocalizationLoss()

  def preprocess(self, inputs):
    """Input preprocessing, resizes images to 28x28.

    Args:
      inputs: a [batch, height_in, width_in, channels] float32 tensor
        representing a batch of images with values between 0 and 255.0.

    Returns:
      preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
      true_image_shapes: int32 tensor of shape [batch, 3] where each row is
        of the form [height, width, channels] indicating the shapes
        of true images in the resized images, as resized images can be padded
        with zeros.
    """
    true_image_shapes = [inputs.shape[:-1].as_list()
                         for _ in range(inputs.shape[-1])]
    return tf.image.resize_images(inputs, [28, 28]), true_image_shapes

  def predict(self, preprocessed_inputs, true_image_shapes):
    """Prediction tensors from inputs tensor.

    Args:
      preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
      true_image_shapes: int32 tensor of shape [batch, 3] where each row is
        of the form [height, width, channels] indicating the shapes
        of true images in the resized images, as resized images can be padded
        with zeros.

    Returns:
      prediction_dict: a dictionary holding prediction tensors to be
        passed to the Loss or Postprocess functions.
    """
    flattened_inputs = tf.contrib.layers.flatten(preprocessed_inputs)
    class_prediction = tf.contrib.layers.fully_connected(
        flattened_inputs, self._num_classes)
    box_prediction = tf.contrib.layers.fully_connected(flattened_inputs, 4)

    return {
        'class_predictions_with_background': tf.reshape(
            class_prediction, [-1, 1, self._num_classes]),
        'box_encodings': tf.reshape(box_prediction, [-1, 1, 4])
    }

  def postprocess(self, prediction_dict, true_image_shapes, **params):
    """Convert predicted output tensors to final detections. Unused.

    Args:
      prediction_dict: a dictionary holding prediction tensors.
      true_image_shapes: int32 tensor of shape [batch, 3] where each row is
        of the form [height, width, channels] indicating the shapes
        of true images in the resized images, as resized images can be padded
        with zeros.
      **params: Additional keyword arguments for specific implementations of
        DetectionModel.

    Returns:
      detections: a dictionary with empty fields.
    """
    return {
        'detection_boxes': None,
        'detection_scores': None,
        'detection_classes': None,
        'num_detections': None
    }

  def loss(self, prediction_dict, true_image_shapes):
    """Compute scalar loss tensors with respect to provided groundtruth.

    Calling this function requires that groundtruth tensors have been
    provided via the provide_groundtruth function.

    Args:
      prediction_dict: a dictionary holding predicted tensors
      true_image_shapes: int32 tensor of shape [batch, 3] where each row is
        of the form [height, width, channels] indicating the shapes
        of true images in the resized images, as resized images can be padded
        with zeros.

    Returns:
      a dictionary mapping strings (loss names) to scalar tensors representing
        loss values.
    """
    batch_reg_targets = tf.stack(
        self.groundtruth_lists(fields.BoxListFields.boxes))
    batch_cls_targets = tf.stack(
        self.groundtruth_lists(fields.BoxListFields.classes))
    weights = tf.constant(
        1.0, dtype=tf.float32,
        shape=[len(self.groundtruth_lists(fields.BoxListFields.boxes)), 1])

    location_losses = self._localization_loss(
        prediction_dict['box_encodings'], batch_reg_targets,
        weights=weights)
    cls_losses = self._classification_loss(
        prediction_dict['class_predictions_with_background'], batch_cls_targets,
        weights=weights)

    loss_dict = {
        'localization_loss': tf.reduce_sum(location_losses),
        'classification_loss': tf.reduce_sum(cls_losses),
    }
    return loss_dict

  def regularization_losses(self):
    """Returns a list of regularization losses for this model.

    Returns a list of regularization losses for this model that the estimator
    needs to use during training/optimization.

    Returns:
      A list of regularization loss tensors.
    """
    pass

  def restore_map(self, fine_tune_checkpoint_type='detection'):
    """Returns a map of variables to load from a foreign checkpoint.

    Args:
      fine_tune_checkpoint_type: whether to restore from a full detection
        checkpoint (with compatible variable names) or to restore from a
        classification checkpoint for initialization prior to training.
        Valid values: `detection`, `classification`. Default 'detection'.

    Returns:
      A dict mapping variable names to variables.
    """
    return {var.op.name: var for var in tf.global_variables()}

  def updates(self):
    """Returns a list of update operators for this model.

    Returns a list of update operators for this model that must be executed at
    each training step. The estimator's train op needs to have a control
    dependency on these updates.

    Returns:
      A list of update operators.
    """
    pass


class TrainerTest(tf.test.TestCase):

  def test_configure_trainer_and_train_two_steps(self):
    train_config_text_proto = """
    optimizer {
      adam_optimizer {
        learning_rate {
          constant_learning_rate {
            learning_rate: 0.01
          }
        }
      }
    }
    data_augmentation_options {
      random_adjust_brightness {
        max_delta: 0.2
      }
    }
    data_augmentation_options {
      random_adjust_contrast {
        min_delta: 0.7
        max_delta: 1.1
      }
    }
    num_steps: 2
    """
    train_config = train_pb2.TrainConfig()
    text_format.Merge(train_config_text_proto, train_config)

    train_dir = self.get_temp_dir()

    trainer.train(
        create_tensor_dict_fn=get_input_function,
        create_model_fn=FakeDetectionModel,
        train_config=train_config,
        master='',
        task=0,
        num_clones=1,
        worker_replicas=1,
        clone_on_cpu=True,
        ps_tasks=0,
        worker_job_name='worker',
        is_chief=True,
        train_dir=train_dir)

  def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self):
    train_config_text_proto = """
    optimizer {
      adam_optimizer {
        learning_rate {
          constant_learning_rate {
            learning_rate: 0.01
          }
        }
      }
    }
    data_augmentation_options {
      random_adjust_brightness {
        max_delta: 0.2
      }
    }
    data_augmentation_options {
      random_adjust_contrast {
        min_delta: 0.7
        max_delta: 1.1
      }
    }
    num_steps: 2
    use_multiclass_scores: true
    """
    train_config = train_pb2.TrainConfig()
    text_format.Merge(train_config_text_proto, train_config)

    train_dir = self.get_temp_dir()

    trainer.train(create_tensor_dict_fn=get_input_function,
                  create_model_fn=FakeDetectionModel,
                  train_config=train_config,
                  master='',
                  task=0,
                  num_clones=1,
                  worker_replicas=1,
                  clone_on_cpu=True,
                  ps_tasks=0,
                  worker_job_name='worker',
                  is_chief=True,
                  train_dir=train_dir)


if __name__ == '__main__':
  tf.test.main()