from typing import Iterable

import idaapi
import idautils
import idc
from .base import get_func, demangle
from ..core import set_name, get_ea, fix_addresses, is_same_function, add_func
from .line import Line
from .xref import Xref
from ..ui import updates_ui
from .. import exceptions

class Comments(object):
    """IDA Function Comments

    Provides easy access to all types of comments for an IDA Function.

    def __init__(self, function):
        self._function = function

    def __bool__(self):
        return any((self.regular, self.repeat,))

    def regular(self):
        """Function Comment"""
        return idaapi.get_func_cmt(self._function._func, False)

    def regular(self, comment):
        idaapi.set_func_cmt(self._function._func, comment, False)

    def repeat(self):
        """Repeatable Function Comment"""
        return idaapi.get_func_cmt(self._function._func, True)

    def repeat(self, comment):
        idaapi.set_func_cmt(self._function._func, comment, True)

    def __repr__(self):
        return ("Comments("
                " reqular={regular},"
                " repeat={repeat})").format(

class FunctionFlagsMixin(object):
    """ Mixin to add convenience checks for the function flags provided by IDA.

    IDA SDK documentation for the flags is found at:

    flags = None

    def is_noret(self):
        """ Function doesn't return """
        return bool(self.flags & idaapi.FUNC_NORET)  # 0x00000001

    def is_far(self):
        """ Is a far function. """
        return bool(self.flags & idaapi.FUNC_FAR)  # 0x00000002

    def is_library(self):
        """ Is a library function. """
        return bool(self.flags & idaapi.FUNC_LIB)  # 0x00000004

    def is_static(self):
        """ Is a static function. """
        return bool(self.flags & idaapi.FUNC_STATICDEF)  # 0x00000008

    def is_frame(self):
        """ Function uses frame pointer (BP) """
        return bool(self.flags & idaapi.FUNC_FRAME)  # 0x00000010

    def is_user_far(self):
        """ User has specified far-ness of the function. """
        return bool(self.flags & idaapi.FUNC_USERFAR)  # 0x00000020

    def is_hidden(self):
        """ A hidden function chunk. """
        return bool(self.flags & idaapi.FUNC_HIDDEN)  # 0x00000040

    def is_thunk(self):
        """ Thunk (jump) function. """
        return bool(self.flags & idaapi.FUNC_THUNK)  # 0x00000080

    def is_bottom_bp(self):
        """ BP points to the bottom of the stack frame. """
        return bool(self.flags & idaapi.FUNC_BOTTOMBP)  # 0x00000100

    def is_noret_pending(self):
        """ Function 'non-return' analysis must be performed. """
        return bool(self.flags & idaapi.FUNC_NORET_PENDING)  # 0x00200

    def is_sp_ready(self):
        """ SP-analysis has been performed. """
        return bool(self.flags & idaapi.FUNC_SP_READY)  # 0x00000400

    def is_purged_ok(self):
        """ 'argsize' field has been validated. """
        return bool(self.flags & idaapi.FUNC_PURGED_OK)  # 0x00004000

    def is_tail(self):
        """ This is a function tail. """
        return bool(self.flags & idaapi.FUNC_TAIL)  # 0x00008000

class Function(FunctionFlagsMixin):
    """IDA Function

    Provides easy access to function related APIs in IDA.

    class UseCurrentAddress(object):
        This is a filler object to replace `None` for the EA.
        In many cases, a programmer can accidentally initialize the
        `Function` object with `ea=None`, resulting in the current address.
        Usually, this is not the desired outcome. This object resolves this issue.

    def __init__(self, ea=UseCurrentAddress, name=None):
        if name is not None and ea != self.UseCurrentAddress:
            raise ValueError(("Either supply a name or an address (ea). "
                              "Not both. (ea={!r}, name={!r})").format(ea, name))

        elif name is not None:
            ea = idc.get_name_ea_simple(name)
            if ea == idc.BADADDR:
                raise exceptions.SarkNoFunction(
                    "The supplied name does not belong to an existing function. "
                    "(name = {!r})".format(name))

        elif ea == self.UseCurrentAddress:
            ea =

        elif ea is None:
            raise ValueError("`None` is not a valid address. To use the current screen ea, "
                             "use `Function(ea=Function.UseCurrentAddress)` or supply no `ea`.")

        elif isinstance(ea, Line):
            ea = ea.ea
        self._func = get_func(ea)
        self._comments = Comments(self)

    def is_function(ea=UseCurrentAddress):
            return True
        except exceptions.SarkNoFunction:
            return False

    def create(ea=UseCurrentAddress):
        if ea == Function.UseCurrentAddress:
            ea =

        if Function.is_function(ea):
            raise exceptions.SarkFunctionExists("Function already exists")

        if not add_func(ea):
            raise exceptions.SarkAddFunctionFailed("Failed to add function")

        return Function(ea)

    def comments(self):
        return self._comments

    def __eq__(self, other):
            return self.start_ea == other.start_ea
        except AttributeError:
            return False

    def __hash__(self):
        return self.start_ea

    def lines(self) -> Iterable[Line]:
        """Get all function lines."""
        return iter_function_lines(self._func)

    def start_ea(self) -> int:
        """Start Address"""
        return self._func.start_ea

    # Alias for `start_ea` for increased guessability and less typing.
    ea = start_ea

    def end_ea(self):
        """End Address

        Note that taking all the lines between `start_ea` and `end_ea` does not guarantee
        that you get all the lines in the function. To get all the lines, use `.lines`.
        return self._func.end_ea

    def flags(self):
        """The function flags.

        See `idaapi.FUNC_*` constants.
        return self._func.flags

    def xrefs_from(self):
        """Xrefs from the function.

        This includes the xrefs from every line in the function, as `Xref` objects.
        Xrefs are filtered to exclude code references that are internal to the function. This
        means that every xrefs to the function's code will NOT be returned (yet, references
        to the function's data will be returnd). To get those extra xrefs, you need to iterate
        the function's lines yourself.
        for line in self.lines:
            for xref in line.xrefs_from:
                if xref.type.is_flow:

                if in self and xref.iscode:

                yield xref

    def calls_from(self):
        return (xref for xref in self.xrefs_from if xref.type.is_call)

    def drefs_from(self):
        """Destination addresses of data xrefs from this function."""
        for line in self.lines:
            for ea in line.drefs_from:
                yield ea

    def crefs_from(self):
        """Destination addresses of code xrefs from this function."""
        for line in self.lines:
            for ea in line.crefs_from:
                yield ea

    def xrefs_to(self):
        """Xrefs to the function.

        This only includes references to that function's start address.
        return map(Xref, idautils.XrefsTo(self.start_ea))

    def drefs_to(self):
        """Source addresses of data xrefs to this function."""
        return idautils.DataRefsTo(self.start_ea)

    def crefs_to(self):
        """Source addresses of code xrefs to this function."""
        return idautils.CodeRefsTo(self.start_ea, 1)

    def name(self):
        """Function's Name"""
        return idaapi.get_ea_name(self.start_ea)

    def demangled(self):
        """Return the demangled name of the function. If none exists, return `.name`"""
        return demangle(

    def name(self, name):
        """Set the function name.

        If the name exists, an exception will be raised. To use IDA's name counting use
        `.set_name(desired_name, anyway=True)`.

    def set_name(self, name, anyway=False):
        """Set Function Name.

        Default behavior throws an exception when setting to a name that already exists in
        the IDB. to make IDA automatically add a counter to the name (like in the GUI,)
        use `anyway=True`.

            name: Desired name.
            anyway: `True` to set anyway.
        set_name(self.start_ea, name, anyway=anyway)

    def __repr__(self):
        return 'Function(name="{}", addr=0x{:08X})'.format(, self.start_ea)

    def __contains__(self, item):
        """Is an item contained (its EA is in) the function."""
        # If the item has an EA, use it. If not, use the item itself assuming it is an EA.
        ea = getattr(item, "ea", item)

        return is_same_function(ea, self.ea)

    def frame_size(self):
        return idaapi.get_frame_size(self._func)

    def color(self):
        """Function color in IDA View"""
        color = idc.get_color(self.ea, idc.CIC_FUNC)
        if color == 0xFFFFFFFF:
            return None

        return color

    def color(self, color):
        """Function Color in IDA View.

        Set color to `None` to clear the color.
        if color is None:
            color = 0xFFFFFFFF

        idc.set_color(self.ea, idc.CIC_FUNC, color)

    def has_name(self):
        return Line(self.start_ea).has_name

    def func_t(self):
        return self._func

    def signature(self):
        '''The C signature of the function.'''
        return idc.get_type(self.start_ea)

    def signature(self, c_signature):
        success = idc.SetType(self.start_ea, c_signature)
        if not success:
            raise exceptions.SetTypeFailed(self.start_ea, c_signature)

    def tinfo(self):
        '''The tinfo of the function type'''
        return idc.get_tinfo(self.start_ea)

    def tinfo(self, tinfo):
        success = idc.apply_type(self.start_ea, tinfo)
        if not success:
            raise exceptions.SetTypeFailed(self.start_ea, tinfo)

def iter_function_lines(func_ea) -> Iterable[Line]:
    """Iterate the lines of a function.

        func_ea (idaapi.func_t, int): The function to iterate.

        Iterator over all the lines of the function.
    for line in idautils.FuncItems(get_ea(func_ea)):
        yield Line(line)

def functions(start=None, end=None):
    """Get all functions in range.

        start: Start address of the range. Defaults to IDB start.
        end: End address of the range. Defaults to IDB end.

        This is a generator that iterates over all the functions in the IDB.
    start, end = fix_addresses(start, end)

    for func_t in idautils.Functions(start, end):
        yield Function(func_t)