from __future__ import division

import math

import six

try:
    from ecdsa import SigningKey as EcdsaSigningKey, VerifyingKey as EcdsaVerifyingKey
except ImportError:
    EcdsaSigningKey = EcdsaVerifyingKey = None

from jose.backends.base import Key
from jose.utils import base64_to_long, long_to_base64
from jose.constants import ALGORITHMS
from jose.exceptions import JWKError

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec, rsa, padding
from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature, encode_dss_signature
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
from cryptography.utils import int_from_bytes, int_to_bytes
from cryptography.x509 import load_pem_x509_certificate


class CryptographyECKey(Key):
    SHA256 = hashes.SHA256
    SHA384 = hashes.SHA384
    SHA512 = hashes.SHA512

    def __init__(self, key, algorithm, cryptography_backend=default_backend):
        if algorithm not in ALGORITHMS.EC:
            raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)

        self.hash_alg = {
            ALGORITHMS.ES256: self.SHA256,
            ALGORITHMS.ES384: self.SHA384,
            ALGORITHMS.ES512: self.SHA512
        }.get(algorithm)
        self._algorithm = algorithm

        self.cryptography_backend = cryptography_backend

        if hasattr(key, 'public_bytes') or hasattr(key, 'private_bytes'):
            self.prepared_key = key
            return

        if None not in (EcdsaSigningKey, EcdsaVerifyingKey) and isinstance(key, (EcdsaSigningKey, EcdsaVerifyingKey)):
            # convert to PEM and let cryptography below load it as PEM
            key = key.to_pem().decode('utf-8')

        if isinstance(key, dict):
            self.prepared_key = self._process_jwk(key)
            return

        if isinstance(key, six.string_types):
            key = key.encode('utf-8')

        if isinstance(key, six.binary_type):
            # Attempt to load key. We don't know if it's
            # a Public Key or a Private Key, so we try
            # the Public Key first.
            try:
                try:
                    key = load_pem_public_key(key, self.cryptography_backend())
                except ValueError:
                    key = load_pem_private_key(key, password=None, backend=self.cryptography_backend())
            except Exception as e:
                raise JWKError(e)

            self.prepared_key = key
            return

        raise JWKError('Unable to parse an ECKey from key: %s' % key)

    def _process_jwk(self, jwk_dict):
        if not jwk_dict.get('kty') == 'EC':
            raise JWKError("Incorrect key type. Expected: 'EC', Received: %s" % jwk_dict.get('kty'))

        if not all(k in jwk_dict for k in ['x', 'y', 'crv']):
            raise JWKError('Mandatory parameters are missing')

        x = base64_to_long(jwk_dict.get('x'))
        y = base64_to_long(jwk_dict.get('y'))
        curve = {
            'P-256': ec.SECP256R1,
            'P-384': ec.SECP384R1,
            'P-521': ec.SECP521R1,
        }[jwk_dict['crv']]

        public = ec.EllipticCurvePublicNumbers(x, y, curve())

        if 'd' in jwk_dict:
            d = base64_to_long(jwk_dict.get('d'))
            private = ec.EllipticCurvePrivateNumbers(d, public)

            return private.private_key(self.cryptography_backend())
        else:
            return public.public_key(self.cryptography_backend())

    def _sig_component_length(self):
        """Determine the correct serialization length for an encoded signature component.

        This is the number of bytes required to encode the maximum key value.
        """
        return int(math.ceil(self.prepared_key.key_size / 8.0))

    def _der_to_raw(self, der_signature):
        """Convert signature from DER encoding to RAW encoding."""
        r, s = decode_dss_signature(der_signature)
        component_length = self._sig_component_length()
        return int_to_bytes(r, component_length) + int_to_bytes(s, component_length)

    def _raw_to_der(self, raw_signature):
        """Convert signature from RAW encoding to DER encoding."""
        component_length = self._sig_component_length()
        if len(raw_signature) != int(2 * component_length):
            raise ValueError("Invalid signature")

        r_bytes = raw_signature[:component_length]
        s_bytes = raw_signature[component_length:]
        r = int_from_bytes(r_bytes, "big")
        s = int_from_bytes(s_bytes, "big")
        return encode_dss_signature(r, s)

    def sign(self, msg):
        if self.hash_alg.digest_size * 8 > self.prepared_key.curve.key_size:
            raise TypeError("this curve (%s) is too short "
                            "for your digest (%d)" % (self.prepared_key.curve.name,
                                                      8 * self.hash_alg.digest_size))
        signature = self.prepared_key.sign(msg, ec.ECDSA(self.hash_alg()))
        return self._der_to_raw(signature)

    def verify(self, msg, sig):
        try:
            signature = self._raw_to_der(sig)
            self.prepared_key.verify(signature, msg, ec.ECDSA(self.hash_alg()))
            return True
        except Exception:
            return False

    def is_public(self):
        return hasattr(self.prepared_key, 'public_bytes')

    def public_key(self):
        if self.is_public():
            return self
        return self.__class__(self.prepared_key.public_key(), self._algorithm)

    def to_pem(self):
        if self.is_public():
            pem = self.prepared_key.public_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PublicFormat.SubjectPublicKeyInfo
            )
            return pem
        pem = self.prepared_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.TraditionalOpenSSL,
            encryption_algorithm=serialization.NoEncryption()
        )
        return pem

    def to_dict(self):
        if not self.is_public():
            public_key = self.prepared_key.public_key()
        else:
            public_key = self.prepared_key

        crv = {
            'secp256r1': 'P-256',
            'secp384r1': 'P-384',
            'secp521r1': 'P-521',
        }[self.prepared_key.curve.name]

        # Calculate the key size in bytes. Section 6.2.1.2 and 6.2.1.3 of
        # RFC7518 prescribes that the 'x', 'y' and 'd' parameters of the curve
        # points must be encoded as octed-strings of this length.
        key_size = (self.prepared_key.curve.key_size + 7) // 8

        data = {
            'alg': self._algorithm,
            'kty': 'EC',
            'crv': crv,
            'x': long_to_base64(public_key.public_numbers().x, size=key_size).decode('ASCII'),
            'y': long_to_base64(public_key.public_numbers().y, size=key_size).decode('ASCII'),
        }

        if not self.is_public():
            data['d'] = long_to_base64(
                self.prepared_key.private_numbers().private_value,
                size=key_size
            ).decode('ASCII')

        return data


class CryptographyRSAKey(Key):
    SHA256 = hashes.SHA256
    SHA384 = hashes.SHA384
    SHA512 = hashes.SHA512

    def __init__(self, key, algorithm, cryptography_backend=default_backend):
        if algorithm not in ALGORITHMS.RSA:
            raise JWKError('hash_alg: %s is not a valid hash algorithm' % algorithm)

        self.hash_alg = {
            ALGORITHMS.RS256: self.SHA256,
            ALGORITHMS.RS384: self.SHA384,
            ALGORITHMS.RS512: self.SHA512
        }.get(algorithm)
        self._algorithm = algorithm

        self.cryptography_backend = cryptography_backend

        # if it conforms to RSAPublicKey interface
        if hasattr(key, 'public_bytes') and hasattr(key, 'public_numbers'):
            self.prepared_key = key
            return

        if isinstance(key, dict):
            self.prepared_key = self._process_jwk(key)
            return

        if isinstance(key, six.string_types):
            key = key.encode('utf-8')

        if isinstance(key, six.binary_type):
            try:
                if key.startswith(b'-----BEGIN CERTIFICATE-----'):
                    self._process_cert(key)
                    return

                try:
                    self.prepared_key = load_pem_public_key(key, self.cryptography_backend())
                except ValueError:
                    self.prepared_key = load_pem_private_key(key, password=None, backend=self.cryptography_backend())
            except Exception as e:
                raise JWKError(e)
            return

        raise JWKError('Unable to parse an RSA_JWK from key: %s' % key)

    def _process_jwk(self, jwk_dict):
        if not jwk_dict.get('kty') == 'RSA':
            raise JWKError("Incorrect key type. Expected: 'RSA', Received: %s" % jwk_dict.get('kty'))

        e = base64_to_long(jwk_dict.get('e', 256))
        n = base64_to_long(jwk_dict.get('n'))
        public = rsa.RSAPublicNumbers(e, n)

        if 'd' not in jwk_dict:
            return public.public_key(self.cryptography_backend())
        else:
            # This is a private key.
            d = base64_to_long(jwk_dict.get('d'))

            extra_params = ['p', 'q', 'dp', 'dq', 'qi']

            if any(k in jwk_dict for k in extra_params):
                # Precomputed private key parameters are available.
                if not all(k in jwk_dict for k in extra_params):
                    # These values must be present when 'p' is according to
                    # Section 6.3.2 of RFC7518, so if they are not we raise
                    # an error.
                    raise JWKError('Precomputed private key parameters are incomplete.')

                p = base64_to_long(jwk_dict['p'])
                q = base64_to_long(jwk_dict['q'])
                dp = base64_to_long(jwk_dict['dp'])
                dq = base64_to_long(jwk_dict['dq'])
                qi = base64_to_long(jwk_dict['qi'])
            else:
                # The precomputed private key parameters are not available,
                # so we use cryptography's API to fill them in.
                p, q = rsa.rsa_recover_prime_factors(n, e, d)
                dp = rsa.rsa_crt_dmp1(d, p)
                dq = rsa.rsa_crt_dmq1(d, q)
                qi = rsa.rsa_crt_iqmp(p, q)

            private = rsa.RSAPrivateNumbers(p, q, d, dp, dq, qi, public)

            return private.private_key(self.cryptography_backend())

    def _process_cert(self, key):
        key = load_pem_x509_certificate(key, self.cryptography_backend())
        self.prepared_key = key.public_key()

    def sign(self, msg):
        try:
            signature = self.prepared_key.sign(
                msg,
                padding.PKCS1v15(),
                self.hash_alg()
            )
        except Exception as e:
            raise JWKError(e)
        return signature

    def verify(self, msg, sig):
        try:
            self.prepared_key.verify(
                sig,
                msg,
                padding.PKCS1v15(),
                self.hash_alg()
            )
            return True
        except InvalidSignature:
            return False

    def is_public(self):
        return hasattr(self.prepared_key, 'public_bytes')

    def public_key(self):
        if self.is_public():
            return self
        return self.__class__(self.prepared_key.public_key(), self._algorithm)

    def to_pem(self, pem_format='PKCS8'):
        if self.is_public():
            if pem_format == 'PKCS8':
                fmt = serialization.PublicFormat.SubjectPublicKeyInfo
            elif pem_format == 'PKCS1':
                fmt = serialization.PublicFormat.PKCS1
            else:
                raise ValueError("Invalid format specified: %r" % pem_format)
            pem = self.prepared_key.public_bytes(
                encoding=serialization.Encoding.PEM,
                format=fmt
            )
            return pem

        if pem_format == 'PKCS8':
            fmt = serialization.PrivateFormat.PKCS8
        elif pem_format == 'PKCS1':
            fmt = serialization.PrivateFormat.TraditionalOpenSSL
        else:
            raise ValueError("Invalid format specified: %r" % pem_format)

        return self.prepared_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=fmt,
            encryption_algorithm=serialization.NoEncryption()
        )

    def to_dict(self):
        if not self.is_public():
            public_key = self.prepared_key.public_key()
        else:
            public_key = self.prepared_key

        data = {
            'alg': self._algorithm,
            'kty': 'RSA',
            'n': long_to_base64(public_key.public_numbers().n).decode('ASCII'),
            'e': long_to_base64(public_key.public_numbers().e).decode('ASCII'),
        }

        if not self.is_public():
            data.update({
                'd': long_to_base64(self.prepared_key.private_numbers().d).decode('ASCII'),
                'p': long_to_base64(self.prepared_key.private_numbers().p).decode('ASCII'),
                'q': long_to_base64(self.prepared_key.private_numbers().q).decode('ASCII'),
                'dp': long_to_base64(self.prepared_key.private_numbers().dmp1).decode('ASCII'),
                'dq': long_to_base64(self.prepared_key.private_numbers().dmq1).decode('ASCII'),
                'qi': long_to_base64(self.prepared_key.private_numbers().iqmp).decode('ASCII'),
            })

        return data