from test_zmq.authenticator import MultiZapAuthenticator

    import ujson as json
except ImportError:
    import json

import os
import shutil
import sys
import time
from binascii import hexlify, unhexlify
from collections import deque
from typing import Mapping, Tuple, Any, Union, Optional, NamedTuple

import zmq.auth
from zmq.utils import z85
from zmq.utils.monitor import recv_monitor_message

import zmq
from .util import createEncAndSigKeys, \
    moveKeyFilesToCorrectLocations, createCertsFromKeys
from test_zmq.crypto.nacl_wrappers import Signer, Verifier
from test_zmq.crypto.util import isHex, ed25519PkToCurve25519

Quota = NamedTuple("Quota", [("count", int), ("size", int)])

class ClientMessageProvider:
    def __init__(self, name, prepare_to_send, mt_outgoing_size, listener=None):
        self._name = name
        self.listener = listener
        self._prepare_to_send = prepare_to_send
        self._mt_outgoing_size = mt_outgoing_size

    def transmit_through_listener(self, msg, ident) -> Tuple[bool, Optional[str]]:
        result, error_msg, need_to_resend = self._transmit_one_msg_throughlistener(msg,
        return result, error_msg

    def _transmit_one_msg_throughlistener(self, msg, ident) -> Tuple[bool, Optional[str], bool]:

        def prepare_error_msg(ex):
            err_str = '{} got error {} while sending through listener to {}' \
                .format(self, ex, ident)
            return err_str

        need_to_resend = False
        if isinstance(ident, str):
            ident = ident.encode()
            msg = self._prepare_to_send(msg)
            self.listener.send_multipart([ident, msg], flags=zmq.NOBLOCK)
        except zmq.Again as ex:
            need_to_resend = True
            return False, prepare_error_msg(ex), need_to_resend
        except zmq.ZMQError as ex:
            need_to_resend = (ex.errno == 113)
            return False, prepare_error_msg(ex), need_to_resend
        except Exception as ex:
            return False, prepare_error_msg(ex), need_to_resend
        return True, None, need_to_resend

    def __repr__(self):
        return self._name

# TODO: Use Async io
# TODO: There a number of methods related to keys management,
# they can be moved to some class like KeysManager
class ZStack():
    # Assuming only one listener per stack for now.

    PublicKeyDirName = 'public_keys'
    PrivateKeyDirName = 'private_keys'
    VerifKeyDirName = 'verif_keys'
    SigKeyDirName = 'sig_keys'

    sigLen = 64
    pingMessage = 'pi'
    pongMessage = 'po'
    healthMessages = {pingMessage.encode(), pongMessage.encode()}

    # TODO: This is not implemented, implement this
    messageTimeout = 3

    def __init__(self, name, ha, basedirpath, msgHandler, restricted=True,
                 seed=None, onlyListener=False, config=None, msgRejectHandler=None, queue_size=0,
                 mt_incoming_size=None, mt_outgoing_size=None, timer=None):
        self._name = name
        self.ha = ha
        self.basedirpath = basedirpath
        self.msgHandler = msgHandler
        self.seed = seed
        self.queue_size = queue_size
        self.msgRejectHandler = msgRejectHandler or self.__defaultMsgRejectHandler

        self._node_mode = None
        self._stashed_unknown_remote_msgs = deque()

        self.mt_incoming_size = mt_incoming_size
        self.mt_outgoing_size = mt_outgoing_size

        self.homeDir = None
        # As of now there would be only one file in secretKeysDir and sigKeyDir
        self.publicKeysDir = None
        self.secretKeysDir = None
        self.verifKeyDir = None
        self.sigKeyDir = None

        self.signer = None
        self.verifiers = {}


        # self.poller = test.asyncio.Poller()

        self.restricted = restricted

        self.ctx = None  # type: Context
        self.listener = None
        self.create_listener_monitor = create_listener_monitor
        self.listener_monitor = None
        self.auth = None

        # Each remote is identified uniquely by the name
        self._remotes = {}

        self.remotesByKeys = {}

        self.remote_ping_stats = {}

        # Indicates if this stack will maintain any remotes or will
        # communicate simply to listeners. Used in ClientZStack
        self.onlyListener = onlyListener

        self._conns = set()  # type: Set[str]

        self.rxMsgs = deque()
        self._created = time.perf_counter()

        self.last_heartbeat_at = None

        self._stashed_pongs = set()
        self._received_pings = set()
        self._client_message_provider = ClientMessageProvider(,

    def __defaultMsgRejectHandler(self, reason: str, frm):

    def remotes(self):
        return self._remotes

    def created(self):
        return self._created

    def name(self):
        return self._name

    def set_mode(self, value):
        self._node_mode = value

    def isRemoteConnected(r) -> bool:
        return r.isConnected

    def initLocalKeys(name, baseDir, sigseed, override=False):
        sDir = os.path.join(baseDir, '__sDir')
        eDir = os.path.join(baseDir, '__eDir')
        os.makedirs(sDir, exist_ok=True)
        os.makedirs(eDir, exist_ok=True)
        (public_key, secret_key), (verif_key, sig_key) = \
            createEncAndSigKeys(eDir, sDir, name, seed=sigseed)

        homeDir = ZStack.homeDirPath(baseDir, name)
        verifDirPath = ZStack.verifDirPath(homeDir)
        sigDirPath = ZStack.sigDirPath(homeDir)
        secretDirPath = ZStack.secretDirPath(homeDir)
        pubDirPath = ZStack.publicDirPath(homeDir)
        for d in (homeDir, verifDirPath, sigDirPath, secretDirPath, pubDirPath):
            os.makedirs(d, exist_ok=True)

        moveKeyFilesToCorrectLocations(sDir, verifDirPath, sigDirPath)
        moveKeyFilesToCorrectLocations(eDir, pubDirPath, secretDirPath)

        return hexlify(public_key).decode(), hexlify(verif_key).decode()

    def initRemoteKeys(name, remoteName, baseDir, verkey, override=False):
        homeDir = ZStack.homeDirPath(baseDir, name)
        verifDirPath = ZStack.verifDirPath(homeDir)
        pubDirPath = ZStack.publicDirPath(homeDir)
        for d in (homeDir, verifDirPath, pubDirPath):
            os.makedirs(d, exist_ok=True)

        if isHex(verkey):
            verkey = unhexlify(verkey)

        createCertsFromKeys(verifDirPath, remoteName, z85.encode(verkey))
        public_key = ed25519PkToCurve25519(verkey)
        createCertsFromKeys(pubDirPath, remoteName, z85.encode(public_key))

    def onHostAddressChanged(self):
        # we don't store remote data like ip, port, domain name, etc, so
        # nothing to do here

    def areKeysSetup(name, baseDir):
        homeDir = ZStack.homeDirPath(baseDir, name)
        verifDirPath = ZStack.verifDirPath(homeDir)
        pubDirPath = ZStack.publicDirPath(homeDir)
        sigDirPath = ZStack.sigDirPath(homeDir)
        secretDirPath = ZStack.secretDirPath(homeDir)
        for d in (verifDirPath, pubDirPath):
            if not os.path.isfile(os.path.join(d, '{}.key'.format(name))):
                return False
        for d in (sigDirPath, secretDirPath):
            if not os.path.isfile(os.path.join(d, '{}.key_secret'.format(name))):
                return False
        return True

    def keyDirNames():
        return ZStack.PublicKeyDirName, ZStack.PrivateKeyDirName, \
            ZStack.VerifKeyDirName, ZStack.SigKeyDirName

    def getHaFromLocal(name, basedirpath):
        return None

    def __repr__(self):

    def homeDirPath(baseDirPath, name):
        return os.path.join(os.path.expanduser(baseDirPath), name)

    def publicDirPath(homeDirPath):
        return os.path.join(homeDirPath, ZStack.PublicKeyDirName)

    def secretDirPath(homeDirPath):
        return os.path.join(homeDirPath, ZStack.PrivateKeyDirName)

    def verifDirPath(homeDirPath):
        return os.path.join(homeDirPath, ZStack.VerifKeyDirName)

    def sigDirPath(homeDirPath):
        return os.path.join(homeDirPath, ZStack.SigKeyDirName)

    def learnKeysFromOthers(baseDir, name, others):
        homeDir = ZStack.homeDirPath(baseDir, name)
        verifDirPath = ZStack.verifDirPath(homeDir)
        pubDirPath = ZStack.publicDirPath(homeDir)
        for d in (homeDir, verifDirPath, pubDirPath):
            os.makedirs(d, exist_ok=True)

        for other in others:
            createCertsFromKeys(verifDirPath,, other.verKey)
            createCertsFromKeys(pubDirPath,, other.publicKey)

    def tellKeysToOthers(self, others):
        for other in others:
            createCertsFromKeys(other.verifKeyDir,, self.verKey)
            createCertsFromKeys(other.publicKeysDir,, self.publicKey)

    def setupDirs(self):
        self.homeDir = self.homeDirPath(self.basedirpath,
        self.publicKeysDir = self.publicDirPath(self.homeDir)
        self.secretKeysDir = self.secretDirPath(self.homeDir)
        self.verifKeyDir = self.verifDirPath(self.homeDir)
        self.sigKeyDir = self.sigDirPath(self.homeDir)

        for d in (self.homeDir, self.publicKeysDir, self.secretKeysDir,
                  self.verifKeyDir, self.sigKeyDir):
            os.makedirs(d, exist_ok=True)

    def setupOwnKeysIfNeeded(self):
        if not os.listdir(self.sigKeyDir):
            # If signing keys are not present, secret (private keys) should
            # not be present since they should be converted keys.
            assert not os.listdir(self.secretKeysDir)
            # Seed should be present
            assert self.seed, 'Keys are not setup for {}'.format(self)
            print("Signing and Encryption keys were not found for {}. Creating them now".format(self))
            tdirS = os.path.join(self.homeDir, '__skeys__')
            tdirE = os.path.join(self.homeDir, '__ekeys__')
            os.makedirs(tdirS, exist_ok=True)
            os.makedirs(tdirE, exist_ok=True)
            createEncAndSigKeys(tdirE, tdirS,, self.seed)
            moveKeyFilesToCorrectLocations(tdirE, self.publicKeysDir,
            moveKeyFilesToCorrectLocations(tdirS, self.verifKeyDir,

    def setupAuth(self, restricted=True, force=False):
        if self.auth and not force:
            raise RuntimeError('Listener already setup')
        location = self.publicKeysDir if restricted else zmq.auth.CURVE_ALLOW_ANY
        # self.auth = AsyncioAuthenticator(self.ctx)
        self.auth = MultiZapAuthenticator(self.ctx)
        self.auth.configure_curve(domain='*', location=location)

    def teardownAuth(self):
        if self.auth:

    def setupSigning(self):
        # Setup its signer from the signing key stored at disk and for all
        # verification keys stored at disk, add Verifier
        _, sk = self.selfSigKeys
        self.signer = Signer(z85.decode(sk))
        for vk in self.getAllVerKeys():

    def addVerifier(self, verkey):
        self.verifiers[verkey] = Verifier(z85.decode(verkey))

    def start(self, restricted=None, reSetupAuth=True):
        self.ctx = zmq.Context()
        restricted = self.restricted if restricted is None else restricted
        print('{} starting with restricted as {} and reSetupAuth '
              'as {}'.format(self, restricted, reSetupAuth))
        self.setupAuth(restricted, force=reSetupAuth)

    def stop(self):
        if self.opened:
            print('stack {} closing its listener'.format(self))
        print("stack {} stopped".format(self))

    def opened(self):
        return self.listener is not None

    def open(self):
        # noinspection PyUnresolvedReferences
        self.listener = self.ctx.socket(zmq.ROUTER)
        self._client_message_provider.listener = self.listener
        self.listener.setsockopt(zmq.ROUTER_MANDATORY, 1)
        self.listener.setsockopt(zmq.ROUTER_HANDOVER, 1)
        if self.create_listener_monitor:
            self.listener_monitor = self.listener.get_monitor_socket()
        # noinspection PyUnresolvedReferences
        # self.poller.register(self.listener, test.POLLIN)
        public, secret = self.selfEncKeys
        self.listener.curve_secretkey = secret
        self.listener.curve_publickey = public
        self.listener.curve_server = True
        self.listener.identity = self.publicKey
        print('{} will bind its listener at {}:{}'.format(self, self.ha[0], self.ha[1]))
        self.listener.setsockopt(zmq.TCP_KEEPALIVE, 1)
        self.listener.setsockopt(zmq.TCP_KEEPALIVE_INTVL, 1)
        self.listener.setsockopt(zmq.TCP_KEEPALIVE_IDLE, 20)
        self.listener.setsockopt(zmq.TCP_KEEPALIVE_CNT, 10)
        # Cycle to deal with "Address already in use" in case of immediate stack restart.
        bound = False

        sleep_between_bind_retries = 0.2
        bind_retry_time = 0
        while not bound:
                    '{protocol}://{ip}:{port}'.format(ip=self.ha[0], port=self.ha[1],
                bound = True
            except zmq.error.ZMQError as zmq_err:
                print("{} can not bind to {}:{}. Will try in {} secs.".
                      format(self, self.ha[0], self.ha[1], sleep_between_bind_retries))
                bind_retry_time += sleep_between_bind_retries
                if bind_retry_time > self.config.MAX_WAIT_FOR_BIND_SUCCESS:
                    raise zmq_err

    def close(self):
        if self.listener_monitor is not None:
            self.listener_monitor = None
        self.listener = None
        print('{} starting to disconnect remotes'.format(self))
        for r in self.remotes.values():
            self.remotesByKeys.pop(r.publicKey, None)

        self._remotes = {}
        if self.remotesByKeys:
            print('{} found remotes that were only in remotesByKeys and '
                  'not in remotes. This is suspicious')
            for r in self.remotesByKeys.values():
            self.remotesByKeys = {}
        self._conns = set()


    def selfEncKeys(self):
        serverSecretFile = os.path.join(self.secretKeysDir,
        return zmq.auth.load_certificate(serverSecretFile)

    def selfSigKeys(self):
        serverSecretFile = os.path.join(self.sigKeyDir,
        return zmq.auth.load_certificate(serverSecretFile)

    def isRestricted(self):
        return not self.auth.allow_any if self.auth is not None \
            else self.restricted

    def isKeySharing(self):
        # TODO: Change name after removing test
        return not self.isRestricted

    def getHa(self, name):
        # Return HA as None when its a `peersWithoutRemote`
        if self.onlyListener:
            return None
        return super().getHa(name)

    async def service(self, limit=None, quota: Optional[Quota] = None) -> int:
        Service `limit` number of received messages in this stack.

        :param limit: the maximum number of messages to be processed. If None,
        processes all of the messages in rxMsgs.
        :return: the number of messages processed.
        if self.listener:
            await self._serviceStack(quota)
            print("{} is stopped".format(self))

        r = len(self.rxMsgs)
        if r > 0:
            pracLimit = limit if limit else sys.maxsize
            return self.processReceived(pracLimit)
        return 0

    def _receiveFromListener(self, quota: Quota) -> int:
        Receives messages from listener
        :param quota: number of messages to receive
        :return: number of received messages
        i = 0
        incoming_size = 0
            ident, msg = self.listener.recv_multipart(flags=zmq.NOBLOCK)
            if msg:
                # Router probing sends empty message on connection
                incoming_size += len(msg)
                i += 1
                self._verifyAndAppend(msg, ident)
        except zmq.Again as e:
            return i
        except zmq.ZMQError as e:
            print("Strange ZMQ behaviour during node-to-node message receiving, experienced {}".format(e))
        if i > 0:
            print('{} got {} messages through listener'.
                  format(self, i))
        return i

    def _verifyAndAppend(self, msg, ident):
        except ValueError:
            print("Identifier {} is not decoded into UTF-8 string. "
                  "Request will not be processed".format(ident))
            return False
            decoded = msg.decode()
        except (UnicodeDecodeError) as ex:
            errstr = 'Message will be discarded due to {}'.format(ex)
            frm = self.remotesByKeys[ident].name if ident in self.remotesByKeys else ident
            print("Got from {} {}".format(frm, errstr))
            self.msgRejectHandler(errstr, frm)
            return False
        self.rxMsgs.append((decoded, ident))
        return True

    async def _serviceStack(self, quota: Optional[Quota] = None):

        return len(self.rxMsgs)

    def processReceived(self, limit):
        if limit <= 0:
            return 0
        num_processed = 0
        for num_processed in range(limit):
            if len(self.rxMsgs) == 0:
                return num_processed

            msg, ident = self.rxMsgs.popleft()
            frm = self.remotesByKeys[ident].name \
                if ident in self.remotesByKeys else ident
            self.msgHandler(self, (msg, frm))

        return num_processed + 1

    def doProcessReceived(self, msg, frm, ident):
        return msg

    def send(self, msg: Any, remoteName: str = None, ha=None):
        if self.onlyListener:
            return self._client_message_provider.transmit_through_listener(msg,

    def transmit(self, msg, uid, timeout=None, serialized=False, is_batch=False):
        remote = self.remotes.get(uid)
        err_str = None
        if not remote:
            return False, err_str
        socket = remote.socket
        if not socket:
            return False, err_str
            if not serialized:
                msg = self.prepare_to_send(msg)

            print('{} transmitting message {} to {} by socket {} {}'
                  .format(self, msg, uid, socket.FD, socket.underlying))
            socket.send(msg, flags=zmq.NOBLOCK)

            return True, err_str
        except zmq.Again:
            print('{} could not transmit message to {}'.format(self, uid))
        return False, err_str

    def serializeMsg(msg):
        if isinstance(msg, Mapping):
            msg = json.dumps(msg)
        if isinstance(msg, str):
            msg = msg.encode()
        assert isinstance(msg, bytes)
        return msg

    def deserializeMsg(msg):
        if isinstance(msg, bytes):
            msg = msg.decode()
        msg = json.loads(msg)
        return msg

    def signedMsg(self, msg: bytes, signer: Signer = None):
        sig = self.signer.signature(msg)
        return msg + sig

    def verify(self, msg, by):
        if self.isKeySharing:
            return True
        if by not in self.remotesByKeys:
            return False
        verKey = self.remotesByKeys[by].verKey
        r = self.verifiers[verKey].verify(
            msg[-self.sigLen:], msg[:-self.sigLen])
        return r

    def loadPubKeyFromDisk(directory, name):
        filePath = os.path.join(directory,
            public, _ = zmq.auth.load_certificate(filePath)
            return public
        except (ValueError, IOError) as ex:
            raise KeyError from ex

    def loadSecKeyFromDisk(directory, name):
        filePath = os.path.join(directory,
            _, secret = zmq.auth.load_certificate(filePath)
            return secret
        except (ValueError, IOError) as ex:
            raise KeyError from ex

    def publicKey(self):
        return self.getPublicKey(

    def publicKeyRaw(self):
        return z85.decode(self.publicKey)

    def pubhex(self):
        return hexlify(z85.decode(self.publicKey))

    def getPublicKey(self, name):
        return self.loadPubKeyFromDisk(self.publicKeysDir, name)

    def verKey(self):
        return self.getVerKey(

    def verKeyRaw(self):
        if self.verKey:
            return z85.decode(self.verKey)
        return None

    def verhex(self):
        if self.verKey:
            return hexlify(z85.decode(self.verKey))
        return None

    def getVerKey(self, name):
        return self.loadPubKeyFromDisk(self.verifKeyDir, name)

    def sigKey(self):
        return self.loadSecKeyFromDisk(self.sigKeyDir,

    # TODO: Change name to sighex after removing test
    def keyhex(self):
        return hexlify(z85.decode(self.sigKey))

    def priKey(self):
        return self.loadSecKeyFromDisk(self.secretKeysDir,

    def prihex(self):
        return hexlify(z85.decode(self.priKey))

    def getAllVerKeys(self):
        keys = []
        for key_file in os.listdir(self.verifKeyDir):
            if key_file.endswith(".key"):
                serverVerifFile = os.path.join(self.verifKeyDir,
                serverPublic, _ = zmq.auth.load_certificate(serverVerifFile)
        return keys

    def setRestricted(self, restricted: bool):
        if self.isRestricted != restricted:
            print('{} setting restricted to {}'.
                  format(self, restricted))

            # TODO: REMOVE, it will make code slow, only doing to allow the
            # socket to become available again

            self.start(restricted, reSetupAuth=True)

    def _safeRemove(self, filePath):
        except Exception as ex:
            print('{} could delete file {} due to {}'.format(self, filePath, ex))

    def clearLocalRoleKeep(self):
        for d in (self.secretKeysDir, self.sigKeyDir):
            filePath = os.path.join(d, "{}.key_secret".format(

        for d in (self.publicKeysDir, self.verifKeyDir):
            filePath = os.path.join(d, "{}.key".format(

    def clearRemoteRoleKeeps(self):
        for d in (self.secretKeysDir, self.sigKeyDir):
            for key_file in os.listdir(d):
                if key_file != '{}.key_secret'.format(
                    self._safeRemove(os.path.join(d, key_file))

        for d in (self.publicKeysDir, self.verifKeyDir):
            for key_file in os.listdir(d):
                if key_file != '{}.key'.format(
                    self._safeRemove(os.path.join(d, key_file))

    def clearAllDir(self):

    def prepare_to_send(self, msg: Any):
        msg_bytes = self.serializeMsg(msg)
        return msg_bytes

    def get_monitor_events(monitor_socket, non_block=True):
        events = []
        # noinspection PyUnresolvedReferences
        flags = zmq.NOBLOCK if non_block else 0
        while True:
                # noinspection PyUnresolvedReferences
                message = recv_monitor_message(monitor_socket, flags)
            except zmq.Again:
        return events