import uuid
import json as json_mod
from datetime import datetime

import requests
import base64
import tempfile
import shutil

from .action import ConnectionAction
from .database import Database, DBHandle
from .theExceptions import CreationError, ConnectionError
from .users import Users

from .ca_certificate import CA_Certificate

class JsonHook(object):
    """This one replaces requests' original json() function. If a call to json() fails, it will print a message with the request content"""
    def __init__(self, ret):
        self.ret = ret
        self.ret.json_originalFct = self.ret.json

    def __call__(self, *args, **kwargs):
        try:
            return self.ret.json_originalFct(*args, **kwargs)
        except Exception as e:
            print( "Unable to get json for request: %s. Content: %s" % (self.ret.url, self.ret.content) )
            raise e

class AikidoSession(object):
    """Magical Aikido being that you probably do not need to access directly that deflects every http request to requests in the most graceful way.
    It will also save basic stats on requests in it's attribute '.log'.
    """

    class Holder(object):
        def __init__(self, fct, auth, verify=True):
            self.fct = fct
            self.auth = auth
            if not isinstance(verify, bool) and not isinstance(verify, CA_Certificate) and not not isinstance(verify, str) :
                raise ValueError("'verify' argument can only be of type: bool, CA_Certificate or str ")
            self.verify = verify

        def __call__(self, *args, **kwargs):
            if self.auth:
                kwargs["auth"] = self.auth
            if isinstance(self.verify, CA_Certificate):
                kwargs["verify"] = self.verify.get_file_path()
            else :
                kwargs["verify"] = self.verify

            try:
                ret = self.fct(*args, **kwargs)
            except:
                print ("===\nUnable to establish connection, perhaps arango is not running.\n===")
                raise

            if len(ret.content) < 1:
                raise ConnectionError("Empty server response", ret.url, ret.status_code, ret.content)
            elif ret.status_code == 401:
                raise ConnectionError("Unauthorized access, you must supply a (username, password) with the correct credentials", ret.url, ret.status_code, ret.content)

            ret.json = JsonHook(ret)
            return ret

    def __init__(self, username, password, verify=True, max_retries=5, log_requests=False):
        if username:
            self.auth = (username, password)
        else:
            self.auth = None

        self.verify = verify
        self.max_retries = max_retries
        self.log_requests = log_requests

        if log_requests:
            self.log = {}
            self.log["nb_request"] = 0
            self.log["requests"] = {}

    def __getattr__(self, request_function_name):
        try:
            session = requests.Session()
            http = requests.adapters.HTTPAdapter(max_retries=self.max_retries)
            https = requests.adapters.HTTPAdapter(max_retries=self.max_retries)
            session.mount('http://', http)
            session.mount('https://', https)
            request_function = getattr(session, request_function_name)
        except AttributeError:
            raise AttributeError("Attribute '%s' not found (no Aikido move available)" % request_function_name)

        auth = object.__getattribute__(self, "auth")
        verify = object.__getattribute__(self, "verify")
        if self.log_requests:
            log = object.__getattribute__(self, "log")
            log["nb_request"] += 1
            log["requests"][request_function.__name__] += 1

        return AikidoSession.Holder(request_function, auth, verify)

    def disconnect(self):
        pass

class Connection(object):
    """This is the entry point in pyArango and directly handles databases.
    @param arangoURL: can be either a string url or a list of string urls to different coordinators
    @param use_grequests: allows for running concurent requets."""

    LOAD_BLANCING_METHODS = {'round-robin', 'random'}

    def __init__(self,
        arangoURL = 'http://127.0.0.1:8529',
        username = None,
        password = None,
        verify = True,
        verbose = False,
        statsdClient = None,
        reportFileName = None,
        loadBalancing = "round-robin",
        use_grequests = False,
        use_jwt_authentication=False,
        use_lock_for_reseting_jwt=True,
        max_retries=5,
    ):

        if loadBalancing not in Connection.LOAD_BLANCING_METHODS:
            raise ValueError("loadBalancing should be one of : %s, got %s" % (Connection.LOAD_BLANCING_METHODS, loadBalancing) )

        self.loadBalancing = loadBalancing
        self.currentURLId = 0
        self.username = username
        self.use_grequests = use_grequests
        self.use_jwt_authentication = use_jwt_authentication
        self.use_lock_for_reseting_jwt = use_lock_for_reseting_jwt
        self.max_retries = max_retries
        self.action = ConnectionAction(self)

        self.databases = {}
        self.verbose = verbose

        if isinstance(arangoURL, str):
            self.arangoURL = [arangoURL]
        else:
            self.arangoURL = arangoURL

        for i, url in enumerate(self.arangoURL):
            if url[-1] == "/":
                self.arangoURL[i] = url[:-1]

        self.identifier = None
        self.startTime = None
        self.session = None
        self.resetSession(username, password, verify)

        self.users = Users(self)

        if reportFileName != None:
            self.reportFile = open(reportFileName, 'a')
        else:
            self.reportFile = None

        self.statsdc = statsdClient
        self.reload()

    def getEndpointURL(self):
        """return an endpoint url applying load balacing strategy"""
        if self.loadBalancing == "round-robin":
            url = self.arangoURL[self.currentURLId]
            self.currentURLId = (self.currentURLId + 1) % len(self.arangoURL)
            return url
        elif self.loadBalancing == "random":
            import random
            return random.choice(self.arangoURL)

    def getURL(self):
        """return an URL for the connection"""
        return '%s/_api' % self.getEndpointURL()

    def getDatabasesURL(self):
        """return an URL to the databases"""
        if not self.session.auth:
            return '%s/database/user' % self.getURL()
        else:
            return '%s/user/%s/database' % (self.getURL(), self.username)

    def updateEndpoints(self, coordinatorURL = None):
        """udpdates the list of available endpoints from the server"""
        raise NotImplementedError("Not done yet.")

    def disconnectSession(self):
        if self.session:
            self.session.disconnect()

    def getVersion(self):
        """fetches the arangodb server version"""
        r = self.session.get(self.getURL() + "/version")
        data = r.json()
        if r.status_code == 200 and not "error" in data:
            return data
        else:
            raise CreationError(data["errorMessage"], data)

    def resetSession(self, username=None, password=None, verify=True):
        """resets the session"""
        self.disconnectSession()
        if self.use_grequests:
            from .gevent_session import AikidoSession_GRequests
            self.session = AikidoSession_GRequests(
                username, password, self.arangoURL,
                self.use_jwt_authentication,
                self.use_lock_for_reseting_jwt, self.max_retries,
                verify
            )
        else:
            self.session = AikidoSession(username, password, verify, self.max_retries)

    def reload(self):
        """Reloads the database list.
        Because loading a database triggers the loading of all collections and graphs within,
        only handles are loaded when this function is called. The full databases are loaded on demand when accessed
        """

        r = self.session.get(self.getDatabasesURL())

        data = r.json()
        if r.status_code == 200 and not data["error"]:
            self.databases = {}
            for dbName in data["result"]:
                if dbName not in self.databases:
                    self.databases[dbName] = DBHandle(self, dbName)
        else:
            raise ConnectionError(data["errorMessage"], self.getDatabasesURL(), r.status_code, r.content)

    def createDatabase(self, name, **dbArgs):
        "use dbArgs for arguments other than name. for a full list of arguments please have a look at arangoDB's doc"
        dbArgs['name'] = name
        payload = json_mod.dumps(dbArgs, default=str)
        url = self.getURL() + "/database"
        r = self.session.post(url, data = payload)
        data = r.json()
        if r.status_code == 201 and not data["error"]:
            db = Database(self, name)
            self.databases[name] = db
            return self.databases[name]
        else:
            raise CreationError(data["errorMessage"], r.content)

    def hasDatabase(self, name):
        """returns true/false wether the connection has a database by the name of 'name'"""
        return name in self.databases

    def __getitem__(self, dbName):
        """Collection[dbName] returns a database by the name of 'dbName', raises a KeyError if not found"""
        try:
            return self.databases[dbName]
        except KeyError:
            self.reload()
            try:
                return self.databases[dbName]
            except KeyError:
                raise KeyError("Can't find any database named : %s" % dbName)

    def reportStart(self, name):
        if self.statsdc != None:
            self.identifier = str(uuid.uuid5(uuid.NAMESPACE_DNS, name))[-6:]
            if self.reportFile != None:
                self.reportFile.write("[%s]: %s\n" % (self.identifier, name))
                self.reportFile.flush()
            self.startTime = datetime.now()

    def reportItem(self):
        if self.statsdc != None:
            diff = datetime.now() - self.startTime
            microsecs = (diff.total_seconds() * (1000 ** 2) ) + diff.microseconds
            self.statsdc.timing("pyArango_" + self.identifier, int(microsecs))