# coding=utf-8
# Copyright 2020 The Mesh TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Mesh TensorFlow."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl.testing import parameterized

import mesh_tensorflow as mtf
import tensorflow.compat.v1 as tf
from tensorflow.python.framework import test_util  # pylint:disable=g-direct-tensorflow-import


class LaidOutTensor(object):
  """LaidOutTensor (see placement_mesh_impl.py, simd_mesh_impl.py) for tests."""

  def __init__(self, tensor_list):
    self.tensor_list = tensor_list


class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase):

  @parameterized.parameters(
      (mtf.Dimension("x", 5),),
      (("x", 5),),
  )
  def testConvertToDimension(self, inputs):
    dimension = mtf.convert_to_dimension(inputs)
    self.assertEqual(dimension.name, "x")
    self.assertEqual(dimension.size, 5)

  def testConvertToDimensionGenericInputs(self):
    dimension = mtf.convert_to_dimension(None)
    self.assertEqual(dimension, None)
    with self.assertRaises(TypeError):
      mtf.convert_to_dimension(5)

  @parameterized.parameters(
      (mtf.Shape([mtf.Dimension("x", 4),
                  mtf.Dimension("y", 8)]),),
      ("x:4;y:8",),
      ("x:4.y:8",),
      ("x:4 y:8",),
      ("x:4,y:8",),
  )
  def testConvertToShape(self, inputs):
    shape = mtf.convert_to_shape(inputs)
    self.assertEqual(shape, mtf.Shape([mtf.Dimension("x", 4),
                                       mtf.Dimension("y", 8)]))

  def testConvertToShapeGenericInputs(self):
    shape = mtf.convert_to_shape([])
    self.assertEqual(shape.dims, [])
    shape = mtf.convert_to_shape(None)
    self.assertEqual(shape, None)
    with self.assertRaises(ValueError):
      mtf.convert_to_shape("x;4")

  @parameterized.parameters(
      (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]),),
      ("d_ff:model;heads:model",),
      ("d_ff:model.heads:model",),
      ("d_ff:model heads:model",),
      ("d_ff:model,heads:model",),
      ([("d_ff", "model"), ("heads", "model")],),
  )
  def testConvertToLayoutRules(self, inputs):
    layout_rules = mtf.convert_to_layout_rules(inputs)
    self.assertEqual(
        layout_rules._pairs,
        mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs)

  def testConvertToLayoutRulesGenericInputs(self):
    with self.assertRaises(ValueError):
      mtf.convert_to_layout_rules("d_ff;heads")

  def testTensorLayout(self):
    tensor_layout = mtf.TensorLayout([0, 2, 1])
    self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ())
    self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0,))
    self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2))
    tensor_layout = mtf.TensorLayout([None, 0])
    self.assertFalse(tensor_layout.is_fully_replicated)
    tensor_layout = mtf.TensorLayout([None, None, None])
    self.assertTrue(tensor_layout.is_fully_replicated)

  def testGraph(self):
    graph = mtf.Graph()
    self.assertEmpty(graph.operations)
    self.assertEmpty(graph.trainable_variables)
    self.assertEmpty(graph.all_variables)
    mesh = mtf.Mesh(graph, "mesh_test")
    _ = mtf.import_tf_tensor(mesh,
                             tf_tensor=tf.constant(0.),
                             shape=mtf.Shape([]))
    self.assertLen(graph.operations, 1)
    self.assertEmpty(graph.trainable_variables)
    self.assertEmpty(graph.all_variables)
    _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True)
    self.assertLen(graph.operations, 2)
    self.assertLen(graph.trainable_variables, 1)
    self.assertLen(graph.all_variables, 1)
    _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False)
    self.assertLen(graph.operations, 3)
    self.assertLen(graph.trainable_variables, 1)
    self.assertLen(graph.all_variables, 2)

  def testGraphNames(self):
    # Standard Usage.
    graph = mtf.Graph()
    self.assertEqual(graph.unique_name("a"), "a")
    self.assertEqual(graph.unique_name("a"), "a_1")
    self.assertEqual(graph.unique_name("a"), "a_2")

    # Edge cases, the user may choose the name "a_1".
    graph = mtf.Graph()
    self.assertEqual(graph.unique_name("a"), "a")
    self.assertEqual(graph.unique_name("a"), "a_1")
    self.assertEqual(graph.unique_name("a_1"), "a_1_1")

    graph = mtf.Graph()
    self.assertEqual(graph.unique_name("a"), "a")
    self.assertEqual(graph.unique_name("a_1"), "a_1")
    self.assertEqual(graph.unique_name("a"), "a_2")

    # Case insensitive.
    graph = mtf.Graph()
    self.assertEqual(graph.unique_name("a"), "a")
    self.assertEqual(graph.unique_name("A"), "A_1")

  @test_util.run_in_graph_and_eager_modes()
  def testLowering(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    inputs = tf.constant(0.)
    mtf_inputs = mtf.import_tf_tensor(mesh,
                                      tf_tensor=inputs,
                                      shape=mtf.Shape([]))
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape=[], layout={}, devices=[""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    outputs = lowering.export_to_tf_tensor(mtf_inputs)
    inputs_value, outputs_value = self.evaluate([inputs, outputs])
    self.assertEqual(inputs_value, outputs_value)

    # Check that methods run without error.
    _ = lowering.copy_masters_to_slices()
    _ = lowering.copy_slices_to_masters()

  def testMesh(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    self.assertEqual(mesh.graph, graph)

  def testMeshImpl(self):
    shape = mtf.Shape([mtf.Dimension("batch", 4),
                       mtf.Dimension("model", 8)])
    layout_rules = mtf.LayoutRules([("batch", "batch"),
                                    ("d_ff", "model"),
                                    ("heads", "model")])
    mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules)
    self.assertEqual(mesh_impl.shape, shape)
    self.assertLen(shape, mesh_impl.ndims)
    self.assertEqual(mesh_impl.layout_rules, layout_rules)
    self.assertEqual(mesh_impl.size, shape.size)
    self.assertTrue(mesh_impl.supports_control_dependencies)

    batch = mtf.Dimension("batch", 128)
    length = mtf.Dimension("length", 500)
    d_ff = mtf.Dimension("d_ff", 2048)
    heads = mtf.Dimension("heads", 8)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1)
    self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1)
    self.assertEqual(mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])),
                     mtf.TensorLayout([0, None, 1]))


class OperationSplittabilityTest(tf.test.TestCase):

  def setUp(self):
    super(OperationSplittabilityTest, self).setUp()
    self.graph = mtf.Graph()
    self.mesh = mtf.Mesh(self.graph, "my_mesh")

    self.a_dim = mtf.Dimension("a", 5)
    self.b_dim = mtf.Dimension("b", 10)
    self.c_dim = mtf.Dimension("c", 15)

    self.ab_shape = mtf.Shape([self.a_dim, self.b_dim])
    self.x = mtf.zeros(self.mesh, self.ab_shape)

    self.batch_dim = mtf.Dimension("batch", 100)
    self.grid_h_dim = mtf.Dimension("grid_h", 10)
    self.grid_w_dim = mtf.Dimension("grid_w", 10)
    self.filter_h_dim = mtf.Dimension("filter_h", 5)
    self.filter_w_dim = mtf.Dimension("filter_w", 5)
    self.in_dim = mtf.Dimension("in", 10)
    self.out_dim = mtf.Dimension("out", 10)
    self.image = mtf.zeros(self.mesh, [self.batch_dim, self.grid_h_dim,
                                       self.grid_w_dim, self.in_dim])

  def testOperation(self):
    operation = mtf.Operation([self.x], name="operation")

    # Everything is splittable.
    self.assertEqual(
        operation._initialize_all_dimensions_as_splittable(),
        (frozenset(["a", "b"]), frozenset()))

    # Everything is unsplittable.
    self.assertEqual(
        operation._initialize_splittable_and_unsplittable_dims("unsplittable"),
        (frozenset(), frozenset(["a", "b"])))

    # Everything is unsplittable except dimension "b".
    self.assertEqual(
        operation._initialize_splittable_and_unsplittable_dims(
            "unsplittable", ["b"]),
        (frozenset(["b"]), frozenset(["a"])))

    self.assertRaises(
        ValueError,
        operation._initialize_splittable_and_unsplittable_dims,
        "invalid")

  def testSlicewiseOperationAndGenericGradOperation(self):
    slicewise_operation = mtf.SlicewiseOperation(
        tf.exp,
        [self.x],
        [self.x.shape],
        [self.x.dtype],
        splittable_dims=[self.a_dim],  # pretend only dim "a" can be split.
        grad_function=lambda op, dy: [dy * op.outputs[0]],
        name="component-wise exp")

    self.assertEqual(slicewise_operation.splittable_dims, frozenset(["a"]))
    self.assertEqual(slicewise_operation.unsplittable_dims, frozenset(["b"]))

    generic_grad_operation = mtf.GenericGradOperation(slicewise_operation,
                                                      [self.x])

    self.assertEqual(generic_grad_operation.splittable_dims,
                     frozenset(["a", "b"]))
    self.assertEqual(generic_grad_operation.unsplittable_dims,
                     frozenset())

  def testScalarMultiplyOperationandScalarAddOperation(self):
    scalar = 2.0
    scalar_multiply_operation = mtf.ScalarMultiplyOperation(self.x, scalar)
    self.assertEqual(scalar_multiply_operation.splittable_dims,
                     frozenset(["a", "b"]))
    self.assertEqual(scalar_multiply_operation.unsplittable_dims, frozenset())

    scalar_add_operation = mtf.ScalarAddOperation(self.x, scalar)
    self.assertEqual(scalar_add_operation.splittable_dims,
                     frozenset(["a", "b"]))
    self.assertEqual(scalar_add_operation.unsplittable_dims, frozenset())

  def testBinaryOpWithBroadcasting(self):
    x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.c_dim]))
    binary_op_with_broadcasting = mtf.BinaryOpWithBroadcasting(
        tf.less,
        self.x,
        x2,
        mtf.Shape([self.a_dim, self.b_dim, self.c_dim]),
        tf.bool,
        name="less with broadcasting")

    self.assertEqual(binary_op_with_broadcasting.splittable_dims,
                     frozenset(["a", "b", "c"]))
    self.assertEqual(binary_op_with_broadcasting.unsplittable_dims, frozenset())

  def testBroadcastOperation(self):
    broadcast_operation = mtf.BroadcastOperation(
        self.x, mtf.Shape([self.b_dim, self.c_dim, self.a_dim]))
    self.assertEqual(broadcast_operation.splittable_dims,
                     frozenset(["a", "b", "c"]))
    self.assertEqual(broadcast_operation.unsplittable_dims, frozenset())

  def testReduceOperation(self):
    reduce_operation = mtf.ReduceOperation(self.x, mtf.Shape([self.b_dim]),
                                           "sum")
    self.assertEqual(reduce_operation.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(reduce_operation.unsplittable_dims, frozenset())

  def testPoolOperation(self):
    reduce_operation = mtf.PoolOperation(self.image, [2, 2], [2, 2], "AVG_2D")
    self.assertEqual(reduce_operation.splittable_dims,
                     frozenset(["batch", "in"]))
    self.assertEqual(reduce_operation.unsplittable_dims,
                     frozenset(["grid_h", "grid_w"]))

  def testConcatOperation(self):
    concat_dim1 = mtf.Dimension("concat", 5)
    concat_dim2 = mtf.Dimension("concat", 7)

    x1 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.b_dim, concat_dim1]))
    x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.b_dim, concat_dim2]))

    concat_operation = mtf.ConcatOperation([x1, x2], "concat")
    self.assertEqual(concat_operation.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(concat_operation.unsplittable_dims, frozenset(["concat"]))

  def testSplitOperation(self):
    split_operation = mtf.SplitOperation(self.x, self.b_dim, [3, 7])
    self.assertEqual(split_operation.splittable_dims, frozenset(["a"]))
    self.assertEqual(split_operation.unsplittable_dims, frozenset(["b"]))

  def testStackOperation(self):
    stack_operation = mtf.StackOperation([self.x, self.x], "stack", axis=0)
    self.assertEqual(stack_operation.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(stack_operation.unsplittable_dims, frozenset(["stack"]))

  def testUnstackOperation(self):
    unstack_operation = mtf.UnstackOperation(self.x, self.b_dim)
    self.assertEqual(unstack_operation.splittable_dims, frozenset(["a"]))
    self.assertEqual(unstack_operation.unsplittable_dims, frozenset(["b"]))

  def testEinsumOperation(self):
    x2 = mtf.zeros(self.mesh, mtf.Shape([self.a_dim, self.c_dim]))
    einsum_operation = mtf.EinsumOperation([self.x, x2],
                                           mtf.Shape([self.b_dim, self.c_dim]))
    self.assertEqual(einsum_operation.splittable_dims,
                     frozenset(["a", "b", "c"]))
    self.assertEqual(einsum_operation.unsplittable_dims, frozenset())

  def testConv2dOperations(self):
    conv_input = mtf.zeros(
        self.mesh,
        mtf.Shape([self.batch_dim, self.grid_h_dim, self.grid_w_dim,
                   self.in_dim]))
    conv_filter = mtf.zeros(
        self.mesh,
        mtf.Shape([self.filter_h_dim, self.filter_w_dim, self.in_dim,
                   self.out_dim]))
    strides = [1, 1, 1, 1]
    padding = "SAME"

    conv2d_operation = mtf.Conv2dOperation(conv_input, conv_filter, strides,
                                           padding)
    self.assertEqual(conv2d_operation.splittable_dims,
                     frozenset(["batch", "in", "out"]))
    self.assertEqual(conv2d_operation.unsplittable_dims,
                     frozenset(["filter_h", "filter_w", "grid_h", "grid_w"]))

    output = conv2d_operation.outputs[0]
    d_output = mtf.zeros(self.mesh, output.shape)

    conv2d_backprop_input_operation = mtf.Conv2or3dBackpropInputOperation(
        2, False, conv_input.shape, conv_filter, d_output, strides, padding)
    self.assertEqual(conv2d_backprop_input_operation.splittable_dims,
                     frozenset(["batch", "filter_h", "filter_w", "grid_h",
                                "grid_w", "in", "out"]))
    self.assertEqual(conv2d_backprop_input_operation.unsplittable_dims,
                     frozenset())

    conv2d_backprop_filter_operation = mtf.Conv2or3dBackpropFilterOperation(
        2, False, conv_input, conv_filter.shape, d_output, strides, padding)
    self.assertEqual(conv2d_backprop_filter_operation.splittable_dims,
                     frozenset(["batch", "filter_h", "filter_w", "grid_h",
                                "grid_w", "in", "out"]))
    self.assertEqual(conv2d_backprop_filter_operation.unsplittable_dims,
                     frozenset())

  def testShiftOperation(self):
    shift_operation = mtf.ShiftOperation(self.x, -5, self.b_dim, wrap=True)
    self.assertEqual(shift_operation.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(shift_operation.unsplittable_dims, frozenset())

  def testSliceOperation(self):
    slice_operation = mtf.SliceOperation(self.x, begin=3, size=4,
                                         slice_dim_name="b")
    self.assertEqual(slice_operation.splittable_dims, frozenset(["a"]))
    self.assertEqual(slice_operation.unsplittable_dims, frozenset(["b"]))

  def testPadOperation(self):
    pad_operation = mtf.PadOperation(self.x, [7, 2], "a")
    self.assertEqual(pad_operation.splittable_dims, frozenset(["b"]))
    self.assertEqual(pad_operation.unsplittable_dims, frozenset(["a"]))

  def testOneHotOperation(self):
    x = mtf.zeros(self.mesh, self.ab_shape, dtype=tf.int32)
    one_hot_operation = mtf.OneHotOperation(x, self.c_dim, 1, 0, dtype=tf.bool)
    self.assertEqual(one_hot_operation.splittable_dims,
                     frozenset(["a", "b", "c"]))
    self.assertEqual(one_hot_operation.unsplittable_dims, frozenset())

  def testImportOperation(self):
    tf_x = tf.zeros([5, 10])
    import_operation = mtf.ImportOperation(self.mesh, tf_x, self.ab_shape)
    self.assertEqual(import_operation.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(import_operation.unsplittable_dims, frozenset())

  def testImportLaidOutTensorOperation(self):
    laid_out_x = LaidOutTensor([self.x])

    import_laid_out_tensor_operation = mtf.ImportLaidOutTensorOperation(
        self.mesh, laid_out_x, self.ab_shape)
    self.assertEqual(import_laid_out_tensor_operation.splittable_dims,
                     frozenset())
    self.assertEqual(import_laid_out_tensor_operation.unsplittable_dims,
                     frozenset(["a", "b"]))

  def testVariableOperations(self):
    var = mtf.Variable(self.mesh,
                       "test_variable",
                       self.ab_shape,
                       mtf.VariableDType(tf.int32, tf.int32, tf.int32),
                       initializer=tf.zeros_initializer(),
                       trainable=True)

    self.assertEqual(var.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(var.unsplittable_dims, frozenset())

    read_variable = mtf.ReadVariable(var)
    self.assertEqual(read_variable.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(read_variable.unsplittable_dims, frozenset())

    assign = mtf.Assign([var], [self.x])
    self.assertEqual(assign.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(assign.unsplittable_dims, frozenset())

    depend = mtf.Depend(read_variable.outputs[0], [assign])
    self.assertEqual(depend.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(depend.unsplittable_dims, frozenset())

  def testConstant(self):
    constant = mtf.Constant(self.mesh, 0, self.ab_shape, dtype=tf.int32)
    self.assertEqual(constant.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(constant.unsplittable_dims, frozenset())

  def testStopGradient(self):
    stop_gradient = mtf.StopGradient(self.x)
    self.assertEqual(stop_gradient.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(stop_gradient.unsplittable_dims, frozenset())

  def testPrintOperation(self):
    print_operation = mtf.PrintOperation(self.x, [self.x], "Tensor x: ")
    self.assertEqual(print_operation.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(print_operation.unsplittable_dims, frozenset())

  def testReshapeOperation(self):
    reshape_operation = mtf.ReshapeOperation(
        self.x, mtf.Shape([mtf.Dimension("x", 25), mtf.Dimension("y", 2)]))
    self.assertEqual(reshape_operation.splittable_dims,
                     frozenset(["a", "b", "x", "y"]))
    self.assertEqual(reshape_operation.unsplittable_dims, frozenset())

  def testRandomOperation(self):
    random_operation = mtf.RandomOperation(self.mesh, self.ab_shape,
                                           tf.random_uniform)
    self.assertEqual(random_operation.splittable_dims, frozenset(["a", "b"]))
    self.assertEqual(random_operation.unsplittable_dims, frozenset())

  def testWhileLoopOperation(self):
    # This test case implements the following:
    # for i in range(10):
    #   x = x * 2
    i = mtf.constant(self.mesh, 0, mtf.Shape([]))
    cond_fn = lambda i, x: mtf.less(i, 10)
    body_fn = lambda i, x: [mtf.add(i, 1), mtf.multiply(x, 2)]

    while_loop_operation = mtf.WhileLoopOperation(cond_fn, body_fn, [i, self.x])
    self.assertEqual(while_loop_operation.splittable_dims,
                     frozenset(["a", "b"]))
    self.assertEqual(while_loop_operation.unsplittable_dims, frozenset())


class NthSmallestTest(tf.test.TestCase):

  def testNthLargest(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    a_dim = mtf.Dimension("a", 6)
    b_dim = mtf.Dimension("b", 2)
    inputs = tf.constant([[1, 10],
                          [2, 9],
                          [3, 8],
                          [4, 7],
                          [5, 6],
                          [6, 5]])
    n = 1  # find second largest element (since n is zero-indexed)
    reduced_dim = a_dim
    expected_outputs = tf.constant([5, 9])

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([a_dim, b_dim]))
    mtf_outputs = mtf.nth_largest_element(
        mtf_inputs, n, reduced_dim, "test_nth_largest")
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="all:2", layout="a:all", devices=["", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
    self.assertAllEqual(self.evaluate(actual_outputs),
                        self.evaluate(expected_outputs))

  def testNthSmallestReduceSecondDim(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    a_dim = mtf.Dimension("a", 6)
    b_dim = mtf.Dimension("b", 2)
    inputs = tf.constant([[1, 10],
                          [2, 9],
                          [3, 8],
                          [4, 7],
                          [5, 6],
                          [6, 5]])
    n = 0  # find smallest element (n is zero-indexed)
    reduced_dim = b_dim
    expected_outputs = tf.constant([1, 2, 3, 4, 5, 5])

    mtf_inputs = mtf.import_tf_tensor(
        mesh, inputs, shape=mtf.Shape([a_dim, b_dim]))
    mtf_outputs = mtf.nth_smallest_element(
        mtf_inputs, n, reduced_dim, "test_nth_smallest")
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="all:2", layout="a:all", devices=["", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_outputs = lowering.export_to_tf_tensor(mtf_outputs)
    self.assertAllEqual(self.evaluate(actual_outputs),
                        self.evaluate(expected_outputs))


class TopKTest(tf.test.TestCase):

  def testTopK(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    a_dim = mtf.Dimension("a", 6)
    b_dim = mtf.Dimension("b", 2)
    inputs = tf.constant([[1, 10],
                          [2, 9],
                          [3, 8],
                          [4, 7],
                          [5, 6],
                          [6, 5]],
                         dtype=tf.float32)
    k_dim = mtf.Dimension("k", 2)
    d_values = tf.constant([[11, 12], [13, 14]], dtype=tf.float32)
    reduced_dim = a_dim
    expected_values = tf.constant([[6, 5], [10, 9]], dtype=tf.float32)
    expected_indices = tf.constant([[5, 4], [0, 1]])
    expected_d_inputs = tf.constant([[0, 13],
                                     [0, 14],
                                     [0, 0],
                                     [0, 0],
                                     [12, 0],
                                     [11, 0]],
                                    dtype=tf.float32)

    mtf_inputs = mtf.import_fully_replicated(
        mesh, inputs, shape=mtf.Shape([a_dim, b_dim]))
    mtf_d_values = mtf.import_tf_tensor(
        mesh, d_values, shape=mtf.Shape([b_dim, k_dim]))
    mtf_values, mtf_indices = mtf.top_k(mtf_inputs,
                                        reduced_dim=reduced_dim,
                                        k_dim=k_dim,
                                        name="test_nth_smallest")
    [mtf_d_inputs] = mtf.gradients([mtf_values], [mtf_inputs], [mtf_d_values])
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="rows:2,cols:2", layout="a:rows,b:cols", devices=["", "", "", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_values = lowering.export_to_tf_tensor(mtf_values)
    actual_indices = lowering.export_to_tf_tensor(mtf_indices)
    actual_d_inputs = lowering.export_to_tf_tensor(mtf_d_inputs)
    actual_inputs = lowering.export_to_tf_tensor(mtf_inputs)
    self.assertAllEqual(self.evaluate(actual_inputs),
                        self.evaluate(inputs))
    self.assertAllEqual(self.evaluate(actual_values),
                        self.evaluate(expected_values))
    self.assertAllEqual(self.evaluate(actual_indices),
                        self.evaluate(expected_indices))
    self.assertAllEqual(self.evaluate(actual_d_inputs),
                        self.evaluate(expected_d_inputs))


class RecomputeGradTest(tf.test.TestCase):

  def testRecomputeGrad(self):
    graph = mtf.Graph()
    mesh = mtf.Mesh(graph, "my_mesh")
    # let's differentiate x^2 + x
    # dy/dx = 2x+1
    def x_squared_plus_x(x):
      return x * x + x
    x = tf.constant([5, 10], dtype=tf.float32)
    dy = tf.constant([2, 3], dtype=tf.float32)
    two = mtf.Dimension("two", 2)
    expected_y = tf.constant([30, 110], dtype=tf.float32)
    expected_dx = tf.constant([22, 63], dtype=tf.float32)
    mtf_x = mtf.import_fully_replicated(
        mesh, x, shape=mtf.Shape([two]))
    mtf_dy = mtf.import_tf_tensor(
        mesh, dy, shape=mtf.Shape([two]))
    mtf_y = mtf.recompute_grad(x_squared_plus_x, [mtf_x])
    [mtf_dx] = mtf.gradients([mtf_y], [mtf_x], [mtf_dy])
    mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
        shape="processors:2", layout="two:processors", devices=["", ""])
    lowering = mtf.Lowering(graph, {mesh: mesh_impl})
    actual_y = lowering.export_to_tf_tensor(mtf_y)
    actual_dx = lowering.export_to_tf_tensor(mtf_dx)
    self.assertAllEqual(self.evaluate(actual_y),
                        self.evaluate(expected_y))
    self.assertAllEqual(self.evaluate(actual_dx),
                        self.evaluate(expected_dx))


if __name__ == "__main__":
  tf.disable_v2_behavior()
  tf.enable_eager_execution()
  tf.test.main()