''' Created on May 29, 2015 @author: mzwier ''' from pickle import UnpicklingError # Every ten seconds the master requests a status report from workers. # This also notifies workers that the master is still alive DEFAULT_STATUS_POLL = 10 # If we haven't heard from the master or a worker (as appropriate) in these # amounts of time, we assume a crash and shut down. MASTER_CRASH_TIMEOUT = DEFAULT_STATUS_POLL * 6 WORKER_CRASH_TIMEOUT = DEFAULT_STATUS_POLL * 3 import logging log = logging.getLogger(__name__) #import gevent import sys, uuid, socket, os,tempfile, errno, time, threading, contextlib, traceback, multiprocessing, json, re from collections import OrderedDict import signal signames = {val:name for name, val in reversed(sorted(signal.__dict__.items())) if name.startswith('SIG') and not name.startswith('SIG_')} import zmq import numpy DEFAULT_LINGER = 1 def randport(address='127.0.0.1'): '''Select a random unused TCP port number on the given address.''' s = socket.socket() s.bind((address,0)) try: port = s.getsockname()[1] finally: s.close() return port class ZMQWMError(RuntimeError): '''Base class for errors related to the ZeroMQ work manager itself''' pass class ZMQWorkerMissing(ZMQWMError): '''Exception representing that a worker processing a task died or disappeared''' pass class ZMQWMEnvironmentError(ZMQWMError): '''Class representing an error in the environment in which the ZeroMQ work manager is running. This includes such things as master/worker ID mismatches.''' class ZMQWMTimeout(ZMQWMEnvironmentError): '''A timeout of a sort that indicatess that a master or worker has failed or never started.''' class Message: SHUTDOWN = 'shutdown' ACK = 'ok' NAK = 'no' IDENTIFY = 'identify' # Two-way identification (a reply must be an IDENTIFY message) TASKS_AVAILABLE = 'tasks_available' TASK_REQUEST = 'task_request' MASTER_BEACON = 'master_alive' RECONFIGURE_TIMEOUT = 'reconfigure_timeout' TASK = 'task' RESULT = 'result' idempotent_announcement_messages = {SHUTDOWN, TASKS_AVAILABLE, MASTER_BEACON} def __init__(self, message=None, payload=None, master_id=None, src_id=None): if isinstance(message,Message): self.message = message.message self.payload = message.payload self.master_id = message.master_id self.src_id = message.src_id else: self.master_id = master_id self.src_id = src_id self.message = message self.payload = payload def __repr__(self): return ('<{!s} master_id={master_id!s} src_id={src_id!s} message={message!r} payload={payload!r}>' .format(self.__class__.__name__, **self.__dict__)) @classmethod def coalesce_announcements(cls, messages): d = OrderedDict() for msg in messages: if msg.message in cls.idempotent_announcement_messages: key = msg.message else: key = (msg.message, msg.payload) d[key] = msg coalesced = list(msg.values()) log.debug('coalesced {} announcements into {}'.format(len(messages), len(coalesced))) return coalesced TIMEOUT_MASTER_BEACON = 'master_beacon' TIMEOUT_WORKER_CONTACT = 'worker_contact' class Task: def __init__(self, fn, args, kwargs, task_id = None): self.task_id = task_id or uuid.uuid4() self.fn = fn self.args = args self.kwargs = kwargs def __repr__(self): try: return '<{} {task_id!s} {fn!r} {:d} args {:d} kwargs>'\ .format(self.__class__.__name__, len(self.args), len(self.kwargs), **self.__dict__) except TypeError: # no length return '<{} {task_id!s} {fn!r}'.format(self.__class__.__name__, **self.__dict__) def __hash__(self): return hash(self.task_id) def execute(self): '''Run this task, returning a Result object.''' rsl = Result(task_id = self.task_id) try: rsl.result = self.fn(*self.args, **self.kwargs) except BaseException as e: rsl.exception = e rsl.traceback = traceback.format_exc() return rsl class Result: def __init__(self, task_id, result=None, exception=None, traceback=None): self.task_id = task_id self.result = result self.exception = exception self.traceback = traceback def __repr__(self): return '<{} {task_id!s} ({})>'\ .format(self.__class__.__name__, 'result' if self.exception is None else 'exception', **self.__dict__) def __hash__(self): return hash(self.task_id) class PassiveTimer: __slots__ = {'started', 'duration'} def __init__(self, duration, started=None): if started is None: started = time.time() self.started = started self.duration = duration @property def expired(self, at=None): at = at or time.time() return (at - self.started) > self.duration @property def expires_in(self): at = time.time() return self.started + self.duration - at def reset(self, at=None): self.started = at or time.time() start = reset class PassiveMultiTimer: def __init__(self): self._identifiers = numpy.empty((0,), numpy.object_) self._durations = numpy.empty((0,), float) self._started = numpy.empty((0,), float) self._indices = {} # indexes into durations/started, keyed by identifier def add_timer(self, identifier, duration): if identifier in self._identifiers: raise KeyError('timer {!r} already present'.format(identifier)) new_idx = len(self._identifiers) self._durations.resize((new_idx+1,)) self._started.resize((new_idx+1,)) self._identifiers.resize((new_idx+1,)) self._durations[new_idx] = duration self._started[new_idx] = time.time() self._identifiers[new_idx] = identifier self._indices[identifier] = new_idx def remove_timer(self, identifier): idx = self._indices.pop(identifier) self._durations = numpy.delete(self._durations, idx) self._started = numpy.delete(self._started, idx) self._identifiers = numpy.delete(self._identifiers, idx) def change_duration(self, identifier, duration): idx = self._indices[identifier] self._durations[idx] = duration def reset(self, identifier=None, at=None): at = at or time.time() if identifier is None: # reset all timers self._started.fill(at) else: self._started[self._indices[identifier]] = at def expired(self, identifier, at = None): at = at or time.time() idx = self._indices[identifier] return (at - self._started[idx]) > self._durations[idx] def next_expiration(self): at = time.time() idx = (self._started + self._durations - at).argmin() return self._identifiers[idx] def next_expiration_in(self): at = time.time() idx = (self._started + self._durations - at).argmin() next_at = self._started[idx] + self._durations[idx] - at return next_at if next_at > 0 else 0 def which_expired(self, at=None): at = at or time.time() expired_indices = (at - self._started) > self._durations return self._identifiers[expired_indices] class ZMQCore: # The overall communication topology (socket layout, etc) # Cannot be updated without updating configuration files, command-line parameters, # etc. (Changes break user scripts.) PROTOCOL_MAJOR = 3 # The set of messages and replies in use. # Cannot be updated without changing existing communications logic. (Changes break # the ZMQ WM library.) PROTOCOL_MINOR = 0 # Minor updates and additions to the protocol. # Changes do not break the ZMQ WM library, but only add new # functionality/code paths without changing existing code paths. PROTOCOL_UPDATE = 0 PROTOCOL_VERSION = (PROTOCOL_MAJOR, PROTOCOL_MINOR, PROTOCOL_UPDATE) # The default transport for "internal" (inter-thread/-process) communication # IPC should work except on really odd systems with no local storage internal_transport = 'ipc' default_comm_mode = 'ipc' default_master_heartbeat = 20.0 default_worker_heartbeat = 20.0 default_timeout_factor = 5.0 default_startup_timeout = 120.0 default_shutdown_timeout = 5.0 _ipc_endpoints_to_delete = [] @classmethod def make_ipc_endpoint(cls): (fd, socket_path) = tempfile.mkstemp() os.close(fd) endpoint = 'ipc://{}'.format(socket_path) cls._ipc_endpoints_to_delete.append(endpoint) return endpoint @classmethod def remove_ipc_endpoints(cls): while cls._ipc_endpoints_to_delete: endpoint = cls._ipc_endpoints_to_delete.pop() assert endpoint.startswith('ipc://') socket_path = endpoint[6:] try: os.unlink(socket_path) except OSError as e: if e.errno != errno.ENOENT: log.debug('could not unlink IPC endpoint {!r}: {}'.format(socket_path, e)) else: log.debug('unlinked IPC endpoint {!r}'.format(socket_path)) @classmethod def make_tcp_endpoint(cls, address='127.0.0.1'): return 'tcp://{}:{}'.format(address,randport(address)) @classmethod def make_internal_endpoint(cls): assert cls.internal_transport in {'ipc', 'tcp'} if cls.internal_transport == 'ipc': return cls.make_ipc_endpoint() else: # cls.internal_transport == 'tcp' return cls.make_tcp_endpoint() def __init__(self): # Unique identifier of this ZMQ node self.node_id = uuid.uuid4() # Identifier of the task distribution network (work manager) self.network_id = None # Beacons # Workers expect to hear from the master at least every master_beacon_period # Master expects to hear from the workers at least every worker_beacon_period # If more than {master,worker}_beacon_period*timeout_factor elapses, the # master/worker is considered missing. self.worker_beacon_period = self.default_worker_heartbeat self.master_beacon_period = self.default_master_heartbeat self.timeout_factor = self.default_timeout_factor # These should allow for some fuzz, and should ratchet up as more and # more workers become available (maybe order 1 s for 100 workers?) This # should also account appropriately for startup delay on difficult # systems. # Number of seconds to allow first contact between at least one worker # and the master. self.startup_timeout = self.default_startup_timeout # A friendlier description for logging self.node_description = '{!s} on {!s} at PID {:d}'.format(self.__class__.__name__, socket.gethostname(), os.getpid()) self.validation_fail_action = 'exit' # other options are 'raise' and 'warn' self.log = logging.getLogger(__name__ + '.' + self.__class__.__name__ + '.' + str(self.node_id)) # ZeroMQ context self.context = None # External communication endpoints self.rr_endpoint = None self.ann_endpoint = None self.inproc_endpoint = 'inproc://{!s}'.format(self.node_id) # Sockets self.rr_socket = None self.ann_socket = None # This is the main-thread end of this self._inproc_socket = None self.master_id = None if os.environ.get('WWMGR_ZMQ_DEBUG_MESSAGES', 'n').upper() in {'Y', 'YES', '1', 'T', 'TRUE'}: self._super_debug = True else: self._super_debug = None def __repr__(self): return '<{!s} {!s}>'.format(self.__class__.__name__, self.node_id) def get_identification(self): return {'node_id': self.node_id, 'master_id': self.master_id, 'class': self.__class__.__name__, 'description': self.node_description, 'hostname': socket.gethostname(), 'pid': os.getpid()} def validate_message(self, message): '''Validate incoming message. Raises an exception if the message is improperly formatted (TypeError) or does not correspond to the appropriate master (ZMQWMEnvironmentError).''' try: super_validator = super(ZMQCore,self).validate_message except AttributeError: pass else: super_validator(message) if not isinstance(message, Message): raise TypeError('message is not an instance of core.Message') if message.src_id is None: raise ZMQWMEnvironmentError('message src_id is not set') if self.master_id is not None and message.master_id is not None and message.master_id != self.master_id: raise ZMQWMEnvironmentError('incoming message associated with another master (this={!s}, incoming={!s}'.format(self.master_id, message.master_id)) @contextlib.contextmanager def message_validation(self, msg): '''A context manager for message validation. The instance variable ``validation_fail_action`` controls the behavior of this context manager: * 'raise': re-raise the exception that indicated failed validation. Useful for development. * 'exit' (default): report the error and exit the program. * 'warn': report the error and continue.''' try: yield except Exception as e: if self.validation_fail_action == 'raise': self.log.exception('message validation failed for {!r}'.format(msg)) raise elif self.validation_fail_action == 'exit': self.log.error('message validation falied: {!s}'.format(e)) sys.exit(1) elif self.validation_fail_action == 'warn': self.log.warn('message validation falied: {!s}'.format(e)) def recv_message(self, socket, flags=0, validate=True, timeout=None): '''Receive a message object from the given socket, using the given flags. Message validation is performed if ``validate`` is true. If ``timeout`` is given, then it is the number of milliseconds to wait prior to raising a ZMQWMTimeout exception. ``timeout`` is ignored if ``flags`` includes ``zmq.NOBLOCK``.''' if timeout is None or flags & zmq.NOBLOCK: message = socket.recv_pyobj(flags) else: poller = zmq.Poller() poller.register(socket, zmq.POLLIN) try: poll_results = dict(poller.poll(timeout=timeout)) if socket in poll_results: message = socket.recv_pyobj(flags) else: raise ZMQWMTimeout('recv timed out') finally: poller.unregister(socket) if self._super_debug: self.log.debug('received {!r}'.format(message)) if validate: with self.message_validation(message): self.validate_message(message) return message def recv_all(self, socket, flags=0, validate=True): '''Receive all messages currently available from the given socket.''' messages = [] while True: try: messages.append(self.recv_message(socket, flags | zmq.NOBLOCK, validate)) except zmq.Again: return messages def recv_ack(self, socket, flags=0, validate=True, timeout=None): msg = self.recv_message(socket, flags, validate, timeout) if validate: with self.message_validation(msg): assert msg.message in (Message.ACK, Message.NAK) return msg def send_message(self, socket, message, payload=None, flags=0): '''Send a message object. Subclasses may override this to decorate the message with appropriate IDs, then delegate upward to actually send the message. ``message`` may either be a pre-constructed ``Message`` object or a message identifier, in which (latter) case ``payload`` will become the message payload. ``payload`` is ignored if ``message`` is a ``Message`` object.''' message = Message(message, payload) if message.master_id is None: message.master_id = self.master_id message.src_id=self.node_id if self._super_debug: self.log.debug('sending {!r}'.format(message)) socket.send_pyobj(message,flags) def send_reply(self, socket, original_message, reply=Message.ACK, payload=None,flags=0): '''Send a reply to ``original_message`` on ``socket``. The reply message is a Message object or a message identifier. The reply master_id and worker_id are set from ``original_message``, unless master_id is not set, in which case it is set from self.master_id.''' reply = Message(reply, payload) reply.master_id = original_message.master_id or self.master_id assert original_message.worker_id is not None # should have been caught by validation prior to this reply.worker_id = original_message.worker_id self.send_message(socket, reply) def send_ack(self, socket, original_message): '''Send an acknowledgement message, which is mostly just to respect REQ/REP recv/send patterns.''' self.send_message(socket, Message(Message.ACK, master_id=original_message.master_id or self.master_id, src_id=self.node_id)) def send_nak(self, socket, original_message): '''Send a negative acknowledgement message.''' self.send_message(socket, Message(Message.NAK, master_id=original_message.master_id or self.master_id, src_id=self.node_id)) def send_inproc_message(self, message, payload=None, flags=0): inproc_socket = self.context.socket(zmq.PUB) inproc_socket.connect(self.inproc_endpoint) # annoying wait for sockets to settle time.sleep(0.01) self.send_message(inproc_socket, message, payload, flags) # used to be a close with linger here, but it was cutting off messages def signal_shutdown(self): try: self.send_inproc_message(Message.SHUTDOWN) except AttributeError: # this is expected if self.context has been set to None (i.e. it has already been destroyed) pass except Exception as e: self.log.debug('ignoring exception {!r} in signal_shutdown()'.format(e)) def shutdown_handler(self, signal=None, frame=None): if signal is None: self.log.info('shutting down') else: self.log.info('shutting down on signal {!s}'.format(signames.get(signal,signal))) self.signal_shutdown() def install_signal_handlers(self, signals = None): if not signals: signals = {signal.SIGINT, signal.SIGQUIT, signal.SIGTERM} for sig in signals: signal.signal(sig, self.shutdown_handler) def install_sigint_handler(self): self.install_signal_handlers() def startup(self): self.context = zmq.Context() self.comm_thread = threading.Thread(target=self.comm_loop) self.comm_thread.start() #self.install_signal_handlers() def shutdown(self): self.shutdown_handler() def join(self): while True: self.comm_thread.join(0.1) if not self.comm_thread.is_alive(): break def shutdown_process(process, timeout=1.0): process.join(timeout) if process.is_alive(): log.debug('sending SIGINT to process {:d}'.format(process.pid)) os.kill(process.pid, signal.SIGINT) process.join(timeout) if process.is_alive(): log.warning('sending SIGKILL to worker process {:d}'.format(process.pid)) os.kill(process.pid, signal.SIGKILL) process.join() log.debug('process {:d} terminated with code {:d}'.format(process.pid, process.exitcode)) else: log.debug('worker process {:d} terminated gracefully with code {:d}'.format(process.pid, process.exitcode)) assert not process.is_alive() class IsNode: def __init__(self, n_local_workers=None): from work_managers.zeromq.worker import ZMQWorker if n_local_workers is None: n_local_workers = multiprocessing.cpu_count() self.downstream_rr_endpoint = None self.downstream_ann_endpoint = None if n_local_workers: self.local_ann_endpoint = self.make_internal_endpoint() self.local_rr_endpoint = self.make_internal_endpoint() self.local_workers = [ZMQWorker(self.local_rr_endpoint, self.local_ann_endpoint) for _n in range(n_local_workers)] else: self.local_ann_endpoint = None self.local_rr_endpoint = None self.local_workers = [] self.local_worker_processes = [multiprocessing.Process(target = worker.startup, args=(n,)) for (n, worker) in enumerate(self.local_workers)] self.host_info_files = [] def write_host_info(self, filename=None): filename = filename or 'zmq_host_info_{}.json'.format(self.node_id.hex) hostname = socket.gethostname() with open(filename, 'wt') as infofile: info = {} info['rr_endpoint'] = re.sub(r'\*', hostname, self.downstream_rr_endpoint or '') info['ann_endpoint'] = re.sub(r'\*', hostname, self.downstream_ann_endpoint or '') json.dump(info,infofile) self.host_info_files.append(filename) def startup(self): for process in self.local_worker_processes: process.start() def shutdown(self): try: shutdown_timeout = self.shutdown_timeout except AttributeError: shutdown_timeout = 1.0 for process in self.local_worker_processes: shutdown_process(process, shutdown_timeout) for host_info_file in self.host_info_files: try: os.unlink(host_info_file) except OSError: pass