#!/usr/bin/env python3
"""Re-apply type annotations from .pyi stubs to your codebase."""

import os
import re
import sys
import tokenize
import traceback
from functools import singledispatch
from lib2to3 import pygram, pytree
from lib2to3.pgen2 import driver, token
from lib2to3.pgen2.parse import ParseError
from lib2to3.pygram import python_symbols as syms
from lib2to3.pytree import Leaf, Node
from pathlib import Path

from pathspec import PathSpec
from typed_ast import ast3

from .config import ReApplyFlags
from .version import __version__


def retype_path(
    src,
    pyi_dir,
    targets,
    *,
    src_explicitly_given=False,
    quiet=False,
    hg=False,
    flags=None,
):
    """Recursively retype files or directories given. Generate errors."""
    src = src.absolute()
    if src.is_dir():
        extra_ignore = []
        for folder in [pyi_dir, targets]:
            try:
                extra_ignore.append("/{}".format(folder.relative_to(src)))
            except ValueError:
                pass
        for file in walk_not_git_ignored(
            src, lambda p: p.suffix == ".py", extra_ignore
        ):
            nested = file.relative_to(src).parent
            yield from retype_path(
                file,
                pyi_dir / nested,
                targets / nested,
                quiet=quiet,
                hg=hg,
                flags=flags,
            )
    elif src.suffix == ".py" or src_explicitly_given:
        try:
            retype_file(src, pyi_dir, targets, quiet=quiet, hg=hg, flags=flags)
        except Exception as e:
            yield (src, str(e), type(e), traceback.format_tb(e.__traceback__))


def retype_file(src, pyi_dir, targets, *, quiet=False, hg=False, flags=None):
    """Retype `src`, finding types in `pyi_dir`. Save in `targets`.

    The file should remain formatted exactly as it was before, save for:
    - annotations
    - additional imports needed to satisfy annotations
    - additional module-level names needed to satisfy annotations

    Type comments in sources are normalized to type annotations.
    """
    if flags is None:
        flags = ReApplyFlags()
    with tokenize.open(src) as src_buffer:
        src_contents = src_buffer.read()
        if src_contents == "":
            return
        src_encoding = src_buffer.encoding
        src_node = lib2to3_parse(src_contents)
    try:
        with open((pyi_dir / src.name).with_suffix(".pyi")) as pyi_file:
            pyi_txt = pyi_file.read()
    except FileNotFoundError:
        if not quiet:
            print(
                f"warning: .pyi file for source {src} not found in {pyi_dir}",
                file=sys.stderr,
            )
    else:
        pyi_ast = ast3.parse(pyi_txt)
        assert isinstance(pyi_ast, ast3.Module)
        reapply_all(pyi_ast.body, src_node, flags)
    fix_remaining_type_comments(src_node, flags)
    targets.mkdir(parents=True, exist_ok=True)
    with open(targets / src.name, "w", encoding=src_encoding) as target_file:
        target_file.write(lib2to3_unparse(src_node, hg=hg))
    return targets / src.name


def lib2to3_parse(src_txt):
    """Given a string with source, return the lib2to3 Node."""
    grammar = pygram.python_grammar_no_print_statement
    drv = driver.Driver(grammar, pytree.convert)
    if src_txt[-1] != "\n":
        nl = "\r\n" if "\r\n" in src_txt[:1024] else "\n"
        src_txt += nl
    try:
        result = drv.parse_string(src_txt, True)
    except ParseError as pe:
        lineno, column = pe.context[1]
        lines = src_txt.splitlines()
        try:
            faulty_line = lines[lineno - 1]
        except IndexError:
            faulty_line = "<line number missing in source>"
        raise ValueError(f"Cannot parse: {lineno}:{column}: {faulty_line}") from None

    if isinstance(result, Leaf):
        result = Node(syms.file_input, [result])

    return result


def lib2to3_unparse(node, *, hg=False):
    """Given a lib2to3 node, return its string representation."""
    code = str(node)
    if hg:
        from retype.retype_hgext import apply_job_security

        code = apply_job_security(code)
    return code


def reapply_all(ast_node, lib2to3_node, flags):
    """Reapplies the typed_ast node into the lib2to3 tree.

    Also does post-processing. This is done in reverse order to enable placing
    TypeVars and aliases that depend on one another.
    """
    late_processing = reapply(ast_node, lib2to3_node, flags)
    for lazy_func in reversed(late_processing):
        lazy_func()


@singledispatch
def reapply(ast_node, lib2to3_node, flags):
    """Reapplies the typed_ast node into the lib2to3 tree.

    By default does nothing.
    """
    return []


@reapply.register(list)
def _r_list(l, lib2to3_node, flags):
    if lib2to3_node.type not in (syms.file_input, syms.suite):
        return []

    result = []
    for pyi_node in l:
        result.extend(reapply(pyi_node, lib2to3_node, flags))
    return result


@reapply.register(ast3.ImportFrom)
def _r_importfrom(import_from, node, flags):
    assert node.type in (syms.file_input, syms.suite)
    level = import_from.level or 0
    module = "." * level + (import_from.module or "")
    names = import_from.names
    for child in flatten_some(node.children):
        if child.type != syms.simple_stmt:
            continue

        stmt = child.children[0]
        if stmt.type == syms.import_from:
            imp = stmt.children
            # if the module we're looking for is already imported, skip it.
            if str(imp[1]).strip() == module and names_already_imported(names, imp[3]):
                break
    else:
        import_stmt = make_import(*names, from_module=module)
        append_after_imports(import_stmt, node)
    return []


@reapply.register(ast3.Import)
def _r_import(import_, node, flags):
    assert node.type in (syms.file_input, syms.suite)
    names = import_.names
    for child in flatten_some(node.children):
        if child.type != syms.simple_stmt:
            continue

        stmt = child.children[0]
        if stmt.type == syms.import_name:
            imp = stmt.children
            # if the module we're looking for is already imported, skip it.
            if names_already_imported(names, imp[1]):
                break
    else:
        import_stmt = make_import(*names)
        append_after_imports(import_stmt, node)
    return []


@reapply.register(ast3.ClassDef)
def _r_classdef(cls, node, flags):
    assert node.type in (syms.file_input, syms.suite)
    name = Leaf(token.NAME, cls.name)
    for child in flatten_some(node.children):
        if child.type == syms.decorated:
            # skip decorators
            child = child.children[1]
        if child.type == syms.classdef and child.children[1] == name:
            cls_node = child.children[-1]
            break
    else:
        raise ValueError(f"Class {name.value!r} not found in source.")

    result = []
    for ast_elem in cls.body:
        result.extend(reapply(ast_elem, cls_node, flags))
    return result


@reapply.register(ast3.AsyncFunctionDef)
@reapply.register(ast3.FunctionDef)
def _r_functiondef(fun, node, flags):
    assert node.type in (syms.file_input, syms.suite)
    name = Leaf(token.NAME, fun.name)
    pyi_decorators = decorator_names(fun.decorator_list)
    pyi_method_decorators = list(
        filter(is_builtin_method_decorator, pyi_decorators)
    ) or ["instancemethod"]
    is_method = (
        node.parent is not None
        and node.parent.type == syms.classdef
        and "staticmethod" not in pyi_method_decorators
    )
    args, returns = get_function_signature(fun, is_method=is_method)
    for child in flatten_some(node.children):
        decorators = None
        if child.type == syms.decorated:
            # skip decorators
            decorators = child.children[0]
            child = child.children[1]

        if child.type in (syms.async_stmt, syms.async_funcdef):
            # async def in 3.5 and 3.6
            child = child.children[1]

        if child.type != syms.funcdef:
            continue

        offset = 1
        if child.children[offset] == name:
            lineno = child.get_lineno()
            column = 1

            if decorators:
                src_decorators = decorator_names(decorators)
                src_method_decorators = list(
                    filter(is_builtin_method_decorator, src_decorators)
                ) or ["instancemethod"]
                if pyi_method_decorators != src_method_decorators:
                    raise ValueError(
                        f"Incompatible method kind for {fun.name!r}: "
                        + f"{lineno}:{column}: Expected: "
                        + f"{pyi_method_decorators[0]}, actual: "
                        + f"{src_method_decorators[0]}"
                    )

                is_method = "staticmethod" not in pyi_decorators

            try:
                annotate_parameters(
                    child.children[offset + 1], args, is_method=is_method, flags=flags
                )
                annotate_return(child.children, returns, offset + 2, flags)
                reapply(fun.body, child.children[-1], flags)
                remove_function_signature_type_comment(child.children[-1])
            except ValueError as ve:
                raise ValueError(
                    f"Annotation problem in function {name.value!r}: "
                    + f"{lineno}:{column}: {ve}"
                )
            break
    else:
        raise ValueError(f"Function {name.value!r} not found in source.")

    return []


@reapply.register(ast3.AnnAssign)
def _r_annassign(annassign, body, flags):
    assert body.type in (syms.file_input, syms.suite)

    target = annassign.target
    if isinstance(target, ast3.Name):
        name = target.id
    elif isinstance(target, ast3.Attribute):
        name = serialize_attribute(target)
    else:
        raise NotImplementedError(f"unexpected assignment target: {target}")

    annotation = convert_annotation(annassign.annotation)
    annotation.prefix = " "
    annassign_node = Node(syms.annassign, [new(_colon), annotation])
    for child in flatten_some(body.children):
        if child.type != syms.simple_stmt:
            continue

        maybe_expr = child.children[0]
        if maybe_expr.type != syms.expr_stmt:
            continue

        expr = maybe_expr.children

        if (
            expr[0].type in (token.NAME, syms.power)
            and minimize_whitespace(str(expr[0])) == name
        ):
            if expr[1].type == syms.annassign:
                # variable already typed, let's just ensure it's sane
                if len(expr[1].children) > 2 and expr[1].children[2] != _eq:
                    raise NotImplementedError(
                        f"unexpected element after annotation: {str(expr[3])}"
                    )
                expr[1].children[1] = maybe_replace_any_if_equal(
                    f"variable annotation for {name!r}",
                    annotation,
                    expr[1].children[1],
                    flags,
                )
                break

            if expr[1] != _eq:
                # If it's not an assignment, we're ignoring it. It could be:
                # - indexing
                # - tuple unpacking
                # - calls
                # - etc. etc.
                continue

            maybe_type_comment = _type_comment_re.match(child.children[1].prefix)
            if maybe_type_comment:
                # variable already typed by type comment, let's ensure it's sane...
                type_comment = parse_type_comment(maybe_type_comment.group("type"))
                actual_annotation = convert_annotation(type_comment)
                ensure_annotations_equal(
                    f"variable annotation for {name!r}",
                    annotation,
                    actual_annotation,
                    flags,
                )
                # ...and remove the redundant comment
                child.children[1].prefix = maybe_space_before_comment(
                    maybe_type_comment.group("nl")
                )

            if len(expr[2:]) > 0 and expr[2:] != [_ellipsis]:
                # copy the value unless it was an old-style variable type comment
                # with no actual value (but just a ... placeholder)
                annassign_node.children.append(new(_eq))
                annassign_node.children.extend(new(elem) for elem in expr[2:])

            maybe_expr.children = [expr[0], annassign_node]

            break
    else:
        # If the variable was used in some `if` statement, etc.; let's define
        # its type from the stub on the top level of the function.
        offset, prefix = get_offset_and_prefix(body, skip_assignments=True)
        body.children.insert(
            offset,
            Node(
                syms.simple_stmt,
                [
                    Node(syms.expr_stmt, [Leaf(token.NAME, name), annassign_node]),
                    new(_newline),
                ],
                prefix=prefix.lstrip("\n"),
            ),
        )

    return []


@reapply.register(ast3.Assign)
def _r_assign(assign, body, flags):
    assert body.type in (syms.file_input, syms.suite)

    if len(assign.targets) != 1:
        # Type aliases and old-style var type comments cannot have multiple
        # targets.
        return []

    if assign.type_comment:
        # old-style variable type comment, let's treat it exactly like
        # a new-style annotated assignment
        tc = parse_type_comment(assign.type_comment)
        annassign = ast3.AnnAssign(
            target=assign.targets[0], annotation=tc, value=assign.value, simple=False
        )
        return reapply(annassign, body, flags)

    if not isinstance(assign.targets[0], ast3.Name):
        # Type aliases cannot be attributes, etc.
        return []

    name = assign.targets[0].id
    value = convert_annotation(assign.value)
    value.prefix = " "

    for child in flatten_some(body.children):
        if child.type != syms.simple_stmt:
            continue

        maybe_expr = child.children[0]
        if maybe_expr.type != syms.expr_stmt:
            continue

        expr = maybe_expr.children

        if (
            isinstance(expr[0], Leaf)
            and expr[0].type == token.NAME
            and expr[0].value == name
            and expr[1] == _eq
        ):
            expr[2] = maybe_replace_any_if_equal(
                f"alias {name!r}", value, expr[2], flags=flags
            )
            break
    else:
        # We need to defer placing aliases because we need to place them
        # relative to their usage, and the type annotations likely come after
        # in the .pyi file.

        def lazy_aliasing() -> None:
            # We should find the first place where the alias is used and put it
            # right above.  This way we don't need to look at the value at all.
            _, prefix = get_offset_and_prefix(body, skip_assignments=True)
            name_node = Leaf(token.NAME, name)
            for _offset, stmt in enumerate(body.children):
                if name_used_in_node(stmt, name_node):
                    break
            else:
                _offset = -1

            body.children.insert(
                _offset,
                Node(
                    syms.simple_stmt,
                    [
                        Node(syms.expr_stmt, [Leaf(token.NAME, name), new(_eq), value]),
                        new(_newline),
                    ],
                    prefix=prefix.lstrip("\n"),
                ),
            )

        return [lazy_aliasing]

    return []


@singledispatch
def serialize_attribute(attr):
    """serialize_attribute(Attribute()) -> "self.f1.f2.f3"

    Change an AST object into its string representation."""
    return ""


@serialize_attribute.register(ast3.Attribute)
def _sa_attribute(attr):
    return f"{serialize_attribute(attr.value)}.{attr.attr}"


@serialize_attribute.register(ast3.Name)
def _sa_name(name):
    return name.id


@serialize_attribute.register(ast3.Expr)
def _sa_expr(expr):
    return serialize_attribute(expr.value)


@singledispatch
def convert_annotation(ann):
    """Converts an AST object into its lib2to3 equivalent."""
    raise NotImplementedError(f"unknown AST node type: {ann!r}")


@convert_annotation.register(ast3.Subscript)
def _c_subscript(sub):
    return Node(
        syms.power,
        [
            convert_annotation(sub.value),
            Node(syms.trailer, [new(_lsqb), convert_annotation(sub.slice), new(_rsqb)]),
        ],
    )


@convert_annotation.register(ast3.Name)
def _c_name(name):
    return Leaf(token.NAME, name.id)


@convert_annotation.register(ast3.NameConstant)
def _c_nameconstant(const):
    return Leaf(token.NAME, repr(const.value))


@convert_annotation.register(ast3.Ellipsis)
def _c_ellipsis(ell):
    return Node(syms.atom, [new(_dot), new(_dot), new(_dot)])


@convert_annotation.register(ast3.Str)
def _c_str(s):
    return Leaf(token.STRING, repr(s.s))


@convert_annotation.register(ast3.Num)
def _c_num(n):
    return Leaf(token.NUMBER, repr(n.n))


@convert_annotation.register(ast3.Index)
def _c_index(index):
    return convert_annotation(index.value)


@convert_annotation.register(ast3.Tuple)
def _c_tuple(tup):
    contents = [convert_annotation(elt) for elt in tup.elts]
    for index in range(len(contents) - 1, 0, -1):
        contents[index].prefix = " "
        contents.insert(index, new(_comma))

    return Node(syms.subscriptlist, contents)


@convert_annotation.register(ast3.Attribute)
def _c_attribute(attr):
    # This is hacky. ¯\_(ツ)_/¯
    return Leaf(token.NAME, f"{convert_annotation(attr.value)}.{attr.attr}")


@convert_annotation.register(ast3.Call)
def _c_call(call):
    contents = [convert_annotation(arg) for arg in call.args]
    contents.extend(convert_annotation(kwarg) for kwarg in call.keywords)
    for index in range(len(contents) - 1, 0, -1):
        contents[index].prefix = " "
        contents.insert(index, new(_comma))

    call_args = [new(_lpar), new(_rpar)]
    if contents:
        call_args.insert(1, Node(syms.arglist, contents))
    return Node(
        syms.power, [convert_annotation(call.func), Node(syms.trailer, call_args)]
    )


@convert_annotation.register(ast3.keyword)
def _c_keyword(kwarg):
    assert kwarg.arg
    return Node(
        syms.argument,
        [
            Leaf(token.NAME, kwarg.arg),
            new(_eq, prefix=""),
            convert_annotation(kwarg.value),
        ],
    )


@convert_annotation.register(ast3.List)
def _c_list(l):
    contents = [convert_annotation(elt) for elt in l.elts]
    for index in range(len(contents) - 1, 0, -1):
        contents[index].prefix = " "
        contents.insert(index, new(_comma))

    list_literal = [new(_lsqb), new(_rsqb)]
    if contents:
        list_literal.insert(1, Node(syms.listmaker, contents))
    return Node(syms.atom, list_literal)


@singledispatch
def names_already_imported(names, node):
    """Returns True if `node` represents `names`."""
    return False


@names_already_imported.register(list)
def _nai_list(names, node):
    return all(names_already_imported(name, node) for name in names)


@names_already_imported.register(ast3.alias)
def _nai_alias(alias, node):
    # Comments below show example imports that match the rule.
    name = Leaf(token.NAME, alias.name)
    if not alias.asname or alias.asname == alias.name:
        # import hay, x, stack
        # from field import hay, s, stack
        if node.type in (syms.dotted_as_names, syms.import_as_names):
            return name in node.children

        # import x as x
        # from field import x as x
        if node.type in (syms.dotted_as_name, syms.import_as_name):
            return [name, _as, name] == node.children

        # import x
        return node == name

    asname = Leaf(token.NAME, alias.asname)
    dotted_as_name = Node(syms.dotted_as_name, [name, _as, asname])
    # import hay as stack, x as y
    if node.type == syms.dotted_as_names:
        return dotted_as_name in node.children

    import_as_name = Node(syms.import_as_name, [name, _as, asname])
    # from field import hay as stack, x as y
    if node.type == syms.import_as_names:
        return import_as_name in node.children

    # import x as y
    # from field import x as y
    return node in (dotted_as_name, import_as_name)


@singledispatch
def decorator_names(obj):
    return []


@decorator_names.register(Node)
def _dn_node(node):
    if node.type == syms.decorator:
        return [str(node.children[1])]

    if node.type == syms.decorators:
        return [str(decorator.children[1]) for decorator in node.children]

    return []


@decorator_names.register(list)
def _dn_list(l):
    result = []
    for elem in l:
        result.extend(decorator_names(elem))
    return result


@decorator_names.register(ast3.Name)
def _dn_name(name):
    return [name.id]


@decorator_names.register(ast3.Call)
def _dn_call(call):
    return decorator_names(call.func)


@decorator_names.register(ast3.Attribute)
def _dn_attribute(attr):
    return [serialize_attribute(attr)]


def fix_remaining_type_comments(node, flags):
    """Converts type comments in `node` to proper annotated assignments."""
    assert node.type == syms.file_input

    last_n = None
    for n in node.post_order():
        if last_n is not None:
            if n.type == token.NEWLINE and is_assignment(last_n):
                fix_variable_annotation_type_comment(n, last_n)
            elif n.type == syms.funcdef and last_n.type == syms.suite:
                fix_signature_annotation_type_comment(n, last_n, offset=1, flags=flags)
            elif n.type == syms.async_funcdef and last_n.type == syms.suite:
                fix_signature_annotation_type_comment(n, last_n, offset=2, flags=flags)
        last_n = n


def fix_variable_annotation_type_comment(node, last):
    m = _type_comment_re.match(node.prefix)
    if not m:
        return

    type_comment = parse_type_comment(m.group("type"))
    ann = convert_annotation(type_comment)
    ann.prefix = " "
    annassign_node = Node(syms.annassign, [new(_colon), ann])
    expr = last.children
    if len(expr[2:]) > 0 and expr[2:] != [_ellipsis]:
        # with assignment
        annassign_node.children.append(new(_eq))
        annassign_node.children.extend(new(elem) for elem in expr[2:])
    last.children = [expr[0], annassign_node]
    node.prefix = maybe_space_before_comment(m.group("nl"))


def fix_signature_annotation_type_comment(node, last, *, offset, flags):
    for ch in last.children:
        if ch.type == token.INDENT:
            break
    else:
        return

    m = _type_comment_re.match(ch.prefix)
    if not m:
        return

    parameters = node.children[offset + 1]
    args_tc, returns_tc = parse_signature_type_comment(m.group("type"))
    ast_args = parse_arguments(str(parameters))
    # `is_method=True` below only means we allow for missing first annotation.
    # It's not even worth checking at this point.
    copy_arguments_to_annotations(ast_args, args_tc, is_method=True)
    annotate_parameters(parameters, ast_args, is_method=True, flags=flags)
    annotate_return(node.children, returns_tc, offset + 2, flags)
    remove_function_signature_type_comment(last)


def is_assignment(node):
    if node.type != syms.expr_stmt:
        return False

    expr = node.children
    # The `bool()` below shuts up a "returning Any" warning from mypy.
    return expr[0].type in (token.NAME, syms.power) and bool(expr[1] == _eq)


def is_builtin_method_decorator(name):
    return name in {"classmethod", "staticmethod"}


def make_import(*names, from_module=None):
    assert names
    imports = []

    if from_module:
        statement = syms.import_from
        container = syms.import_as_names
        single = syms.import_as_name
        result = [
            Leaf(token.NAME, "from"),
            Leaf(token.NAME, from_module, prefix=" "),
            Leaf(token.NAME, "import", prefix=" "),
        ]
    else:
        statement = syms.import_name
        container = syms.dotted_as_names
        single = syms.dotted_as_name
        result = [Leaf(token.NAME, "import")]

    for alias in names:
        name = Leaf(token.NAME, alias.name, prefix=" ")
        if alias.asname:
            _as = Leaf(token.NAME, "as", prefix=" ")
            asname = Leaf(token.NAME, alias.asname, prefix=" ")
            imports.append(Node(single, [name, _as, asname]))
        else:
            imports.append(name)
    if len(imports) == 1:
        result.append(imports[0])
    else:
        imports_and_commas = []
        for imp in imports[:-1]:
            imports_and_commas.append(imp)
            imports_and_commas.append(Leaf(token.COMMA, ","))
        imports_and_commas.append(imports[-1])
        result.append(Node(container, imports_and_commas))
    return Node(
        syms.simple_stmt,
        [Node(statement, result), Leaf(token.NEWLINE, "\n")],  # FIXME: \r\n?
    )


def append_after_imports(stmt_to_insert, node):
    offset, stmt_to_insert.prefix = get_offset_and_prefix(node)
    node.children.insert(offset, stmt_to_insert)


def annotate_parameters(parameters, ast_args, *, is_method=False, flags):
    params = parameters.children[1:-1]
    if len(params) == 0:
        return  # FIXME: handle checking if the expected (AST) function is also empty.
    elif len(params) > 1:
        raise NotImplementedError(f"unknown AST structure in parameters: {params}")

    # Simplify the possible data structures so we can just pull from it.
    if params[0].type == syms.typedargslist:
        params = params[0].children

    typedargslist = []

    num_args_no_defaults = len(ast_args.args) - len(ast_args.defaults)
    defaults = [None] * num_args_no_defaults
    defaults.extend(ast_args.defaults)
    typedargslist.extend(
        gen_annotated_params(
            ast_args.args, defaults, params, is_method=is_method, flags=flags
        )
    )
    hopefully_vararg = None
    if ast_args.vararg or ast_args.kwonlyargs:
        try:
            hopefully_star, hopefully_vararg = pop_param(params)
            if hopefully_star != _star:
                raise ValueError
        except (IndexError, ValueError):
            raise ValueError(
                f".pyi file expects *args or keyword-only arguments in source"
            ) from None
        else:
            typedargslist.append(new(_comma))
            typedargslist.append(new(hopefully_star))

    if ast_args.vararg and hopefully_vararg:
        if hopefully_vararg.type == syms.tname:
            assert isinstance(hopefully_vararg.children[0], Leaf)
            hopefully_vararg_name = hopefully_vararg.children[0].value
        else:
            assert isinstance(hopefully_vararg, Leaf)
            hopefully_vararg_name = hopefully_vararg.value
        if hopefully_vararg_name != ast_args.vararg.arg:
            raise ValueError(f".pyi file expects *{ast_args.vararg.arg} in source")

        typedargslist.append(
            get_annotated_param(
                hopefully_vararg, ast_args.vararg, missing_ok=True, flags=flags
            )
        )

    if ast_args.kwonlyargs:
        if not ast_args.vararg:
            if hopefully_vararg != _comma:
                raise ValueError(
                    f".pyi file expects keyword-only arguments but "
                    + f"*{str(hopefully_vararg).strip()} found in source"
                )

        typedargslist.extend(
            gen_annotated_params(
                ast_args.kwonlyargs,
                ast_args.kw_defaults,
                params,
                implicit_default=True,
                flags=flags,
            )
        )

    if ast_args.kwarg:
        try:
            hopefully_dstar, hopefully_kwarg = pop_param(params)
            if not hopefully_kwarg:
                raise ValueError
            if hopefully_kwarg.type == syms.tname:
                assert isinstance(hopefully_kwarg.children[0], Leaf)
                hopefully_kwarg_name = hopefully_kwarg.children[0].value
            else:
                assert isinstance(hopefully_kwarg, Leaf)
                hopefully_kwarg_name = hopefully_kwarg.value
            if hopefully_dstar != _dstar or hopefully_kwarg_name != ast_args.kwarg.arg:
                raise ValueError
        except (IndexError, ValueError):
            raise ValueError(
                f".pyi file expects **{ast_args.kwarg.arg} in source"
            ) from None
        else:
            typedargslist.append(new(_comma))
            typedargslist.append(new(hopefully_dstar))
            typedargslist.append(
                get_annotated_param(
                    hopefully_kwarg, ast_args.kwarg, missing_ok=True, flags=flags
                )
            )

    if params:
        extra_params = minimize_whitespace(
            str(Node(syms.typedargslist, [new(p) for p in params]))
        )
        raise ValueError(f"extra arguments in source: {extra_params}")

    if typedargslist:
        typedargslist = typedargslist[1:]  # drop the initial comma
        if len(typedargslist) == 1:
            # don't pack a single argument to be consistent with how lib2to3
            # parses existing code.
            body = typedargslist[0]
        else:
            body = Node(syms.typedargslist, typedargslist)
        parameters.children = [
            parameters.children[0],  # (
            body,
            parameters.children[-1],  # )
        ]
        for arg in parameters.pre_order():
            # remove now spurious type comments
            arg.prefix = maybe_space_before_comment(
                _type_comment_re.sub(r"\g<nl>", arg.prefix, re.MULTILINE)
            )
    else:
        parameters.children = [
            parameters.children[0],  # (
            parameters.children[-1],  # )
        ]


def annotate_return(function, ast_returns, offset, flags):
    if ast_returns is None:
        if function[offset] == _colon:
            if flags.incremental:
                return

            raise ValueError(
                ".pyi file is missing return value and source doesn't "
                "provide it either"
            )
        elif function[offset] == _rarrow:
            # Source-provided return value, this is fine.
            return

        raise NotImplementedError(f"unexpected return token: {str(function[offset])!r}")

    ret_stmt = convert_annotation(ast_returns)
    ret_stmt.prefix = " "
    if function[offset] == _rarrow:
        function[offset + 1] = maybe_replace_any_if_equal(
            "return value", ret_stmt, function[offset + 1], flags
        )
    elif function[offset] == _colon:
        function.insert(offset, new(_rarrow))
        function.insert(offset + 1, ret_stmt)
    else:
        raise NotImplementedError(f"unexpected return token: {str(function[offset])!r}")


def get_function_signature(fun, *, is_method=False):
    """Returns (args, returns).

    `args` is ast3.arguments, `returns` is the return type AST node. The kicker
    about this function is that it pushes type comments into proper annotation
    fields, standardizing type handling.
    """
    args = fun.args
    returns = fun.returns
    if fun.type_comment:
        try:
            args_tc, returns_tc = parse_signature_type_comment(fun.type_comment)
            if returns and returns_tc:
                raise ValueError(
                    "using both a type annotation and a type comment is not allowed"
                )
            returns = returns_tc
            copy_arguments_to_annotations(args, args_tc, is_method=is_method)
        except (SyntaxError, ValueError) as exc:
            raise ValueError(
                f"Annotation problem in function {fun.name!r}: "
                + f"{fun.lineno}:{fun.col_offset + 1}: {exc}"
            )
    copy_type_comments_to_annotations(args)

    return args, returns


def parse_signature_type_comment(type_comment):
    """Parse the fugly signature type comment into AST nodes.

    Caveats: ASTifying **kwargs is impossible with the current grammar so we
    hack it into unary subtraction (to differentiate from Starred in vararg).

    For example from:
    "(str, int, *int, **Any) -> 'SomeReturnType'"

    To:
    ([ast3.Name, ast.Name, ast3.Name, ast.Name], ast3.Str)
    """
    try:
        result = ast3.parse(type_comment, "<func_type>", "func_type")
    except SyntaxError:
        raise ValueError(f"invalid function signature type comment: {type_comment!r}")

    assert isinstance(result, ast3.FunctionType)
    if len(result.argtypes) == 1:
        argtypes = result.argtypes[0]
    else:
        argtypes = result.argtypes
    return argtypes, result.returns


def parse_type_comment(type_comment):
    """Parse a type comment string into AST nodes."""
    try:
        result = ast3.parse(type_comment, "<type_comment>", "eval")
    except SyntaxError:
        raise ValueError(f"invalid type comment: {type_comment!r}") from None

    assert isinstance(result, ast3.Expression)
    return result.body


def parse_arguments(arguments):
    """parse_arguments('(a, b, *, c=False, **d)') -> ast3.arguments

    Parse a string with function arguments into an AST node.
    """
    arguments = f"def f{arguments}: ..."
    try:
        result = ast3.parse(arguments, "<arguments>", "exec")
    except SyntaxError:
        raise ValueError(f"invalid arguments: {arguments!r}") from None

    assert isinstance(result, ast3.Module)
    assert len(result.body) == 1
    assert isinstance(result.body[0], ast3.FunctionDef)
    args = result.body[0].args
    copy_type_comments_to_annotations(args)
    return args


def copy_arguments_to_annotations(args, type_comment, *, is_method=False):
    """Copies AST nodes from `type_comment` into the ast3.arguments in `args`.

    Does validaation of argument count (allowing for untyped self/cls)
    and type (vararg and kwarg).
    """
    if isinstance(type_comment, ast3.Ellipsis):
        return

    expected = len(args.args)
    if args.vararg:
        expected += 1
    expected += len(args.kwonlyargs)
    if args.kwarg:
        expected += 1
    actual = len(type_comment) if isinstance(type_comment, list) else 1
    if expected != actual:
        if is_method and expected - actual == 1:
            pass  # fine, we're just skipping `self`, `cls`, etc.
        else:
            raise ValueError(
                f"number of arguments in type comment doesn't match; "
                + f"expected {expected}, found {actual}"
            )

    if isinstance(type_comment, list):
        next_value = type_comment.pop
    else:
        # If there's just one value, only one of the loops and ifs below will
        # be populated. We ensure this with the expected/actual length check
        # above.
        _tc = type_comment

        def next_value(index: int = 0) -> ast3.expr:
            return _tc

    for arg in args.args[expected - actual :]:
        ensure_no_annotation(arg.annotation)
        arg.annotation = next_value(0)

    if args.vararg:
        ensure_no_annotation(args.vararg.annotation)
        args.vararg.annotation = next_value(0)

    for arg in args.kwonlyargs:
        ensure_no_annotation(arg.annotation)
        arg.annotation = next_value(0)

    if args.kwarg:
        ensure_no_annotation(args.kwarg.annotation)
        args.kwarg.annotation = next_value(0)


def copy_type_comments_to_annotations(args):
    """Copies argument type comments from the legacy long form to annotations
    in the entire function signature.
    """
    for arg in args.args:
        copy_type_comment_to_annotation(arg)

    if args.vararg:
        copy_type_comment_to_annotation(args.vararg)

    for arg in args.kwonlyargs:
        copy_type_comment_to_annotation(arg)

    if args.kwarg:
        copy_type_comment_to_annotation(args.kwarg)


def copy_type_comment_to_annotation(arg):
    if not arg.type_comment:
        return

    ann = parse_type_comment(arg.type_comment)
    ensure_no_annotation(arg.annotation)
    arg.annotation = ann


def maybe_replace_any_if_equal(name, expected, actual, flags):
    """Return the type given in `expected`.

    Raise ValueError if `expected` isn't equal to `actual`.  If --replace-any is
    used, the Any type in `actual` is considered equal.

    The implementation is naively checking if the string representation of
    `actual` is one of "Any", "typing.Any", or "t.Any".  This is done for two
    reasons:

    1. I'm lazy.
    2. We want people to be able to explicitly state that they want Any without it
       being replaced.  This way they can use an alias.
    """
    is_equal = expected == actual
    if not is_equal and flags.replace_any:
        actual_str = minimize_whitespace(str(actual))
        if actual_str and actual_str[0] in {'"', "'"}:
            actual_str = actual_str[1:-1]
        is_equal = actual_str in {"Any", "typing.Any", "t.Any"}

    if not is_equal:
        expected_annotation = minimize_whitespace(str(expected))
        actual_annotation = minimize_whitespace(str(actual))
        raise ValueError(
            f"incompatible existing {name}. "
            + f"Expected: {expected_annotation!r}, actual: {actual_annotation!r}"
        )

    return expected or actual


def ensure_no_annotation(ann):
    if ann:
        raise ValueError(
            f"using both a type annotation and a type comment is not allowed: {ann}"
        )


def ensure_annotations_equal(name, expected, actual, flags):
    """Raise ValueError if `expected` isn't equal to `actual`.

    If --replace-any is used, the Any type in `actual` is considered equal.
    """
    maybe_replace_any_if_equal(name, expected, actual, flags)


def remove_function_signature_type_comment(body):
    """Removes the legacy signature type comment, leaving other comments if any."""
    for node in body.children:
        if node.type == token.INDENT:
            prefix = node.prefix.lstrip()
            if prefix.startswith("# type: "):
                node.prefix = "\n".join(prefix.split("\n")[1:])
            break


def minimize_whitespace(text):
    return re.sub(r"[\n\t ]+", " ", text, re.MULTILINE).strip()


def maybe_space_before_comment(text):
    if not text:
        return ""

    if text.startswith("#"):
        return "  " + text

    return text


def flatten_some(children):
    """Generates nodes or leaves, unpacking bodies of try:except:finally: statements."""
    for node in children:
        if node.type in (syms.try_stmt, syms.suite):
            yield from flatten_some(node.children)
        else:
            yield node


def pop_param(params):
    """Pops the parameter and the "remainder" (comma, default value).

    Returns a tuple of ('name', default) or (_star, 'name') or (_dstar, 'name').
    """
    default = None

    name = params.pop(0)
    if name in (_star, _dstar):
        default = params.pop(0)
        if default == _comma:
            return name, default

    try:
        remainder = params.pop(0)
        if remainder == _eq:
            default = params.pop(0)
            remainder = params.pop(0)
        if remainder != _comma:
            raise ValueError(f"unexpected token: {remainder}")

    except IndexError:
        pass
    return name, default


def gen_annotated_params(
    args, defaults, params, *, implicit_default=False, is_method=False, flags
):
    missing_ok = is_method or flags.incremental
    for arg, expected_default in zip(args, defaults):
        yield new(_comma)

        try:
            param, actual_default = pop_param(params)
        except IndexError:
            raise ValueError(
                f"missing regular argument {arg.arg!r} in source"
            ) from None

        if param in (_star, _dstar):
            # unexpected *args, keyword-only args, or **kwargs
            raise ValueError(f"missing regular argument {arg.arg!r} in source")

        if expected_default is None and actual_default is not None:
            if not implicit_default or actual_default != _none:
                param_s = minimize_whitespace(str(param))
                raise ValueError(
                    f".pyi file does not specify default value for arg "
                    + f"`{param_s}` but the source does"
                )

        if expected_default is not None and actual_default is None:
            param_s = minimize_whitespace(str(param))
            raise ValueError(
                f"source file does not specify default value for arg `{param_s}` "
                + f"but the .pyi file does"
            )

        node = get_annotated_param(param, arg, missing_ok=missing_ok, flags=flags)
        yield node
        if actual_default:
            whitespace = " " if node.type == syms.tname else ""
            yield new(_eq, prefix=whitespace)
            yield new(actual_default, prefix=whitespace)

        missing_ok = flags.incremental


def get_annotated_param(node, arg, *, missing_ok=False, flags):
    if node.type not in (token.NAME, syms.tname):
        raise NotImplementedError(f"unexpected node token: `{node}`")

    actual_ann = None
    if node.type == syms.tname:
        actual_ann = node.children[2]
        node = node.children[0]
    if not isinstance(node, Leaf) or arg.arg != node.value:
        raise ValueError(
            f".pyi file expects argument {arg.arg!r} next but argument "
            + f"{minimize_whitespace(str(node))!r} found in source"
        )

    if arg.annotation is None:
        if actual_ann is None:
            if missing_ok:
                return new(node)

            raise ValueError(
                f".pyi file is missing annotation for {arg.arg!r} and source "
                + f"doesn't provide it either"
            )

        ann = new(actual_ann)
    else:
        ann = convert_annotation(arg.annotation)
        ann.prefix = " "

    if actual_ann is not None:
        ensure_annotations_equal(
            "annotation for {arg.arg!r}", ann, actual_ann, flags=flags
        )

    return Node(syms.tname, [new(node), new(_colon), ann])


def get_offset_and_prefix(body, skip_assignments=False):
    """Returns the offset after which a statement can be inserted to the `body`.

    This offset is calculated to come after all imports, and maybe existing
    (possibly annotated) assignments if `skip_assignments` is True.

    Also returns the indentation prefix that should be applied to the inserted
    node.
    """
    assert body.type in (syms.file_input, syms.suite)

    _offset = 0
    prefix = ""
    for _offset, child in enumerate(body.children):
        if child.type == syms.simple_stmt:
            stmt = child.children[0]
            if stmt.type == syms.expr_stmt:
                expr = stmt.children
                if not skip_assignments:
                    break

                if (
                    len(expr) != 2
                    or expr[0].type != token.NAME
                    or expr[1].type != syms.annassign
                    or _eq in expr[1].children
                ):
                    break

            elif stmt.type not in (syms.import_name, syms.import_from, token.STRING):
                break

        elif child.type == token.INDENT:
            assert isinstance(child, Leaf)
            prefix = child.value
        elif child.type != token.NEWLINE:
            break

    prefix, child.prefix = child.prefix, prefix
    return _offset, prefix


@singledispatch
def name_used_in_node(node, name):
    """Returns True if `name` appears in `node`. False otherwise."""


@name_used_in_node.register(Node)
def _nuin_node(node, name):
    for n in node.pre_order():
        if n == name:
            return True

    return False


@name_used_in_node.register(Leaf)
def _nuin_leaf(leaf, name):
    return leaf == name


def fix_line_numbers(body):
    r"""Recomputes all line numbers based on the number of \n characters."""
    maxline = 0
    for node in body.pre_order():
        maxline += node.prefix.count("\n")
        if isinstance(node, Leaf):
            node.lineno = maxline
            maxline += str(node.value).count("\n")


def new(n, prefix=None):
    """lib2to3's AST requires unique objects as children."""

    if isinstance(n, Leaf):
        return Leaf(n.type, n.value, prefix=n.prefix if prefix is None else prefix)

    # this is hacky, we assume complex nodes are just being reused once from the
    # original AST.
    n.parent = None
    if prefix is not None:
        n.prefix = prefix
    return n


def _load_ignore(at_path, parent_spec, ignores):
    ignore_file = at_path / ".gitignore"
    if not ignore_file.exists():
        return parent_spec
    lines = ignore_file.read_text().split(os.linesep)
    spec = PathSpec.from_lines("gitwildmatch", lines)
    spec = PathSpec(parent_spec.patterns + spec.patterns)
    ignores[at_path] = spec
    return spec


def walk_not_git_ignored(path, keep, extra_ignore):
    spec = PathSpec.from_lines("gitwildmatch", [".git"] + extra_ignore)
    ignores = {}
    # detect git folder, collect ignores up to root
    at = path
    while True:
        git_exist = (at / ".git").exists()
        if git_exist:
            ignores[at.parent] = spec
            # go down back and load all ignores
            for part in (".",) + path.relative_to(at).parts:
                at = at / part
                spec = _load_ignore(at, spec, ignores)
            break
        if at == at.parent:
            ignores[path] = spec
            break
        at = at.parent

    # now walk from root, collect new ignores and evaluate
    for root, dirs, files in os.walk(str(path)):
        root_path = Path(root).relative_to(path)  # current path
        current_path = path / root
        parent_spec = ignores.get(current_path) or next(
            ignores[p] for p in current_path.parents if p in ignores
        )
        spec = _load_ignore(Path(root), parent_spec, ignores)
        for file_name in files:
            result = root_path / file_name
            if (
                file_name != ".gitignore"
                and keep(result)
                and not spec.match_file(str(result))
            ):
                yield path / result
        for cur_dir in list(dirs):
            if spec.match_file(str(root_path / cur_dir)):
                dirs.remove(cur_dir)


_as = Leaf(token.NAME, "as", prefix=" ")
_colon = Leaf(token.COLON, ":")
_comma = Leaf(token.COMMA, ",")
_dot = Leaf(token.DOT, ".")
_dstar = Leaf(token.DOUBLESTAR, "**")
_eq = Leaf(token.EQUAL, "=", prefix=" ")
_lpar = Leaf(token.LPAR, "(")
_lsqb = Leaf(token.LSQB, "[")
_newline = Leaf(token.NEWLINE, "\n")
_none = Leaf(token.NAME, "None")
_rarrow = Leaf(token.RARROW, "->", prefix=" ")
_rpar = Leaf(token.RPAR, ")")
_rsqb = Leaf(token.RSQB, "]")
_star = Leaf(token.STAR, "*")
_ellipsis = Node(syms.atom, children=[new(_dot), new(_dot), new(_dot)])

_type_comment_re = re.compile(
    r"""
    ^
    [\t ]*
    \#[ ]type:[ ]*
    (?P<type>
        [^#\t\n]+?
    )
    (?<!ignore)     # note: this will force the non-greedy + in <type> to match
                    # a trailing space which is why we need the silliness below
    (?<!ignore[ ]{1})(?<!ignore[ ]{2})(?<!ignore[ ]{3})(?<!ignore[ ]{4})
    (?<!ignore[ ]{5})(?<!ignore[ ]{6})(?<!ignore[ ]{7})(?<!ignore[ ]{8})
    (?<!ignore[ ]{9})(?<!ignore[ ]{10})
    [\t ]*
    (?P<nl>
        (?:\#[^\n]*)?
        \n?
    )
    $
    """,
    re.MULTILINE | re.VERBOSE,
)


__all__ = ("__version__", "retype_path", "retype_file")