#  Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS-IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import re
import textwrap
from typing import List, Tuple, Dict, Optional, Union, Callable, Iterator, Set

import ast

from _py2tmp.compiler.output_files import ObjectFileContent
from _py2tmp.coverage import SourceBranch
from _py2tmp.ir2 import ir2, get_free_variables, get_return_type


class Symbol:
    def __init__(self,
                 name: str,
                 expr_type: ir2.ExprType,
                 is_function_that_may_throw: bool,
                 source_module: Optional[str]):
        if is_function_that_may_throw:
            assert isinstance(expr_type, ir2.FunctionType)
        self.expr_type = expr_type
        self.name = name
        self.is_function_that_may_throw = is_function_that_may_throw
        self.source_module = source_module

class SymbolLookupResult:
    def __init__(self, symbol: Symbol, ast_node: ast.AST, is_only_partially_defined: bool, symbol_table: 'SymbolTable'):
        self.symbol = symbol
        self.ast_node = ast_node
        self.is_only_partially_defined = is_only_partially_defined
        self.symbol_table = symbol_table

class SymbolTable:
    def __init__(self, parent: 'SymbolTable' =None):
        self.symbols_by_name: Dict[str, Tuple[Symbol, ast.AST, bool]] = dict()
        self.parent = parent

    def get_symbol_definition(self, name: str):
        result = self.symbols_by_name.get(name)
        if result:
            symbol, ast_node, is_only_partially_defined = result
            return SymbolLookupResult(symbol, ast_node, is_only_partially_defined, self)
        if self.parent:
            return self.parent.get_symbol_definition(name)
        return None

    def add_symbol(self,
                   name: str,
                   expr_type: ir2.ExprType,
                   definition_ast_node: ast.AST,
                   is_only_partially_defined: bool,
                   is_function_that_may_throw: bool,
                   source_module: Optional[str]):
        if is_function_that_may_throw:
            assert isinstance(expr_type, ir2.FunctionType)
        self.symbols_by_name[name] = (Symbol(name, expr_type, is_function_that_may_throw, source_module),
                                      definition_ast_node,
                                      is_only_partially_defined)

class CompilationContext:
    def __init__(self,
                 symbol_table: SymbolTable,
                 custom_types_symbol_table: SymbolTable,
                 external_ir2_symbols_by_name_by_module: Dict[str, Dict[str, Union[ir2.FunctionDefn, ir2.CustomType]]],
                 filename: str,
                 source_lines: List[str],
                 identifier_generator: Iterator[str],
                 function_name: Optional[str] = None,
                 function_definition_line: Optional[int] = None,
                 first_enclosing_except_stmt_line: Optional[int] = None,
                 partially_typechecked_function_definitions_by_name: Dict[str, ast.FunctionDef] = None):
        assert (function_name is None) == (function_definition_line is None)
        self.symbol_table = symbol_table
        self.custom_types_symbol_table = custom_types_symbol_table
        self.external_ir2_symbols_by_name_by_module = external_ir2_symbols_by_name_by_module
        self.partially_typechecked_function_definitions_by_name = partially_typechecked_function_definitions_by_name or dict()
        self.filename = filename
        self.source_lines = source_lines
        self.current_function_name = function_name
        self.current_function_definition_line = function_definition_line
        self.first_enclosing_except_stmt_line = first_enclosing_except_stmt_line
        self.identifier_generator = identifier_generator

    def create_child_context(self,
                             function_name: Optional[str] = None,
                             function_definition_line: Optional[int] = None,
                             first_enclosing_except_stmt_line: Optional[int] = None):
        assert (function_name is None) == (function_definition_line is None)
        return CompilationContext(SymbolTable(parent=self.symbol_table),
                                  self.custom_types_symbol_table,
                                  self.external_ir2_symbols_by_name_by_module,
                                  self.filename,
                                  self.source_lines,
                                  self.identifier_generator,
                                  function_name=function_name or self.current_function_name,
                                  function_definition_line=function_definition_line or self.current_function_definition_line,
                                  first_enclosing_except_stmt_line=first_enclosing_except_stmt_line or self.first_enclosing_except_stmt_line,
                                  partially_typechecked_function_definitions_by_name=self.partially_typechecked_function_definitions_by_name)

    def add_symbol(self,
                   name: str,
                   expr_type: ir2.ExprType,
                   definition_ast_node: ast.AST,
                   is_only_partially_defined: bool,
                   is_function_that_may_throw: bool,
                   source_module: Optional[str] = None):
        """
        Adds a symbol to the symbol table.

        This throws an error (created by calling `create_already_defined_error(previous_type)`) if a symbol with the
        same name and different type was already defined in this scope.
        """
        if is_function_that_may_throw:
            assert isinstance(expr_type, ir2.FunctionType)

        self._check_not_already_defined(name, definition_ast_node)

        self.symbol_table.add_symbol(name=name,
                                     expr_type=expr_type,
                                     definition_ast_node=definition_ast_node,
                                     is_only_partially_defined=is_only_partially_defined,
                                     is_function_that_may_throw=is_function_that_may_throw,
                                     source_module=source_module)

    def add_custom_type_symbol(self,
                               custom_type: ir2.CustomType,
                               definition_ast_node: Union[ast.ClassDef, ast.ImportFrom, None],
                               source_module: Optional[str] = None):
        self.add_symbol(name=custom_type.name,
                        expr_type=ir2.FunctionType(argtypes=tuple(arg.expr_type
                                                                  for arg in custom_type.arg_types),
                                                   argnames=tuple(arg.name
                                                                  for arg in custom_type.arg_types),
                                                   returns=custom_type),
                        definition_ast_node=definition_ast_node,
                        is_only_partially_defined=False,
                        is_function_that_may_throw=False)
        self.custom_types_symbol_table.add_symbol(name=custom_type.name,
                                                  expr_type=custom_type,
                                                  definition_ast_node=definition_ast_node,
                                                  is_only_partially_defined=False,
                                                  is_function_that_may_throw=False,
                                                  source_module=source_module)

    def add_symbol_for_function_with_unknown_return_type(self,
                                                         name: str,
                                                         definition_ast_node: ast.FunctionDef):
        self._check_not_already_defined(name, definition_ast_node)

        self.partially_typechecked_function_definitions_by_name[name] = definition_ast_node

    def add_symbol_for_external_elem(self,
                                     elem: Union[ir2.FunctionDefn, ir2.CustomType],
                                     import_from_ast_node: ast.ImportFrom,
                                     source_module: str):
        if isinstance(elem, ir2.FunctionDefn):
            self.add_symbol(name=elem.name,
                            expr_type=ir2.FunctionType(argtypes=tuple(arg.expr_type for arg in elem.args),
                                                       argnames=tuple(arg.name for arg in elem.args),
                                                       returns=elem.return_type),
                            definition_ast_node=import_from_ast_node,
                            is_only_partially_defined=False,
                            is_function_that_may_throw=True,
                            source_module=source_module)
        elif isinstance(elem, ir2.CustomType):
            self.add_custom_type_symbol(custom_type=elem,
                                        definition_ast_node=import_from_ast_node,
                                        source_module=source_module)
        else:
            raise NotImplementedError('Unexpected elem type: %s' % elem.__class__.__name__)

    def get_symbol_definition(self, name: str):
        return self.symbol_table.get_symbol_definition(name)

    def get_partial_function_definition(self, name: str):
        return self.partially_typechecked_function_definitions_by_name.get(name)

    def get_type_symbol_definition(self, name: str):
        return self.custom_types_symbol_table.get_symbol_definition(name)

    def set_function_type(self, name: str, expr_type: ir2.FunctionType):
        if name in self.partially_typechecked_function_definitions_by_name:
            ast_node = self.partially_typechecked_function_definitions_by_name[name]
            del self.partially_typechecked_function_definitions_by_name[name]
            self.symbol_table.add_symbol(name=name,
                                         expr_type=expr_type,
                                         definition_ast_node=ast_node,
                                         is_only_partially_defined=False,
                                         is_function_that_may_throw=True,
                                         source_module=None)
        else:
            assert self.get_symbol_definition(name).symbol.expr_type == expr_type

    def _check_not_already_defined(self, name: str, definition_ast_node: ast.AST):
        symbol_lookup_result = self.symbol_table.get_symbol_definition(name)
        if not symbol_lookup_result:
            symbol_lookup_result = self.custom_types_symbol_table.get_symbol_definition(name)
        if symbol_lookup_result:
            is_only_partially_defined = symbol_lookup_result.is_only_partially_defined
            previous_definition_ast_node = symbol_lookup_result.ast_node
        elif name in self.partially_typechecked_function_definitions_by_name:
            is_only_partially_defined = False
            previous_definition_ast_node = self.partially_typechecked_function_definitions_by_name[name]
        else:
            is_only_partially_defined = None
            previous_definition_ast_node = None
        if previous_definition_ast_node:
            if is_only_partially_defined:
                raise CompilationError(self, definition_ast_node,
                                       '%s could be already initialized at this point.' % name,
                                       notes=[(previous_definition_ast_node, 'It might have been initialized here (depending on which branch is taken).')])
            else:
                raise CompilationError(self, definition_ast_node,
                                       '%s was already defined in this scope.' % name,
                                       notes=[(previous_definition_ast_node, 'The previous declaration was here.')])


class CompilationError(Exception):
    def __init__(self, compilation_context: CompilationContext, ast_node: ast.AST, error_message: str, notes: List[Tuple[ast.AST, str]] = ()):
        error_message = CompilationError._diagnostic_to_string(compilation_context=compilation_context,
                                                               ast_node=ast_node,
                                                               message='error: ' + error_message)
        notes = [CompilationError._diagnostic_to_string(compilation_context=compilation_context,
                                                        ast_node=note_ast_node,
                                                        message='note: ' + note_message)
                 for note_ast_node, note_message in notes]
        super().__init__(''.join([error_message] + notes))

    @staticmethod
    def _diagnostic_to_string(compilation_context: CompilationContext, ast_node: ast.AST, message: str):
        first_line_number = ast_node.lineno
        first_column_number = ast_node.col_offset
        error_marker = ' ' * first_column_number + '^'
        return textwrap.dedent('''\
            {filename}:{first_line_number}:{first_column_number}: {message}
            {line}
            {error_marker}
            ''').format(filename=compilation_context.filename,
                        first_line_number=first_line_number,
                        first_column_number=first_column_number,
                        message=message,
                        line=compilation_context.source_lines[first_line_number - 1],
                        error_marker=error_marker)


def _first_stmt_line(stmt: ast.AST):
    if isinstance(stmt, ast.ClassDef) and stmt.decorator_list:
        return stmt.decorator_list[0].lineno
    else:
        return stmt.lineno

def module_ast_to_ir2(module_ast_node: ast.Module,
                      filename: str,
                      source_lines: List[str],
                      identifier_generator: Iterator[str],
                      context_object_files: ObjectFileContent):
    external_ir2_symbols_by_name_by_module = {module_name: {elem.name: elem
                                                            for elem in itertools.chain(module_info.ir2_module.function_defns, module_info.ir2_module.custom_types)
                                                            if elem.name in module_info.ir2_module.public_names}
                                              for module_name, module_info in context_object_files.modules_by_name.items()
                                              if module_info.ir2_module}
    compilation_context = CompilationContext(SymbolTable(),
                                             SymbolTable(),
                                             external_ir2_symbols_by_name_by_module,
                                             filename,
                                             source_lines,
                                             identifier_generator)

    if module_ast_node.body:
        first_line = _first_stmt_line(module_ast_node.body[0])
    else:
        first_line = 1

    function_defns = []
    toplevel_assertions = []
    custom_types = []
    pass_stmts = []

    # First pass: process everything except function bodies and toplevel assertions
    for index, ast_node in enumerate(module_ast_node.body):
        if index + 1 < len(module_ast_node.body):
            next_stmt_line = _first_stmt_line(module_ast_node.body[index + 1])
        else:
            next_stmt_line = -first_line
        if isinstance(ast_node, ast.FunctionDef):
            function_name, arg_types, arg_names, return_type = function_def_ast_to_symbol_info(ast_node, compilation_context)

            if return_type:
                compilation_context.add_symbol(
                    name=function_name,
                    expr_type=ir2.FunctionType(argtypes=arg_types,
                                               argnames=arg_names,
                                               returns=return_type),
                    definition_ast_node=ast_node,
                    is_only_partially_defined=False,
                    is_function_that_may_throw=True)
            else:
                compilation_context.add_symbol_for_function_with_unknown_return_type(
                    name=function_name,
                    definition_ast_node=ast_node)
        elif isinstance(ast_node, ast.ImportFrom):
            _import_from_ast_to_ir2(ast_node, compilation_context)
            pass_stmts.append(ir2.PassStmt(SourceBranch(compilation_context.filename,
                                                        ast_node.lineno,
                                                        next_stmt_line)))
        elif isinstance(ast_node, ast.Import):
            raise CompilationError(compilation_context, ast_node,
                                   'TMPPy only supports imports of the form "from some_module import some_symbol, some_other_symbol".')
        elif isinstance(ast_node, ast.ClassDef):
            custom_type, additional_pass_stmts = class_definition_ast_to_ir2(ast_node, compilation_context, next_stmt_line)
            for pass_stmt in additional_pass_stmts:
                pass_stmts.append(pass_stmt)
            compilation_context.add_custom_type_symbol(custom_type=custom_type,
                                                       definition_ast_node=ast_node)
            custom_types.append(custom_type)
        elif isinstance(ast_node, ast.Assert):
            # We'll process this in the 2nd pass (since we need to infer function return types first).
            pass
        elif isinstance(ast_node, ast.Pass):
            pass_stmts.append(ir2.PassStmt(SourceBranch(compilation_context.filename,
                                                        ast_node.lineno,
                                                        next_stmt_line)))
        else:
            # raise CompilationError(compilation_context, ast_node, 'This Python construct is not supported in TMPPy:\n%s' % ast_to_string(ast_node))
            raise CompilationError(compilation_context, ast_node, 'This Python construct is not supported in TMPPy')

    # 2nd pass: process function bodies and toplevel assertions
    for index, ast_node in enumerate(module_ast_node.body):
        if index + 1 < len(module_ast_node.body):
            next_stmt_line = _first_stmt_line(module_ast_node.body[index + 1])
        else:
            next_stmt_line = -first_line

        if isinstance(ast_node, ast.FunctionDef):
            new_function_defn = function_def_ast_to_ir2(ast_node, compilation_context, next_stmt_line)
            function_defns.append(new_function_defn)
            pass_stmts.append(ir2.PassStmt(SourceBranch(compilation_context.filename,
                                                        ast_node.lineno,
                                                        next_stmt_line)))

            compilation_context.set_function_type(
                name=ast_node.name,
                expr_type=ir2.FunctionType(returns=new_function_defn.return_type,
                                           argtypes=tuple(arg.expr_type
                                                          for arg in new_function_defn.args),
                                           argnames=tuple(arg.name
                                                          for arg in new_function_defn.args)))
        elif isinstance(ast_node, ast.Assert):
            toplevel_assertions.append(assert_ast_to_ir2(ast_node, compilation_context, next_stmt_line))

    pass_stmts.append(ir2.PassStmt(source_branch=SourceBranch(file_name=compilation_context.filename,
                                                              source_line=-first_line,
                                                              dest_line=first_line)))
    if not module_ast_node.body:
        # There is an implicit Pass statement in empty modules.
        pass_stmts.append(ir2.PassStmt(source_branch=SourceBranch(file_name=filename,
                                                                  source_line=1,
                                                                  dest_line=-1)))

    public_names = set()
    for function_defn in function_defns:
        if not function_defn.name.startswith('_'):
            public_names.add(function_defn.name)

    return ir2.Module(function_defns=tuple(function_defns),
                      assertions=tuple(toplevel_assertions),
                      custom_types=tuple(custom_types),
                      public_names=frozenset(public_names),
                      pass_stmts=tuple(pass_stmts))


def _import_from_ast_to_ir2(ast_node: ast.ImportFrom, compilation_context: CompilationContext):
    if len(ast_node.names) == 0:
        raise CompilationError(compilation_context, ast_node,
                               'Imports must import at least 1 symbol.')  # pragma: no cover
    for imported_name in ast_node.names:
        if not isinstance(imported_name, ast.alias) or imported_name.asname:
            raise CompilationError(compilation_context, ast_node,
                                   'TMPPy only supports imports of the form "from some_module import some_symbol, some_other_symbol".')

    builtin_imports_by_module = {
        'tmppy': ('Type', 'empty_list', 'empty_set', 'match'),
        'typing': ('List', 'Set', 'Callable'),
        'dataclasses': ('dataclass',),
    }

    importable_names = builtin_imports_by_module.get(ast_node.module)
    action = lambda symbol_name: None

    if not importable_names:
        # TODO: require all directly imported modules to be specified directly instead of allowing transitively-present
        # modules.
        symbols_by_name = compilation_context.external_ir2_symbols_by_name_by_module.get(ast_node.module)
        if symbols_by_name is not None:
            importable_names = symbols_by_name.keys()
            action = lambda imported_name: compilation_context.add_symbol_for_external_elem(elem=compilation_context.external_ir2_symbols_by_name_by_module[ast_node.module][imported_name.name],
                                                                                            import_from_ast_node=ast_node,
                                                                                            source_module=ast_node.module)

    if importable_names is None:
        raise CompilationError(compilation_context, ast_node,
                               'Module not found. The only modules that can be imported are the builtin modules ('
                               + ', '.join(sorted(builtin_imports_by_module.keys()))
                               + ') and the modules in the specified object files ('
                               + (', '.join(sorted(name for name in compilation_context.external_ir2_symbols_by_name_by_module.keys() if name != '_py2tmp.compiler._tmppy_builtins')) or 'none')
                               + ')')


    for imported_name in ast_node.names:
        if imported_name.name not in importable_names:
            raise CompilationError(compilation_context, ast_node, 'The only supported imports from %s are: %s.' % (
                ast_node.module, ', '.join(sorted(importable_names))))
        action(imported_name)


def match_expression_ast_to_ir2(ast_node: ast.Call,
                                compilation_context: CompilationContext,
                                in_match_pattern: bool,
                                check_var_reference: Callable[[ast.Name], None],
                                match_lambda_argument_names: Set[str],
                                current_stmt_line: int):
    assert isinstance(ast_node.func, ast.Call)
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value, 'Keyword arguments are not allowed in match()')
    if ast_node.func.keywords:
        raise CompilationError(compilation_context, ast_node.func.keywords[0].value, 'Keyword arguments are not allowed in match()')
    if not ast_node.func.args:
        raise CompilationError(compilation_context, ast_node.func, 'Found match() with no arguments; it must have at least 1 argument.')
    matched_exprs = []
    for expr_ast in ast_node.func.args:
        expr = expression_ast_to_ir2(expr_ast, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
        if expr.expr_type != ir2.TypeType():
            raise CompilationError(compilation_context, expr_ast,
                                   'All arguments passed to match must have type Type, but an argument with type %s was specified.' % str(expr.expr_type))
        matched_exprs.append(expr)

    if len(ast_node.args) != 1 or not isinstance(ast_node.args[0], ast.Lambda):
        raise CompilationError(compilation_context, ast_node, 'Malformed match()')
    [lambda_expr_ast] = ast_node.args
    lambda_args = lambda_expr_ast.args
    if lambda_args.vararg:
        raise CompilationError(compilation_context, lambda_args.vararg,
                               'Malformed match(): vararg lambda arguments are not supported')
    assert not lambda_args.kwonlyargs
    assert not lambda_args.kw_defaults
    assert not lambda_args.defaults

    lambda_arg_ast_node_by_name = {arg.arg: arg
                                   for arg in lambda_args.args}
    lambda_arg_index_by_name = {arg.arg: i
                                for i, arg in enumerate(lambda_args.args)}
    lambda_arg_names = {arg.arg for arg in lambda_args.args}
    unused_lambda_arg_names = {arg.arg for arg in lambda_args.args}

    if not isinstance(lambda_expr_ast.body, ast.Dict):
        raise CompilationError(compilation_context, ast_node, 'Malformed match()')
    dict_expr_ast = lambda_expr_ast.body

    if not dict_expr_ast.keys:
        raise CompilationError(compilation_context, dict_expr_ast,
                               'An empty mapping dict was passed to match(), but at least 1 mapping is required.')

    parent_function_name = compilation_context.current_function_name
    assert parent_function_name

    main_definition = None
    main_definition_key_expr_ast = None
    last_result_expr_type = None
    last_result_expr_ast_node = None
    match_cases = []
    for key_expr_ast, value_expr_ast in zip(dict_expr_ast.keys, dict_expr_ast.values):
        match_case_compilation_context = compilation_context.create_child_context()

        if isinstance(key_expr_ast, ast.Tuple):
            pattern_ast_nodes = key_expr_ast.elts
        else:
            pattern_ast_nodes = [key_expr_ast]

        if len(pattern_ast_nodes) != len(matched_exprs):
            raise CompilationError(match_case_compilation_context, key_expr_ast,
                                   '%s type patterns were provided, while %s were expected' % (len(pattern_ast_nodes), len(matched_exprs)),
                                   [(ast_node.func, 'The corresponding match() was here')])

        pattern_exprs = []
        for pattern_ast_node in pattern_ast_nodes:
            pattern_expr = expression_ast_to_ir2(pattern_ast_node,
                                                 match_case_compilation_context,
                                                 in_match_pattern=True,
                                                 check_var_reference=check_var_reference,
                                                 match_lambda_argument_names=lambda_arg_names,
                                                 current_stmt_line=current_stmt_line)
            pattern_exprs.append(pattern_expr)
            if pattern_expr.expr_type != ir2.TypeType():
                raise CompilationError(match_case_compilation_context, pattern_ast_node,
                                       'Type patterns must have type Type but this pattern has type %s' % str(pattern_expr.expr_type),
                                       [(ast_node.func, 'The corresponding match() was here')])

        lambda_args_used_in_pattern = {var.name: var.expr_type
                                       for pattern_expr in pattern_exprs
                                       for var in get_free_variables(pattern_expr).values()
                                       if var.name in lambda_arg_names}
        for var in lambda_args_used_in_pattern.keys():
            unused_lambda_arg_names.discard(var)

        def check_var_reference_in_result_expr(ast_node: ast.Name):
            check_var_reference(ast_node)
            if ast_node.id in lambda_arg_names and not ast_node.id in lambda_args_used_in_pattern:
                raise CompilationError(match_case_compilation_context, ast_node,
                                       '%s was used in the result of this match branch but not in any of its patterns' % ast_node.id)

        result_expr = expression_ast_to_ir2(value_expr_ast,
                                            match_case_compilation_context,
                                            in_match_pattern=in_match_pattern,
                                            check_var_reference=check_var_reference_in_result_expr,
                                            match_lambda_argument_names=match_lambda_argument_names,
                                            current_stmt_line=current_stmt_line)

        if last_result_expr_type and result_expr.expr_type != last_result_expr_type:
            raise CompilationError(match_case_compilation_context, value_expr_ast,
                                   'All branches in a match() must return the same type, but this branch returns a %s '
                                   'while a previous branch in this match expression returns a %s' % (
                                       str(result_expr.expr_type), str(last_result_expr_type)),
                                   notes=[(last_result_expr_ast_node,
                                           'A previous branch returning a %s was here.' % str(last_result_expr_type))])
        last_result_expr_type = result_expr.expr_type
        last_result_expr_ast_node = value_expr_ast

        matched_var_names = set()
        matched_variadic_var_names = set()
        for arg, arg_type in lambda_args_used_in_pattern.items():
            if arg_type == ir2.TypeType():
                matched_var_names.add(arg)
            elif arg_type == ir2.ListType(ir2.TypeType()):
                matched_variadic_var_names.add(arg)
            else:
                raise NotImplementedError('Unexpected arg type: %s' % str(arg_type))

        match_case = ir2.MatchCase(matched_var_names=frozenset(matched_var_names),
                                   matched_variadic_var_names=frozenset(matched_variadic_var_names),
                                   type_patterns=tuple(pattern_exprs),
                                   expr=result_expr,
                                   match_case_start_branch=SourceBranch(compilation_context.filename,
                                                                        current_stmt_line,
                                                                        -lambda_expr_ast.lineno),
                                   match_case_end_branch=SourceBranch(compilation_context.filename,
                                                                      -lambda_expr_ast.lineno,
                                                                      current_stmt_line))
        match_cases.append(match_case)

        if match_case.is_main_definition():
            if main_definition:
                assert main_definition_key_expr_ast
                raise CompilationError(match_case_compilation_context, key_expr_ast,
                                       'Found multiple specializations that specialize nothing',
                                       notes=[(main_definition_key_expr_ast, 'A previous specialization that specializes nothing was here')])
            main_definition = match_case
            main_definition_key_expr_ast = key_expr_ast

    if unused_lambda_arg_names:
        unused_arg_name = max(unused_lambda_arg_names, key=lambda arg_name: lambda_arg_index_by_name[arg_name])
        unused_arg_ast_node = lambda_arg_ast_node_by_name[unused_arg_name]
        raise CompilationError(compilation_context, unused_arg_ast_node,
                               'The lambda argument %s was not used in any pattern, it should be removed.' % unused_arg_name)

    return ir2.MatchExpr(matched_exprs=tuple(matched_exprs),
                         match_cases=tuple(match_cases))

def return_stmt_ast_to_ir2(ast_node: ast.Return,
                           compilation_context: CompilationContext):
    expression = ast_node.value
    if not expression:
        raise CompilationError(compilation_context, ast_node,
                               'Return statements with no returned expression are not supported.')

    expression = expression_ast_to_ir2(expression,
                                       compilation_context,
                                       in_match_pattern=False,
                                       check_var_reference=lambda ast_node: None,
                                       match_lambda_argument_names=set(),
                                       current_stmt_line=ast_node.lineno)

    return ir2.ReturnStmt(expr=expression,
                          source_branch=SourceBranch(compilation_context.filename,
                                                     ast_node.lineno,
                                                     -compilation_context.current_function_definition_line))

def if_stmt_ast_to_ir2(ast_node: ast.If,
                       compilation_context: CompilationContext,
                       previous_return_stmt: Optional[Tuple[ir2.ExprType, ast.Return]],
                       check_always_returns: bool,
                       next_stmt_line: int):
    cond_expr = expression_ast_to_ir2(ast_node.test,
                                      compilation_context,
                                      in_match_pattern=False,
                                      check_var_reference=lambda ast_node: None,
                                      match_lambda_argument_names=set(),
                                      current_stmt_line=ast_node.lineno)
    if cond_expr.expr_type != ir2.BoolType():
        raise CompilationError(compilation_context, ast_node,
                               'The condition in an if statement must have type bool, but was: %s' % str(cond_expr.expr_type))

    if_branch_compilation_context = compilation_context.create_child_context()
    if_stmts, first_return_stmt = statements_ast_to_ir2(ast_node.body, if_branch_compilation_context,
                                                        previous_return_stmt=previous_return_stmt,
                                                        check_block_always_returns=check_always_returns,
                                                        stmts_are_toplevel_in_function=False,
                                                        next_stmt_line=next_stmt_line)

    if not previous_return_stmt and first_return_stmt:
        previous_return_stmt = first_return_stmt

    else_branch_compilation_context = compilation_context.create_child_context()

    if ast_node.orelse:
        else_stmts, first_return_stmt = statements_ast_to_ir2(ast_node.orelse,
                                                              else_branch_compilation_context,
                                                              previous_return_stmt=previous_return_stmt,
                                                              check_block_always_returns=check_always_returns,
                                                              stmts_are_toplevel_in_function=False,
                                                              next_stmt_line=next_stmt_line)

        if not previous_return_stmt and first_return_stmt:
            previous_return_stmt = first_return_stmt
    else:
        else_stmts = []
        if check_always_returns:
            raise CompilationError(compilation_context, ast_node,
                                   'Missing return statement. You should add an else branch that returns, or a return after the if.')

    _join_definitions_in_branches(compilation_context,
                                  if_branch_compilation_context,
                                  if_stmts,
                                  else_branch_compilation_context,
                                  else_stmts)

    if_branch_first_nontrivial_stmt_line = compute_next_stmt_line_number_by_index([None, *ast_node.body], next_stmt_line)[0]
    else_branch_first_nontrivial_stmt_line = compute_next_stmt_line_number_by_index([None, *ast_node.orelse], next_stmt_line)[0]

    stmt_ir = ir2.IfStmt(cond_expr=cond_expr,
                         if_stmts=(ir2.PassStmt(SourceBranch(compilation_context.filename,
                                                ast_node.lineno,
                                                if_branch_first_nontrivial_stmt_line)),
                                   *if_stmts),
                         else_stmts=(ir2.PassStmt(SourceBranch(compilation_context.filename,
                                                  ast_node.lineno,
                                                  else_branch_first_nontrivial_stmt_line)),
                                     *else_stmts))

    return stmt_ir, previous_return_stmt

def _join_definitions_in_branches(parent_context: CompilationContext,
                                  branch1_context: CompilationContext,
                                  branch1_stmts: Tuple[ir2.Stmt, ...],
                                  branch2_context: CompilationContext,
                                  branch2_stmts: Tuple[ir2.Stmt, ...]):
    branch1_return_info = get_return_type(branch1_stmts)
    branch2_return_info = get_return_type(branch2_stmts)

    symbol_names = set()
    if not branch1_return_info.always_returns:
        symbol_names = symbol_names.union(branch1_context.symbol_table.symbols_by_name.keys())
    if not branch2_return_info.always_returns:
        symbol_names = symbol_names.union(branch2_context.symbol_table.symbols_by_name.keys())

    for symbol_name in symbol_names:
        if branch1_return_info.always_returns or symbol_name not in branch1_context.symbol_table.symbols_by_name:
            branch1_symbol = None
            branch1_definition_ast_node = None
            branch1_symbol_is_only_partially_defined = None
        else:
            branch1_symbol, branch1_definition_ast_node, branch1_symbol_is_only_partially_defined = branch1_context.symbol_table.symbols_by_name[symbol_name]

        if branch2_return_info.always_returns or symbol_name not in branch2_context.symbol_table.symbols_by_name:
            branch2_symbol = None
            branch2_definition_ast_node = None
            branch2_symbol_is_only_partially_defined = None
        else:
            branch2_symbol, branch2_definition_ast_node, branch2_symbol_is_only_partially_defined = branch2_context.symbol_table.symbols_by_name[symbol_name]

        if branch1_symbol and branch2_symbol:
            if branch1_symbol.expr_type != branch2_symbol.expr_type:
                raise CompilationError(parent_context, branch2_definition_ast_node,
                                       'The variable %s is defined with type %s here, but it was previously defined with type %s in another branch.' % (
                                           symbol_name, str(branch2_symbol.expr_type), str(branch1_symbol.expr_type)),
                                       notes=[(branch1_definition_ast_node, 'A previous definition with type %s was here.' % str(branch1_symbol.expr_type))])
            symbol = branch1_symbol
            definition_ast_node = branch1_definition_ast_node
            is_only_partially_defined = branch1_symbol_is_only_partially_defined or branch2_symbol_is_only_partially_defined
        elif branch1_symbol:
            symbol = branch1_symbol
            definition_ast_node = branch1_definition_ast_node
            if branch2_return_info.always_returns:
                is_only_partially_defined = branch1_symbol_is_only_partially_defined
            else:
                is_only_partially_defined = True
        else:
            assert branch2_symbol
            symbol = branch2_symbol
            definition_ast_node = branch2_definition_ast_node
            if branch1_return_info.always_returns:
                is_only_partially_defined = branch2_symbol_is_only_partially_defined
            else:
                is_only_partially_defined = True

        parent_context.add_symbol(name=symbol.name,
                                  expr_type=symbol.expr_type,
                                  definition_ast_node=definition_ast_node,
                                  is_only_partially_defined=is_only_partially_defined,
                                  is_function_that_may_throw=isinstance(symbol.expr_type, ir2.FunctionType))

def raise_stmt_ast_to_ir2(ast_node: ast.Raise, compilation_context: CompilationContext):
    if ast_node.cause:
        raise CompilationError(compilation_context, ast_node.cause,
                               '"raise ... from ..." is not supported. Use a plain "raise ..." instead.')
    exception_expr = expression_ast_to_ir2(ast_node.exc,
                                           compilation_context,
                                           in_match_pattern=False,
                                           check_var_reference=lambda ast_node: None,
                                           match_lambda_argument_names=set(),
                                           current_stmt_line=ast_node.lineno)
    if not (isinstance(exception_expr.expr_type, ir2.CustomType) and exception_expr.expr_type.is_exception_class):
        if isinstance(exception_expr.expr_type, ir2.CustomType):
            custom_type_defn = compilation_context.get_type_symbol_definition(exception_expr.expr_type.name).ast_node
            notes = [(custom_type_defn, 'The type %s was defined here.' % exception_expr.expr_type.name)]
        else:
            notes = []
        raise CompilationError(compilation_context, ast_node.exc,
                               'Can\'t raise an exception of type "%s", because it\'s not a subclass of Exception.' % str(exception_expr.expr_type),
                               notes=notes)
    return ir2.RaiseStmt(expr=exception_expr, source_branch=SourceBranch(compilation_context.filename,
                                                                         ast_node.lineno,
                                                                         compilation_context.first_enclosing_except_stmt_line
                                                                         if compilation_context.first_enclosing_except_stmt_line
                                                                         else -compilation_context.current_function_definition_line))

def try_stmt_ast_to_ir2(ast_node: ast.Try,
                        compilation_context: CompilationContext,
                        previous_return_stmt: Optional[Tuple[ir2.ExprType, ast.Return]],
                        check_always_returns: bool,
                        is_toplevel_in_function: bool,
                        next_stmt_line: int):

    if not is_toplevel_in_function:
        raise CompilationError(compilation_context, ast_node,
                               'try-except blocks are only supported at top-level in functions (not e.g. inside if-else statements).')

    if not ast_node.handlers:
        raise CompilationError(compilation_context, ast_node,
                               '"try" blocks must have an "except" clause.')
    # TODO: consider supporting this case too.
    if len(ast_node.handlers) > 1:
        raise CompilationError(compilation_context, ast_node.handlers[1],
                               '"try" blocks with multiple "except" clauses are not currently supported.')
    [handler] = ast_node.handlers

    body_compilation_context = compilation_context.create_child_context(first_enclosing_except_stmt_line=handler.lineno)
    body_stmts, body_first_return_stmt = statements_ast_to_ir2(ast_node.body,
                                                               body_compilation_context,
                                                               check_block_always_returns=check_always_returns,
                                                               previous_return_stmt=previous_return_stmt,
                                                               stmts_are_toplevel_in_function=False,
                                                               next_stmt_line=next_stmt_line)
    if not previous_return_stmt:
        previous_return_stmt = body_first_return_stmt

    if not (isinstance(handler, ast.ExceptHandler)
            and isinstance(handler.type, ast.Name)
            and isinstance(handler.type.ctx, ast.Load)
            and handler.name):
        raise CompilationError(compilation_context, handler,
                               '"except" clauses must be of the form: except SomeType as some_var')

    # TODO: consider adding support for this.
    if handler.type.id == 'Exception':
        raise CompilationError(compilation_context, handler.type,
                               'Catching all exceptions is not supported, you must catch a specific exception type.')

    caught_exception_type = type_declaration_ast_to_ir2_expression_type(handler.type, compilation_context)

    if ast_node.orelse:
        raise CompilationError(compilation_context, ast_node.orelse[0], '"else" clauses are not supported in try-except.')

    if ast_node.finalbody:
        raise CompilationError(compilation_context, ast_node.finalbody[0], '"finally" clauses are not supported.')

    except_body_compilation_context = compilation_context.create_child_context()
    except_body_compilation_context.add_symbol(name=handler.name,
                                               expr_type=caught_exception_type,
                                               definition_ast_node=handler,
                                               is_only_partially_defined=False,
                                               is_function_that_may_throw=False)
    except_body_stmts, except_body_first_return_stmt = statements_ast_to_ir2(handler.body,
                                                                             except_body_compilation_context,
                                                                             check_block_always_returns=check_always_returns,
                                                                             previous_return_stmt=previous_return_stmt,
                                                                             stmts_are_toplevel_in_function=False,
                                                                             next_stmt_line=next_stmt_line)
    if not previous_return_stmt:
        previous_return_stmt = except_body_first_return_stmt

    _join_definitions_in_branches(compilation_context,
                                  body_compilation_context,
                                  body_stmts,
                                  except_body_compilation_context,
                                  except_body_stmts)

    try_branch_first_nontrivial_stmt_line = compute_next_stmt_line_number_by_index([None, *ast_node.body], next_stmt_line)[0]
    except_branch_first_nontrivial_stmt_line = compute_next_stmt_line_number_by_index([None, *handler.body], next_stmt_line)[0]

    try_except_stmt = ir2.TryExcept(try_body=body_stmts,
                                    caught_exception_type=caught_exception_type,
                                    caught_exception_name=handler.name,
                                    except_body=except_body_stmts,
                                    try_branch=SourceBranch(compilation_context.filename,
                                                            ast_node.lineno,
                                                            try_branch_first_nontrivial_stmt_line),
                                    except_branch=SourceBranch(compilation_context.filename,
                                                               handler.lineno,
                                                               except_branch_first_nontrivial_stmt_line))

    return try_except_stmt, previous_return_stmt

def statements_ast_to_ir2(ast_nodes: List[ast.AST],
                          compilation_context: CompilationContext,
                          previous_return_stmt: Optional[Tuple[ir2.ExprType, ast.Return]],
                          check_block_always_returns: bool,
                          stmts_are_toplevel_in_function: bool,
                          next_stmt_line: int):
    assert ast_nodes

    next_stmt_line_number_by_index = compute_next_stmt_line_number_by_index(ast_nodes, next_stmt_line)

    statements = []
    first_return_stmt = None
    for index, statement_node in enumerate(ast_nodes):
        next_stmt_line = next_stmt_line_number_by_index[index]
        if get_return_type(statements).always_returns:
            raise CompilationError(compilation_context, statement_node, 'Unreachable statement.')

        check_stmt_always_returns = check_block_always_returns and statement_node is ast_nodes[-1]

        if isinstance(statement_node, ast.Assert):
            statements.append(assert_ast_to_ir2(statement_node, compilation_context, next_stmt_line))
        elif isinstance(statement_node, ast.Assign) or isinstance(statement_node, ast.AnnAssign) or isinstance(statement_node, ast.AugAssign):
            statements.append(assignment_ast_to_ir2(statement_node, compilation_context, next_stmt_line))
        elif isinstance(statement_node, ast.Return):
            return_stmt = return_stmt_ast_to_ir2(statement_node, compilation_context)
            if previous_return_stmt:
                previous_return_stmt_type, previous_return_stmt_ast_node = previous_return_stmt
                if return_stmt.expr.expr_type != previous_return_stmt_type:
                    raise CompilationError(compilation_context, statement_node,
                                           'Found return statement with different return type: %s instead of %s.' % (
                                               str(return_stmt.expr.expr_type), str(previous_return_stmt_type)),
                                           notes=[(previous_return_stmt_ast_node, 'A previous return statement returning a %s was here.' % (
                                               str(previous_return_stmt_type),))])
            if not first_return_stmt:
                first_return_stmt = (return_stmt.expr.expr_type, statement_node)
            if not previous_return_stmt:
                previous_return_stmt = first_return_stmt
            statements.append(return_stmt)
        elif isinstance(statement_node, ast.If):
            if_stmt, first_return_stmt_in_if = if_stmt_ast_to_ir2(statement_node,
                                                                  compilation_context,
                                                                  previous_return_stmt,
                                                                  check_stmt_always_returns,
                                                                  next_stmt_line)
            if not first_return_stmt:
                first_return_stmt = first_return_stmt_in_if
            if not previous_return_stmt:
                previous_return_stmt = first_return_stmt
            statements.append(if_stmt)
        elif isinstance(statement_node, ast.Raise):
            statements.append(raise_stmt_ast_to_ir2(statement_node, compilation_context))
        elif isinstance(statement_node, ast.Try):
            try_except_stmt, first_return_stmt_in_try_except = try_stmt_ast_to_ir2(statement_node,
                                                                                   compilation_context,
                                                                                   previous_return_stmt=previous_return_stmt,
                                                                                   check_always_returns=check_stmt_always_returns,
                                                                                   is_toplevel_in_function=stmts_are_toplevel_in_function,
                                                                                   next_stmt_line=next_stmt_line)
            if not first_return_stmt:
                first_return_stmt = first_return_stmt_in_try_except
            if not previous_return_stmt:
                previous_return_stmt = first_return_stmt
            statements.append(try_except_stmt)
        elif isinstance(statement_node, ast.Pass):
            statements.append(ir2.PassStmt(SourceBranch(compilation_context.filename,
                                                        statement_node.lineno,
                                                        next_stmt_line)))
        elif isinstance(statement_node, ast.Expr):
            statements.append(expression_stmt_to_ir2(statement_node, compilation_context, next_stmt_line))
        else:
            raise CompilationError(compilation_context, statement_node, 'Unsupported statement.')

    if check_block_always_returns and not get_return_type(statements).always_returns:
        raise CompilationError(compilation_context, ast_nodes[-1],
                               'Missing return statement.')

    return tuple(statements), first_return_stmt


def compute_next_stmt_line_number_by_index(ast_nodes: List[Optional[ast.AST]], next_stmt_line: int) -> List[Optional[int]]:
    next_stmt_line_number_by_index: List[Optional[int]] = [None] * len(ast_nodes)
    next_stmt_line_number = next_stmt_line
    for index, statement_node in reversed(list(enumerate(ast_nodes))):
        next_stmt_line_number_by_index[index] = next_stmt_line_number
        if statement_node is not None:
            next_stmt_line_number = statement_node.lineno
    return next_stmt_line_number_by_index

def function_def_ast_to_symbol_info(ast_node: ast.FunctionDef, compilation_context: CompilationContext):
    function_body_compilation_context = compilation_context.create_child_context(function_name=ast_node.name,
                                                                                 function_definition_line=ast_node.lineno)
    arg_types = []
    arg_names = []
    for arg in ast_node.args.args:
        if not arg.annotation:
            if arg.type_comment:
                raise CompilationError(compilation_context, arg, 'All function arguments must have a type annotation. Note that type comments are not supported.')
            else:
                raise CompilationError(compilation_context, arg, 'All function arguments must have a type annotation.')
        arg_type = type_declaration_ast_to_ir2_expression_type(arg.annotation, compilation_context)
        function_body_compilation_context.add_symbol(name=arg.arg,
                                                     expr_type=arg_type,
                                                     definition_ast_node=arg,
                                                     is_only_partially_defined=False,
                                                     is_function_that_may_throw=isinstance(arg_type, ir2.FunctionType))
        arg_types.append(arg_type)
        arg_names.append(arg.arg)
    if not arg_types:
        raise CompilationError(compilation_context, ast_node, 'Functions with no arguments are not supported.')

    if ast_node.args.vararg:
        raise CompilationError(compilation_context, ast_node, 'Function vararg arguments are not supported.')
    if ast_node.args.kwonlyargs:
        raise CompilationError(compilation_context, ast_node, 'Keyword-only function arguments are not supported.')
    if ast_node.args.kw_defaults or ast_node.args.defaults:
        raise CompilationError(compilation_context, ast_node, 'Default values for function arguments are not supported.')
    if ast_node.args.kwarg:
        raise CompilationError(compilation_context, ast_node, 'Keyword function arguments are not supported.')
    if ast_node.decorator_list:
        raise CompilationError(compilation_context, ast_node, 'Function decorators are not supported.')

    if ast_node.returns:
        return_type = type_declaration_ast_to_ir2_expression_type(ast_node.returns, compilation_context)
    else:
        return_type = None

    return ast_node.name, tuple(arg_types), tuple(arg_names), return_type

def function_def_ast_to_ir2(ast_node: ast.FunctionDef, compilation_context: CompilationContext, next_stmt_line: int):
    function_body_compilation_context = compilation_context.create_child_context(function_name=ast_node.name,
                                                                                 function_definition_line=ast_node.lineno)
    args = []
    for arg in ast_node.args.args:
        arg_type = type_declaration_ast_to_ir2_expression_type(arg.annotation, compilation_context)
        function_body_compilation_context.add_symbol(name=arg.arg,
                                                     expr_type=arg_type,
                                                     definition_ast_node=arg,
                                                     is_only_partially_defined=False,
                                                     is_function_that_may_throw=isinstance(arg_type, ir2.FunctionType))
        args.append(ir2.FunctionArgDecl(expr_type=arg_type, name=arg.arg))

    statements, first_return_stmt = statements_ast_to_ir2(ast_node.body, function_body_compilation_context,
                                                          previous_return_stmt=None,
                                                          check_block_always_returns=True,
                                                          stmts_are_toplevel_in_function=True,
                                                          next_stmt_line=-ast_node.lineno)

    return_type = None
    first_return_stmt_ast_node = None
    if first_return_stmt:
        return_type, first_return_stmt_ast_node = first_return_stmt

    if ast_node.returns:
        declared_return_type = type_declaration_ast_to_ir2_expression_type(ast_node.returns, compilation_context)

        # first_return_stmt can be None if the function raises an exception instead of returning in all branches.
        if first_return_stmt:
            if declared_return_type != return_type:
                raise CompilationError(compilation_context, ast_node.returns,
                                       '%s declared %s as return type, but the actual return type was %s.' % (
                                           ast_node.name, str(declared_return_type), str(return_type)),
                                       notes=[(first_return_stmt_ast_node, 'A %s was returned here' % str(return_type))])

        return_type = declared_return_type

    if not first_return_stmt and not ast_node.returns:
        return_type = ir2.BottomType()

    first_nontrivial_stmt_line = compute_next_stmt_line_number_by_index([None, *ast_node.body], next_stmt_line)[0]

    return ir2.FunctionDefn(name=ast_node.name,
                            args=tuple(args),
                            body=(ir2.PassStmt(SourceBranch(compilation_context.filename,
                                                            -ast_node.lineno,
                                                            first_nontrivial_stmt_line)),
                                  *statements),
                            return_type=return_type)

def assert_ast_to_ir2(ast_node: ast.Assert, compilation_context: CompilationContext, next_stmt_line: int):
    expr = expression_ast_to_ir2(ast_node.test,
                                 compilation_context,
                                 in_match_pattern=False,
                                 check_var_reference=lambda ast_node: None,
                                 match_lambda_argument_names=set(),
                                 current_stmt_line=ast_node.lineno)

    if not isinstance(expr.expr_type, ir2.BoolType):
        raise CompilationError(compilation_context, ast_node.test,
                               'The value passed to assert must have type bool, but got a value with type %s.' % expr.expr_type)

    if ast_node.msg:
        assert isinstance(ast_node.msg, ast.Str)
        message = ast_node.msg.s
    else:
        message = ''

    first_line_number = ast_node.lineno
    message = 'TMPPy assertion failed: {message}\n{filename}:{first_line_number}: {line}'.format(
        filename=compilation_context.filename,
        first_line_number=first_line_number,
        message=message,
        line=compilation_context.source_lines[first_line_number - 1])
    message = message.replace('\\', '\\\\').replace('"', '\"').replace('\n', '\\n')

    return ir2.Assert(expr=expr,
                      message=message,
                      source_branch=SourceBranch(compilation_context.filename,
                                                 ast_node.lineno,
                                                 next_stmt_line))

def assignment_ast_to_ir2(ast_node: Union[ast.Assign, ast.AnnAssign, ast.AugAssign],
                          compilation_context: CompilationContext,
                          next_stmt_line: int):
    if isinstance(ast_node, ast.AugAssign):
        raise CompilationError(compilation_context, ast_node, 'Augmented assignments are not supported.')
    if isinstance(ast_node, ast.AnnAssign):
        raise CompilationError(compilation_context, ast_node, 'Assignments with type annotations are not supported.')
    assert isinstance(ast_node, ast.Assign)
    if ast_node.type_comment:
        raise CompilationError(compilation_context, ast_node, 'Type comments in assignments are not supported.')
    if len(ast_node.targets) > 1:
        raise CompilationError(compilation_context, ast_node, 'Multi-assignment is not supported.')
    [target] = ast_node.targets
    if isinstance(target, ast.List) or isinstance(target, ast.Tuple):
        # This is an "unpacking" assignment
        for lhs_elem_ast_node in target.elts:
            if not isinstance(lhs_elem_ast_node, ast.Name):
                raise CompilationError(compilation_context, lhs_elem_ast_node,
                                       'This kind of unpacking assignment is not supported. Only unpacking assignments of the form x,y=... or [x,y]=... are supported.')

        expr = expression_ast_to_ir2(ast_node.value,
                                     compilation_context,
                                     in_match_pattern=False,
                                     check_var_reference=lambda ast_node: None,
                                     match_lambda_argument_names=set(),
                                     current_stmt_line=ast_node.lineno)
        if not isinstance(expr.expr_type, ir2.ListType):
            raise CompilationError(compilation_context, ast_node,
                                   'Unpacking requires a list on the RHS, but the value on the RHS has type %s' % str(expr.expr_type))
        elem_type = expr.expr_type.elem_type

        var_refs = []
        for lhs_elem_ast_node in target.elts:
            lhs_var_name = lhs_elem_ast_node.id
            if lhs_var_name == '_':
                lhs_var_name = next(compilation_context.identifier_generator)
            compilation_context.add_symbol(name=lhs_var_name,
                                           expr_type=elem_type,
                                           definition_ast_node=lhs_elem_ast_node,
                                           is_only_partially_defined=False,
                                           is_function_that_may_throw=isinstance(elem_type, ir2.FunctionType))

            var_ref = ir2.VarReference(expr_type=elem_type,
                                       name=lhs_var_name,
                                       is_global_function=False,
                                       is_function_that_may_throw=isinstance(elem_type, ir2.FunctionType))
            var_refs.append(var_ref)

        first_line_number = ast_node.lineno
        message = 'unexpected number of elements in the TMPPy list unpacking at:\n{filename}:{first_line_number}: {line}'.format(
            filename=compilation_context.filename,
            first_line_number=first_line_number,
            line=compilation_context.source_lines[first_line_number - 1])
        message = message.replace('\\', '\\\\').replace('"', '\"').replace('\n', '\\n')

        return ir2.UnpackingAssignment(lhs_list=tuple(var_refs),
                                       rhs=expr,
                                       error_message=message,
                                       source_branch=SourceBranch(compilation_context.filename,
                                                                  ast_node.lineno,
                                                                  next_stmt_line))

    elif isinstance(target, ast.Name):
        # This is a "normal" assignment
        expr = expression_ast_to_ir2(ast_node.value,
                                     compilation_context,
                                     in_match_pattern=False,
                                     check_var_reference=lambda ast_node: None,
                                     match_lambda_argument_names=set(),
                                     current_stmt_line=ast_node.lineno)

        lhs_var_name = target.id
        if lhs_var_name == '_':
            lhs_var_name = next(compilation_context.identifier_generator)
        compilation_context.add_symbol(name=lhs_var_name,
                                       expr_type=expr.expr_type,
                                       definition_ast_node=target,
                                       is_only_partially_defined=False,
                                       is_function_that_may_throw=isinstance(expr.expr_type, ir2.FunctionType))

        return ir2.Assignment(lhs=ir2.VarReference(expr_type=expr.expr_type,
                                                   name=lhs_var_name,
                                                   is_global_function=False,
                                                   is_function_that_may_throw=isinstance(expr.expr_type, ir2.FunctionType)),
                              rhs=expr,
                              source_branch=SourceBranch(compilation_context.filename,
                                                         ast_node.lineno,
                                                         next_stmt_line))
    else:
        raise CompilationError(compilation_context, ast_node, 'Assignment not supported.')

def expression_stmt_to_ir2(stmt: ast.Expr, compilation_context: CompilationContext, next_stmt_line: int):
    expr = expression_ast_to_ir2(stmt.value,
                                 compilation_context,
                                 in_match_pattern=False,
                                 check_var_reference=lambda ast_node: None,
                                 match_lambda_argument_names=set(),
                                 current_stmt_line=stmt.lineno)

    lhs_var_name = next(compilation_context.identifier_generator)
    compilation_context.add_symbol(name=lhs_var_name,
                                   expr_type=expr.expr_type,
                                   definition_ast_node=stmt,
                                   is_only_partially_defined=False,
                                   is_function_that_may_throw=isinstance(expr.expr_type, ir2.FunctionType))

    return ir2.Assignment(lhs=ir2.VarReference(expr_type=expr.expr_type,
                                               name=lhs_var_name,
                                               is_global_function=False,
                                               is_function_that_may_throw=isinstance(expr.expr_type, ir2.FunctionType)),
                          rhs=expr,
                          source_branch=SourceBranch(compilation_context.filename,
                                                     stmt.lineno,
                                                     next_stmt_line))

def int_comparison_ast_to_ir2(lhs_ast_node: ast.AST,
                              rhs_ast_node: ast.AST,
                              op: str,
                              compilation_context: CompilationContext,
                              in_match_pattern: bool,
                              check_var_reference: Callable[[ast.Name], None],
                              current_stmt_line: int):
    lhs = expression_ast_to_ir2(lhs_ast_node,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names=set(),
                                current_stmt_line=current_stmt_line)
    rhs = expression_ast_to_ir2(rhs_ast_node,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names=set(),
                                current_stmt_line=current_stmt_line)

    if lhs.expr_type != ir2.IntType():
        raise CompilationError(compilation_context, lhs_ast_node,
                               'The "%s" operator is only supported for ints, but this value has type %s.' % (op, str(lhs.expr_type)))
    if rhs.expr_type != ir2.IntType():
        raise CompilationError(compilation_context, rhs_ast_node,
                               'The "%s" operator is only supported for ints, but this value has type %s.' % (op, str(rhs.expr_type)))

    return ir2.IntComparisonExpr(lhs=lhs, rhs=rhs, op=op)

def compare_ast_to_ir2(ast_node: ast.Compare,
                       compilation_context: CompilationContext,
                       in_match_pattern: bool,
                       check_var_reference: Callable[[ast.Name], None],
                       match_lambda_argument_names: Set[str],
                       current_stmt_line: int):
    if len(ast_node.ops) != 1 or len(ast_node.comparators) != 1:
        raise CompilationError(compilation_context, ast_node, 'Comparison not supported.')  # pragma: no cover

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'Comparisons are not allowed in match patterns')

    lhs = ast_node.left
    op = ast_node.ops[0]
    rhs = ast_node.comparators[0]

    if isinstance(op, ast.Eq):
        return eq_ast_to_ir2(lhs, rhs, compilation_context, in_match_pattern, check_var_reference, current_stmt_line)
    elif isinstance(op, ast.NotEq):
        return not_eq_ast_to_ir2(lhs, rhs, compilation_context, in_match_pattern, check_var_reference, current_stmt_line)
    elif isinstance(op, ast.In):
        return in_ast_to_ir2(lhs, rhs, compilation_context, in_match_pattern, check_var_reference, current_stmt_line)
    elif isinstance(op, ast.Lt):
        return int_comparison_ast_to_ir2(lhs, rhs, '<', compilation_context, in_match_pattern, check_var_reference, current_stmt_line)
    elif isinstance(op, ast.LtE):
        return int_comparison_ast_to_ir2(lhs, rhs, '<=', compilation_context, in_match_pattern, check_var_reference, current_stmt_line)
    elif isinstance(op, ast.Gt):
        return int_comparison_ast_to_ir2(lhs, rhs, '>', compilation_context, in_match_pattern, check_var_reference, current_stmt_line)
    elif isinstance(op, ast.GtE):
        return int_comparison_ast_to_ir2(lhs, rhs, '>=', compilation_context, in_match_pattern, check_var_reference, current_stmt_line)
    else:
        raise CompilationError(compilation_context, ast_node, 'Comparison not supported.')  # pragma: no cover

def attribute_expression_ast_to_ir2(ast_node: ast.Attribute,
                                    compilation_context: CompilationContext,
                                    in_match_pattern: bool,
                                    check_var_reference: Callable[[ast.Name], None],
                                    match_lambda_argument_names: Set[str],
                                    current_stmt_line: int):
    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'Attribute access is not allowed in match patterns')

    value_expr = expression_ast_to_ir2(ast_node.value,
                                       compilation_context,
                                       in_match_pattern,
                                       check_var_reference,
                                       match_lambda_argument_names,
                                       current_stmt_line)
    if isinstance(value_expr.expr_type, ir2.TypeType):
        return ir2.AttributeAccessExpr(expr=value_expr,
                                       attribute_name=ast_node.attr,
                                       expr_type=ir2.TypeType())
    elif isinstance(value_expr.expr_type, ir2.CustomType):
        for arg in value_expr.expr_type.arg_types:
            if arg.name == ast_node.attr:
                return ir2.AttributeAccessExpr(expr=value_expr,
                                               attribute_name=ast_node.attr,
                                               expr_type=arg.expr_type)
        else:
            lookup_result = compilation_context.get_type_symbol_definition(value_expr.expr_type.name)
            assert lookup_result
            raise CompilationError(compilation_context, ast_node.value,
                                   'Values of type "%s" don\'t have the attribute "%s". The available attributes for this type are: {"%s"}.' % (
                                       str(value_expr.expr_type), ast_node.attr, '", "'.join(sorted(arg.name
                                                                                               for arg in value_expr.expr_type.arg_types))),
                                   notes=[(lookup_result.ast_node, '%s was defined here.' % str(value_expr.expr_type))])
    else:
        raise CompilationError(compilation_context, ast_node.value,
                               'Attribute access is not supported for values of type %s.' % str(value_expr.expr_type))

def number_literal_expression_ast_to_ir2(ast_node: ast.Num,
                                         compilation_context: CompilationContext,
                                         in_match_pattern: bool,
                                         check_var_reference: Callable[[ast.Name], None],
                                         match_lambda_argument_names: Set[str],
                                         positive: bool):
    n = ast_node.n
    if isinstance(n, float):
        raise CompilationError(compilation_context, ast_node, 'Floating-point values are not supported.')
    if isinstance(n, complex):
        raise CompilationError(compilation_context, ast_node, 'Complex values are not supported.')
    assert isinstance(n, int)
    if not positive:
        n = -n
    if n <= -2**63:
        raise CompilationError(compilation_context, ast_node,
                               'int value out of bounds: values lower than -2^63+1 are not supported.')
    if n >= 2**63:
        raise CompilationError(compilation_context, ast_node,
                               'int value out of bounds: values greater than 2^63-1 are not supported.')
    return ir2.IntLiteral(value=n)

def and_expression_ast_to_ir2(ast_node: ast.BoolOp,
                              compilation_context: CompilationContext,
                              in_match_pattern: bool,
                              check_var_reference: Callable[[ast.Name], None],
                              match_lambda_argument_names: Set[str],
                              current_stmt_line: int):
    assert isinstance(ast_node.op, ast.And)

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'The "and" operator is not allowed in match patterns')

    if not compilation_context.current_function_name:
        raise CompilationError(compilation_context, ast_node,
                               'The "and" operator is only supported in functions, not at toplevel.')

    assert len(ast_node.values) >= 2

    exprs = []
    for expr_ast_node in ast_node.values:
        expr = expression_ast_to_ir2(expr_ast_node,
                                     compilation_context,
                                     in_match_pattern,
                                     check_var_reference,
                                     match_lambda_argument_names,
                                     current_stmt_line)
        if expr.expr_type != ir2.BoolType():
            raise CompilationError(compilation_context, expr_ast_node,
                                   'The "and" operator is only supported for booleans, but this value has type %s.' % str(expr.expr_type))
        exprs.append(expr)

    final_expr = exprs[-1]
    for expr in reversed(exprs[:-1]):
        final_expr = ir2.AndExpr(lhs=expr, rhs=final_expr)

    return final_expr

def or_expression_ast_to_ir2(ast_node: ast.BoolOp,
                             compilation_context: CompilationContext,
                             in_match_pattern: bool,
                             check_var_reference: Callable[[ast.Name], None],
                             match_lambda_argument_names: Set[str],
                             current_stmt_line: int):
    assert isinstance(ast_node.op, ast.Or)

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'The "or" operator is not allowed in match patterns')

    if not compilation_context.current_function_name:
        raise CompilationError(compilation_context, ast_node,
                               'The "or" operator is only supported in functions, not at toplevel.')

    assert len(ast_node.values) >= 2

    exprs = []
    for expr_ast_node in ast_node.values:
        expr = expression_ast_to_ir2(expr_ast_node,
                                     compilation_context,
                                     in_match_pattern,
                                     check_var_reference,
                                     match_lambda_argument_names,
                                     current_stmt_line)
        if expr.expr_type != ir2.BoolType():
            raise CompilationError(compilation_context, expr_ast_node,
                                   'The "or" operator is only supported for booleans, but this value has type %s.' % str(expr.expr_type))
        exprs.append(expr)

    final_expr = exprs[-1]
    for expr in reversed(exprs[:-1]):
        final_expr = ir2.OrExpr(lhs=expr, rhs=final_expr)

    return final_expr

def not_expression_ast_to_ir2(ast_node: ast.UnaryOp,
                              compilation_context: CompilationContext,
                              in_match_pattern: bool,
                              check_var_reference: Callable[[ast.Name], None],
                              match_lambda_argument_names: Set[str],
                              current_stmt_line: int):
    assert isinstance(ast_node.op, ast.Not)

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'The "not" operator is not allowed in match patterns')

    expr = expression_ast_to_ir2(ast_node.operand,
                                 compilation_context,
                                 in_match_pattern,
                                 check_var_reference,
                                 match_lambda_argument_names,
                                 current_stmt_line)

    if expr.expr_type != ir2.BoolType():
        raise CompilationError(compilation_context, ast_node.operand,
                               'The "not" operator is only supported for booleans, but this value has type %s.' % str(expr.expr_type))

    return ir2.NotExpr(expr=expr)

def unary_minus_expression_ast_to_ir2(ast_node: ast.UnaryOp,
                                      compilation_context: CompilationContext,
                                      in_match_pattern: bool,
                                      check_var_reference: Callable[[ast.Name], None],
                                      match_lambda_argument_names: Set[str],
                                      current_stmt_line: int):
    assert isinstance(ast_node.op, ast.USub)

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'The "-" operator is not allowed in match patterns')

    expr = expression_ast_to_ir2(ast_node.operand,
                                 compilation_context,
                                 in_match_pattern,
                                 check_var_reference,
                                 match_lambda_argument_names,
                                 current_stmt_line)

    if expr.expr_type != ir2.IntType():
        raise CompilationError(compilation_context, ast_node.operand,
                               'The "-" operator is only supported for ints, but this value has type %s.' % str(expr.expr_type))

    return ir2.IntUnaryMinusExpr(expr=expr)

def int_binary_op_expression_ast_to_ir2(ast_node: ast.BinOp,
                                        op: str,
                                        compilation_context: CompilationContext,
                                        in_match_pattern: bool,
                                        check_var_reference: Callable[[ast.Name], None],
                                        match_lambda_argument_names: Set[str],
                                        current_stmt_line: int):
    lhs = expression_ast_to_ir2(ast_node.left,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names,
                                current_stmt_line)
    rhs = expression_ast_to_ir2(ast_node.right,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names,
                                current_stmt_line)

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'The "%s" operator is not allowed in match patterns' % op)

    if lhs.expr_type != ir2.IntType():
        raise CompilationError(compilation_context, ast_node.left,
                               'The "%s" operator is only supported for ints, but this value has type %s.' % (op, str(lhs.expr_type)))

    if rhs.expr_type != ir2.IntType():
        raise CompilationError(compilation_context, ast_node.right,
                               'The "%s" operator is only supported for ints, but this value has type %s.' % (op, str(rhs.expr_type)))

    return ir2.IntBinaryOpExpr(lhs=lhs, rhs=rhs, op=op)

def list_comprehension_ast_to_ir2(ast_node: ast.ListComp,
                                  compilation_context: CompilationContext,
                                  in_match_pattern: bool,
                                  check_var_reference: Callable[[ast.Name], None],
                                  match_lambda_argument_names: Set[str],
                                  current_stmt_line: int):
    assert ast_node.generators
    if len(ast_node.generators) > 1:
        raise CompilationError(compilation_context, ast_node.generators[1].target,
                               'List comprehensions with multiple "for" clauses are not currently supported.')

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'List comprehensions are not allowed in match patterns')

    [generator] = ast_node.generators
    if generator.ifs:
        raise CompilationError(compilation_context, generator.ifs[0],
                               '"if" clauses in list comprehensions are not currently supported.')
    if not isinstance(generator.target, ast.Name):
        raise CompilationError(compilation_context, generator.target,
                               'Only list comprehensions of the form [... for var_name in ...] are supported.')

    list_expr = expression_ast_to_ir2(generator.iter,
                                      compilation_context,
                                      in_match_pattern,
                                      check_var_reference,
                                      match_lambda_argument_names,
                                      current_stmt_line)
    if not isinstance(list_expr.expr_type, ir2.ListType):
        notes = []
        if isinstance(list_expr, ir2.VarReference):
            lookup_result = compilation_context.get_symbol_definition(list_expr.name)
            assert lookup_result
            notes.append((lookup_result.ast_node, '%s was defined here' % list_expr.name))
        raise CompilationError(compilation_context, ast_node.generators[0].target,
                               'The RHS of a list comprehension should be a list, but this value has type "%s".' % str(list_expr.expr_type),
                               notes=notes)

    child_context = compilation_context.create_child_context()
    child_context.add_symbol(name=generator.target.id,
                             expr_type=list_expr.expr_type.elem_type,
                             definition_ast_node=generator.target,
                             is_only_partially_defined=False,
                             is_function_that_may_throw=False)
    result_elem_expr = expression_ast_to_ir2(ast_node.elt,
                                             child_context,
                                             in_match_pattern,
                                             check_var_reference,
                                             match_lambda_argument_names,
                                             current_stmt_line)

    if isinstance(result_elem_expr.expr_type, ir2.FunctionType):
        raise CompilationError(compilation_context, ast_node,
                               'Creating lists of functions is not supported. The elements of this list have type: %s' % str(result_elem_expr.expr_type))

    return ir2.ListComprehension(list_expr=list_expr,
                                 loop_var=ir2.VarReference(name=generator.target.id,
                                                           expr_type=list_expr.expr_type.elem_type,
                                                           is_global_function=False,
                                                           is_function_that_may_throw=False),
                                 result_elem_expr=result_elem_expr,
                                 loop_body_start_branch=SourceBranch(compilation_context.filename,
                                                                     current_stmt_line,
                                                                     -ast_node.elt.lineno),
                                 loop_exit_branch=SourceBranch(compilation_context.filename,
                                                               -ast_node.elt.lineno,
                                                               current_stmt_line))

def set_comprehension_ast_to_ir2(ast_node: ast.SetComp,
                                 compilation_context: CompilationContext,
                                 in_match_pattern: bool,
                                 check_var_reference: Callable[[ast.Name], None],
                                 match_lambda_argument_names: Set[str],
                                 current_stmt_line: int):
    assert ast_node.generators
    if len(ast_node.generators) > 1:
        raise CompilationError(compilation_context, ast_node.generators[1].target,
                               'Set comprehensions with multiple "for" clauses are not currently supported.')

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'Set comprehensions are not allowed in match patterns')

    [generator] = ast_node.generators
    if generator.ifs:
        raise CompilationError(compilation_context, generator.ifs[0],
                               '"if" clauses in set comprehensions are not currently supported.')
    if not isinstance(generator.target, ast.Name):
        raise CompilationError(compilation_context, generator.target,
                               'Only set comprehensions of the form {... for var_name in ...} are supported.')

    set_expr = expression_ast_to_ir2(generator.iter,
                                     compilation_context,
                                     in_match_pattern,
                                     check_var_reference,
                                     match_lambda_argument_names,
                                     current_stmt_line)
    if not isinstance(set_expr.expr_type, ir2.SetType):
        notes = []
        if isinstance(set_expr, ir2.VarReference):
            lookup_result = compilation_context.get_symbol_definition(set_expr.name)
            assert lookup_result
            notes.append((lookup_result.ast_node, '%s was defined here' % set_expr.name))
        raise CompilationError(compilation_context, ast_node.generators[0].target,
                               'The RHS of a set comprehension should be a set, but this value has type "%s".' % str(set_expr.expr_type),
                               notes=notes)

    child_context = compilation_context.create_child_context()
    child_context.add_symbol(name=generator.target.id,
                             expr_type=set_expr.expr_type.elem_type,
                             definition_ast_node=generator.target,
                             is_only_partially_defined=False,
                             is_function_that_may_throw=False)
    result_elem_expr = expression_ast_to_ir2(ast_node.elt,
                                             child_context,
                                             in_match_pattern,
                                             check_var_reference,
                                             match_lambda_argument_names,
                                             current_stmt_line)

    if isinstance(result_elem_expr.expr_type, ir2.FunctionType):
        raise CompilationError(compilation_context, ast_node,
                               'Creating sets of functions is not supported. The elements of this set have type: %s' % str(result_elem_expr.expr_type))

    return ir2.SetComprehension(set_expr=set_expr,
                                loop_var=ir2.VarReference(name=generator.target.id,
                                                          expr_type=set_expr.expr_type.elem_type,
                                                          is_global_function=False,
                                                          is_function_that_may_throw=False),
                                result_elem_expr=result_elem_expr,
                                loop_body_start_branch=SourceBranch(compilation_context.filename,
                                                                    current_stmt_line,
                                                                    -ast_node.elt.lineno),
                                loop_exit_branch=SourceBranch(compilation_context.filename,
                                                              -ast_node.elt.lineno,
                                                              current_stmt_line))


def add_expression_ast_to_ir2(ast_node: ast.BinOp,
                              compilation_context: CompilationContext,
                              in_match_pattern: bool,
                              check_var_reference: Callable[[ast.Name], None],
                              match_lambda_argument_names: Set[str],
                              current_stmt_line: int):
    lhs = expression_ast_to_ir2(ast_node.left,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names,
                                current_stmt_line)
    rhs = expression_ast_to_ir2(ast_node.right,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names,
                                current_stmt_line)

    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'The "+" operator is not allowed in match patterns')

    if not isinstance(lhs.expr_type, (ir2.IntType, ir2.ListType)):
        raise CompilationError(compilation_context, ast_node.left,
                               'The "+" operator is only supported for ints and lists, but this value has type %s.' % str(lhs.expr_type))

    if not isinstance(rhs.expr_type, (ir2.IntType, ir2.ListType)):
        raise CompilationError(compilation_context, ast_node.right,
                               'The "+" operator is only supported for ints and lists, but this value has type %s.' % str(rhs.expr_type))

    if lhs.expr_type != rhs.expr_type:
        raise CompilationError(compilation_context, ast_node.left,
                               'Type mismatch: the LHS of "+" has type %s but the RHS has type %s.' % (str(lhs.expr_type), str(rhs.expr_type)))

    if lhs.expr_type == ir2.IntType():
        return ir2.IntBinaryOpExpr(lhs=lhs, rhs=rhs, op='+')
    else:
        return ir2.ListConcatExpr(lhs=lhs, rhs=rhs)


def expression_ast_to_ir2(ast_node: ast.AST,
                          compilation_context: CompilationContext,
                          in_match_pattern: bool,
                          check_var_reference: Callable[[ast.Name], None],
                          match_lambda_argument_names: Set[str],
                          current_stmt_line: int) -> ir2.Expr:
    if isinstance(ast_node, ast.NameConstant):
        return name_constant_ast_to_ir2(ast_node,
                                        compilation_context,
                                        in_match_pattern,
                                        check_var_reference,
                                        match_lambda_argument_names)
    elif isinstance(ast_node, ast.Call) and isinstance(ast_node.func, ast.Name) and ast_node.func.id == 'Type':
        return atomic_type_literal_ast_to_ir2(ast_node,
                                              compilation_context,
                                              in_match_pattern,
                                              check_var_reference,
                                              match_lambda_argument_names)
    elif (isinstance(ast_node, ast.Call)
          and isinstance(ast_node.func, ast.Attribute)
          and isinstance(ast_node.func.value, ast.Name) and ast_node.func.value.id == 'Type'):
        return type_factory_method_ast_to_ir2(ast_node,
                                              compilation_context,
                                              in_match_pattern,
                                              check_var_reference,
                                              match_lambda_argument_names,
                                              current_stmt_line)
    elif isinstance(ast_node, ast.Call) and isinstance(ast_node.func, ast.Name) and ast_node.func.id == 'empty_list':
        return empty_list_literal_ast_to_ir2(ast_node,
                                             compilation_context,
                                             in_match_pattern,
                                             check_var_reference,
                                             match_lambda_argument_names)
    elif isinstance(ast_node, ast.Call) and isinstance(ast_node.func, ast.Name) and ast_node.func.id == 'empty_set':
        return empty_set_literal_ast_to_ir2(ast_node,
                                            compilation_context,
                                            in_match_pattern,
                                            check_var_reference,
                                            match_lambda_argument_names)
    elif isinstance(ast_node, ast.Call) and isinstance(ast_node.func, ast.Name) and ast_node.func.id == 'sum':
        return int_iterable_sum_expr_ast_to_ir2(ast_node,
                                                compilation_context,
                                                in_match_pattern,
                                                check_var_reference,
                                                match_lambda_argument_names,
                                                current_stmt_line)
    elif isinstance(ast_node, ast.Call) and isinstance(ast_node.func, ast.Name) and ast_node.func.id == 'all':
        return bool_iterable_all_expr_ast_to_ir2(ast_node,
                                                 compilation_context,
                                                 in_match_pattern,
                                                 check_var_reference,
                                                 current_stmt_line)
    elif isinstance(ast_node, ast.Call) and isinstance(ast_node.func, ast.Name) and ast_node.func.id == 'any':
        return bool_iterable_any_expr_ast_to_ir2(ast_node,
                                                 compilation_context,
                                                 in_match_pattern,
                                                 check_var_reference,
                                                 current_stmt_line)
    elif isinstance(ast_node, ast.Call) and isinstance(ast_node.func, ast.Call) and isinstance(ast_node.func.func, ast.Name) and ast_node.func.func.id == 'match':
        return match_expression_ast_to_ir2(ast_node,
                                           compilation_context,
                                           in_match_pattern,
                                           check_var_reference,
                                           match_lambda_argument_names,
                                           current_stmt_line)
    elif isinstance(ast_node, ast.Call):
        return function_call_ast_to_ir2(ast_node,
                                        compilation_context,
                                        in_match_pattern,
                                        check_var_reference,
                                        match_lambda_argument_names,
                                        current_stmt_line)
    elif isinstance(ast_node, ast.Compare):
        return compare_ast_to_ir2(ast_node,
                                  compilation_context,
                                  in_match_pattern,
                                  check_var_reference,
                                  match_lambda_argument_names,
                                  current_stmt_line)
    elif isinstance(ast_node, ast.Name) and isinstance(ast_node.ctx, ast.Load):
        return var_reference_ast_to_ir2(ast_node,
                                        compilation_context,
                                        in_match_pattern,
                                        check_var_reference,
                                        match_lambda_argument_names)
    elif isinstance(ast_node, ast.List) and isinstance(ast_node.ctx, ast.Load):
        return list_expression_ast_to_ir2(ast_node,
                                          compilation_context,
                                          in_match_pattern,
                                          check_var_reference,
                                          match_lambda_argument_names,
                                          current_stmt_line)
    elif isinstance(ast_node, ast.Set):
        return set_expression_ast_to_ir2(ast_node,
                                         compilation_context,
                                         in_match_pattern,
                                         check_var_reference,
                                         match_lambda_argument_names,
                                         current_stmt_line)
    elif isinstance(ast_node, ast.Attribute) and isinstance(ast_node.ctx, ast.Load):
        return attribute_expression_ast_to_ir2(ast_node,
                                               compilation_context,
                                               in_match_pattern,
                                               check_var_reference,
                                               match_lambda_argument_names,
                                               current_stmt_line)
    elif isinstance(ast_node, ast.Num):
        return number_literal_expression_ast_to_ir2(ast_node,
                                                    compilation_context,
                                                    in_match_pattern,
                                                    check_var_reference,
                                                    match_lambda_argument_names,
                                                    positive=True)
    elif isinstance(ast_node, ast.UnaryOp) and isinstance(ast_node.op, ast.USub) and isinstance(ast_node.operand, ast.Num):
        return number_literal_expression_ast_to_ir2(ast_node.operand,
                                                    compilation_context,
                                                    in_match_pattern,
                                                    check_var_reference,
                                                    match_lambda_argument_names,
                                                    positive=False)
    elif isinstance(ast_node, ast.BoolOp) and isinstance(ast_node.op, ast.And):
        return and_expression_ast_to_ir2(ast_node,
                                         compilation_context,
                                         in_match_pattern,
                                         check_var_reference,
                                         match_lambda_argument_names,
                                         current_stmt_line)
    elif isinstance(ast_node, ast.BoolOp) and isinstance(ast_node.op, ast.Or):
        return or_expression_ast_to_ir2(ast_node,
                                        compilation_context,
                                        in_match_pattern,
                                        check_var_reference,
                                        match_lambda_argument_names,
                                        current_stmt_line)
    elif isinstance(ast_node, ast.UnaryOp) and isinstance(ast_node.op, ast.Not):
        return not_expression_ast_to_ir2(ast_node,
                                         compilation_context,
                                         in_match_pattern,
                                         check_var_reference,
                                         match_lambda_argument_names,
                                         current_stmt_line)
    elif isinstance(ast_node, ast.UnaryOp) and isinstance(ast_node.op, ast.USub):
        return unary_minus_expression_ast_to_ir2(ast_node,
                                                 compilation_context,
                                                 in_match_pattern,
                                                 check_var_reference,
                                                 match_lambda_argument_names,
                                                 current_stmt_line)
    elif isinstance(ast_node, ast.BinOp) and isinstance(ast_node.op, ast.Add):
        return add_expression_ast_to_ir2(ast_node,
                                         compilation_context,
                                         in_match_pattern,
                                         check_var_reference,
                                         match_lambda_argument_names,
                                         current_stmt_line)
    elif isinstance(ast_node, ast.BinOp) and isinstance(ast_node.op, ast.Sub):
        return int_binary_op_expression_ast_to_ir2(ast_node,
                                                   '-',
                                                   compilation_context,
                                                   in_match_pattern,
                                                   check_var_reference,
                                                   match_lambda_argument_names,
                                                   current_stmt_line)
    elif isinstance(ast_node, ast.BinOp) and isinstance(ast_node.op, ast.Mult):
        return int_binary_op_expression_ast_to_ir2(ast_node,
                                                   '*',
                                                   compilation_context,
                                                   in_match_pattern,
                                                   check_var_reference,
                                                   match_lambda_argument_names,
                                                   current_stmt_line)
    elif isinstance(ast_node, ast.BinOp) and isinstance(ast_node.op, ast.FloorDiv):
        return int_binary_op_expression_ast_to_ir2(ast_node,
                                                   '//',
                                                   compilation_context,
                                                   in_match_pattern,
                                                   check_var_reference,
                                                   match_lambda_argument_names,
                                                   current_stmt_line)
    elif isinstance(ast_node, ast.BinOp) and isinstance(ast_node.op, ast.Mod):
        return int_binary_op_expression_ast_to_ir2(ast_node,
                                                   '%',
                                                   compilation_context,
                                                   in_match_pattern,
                                                   check_var_reference,
                                                   match_lambda_argument_names,
                                                   current_stmt_line)
    elif isinstance(ast_node, ast.ListComp):
        return list_comprehension_ast_to_ir2(ast_node,
                                             compilation_context,
                                             in_match_pattern,
                                             check_var_reference,
                                             match_lambda_argument_names,
                                             current_stmt_line)
    elif isinstance(ast_node, ast.SetComp):
        return set_comprehension_ast_to_ir2(ast_node,
                                            compilation_context,
                                            in_match_pattern,
                                            check_var_reference,
                                            match_lambda_argument_names,
                                            current_stmt_line)
    else:
        # raise CompilationError(compilation_context, ast_node, 'This kind of expression is not supported: %s' % ast_to_string(ast_node))
        raise CompilationError(compilation_context, ast_node, 'This kind of expression is not supported.')  # pragma: no cover

def name_constant_ast_to_ir2(ast_node: ast.NameConstant,
                             compilation_context: CompilationContext,
                             in_match_pattern: bool,
                             check_var_reference: Callable[[ast.Name], None],
                             match_lambda_argument_names: Set[str]):
    if isinstance(ast_node.value, bool):
        return ir2.BoolLiteral(value=ast_node.value)
    else:
        raise CompilationError(compilation_context, ast_node, 'NameConstant not supported: ' + str(ast_node.value))  # pragma: no cover

_check_atomic_type_regex = re.compile(r'[A-Za-z_][A-Za-z0-9_]*(::[A-Za-z_][A-Za-z0-9_]*)*')

def _check_atomic_type(ast_node: ast.Str, compilation_context: CompilationContext):
    if not _check_atomic_type_regex.fullmatch(ast_node.s):
        raise CompilationError(compilation_context, ast_node,
                               'Invalid atomic type. Atomic types should be C++ identifiers (possibly namespace-qualified).')

def atomic_type_literal_ast_to_ir2(ast_node: ast.Call,
                                   compilation_context: CompilationContext,
                                   in_match_pattern: bool,
                                   check_var_reference: Callable[[ast.Name], None],
                                   match_lambda_argument_names: Set[str]):
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value,
                               'Keyword arguments are not supported in Type()')

    if len(ast_node.args) != 1:
        raise CompilationError(compilation_context, ast_node, 'Type() takes 1 argument. Got: %s' % len(ast_node.args))
    [arg] = ast_node.args
    if not isinstance(arg, ast.Str):
        raise CompilationError(compilation_context, arg, 'The argument passed to Type should be a string constant.')
    _check_atomic_type(arg, compilation_context)
    return ir2.AtomicTypeLiteral(cpp_type=arg.s)

def _extract_single_type_expr_arg(ast_node: ast.Call,
                                  called_fun_name: str,
                                  compilation_context: CompilationContext,
                                  in_match_pattern: bool,
                                  check_var_reference: Callable[[ast.Name], None],
                                  match_lambda_argument_names: Set[str],
                                  current_stmt_line: int):
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value,
                               'Keyword arguments are not supported in %s()' % called_fun_name)

    if len(ast_node.args) != 1:
        raise CompilationError(compilation_context, ast_node, '%s() takes 1 argument. Got: %s' % (called_fun_name, len(ast_node.args)))
    [arg] = ast_node.args

    arg_ir = expression_ast_to_ir2(arg,
                                   compilation_context,
                                   in_match_pattern,
                                   check_var_reference,
                                   match_lambda_argument_names,
                                   current_stmt_line)
    if arg_ir.expr_type != ir2.TypeType():
        raise CompilationError(compilation_context, arg, 'The argument passed to %s() should have type Type, but was: %s' % (called_fun_name, str(arg_ir.expr_type)))
    return arg_ir

def type_pointer_expr_ast_to_ir2(ast_node: ast.Call,
                                 compilation_context: CompilationContext,
                                 in_match_pattern: bool,
                                 check_var_reference: Callable[[ast.Name], None],
                                 match_lambda_argument_names: Set[str],
                                 current_stmt_line: int):
    return ir2.PointerTypeExpr(_extract_single_type_expr_arg(ast_node,
                                                             'Type.pointer',
                                                             compilation_context,
                                                             in_match_pattern,
                                                             check_var_reference,
                                                             match_lambda_argument_names,
                                                             current_stmt_line))

def type_reference_expr_ast_to_ir2(ast_node: ast.Call,
                                   compilation_context: CompilationContext,
                                   in_match_pattern: bool,
                                   check_var_reference: Callable[[ast.Name], None],
                                   match_lambda_argument_names: Set[str],
                                   current_stmt_line: int):
    return ir2.ReferenceTypeExpr(_extract_single_type_expr_arg(ast_node,
                                                               'Type.reference',
                                                               compilation_context,
                                                               in_match_pattern,
                                                               check_var_reference,
                                                               match_lambda_argument_names,
                                                               current_stmt_line))

def type_rvalue_reference_expr_ast_to_ir2(ast_node: ast.Call,
                                          compilation_context: CompilationContext,
                                          in_match_pattern: bool,
                                          check_var_reference: Callable[[ast.Name], None],
                                          match_lambda_argument_names: Set[str],
                                          current_stmt_line: int):
    return ir2.RvalueReferenceTypeExpr(_extract_single_type_expr_arg(ast_node,
                                                                     'Type.rvalue_reference',
                                                                     compilation_context,
                                                                     in_match_pattern,
                                                                     check_var_reference,
                                                                     match_lambda_argument_names,
                                                                     current_stmt_line))

def const_type_expr_ast_to_ir2(ast_node: ast.Call,
                               compilation_context: CompilationContext,
                               in_match_pattern: bool,
                               check_var_reference: Callable[[ast.Name], None],
                               match_lambda_argument_names: Set[str],
                               current_stmt_line: int):
    return ir2.ConstTypeExpr(_extract_single_type_expr_arg(ast_node,
                                                           'Type.const',
                                                           compilation_context,
                                                           in_match_pattern,
                                                           check_var_reference,
                                                           match_lambda_argument_names,
                                                           current_stmt_line))

def type_array_expr_ast_to_ir2(ast_node: ast.Call,
                               compilation_context: CompilationContext,
                               in_match_pattern: bool,
                               check_var_reference: Callable[[ast.Name], None],
                               match_lambda_argument_names: Set[str],
                               current_stmt_line: int):
    return ir2.ArrayTypeExpr(_extract_single_type_expr_arg(ast_node,
                                                           'Type.array',
                                                           compilation_context,
                                                           in_match_pattern,
                                                           check_var_reference,
                                                           match_lambda_argument_names,
                                                           current_stmt_line))

def function_type_expr_ast_to_ir2(ast_node: ast.Call,
                                  compilation_context: CompilationContext,
                                  in_match_pattern: bool,
                                  check_var_reference: Callable[[ast.Name], None],
                                  match_lambda_argument_names: Set[str],
                                  current_stmt_line: int):
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value,
                               'Keyword arguments are not supported in Type.function()')

    if len(ast_node.args) != 2:
        raise CompilationError(compilation_context, ast_node, 'Type.function() takes 2 arguments. Got: %s' % len(ast_node.args))
    [return_type_ast_node, arg_list_ast_node] = ast_node.args

    return_type_ir = expression_ast_to_ir2(return_type_ast_node,
                                           compilation_context,
                                           in_match_pattern,
                                           check_var_reference,
                                           match_lambda_argument_names,
                                           current_stmt_line)
    if return_type_ir.expr_type != ir2.TypeType():
        raise CompilationError(compilation_context, return_type_ast_node,
                               'The first argument passed to Type.function should have type Type, but was: %s' % str(return_type_ir.expr_type))

    arg_list_ir = expression_ast_to_ir2(arg_list_ast_node,
                                        compilation_context,
                                        in_match_pattern,
                                        check_var_reference,
                                        match_lambda_argument_names,
                                        current_stmt_line)
    if arg_list_ir.expr_type != ir2.ListType(ir2.TypeType()):
        raise CompilationError(compilation_context, arg_list_ast_node,
                               'The second argument passed to Type.function should have type List[Type], but was: %s' % str(arg_list_ir.expr_type))

    return ir2.FunctionTypeExpr(return_type_expr=return_type_ir,
                                arg_list_expr=arg_list_ir)

def template_instantiation_ast_to_ir2(ast_node: ast.Call,
                                      compilation_context: CompilationContext,
                                      in_match_pattern: bool,
                                      check_var_reference: Callable[[ast.Name], None],
                                      match_lambda_argument_names: Set[str],
                                      current_stmt_line: int):
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value,
                               'Keyword arguments are not supported in Type.template_instantiation()')

    if len(ast_node.args) != 2:
        raise CompilationError(compilation_context, ast_node, 'Type.template_instantiation() takes 2 arguments. Got: %s' % len(ast_node.args))
    [template_atomic_cpp_type_ast_node, arg_list_ast_node] = ast_node.args

    if not isinstance(template_atomic_cpp_type_ast_node, ast.Str):
        raise CompilationError(compilation_context, template_atomic_cpp_type_ast_node,
                               'The first argument passed to Type.template_instantiation should be a string')
    _check_atomic_type(template_atomic_cpp_type_ast_node, compilation_context)

    arg_list_ir = expression_ast_to_ir2(arg_list_ast_node,
                                        compilation_context,
                                        in_match_pattern,
                                        check_var_reference,
                                        match_lambda_argument_names,
                                        current_stmt_line)
    if arg_list_ir.expr_type != ir2.ListType(ir2.TypeType()):
        raise CompilationError(compilation_context, arg_list_ast_node,
                               'The second argument passed to Type.template_instantiation should have type List[Type], but was: %s' % str(arg_list_ir.expr_type))

    return ir2.TemplateInstantiationExpr(template_atomic_cpp_type=template_atomic_cpp_type_ast_node.s,
                                         arg_list_expr=arg_list_ir)

_cxx_identifier_regex = re.compile(r'[A-Za-z_][A-Za-z0-9_]*')

def template_member_access_ast_to_ir2(ast_node: ast.Call,
                                      compilation_context: CompilationContext,
                                      in_match_pattern: bool,
                                      check_var_reference: Callable[[ast.Name], None],
                                      match_lambda_argument_names: Set[str],
                                      current_stmt_line: int):
    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'Type.template_member() is not allowed in match patterns')

    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value,
                               'Keyword arguments are not supported in Type.template_member()')

    if len(ast_node.args) != 3:
        raise CompilationError(compilation_context, ast_node, 'Type.template_member() takes 3 arguments. Got: %s' % len(ast_node.args))
    [class_type_ast_node, member_name_ast_node, arg_list_ast_node] = ast_node.args

    class_type_expr_ir = expression_ast_to_ir2(class_type_ast_node,
                                               compilation_context,
                                               in_match_pattern,
                                               check_var_reference,
                                               match_lambda_argument_names,
                                               current_stmt_line)
    if class_type_expr_ir.expr_type != ir2.TypeType():
        raise CompilationError(compilation_context, class_type_ast_node,
                               'The first argument passed to Type.template_member should have type Type, but was: %s' % str(class_type_expr_ir.expr_type))

    if not isinstance(member_name_ast_node, ast.Str):
        raise CompilationError(compilation_context, member_name_ast_node,
                               'The second argument passed to Type.template_member should be a string')
    if not _cxx_identifier_regex.fullmatch(member_name_ast_node.s):
        raise CompilationError(compilation_context, member_name_ast_node,
                               'The second argument passed to Type.template_member should be a valid C++ identifier')

    arg_list_ir = expression_ast_to_ir2(arg_list_ast_node,
                                        compilation_context,
                                        in_match_pattern,
                                        check_var_reference,
                                        match_lambda_argument_names,
                                        current_stmt_line)
    if arg_list_ir.expr_type != ir2.ListType(ir2.TypeType()):
        raise CompilationError(compilation_context, arg_list_ast_node,
                               'The third argument passed to Type.template_member should have type List[Type], but was: %s' % str(arg_list_ir.expr_type))

    return ir2.TemplateMemberAccessExpr(class_type_expr=class_type_expr_ir,
                                        member_name=member_name_ast_node.s,
                                        arg_list_expr=arg_list_ir)

def type_factory_method_ast_to_ir2(ast_node: ast.Call,
                                   compilation_context: CompilationContext,
                                   in_match_pattern: bool,
                                   check_var_reference: Callable[[ast.Name], None],
                                   match_lambda_argument_names: Set[str],
                                   current_stmt_line: int):
    assert isinstance(ast_node, ast.Call)
    assert isinstance(ast_node.func, ast.Attribute)
    assert isinstance(ast_node.func.value, ast.Name)
    assert ast_node.func.value.id == 'Type'

    if ast_node.func.attr == 'pointer':
        return type_pointer_expr_ast_to_ir2(ast_node, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
    elif ast_node.func.attr == 'reference':
        return type_reference_expr_ast_to_ir2(ast_node, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
    elif ast_node.func.attr == 'rvalue_reference':
        return type_rvalue_reference_expr_ast_to_ir2(ast_node, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
    elif ast_node.func.attr == 'const':
        return const_type_expr_ast_to_ir2(ast_node, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
    elif ast_node.func.attr == 'array':
        return type_array_expr_ast_to_ir2(ast_node, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
    elif ast_node.func.attr == 'function':
        return function_type_expr_ast_to_ir2(ast_node, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
    elif ast_node.func.attr == 'template_instantiation':
        return template_instantiation_ast_to_ir2(ast_node, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
    elif ast_node.func.attr == 'template_member':
        return template_member_access_ast_to_ir2(ast_node, compilation_context, in_match_pattern, check_var_reference, match_lambda_argument_names, current_stmt_line)
    else:
        raise CompilationError(compilation_context, ast_node,
                               'Undefined Type factory method')

def empty_list_literal_ast_to_ir2(ast_node: ast.Call,
                                  compilation_context: CompilationContext,
                                  in_match_pattern: bool,
                                  check_var_reference: Callable[[ast.Name], None],
                                  match_lambda_argument_names: Set[str]):
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value, 'Keyword arguments are not supported.')
    if len(ast_node.args) != 1:
        raise CompilationError(compilation_context, ast_node, 'empty_list() takes 1 argument. Got: %s' % len(ast_node.args))
    [arg] = ast_node.args
    elem_type = type_declaration_ast_to_ir2_expression_type(arg, compilation_context)
    return ir2.ListExpr(elem_type=elem_type, elem_exprs=(), list_extraction_expr=None)

def empty_set_literal_ast_to_ir2(ast_node: ast.Call,
                                 compilation_context: CompilationContext,
                                 in_match_pattern: bool,
                                 check_var_reference: Callable[[ast.Name], None],
                                 match_lambda_argument_names: Set[str]):
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value, 'Keyword arguments are not supported.')
    if len(ast_node.args) != 1:
        raise CompilationError(compilation_context, ast_node, 'empty_set() takes 1 argument. Got: %s' % len(ast_node.args))
    [arg] = ast_node.args
    elem_type = type_declaration_ast_to_ir2_expression_type(arg, compilation_context)
    return ir2.SetExpr(elem_type=elem_type, elem_exprs=())

def int_iterable_sum_expr_ast_to_ir2(ast_node: ast.Call,
                                     compilation_context: CompilationContext,
                                     in_match_pattern: bool,
                                     check_var_reference: Callable[[ast.Name], None],
                                     match_lambda_argument_names: Set[str],
                                     current_stmt_line: int):
    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'sum() is not allowed in match patterns')

    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value, 'Keyword arguments are not supported.')
    if len(ast_node.args) != 1:
        raise CompilationError(compilation_context, ast_node, 'sum() takes 1 argument. Got: %s' % len(ast_node.args))
    [arg] = ast_node.args
    arg_expr = expression_ast_to_ir2(arg,
                                     compilation_context,
                                     in_match_pattern,
                                     check_var_reference,
                                     match_lambda_argument_names,
                                     current_stmt_line)
    if not (isinstance(arg_expr.expr_type, (ir2.ListType, ir2.SetType)) and isinstance(arg_expr.expr_type.elem_type, ir2.IntType)):
        notes = []
        if isinstance(arg_expr, ir2.VarReference):
            lookup_result = compilation_context.get_symbol_definition(arg_expr.name)
            assert lookup_result
            assert not lookup_result.is_only_partially_defined
            notes.append((lookup_result.ast_node, '%s was defined here' % arg_expr.name))
        raise CompilationError(compilation_context, arg,
                               'The argument of sum() must have type List[int] or Set[int]. Got type: %s' % str(arg_expr.expr_type),
                               notes=notes)
    if isinstance(arg_expr.expr_type, ir2.ListType):
        return ir2.IntListSumExpr(list_expr=arg_expr)
    else:
        return ir2.IntSetSumExpr(set_expr=arg_expr)

def bool_iterable_all_expr_ast_to_ir2(ast_node: ast.Call,
                                      compilation_context: CompilationContext,
                                      in_match_pattern: bool,
                                      check_var_reference: Callable[[ast.Name], None],
                                      current_stmt_line: int):
    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'all() is not allowed in match patterns')

    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value, 'Keyword arguments are not supported.')
    if len(ast_node.args) != 1:
        raise CompilationError(compilation_context, ast_node, 'all() takes 1 argument. Got: %s' % len(ast_node.args))
    [arg] = ast_node.args
    arg_expr = expression_ast_to_ir2(arg,
                                     compilation_context,
                                     in_match_pattern,
                                     check_var_reference,
                                     match_lambda_argument_names=set(),
                                     current_stmt_line=current_stmt_line)
    if not (isinstance(arg_expr.expr_type, (ir2.ListType, ir2.SetType)) and isinstance(arg_expr.expr_type.elem_type, ir2.BoolType)):
        notes = []
        if isinstance(arg_expr, ir2.VarReference):
            lookup_result = compilation_context.get_symbol_definition(arg_expr.name)
            assert lookup_result
            assert not lookup_result.is_only_partially_defined
            notes.append((lookup_result.ast_node, '%s was defined here' % arg_expr.name))
        raise CompilationError(compilation_context, arg,
                               'The argument of all() must have type List[bool] or Set[bool]. Got type: %s' % str(arg_expr.expr_type),
                               notes=notes)
    if isinstance(arg_expr.expr_type, ir2.ListType):
        return ir2.BoolListAllExpr(list_expr=arg_expr)
    else:
        return ir2.BoolSetAllExpr(set_expr=arg_expr)

def bool_iterable_any_expr_ast_to_ir2(ast_node: ast.Call,
                                      compilation_context: CompilationContext,
                                      in_match_pattern: bool,
                                      check_var_reference: Callable[[ast.Name], None],
                                      current_stmt_line: int):
    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'any() is not allowed in match patterns')

    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value, 'Keyword arguments are not supported.')
    if len(ast_node.args) != 1:
        raise CompilationError(compilation_context, ast_node, 'any() takes 1 argument. Got: %s' % len(ast_node.args))
    [arg] = ast_node.args
    arg_expr = expression_ast_to_ir2(arg,
                                     compilation_context,
                                     in_match_pattern,
                                     check_var_reference,
                                     match_lambda_argument_names=set(),
                                     current_stmt_line=current_stmt_line)
    if not (isinstance(arg_expr.expr_type, (ir2.ListType, ir2.SetType)) and isinstance(arg_expr.expr_type.elem_type, ir2.BoolType)):
        notes = []
        if isinstance(arg_expr, ir2.VarReference):
            lookup_result = compilation_context.get_symbol_definition(arg_expr.name)
            assert lookup_result
            assert not lookup_result.is_only_partially_defined
            notes.append((lookup_result.ast_node, '%s was defined here' % arg_expr.name))
        raise CompilationError(compilation_context, arg,
                               'The argument of any() must have type List[bool] or Set[bool]. Got type: %s' % str(arg_expr.expr_type),
                               notes=notes)
    if isinstance(arg_expr.expr_type, ir2.ListType):
        return ir2.BoolListAnyExpr(list_expr=arg_expr)
    else:
        return ir2.BoolSetAnyExpr(set_expr=arg_expr)

def _is_structural_equality_check_supported_for_type(expr_type: ir2.ExprType):
    if isinstance(expr_type, ir2.BoolType):
        return True
    elif isinstance(expr_type, ir2.IntType):
        return True
    elif isinstance(expr_type, ir2.TypeType):
        return True
    elif isinstance(expr_type, ir2.FunctionType):
        return False
    elif isinstance(expr_type, ir2.ListType):
        return _is_structural_equality_check_supported_for_type(expr_type.elem_type)
    elif isinstance(expr_type, ir2.SetType):
        return False
    elif isinstance(expr_type, ir2.CustomType):
        return all(_is_structural_equality_check_supported_for_type(arg_type.expr_type)
                   for arg_type in expr_type.arg_types)
    else:
        raise NotImplementedError('Unexpected type: %s' % expr_type.__class__.__name__)

def _is_equality_check_supported_for_type(expr_type: ir2.ExprType):
    if isinstance(expr_type, ir2.SetType):
        return _is_structural_equality_check_supported_for_type(expr_type.elem_type)
    else:
        return _is_structural_equality_check_supported_for_type(expr_type)

def eq_ast_to_ir2(lhs_node: ast.AST,
                  rhs_node: ast.AST,
                  compilation_context: CompilationContext,
                  in_match_pattern: bool,
                  check_var_reference: Callable[[ast.Name], None],
                  current_stmt_line: int):
    assert not in_match_pattern

    lhs = expression_ast_to_ir2(lhs_node,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names=set(),
                                current_stmt_line=current_stmt_line)
    rhs = expression_ast_to_ir2(rhs_node,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names=set(),
                                current_stmt_line=current_stmt_line)
    if lhs.expr_type != rhs.expr_type:
        raise CompilationError(compilation_context, lhs_node, 'Type mismatch in ==: %s vs %s' % (
            str(lhs.expr_type), str(rhs.expr_type)))
    if not _is_equality_check_supported_for_type(lhs.expr_type):
        raise CompilationError(compilation_context, lhs_node, 'Type not supported in equality comparison: ' + str(lhs.expr_type))
    return ir2.EqualityComparison(lhs=lhs, rhs=rhs)

def not_eq_ast_to_ir2(lhs_node: ast.AST,
                      rhs_node: ast.AST,
                      compilation_context: CompilationContext,
                      in_match_pattern: bool,
                      check_var_reference: Callable[[ast.Name], None],
                      current_stmt_line: int):
    assert not in_match_pattern

    lhs = expression_ast_to_ir2(lhs_node,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names=set(),
                                current_stmt_line=current_stmt_line)
    rhs = expression_ast_to_ir2(rhs_node,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names=set(),
                                current_stmt_line=current_stmt_line)
    if lhs.expr_type != rhs.expr_type:
        raise CompilationError(compilation_context, lhs_node, 'Type mismatch in !=: %s vs %s' % (
            str(lhs.expr_type), str(rhs.expr_type)))
    if not _is_equality_check_supported_for_type(lhs.expr_type):
        raise CompilationError(compilation_context, lhs_node, 'Type not supported in equality comparison: ' + str(lhs.expr_type))
    return ir2.NotExpr(expr=ir2.EqualityComparison(lhs=lhs, rhs=rhs))

def in_ast_to_ir2(lhs_node: ast.AST,
                  rhs_node: ast.AST,
                  compilation_context: CompilationContext,
                  in_match_pattern: bool,
                  check_var_reference: Callable[[ast.Name], None],
                  current_stmt_line: int):
    assert not in_match_pattern

    lhs = expression_ast_to_ir2(lhs_node,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names=set(),
                                current_stmt_line=current_stmt_line)
    rhs = expression_ast_to_ir2(rhs_node,
                                compilation_context,
                                in_match_pattern,
                                check_var_reference,
                                match_lambda_argument_names=set(),
                                current_stmt_line=current_stmt_line)

    if isinstance(rhs.expr_type, ir2.ListType):
        rhs_elem_type = rhs.expr_type.elem_type
    elif isinstance(rhs.expr_type, ir2.SetType):
        rhs_elem_type = rhs.expr_type.elem_type
    else:
        raise CompilationError(compilation_context, rhs_node,
                               'The object on the RHS of "in" must be a list or a set, but found type: ' + str(rhs.expr_type))

    if lhs.expr_type != rhs_elem_type:
        raise CompilationError(compilation_context, lhs_node, 'Type mismatch in in: %s vs %s' % (
            str(lhs.expr_type), str(rhs.expr_type)))
    if not _is_equality_check_supported_for_type(lhs.expr_type):
        raise CompilationError(compilation_context, lhs_node, 'Type not supported in equality comparison (for the "in" operator): ' + str(lhs.expr_type))
    return ir2.InExpr(lhs=lhs, rhs=rhs)

def _construct_note_diagnostic_for_function_signature(function_lookup_result: SymbolLookupResult):
    return function_lookup_result.ast_node, 'The definition of %s was here' % function_lookup_result.symbol.name

def function_call_ast_to_ir2(ast_node: ast.Call,
                             compilation_context: CompilationContext,
                             in_match_pattern: bool,
                             check_var_reference: Callable[[ast.Name], None],
                             match_lambda_argument_names: Set[str],
                             current_stmt_line: int):
    # TODO: allow calls to custom types' constructors.
    if in_match_pattern:
        raise CompilationError(compilation_context, ast_node,
                               'Function calls are not allowed in match patterns')

    fun_expr = expression_ast_to_ir2(ast_node.func,
                                     compilation_context,
                                     in_match_pattern,
                                     check_var_reference,
                                     match_lambda_argument_names,
                                     current_stmt_line)
    if not isinstance(fun_expr.expr_type, ir2.FunctionType):
        raise CompilationError(compilation_context, ast_node,
                               'Attempting to call an object that is not a function. It has type: %s' % str(fun_expr.expr_type))

    if ast_node.keywords and ast_node.args:
        raise CompilationError(compilation_context, ast_node, 'Function calls with a mix of keyword and non-keyword arguments are not supported. Please choose either style.')

    if ast_node.keywords:
        if not isinstance(fun_expr, ir2.VarReference):
            raise CompilationError(compilation_context, ast_node.keywords[0].value,
                                   'Keyword arguments can only be used when calling a specific function or constructing a specific type, not when calling other callable objects. Please switch to non-keyword arguments.')
        lookup_result = compilation_context.get_symbol_definition(fun_expr.name)
        assert lookup_result
        assert not lookup_result.is_only_partially_defined
        if not lookup_result.symbol.expr_type.argnames:
            raise CompilationError(compilation_context, ast_node.keywords[0].value,
                                   'Keyword arguments can only be used when calling a specific function or constructing a specific type, not when calling other callable objects. Please switch to non-keyword arguments.')


        arg_expr_by_name = {keyword_arg.arg: expression_ast_to_ir2(keyword_arg.value,
                                                                   compilation_context,
                                                                   in_match_pattern,
                                                                   check_var_reference,
                                                                   match_lambda_argument_names,
                                                                   current_stmt_line)
                            for keyword_arg in ast_node.keywords}
        formal_arg_names = set(lookup_result.symbol.expr_type.argnames)
        specified_nonexisting_args = arg_expr_by_name.keys() - formal_arg_names
        missing_args = formal_arg_names - arg_expr_by_name.keys()
        if specified_nonexisting_args and missing_args:
            raise CompilationError(compilation_context, ast_node,
                                   'Incorrect arguments in call to %s. Missing arguments: {%s}. Specified arguments that don\'t exist: {%s}' % (
                                       fun_expr.name, ', '.join(sorted(missing_args)), ', '.join(sorted(specified_nonexisting_args))),
                                   notes=[_construct_note_diagnostic_for_function_signature(lookup_result)])
        elif specified_nonexisting_args:
            raise CompilationError(compilation_context, ast_node,
                                   'Incorrect arguments in call to %s. Specified arguments that don\'t exist: {%s}' % (
                                       fun_expr.name, ', '.join(sorted(specified_nonexisting_args))),
                                   notes=[_construct_note_diagnostic_for_function_signature(lookup_result)])
        elif missing_args:
            raise CompilationError(compilation_context, ast_node,
                                   'Incorrect arguments in call to %s. Missing arguments: {%s}' % (
                                       fun_expr.name, ', '.join(sorted(missing_args))),
                                   notes=[_construct_note_diagnostic_for_function_signature(lookup_result)])

        args = tuple(arg_expr_by_name[arg]
                     for arg in lookup_result.symbol.expr_type.argnames)

        for expr, keyword_arg, arg_type, arg_decl_ast_node in zip(args, ast_node.keywords, fun_expr.expr_type.argtypes, lookup_result.symbol.expr_type.argnames):
            if expr.expr_type != arg_type:
                notes = [_construct_note_diagnostic_for_function_signature(lookup_result)]
                if isinstance(keyword_arg.value, ast.Name):
                    lookup_result = compilation_context.get_symbol_definition(keyword_arg.value.id)
                    assert not lookup_result.is_only_partially_defined
                    notes.append((lookup_result.ast_node, 'The definition of %s was here' % keyword_arg.value.id))

                raise CompilationError(compilation_context, keyword_arg.value,
                                       'Type mismatch for argument %s: expected type %s but was: %s' % (
                                           keyword_arg.arg, str(arg_type), str(expr.expr_type)),
                                       notes=notes)
    else:
        ast_node_args = ast_node.args or []
        args = tuple(expression_ast_to_ir2(arg_node,
                                      compilation_context,
                                      in_match_pattern,
                                      check_var_reference,
                                      match_lambda_argument_names,
                                      current_stmt_line)
                     for arg_node in ast_node_args)
        if len(args) != len(fun_expr.expr_type.argtypes):
            if isinstance(ast_node.func, ast.Name):
                lookup_result = compilation_context.get_symbol_definition(ast_node.func.id)
                assert lookup_result
                assert not lookup_result.is_only_partially_defined
                raise CompilationError(compilation_context, ast_node,
                                       'Argument number mismatch in function call to %s: got %s arguments, expected %s' % (
                                           ast_node.func.id, len(args), len(fun_expr.expr_type.argtypes)),
                                       notes=[_construct_note_diagnostic_for_function_signature(lookup_result)])
            else:
                raise CompilationError(compilation_context, ast_node,
                                       'Argument number mismatch in function call: got %s arguments, expected %s' % (
                                           len(args), len(fun_expr.expr_type.argtypes)))

        for arg_index, (expr, expr_ast_node, arg_type) in enumerate(zip(args, ast_node_args, fun_expr.expr_type.argtypes)):
            if expr.expr_type != arg_type:
                notes = []

                if isinstance(ast_node.func, ast.Name):
                    lookup_result = compilation_context.get_symbol_definition(ast_node.func.id)
                    assert lookup_result
                    notes.append(_construct_note_diagnostic_for_function_signature(lookup_result))

                if isinstance(expr_ast_node, ast.Name):
                    lookup_result = compilation_context.get_symbol_definition(expr_ast_node.id)
                    assert lookup_result
                    notes.append((lookup_result.ast_node, 'The definition of %s was here' % expr_ast_node.id))

                raise CompilationError(compilation_context, expr_ast_node,
                                       'Type mismatch for argument %s: expected type %s but was: %s' % (
                                           arg_index, str(arg_type), str(expr.expr_type)),
                                       notes=notes)

    return ir2.FunctionCall(fun_expr=fun_expr,
                            args=args,
                            may_throw=(not isinstance(fun_expr, ir2.VarReference)
                                       or fun_expr.is_function_that_may_throw))

def var_reference_ast_to_ir2(ast_node: ast.Name,
                             compilation_context: CompilationContext,
                             in_match_pattern: bool,
                             check_var_reference: Callable[[ast.Name], None],
                             match_lambda_argument_names: Set[str]):
    assert isinstance(ast_node.ctx, ast.Load)
    check_var_reference(ast_node)

    lookup_result = compilation_context.get_symbol_definition(ast_node.id)

    # In match patterns, variables get defined at the first point of use, either here or in list_expression_ast_to_ir2().
    if in_match_pattern and ast_node.id in match_lambda_argument_names:
        if lookup_result:
            if lookup_result.symbol.expr_type != ir2.TypeType():
                raise CompilationError(compilation_context, ast_node,
                                       'Can\'t match %s as a Type because it was already used to match a List[Type]' % ast_node.id,
                                       notes=[(lookup_result.ast_node, 'A previous match as a List[Type] was here')])
        else:
            compilation_context.add_symbol(name=ast_node.id,
                                           expr_type=ir2.TypeType(),
                                           definition_ast_node=ast_node,
                                           is_only_partially_defined=False,
                                           is_function_that_may_throw=False)
            lookup_result = compilation_context.get_symbol_definition(ast_node.id)

    if lookup_result:
        if lookup_result.is_only_partially_defined:
            raise CompilationError(compilation_context, ast_node,
                                   'Reference to a variable that may or may not have been initialized (depending on which branch was taken)',
                                   notes=[(lookup_result.ast_node, '%s might have been initialized here' % ast_node.id)])
        return ir2.VarReference(expr_type=lookup_result.symbol.expr_type,
                                name=lookup_result.symbol.name,
                                is_global_function=lookup_result.symbol_table.parent is None,
                                is_function_that_may_throw=(isinstance(lookup_result.symbol.expr_type, ir2.FunctionType)
                                                            and lookup_result.symbol.is_function_that_may_throw),
                                source_module=lookup_result.symbol.source_module)
    else:
        definition_ast_node = compilation_context.get_partial_function_definition(ast_node.id)
        if definition_ast_node:
            if compilation_context.current_function_name == ast_node.id:
                raise CompilationError(compilation_context, ast_node, 'Recursive function references are only allowed if the return type is declared explicitly.',
                                       notes=[(definition_ast_node, '%s was defined here' % ast_node.id)])
            else:
                raise CompilationError(compilation_context, ast_node, 'Reference to a function whose return type hasn\'t been determined yet. Please add a return type declaration in %s or move its declaration before its use.' % ast_node.id,
                                       notes=[(definition_ast_node, '%s was defined here' % ast_node.id)])
        else:
            raise CompilationError(compilation_context, ast_node, 'Reference to undefined variable/function')

def list_expression_ast_to_ir2(ast_node: ast.List,
                               compilation_context: CompilationContext,
                               in_match_pattern: bool,
                               check_var_reference: Callable[[ast.Name], None],
                               match_lambda_argument_names: Set[str],
                               current_stmt_line: int):

    elem_exprs = []
    list_extraction_expr = None
    for index, elem_expr_node in enumerate(ast_node.elts):
        if isinstance(elem_expr_node, ast.Starred) and in_match_pattern:
            # [..., *Ts]
            if not isinstance(elem_expr_node.value, ast.Name):
                raise CompilationError(compilation_context, elem_expr_node,
                                       'List extraction is only allowed with an identifier, e.g. [*Ts]')

            if index != len(ast_node.elts) - 1:
                raise CompilationError(compilation_context, elem_expr_node,
                                       'List extraction is only allowed at the end of the list')

            list_extraction_var_name = elem_expr_node.value.id
            if list_extraction_var_name not in match_lambda_argument_names:
                raise CompilationError(compilation_context, elem_expr_node.value,
                                       'List extraction can only be used with type variables that are lambda arguments of this match()')

            existing_symbol = compilation_context.get_symbol_definition(elem_expr_node.value.id)
            if existing_symbol:
                if existing_symbol.symbol.expr_type != ir2.ListType(ir2.TypeType()):
                    raise CompilationError(compilation_context, elem_expr_node.value,
                                           'List extraction can\'t be used on %s because it was already used to match a Type' % elem_expr_node.value.id,
                                           notes=[(existing_symbol.ast_node, 'A previous match as a Type was here')])
            else:
                compilation_context.add_symbol(name=elem_expr_node.value.id,
                                               expr_type=ir2.ListType(ir2.TypeType()),
                                               definition_ast_node=elem_expr_node.value,
                                               is_only_partially_defined=False,
                                               is_function_that_may_throw=False)
            list_extraction_expr = ir2.VarReference(expr_type=ir2.ListType(ir2.TypeType()),
                                                    name=elem_expr_node.value.id,
                                                    is_global_function=False,
                                                    is_function_that_may_throw=False)
        else:
            elem_exprs.append(expression_ast_to_ir2(elem_expr_node,
                                                    compilation_context,
                                                    in_match_pattern,
                                                    check_var_reference,
                                                    match_lambda_argument_names,
                                                    current_stmt_line))

    if len(elem_exprs) > 0:
        elem_type = elem_exprs[0].expr_type
    elif list_extraction_expr:
        assert isinstance(list_extraction_expr.expr_type, ir2.ListType)
        elem_type = list_extraction_expr.expr_type.elem_type
    else:
        raise CompilationError(compilation_context, ast_node, 'Untyped empty lists are not supported. Please import empty_list from pytmp and then write e.g. empty_list(int) to create an empty list of ints.')

    for elem_expr, elem_expr_ast_node in zip(elem_exprs, ast_node.elts):
        if elem_expr.expr_type != elem_type:
            raise CompilationError(compilation_context, elem_expr_ast_node,
                                   'Found different types in list elements, this is not supported. The type of this element was %s instead of %s' % (
                                       str(elem_expr.expr_type), str(elem_type)),
                                   notes=[(ast_node.elts[0], 'A previous list element with type %s was here.' % str(elem_type))])
    if isinstance(elem_type, ir2.FunctionType):
        raise CompilationError(compilation_context, ast_node,
                               'Creating lists of functions is not supported. The elements of this list have type: %s' % str(elem_type))

    return ir2.ListExpr(elem_type=elem_type,
                        elem_exprs=tuple(elem_exprs),
                        list_extraction_expr=list_extraction_expr)

def set_expression_ast_to_ir2(ast_node: ast.Set,
                              compilation_context: CompilationContext,
                              in_match_pattern: bool,
                              check_var_reference: Callable[[ast.Name], None],
                              match_lambda_argument_names: Set[str],
                              current_stmt_line: int):
    elem_exprs = tuple(expression_ast_to_ir2(elem_expr_node,
                                        compilation_context,
                                        in_match_pattern,
                                        check_var_reference,
                                        match_lambda_argument_names,
                                        current_stmt_line)
                       for elem_expr_node in ast_node.elts)
    assert elem_exprs
    elem_type = elem_exprs[0].expr_type
    for elem_expr, elem_expr_ast_node in zip(elem_exprs, ast_node.elts):
        if elem_expr.expr_type != elem_type:
            raise CompilationError(compilation_context, elem_expr_ast_node,
                                   'Found different types in set elements, this is not supported. The type of this element was %s instead of %s' % (
                                       str(elem_expr.expr_type), str(elem_type)),
                                   notes=[(ast_node.elts[0], 'A previous set element with type %s was here.' % str(elem_type))])
    if isinstance(elem_type, ir2.FunctionType):
        raise CompilationError(compilation_context, ast_node,
                               'Creating sets of functions is not supported. The elements of this set have type: %s' % str(elem_type))

    return ir2.SetExpr(elem_type=elem_type, elem_exprs=elem_exprs)

def type_declaration_ast_to_ir2_expression_type(ast_node: ast.AST, compilation_context: CompilationContext):
    if isinstance(ast_node, ast.Name) and isinstance(ast_node.ctx, ast.Load):
        if ast_node.id == 'bool':
            return ir2.BoolType()
        elif ast_node.id == 'int':
            return ir2.IntType()
        elif ast_node.id == 'Type':
            return ir2.TypeType()
        else:
            lookup_result = compilation_context.get_type_symbol_definition(ast_node.id)
            if lookup_result:
                return lookup_result.symbol.expr_type
            else:
                raise CompilationError(compilation_context, ast_node, 'Unsupported (or undefined) type: ' + ast_node.id)

    if (isinstance(ast_node, ast.Subscript)
        and isinstance(ast_node.value, ast.Name)
        and isinstance(ast_node.value.ctx, ast.Load)
        and isinstance(ast_node.ctx, ast.Load)
        and isinstance(ast_node.slice, ast.Index)):
        if ast_node.value.id == 'List':
            return ir2.ListType(type_declaration_ast_to_ir2_expression_type(ast_node.slice.value, compilation_context))
        if ast_node.value.id == 'Set':
            return ir2.SetType(type_declaration_ast_to_ir2_expression_type(ast_node.slice.value, compilation_context))
        elif (ast_node.value.id == 'Callable'
              and isinstance(ast_node.slice.value, ast.Tuple)
              and len(ast_node.slice.value.elts) == 2
              and isinstance(ast_node.slice.value.elts[0], ast.List)
              and isinstance(ast_node.slice.value.elts[0].ctx, ast.Load)
              and all(isinstance(elem, ast.Name) and isinstance(elem.ctx, ast.Load)
                      for elem in ast_node.slice.value.elts[0].elts)):
            return ir2.FunctionType(
                argtypes=tuple(type_declaration_ast_to_ir2_expression_type(arg_type_decl, compilation_context)
                               for arg_type_decl in ast_node.slice.value.elts[0].elts),
                argnames=None,
                returns=type_declaration_ast_to_ir2_expression_type(ast_node.slice.value.elts[1], compilation_context))

    raise CompilationError(compilation_context, ast_node, 'Unsupported type declaration.')

# Checks if the statement is of the form:
# self.some_field = <expr>
def _is_class_field_initialization(ast_node: ast.AST):
    return (isinstance(ast_node, ast.Assign)
            and not ast_node.type_comment
            and len(ast_node.targets) == 1
            and isinstance(ast_node.targets[0], ast.Attribute)
            and isinstance(ast_node.targets[0].ctx, ast.Store)
            and isinstance(ast_node.targets[0].value, ast.Name)
            and ast_node.targets[0].value.id == 'self'
            and isinstance(ast_node.targets[0].value.ctx, ast.Load))

def class_definition_ast_to_ir2(ast_node: ast.ClassDef, compilation_context: CompilationContext, next_stmt_line: int):
    if ast_node.bases:
        if len(ast_node.bases) > 1:
            raise CompilationError(compilation_context, ast_node.bases[1],
                                   'Multiple base classes are not supported.')
        [base] = ast_node.bases
        if not (isinstance(base, ast.Name) and isinstance(base.ctx, ast.Load) and base.id == 'Exception'):
            raise CompilationError(compilation_context, base,
                                   '"Exception" is the only supported base class.')
        inherits_from_exception = True
    else:
        inherits_from_exception = False

    if not ast_node.decorator_list:
        has_dataclass_decorator = False
    else:
        if len(ast_node.decorator_list) > 1:
            raise CompilationError(compilation_context, ast_node.decorator_list[1],
                                   'Classes with multiple decorators are not supported.')
        [decorator] = ast_node.decorator_list
        if not (isinstance(decorator, ast.Name) and isinstance(decorator.ctx, ast.Load) and decorator.id == 'dataclass'):
            raise CompilationError(compilation_context, decorator,
                                   '"@dataclass" is the only supported class decorator.')
        has_dataclass_decorator = True

    if not inherits_from_exception and not has_dataclass_decorator:
        raise CompilationError(compilation_context, ast_node,
                               'Custom classes must either inherit from Exception or be decorated with @dataclass.')

    if inherits_from_exception and has_dataclass_decorator:
        raise CompilationError(compilation_context, ast_node,
                               'Custom Exception classes should not have the @dataclass decorator.')

    if inherits_from_exception:
        return custom_exception_class_to_ir2(ast_node, compilation_context, next_stmt_line)
    else:
        return custom_dataclass_definition_to_ir2(ast_node, compilation_context, next_stmt_line)

def custom_exception_class_to_ir2(ast_node: ast.ClassDef,
                                  compilation_context: CompilationContext,
                                  next_stmt_line: int):
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value,
                               'Keyword class arguments are not supported.')

    if len(ast_node.body) != 1 or not isinstance(ast_node.body[0], ast.FunctionDef) or ast_node.body[0].name != '__init__':
        raise CompilationError(compilation_context, ast_node,
                               'Custom classes must contain an __init__ method (and nothing else).')

    [init_defn_ast_node] = ast_node.body

    init_args_ast_node = init_defn_ast_node.args

    if init_args_ast_node.vararg:
        raise CompilationError(compilation_context, init_args_ast_node.vararg,
                               'Vararg arguments are not supported in __init__.')
    if init_args_ast_node.kwonlyargs:
        raise CompilationError(compilation_context, init_args_ast_node.kwonlyargs[0],
                               'Keyword-only arguments are not supported in __init__.')
    if init_args_ast_node.kw_defaults or init_args_ast_node.defaults:
        raise CompilationError(compilation_context, init_defn_ast_node,
                               'Default arguments are not supported in __init__.')
    if init_args_ast_node.kwarg:
        raise CompilationError(compilation_context, init_defn_ast_node,
                               'Keyword arguments are not supported in __init__.')

    init_args_ast_nodes = init_args_ast_node.args

    if not init_args_ast_nodes or init_args_ast_nodes[0].arg != 'self':
        raise CompilationError(compilation_context, init_defn_ast_node,
                               'Expected "self" as first argument of __init__.')

    if init_args_ast_nodes[0].annotation:
        raise CompilationError(compilation_context, init_args_ast_nodes[0].annotation,
                               'Type annotations on the "self" argument are not supported.')

    for arg in init_args_ast_nodes:
        if arg.type_comment:
            raise CompilationError(compilation_context, arg,
                                   'Type comments on arguments are not supported.')

    init_args_ast_nodes = init_args_ast_nodes[1:]

    if not init_args_ast_nodes:
        raise CompilationError(compilation_context, init_defn_ast_node,
                               'Custom types must have at least 1 constructor argument (and field).')

    arg_decl_nodes_by_name = dict()
    arg_types = []
    for arg in init_args_ast_nodes:
        if not init_args_ast_nodes[0].annotation:
            raise CompilationError(compilation_context, arg,
                                   'All arguments of __init__ (except "self") must have a type annotation.')
        if arg.arg in arg_decl_nodes_by_name:
            previous_arg_node = arg_decl_nodes_by_name[arg.arg]
            raise CompilationError(compilation_context, arg,
                                   'Found multiple arguments with name "%s".' % arg.arg,
                                   notes=[(previous_arg_node, 'A previous argument with name "%s" was declared here.' % arg.arg)])

        if arg.arg == 'type':
            raise CompilationError(compilation_context, arg,
                                   'Arguments of a custom type cannot be called "type", it\'s a reserved identifier')

        arg_decl_nodes_by_name[arg.arg] = arg
        arg_types.append(ir2.CustomTypeArgDecl(name=arg.arg,
                                               expr_type = type_declaration_ast_to_ir2_expression_type(arg.annotation, compilation_context)))

    init_body_ast_nodes = init_defn_ast_node.body

    first_stmt = init_body_ast_nodes[0]
    if not (_is_class_field_initialization(first_stmt)
            and isinstance(first_stmt.value, ast.Str)
            and first_stmt.targets[0].attr == 'message'):
        raise CompilationError(compilation_context, first_stmt,
                               'Unexpected statement. The first statement in the constructor of an exception class must be of the form: self.message = \'...\'.')
    exception_message = first_stmt.value.s
    init_body_ast_nodes = init_body_ast_nodes[1:]

    arg_assign_nodes_by_name = dict()
    for stmt_ast_node in init_body_ast_nodes:
        if not (_is_class_field_initialization(stmt_ast_node)
                and isinstance(stmt_ast_node.value, ast.Name)
                and isinstance(stmt_ast_node.value.ctx, ast.Load)):
            raise CompilationError(compilation_context, stmt_ast_node,
                                   'Unexpected statement. All statements in __init__ methods must be of the form "self.some_var = some_var".')

        if stmt_ast_node.value.id in arg_assign_nodes_by_name:
            previous_assign_node = arg_assign_nodes_by_name[stmt_ast_node.value.id]
            raise CompilationError(compilation_context, stmt_ast_node,
                                   'Found multiple assignments to the field "%s".' % stmt_ast_node.value.id,
                                   notes=[(previous_assign_node, 'A previous assignment to "self.%s" was here.' % stmt_ast_node.value.id)])

        if stmt_ast_node.targets[0].attr != stmt_ast_node.value.id:
            raise CompilationError(compilation_context, stmt_ast_node,
                                   '__init__ arguments must be assigned to a field of the same name, but "%s" was assigned to "%s".' % (stmt_ast_node.value.id, stmt_ast_node.targets[0].attr))

        if stmt_ast_node.value.id not in arg_decl_nodes_by_name:
            raise CompilationError(compilation_context, stmt_ast_node,
                                   'Unsupported assignment. All assigments in __init__ methods must assign a parameter to a field with the same name.')

        arg_assign_nodes_by_name[stmt_ast_node.value.id] = stmt_ast_node

    for arg_name, decl_ast_node in arg_decl_nodes_by_name.items():
        if arg_name not in arg_assign_nodes_by_name:
            raise CompilationError(compilation_context, decl_ast_node,
                                   'All __init__ arguments must be assigned to fields, but "%s" was never assigned.' % arg_name)

    definition_branches = (
        (-ast_node.lineno, ast_node.lineno),
        (ast_node.lineno, init_defn_ast_node.lineno),
        (ast_node.lineno, next_stmt_line),
        (init_defn_ast_node.lineno, -ast_node.lineno),
    )
    constructor_branches = [(-init_defn_ast_node.lineno, init_defn_ast_node.body[0].lineno)]
    for index, stmt_ast_node in enumerate(init_defn_ast_node.body):
        if index + 1 < len(init_defn_ast_node.body):
            constructor_branches.append((stmt_ast_node.lineno, init_defn_ast_node.body[index + 1].lineno))
        else:
            constructor_branches.append((stmt_ast_node.lineno, -init_defn_ast_node.lineno))


    custom_type_ir = ir2.CustomType(name=ast_node.name,
                                    arg_types=tuple(arg_types),
                                    is_exception_class=True,
                                    exception_message=exception_message,
                                    constructor_source_branches=tuple(
                                        SourceBranch(compilation_context.filename, start, end)
                                        for start, end in constructor_branches))

    pass_stmts = tuple(ir2.PassStmt(SourceBranch(compilation_context.filename, start, end))
                       for start, end in definition_branches)

    return custom_type_ir, pass_stmts

def custom_dataclass_definition_to_ir2(ast_node: ast.ClassDef,
                                       compilation_context: CompilationContext,
                                       next_stmt_line: int):
    if ast_node.keywords:
        raise CompilationError(compilation_context, ast_node.keywords[0].value,
                               'Keyword class arguments are not supported.')

    if not ast_node.body:
        raise CompilationError(compilation_context, ast_node,
                               'Dataclasses must have at least 1 field.')

    arg_decl_nodes_by_name: Dict[str, ast.AnnAssign] = dict()
    arg_types = []
    for field in ast_node.body:
        if not isinstance(field, ast.AnnAssign) or not isinstance(field.target, ast.Name):
            raise CompilationError(compilation_context, field,
                                   'Dataclasses can contain only typed field assignments (and no other statements).')
        if field.value:
            raise CompilationError(compilation_context, field,
                                   'Dataclass field defaults are not supported.')

        field_name = field.target.id
        field_type = field.annotation

        if field_name in arg_decl_nodes_by_name:
            previous_arg_node = arg_decl_nodes_by_name[field_name]
            raise CompilationError(compilation_context, field,
                                   'Found multiple dataclass fields with name "%s".' % field_name,
                                   notes=[(previous_arg_node, 'A previous field with name "%s" was declared here.' % field_name)])

        if field_name == 'type':
            raise CompilationError(compilation_context, field,
                                   'Dataclass fields cannot be called "type", it\'s a reserved identifier')

        arg_decl_nodes_by_name[field_name] = field
        arg_types.append(ir2.CustomTypeArgDecl(name=field_name,
                                               expr_type = type_declaration_ast_to_ir2_expression_type(field_type, compilation_context)))

    custom_type_ir = ir2.CustomType(name=ast_node.name,
                                    arg_types=tuple(arg_types),
                                    is_exception_class=False,
                                    exception_message=None,
                                    constructor_source_branches=())

    # 1: from dataclasses import dataclass
    # 2: @dataclass
    # 3: class MyType:
    # 4:     x: bool
    # 5:     y: int
    # 6: assert MyType(True, 15).x
    # Generated branches that should not be generated:
    # * (-3, 3): line -3 didn't jump to line 3, in the IR nodes:
    # PassStmt(
    #   source_branch = SourceBranch(
    #     file_name = '<unknown>',
    #     source_line = -3,
    #     dest_line = 3))
    # * (1, 3): line 1 didn't jump to line 3, in the IR nodes:
    # PassStmt(
    #   source_branch = SourceBranch(
    #     file_name = '<unknown>',
    #     source_line = 1,
    #     dest_line = 3))
    # Not generated branches that should have been generated:

    # * (4, 5): line 4 didn't jump to line 5
    # * (5, -2): line 5 didn't exit the body of class 'MyType'
    # Matching branches (generated correctly):
    # * (-1, 1): line -1 didn't jump to line 1 (in nodes: Module)
    # * (3, 6): line 3 didn't jump to line 6 (in nodes: PassStmt)
    # * (6, -1): line 6 didn't exit the module (in nodes: Assert)

    dataclass_decorator_line = ast_node.decorator_list[0].lineno
    class_line = ast_node.lineno
    fields = ast_node.body

    definition_branches = (
        (-dataclass_decorator_line, dataclass_decorator_line),
        (dataclass_decorator_line, class_line),
        (class_line, next_stmt_line),
        (dataclass_decorator_line, fields[0].lineno),
        *((field1.lineno, field2.lineno)
          for field1, field2 in zip(fields[:-1], fields[1:])),
        (fields[-1].lineno, -dataclass_decorator_line),
    )
    pass_stmts = tuple(ir2.PassStmt(SourceBranch(compilation_context.filename, start, end))
                       for start, end in definition_branches)

    return custom_type_ir, pass_stmts