# -*- coding: utf-8 -*-
import inspect
import itertools

import pytest
from multiset import Multiset

from matchpy.expressions.expressions import (Arity, Operation, Symbol, SymbolWildcard, Wildcard, Expression)
from .common import *

SIMPLE_EXPRESSIONS = [
    a,
    b,
    f(a, b),
    x_,
    ___,
    f(_, variable_name='x'),
    s_,
    _s,
]

class SpecialF(f):
    name = 'special'


class TestExpression:
    @pytest.mark.parametrize(
        '   expression,                                                         simplified',
        [
            (f_i(a),                                                            a),
            (f_i(a, b),                                                         f_i(a, b)),
            (f_i(_),                                                            _),
            (f_i(___),                                                          f_i(___)),
            (f_i(__),                                                           f_i(__)),
            (f_i(x_),                                                           x_),
            (f_i(x___),                                                         f_i(x___)),
            (f_i(x__),                                                          f_i(x__)),
            (f_a(f_a(a)),                                                       f_a(a)),
            (f_a(f_a(a, b)),                                                    f_a(a, b)),
            (f_a(a, f_a(b)),                                                    f_a(a, b)),
            (f_a(f_a(a), b),                                                    f_a(a, b)),
            (f_a(f(a)),                                                         f_a(f(a))),
            (f_c(a, b),                                                         f_c(a, b)),
            (f_c(b, a),                                                         f_c(a, b)),
        ]
    )  # yapf: disable
    def test_operation_simplify(self, expression, simplified):
        assert expression == simplified

    @pytest.mark.parametrize(
        '   operation,                                              operands,       expected_error',
        [
            (Operation.new('f', Arity.unary),                       [],             ValueError),
            (Operation.new('f', Arity.unary),                       [a, b],         ValueError),
            (Operation.new('f', Arity.variadic),                    [],             None),
            (Operation.new('f', Arity.variadic),                    [a],            None),
            (Operation.new('f', Arity.variadic),                    [a, b],         None),
            (Operation.new('f', Arity.binary, associative=True),    [a, a, b],      ValueError),
            (Operation.new('f', Arity.binary),                      [x_, x___],     None),
            (Operation.new('f', Arity.binary),                      [x_, x__],      None),
            (Operation.new('f', Arity.binary),                      [x_, x_, x__],  ValueError),
            (Operation.new('f', Arity.binary),                      [x_, x_, x___], None),
            (Operation.new('f', Arity.binary),                      [x_, x_],       None),
            (Operation.new('f', Arity.binary),                      [x_, x_, x_],   ValueError),
        ]
    )  # yapf: disable
    def test_operation_errors(self, operation, operands, expected_error):
        if expected_error is not None:
            with pytest.raises(expected_error):
                operation(*operands)
        else:
            _ = operation(*operands)

    @pytest.mark.parametrize(
        '   expression,     is_constant',
        [
            (a,             True),
            (x_,            False),
            (_,             False),
            (f(a),          True),
            (f(a, b),       True),
            (f(x_),         False),
        ]
    )  # yapf: disable
    def test_is_constant(self, expression, is_constant):
        assert expression.is_constant == is_constant

    @pytest.mark.parametrize(
        '   expression,     is_syntactic',
        [
            (a,             True),
            (x_,            True),
            (_,             True),
            (x___,          False),
            (___,           False),
            (x__,           False),
            (__,            False),
            (f(a),          True),
            (f(a, b),       True),
            (f(x_),         True),
            (f(x__),        False),
            (f_a(a),        False),
            (f_a(a, b),     False),
            (f_a(x_),       False),
            (f_a(x__),      False),
            (f_c(a),        False),
            (f_c(a, b),     False),
            (f_c(x_),       False),
            (f_c(x__),      False),
            (f_ac(a),       False),
            (f_ac(a, b),    False),
            (f_ac(x_),      False),
            (f_ac(x__),     False),
        ]
    )  # yapf: disable
    def test_is_syntactic(self, expression, is_syntactic):
        assert expression.is_syntactic == is_syntactic

    @pytest.mark.parametrize(
        '   expression,         symbols',
        [
            (a,                 ['a']),
            (x_,                []),
            (_,                 []),
            (f(a),              ['a', 'f']),
            (f(a, b),           ['a', 'b', 'f']),
            (f(x_),             ['f']),
            (f(a, a),           ['a', 'a', 'f']),
            (f(f(a), f(b, c)),  ['a', 'b', 'c', 'f', 'f', 'f']),
        ]
    )  # yapf: disable
    def test_symbols(self, expression, symbols):
        assert expression.symbols == Multiset(symbols)

    @pytest.mark.parametrize(
        '   expression,                 variables',
        [
            (a,                         []),
            (x_,                        ['x']),
            (_,                         []),
            (f(a),                      []),
            (f(x_),                     ['x']),
            (f(x_, x_),                 ['x', 'x']),
            (f(x_, a),                  ['x']),
            (f(x_, a, y_),              ['x', 'y']),
            (f(f(x_), f(b, x_)),        ['x', 'x']),
            (f(a, variable_name='x'),        ['x']),
            (f(f(y_), variable_name='x'),    ['x', 'y']),
        ]
    )  # yapf: disable
    def test_variables(self, expression, variables):
        assert expression.variables == Multiset(variables)

    @pytest.mark.parametrize(
        '   expression,     predicate,                  preorder_list',
        [                                               # expression        position
            (f(a, x_),      None,                       [(f(a, x_),         ()),
                                                         (a,                (0, )),
                                                         (x_,               (1, ))]),
            (f(a, f(x_)),   lambda e: e.head is None,   [(x_,               (1, 0))]),
            (f(a, f(x_)),   lambda e: e.head == f,      [(f(a, f(x_)),      ()),
                                                         (f(x_),            (1, ))])
        ]
    )  # yapf: disable
    def test_preorder_iter(self, expression, predicate, preorder_list):
        result = list(expression.preorder_iter(predicate))
        assert result == preorder_list

    GETITEM_TEST_EXPRESSION = f(a, f(x_, b), _)

    @pytest.mark.parametrize(
        '   position,       expected_result',
        [
            ((),            GETITEM_TEST_EXPRESSION),
            ((0, ),         a),
            ((0, 0),        IndexError),
            ((1, ),         f(x_, b)),
            ((1, 0),        x_),
            ((1, 0, 0),     IndexError),
            ((1, 1),        b),
            ((1, 1, 0),     IndexError),
            ((1, 2),        IndexError),
            ((2, ),         _),
            ((3, ),         IndexError),
        ]
    )  # yapf: disable
    def test_getitem(self, position, expected_result):
        if inspect.isclass(expected_result) and issubclass(expected_result, Exception):
            with pytest.raises(expected_result):
                result = self.GETITEM_TEST_EXPRESSION[position]
                print(result)
        else:
            result = self.GETITEM_TEST_EXPRESSION[position]
            assert result == expected_result

    @pytest.mark.parametrize(
        '   start,          end,    expected_result',
        [
            ((),            (),     [GETITEM_TEST_EXPRESSION]),
            ((0, ),         (0, ),  [a]),
            ((0, ),         (1, ),  [a, f(x_, b)]),
            ((0, ),         (2, ),  [a, f(x_, b), _]),
            ((0, ),         (3, ),  [a, f(x_, b), _]),
            ((1, ),         (2, ),  [f(x_, b), _]),
            ((1, 0),        (1, 1), [x_, b]),
            ((1, 0),        (2, ),  IndexError),
            ((1, ),         (0, ),  IndexError),
            ((1, 0),        (2, 0), IndexError),
        ]
    )  # yapf: disable
    def test_getitem_slice(self, start, end, expected_result):
        if inspect.isclass(expected_result) and issubclass(expected_result, Exception):
            with pytest.raises(expected_result):
                result = self.GETITEM_TEST_EXPRESSION[start:end]
                print(result)
        else:
            result = self.GETITEM_TEST_EXPRESSION[start:end]
            assert result == expected_result

    def test_getitem_slice_symbol(self):
        with pytest.raises(IndexError):
            print(a[(0, ):()])
        with pytest.raises(IndexError):
            print(a[(0, ):(1, )])
        assert a[():()] == [a]

    @pytest.mark.parametrize(
        '   expression1,                    expression2',
        [
            (a,                             b),
            (a,                             Symbol('a', variable_name='x')),
            (Symbol('a', variable_name='x'),     Symbol('a', variable_name='y')),
            (a,                             _),
            (a,                             _s),
            (a,                             x_),
            (_,                             x_),
            (_s,                            x_),
            (x_,                            y_),
            (x_,                            x__),
            (f(a),                          f(b)),
            (f(a),                          f2(a)),
            (f(a),                          f(a, a)),
            (f(b),                          f(a, a)),
            (f(a, a),                       f(a, b)),
            (f(a, a),                       f(a, a, a)),
            (a,                             f(a)),
            (x_,                            f(a)),
            (_,                             f(a)),
            (_s,                            f(a)),
            (_s,                            s_),
            (SymbolWildcard(variable_name='x'),  SymbolWildcard(variable_name='y')),
            (s_,                            ss_),
            (_s,                            __),
            (_,                             _s),
            (SymbolWildcard(SpecialSymbol), SymbolWildcard(Symbol)),
            (f(a),                          SpecialF(a)),
        ]
    )  # yapf: disable
    def test_lt(self, expression1, expression2):
        assert expression1 < expression2, "{!s} < {!s} did not hold".format(expression1, expression2)
        assert not (expression2 < expression1
                   ), "Inconsistent order: Both {0} < {1} and {1} < {0}".format(expression2, expression1)

    @pytest.mark.parametrize('expression', [a, f(a), x_, _])
    def test_lt_error(self, expression):
        with pytest.raises(TypeError):
            expression < object()

    def test_operation_new_error(self):
        with pytest.raises(ValueError):
            _ = Operation.new('if', Arity.variadic)

        with pytest.raises(ValueError):
            _ = Operation.new('+', Arity.variadic)

    def test_wildcard_error(self):
        with pytest.raises(ValueError):
            _ = Wildcard(-1, False)

        with pytest.raises(ValueError):
            _ = Wildcard(0, True)

    def test_symbol_wildcard_error(self):
        with pytest.raises(TypeError):
            _ = SymbolWildcard(object)

    @pytest.mark.parametrize(
        '   expression,                         renaming,       expected_result',
        [
            (a,                                 {},             a),
            (a,                                 {'x': 'y'},     a),
            (x_,                                {},             x_),
            (x_,                                {'x': 'y'},     y_),
            (SymbolWildcard(),                  {},             SymbolWildcard()),
            (SymbolWildcard(),                  {'x': 'y'},     SymbolWildcard()),
            (f(x_),                             {},             f(x_)),
            (f(x_),                             {'x': 'y'},     f(y_)),
        ]
    )  # yapf: disable
    def test_with_renamed_vars(self, expression, renaming, expected_result):
        new_expr = expression.with_renamed_vars(renaming)
        assert new_expr == expected_result

    @pytest.mark.parametrize('expression', SIMPLE_EXPRESSIONS)
    @pytest.mark.parametrize('other', SIMPLE_EXPRESSIONS)
    def test_hash(self, expression, other):
        expression = expression
        other = other
        if expression != other:
            assert hash(expression) != hash(other), "hash({!s}) == hash({!s})".format(expression, other)
        else:
            assert hash(expression) == hash(other), "hash({!s}) != hash({!s})".format(expression, other)

    @pytest.mark.parametrize('expression', SIMPLE_EXPRESSIONS)
    def test_copy(self, expression):
        other = expression.__copy__()
        assert other == expression
        assert other is not expression

    @pytest.mark.parametrize(
        '   expression,     subexpression,  contains',
        [
            (a,             a,              True),
            (a,             b,              False),
            (f(a),          a,              True),
            (f(a),          b,              False),
            (f(a),          f(a),           True),
            (f(a, b),       f(a),           False),
            (f(a),          f(a, b),        False),
            (f(x_, y_),     x_,             True),
            (f(x_, y_),     y_,             True),
            (f(x_, y_),     a,              False),
        ]
    )  # yapf: disable
    def test_contains(self, expression, subexpression, contains):
        if contains:
            assert subexpression in expression, "{!s} should be contained in {!s}".format(subexpression, expression)
        else:
            assert subexpression not in expression, "{!s} should not be contained in {!s}".format(subexpression, expression)


class TestOperation:
    def test_one_identity_error(self):
        with pytest.raises(TypeError):
            Operation.new('Invalid', Arity.unary, one_identity=True)
        with pytest.raises(TypeError):
            Operation.new('Invalid', Arity.binary, one_identity=True)

    def test_infix_error(self):
        with pytest.raises(TypeError):
            Operation.new('Invalid', Arity.unary, infix=True)