# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Gradients for operators defined in math_ops.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops


def _safe_shape_div(x, y):
  """Divides `x / y` assuming `x, y >= 0`, treating `0 / 0 = 0`."""
  return x // math_ops.maximum(y, 1)


@ops.RegisterGradient("Sum")
def _SumGrad(op, grad):
  """Gradient for Sum."""
  # Fast path for when reducing to a scalar and ndims is known: adds only
  # Reshape and Tile ops (and possibly a Shape).
  if (op.inputs[0].get_shape().ndims is not None and op.inputs[1].op.type ==
      "Const"):
    rank = op.inputs[0].get_shape().ndims
    axes = tensor_util.MakeNdarray(op.inputs[1].op.get_attr("value"))
    if np.array_equal(axes, np.arange(rank)):  # Reduce all dims.
      grad = array_ops.reshape(grad, [1] * rank)
      # If shape is not fully defined (but rank is), we use Shape.
      if op.inputs[0].get_shape().is_fully_defined():
        input_shape = op.inputs[0].get_shape().as_list()
      else:
        input_shape = array_ops.shape(op.inputs[0])
      return [array_ops.tile(grad, input_shape), None]

  input_shape = array_ops.shape(op.inputs[0])
  output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
  grad = array_ops.reshape(grad, output_shape_kept_dims)
  return [array_ops.tile(grad, tile_scaling), None]


def _MinOrMaxGrad(op, grad):
  """Gradient for Min or Max. Amazingly it's precisely the same code."""
  input_shape = array_ops.shape(op.inputs[0])
  output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
  y = op.outputs[0]
  y = array_ops.reshape(y, output_shape_kept_dims)
  grad = array_ops.reshape(grad, output_shape_kept_dims)

  # Compute the number of selected (maximum or minimum) elements in each
  # reduction dimension. If there are multiple minimum or maximum elements
  # then the gradient will be divided between them.
  indicators = math_ops.cast(math_ops.equal(y, op.inputs[0]), grad.dtype)
  num_selected = array_ops.reshape(
      math_ops.reduce_sum(indicators, op.inputs[1]),
      output_shape_kept_dims)

  return [math_ops.div(indicators, num_selected) * grad, None]


@ops.RegisterGradient("Max")
def _MaxGrad(op, grad):
  """Gradient for Max."""
  return _MinOrMaxGrad(op, grad)


@ops.RegisterGradient("Min")
def _MinGrad(op, grad):
  return _MinOrMaxGrad(op, grad)


@ops.RegisterGradient("Mean")
def _MeanGrad(op, grad):
  """Gradient for Mean."""
  sum_grad = _SumGrad(op, grad)[0]
  input_shape = array_ops.shape(op.inputs[0])
  output_shape = array_ops.shape(op.outputs[0])
  factor = _safe_shape_div(math_ops.reduce_prod(input_shape),
                           math_ops.reduce_prod(output_shape))
  return sum_grad / math_ops.cast(factor, sum_grad.dtype), None


@ops.RegisterGradient("Prod")
def _ProdGrad(op, grad):
  """Gradient for Prod."""
  # The gradient can be expressed by dividing the product by each entry of the
  # input tensor, but this approach can't deal with zeros in the input.
  # Here, we avoid this problem by composing the output as a product of two
  # cumprod operations.

  input_shape = array_ops.shape(op.inputs[0])
  # Reshape reduction indices for the case where the parameter is a scalar
  reduction_indices = array_ops.reshape(op.inputs[1], [-1])

  # Expand grad to full input shape
  output_shape_kept_dims = math_ops.reduced_shape(input_shape, op.inputs[1])
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
  grad = array_ops.reshape(grad, output_shape_kept_dims)
  grad = array_ops.tile(grad, tile_scaling)

  # Pack all reduced dimensions into a single one, so we can perform the
  # cumprod ops. If the reduction dims list is empty, it defaults to float32,
  # so we need to cast here.  We put all the shape-related ops on CPU to avoid
  # copying back and forth, and since listdiff is CPU only.
  with ops.device("/cpu:0"):
    reduced = math_ops.cast(reduction_indices, dtypes.int32)
    idx = math_ops.range(0, array_ops.rank(op.inputs[0]))
    other, _ = array_ops.setdiff1d(idx, reduced)
    perm = array_ops.concat(0, [reduced, other])
    reduced_num = math_ops.reduce_prod(array_ops.gather(input_shape, reduced))
    other_num = math_ops.reduce_prod(array_ops.gather(input_shape, other))
  permuted = array_ops.transpose(op.inputs[0], perm)
  permuted_shape = array_ops.shape(permuted)
  reshaped = array_ops.reshape(permuted, (reduced_num, other_num))

  # Calculate product, leaving out the current entry
  left = math_ops.cumprod(reshaped, axis=0, exclusive=True)
  right = math_ops.cumprod(reshaped, axis=0, exclusive=True, reverse=True)
  y = array_ops.reshape(left * right, permuted_shape)

  # Invert the transpose and reshape operations.
  # Make sure to set the statically known shape information through a reshape.
  out = grad * array_ops.transpose(y, array_ops.invert_permutation(perm))
  return array_ops.reshape(out, input_shape), None


@ops.RegisterGradient("SegmentSum")
def _SegmentSumGrad(op, grad):
  """Gradient for SegmentSum."""
  return array_ops.gather(grad, op.inputs[1]), None


@ops.RegisterGradient("SegmentMean")
def _SegmentMeanGrad(op, grad):
  """Gradient for SegmentMean."""
  input_rank = array_ops.rank(op.inputs[0])
  ones_shape = array_ops.concat(
      0, [array_ops.shape(op.inputs[1]),
          array_ops.fill(array_ops.expand_dims(input_rank - 1, 0), 1)])
  ones = array_ops.fill(ones_shape,
                        constant_op.constant(1, dtype=grad.dtype))
  scaled_grad = math_ops.div(grad, math_ops.segment_sum(ones, op.inputs[1]))
  return array_ops.gather(scaled_grad, op.inputs[1]), None


@ops.RegisterGradient("SparseSegmentSum")
def _SparseSegmentSumGrad(op, grad):
  """Gradient for SparseSegmentSum."""
  input_rows = array_ops.shape(op.inputs[0])[0]
  return (math_ops.unsorted_segment_sum(
      array_ops.gather(grad, op.inputs[2]),
      op.inputs[1], input_rows), None, None)


@ops.RegisterGradient("SparseSegmentMean")
def _SparseSegmentMeanGrad(op, grad):
  """Gradient for SparseSegmentMean."""
  dim0 = array_ops.shape(op.inputs[0])[0]
  return (math_ops.sparse_segment_mean_grad(grad,
                                            op.inputs[1],
                                            op.inputs[2],
                                            dim0),
          None, None)


@ops.RegisterGradient("SparseSegmentSqrtN")
def _SparseSegmentSqrtNGrad(op, grad):
  """Gradient for SparseSegmentSqrtN."""
  dim0 = array_ops.shape(op.inputs[0])[0]
  return (math_ops.sparse_segment_sqrt_n_grad(grad,
                                              op.inputs[1],
                                              op.inputs[2],
                                              dim0),
          None, None)


def _SegmentMinOrMaxGrad(op, grad):
  """Gradient for SegmentMin and SegmentMax. Both share the same code."""
  zeros = array_ops.zeros(array_ops.shape(op.inputs[0]),
                          dtype=op.inputs[0].dtype)

  # Get the number of selected (minimum or maximum) elements in each segment.
  gathered_outputs = array_ops.gather(op.outputs[0], op.inputs[1])
  is_selected = math_ops.equal(op.inputs[0], gathered_outputs)
  num_selected = math_ops.segment_sum(math_ops.cast(is_selected, grad.dtype),
                                      op.inputs[1])

  # Compute the gradient for each segment. The gradient for the ith segment is
  # divided evenly among the selected elements in that segment.
  weighted_grads = math_ops.div(grad, num_selected)
  gathered_grads = array_ops.gather(weighted_grads, op.inputs[1])

  return math_ops.select(is_selected, gathered_grads, zeros), None


@ops.RegisterGradient("SegmentMin")
def _SegmentMinGrad(op, grad):
  """Gradient for SegmentMin."""
  return _SegmentMinOrMaxGrad(op, grad)


@ops.RegisterGradient("SegmentMax")
def _SegmentMaxGrad(op, grad):
  """Gradient for SegmentMax."""
  return _SegmentMinOrMaxGrad(op, grad)


@ops.RegisterGradient("UnsortedSegmentSum")
def _UnsortedSegmentSumGrad(op, grad):
  """Gradient for SegmentSum."""
  return array_ops.gather(grad, op.inputs[1]), None, None


@ops.RegisterGradient("Abs")
def _AbsGrad(op, grad):
  x = op.inputs[0]
  return grad * math_ops.sign(x)


@ops.RegisterGradient("Neg")
def _NegGrad(_, grad):
  """Returns -grad."""
  return -grad


@ops.RegisterGradient("Inv")
def _InvGrad(op, grad):
  """Returns -grad * (1 / x^2)."""
  y = op.outputs[0]  # y = 1 / x
  # pylint: disable=protected-access
  return gen_math_ops._inv_grad(y, grad)


@ops.RegisterGradient("InvGrad")
def _InvGradGrad(op, grad):
  b = op.inputs[1]
  # op.output[0]: y = -b * conj(a)^2
  with ops.control_dependencies([grad.op]):
    ca = math_ops.conj(op.inputs[0])
    cg = math_ops.conj(grad)
    # pylint: disable=protected-access
    return cg * -2.0 * b * ca, gen_math_ops._inv_grad(ca, grad)


@ops.RegisterGradient("Square")
def _SquareGrad(op, grad):
  x = op.inputs[0]
  # Added control dependencies to prevent 2*x from being computed too early.
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return grad * (2.0 * x)


@ops.RegisterGradient("Sqrt")
def _SqrtGrad(op, grad):
  y = op.outputs[0]  # y = x^(1/2)
  return gen_math_ops._sqrt_grad(y, grad)


@ops.RegisterGradient("SqrtGrad")
def _SqrtGradGrad(op, grad):
  a = op.inputs[0]
  y = op.outputs[0]  # y = 0.5 * b / conj(a)
  with ops.control_dependencies([grad.op]):
    ga = grad / a
    return -math_ops.conj(ga) * y, 0.5 * ga


@ops.RegisterGradient("Rsqrt")
def _RsqrtGrad(op, grad):
  """Returns -0.5 * grad * conj(y)^3."""
  y = op.outputs[0]  # y = x^(-1/2)
  return gen_math_ops._rsqrt_grad(y, grad)


@ops.RegisterGradient("RsqrtGrad")
def _RsqrtGradGrad(op, grad):
  """Returns backprop gradient for f(a,b) = -0.5 * b * conj(a)^3."""
  a = op.inputs[0]  # a = x^{-1/2}
  b = op.inputs[1]  # backprop gradient for a
  with ops.control_dependencies([grad.op]):
    ca = math_ops.conj(a)
    cg = math_ops.conj(grad)
    grad_a = -1.5 * cg * b * math_ops.square(ca)
    # pylint: disable=protected-access
    grad_b = gen_math_ops._rsqrt_grad(ca, grad)
    return grad_a, grad_b


@ops.RegisterGradient("Exp")
def _ExpGrad(op, grad):
  """Returns grad * exp(x)."""
  y = op.outputs[0]  # y = e^x
  with ops.control_dependencies([grad.op]):
    y = math_ops.conj(y)
    return grad * y


@ops.RegisterGradient("Log")
def _LogGrad(op, grad):
  """Returns grad * (1/x)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return grad * math_ops.inv(x)


@ops.RegisterGradient("Log1p")
def _Log1pGrad(op, grad):
  """Returns grad * (1/(1 + x))."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return grad * math_ops.inv(1 + x)


@ops.RegisterGradient("Tanh")
def _TanhGrad(op, grad):
  """Returns grad * (1 - tanh(x) * tanh(x))."""
  y = op.outputs[0]  # y = tanh(x)
  with ops.control_dependencies([grad.op]):
    y = math_ops.conj(y)
    # pylint: disable=protected-access
    return gen_math_ops._tanh_grad(y, grad)


@ops.RegisterGradient("TanhGrad")
def _TanhGradGrad(op, grad):
  with ops.control_dependencies([grad.op]):
    a = math_ops.conj(op.inputs[0])
    b = math_ops.conj(op.inputs[1])
    # pylint: disable=protected-access
    return grad * -2.0 * b * a, gen_math_ops._tanh_grad(a, grad)


@ops.RegisterGradient("Erf")
def _ErfGrad(op, grad):
  """Returns grad * 2/sqrt(pi) * exp(-x**2)."""
  x = op.inputs[0]
  two_over_root_pi = constant_op.constant(2 / np.sqrt(np.pi), dtype=grad.dtype)
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return grad * two_over_root_pi * math_ops.exp(-math_ops.square(x))


@ops.RegisterGradient("Erfc")
def _ErfcGrad(op, grad):
  """Returns -grad * 2/sqrt(pi) * exp(-x**2)."""
  x = op.inputs[0]
  minus_two_over_root_pi = constant_op.constant(-2 / np.sqrt(np.pi),
                                                dtype=grad.dtype)
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return grad * minus_two_over_root_pi * math_ops.exp(-math_ops.square(x))


@ops.RegisterGradient("Lgamma")
def _LgammaGrad(op, grad):
  """Returns grad * digamma(x)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return grad * math_ops.digamma(x)


@ops.RegisterGradient("Digamma")
def _DigammaGrad(op, grad):
  """Compute gradient of the digamma function with respect to its argument."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return grad * math_ops.polygamma(array_ops.constant(1, dtype=x.dtype), x)


@ops.RegisterGradient("Igamma")
def _IgammaGrad(op, grad):
  """Returns gradient of igamma(a, x) with respect to a and x."""
  # TODO(ebrevdo): Perhaps add the derivative w.r.t. a
  a = op.inputs[0]
  x = op.inputs[1]
  sa = array_ops.shape(a)
  sx = array_ops.shape(x)
  unused_ra, rx = gen_array_ops._broadcast_gradient_args(sa, sx)

  # Perform operations in log space before summing, because Gamma(a)
  # and Gamma'(a) can grow large.
  partial_x = math_ops.exp(-x + (a-1) * math_ops.log(x) - math_ops.lgamma(a))
  return (None,
          array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))


@ops.RegisterGradient("Igammac")
def _IgammacGrad(op, grad):
  """Returns gradient of igammac(a, x) = 1 - igamma(a, x) w.r.t. a and x."""
  return [-1 * g if g is not None else None for g in _IgammaGrad(op, grad)]


@ops.RegisterGradient("Zeta")
def _ZetaGrad(op, grad):
  """Returns gradient of zeta(x, q) with respect to x and q."""
  # TODO(tillahoffmann): Add derivative with respect to x
  x = op.inputs[0]
  q = op.inputs[1]
  # Broadcast gradients
  sx = array_ops.shape(x)
  sq = array_ops.shape(q)
  unused_rx, rq = gen_array_ops._broadcast_gradient_args(sx, sq)
  # Evaluate gradient
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    q = math_ops.conj(q)
    partial_q = -x * math_ops.zeta(x + 1, q)
    return (None,
            array_ops.reshape(math_ops.reduce_sum(partial_q * grad, rq), sq))


@ops.RegisterGradient("Polygamma")
def _PolygammaGrad(op, grad):
  """Returns gradient of psi(n, x) with respect to n and x."""
  # TODO(tillahoffmann): Add derivative with respect to n
  n = op.inputs[0]
  x = op.inputs[1]
  # Broadcast gradients
  sn = array_ops.shape(n)
  sx = array_ops.shape(x)
  unused_rn, rx = gen_array_ops._broadcast_gradient_args(sn, sx)
  # Evaluate gradient
  with ops.control_dependencies([grad.op]):
    n = math_ops.conj(n)
    x = math_ops.conj(x)
    partial_x = math_ops.polygamma(n + 1, x)
    return (None,
            array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx))


@ops.RegisterGradient("Sigmoid")
def _SigmoidGrad(op, grad):
  """Returns grad * sigmoid(x) * (1 - sigmoid(x))."""
  y = op.outputs[0]  # y = sigmoid(x)
  with ops.control_dependencies([grad.op]):
    y = math_ops.conj(y)
    # pylint: disable=protected-access
    return gen_math_ops._sigmoid_grad(y, grad)


@ops.RegisterGradient("SigmoidGrad")
def _SigmoidGradGrad(op, grad):
  with ops.control_dependencies([grad.op]):
    a = math_ops.conj(op.inputs[0])
    b = math_ops.conj(op.inputs[1])
    gb = grad * b
    # pylint: disable=protected-access
    return gb - 2.0 * gb * a, gen_math_ops._sigmoid_grad(a, grad)


@ops.RegisterGradient("Sign")
def _SignGrad(op, _):
  """Returns 0."""
  x = op.inputs[0]
  return array_ops.zeros(array_ops.shape(x), dtype=x.dtype)


@ops.RegisterGradient("Sin")
def _SinGrad(op, grad):
  """Returns grad * cos(x)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return grad * math_ops.cos(x)


@ops.RegisterGradient("Cos")
def _CosGrad(op, grad):
  """Returns grad * -sin(x)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    return -grad * math_ops.sin(x)


@ops.RegisterGradient("Tan")
def _TanGrad(op, grad):
  """Returns grad * 1/sec^2(x)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    secx = math_ops.inv(math_ops.cos(x))
    secx2 = math_ops.square(secx)
    return grad * secx2


@ops.RegisterGradient("Asin")
def _AsinGrad(op, grad):
  """Returns grad * 1/sqrt(1-x^2)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    x2 = math_ops.square(x)
    one = constant_op.constant(1, dtype=grad.dtype)
    den = math_ops.sqrt(math_ops.sub(one, x2))
    inv = math_ops.inv(den)
    return grad * inv


@ops.RegisterGradient("Acos")
def _AcosGrad(op, grad):
  """Returns grad * -1/sqrt(1-x^2)."""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    x2 = math_ops.square(x)
    one = constant_op.constant(1, dtype=grad.dtype)
    den = math_ops.sqrt(math_ops.sub(one, x2))
    inv = math_ops.inv(den)
    return -grad * inv


@ops.RegisterGradient("Atan")
def _AtanGrad(op, grad):
  """Returns grad * 1/ (1 + x^2)"""
  x = op.inputs[0]
  with ops.control_dependencies([grad.op]):
    x = math_ops.conj(x)
    x2 = math_ops.square(x)
    one = constant_op.constant(1, dtype=grad.dtype)
    inv = math_ops.inv(math_ops.add(one, x2))
    return grad * inv


@ops.RegisterGradient("AddN")
def _AddNGrad(op, grad):
  """Copies the gradient to all inputs."""
  # Not broadcasting.
  return [grad] * len(op.inputs)


@ops.RegisterGradient("Add")
def _AddGrad(op, grad):
  x = op.inputs[0]
  y = op.inputs[1]
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
          array_ops.reshape(math_ops.reduce_sum(grad, ry), sy))


@ops.RegisterGradient("Sub")
def _SubGrad(op, grad):
  x = op.inputs[0]
  y = op.inputs[1]
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  return (array_ops.reshape(math_ops.reduce_sum(grad, rx), sx),
          array_ops.reshape(-math_ops.reduce_sum(grad, ry), sy))


@ops.RegisterGradient("Mul")
def _MulGrad(op, grad):
  """The gradient of scalar multiplication."""
  x = op.inputs[0]
  y = op.inputs[1]
  assert x.dtype.base_dtype == y.dtype.base_dtype, (x.dtype, " vs. ", y.dtype)
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  x = math_ops.conj(x)
  y = math_ops.conj(y)
  return (array_ops.reshape(math_ops.reduce_sum(grad * y, rx), sx),
          array_ops.reshape(math_ops.reduce_sum(x * grad, ry), sy))


@ops.RegisterGradient("Div")
def _DivGrad(op, grad):
  x = op.inputs[0]
  y = op.inputs[1]
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)  # pylint: disable=protected-access
  x = math_ops.conj(x)
  y = math_ops.conj(y)
  return (array_ops.reshape(math_ops.reduce_sum(grad / y, rx), sx),
          array_ops.reshape(math_ops.reduce_sum(grad *
                                         (-x / math_ops.square(y)), ry), sy))


@ops.RegisterGradient("Pow")
def _PowGrad(op, grad):
  """Returns grad * (y*x^(y-1), z*log(x))."""
  x = op.inputs[0]
  y = op.inputs[1]
  z = op.outputs[0]
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  x = math_ops.conj(x)
  y = math_ops.conj(y)
  z = math_ops.conj(z)
  gx = array_ops.reshape(
      math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
  # Avoid false singularity at x = 0
  if x.dtype.is_complex:
    # real(x) < 0 is fine for the complex case
    log_x = math_ops.select(
        math_ops.not_equal(x, 0), math_ops.log(x), array_ops.zeros_like(x))
  else:
    # There's no sensible real value to return if x < 0, so return 0
    log_x = math_ops.select(x > 0, math_ops.log(x), array_ops.zeros_like(x))
  gy = array_ops.reshape(
      math_ops.reduce_sum(grad * z * log_x, ry), sy)
  return gx, gy


def _MaximumMinimumGrad(op, grad, selector_op):
  """Factor out the code for the gradient of Maximum or Minimum."""
  x = op.inputs[0]
  y = op.inputs[1]
  gdtype = grad.dtype
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  gradshape = array_ops.shape(grad)
  zeros = array_ops.zeros(gradshape, gdtype)
  xmask = selector_op(x, y)
  rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  xgrad = math_ops.select(xmask, grad, zeros)
  ygrad = math_ops.select(math_ops.logical_not(xmask), grad, zeros)
  gx = array_ops.reshape(math_ops.reduce_sum(xgrad, rx), sx)
  gy = array_ops.reshape(math_ops.reduce_sum(ygrad, ry), sy)
  return (gx, gy)


@ops.RegisterGradient("Maximum")
def _MaximumGrad(op, grad):
  """Returns grad*(x > y, x <= y) with type of grad."""
  return _MaximumMinimumGrad(op, grad, math_ops.greater_equal)


@ops.RegisterGradient("Minimum")
def _MinimumGrad(op, grad):
  """Returns grad*(x < y, x >= y) with type of grad."""
  return _MaximumMinimumGrad(op, grad, math_ops.less_equal)


@ops.RegisterGradient("SquaredDifference")
def _SquaredDifferenceGrad(op, grad):
  """Returns the gradient for (x-y)^2."""
  x = op.inputs[0]
  y = op.inputs[1]
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  # pylint: disable=protected-access
  rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  # pylint: enable=protected-access
  # .op works with Tensors or IndexedSlices
  with ops.control_dependencies([grad.op]):
    # The parens ensure that if grad is IndexedSlices, it'll get multiplied by
    # Tensor (not a number like 2.0) which causes it to convert to Tensor.
    x_grad = math_ops.scalar_mul(2.0, grad) * (x - y)
  return (array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx),
          -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy))


# Logical operations have no gradients.
ops.NotDifferentiable("Less")
ops.NotDifferentiable("LessEqual")
ops.NotDifferentiable("Greater")
ops.NotDifferentiable("GreaterEqual")
ops.NotDifferentiable("Equal")
ops.NotDifferentiable("NotEqual")
ops.NotDifferentiable("LogicalAnd")
ops.NotDifferentiable("LogicalOr")
ops.NotDifferentiable("LogicalNot")


@ops.RegisterGradient("Select")
def _SelectGrad(op, grad):
  c = op.inputs[0]
  x = op.inputs[1]
  zeros = array_ops.zeros_like(x)
  return (None, math_ops.select(c, grad, zeros),
          math_ops.select(c, zeros, grad))


@ops.RegisterGradient("MatMul")
def _MatMulGrad(op, grad):
  t_a = op.get_attr("transpose_a")
  t_b = op.get_attr("transpose_b")
  if not t_a and not t_b:
    return (math_ops.matmul(grad, op.inputs[1], transpose_b=True),
            math_ops.matmul(op.inputs[0], grad, transpose_a=True))
  elif not t_a and t_b:
    return (math_ops.matmul(grad, op.inputs[1]),
            math_ops.matmul(grad, op.inputs[0], transpose_a=True))
  elif t_a and not t_b:
    return (math_ops.matmul(op.inputs[1], grad, transpose_b=True),
            math_ops.matmul(op.inputs[0], grad))
  elif t_a and t_b:
    return (math_ops.matmul(op.inputs[1], grad, transpose_a=True,
                            transpose_b=True),
            math_ops.matmul(grad, op.inputs[0], transpose_a=True,
                            transpose_b=True))


@ops.RegisterGradient("SparseMatMul")
def _SparseMatMulGrad(op, grad):
  """Gradient for SparseMatMul."""

  t_a = op.get_attr("transpose_a")
  t_b = op.get_attr("transpose_b")
  is_sparse = {
      op.inputs[0]: op.get_attr("a_is_sparse"),
      op.inputs[1]: op.get_attr("b_is_sparse"),
      # Use heuristic to figure out if grad might be sparse
      grad: (grad.op.type == "ReluGrad")
  }
  def _SparseMatMul(t1, t2, out_dtype,
                    transpose_a=False, transpose_b=False):
    """Helper function to create SparseMatMul op."""

    assert t1 in is_sparse and t2 in is_sparse
    t1_sparse = is_sparse[t1]
    t2_sparse = is_sparse[t2]
    if transpose_b:
      t2 = array_ops.transpose(t2)
      transpose_b = False
    prod = math_ops.matmul(t1, t2,
                           transpose_a=transpose_a,
                           transpose_b=transpose_b,
                           a_is_sparse=t1_sparse,
                           b_is_sparse=t2_sparse)
    if prod.dtype != out_dtype:
      prod = math_ops.cast(prod, out_dtype)
    return prod

  dtype_a = op.inputs[0].dtype
  dtype_b = op.inputs[1].dtype
  if not t_a and not t_b:
    return (_SparseMatMul(grad, op.inputs[1], dtype_a, transpose_b=True),
            _SparseMatMul(op.inputs[0], grad, dtype_b, transpose_a=True))
  elif not t_a and t_b:
    return (_SparseMatMul(grad, op.inputs[1], dtype_a),
            _SparseMatMul(grad, op.inputs[0], dtype_b, transpose_a=True))
  elif t_a and not t_b:
    return (_SparseMatMul(op.inputs[1], grad, dtype_a, transpose_b=True),
            _SparseMatMul(op.inputs[0], grad, dtype_b))
  elif t_a and t_b:
    return (_SparseMatMul(op.inputs[1], grad, dtype_a,
                          transpose_a=True, transpose_b=True),
            _SparseMatMul(grad, op.inputs[0], dtype_b,
                          transpose_a=True, transpose_b=True))


@ops.RegisterGradient("Floor")
def _FloorGrad(_, unused_grad):
  return [None]


@ops.RegisterGradient("BatchMatMul")
def _BatchMatMul(op, grad):
  """Returns the gradient of x and y given the gradient of x * y."""
  x = op.inputs[0]
  y = op.inputs[1]
  adj_x = op.get_attr("adj_x")
  adj_y = op.get_attr("adj_y")

  if not adj_x:
    if not adj_y:
      grad_x = math_ops.batch_matmul(grad, y, False, True)
      grad_y = math_ops.batch_matmul(x, grad, True, False)
    else:
      grad_x = math_ops.batch_matmul(grad, y, False, False)
      grad_y = math_ops.batch_matmul(grad, x, True, False)
  else:
    if not adj_y:
      grad_x = math_ops.batch_matmul(y, grad, False, True)
      grad_y = math_ops.batch_matmul(x, grad, False, False)
    else:
      grad_x = math_ops.batch_matmul(y, grad, True, True)
      grad_y = math_ops.batch_matmul(grad, x, True, True)

  return grad_x, grad_y


ops.NotDifferentiable("Range")
ops.NotDifferentiable("LinSpace")


@ops.RegisterGradient("Complex")
def _ComplexGrad(op, grad):
  """Returns the real and imaginary components of 'grad', respectively."""
  x = op.inputs[0]
  y = op.inputs[1]
  sx = array_ops.shape(x)
  sy = array_ops.shape(y)
  rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy)
  return (array_ops.reshape(math_ops.reduce_sum(math_ops.real(grad), rx), sx),
          array_ops.reshape(math_ops.reduce_sum(math_ops.imag(grad), ry), sy))


@ops.RegisterGradient("Real")
def _RealGrad(_, grad):
  """Returns 'grad' as the real part and set the imaginary part 0."""
  zero = constant_op.constant(0, dtype=grad.dtype)
  return math_ops.complex(grad, zero)


@ops.RegisterGradient("Imag")
def _ImagGrad(_, grad):
  """Returns 'grad' as the imaginary part and set the real part 0."""
  zero = constant_op.constant(0, dtype=grad.dtype)
  return math_ops.complex(zero, grad)


@ops.RegisterGradient("Conj")
def _ConjGrad(_, grad):
  """Returns the complex conjugate of grad."""
  return math_ops.conj(grad)


@ops.RegisterGradient("ComplexAbs")
def _ComplexAbsGrad(op, grad):
  """Returns the gradient of ComplexAbs."""
  # TODO(b/27786104): The cast to complex could be removed once arithmetic
  # supports mixtures of complex64 and real values.
  return (math_ops.complex(grad, array_ops.zeros_like(grad)) *
          math_ops.sign(op.inputs[0]))


@ops.RegisterGradient("Cast")
def _CastGrad(op, grad):
  t = [dtypes.float16, dtypes.float32, dtypes.float64,
       dtypes.bfloat16, dtypes.complex64, dtypes.complex128]
  src_type = op.inputs[0].dtype.base_dtype
  dst_type = grad.dtype.base_dtype
  if src_type in t and dst_type in t:
    return math_ops.cast(grad, src_type)
  else:
    return None


def _FFTSizeForGrad(grad, rank):
  return math_ops.reduce_prod(
      array_ops.slice(
          array_ops.reverse(array_ops.shape(grad), (True,)), (0,), (rank,)))


@ops.RegisterGradient("FFT")
def _FFTGrad(_, grad):
  size = math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32)
  return math_ops.ifft(grad) * math_ops.complex(size, 0.)


@ops.RegisterGradient("IFFT")
def _IFFTGrad(_, grad):
  rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 1), dtypes.float32)
  return math_ops.fft(grad) * math_ops.complex(rsize, 0.)


@ops.RegisterGradient("FFT2D")
def _FFT2DGrad(_, grad):
  size = math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32)
  return math_ops.ifft2d(grad) * math_ops.complex(size, 0.)


@ops.RegisterGradient("IFFT2D")
def _IFFT2DGrad(_, grad):
  rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 2), dtypes.float32)
  return math_ops.fft2d(grad) * math_ops.complex(rsize, 0.)


@ops.RegisterGradient("FFT3D")
def _FFT3DGrad(_, grad):
  size = math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32)
  return math_ops.ifft3d(grad) * math_ops.complex(size, 0.)


@ops.RegisterGradient("IFFT3D")
def _IFFT3DGrad(_, grad):
  rsize = 1. / math_ops.cast(_FFTSizeForGrad(grad, 3), dtypes.float32)
  return math_ops.fft3d(grad) * math_ops.complex(rsize, 0.)


@ops.RegisterGradient("Cross")
def _CrossGrad(op, grad):
  u = op.inputs[0]
  v = op.inputs[1]
  return (math_ops.cross(v, grad), math_ops.cross(grad, u))


@ops.RegisterGradient("Cumsum")
def _CumsumGrad(op, grad):
  axis = op.inputs[1]
  exclusive = op.get_attr("exclusive")
  reverse = op.get_attr("reverse")
  return [math_ops.cumsum(grad, axis, exclusive=exclusive,
                          reverse=not reverse), None]


@ops.RegisterGradient("Cumprod")
def _CumprodGrad(op, grad):
  x = op.inputs[0]
  axis = op.inputs[1]
  exclusive = op.get_attr("exclusive")
  reverse = op.get_attr("reverse")

  # TODO This fails when x contains 0 and should be fixed
  prod = math_ops.cumprod(x, axis, exclusive=exclusive, reverse=reverse)
  out = math_ops.cumsum(prod * grad, axis, exclusive=exclusive,
                        reverse=not reverse)
  return [out / x, None]