r"""
Comment module (internal)

This module contains the functionality for encoding and decoding
arbitrary python objects into IDA's comments and is thus a major
component of the way the tag system in this plugin works.

These encoded python objects are constrained in that they must also
double as being human readable in IDA's disassembly view. This then
allows one to essentially store a dictionary within each comment that
can be used to read/write primitive python objects.

Each comment stored at an address is a list of key-value pairs
separated by newlines. Each key represents a tag and so as a comment,
it has the following appearance.

    [key1] value1
    [key2] value2
    ...

If a comment does not follow this format, then it assumed that the
comment's value is a simple unencoded tag and can be accessed via
the '' key (empty). This means that if your comment is simply the
string "HeapHandle", then its dictionary will look similar to
`{'' : 'HeapHandle'}`.

Unfortunately IDA has a limitation on the length of comments and so
arbitrarily long data will not be able to be stored. This includes
things such as more complex types or custom objects. If one truly
needs to store a more complex type then it is suggested that the user
marshal or pickle their object, compress it in some way, and then
base64 encode it at a given address with their key. Still this is not
recommended, but who am I to stop a user from wanting to be crazy.

Some of the values are encoded in order to retain any specific
information about the value as well as still allowing the user to read
the tag's immediate contents. By default all integers are always
encoded as hexadecimal. If you can't read hex, then you should save
your database and quit your job. Or you could also practice it I guess.
It's the next power of 2 from decimal and is hence pretty important.
Strings and other types that contain any line-breaking characters also
get encoded whereas both lists and most iterable types are stored as-is.

Some examples of how tag values are encoded are as follows:

    4021 -> 0xfb5
    -100 -> -0x64
    0.5 -> float(0.5)
    'hello world\nhow\'re ya' -> hello world\nhow're ya
    {10,20,30} -> set([10, 20, 30])

With regards to string, only certain characters are escaped because
I don't know if IDA's comments support unicode. So, the characters
that have special meaning are '\n', '\r', and '\t'. Similarly if a
backslash is escaped, it will result in a single '\'. If any other
character is prefixed with a backslash, it will decode into a character
that is prefixed with a backslash.
"""

import functools, operator, itertools, types
import collections, heapq, string
import six, logging

import internal, idaapi
import codecs

### cheap data structure for doing pattern matching with
class pattern(object):
    class star(tuple): pass
    class maybe(tuple): pass

class node(dict):
    id = 0
    def __missing__(self, token):
        cls = self.__class__

        res = cls()
        res.id = self.id + len(self) + 1
        return self.setdefault(token, res)

class trie(node):
    def assign(self, symbols, value):
        res = list(symbols)
        head = res.pop(0)
        symbols = tuple(res)

        head = head if hasattr(head, '__iter__') or len(head) > 1 else (head,)
        if isinstance(head, pattern.star):
            if len(symbols) <= 0:
                raise ValueError('Refusing to register the STAR(*) pattern as the last symbol')

            [ self.__setitem__(item, self) for item in head ]
            self.assign(symbols, value)
            return

        elif isinstance(head, pattern.maybe):
            if len(symbols) <= 0:
                raise ValueError('Refusing to register the MAYBE(?) pattern as the last symbol')
            head, res = tuple(head), trie()
            res.assign(symbols, value)
            [ self.__setitem__(item, res) for item in head ]
            self.assign(symbols, value)
            return

        elif len(symbols) > 0:
            [ self[item].assign(symbols, value) for item in head ]

        else:
            [ self.__setitem__(item, value) for item in head ]
        return

    def descend(self, symbols):
        yield self
        for i, symbol in enumerate(symbols):
            if not isinstance(self, node) or not operator.contains(self, symbol):
                raise KeyError(i, symbol)    # raise KeyError for .get() and .find()
            self = self[symbol]
            yield self
        return

    def get(self, symbols):
        for i, symbol in enumerate(self.descend(symbols)):
            continue
        if isinstance(symbol, node):
            raise KeyError(i, symbol)    # raise KeyError for our caller if we couldn't find anything
        return symbol

    def find(self, symbols):
        for i, symbol in enumerate(self.descend(symbols)):
            if not isinstance(symbol, node):
                return symbol
            continue
        raise KeyError(i, symbol)    # raise KeyError for our caller if we couldn't find anything

    def dump(self):
        cls = self.__class__
        # FIXME: doesn't support recursion
        def stringify(layer, indent=0, tab='  '):
            data = (k for k, v in six.viewitems(layer) if not isinstance(v, node))
            result = []
            for k in data:
                result.append("{:s}{!r} -> {!r}".format(tab * indent, k, layer[k]))

            branches = [k for k, v in six.viewitems(layer) if isinstance(v, node)]
            for k in branches:
                result.append("{:s}{!r}".format(tab * indent, k))
                branch_data = stringify(layer[k], indent+1, tab=tab)
                result.extend(branch_data)
            return result
        return '\n'.join(("{!r}({:d})".format(cls, self.id), '\n'.join(stringify(self))))

### cache for looking up encoder/decoder types
class cache(object):
    state, tree = collections.defaultdict(set), trie()

    @classmethod
    def register(cls, type, *characters):
        def result(definition):
            # add definition to constant search by specified type
            cls.state[type].add(definition)

            # add definition to symbolic search
            if characters:
                cls.tree.assign(characters, definition)

            return definition
        return result

    @classmethod
    def by(cls, instance):
        type = instance.__class__
        try:
            if type not in cls.state:
                type = next(t for t in cls.state if issubclass(type, t))
            res = next(enc for enc in cls.state[type] if enc.type(instance))
        except StopIteration:
            raise internal.exceptions.SerializationError(u"{:s}.by({!s}) : Unable to find an encoder for the serialization of the specified type ({!s}).".format('.'.join(('internal', __name__, cls.__name__)), type, type))
        return res

    @classmethod
    def match(cls, string):
        return cls.tree.find(string)

class default(object):
    @classmethod
    def type(cls, instance):
        return True
    @classmethod
    def encode(cls, instance):
        return repr(instance)
    @classmethod
    def decode(cls, data):
        return eval(data)

### type encoder/decoder registration
# FIXME: maybe figure out how to parse out an int from a long (which ends in 'L')
@cache.register(object, pattern.star(' \t'), pattern.maybe('-+'), '0123456789')
class _int(default):
    @classmethod
    def type(cls, instance):
        return isinstance(instance, six.integer_types)

    @classmethod
    def encode(cls, instance):
        return "{:-#x}".format(instance)

    @classmethod
    def decode(cls, data):
        if data.startswith('0x'):
            return int(data, 16)
        elif data.startswith('0o'):
            return int(data, 8)
        elif data.startswith('0o'):
            return int(data, 8)
        elif data.startswith('0b'):
            return int(data, 2)
        elif data.startswith('0y'):
            return int(data[2:], 2)
        return int(data)

@cache.register(object, pattern.star(' \t'), *'float(')
class _float(default):
    @classmethod
    def type(cls, instance):
        return isinstance(instance, float)
    @classmethod
    def encode(cls, instance):
        return "float({:f})".format(instance)

@cache.register(str)
class _str(default):
    """
    This encoder/decoder actually supports both ``unicode`` and regular
    ``str`` due to the ``type`` method checking both string types. Also,
    we use this class as a superclass for ``_unicode`` so that any kind
    of string will be encoded into a unicode string which will be
    converted into UTF8 when written into IDA.
    """

    @classmethod
    def type(cls, instance):
        return isinstance(instance, six.string_types)

    @classmethod
    def _unescape(cls, iterable):
        '''Invert the utils.character.unescape coroutine into a generator.'''
        state = internal.interface.collect_t(list, lambda agg, ch: agg + [ch])
        unescape = internal.utils.character.unescape(state); next(unescape)

        # iterate through each character in the string
        for ch in iterable:

            # if we encounter a backslash while not having any state then we
            # can just emit the backslash since this is just a regular
            # string w/ no magic
            if ch == '\\' and not state.get():
                yield '\\'
                continue

            # process the character so we can figure out what's being returned
            unescape.send(ch)

            # iterate through the results and yield them to the caller
            for ch in state.get():
                yield ch

            # now we can start over
            state.reset()
        return

    @classmethod
    def _escape(cls, iterable):
        '''Invert the utils.character.escape coroutine into a generator.'''
        state = internal.interface.collect_t(list, lambda agg, ch: agg + [ch])
        escape = internal.utils.character.escape(state); next(escape)

        # iterate through each character in the string
        for ch in iterable:

            # check if we've been asked to escape a backslash while there's no
            # state. this way we can handle backslashes specially since this
            # is just a regular string.
            if ch == '\\' and not state.get():
                yield '\\'
                continue

            # process the character so we can figure out how to escape it
            escape.send(ch)

            # iterate through the results and yield them to the caller
            for ch in state.get():
                yield ch

            # empty our state and start over
            state.reset()
        return

    @classmethod
    def decode(cls, data):
        res = data if isinstance(data, unicode) else data.decode('utf8')
        return unicode().join(cls._unescape(iter(res.lstrip())))

    @classmethod
    def encode(cls, instance):
        res = cls._escape(iter(instance))
        return unicode().join(res)

@cache.register(unicode, pattern.star(' \t'), 'u', "'\"")
class _unicode(_str):
    """
    This encoder/decoder really just a wrapper around the ``_str``
    class. Its encoder will simply escape the string in the exact
    same way as ``_str``. We register a pattern for it so that we
    can decode unicode strings encoded in their older format. Due
    to the older format requiring unicode strings to begin with
    the "u'" prefix, we can simply eval it in order to decode
    back to a unicode string.
    """

    @classmethod
    def type(cls, instance):
        return isinstance(instance, unicode)

    @classmethod
    def decode(cls, data):
        logging.warn(u"{:s}.decode({!s}) : Decoding a unicode string that was encoded using the old format.".format('.'.join(('internal', __name__, cls.__name__)), internal.utils.string.repr(data)))
        return eval(data)

@cache.register(dict, pattern.star(' \t'), '{')
class _dict(default):
    @classmethod
    def type(cls, instance):
        return isinstance(instance, dict)
    @classmethod
    def encode(cls, instance):
        f = lambda item: "{:-#x}".format(item) if isinstance(item, six.integer_types) else "{!r}".format(item)
        return '{' + ', '.join("{:s} : {!r}".format(f(key), instance[key]) for key in instance) + '}'

@cache.register(list, pattern.star(' \t'), '[')
class _list(default):
    @classmethod
    def type(cls, instance):
        return isinstance(instance, list)
    @classmethod
    def encode(cls, instance):
        f = lambda n: "{:-#x}".format(n) if isinstance(n, six.integer_types) else "{!r}".format(n)
        return '[' + ', '.join(map(f, instance)) + ']'

@cache.register(tuple, pattern.star(' \t'), '(')
class _tuple(default):
    @classmethod
    def type(cls, instance):
        return isinstance(instance, tuple)
    @classmethod
    def encode(cls, instance):
        f = lambda n: "{:-#x}".format(n) if isinstance(n, six.integer_types) else "{!r}".format(n)
        return '(' + ', '.join(map(f, instance)) + (', ' if len(instance) == 1 else '') + ')'

@cache.register(set, pattern.star(' \t'), *'set([')
class _set(default):
    @classmethod
    def type(cls, instance):
        return isinstance(instance, set)
    @classmethod
    def encode(cls, instance):
        f = lambda n: "{:-#x}".format(n) if isinstance(n, six.integer_types) else "{!r}".format(n)
        return 'set([' + ', '.join(map(f, instance)) + '])'

### general tag encoding/decoding
class tag(object):
    """
    Namespace for encoding and decoding a tag and it's value.
    """

    ## Tag name
    class name(object):
        """
        Namespace for encoding and decoding a tag's name.
        """

        prefix, suffix = '[]'
        backslash = '\\'

        mappings = {
            prefix : backslash + prefix,
            suffix : backslash + suffix,
            ' ' : r' ',
            backslash : backslash + backslash,
        }

        @classmethod
        def encode(cls, iterable, result):
            '''Given an `iterable` string, send each character in a printable form to `result`.'''

            # construct a transformer that writes characters to result
            escape = internal.utils.character.escape(result); next(escape)

            # send the key prefix
            result.send(cls.prefix)

            # now we can actually process the string
            for ch in iterable:

                # first check if character has an existing key mapping
                if operator.contains(cls.mappings, ch):
                    for ch in operator.getitem(cls.mappings, ch):
                        result.send(ch)

                # otherwise pass it to the regular escape function
                else:
                    escape.send(ch)

                continue

            # submit the suffix and we're good
            result.send(cls.suffix)
            return

        @classmethod
        def decode(cls, iterable, result):
            '''Given an `iterable`, decode it into a unicode string and send it to `result`.'''

            # first create our aggregate type for decoding into
            agg = internal.interface.collect_t(unicode, operator.add)

            # construct a transformer that unescapes characters to result
            unescape = internal.utils.character.unescape(agg); next(unescape)

            # first we'll skip the initial whitespace
            ch = next(iterable)
            while ch != cls.prefix and internal.utils.character.whitespaceQ(ch):
                ch = next(iterable)

            # check if it matches our prefix (which it should)
            if ch != cls.prefix:
                raise internal.exceptions.InvalidFormatError(u"{:s}.decode({!s}, {!s}) : Input for tag name does not begin with the proper character ('{:s}') and instead starts with '{:s}'.".format('.'.join(('internal', __name__, 'tag', cls.__name__)), iterable, result, internal.utils.string.escape(cls.prefix, '\''), internal.utils.string.escape(ch, '\'')))

            # read each character up to the sentinel
            agg.reset()
            ch = next(iterable)

            # loop until we find our suffix
            while ch != cls.suffix:

                # submit our character to the unescape transformer while
                # it's a backslash
                while ch in u'\\':
                    unescape.send(ch)
                    ch = next(iterable)

                # otherwise, continue to unescape it and look for a suffix
                unescape.send(ch)
                ch = next(iterable)

            # the last character read should be our suffix, so fail if otherwise
            if ch != cls.suffix:
                raise internal.exceptions.InvalidFormatError(u"{:s}.decode({!s}, {!s}) : Input for tag name does not terminate with the correct character ('{:s}') and instead is terminated with '{:s}'.".format('.'.join(('internal', __name__, 'tag', cls.__name__)), iterable, result, internal.utils.string.escape(cls.suffix, '\''), internal.utils.string.escape(ch, '\'')))

            # at this point, agg should have our unescaped key that we can submit
            result.send(agg.get())

    ## Tag value
    class value(object):
        """
        Namespace for encoding and decoding a tag's value.
        """

        @classmethod
        def encode(cls, iterable, result):
            '''Read a value from `iterable` and encode it into `result`.'''
            value = next(iterable)
            t = cache.by(value)
            for ch in t.encode(value):
                result.send(ch)
            return

        @classmethod
        def decode(cls, iterable, result):
            '''Given an `iterable`, decode it into a unicode string and send it to `result`.'''
            res = []

            # first we'll skip whitespace
            ch = next(iterable, '\n')
            while ch != '\n' and internal.utils.character.whitespaceQ(ch):
                res.append(ch)
                ch = next(iterable, '\n')

            # if we were just whitespace, then decode it as
            # a unicode string to return
            if ch == '\n':
                value_s = unicode().join(res)
                result.send(_str.decode(value_s))
                return

            # now we'll continue reading until the sentinel character
            value_l = []
            while ch != '\n':
                value_l.append(ch)
                ch = next(iterable, '\n')
            value_s = str().join(value_l)

            # now we'll try to find out what type to decode it as
            try:
                t = cache.match(value_s)
            except KeyError:
                t = _str

            # we have a type and a value. try to decode it
            try:
                value = t.decode(value_s)

            # if we weren't able to, then fall back to a string
            except Exception as E:
                t = _str
                logging.debug(u"{:s}.decode({!s}, {!s}) : Assuming value ({!s}) is of type {!s}.".format('.'.join(('internal', __name__, 'tag', cls.__name__)), iterable, result, internal.utils.string.repr(value_s), t))
                value = t.decode(value_s)

            # now we can submit it
            result.send(value)

    @classmethod
    def encode(cls, key, value):
        '''Encode the provided `key` and `value` into a line fit for a comment.'''
        result = internal.interface.collect_t(unicode, operator.add)

        # first encode the beginning of the name
        tag.name.encode(iter(key), result)

        # store some whitespace in between
        result.send(' ')

        # next encode the value component
        tag.value.encode(iter([value]), result)

        # now we can return the resulting string
        return result.get()

    @classmethod
    def decode(cls, iterable):
        '''Read the line in `iterable` and return the key and its value.'''

        # first decode the key
        key = internal.interface.collect_t(object, lambda agg, key: key)
        tag.name.decode(iterable, key)

        # next decode its value
        value = internal.interface.collect_t(object, lambda agg, value: value)
        tag.value.decode(iterable, value)

        # plain and simple...
        return key.get(), value.get()

### Encoding and decoding of a comment
def decode(data, default=u''):
    """Decode all the `(key, value)` pairs from the string `data` delimited by newlines.

    If unable to decode the key and value from a line in `data`, then use `default` as the key name.
    """

    # if data is empty, then return an empty dict
    if not data:
        return {}

    # initialize some variables to keep our state
    res = {}
    key, value = internal.interface.collect_t(object, lambda _, key: key), internal.interface.collect_t(object, lambda _, value: value)

    # iterate through each line in the data
    for line in data.split(u'\n'):
        iterable = iter(line)

        # try and decode the key and the value from the line
        try:
            k, v = tag.decode(iterable)

        # if the key wasn't terminated properly, or formatted correctly,
        # then append it to the default key separated by newlines
        except (StopIteration, internal.exceptions.InvalidFormatError) as E:
            items = filter(None, res.setdefault(default, u'').split(u'\n')) + [line]
            k, v = default, u'\n'.join(items)

        # warn the user if we're overwriting a key in the tag that
        # already exists.
        else:
            if operator.contains(res, k):
                logging.warn(u"{:s}.decode(..., {!s}) : Overwriting key ({!s}) containing old value ({!r}) with new value ({!r}).".format('.'.join(('internal', __name__)), internal.utils.string.repr(default), internal.utils.string.repr(k), res[k], v))

        # add our item to the result dictionary
        res[k] = v

    # return the dictionary we decoded
    return res

def encode(dict):
    '''Encode a dictionary into a multi-line string encoded as a list of tags.'''
    res = []

    # walk each item in the dictionary
    for k, v in six.iteritems(dict or {}):
        # encode the key and value from the dictionary
        line = tag.encode(k, v)

        # aggregate it into our list
        res.append(line)

    # join them by newlines and return them to the caller
    return '\n'.join(res)

def check(data):
    '''Check that the string `data` has the correct format by trying to decode it.'''
    res = map(iter, (data or '').split('\n'))
    try:
        map(tag.decode, res)
    except Exception as E:
        return False
    return True

### Tag reference counting
class tagging(object):
    """
    This namespace is essentially the configuration of the tagging
    database. This configurations specifies how to marshal and
    compress reference counts that are retained by the tagging
    infrastructure.

    The keys for the dictionaries that store the reference count
    are named according to ``tagging.__tags__`` for the tag names
    and ``tagging.__address__`` for the tag addresses. In order
    to access the tagging database, the netnode is returned by
    the ``tagging.node()`` function.

    When a database has been successfully created, a hook is
    responsible for calling the ``tagging.__init_tagcache__()``
    function. This will then create a netnode with the name
    specified in ``tagging.__node__``.
    """
    __node__ = '$ tagcache'
    __tags__, __address__ = 'name', 'address'

    marshaller = __import__('marshal')
    codec = __import__('codecs').lookup('bz2_codec')

    @classmethod
    def __init_tagcache__(cls, idp_modname):
        cls.node()
        logging.debug(u"{:s}.init_tagcache('{:s}') : Initialized tagcache with netnode \"{:s}\" and node id {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), internal.utils.string.escape(idp_modname, '\''), internal.utils.string.escape(cls.__node__, '"'), cls.__nodeid__))

    @classmethod
    def __nw_init_tagcache__(cls, nw_code, is_old_database):
        idp_modname = idaapi.get_idp_name()
        return cls.__init_tagcache__(idp_modname)

    @classmethod
    def node(cls):
        if hasattr(cls, '__nodeid__'):
            return cls.__nodeid__
        node = internal.netnode.get(cls.__node__)
        if node == idaapi.BADADDR:
            node = internal.netnode.new(cls.__node__)
        cls.__nodeid__ = node
        return node

class contents(tagging):
    '''Tagging for an address within a function (contents)'''
    """
    This namespace is used to update the tag state for any content tags
    associated with a function in the database. The address for the top
    of the function represents a key within the netnode that is used to
    fetch the blob and the supval which contains a marshall'd dictionary
    and a marshall'd set. These are stored within the `tagging.node()`
    netnode withn the tag `contents.btag`.

    The marshall'd dictionary that is stored in the netnode's blob is
    used to retain a dictionary of reference counts for both the tag
    names and the addresses that they reside at. Anytime a tag is
    written or removed, the reference count for both the name and the
    address is adjusted.

    Due to a size limit of a blob, the supval for the tagging node is
    used to store the tag names that are used within a function as a
    marshall'd ``set``. This ``set`` is used to verify that the tag
    names within the contents of the function correspond with the
    reference count that is stored within the marshall'd dictionary
    in the blob.
    """

    ## for each function's content
    # netnode.blob[fn.start_ea, btag] = marshal.dumps({'name', 'address'})
    # netnode.sup[fn.start_ea] = marshal.dumps({tagnames})

    #btag = idaapi.stag         # XXX: apparently 'S' is used for comments
    btag = idaapi.atag

    @classmethod
    def _key(cls, ea):
        '''Converts the address `ea` to a key that's used to store contents data for the specified function.'''
        res = idaapi.get_func(ea)
        return internal.interface.range.start(res) if res else None

    @classmethod
    def _read_header(cls, target, ea):
        """Read the contents dictionary out of the supval belonging to the function at `target`.

        If `target` is ``None``, then use the address of the function containing `ea`.
        """
        node, key = tagging.node(), cls._key(ea) if target is None else target
        if key is None:
            raise internal.exceptions.FunctionNotFoundError(u"{:s}._read_header({!r}, {:#x}) : Unable to find a function for target ({!r}) at {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, key, ea))

        encdata = internal.netnode.sup.get(node, key)
        if encdata is None:
            return None

        try:
            data, sz = cls.codec.decode(encdata)
            if len(encdata) != sz:
                raise internal.exceptions.SizeMismatchError(u"{:s}._read_header({!r}, {:#x}) : The number of bytes that was decoded ({:#x}) did not match the expected size ({:+#x}).".format('.'.join(('internal', __name__, cls.__name__)), target, ea, sz, len(encdata)))
        except Exception as E:
            raise internal.exceptions.SerializationError(u"{:s}._read_header({!r}, {:#x}) : Unable to decode contents for {:#x} at {:#x}. The data that failed to be decoded is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, key, ea, encdata))

        try:
            result = cls.marshaller.loads(data)
        except Exception as E:
            raise internal.exceptions.SerializationError(u"{:s}._read_header({!r}, {:#x}) : Unable to unmarshal contents for {:#x} at {:#x}. The data that failed to be unmarshalled is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, key, ea, data))
        return result

    @classmethod
    def _write_header(cls, target, ea, value):
        """Write the specified `value` into the contents supval belonging to the supval of the function at `target`.

        If `target` is ``None`` then use `ea` to locate the function.
        If `value` is ``None``, then remove the supval at the specified `target`.
        """
        node, key = tagging.node(), cls._key(ea) if target is None else target
        if key is None:
            raise internal.exceptions.FunctionNotFoundError(u"{:s}._write_header({!r}, {:#x}, {!s}) : Unable to find a function for target ({!r}) at {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), key, ea))

        if not value:
            ok = internal.netnode.sup.remove(node, key)
            return bool(ok)

        try:
            data = cls.marshaller.dumps(value)

        except Exception as E:
            raise internal.exceptions.SerializationError(u"{:s}._write_header({!r}, {:#x}, {!s}) : Unable to marshal contents for {:#x} at {:#x}. The data that failed to be marshalled is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), key, ea, value))

        try:
            encdata, sz = cls.codec.encode(data)
            if sz != len(data):
                raise internal.exceptions.SizeMismatchError(u"{:s}._write_header({!r}, {:#x}, {!s}) : The number of bytes that was encoded ({:#x}) did not match the expected size ({:+#x}).".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), sz, len(data)))

        except Exception as E:
            raise internal.exceptions.SerializationError(u"{:s}._write_header({!r}, {:#x}, {!s}) : Unable to encode contents for {:#x} at {:#x}. The data that failed to be encoded is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), key, ea, data))

        if len(encdata) > internal.netnode.sup.MAX_SIZE:
            logging.warn(u"{:s}._write_header({!r}, {:#x}, {!s}) : Too many tags within function. The size {:#x} must be < {:#x}. Ignoring it.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), len(encdata), internal.netnode.sup.MAX_SIZE))

        ok = internal.netnode.sup.set(node, key, encdata)
        return bool(ok)

    @classmethod
    def _read(cls, target, ea):
        """Reads the value from the contents supval for the specific `target`.

        If `target` is undefined or ``None`` then use `ea` to locate the function.
        """
        node, key = tagging.node(), cls._key(ea) if target is None else target
        if key is None:
            raise internal.exceptions.FunctionNotFoundError(u"{:s}._read({!r}, {:#x}) : Unable to find a function for target ({!r}) at {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, key, ea))

        encdata = internal.netnode.blob.get(key, cls.btag)
        if encdata is None:
            return None

        try:
            data, sz = cls.codec.decode(encdata)
            if len(encdata) != sz:
                raise internal.exceptions.SizeMismatchError(u"{:s}._read({!r}, {:#x}) : The number of bytes that was decoded ({:#x}) did not match the expected size ({:+#x}).".format('.'.join(('internal', __name__, cls.__name__)), target, ea, sz, len(encdata)))

        except Exception as E:
            raise internal.exceptions.SerializationError(u"{:s}._read({!r}, {:#x}) : Unable to decode contents for {:#x} at {:#x}. The data that failed to decode is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, key, ea, encdata))

        try:
            result = cls.marshaller.loads(data)

        except Exception as E:
            raise internal.exceptions.SerializationError(u"{:s}._read({!r}, {:#x}) : Unable to unmarshal contents for {:#x} at {:#x}. The data that failed to be unmarshalled is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, key, ea, data))
        return result

    @classmethod
    def _write(cls, target, ea, value):
        """Writes a `value` to the contents supval for the specific `target`.

        If `target` is undefined or ``None`` then use `ea` to locate the function.
        If `value` is ``None``, then erase the value from the supval.
        """
        node, key = tagging.node(), cls._key(ea) if target is None else target
        if key is None:
            raise internal.exceptions.FunctionNotFoundError(u"{:s}._write({!r}, {:#x}, {!r}) : Unable to find a function for target ({!r}) at {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, value, key, ea))

        # erase cache and blob if no data is specified
        if not value:
            try:
                ok = cls._write_header(target, ea, None)
                if not ok:
                    logging.debug(u"{:s}._write({!r}, {:#x}, {!s}) : Unable to remove address from sup cache with the key {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), key))
            finally:
                return internal.netnode.blob.remove(key, cls.btag)

        # update blob for given address
        res = value
        try:
            data = cls.marshaller.dumps(res)

        except Exception as E:
            raise internal.exceptions.SerializationError(u"{:s}._write({!r}, {:#x}, {!s}) : Unable to marshal contents for {:#x} at {:#x}. The data that failed to be marshalled is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), key, ea, res))

        try:
            encdata, sz = cls.codec.encode(data)

        except Exception as E:
            raise internal.exceptions.SerializationError(u"{:s}._write({!r}, {:#x}, {!s}) : Unable to encode contents for {:#x} at {:#x}. The data that failed to be encoded is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), key, ea, data))

        if sz != len(data):
            raise internal.exceptions.SizeMismatchError(u"{:s}._write({!r}, {:#x}, {!s}) : The number of bytes that was encoded ({:#x}) did not match the expected size ({:+#x}).".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), sz, len(data)))

        # write blob
        try:
            ok = internal.netnode.blob.set(key, cls.btag, encdata)
            if not ok: raise AssertionError # XXX: use an explicit exception

        except Exception as E:
            raise internal.exceptions.DisassemblerError(u"{:s}._write({!r}, {:#x}, {!s}) : Unable to set contents for {:#x} at {:#x}. The data that failed to be set is {!r}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), key, ea, encdata))

        # update sup cache with keys
        res = set(six.viewkeys(value))
        try:
            ok = cls._write_header(target, ea, res)
            if not ok: raise AssertionError # XXX: use an explicit exception

        except Exception as E:
            logging.fatal(u"{:s}._write({!r}, {:#x}, {!s}) : Unable to set address to sup cache with the key {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), target, ea, internal.utils.string.repr(value), key))
        return ok

    @classmethod
    def iterate(cls):
        '''Yield each address and names for all of the contents tags in the database according to what is written into the tagging supval.'''
        node = tagging.node()
        for ea in internal.netnode.sup.fiter(node):
            encdata = internal.netnode.sup.get(node, ea)
            data, sz = cls.codec.decode(encdata)
            if len(encdata) != sz:
                logging.warn(u"{:s}.iterate() : Failed decoding tag names out of sup cache for {:#x} due to the length of encoded data ({:#x}) not matching the expected size ({:#x}).".format('.'.join(('internal', __name__, cls.__name__)), ea, len(encdata), sz))
            res = cls.marshaller.loads(data)
            yield ea, res
        return

    @classmethod
    def inc(cls, address, name, **target):
        """Increase the ref count for the given `address` and `name` belonging to the function `target`.

        If `target` is undefined or ``None`` then use `address` to locate the function.
        """
        res = cls._read(target.get('target', None), address) or {}
        state, cache = res.get(cls.__tags__, {}), res.get(cls.__address__, {})

        state[name] = refs = state.get(name, 0) + 1
        cache[address] = cache.get(address, 0) + 1

        if state: res[cls.__tags__] = state
        else: del res[cls.__tags__]

        if cache: res[cls.__address__] = cache
        else: del res[cls.__address__]

        cls._write(target.get('target', None), address, res)
        return refs

    @classmethod
    def dec(cls, address, name, **target):
        """Decreate the ref count for the given `address` and `name` belonging to the function `target`.

        If `target` is undefined or ``None`` then use `address` to locate the function.
        """
        res = cls._read(target.get('target', None), address) or {}
        state, cache = res.get(cls.__tags__, {}), res.get(cls.__address__, {})

        refs, count = state.pop(name, 0) - 1, cache.pop(address, 0) - 1
        if refs > 0:
            state[name] = refs

        if count > 0:
            cache[address] = count

        if state: res[cls.__tags__] = state
        else: res.pop(cls.__tags__, None)

        if cache: res[cls.__address__] = cache
        else: res.pop(cls.__address__, None)

        cls._write(target.get('target', None), address, res)
        return refs

    @classmethod
    def name(cls, address, **target):
        """Return all the tag names (``set``) for the contents of the function `target`.

        If `target` is undefined or ``None`` then use `address` to locate the function.
        """
        res = cls._read(target.get('target', None), address) or {}
        res = res.get(cls.__tags__, {})
        return set(six.viewkeys(res))

    @classmethod
    def address(cls, address, **target):
        """Return all the addresses (``sorted``) with tags in the contents for the function `target`.

        If `target` is undefined or ``None`` then use `address` to locate the function.
        """
        res = cls._read(target.get('target', None), address) or {}
        res = res.get(cls.__address__, {})
        return sorted(six.viewkeys(res))

    @classmethod
    def set_name(cls, address, name, count, **target):
        """Set the contents tag count for the function `target` and `name` to `count`.

        If `target` is undefined or ``None`` then use `address` to locate the function.
        """
        state = cls._read(target.get('target', None), address) or {}

        res = state.get(cls.__tags__, {})
        if count > 0:
            res[name] = count
        else:
            res.pop(name, None)

        if res:
            state[cls.__tags__] = res
        else:
            state.pop(cls.__tags__, None)

        ok = cls._write(target.get('target', None), address, state)
        if ok:
            return state
        raise internal.exceptions.ReadOrWriteError(u"{:s}.set_name({:#x}, {!r}, {:d}{:s}) : Unable to write name to address {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), address, name, count, ', {:s}'.format(internal.utils.string.kwargs(target)) if target else '', address))

    @classmethod
    def set_address(cls, address, count, **target):
        """Set the contents tag count for the function `target` and `address` to `count`.

        If `target` is undefined or ``None`` then use `address` to locate the function.
        """
        state = cls._read(target.get('target', None), address) or {}

        res = state.get(cls.__address__, {})
        if count > 0:
            res[address] = count
        else:
            res.pop(address, None)

        if res:
            state[cls.__address__] = res
        else:
            state.pop(cls.__address__, None)

        ok = cls._write(target.get('target', None), address, state)
        if ok:
            return state
        raise internal.exceptions.ReadOrWriteError(u"{:s}.set_address({:#x}, {:d}{:s}) : Unable to write name to address {:#x}.".format('.'.join(('internal', __name__, cls.__name__)), address, count, ', {:s}'.format(internal.utils.string.kwargs(target)) if target else '', address))

class globals(tagging):
    """
    This namespace is used to update the tag state for all the globals in
    the database. Each global tag has its target address and its name and
    is managed by keeping track of a reference count.

    The reference count is stored within a netnode as defined by
    `tagging.node()`. The refcount for each address containing a
    tag is stored in an altval keyed by the address. The refcount
    for each tag name is stored in a hashval keyed by the tags
    name.
    """

    ## FIXME: for each global/function
    # netnode.alt[address] = refcount
    # netnode.hash[name] = refcount

    @classmethod
    def inc(cls, address, name):
        '''Increase the global tag count for the given `address` and `name`.'''
        node, eName = tagging.node(), internal.utils.string.to(name)

        cName = (internal.netnode.hash.get(node, eName, type=int) or 0) + 1
        cAddress = (internal.netnode.alt.get(node, address) or 0) + 1

        internal.netnode.hash.set(node, eName, cName)
        internal.netnode.alt.set(node, address, cAddress)

        return cName

    @classmethod
    def dec(cls, address, name):
        '''Decrease the global tag count for the given `address` and `name`.'''
        node, eName = tagging.node(), internal.utils.string.to(name)

        cName = (internal.netnode.hash.get(node, eName, type=int) or 1) - 1
        cAddress = (internal.netnode.alt.get(node, address) or 1) - 1

        if cName < 1:
            internal.netnode.hash.remove(node, eName)
        else:
            internal.netnode.hash.set(node, eName, cName)

        if cAddress < 1:
            internal.netnode.alt.remove(node, address)
        else:
            internal.netnode.alt.set(node, address, cAddress)

        return cName

    @classmethod
    def name(cls):
        '''Return all the tag names (``set``) in the specified database (globals and func-tags)'''
        node = tagging.node()
        return { internal.utils.string.of(name) for name in internal.netnode.hash.fiter(node) }

    @classmethod
    def address(cls):
        '''Return all the tag addresses (``sorted``) in the specified database (globals and func-tags)'''
        return sorted(ea for ea, _ in internal.netnode.alt.fiter(tagging.node()))

    @classmethod
    def set_name(cls, name, count):
        '''Set the global tag count for `name` in the database to `count`.'''
        node, eName = tagging.node(), internal.utils.string.to(name)
        res = internal.netnode.hash.get(node, eName, type=int)
        internal.netnode.hash.set(node, eName, count)
        return res

    @classmethod
    def set_address(cls, address, count):
        '''Set the global tag count for `address` in the database to `count`.'''
        node = tagging.node()
        res = internal.netnode.alt.get(node, address)
        internal.netnode.alt.set(node, address, count)
        return res