"""
rohmu - content encryption

Copyright (c) 2016 Ohmu Ltd
See LICENSE for details
"""

from . import IO_BLOCK_SIZE
from .filewrap import FileWrap, Sink, Stream
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives.hashes import SHA1, SHA256
from cryptography.hazmat.primitives.hmac import HMAC
from cryptography.hazmat.primitives import serialization
import cryptography
import cryptography.hazmat.backends.openssl.backend
import io
import logging
import os
import struct


if cryptography.__version__ < "1.6":
    # workaround for deadlock https://github.com/pyca/cryptography/issues/2911
    cryptography.hazmat.backends.openssl.backend.activate_builtin_random()

FILEMAGIC = b"pghoa1"
AES_BLOCK_SIZE = 16


class EncryptorError(Exception):
    """ EncryptorError """


class Encryptor:
    def __init__(self, rsa_public_key_pem):
        if not isinstance(rsa_public_key_pem, bytes):
            rsa_public_key_pem = rsa_public_key_pem.encode("ascii")
        self.rsa_public_key = serialization.load_pem_public_key(rsa_public_key_pem, backend=default_backend())
        self.cipher = None
        self.authenticator = None

    def update(self, data):
        ret = b""
        if self.cipher is None:
            key = os.urandom(16)
            nonce = os.urandom(16)
            auth_key = os.urandom(32)
            self.cipher = Cipher(algorithms.AES(key), modes.CTR(nonce), backend=default_backend()).encryptor()
            self.authenticator = HMAC(auth_key, SHA256(), backend=default_backend())
            pad = padding.OAEP(mgf=padding.MGF1(algorithm=SHA1()),
                               algorithm=SHA1(),
                               label=None)
            cipherkey = self.rsa_public_key.encrypt(key + nonce + auth_key, pad)
            ret = FILEMAGIC + struct.pack(">H", len(cipherkey)) + cipherkey
        cur = self.cipher.update(data)
        self.authenticator.update(cur)
        if ret:
            return ret + cur
        else:
            return cur

    def finalize(self):
        if self.cipher is None:
            return b""  # empty plaintext input yields empty encrypted output

        ret = self.cipher.finalize()
        self.authenticator.update(ret)
        ret += self.authenticator.finalize()
        self.cipher = None
        self.authenticator = None
        return ret


class EncryptorFile(FileWrap):
    def __init__(self, next_fp, rsa_public_key_pem):
        super().__init__(next_fp)
        self.key = rsa_public_key_pem
        self.encryptor = Encryptor(self.key)
        self.offset = 0
        self.state = "OPEN"

    def flush(self):
        self._check_not_closed()
        self.next_fp.flush()

    def close(self):
        if self.state == "CLOSED":
            return
        final = self.encryptor.finalize()
        self.encryptor = None
        self.next_fp.write(final)
        super().close()

    def writable(self):
        """True if this stream supports writing"""
        self._check_not_closed()
        return True

    def write(self, data):
        """Encrypt and write the given bytes"""
        self._check_not_closed()
        if not data:
            return 0
        enc_data = self.encryptor.update(data)
        self.next_fp.write(enc_data)
        self.offset += len(data)
        return len(data)


class EncryptorStream(Stream):
    """Non-seekable stream of data that adds encryption on top of given source stream"""

    def __init__(self, src_fp, rsa_public_key_pem):
        super().__init__(src_fp)
        self._encryptor = Encryptor(rsa_public_key_pem)

    def _process_chunk(self, data):
        return self._encryptor.update(data)

    def _finalize(self):
        return self._encryptor.finalize()


class Decryptor:
    def __init__(self, rsa_private_key_pem):
        if not isinstance(rsa_private_key_pem, bytes):
            rsa_private_key_pem = rsa_private_key_pem.encode("ascii")
        self.rsa_private_key = serialization.load_pem_private_key(
            data=rsa_private_key_pem,
            password=None,
            backend=default_backend())
        self.cipher = None
        self.authenticator = None
        self._cipher_key_len = None
        self._header_size = None
        self._footer_size = 32

    def expected_header_bytes(self):
        if self._header_size is not None:
            return 0
        return self._cipher_key_len or 8

    def header_size(self):
        return self._header_size

    def footer_size(self):
        return self._footer_size

    def process_header(self, data):
        if self._cipher_key_len is None:
            if data[0:6] != FILEMAGIC:
                raise EncryptorError("Invalid magic bytes")
            self._cipher_key_len = struct.unpack(">H", data[6:8])[0]
        else:
            pad = padding.OAEP(mgf=padding.MGF1(algorithm=SHA1()),
                               algorithm=SHA1(),
                               label=None)
            try:
                plainkey = self.rsa_private_key.decrypt(data, pad)
            except AssertionError:
                raise EncryptorError("Decrypting key data failed")
            if len(plainkey) != 64:
                raise EncryptorError("Integrity check failed")
            key = plainkey[0:16]
            nonce = plainkey[16:32]
            auth_key = plainkey[32:64]
            self._header_size = 8 + len(data)

            self.cipher = Cipher(algorithms.AES(key), modes.CTR(nonce), backend=default_backend()).decryptor()
            self.authenticator = HMAC(auth_key, SHA256(), backend=default_backend())

    def process_data(self, data):
        if not data:
            return b""
        self.authenticator.update(data)
        return self.cipher.update(data)

    def finalize(self, footer):
        if footer != self.authenticator.finalize():
            raise EncryptorError("Integrity check failed")
        result = self.cipher.finalize()
        self.cipher = None
        self.authenticator = None
        return result


class DecryptorFile(FileWrap):
    def __init__(self, next_fp, rsa_private_key_pem):
        super().__init__(next_fp)
        self._key = rsa_private_key_pem
        self.log = logging.getLogger(self.__class__.__name__)
        self._decryptor = None
        self._crypted_size = None
        self._boundary_block = None
        self._plaintext_size = None
        # Our actual plain-text read offset. seek may change self.offset to something
        # else temporarily but we keep _decrypt_offset intact until we actually do a
        # read in case the caller just called seek in order to then immediately seek back
        self._decrypt_offset = None
        self.offset = None
        self._reset()

    def _reset(self):
        self._decryptor = Decryptor(self._key)
        self._crypted_size = self._file_size(self.next_fp)
        self._boundary_block = None
        self._plaintext_size = None
        self._decrypt_offset = 0
        # Plaintext offset
        self.offset = 0
        self.state = "OPEN"

    @classmethod
    def _file_size(cls, file):
        current_offset = file.seek(0, os.SEEK_SET)
        file_end_offset = file.seek(0, os.SEEK_END)
        file.seek(current_offset, os.SEEK_SET)
        return file_end_offset

    def _initialize_decryptor(self):
        if self._plaintext_size is not None:
            return
        while True:
            required_bytes = self._decryptor.expected_header_bytes()
            if not required_bytes:
                break
            self._decryptor.process_header(self._read_raw_exactly(required_bytes))
        self._plaintext_size = self._crypted_size - self._decryptor.header_size() - self._decryptor.footer_size()

    def _read_raw_exactly(self, required_bytes):
        data = self.next_fp.read(required_bytes)
        while data and len(data) < required_bytes:
            next_chunk = self.next_fp.read(required_bytes - len(data))
            if not next_chunk:
                break
            data += next_chunk
        if not data or len(data) != required_bytes:
            raise EncryptorError("Failed to read {} bytes of header or footer data".format(required_bytes))
        return data

    def _move_decrypt_offset_to_plaintext_offset(self):
        if self._decrypt_offset == self.offset:
            return
        seek_to = self.offset
        if self._decrypt_offset > self.offset:
            self.log.warning("Negative seek from %d to %d, must re-initialize decryptor",
                             self._decrypt_offset, self.offset)
            self._reset()
            self._initialize_decryptor()
        discard_bytes = seek_to - self._decrypt_offset
        self.offset = self._decrypt_offset
        while discard_bytes > 0:
            data = self._read_block(discard_bytes)
            discard_bytes -= len(data)

    def _read_all(self):
        full_data = bytearray()
        while True:
            data = self._read_block(IO_BLOCK_SIZE)
            if not data:
                return bytes(full_data)
            full_data.extend(data)

    def _read_block(self, size):
        if self._crypted_size == 0:
            return b""

        self._initialize_decryptor()

        if self.offset == self._plaintext_size:
            return b""

        self._move_decrypt_offset_to_plaintext_offset()

        # If we have an existing boundary block, fulfil the read entirely from that
        if self._boundary_block:
            size = min(size, len(self._boundary_block) - self.offset % AES_BLOCK_SIZE)
            data = self._boundary_block[self.offset % AES_BLOCK_SIZE:self.offset % AES_BLOCK_SIZE + size]
            if self.offset % AES_BLOCK_SIZE + size == len(self._boundary_block):
                self._boundary_block = None
            data_len = len(data)
            self.offset += data_len
            self._decrypt_offset += data_len
            return data

        # Only serve multiples of AES_BLOCK_SIZE whenever possible to keep things simpler
        read_size = size
        if self.offset + max(AES_BLOCK_SIZE, size) >= self._plaintext_size:
            read_size = self._plaintext_size - self.offset
        elif size > AES_BLOCK_SIZE and size % AES_BLOCK_SIZE != 0 and self.offset + size < self._plaintext_size:
            read_size = size - size % AES_BLOCK_SIZE
        elif size < AES_BLOCK_SIZE:
            read_size = AES_BLOCK_SIZE

        encrypted = self._read_raw_exactly(read_size)
        decrypted = self._decryptor.process_data(encrypted)
        if self.offset + read_size == self._plaintext_size:
            footer = self._read_raw_exactly(self._decryptor.footer_size())
            last_part = self._decryptor.finalize(footer)
            if last_part:
                decrypted += last_part

        if size < AES_BLOCK_SIZE:
            self._boundary_block = decrypted
            return self._read_block(size)
        decrypted_len = len(decrypted)
        self.offset += decrypted_len
        self._decrypt_offset += decrypted_len
        return decrypted

    def close(self):
        super().close()
        self._decryptor = None

    def read(self, size=-1):
        """Read up to size decrypted bytes"""
        self._check_not_closed()
        if self.state == "EOF" or size == 0:
            return b""
        elif size < 0:
            return self._read_all()
        else:
            return self._read_block(size)

    def readable(self):
        """True if this stream supports reading"""
        self._check_not_closed()
        return self.state in ["OPEN", "EOF"]

    def seek(self, offset, whence=0):
        self._check_not_closed()
        self._initialize_decryptor()
        if whence == os.SEEK_SET:
            if offset != self.offset:
                if offset > self._plaintext_size:
                    raise io.UnsupportedOperation("DecryptorFile does not support seeking beyond EOF")
                if offset < 0:
                    raise ValueError("negative seek position")
                self.offset = offset
            return self.offset
        elif whence == os.SEEK_CUR:
            if offset != 0:
                raise io.UnsupportedOperation("can't do nonzero cur-relative seeks")
            return self.offset
        elif whence == os.SEEK_END:
            if offset != 0:
                raise io.UnsupportedOperation("can't do nonzero end-relative seeks")
            self.offset = self._plaintext_size
            return self.offset
        else:
            raise ValueError("Invalid whence value")

    def seekable(self):
        """True if this stream supports random access"""
        self._check_not_closed()
        return True


class DecryptSink(Sink):
    def __init__(self, next_sink, file_size, encryption_key_data):
        super().__init__(next_sink)
        if file_size <= 0:
            raise ValueError("Invalid file_size: " + str(file_size))
        self.data_bytes_received = 0
        self.data_size = file_size
        self.decryptor = Decryptor(encryption_key_data)
        self.file_size = file_size
        self.footer = b""
        self.header = b""

    def _extract_encryption_footer_bytes(self, data):
        expected_data_bytes = self.data_size - self.data_bytes_received
        if len(data) > expected_data_bytes:
            self.footer += data[expected_data_bytes:]
            data = data[:expected_data_bytes]
        return data

    def _process_encryption_header(self, data):
        if not data or not self.decryptor.expected_header_bytes():
            return data
        if self.header:
            data = self.header + data
            self.header = None
        offset = 0
        while self.decryptor.expected_header_bytes() > 0:
            header_bytes = self.decryptor.expected_header_bytes()
            if header_bytes + offset > len(data):
                self.header = data[offset:]
                return b""
            self.decryptor.process_header(data[offset:offset + header_bytes])
            offset += header_bytes
        data = data[offset:]
        self.data_size = self.file_size - self.decryptor.header_size() - self.decryptor.footer_size()
        return data

    def write(self, data):
        written = len(data)
        data = self._process_encryption_header(data)
        if not data:
            return written
        data = self._extract_encryption_footer_bytes(data)
        self.data_bytes_received += len(data)
        if data:
            data = self.decryptor.process_data(data)
        if len(self.footer) == self.decryptor.footer_size():
            final_data = self.decryptor.finalize(self.footer)
            if final_data:
                data += final_data
        if not data:
            return written
        self._write_to_next_sink(data)
        return written