# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Categorical calibration layer with monotonicity and bound constraints.

Keras implementation of tensorflow lattice categorical calibration layer. This
layer takes single or multi-dimensional input and transforms it using lookup
tables satisfying monotonicity and bounds constraints if specified.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from . import categorical_calibration_lib
import tensorflow as tf
from tensorflow import keras

DEFAULT_INPUT_VALUE_NAME = "default_input_value"
CATEGORICAL_CALIBRATION_KERNEL_NAME = "categorical_calibration_kernel"

# TODO: implement variation/variance regularizer.


class CategoricalCalibration(keras.layers.Layer):
  # pyformat: disable
  """Categorical calibration layer with monotonicity and bound constraints.

  This layer takes input of shape `(batch_size, units)` or `(batch_size, 1)` and
  transforms it using `units` number of lookup tables satisfying monotonicity
  and bounds constraints if specified. If multi dimensional input is provided,
  each output will be for the corresponding input, otherwise all calibration
  functions will act on the same input. All units share the same layer
  configuration, but each one has their separate set of trained parameters.

  Input shape:
  Rank-2 tensor with shape:  `(batch_size, units)` or `(batch_size, 1)`.

  Output shape:
  Rank-2 tensor with shape: `(batch_size, units)`.

  Attributes:
    - All `__init__` args.
    kernel: TF variable of shape `(batch_size, units)` which stores the lookup
    table.

  Example:

  ```python
  calibrator = tfl.layers.CategoricalCalibration(
      # Number of categories.
      num_buckets=3,
      # Output can be bounded.
      output_min=0.0,
      output_max=1.0,
      # For categorical calibration layer monotonicity is specified for pairs of
      # indices of categories. Output for first category in pair will be less
      # than or equal to output for second category.
      monotonicities=[(0, 1), (0, 2)])
  ```

  Usage with functional models:

  ```python
  input_feature = keras.layers.Input(shape=[1])
  calibrated_feature = tfl.layers.CategoricalCalibration(
      num_buckets=3,
      output_min=0.0,
      output_max=1.0,
      monotonicities=[(0, 1), (0, 2)],
  )(feature)
  ...
  model = keras.models.Model(
      inputs=[input_feature, ...],
      outputs=...)
  ```
  """
  # pyformat: enable

  def __init__(self,
               num_buckets,
               units=1,
               output_min=None,
               output_max=None,
               monotonicities=None,
               kernel_initializer="uniform",
               kernel_regularizer=None,
               default_input_value=None,
               **kwargs):
    # pyformat: disable
    """Initializes a `CategoricalCalibration` instance.

    Args:
      num_buckets: Number of categories.
      units: Output dimension of the layer. See class comments for details.
      output_min: Minimum output of calibrator.
      output_max: Maximum output of calibrator.
      monotonicities: List of pairs with `(i, j)` indices indicating `output(i)`
        should be less than or equal to `output(j)`.
      kernel_initializer: None or one of:
        - `'uniform'`: If `output_min` and `output_max` are provided initial
          values will be uniformly sampled from `[output_min, output_max]`
          range.
        - `'constant'`: If `output_min` and `output_max` are provided all output
          values will be initlized to the constant
          `(output_min + output_max) / 2`.
        - Any Keras initializer object.
      kernel_regularizer: None or single element or list of any Keras
        regularizer objects.
      default_input_value: If set, all inputs which are equal to this value will
        be treated as default and mapped to the last bucket.
      **kwargs: Other args passed to `tf.keras.layers.Layer` initializer.

    Raises:
      ValueError: If layer hyperparameters are invalid.
    """
    # pyformat: enable
    dtype = kwargs.pop("dtype", tf.float32)  # output dtype
    super(CategoricalCalibration, self).__init__(dtype=dtype, **kwargs)

    categorical_calibration_lib.verify_hyperparameters(
        num_buckets=num_buckets,
        output_min=output_min,
        output_max=output_max,
        monotonicities=monotonicities)
    self.num_buckets = num_buckets
    self.units = units
    self.output_min = output_min
    self.output_max = output_max
    self.monotonicities = monotonicities
    if output_min is not None and output_max is not None:
      if kernel_initializer == "constant":
        kernel_initializer = keras.initializers.Constant(
            (output_min + output_max) / 2)
      elif kernel_initializer == "uniform":
        kernel_initializer = keras.initializers.RandomUniform(
            output_min, output_max)
    self.kernel_initializer = keras.initializers.get(kernel_initializer)
    self.kernel_regularizer = []
    if kernel_regularizer:
      if callable(kernel_regularizer):
        kernel_regularizer = [kernel_regularizer]
      for reg in kernel_regularizer:
        self.kernel_regularizer.append(keras.regularizers.get(reg))
    self.default_input_value = default_input_value

  def build(self, input_shape):
    """Standard Keras build() method."""
    if (self.output_min is not None or self.output_max is not None or
        self.monotonicities):
      constraints = CategoricalCalibrationConstraints(
          output_min=self.output_min,
          output_max=self.output_max,
          monotonicities=self.monotonicities)
    else:
      constraints = None

    if not self.kernel_regularizer:
      kernel_reg = None
    elif len(self.kernel_regularizer) == 1:
      kernel_reg = self.kernel_regularizer[0]
    else:
      # Keras interface assumes only one regularizer, so summ all regularization
      # losses which we have.
      kernel_reg = lambda x: tf.add_n([r(x) for r in self.kernel_regularizer])

    # categorical calibration layer kernel is units-column matrix with value of
    # output(i) = self.kernel[i]. Default value converted to the last index.
    self.kernel = self.add_weight(
        CATEGORICAL_CALIBRATION_KERNEL_NAME,
        shape=[self.num_buckets, self.units],
        initializer=self.kernel_initializer,
        regularizer=kernel_reg,
        constraint=constraints,
        dtype=self.dtype)

    if self.kernel_regularizer and not tf.executing_eagerly():
      # Keras has its own mechanism to handle regularization losses which
      # does not use GraphKeys, but we want to also add losses to graph keys so
      # they are easily accessable when layer is being used outside of Keras.
      # Adding losses to GraphKeys will not interfer with Keras.
      for reg in self.kernel_regularizer:
        tf.compat.v1.add_to_collection(
            tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, reg(self.kernel))

    super(CategoricalCalibration, self).build(input_shape)

  def call(self, inputs):
    """Standard Keras call() method."""
    if inputs.dtype not in [tf.uint8, tf.int32, tf.int64]:
      inputs = tf.cast(inputs, dtype=tf.int32)

    if self.default_input_value is not None:
      default_input_value_tensor = tf.constant(
          int(self.default_input_value),
          dtype=inputs.dtype,
          name=DEFAULT_INPUT_VALUE_NAME)
      replacement = tf.zeros_like(inputs) + (self.num_buckets - 1)
      inputs = tf.where(
          tf.equal(inputs, default_input_value_tensor), replacement, inputs)

    # We can't use tf.gather_nd(self.kernel, inputs) as it doesn't support
    # constraints (constraint functions are not supported for IndexedSlices).
    # Instead we use matrix multiplication by one-hot encoding of the index.
    if self.units == 1:
      # This can be slightly faster as it uses matmul.
      return tf.matmul(
          tf.one_hot(tf.squeeze(inputs, axis=[-1]), depth=self.num_buckets),
          self.kernel)
    return tf.reduce_sum(
        tf.one_hot(inputs, axis=1, depth=self.num_buckets) * self.kernel,
        axis=1)

  def compute_output_shape(self, input_shape):
    """Standard Keras compute_output_shape() method."""
    del input_shape
    return [None, self.units]

  def get_config(self):
    """Standard Keras config for serialization."""
    config = {
        "num_buckets": self.num_buckets,
        "units": self.units,
        "output_min": self.output_min,
        "output_max": self.output_max,
        "monotonicities": self.monotonicities,
        "kernel_initializer":
            keras.initializers.serialize(self.kernel_initializer),
        "kernel_regularizer":
            [keras.regularizers.serialize(r) for r in self.kernel_regularizer],
        "default_input_value": self.default_input_value,
    }  # pyformat: disable
    config.update(super(CategoricalCalibration, self).get_config())
    return config

  def assert_constraints(self, eps=1e-6):
    """Asserts that layer weights satisfy all constraints.

    In graph mode builds and returns list of assertion ops. Note that ops will
    be created at the moment when this function is being called.
    In eager mode directly executes assertions.

    Args:
      eps: Allowed constraints violation.

    Returns:
      List of assertion ops in graph mode or immediately asserts in eager mode.
    """
    return categorical_calibration_lib.assert_constraints(
        weights=self.kernel,
        output_min=self.output_min,
        output_max=self.output_max,
        monotonicities=self.monotonicities,
        eps=eps)


class CategoricalCalibrationConstraints(keras.constraints.Constraint):
  # pyformat: disable
  """Monotonicity and bounds constraints for categorical calibration layer.

  Updates the weights of CategoricalCalibration layer to satify bound and
  monotonicity constraints. The update is an approximate L2 projection into the
  constrained parameter space.

  Attributes:
    - All `__init__` arguments.
  """
  # pyformat: enable

  def __init__(self, output_min=None, output_max=None, monotonicities=None):
    """Initializes an instance of `CategoricalCalibrationConstraints`.

    Args:
      output_min: Minimum possible output of categorical function.
      output_max: Maximum possible output of categorical function.
      monotonicities: Monotonicities of CategoricalCalibration layer.
    """
    categorical_calibration_lib.verify_hyperparameters(
        output_min=output_min,
        output_max=output_max,
        monotonicities=monotonicities)
    self.monotonicities = monotonicities
    self.output_min = output_min
    self.output_max = output_max

  def __call__(self, w):
    """Applies constraints to w."""
    return categorical_calibration_lib.project(
        weights=w,
        output_min=self.output_min,
        output_max=self.output_max,
        monotonicities=self.monotonicities)

  def get_config(self):
    """Standard Keras config for serialization."""
    return {
        "output_min": self.output_min,
        "output_max": self.output_max,
        "monotonicities": self.monotonicities,
    }  # pyformat: disable