# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Operations for np_box_mask_list.BoxMaskList.

Example box operations that are supported:
  * Areas: compute bounding box areas
  * IOU: pairwise intersection-over-union scores
"""
import numpy as np

import ava.np_box_list_ops as np_box_list_ops
import ava.np_box_mask_list as np_box_mask_list
import ava.np_mask_ops as np_mask_ops


def box_list_to_box_mask_list(boxlist):
  """Converts a BoxList containing 'masks' into a BoxMaskList.

  Args:
    boxlist: An np_box_list.BoxList object.

  Returns:
    An np_box_mask_list.BoxMaskList object.

  Raises:
    ValueError: If boxlist does not contain `masks` as a field.
  """
  if not boxlist.has_field('masks'):
    raise ValueError('boxlist does not contain mask field.')
  box_mask_list = np_box_mask_list.BoxMaskList(
      box_data=boxlist.get(),
      mask_data=boxlist.get_field('masks'))
  extra_fields = boxlist.get_extra_fields()
  for key in extra_fields:
    if key != 'masks':
      box_mask_list.data[key] = boxlist.get_field(key)
  return box_mask_list


def area(box_mask_list):
  """Computes area of masks.

  Args:
    box_mask_list: np_box_mask_list.BoxMaskList holding N boxes and masks

  Returns:
    a numpy array with shape [N*1] representing mask areas
  """
  return np_mask_ops.area(box_mask_list.get_masks())


def intersection(box_mask_list1, box_mask_list2):
  """Compute pairwise intersection areas between masks.

  Args:
    box_mask_list1: BoxMaskList holding N boxes and masks
    box_mask_list2: BoxMaskList holding M boxes and masks

  Returns:
    a numpy array with shape [N*M] representing pairwise intersection area
  """
  return np_mask_ops.intersection(box_mask_list1.get_masks(),
                                  box_mask_list2.get_masks())


def iou(box_mask_list1, box_mask_list2):
  """Computes pairwise intersection-over-union between box and mask collections.

  Args:
    box_mask_list1: BoxMaskList holding N boxes and masks
    box_mask_list2: BoxMaskList holding M boxes and masks

  Returns:
    a numpy array with shape [N, M] representing pairwise iou scores.
  """
  return np_mask_ops.iou(box_mask_list1.get_masks(),
                         box_mask_list2.get_masks())


def ioa(box_mask_list1, box_mask_list2):
  """Computes pairwise intersection-over-area between box and mask collections.

  Intersection-over-area (ioa) between two masks mask1 and mask2 is defined as
  their intersection area over mask2's area. Note that ioa is not symmetric,
  that is, IOA(mask1, mask2) != IOA(mask2, mask1).

  Args:
    box_mask_list1: np_box_mask_list.BoxMaskList holding N boxes and masks
    box_mask_list2: np_box_mask_list.BoxMaskList holding M boxes and masks

  Returns:
    a numpy array with shape [N, M] representing pairwise ioa scores.
  """
  return np_mask_ops.ioa(box_mask_list1.get_masks(), box_mask_list2.get_masks())


def gather(box_mask_list, indices, fields=None):
  """Gather boxes from np_box_mask_list.BoxMaskList according to indices.

  By default, gather returns boxes corresponding to the input index list, as
  well as all additional fields stored in the box_mask_list (indexing into the
  first dimension).  However one can optionally only gather from a
  subset of fields.

  Args:
    box_mask_list: np_box_mask_list.BoxMaskList holding N boxes
    indices: a 1-d numpy array of type int_
    fields: (optional) list of fields to also gather from.  If None (default),
        all fields are gathered from.  Pass an empty fields list to only gather
        the box coordinates.

  Returns:
    subbox_mask_list: a np_box_mask_list.BoxMaskList corresponding to the subset
        of the input box_mask_list specified by indices

  Raises:
    ValueError: if specified field is not contained in box_mask_list or if the
        indices are not of type int_
  """
  if fields is not None:
    if 'masks' not in fields:
      fields.append('masks')
  return box_list_to_box_mask_list(
      np_box_list_ops.gather(
          boxlist=box_mask_list, indices=indices, fields=fields))


def sort_by_field(box_mask_list, field,
                  order=np_box_list_ops.SortOrder.DESCEND):
  """Sort boxes and associated fields according to a scalar field.

  A common use case is reordering the boxes according to descending scores.

  Args:
    box_mask_list: BoxMaskList holding N boxes.
    field: A BoxMaskList field for sorting and reordering the BoxMaskList.
    order: (Optional) 'descend' or 'ascend'. Default is descend.

  Returns:
    sorted_box_mask_list: A sorted BoxMaskList with the field in the specified
      order.
  """
  return box_list_to_box_mask_list(
      np_box_list_ops.sort_by_field(
          boxlist=box_mask_list, field=field, order=order))


def non_max_suppression(box_mask_list,
                        max_output_size=10000,
                        iou_threshold=1.0,
                        score_threshold=-10.0):
  """Non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes. In each iteration, the detected bounding box with
  highest score in the available pool is selected.

  Args:
    box_mask_list: np_box_mask_list.BoxMaskList holding N boxes.  Must contain
      a 'scores' field representing detection scores. All scores belong to the
      same class.
    max_output_size: maximum number of retained boxes
    iou_threshold: intersection over union threshold.
    score_threshold: minimum score threshold. Remove the boxes with scores
                     less than this value. Default value is set to -10. A very
                     low threshold to pass pretty much all the boxes, unless
                     the user sets a different score threshold.

  Returns:
    an np_box_mask_list.BoxMaskList holding M boxes where M <= max_output_size

  Raises:
    ValueError: if 'scores' field does not exist
    ValueError: if threshold is not in [0, 1]
    ValueError: if max_output_size < 0
  """
  if not box_mask_list.has_field('scores'):
    raise ValueError('Field scores does not exist')
  if iou_threshold < 0. or iou_threshold > 1.0:
    raise ValueError('IOU threshold must be in [0, 1]')
  if max_output_size < 0:
    raise ValueError('max_output_size must be bigger than 0.')

  box_mask_list = filter_scores_greater_than(box_mask_list, score_threshold)
  if box_mask_list.num_boxes() == 0:
    return box_mask_list

  box_mask_list = sort_by_field(box_mask_list, 'scores')

  # Prevent further computation if NMS is disabled.
  if iou_threshold == 1.0:
    if box_mask_list.num_boxes() > max_output_size:
      selected_indices = np.arange(max_output_size)
      return gather(box_mask_list, selected_indices)
    else:
      return box_mask_list

  masks = box_mask_list.get_masks()
  num_masks = box_mask_list.num_boxes()

  # is_index_valid is True only for all remaining valid boxes,
  is_index_valid = np.full(num_masks, 1, dtype=bool)
  selected_indices = []
  num_output = 0
  for i in range(num_masks):
    if num_output < max_output_size:
      if is_index_valid[i]:
        num_output += 1
        selected_indices.append(i)
        is_index_valid[i] = False
        valid_indices = np.where(is_index_valid)[0]
        if valid_indices.size == 0:
          break

        intersect_over_union = np_mask_ops.iou(
            np.expand_dims(masks[i], axis=0), masks[valid_indices])
        intersect_over_union = np.squeeze(intersect_over_union, axis=0)
        is_index_valid[valid_indices] = np.logical_and(
            is_index_valid[valid_indices],
            intersect_over_union <= iou_threshold)
  return gather(box_mask_list, np.array(selected_indices))


def multi_class_non_max_suppression(box_mask_list, score_thresh, iou_thresh,
                                    max_output_size):
  """Multi-class version of non maximum suppression.

  This op greedily selects a subset of detection bounding boxes, pruning
  away boxes that have high IOU (intersection over union) overlap (> thresh)
  with already selected boxes.  It operates independently for each class for
  which scores are provided (via the scores field of the input box_list),
  pruning boxes with score less than a provided threshold prior to
  applying NMS.

  Args:
    box_mask_list: np_box_mask_list.BoxMaskList holding N boxes.  Must contain a
      'scores' field representing detection scores.  This scores field is a
      tensor that can be 1 dimensional (in the case of a single class) or
      2-dimensional, in which case we assume that it takes the
      shape [num_boxes, num_classes]. We further assume that this rank is known
      statically and that scores.shape[1] is also known (i.e., the number of
      classes is fixed and known at graph construction time).
    score_thresh: scalar threshold for score (low scoring boxes are removed).
    iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap
      with previously selected boxes are removed).
    max_output_size: maximum number of retained boxes per class.

  Returns:
    a box_mask_list holding M boxes with a rank-1 scores field representing
      corresponding scores for each box with scores sorted in decreasing order
      and a rank-1 classes field representing a class label for each box.
  Raises:
    ValueError: if iou_thresh is not in [0, 1] or if input box_mask_list does
      not have a valid scores field.
  """
  if not 0 <= iou_thresh <= 1.0:
    raise ValueError('thresh must be between 0 and 1')
  if not isinstance(box_mask_list, np_box_mask_list.BoxMaskList):
    raise ValueError('box_mask_list must be a box_mask_list')
  if not box_mask_list.has_field('scores'):
    raise ValueError('input box_mask_list must have \'scores\' field')
  scores = box_mask_list.get_field('scores')
  if len(scores.shape) == 1:
    scores = np.reshape(scores, [-1, 1])
  elif len(scores.shape) == 2:
    if scores.shape[1] is None:
      raise ValueError('scores field must have statically defined second '
                       'dimension')
  else:
    raise ValueError('scores field must be of rank 1 or 2')

  num_boxes = box_mask_list.num_boxes()
  num_scores = scores.shape[0]
  num_classes = scores.shape[1]

  if num_boxes != num_scores:
    raise ValueError('Incorrect scores field length: actual vs expected.')

  selected_boxes_list = []
  for class_idx in range(num_classes):
    box_mask_list_and_class_scores = np_box_mask_list.BoxMaskList(
        box_data=box_mask_list.get(),
        mask_data=box_mask_list.get_masks())
    class_scores = np.reshape(scores[0:num_scores, class_idx], [-1])
    box_mask_list_and_class_scores.add_field('scores', class_scores)
    box_mask_list_filt = filter_scores_greater_than(
        box_mask_list_and_class_scores, score_thresh)
    nms_result = non_max_suppression(
        box_mask_list_filt,
        max_output_size=max_output_size,
        iou_threshold=iou_thresh,
        score_threshold=score_thresh)
    nms_result.add_field(
        'classes',
        np.zeros_like(nms_result.get_field('scores')) + class_idx)
    selected_boxes_list.append(nms_result)
  selected_boxes = np_box_list_ops.concatenate(selected_boxes_list)
  sorted_boxes = np_box_list_ops.sort_by_field(selected_boxes, 'scores')
  return box_list_to_box_mask_list(boxlist=sorted_boxes)


def prune_non_overlapping_masks(box_mask_list1, box_mask_list2, minoverlap=0.0):
  """Prunes the boxes in list1 that overlap less than thresh with list2.

  For each mask in box_mask_list1, we want its IOA to be more than minoverlap
  with at least one of the masks in box_mask_list2. If it does not, we remove
  it. If the masks are not full size image, we do the pruning based on boxes.

  Args:
    box_mask_list1: np_box_mask_list.BoxMaskList holding N boxes and masks.
    box_mask_list2: np_box_mask_list.BoxMaskList holding M boxes and masks.
    minoverlap: Minimum required overlap between boxes, to count them as
                overlapping.

  Returns:
    A pruned box_mask_list with size [N', 4].
  """
  intersection_over_area = ioa(box_mask_list2, box_mask_list1)  # [M, N] tensor
  intersection_over_area = np.amax(intersection_over_area, axis=0)  # [N] tensor
  keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap))
  keep_inds = np.nonzero(keep_bool)[0]
  new_box_mask_list1 = gather(box_mask_list1, keep_inds)
  return new_box_mask_list1


def concatenate(box_mask_lists, fields=None):
  """Concatenate list of box_mask_lists.

  This op concatenates a list of input box_mask_lists into a larger
  box_mask_list.  It also
  handles concatenation of box_mask_list fields as long as the field tensor
  shapes are equal except for the first dimension.

  Args:
    box_mask_lists: list of np_box_mask_list.BoxMaskList objects
    fields: optional list of fields to also concatenate.  By default, all
      fields from the first BoxMaskList in the list are included in the
      concatenation.

  Returns:
    a box_mask_list with number of boxes equal to
      sum([box_mask_list.num_boxes() for box_mask_list in box_mask_list])
  Raises:
    ValueError: if box_mask_lists is invalid (i.e., is not a list, is empty, or
      contains non box_mask_list objects), or if requested fields are not
      contained in all box_mask_lists
  """
  if fields is not None:
    if 'masks' not in fields:
      fields.append('masks')
  return box_list_to_box_mask_list(
      np_box_list_ops.concatenate(boxlists=box_mask_lists, fields=fields))


def filter_scores_greater_than(box_mask_list, thresh):
  """Filter to keep only boxes and masks with score exceeding a given threshold.

  This op keeps the collection of boxes and masks whose corresponding scores are
  greater than the input threshold.

  Args:
    box_mask_list: BoxMaskList holding N boxes and masks.  Must contain a
      'scores' field representing detection scores.
    thresh: scalar threshold

  Returns:
    a BoxMaskList holding M boxes and masks where M <= N

  Raises:
    ValueError: if box_mask_list not a np_box_mask_list.BoxMaskList object or
      if it does not have a scores field
  """
  if not isinstance(box_mask_list, np_box_mask_list.BoxMaskList):
    raise ValueError('box_mask_list must be a BoxMaskList')
  if not box_mask_list.has_field('scores'):
    raise ValueError('input box_mask_list must have \'scores\' field')
  scores = box_mask_list.get_field('scores')
  if len(scores.shape) > 2:
    raise ValueError('Scores should have rank 1 or 2')
  if len(scores.shape) == 2 and scores.shape[1] != 1:
    raise ValueError('Scores should have rank 1 or have shape '
                     'consistent with [None, 1]')
  high_score_indices = np.reshape(np.where(np.greater(scores, thresh)),
                                  [-1]).astype(np.int32)
  return gather(box_mask_list, high_score_indices)