# Copyright 2013 Cloudbase Solutions Srl
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.


import ctypes
from ctypes import wintypes
import uuid

import six

from cloudbaseinit.utils import encoding
from cloudbaseinit.utils.windows import cryptoapi
from cloudbaseinit.utils import x509constants


malloc = ctypes.cdll.msvcrt.malloc
malloc.restype = ctypes.c_void_p
malloc.argtypes = [ctypes.c_size_t]

free = ctypes.cdll.msvcrt.free
free.restype = None
free.argtypes = [ctypes.c_void_p]

STORE_NAME_MY = "My"
STORE_NAME_ROOT = "Root"
STORE_NAME_TRUSTED_PEOPLE = "TrustedPeople"

X509_START_DATE_INTERVAL = -24 * 60 * 60 * 10000000
X509_END_DATE_INTERVAL = 10 * 365 * 24 * 60 * 60 * 10000000


class CryptoAPICertManager(object):
    @staticmethod
    def _get_thumprint_str(thumbprint, size):
        thumbprint_ar = ctypes.cast(
            thumbprint,
            ctypes.POINTER(ctypes.c_ubyte *
                           size)).contents

        thumbprint_str = ""
        for b in thumbprint_ar:
            thumbprint_str += "%02x" % b
        return thumbprint_str

    @staticmethod
    def _get_thumbprint_buffer(thumbprint_str):
        thumbprint_bytes = encoding.hex_to_bytes(thumbprint_str)
        return ctypes.cast(
            ctypes.create_string_buffer(thumbprint_bytes),
            ctypes.POINTER(wintypes.BYTE *
                           len(thumbprint_bytes))).contents

    def _get_cert_thumprint(self, cert_context_p):
        thumbprint = None

        try:
            thumprint_len = wintypes.DWORD()

            if not cryptoapi.CertGetCertificateContextProperty(
                    cert_context_p,
                    cryptoapi.CERT_SHA1_HASH_PROP_ID,
                    None, ctypes.byref(thumprint_len)):
                raise cryptoapi.CryptoAPIException()

            size = ctypes.c_size_t(thumprint_len.value)
            thumbprint = malloc(size)

            if not cryptoapi.CertGetCertificateContextProperty(
                    cert_context_p,
                    cryptoapi.CERT_SHA1_HASH_PROP_ID,
                    thumbprint, ctypes.byref(thumprint_len)):
                raise cryptoapi.CryptoAPIException()

            return self._get_thumprint_str(thumbprint, thumprint_len.value)
        finally:
            if thumbprint:
                free(thumbprint)

    def _generate_key(self, container_name, machine_keyset):
        crypt_prov_handle = wintypes.HANDLE()
        key_handle = wintypes.HANDLE()

        try:
            flags = 0
            if machine_keyset:
                flags |= cryptoapi.CRYPT_MACHINE_KEYSET

            if not cryptoapi.CryptAcquireContext(
                    ctypes.byref(crypt_prov_handle),
                    container_name,
                    None,
                    cryptoapi.PROV_RSA_FULL,
                    flags):
                flags |= cryptoapi.CRYPT_NEWKEYSET
                if not cryptoapi.CryptAcquireContext(
                        ctypes.byref(crypt_prov_handle),
                        container_name,
                        None,
                        cryptoapi.PROV_RSA_FULL,
                        flags):
                    raise cryptoapi.CryptoAPIException()

            # RSA 2048 bits
            if not cryptoapi.CryptGenKey(crypt_prov_handle,
                                         cryptoapi.AT_KEYEXCHANGE,
                                         0x08000000,
                                         ctypes.byref(key_handle)):
                raise cryptoapi.CryptoAPIException()

            return key_handle
        finally:
            if key_handle:
                cryptoapi.CryptDestroyKey(key_handle)
            if crypt_prov_handle:
                cryptoapi.CryptReleaseContext(crypt_prov_handle, 0)

    @staticmethod
    def _add_system_time_interval(system_time, increment):
        '''increment's unit: 10ns'''
        file_time = cryptoapi.FILETIME()
        if not cryptoapi.SystemTimeToFileTime(ctypes.byref(system_time),
                                              ctypes.byref(file_time)):
            raise cryptoapi.CryptoAPIException()

        t = file_time.dwLowDateTime + (file_time.dwHighDateTime << 32)
        t += increment

        file_time.dwLowDateTime = t & 0xFFFFFFFF
        file_time.dwHighDateTime = t >> 32 & 0xFFFFFFFF

        new_system_time = cryptoapi.SYSTEMTIME()
        if not cryptoapi.FileTimeToSystemTime(ctypes.byref(file_time),
                                              ctypes.byref(new_system_time)):
            raise cryptoapi.CryptoAPIException()
        return new_system_time

    def create_self_signed_cert(self, subject, validity_years=10,
                                machine_keyset=True, store_name=STORE_NAME_MY):
        subject_encoded = None
        cert_context_p = None
        store_handle = None

        container_name = str(uuid.uuid4())
        self._generate_key(container_name, machine_keyset)

        try:
            subject_encoded_len = wintypes.DWORD()

            if not cryptoapi.CertStrToName(cryptoapi.X509_ASN_ENCODING,
                                           subject,
                                           cryptoapi.CERT_X500_NAME_STR, None,
                                           None,
                                           ctypes.byref(subject_encoded_len),
                                           None):
                raise cryptoapi.CryptoAPIException()

            size = ctypes.c_size_t(subject_encoded_len.value)
            subject_encoded = ctypes.cast(malloc(size),
                                          ctypes.POINTER(wintypes.BYTE))

            if not cryptoapi.CertStrToName(cryptoapi.X509_ASN_ENCODING,
                                           subject,
                                           cryptoapi.CERT_X500_NAME_STR, None,
                                           subject_encoded,
                                           ctypes.byref(subject_encoded_len),
                                           None):
                raise cryptoapi.CryptoAPIException()

            subject_blob = cryptoapi.CRYPTOAPI_BLOB()
            subject_blob.cbData = subject_encoded_len
            subject_blob.pbData = subject_encoded

            key_prov_info = cryptoapi.CRYPT_KEY_PROV_INFO()
            key_prov_info.pwszContainerName = container_name
            key_prov_info.pwszProvName = None
            key_prov_info.dwProvType = cryptoapi.PROV_RSA_FULL
            key_prov_info.cProvParam = None
            key_prov_info.rgProvParam = None
            key_prov_info.dwKeySpec = cryptoapi.AT_KEYEXCHANGE

            if machine_keyset:
                key_prov_info.dwFlags = cryptoapi.CRYPT_MACHINE_KEYSET
            else:
                key_prov_info.dwFlags = 0

            sign_alg = cryptoapi.CRYPT_ALGORITHM_IDENTIFIER()
            sign_alg.pszObjId = cryptoapi.szOID_RSA_SHA1RSA

            start_time = cryptoapi.SYSTEMTIME()
            cryptoapi.GetSystemTime(ctypes.byref(start_time))

            end_time = self._add_system_time_interval(
                start_time, X509_END_DATE_INTERVAL)

            # Needed in case of time sync issues as PowerShell remoting
            # enforces a valid time interval even for self signed certificates
            start_time = self._add_system_time_interval(
                start_time, X509_START_DATE_INTERVAL)

            cert_context_p = cryptoapi.CertCreateSelfSignCertificate(
                None, ctypes.byref(subject_blob), 0,
                ctypes.byref(key_prov_info),
                ctypes.byref(sign_alg), ctypes.byref(start_time),
                ctypes.byref(end_time), None)
            if not cert_context_p:
                raise cryptoapi.CryptoAPIException()

            if not cryptoapi.CertAddEnhancedKeyUsageIdentifier(
                    cert_context_p, cryptoapi.szOID_PKIX_KP_SERVER_AUTH):
                raise cryptoapi.CryptoAPIException()

            if machine_keyset:
                flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
            else:
                flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER

            store_handle = cryptoapi.CertOpenStore(
                cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
                six.text_type(store_name))
            if not store_handle:
                raise cryptoapi.CryptoAPIException()

            if not cryptoapi.CertAddCertificateContextToStore(
                    store_handle, cert_context_p,
                    cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING, None):
                raise cryptoapi.CryptoAPIException()

            return (self._get_cert_thumprint(cert_context_p),
                    self._get_cert_str(cert_context_p))

        finally:
            if store_handle:
                cryptoapi.CertCloseStore(store_handle, 0)
            if cert_context_p:
                cryptoapi.CertFreeCertificateContext(cert_context_p)
            if subject_encoded:
                free(subject_encoded)

    def _get_cert_str(self, cert_context_p):
        ch_cer_str = wintypes.DWORD(0)
        if not cryptoapi.CryptBinaryToString(
                cert_context_p.contents.pbCertEncoded,
                cert_context_p.contents.cbCertEncoded,
                cryptoapi.CRYPT_STRING_BASE64,
                None, ctypes.byref(ch_cer_str)):
            raise cryptoapi.CryptoAPIException()

        cer_str = ctypes.create_unicode_buffer(ch_cer_str.value)
        if not cryptoapi.CryptBinaryToString(
                cert_context_p.contents.pbCertEncoded,
                cert_context_p.contents.cbCertEncoded,
                cryptoapi.CRYPT_STRING_BASE64,
                cer_str,
                ctypes.byref(ch_cer_str)):
            raise cryptoapi.CryptoAPIException()

        return cer_str.value

    def _get_cert_base64(self, cert_data):
        """Remove certificate header and footer and also new lines."""
        # It's assured that the certificate is already a string.
        removal = [
            x509constants.PEM_HEADER,
            x509constants.PEM_FOOTER,
            "\r",
            "\n"
        ]
        for remove in removal:
            cert_data = cert_data.replace(remove, "")
        return cert_data

    def _find_certificate_in_store(self, thumbprint_str, machine_keyset=True,
                                   store_name=STORE_NAME_MY):
        store_handle = None

        thumbprint = self._get_thumbprint_buffer(thumbprint_str)
        hash_blob = cryptoapi.CRYPTOAPI_BLOB()
        hash_blob.cbData = len(thumbprint)
        hash_blob.pbData = thumbprint

        try:
            flags = cryptoapi.CERT_STORE_OPEN_EXISTING_FLAG
            if machine_keyset:
                flags |= cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
            else:
                flags |= cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER

            store_handle = cryptoapi.CertOpenStore(
                cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
                six.text_type(store_name))
            if not store_handle:
                raise cryptoapi.CryptoAPIException()

            cert_context_p = cryptoapi.CertFindCertificateInStore(
                store_handle,
                cryptoapi.X509_ASN_ENCODING | cryptoapi.PKCS_7_ASN_ENCODING,
                0,
                cryptoapi.CERT_FIND_SHA1_HASH,
                ctypes.pointer(hash_blob),
                None)
            if not cert_context_p:
                raise cryptoapi.CryptoAPIException()

            return cert_context_p
        finally:
            if store_handle:
                cryptoapi.CertCloseStore(store_handle, 0)

    def delete_certificate_from_store(self, thumbprint_str,
                                      machine_keyset=True,
                                      store_name=STORE_NAME_MY):
        cert_context_p = None

        try:
            cert_context_p = self._find_certificate_in_store(
                thumbprint_str, machine_keyset, store_name)
            if not cert_context_p:
                raise cryptoapi.CryptoAPIException()

            if not cryptoapi.CertDeleteCertificateFromStore(cert_context_p):
                raise cryptoapi.CryptoAPIException()
        finally:
            if cert_context_p:
                cryptoapi.CertFreeCertificateContext(cert_context_p)

    def import_pfx_certificate(self, pfx_data, pfx_password=None,
                               machine_keyset=True, store_name=STORE_NAME_MY):
        cert_context_p = None
        import_store_handle = None
        store_handle = None

        try:
            pfx_blob = cryptoapi.CRYPTOAPI_BLOB()
            pfx_blob.cbData = len(pfx_data)
            pfx_blob.pbData = ctypes.cast(
                pfx_data, ctypes.POINTER(wintypes.BYTE))

            import_store_handle = cryptoapi.PFXImportCertStore(
                ctypes.pointer(pfx_blob), pfx_password, 0)
            if not import_store_handle:
                raise cryptoapi.CryptoAPIException()

            cert_context_p = cryptoapi.CertFindCertificateInStore(
                import_store_handle,
                cryptoapi.X509_ASN_ENCODING | cryptoapi.PKCS_7_ASN_ENCODING,
                0, cryptoapi.CERT_FIND_ANY, None, None)
            if not cert_context_p:
                raise cryptoapi.CryptoAPIException()

            if machine_keyset:
                flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
            else:
                flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER

            store_handle = cryptoapi.CertOpenStore(
                cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
                six.text_type(store_name))
            if not store_handle:
                raise cryptoapi.CryptoAPIException()

            if not cryptoapi.CertAddCertificateContextToStore(
                    store_handle, cert_context_p,
                    cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING, None):
                raise cryptoapi.CryptoAPIException()

        finally:
            if import_store_handle:
                cryptoapi.CertCloseStore(import_store_handle, 0)
            if cert_context_p:
                cryptoapi.CertFreeCertificateContext(cert_context_p)
            if store_handle:
                cryptoapi.CertCloseStore(store_handle, 0)

    def decode_pkcs7_base64_blob(self, data, thumbprint_str,
                                 machine_keyset=True,
                                 store_name=STORE_NAME_MY):
        base64_data = data.replace('\r', '').replace('\n', '')
        store_handle = None
        cert_context_p = None

        try:
            data_encoded_len = wintypes.DWORD()

            if not cryptoapi.CryptStringToBinaryW(
                    base64_data, len(base64_data),
                    cryptoapi.CRYPT_STRING_BASE64,
                    None, ctypes.byref(data_encoded_len),
                    None, None):
                raise cryptoapi.CryptoAPIException()

            data_encoded = ctypes.cast(
                ctypes.create_string_buffer(data_encoded_len.value),
                ctypes.POINTER(wintypes.BYTE))

            if not cryptoapi.CryptStringToBinaryW(
                    base64_data, len(base64_data),
                    cryptoapi.CRYPT_STRING_BASE64,
                    data_encoded, ctypes.byref(data_encoded_len),
                    None, None):
                raise cryptoapi.CryptoAPIException()

            store_handle = cryptoapi.CertOpenStore(
                cryptoapi.CERT_STORE_PROV_MEMORY,
                cryptoapi.X509_ASN_ENCODING | cryptoapi.PKCS_7_ASN_ENCODING,
                None, cryptoapi.CERT_STORE_CREATE_NEW_FLAG, None)
            if not store_handle:
                raise cryptoapi.CryptoAPIException()

            cert_context_p = self._find_certificate_in_store(
                thumbprint_str, machine_keyset, store_name)

            if not cryptoapi.CertAddCertificateLinkToStore(
                    store_handle, cert_context_p,
                    cryptoapi.CERT_STORE_ADD_NEW, None):
                raise cryptoapi.CryptoAPIException()

            para = cryptoapi.CRYPT_DECRYPT_MESSAGE_PARA()
            para.cbSize = ctypes.sizeof(cryptoapi.CRYPT_DECRYPT_MESSAGE_PARA)
            para.dwMsgAndCertEncodingType = (cryptoapi.X509_ASN_ENCODING |
                                             cryptoapi.PKCS_7_ASN_ENCODING)
            para.cCertStore = 1
            para.rghCertStore = ctypes.pointer(wintypes.HANDLE(store_handle))
            para.dwFlags = cryptoapi.CRYPT_SILENT

            data_decoded_len = wintypes.DWORD()
            if not cryptoapi.CryptDecryptMessage(
                    ctypes.byref(para),
                    data_encoded,
                    data_encoded_len,
                    None,
                    ctypes.byref(data_decoded_len),
                    None):
                raise cryptoapi.CryptoAPIException()

            data_decoded_buf = ctypes.create_string_buffer(
                data_decoded_len.value)
            data_decoded = ctypes.cast(
                data_decoded_buf, ctypes.POINTER(wintypes.BYTE))

            if not cryptoapi.CryptDecryptMessage(
                    ctypes.pointer(para),
                    data_encoded,
                    data_encoded_len,
                    data_decoded,
                    ctypes.byref(data_decoded_len),
                    None):
                raise cryptoapi.CryptoAPIException()

            return bytes(data_decoded_buf)
        finally:
            if cert_context_p:
                cryptoapi.CertFreeCertificateContext(cert_context_p)
            if store_handle:
                cryptoapi.CertCloseStore(store_handle, 0)

    def import_cert(self, cert_data, machine_keyset=True,
                    store_name=STORE_NAME_MY):

        base64_cert_data = self._get_cert_base64(cert_data)

        cert_encoded = None
        store_handle = None
        cert_context_p = None

        try:
            cert_encoded_len = wintypes.DWORD()

            if not cryptoapi.CryptStringToBinaryW(
                    base64_cert_data, len(base64_cert_data),
                    cryptoapi.CRYPT_STRING_BASE64,
                    None, ctypes.byref(cert_encoded_len),
                    None, None):
                raise cryptoapi.CryptoAPIException()

            size = ctypes.c_size_t(cert_encoded_len.value)
            cert_encoded = ctypes.cast(malloc(size),
                                       ctypes.POINTER(wintypes.BYTE))

            if not cryptoapi.CryptStringToBinaryW(
                    base64_cert_data, len(base64_cert_data),
                    cryptoapi.CRYPT_STRING_BASE64,
                    cert_encoded, ctypes.byref(cert_encoded_len),
                    None, None):
                raise cryptoapi.CryptoAPIException()

            if machine_keyset:
                flags = cryptoapi.CERT_SYSTEM_STORE_LOCAL_MACHINE
            else:
                flags = cryptoapi.CERT_SYSTEM_STORE_CURRENT_USER

            store_handle = cryptoapi.CertOpenStore(
                cryptoapi.CERT_STORE_PROV_SYSTEM, 0, 0, flags,
                six.text_type(store_name))
            if not store_handle:
                raise cryptoapi.CryptoAPIException()

            cert_context_p = ctypes.POINTER(cryptoapi.CERT_CONTEXT)()

            if not cryptoapi.CertAddEncodedCertificateToStore(
                    store_handle,
                    cryptoapi.X509_ASN_ENCODING |
                    cryptoapi.PKCS_7_ASN_ENCODING,
                    cert_encoded, cert_encoded_len,
                    cryptoapi.CERT_STORE_ADD_REPLACE_EXISTING,
                    ctypes.byref(cert_context_p)):
                raise cryptoapi.CryptoAPIException()

            # Get the UPN (1.3.6.1.4.1.311.20.2.3 OID) from the
            # certificate subject alt name
            upn = None
            upn_len = cryptoapi.CertGetNameString(
                cert_context_p,
                cryptoapi.CERT_NAME_UPN_TYPE, 0,
                None, None, 0)
            if upn_len > 1:
                upn_ar = ctypes.create_unicode_buffer(upn_len)
                if cryptoapi.CertGetNameString(
                        cert_context_p,
                        cryptoapi.CERT_NAME_UPN_TYPE,
                        0, None, upn_ar, upn_len) != upn_len:
                    raise cryptoapi.CryptoAPIException()
                upn = upn_ar.value

            thumbprint = self._get_cert_thumprint(cert_context_p)
            return thumbprint, upn
        finally:
            if cert_context_p:
                cryptoapi.CertFreeCertificateContext(cert_context_p)
            if store_handle:
                cryptoapi.CertCloseStore(store_handle, 0)
            if cert_encoded:
                free(cert_encoded)