# -*- coding: utf-8 -*-
from typing import Iterable, Iterator, List, Sequence, Tuple, cast, Set

from multiset import Multiset

from ..expressions.expressions import (
    Expression, Pattern, Operation, Symbol, SymbolWildcard, Wildcard, AssociativeOperation, CommutativeOperation, OneIdentityOperation
)
from ..expressions.constraints import Constraint
from ..expressions.substitution import Substitution
from ..expressions.functions import (
    is_constant, preorder_iter_with_position, match_head, create_operation_expression, op_iter, op_len
)
from ..utils import (
    VariableWithCount, commutative_sequence_variable_partition_iter, fixed_integer_vector_iter, weak_composition_iter,
    generator_chain, optional_iter
)
from ._common import CommutativePatternsParts, check_one_identity

__all__ = ['match', 'match_anywhere']


def match(subject: Expression, pattern: Pattern) -> Iterator[Substitution]:
    r"""Tries to match the given *pattern* to the given *subject*.

    Yields each match in form of a substitution.

    Parameters:
        subject:
            An subject to match.
        pattern:
            The pattern to match.

    Yields:
        All possible match substitutions.

    Raises:
        ValueError:
            If the subject is not constant.
    """
    if not is_constant(subject):
        raise ValueError("The subject for matching must be constant.")
    global_constraints = [c for c in pattern.constraints if not c.variables]
    local_constraints = set(c for c in pattern.constraints if c.variables)
    for subst in _match([subject], pattern.expression, Substitution(), local_constraints):
        for constraint in global_constraints:
            if not constraint(subst):
                break
        else:
            yield subst


def match_anywhere(subject: Expression, pattern: Pattern) -> Iterator[Tuple[Substitution, Tuple[int, ...]]]:
    """Tries to match the given *pattern* to the any subexpression of the given *subject*.

    Yields each match in form of a substitution and a position tuple.
    The position is a tuple of indices, e.g. the empty tuple refers to the *subject* itself,
    :code:`(0, )` refers to the first child (operand) of the subject, :code:`(0, 0)` to the first child of
    the first child etc.

    Parameters:
        subject:
            An subject to match.
        pattern:
            The pattern to match.

    Yields:
        All possible substitution and position pairs.

    Raises:
        ValueError:
            If the subject is not constant.
    """
    if not is_constant(subject):
        raise ValueError("The subject for matching must be constant.")
    for child, pos in preorder_iter_with_position(subject):
        if match_head(child, pattern):
            for subst in match(child, pattern):
                yield subst, pos


def _match(subjects: List[Expression], pattern: Expression, subst: Substitution,
           constraints: Set[Constraint]) -> Iterator[Substitution]:
    match_iter = None
    expr = subjects[0] if subjects else None
    if isinstance(pattern, Wildcard):
        # All size checks are already handled elsewhere
        # When called directly from match, len(subjects) = 1
        # The operation matching also already only assigns valid number of subjects to a wildcard
        # So all we need to check here is the symbol type for SymbolWildcards
        if isinstance(pattern, SymbolWildcard) and not isinstance(subjects[0], pattern.symbol_type):
            return
        match_iter = iter([subst])
        if pattern.optional is not None and not subjects:
            expr = pattern.optional
        elif not pattern.fixed_size:
            expr = tuple(subjects)

    elif isinstance(pattern, Symbol):
        if len(subjects) == 1 and isinstance(subjects[0], type(pattern)) and subjects[0].name == pattern.name:
            match_iter = iter([subst])

    elif isinstance(pattern, Operation):
        if isinstance(pattern, OneIdentityOperation):
            yield from _match_one_identity(subjects, pattern, subst, constraints)
        if len(subjects) != 1 or not isinstance(subjects[0], pattern.__class__):
            return
        op_expr = cast(Operation, subjects[0])
        # if not op_expr.symbols >= pattern.symbols:
        #     return
        match_iter = _match_operation(op_expr, pattern, subst, constraints)

    else:
        if len(subjects) == 1 and subjects[0] == pattern:
            match_iter = iter([subst])

    if match_iter is not None:
        if getattr(pattern, 'variable_name', False):
            for new_subst in match_iter:
                try:
                    if expr is None and getattr(pattern, 'optional', None) is not None:
                        expr = pattern.optional
                    new_subst = new_subst.union_with_variable(pattern.variable_name, expr)
                except ValueError:
                    pass
                else:
                    yield from _check_constraints(new_subst, constraints)
        else:
            yield from match_iter


def _check_constraints(substitution, constraints):
    restore_constraints = set()
    try:
        for constraint in list(constraints):
            for var in constraint.variables:
                if var not in substitution:
                    break
            else:
                if not constraint(substitution):
                    break
                restore_constraints.add(constraint)
                constraints.remove(constraint)
        else:
            yield substitution
    finally:
        for constraint in restore_constraints:
            constraints.add(constraint)


def _match_factory(subjects, operand, constraints):
    def factory(subst):
        yield from _match(subjects, operand, subst, constraints)

    return factory


def _count_seq_vars(subjects, operation):
    remaining = op_len(subjects)
    sequence_var_count = 0
    optional_count = 0
    for operand in op_iter(operation):
        if isinstance(operand, Wildcard):
            if not operand.fixed_size or isinstance(operation, AssociativeOperation):
                sequence_var_count += 1
                if operand.optional is None:
                    remaining -= operand.min_count
            elif operand.optional is not None:
                optional_count += 1
            else:
                remaining -= operand.min_count
        else:
            remaining -= 1
        if remaining < 0:
            raise ValueError
    return remaining, sequence_var_count, optional_count


def _build_full_partition(
        optional_parts, sequence_var_partition: Sequence[int], subjects: Sequence[Expression], operation: Operation
) -> List[Sequence[Expression]]:
    """Distribute subject operands among pattern operands.

    Given a partitoning for the variable part of the operands (i.e. a list of how many extra operands each sequence
    variable gets assigned).
    """
    i = 0
    var_index = 0
    opt_index = 0
    result = []
    for operand in op_iter(operation):
        wrap_associative = False
        if isinstance(operand, Wildcard):
            count = operand.min_count if operand.optional is None else 0
            if not operand.fixed_size or isinstance(operation, AssociativeOperation):
                count += sequence_var_partition[var_index]
                var_index += 1
                wrap_associative = operand.fixed_size and operand.min_count
            elif operand.optional is not None:
                count = optional_parts[opt_index]
                opt_index += 1
        else:
            count = 1

        operand_expressions = list(op_iter(subjects))[i:i + count]
        i += count

        if wrap_associative and len(operand_expressions) > wrap_associative:
            fixed = wrap_associative - 1
            operand_expressions = tuple(operand_expressions[:fixed]) + (
                create_operation_expression(operation, operand_expressions[fixed:]),
            )

        result.append(operand_expressions)

    return result


def _non_commutative_match(subjects, operation, subst, constraints):
    try:
        remaining, sequence_var_count, optional_count = _count_seq_vars(subjects, operation)
    except ValueError:
        return
    for new_remaining, optional in optional_iter(remaining, optional_count):
        if new_remaining < 0:
            continue
        for part in weak_composition_iter(new_remaining, sequence_var_count):
            partition = _build_full_partition(optional, part, subjects, operation)
            factories = [_match_factory(e, o, constraints) for e, o in zip(partition, op_iter(operation))]

            for new_subst in generator_chain(subst, *factories):
                yield new_subst


def _match_one_identity(subjects, operation, subst, constraints):
    non_optional, added_subst = check_one_identity(operation)
    if non_optional is not None:
        try:
            new_subst = subst.union(added_subst)
        except ValueError:
            return
        yield from _match(subjects, non_optional, new_subst, constraints)


def _match_operation(subjects, operation, subst, constraints):
    if op_len(operation) == 0:
        if op_len(subjects) == 0:
            yield subst
        return
    if not isinstance(operation, CommutativeOperation):
        yield from _non_commutative_match(subjects, operation, subst, constraints)
    else:
        parts = CommutativePatternsParts(type(operation), *op_iter(operation))
        yield from _match_commutative_operation(subjects, parts, subst, constraints)


def _match_commutative_operation(
        subject_operands: Iterable[Expression],
        pattern: CommutativePatternsParts,
        substitution: Substitution,
        constraints
) -> Iterator[Substitution]:
    subjects = Multiset(op_iter(subject_operands))  # type: Multiset
    if not pattern.constant <= subjects:
        return
    subjects -= pattern.constant
    rest_expr = pattern.rest + pattern.syntactic
    needed_length = (
        pattern.sequence_variable_min_length + pattern.fixed_variable_length + len(rest_expr) +
        pattern.wildcard_min_length
    )

    if len(subjects) < needed_length:
        return

    fixed_vars = Multiset(pattern.fixed_variables)  # type: Multiset[str]
    for name, count in pattern.fixed_variables.items():
        if name in substitution:
            replacement = substitution[name]
            if issubclass(pattern.operation, AssociativeOperation) and isinstance(replacement, pattern.operation):
                needed_count = Multiset(op_iter(substitution[name]))  # type: Multiset
            else:
                if isinstance(replacement, (tuple, list, Multiset)):
                    return
                needed_count = Multiset({replacement: 1})
            if count > 1:
                needed_count *= count
            if not needed_count <= subjects:
                return
            subjects -= needed_count
            del fixed_vars[name]

    factories = [_fixed_expr_factory(e, constraints) for e in rest_expr]

    if not issubclass(pattern.operation, AssociativeOperation):
        for name, count in fixed_vars.items():
            min_count, symbol_type, default = pattern.fixed_variable_infos[name]
            factory = _fixed_var_iter_factory(name, count, min_count, symbol_type, constraints, default)
            factories.append(factory)

        if pattern.wildcard_fixed is True:
            factory = _fixed_var_iter_factory(None, 1, pattern.wildcard_min_length, None, constraints, None)
            factories.append(factory)
    else:
        for name, count in fixed_vars.items():
            min_count, symbol_type, default = pattern.fixed_variable_infos[name]
            if symbol_type is not None:
                factory = _fixed_var_iter_factory(name, count, min_count, symbol_type, constraints, default)
                factories.append(factory)

    for rem_expr, substitution in generator_chain((subjects, substitution), *factories):
        sequence_vars = _variables_with_counts(pattern.sequence_variables, pattern.sequence_variable_infos)
        if issubclass(pattern.operation, AssociativeOperation):
            sequence_vars += _variables_with_counts(fixed_vars, pattern.fixed_variable_infos)
            if pattern.wildcard_fixed is True:
                sequence_vars += (VariableWithCount(None, 1, pattern.wildcard_min_length, None), )
        if pattern.wildcard_fixed is False:
            sequence_vars += (VariableWithCount(None, 1, pattern.wildcard_min_length, None), )

        for sequence_subst in commutative_sequence_variable_partition_iter(Multiset(rem_expr), sequence_vars):
            if issubclass(pattern.operation, AssociativeOperation):
                for v in fixed_vars.distinct_elements():
                    if v not in sequence_subst:
                        continue
                    l = pattern.fixed_variable_infos[v].min_count
                    value = cast(Sequence, sequence_subst[v])
                    if isinstance(value, (list, tuple, Multiset)):
                        if len(value) > l:
                            normal = Multiset(list(value)[:l - 1])
                            wrapped = pattern.operation(*(value - normal))
                            normal.add(wrapped)
                            sequence_subst[v] = normal if l > 1 else next(iter(normal))
                        else:
                            assert len(value) == 1 and l == 1, "Fixed variables with length != 1 are not supported."
                            sequence_subst[v] = next(iter(value))
            try:
                result = substitution.union(sequence_subst)
            except ValueError:
                pass
            else:
                yield from _check_constraints(result, constraints)


def _variables_with_counts(variables, infos):
    return tuple(
        VariableWithCount(name, count, infos[name].min_count, infos[name].default)
        for name, count in variables.items() if infos[name].type is None
    )


def _fixed_expr_factory(expression, constraints):
    def factory(data):
        subjects, substitution = data
        for expr in subjects.distinct_elements():
            if match_head(expr, expression):
                for subst in _match([expr], expression, substitution, constraints):
                    yield subjects - Multiset({expr: 1}), subst

    return factory


def _fixed_var_iter_factory(variable_name, count, length, symbol_type, constraints, optional):
    def factory(data):
        subjects, substitution = data
        if variable_name in substitution:
            value = ([substitution[variable_name]]
                     if not isinstance(substitution[variable_name], (tuple, list, Multiset)) else substitution[variable_name])
            if optional is not None and value == [optional]:
                yield subjects, substitution
            existing = Multiset(value) * count
            if not existing <= subjects:
                return
            yield subjects - existing, substitution
        else:
            if optional is not None:
                new_substitution = Substitution(substitution)
                new_substitution[variable_name] = optional
                yield subjects, new_substitution
            if length == 1:
                for expr, expr_count in subjects.items():
                    if expr_count >= count and (symbol_type is None or isinstance(expr, symbol_type)):
                        if variable_name is not None:
                            new_substitution = Substitution(substitution)
                            new_substitution[variable_name] = expr
                            for new_substitution in _check_constraints(new_substitution, constraints):
                                yield subjects - Multiset({expr: count}), new_substitution
                        else:
                            yield subjects - Multiset({expr: count}), substitution
            else:
                assert variable_name is None, "Fixed variables with length != 1 are not supported."
                exprs_with_counts = list(subjects.items())
                counts = tuple(c // count for _, c in exprs_with_counts)
                for subset in fixed_integer_vector_iter(counts, length):
                    sub_counter = Multiset(dict((exprs_with_counts[i][0], c * count) for i, c in enumerate(subset)))
                    yield subjects - sub_counter, substitution

    return factory