"""Handle recomputation of values of a function given its abstract syntax tree and function frame.""" import ast import builtins import functools import platform from typing import Any, Mapping, Dict, List, Optional, Union, Tuple, Set, Callable # pylint: disable=unused-import class Placeholder: """Represent a placeholder for variables local to the lambda such as targets in generator expressions.""" def __repr__(self) -> str: """Represent the placeholder as <Placeholder>.""" return "<Placeholder>" PLACEHOLDER = Placeholder() class Visitor(ast.NodeVisitor): """ Traverse the abstract syntax tree and recompute the values of each node defined by the function frame. :ivar recomputed_values: mapping node -> value assigned to each visited node :type recomputed_values: Mapping[ast.AST, Any] """ # pylint: disable=invalid-name # pylint: disable=missing-docstring # pylint: disable=too-many-public-methods def __init__(self, variable_lookup: List[Mapping[str, Any]]) -> None: """ Initialize. :param variable_lookup: list of lookup tables to look-up the values of the variables, sorted by precedence """ # Resolve precedence of variable lookup self._name_to_value = dict() # type: Dict[str, Any] for lookup in variable_lookup: for name, value in lookup.items(): if name not in self._name_to_value: self._name_to_value[name] = value # value assigned to each visited node self.recomputed_values = dict() # type: Dict[ast.AST, Any] def visit_Num(self, node: ast.Num) -> Union[int, float]: """Recompute the value as the number at the node.""" result = node.n self.recomputed_values[node] = result assert isinstance(result, (int, float)) return result def visit_Str(self, node: ast.Str) -> str: """Recompute the value as the string at the node.""" result = node.s self.recomputed_values[node] = result return result def visit_Bytes(self, node: ast.Bytes) -> bytes: """Recompute the value as the bytes at the node.""" result = node.s self.recomputed_values[node] = result return node.s def visit_List(self, node: ast.List) -> List[Any]: """Visit the elements and assemble the results into a list.""" if isinstance(node.ctx, ast.Store): raise NotImplementedError("Can not compute the value of a Store on a list") result = [self.visit(node=elt) for elt in node.elts] self.recomputed_values[node] = result return result def visit_Tuple(self, node: ast.Tuple) -> Tuple[Any, ...]: """Visit the elements and assemble the results into a tuple.""" if isinstance(node.ctx, ast.Store): raise NotImplementedError("Can not compute the value of a Store on a tuple") result = tuple(self.visit(node=elt) for elt in node.elts) self.recomputed_values[node] = result return result def visit_Set(self, node: ast.Set) -> Set[Any]: """Visit the elements and assemble the results into a set.""" result = set(self.visit(node=elt) for elt in node.elts) self.recomputed_values[node] = result return result def visit_Dict(self, node: ast.Dict) -> Dict[Any, Any]: """Visit keys and values and assemble a dictionary with the results.""" recomputed_dict = dict() # type: Dict[Any, Any] for key, val in zip(node.keys, node.values): assert isinstance(key, ast.AST) assert isinstance(val, ast.AST) recomputed_dict[self.visit(node=key)] = self.visit(node=val) self.recomputed_values[node] = recomputed_dict return recomputed_dict def visit_NameConstant(self, node: ast.NameConstant) -> Any: """Forward the node value as a result.""" self.recomputed_values[node] = node.value return node.value def visit_Name(self, node: ast.Name) -> Any: """Load the variable by looking it up in the variable look-up and in the built-ins.""" if not isinstance(node.ctx, ast.Load): raise NotImplementedError("Can only compute a value of Load on a name {}, but got context: {}".format( node.id, node.ctx)) result = None # type: Optional[Any] if node.id in self._name_to_value: result = self._name_to_value[node.id] if result is None and hasattr(builtins, node.id): result = getattr(builtins, node.id) if result is None and node.id != "None": # The variable refers to a name local of the lambda (e.g., a target in the generator expression). # Since we evaluate generator expressions with runtime compilation, None is returned here as a placeholder. return PLACEHOLDER self.recomputed_values[node] = result return result def visit_Expr(self, node: ast.Expr) -> Any: """Visit the node's ``value``.""" result = self.visit(node=node.value) self.recomputed_values[node] = result return result def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: """Visit the node operand and apply the operation on the result.""" if isinstance(node.op, ast.UAdd): result = +self.visit(node=node.operand) elif isinstance(node.op, ast.USub): result = -self.visit(node=node.operand) elif isinstance(node.op, ast.Not): result = not self.visit(node=node.operand) elif isinstance(node.op, ast.Invert): result = ~self.visit(node=node.operand) else: raise NotImplementedError("Unhandled op of {}: {}".format(node, node.op)) self.recomputed_values[node] = result return result def visit_BinOp(self, node: ast.BinOp) -> Any: """Recursively visit the left and right operand, respectively, and apply the operation on the results.""" # pylint: disable=too-many-branches left = self.visit(node=node.left) right = self.visit(node=node.right) if isinstance(node.op, ast.Add): result = left + right elif isinstance(node.op, ast.Sub): result = left - right elif isinstance(node.op, ast.Mult): result = left * right elif isinstance(node.op, ast.Div): result = left / right elif isinstance(node.op, ast.FloorDiv): result = left // right elif isinstance(node.op, ast.Mod): result = left % right elif isinstance(node.op, ast.Pow): result = left**right elif isinstance(node.op, ast.LShift): result = left << right elif isinstance(node.op, ast.RShift): result = left >> right elif isinstance(node.op, ast.BitOr): result = left | right elif isinstance(node.op, ast.BitXor): result = left ^ right elif isinstance(node.op, ast.BitAnd): result = left & right elif isinstance(node.op, ast.MatMult): result = left @ right else: raise NotImplementedError("Unhandled op of {}: {}".format(node, node.op)) self.recomputed_values[node] = result return result def visit_BoolOp(self, node: ast.BoolOp) -> Any: """Recursively visit the operands and apply the operation on them.""" values = [self.visit(value_node) for value_node in node.values] if isinstance(node.op, ast.And): result = functools.reduce(lambda left, right: left and right, values, True) elif isinstance(node.op, ast.Or): result = functools.reduce(lambda left, right: left or right, values, True) else: raise NotImplementedError("Unhandled op of {}: {}".format(node, node.op)) self.recomputed_values[node] = result return result def visit_Compare(self, node: ast.Compare) -> Any: """Recursively visit the comparators and apply the operations on them.""" # pylint: disable=too-many-branches left = self.visit(node=node.left) comparators = [self.visit(node=comparator) for comparator in node.comparators] result = None # type: Optional[Any] for comparator, op in zip(comparators, node.ops): if isinstance(op, ast.Eq): comparison = left == comparator elif isinstance(op, ast.NotEq): comparison = left != comparator elif isinstance(op, ast.Lt): comparison = left < comparator elif isinstance(op, ast.LtE): comparison = left <= comparator elif isinstance(op, ast.Gt): comparison = left > comparator elif isinstance(op, ast.GtE): comparison = left >= comparator elif isinstance(op, ast.Is): comparison = left is comparator elif isinstance(op, ast.IsNot): comparison = left is not comparator elif isinstance(op, ast.In): comparison = left in comparator elif isinstance(op, ast.NotIn): comparison = left not in comparator else: raise NotImplementedError("Unhandled op of {}: {}".format(node, op)) if result is None: result = comparison else: result = result and comparison left = comparator self.recomputed_values[node] = result return result def visit_Call(self, node: ast.Call) -> Any: """Visit the function and the arguments and finally make the function call with them.""" func = self.visit(node=node.func) args = [] # type: List[Any] for arg_node in node.args: if isinstance(arg_node, ast.Starred): args.extend(self.visit(node=arg_node)) else: args.append(self.visit(node=arg_node)) kwargs = dict() # type: Dict[str, Any] for keyword in node.keywords: if keyword.arg is None: kw = self.visit(node=keyword.value) for key, val in kw.items(): kwargs[key] = val else: kwargs[keyword.arg] = self.visit(node=keyword.value) result = func(*args, **kwargs) self.recomputed_values[node] = result return result def visit_IfExp(self, node: ast.IfExp) -> Any: """Visit the ``test``, and depending on its outcome, the ``body`` or ``orelse``.""" test = self.visit(node=node.test) if test: result = self.visit(node=node.body) else: result = self.visit(node=node.orelse) self.recomputed_values[node] = result return result def visit_Attribute(self, node: ast.Attribute) -> Any: """Visit the node's ``value`` and get the attribute from the result.""" value = self.visit(node=node.value) if not isinstance(node.ctx, ast.Load): raise NotImplementedError( "Can only compute a value of Load on the attribute {}, but got context: {}".format(node.attr, node.ctx)) result = getattr(value, node.attr) self.recomputed_values[node] = result return result def visit_Index(self, node: ast.Index) -> Any: """Visit the node's ``value``.""" result = self.visit(node=node.value) self.recomputed_values[node] = result return result def visit_Slice(self, node: ast.Slice) -> slice: """Visit ``lower``, ``upper`` and ``step`` and recompute the node as a ``slice``.""" lower = None # type: Optional[int] if node.lower is not None: lower = self.visit(node=node.lower) upper = None # type: Optional[int] if node.upper is not None: upper = self.visit(node=node.upper) step = None # type: Optional[int] if node.step is not None: step = self.visit(node=node.step) result = slice(lower, upper, step) self.recomputed_values[node] = result return result def visit_ExtSlice(self, node: ast.ExtSlice) -> Tuple[Any, ...]: """Visit each dimension of the advanced slicing and assemble the dimensions in a tuple.""" result = tuple(self.visit(node=dim) for dim in node.dims) self.recomputed_values[node] = result return result def visit_Subscript(self, node: ast.Subscript) -> Any: """Visit the ``slice`` and a ``value`` and get the element.""" value = self.visit(node=node.value) a_slice = self.visit(node=node.slice) result = value[a_slice] self.recomputed_values[node] = result return result def _execute_comprehension(self, node: Union[ast.ListComp, ast.SetComp, ast.GeneratorExp, ast.DictComp]) -> Any: """Compile the generator or comprehension from the node and execute the compiled code.""" args = [ast.arg(arg=name) for name in sorted(self._name_to_value.keys())] if platform.python_version_tuple() < ('3', ): raise NotImplementedError("Python versions below not supported, got: {}".format(platform.python_version())) if platform.python_version_tuple() < ('3', '8'): func_def_node = ast.FunctionDef( name="generator_expr", args=ast.arguments(args=args, kwonlyargs=[], kw_defaults=[], defaults=[]), decorator_list=[], body=[ast.Return(node)]) module_node = ast.Module(body=[func_def_node]) else: func_def_node = ast.FunctionDef( name="generator_expr", args=ast.arguments(args=args, posonlyargs=[], kwonlyargs=[], kw_defaults=[], defaults=[]), decorator_list=[], body=[ast.Return(node)]) module_node = ast.Module(body=[func_def_node], type_ignores=[]) ast.fix_missing_locations(module_node) code = compile(source=module_node, filename='<ast>', mode='exec') module_locals = {} # type: Dict[str, Any] module_globals = {} # type: Dict[str, Any] exec(code, module_globals, module_locals) # pylint: disable=exec-used generator_expr_func = module_locals["generator_expr"] return generator_expr_func(**self._name_to_value) def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any: """Compile the generator expression as a function and call it.""" result = self._execute_comprehension(node=node) for generator in node.generators: self.visit(generator.iter) # Do not set the computed value of the node since its representation would be non-informative. return result def visit_ListComp(self, node: ast.ListComp) -> Any: """Compile the list comprehension as a function and call it.""" result = self._execute_comprehension(node=node) for generator in node.generators: self.visit(generator.iter) self.recomputed_values[node] = result return result def visit_SetComp(self, node: ast.SetComp) -> Any: """Compile the set comprehension as a function and call it.""" result = self._execute_comprehension(node=node) for generator in node.generators: self.visit(generator.iter) self.recomputed_values[node] = result return result def visit_DictComp(self, node: ast.DictComp) -> Any: """Compile the dictionary comprehension as a function and call it.""" result = self._execute_comprehension(node=node) for generator in node.generators: self.visit(generator.iter) self.recomputed_values[node] = result return result def visit_Lambda(self, node: ast.Lambda) -> Callable[..., Any]: """Do not support inline lambda until there is a feature request since this is quite tricky to implement.""" raise NotImplementedError( "Recomputation of in-line lambda functions is not supported since it is quite tricky to implement and " "we decided to implement it only once there is a real need for it. " "Please make a feature request on https://github.com/Parquery/icontract") def visit_Return(self, node: ast.Return) -> Any: # pylint: disable=no-self-use """Raise an exception that this node is unexpected.""" raise AssertionError("Unexpected return node during the re-computation: {}".format(ast.dump(node))) def generic_visit(self, node: ast.AST) -> None: """Raise an exception that this node has not been handled.""" raise NotImplementedError("Unhandled recomputation of the node: {} {}".format(type(node), node))