from Cryptodome.Cipher import AES, ChaCha20, Salsa20
from .twofish import Twofish
from Cryptodome.Util import Padding as CryptoPadding
import hashlib
from construct import (
    Adapter, BitStruct, BitsSwapped, Container, Flag, Padding, ListContainer, Mapping, GreedyBytes, Int32ul, Switch
)
from lxml import etree
from copy import deepcopy
import base64
import unicodedata
import zlib
import re
import codecs
from io import BytesIO
from collections import OrderedDict


class HeaderChecksumError(Exception):
    pass


class CredentialsError(Exception):
    pass


class PayloadChecksumError(Exception):
    pass


class DynamicDict(Adapter):
    """ListContainer <---> Container
    Convenience mapping so we dont have to iterate ListContainer to find
    the right item

    FIXME: lump kwarg was added to get around the fact that InnerHeader is
    not truly a dict.  We lump all 'binary' InnerHeaderItems into a single list
    """

    def __init__(self, key, subcon, lump=[]):
        super(DynamicDict, self).__init__(subcon)
        self.key = key
        self.lump = lump

    # map ListContainer to Container
    def _decode(self, obj, context, path):
        d = OrderedDict()
        for l in self.lump:
            d[l] = ListContainer([])
        for item in obj:
            if item[self.key] in self.lump:
                d[item[self.key]].append(item)
            else:
                d[item[self.key]] = item

        return Container(d)

    # map Container to ListContainer
    def _encode(self, obj, context, path):
        l = []
        for key in obj:
            if key in self.lump:
                l += obj[key]
            else:
                l.append(obj[key])

        return ListContainer(l)


def Reparsed(subcon_out):
    class Reparsed(Adapter):
        """Bytes <---> Parsed subcon result
        Takes in bytes and reparses it with subcon_out"""

        def _decode(self, data, con, path):
            return subcon_out.parse(data, **con)

        def _encode(self, obj, con, path):
            return subcon_out.build(obj, **con)

    return Reparsed


# is the payload compressed?
CompressionFlags = BitsSwapped(
    BitStruct("compression" / Flag, Padding(8 * 4 - 1))
)


# -------------------- Key Computation --------------------
def aes_kdf(key, rounds, key_composite):
    """Set up a context for AES128-ECB encryption to find transformed_key"""

    cipher = AES.new(key, AES.MODE_ECB)

    # get the number of rounds from the header and transform the key_composite
    transformed_key = key_composite
    for _ in range(0, rounds):
        transformed_key = cipher.encrypt(transformed_key)

    return hashlib.sha256(transformed_key).digest()


def compute_key_composite(password=None, keyfile=None):
    """Compute composite key.
    Used in header verification and payload decryption."""

    # hash the password
    if password:
        password_composite = hashlib.sha256(password.encode('utf-8')).digest()
    else:
        password_composite = b''
    # hash the keyfile
    if keyfile:
        # try to read XML keyfile
        try:
            with open(keyfile, 'r') as f:
                tree = etree.parse(f).getroot()
                keyfile_composite = base64.b64decode(tree.find('Key/Data').text)
        # otherwise, try to read plain keyfile
        except (etree.XMLSyntaxError, UnicodeDecodeError):
            try:
                with open(keyfile, 'rb') as f:
                    key = f.read()

                    try:
                        int(key, 16)
                        is_hex = True
                    except ValueError:
                        is_hex = False
                    # if the length is 32 bytes we assume it is the key
                    if len(key) == 32:
                        keyfile_composite = key
                    # if the length is 64 bytes we assume the key is hex encoded
                    elif len(key) == 64 and is_hex:
                        keyfile_composite = codecs.decode(key, 'hex')
                    # anything else may be a file to hash for the key
                    else:
                        keyfile_composite = hashlib.sha256(key).digest()
            except:
                raise IOError('Could not read keyfile')

    else:
        keyfile_composite = b''

    # create composite key from password and keyfile composites
    return hashlib.sha256(password_composite + keyfile_composite).digest()


def compute_master(context):
    """Computes master key from transformed key and master seed.
    Used in payload decryption."""

    # combine the transformed key with the header master seed to find the master_key
    master_key = hashlib.sha256(
        context._.header.value.dynamic_header.master_seed.data +
        context.transformed_key).digest()
    return master_key


# -------------------- XML Processing --------------------


class XML(Adapter):
    """Bytes <---> lxml etree"""

    def _decode(self, data, con, path):
        return etree.parse(BytesIO(data))

    def _encode(self, tree, con, path):
        return etree.tostring(tree)


class UnprotectedStream(Adapter):
    """lxml etree <---> unprotected lxml etree
    Iterate etree for Protected elements and decrypt using cipher
    provided by get_cipher"""

    protected_xpath = '//Value[@Protected=\'True\']'
    unprotected_xpath = '//Value[@Protected=\'False\']'

    def __init__(self, protected_stream_key, subcon):
        super(UnprotectedStream, self).__init__(subcon)
        self.protected_stream_key = protected_stream_key

    def _decode(self, tree, con, path):
        cipher = self.get_cipher(self.protected_stream_key(con))
        for elem in tree.xpath(self.protected_xpath):
            if elem.text is not None:
                result = cipher.decrypt(base64.b64decode(elem.text)).decode('utf-8')
                # strip invalid XML characters - https://stackoverflow.com/questions/8733233
                result = re.sub(
                    u'[^\u0020-\uD7FF\u0009\u000A\u000D\uE000-\uFFFD\U00010000-\U0010FFFF]+',
                    '',
                    result
                )
                elem.text = result
            elem.attrib['Protected'] = 'False'
        return tree

    def _encode(self, tree, con, path):
        tree_copy = deepcopy(tree)
        cipher = self.get_cipher(self.protected_stream_key(con))
        for elem in tree_copy.xpath(self.unprotected_xpath):
            if elem.text is not None:
                elem.text = base64.b64encode(
                    cipher.encrypt(
                        elem.text.encode('utf-8')
                    )
                )
            elem.attrib['Protected'] = 'True'
        return tree


class ARCFourVariantStream(UnprotectedStream):
    def get_cipher(self, protected_stream_key):
        raise Exception("ARCFourVariant not implemented")


# https://github.com/dlech/KeePass2.x/blob/97141c02733cd3abf8d4dce1187fa7959ded58a8/KeePassLib/Cryptography/CryptoRandomStream.cs#L115-L119
class Salsa20Stream(UnprotectedStream):
    def get_cipher(self, protected_stream_key):
        key = hashlib.sha256(protected_stream_key).digest()
        return Salsa20.new(
            key=key,
            nonce=b'\xE8\x30\x09\x4B\x97\x20\x5D\x2A'
        )


# https://github.com/dlech/KeePass2.x/blob/97141c02733cd3abf8d4dce1187fa7959ded58a8/KeePassLib/Cryptography/CryptoRandomStream.cs#L103-L111
class ChaCha20Stream(UnprotectedStream):
    def get_cipher(self, protected_stream_key):
        key_hash = hashlib.sha512(protected_stream_key).digest()
        key = key_hash[:32]
        nonce = key_hash[32:44]
        return ChaCha20.new(
            key=key,
            nonce=nonce
        )


def Unprotect(protected_stream_id, protected_stream_key, subcon):
    """Select stream cipher based on protected_stream_id"""

    return Switch(
        protected_stream_id,
        {'arcfourvariant': ARCFourVariantStream(protected_stream_key, subcon),
         'salsa20': Salsa20Stream(protected_stream_key, subcon),
         'chacha20': ChaCha20Stream(protected_stream_key, subcon),
         },
        default=subcon
    )


# -------------------- Payload Encryption/Decompression --------------------

class Concatenated(Adapter):
    """Data Blocks <---> Bytes"""

    def _decode(self, blocks, con, path):
        return b''.join([block.block_data for block in blocks])

    def _encode(self, payload_data, con, path):
        blocks = []
        # split payload_data into 1 MB blocks (spec default)
        i = 0
        while i < len(payload_data):
            blocks.append(Container(block_data=payload_data[i:i + 2**20]))
            i += 2**20
        blocks.append(Container(block_data=b''))

        return blocks


class DecryptedPayload(Adapter):
    """Encrypted Bytes <---> Decrypted Bytes"""

    def _decode(self, payload_data, con, path):
        cipher = self.get_cipher(
            con.master_key,
            con._.header.value.dynamic_header.encryption_iv.data
        )
        payload_data = cipher.decrypt(payload_data)

        return payload_data

    def _encode(self, payload_data, con, path):
        payload_data = CryptoPadding.pad(payload_data, 16)
        cipher = self.get_cipher(
            con.master_key,
            con._.header.value.dynamic_header.encryption_iv.data
        )
        payload_data = cipher.encrypt(payload_data)

        return payload_data


class AES256Payload(DecryptedPayload):
    def get_cipher(self, master_key, encryption_iv):
        return AES.new(master_key, AES.MODE_CBC, encryption_iv)


class ChaCha20Payload(DecryptedPayload):
    def get_cipher(self, master_key, encryption_iv):
        return ChaCha20.new(key=master_key, nonce=encryption_iv)


class TwoFishPayload(DecryptedPayload):
    def get_cipher(self, master_key, encryption_iv):
        return Twofish.new(master_key, mode=Twofish.MODE_CBC, IV=encryption_iv)


class Decompressed(Adapter):
    """Compressed Bytes <---> Decompressed Bytes"""

    def _decode(self, data, con, path):
        return zlib.decompress(data, 16 + 15)

    def _encode(self, data, con, path):
        compressobj = zlib.compressobj(
            6,
            zlib.DEFLATED,
            16 + 15,
            zlib.DEF_MEM_LEVEL,
            0
        )
        data = compressobj.compress(data)
        data += compressobj.flush()
        return data


# -------------------- Cipher Enums --------------------

# payload encryption method
# https://github.com/keepassxreboot/keepassxc/blob/8324d03f0a015e62b6182843b4478226a5197090/src/format/KeePass2.cpp#L24-L26
CipherId = Mapping(
    GreedyBytes,
    {'aes256': b'1\xc1\xf2\xe6\xbfqCP\xbeX\x05!j\xfcZ\xff',
     'twofish': b'\xadh\xf2\x9fWoK\xb9\xa3j\xd4z\xf9e4l',
     'chacha20': b'\xd6\x03\x8a+\x8boL\xb5\xa5$3\x9a1\xdb\xb5\x9a'
     }
)

# protected entry encryption method
# https://github.com/dlech/KeePass2.x/blob/149ab342338ffade24b44aaa1fd89f14b64fda09/KeePassLib/Cryptography/CryptoRandomStream.cs#L35
ProtectedStreamId = Mapping(
    Int32ul,
    {'none': 0,
     'arcfourvariant': 1,
     'salsa20': 2,
     'chacha20': 3,
     }
)