import ast

from ward.tests.utilities import testable_test
from ward import test, fixture
from ward.rewrite import (
    rewrite_assertions_in_tests,
    RewriteAssert,
    get_assertion_msg,
    make_call_node,
    is_binary_comparison,
    is_comparison_type,
)
from ward.testing import Test, each


@testable_test
def passing_fn():
    assert 1 == 1


@testable_test
def failing_fn():
    assert 1 == 2


@fixture
def passing():
    yield Test(
        fn=passing_fn, module_name="m", id="id-pass",
    )


@fixture
def failing():
    yield Test(
        fn=failing_fn, module_name="m", id="id-fail",
    )


@test("rewrite_assertions_in_tests returns all tests, keeping metadata")
def _(p=passing, f=failing):
    in_tests = [p, f]
    out_tests = rewrite_assertions_in_tests(in_tests)

    def meta(test):
        return test.description, test.id, test.module_name, test.fn.ward_meta

    assert [meta(test) for test in in_tests] == [meta(test) for test in out_tests]


@test("RewriteAssert.visit_Assert doesn't transform `{src}`")
def _(
    src=each(
        "assert x",
        "assert f(x)",
        "assert x + y + z",
        "assert 1 < 2 < 3",
        "assert 1 == 1 == 3",
        "print(x)",
        "yield",
    )
):
    in_tree = ast.parse(src).body[0]
    out_tree = RewriteAssert().visit(in_tree)
    assert in_tree == out_tree


@test("RewriteAssert.visit_Assert transforms `{src}` correctly")
def _(
    src=each(
        "assert x == y",
        "assert x != y",
        "assert x in y",
        "assert x not in y",
        "assert x is y",
        "assert x is not y",
        "assert x < y",
        "assert x <= y",
        "assert x > y",
        "assert x >= y",
    ),
    fn=each(
        "assert_equal",
        "assert_not_equal",
        "assert_in",
        "assert_not_in",
        "assert_is",
        "assert_is_not",
        "assert_less_than",
        "assert_less_than_equal_to",
        "assert_greater_than",
        "assert_greater_than_equal_to",
    ),
):
    in_tree = ast.parse(src).body[0]
    out_tree = RewriteAssert().visit(in_tree)

    assert out_tree.lineno == in_tree.lineno
    assert out_tree.col_offset == in_tree.col_offset
    assert out_tree.value.lineno == in_tree.lineno
    assert out_tree.value.col_offset == in_tree.col_offset
    assert out_tree.value.func.id == fn
    assert out_tree.value.args[0].id == "x"
    assert out_tree.value.args[1].id == "y"
    assert out_tree.value.args[2].s == ""


@test("RewriteAssert.visit_Assert transforms `{src}`")
def _(src="assert 1 == 2, 'msg'"):
    in_tree = ast.parse(src).body[0]
    out_tree = RewriteAssert().visit(in_tree)
    assert out_tree.value.args[2].s == "msg"


@test("get_assertion_message({src}) returns '{msg}'")
def _(
    src=each("assert 1 == 2, 'msg'", "assert 1 == 2", "assert 1 == 2, 1"),
    msg=each("msg", "", ""),
):
    in_tree = ast.parse(src).body[0]
    assert msg == get_assertion_msg(in_tree)


@test("make_call_node converts `{src}` to correct function call node`")
def _(
    src=each(
        "assert x == y",
        "assert x == y, 'message'",
        "assert x < y",
        "assert x in y",
        "assert x is y",
        "assert x is not y",
    ),
    func="my_assert",
):
    assert_node = ast.parse(src).body[0]
    call = make_call_node(assert_node, func)

    # check that `assert x OP y` becomes `my_assert(x, y, '')`
    lhs = assert_node.test.left.id
    rhs = assert_node.test.comparators[0].id
    msg = assert_node.msg.s if assert_node.msg else ""

    assert call.value.args[0].id == lhs
    assert call.value.args[1].id == rhs
    assert call.value.args[2].s == msg


@test("is_binary_comparison returns True for assert binary comparisons")
def _(src=each("assert x == y", "assert x is y", "assert x < y", "assert x is not y")):
    assert_node = ast.parse(src).body[0]
    assert is_binary_comparison(assert_node)


@test("is_binary_comparison('{src}') is False")
def _(src=each("assert True", "assert x < y < z", "assert not False")):
    assert_node = ast.parse(src).body[0]
    assert not is_binary_comparison(assert_node)


@test("is_comparison_type returns True if node is of given type")
def _(
    src=each("assert x == y", "assert x is y", "assert x < y", "assert x is not y"),
    node_type=each(ast.Eq, ast.Is, ast.Lt, ast.IsNot),
):
    assert_node = ast.parse(src).body[0]
    assert is_comparison_type(assert_node, node_type)


@test("is_comparison_type returns False if node is not of given type")
def _(
    src=each("assert x == y", "assert x is y", "assert x < y", "assert x is not y"),
    node_type=each(ast.Add, ast.Add, ast.Add, ast.Add),
):
    assert_node = ast.parse(src).body[0]
    assert not is_comparison_type(assert_node, node_type)


if True:

    @test("test with indentation level of 1")
    def _():
        assert 1 + 2 == 3

    if True:

        @test("test with indentation level of 2")
        def _():
            assert 2 + 3 == 5