# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for converting TensorFlow control flow op to Relay."""
import pytest
try:
    import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
except ImportError:
    import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import numpy as np
from tvm import nd
from tvm import relay
from tvm.relay.frontend.tensorflow import from_tensorflow


def check_equal(graph, tf_out, input_map=None):
    mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
    if input_map is not None:
        params.update(input_map)
    ex = relay.create_executor('vm', mod=mod)
    relay_out = ex.evaluate()(**params)
    if isinstance(relay_out, nd.NDArray):
        np.testing.assert_allclose(tf_out, relay_out.asnumpy())
    else:
        if not isinstance(tf_out, (list, tuple)):
            tf_out = [tf_out]
        for x, y in zip(tf_out, [r.asnumpy() for r in relay_out]):
            np.testing.assert_allclose(x, y)


def test_vanilla_loop():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(0, name="while/constant")

        def c(i): return tf.less(i, 10)

        def b(i): return tf.add(i, 1)

        r = tf.while_loop(c, b, [i])

        with tf.Session() as sess:
            tf_out = sess.run(r)

        check_equal(graph, tf_out)


def test_callnode_loop_vars():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.add(tf.constant(0), 1)

        def c(i): return tf.less(i, 10)

        def b(i): return tf.add(i, 1)

        r = tf.while_loop(c, b, [i])

        with tf.Session() as sess:
            tf_out = sess.run(r)

        check_equal(graph, tf_out)


def test_loop_2_vars():
    graph = tf.Graph()
    with graph.as_default():
        i0 = tf.constant(0)
        j0 = tf.ones([2, 2])

        def c(i, j): return i < 10

        def b(i, j): return [tf.add(i, 1), j]

        i1, i2 = tf.while_loop(c, b, loop_vars=[i0, j0])
        i1 += tf.constant(1337)

        with tf.Session() as sess:
            tf_out = sess.run(i1)

    check_equal(graph, tf_out)


def test_loop_3_vars():
    graph = tf.Graph()
    with graph.as_default():
        i0 = tf.constant(1)
        j0 = tf.constant(2)
        k0 = tf.constant(4)

        def c(i, j, k): return i < 10

        def b(i, j, k): return [i+1, j * k, k + i]
        r = tf.while_loop(c, b, loop_vars=[i0, j0, k0])

        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out)


def test_loop_conditions():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(1)
        j = tf.constant(1)
        k = tf.constant(5)

        def c(i, j, k): return \
            tf.equal(tf.not_equal(tf.less(i + j, 10),
                                  tf.less(j * k, 100)),
                     tf.greater_equal(k, i + j))

        def b(i, j, k): return [i+j, j+k, k+1]
        r = tf.while_loop(c, b, loop_vars=[i, j, k])
        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out)


@pytest.mark.skip
def test_loop_bodies():
    graph = tf.Graph()
    with graph.as_default():
        def body(x):
            a = tf.constant(np.array([[5, 6], [7, 8]]), dtype=tf.int32)
            b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32)
            c = a + b
            return tf.nn.relu(x + c)

        def condition(x):
            return tf.reduce_sum(x) < 100
        x = tf.constant(0, shape=[2, 2])
        r = tf.while_loop(condition, body, [x])
        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out)


def test_nested_loop():
    graph = tf.Graph()
    with graph.as_default():

        def body(x):
            def nest_body(c):
                return tf.multiply(c, 2)
            def cd(c): return tf.less(c, 10)
            c = tf.constant(2)
            res = tf.while_loop(cd, nest_body, loop_vars=[c])
            return tf.nn.relu(x + res)

        def condition(x):
            return tf.greater(x, 100)
        x = tf.constant(3)
        r = tf.while_loop(condition, body, loop_vars=[x])

        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out)


def test_vanilla_cond():
    graph = tf.Graph()
    with graph.as_default():
        i = tf.constant(1)
        j = tf.constant(4)

        def f1():
            return tf.multiply(1, 17)

        def f2():
            return tf.add(4, 23)
        r = tf.cond(tf.less(i, j), f1, f2)

    with tf.Session(graph=graph) as sess:
        tf_out = sess.run(r)

    check_equal(graph, tf_out)


def test_multiple_cond_vars():
    graph = tf.Graph()
    with graph.as_default():
        x1 = tf.constant(7)
        x2 = tf.constant(12)
        z = tf.constant(20)
        r = tf.cond(tf.less(tf.add(x1, x2), 10),
                    lambda: tf.add(10, 2), lambda: tf.square(5))

        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out)


def test_cond_fn_parameters():
    graph = tf.Graph()
    with graph.as_default():
        def fn1(x, y):
            return tf.multiply(5, 6)

        def fn2(x, y):
            return tf.add(3, 4)

        i = tf.constant(1)
        j = tf.constant(2)
        k = tf.constant(3)
        r = tf.cond(tf.less(i, j), lambda: fn1(i, k), lambda: fn2(j, k))

        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={i: 1, j: 2, k: 3})

    check_equal(graph, tf_out)


def test_nested_cond():
    graph = tf.Graph()
    with graph.as_default():
        def fn1(a, b):
            def nest_fn1():
                return tf.add(1, 2)

            def nest_fn2():
                return tf.subtract(10, 5)

            res = tf.cond(tf.less(1, 2), nest_fn1, nest_fn2)
            return tf.multiply(tf.add(87, res), 10)

        def fn2(a, b):
            return tf.add(10, 10)

        x = tf.constant(5)
        y = tf.constant(6)
        z = tf.constant(7)
        pred = tf.less(x, y)
        r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True})

    check_equal(graph, tf_out)


def test_loop_in_cond():
    graph = tf.Graph()
    with graph.as_default():
        def fn1(a, b):
            i = tf.constant(0)

            def cd(i): return tf.less(i, 10)

            def bd(i): return tf.add(i, 1)
            res = tf.while_loop(cd, bd, [i])
            return tf.multiply(tf.add(20, res), 10)

        def fn2(a, b):
            return tf.add(10, 20)

        x = tf.constant(7)
        y = tf.constant(20)
        z = tf.constant(10)
        pred = tf.less(x, y)
        r = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={x: 1, y: 2, z: 3, pred: True})

    check_equal(graph, tf_out)


def test_cond_in_loop():
    graph = tf.Graph()
    with graph.as_default():
        def body(x):
            x = tf.constant(7)
            z = tf.constant(20)
            res = tf.cond(tf.less(x, 10), lambda: tf.add(
                10, 20), lambda: tf.square(10))
            return tf.multiply(res, x)

        x = tf.constant(21)
        def condition(x):
            return tf.less(x, 100)

        r = tf.while_loop(condition, body, loop_vars=[x])
        with tf.Session() as sess:
            tf_out = sess.run(r)

    check_equal(graph, tf_out)

def test_vanilla_loop_bound():
    graph = tf.Graph()
    with graph.as_default():
        dshape = (2, 10)
        dtype = "float32"
        dname = "data"
        np_data = np.random.uniform(size=dshape).astype(dtype)
        data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
        x = tf.slice(data, [1, 4], [1, 4])
        outer = x + 5.0
        def body(x, y):
            res = tf.cond(tf.less(y, 10), lambda: tf.add(
                10.0, 20.0), lambda: tf.square(10.0))
            z = tf.constant(7)
            res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
            return tf.multiply(res, x * outer), y + 1

        y = tf.constant(0)
        def condition(x, y):
            return tf.less(y, 20)

        r = tf.while_loop(condition, body, loop_vars=[x, y])
        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})

    check_equal(graph, tf_out, {dname: np_data})

def test_nested_loop_bound():
    graph = tf.Graph()
    with graph.as_default():
        dshape = (2, 10)
        dtype = "float32"
        dname = "data"
        np_data = np.random.uniform(size=dshape).astype(dtype)
        data = tf.placeholder(shape=dshape, dtype=dtype, name=dname)
        x = tf.slice(data, [1, 4], [1, 4])
        outer = x + 5.0
        def body(x, y):
            res = tf.cond(tf.less(y, 10), lambda: tf.add(
                10.0, 20.0), lambda: tf.square(10.0))
            def nested_body(nx, ny):
                return nx + 1, res + 2.0
            def nested_cond(nx, ny):
                return tf.less(nx, 15)
            nx = tf.constant(0)
            ny = tf.constant(0.0)
            nested_res = tf.while_loop(nested_cond, nested_body, loop_vars=[nx, ny])
            res = res + nested_res[1]
            z = tf.constant(7)
            res = tf.cond(tf.less(z, 10), lambda: res * 5, lambda: res + 10)
            return tf.multiply(res, x * outer), y + 1

        y = tf.constant(0)
        def condition(x, y):
            return tf.less(y, 20)

        r = tf.while_loop(condition, body, loop_vars=[x, y])
        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={"%s:0" % dname: np_data})

    check_equal(graph, tf_out, {dname: np_data})

def test_switch():
    graph = tf.Graph()

    with graph.as_default():
        data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
        dname = 'data'
        flag_name = 'flag'
        data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
        split = tf.split(data, 2, axis=0)
        flag = tf.placeholder(shape={}, dtype=tf.bool, name=flag_name)
        output_false, output_true = control_flow_ops.switch(split[1], flag)
        with tf.Session() as sess:
            tf_out = sess.run(output_false, feed_dict={data.name: data_np, flag.name: False})

    check_equal(graph, tf_out, {dname: data_np, flag_name: False})

def test_loop_tuple_input():
    graph = tf.Graph()

    with graph.as_default():
        data_np = np.random.uniform(0, 5, size=(2, 4, 5, 1)).astype('float32')
        dname = 'data'
        data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name=dname)
        split = tf.split(data, 2, axis=0)

        def body(x, y):
            return x + 2, y + 1

        start = tf.constant(0)
        def condition(x, y):
            return tf.less(y, 20)

        r = tf.while_loop(condition, body, loop_vars=[split[1], start])
        with tf.Session() as sess:
            tf_out = sess.run(r, feed_dict={data.name: data_np})

    check_equal(graph, tf_out, {dname: data_np})


if __name__ == "__main__":
    # tf.while_loop
    test_vanilla_loop()
    test_loop_2_vars()
    test_loop_3_vars()
    test_loop_conditions()
    # TODO(@jroesch): Need to fix memory alloc to support closure
    # test_loop_bodies()
    test_callnode_loop_vars()

    # tf.cond
    test_vanilla_cond()
    test_multiple_cond_vars()
    test_cond_fn_parameters()

    # nested cases
    test_nested_loop()
    test_nested_cond()
    test_loop_in_cond()
    test_cond_in_loop()
    test_vanilla_loop_bound()
    test_nested_loop_bound()

    test_switch()
    test_loop_tuple_input()