# Copyright 2018 RethinkDB
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file incorporates work covered by the following copyright:
# Copyright 2010-2016 RethinkDB, all rights reserved.

import base64
import binascii
import hashlib
import hmac
import struct
import sys
import threading
from random import SystemRandom

import six

from rethinkdb import ql2_pb2
from rethinkdb.errors import ReqlAuthError, ReqlDriverError
from rethinkdb.helpers import chain_to_bytes, decode_utf8
from rethinkdb.logger import default_logger

try:
    xrange
except NameError:
    xrange = range


def compare_digest(digest_a, digest_b):
    if sys.version_info[0] == 3:

        def xor_bytes(digest_a, digest_b):
            return digest_a ^ digest_b

    else:

        def xor_bytes(digest_a, digest_b, _ord=ord):
            return _ord(digest_a) ^ _ord(digest_b)

    left = None
    right = digest_b
    if len(digest_a) == len(digest_b):
        left = digest_a
        result = 0
    if len(digest_a) != len(digest_b):
        left = digest_b
        result = 1

    for l, r in zip(left, right):
        result |= xor_bytes(l, r)

    return result == 0


def pbkdf2_hmac(hash_name, password, salt, iterations):
    if hash_name != "sha256":
        raise AssertionError(
            'Hash name {hash_name} is not equal with "sha256"'.format(
                hash_name=hash_name
            )
        )

    def from_bytes(value, hexlify=binascii.hexlify, int=int):
        return int(hexlify(value), 16)

    def to_bytes(value, unhexlify=binascii.unhexlify):
        try:
            return unhexlify(bytes("%064x" % value, "ascii"))
        except TypeError:
            return unhexlify(bytes("%064x" % value))

    cache_key = (password, salt, iterations)

    cache_result = HandshakeV1_0.PBKDF2_CACHE.get(cache_key)

    if cache_result is not None:
        return cache_result

    mac = hmac.new(password, None, hashlib.sha256)

    def digest(msg, mac=mac):
        mac_copy = mac.copy()
        mac_copy.update(msg)
        return mac_copy.digest()

    t = digest(salt + b"\x00\x00\x00\x01")
    u = from_bytes(t)
    for c in xrange(iterations - 1):
        t = digest(t)
        u ^= from_bytes(t)

    u = to_bytes(u)
    HandshakeV1_0.PBKDF2_CACHE.set(cache_key, u)
    return u


class LocalThreadCache(threading.local):
    def __init__(self):
        self._cache = dict()

    def set(self, key, val):
        self._cache[key] = val

    def get(self, key):
        return self._cache.get(key)


class HandshakeV1_0(object):
    """
    RethinkDB client drivers are responsible for serializing queries, sending them to the server using the
    ReQL wire protocol, and receiving responses from the server and returning them to the calling application.

    The client sends the protocol version, authentication method, and authentication as a null-terminated JSON
    response. RethinkDB currently supports only one authentication method, SCRAM-SHA-256, as specified in IETF
    RFC 7677 and RFC 5802. The RFC is followed with the exception of error handling (RethinkDB uses its own
    higher level error reporting rather than the e= field). RethinkDB does not support channel binding and clients
    should not request this. The value of "authentication" is the "client-first-message" specified in RFC 5802
    (the channel binding flag, optional SASL authorization identity, username (n=), and random nonce (r=).

    More info: https://rethinkdb.com/docs/writing-drivers/
    """

    VERSION = ql2_pb2.VersionDummy.Version.V1_0
    PROTOCOL = ql2_pb2.VersionDummy.Protocol.JSON
    PBKDF2_CACHE = LocalThreadCache()

    def __init__(self, json_decoder, json_encoder, host, port, username, password):
        """
        TODO:
        """

        self._json_decoder = json_decoder
        self._json_encoder = json_encoder
        self._host = host
        self._port = port
        self._username = (
            username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
        )

        self._password = six.b(password)

        self._compare_digest = self._get_compare_digest()
        self._pbkdf2_hmac = self._get_pbkdf2_hmac()

        self._protocol_version = 0
        self._random_nonce = None
        self._first_client_message = None
        self._server_signature = None
        self._state = 0

    @staticmethod
    def _get_compare_digest():
        """
        Get the compare_digest function from hashlib if package contains it, else get
        our own function. Please note that hashlib contains this function only for
        Python 2.7.7+ and 3.3+.
        """

        return getattr(hmac, "compare_digest", compare_digest)

    @staticmethod
    def _get_pbkdf2_hmac():
        """
        Get the pbkdf2_hmac function from hashlib if package contains it, else get
        our own function. Please note that hashlib contains this function only for
        Python 2.7.8+ and 3.4+.
        """

        return getattr(hashlib, "pbkdf2_hmac", pbkdf2_hmac)

    @staticmethod
    def _get_authentication_and_first_client_message(response):
        """
        Get the first client message and the authentication related data from the
        response provided by RethinkDB.

        :param response: Response dict from the database
        :return: None
        """

        first_client_message = response["authentication"].encode("ascii")
        authentication = dict(
            x.split(b"=", 1) for x in first_client_message.split(b",")
        )
        return first_client_message, authentication

    def _next_state(self):
        """
        Increase the state counter.
        """

        self._state += 1

    def _decode_json_response(self, response, with_utf8=False):
        """
        Get decoded json response from response.

        :param response: Response from the database
        :param with_utf8: UTF-8 decode response before json decoding
        :raises: ReqlDriverError | ReqlAuthError
        :return: Json decoded response of the original response
        """

        if with_utf8:
            response = decode_utf8(response)

        json_response = self._json_decoder.decode(response)

        if not json_response.get("success"):
            if 10 <= json_response["error_code"] <= 20:
                raise ReqlAuthError(json_response["error"], self._host, self._port)

            raise ReqlDriverError(json_response["error"])

        return json_response

    def _init_connection(self, response):
        """
        Prepare initial connection message. We send the version as well as the initial
        JSON as an optimization.

        :param response: Response from the database
        :raises: ReqlDriverError
        :return: Initial message which will be sent to the DB
        """

        if response is not None:
            raise ReqlDriverError("Unexpected response")

        self._random_nonce = base64.standard_b64encode(
            bytes(bytearray(SystemRandom().getrandbits(8) for i in range(18)))
        )

        self._first_client_message = chain_to_bytes(
            "n=", self._username, ",r=", self._random_nonce
        )

        initial_message = chain_to_bytes(
            struct.pack("<L", self.VERSION),
            self._json_encoder.encode(
                {
                    "protocol_version": self._protocol_version,
                    "authentication_method": "SCRAM-SHA-256",
                    "authentication": chain_to_bytes(
                        "n,,", self._first_client_message
                    ).decode("ascii"),
                }
            ).encode("utf-8"),
            b"\0",
        )

        self._next_state()
        return initial_message

    def _read_response(self, response):
        """
        Read response of the server. Due to we've already sent the initial JSON, and only support a single
        protocol version at the moment thus we simply read the next response and return an empty string as a
        message.

        :param response: Response from the database
        :raises: ReqlDriverError | ReqlAuthError
        :return: An empty string
        """

        json_response = self._decode_json_response(response)
        min_protocol_version = json_response["min_protocol_version"]
        max_protocol_version = json_response["max_protocol_version"]

        if not min_protocol_version <= self._protocol_version <= max_protocol_version:
            raise ReqlDriverError(
                "Unsupported protocol version {version}, expected between {min} and {max}".format(
                    version=self._protocol_version,
                    min=min_protocol_version,
                    max=max_protocol_version,
                )
            )

        self._next_state()
        return ""

    def _prepare_auth_request(self, response):
        """
        Put tohether the authentication request based on the response of the database.

        :param response: Response from the database
        :raises: ReqlDriverError | ReqlAuthError
        :return: An empty string
        """

        json_response = self._decode_json_response(response, with_utf8=True)
        (
            first_client_message,
            authentication,
        ) = self._get_authentication_and_first_client_message(json_response)

        random_nonce = authentication[b"r"]
        if not random_nonce.startswith(self._random_nonce):
            raise ReqlAuthError("Invalid nonce from server", self._host, self._port)

        salted_password = self._pbkdf2_hmac(
            "sha256",
            self._password,
            base64.standard_b64decode(authentication[b"s"]),
            int(authentication[b"i"]),
        )

        message_without_proof = chain_to_bytes("c=biws,r=", random_nonce)
        auth_message = b",".join(
            (self._first_client_message, first_client_message, message_without_proof)
        )

        self._server_signature = hmac.new(
            hmac.new(salted_password, b"Server Key", hashlib.sha256).digest(),
            auth_message,
            hashlib.sha256,
        ).digest()

        client_key = hmac.new(salted_password, b"Client Key", hashlib.sha256).digest()
        client_signature = hmac.new(
            hashlib.sha256(client_key).digest(), auth_message, hashlib.sha256
        ).digest()
        client_proof = struct.pack(
            "32B",
            *(
                l ^ random_nonce
                for l, random_nonce in zip(
                    struct.unpack("32B", client_key),
                    struct.unpack("32B", client_signature),
                )
            )
        )

        authentication_request = chain_to_bytes(
            self._json_encoder.encode(
                {
                    "authentication": chain_to_bytes(
                        message_without_proof,
                        ",p=",
                        base64.standard_b64encode(client_proof),
                    ).decode("ascii")
                }
            ),
            b"\0",
        )

        self._next_state()
        return authentication_request

    def _read_auth_response(self, response):
        """
        Read the authentication request's response sent by the database
        and validate the server signature which was returned.

        :param response: Response from the database
        :raises: ReqlDriverError | ReqlAuthError
        :return: None
        """

        json_response = self._decode_json_response(response, with_utf8=True)

        (
            first_client_message,
            authentication,
        ) = self._get_authentication_and_first_client_message(json_response)
        server_signature = base64.standard_b64decode(authentication[b"v"])

        if not self._compare_digest(server_signature, self._server_signature):
            raise ReqlAuthError("Invalid server signature", self._host, self._port)

        self._next_state()

    def reset(self):
        self._random_nonce = None
        self._first_client_message = None
        self._server_signature = None
        self._state = 0

    def next_message(self, response):
        if response is not None:
            response = response.decode("utf-8")

        if self._state == 0:
            return self._init_connection(response)

        elif self._state == 1:
            return self._read_response(response)

        elif self._state == 2:
            return self._prepare_auth_request(response)

        elif self._state == 3:
            return self._read_auth_response(response)

        raise ReqlDriverError("Unexpected handshake state")