""" Core agent classes. """ import contextlib import errno import inspect import json import multiprocessing import os import pickle import signal import sys import time import types from datetime import datetime from typing import Any from typing import Dict from typing import Union import cloudpickle import dill import Pyro4 import zmq from Pyro4.errors import PyroError from . import config from .address import AgentAddress from .address import AgentAddressKind from .address import AgentAddressSerializer from .address import AgentChannel from .address import address_to_host_port from .address import guess_kind from .common import LogLevel from .common import after from .common import format_exception from .common import format_method_exception from .common import get_linger from .common import repeat from .common import topic_to_bytes from .common import topics_to_bytes from .common import unbound_method from .common import unique_identifier from .common import validate_handler from .proxy import NSProxy from .proxy import Proxy TOPIC_SEPARATOR = b'\x80' _serialize_calls = { 'pickle': lambda message: pickle.dumps(message, -1), 'cloudpickle': lambda message: cloudpickle.dumps(message, -1), 'dill': lambda message: dill.dumps(message, -1), 'json': lambda message: json.dumps(message).encode(), 'raw': lambda message: message, } _deserialize_calls = { 'pickle': lambda message: pickle.loads(message), 'cloudpickle': lambda message: cloudpickle.loads(message), 'dill': lambda message: dill.loads(message), 'json': lambda message: json.loads(bytes(message).decode()), 'raw': lambda message: message, } def serialize_message(message, serializer): """ Check if a message needs to be serialized and do it if that is the case. Parameters ---------- message : anything The message to serialize. serializer : AgentAddressSerializer The type of serializer that should be used. Returns ------- bytes The serialized message, or the same message in case no serialization is needed. """ try: return _serialize_calls[serializer](message) except KeyError: raise ValueError('"%s" not supported for serialization' % serializer) def deserialize_message(message, serializer): """ Check if a message needs to be deserialized and do it if that is the case. Parameters ---------- message : bytes, memoryview The serialized message. serializer : AgentAddressSerializer The type of (de)serializer that should be used. Returns ------- anything The deserialized message, or the same message in case no deserialization is needed. """ try: return _deserialize_calls[serializer](message) except KeyError: raise ValueError('"%s" not supported for deserialization' % serializer) def compose_message( message: bytes, topic: bytes, serializer: AgentAddressSerializer ) -> bytes: """ Compose a message and leave it ready to be sent through a socket. This is used in PUB-SUB patterns to combine the topic and the message in a single bytes buffer. Parameters ---------- message Message to be composed. topic Topic to combine the message with. serializer Serialization for the message part. Returns ------- The bytes representation of the final message to be sent. """ if serializer.requires_separator: return topic + TOPIC_SEPARATOR + message return topic + message def execute_code_after_yield(generator): """ Some responses are dispatched with yield (generator handler). In those cases we still want to execute the remaining code in the generator, and also make sure it does not yield any more. Parameters ---------- generator The handler that already yielded one result and is not expected to yield again. Raises ------ ValueError If the generator yielded once more, which is unexpected. """ try: next(generator) except StopIteration: pass else: raise ValueError('Reply handler yielded more than once!') class Agent: """ A base agent class which is to be served by an AgentProcess. An AgentProcess runs a Pyro multiplexed server and serves one Agent object. Parameters ---------- name : str, default is None Name of the Agent. host : str, default is None Host address where the agent will bind to. When not set, ``'127.0.0.1'`` (localhost) is used. transport : str, AgentAddressTransport, default is None Transport protocol. attributes : dict, default is None A dictionary that defines initial attributes for the agent. Attributes ---------- name : str Name of the agent. _host : str Host address where the agent is binding to. _uuid : bytes Globally unique identifier for the agent. _running : bool Set to ``True`` if the agent is running (executing the main loop). _serializer : str Default agent serialization format. _transport : str, AgentAddressTransport, default is None Default agent transport protocol. _socket : dict A dictionary in which the key is the address or the alias and the value is the actual socket. _address : dict A dictionary in which the key is the address or the alias and the value is the actual address. _handler : dict A dictionary in which the key is the socket and the values are the handlers for each socket. _context : zmq.Context ZMQ context to create ZMQ socket objects. _poller : zmq.Poller ZMQ poller to wait for incoming data from sockets. _poll_timeout : int Polling timeout, in milliseconds. After this timeout, if no message is received, the agent executes de `idle()` method before going back to polling. _keep_alive : bool When set to ``True``, the agent will continue executing the main loop. _async_req_uuid : dict Stores the UUIDs of the asynchronous request sockets (used in communication channels). _async_req_handler : dict Stores the handler for every asynchronous request sockets (used in communication channels). _die_now : bool During shutdown, this attribute is set for the agent to die. _DEBUG : bool Whether to print debug level messages. _pending_requests : dict Stores pending (waiting for reply) asynchronous requests. The asynchronous request UUID is used as key and its handler as the value. _timer : dict Stores all the current active timers, using their aliases as keys. """ def __init__( self, name='', host=None, serializer=None, transport=None, attributes=None, ): self._uuid = unique_identifier() self.name = name if not self.name: self.name = self._uuid.decode() self._host = host if not self._host: self._host = '127.0.0.1' self._serializer = serializer self._transport = transport self._socket = {} self._address = {} self._handler = {} self._async_req_uuid = {} self._async_req_handler = {} self._pending_requests = {} self._timer = {} self._poll_timeout = 1000 self._keep_alive = True self._die_now = False self._running = False self._DEBUG = False self._context = zmq.Context() self._poller = zmq.Poller() self._set_attributes(attributes) self.on_init() def _set_attributes(self, attributes): if not attributes: return for key, value in attributes.items(): if hasattr(self, key): raise KeyError('Agent already has "%s" attribute!' % key) setattr(self, key, value) def on_init(self): """ This user-defined method is to be executed after initialization. """ pass def before_loop(self): """ This user-defined method is to be executed right before the main loop. """ pass def _handle_loopback(self, message): """ Handle incoming messages in the loopback socket. """ header, data = cloudpickle.loads(message) if header == 'EXECUTE_METHOD': method, args, kwargs = data try: response = getattr(self, method)(*args, **kwargs) except Exception as error: yield format_method_exception(error, method, args, kwargs) raise yield response or True else: error = 'Unrecognized loopback message: {} {}'.format(header, data) self.log_error(error) yield error def _handle_loopback_safe(self, data): """ Handle incoming messages in the _loopback_safe socket. """ method, args, kwargs = cloudpickle.loads(data) try: response = getattr(self, method)(*args, **kwargs) except Exception as error: yield format_method_exception(error, method, args, kwargs) raise yield response def _loopback_reqrep(self, socket, data_to_send): """ Create a temporary connection a loopback socket and send a request. Returns ------- Response obtained from the loopback socket. """ loopback = self._context.socket(zmq.REQ) try: loopback.connect(socket) loopback.send_pyobj(data_to_send) return loopback.recv_pyobj() except zmq.error.ContextTerminated: pass finally: loopback.close(linger=0) def _loopback(self, header, data=None): """ Send a message to the loopback socket. """ if not self._running: raise NotImplementedError() data = cloudpickle.dumps((header, data)) return self._loopback_reqrep('inproc://loopback', data) def safe_call(self, method, *args, **kwargs): """ A safe call to a method. A safe call is simply sent to be executed by the main thread. Parameters ---------- method : str Method name to be executed by the main thread. *args : arguments Method arguments. *kwargs : keyword arguments Method keyword arguments. """ if not self._running: raise RuntimeError( 'Agent must be running to safely execute methods!' ) data = cloudpickle.dumps((method, args, kwargs)) return self._loopback_reqrep('inproc://_loopback_safe', data) def each(self, period, method, *args, alias=None, **kwargs): """ Execute a repeated action with a defined period. Parameters ---------- period : float Repeat the action execution with a delay of ``period`` seconds between executions. method Method (action) to be executed by the agent. alias : str, default is None An alias for the generated timer. *args : tuple Parameters to pass for the method execution. **kwargs : dict Named parameters to pass for the method execution. Returns ------- str The timer alias or identifier. """ if not isinstance(method, str): method = self.set_method(method) timer = repeat( period, self._loopback, 'EXECUTE_METHOD', (method, args, kwargs) ) if not alias: alias = unique_identifier() self._timer[alias] = timer return alias def after(self, delay, method, *args, alias=None, **kwargs): """ Execute an action after a delay. Parameters ---------- delay : float Execute the action after ``delay`` seconds. method Method (action) to be executed by the agent. alias : str, default is None An alias for the generated timer. *args : tuple Parameters to pass for the method execution. **kwargs : dict Named parameters to pass for the method execution. Returns ------- str The timer alias or identifier. """ if not isinstance(method, str): method = self.set_method(method) timer = after( delay, self._loopback, 'EXECUTE_METHOD', (method, args, kwargs) ) if not alias: alias = unique_identifier() self._timer[alias] = timer return alias def stop_all_timers(self): """ Stop all currently running timers. """ for alias in self.list_timers(): self.stop_timer(alias) def stop_timer(self, alias): """ Stop a currently running timer. Parameters ---------- alias : str The alias or identifier of the timer. """ self._timer[alias].stop() del self._timer[alias] def list_timers(self): """ Return a list with all timer aliases currently running. Returns ------- list (str) A list with all the timer aliases currently running. """ return list(self._timer.keys()) def raise_exception(self): """ Raise an exception (for testing purposes). """ raise RuntimeError('User raised an exception!') def stop(self): """ Stop the agent. Agent will stop running. """ self.log_debug('Stopping...') self._keep_alive = False return 'OK' def set_logger(self, logger, alias='_logger'): """ Connect the agent to a logger and start logging messages to it. """ if isinstance(logger, Proxy): logger = logger.addr('sub') if not isinstance(logger, AgentAddress): raise ValueError('An AgentAddress must be provided for logging!') self.connect(logger, alias=alias) def _log_message(self, level, message, logger='_logger'): """ Log a message. Parameters ---------- level : LogLevel Logging severity level: INFO, WARNING, ERROR, DEBUG. message : str Message to log. logger : str Alias of the logger. """ level = LogLevel(level) message = '[%s] (%s): %s' % (datetime.utcnow(), self.name, message) if self._registered(logger): logger_kind = AgentAddressKind(self._address[logger].kind) assert ( logger_kind == 'PUB' ), 'Logger must use publisher-subscriber pattern!' self.send(logger, message, topic=level) elif level in ('INFO', 'DEBUG'): sys.stdout.write('%s %s\n' % (level, message)) sys.stdout.flush() # When logging an error, always write to stderr if level == 'ERROR': sys.stderr.write('ERROR %s\n' % message) sys.stderr.flush() # When logging a warning, always write to stdout elif level == 'WARNING': sys.stdout.write('WARNING %s\n' % message) sys.stdout.flush() def log_error(self, message, logger='_logger'): """ Log an error message. Parameters ---------- message : str Message to log. logger : str Alias of the logger. """ self._log_message('ERROR', message, logger) def log_warning(self, message, logger='_logger'): """ Log a warning message. Parameters ---------- message : str Message to log. logger : str Alias of the logger. """ self._log_message('WARNING', message, logger) def log_info(self, message, logger='_logger'): """ Log an info message. Parameters ---------- message : str Message to log. logger : str Alias of the logger. """ self._log_message('INFO', message, logger) def log_debug(self, message, logger='_logger'): """ Log a debug message. Parameters ---------- message : str Message to log. logger : str Alias of the logger. """ # Ignore DEBUG logs if not `self._DEBUG` if not self._DEBUG: return self._log_message('DEBUG', message, logger) def addr(self, alias): """ Return the address of a socket given by its alias. Parameters ---------- alias : str Alias of the socket whose address is to be retrieved. Returns ------- AgentAddress Address of the agent socket associated with the alias. """ return self._address[alias] def _register(self, socket, address, alias=None, handler=None): """ Internally register a connection as a socket-address pair. Parameters ---------- socket : zmq.Socket The socket object to store. address : str, AgentAddress The address the socket is bound to. alias : str, default is None Optional alias for the connection. handler : function(s) Optional handler(s) for the socket. This can be a list or a dictionary too. """ assert not self._registered(address), 'Socket is already registered!' if not alias: alias = address self._socket[alias] = socket self._socket[address] = socket self._socket[socket] = socket self._address[alias] = address self._address[socket] = address self._address[address] = address if handler is not None: self._poller.register(socket, zmq.POLLIN) if address.kind in ('SUB', 'SYNC_SUB'): self.subscribe(socket, handler) else: self._set_handler(socket, handler) def _set_handler(self, socket, handler, update=False): """ Set the socket handler(s). Parameters ---------- socket : zmq.Socket Socket to set its handler(s). handler : function(s) Handler(s) for the socket. This can be a list or a dictionary too. """ if update: try: self._handler[socket].update(self._curated_handlers(handler)) except KeyError: self._handler[socket] = self._curated_handlers(handler) else: self._handler[socket] = self._curated_handlers(handler) def _curated_handlers(self, handler): if isinstance(handler, (list, tuple)): return [self._curate_handler(h) for h in handler] if isinstance(handler, dict): return {k: self._curate_handler(v) for k, v in handler.items()} return self._curate_handler(handler) def _curate_handler(self, handler): if isinstance(handler, str): handler = getattr(self, handler) function_type = (types.FunctionType, types.BuiltinFunctionType) if isinstance(handler, function_type): return handler method_type = (types.MethodType, types.BuiltinMethodType) if isinstance(handler, method_type): return unbound_method(handler) raise TypeError('Unknown handler type "%s"' % type(handler)) def _registered(self, address): """ Check if an address is already registered. """ return address in self._socket def bind( self, kind, alias=None, handler=None, addr=None, transport=None, serializer=None, ): """ Bind to an agent address. Parameters ---------- kind : str, AgentAddressKind The agent address kind: PUB, REQ... alias : str, default is None Optional alias for the socket. handler, default is None If the socket receives input messages, the handler/s is/are to be set with this parameter. addr : str, default is None The address to bind to. transport : str, AgentAddressTransport, default is None Transport protocol. Returns ------- AgentAddress The address where the agent binded to. """ kind = guess_kind(kind) transport = transport or self._transport or config['TRANSPORT'] serializer = serializer or self._serializer or config['SERIALIZER'] if isinstance(kind, AgentAddressKind): return self._bind_address( kind, alias, handler, addr, transport, serializer ) else: return self._bind_channel( kind, alias, handler, addr, transport, serializer ) def _bind_address( self, kind, alias=None, handler=None, addr=None, transport=None, serializer=None, ): """ Bind to an agent address. Parameters ---------- kind : str, AgentAddressKind The agent address kind: PUB, REQ... alias : str, default is None Optional alias for the socket. handler, default is None If the socket receives input messages, the handler/s is/are to be set with this parameter. addr : str, default is None The address to bind to. transport : str, AgentAddressTransport, default is None Transport protocol. Returns ------- AgentAddress The address where the agent binded to. """ validate_handler(handler, required=kind.requires_handler()) socket = self._context.socket(kind.zmq()) addr = self._bind_socket(socket, addr=addr, transport=transport) server_address = AgentAddress( transport, addr, kind, 'server', serializer ) self._register(socket, server_address, alias, handler) # SUB sockets are a special case if kind == 'SUB': self.subscribe(server_address, handler) return server_address def _bind_channel( self, kind, alias=None, handler=None, addr=None, transport=None, serializer=None, ): """ Bind process for channels. Parameters ---------- kind : str, AgentAddressKind The agent address kind: PUB, REQ... alias : str, default is None Optional alias for the socket. handler, default is None If the socket receives input messages, the handler/s is/are to be set with this parameter. addr : str, default is None The address to bind to. transport : str, AgentAddressTransport, default is None Transport protocol. Returns ------- AgentChannel The channel where the agent binded to. """ if kind == 'ASYNC_REP': validate_handler(handler, required=True) socket = self._context.socket(zmq.PULL) addr = self._bind_socket(socket, addr=addr, transport=transport) server_address = AgentAddress( transport, addr, 'PULL', 'server', serializer ) channel = AgentChannel(kind, receiver=server_address, sender=None) self._register(socket, channel, alias, handler) return channel if kind == 'SYNC_PUB': if addr: raise NotImplementedError() if not addr: addr = (None, None) pull_address = self.bind( 'PULL_SYNC_PUB', addr=addr[0], handler=handler, transport=transport, serializer=serializer, ) pub_socket = self._context.socket(zmq.PUB) aux = self._bind_socket( pub_socket, addr=addr[1], transport=transport ) pub_address = AgentAddress( transport, aux, 'PUB', 'server', serializer ) channel = AgentChannel( kind, receiver=pull_address, sender=pub_address ) self._register(pub_socket, channel, alias=alias) return channel else: raise NotImplementedError('Unsupported channel kind %s!' % kind) def _bind_socket(self, socket, addr=None, transport=None): """ Bind a socket using the corresponding transport and address. Parameters ---------- socket : zmq.Socket Socket to bind. addr : str, default is None The address to bind to. transport : str, AgentAddressTransport, default is None Transport protocol. Returns ------- addr : str The address where the socket binded to. """ if transport == 'tcp': return self._bind_socket_tcp(socket, addr=addr) if not addr: addr = unique_identifier().decode('ascii') if transport == 'ipc': addr = config['IPC_DIR'] / addr socket.bind('%s://%s' % (transport, addr)) return addr def _bind_socket_tcp(self, socket, addr): """ Bind a socket using the TCP transport and corresponding address. Parameters ---------- socket : zmq.Socket Socket to bind. addr : str, default is None The address to bind to. Returns ------- addr : str The address where the socket binded to. """ host, port = address_to_host_port(addr) if not host: host = self._host if not port: uri = 'tcp://%s' % host port = socket.bind_to_random_port(uri) addr = host + ':' + str(port) else: socket.bind('tcp://%s' % (addr)) return addr def connect(self, server, alias=None, handler=None): """ Connect to a server agent address. Parameters ---------- server : AgentAddress Agent address to connect to. alias : str, default is None Optional alias for the new address. handler, default is None If the new socket receives input messages, the handler/s is/are to be set with this parameter. """ if isinstance(server, AgentAddress): return self._connect_address(server, alias=alias, handler=handler) else: return self._connect_channel(server, alias=alias, handler=handler) def _connect_address(self, server_address, alias=None, handler=None): """ Connect to a basic ZMQ agent address. Parameters ---------- server_address : AgentAddress Agent address to connect to. alias : str, default is None Optional alias for the new address. handler, default is None If the new socket receives input messages, the handler/s is/are to be set with this parameter. """ assert ( server_address.role == 'server' ), 'Incorrect address! A server address must be provided!' client_address = server_address.twin() validate_handler( handler, required=client_address.kind.requires_handler() ) if self._registered(client_address): self._connect_old(client_address, alias, handler) else: self._connect_and_register(client_address, alias, handler) if client_address.kind == 'SUB': if not alias: alias = client_address self.subscribe(alias, handler) return client_address def _connect_channel(self, channel, alias=None, handler=None): """ Connect to a server agent channel. Parameters ---------- channel : AgentChannel Agent channel to connect to. alias : str, default is None Optional alias for the new channel. handler, default is None If the new socket receives input messages, the handler/s is/are to be set with this parameter. """ kind = channel.kind if kind == 'ASYNC_REP': return self._connect_channel_async_rep( channel, handler=handler, alias=alias ) if kind == 'SYNC_PUB': return self._connect_channel_sync_pub( channel, handler=handler, alias=alias ) raise NotImplementedError('Unsupported channel kind %s!' % kind) def _connect_channel_async_rep(self, channel, handler, alias=None): """ Connect to a server agent ASYNC_REP channel. Parameters ---------- channel : AgentChannel Agent channel to connect to. alias : str, default is None Optional alias for the new channel. handler, default is None If the new socket receives input messages, the handler/s is/are to be set with this parameter. """ # Connect PUSH-PULL (asynchronous REQ-REP) pull_address = channel.receiver self._connect_address(pull_address, alias=alias, handler=None) if self._registered(channel): raise NotImplementedError('Tried to (re)connect a channel') self._connect_and_register( pull_address.twin(), alias=alias, register_as=channel ) # Create socket for receiving responses uuid = unique_identifier() addr = self.bind( 'PULL', alias=uuid, handler=self._handle_async_requests ) self._async_req_uuid[pull_address] = uuid self._async_req_uuid[pull_address.twin()] = uuid self._async_req_uuid[addr] = uuid self._async_req_handler[uuid] = handler def _connect_channel_sync_pub(self, channel, handler, alias=None): """ Connect to a server agent SYNC_PUB channel. Parameters ---------- channel : AgentChannel Agent channel to connect to. alias : str, default is None Optional alias for the new channel. handler, default is None If the new socket receives input messages, the handler/s is/are to be set with this parameter. """ # Connect PUSH-PULL (synchronous PUB-SUB) client_channel = channel.twin() self._connect_address(channel.receiver, alias=alias, handler=None) if self._registered(channel): raise NotImplementedError('Tried to (re)connect a channel') self._connect_and_register( client_channel.sender, alias=alias, register_as=client_channel ) # Create socket for receiving responses pub_address = channel.sender assert pub_address.kind == 'PUB' uuid = unique_identifier() topic_handlers = {} if isinstance(handler, dict): topic_handlers = topics_to_bytes(handler, uuid=channel.uuid) else: topic_handlers[channel.uuid] = handler topic_handlers[uuid] = self._handle_async_requests addr = self.connect(pub_address, alias=uuid, handler=topic_handlers) assert addr.kind == 'SUB' self._async_req_uuid[channel.receiver] = uuid self._async_req_uuid[client_channel.sender] = uuid self._async_req_uuid[addr] = uuid self._async_req_handler[uuid] = handler return client_channel def _connect_old(self, client_address, alias=None, handler=None): if handler is not None: raise NotImplementedError('Undefined behavior!') self._socket[alias] = self._socket[client_address] self._address[alias] = client_address return client_address def _connect_and_register( self, client_address, alias=None, handler=None, register_as=None ): """ Establish and register a new connection. Parameters ---------- client_address : AgentAddress The address to connect to. alias : str Optional alias for the connection. handler : function(s) Optional handler(s) for the socket. register_as What the socket should be registered as (usually an AgentAddress). """ if not register_as: register_as = client_address socket = self._context.socket(client_address.kind.zmq()) socket.connect( '%s://%s' % (client_address.transport, client_address.address) ) self._register(socket, register_as, alias, handler) return client_address def _handle_async_requests(self, data): """ Receive and process an async request. """ address_uuid, uuid, response = data if uuid not in self._pending_requests: error = 'Received response for an unknown request! %s' % uuid self.log_warning(error) return handler = self._pending_requests.pop(uuid) if isinstance(handler, str): handler = getattr(self, handler) handler(response) else: handler(self, response) def subscribe( self, alias: str, handler: Dict[Union[bytes, str], Any] ) -> None: """ Subscribe a SUB/SYNC_SUB socket given by its alias to the given topics, and leave the handlers prepared internally. Parameters ---------- alias Alias of the new subscriber socket. handler A dictionary in which the keys represent the different topics and the values the actual handlers. If, instead of a dictionary, a single handler is given, it will be used to subscribe the agent to any topic. """ if not isinstance(handler, dict): handler = {'': handler} curated_handlers = topics_to_bytes(handler) # Subscribe to topics for topic in curated_handlers: self._subscribe_to_topic(alias, topic) # Reset handlers if isinstance(self._address[alias], AgentChannel): channel = self._address[alias] sub_address = channel.receiver uuid = channel.twin_uuid curated_handlers = topics_to_bytes(handler, uuid=uuid) self._set_handler( self._socket[sub_address], curated_handlers, update=True ) else: self._set_handler( self._socket[alias], curated_handlers, update=True ) def unsubscribe(self, alias: str, topic: Union[bytes, str]) -> None: """ Unsubscribe a SUB/SYNC_SUB socket given by its alias from a given specific topic, and delete its entry from the handlers dictionary. If instead of a single topic, a tuple or a list of topics is passed, the agent will unsubscribe from all the supplied topics. """ if isinstance(topic, (tuple, list)): for t in topic: self.unsubscribe(alias, t) return topic = topic_to_bytes(topic) if isinstance(self._address[alias], AgentAddress): self._socket[alias].setsockopt(zmq.UNSUBSCRIBE, topic) del self._handler[self._socket[alias]][topic] elif isinstance(self._address[alias], AgentChannel): channel = self._address[alias] sub_address = channel.receiver treated_topic = channel.twin_uuid + topic self._socket[sub_address].setsockopt( zmq.UNSUBSCRIBE, treated_topic ) del self._handler[self._socket[sub_address]][treated_topic] else: raise NotImplementedError( 'Unsupported address type %s!' % self._address[alias] ) def _subscribe_to_topic(self, alias: str, topic: Union[bytes, str]): """ Do the actual ZeroMQ subscription of a socket given by its alias to a specific topic. This method only makes sense to be called on SUB/SYNC_SUB sockets. Note that the handler is not set within this function. """ topic = topic_to_bytes(topic) if isinstance(self._address[alias], AgentAddress): self._socket[alias].setsockopt(zmq.SUBSCRIBE, topic) elif isinstance(self._address[alias], AgentChannel): channel = self._address[alias] sub_address = channel.receiver treated_topic = channel.uuid + topic self._socket[sub_address].setsockopt(zmq.SUBSCRIBE, treated_topic) else: raise NotImplementedError( 'Unsupported address type %s!' % self._address[alias] ) def idle(self): """ This function is to be executed when the agent is idle. After a timeout occurs when the agent's poller receives no data in any of its sockets, the agent may execute this function. Note ---- The timeout is set by the agent's ``poll_timeout`` attribute. """ pass def set_attr(self, **kwargs): """ Set object attributes. Parameters ---------- kwargs : [name, value] Keyword arguments will be used to set the object attributes. """ for name, value in kwargs.items(): setattr(self, name, value) self.log_debug('SET self.%s = %s' % (name, value)) def get_attr(self, name: str): """ Return the specified attribute of the agent. Parameters ---------- name Name of the attribute to be retrieved. """ return getattr(self, name) def set_method(self, *args, **kwargs): """ Set object methods. Parameters ---------- args : [function] New methods will be created for each function, taking the same name as the original function. kwargs : [name, function] New methods will be created for each function, taking the name specified by the parameter. Returns ------- str Name of the registered method in the agent. """ for function in args: method = types.MethodType(function, self) name = method.__name__ setattr(self, name, method) self.log_debug('SET self.%s() = %s' % (name, function)) for name, function in kwargs.items(): method = types.MethodType(function, self) setattr(self, name, method) self.log_debug('SET self.%s() = %s' % (name, function)) return name def execute_as_function(self, function, *args, **kwargs): """ Execute a function passed as parameter. """ return function(*args, **kwargs) def execute_as_method(self, function, *args, **kwargs): """ Execute a function as a method, without adding it to the set of agent methods. """ return function(self, *args, **kwargs) def _loop(self): """ Agent's main loop. This loop is executed until the `_keep_alive` attribute is False or until an error occurs. """ while self._keep_alive: try: if self._iterate(): break except zmq.error.ContextTerminated: self._die_now = True break def _iterate(self): """ Agent's main iteration. This iteration is normally executed inside the main loop. The agent is polling all its sockets for input data. It will wait for `poll_timeout`; after this period, the method `idle` will be executed before polling again. Returns ------- int 1 if an error occurred during the iteration (we would expect this to happen if an interruption occurs during polling). 0 otherwise. """ try: events = dict(self._poller.poll(self._poll_timeout)) except zmq.ZMQError as error: # Raise the exception in case it is not due to SIGINT if error.errno == errno.EINTR: return 1 raise if not events: # Agent is idle self.idle() return 0 self._process_events(events) return 0 def _process_sub_message(self, serializer, message): """ Return the received message in a PUBSUB communication. Parameters ---------- message : bytes Received message without any treatment. Note that we do not know whether there is a topic or not. Returns ------- anything The content of the message passed. """ if serializer.requires_separator: sep = message.index(TOPIC_SEPARATOR) + 1 message = memoryview(message)[sep:] return deserialize_message(message=message, serializer=serializer) def _process_events(self, events): """ Process a socket's event. Parameters ---------- events : dict Events to be processed. """ for socket in events: if events[socket] != zmq.POLLIN: continue self._process_single_event(socket) def _process_single_event(self, socket): """ Process a socket's event. Parameters ---------- socket : zmq.Socket Socket that generated the event. """ data = socket.recv() address = self._address[socket] if address.kind == 'SUB': self._process_sub_event(socket, address, data) elif address.kind == 'PULL': self._process_pull_event(socket, address, data) elif address.kind == 'REP': self._process_rep_event(socket, address, data) else: self._process_single_event_complex(address, socket, data) def _process_single_event_complex(self, address, socket, data): """ Process a socket's event for complex sockets (channels). Parameters ---------- address : AgentAddress or AgentChannel Agent address or channel associated to the socket. socket : zmq.Socket Socket that generated the event. data Received in the socket. """ if address.kind == 'ASYNC_REP': self._process_async_rep_event(socket, address, data) elif address.kind == 'PULL_SYNC_PUB': self._process_sync_pub_event(socket, address.channel, data) else: raise NotImplementedError('Unsupported kind %s!' % address.kind) def _process_rep_event(self, socket, addr, data): """ Process a REP socket's event. Parameters ---------- socket : zmq.Socket Socket that generated the event. addr : AgentAddress AgentAddress associated with the socket that generated the event. data : bytes Data received on the socket. """ message = deserialize_message(message=data, serializer=addr.serializer) handler = self._handler[socket] if inspect.isgeneratorfunction(handler): generator = handler(self, message) socket.send(serialize_message(next(generator), addr.serializer)) execute_code_after_yield(generator) else: reply = handler(self, message) socket.send(serialize_message(reply, addr.serializer)) def _process_async_rep_event(self, socket, channel, data): """ Process a ASYNC_REP socket's event. Parameters ---------- socket : zmq.Socket Socket that generated the event. channel : AgentChannel AgentChannel associated with the socket that generated the event. data : bytes Data received on the socket. """ message = deserialize_message( message=data, serializer=channel.serializer ) address_uuid, request_uuid, data, address = message client_address = address.twin() if not self._registered(client_address): self.connect(address) handler = self._handler[socket] is_generator = inspect.isgeneratorfunction(handler) if is_generator: generator = handler(self, data) reply = next(generator) else: reply = handler(self, data) self.send(client_address, (address_uuid, request_uuid, reply)) if is_generator: execute_code_after_yield(generator) def _process_sync_pub_event(self, socket, channel, data): """ Process a SYNC_PUB socket's event. Parameters ---------- socket : zmq.Socket Socket that generated the event. channel : AgentChannel AgentChannel associated with the socket that generated the event. data : bytes Data received on the socket. """ message = deserialize_message( message=data, serializer=channel.serializer ) address_uuid, request_uuid, data = message handler = self._handler[socket] is_generator = inspect.isgeneratorfunction(handler) if is_generator: generator = handler(self, data) reply = next(generator) else: reply = handler(self, data) message = (address_uuid, request_uuid, reply) self._send_channel_sync_pub( channel=channel, message=message, topic=address_uuid, general=False ) if is_generator: execute_code_after_yield(generator) def _process_pull_event(self, socket, addr, data): """ Process a PULL socket's event. Parameters ---------- socket : zmq.Socket Socket that generated the event. addr : AgentAddress AgentAddress associated with the socket that generated the event. data : bytes Data received on the socket. """ message = deserialize_message(message=data, serializer=addr.serializer) handler = self._handler[socket] if not isinstance(handler, (list, dict, tuple)): handler = [handler] for h in handler: h(self, message) def _process_sub_event(self, socket, addr, data): """ Process a SUB socket's event. Parameters ---------- socket : zmq.Socket Socket that generated the event. addr : AgentAddress AgentAddress associated with the socket that generated the event. data : bytes Data received on the socket. """ handlers = self._handler[socket] message = self._process_sub_message(addr.serializer, data) for topic in handlers: if not data.startswith(topic): continue # Call the handler (with or without the topic) handler = handlers[topic] nparams = len(inspect.signature(handler).parameters) if nparams == 2: handler(self, message) elif nparams == 3: handler(self, message, topic) def send( self, address, message, topic=None, handler=None, wait=None, on_error=None, ): """ Send a message through the specified address. Note that replies in a REQREP pattern do not use this function in order to be sent. Parameters ---------- address : AgentAddress or AgentChannel The address to send the message through. message The message to be sent. topic : str The topic, in case it is relevant (i.e.: for PUB sockets). handler : function, method or string Code that will be executed on input messages if relevant (i.e.: for asynchronous requests in channels). wait : float For channel requests, wait at most this number of seconds for a response from the server. on_error : function, method or string Code to be executed if ``wait`` is passed and the response is not received. """ address = self._address[address] if isinstance(address, AgentChannel): return self._send_channel( channel=address, message=message, topic=topic, handler=handler, wait=wait, on_error=on_error, ) if isinstance(address, AgentAddress): return self._send_address( address=address, message=message, topic=topic ) raise NotImplementedError('Unsupported address type %s!' % address) def _send_address(self, address, message, topic=None): """ Send a message through a specific address. """ message = serialize_message( message=message, serializer=address.serializer ) if address.kind == 'PUB': if topic is None: topic = '' topic = topic_to_bytes(topic) message = compose_message( message=message, topic=topic, serializer=address.serializer ) self._socket[address].send(message) def _send_channel(self, channel, message, topic, handler, wait, on_error): """ Send a message through a specific channel. """ kind = channel.kind if kind == 'ASYNC_REP': return self._send_channel_async_rep( channel=channel, message=message, wait=wait, on_error=on_error, handler=handler, ) if kind == 'SYNC_PUB': return self._send_channel_sync_pub( channel=channel, message=message, topic=topic ) if kind == 'SYNC_SUB': return self._send_channel_sync_sub( channel, message, topic, handler, wait, on_error ) raise NotImplementedError('Unsupported channel kind %s!' % kind) def _send_channel_async_rep( self, channel, message, wait, on_error, handler=None ): """ Send a message through an ASYNC_REP channel. """ address = channel.receiver address_uuid = self._async_req_uuid[address] request_uuid = unique_identifier() if handler is not None: self._pending_requests[request_uuid] = handler else: self._pending_requests[request_uuid] = self._async_req_handler[ address_uuid ] receiver_address = self._address[address_uuid] message = (address_uuid, request_uuid, message, receiver_address) message = serialize_message( message=message, serializer=channel.serializer ) self._socket[channel].send(message) self._wait_received(wait, uuid=request_uuid, on_error=on_error) def _send_channel_sync_pub( self, channel, message, topic=None, general=True ): """ Send a message through a SYNC_PUB channel. """ message = serialize_message( message=message, serializer=channel.serializer ) if topic is None: topic = '' topic = topic_to_bytes(topic) if general: topic = channel.uuid + topic message = compose_message( message=message, topic=topic, serializer=channel.serializer ) self._socket[channel].send(message) def _send_channel_sync_sub( self, channel, message, topic, handler, wait, on_error ): """ Send a message through a SYNC_SUB channel. """ address = channel.receiver address_uuid = self._async_req_uuid[address] request_uuid = unique_identifier() if handler is None: raise ValueError('No handler for SYNC_PUB request') self._pending_requests[request_uuid] = handler message = (address_uuid, request_uuid, message) self._send_address(channel.sender, message) self._wait_received(wait, uuid=request_uuid, on_error=on_error) return def _check_received(self, uuid, wait, on_error): """ Check if the requested information has been received. Parameters ---------- uuid : str Request identifier. wait : float The total number of seconds since the request was made. on_error : function, method or string Code to be executed in case a response was not received for the request in time. If not provided, it will simply log a warning. """ if uuid not in self._pending_requests: return del self._pending_requests[uuid] if not on_error: warning = 'Did not receive request {} after {} seconds'.format( uuid, wait ) self.log_warning(warning) return on_error(self) def _wait_received(self, wait, uuid, on_error): """ Set up a timer to check a response was received for a given request after a defined time lapse. Parameters ---------- uuid : str Request identifier. wait : float The total number of seconds to wait for the response. on_error : function, method or string Code to be executed in case a response was not received for the request in time. If not provided, it will simply log a warning. """ if not wait: return return self.after(wait, '_check_received', uuid, wait, on_error) def recv(self, address): """ Receive a message from the specified address. This method is only used in REQREP communication patterns. Parameters ---------- address : Returns ------- anything The content received in the address. """ message = self._socket[address].recv() serializer = self._address[address].serializer return deserialize_message(message=message, serializer=serializer) def send_recv(self, address, message): """ This method is only used in REQREP communication patterns. """ self.send(address, message) return self.recv(address) def is_running(self): """ Returns a boolean indicating whether the agent is running or not. """ return self._running @Pyro4.oneway def run(self): """ Start the main loop. """ # A loopback socket where, for example, timers are processed self.bind( 'REP', alias='loopback', addr='loopback', handler=self._handle_loopback, transport='inproc', serializer='pickle', ) # This in-process socket handles safe access to # memory from other threads (i.e. when using Pyro proxies). self.bind( 'REP', alias='_loopback_safe', addr='_loopback_safe', handler=self._handle_loopback_safe, transport='inproc', serializer='pickle', ) self._running = True self.before_loop() try: self._loop() except Exception as error: self._running = False msg = 'An exception occurred while running! (%s)\n' % error msg += format_exception() self.log_error(msg) raise self._running = False if self._die_now: self._die() def shutdown(self): """ Cleanly stop and shut down the agent assuming the agent is running. Will let the main thread do the tear down. """ self.log_debug('Shutting down...') self._keep_alive = False self._die_now = True def kill(self): """ Force shutdown of the agent. If the agent is running the ZMQ context is terminated to allow the main thread to quit and do the tear down. """ self.stop_all_timers() if self._running: self._context.term() else: self._die(linger=0) def _die(self, linger=None): """ Tear down the agent. Last action before ending existence. - Stop timers. - Close all external sockets. - Shutdown the Pyro daemon. """ self.stop_all_timers() self.close_all(linger=linger) self._pyroDaemon.shutdown() def _get_unique_external_zmq_sockets(self): """ Return an iterable containing all the zmq.Socket objects from `self.socket` which are not internal, without repetition. Originally, a socket was internal if its alias was one of the following: - loopback - _loopback_safe - inproc://loopback - inproc://_loopback_safe However, since we are storing more than one entry in the `self.socket` dictionary per zmq.socket (by storing its AgentAddress, for example), we need a way to simply get all non-internal zmq.socket objects, and this is precisely what this function does. """ reserved = ( 'loopback', '_loopback_safe', 'inproc://loopback', 'inproc://_loopback_safe', ) external_sockets = [] for k, v in self._socket.items(): if isinstance(k, zmq.sugar.socket.Socket): continue if isinstance(k, AgentAddress) and k.address in reserved: continue if k in reserved: continue external_sockets.append(v) return set(external_sockets) def has_socket(self, alias): """ Return whether the agent has the passed socket internally stored. """ return alias in self._socket def _delete_socket_entries(self, entries): for entry in entries: del self._socket[entry] def _cleanup_ipc_socket_files(self, address): """ Make sure no IPC socket files are left in the file system. """ if isinstance(address, AgentChannel): candidates = [address.receiver, address.sender] else: candidates = [address] for candidate in candidates: if candidate is None or candidate.role != 'server': continue with contextlib.suppress(FileNotFoundError): candidate.address.unlink() def _close_socket(self, socket, linger): """ Close a socket using the provided linger value. """ if any(reg[0] == socket for reg in self._poller.sockets): self._poller.unregister(socket) linger = get_linger(linger) socket.close(linger=linger) address = self._address[socket] if address.transport == 'ipc': self._cleanup_ipc_socket_files(address) def close(self, alias, linger=None): """ Close a socket given its alias and clear its entry from the `Agent._socket` dictionary. """ socket = self._socket[alias] # Each socket might be pointed by different keys entries_to_delete = [] for k, v in self._socket.items(): if v == socket: entries_to_delete.append(k) self._delete_socket_entries(entries_to_delete) self._close_socket(socket, linger=linger) def close_all(self, linger=None): """ Close all non-internal zmq sockets. """ # Each socket might be pointed by different keys sockets_to_delete = [] for socket in self._get_unique_external_zmq_sockets(): sockets_to_delete.append(socket) self._close_socket(socket, linger=linger) entries_to_delete = [] for k, v in self._socket.items(): if v in sockets_to_delete: entries_to_delete.append(k) self._delete_socket_entries(entries_to_delete) def ping(self): """ A test method to check the readiness of the agent. Used for testing purposes, where timing is very important. Do not remove. """ return 'pong' class AgentProcess(multiprocessing.Process): """ Agent class. Instances of an Agent are system processes which can be run independently. """ def __init__( self, name='', nsaddr=None, addr=None, serializer=None, transport=None, base=Agent, attributes=None, ): super().__init__() self.name = name self._daemon = None self._host, self.port = address_to_host_port(addr) if self.port is None: self.port = 0 self.nsaddr = nsaddr self._serializer = serializer self._transport = transport self.base = cloudpickle.dumps(base) self._shutdown_event = multiprocessing.Event() self._queue = multiprocessing.Queue() self._sigint = False self.attributes = attributes def run(self): """ Begin execution of the agent process and start the main loop. """ # Capture SIGINT signal.signal(signal.SIGINT, self._sigint_handler) try: ns = NSProxy(self.nsaddr) self._daemon = Pyro4.Daemon(self._host, self.port) self.base = cloudpickle.loads(self.base) self.agent = self.base( name=self.name, host=self._host, serializer=self._serializer, transport=self._transport, attributes=self.attributes, ) except Exception: self._queue.put(format_exception()) return self.name = self.agent.name uri = self._daemon.register(self.agent) try: ns.register(self.name, uri, safe=True) except Pyro4.errors.NamingError: self._queue.put(format_exception()) return finally: ns.release() self._queue.put('STARTED:' + self.name) self._daemon.requestLoop(lambda: not self._shutdown_event.is_set()) self._daemon.unregister(self.agent) self._teardown() def _remove_from_nameserver(self): """ Make sure to remove the agent's name from the name server. """ while True: try: ns = NSProxy(self.nsaddr) ns.remove(self.name) ns.release() except PyroError: time.sleep(0.1) continue break def _teardown(self): """ Remove self from the name server address book, close daemon and die. """ if not self._sigint: # Clean teardown self._remove_from_nameserver() self.agent._killed = True self._daemon.close() def start(self): """ Start the system process. Raises ------ RuntimeError If an error occurred when initializing the daemon. """ super().start() status = self._queue.get() if not status.startswith('STARTED'): raise RuntimeError( 'An error occurred while creating the daemon!' + '\n===============\n'.join(['', status, '']) ) return status.partition(':')[-1] def kill(self): """ Force kill the agent process. """ self._shutdown_event.set() if self._daemon: self._daemon.shutdown() def _sigint_handler(self, signal, frame): """ Handle interruption signals. """ self._sigint = True self.kill() def run_agent( name='', nsaddr=None, addr=None, base=Agent, serializer=None, transport=None, safe=None, attributes=None, ): """ Ease the agent creation process. This function will create a new agent, start the process and then run its main loop through a proxy. Parameters ---------- name : str, default is '' Agent name or alias. nsaddr : SocketAddress, default is None Name server address. addr : SocketAddress, default is None New agent address, if it is to be fixed. transport : str, AgentAddressTransport, default is None Transport protocol. safe : bool, default is None Use safe calls by default from the Proxy. attributes : dict, default is None A dictionary that defines initial attributes for the agent. Returns ------- proxy A proxy to the new agent. """ if not nsaddr: nsaddr = os.environ.get('OSBRAIN_NAMESERVER_ADDRESS') agent = AgentProcess( name=name, nsaddr=nsaddr, addr=addr, base=base, serializer=serializer, transport=transport, attributes=attributes, ) agent_name = agent.start() proxy = Proxy(agent_name, nsaddr, safe=safe) proxy.run() proxy.wait_for_running() return proxy