# coding=utf-8
# Copyright 2019 The Interval Bound Propagation Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Definition of input bounds to each layer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import itertools

import six
import sonnet as snt
import tensorflow.compat.v1 as tf


@six.add_metaclass(abc.ABCMeta)
class AbstractBounds(object):
  """Abstract bounds class."""

  def __init__(self):
    self._update_cache_op = None

  @classmethod
  @abc.abstractmethod
  def convert(cls, bounds):
    """Converts another bound type to this type."""

  @abc.abstractproperty
  def shape(self):
    """Returns shape (as list) of the tensor, including batch dimension."""

  def concretize(self):
    return self

  def _raise_not_implemented(self, name):
    raise NotImplementedError(
        '{} modules are not supported by "{}".'.format(
            name, self.__class__.__name__))

  def apply_linear(self, wrapper, w, b):  # pylint: disable=unused-argument
    self._raise_not_implemented('snt.Linear')

  def apply_conv1d(self, wrapper, w, b, padding, stride):  # pylint: disable=unused-argument
    self._raise_not_implemented('snt.Conv1D')

  def apply_conv2d(self, wrapper, w, b, padding, strides):  # pylint: disable=unused-argument
    self._raise_not_implemented('snt.Conv2D')

  def apply_increasing_monotonic_fn(self, wrapper, fn, *args, **parameters):  # pylint: disable=unused-argument
    self._raise_not_implemented(fn.__name__)

  def apply_piecewise_monotonic_fn(self, wrapper, fn, boundaries, *args):  # pylint: disable=unused-argument
    self._raise_not_implemented(fn.__name__)

  def apply_batch_norm(self, wrapper, mean, variance, scale, bias, epsilon):  # pylint: disable=unused-argument
    self._raise_not_implemented('ibp.BatchNorm')

  def apply_batch_reshape(self, wrapper, shape):  # pylint: disable=unused-argument
    self._raise_not_implemented('snt.BatchReshape')

  def apply_softmax(self, wrapper):  # pylint: disable=unused-argument
    self._raise_not_implemented('tf.nn.softmax')

  @property
  def update_cache_op(self):
    """TF op to update cached bounds for re-use across session.run calls."""
    if self._update_cache_op is None:
      raise ValueError('Bounds not cached: enable_caching() not called.')
    return self._update_cache_op

  def enable_caching(self):
    """Enables caching the bounds for re-use across session.run calls."""
    if self._update_cache_op is not None:
      raise ValueError('Bounds already cached: enable_caching() called twice.')
    self._update_cache_op = self._set_up_cache()

  def _set_up_cache(self):
    """Replace fields with cached versions.

    Returns:
      TensorFlow op to update the cache.
    """
    return tf.no_op()  # By default, don't cache.

  def _cache_with_update_op(self, tensor):
    """Creates non-trainable variable to cache the tensor across sess.run calls.

    Args:
      tensor: Tensor to cache.

    Returns:
      cached_tensor: Non-trainable variable to contain the cached value
        of `tensor`.
      update_op: TensorFlow op to re-evaluate `tensor` and assign the result
        to `cached_tensor`.
    """
    cached_tensor = tf.get_variable(
        tensor.name.replace(':', '__') + '_ibp_cache',
        shape=tensor.shape, dtype=tensor.dtype, trainable=False)
    update_op = tf.assign(cached_tensor, tensor)
    return cached_tensor, update_op


class IntervalBounds(AbstractBounds):
  """Axis-aligned bounding box."""

  def __init__(self, lower, upper):
    super(IntervalBounds, self).__init__()
    self._lower = lower
    self._upper = upper

  @property
  def lower(self):
    return self._lower

  @property
  def upper(self):
    return self._upper

  @property
  def shape(self):
    return self.lower.shape.as_list()

  def __iter__(self):
    yield self.lower
    yield self.upper

  @classmethod
  def convert(cls, bounds):
    if isinstance(bounds, tf.Tensor):
      return cls(bounds, bounds)
    bounds = bounds.concretize()
    if not isinstance(bounds, cls):
      raise ValueError('Cannot convert "{}" to "{}"'.format(bounds,
                                                            cls.__name__))
    return bounds

  def apply_linear(self, wrapper, w, b):
    return self._affine(w, b, tf.matmul)

  def apply_conv1d(self, wrapper, w, b, padding, stride):
    return self._affine(w, b, tf.nn.conv1d, padding=padding, stride=stride)

  def apply_conv2d(self, wrapper, w, b, padding, strides):
    return self._affine(w, b, tf.nn.convolution,
                        padding=padding, strides=strides)

  def _affine(self, w, b, fn, **kwargs):
    c = (self.lower + self.upper) / 2.
    r = (self.upper - self.lower) / 2.
    c = fn(c, w, **kwargs)
    if b is not None:
      c = c + b
    r = fn(r, tf.abs(w), **kwargs)
    return IntervalBounds(c - r, c + r)

  def apply_increasing_monotonic_fn(self, wrapper, fn, *args, **parameters):
    args_lower = [self.lower] + [a.lower for a in args]
    args_upper = [self.upper] + [a.upper for a in args]
    return IntervalBounds(fn(*args_lower), fn(*args_upper))

  def apply_piecewise_monotonic_fn(self, wrapper, fn, boundaries, *args):
    valid_values = []
    for a in [self] + list(args):
      vs = []
      vs.append(a.lower)
      vs.append(a.upper)
      for b in boundaries:
        vs.append(
            tf.maximum(a.lower, tf.minimum(a.upper, b * tf.ones_like(a.lower))))
      valid_values.append(vs)
    outputs = []
    for inputs in itertools.product(*valid_values):
      outputs.append(fn(*inputs))
    outputs = tf.stack(outputs, axis=-1)
    return IntervalBounds(tf.reduce_min(outputs, axis=-1),
                          tf.reduce_max(outputs, axis=-1))

  def apply_batch_norm(self, wrapper, mean, variance, scale, bias, epsilon):
    # Element-wise multiplier.
    multiplier = tf.rsqrt(variance + epsilon)
    if scale is not None:
      multiplier *= scale
    w = multiplier
    # Element-wise bias.
    b = -multiplier * mean
    if bias is not None:
      b += bias
    b = tf.squeeze(b, axis=0)
    # Because the scale might be negative, we need to apply a strategy similar
    # to linear.
    c = (self.lower + self.upper) / 2.
    r = (self.upper - self.lower) / 2.
    c = tf.multiply(c, w) + b
    r = tf.multiply(r, tf.abs(w))
    return IntervalBounds(c - r, c + r)

  def apply_batch_reshape(self, wrapper, shape):
    return IntervalBounds(snt.BatchReshape(shape)(self.lower),
                          snt.BatchReshape(shape)(self.upper))

  def apply_softmax(self, wrapper):
    ub = self.upper
    lb = self.lower
    # Keep diagonal and take opposite bound for non-diagonals.
    lbs = tf.matrix_diag(lb) + tf.expand_dims(ub, axis=-2) - tf.matrix_diag(ub)
    ubs = tf.matrix_diag(ub) + tf.expand_dims(lb, axis=-2) - tf.matrix_diag(lb)
    # Get diagonal entries after softmax operation.
    ubs = tf.matrix_diag_part(tf.nn.softmax(ubs))
    lbs = tf.matrix_diag_part(tf.nn.softmax(lbs))
    return IntervalBounds(lbs, ubs)

  def _set_up_cache(self):
    self._lower, update_lower_op = self._cache_with_update_op(self._lower)
    self._upper, update_upper_op = self._cache_with_update_op(self._upper)
    return tf.group([update_lower_op, update_upper_op])