import tensorflow as tf
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_nn_ops


# @ops.RegisterGradient("MaxPoolWithArgmax")
# def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
#   """The gradients for `MaxPoolWithArgmax`.
#   Args:
#     op: The `MaxPoolWithArgmax` `Operation` that we are differentiating, which
#       we can use to find the inputs and outputs of the original op.
#     grad: Gradient with respect to the output of the `MaxPoolWithArgmax` op.
#     op.inputs[0]: x
#     op.outputs[0]: y
#     op.outputs[1]: argmax_in_x
#   Returns:
#     Gradients with respect to the input of `MaxPoolWithArgmax`.
#   """
# 
#   return gen_nn_ops._max_pool_grad_with_argmax(
#       op.inputs[0],
#       grad,
#       op.outputs[1],
#       op.get_attr("ksize"),
#       op.get_attr("strides"),
#       padding=op.get_attr("padding"))


def _max_pool_grad_grad(dy, x, y, ksize, strides, padding, argmax=None):
  """Gradients of MaxPoolGrad."""
  if argmax is None:
    _, argmax = tf.nn.max_pool_with_argmax(x, ksize, strides, padding)
  grad = dy
  grad_flat = tf.reshape(grad, [-1])
  argmax_flat = tf.reshape(argmax, [-1])

  x_shape = tf.cast(tf.shape(x), argmax.dtype)
  batch_dim = tf.reshape(
      tf.range(
          x_shape[0], dtype=argmax.dtype), [-1, 1, 1, 1])
  nelem = tf.reduce_prod(x_shape[1:])
  batch_dim *= nelem

  y_zero = tf.zeros_like(y, dtype=argmax.dtype)
  batch_dim += y_zero
  batch_dim = tf.reshape(batch_dim, [-1])

  argmax_flat += batch_dim
  grad_input = tf.gather(grad_flat, argmax_flat)
  grad_input = tf.reshape(grad_input, tf.shape(y))
  return grad_input


@ops.RegisterGradient("MaxPoolGradWithArgmax")
def _MaxPoolGradWithArgmaxGrad(op, grad):
  """The gradients for `MaxPoolGradWithArgmax`.
  Args:
    op: The `MaxPoolGradWithArgmax` `Operation` that we are differentiating,
      which we can use to find the inputs and outputs of the original op.
    grad: Gradient with respect to the output of the `MaxPoolGradWithArgmax` op.
    op.inputs[0]: x
    op.inputs[1]: dl/dy
    op.inputs[2]: argmax_in_x
    op.outputs[0]: dl/dx
  Returns:
    Gradients with respect to the input of `MaxPoolGradWithArgmax`.
  """
  ksize = op.get_attr("ksize")
  strides = op.get_attr("strides")
  padding = op.get_attr("padding")
  return [
      None, _max_pool_grad_grad(
          grad,
          op.inputs[0],
          op.inputs[1],
          ksize,
          strides,
          padding,
          argmax=op.inputs[2]), None
  ]


# @ops.RegisterGradient("MaxPoolGrad")
# def _MaxPoolGradGrad(op, grad):
#   """The gradients for `MaxPoolGrad`.
#   Args:
#     op: The `MaxPoolGrad` `Operation` that we are differentiating, which we can use
#       to find the inputs and outputs of the original op.
#       op.inputs[0]: x
#       op.inputs[1]: y
#       op.inputs[2]: dl/dy
#       op.outputs[0]: dl/dx
#     grad: Gradient with respect to the output of the `MaxPoolGrad` op.
#   Returns:
#     Gradients with respect to the input of `MaxPoolGrad`.
#   """
#   ksize = op.get_attr("ksize")
#   strides = op.get_attr("strides")
#   padding = op.get_attr("padding")
#   return [
#       None, None, _max_pool_grad_grad(grad, op.inputs[0], op.inputs[1], ksize,
#                                       strides, padding)
#   ]