from test_zmq.authenticator import MultiZapAuthenticator try: import ujson as json except ImportError: import json import os import shutil import sys import time from binascii import hexlify, unhexlify from collections import deque from typing import Mapping, Tuple, Any, Union, Optional, NamedTuple import zmq.auth from zmq.utils import z85 from zmq.utils.monitor import recv_monitor_message import zmq from .util import createEncAndSigKeys, \ moveKeyFilesToCorrectLocations, createCertsFromKeys from test_zmq.crypto.nacl_wrappers import Signer, Verifier from test_zmq.crypto.util import isHex, ed25519PkToCurve25519 Quota = NamedTuple("Quota", [("count", int), ("size", int)]) ZMQ_NETWORK_PROTOCOL = 'tcp' class ClientMessageProvider: def __init__(self, name, prepare_to_send, mt_outgoing_size, listener=None): self._name = name self.listener = listener self._prepare_to_send = prepare_to_send self._mt_outgoing_size = mt_outgoing_size def transmit_through_listener(self, msg, ident) -> Tuple[bool, Optional[str]]: result, error_msg, need_to_resend = self._transmit_one_msg_throughlistener(msg, ident) return result, error_msg def _transmit_one_msg_throughlistener(self, msg, ident) -> Tuple[bool, Optional[str], bool]: def prepare_error_msg(ex): err_str = '{} got error {} while sending through listener to {}' \ .format(self, ex, ident) print(err_str) return err_str need_to_resend = False if isinstance(ident, str): ident = ident.encode() try: msg = self._prepare_to_send(msg) self.listener.send_multipart([ident, msg], flags=zmq.NOBLOCK) except zmq.Again as ex: need_to_resend = True return False, prepare_error_msg(ex), need_to_resend except zmq.ZMQError as ex: need_to_resend = (ex.errno == 113) return False, prepare_error_msg(ex), need_to_resend except Exception as ex: return False, prepare_error_msg(ex), need_to_resend return True, None, need_to_resend def __repr__(self): return self._name # TODO: Use Async io # TODO: There a number of methods related to keys management, # they can be moved to some class like KeysManager class ZStack(): # Assuming only one listener per stack for now. PublicKeyDirName = 'public_keys' PrivateKeyDirName = 'private_keys' VerifKeyDirName = 'verif_keys' SigKeyDirName = 'sig_keys' sigLen = 64 pingMessage = 'pi' pongMessage = 'po' healthMessages = {pingMessage.encode(), pongMessage.encode()} # TODO: This is not implemented, implement this messageTimeout = 3 def __init__(self, name, ha, basedirpath, msgHandler, restricted=True, seed=None, onlyListener=False, config=None, msgRejectHandler=None, queue_size=0, create_listener_monitor=False, mt_incoming_size=None, mt_outgoing_size=None, timer=None): self._name = name self.ha = ha self.basedirpath = basedirpath self.msgHandler = msgHandler self.seed = seed self.queue_size = queue_size self.msgRejectHandler = msgRejectHandler or self.__defaultMsgRejectHandler self._node_mode = None self._stashed_unknown_remote_msgs = deque() self.mt_incoming_size = mt_incoming_size self.mt_outgoing_size = mt_outgoing_size self.homeDir = None # As of now there would be only one file in secretKeysDir and sigKeyDir self.publicKeysDir = None self.secretKeysDir = None self.verifKeyDir = None self.sigKeyDir = None self.signer = None self.verifiers = {} self.setupDirs() self.setupOwnKeysIfNeeded() self.setupSigning() # self.poller = test.asyncio.Poller() self.restricted = restricted self.ctx = None # type: Context self.listener = None self.create_listener_monitor = create_listener_monitor self.listener_monitor = None self.auth = None # Each remote is identified uniquely by the name self._remotes = {} self.remotesByKeys = {} self.remote_ping_stats = {} # Indicates if this stack will maintain any remotes or will # communicate simply to listeners. Used in ClientZStack self.onlyListener = onlyListener self._conns = set() # type: Set[str] self.rxMsgs = deque() self._created = time.perf_counter() self.last_heartbeat_at = None self._stashed_pongs = set() self._received_pings = set() self._client_message_provider = ClientMessageProvider(self.name, self.prepare_to_send, self.mt_outgoing_size) def __defaultMsgRejectHandler(self, reason: str, frm): pass @property def remotes(self): return self._remotes @property def created(self): return self._created @property def name(self): return self._name def set_mode(self, value): self._node_mode = value @staticmethod def isRemoteConnected(r) -> bool: return r.isConnected @staticmethod def initLocalKeys(name, baseDir, sigseed, override=False): sDir = os.path.join(baseDir, '__sDir') eDir = os.path.join(baseDir, '__eDir') os.makedirs(sDir, exist_ok=True) os.makedirs(eDir, exist_ok=True) (public_key, secret_key), (verif_key, sig_key) = \ createEncAndSigKeys(eDir, sDir, name, seed=sigseed) homeDir = ZStack.homeDirPath(baseDir, name) verifDirPath = ZStack.verifDirPath(homeDir) sigDirPath = ZStack.sigDirPath(homeDir) secretDirPath = ZStack.secretDirPath(homeDir) pubDirPath = ZStack.publicDirPath(homeDir) for d in (homeDir, verifDirPath, sigDirPath, secretDirPath, pubDirPath): os.makedirs(d, exist_ok=True) moveKeyFilesToCorrectLocations(sDir, verifDirPath, sigDirPath) moveKeyFilesToCorrectLocations(eDir, pubDirPath, secretDirPath) shutil.rmtree(sDir) shutil.rmtree(eDir) return hexlify(public_key).decode(), hexlify(verif_key).decode() @staticmethod def initRemoteKeys(name, remoteName, baseDir, verkey, override=False): homeDir = ZStack.homeDirPath(baseDir, name) verifDirPath = ZStack.verifDirPath(homeDir) pubDirPath = ZStack.publicDirPath(homeDir) for d in (homeDir, verifDirPath, pubDirPath): os.makedirs(d, exist_ok=True) if isHex(verkey): verkey = unhexlify(verkey) createCertsFromKeys(verifDirPath, remoteName, z85.encode(verkey)) public_key = ed25519PkToCurve25519(verkey) createCertsFromKeys(pubDirPath, remoteName, z85.encode(public_key)) def onHostAddressChanged(self): # we don't store remote data like ip, port, domain name, etc, so # nothing to do here pass @staticmethod def areKeysSetup(name, baseDir): homeDir = ZStack.homeDirPath(baseDir, name) verifDirPath = ZStack.verifDirPath(homeDir) pubDirPath = ZStack.publicDirPath(homeDir) sigDirPath = ZStack.sigDirPath(homeDir) secretDirPath = ZStack.secretDirPath(homeDir) for d in (verifDirPath, pubDirPath): if not os.path.isfile(os.path.join(d, '{}.key'.format(name))): return False for d in (sigDirPath, secretDirPath): if not os.path.isfile(os.path.join(d, '{}.key_secret'.format(name))): return False return True @staticmethod def keyDirNames(): return ZStack.PublicKeyDirName, ZStack.PrivateKeyDirName, \ ZStack.VerifKeyDirName, ZStack.SigKeyDirName @staticmethod def getHaFromLocal(name, basedirpath): return None def __repr__(self): return self.name @staticmethod def homeDirPath(baseDirPath, name): return os.path.join(os.path.expanduser(baseDirPath), name) @staticmethod def publicDirPath(homeDirPath): return os.path.join(homeDirPath, ZStack.PublicKeyDirName) @staticmethod def secretDirPath(homeDirPath): return os.path.join(homeDirPath, ZStack.PrivateKeyDirName) @staticmethod def verifDirPath(homeDirPath): return os.path.join(homeDirPath, ZStack.VerifKeyDirName) @staticmethod def sigDirPath(homeDirPath): return os.path.join(homeDirPath, ZStack.SigKeyDirName) @staticmethod def learnKeysFromOthers(baseDir, name, others): homeDir = ZStack.homeDirPath(baseDir, name) verifDirPath = ZStack.verifDirPath(homeDir) pubDirPath = ZStack.publicDirPath(homeDir) for d in (homeDir, verifDirPath, pubDirPath): os.makedirs(d, exist_ok=True) for other in others: createCertsFromKeys(verifDirPath, other.name, other.verKey) createCertsFromKeys(pubDirPath, other.name, other.publicKey) def tellKeysToOthers(self, others): for other in others: createCertsFromKeys(other.verifKeyDir, self.name, self.verKey) createCertsFromKeys(other.publicKeysDir, self.name, self.publicKey) def setupDirs(self): self.homeDir = self.homeDirPath(self.basedirpath, self.name) self.publicKeysDir = self.publicDirPath(self.homeDir) self.secretKeysDir = self.secretDirPath(self.homeDir) self.verifKeyDir = self.verifDirPath(self.homeDir) self.sigKeyDir = self.sigDirPath(self.homeDir) for d in (self.homeDir, self.publicKeysDir, self.secretKeysDir, self.verifKeyDir, self.sigKeyDir): os.makedirs(d, exist_ok=True) def setupOwnKeysIfNeeded(self): if not os.listdir(self.sigKeyDir): # If signing keys are not present, secret (private keys) should # not be present since they should be converted keys. assert not os.listdir(self.secretKeysDir) # Seed should be present assert self.seed, 'Keys are not setup for {}'.format(self) print("Signing and Encryption keys were not found for {}. Creating them now".format(self)) tdirS = os.path.join(self.homeDir, '__skeys__') tdirE = os.path.join(self.homeDir, '__ekeys__') os.makedirs(tdirS, exist_ok=True) os.makedirs(tdirE, exist_ok=True) createEncAndSigKeys(tdirE, tdirS, self.name, self.seed) moveKeyFilesToCorrectLocations(tdirE, self.publicKeysDir, self.secretKeysDir) moveKeyFilesToCorrectLocations(tdirS, self.verifKeyDir, self.sigKeyDir) shutil.rmtree(tdirE) shutil.rmtree(tdirS) def setupAuth(self, restricted=True, force=False): if self.auth and not force: raise RuntimeError('Listener already setup') location = self.publicKeysDir if restricted else zmq.auth.CURVE_ALLOW_ANY # self.auth = AsyncioAuthenticator(self.ctx) self.auth = MultiZapAuthenticator(self.ctx) self.auth.start() self.auth.allow('0.0.0.0') self.auth.configure_curve(domain='*', location=location) def teardownAuth(self): if self.auth: self.auth.stop() def setupSigning(self): # Setup its signer from the signing key stored at disk and for all # verification keys stored at disk, add Verifier _, sk = self.selfSigKeys self.signer = Signer(z85.decode(sk)) for vk in self.getAllVerKeys(): self.addVerifier(vk) def addVerifier(self, verkey): self.verifiers[verkey] = Verifier(z85.decode(verkey)) def start(self, restricted=None, reSetupAuth=True): self.ctx = zmq.Context() restricted = self.restricted if restricted is None else restricted print('{} starting with restricted as {} and reSetupAuth ' 'as {}'.format(self, restricted, reSetupAuth)) self.setupAuth(restricted, force=reSetupAuth) self.open() def stop(self): if self.opened: print('stack {} closing its listener'.format(self)) self.close() print("stack {} stopped".format(self)) @property def opened(self): return self.listener is not None def open(self): # noinspection PyUnresolvedReferences self.listener = self.ctx.socket(zmq.ROUTER) self._client_message_provider.listener = self.listener self.listener.setsockopt(zmq.ROUTER_MANDATORY, 1) self.listener.setsockopt(zmq.ROUTER_HANDOVER, 1) if self.create_listener_monitor: self.listener_monitor = self.listener.get_monitor_socket() # noinspection PyUnresolvedReferences # self.poller.register(self.listener, test.POLLIN) public, secret = self.selfEncKeys self.listener.curve_secretkey = secret self.listener.curve_publickey = public self.listener.curve_server = True self.listener.identity = self.publicKey print('{} will bind its listener at {}:{}'.format(self, self.ha[0], self.ha[1])) self.listener.setsockopt(zmq.TCP_KEEPALIVE, 1) self.listener.setsockopt(zmq.TCP_KEEPALIVE_INTVL, 1) self.listener.setsockopt(zmq.TCP_KEEPALIVE_IDLE, 20) self.listener.setsockopt(zmq.TCP_KEEPALIVE_CNT, 10) self.listener.set_hwm(0) # Cycle to deal with "Address already in use" in case of immediate stack restart. bound = False sleep_between_bind_retries = 0.2 bind_retry_time = 0 while not bound: try: self.listener.bind( '{protocol}://{ip}:{port}'.format(ip=self.ha[0], port=self.ha[1], protocol=ZMQ_NETWORK_PROTOCOL) ) bound = True except zmq.error.ZMQError as zmq_err: print("{} can not bind to {}:{}. Will try in {} secs.". format(self, self.ha[0], self.ha[1], sleep_between_bind_retries)) bind_retry_time += sleep_between_bind_retries if bind_retry_time > self.config.MAX_WAIT_FOR_BIND_SUCCESS: raise zmq_err time.sleep(sleep_between_bind_retries) def close(self): if self.listener_monitor is not None: self.listener.disable_monitor() self.listener_monitor = None self.listener.unbind(self.listener.LAST_ENDPOINT) self.listener.close(linger=0) self.listener = None print('{} starting to disconnect remotes'.format(self)) for r in self.remotes.values(): r.disconnect() self.remotesByKeys.pop(r.publicKey, None) self._remotes = {} if self.remotesByKeys: print('{} found remotes that were only in remotesByKeys and ' 'not in remotes. This is suspicious') for r in self.remotesByKeys.values(): r.disconnect() self.remotesByKeys = {} self._conns = set() self.teardownAuth() @property def selfEncKeys(self): serverSecretFile = os.path.join(self.secretKeysDir, "{}.key_secret".format(self.name)) return zmq.auth.load_certificate(serverSecretFile) @property def selfSigKeys(self): serverSecretFile = os.path.join(self.sigKeyDir, "{}.key_secret".format(self.name)) return zmq.auth.load_certificate(serverSecretFile) @property def isRestricted(self): return not self.auth.allow_any if self.auth is not None \ else self.restricted @property def isKeySharing(self): # TODO: Change name after removing test return not self.isRestricted def getHa(self, name): # Return HA as None when its a `peersWithoutRemote` if self.onlyListener: return None return super().getHa(name) async def service(self, limit=None, quota: Optional[Quota] = None) -> int: """ Service `limit` number of received messages in this stack. :param limit: the maximum number of messages to be processed. If None, processes all of the messages in rxMsgs. :return: the number of messages processed. """ if self.listener: await self._serviceStack(quota) else: print("{} is stopped".format(self)) r = len(self.rxMsgs) if r > 0: pracLimit = limit if limit else sys.maxsize return self.processReceived(pracLimit) return 0 def _receiveFromListener(self, quota: Quota) -> int: """ Receives messages from listener :param quota: number of messages to receive :return: number of received messages """ i = 0 incoming_size = 0 try: ident, msg = self.listener.recv_multipart(flags=zmq.NOBLOCK) if msg: # Router probing sends empty message on connection incoming_size += len(msg) i += 1 self._verifyAndAppend(msg, ident) except zmq.Again as e: return i except zmq.ZMQError as e: print("Strange ZMQ behaviour during node-to-node message receiving, experienced {}".format(e)) if i > 0: print('{} got {} messages through listener'. format(self, i)) return i def _verifyAndAppend(self, msg, ident): try: ident.decode() except ValueError: print("Identifier {} is not decoded into UTF-8 string. " "Request will not be processed".format(ident)) return False try: decoded = msg.decode() except (UnicodeDecodeError) as ex: errstr = 'Message will be discarded due to {}'.format(ex) frm = self.remotesByKeys[ident].name if ident in self.remotesByKeys else ident print("Got from {} {}".format(frm, errstr)) self.msgRejectHandler(errstr, frm) return False self.rxMsgs.append((decoded, ident)) return True async def _serviceStack(self, quota: Optional[Quota] = None): self._receiveFromListener(quota) return len(self.rxMsgs) def processReceived(self, limit): if limit <= 0: return 0 num_processed = 0 for num_processed in range(limit): if len(self.rxMsgs) == 0: return num_processed msg, ident = self.rxMsgs.popleft() frm = self.remotesByKeys[ident].name \ if ident in self.remotesByKeys else ident self.msgHandler(self, (msg, frm)) return num_processed + 1 def doProcessReceived(self, msg, frm, ident): return msg def send(self, msg: Any, remoteName: str = None, ha=None): if self.onlyListener: return self._client_message_provider.transmit_through_listener(msg, remoteName) def transmit(self, msg, uid, timeout=None, serialized=False, is_batch=False): remote = self.remotes.get(uid) err_str = None if not remote: return False, err_str socket = remote.socket if not socket: return False, err_str try: if not serialized: msg = self.prepare_to_send(msg) print('{} transmitting message {} to {} by socket {} {}' .format(self, msg, uid, socket.FD, socket.underlying)) socket.send(msg, flags=zmq.NOBLOCK) return True, err_str except zmq.Again: print('{} could not transmit message to {}'.format(self, uid)) return False, err_str @staticmethod def serializeMsg(msg): if isinstance(msg, Mapping): msg = json.dumps(msg) if isinstance(msg, str): msg = msg.encode() assert isinstance(msg, bytes) return msg @staticmethod def deserializeMsg(msg): if isinstance(msg, bytes): msg = msg.decode() msg = json.loads(msg) return msg def signedMsg(self, msg: bytes, signer: Signer = None): sig = self.signer.signature(msg) return msg + sig def verify(self, msg, by): if self.isKeySharing: return True if by not in self.remotesByKeys: return False verKey = self.remotesByKeys[by].verKey r = self.verifiers[verKey].verify( msg[-self.sigLen:], msg[:-self.sigLen]) return r @staticmethod def loadPubKeyFromDisk(directory, name): filePath = os.path.join(directory, "{}.key".format(name)) try: public, _ = zmq.auth.load_certificate(filePath) return public except (ValueError, IOError) as ex: raise KeyError from ex @staticmethod def loadSecKeyFromDisk(directory, name): filePath = os.path.join(directory, "{}.key_secret".format(name)) try: _, secret = zmq.auth.load_certificate(filePath) return secret except (ValueError, IOError) as ex: raise KeyError from ex @property def publicKey(self): return self.getPublicKey(self.name) @property def publicKeyRaw(self): return z85.decode(self.publicKey) @property def pubhex(self): return hexlify(z85.decode(self.publicKey)) def getPublicKey(self, name): return self.loadPubKeyFromDisk(self.publicKeysDir, name) @property def verKey(self): return self.getVerKey(self.name) @property def verKeyRaw(self): if self.verKey: return z85.decode(self.verKey) return None @property def verhex(self): if self.verKey: return hexlify(z85.decode(self.verKey)) return None def getVerKey(self, name): return self.loadPubKeyFromDisk(self.verifKeyDir, name) @property def sigKey(self): return self.loadSecKeyFromDisk(self.sigKeyDir, self.name) # TODO: Change name to sighex after removing test @property def keyhex(self): return hexlify(z85.decode(self.sigKey)) @property def priKey(self): return self.loadSecKeyFromDisk(self.secretKeysDir, self.name) @property def prihex(self): return hexlify(z85.decode(self.priKey)) def getAllVerKeys(self): keys = [] for key_file in os.listdir(self.verifKeyDir): if key_file.endswith(".key"): serverVerifFile = os.path.join(self.verifKeyDir, key_file) serverPublic, _ = zmq.auth.load_certificate(serverVerifFile) keys.append(serverPublic) return keys def setRestricted(self, restricted: bool): if self.isRestricted != restricted: print('{} setting restricted to {}'. format(self, restricted)) self.stop() # TODO: REMOVE, it will make code slow, only doing to allow the # socket to become available again time.sleep(1) self.start(restricted, reSetupAuth=True) def _safeRemove(self, filePath): try: os.remove(filePath) except Exception as ex: print('{} could delete file {} due to {}'.format(self, filePath, ex)) def clearLocalRoleKeep(self): for d in (self.secretKeysDir, self.sigKeyDir): filePath = os.path.join(d, "{}.key_secret".format(self.name)) self._safeRemove(filePath) for d in (self.publicKeysDir, self.verifKeyDir): filePath = os.path.join(d, "{}.key".format(self.name)) self._safeRemove(filePath) def clearRemoteRoleKeeps(self): for d in (self.secretKeysDir, self.sigKeyDir): for key_file in os.listdir(d): if key_file != '{}.key_secret'.format(self.name): self._safeRemove(os.path.join(d, key_file)) for d in (self.publicKeysDir, self.verifKeyDir): for key_file in os.listdir(d): if key_file != '{}.key'.format(self.name): self._safeRemove(os.path.join(d, key_file)) def clearAllDir(self): shutil.rmtree(self.homeDir) def prepare_to_send(self, msg: Any): msg_bytes = self.serializeMsg(msg) return msg_bytes @staticmethod def get_monitor_events(monitor_socket, non_block=True): events = [] # noinspection PyUnresolvedReferences flags = zmq.NOBLOCK if non_block else 0 while True: try: # noinspection PyUnresolvedReferences message = recv_monitor_message(monitor_socket, flags) events.append(message) except zmq.Again: break return events