# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Polyharmonic spline interpolation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops

EPSILON = 0.0000000001


def _cross_squared_distance_matrix(x, y):
  """Pairwise squared distance between two (batch) matrices' rows (2nd dim).

  Computes the pairwise distances between rows of x and rows of y
  Args:
    x: [batch_size, n, d] float `Tensor`
    y: [batch_size, m, d] float `Tensor`

  Returns:
    squared_dists: [batch_size, n, m] float `Tensor`, where
    squared_dists[b,i,j] = ||x[b,i,:] - y[b,j,:]||^2
  """
  x_norm_squared = math_ops.reduce_sum(math_ops.square(x), 2)
  y_norm_squared = math_ops.reduce_sum(math_ops.square(y), 2)

  # Expand so that we can broadcast.
  x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)
  y_norm_squared_tile = array_ops.expand_dims(y_norm_squared, 1)

  x_y_transpose = math_ops.matmul(x, y, adjoint_b=True)

  # squared_dists[b,i,j] = ||x_bi - y_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
  squared_dists = x_norm_squared_tile - 2 * x_y_transpose + y_norm_squared_tile

  return squared_dists


def _pairwise_squared_distance_matrix(x):
  """Pairwise squared distance among a (batch) matrix's rows (2nd dim).

  This saves a bit of computation vs. using _cross_squared_distance_matrix(x,x)

  Args:
    x: `[batch_size, n, d]` float `Tensor`

  Returns:
    squared_dists: `[batch_size, n, n]` float `Tensor`, where
    squared_dists[b,i,j] = ||x[b,i,:] - x[b,j,:]||^2
  """

  x_x_transpose = math_ops.matmul(x, x, adjoint_b=True)
  x_norm_squared = array_ops.matrix_diag_part(x_x_transpose)
  x_norm_squared_tile = array_ops.expand_dims(x_norm_squared, 2)

  # squared_dists[b,i,j] = ||x_bi - x_bj||^2 = x_bi'x_bi- 2x_bi'x_bj + x_bj'x_bj
  squared_dists = x_norm_squared_tile - 2 * x_x_transpose + array_ops.transpose(
      x_norm_squared_tile, [0, 2, 1])

  return squared_dists


def _solve_interpolation(train_points, train_values, order,
                         regularization_weight):
  """Solve for interpolation coefficients.

  Computes the coefficients of the polyharmonic interpolant for the 'training'
  data defined by (train_points, train_values) using the kernel phi.

  Args:
    train_points: `[b, n, d]` interpolation centers
    train_values: `[b, n, k]` function values
    order: order of the interpolation
    regularization_weight: weight to place on smoothness regularization term

  Returns:
    w: `[b, n, k]` weights on each interpolation center
    v: `[b, d, k]` weights on each input dimension
  """

  b, n, d = train_points.get_shape().as_list()
  b = tf.shape(train_points)[0]
  _, _, k = train_values.get_shape().as_list()

  # First, rename variables so that the notation (c, f, w, v, A, B, etc.)
  # follows https://en.wikipedia.org/wiki/Polyharmonic_spline.
  # To account for python style guidelines we use
  # matrix_a for A and matrix_b for B.

  c = train_points
  f = train_values

  # Next, construct the linear system.
  with ops.name_scope('construct_linear_system'):

    matrix_a = _phi(_pairwise_squared_distance_matrix(c), order)  # [b, n, n]
    if regularization_weight > 0:
      batch_identity_matrix = np.expand_dims(np.eye(n), 0)
      batch_identity_matrix = constant_op.constant(
          batch_identity_matrix, dtype=train_points.dtype)

      matrix_a += regularization_weight * batch_identity_matrix

    # Append ones to the feature values for the bias term in the linear model.
    ones = array_ops.ones([b, n, 1], train_points.dtype)
    matrix_b = array_ops.concat([c, ones], 2)  # [b, n, d + 1]

    # [b, n + d + 1, n]
    left_block = array_ops.concat(
        [matrix_a, array_ops.transpose(matrix_b, [0, 2, 1])], 1)

    num_b_cols = matrix_b.get_shape()[2]  # d + 1
    lhs_zeros = array_ops.zeros([b, num_b_cols, num_b_cols], train_points.dtype)
    right_block = array_ops.concat([matrix_b, lhs_zeros],
                                   1)  # [b, n + d + 1, d + 1]
    lhs = array_ops.concat([left_block, right_block],
                           2)  # [b, n + d + 1, n + d + 1]

    rhs_zeros = array_ops.zeros([b, d + 1, k], train_points.dtype)
    rhs = array_ops.concat([f, rhs_zeros], 1)  # [b, n + d + 1, k]

  # Then, solve the linear system and unpack the results.
  with ops.name_scope('solve_linear_system'):
    w_v = linalg_ops.matrix_solve(lhs, rhs)
    w = w_v[:, :n, :]
    v = w_v[:, n:, :]

  return w, v


def _apply_interpolation(query_points, train_points, w, v, order):
  """Apply polyharmonic interpolation model to data.

  Given coefficients w and v for the interpolation model, we evaluate
  interpolated function values at query_points.

  Args:
    query_points: `[b, m, d]` x values to evaluate the interpolation at
    train_points: `[b, n, d]` x values that act as the interpolation centers
                    ( the c variables in the wikipedia article)
    w: `[b, n, k]` weights on each interpolation center
    v: `[b, d, k]` weights on each input dimension
    order: order of the interpolation

  Returns:
    Polyharmonic interpolation evaluated at points defined in query_points.
  """

  batch_size = train_points.get_shape()[0].value
  batch_size = tf.shape(train_points)[0]
  num_query_points = query_points.get_shape()[1].value

  # First, compute the contribution from the rbf term.
  pairwise_dists = _cross_squared_distance_matrix(query_points, train_points)
  phi_pairwise_dists = _phi(pairwise_dists, order)

  rbf_term = math_ops.matmul(phi_pairwise_dists, w)

  # Then, compute the contribution from the linear term.
  # Pad query_points with ones, for the bias term in the linear model.
  query_points_pad = array_ops.concat([
      query_points,
      array_ops.ones([batch_size, num_query_points, 1], train_points.dtype)
  ], 2)
  linear_term = math_ops.matmul(query_points_pad, v)

  return rbf_term + linear_term


def _phi(r, order):
  """Coordinate-wise nonlinearity used to define the order of the interpolation.

  See https://en.wikipedia.org/wiki/Polyharmonic_spline for the definition.

  Args:
    r: input op
    order: interpolation order

  Returns:
    phi_k evaluated coordinate-wise on r, for k = r
  """

  # using EPSILON prevents log(0), sqrt0), etc.
  # sqrt(0) is well-defined, but its gradient is not
  with ops.name_scope('phi'):
    if order == 1:
      r = math_ops.maximum(r, EPSILON)
      r = math_ops.sqrt(r)
      return r
    elif order == 2:
      return 0.5 * r * math_ops.log(math_ops.maximum(r, EPSILON))
    elif order == 4:
      return 0.5 * math_ops.square(r) * math_ops.log(
          math_ops.maximum(r, EPSILON))
    elif order % 2 == 0:
      r = math_ops.maximum(r, EPSILON)
      return 0.5 * math_ops.pow(r, 0.5 * order) * math_ops.log(r)
    else:
      r = math_ops.maximum(r, EPSILON)
      return math_ops.pow(r, 0.5 * order)


def interpolate_spline(train_points,
                       train_values,
                       query_points,
                       order,
                       regularization_weight=0.0,
                       name='interpolate_spline'):
  r"""Interpolate signal using polyharmonic interpolation.

  The interpolant has the form
  $$f(x) = \sum_{i = 1}^n w_i \phi(||x - c_i||) + v^T x + b.$$

  This is a sum of two terms: (1) a weighted sum of radial basis function (RBF)
  terms, with the centers \\(c_1, ... c_n\\), and (2) a linear term with a bias.
  The \\(c_i\\) vectors are 'training' points. In the code, b is absorbed into v
  by appending 1 as a final dimension to x. The coefficients w and v are
  estimated such that the interpolant exactly fits the value of the function at
  the \\(c_i\\) points, the vector w is orthogonal to each \\(c_i\\), and the
  vector w sums to 0. With these constraints, the coefficients can be obtained
  by solving a linear system.

  \\(\phi\\) is an RBF, parametrized by an interpolation
  order. Using order=2 produces the well-known thin-plate spline.

  We also provide the option to perform regularized interpolation. Here, the
  interpolant is selected to trade off between the squared loss on the training
  data and a certain measure of its curvature
  ([details](https://en.wikipedia.org/wiki/Polyharmonic_spline)).
  Using a regularization weight greater than zero has the effect that the
  interpolant will no longer exactly fit the training data. However, it may be
  less vulnerable to overfitting, particularly for high-order interpolation.

  Note the interpolation procedure is differentiable with respect to all inputs
  besides the order parameter.

  Args:
    train_points: `[batch_size, n, d]` float `Tensor` of n d-dimensional
      locations. These do not need to be regularly-spaced.
    train_values: `[batch_size, n, k]` float `Tensor` of n c-dimensional values
      evaluated at train_points.
    query_points: `[batch_size, m, d]` `Tensor` of m d-dimensional locations
      where we will output the interpolant's values.
    order: order of the interpolation. Common values are 1 for
      \\(\phi(r) = r\\), 2 for \\(\phi(r) = r^2 * log(r)\\) (thin-plate spline),
       or 3 for \\(\phi(r) = r^3\\).
    regularization_weight: weight placed on the regularization term.
      This will depend substantially on the problem, and it should always be
      tuned. For many problems, it is reasonable to use no regularization.
      If using a non-zero value, we recommend a small value like 0.001.
    name: name prefix for ops created by this function

  Returns:
    `[b, m, k]` float `Tensor` of query values. We use train_points and
    train_values to perform polyharmonic interpolation. The query values are
    the values of the interpolant evaluated at the locations specified in
    query_points.
  """
  with ops.name_scope(name):
    train_points = ops.convert_to_tensor(train_points)
    train_values = ops.convert_to_tensor(train_values)
    query_points = ops.convert_to_tensor(query_points)

    # First, fit the spline to the observed data.
    with ops.name_scope('solve'):
      w, v = _solve_interpolation(train_points, train_values, order,
                                  regularization_weight)

    # Then, evaluate the spline at the query locations.
    with ops.name_scope('predict'):
      query_values = _apply_interpolation(query_points, train_points, w, v,
                                          order)

  return query_values