import binascii
import collections
import inspect
import struct
import sys

import msgpack

from tattle import crypto
from tattle import logging

__all__ = [
    'MESSAGE_HEADER_LENGTH',
    'MESSAGE_HEADER_FORMAT',
    'MESSAGE_FLAG_ENCRYPT',
    'Message',
    'MessageError',
    'MessageEncodeError',
    'MessageDecodeError',
    'MessageChecksumError',
    'MessageSerializer',
    'InternetAddress',
    'PingMessage',
    'PingRequestMessage',
    'AckMessage',
    'NackMessage',
    'SuspectMessage',
    'DeadMessage',
    'AliveMessage',
    'RemoteNodeState',
    'SyncMessage',
    'UserMessage',
]

LOG = logging.get_logger(__name__)

MESSAGE_HEADER_LENGTH = 7  # 2 for length, 1 for flags, 4 for CRC32
MESSAGE_HEADER_FORMAT = '!HBL'  # network byte order is B/E

MESSAGE_FLAG_ENCRYPT = 0x80


class MessageError(Exception):
    pass


class MessageEncodeError(MessageError):
    pass


class MessageDecodeError(MessageError):
    pass


class MessageChecksumError(MessageDecodeError):
    pass


class _BaseMessage(object):
    _fields_ = []

    def __init__(self, *args, **kwargs):

        # initialize fields
        fields = self.__class__.get_fields()
        for f in fields:
            self.__setattr__(f[0], None)

        # assign values from args
        for i, a in enumerate(args):
            key, cls = fields[i]
            if cls is not None:
                # if field has a type defined it must of that type or None
                if a is not None and not issubclass(a.__class__, cls):
                    raise TypeError("Field %s must be of type: %s (is %s)" % (key, cls.__name__, a.__class__.__name__))
            self.__setattr__(key, a)

        # assign values from kwargs
        names = [f[0] for f in fields]
        for k, a in kwargs.items():
            i = names.index(k)
            if i < 0:
                raise KeyError("Invalid field: %s" % k)
            key, cls = fields[i]
            if cls is not None:
                # if field has a type defined it must of that type or None
                if a is not None and not issubclass(a.__class__, cls):
                    raise TypeError("Field %s must be of type: %s (is %s)" % (key, cls.__name__, a.__class__.__name__))
            self.__setattr__(k, a)

    def __repr__(self):
        d = collections.OrderedDict()
        for f in self.__class__.get_fields():
            attr = getattr(self, f[0])
            d[f[0]] = attr
        return "<%s %s>" % (self.__class__.__name__, dict(d))

    def __eq__(self, other):
        """Override the default equals behavior"""
        if isinstance(other, self.__class__):
            return self.__dict__ == other.__dict__
        return NotImplemented

    def __ne__(self, other):
        """Define a non-equality test"""
        if isinstance(other, self.__class__):
            return not self.__eq__(other)
        return NotImplemented

    def __hash__(self):
        """Override the default hash behavior"""
        return hash(tuple(sorted(self.__dict__.items())))

    @classmethod
    def get_fields(cls):
        fields = []
        for base in reversed(inspect.getmro(cls)):
            if issubclass(base, _BaseMessage):
                # noinspection PyProtectedMember
                for f in base._fields_:
                    if isinstance(f, tuple):
                        fields.append(f)
                    else:
                        fields.append((f, None))
        return fields


class Message(_BaseMessage):
    def __init__(self, *args, **kwargs):
        super(Message, self).__init__(*args, **kwargs)


class MessageSerializer(object):
    """
    Utility class for serializing and deserializing Messages
    """

    @classmethod
    def _deserialize_internal(cls, data):
        # get message type
        message_type_name = data.pop(0)
        if message_type_name is None:
            return None
        message_type = getattr(sys.modules[__name__], message_type_name)

        message_args = []

        # deserialize all message fields
        message_fields = message_type.get_fields()
        for i in range(len(message_fields)):
            field_name, field_type = message_fields[i]
            if field_type is not None:

                # deserialize the field unless its None
                attr = data[0]
                if attr is None:
                    message_args.append(data.pop(0))
                else:
                    message_args.append(cls._deserialize_internal(data))

            else:
                attr = data.pop(0)
                if isinstance(attr, str) or isinstance(attr, bytes):
                    message_args.append(attr)
                elif isinstance(attr, collections.Sequence):
                    message_args.append([cls._deserialize_internal(i) for i in attr])
                else:
                    message_args.append(attr)

        # shenanigans to initialize Message without calling constructor
        message = message_type.__new__(message_type, *message_args)
        _BaseMessage.__init__(message, *message_args)
        return message

    @classmethod
    def _deserialize_message(cls, raw):
        return cls._deserialize_internal(msgpack.unpackb(raw, encoding='utf-8', use_list=True))

    @classmethod
    def _decrypt_message(cls, raw, keys):
        return crypto.decrypt_data(raw, keys)

    @classmethod
    def _verify_checksum(cls, raw, crc):
        expected = binascii.crc32(raw) & 0xffffffff  # https://docs.python.org/3/library/binascii.html#binascii.crc32
        if crc != expected:
            raise MessageChecksumError("Message checksum mismatch: 0x%X != 0x%X" % (crc, expected))

    @classmethod
    def decode(cls, buf, encryption=None) -> Message:
        """
        Decode a message from bytes

        :param buf:
        :param encryption: list of encryption keys
        :return: deserialized message
        """

        # unpack message header
        if len(buf) <= MESSAGE_HEADER_LENGTH:
            raise MessageDecodeError("Message is too short")
        (length, flags, crc) = struct.unpack(MESSAGE_HEADER_FORMAT, buf[0:MESSAGE_HEADER_LENGTH])

        # unpack message body
        raw = buf[MESSAGE_HEADER_LENGTH:]

        # verify message checksum
        cls._verify_checksum(raw, crc)

        # handle encryption
        if flags & MESSAGE_FLAG_ENCRYPT == MESSAGE_FLAG_ENCRYPT:
            raw = cls._decrypt_message(raw, keys=encryption)

        # return the deserialized message
        return cls._deserialize_message(raw)

    @classmethod
    def _serialize_internal(cls, msg):
        # insert the name of the class
        data = [msg.__class__.__name__]

        # get list of fields
        fields = msg.__class__.get_fields()
        for field_name, field_type in fields:
            attr = getattr(msg, field_name)
            if field_type is not None and attr is not None:
                # if attr has a field type defined deserialize that field
                data.extend(cls._serialize_internal(attr))
            else:
                if isinstance(attr, str) or isinstance(attr, bytes):
                    data.append(attr)
                elif isinstance(attr, collections.Sequence):
                    data.append([cls._serialize_internal(i) for i in attr])
                elif isinstance(attr, collections.Mapping):
                    data.append({k: cls._serialize_internal(v) for k, v in attr.items()})
                else:
                    data.append(attr)
        return data

    @classmethod
    def _serialize_message(cls, msg):
        return msgpack.packb(cls._serialize_internal(msg), use_bin_type=True, encoding='utf-8')

    @classmethod
    def _encrypt_message(cls, raw, key):
        return crypto.encrypt_data(raw, key)

    @classmethod
    def _compute_checksum(cls, raw):
        crc = binascii.crc32(raw) & 0xffffffff  # https://docs.python.org/3/library/binascii.html#binascii.crc32
        return crc

    @classmethod
    def encode(cls, msg: Message, encryption=None) -> bytes:
        """
        Encode a message to bytes

        :param msg:
        :param encryption: encryption key
        :return:
        """

        # serialize the message
        raw = cls._serialize_message(msg)

        flags = 0

        # encrypt message
        if encryption is not None:
            flags |= MESSAGE_FLAG_ENCRYPT
            raw = cls._encrypt_message(raw, key=encryption)

        # calculate message checksum
        crc = cls._compute_checksum(raw)

        # calculate message length
        length = len(raw) + MESSAGE_HEADER_LENGTH

        # pack message header
        header = struct.pack(MESSAGE_HEADER_FORMAT, length, flags, crc)
        return header + raw


class InternetAddress(_BaseMessage):
    _fields_ = [
        "addr_v4",
        "addr_v6",
        "port"
    ]

    def __init__(self, addr_v4, port, addr_v6=None):
        super(InternetAddress, self).__init__(addr_v4, addr_v6, port)

    @property
    def address(self):
        if self.addr_v6 is not None:
            return self.addr_v6
        else:
            return self.addr_v4

    def __str__(self):
        return "<%s %s,%d>" % (self.__class__.__name__, self.address, self.port)


class PingMessage(Message):
    _fields_ = [
        "seq",
        "node",
        "sender",
        ("sender_addr", InternetAddress),
    ]

    def __init__(self, seq, target, sender=None, sender_addr=None):
        """
        Create new instance of the PingMessage class

        :param seq: sequence number
        :param target: target node name
        :param sender: sender node name
        :param sender_addr: sender node address
        """
        super(PingMessage, self).__init__(seq, target, sender, sender_addr)

    def __str__(self):
        return "<%s %s>" % (self.__class__.__name__, self.node)


class PingRequestMessage(Message):
    _fields_ = [
        "seq",
        "node",
        ("node_addr", InternetAddress),
        "sender",
        ("sender_addr", InternetAddress),
    ]

    def __init__(self, seq, target, target_addr, sender=None, sender_addr=None):
        """
        Create new instance of the PingRequestMessage class

        :param seq: sequence number
        :param target: target node name
        :param target_addr: target node address
        :param sender: sender node name
        :param sender_addr: sender node address
        """
        super(PingRequestMessage, self).__init__(seq, target, target_addr, sender, sender_addr)

    def __str__(self):
        return "<%s %s>" % (self.__class__.__name__, self.node)


class AckMessage(Message):
    _fields_ = [
        "seq",
        "sender"
    ]

    def __str__(self):
        return "<%s seq=%s>" % (self.__class__.__name__, self.seq)


class NackMessage(Message):
    _fields_ = [
        "seq",
        "sender"
    ]

    def __str__(self):
        return "<%s seq=%s>" % (self.__class__.__name__, self.seq)


class SuspectMessage(Message):
    _fields_ = [
        "node",  # node name
        "incarnation",
        "sender",
    ]

    def __str__(self):
        return "<%s %s>" % (self.__class__.__name__, self.node)


class DeadMessage(Message):
    _fields_ = [
        "node",  # node name
        "incarnation",
        "sender"
    ]

    def __str__(self):
        return "<%s %s>" % (self.__class__.__name__, self.node)


class AliveMessage(Message):
    _fields_ = [
        "node",  # node name
        ("addr", InternetAddress),
        "incarnation"
    ]

    def __init__(self, node, addr, incarnation):
        super(AliveMessage, self).__init__(node, addr, incarnation)

    def __str__(self):
        return "<%s %s>" % (self.__class__.__name__, self.node)


class RemoteNodeState(_BaseMessage):
    _fields_ = [
        "node",
        ("addr", InternetAddress),
        "version",
        "incarnation",
        "status",
        "metadata",
    ]

    def __init__(self, node, addr, version, incarnation, status, metadata):
        super(RemoteNodeState, self).__init__(node, addr, version, incarnation, status, metadata)


class SyncMessage(Message):
    _fields_ = [
        "nodes"
    ]

    def __init__(self, remote_state):
        super(SyncMessage, self).__init__(remote_state)


class UserMessage(Message):
    _fields_ = [
        "data"
        "sender"
    ]

    def __init__(self, data, sender):
        """
        Create a new instance of the UserMessage class

        :param data: user message
        :param sender: sender node name
        """
        super(UserMessage, self).__init__(data, sender)