"""Fixer that inserts mypy annotations into all methods. This transforms e.g. def foo(self, bar, baz=12): return bar + baz into a type annoted version: def foo(self, bar, baz=12): # type: (Any, int) -> Any # noqa: F821 return bar + baz or (when setting options['annotation_style'] to 'py3'): def foo(self, bar : Any, baz : int = 12) -> Any: return bar + baz It does not do type inference but it recognizes some basic default argument values such as numbers and strings (and assumes their type implies the argument type). It also uses some basic heuristics to decide whether to ignore the first argument: - always if it's named 'self' - if there's a @classmethod decorator Finally, it knows that __init__() is supposed to return None. """ from __future__ import print_function import os import re from lib2to3.fixer_base import BaseFix from lib2to3.fixer_util import syms, touch_import, find_indentation from lib2to3.patcomp import compile_pattern from lib2to3.pgen2 import token from lib2to3.pytree import Leaf, Node class FixAnnotate(BaseFix): # This fixer is compatible with the bottom matcher. BM_compatible = True # This fixer shouldn't run by default. explicit = True # The pattern to match. PATTERN = """ funcdef< 'def' name=any parameters=parameters< '(' [args=any] rpar=')' > ':' suite=any+ > """ _maxfixes = os.getenv('MAXFIXES') counter = None if not _maxfixes else int(_maxfixes) def transform(self, node, results): if FixAnnotate.counter is not None: if FixAnnotate.counter <= 0: return # Check if there's already a long-form annotation for some argument. parameters = results.get('parameters') if parameters is not None: for ch in parameters.pre_order(): if ch.prefix.lstrip().startswith('# type:'): return args = results.get('args') if args is not None: for ch in args.pre_order(): if ch.prefix.lstrip().startswith('# type:'): return children = results['suite'][0].children # NOTE: I've reverse-engineered the structure of the parse tree. # It's always a list of nodes, the first of which contains the # entire suite. Its children seem to be: # # [0] NEWLINE # [1] INDENT # [2...n-2] statements (the first may be a docstring) # [n-1] DEDENT # # Comments before the suite are part of the INDENT's prefix. # # "Compact" functions (e.g. "def foo(x, y): return max(x, y)") # have a different structure (no NEWLINE, INDENT, or DEDENT). # Check if there's already an annotation. for ch in children: if ch.prefix.lstrip().startswith('# type:'): return # There's already a # type: comment here; don't change anything. # Python 3 style return annotation are already skipped by the pattern ### Python 3 style argument annotation structure # # Structure of the arguments tokens for one positional argument without default value : # + LPAR '(' # + NAME_NODE_OR_LEAF arg1 # + RPAR ')' # # NAME_NODE_OR_LEAF is either: # 1. Just a leaf with value NAME # 2. A node with children: NAME, ':", node expr or value leaf # # Structure of the arguments tokens for one args with default value or multiple # args, with or without default value, and/or with extra arguments : # + LPAR '(' # + node # [ # + NAME_NODE_OR_LEAF # [ # + EQUAL '=' # + node expr or value leaf # ] # ( # + COMMA ',' # + NAME_NODE_OR_LEAF positional argn # [ # + EQUAL '=' # + node expr or value leaf # ] # )* # ] # [ # + STAR '*' # [ # + NAME_NODE_OR_LEAF positional star argument name # ] # ] # [ # + COMMA ',' # + DOUBLESTAR '**' # + NAME_NODE_OR_LEAF positional keyword argument name # ] # + RPAR ')' # Let's skip Python 3 argument annotations it = iter(args.children) if args else iter([]) for ch in it: if ch.type == token.STAR: # *arg part ch = next(it) if ch.type == token.COMMA: continue elif ch.type == token.DOUBLESTAR: # *arg part ch = next(it) if ch.type > 256: # this is a node, therefore an annotation assert ch.children[0].type == token.NAME return try: ch = next(it) if ch.type == token.COLON: # this is an annotation return elif ch.type == token.EQUAL: ch = next(it) ch = next(it) assert ch.type == token.COMMA continue except StopIteration: break # Compute the annotation annot = self.make_annotation(node, results) if annot is None: return argtypes, restype = annot if self.options['annotation_style'] == 'py3': self.add_py3_annot(argtypes, restype, node, results) else: self.add_py2_annot(argtypes, restype, node, results) # Common to py2 and py3 style annotations: if FixAnnotate.counter is not None: FixAnnotate.counter -= 1 # Also add 'from typing import Any' at the top if needed. self.patch_imports(argtypes + [restype], node) def add_py3_annot(self, argtypes, restype, node, results): args = results.get('args') argleaves = [] if args is None: # function with 0 arguments it = iter([]) elif len(args.children) == 0: # function with 1 argument it = iter([args]) else: # function with multiple arguments or 1 arg with default value it = iter(args.children) for ch in it: argstyle = 'name' if ch.type == token.STAR: # *arg part argstyle = 'star' ch = next(it) if ch.type == token.COMMA: continue elif ch.type == token.DOUBLESTAR: # *arg part argstyle = 'keyword' ch = next(it) assert ch.type == token.NAME argleaves.append((argstyle, ch)) try: ch = next(it) if ch.type == token.EQUAL: ch = next(it) ch = next(it) assert ch.type == token.COMMA continue except StopIteration: break # when self or cls is not annotated, argleaves == argtypes+1 argleaves = argleaves[len(argleaves) - len(argtypes):] for ch_withstyle, chtype in zip(argleaves, argtypes): style, ch = ch_withstyle if style == 'star': assert chtype[0] == '*' assert chtype[1] != '*' chtype = chtype[1:] elif style == 'keyword': assert chtype[0:2] == '**' assert chtype[2] != '*' chtype = chtype[2:] ch.value = '%s: %s' % (ch.value, chtype) # put spaces around the equal sign if ch.next_sibling and ch.next_sibling.type == token.EQUAL: nextch = ch.next_sibling if not nextch.prefix[:1].isspace(): nextch.prefix = ' ' + nextch.prefix nextch = nextch.next_sibling assert nextch != None if not nextch.prefix[:1].isspace(): nextch.prefix = ' ' + nextch.prefix # Add return annotation rpar = results['rpar'] rpar.value = '%s -> %s' % (rpar.value, restype) rpar.changed() def add_py2_annot(self, argtypes, restype, node, results): children = results['suite'][0].children # Insert '# type: {annot}' comment. # For reference, see lib2to3/fixes/fix_tuple_params.py in stdlib. if len(children) >= 1 and children[0].type != token.NEWLINE: # one liner function if children[0].prefix.strip() == '': children[0].prefix = '' children.insert(0, Leaf(token.NEWLINE, '\n')) children.insert( 1, Leaf(token.INDENT, find_indentation(node) + ' ')) children.append(Leaf(token.DEDENT, '')) if len(children) >= 2 and children[1].type == token.INDENT: degen_str = '(...) -> %s' % restype short_str = '(%s) -> %s' % (', '.join(argtypes), restype) if (len(short_str) > 64 or len(argtypes) > 5) and len(short_str) > len(degen_str): self.insert_long_form(node, results, argtypes) annot_str = degen_str else: annot_str = short_str children[1].prefix = '%s# type: %s\n%s' % (children[1].value, annot_str, children[1].prefix) children[1].changed() else: self.log_message("%s:%d: cannot insert annotation for one-line function" % (self.filename, node.get_lineno())) def insert_long_form(self, node, results, argtypes): argtypes = list(argtypes) # We destroy it args = results['args'] if isinstance(args, Node): children = args.children elif isinstance(args, Leaf): children = [args] else: children = [] # Interpret children according to the following grammar: # (('*'|'**')? NAME ['=' expr] ','?)* flag = False # Set when the next leaf should get a type prefix indent = '' # Will be set by the first child def set_prefix(child): if argtypes: arg = argtypes.pop(0).lstrip('*') else: arg = 'Any' # Somehow there aren't enough args if not arg: # Skip self (look for 'check_self' below) prefix = child.prefix.rstrip() else: prefix = ' # type: ' + arg old_prefix = child.prefix.strip() if old_prefix: assert old_prefix.startswith('#') prefix += ' ' + old_prefix child.prefix = prefix + '\n' + indent check_self = self.is_method(node) for child in children: if check_self and isinstance(child, Leaf) and child.type == token.NAME: check_self = False if child.value in ('self', 'cls'): argtypes.insert(0, '') if not indent: indent = ' ' * child.column if isinstance(child, Leaf) and child.value == ',': flag = True elif isinstance(child, Leaf) and flag: set_prefix(child) flag = False need_comma = len(children) >= 1 and children[-1].type != token.COMMA if need_comma and len(children) >= 2: if (children[-1].type == token.NAME and (children[-2].type in (token.STAR, token.DOUBLESTAR))): need_comma = False if need_comma: children.append(Leaf(token.COMMA, u",")) # Find the ')' and insert a prefix before it too. parameters = args.parent close_paren = parameters.children[-1] assert close_paren.type == token.RPAR, close_paren set_prefix(close_paren) assert not argtypes, argtypes def patch_imports(self, types, node): for typ in types: if 'Any' in typ: touch_import('typing', 'Any', node) break def make_annotation(self, node, results): name = results['name'] assert isinstance(name, Leaf), repr(name) assert name.type == token.NAME, repr(name) decorators = self.get_decorators(node) is_method = self.is_method(node) if name.value == '__init__' or not self.has_return_exprs(node): restype = 'None' else: restype = 'Any' args = results.get('args') argtypes = [] if isinstance(args, Node): children = args.children elif isinstance(args, Leaf): children = [args] else: children = [] # Interpret children according to the following grammar: # (('*'|'**')? NAME ['=' expr] ','?)* stars = inferred_type = '' in_default = False at_start = True for child in children: if isinstance(child, Leaf): if child.value in ('*', '**'): stars += child.value elif child.type == token.NAME and not in_default: if not is_method or not at_start or 'staticmethod' in decorators: inferred_type = 'Any' else: # Always skip the first argument if it's named 'self'. # Always skip the first argument of a class method. if child.value == 'self' or 'classmethod' in decorators: pass else: inferred_type = 'Any' elif child.value == '=': in_default = True elif in_default and child.value != ',': if child.type == token.NUMBER: if re.match(r'\d+[lL]?$', child.value): inferred_type = 'int' else: inferred_type = 'float' # TODO: complex? elif child.type == token.STRING: if child.value.startswith(('u', 'U')): inferred_type = 'unicode' else: inferred_type = 'str' elif child.type == token.NAME and child.value in ('True', 'False'): inferred_type = 'bool' elif child.value == ',': if inferred_type: argtypes.append(stars + inferred_type) # Reset stars = inferred_type = '' in_default = False at_start = False if inferred_type: argtypes.append(stars + inferred_type) return argtypes, restype # The parse tree has a different shape when there is a single # decorator vs. when there are multiple decorators. DECORATED = "decorated< (d=decorator | decorators< dd=decorator+ >) funcdef >" decorated = compile_pattern(DECORATED) def get_decorators(self, node): """Return a list of decorators found on a function definition. This is a list of strings; only simple decorators (e.g. @staticmethod) are returned. If the function is undecorated or only non-simple decorators are found, return []. """ if node.parent is None: return [] results = {} if not self.decorated.match(node.parent, results): return [] decorators = results.get('dd') or [results['d']] decs = [] for d in decorators: for child in d.children: if isinstance(child, Leaf) and child.type == token.NAME: decs.append(child.value) return decs def is_method(self, node): """Return whether the node occurs (directly) inside a class.""" node = node.parent while node is not None: if node.type == syms.classdef: return True if node.type == syms.funcdef: return False node = node.parent return False RETURN_EXPR = "return_stmt< 'return' any >" return_expr = compile_pattern(RETURN_EXPR) def has_return_exprs(self, node): """Traverse the tree below node looking for 'return expr'. Return True if at least 'return expr' is found, False if not. (If both 'return' and 'return expr' are found, return True.) """ results = {} if self.return_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if self.has_return_exprs(child): return True return False YIELD_EXPR = "yield_expr< 'yield' [any] >" yield_expr = compile_pattern(YIELD_EXPR) def is_generator(self, node): """Traverse the tree below node looking for 'yield [expr]'.""" results = {} if self.yield_expr.match(node, results): return True for child in node.children: if child.type not in (syms.funcdef, syms.classdef): if self.is_generator(child): return True return False