# Copyright 2017 Google Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#      http://www.apache.org/licenses/LICENSE-2.0
#      Unless required by applicable law or agreed to in writing, software
#      distributed under the License is distributed on an "AS IS" BASIS,
#      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#      See the License for the specific language governing permissions and
#      limitations under the License.
"""TensorFlow extensions."""
from __future__ import absolute_import

from numbers import Number

import numpy as np
from tangent import grads
from tangent import non_differentiable
from tangent import tangents
from tangent import utils
from tangent.grads import adjoint
from tangent.tangents import tangent_
from tangent.utils import array_shapes_match
from tangent.utils import register_all_add_grad
from tangent.utils import register_all_shape_checker
from tangent.utils import register_init_grad
from tangent.utils import register_shape_function
from tangent.utils import register_unbroadcast
from tangent.utils import register_unreduce
import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops

def size(x, axis):
  axis_shape = x.shape if axis is None else tuple(x.shape[a] for a in axis)
  return max(np.prod(axis_shape).value, 1)

def dtype(t):
  return t.dtype

def shape_as_list(t):
  return t.shape.as_list()

def tensor_shapes_match(a, b):
  return tf.shape(a).shape == tf.shape(b).shape

register_shape_function(ops.EagerTensor, shape_as_list)
register_shape_function(resource_variable_ops.ResourceVariable, shape_as_list)

    tf.shape, tf.to_float, tf.equal, tf.constant,
    tf.zeros, tf.ones, tf.zeros_like, tf.ones_like,
    size, shape_as_list, dtype)

register_init_grad(ops.EagerTensor, tf.zeros_like)
register_init_grad(resource_variable_ops.ResourceVariable, tf.zeros_like)

    tf.add, (ops.EagerTensor, resource_variable_ops.ResourceVariable))

    (ops.EagerTensor, resource_variable_ops.ResourceVariable))

# Utilities

def unbroadcast_tfe_to(tensor, shape):
  """Reverse the broadcasting operation.

  See utils.py.

    tensor: A Tensor.
    shape: A shape that could have been broadcasted to the shape of tensor.

    Tensor with dimensions summed to match `shape`.
  axis = utils.create_unbroadcast_axis(shape, shape_as_list(tensor))
  return tf.reshape(tf.reduce_sum(tensor, axis=axis), shape)

def unbroadcast_tensor(tensor, like):
  """Reverse the broadcasting operation.

  See utils.py.

    tensor: A Tensor.
    like: A Tensor that could have been broadcasted to the shape of tensor.

    Tensor with certain dimensions summed to match the shape of `like`.
  return unbroadcast_tfe_to(tensor, shape_as_list(like))

register_unbroadcast(ops.EagerTensor, unbroadcast_tensor)
register_unbroadcast(resource_variable_ops.ResourceVariable, unbroadcast_tensor)

def unreduce_tensor(tensor, shape, axis, keepdims):
  """Reverse summing over a dimension.

  See utils.py.

    tensor: The tensor that was reduced.
    shape: A list, the original shape of the tensor before reduction.
    axis: The axis or axes that were summed.
    keepdims: Whether these axes were kept as singleton axes.

    A tensor with axes broadcast to match the shape of the original tensor.
  if not keepdims:
    if axis is None:
      axis = range(len(shape))
    elif isinstance(axis, int):
      axis = axis,
    for ax in sorted(axis):
      tensor = tf.expand_dims(tensor, ax)
  tile_shape = np.array(shape) / np.array(shape_as_list(tensor))
  return tf.tile(tensor, tile_shape)

register_unreduce(ops.EagerTensor, unreduce_tensor)
register_unreduce(resource_variable_ops.ResourceVariable, unreduce_tensor)

# TODO: Once the optimizer can handle multiple return values, consolidate.
def matmul_adjoint_x(dz, x, y, transpose_a, transpose_b):
  """Implementation of dtfmatmul wrt x, separate for readability."""
  if not transpose_a and not transpose_b:
    return tf.matmul(dz, y, transpose_b=True)
  elif not transpose_a and transpose_b:
    return tf.matmul(dz, y)
  elif transpose_a and not transpose_b:
    return tf.matmul(y, dz, transpose_b=True)
  else:  # transpose_a and transpose_b
    return tf.matmul(y, dz, transpose_a=True, transpose_b=True)

def matmul_adjoint_y(dz, x, y, transpose_a, transpose_b):
  """Implementation of dtfmatmul, separate for readability."""
  if not transpose_a and not transpose_b:
    return tf.matmul(x, dz, transpose_a=True)
  elif not transpose_a and transpose_b:
    return tf.matmul(dz, x, transpose_a=True)
  elif transpose_a and not transpose_b:
    return tf.matmul(x, dz)
  else:  # transpose_a and transpose_b
    return tf.matmul(dz, x, transpose_a=True, transpose_b=True)

# Adjoints

def dtfexp(y, x):
  d[x] = y * d[y]

def dtflog(y, x):
  d[x] = d[y] / x

def dtftanh(y, x):
  d[x] = d[y] * (1 - (y * y))

def dtfcosh(y, x):
  d[x] = d[y] * tf.sinh(x)

def dtfsinh(y, x):
  d[x] = d[y] * tf.cosh(x)

def drsqrt(y, x):
  d[x] = -0.5 * d[y] * tf.pow(tf.conj(y), tf.constant(3.0))

def dtfnegative(y, x):
  # TODO: Remove the unbroadcast.
  d[x] = tangent.unbroadcast_tensor(tf.negative(d[y]), x)

def dtfexpand_dims(y, x, axis):
  d[x] = tf.squeeze(d[y], axis)

def dtfsqueeze(y, x, axis=None):
  d[x] = tf.expand_dims(d[y], axis)

def dtfreshape(y, x, shape):
  d[x] = tf.reshape(d[y], tf.shape(x))

def dtfreduce_sum(y, x, axis=None, keep_dims=False):
  # TODO: We may be able to assume unreduce_tensor works throughout.
  d[x] = tangent.unreduce(d[y], tangent.shape_as_list(x), axis, keep_dims)

def dtfreduce_mean(y, x, axis=None, keep_dims=False):
  n = tf.constant(float(tangent.size(x, axis)))
  d[x] = tf.divide(
      tangent.unreduce(d[y], tangent.shape_as_list(x), axis, keep_dims), n)

def dtfreduce_max(y, x, axis=None, keep_dims=False):
  mask = tf.to_float(
          tangent.unreduce(y, tangent.shape_as_list(x), axis, keep_dims), x))
  d[x] = tf.multiply(
      tangent.unreduce(d[y], tangent.shape_as_list(x), axis, keep_dims), mask)

def dtfadd(z, x, y):
  d[x] = tangent.unbroadcast(d[z], x)
  d[y] = tangent.unbroadcast(d[z], y)

def dtfsubtract(z, x, y):
  d[x] = tangent.unbroadcast(d[z], x)
  d[y] = tangent.unbroadcast(tf.negative(d[z]), y)

def dtfmultiply(z, x, y):
  d[x] = tangent.unbroadcast(tf.multiply(d[z], y), x)
  d[y] = tangent.unbroadcast(tf.multiply(d[z], x), y)

def dtfdivide(z, x, y):
  d[x] = tangent.unbroadcast(tf.divide(d[z], y), x)
  d[y] = tangent.unbroadcast(
      tf.negative(tf.divide(tf.multiply(d[z], x), tf.multiply(y, y))), y)

def dtfmaximum(z, x, y):
  d[x] = tf.multiply(d[z], tf.to_float(tf.equal(z, x)))
  d[y] = tf.multiply(d[z], tf.to_float(tf.equal(z, y)))

def dtfsquared_difference(z, x, y):
  d[x] = tangent.unbroadcast(2 * d[z] * (x - y), x)
  d[y] = tangent.unbroadcast(2 * d[z] * (y - x), y)

def dtfmatmul(z, x, y, transpose_a=False, transpose_b=False):
  d[x] = tangent.matmul_adjoint_x(d[z], x, y, transpose_a, transpose_b)
  d[y] = tangent.matmul_adjoint_y(d[z], x, y, transpose_a, transpose_b)

def dtfconv2d(z, x, y, strides, padding):
  d[x] = tf.nn.conv2d_backprop_input(tf.shape(x), y, d[z], strides, padding)
  d[y] = tf.nn.conv2d_backprop_filter(x, tf.shape(y), d[z], strides, padding)

def dtfconv2d_backprop_input(z, shape, x, y, strides, padding):
  # TODO: Add tests.
  d[x] = tf.nn.conv2d_backprop_filter(d[z], shape, y, strides, padding)
  d[y] = tf.nn.conv2d(d[z], x, strides, padding)

def dtfconv2d_backprop_filter(z, x, shape, y, strides, padding):
  # TODO: Add tests.
  d[x] = tf.nn.conv2d_backprop_input(shape, d[z], y, strides, padding)
  d[y] = tf.nn.conv2d(x, d[z], strides, padding)

def dtfavg_pool(y, x, sizes, strides, padding):
  # TODO: We shouldn't rely on private modules.
  d[x] = tf.nn._nn_grad.gen_nn_ops._avg_pool_grad(
      tf.shape(x), d[y], sizes, strides, padding)

def dtfmax_pool(y, x, sizes, strides, padding):
  # TODO: We shouldn't rely on private modules.
  d[x] = tf.nn._nn_grad.gen_nn_ops._max_pool_grad(
      x, y, d[y], sizes, strides, padding)

# Tangents

def tshape_as_list(y, x):
  d[y] = tangent.shape_as_list(d[x])

def ttfexp(y, x):
  d[y] = d[x] * y

def ttflog(y, x):
  d[y] = d[x] / x

def ttftanh(y, x):
  cx = tf.cosh(x)
  d[y] = d[x] / (cx * cx)

def ttfcosh(y, x):
  d[y] = d[x] * tf.sinh(x)

def ttfsinh(y, x):
  d[y] = d[x] * tf.cosh(x)

def ttfexpand_dims(y, x, axis):
  d[y] = tf.expand_dims(d[x], axis)

def ttfsqueeze(y, x, axis):
  d[y] = tf.squeeze(d[x], axis)

def ttfreshape(y, x, shape):
  d[y] = tf.reshape(d[x], shape)

def ttfreduce_sum(y, x, axis=None, keep_dims=False):
  d[y] = tf.reduce_sum(d[x], axis, keep_dims)

def ttfreduce_mean(y, x, axis=None, keep_dims=False):
  d[y] = tf.reduce_mean(d[x], axis, keep_dims)

def ttfreduce_max(y, x, axis=None, keep_dims=False):
  mask = tf.to_float(
              tf.ones_like(y), tangent.shape_as_list(x), axis, keep_dims), x))
  d[y] = tf.multiply(d[x], mask)

def ttfnegative(y, x):
  d[y] = tf.negative(d[x])

def ttfadd(z, x, y):
  d[z] = tf.add(d[x], d[y])

def ttfsubtract(z, x, y):
  d[z] = tf.subtract(d[x], d[y])

def ttfmultiply(z, x, y):
  d[z] = tf.add(tf.multiply(d[x], y), tf.multiply(x, d[y]))

def ttfdivide(z, x, y):
  d[z] = tf.divide(
          tf.subtract(tf.multiply(d[x], y), tf.multiply(x, d[y])),
          tf.multiply(y, y))

def ttfmaximum(z, x, y):
  d[z] = d[x] * tf.to_float(tf.equal(z, x)) + d[y] * tf.to_float(tf.equal(z, y))

def ttfavg_pool(y, x, sizes, strides, padding):
  raise tangent.ForwardNotImplementedError(tf.nn.avg_pool)

def ttfmax_pool(y, x, sizes, strides, padding):
  raise tangent.ForwardNotImplementedError(tf.nn.max_pool)

def tshape(y, x):
  d[y] = tf.shape(d[x])

# Blacklist unimplemented Eager grads

    grads.get_module_functions((tf, tf.distributions, tf.image, tf.layers,
                                tf.linalg, tf.losses,
                                tf.nn)) - set(grads.adjoints))

    grads.get_module_functions((tf, tf.distributions, tf.image, tf.layers,
                                tf.linalg, tf.losses,
                                tf.nn)) - set(tangents.tangents))