# -*- coding: utf-8 -*- """Contains the :class:`ManyToOneMatcher` which can be used for fast many-to-one matching. You can initialize the matcher with a list of the patterns that you wish to match: >>> pattern1 = Pattern(f(a, x_)) >>> pattern2 = Pattern(f(y_, b)) >>> matcher = ManyToOneMatcher(pattern1, pattern2) You can also add patterns later: >>> pattern3 = Pattern(f(a, b)) >>> matcher.add(pattern3) A pattern can be added with a label which is yielded instead of the pattern during matching: >>> pattern4 = Pattern(f(x_, y_)) >>> matcher.add(pattern4, "some label") Then you can match a subject against all the patterns at once: >>> subject = f(a, b) >>> matches = matcher.match(subject) >>> for matched_pattern, substitution in sorted(map(lambda m: (str(m[0]), str(m[1])), matches)): ... print('{} matched with {}'.format(matched_pattern, substitution)) f(a, b) matched with {} f(a, x_) matched with {x ↦ b} f(y_, b) matched with {y ↦ a} some label matched with {x ↦ a, y ↦ b} Also contains the :class:`ManyToOneReplacer` which can replace a set :class:`ReplacementRule` at one using a :class:`ManyToOneMatcher` for finding the matches. """ import math import html import itertools from collections import deque from operator import itemgetter from typing import Container, Dict, Iterable, Iterator, List, NamedTuple, Optional, Sequence, Set, Tuple, Type, Union try: from graphviz import Digraph, Graph except ImportError: Digraph = None Graph = None from multiset import Multiset from ..expressions.expressions import ( Expression, Operation, Symbol, SymbolWildcard, Wildcard, Pattern, AssociativeOperation, CommutativeOperation, OneIdentityOperation ) from ..expressions.substitution import Substitution from ..expressions.functions import ( is_anonymous, contains_variables_from_set, create_operation_expression, preorder_iter_with_position, rename_variables, op_iter, preorder_iter, op_len ) from ..utils import (VariableWithCount, commutative_sequence_variable_partition_iter) from .. import functions from .bipartite import BipartiteGraph, enum_maximum_matchings_iter, LEFT from .syntactic import OPERATION_END, is_operation from ._common import check_one_identity __all__ = ['ManyToOneMatcher', 'ManyToOneReplacer'] LabelType = Union[Expression, Type[Operation]] HeadType = Optional[Union[Expression, Type[Operation], Type[Symbol]]] MultisetOfInt = Multiset MultisetOfExpression = Multiset _EPS = object() _State = NamedTuple('_State', [ ('number', int), ('transitions', Dict[LabelType, '_Transition']), ('matcher', Optional['CommutativeMatcher']) ]) # yapf: disable _Transition = NamedTuple('_Transition', [ ('label', LabelType), ('target', _State), ('variable_name', Optional[str]), ('patterns', Set[int]), ('check_constraints', Optional[Set[int]]), ('subst', Substitution), ]) # yapf: disable _VISITED = set() class _MatchIter: def __init__(self, matcher, subject, intial_associative=None): self.matcher = matcher self.subjects = deque([subject]) if subject is not None else deque() self.patterns = set(range(len(matcher.patterns))) self.substitution = Substitution() self.constraints = set(range(len(matcher.constraints))) self.associative = [intial_associative] def __iter__(self): for _ in self._match(self.matcher.root): yield from self._internal_iter() def grouped(self): """ Yield the matches grouped by their final state in the automaton, i.e. structurally identical patterns only differing in constraints will be yielded together. Each group is yielded as a list of tuples consisting of a pattern and a match substitution. Yields: The grouped matches. """ for _ in self._match(self.matcher.root): yield list(self._internal_iter()) def any(self): """ Returns: True, if any match is found. """ try: next(self) except StopIteration: return False return True def _internal_iter(self): for pattern_index in self.patterns: renaming = self.matcher.pattern_vars[pattern_index] new_substitution = self.substitution.rename({renamed: original for original, renamed in renaming.items()}) pattern, label, _ = self.matcher.patterns[pattern_index] valid = True for constraint in pattern.global_constraints: if not constraint(new_substitution): valid = False break if valid: yield label, new_substitution def _match(self, state: _State) -> Iterator[_State]: _VISITED.add(state.number) if len(self.subjects) == 0: if state.number in self.matcher.finals or OPERATION_END in state.transitions: yield state heads = [None] else: heads = list(self._get_heads(self.subjects[0])) for head in heads: for transition in state.transitions.get(head, []): yield from self._match_transition(transition) def _match_transition(self, transition: _Transition) -> Iterator[_State]: if self.patterns.isdisjoint(transition.patterns): return label = transition.label if label is _EPS: subject = self.subjects[0] if self.subjects else None yield from self._check_transition(transition, subject, False) return if is_operation(label): if transition.target.matcher: yield from self._match_commutative_operation(transition.target) else: yield from self._match_regular_operation(transition) return if isinstance(label, Wildcard) and not isinstance(label, SymbolWildcard): min_count = label.min_count if label.optional is not None and min_count > 0: yield from self._check_transition(transition, label.optional, False) if label.fixed_size and not self.associative[-1]: assert min_count == 1, "Fixed wildcards with length != 1 are not supported." if not self.subjects: return else: yield from self._match_sequence_variable(label, transition) return subject = self.subjects.popleft() if self.subjects else None yield from self._check_transition(transition, subject) def _check_transition(self, transition, subject, restore_subject=True): if self.patterns.isdisjoint(transition.patterns): return restore_constraints = set() restore_patterns = self.patterns - transition.patterns self.patterns &= transition.patterns old_values = {} try: if transition.subst is not None: try: for name, value in transition.subst.items(): old_values[name] = self.substitution.get(name, None) self.substitution.try_add_variable(name, value) except ValueError: for k, v in old_values.items(): if v is None: del self.substitution[k] else: self.substitution[k] = v return if transition.variable_name is not None: try: old_values[transition.variable_name] = self.substitution.get(transition.variable_name, None) self.substitution.try_add_variable(transition.variable_name, subject) except ValueError: return self._check_constraints(transition.check_constraints, restore_constraints, restore_patterns) if not self.patterns: return yield from self._match(transition.target) finally: if restore_subject and subject is not None: self.subjects.appendleft(subject) self.constraints |= restore_constraints self.patterns |= restore_patterns for k, v in old_values.items(): if v is None: del self.substitution[k] else: self.substitution[k] = v def _check_constraints(self, variable: str, restore_constraints, restore_patterns) -> bool: if isinstance(variable, str): check_constraints = self.matcher.constraint_vars.get(variable, []) else: check_constraints = variable variables = set(self.substitution.keys()) for constraint_index in check_constraints: if constraint_index not in self.constraints: continue constraint, patterns = self.matcher.constraints[constraint_index] if constraint.variables <= variables and not self.patterns.isdisjoint(patterns): self.constraints.remove(constraint_index) restore_constraints.add(constraint_index) if not constraint(self.substitution): restore_patterns |= self.patterns & patterns self.patterns -= patterns if not self.patterns: break @staticmethod def _get_heads(expression: Expression) -> Iterator[HeadType]: for base in type(expression).__mro__: if base is not object: yield base if not isinstance(expression, Operation): yield expression yield None def _match_sequence_variable(self, wildcard: Wildcard, transition: _Transition) -> Iterator[_State]: min_count = wildcard.min_count if len(self.subjects) < min_count: return matched_subject = [] for _ in range(min_count): matched_subject.append(self.subjects.popleft()) while True: if self.associative[-1] and wildcard.fixed_size: assert min_count == 1, "Fixed wildcards with length != 1 are not supported." if len(matched_subject) > 1: wrapped = self.associative[-1](*matched_subject) else: wrapped = matched_subject[0] else: if len(matched_subject) == 0 and wildcard.optional is not None: wrapped = wildcard.optional else: wrapped = tuple(matched_subject) yield from self._check_transition(transition, wrapped, False) if not self.subjects: break matched_subject.append(self.subjects.popleft()) self.subjects.extendleft(reversed(matched_subject)) def _match_commutative_operation(self, state: _State) -> Iterator[_State]: subject = self.subjects.popleft() matcher = state.matcher substitution = self.substitution matcher.add_subject(None) for operand in op_iter(subject): matcher.add_subject(operand) for matched_pattern, new_substitution in matcher.match(subject, substitution): restore_constraints = set() diff = set(new_substitution.keys()) - set(substitution.keys()) self.substitution = new_substitution transition_set = state.transitions[matched_pattern] t_iter = iter(t.patterns for t in transition_set) potential_patterns = next(t_iter).union(*t_iter) restore_patterns = self.patterns - potential_patterns self.patterns &= potential_patterns for variable in diff: self._check_constraints(variable, restore_constraints, restore_patterns) if not self.patterns: break if self.patterns: transition_set = state.transitions[matched_pattern] for next_transition in transition_set: yield from self._check_transition(next_transition, subject, False) self.constraints |= restore_constraints self.patterns |= restore_patterns self.substitution = substitution self.subjects.appendleft(subject) def _match_regular_operation(self, transition: _Transition) -> Iterator[_State]: subject = self.subjects.popleft() after_subjects = self.subjects operand_subjects = self.subjects = deque(op_iter(subject)) new_associative = transition.label if issubclass(transition.label, AssociativeOperation) else None self.associative.append(new_associative) for new_state in self._check_transition(transition, subject, False): self.subjects = after_subjects self.associative.pop() for end_transition in new_state.transitions[OPERATION_END]: yield from self._check_transition(end_transition, None, False) self.subjects = operand_subjects self.associative.append(new_associative) self.subjects = after_subjects self.subjects.appendleft(subject) self.associative.pop() class ManyToOneMatcher: __slots__ = ('patterns', 'states', 'root', 'pattern_vars', 'constraints', 'constraint_vars', 'finals', 'rename') _state_id = 0 def __init__(self, *patterns: Expression, rename=True) -> None: """ Args: *patterns: The patterns which the matcher should match. """ self.patterns = [] self.states = [] self.root = self._create_state() self.pattern_vars = [] self.constraints = [] self.constraint_vars = {} self.finals = set() self.rename = rename for pattern in patterns: self.add(pattern) def add(self, pattern: Pattern, label=None) -> None: """Add a new pattern to the matcher. The optional label defaults to the pattern itself and is yielded during matching. The same pattern can be added with different labels which means that every match for the pattern will result in every associated label being yielded with that match individually. Equivalent patterns with the same label are not added again. However, patterns that are structurally equivalent, but have different constraints or different variable names are distinguished by the matcher. Args: pattern: The pattern to add. label: An optional label for the pattern. Defaults to the pattern itself. """ if label is None: label = pattern for i, (p, l, _) in enumerate(self.patterns): if pattern == p and label == l: return i # TODO: Avoid renaming in the pattern, use variable indices instead renaming = self._collect_variable_renaming(pattern.expression) if self.rename else {} self._internal_add(pattern, label, renaming) def _internal_add(self, pattern: Pattern, label, renaming) -> int: """Add a new pattern to the matcher. Equivalent patterns are not added again. However, patterns that are structurally equivalent, but have different constraints or different variable names are distinguished by the matcher. Args: pattern: The pattern to add. Returns: The internal id for the pattern. This is mainly used by the :class:`CommutativeMatcher`. """ pattern_index = len(self.patterns) renamed_constraints = [c.with_renamed_vars(renaming) for c in pattern.local_constraints] constraint_indices = [self._add_constraint(c, pattern_index) for c in renamed_constraints] self.patterns.append((pattern, label, constraint_indices)) self.pattern_vars.append(renaming) pattern = rename_variables(pattern.expression, renaming) state = self.root patterns_stack = [deque([pattern])] self._process_pattern_stack(state, patterns_stack, renamed_constraints, pattern_index) return pattern_index def _process_pattern_stack(self, state, patterns_stack, renamed_constraints, pattern_index): while patterns_stack: if patterns_stack[-1]: subpattern = patterns_stack[-1].popleft() variable_name = getattr(subpattern, 'variable_name', None) if isinstance(subpattern, Operation): if isinstance(subpattern, OneIdentityOperation): non_optional, added_subst = check_one_identity(subpattern) if non_optional is not None: stack = [q.copy() for q in patterns_stack] stack[-1].appendleft(non_optional) new_state = self._create_expression_transition(state, _EPS, variable_name, pattern_index, added_subst) self._process_pattern_stack(new_state, stack, renamed_constraints, pattern_index) if not isinstance(subpattern, CommutativeOperation): patterns_stack.append(deque(op_iter(subpattern))) state = self._create_expression_transition(state, subpattern, variable_name, pattern_index) if isinstance(subpattern, CommutativeOperation): subpattern_id = state.matcher.add_pattern(subpattern, renamed_constraints) state = self._create_simple_transition(state, subpattern_id, pattern_index) else: patterns_stack.pop() if len(patterns_stack) > 0: state = self._create_simple_transition(state, OPERATION_END, pattern_index) self.finals.add(state.number) def _add_constraint(self, constraint, pattern): index = None for i, (c, patterns) in enumerate(self.constraints): if c == constraint: patterns.add(pattern) index = i break else: index = len(self.constraints) self.constraints.append((constraint, set([pattern]))) for var in constraint.variables: self.constraint_vars.setdefault(var, set()).add(index) return index def match(self, subject: Expression) -> Iterator[Tuple[Expression, Substitution]]: """Match the subject against all the matcher's patterns. Args: subject: The subject to match. Yields: For every match, a tuple of the matching pattern and the match substitution. """ return _MatchIter(self, subject) def is_match(self, subject: Expression) -> bool: """Check if the subject matches any of the matcher's patterns. Args: subject: The subject to match. Return: True, if the subject is matched by any of the matcher's patterns. False, otherwise. """ return _MatchIter(self, subject).any() def _create_expression_transition( self, state: _State, expression: Expression, variable_name: Optional[str], index: int, subst=None ) -> _State: label, head = self._get_label_and_head(expression) transitions = state.transitions.setdefault(head, []) commutative = isinstance(expression, CommutativeOperation) matcher = None for transition in transitions: if transition.variable_name == variable_name and transition.label == label and transition.subst == subst: transition.patterns.add(index) if variable_name is not None: constraints = set( self.constraint_vars[variable_name] if variable_name in self.constraint_vars else [] ) for c in list(constraints): patterns = self.constraints[c][1] if patterns.isdisjoint(transition.patterns): constraints.discard(c) transition.check_constraints.update(constraints) state = transition.target break else: if commutative: matcher = CommutativeMatcher(type(expression) if isinstance(expression, AssociativeOperation) else None) state = self._create_state(matcher) if variable_name is not None: constraints = set(self.constraint_vars[variable_name] if variable_name in self.constraint_vars else []) for c in list(constraints): patterns = self.constraints[c][1] if index not in patterns: constraints.discard(c) else: constraints = None transition = _Transition(label, state, variable_name, {index}, constraints, subst) transitions.append(transition) return state def _create_simple_transition(self, state: _State, label: LabelType, index: int, variable_name=None) -> _State: if label in state.transitions: transition = state.transitions[label][0] transition.patterns.add(index) return transition.target new_state = self._create_state() transition = _Transition(label, new_state, variable_name, {index}, None, None) state.transitions[label] = [transition] return new_state @staticmethod def _get_label_and_head(expression: Expression) -> Tuple[LabelType, HeadType]: if expression is _EPS: return _EPS, None if isinstance(expression, Operation): head = label = type(expression) else: label = expression if isinstance(label, SymbolWildcard): head = label.symbol_type label = SymbolWildcard(symbol_type=label.symbol_type) elif isinstance(label, Wildcard): head = None label = Wildcard(label.min_count, label.fixed_size, optional=label.optional) elif isinstance(label, Symbol): head = label = type(label)(label.name) else: head = expression return label, head def _create_state(self, matcher: 'CommutativeMatcher'=None) -> _State: state = _State(ManyToOneMatcher._state_id, dict(), matcher) self.states.append(state) ManyToOneMatcher._state_id += 1 return state @classmethod def _collect_variable_renaming( cls, expression: Expression, position: List[int]=None, variables: Dict[str, str]=None ) -> Dict[str, str]: """Return renaming for the variables in the expression. The variable names are generated according to the position of the variable in the expression. The goal is to rename variables in structurally identical patterns so that the automaton contains less redundant states. """ if position is None: position = [0] if variables is None: variables = {} if getattr(expression, 'variable_name', False): if expression.variable_name not in variables: variables[expression.variable_name] = cls._get_name_for_position(position, variables.values()) position[-1] += 1 if isinstance(expression, Operation): if isinstance(expression, CommutativeOperation): for operand in op_iter(expression): position.append(0) cls._collect_variable_renaming(operand, position, variables) position.pop() else: for operand in op_iter(expression): cls._collect_variable_renaming(operand, position, variables) return variables @staticmethod def _get_name_for_position(position: List[int], variables: Container[str]) -> str: new_name = 'i{}'.format('.'.join(map(str, position))) if new_name in variables: counter = 1 while '{}_{}'.format(new_name, counter) in variables: counter += 1 new_name = '{}_{}'.format(new_name, counter) return new_name def as_graph(self) -> Digraph: # pragma: no cover return self._as_graph(None) _PATTERN_COLORS = [ '#2E4272', '#7887AB', '#4F628E', '#162955', '#061539', '#403075', '#887CAF', '#615192', '#261758', '#13073A', '#226666', '#669999', '#407F7F', '#0D4D4D', '#003333', ] _CONSTRAINT_COLORS = [ '#AA3939', '#D46A6A', '#801515', '#550000', '#AA6C39', '#D49A6A', '#804515', '#552600', '#882D61', '#AA5585', '#661141', '#440027', ] _VARIABLE_COLORS = [ '#8EA336', '#B9CC66', '#677B14', '#425200', '#5C9632', '#B5E196', '#85BC5E', '#3A7113', '#1F4B00', '#AAA139', '#807715', '#554E00', ] @classmethod def _colored_pattern(cls, pid): # pragma: no cover color = cls._PATTERN_COLORS[pid % len(cls._PATTERN_COLORS)] return '<font color="{}"><b>p{}</b></font>'.format(color, pid) @classmethod def _colored_constraint(cls, cid): # pragma: no cover color = cls._CONSTRAINT_COLORS[cid % len(cls._CONSTRAINT_COLORS)] return '<font color="{}"><b>c{}</b></font>'.format(color, cid) @classmethod def _colored_variable(cls, var): # pragma: no cover color = cls._VARIABLE_COLORS[hash(var) % len(cls._VARIABLE_COLORS)] return '<font color="{}"><b>{}</b></font>'.format(color, var) @classmethod def _format_pattern_set(cls, patterns): # pragma: no cover return '{{{}}}'.format(', '.join(map(cls._colored_pattern, patterns))) @classmethod def _format_constraint_set(cls, constraints): # pragma: no cover return '{{{}}}'.format(', '.join(map(cls._colored_constraint, constraints))) def _as_graph(self, finals: Optional[List[str]]) -> Digraph: # pragma: no cover if Digraph is None: raise ImportError('The graphviz package is required to draw the graph.') graph = Digraph() if finals is None: patterns = [ '{}: {} with {}'.format( self._colored_pattern(i), html.escape(str(p.expression)), self._format_constraint_set(c) ) for i, (p, l, c) in enumerate(self.patterns) ] graph.node('patterns', '<<b>Patterns:</b><br/>\n{}>'.format('<br/>\n'.join(patterns)), {'shape': 'box'}) self._make_graph_nodes(graph, finals) if finals is None: constraints = [ '{}: {} for {}'.format(self._colored_constraint(i), html.escape(str(c)), self._format_pattern_set(p)) for i, (c, p) in enumerate(self.constraints) ] graph.node( 'constraints', '<<b>Constraints:</b><br/>\n{}>'.format('<br/>\n'.join(constraints)), {'shape': 'box'} ) self._make_graph_edges(graph) return graph def _make_graph_nodes(self, graph: Digraph, finals: Optional[List[str]]) -> None: # pragma: no cover state_patterns = {} for state in self.states: state_patterns.setdefault(state.number, set()) for transition in itertools.chain.from_iterable(state.transitions.values()): state_patterns.setdefault(transition.target.number, set()).update(transition.patterns) for state in self.states: name = 'n{!s}'.format(state.number) if state.matcher: has_states = len(state.matcher.automaton.states) > 1 if has_states: graph.node(name, 'Sub Matcher', {'shape': 'box'}) subfinals = [] if has_states: graph.subgraph(state.matcher.automaton._as_graph(subfinals)) submatch_label = '<<b>Sub Matcher End</b>' if has_states else '<<b>Sub Matcher</b>' for pattern_index, subpatterns, variables in state.matcher.patterns.values(): var_formatted = ', '.join( '{}[{}]x{}{}{}'.format(self._colored_variable(n), m, c, 'W' if w else '', ': {}'.format(d) if d is not None else '') for (n, c, m, d), w in variables ) submatch_label += '<br/>\n{}: {} {}'.format( self._colored_pattern(pattern_index), subpatterns, var_formatted ) submatch_label += '>' end_name = (name + '-end') if has_states else name graph.node(end_name, submatch_label, {'shape': 'box'}) for f in subfinals: graph.edge(f, end_name) if has_states: graph.edge(name, 'n{}'.format(state.matcher.automaton.root.number)) else: attrs = {'shape': ('doublecircle' if state.number in self.finals else 'circle')} if state.number in _VISITED: attrs['color'] = 'red' graph.node(name, str(state.number), attrs) if state.number in self.finals: sp = state_patterns[state.number] if finals is not None: finals.append(name + '-out') variables = [ '{}: {}'.format( self._colored_pattern(i), ', '.join('{} -> {}'.format(self._colored_variable(o), n) for n, o in r.items()) ) for i, r in enumerate(self.pattern_vars) if i in sp ] graph.node( name + '-out', '<<b>Pattern Variables:</b><br/>\n{}>'.format('<br/>\n'.join(variables)), {'shape': 'box'} ) graph.edge(name, name + '-out') def _make_graph_edges(self, graph: Digraph) -> None: # pragma: no cover for state in self.states: for _, transitions in state.transitions.items(): for transition in transitions: t_label = '<' if transition.variable_name: t_label += '{}: '.format(self._colored_variable(transition.variable_name)) t_label += 'ε' if transition.label is _EPS else html.escape(str(transition.label)) if is_operation(transition.label): t_label += '(' t_label += '<br/>{}'.format(self._format_pattern_set(transition.patterns)) if transition.check_constraints is not None: t_label += '<br/>{}'.format(self._format_constraint_set(transition.check_constraints)) if transition.subst is not None: t_label += '<br/>{}'.format(html.escape(str(transition.subst))) t_label += '>' start = 'n{!s}'.format(state.number) if state.matcher and len(state.matcher.automaton.states) > 1: start += '-end' end = 'n{!s}'.format(transition.target.number) graph.edge(start, end, t_label) class ManyToOneReplacer: """Class that contains a set of replacement rules and can apply them efficiently to an expression.""" def __init__(self, *rules): """ A replacement rule consists of a *pattern*, that is matched against any subexpression of the expression. If a match is found, the *replacement* callback of the rule is called with the variables from the match substitution. Whatever the callback returns is used as a replacement for the matched subexpression. This can either be a single expression or a sequence of expressions, which is then integrated into the surrounding operation in place of the subexpression. Note that the pattern can therefore not be a single sequence variable/wildcard, because only single expressions will be matched. Args: *rules: The replacement rules. """ self.matcher = ManyToOneMatcher() for rule in rules: self.add(rule) def add(self, rule: 'functions.ReplacementRule') -> None: """Add a new rule to the replacer. Args: rule: The rule to add. """ self.matcher.add(rule.pattern, rule.replacement) def replace(self, expression: Expression, max_count: int=math.inf) -> Union[Expression, Sequence[Expression]]: """Replace all occurrences of the patterns according to the replacement rules. Args: expression: The expression to which the replacement rules are applied. max_count: If given, at most *max_count* applications of the rules are performed. Otherwise, the rules are applied until there is no more match. If the set of replacement rules is not confluent, the replacement might not terminate without a *max_count* set. Returns: The resulting expression after the application of the replacement rules. This can also be a sequence of expressions, if the root expression is replaced with a sequence of expressions by a rule. """ replaced = True replace_count = 0 while replaced and replace_count < max_count: replaced = False for subexpr, pos in preorder_iter_with_position(expression): try: replacement, subst = next(iter(self.matcher.match(subexpr))) result = replacement(**subst) expression = functions.replace(expression, pos, result) replaced = True break except StopIteration: pass replace_count += 1 return expression def replace_post_order(self, expression: Expression) -> Union[Expression, Sequence[Expression]]: """Replace all occurrences of the patterns according to the replacement rules. Replaces innermost expressions first. Args: expression: The expression to which the replacement rules are applied. max_count: If given, at most *max_count* applications of the rules are performed. Otherwise, the rules are applied until there is no more match. If the set of replacement rules is not confluent, the replacement might not terminate without a *max_count* set. Returns: The resulting expression after the application of the replacement rules. This can also be a sequence of expressions, if the root expression is replaced with a sequence of expressions by a rule. """ return self._replace_post_order(expression)[0] def _replace_post_order(self, expression): any_replaced = False while True: if isinstance(expression, Operation): new_operands = [self._replace_post_order(o) for o in op_iter(expression)] if any(r for _, r in new_operands): new_operands = [o for o, _ in new_operands] expression = create_operation_expression(expression, new_operands) any_replaced = True try: replacement, subst = next(iter(self.matcher.match(expression))) expression = replacement(**subst) any_replaced = True except StopIteration: break return expression, any_replaced Subgraph = BipartiteGraph[Tuple[int, int], Tuple[int, int], Substitution] Matching = Dict[Tuple[int, int], Tuple[int, int]] class CommutativeMatcher(object): __slots__ = ( 'patterns', 'subjects', 'subjects_by_id', 'automaton', 'bipartite', 'associative', 'max_optional_count', 'anonymous_patterns' ) def __init__(self, associative: Optional[type]) -> None: self.patterns = {} self.subjects = {} self.subjects_by_id = {} self.automaton = ManyToOneMatcher() self.bipartite = BipartiteGraph() self.associative = associative self.max_optional_count = 0 self.anonymous_patterns = set() def add_pattern(self, operands: Iterable[Expression], constraints) -> int: pattern_set, pattern_vars = self._extract_sequence_wildcards(operands, constraints) sorted_vars = tuple(sorted(pattern_vars.values(), key=lambda v: (v[0][0] or '', v[0][1], v[0][2], v[1]))) sorted_subpatterns = tuple(sorted(pattern_set)) pattern_key = sorted_subpatterns + sorted_vars if pattern_key not in self.patterns: inserted_id = len(self.patterns) self.patterns[pattern_key] = (inserted_id, pattern_set, sorted_vars) else: inserted_id = self.patterns[pattern_key][0] return inserted_id def get_match_iter(self, subject): match_iter = _MatchIter(self.automaton, subject, self.associative) for _ in match_iter._match(self.automaton.root): for pattern_index in match_iter.patterns: substitution = Substitution(match_iter.substitution) yield pattern_index, substitution def add_subject(self, subject: Expression) -> None: if subject not in self.subjects: subject_id, pattern_set = self.subjects[subject] = (len(self.subjects), set()) self.subjects_by_id[subject_id] = subject for pattern_index, substitution in self.get_match_iter(subject): self.bipartite.setdefault((subject_id, pattern_index), []).append(Substitution(substitution)) pattern_set.add(pattern_index) else: subject_id, _ = self.subjects[subject] return subject_id def match(self, subjects: Sequence[Expression], substitution: Substitution) -> Iterator[Tuple[int, Substitution]]: subject_ids = Multiset() pattern_ids = Multiset() if self.max_optional_count > 0: subject_id, subject_pattern_ids = self.subjects[None] subject_ids.add(subject_id) for _ in range(self.max_optional_count): pattern_ids.update(subject_pattern_ids) for subject in op_iter(subjects): subject_id, subject_pattern_ids = self.subjects[subject] subject_ids.add(subject_id) pattern_ids.update(subject_pattern_ids) for pattern_index, pattern_set, pattern_vars in self.patterns.values(): if pattern_set: if not pattern_set <= pattern_ids: continue bipartite_match_iter = self._match_with_bipartite(subject_ids, pattern_set, substitution) for bipartite_substitution, matched_subjects in bipartite_match_iter: ids = subject_ids - matched_subjects remaining = Multiset(self.subjects_by_id[id] for id in ids if self.subjects_by_id[id] is not None) if pattern_vars: sequence_var_iter = self._match_sequence_variables( remaining, pattern_vars, bipartite_substitution ) for result_substitution in sequence_var_iter: yield pattern_index, result_substitution elif len(remaining) == 0: yield pattern_index, bipartite_substitution elif pattern_vars: sequence_var_iter = self._match_sequence_variables(Multiset(op_iter(subjects)), pattern_vars, substitution) for variable_substitution in sequence_var_iter: yield pattern_index, variable_substitution elif op_len(subjects) == 0: yield pattern_index, substitution def _extract_sequence_wildcards(self, operands: Iterable[Expression], constraints) -> Tuple[MultisetOfInt, Dict[str, Tuple[VariableWithCount, bool]]]: pattern_set = Multiset() pattern_vars = dict() opt_count = 0 for operand in op_iter(operands): if isinstance(operand, Wildcard) and operand.optional is not None: opt_count += 1 if not self._is_sequence_wildcard(operand): actual_constraints = [c for c in constraints if contains_variables_from_set(operand, c.variables)] pattern = Pattern(operand, *actual_constraints) index = None for i, (p, _, _) in enumerate(self.automaton.patterns): if pattern == p: index = i break else: vnames = set(e.variable_name for e in preorder_iter(pattern.expression) if hasattr(e, 'variable_name') and e.variable_name is not None) renaming = {n: n for n in vnames} index = self.automaton._internal_add(pattern, None, renaming) if is_anonymous(pattern.expression): self.anonymous_patterns.add(index) pattern_set.add(index) else: varname = getattr(operand, 'variable_name', None) if varname is None: if varname in pattern_vars: (_, _, min_count, _), _ = pattern_vars[varname] else: min_count = 0 pattern_vars[varname] = (VariableWithCount(varname, 1, operand.min_count + min_count, None), False) else: if varname in pattern_vars: (_, count, _, _), wrap = pattern_vars[varname] else: count = 0 wrap = operand.fixed_size and self.associative pattern_vars[varname] = ( VariableWithCount(varname, count + 1, operand.min_count, operand.optional), wrap ) if opt_count > self.max_optional_count: self.max_optional_count = opt_count return pattern_set, pattern_vars def _is_sequence_wildcard(self, expression: Expression) -> bool: if isinstance(expression, SymbolWildcard): return False if isinstance(expression, Wildcard): return not expression.fixed_size or self.associative return False def _match_with_bipartite( self, subject_ids: MultisetOfInt, pattern_set: MultisetOfInt, substitution: Substitution, ) -> Iterator[Tuple[Substitution, MultisetOfInt]]: bipartite = self._build_bipartite(subject_ids, pattern_set) for matching in enum_maximum_matchings_iter(bipartite): if len(matching) < len(pattern_set): break if not self._is_canonical_matching(matching): continue for substs in itertools.product(*(bipartite[edge] for edge in matching.items())): try: bipartite_substitution = substitution.union(*substs) except ValueError: continue matched_subjects = Multiset(subexpression for subexpression, _ in matching) yield bipartite_substitution, matched_subjects def _match_sequence_variables( self, subjects: MultisetOfExpression, pattern_vars: Sequence[VariableWithCount], substitution: Substitution, ) -> Iterator[Substitution]: only_counts = [info for info, _ in pattern_vars] wrapped_vars = [name for (name, _, _, _), wrap in pattern_vars if wrap and name] for variable_substitution in commutative_sequence_variable_partition_iter(subjects, only_counts): for var in wrapped_vars: operands = variable_substitution[var] if isinstance(operands, (tuple, list, Multiset)): if len(operands) > 1: variable_substitution[var] = self.associative(*operands) else: variable_substitution[var] = next(iter(operands)) try: result_substitution = substitution.union(variable_substitution) except ValueError: continue yield result_substitution def _build_bipartite(self, subjects: MultisetOfInt, patterns: MultisetOfInt) -> Subgraph: bipartite = BipartiteGraph() n = 0 m = 0 p_states = {} for subject, s_count in subjects.items(): if (LEFT, subject) in self.bipartite._graph: any_patterns = False for _, pattern in self.bipartite._graph[LEFT, subject]: if pattern in patterns: any_patterns = True subst = self.bipartite[subject, pattern] p_count = patterns[pattern] if pattern in p_states: p_start = p_states[pattern] else: p_start = p_states[pattern] = m m += p_count for i in range(n, n + s_count): for j in range(p_start, p_start + p_count): bipartite[(subject, i), (pattern, j)] = subst if any_patterns: n += s_count return bipartite def _is_canonical_matching(self, matching: Matching) -> bool: anonymous_patterns = self.anonymous_patterns for (s1, n1), (p1, m1) in matching.items(): for (s2, n2), (p2, m2) in matching.items(): if p1 in anonymous_patterns and p2 in anonymous_patterns: if n1 < n2 and m1 > m2: return False elif s1 == s2 and n1 < n2 and m1 > m2: return False return True def bipartite_as_graph(self) -> Graph: # pragma: no cover """Returns a :class:`graphviz.Graph` representation of this bipartite graph.""" if Graph is None: raise ImportError('The graphviz package is required to draw the graph.') graph = Graph() nodes_left = {} # type: Dict[TLeft, str] nodes_right = {} # type: Dict[TRight, str] node_id = 0 for (left, right), value in self.bipartite._edges.items(): if left not in nodes_left: name = 'node{:d}'.format(node_id) nodes_left[left] = name label = str(self.subjects_by_id[left]) graph.node(name, label=label) node_id += 1 if right not in nodes_right: name = 'node{:d}'.format(node_id) nodes_right[right] = name label = str(self.automaton.patterns[right][0]) graph.node(name, label=label) node_id += 1 edge_label = value is not True and str(value) or '' graph.edge(nodes_left[left], nodes_right[right], edge_label) return graph def concrete_bipartite_as_graph(self, subjects, patterns) -> Graph: # pragma: no cover """Returns a :class:`graphviz.Graph` representation of this bipartite graph.""" if Graph is None: raise ImportError('The graphviz package is required to draw the graph.') bipartite = self._build_bipartite(subjects, patterns) graph = Graph() nodes_left = {} # type: Dict[TLeft, str] nodes_right = {} # type: Dict[TRight, str] node_id = 0 for (left, right), value in bipartite._edges.items(): if left not in nodes_left: subject_id, i = left name = 'node{:d}'.format(node_id) nodes_left[left] = name label = '{}, {}'.format(i, self.subjects_by_id[subject_id]) graph.node(name, label=label) node_id += 1 if right not in nodes_right: pattern, i = right name = 'node{:d}'.format(node_id) nodes_right[right] = name label = '{}, {}'.format(i, self.automaton.patterns[pattern][0]) graph.node(name, label=label) node_id += 1 edge_label = value is not True and str(value) or '' graph.edge(nodes_left[left], nodes_right[right], edge_label) return graph class SecondaryAutomaton(): # pragma: no cover # TODO: Decide whether to integrate this def __init__(self, k): self.k = k self.states = self._build(k) def match(self, edges): raise NotImplementedError @staticmethod def _build(k): states = dict() queue = [frozenset([0])] while queue: state_id = queue.pop(0) state = states[state_id] = dict() for i in range(1, 2**k): new_state = set() for t in [2**j for j in range(k) if i & 2**j]: for v in state_id: new_state.add(t | v) new_state = frozenset(new_state - state_id) if new_state: if new_state != state_id: state[i] = new_state if new_state not in states and new_state not in queue: queue.append(new_state) keys = sorted(states.keys()) new_states = [] for state in keys: new_states.append(states[state]) for i, state in enumerate(new_states): new_state = dict() for key, value in state.items(): new_state[key] = keys.index(value) new_states[i] = new_state return new_states def as_graph(self): if Digraph is None: raise ImportError('The graphviz package is required to draw the graph.') graph = Digraph() for i in range(len(self.states)): graph.node(str(i), str(i)) for state, edges in enumerate(self.states): for target, labels in itertools.groupby(sorted(edges.items()), key=itemgetter(1)): label = '\n'.join(bin(l)[2:].zfill(self.k) for l, _ in labels) graph.edge(str(state), str(target), label) return graph