# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#      Unless required by applicable law or agreed to in writing, software
#      distributed under the License is distributed on an "AS IS" BASIS,
#      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#      See the License for the specific language governing permissions and
#      limitations under the License.
"""TensorFlow extensions."""
from __future__ import absolute_import

from numbers import Number

import numpy as np
from tangent import grads
from tangent import non_differentiable
from tangent import tangents
from tangent import utils
from tangent.grads import adjoint
from tangent.tangents import tangent_
from tangent.utils import array_shapes_match
from tangent.utils import register_all_add_grad
from tangent.utils import register_all_shape_checker
from tangent.utils import register_init_grad
from tangent.utils import register_shape_function
from tangent.utils import register_unbroadcast
from tangent.utils import register_unreduce
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops


def size(x, axis):
  axis_shape = x.shape if axis is None else tuple(x.shape[a] for a in axis)
  return max(np.prod(axis_shape).value, 1)


def dtype(t):
  return t.dtype


def shape_as_list(t):
  return t.shape.as_list()


def tensor_shapes_match(a, b):
  return tf.shape(a).shape == tf.shape(b).shape


register_shape_function(ops.EagerTensor, shape_as_list)
register_shape_function(resource_variable_ops.ResourceVariable, shape_as_list)


non_differentiable.register_non_differentiable_functions(
    tf.shape, tf.to_float, tf.equal, tf.constant,
    tf.zeros, tf.ones, tf.zeros_like, tf.ones_like,
    size, shape_as_list, dtype)


register_init_grad(ops.EagerTensor, tf.zeros_like)
register_init_grad(resource_variable_ops.ResourceVariable, tf.zeros_like)


register_all_add_grad(
    tf.add, (ops.EagerTensor, resource_variable_ops.ResourceVariable))

register_all_shape_checker(
    tensor_shapes_match,
    (ops.EagerTensor, resource_variable_ops.ResourceVariable))

#
# Utilities
#


def unbroadcast_tfe_to(tensor, shape):
  """Reverse the broadcasting operation.

  See utils.py.

  Args:
    tensor: A Tensor.
    shape: A shape that could have been broadcasted to the shape of tensor.

  Returns:
    Tensor with dimensions summed to match `shape`.
  """
  axis = utils.create_unbroadcast_axis(shape, shape_as_list(tensor))
  return tf.reshape(tf.reduce_sum(tensor, axis=axis), shape)


def unbroadcast_tensor(tensor, like):
  """Reverse the broadcasting operation.

  See utils.py.

  Args:
    tensor: A Tensor.
    like: A Tensor that could have been broadcasted to the shape of tensor.

  Returns:
    Tensor with certain dimensions summed to match the shape of `like`.
  """
  return unbroadcast_tfe_to(tensor, shape_as_list(like))


register_unbroadcast(ops.EagerTensor, unbroadcast_tensor)
register_unbroadcast(resource_variable_ops.ResourceVariable, unbroadcast_tensor)


def unreduce_tensor(tensor, shape, axis, keepdims):
  """Reverse summing over a dimension.

  See utils.py.

  Args:
    tensor: The tensor that was reduced.
    shape: A list, the original shape of the tensor before reduction.
    axis: The axis or axes that were summed.
    keepdims: Whether these axes were kept as singleton axes.

  Returns:
    A tensor with axes broadcast to match the shape of the original tensor.
  """
  if not keepdims:
    if axis is None:
      axis = range(len(shape))
    elif isinstance(axis, int):
      axis = axis,
    for ax in sorted(axis):
      tensor = tf.expand_dims(tensor, ax)
  tile_shape = np.array(shape) / np.array(shape_as_list(tensor))
  return tf.tile(tensor, tile_shape)


register_unreduce(ops.EagerTensor, unreduce_tensor)
register_unreduce(resource_variable_ops.ResourceVariable, unreduce_tensor)


# TODO: Once the optimizer can handle multiple return values, consolidate.
def matmul_adjoint_x(dz, x, y, transpose_a, transpose_b):
  """Implementation of dtfmatmul wrt x, separate for readability."""
  if not transpose_a and not transpose_b:
    return tf.matmul(dz, y, transpose_b=True)
  elif not transpose_a and transpose_b:
    return tf.matmul(dz, y)
  elif transpose_a and not transpose_b:
    return tf.matmul(y, dz, transpose_b=True)
  else:  # transpose_a and transpose_b
    return tf.matmul(y, dz, transpose_a=True, transpose_b=True)


def matmul_adjoint_y(dz, x, y, transpose_a, transpose_b):
  """Implementation of dtfmatmul, separate for readability."""
  if not transpose_a and not transpose_b:
    return tf.matmul(x, dz, transpose_a=True)
  elif not transpose_a and transpose_b:
    return tf.matmul(dz, x, transpose_a=True)
  elif transpose_a and not transpose_b:
    return tf.matmul(x, dz)
  else:  # transpose_a and transpose_b
    return tf.matmul(dz, x, transpose_a=True, transpose_b=True)


#
# Adjoints
#


@adjoint(tf.exp)
def dtfexp(y, x):
  d[x] = y * d[y]


@adjoint(tf.log)
def dtflog(y, x):
  d[x] = d[y] / x


@adjoint(tf.tanh)
def dtftanh(y, x):
  d[x] = d[y] * (1 - (y * y))


@adjoint(tf.cosh)
def dtfcosh(y, x):
  d[x] = d[y] * tf.sinh(x)


@adjoint(tf.sinh)
def dtfsinh(y, x):
  d[x] = d[y] * tf.cosh(x)


@adjoint(tf.rsqrt)
def drsqrt(y, x):
  d[x] = -0.5 * d[y] * tf.pow(tf.conj(y), tf.constant(3.0))


@adjoint(tf.negative)
def dtfnegative(y, x):
  # TODO: Remove the unbroadcast.
  d[x] = tangent.unbroadcast_tensor(tf.negative(d[y]), x)


@adjoint(tf.expand_dims)
def dtfexpand_dims(y, x, axis):
  d[x] = tf.squeeze(d[y], axis)


@adjoint(tf.squeeze)
def dtfsqueeze(y, x, axis=None):
  d[x] = tf.expand_dims(d[y], axis)


@adjoint(tf.reshape)
def dtfreshape(y, x, shape):
  d[x] = tf.reshape(d[y], tf.shape(x))


@adjoint(tf.reduce_sum)
def dtfreduce_sum(y, x, axis=None, keep_dims=False):
  # TODO: We may be able to assume unreduce_tensor works throughout.
  d[x] = tangent.unreduce(d[y], tangent.shape_as_list(x), axis, keep_dims)


@adjoint(tf.reduce_mean)
def dtfreduce_mean(y, x, axis=None, keep_dims=False):
  n = tf.constant(float(tangent.size(x, axis)))
  d[x] = tf.divide(
      tangent.unreduce(d[y], tangent.shape_as_list(x), axis, keep_dims), n)


@adjoint(tf.reduce_max)
def dtfreduce_max(y, x, axis=None, keep_dims=False):
  mask = tf.to_float(
      tf.equal(
          tangent.unreduce(y, tangent.shape_as_list(x), axis, keep_dims), x))
  d[x] = tf.multiply(
      tangent.unreduce(d[y], tangent.shape_as_list(x), axis, keep_dims), mask)


@adjoint(tf.add)
def dtfadd(z, x, y):
  d[x] = tangent.unbroadcast(d[z], x)
  d[y] = tangent.unbroadcast(d[z], y)


@adjoint(tf.subtract)
def dtfsubtract(z, x, y):
  d[x] = tangent.unbroadcast(d[z], x)
  d[y] = tangent.unbroadcast(tf.negative(d[z]), y)


@adjoint(tf.multiply)
def dtfmultiply(z, x, y):
  d[x] = tangent.unbroadcast(tf.multiply(d[z], y), x)
  d[y] = tangent.unbroadcast(tf.multiply(d[z], x), y)


@adjoint(tf.divide)
def dtfdivide(z, x, y):
  d[x] = tangent.unbroadcast(tf.divide(d[z], y), x)
  d[y] = tangent.unbroadcast(
      tf.negative(tf.divide(tf.multiply(d[z], x), tf.multiply(y, y))), y)


@adjoint(tf.maximum)
def dtfmaximum(z, x, y):
  d[x] = tf.multiply(d[z], tf.to_float(tf.equal(z, x)))
  d[y] = tf.multiply(d[z], tf.to_float(tf.equal(z, y)))


@adjoint(tf.squared_difference)
def dtfsquared_difference(z, x, y):
  d[x] = tangent.unbroadcast(2 * d[z] * (x - y), x)
  d[y] = tangent.unbroadcast(2 * d[z] * (y - x), y)


@adjoint(tf.matmul)
def dtfmatmul(z, x, y, transpose_a=False, transpose_b=False):
  d[x] = tangent.matmul_adjoint_x(d[z], x, y, transpose_a, transpose_b)
  d[y] = tangent.matmul_adjoint_y(d[z], x, y, transpose_a, transpose_b)


@adjoint(tf.nn.conv2d)
def dtfconv2d(z, x, y, strides, padding):
  d[x] = tf.nn.conv2d_backprop_input(tf.shape(x), y, d[z], strides, padding)
  d[y] = tf.nn.conv2d_backprop_filter(x, tf.shape(y), d[z], strides, padding)


@adjoint(tf.nn.conv2d_backprop_input)
def dtfconv2d_backprop_input(z, shape, x, y, strides, padding):
  # TODO: Add tests.
  d[x] = tf.nn.conv2d_backprop_filter(d[z], shape, y, strides, padding)
  d[y] = tf.nn.conv2d(d[z], x, strides, padding)


@adjoint(tf.nn.conv2d_backprop_filter)
def dtfconv2d_backprop_filter(z, x, shape, y, strides, padding):
  # TODO: Add tests.
  d[x] = tf.nn.conv2d_backprop_input(shape, d[z], y, strides, padding)
  d[y] = tf.nn.conv2d(x, d[z], strides, padding)


@adjoint(tf.nn.avg_pool)
def dtfavg_pool(y, x, sizes, strides, padding):
  # TODO: We shouldn't rely on private modules.
  d[x] = tf.nn._nn_grad.gen_nn_ops._avg_pool_grad(
      tf.shape(x), d[y], sizes, strides, padding)


@adjoint(tf.nn.max_pool)
def dtfmax_pool(y, x, sizes, strides, padding):
  # TODO: We shouldn't rely on private modules.
  d[x] = tf.nn._nn_grad.gen_nn_ops._max_pool_grad(
      x, y, d[y], sizes, strides, padding)


#
# Tangents
#


@tangent_(shape_as_list)
def tshape_as_list(y, x):
  d[y] = tangent.shape_as_list(d[x])


@tangent_(tf.exp)
def ttfexp(y, x):
  d[y] = d[x] * y


@tangent_(tf.log)
def ttflog(y, x):
  d[y] = d[x] / x


@tangent_(tf.tanh)
def ttftanh(y, x):
  cx = tf.cosh(x)
  d[y] = d[x] / (cx * cx)


@tangent_(tf.cosh)
def ttfcosh(y, x):
  d[y] = d[x] * tf.sinh(x)


@tangent_(tf.sinh)
def ttfsinh(y, x):
  d[y] = d[x] * tf.cosh(x)


@tangent_(tf.expand_dims)
def ttfexpand_dims(y, x, axis):
  d[y] = tf.expand_dims(d[x], axis)


@tangent_(tf.squeeze)
def ttfsqueeze(y, x, axis):
  d[y] = tf.squeeze(d[x], axis)


@tangent_(tf.reshape)
def ttfreshape(y, x, shape):
  d[y] = tf.reshape(d[x], shape)


@tangent_(tf.reduce_sum)
def ttfreduce_sum(y, x, axis=None, keep_dims=False):
  d[y] = tf.reduce_sum(d[x], axis, keep_dims)


@tangent_(tf.reduce_mean)
def ttfreduce_mean(y, x, axis=None, keep_dims=False):
  d[y] = tf.reduce_mean(d[x], axis, keep_dims)


@tangent_(tf.reduce_max)
def ttfreduce_max(y, x, axis=None, keep_dims=False):
  mask = tf.to_float(
      tf.equal(
          tangent.unreduce(
              tf.ones_like(y), tangent.shape_as_list(x), axis, keep_dims), x))
  d[y] = tf.multiply(d[x], mask)


@tangent_(tf.negative)
def ttfnegative(y, x):
  d[y] = tf.negative(d[x])


@tangent_(tf.add)
def ttfadd(z, x, y):
  d[z] = tf.add(d[x], d[y])


@tangent_(tf.subtract)
def ttfsubtract(z, x, y):
  d[z] = tf.subtract(d[x], d[y])


@tangent_(tf.multiply)
def ttfmultiply(z, x, y):
  d[z] = tf.add(tf.multiply(d[x], y), tf.multiply(x, d[y]))


@tangent_(tf.divide)
def ttfdivide(z, x, y):
  d[z] = tf.divide(
          tf.subtract(tf.multiply(d[x], y), tf.multiply(x, d[y])),
          tf.multiply(y, y))


@tangent_(tf.maximum)
def ttfmaximum(z, x, y):
  d[z] = d[x] * tf.to_float(tf.equal(z, x)) + d[y] * tf.to_float(tf.equal(z, y))


@tangent_(tf.nn.avg_pool)
def ttfavg_pool(y, x, sizes, strides, padding):
  raise tangent.ForwardNotImplementedError(tf.nn.avg_pool)


@tangent_(tf.nn.max_pool)
def ttfmax_pool(y, x, sizes, strides, padding):
  raise tangent.ForwardNotImplementedError(tf.nn.max_pool)


@tangent_(tf.shape)
def tshape(y, x):
  d[y] = tf.shape(d[x])


#
# Blacklist unimplemented Eager grads
#

grads.UNIMPLEMENTED_ADJOINTS.update(
    grads.get_module_functions((tf, tf.distributions, tf.image, tf.layers,
                                tf.linalg, tf.losses,
                                tf.nn)) - set(grads.adjoints))

tangents.UNIMPLEMENTED_TANGENTS.update(
    grads.get_module_functions((tf, tf.distributions, tf.image, tf.layers,
                                tf.linalg, tf.losses,
                                tf.nn)) - set(tangents.tangents))