"""Pattern matching module. This module contains classes to detect a pattern in an expression, and eventually replace it with another expression if found. Classes and methods included in this module are: - EvalPattern: replaces wildcards in a pattern with their supposed values. - PatternMatcher: returns true if pattern is matched on expression. - match: same as PatternMatcher, but with pre-processing applied first - PatternReplacement: takes pattern, replacement expression and target expression as input ; if pattern is found in target expression, replaces it with replacement expression - replace: same as PatternReplacement, but with pre-processing applied first. """ import ast from copy import deepcopy import itertools import astunparse try: import z3 except ImportError: raise Exception("z3 module is needed to use this pattern matcher") from sspam.tools import asttools from sspam.tools.flattening import Flattening, Unflattening from sspam import pre_processing # If set to true, pattern matcher will use z3 to match patterns FLEXIBLE = True class EvalPattern(ast.NodeTransformer): """ Replace wildcards in pattern with supposed values. """ def __init__(self, wildcards): self.wildcards = wildcards def visit_Name(self, node): 'Replace wildcards with supposed value' if node.id in self.wildcards: return deepcopy(self.wildcards[node.id]) return node class PatternMatcher(asttools.Comparator): """ Try to match desired pattern with given ast. Wildcards are indicated with upper letters : A, B, ... Example : A + B will match (x | 34) + (y*67) """ def __init__(self, root, nbits=0): 'Init different components of pattern matcher' super(PatternMatcher, self).__init__() # wildcards used in the pattern with their possible values self.wildcards = {} # wildcards <-> values that are known not to work self.no_solution = [] # root node of expression if isinstance(root, ast.Module): self.root = root.body[0].value elif isinstance(root, ast.Expression): self.root = root.body else: self.root = root if not nbits: self.nbits = asttools.get_default_nbits(self.root) else: self.nbits = nbits # identifiers for z3 evaluation getid = asttools.GetIdentifiers() getid.visit(self.root) self.variables = getid.variables self.functions = getid.functions @staticmethod def is_wildcard(node): 'Check if node is wildcard' return isinstance(node, ast.Name) and node.id.isupper() def check_eq_z3(self, target, pattern): 'Check equivalence with z3' # pylint: disable=exec-used getid = asttools.GetIdentifiers() getid.visit(target) if getid.functions: # not checking exprs with functions for now, because Z3 # does not seem to support function declaration with # arbitrary number of arguments return False for var in self.variables: exec("%s = z3.BitVec('%s', %d)" % (var, var, self.nbits)) target_ast = deepcopy(target) target_ast = Unflattening().visit(target_ast) ast.fix_missing_locations(target_ast) code1 = compile(ast.Expression(target_ast), '<string>', mode='eval') eval_pattern = deepcopy(pattern) EvalPattern(self.wildcards).visit(eval_pattern) eval_pattern = Unflattening().visit(eval_pattern) ast.fix_missing_locations(eval_pattern) getid.reset() getid.visit(eval_pattern) if getid.functions: # same reason as before, not using Z3 if there are # functions return False gvar = asttools.GetIdentifiers() gvar.visit(eval_pattern) if any(var.isupper() for var in gvar.variables): # do not check if all patterns have not been replaced return False code2 = compile(ast.Expression(eval_pattern), '<string>', mode='eval') sol = z3.Solver() if isinstance(eval(code1), int) and eval(code1) == 0: # cases where target == 0 are too permissive return False sol.add(eval(code1) != eval(code2)) return sol.check().r == -1 def check_wildcard(self, target, pattern): 'Check wildcard value or affect it' if pattern.id in self.wildcards: wild_value = self.wildcards[pattern.id] exact_comp = asttools.Comparator().visit(wild_value, target) if exact_comp: return True if FLEXIBLE: return self.check_eq_z3(target, self.wildcards[pattern.id]) else: return False else: self.wildcards[pattern.id] = target return True def get_model(self, target, pattern): 'When target is constant and wildcards have no value yet' # pylint: disable=exec-used if target.n == 0: # zero is too permissive return False getwild = asttools.GetIdentifiers() getwild.visit(pattern) if getwild.functions: # not getting model for expr with functions return False wilds = getwild.variables # let's reduce the model to one wildcard for now # otherwise it adds a lot of checks... if len(wilds) > 1: return False wil = wilds.pop() if wil in self.wildcards: if not isinstance(self.wildcards[wil], ast.Num): return False folded = deepcopy(pattern) folded = Unflattening().visit(folded) EvalPattern(self.wildcards).visit(folded) folded = asttools.ConstFolding(folded, self.nbits).visit(folded) return folded.n == target.n else: exec("%s = z3.BitVec('%s', %d)" % (wil, wil, self.nbits)) eval_pattern = deepcopy(pattern) eval_pattern = Unflattening().visit(eval_pattern) ast.fix_missing_locations(eval_pattern) code = compile(ast.Expression(eval_pattern), '<string>', mode='eval') sol = z3.Solver() sol.add(target.n == eval(code)) if sol.check().r == 1: model = sol.model() for inst in model.decls(): self.wildcards[str(inst)] = ast.Num(int(model[inst].as_long())) return True return False def check_not(self, target, pattern): 'Check NOT pattern node that could be in another form' if self.is_wildcard(pattern.operand): wkey = pattern.operand.id if isinstance(target, ast.Num): if wkey not in self.wildcards: mod = 2**self.nbits self.wildcards[wkey] = ast.Num((~target.n) % mod) return True else: wilds2 = self.wildcards[pattern.operand.id] num = ast.Num((~target.n) % 2**self.nbits) return asttools.Comparator().visit(wilds2, num) else: if wkey not in self.wildcards: self.wildcards[wkey] = ast.UnaryOp(ast.Invert(), target) return True return self.check_eq_z3(target, pattern) else: subpattern = pattern.operand newtarget = ast.UnaryOp(ast.Invert(), target) return self.check_eq_z3(newtarget, subpattern) def check_neg(self, target, pattern): 'Check (-1)*... pattern that could be in another form' if self.is_wildcard(pattern.right): wkey = pattern.right.id if isinstance(target, ast.Num): if wkey not in self.wildcards: mod = 2**self.nbits self.wildcards[wkey] = ast.Num((-target.n) % mod) return True else: wilds2 = self.wildcards[pattern.right.id] num = ast.Num((-target.n) % 2**self.nbits) return asttools.Comparator().visit(wilds2, num) else: if wkey not in self.wildcards: self.wildcards[wkey] = ast.BinOp(ast.Num(-1), ast.Mult(), target) return True return self.check_eq_z3(target, pattern) def check_twomult(self, target, pattern): 'Check 2*... pattern that could be in another form' if isinstance(pattern.left, ast.Num) and pattern.left.n == 2: operand = pattern.right elif isinstance(pattern.right, ast.Num) and pattern.right.n == 2: operand = pattern.left else: return False # deal with case where wildcard operand and target are const values if isinstance(target, ast.Num) and isinstance(operand, ast.Name): conds = (operand.id in self.wildcards and isinstance(self.wildcards[operand.id], ast.Num)) if conds: eva = (self.wildcards[operand.id].n)*2 % 2**(self.nbits) if eva == target.n: return True else: if target.n % 2 == 0: self.wildcards[operand.id] = ast.Num(target.n / 2) return True return False # get all wildcards in operand and check if they have value getwild = asttools.GetIdentifiers() getwild.visit(operand) wilds = getwild.variables for wil in wilds: if wil not in self.wildcards: return False return self.check_eq_z3(target, pattern) def general_check(self, target, pattern): 'General check, very time-consuming, not used at the moment' getwild = asttools.GetIdentifiers() getwild.visit(pattern) wilds = list(getwild.variables) if all(wil in self.wildcards for wil in wilds): eval_pattern = deepcopy(pattern) eval_pattern = EvalPattern(self.wildcards).visit(eval_pattern) return self.check_eq_z3(target, eval_pattern) return False def check_pattern(self, target, pattern): 'Try to match pattern written in different ways' if asttools.CheckConstExpr().visit(pattern): if isinstance(target, ast.Num): # if pattern is only a constant, evaluate and compare # to target pattcopy = deepcopy(pattern) eval_pat = asttools.ConstFolding(pattcopy, self.nbits).visit(pattcopy) return self.visit(target, eval_pat) if isinstance(target, ast.Num): # check that wildcards in pattern have not been affected return self.get_model(target, pattern) # deal with NOT that could have been evaluated before notnode = (isinstance(pattern, ast.UnaryOp) and isinstance(pattern.op, ast.Invert)) if notnode: return self.check_not(target, pattern) # deal with (-1)*B that could have been evaluated negnode = (isinstance(pattern, ast.BinOp) and isinstance(pattern.op, ast.Mult) and isinstance(pattern.left, ast.Num) and pattern.left.n == -1) if negnode: return self.check_neg(target, pattern) # deal with 2*B multnode = (isinstance(pattern, ast.BinOp) and isinstance(pattern.op, ast.Mult)) if multnode: return self.check_twomult(target, pattern) # return self.general_check(target, pattern) return False def visit(self, target, pattern): 'Deal with corner cases before using classic comparison' # if pattern contains is a wildcard, check value against target # or affect it if self.is_wildcard(pattern): return self.check_wildcard(target, pattern) # if types are different, we might be facing the same pattern # written differently if type(target) != type(pattern): if FLEXIBLE: return self.check_pattern(target, pattern) else: return False # get type of node to call the right visit_ method nodetype = target.__class__.__name__ comp = getattr(self, "visit_%s" % nodetype, None) if not comp: raise Exception("no comparison function for %s" % nodetype) return comp(target, pattern) def visit_Num(self, target, pattern): 'Check if num values are equal modulo 2**nbits' mod = 2**self.nbits return (target.n % mod) == (pattern.n % mod) def visit_BinOp(self, target, pattern): 'Check type of operation and operands' # pylint: disable=too-many-branches if type(target.op) != type(pattern.op): if FLEXIBLE: return self.check_pattern(target, pattern) else: return False # if operation is commutative, left and right operands are # interchangeable previous_state = deepcopy(self.wildcards) cond1 = (self.visit(target.left, pattern.left) and self.visit(target.right, pattern.right)) state = asttools.apply_hooks() nos = self.wildcards in self.no_solution asttools.restore_hooks(state) if cond1 and not nos: return True if nos: self.wildcards = deepcopy(previous_state) if not cond1 and not nos: # different visiting order might give different results wildsbackup = deepcopy(self.wildcards) self.wildcards = deepcopy(previous_state) cond1_prime = (self.visit(target.right, pattern.right) and self.visit(target.left, pattern.left)) if cond1_prime: return True else: self.wildcards = deepcopy(wildsbackup) # commutative operators if isinstance(target.op, (ast.Add, ast.Mult, ast.BitAnd, ast.BitOr, ast.BitXor)): cond2 = (self.visit(target.left, pattern.right) and self.visit(target.right, pattern.left)) if cond2: return True wildsbackup = deepcopy(self.wildcards) self.wildcards = deepcopy(previous_state) cond2_prime = (self.visit(target.right, pattern.left) and self.visit(target.left, pattern.right)) if cond2_prime: return True else: self.wildcards = deepcopy(wildsbackup) # if those affectations don't work, try with another order if target == self.root: self.no_solution.append(self.wildcards) self.wildcards = deepcopy(previous_state) cond1 = (self.visit(target.left, pattern.left) and self.visit(target.right, pattern.right)) if cond1: return True cond2 = (self.visit(target.left, pattern.right) and self.visit(target.right, pattern.left)) return cond1 or cond2 self.wildcards = deepcopy(previous_state) return False def visit_BoolOp(self, target, pattern): 'Match pattern on flattened operators of same length and same type' conds = (type(target.op) == type(pattern.op) and len(target.values) == len(pattern.values)) if not conds: return False # try every combination wildcard <=> value old_context = deepcopy(self.wildcards) for perm in itertools.permutations(target.values): self.wildcards = deepcopy(old_context) res = True i = 0 for i in range(len(pattern.values)): res &= self.visit(perm[i], pattern.values[i]) if res: return res return False def visit_UnaryOp(self, target, pattern): 'Match type of UnaryOp and operands' if type(target.op) != type(pattern.op): return False return self.visit(target.operand, pattern.operand) def visit_Call(self, target, pattern): 'Match name of Call and arguments' if (not self.visit(target.func, pattern.func) or len(target.args) != len(pattern.args)): return False if (not all([self.visit(t_arg, p_arg) for t_arg, p_arg in zip(target.args, pattern.args)]) or not all([self.visit(t_key, p_key) for t_key, p_key in zip(target.keywords, pattern.keywords)])): return False # only dealing with None starags and kwards for the moment if (not (target.starargs is None and pattern.starargs is None) or not (target.kwargs is None and pattern.kwargs is None)): return False return True def match(target_str, pattern_str): 'Apply all pre-processing, then pattern matcher' target_ast = ast.parse(target_str, mode="eval").body target_ast = pre_processing.all_preprocessings(target_ast) target_ast = Flattening(ast.Add).visit(target_ast) pattern_ast = ast.parse(pattern_str, mode="eval").body pattern_ast = pre_processing.all_preprocessings(pattern_ast) pattern_ast = Flattening(ast.Add).visit(pattern_ast) return PatternMatcher(target_ast).visit(target_ast, pattern_ast) class PatternReplacement(ast.NodeTransformer): """ Test if a pattern is included in an expression, and replace it if found. """ def __init__(self, patt_ast, target_ast, rep_ast, nbits=0): 'Pattern ast should have as root: BinOp, BoolOp, UnaryOp or Call' if isinstance(patt_ast, ast.Module): self.patt_ast = patt_ast.body[0].value elif isinstance(patt_ast, ast.Expression): self.patt_ast = patt_ast.body else: self.patt_ast = patt_ast if isinstance(rep_ast, ast.Module): self.rep_ast = deepcopy(rep_ast.body[0].value) elif isinstance(rep_ast, ast.Expression): self.rep_ast = deepcopy(rep_ast.body) else: self.rep_ast = deepcopy(rep_ast) if not nbits: getsize = asttools.GetSize() getsize.visit(target_ast) if getsize.result: self.nbits = getsize.result # default bitsize is 8 else: self.nbits = 8 else: self.nbits = nbits def basic_visit(self, node): 'Check if node is matching the pattern, if not, visit children' pat = PatternMatcher(node, self.nbits) matched = pat.visit(node, self.patt_ast) if matched: repc = deepcopy(self.rep_ast) new_node = EvalPattern(pat.wildcards).visit(repc) return new_node else: return self.generic_visit(node) def visit_Call(self, node): 'No particular case for Call replacement' return self.basic_visit(node) def visit_BinOp(self, node): 'No particular case for BinOp replacement' return self.basic_visit(node) def visit_UnaryOp(self, node): 'No particular case for UnaryOp replacement' return self.basic_visit(node) def visit_BoolOp(self, node): 'Check if BoolOp is exaclty matching or contain pattern' if isinstance(self.patt_ast, ast.BoolOp): if len(node.values) == len(self.patt_ast.values): return self.basic_visit(node) elif len(node.values) > len(self.patt_ast.values): # associativity n to m for combi in itertools.combinations(node.values, len(self.patt_ast.values)): rest = [elem for elem in node.values if elem not in combi] testnode = ast.BoolOp(node.op, list(combi)) pat = PatternMatcher(testnode, self.nbits) matched = pat.visit(testnode, self.patt_ast) if matched: new = EvalPattern(pat.wildcards).visit(self.rep_ast) new = ast.BoolOp(node.op, [new] + rest) new = Unflattening().visit(new) return new return self.generic_visit(node) if isinstance(self.patt_ast, ast.BinOp): if type(node.op) != type(self.patt_ast.op): return self.generic_visit(node) op = node.op for combi in itertools.combinations(node.values, 2): rest = [elem for elem in node.values if elem not in combi] testnode = ast.BinOp(combi[0], op, combi[1]) pat = PatternMatcher(testnode, self.nbits) matched = pat.visit(testnode, self.patt_ast) if matched: new_node = EvalPattern(pat.wildcards).visit(self.rep_ast) new_node = ast.BoolOp(op, [new_node] + rest) new_node = Unflattening().visit(new_node) return new_node return self.generic_visit(node) def replace(target_str, pattern_str, replacement_str): 'Apply pre-processing and replace' target_ast = ast.parse(target_str, mode="eval").body target_ast = pre_processing.all_preprocessings(target_ast) target_ast = Flattening(ast.Add).visit(target_ast) patt_ast = ast.parse(pattern_str, mode="eval").body patt_ast = pre_processing.all_preprocessings(patt_ast) patt_ast = Flattening(ast.Add).visit(patt_ast) rep_ast = ast.parse(replacement_str) rep = PatternReplacement(patt_ast, target_ast, rep_ast) return rep.visit(target_ast) # Used for debug purposes: if __name__ == '__main__': # pylint: disable=invalid-name patt_string = "A + B - (A | B)" test = "f(g(x + x) + 3 + 4)" repl = "A & B" print match(test, patt_string) print "-"*80 out = replace(test, patt_string, repl) print ast.dump(out) out = Unflattening().visit(out) print astunparse.unparse(out)