"""Various functions and classes used to analyze and manipulate ast. - flatten, apply_hooks and restore_hooks are used to compare sets of ast. - get_default_nbits returns the default bitsize of an ast if it is different from zero, returns 8 otherwise. - GetIdentifiers collects every identifiers of an ast. - GetNums collects all numerals of an ast. - GetSize computes the default bitsize of an ast from its constants. - GetConstExpr gathers all constants math expressions from an ast. - CheckConstExpr checks if a given node is a constant expression. - ConstFolding applies constant folding (computes constants expressions). - ReplaceBitwiseOp replaces bitwise operators with functions. - ReplaceBitwiseFunctions replaces functions with bitwise operators. - GetConstMod replaces constants with their value modulo 2^n - Comparator is used to compare ast (modulo commutativity / associativity) """ import ast from sspam.tools.flattening import Unflattening def flatten(lis): 'Flatten a list' res = [] for elem in lis: if isinstance(elem, list): res.extend(flatten(elem)) else: res.append(elem) return res def apply_hooks(): 'Apply hooks to change hash and eq functions for ast elements' # pylint: disable=protected-access,unnecessary-lambda # backup ! backup_expr_hash = ast.expr.__hash__ backup_expr_eq = ast.expr.__eq__ backup_expr_context_hash = ast.expr_context.__hash__ backup_operator_hash = ast.operator.__hash__ # used for ast set comparison list_fields = lambda self: flatten([getattr(self, field) for field in self._fields]) hashboolop = lambda self: hash(tuple(sorted(map(hash, list_fields(self))))) ast.expr.__hash__ = hashboolop ast.expr.__eq__ = lambda self, other: Comparator().visit(self, other) ast.expr_context.__hash__ = lambda self: hash(type(self)) ast.operator.__hash__ = lambda self: hash(type(self)) ast.unaryop.__hash__ = lambda self: hash(type(self)) return (backup_expr_hash, backup_expr_eq, backup_expr_context_hash, backup_operator_hash) def restore_hooks(hooks): 'Restore classic hash and eq function for ast elements' ast.expr.__hash__ = hooks[0] ast.expr.__eq__ = hooks[1] ast.expr_context.__hash__ = hooks[2] ast.operator.__hash__ = hooks[3] def get_default_nbits(expr_ast): 'Computes default number of bits with size of constants' getsize = GetSize() getsize.visit(expr_ast) if getsize.result: nbits = getsize.result else: # default bitsize is 8 nbits = 8 return nbits class GetIdentifiers(ast.NodeVisitor): """ Get all identifiers (instances of ast.Name) of an ast. """ def __init__(self): 'Result contains identifiers of the ast' self.variables = set() self.functions = set() def reset(self): 'Empty result set, so that instance may be re-used' self.variables = set() self.functions = set() def visit_Name(self, node): 'Add node id to result' self.variables.add(node.id) def visit_Call(self, node): 'Add func id to result and visit argument' self.functions.add(node.func.id) for arg in node.args: self.visit(arg) class GetNums(ast.NodeVisitor): """ Get all numeric values (instances of ast.Num) of an ast. """ def __init__(self): 'Result contains numeric values of ast.Num nodes' self.result = set() def visit_Num(self, node): 'Add node value to result' self.result.add(node.n) class GetSize(ast.NodeVisitor): """ Get bitsize of ast: approximate with 2**8, 2**16... """ def __init__(self): 'Init nbits' self.result = 0 def reset(self): 'Empty result set, so that instance may be re-used' self.result = 0 def visit_Num(self, node): 'Approximate nbits with n power of two' bitlen = (abs(node.n)).bit_length() if bitlen > self.result: # didn't find a way to do this cleanly... if bitlen == 1 or bitlen == 2: self.result = bitlen elif bitlen == 3 or bitlen == 4: self.result = 4 elif bitlen > 4 and bitlen < 9: self.result = 8 elif bitlen > 8 and bitlen < 17: self.result = 16 elif bitlen > 16 and bitlen < 33: self.result = 32 elif bitlen > 32 and bitlen < 65: self.result = 64 else: raise Exception("Nbits not supported") class GetConstExpr(ast.NodeVisitor): """ Gathers all constant math expressions (with numbers only). Used in ConstantFolding. Code shamefully stolen from pythran. """ def __init__(self): 'Result contains all constant expressions' self.result = set() def reset(self): 'Empty result set, so that instance may be re-used' self.result = set() def add(self, node): 'Add a node to result and return True' self.result.add(node) return True # A Num node is constant by definition visit_Num = add def visit_BinOp(self, node): 'A BinOp node is a const expr if both its operands are' rec = all(map(self.visit, (node.left, node.right))) return rec and self.add(node) def visit_UnaryOp(self, node): 'A UnaryOp is a const expr if its operand is' return self.visit(node.operand) and self.add(node) class CheckConstExpr(ast.NodeVisitor): """ Check if given node is exactly a constant expression. """ def visit_Num(self, node): 'Num is always a constant' # pylint: disable=unused-argument, no-self-use return True def visit_BinOp(self, node): 'A BinOp node is a const expr if both its operands are' return all(map(self.visit, (node.left, node.right))) def visit_BoolOp(self, node): 'A BoolOp is a const expr if all its operands are' return all(map(self.visit, node.values)) def visit_UnaryOp(self, node): 'A UnaryOp is a const expr if its operand is' return self.visit(node.operand) class ConstFolding(ast.NodeTransformer): """ Applies constant folding on an ast. Also stolen from pythran. """ # pylint: disable=exec-used def __init__(self, node, nbits): 'Gather constant expressions' analyzer = GetConstExpr() analyzer.visit(node) self.constexpr = analyzer.result self.mod = 2**nbits def visit_BinOp(self, node): 'If node is a constant expression, replace it with its evaluated value' if node in self.constexpr: # evaluation fake_node = ast.Expression(ast.BinOp(node, ast.Mod(), ast.Num(self.mod))) ast.fix_missing_locations(fake_node) code = compile(fake_node, '<constant folding>', 'eval') obj_env = globals().copy() exec code in obj_env value = eval(code, obj_env) new_node = ast.Num(value) return new_node else: return self.generic_visit(node) def visit_BoolOp(self, node): 'A custom BoolOp can be used in flattened AST' if type(node.op) not in (ast.Add, ast.Mult, ast.BitXor, ast.BitAnd, ast.BitOr): return self.generic_visit(node) # get constant parts of node: list_cste = [child for child in node.values if isinstance(child, ast.Num)] if len(list_cste) < 2: return self.generic_visit(node) rest_values = [n for n in node.values if n not in list_cste] fake_node = Unflattening().visit(ast.BoolOp(node.op, list_cste)) fake_node = ast.Expression(fake_node) ast.fix_missing_locations(fake_node) code = compile(fake_node, '<constant folding>', 'eval') obj_env = globals().copy() exec code in obj_env value = eval(code, obj_env) new_node = ast.Num(value) rest_values.append(new_node) return ast.BoolOp(node.op, rest_values) def visit_UnaryOp(self, node): 'Same idea as visit_BinOp' if node in self.constexpr: # evaluation fake_node = ast.Expression(ast.BinOp(node, ast.Mod(), ast.Num(self.mod))) ast.fix_missing_locations(fake_node) code = compile(fake_node, '<constant folding>', 'eval') obj_env = globals().copy() exec code in obj_env value = eval(code, obj_env) new_node = ast.Num(value) return new_node else: return self.generic_visit(node) class ReplaceBitwiseOp(ast.NodeTransformer): """ Replace bitwise operations (&, |, ^, ~) with custom functions (mand, mor, mxor, mnot) so that expression may be used in sympy. """ def visit_BinOp(self, node): 'Replace bitwise operation with function call' self.generic_visit(node) if isinstance(node.op, ast.BitAnd): return ast.Call(ast.Name('mand', ast.Load()), [node.left, node.right], [], None, None) if isinstance(node.op, ast.BitOr): return ast.Call(ast.Name('mor', ast.Load()), [node.left, node.right], [], None, None) if isinstance(node.op, ast.BitXor): return ast.Call(ast.Name('mxor', ast.Load()), [node.left, node.right], [], None, None) if isinstance(node.op, ast.LShift): return ast.Call(ast.Name('mlshift', ast.Load()), [node.left, node.right], [], None, None) if isinstance(node.op, ast.RShift): return ast.Call(ast.Name('mrshift', ast.Load()), [node.left, node.right], [], None, None) return node def visit_UnaryOp(self, node): 'Replace bitwise unaryop with function call' self.generic_visit(node) if isinstance(node.op, ast.Invert): return ast.Call(ast.Name('mnot', ast.Load()), [node.operand], [], None, None) return node class ReplaceBitwiseFunctions(ast.NodeTransformer): """ Replace mand, mxor and mor with their respective operations. """ def visit_Call(self, node): 'Replace custom function with bitwise operators' self.generic_visit(node) if isinstance(node.func, ast.Name): if len(node.args) == 2: if node.func.id == "mand": op = ast.BitAnd() elif node.func.id == "mxor": op = ast.BitXor() elif node.func.id == "mor": op = ast.BitOr() elif node.func.id == "mlshift": op = ast.LShift() elif node.func.id == "mrshift": op = ast.RShift() else: return node return ast.BinOp(node.args[0], op, node.args[1]) elif len(node.args) == 1 and node.func.id == "mnot": arg = node.args[0] self.generic_visit(node) return ast.UnaryOp(ast.Invert(), arg) return self.generic_visit(node) class GetConstMod(ast.NodeTransformer): """ Replace constants with their value mod 2^n """ def __init__(self, nbits): self.nbits = nbits def visit_Num(self, node): 'Replace constant value with value mod 2^n' node.n = node.n % 2**self.nbits return node class Comparator(object): """ Compare two ast to check if they're equivalent """ # pylint: disable=no-self-use def __init__(self, commut=True): 'Specify if comparator is commutative or not' self.commut = commut def visit(self, node1, node2): 'Call appropriate visitor for matching types' if type(node1) != type(node2): return False # get type of node to call the right visit_ method nodetype = node1.__class__.__name__ comp = getattr(self, "visit_%s" % nodetype, None) if not comp: raise Exception("no comparison function for %s" % nodetype) return comp(node1, node2) def visit_Module(self, node1, node2): 'Check if body of are equivalent' if len(node1.body) != len(node2.body): return False for i in range(len(node1.body)): if not self.visit(node1.body[i], node2.body[i]): return False return True def visit_Expression(self, node1, node2): 'Check if bodies are the same' return self.visit(node1.body, node2.body) def visit_Expr(self, node1, node2): 'Check if value are equivalent' return self.visit(node1.value, node2.value) def visit_Call(self, node1, node2): 'Check func id and arguments' if node1.func.id != node2.func.id: return False return all(self.visit(arg1, arg2) for arg1, arg2 in zip(node1.args, node2.args)) def visit_BinOp(self, node1, node2): 'Check type of operation and operands' if type(node1.op) != type(node2.op): return False # if operation is commutative, left and right operands are # interchangeable cond1 = (self.visit(node1.left, node2.left) and self.visit(node1.right, node2.right)) cond2 = (self.visit(node1.left, node2.right) and self.visit(node1.right, node2.left)) # non-commutative comparator if not self.commut: return cond1 if isinstance(node1.op, (ast.Add, ast.Mult, ast.BitAnd, ast.BitOr, ast.BitXor)): if cond1 or cond2: return True else: return False else: if cond1: return True return False def visit_BoolOp(self, node1, node2): 'Check type of operation and operands (not considering order)' if type(node1.op) != type(node2.op): return False if len(node1.values) != len(node2.values): return False # redefine __hash__ for set comparison hooks = apply_hooks() # this implies that operation is associative / commutative result = set(node1.values) == set(node2.values) restore_hooks(hooks) return result def visit_UnaryOp(self, node1, node2): 'Check type of operation and operand' if type(node1.op) != type(node2.op): return False return self.visit(node1.operand, node2.operand) def visit_Assign(self, node1, node2): 'Compare targets and values' return (self.visit(node1.targets[0], node2.targets[0]) and self.visit(node1.value, node2.value)) def visit_Name(self, node1, node2): 'Check id' return node1.id == node2.id and type(node1.ctx) == type(node2.ctx) def visit_Num(self, node1, node2): 'Check num value' return node1.n == node2.n