# coding=utf-8
# Copyright 2019 The Interval Bound Propagation Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines the output specifications."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

from absl import logging

from interval_bound_propagation.src import bounds as bounds_lib
from interval_bound_propagation.src import verifiable_wrapper
import six
import sonnet as snt
import tensorflow.compat.v1 as tf


@six.add_metaclass(abc.ABCMeta)
class Specification(snt.AbstractModule):
  """Defines a specification."""

  def __init__(self, name, collapse=True):
    super(Specification, self).__init__(name=name)
    self._collapse = collapse

  @abc.abstractmethod
  def _build(self, modules):
    """Computes the worst-case specification value."""

  @abc.abstractmethod
  def evaluate(self, logits):
    """Computes the specification value.

    Args:
      logits: The logits Tensor can have different shapes, i.e.,
        [batch_size, num_classes]: The output should be [batch_size, num_specs].
        [num_restarts, batch_size, num_classes]: The output should be
          [num_restarts, batch_size, num_specs]. Used by UntargetedPGDAttack.
        [num_restarts, num_specs, batch_size, num_classes]: The output should
          be [num_restarts, batch_size, num_specs]. For this case, the
          specifications must be evaluated individually for each column
          (axis = 1). Used by MultiTargetedPGDAttack.

    Returns:
      The specification values evaluated at the network output.
    """

  @abc.abstractproperty
  def num_specifications(self):
    """Returns the number of specifications."""

  @property
  def collapse(self):
    return self._collapse


class LinearSpecification(Specification):
  """Linear specifications: c^T * z_K + d <= 0."""

  def __init__(self, c, d=None, prune_irrelevant=True, collapse=True):
    """Builds a linear specification module."""
    super(LinearSpecification, self).__init__(name='specs', collapse=collapse)
    # c has shape [batch_size, num_specifications, num_outputs]
    # d has shape [batch_size, num_specifications]
    # Some specifications may be irrelevant (not a function of the output).
    # We automatically remove them for clarity. We expect the number of
    # irrelevant specs to be equal for all elements of a batch.
    # Shape is [batch_size, num_specifications]
    if prune_irrelevant:
      irrelevant = tf.equal(tf.reduce_sum(
          tf.cast(tf.abs(c) > 1e-6, tf.int32), axis=-1, keepdims=True), 0)
      batch_size = tf.shape(c)[0]
      num_outputs = tf.shape(c)[2]
      irrelevant = tf.tile(irrelevant, [1, 1, num_outputs])
      self._c = tf.reshape(
          tf.boolean_mask(c, tf.logical_not(irrelevant)),
          [batch_size, -1, num_outputs])
    else:
      self._c = c
    self._d = d

  def _build(self, modules):
    """Outputs specification value."""
    # inputs have shape [batch_size, num_outputs].
    if not (self.collapse and
            isinstance(modules[-1], verifiable_wrapper.LinearFCWrapper)):
      logging.info('Elision of last layer disabled.')
      bounds = modules[-1].output_bounds
      w = self._c
      b = self._d
    else:
      logging.info('Elision of last layer active.')
      # Collapse the last layer.
      bounds = modules[-1].input_bounds
      w = modules[-1].module.w
      b = modules[-1].module.b
      w = tf.einsum('ijk,lk->ijl', self._c, w)
      b = tf.einsum('ijk,k->ij', self._c, b)
      if self._d is not None:
        b += self._d

    # Maximize z * w + b s.t. lower <= z <= upper.
    bounds = bounds_lib.IntervalBounds.convert(bounds)
    c = (bounds.lower + bounds.upper) / 2.
    r = (bounds.upper - bounds.lower) / 2.
    c = tf.einsum('ij,ikj->ik', c, w)
    if b is not None:
      c += b
    r = tf.einsum('ij,ikj->ik', r, tf.abs(w))

    # output has shape [batch_size, num_specifications].
    return c + r

  def evaluate(self, logits):
    if len(logits.shape) == 2:
      output = tf.einsum('ij,ikj->ik', logits, self._c)
    elif len(logits.shape) == 3:
      output = tf.einsum('rij,ikj->rik', logits, self._c)
    else:
      assert len(logits.shape) == 4
      output = tf.einsum('rsbo,bso->rbs', logits, self._c)
    if self._d is not None:
      output += self._d
    return output

  @property
  def num_specifications(self):
    return tf.shape(self._c)[1]

  @property
  def c(self):
    return self._c

  @property
  def d(self):
    return self._d


class ClassificationSpecification(Specification):
  """Creates a linear specification that corresponds to a classification.

  This class is not a standard LinearSpecification as it does not materialize
  the c and d tensors.
  """

  def __init__(self, label, num_classes, collapse=True):
    super(ClassificationSpecification, self).__init__(name='specs',
                                                      collapse=collapse)
    self._label = label
    self._num_classes = num_classes
    # Precompute indices.
    with self._enter_variable_scope():
      indices = []
      for i in range(self._num_classes):
        indices.append(list(range(i)) + list(range(i + 1, self._num_classes)))
      indices = tf.constant(indices, dtype=tf.int32)
      self._correct_idx, self._wrong_idx = self._build_indices(label, indices)

  def _build(self, modules):
    if not (self.collapse and
            isinstance(modules[-1], verifiable_wrapper.LinearFCWrapper)):
      logging.info('Elision of last layer disabled.')
      bounds = modules[-1].output_bounds
      bounds = bounds_lib.IntervalBounds.convert(bounds)
      correct_class_logit = tf.gather_nd(bounds.lower, self._correct_idx)
      wrong_class_logits = tf.gather_nd(bounds.upper, self._wrong_idx)
      return wrong_class_logits - tf.expand_dims(correct_class_logit, 1)

    logging.info('Elision of last layer active.')
    bounds = modules[-1].input_bounds
    bounds = bounds_lib.IntervalBounds.convert(bounds)
    batch_size = tf.shape(bounds.lower)[0]
    w = modules[-1].module.w
    b = modules[-1].module.b
    w_t = tf.tile(tf.expand_dims(tf.transpose(w), 0), [batch_size, 1, 1])
    b_t = tf.tile(tf.expand_dims(b, 0), [batch_size, 1])
    w_correct = tf.expand_dims(tf.gather_nd(w_t, self._correct_idx), -1)
    b_correct = tf.expand_dims(tf.gather_nd(b_t, self._correct_idx), 1)
    w_wrong = tf.transpose(tf.gather_nd(w_t, self._wrong_idx), [0, 2, 1])
    b_wrong = tf.gather_nd(b_t, self._wrong_idx)
    w = w_wrong - w_correct
    b = b_wrong - b_correct
    # Maximize z * w + b s.t. lower <= z <= upper.
    c = (bounds.lower + bounds.upper) / 2.
    r = (bounds.upper - bounds.lower) / 2.
    c = tf.einsum('ij,ijk->ik', c, w)
    if b is not None:
      c += b
    r = tf.einsum('ij,ijk->ik', r, tf.abs(w))
    return c + r

  def evaluate(self, logits):
    if len(logits.shape) == 2:
      correct_class_logit = tf.gather_nd(logits, self._correct_idx)
      correct_class_logit = tf.expand_dims(correct_class_logit, -1)
      wrong_class_logits = tf.gather_nd(logits, self._wrong_idx)
    elif len(logits.shape) == 3:
      # [num_restarts, batch_size, num_classes] to
      # [num_restarts, batch_size, num_specs]
      logits = tf.transpose(logits, [1, 2, 0])  # Put restart dimension last.
      correct_class_logit = tf.gather_nd(logits, self._correct_idx)
      correct_class_logit = tf.transpose(correct_class_logit)
      correct_class_logit = tf.expand_dims(correct_class_logit, -1)
      wrong_class_logits = tf.gather_nd(logits, self._wrong_idx)
      wrong_class_logits = tf.transpose(wrong_class_logits, [2, 0, 1])
    else:
      assert len(logits.shape) == 4
      # [num_restarts, num_specs, batch_size, num_classes] to
      # [num_restarts, batch_size, num_specs].
      logits = tf.transpose(logits, [2, 3, 1, 0])
      correct_class_logit = tf.gather_nd(logits, self._correct_idx)
      correct_class_logit = tf.transpose(correct_class_logit, [2, 0, 1])
      batch_size = tf.shape(logits)[0]
      wrong_idx = tf.concat([
          self._wrong_idx,
          tf.tile(tf.reshape(tf.range(self.num_specifications, dtype=tf.int32),
                             [1, self.num_specifications, 1]),
                  [batch_size, 1, 1])], axis=-1)
      wrong_class_logits = tf.gather_nd(logits, wrong_idx)
      wrong_class_logits = tf.transpose(wrong_class_logits, [2, 0, 1])
    return wrong_class_logits - correct_class_logit

  @property
  def num_specifications(self):
    return self._num_classes - 1

  @property
  def correct_idx(self):
    return self._correct_idx

  @property
  def wrong_idx(self):
    return self._wrong_idx

  def _build_indices(self, label, indices):
    batch_size = tf.shape(label)[0]
    i = tf.range(batch_size, dtype=tf.int32)
    correct_idx = tf.stack([i, tf.cast(label, tf.int32)], axis=1)
    wrong_idx = tf.stack([
        tf.tile(tf.reshape(i, [batch_size, 1]), [1, self._num_classes - 1]),
        tf.gather(indices, label),
    ], axis=2)
    return correct_idx, wrong_idx


class TargetedClassificationSpecification(ClassificationSpecification):
  """Defines a specification that compares the true class with another."""

  def __init__(self, label, num_classes, target_class, collapse=True):
    super(TargetedClassificationSpecification, self).__init__(
        label, num_classes, collapse=collapse)
    batch_size = tf.shape(label)[0]
    if len(target_class.shape) == 1:
      target_class = tf.reshape(target_class, [batch_size, 1])
    self._num_specifications = target_class.shape[1].value
    if self._num_specifications is None:
      raise ValueError('Cannot retrieve the number of target classes')
    self._target_class = target_class
    i = tf.range(batch_size, dtype=tf.int32)
    self._wrong_idx = tf.stack([
        tf.tile(tf.reshape(i, [batch_size, 1]), [1, self.num_specifications]),
        target_class
    ], axis=2)

  @property
  def target_class(self):
    """Returns the target class index."""
    return self._target_class

  @property
  def num_specifications(self):
    return self._num_specifications


class RandomClassificationSpecification(TargetedClassificationSpecification):
  """Creates a single random specification that targets a random class."""

  def __init__(self, label, num_classes, num_targets=1, seed=None,
               collapse=True):
    # Overwrite the target indices. Each session.run() call gets new target
    # indices, the indices should remain the same across restarts.
    batch_size = tf.shape(label)[0]
    j = tf.random.uniform(shape=(batch_size, num_targets), minval=1,
                          maxval=num_classes, dtype=tf.int32, seed=seed)
    target_class = tf.mod(tf.cast(tf.expand_dims(label, -1), tf.int32) + j,
                          num_classes)
    super(RandomClassificationSpecification, self).__init__(
        label, num_classes, target_class, collapse=collapse)


class LeastLikelyClassificationSpecification(
    TargetedClassificationSpecification):
  """Creates a single specification that targets the least likely class."""

  def __init__(self, label, num_classes, logits, num_targets=1, collapse=True):
    # Do not target the true class. If the true class is the least likely to
    # be predicted, it is fine to target any other class as the attack will
    # be successful anyways.
    j = tf.nn.top_k(-logits, k=num_targets, sorted=False).indices
    l = tf.expand_dims(label, 1)
    target_class = tf.mod(
        j + tf.cast(tf.equal(j, tf.cast(l, tf.int32)), tf.int32), num_classes)
    super(LeastLikelyClassificationSpecification, self).__init__(
        label, num_classes, target_class, collapse=collapse)