"""Fixer that inserts mypy annotations from json file into code. This fixer consumes json from TYPE_COLLECTION_JSON env variable in the following format: [ { "path": "/Users/svorobev/src/client/build_number/__init__.py", "func_name": "is_test", "arg_types": ["int", "str"], "ret_type": "Any" }, ... ] (The old format with "type_comment" instead of "arg_types" and "ret_type" is also still supported.) """ from __future__ import print_function import json # noqa import os import re from contextlib import contextmanager from lib2to3.fixer_util import syms, touch_import from lib2to3.pgen2 import token from lib2to3.pytree import Base, Leaf, Node from typing import __all__ as typing_all # type: ignore from typing import Any, Dict, List, Optional, Tuple try: from typing import Text except ImportError: # In Python 3.5.1 stdlib, typing.py does not define Text Text = str # type: ignore from .fix_annotate import FixAnnotate # Taken from mypy codebase: # https://github.com/python/mypy/blob/745d300b8304c3dcf601477762bf9d70b9a4619c/mypy/main.py#L503 PY_EXTENSIONS = ['.pyi', '.py'] def crawl_up(arg): # type: (str) -> Tuple[str, str] """Given a .py[i] filename, return (root directory, module). We crawl up the path until we find a directory without __init__.py[i], or until we run out of path components. """ dir, mod = os.path.split(arg) mod = strip_py(mod) or mod while dir and get_init_file(dir): dir, base = os.path.split(dir) if not base: break if mod == '__init__' or not mod: mod = base else: mod = base + '.' + mod return dir, mod def strip_py(arg): # type: (str) -> Optional[str] """Strip a trailing .py or .pyi suffix. Return None if no such suffix is found. """ for ext in PY_EXTENSIONS: if arg.endswith(ext): return arg[:-len(ext)] return None def get_init_file(dir): # type: (str) -> Optional[str] """Check whether a directory contains a file named __init__.py[i]. If so, return the file's name (with dir prefixed). If not, return None. This prefers .pyi over .py (because of the ordering of PY_EXTENSIONS). """ for ext in PY_EXTENSIONS: f = os.path.join(dir, '__init__' + ext) if os.path.isfile(f): return f return None def get_funcname(node): # type: (Optional[Node]) -> Text """Get function name by (approximately) the following rules: - function -> function_name - method -> ClassName.function_name More specifically, we include every class and function name that the node is a child of, so nested classes and functions get names like OuterClass.InnerClass.outer_fn.inner_fn. """ components = [] # type: List[str] while node: if node.type in (syms.classdef, syms.funcdef): name = node.children[1] assert name.type == token.NAME, repr(name) assert isinstance(name, Leaf) # Same as previous, for mypy components.append(name.value) node = node.parent return '.'.join(reversed(components)) def count_args(node, results): # type: (Node, Dict[str, Base]) -> Tuple[int, bool, bool, bool] """Count arguments and check for self and *args, **kwds. Return (selfish, count, star, starstar) where: - count is total number of args (including *args, **kwds) - selfish is True if the initial arg is named 'self' or 'cls' - star is True iff *args is found - starstar is True iff **kwds is found """ count = 0 selfish = False star = False starstar = False args = results.get('args') if isinstance(args, Node): children = args.children elif isinstance(args, Leaf): children = [args] else: children = [] # Interpret children according to the following grammar: # (('*'|'**')? NAME ['=' expr] ','?)* skip = False previous_token_is_star = False for child in children: if skip: skip = False elif isinstance(child, Leaf): # A single '*' indicates the rest of the arguments are keyword only # and shouldn't be counted as a `*`. if child.type == token.STAR: previous_token_is_star = True elif child.type == token.DOUBLESTAR: starstar = True elif child.type == token.NAME: if count == 0: if child.value in ('self', 'cls'): selfish = True count += 1 if previous_token_is_star: star = True elif child.type == token.EQUAL: skip = True if child.type != token.STAR: previous_token_is_star = False return count, selfish, star, starstar class FixAnnotateJson(FixAnnotate): needed_imports = None line_drift = 5 def add_import(self, mod, name): if mod == self.current_module(): return if self.needed_imports is None: self.needed_imports = set() self.needed_imports.add((mod, name)) def patch_imports(self, types, node): if self.needed_imports: for mod, name in sorted(self.needed_imports): touch_import(mod, name, node) self.needed_imports = None def set_filename(self, filename): super(FixAnnotateJson, self).set_filename(filename) self._current_module = crawl_up(filename)[1] def current_module(self): return self._current_module def make_annotation(self, node, results): name = results['name'] assert isinstance(name, Leaf), repr(name) assert name.type == token.NAME, repr(name) funcname = get_funcname(node) res = self.get_annotation_from_stub(node, results, funcname) # If we couldn't find an annotation and this is a classmethod or # staticmethod, try again with just the funcname, since the # type collector can't figure out class names for those. # (We try with the full name above first so that tools that *can* figure # that out, like dmypy suggest, can use it.) if not res: decs = self.get_decorators(node) if 'staticmethod' in decs or 'classmethod' in decs: res = self.get_annotation_from_stub(node, results, name.value) return res stub_json_file = os.getenv('TYPE_COLLECTION_JSON') # JSON data for the current file stub_json = None # type: List[Dict[str, Any]] @classmethod @contextmanager def max_line_drift_set(cls, max_drift): old_drift = cls.line_drift cls.line_drift = max_drift try: yield finally: cls.line_drift = old_drift @classmethod def init_stub_json_from_data(cls, data, filename): cls.stub_json = data cls.top_dir = crawl_up(os.path.abspath(filename))[0] def init_stub_json(self): with open(self.__class__.stub_json_file) as f: data = json.load(f) self.__class__.init_stub_json_from_data(data, self.filename) def get_annotation_from_stub(self, node, results, funcname): if not self.__class__.stub_json: self.init_stub_json() data = self.__class__.stub_json # We are using relative paths in the JSON. items = [ it for it in data if it['func_name'] == funcname and (it['path'] == self.filename or os.path.join(self.__class__.top_dir, it['path']) == os.path.abspath(self.filename)) ] if len(items) > 1: # this can happen, because of # 1) nested functions # 2) method decorators # as a cheap and dirty solution we just return the nearest one by the line number # (keep the commented-out log_message call in case we need to come back to this) ## self.log_message("%s:%d: duplicate signatures for %s (at lines %s)" % ## (items[0]['path'], node.get_lineno(), items[0]['func_name'], ## ", ".join(str(it['line']) for it in items))) items.sort(key=lambda it: abs(node.get_lineno() - it['line'])) if items: it = items[0] # If the line number is too far off, the source probably drifted # since the trace was collected; it's better to skip this node. # (Allow some drift, since decorators also cause an offset.) if abs(node.get_lineno() - it['line']) >= self.line_drift: self.log_message("%s:%d: '%s' signature from line %d too far away -- skipping" % (self.filename, node.get_lineno(), it['func_name'], it['line'])) return None if 'signature' in it: sig = it['signature'] arg_types = sig['arg_types'] # Passes 1-2 don't always understand *args or **kwds, # so add '*Any' or '**Any' at the end if needed. count, selfish, star, starstar = count_args(node, results) for arg_type in arg_types: if arg_type.startswith('**'): starstar = False elif arg_type.startswith('*'): star = False if star: arg_types.append('*Any') if starstar: arg_types.append('**Any') # Pass 1 omits the first arg iff it's named 'self' or 'cls', # even if it's not a method, so insert `Any` as needed # (but only if it's not actually a method). if selfish and len(arg_types) == count - 1: if self.is_method(node): count -= 1 # Leave out the type for 'self' or 'cls' else: arg_types.insert(0, 'Any') # If after those adjustments the count is still off, # print a warning and skip this node. if len(arg_types) != count: self.log_message("%s:%d: source has %d args, annotation has %d -- skipping" % (self.filename, node.get_lineno(), count, len(arg_types))) return None ret_type = sig['return_type'] arg_types = [self.update_type_names(arg_type) for arg_type in arg_types] # Avoid common error "No return value expected" if ret_type == 'None' and self.has_return_exprs(node): ret_type = 'Optional[Any]' # Special case for generators. if (self.is_generator(node) and not (ret_type == 'Iterator' or ret_type.startswith('Iterator['))): if ret_type.startswith('Optional['): assert ret_type[-1] == ']' ret_type = ret_type[9:-1] ret_type = 'Iterator[%s]' % ret_type ret_type = self.update_type_names(ret_type) return arg_types, ret_type return None def update_type_names(self, type_str): # Replace e.g. `List[pkg.mod.SomeClass]` with # `List[SomeClass]` and remember to import it. return re.sub(r'[\w.:]+', self.type_updater, type_str) def type_updater(self, match): # Replace `pkg.mod.SomeClass` with `SomeClass` # and remember to import it. word = match.group() if word == '...': return word if '.' not in word and ':' not in word: # Assume it's either builtin or from `typing` if word in typing_all: self.add_import('typing', word) return word # If there is a :, treat that as the separator between the # module and the class. Otherwise assume everything but the # last element is the module. if ':' in word: mod, name = word.split(':') to_import = name.split('.', 1)[0] else: mod, name = word.rsplit('.', 1) to_import = name self.add_import(mod, to_import) return name