"""Node transformers for manipulating the abstract syntax tree

Transformers for exploding functions, rewriting syntax, and adding annotations
exist currently.

"""

import ast
import copy
import doctest
import random
import string

import astor

N = []
def __random_string__():
    """Compute a random string

    Also cache them in the global stack `N` so we can get them out later.

    """
    n = ''.join(random.choice(string.ascii_lowercase) for _ in range(10))
    global N
    N.append(n)
    return n

def upcase(s):
    """

    >>> s = 'foo'

    """
    u = f'{s[0].upper()}{s[1:]}'
    return u

def make_annotation(node=None, buffer='outside', content=None, cell_type='code', lineno=None):
    """Return a ast.Expr that looks like

    ```
    __cell__('make-cell', [content, buffer, cell_type])
    ```

    """
    content = astor.to_source(node).strip() if node else content
    lineno = str(node.lineno) if hasattr(node, 'lineno') else str(-1) if not lineno else str(lineno)
    call = ast.Call(
        func=ast.Name(id='__cell__', ctx=ast.Load()),
        args=[
            ast.Str(s=content),
            ast.Str(s=f'{buffer}'),
            ast.Str(s=cell_type),
            ast.Str(s=lineno),
        ],
        keywords=[]
    )
    return ast.Expr(call)

class ExpressionFinder(ast.NodeTransformer):
    """Find the expression which contains the line number"""

    def __init__(self, lineno):
        super(__class__, self).__init__()
        self.lineno = lineno
        self.target_node = None # found expr

    def generic_visit(self, node):
        """Catch-all for nodes that slip through

        Basically everything I haven't gotten around to writing a custom
        annotator for gets caught here and wrapped in an annotation. Currently
        the one that jumps to mind are context managers.

        This is necessary because some nodes we recursively call `self.visit()`
        on and we may run into expressions that we have not written a node
        tranformer for.

        """
        # try:
        #     c = astor.to_source(node)
        #     print(c.strip())
        #     print(getattr(node, 'lineno', -1))
        #     print()
        # except:
        #     pass
        # return super().generic_visit(node)
        if not self.target_node and (getattr(node, 'lineno', -1) == self.lineno):
            self.target_node = node
        return super().generic_visit(node)

    def visit_Module(self, module):
        """Search the module's top-level expressions

        >>> self = ExpressionFinder(lineno=4)
        >>> code = '''
        ...
        ... foo()
        ... if 1:
        ...     if 2:
        ...         print(12)
        ...
        ... '''
        >>>
        >>> module = ast.parse(code)

        """
        self.generic_visit(module)
        module.body = [self.target_node]
        return module
        # for i, expr in enumerate(module.body):
        #     if getattr(expr, 'lineno', -float('inf')) < self.lineno:
        #         continue
        #     elif expr.lineno == self.lineno:
        #         module.body = [expr]
        #     else:
        #         assert expr.lineno > self.lineno
        #         module.body = module.body[i-1:i]
        #     break
        # else:
        #     module.body = module.body[-1:]
        # return module

class DefunFinder(ast.NodeTransformer):
    """Find the function or method which is defined at a particular line number"""

    def __init__(self, func_name, lineno):
        """

        >>> self = DefunFinder.__new__(DefunFinder)
        >>> __class__ = DefunFinder
        >>> func_name = 'bar'
        >>> lineno = 5

        """
        super(__class__, self).__init__()
        self.func_name = func_name
        self.lineno = lineno

    def visit_ClassDef(self, classdef):
        """Check the line number of each of the methods

        >>> self = DefunFinder(func_name='bar', lineno=4)
        >>> code = '''
        ...
        ... class Foo:
        ...     def bar():
        ...         \"\"\"function\"\"\"
        ...         pass
        ...     def biz():
        ...         \"\"\"function\"\"\"
        ...         pass
        ...
        ... '''
        >>>
        >>> tree = ast.parse(code)
        >>> classdef = tree.body[0]

        """
        methods = [stmt for stmt in classdef.body if isinstance(stmt, ast.FunctionDef)]
        for method in methods:
            if method.name == self.func_name and method.lineno == self.lineno:
                raise Exception(f'{classdef.name}.{method.name}')
        return classdef

    def visit_FunctionDef(self, func):
        """Embed a `IPython.embed_kernel()` call into the function

        >>> self = DefunFinder(func_name='bar', lineno=4)
        >>> code = '''
        ...
        ... class Foo:
        ...     def bar():
        ...         \"\"\"function\"\"\"
        ...         pass
        ...     def biz():
        ...         \"\"\"function\"\"\"
        ...         pass
        ...
        ... '''
        >>>
        >>> tree = ast.parse(code)
        >>> classdef = tree.body[0]

        """
        if func.name == self.func_name and func.lineno == self.lineno:
            raise Exception(func.name)
        return func

class IPythonEmbedder(ast.NodeTransformer):
    """Replaces the body of a function with `IPython.embed_kernel()`.

    Specifically swap out the body of a function with a call to fork off the
    `IPython.embed_kernel()` call.

    """
    def __init__(self, namespace):
        """

        >>> self = IPythonEmbedder.__new__(IPythonEmbedder)
        >>> namespace = 'foo.bar'
        >>> __class__ = IPythonEmbedder

        """
        super(__class__, self).__init__()
        self.namespace = namespace
        tokens = namespace.split('.')
        if len(tokens) == 1:
            self.module, = tokens
            self.func_type = 'module'
        elif len(tokens) == 2:
            self.module, self.func_name = tokens
            self.func_type = 'function'
        else:
            assert len(tokens) == 3
            self.module, self.class_name, self.func_name = tokens
            self.func_type = 'method'

    @staticmethod
    def get_kernel_embed():
        """A list of kernel embed nodes

        Returns:
            nodes (list): AST nodes which form the following code.

            ```
            import os
            pid = os.fork()
            if os.fork() == 0:
                open(f'{os.environ["HOME"]}/.pynt', 'a').close()
                import IPython
                IPython.start_kernel(user_ns={**locals(), **globals(), **vars()})
            os.waitpid(pid, 0)
            ```

        This is a purely functional method which always return the same thing.

        """
        return [
            ast.Import(names=[ast.alias(name='os', asname=None),]),
            ast.Assign(targets=[ast.Name(id='pid', ctx=ast.Store()),], value=ast.Call(func=ast.Attribute(value=ast.Name(id='os', ctx=ast.Load()), attr='fork', ctx=ast.Load()), args=[], keywords=[])),
            ast.If(
                test=ast.Compare(left=ast.Name(id='pid', ctx=ast.Load()), ops=[ast.Eq(),], comparators=[ast.Num(n=0),]),
                body=[
                    ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Call(func=ast.Name(id='open', ctx=ast.Load()), args=[
                        ast.JoinedStr(values=[
                            ast.FormattedValue(value=ast.Subscript(value=ast.Attribute(value=ast.Name(id='os', ctx=ast.Load()), attr='environ', ctx=ast.Load()), slice=ast.Index(value=ast.Str(s='HOME')), ctx=ast.Load()), conversion=-1, format_spec=None),
                            ast.Str(s='/.pynt'),
                        ]),
                        ast.Str(s='a'),
                    ], keywords=[]), attr='close', ctx=ast.Load()), args=[], keywords=[])),
                    ast.Import(names=[
                        ast.alias(name='IPython', asname=None),
                    ]),
                    ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id='IPython', ctx=ast.Load()), attr='start_kernel', ctx=ast.Load()), args=[], keywords=[
                        ast.keyword(arg='user_ns', value=ast.Dict(keys=[
                            None,
                            None,
                            None,
                        ], values=[
                            ast.Call(func=ast.Name(id='locals', ctx=ast.Load()), args=[], keywords=[]),
                            ast.Call(func=ast.Name(id='globals', ctx=ast.Load()), args=[], keywords=[]),
                            ast.Call(func=ast.Name(id='vars', ctx=ast.Load()), args=[], keywords=[]),
                        ])),
                    ])),
            ], orelse=[]),
            ast.Expr(value=ast.Call(func=ast.Attribute(value=ast.Name(id='os', ctx=ast.Load()), attr='waitpid', ctx=ast.Load()), args=[
                ast.Name(id='pid', ctx=ast.Load()),
                ast.Num(n=0),
            ], keywords=[])),
        ]

    def visit_Module(self, module):
        """Maybe replace the entire module with a kernel

        If namespace is targeting the top-level then we do it.

        >>> self = IPythonEmbedder(namespace='foo.foo')
        >>> code = '''
        ...
        ... import random
        ... def foo():
        ...     pass
        ...
        ... '''
        >>> module = ast.parse(code)

        """
        if self.func_type == 'module':
            module.body = self.get_kernel_embed()
        else:
            module = self.generic_visit(module)
        return module

    def visit_ClassDef(self, classdef):
        """Embed a kernel into classdef.target`

        If either `self.func_type` is a function or `self.class_name` does not
        match this class then that means this is not the classdef you are
        looking for.

        >>> self = IPythonEmbedder(namespace='ast_server.Foo.biz')
        >>> code = '''
        ...
        ... class Foo:
        ...     def bar():
        ...         \"\"\"function\"\"\"
        ...         pass
        ...     def biz():
        ...         \"\"\"function\"\"\"
        ...         pass
        ...
        ... '''
        >>>
        >>> tree = ast.parse(code)
        >>> classdef = tree.body[0]

        """
        if self.func_type == 'function':
            node = classdef
        elif not self.class_name == classdef.name:
            node = classdef
        else:
            assert classdef.name == self.class_name and self.func_type == 'method'
            methods = [stmt for stmt in classdef.body if isinstance(stmt, ast.FunctionDef)]
            [idx, method], = [(i, method) for i, method in enumerate(methods) if method.name == self.func_name]
            classdef.body[idx] = self.visit_FunctionDef(method)
            node = classdef
        return node

    def visit_FunctionDef(self, func):
        """Embed a `IPython.embed_kernel()` call into the function

        Recall the context this node visitor is running in is that we are
        embedding a function. Because of the existence of `visit_ClassDef()`
        the only time we will visit a method is when we are called directly on
        the method that needs to be embedded. Hence it is sufficient to just
        check that `func.name == self.func_name` with no risk that we will
        embed a method which has the same name as the target method but is in a
        different class.

        >>> self = IPythonEmbedder(namespace='ast_server.foo')
        >>> code = '''
        ...
        ... x
        ... def foo():
        ...     x = 1
        ...     y = 2
        ...     z = x + y
        ...     return z
        ... y
        ...
        ... '''
        >>> tree = ast.parse(code)
        >>> func = tree.body[1]

        """
        if not func.name == self.func_name:
            node = func
        else:
            func.body = self.get_kernel_embed()
            node = func
        return node

class NamespacePromoter(ast.NodeTransformer):
    """Takes a body of a function and pushes it into the global namespace"""

    def __init__(self, buffer):
        super(__class__, self).__init__()
        self.buffer = buffer

    def visit_Return(self, return_):
        """Convert returns into assignment/exception pairs

        Since the body of this function will be in the global namespace we
        can't have any returns. An acceptable alternative is to set a variable
        called 'RETURN' and then immediately raise an exception.

        >>> self = NamespacePromoter(buffer='foo')
        >>> code = '''
        ...
        ... return 5
        ...
        ... '''
        >>> tree = ast.parse(code)
        >>> return_, = tree.body

        """
        nodes = [
            ast.Assign(targets=[ast.Name(id='RETURN', ctx=ast.Store())], value=return_.value, lineno=return_.lineno),
            ast.Raise(exc=ast.Call(func=ast.Name(id='Exception', ctx=ast.Load()), args=[ast.Str(s='return')], keywords=[]), cause=None),
        ]
        return nodes

    def visit_FunctionDef(self, func):
        """Roll out a function definition

        >>> self = NamespacePromoter(buffer='bar')
        >>> code = '''
        ...
        ... x
        ... def foo(a=1, b=2):
        ...     \"\"\"\Short description
        ...
        ...     Longer description.
        ...
        ...     >>> a = 1
        ...     >>> b = 2
        ...
        ...     "\"\"
        ...     if True:
        ...         return 1
        ...     else:
        ...         return 0
        ... y
        ... '''
        >>> tree = ast.parse(code)
        >>> func = tree.body[1]

        """
        # tranform returns
        func = self.generic_visit(func)

        # extract doctests
        docstring = ast.get_docstring(func, clean=True)
        if docstring:
            func.body = func.body[1:]
            parser = doctest.DocTestParser()
            results = parser.parse(docstring)
            docstring_prefix, docstring_examples = results[0].strip(), [result for result in results if isinstance(result, doctest.Example)]
            docstring_assigns = [example.source.strip() for example in docstring_examples]

        # insert function name and docstring san doctests
        exprs = []
        exprs.append(
            make_annotation(
                buffer=self.buffer,
                content=f'`{func.name}`',
                cell_type='1',
                lineno=func.lineno
            )
        )
        if docstring:
            exprs.append(make_annotation(buffer=self.buffer, content=docstring_prefix, cell_type='markdown'))

        # keyword (default) values
        vars, values = reversed(func.args.args), reversed(func.args.defaults)
        for var, value in zip(vars, values):
            try_ = ast.Try(
                body=[ast.Expr(value=ast.Name(id=var.arg, ctx=ast.Load()))],
                handlers=[
                    ast.ExceptHandler(
                        type=ast.Name(id='NameError', ctx=ast.Load()),
                        name=None,
                        body=[ast.Assign(targets=[ast.Name(id=var.arg, ctx=ast.Store())], value=value)]),
                ],
                orelse=[],
                finalbody=[]
            )
            exprs.append(try_)

        # docstring values override keyword values
        if docstring:
            # exprs.append(make_annotation(buffer=self.buffer, content='Docstring Assignments', cell_type='1'))
            for assign_expr in docstring_assigns:
                tree = ast.parse(assign_expr)
                exprs.append(tree.body[0])

        # final dump of all arguments
        exprs.append(make_annotation(buffer=self.buffer, content='Arguments', cell_type='1'))
        exprs.extend(ast.Expr(arg) for arg in func.args.args)

        exprs.append(make_annotation(buffer=self.buffer, content='Body', cell_type='1'))

        return exprs + func.body

class UnpackTry(ast.NodeTransformer):
    def __init__(self, buffer, only_try):
        self.buffer = buffer
        self.only_try = only_try

    def visit_Try(self, tryexp):
        """Unpack a try/except line.

        >>> self = UnpackTry('foo', False)
        >>> code = '''
        ...
        ... try:
        ...     data = requests.get(url, stream=True).raw
        ... except:
        ...     print('error')
        ...
        ... '''
        >>> module = ast.parse(code)
        >>> tryexp, = module.body

        """
        nodes = []
        nodes.append(make_annotation(buffer=self.buffer, content='try', cell_type='2', lineno=tryexp.lineno))
        nodes.extend(tryexp.body)
        if self.only_try:
            return nodes

        for handler in tryexp.handlers:
            handler_toks = ['except']
            if handler.type:
                handler_type = astor.to_source(handler.type).strip()
                handler_toks.append(handler_type)
            if handler.name:
                handler_toks.extend(['as', handler.name])
            handler_str = ' '.join(handler_toks)
            nodes.append(make_annotation(buffer=self.buffer, content=handler_str, cell_type='2', lineno=tryexp.lineno))
            nodes.extend(handler.body)
        if tryexp.orelse:
            nodes.append(make_annotation(buffer=self.buffer, content='else', cell_type='2', lineno=tryexp.lineno))
            nodes.extend(tryexp.orelse)
        if tryexp.finalbody:
            nodes.append(make_annotation(buffer=self.buffer, content='finally', cell_type='2', lineno=tryexp.lineno))
            nodes.extend(tryexp.finalbody)
        return nodes

class UnpackIf(ast.NodeTransformer):
    def __init__(self, buffer):
        self.buffer = buffer

    def visit_If(self, ifexp):
        """Pure syntax rewrite of a for loop

        Unroll only the first iteration through the loop.

        >>> self = FirstPassForSimple('foo')
        >>> code = '''
        ...
        ... if 1:
        ...     print(1)
        ... elif 2:
        ...     print(2)
        ... elif 3:
        ...     print(3)
        ... else:
        ...     print(4)
        ...
        ... '''
        >>> module = ast.parse(code)
        >>> ifexp, = module.body

        """
        # ifexp = self.generic_visit(ifexp)
        nodes = []
        content = f'if {astor.to_source(ifexp.test).strip()}'
        nodes.append(make_annotation(buffer=self.buffer, content=content, cell_type='2', lineno=ifexp.lineno))
        nodes.extend([ifexp.test])
        nodes.extend(ifexp.body)
        nodes.extend(ifexp.orelse)
        return nodes

class FirstPassForSimple(ast.NodeTransformer):
    def __init__(self, buffer):
        self.buffer = buffer

    def visit_Continue(self, cont):
        """

        >>> self = FirstPassForSimple('foo')
        >>> code = 'continue'
        >>> module = ast.parse(code)
        >>> cont, = module.body

        """
        return ast.Pass()

    def visit_Break(self, broken):
        """

        >>> self = FirstPassForSimple('foo')
        >>> code = 'break'
        >>> module = ast.parse(code)
        >>> broken, = module.body

        """
        return ast.Pass()

    def visit_For(self, loop):
        """Pure syntax rewrite of a for loop

        Unroll only the first iteration through the loop.

        >>> self = FirstPassForSimple('foo')

        """
        # loop = self.generic_visit(loop)

        # iter(loop.iter)
        iter_call = ast.Call(
            func=ast.Name(id='iter', ctx=ast.Load()),
            args=[ast.Name(id=loop.iter, ctx=ast.Load())],
            keywords=[]
        )

        # i = next(iter(loop.iter))
        get_first = ast.Assign(
            targets=[loop.target],
            value=ast.Call(
                func=ast.Name(id='next', ctx=ast.Load()),
                args=[iter_call],
                keywords=[]
            )
        )
        content = f'`for {astor.to_source(loop.target).strip()} in {astor.to_source(loop.iter).strip()} ...`'
        nodes = []
        nodes.append(make_annotation(buffer=self.buffer, content=content, cell_type='2', lineno=loop.lineno))
        nodes.append(ast.Expr(loop.iter))
        nodes.append(get_first)
        nodes.extend(loop.body)
        return nodes

class FirstPassFor(ast.NodeTransformer):
    """Performs pure syntax rewrites

    Currently the only syntax rewrite are for loops to while loops. Future
    rewrites include context managers and decorators.

    """
    def __init__(self, buffer):
        self.buffer = buffer

    def visit_For(self, loop_):
        """
        >>> self = FirstPassFor(buffer='foo')
        >>> code = '''
        ...
        ... for i in range(5):
        ...     for j in range(5):
        ...         k = i + j
        ...         print(k)
        ...
        ... '''
        >>> tree = ast.parse(code)
        >>> loop_, = tree.body

        """
        loop = self.generic_visit(loop_)
        var = ast.Name(id=__random_string__(), ctx=ast.Store())
        assign = ast.Assign(targets=[var], value=ast.Call(func=ast.Name(id='iter', ctx=ast.Load()), args=[loop.iter], keywords=[]))
        first_pass = ast.Try(
            body=[ast.Assign(targets=[loop.target], value=ast.Call(func=ast.Name(id='next', ctx=ast.Load()), args=[ast.Name(id=var, ctx=ast.Load())], keywords=[]))],
            handlers=[ast.ExceptHandler(type=ast.Name(id='StopIteration', ctx=ast.Load()), name=None, body=[ast.Pass()])],
            orelse=loop.body,
            finalbody=[]
        )
        content = f'`for {astor.to_source(loop.target).strip()} in {astor.to_source(loop.iter).strip()} ...`'
        return [
            make_annotation(buffer=self.buffer, content=content, cell_type='2', lineno=loop.lineno),
            ast.Expr(loop.iter),
            assign,
            first_pass
        ]

class RestIterableFor(ast.NodeTransformer):
    def __init__(self, buffer):
        self.buffer = buffer

    def visit_For(self, loop_):
        """
        >>> self = RestIterableFor()
        >>> code = '''
        ...
        ... for i in range(5):
        ...     for j in range(5):
        ...         k = i + j
        ...         print(k)
        ... '''
        >>> tree = ast.parse(code)
        >>> loop_, = tree.body
        >>> FirstPassFor().visit(copy.deepcopy(loop_))

        """
        loop = self.generic_visit(loop_)
        global N
        varname = N.pop(0)
        loop.iter = ast.Name(id=varname, ctx=ast.Store())
        return loop

class SyntaxRewriter(ast.NodeTransformer):
    """Performs pure syntax rewrites

    Currently the only syntax rewrite are for loops to while loops. Future
    rewrites include context managers and decorators.

    """
    def __init__(self, buffer):
        super(__class__, self).__init__()
        self.buffer = buffer

    def visit_For(self, loop):
        """Rewrite for loops as while loops

        >>> self = SyntaxRewriter(buffer='foo')
        >>> code = '''
        ...
        ... for i in range(5):
        ...     for j in range(5):
        ...         k = i + j
        ...         print(k)
        ...
        ... '''
        >>> tree = ast.parse(code)
        >>> loop, = tree.body

        """
        first = FirstPassFor(self.buffer).visit(copy.deepcopy(loop))
        rest = RestIterableFor(self.buffer).visit(copy.deepcopy(loop))
        return first + [rest]

class ShallowAnnotator(ast.NodeTransformer):
    """Does a shallow annotation on the code given to it

    Literally only do assignment rewrites.

    """
    def __init__(self, buffer):
        super(__class__, self).__init__()
        self.buffer = buffer

    def visit_Assign(self, assign):
        """Append the targets to the assign code string"""
        assign_content, targets_content = astor.to_source(assign), astor.to_source(assign.targets[0])
        content = assign_content + targets_content.strip()
        return make_annotation(
            buffer=self.buffer,
            content=content,
            lineno=assign.lineno if hasattr(assign, 'lineno') else None
        )

    def visit_Expr(self, expr):
        """Don't double-annotate an annotation

        Even in `expr` is a `ast.Call` its `value` might be a `ast.Attribute`
        not a `ast.Name`. In this case we know it's not an annotation. Perhaps
        a more reliable way would be traversing the AST and looking for any
        node with a `id` of `__cell__` or perhaps tagging the node with a
        boolean flag called `is_annotation`.

        Annotations are only *maybe* here at this point because
        `NamespacePromoter` puts them in.

        """
        if isinstance(getattr(expr, 'value', None), ast.Call) and getattr(expr.value.func, 'id', None) == '__cell__':
            return expr
        else:
            return make_annotation(expr, buffer=self.buffer)

    def generic_visit(self, node):
        if isinstance(node, ast.Module):
            return super().generic_visit(node)
        else:
            return make_annotation(node, buffer=self.buffer)


class DeepAnnotator(ShallowAnnotator):
    """Annotates code with commands to create jupyter notebook cells"""

    def _annotate_nodes(self, nodes):
        """Make annotation on the nodes.

        If the node has a namespace then don't annotate it normally.
        Rather recursively call `visit()` on it.

        """
        exprs = []
        for node in nodes:
            new_nodes = self.visit(node)
            if isinstance(new_nodes, list):
                exprs.extend(new_nodes)
            else:
                exprs.append(new_nodes)
        return exprs

    def visit_If(self, iff):
        return [
            make_annotation(buffer=self.buffer, content=f'`if {astor.to_source(iff.test).strip()} ...`', cell_type='2'),
            make_annotation(iff.test, buffer=self.buffer),
            ast.If(
                test=iff.test,
                body=self._annotate_nodes(iff.body),
                orelse=self._annotate_nodes(iff.orelse)
            )
        ]

    def visit_Try(self, try_):
        handlers = []
        for handler in try_.handlers:
            handlers.append(
                ast.ExceptHandler(
                    type=handler.type,
                    name=None,
                    body=self._annotate_nodes(handler.body)
                )
            )
        return ast.Try(
                body=self._annotate_nodes(try_.body),
                handlers=handlers,
                orelse=self._annotate_nodes(try_.orelse),
                finalbody=self._annotate_nodes(try_.finalbody)
        )


    def visit_Assign(self, assign):
        """Append the targets to the assign code string

        Do the same thing as `generic_visit()` otherwise.

        """
        annotated_assign = super().visit_Assign(assign)
        return [assign, annotated_assign]

    def visit_Expr(self, expr):
        annotated_expr = super().visit_Expr(expr)
        if annotated_expr == expr:
            return expr
        else:
            return [annotated_expr, expr]

    def generic_visit(self, node):
        """Catch-all for nodes that slip through

        Basically everything I haven't gotten around to writing a custom
        annotator for gets caught here and wrapped in an annotation. Currently
        the one that jumps to mind are context managers.

        This is necessary because some nodes we recursively call `self.visit()`
        on and we may run into expressions that we have not written a node
        tranformer for.

        """
        if isinstance(node, ast.Module):
            return super().generic_visit(node)
        else:
            return [make_annotation(node, buffer=self.buffer), node]