""" Implementation of proxy-related features. """ import os import sys import time import Pyro4 from Pyro4.errors import ConnectionClosedError from Pyro4.errors import NamingError from Pyro4.message import FLAGS_ONEWAY from . import config from .address import SocketAddress from .address import address_to_host_port def locate_ns(nsaddr, timeout=3.0): """ Locate a name server to ensure it actually exists. Parameters ---------- nsaddr : SocketAddress The address where the name server should be up and running. timeout : float Timeout in seconds before aborting location. Returns ------- nsaddr The address where the name server was located. Raises ------ NamingError If the name server could not be located. """ host, port = address_to_host_port(nsaddr) time0 = time.time() while True: try: Pyro4.locateNS(host, port) return nsaddr except NamingError: if time.time() - time0 < timeout: time.sleep(0.1) continue raise TimeoutError('Could not locate the name server!') class Proxy(Pyro4.core.Proxy): """ A proxy to access remote agents. Parameters ---------- name : str Proxy name, as registered in the name server. nsaddr : SocketAddress, str Name server address. timeout : float Timeout, in seconds, to wait until the agent is discovered. safe : bool, default is None Use safe calls by default. When not set, osbrain default's :py:data:`osbrain.config['SAFE']` is used. """ def __init__(self, name, nsaddr=None, timeout=3.0, safe=None): if not nsaddr: nsaddr = os.environ.get('OSBRAIN_NAMESERVER_ADDRESS') nshost, nsport = address_to_host_port(nsaddr) # Make sure name server exists locate_ns(nsaddr) time0 = time.time() super().__init__('PYRONAME:%s@%s:%s' % (name, nshost, nsport)) if safe is not None: self._default_safe = safe else: self._default_safe = config['SAFE'] self._safe = self._default_safe self._next_oneway = False while not self._ready_or_timeout(time0, timeout): continue def _ready_or_timeout(self, time0, timeout): """ Check if the proxy is ready or raise after a timeout. Parameters ---------- time0 : float Timestamp (in seconds) to take as the initial time. timeout : float Time (in seconds) allowed after `time0` before raising an exception. """ try: self.unsafe.ping() except Exception: time.sleep(0.1) if time.time() - time0 < timeout: return False raise return True def wait_for_running(self, timeout=3.0): """ Wait until the agent is running. Parameters ---------- timeout : float Raise and exception if the agent is not running after this number of seconds. Use a negative value to wait forever. Raises ------ TimeoutError If the agent is not running after the given timeout. Returns ------- Proxy The object itself. """ time0 = time.time() while not self.is_running(): if timeout >= 0 and time.time() - time0 > timeout: msg = 'Timed out while waiting for the agent to be running' raise TimeoutError(msg) time.sleep(0.01) return self def __getstate__(self): return super().__getstate__() + ( self._next_oneway, self._default_safe, self._safe, ) def __setstate__(self, state): super().__setstate__(state[:-3]) self._next_oneway = state[-3] self._default_safe = state[-2] self._safe = state[-1] def __setattr__(self, name, value): if name in ('_safe', '_default_safe', '_next_oneway'): return super(Pyro4.core.Proxy, self).__setattr__(name, value) if name.startswith('_'): return super().__setattr__(name, value) kwargs = {name: value} return self.set_attr(**kwargs) def __getattr__(self, name): if name in self._pyroAttrs: return self.get_attr(name) return super().__getattr__(name) def release(self): """ Release the connection to the Pyro daemon. """ self._pyroRelease() def nsaddr(self): """ Get the socket address of the name server. Returns ------- SocketAddress The socket address. """ return SocketAddress(self._pyroUri.host, self._pyroUri.port) @property def safe(self): """ Make the next remote method call be safe. Returns ------- The proxy itself. """ self._safe = True return self @property def unsafe(self): """ Make the next remote method call be unsafe. Returns ------- The proxy itself. """ self._safe = False return self @property def oneway(self): """ Make the next remote method call be one way. Returns ------- The proxy itself. """ self._next_oneway = True return self def _pyroInvoke( # noqa: N802 self, methodname, args, kwargs, flags=0, objectId=None # noqa: N803 ): """ Wrapper around `_remote_call` to safely execute methods on remote objects. """ try: result = self._remote_call( methodname, args, kwargs, flags, objectId ) except Exception: sys.stdout.write(''.join(Pyro4.util.getPyroTraceback())) sys.stdout.flush() raise finally: self._safe = self._default_safe self._next_oneway = False self._post_invoke(methodname, args, kwargs) return result def _is_safe_method(self, methodname): """ Check if a remote method can be called safely. Parameters ---------- methodname : str The name of the method to evaluate. Returns ------- bool Whether the method can be safely called. """ return ( methodname in self._pyroMethods and not methodname.startswith('_') and methodname not in ( 'run', 'get_attr', 'kill', 'safe_call', 'concurrent', 'is_running', ) ) def _remote_call( self, methodname, args, kwargs, flags, objectId # noqa: N803 ): """ Call a remote method from the proxy. """ if self._next_oneway: flags |= FLAGS_ONEWAY result = super()._pyroInvoke( methodname, args, kwargs, flags=flags, objectId=objectId ) return result if self._safe and self._is_safe_method(methodname): safe_args = [methodname] + list(args) result = super()._pyroInvoke( 'safe_call', safe_args, kwargs, flags=flags, objectId=objectId ) if isinstance(result, Exception): raise result else: result = super()._pyroInvoke( methodname, args, kwargs, flags=flags, objectId=objectId ) return result def _post_invoke(self, methodname, args, kwargs): """ After invoking a call, check if the proxy must be modified. This could happen if the `set_method` or `set_attr` have been invoked. In that case, the method(s) or attribute(s) are added to the proxy's available method(s)/attributes(s). """ if methodname == 'set_method': self._set_new_available_methods(args, kwargs) elif methodname == 'set_attr': self._set_new_available_attributes(kwargs) def _set_new_available_methods(self, args, kwargs): """ Set new methods available from the proxy. Parameters ---------- args : list A list of new methods to be made available from the proxy. kwargs : dict A dictionary with the methods' names and their values. """ for method in args: self._pyroMethods.add(method.__name__) for name, _ in kwargs.items(): self._pyroMethods.add(name) def _set_new_available_attributes(self, kwargs): """ Set new attributes available from the proxy. Parameters ---------- kwargs : dict A dictionary with the attributes' names and their values. """ for name in kwargs: self._pyroAttrs.add(name) class NSProxy(Pyro4.core.Proxy): """ A proxy to access a name server. Parameters ---------- nsaddr : SocketAddress, str Name server address. timeout : float Timeout, in seconds, to wait until the name server is discovered. """ def __init__(self, nsaddr=None, timeout=3): if not nsaddr: nsaddr = os.environ.get('OSBRAIN_NAMESERVER_ADDRESS') nshost, nsport = address_to_host_port(nsaddr) # Make sure name server exists locate_ns(nsaddr, timeout) ns_name = Pyro4.constants.NAMESERVER_NAME super().__init__('PYRONAME:%s@%s:%d' % (ns_name, nshost, nsport)) def release(self): """ Release the connection to the Pyro daemon. """ self._pyroRelease() def proxy(self, name, timeout=3.0): """ Get a proxy to access an agent registered in the name server. Parameters ---------- name : str Proxy name, as registered in the name server. timeout : float Timeout, in seconds, to wait until the agent is discovered. Returns ------- Proxy A proxy to access an agent registered in the name server. """ return Proxy(name, nsaddr=self.addr(), timeout=timeout) def addr(self, agent_alias=None, address_alias=None): """ Return the name server address or the address of an agent's socket. Parameters ---------- agent_alias : str, default is None The alias of the agent to retrieve its socket address. address_alias : str, default is None The alias of the socket address to retrieve from the agent. Returns ------- SocketAddress or AgentAddress The name server or agent's socket address. """ if not agent_alias and not address_alias: return SocketAddress(self._pyroUri.host, self._pyroUri.port) agent = self.proxy(agent_alias) addr = agent.addr(address_alias) agent.release() return addr def shutdown_agents(self, timeout=10.0): """ Shutdown all agents registered in the name server. Parameters ---------- timeout : float, default is 10. Timeout, in seconds, to wait for the agents to shutdown. """ # Wait for all agents to be shutdown (unregistered) time0 = time.time() super()._pyroInvoke('async_shutdown_agents', (self.addr(),), {}) while time.time() - time0 <= timeout / 2.0: if not len(self.agents()): return time.sleep(0.1) super()._pyroInvoke('async_kill_agents', (self.addr(),), {}) while time.time() - time0 <= timeout: if not len(self.agents()): return time.sleep(0.1) raise TimeoutError( 'Chances are {} were not shutdown after {} s!'.format( self.agents(), timeout ) ) def shutdown(self, timeout=10.0): """ Shutdown the name server. All agents will be shutdown as well. Parameters ---------- timeout : float, default is 10. Timeout, in seconds, to wait for the agents to shutdown. """ self.shutdown_agents(timeout) try: super()._pyroInvoke('daemon_shutdown', (), {}, flags=0) except ConnectionClosedError: pass