#!/usr/bin/env python import argparse import zmq import os import sys import platform import random import time import pickle import logging import queue import threading import json import daemon import collections from parsl.version import VERSION as PARSL_VERSION from funcx.sdk.client import FuncXClient from funcx.executors.high_throughput.interchange_task_dispatch import naive_interchange_task_dispatch from funcx.serialize import FuncXSerializer LOOP_SLOWDOWN = 0.0 # in seconds HEARTBEAT_CODE = (2 ** 32) - 1 PKL_HEARTBEAT_CODE = pickle.dumps((2 ** 32) - 1) class ShutdownRequest(Exception): ''' Exception raised when any async component receives a ShutdownRequest ''' def __init__(self): self.tstamp = time.time() def __repr__(self): return "Shutdown request received at {}".format(self.tstamp) class ManagerLost(Exception): ''' Task lost due to worker loss. Worker is considered lost when multiple heartbeats have been missed. ''' def __init__(self, worker_id): self.worker_id = worker_id self.tstamp = time.time() def __repr__(self): return "Task failure due to loss of worker {}".format(self.worker_id) class BadRegistration(Exception): ''' A new Manager tried to join the executor with a BadRegistration message ''' def __init__(self, worker_id, critical=False): self.worker_id = worker_id self.tstamp = time.time() self.handled = "critical" if critical else "suppressed" def __repr__(self): return "Manager:{} caused a {} failure".format(self.worker_id, self.handled) class Interchange(object): """ Interchange is a task orchestrator for distributed systems. 1. Asynchronously queue large volume of tasks (>100K) 2. Allow for workers to join and leave the union 3. Detect workers that have failed using heartbeats 4. Service single and batch requests from workers 5. Be aware of requests worker resource capacity, eg. schedule only jobs that fit into walltime. TODO: We most likely need a PUB channel to send out global commands, like shutdown """ def __init__(self, config, client_address="127.0.0.1", interchange_address="127.0.0.1", client_ports=(50055, 50056, 50057), worker_ports=None, worker_port_range=(54000, 55000), cores_per_worker=1.0, worker_debug=False, launch_cmd=None, heartbeat_threshold=60, logdir=".", logging_level=logging.INFO, poll_period=10, endpoint_id=None, suppress_failure=False, max_heartbeats_missed=2 ): """ Parameters ---------- config : funcx.Config object Funcx config object that describes how compute should be provisioned client_address : str The ip address at which the parsl client can be reached. Default: "127.0.0.1" interchange_address : str The ip address at which the workers will be able to reach the Interchange. Default: "127.0.0.1" client_ports : triple(int, int, int) The ports at which the client can be reached launch_cmd : str TODO : update worker_ports : tuple(int, int) The specific two ports at which workers will connect to the Interchange. Default: None worker_port_range : tuple(int, int) The interchange picks ports at random from the range which will be used by workers. This is overridden when the worker_ports option is set. Defauls: (54000, 55000) cores_per_worker : float cores to be assigned to each worker. Oversubscription is possible by setting cores_per_worker < 1.0. Default=1 worker_debug : Bool Enables worker debug logging. heartbeat_threshold : int Number of seconds since the last heartbeat after which worker is considered lost. logdir : str Parsl log directory paths. Logs and temp files go here. Default: '.' logging_level : int Logging level as defined in the logging module. Default: logging.INFO (20) endpoint_id : str Identity string that identifies the endpoint to the broker poll_period : int The main thread polling period, in milliseconds. Default: 10ms suppress_failure : Bool When set to True, the interchange will attempt to suppress failures. Default: False max_heartbeats_missed : int Number of heartbeats missed before setting kill_event """ self.logdir = logdir try: os.makedirs(self.logdir) except FileExistsError: pass start_file_logger("{}/interchange.log".format(self.logdir), level=logging_level) logger.info("logger location {}".format(logger.handlers)) logger.info("Initializing Interchange process with Endpoint ID: {}".format(endpoint_id)) self.config = config logger.info("Got config : {}".format(config)) self.strategy = self.config.strategy self.client_address = client_address self.interchange_address = interchange_address self.suppress_failure = suppress_failure self.poll_period = poll_period self.serializer = FuncXSerializer() logger.info("Attempting connection to client at {} on ports: {},{},{}".format( client_address, client_ports[0], client_ports[1], client_ports[2])) self.context = zmq.Context() self.task_incoming = self.context.socket(zmq.DEALER) self.task_incoming.set_hwm(0) self.task_incoming.RCVTIMEO = 10 # in milliseconds logger.info("Task incoming on tcp://{}:{}".format(client_address, client_ports[0])) self.task_incoming.connect("tcp://{}:{}".format(client_address, client_ports[0])) self.results_outgoing = self.context.socket(zmq.DEALER) self.results_outgoing.set_hwm(0) logger.info("Results outgoing on tcp://{}:{}".format(client_address, client_ports[1])) self.results_outgoing.connect("tcp://{}:{}".format(client_address, client_ports[1])) self.command_channel = self.context.socket(zmq.DEALER) self.command_channel.RCVTIMEO = 1000 # in milliseconds # self.command_channel.set_hwm(0) logger.info("Command channel on tcp://{}:{}".format(client_address, client_ports[2])) self.command_channel.connect("tcp://{}:{}".format(client_address, client_ports[2])) logger.info("Connected to client") self.pending_task_queue = {} self.containers = {} self.total_pending_task_count = 0 self.fxs = FuncXClient() logger.info("Interchange address is {}".format(self.interchange_address)) self.worker_ports = worker_ports self.worker_port_range = worker_port_range self.task_outgoing = self.context.socket(zmq.ROUTER) self.task_outgoing.set_hwm(0) self.results_incoming = self.context.socket(zmq.ROUTER) self.results_incoming.set_hwm(0) # initalize the last heartbeat time to start the loop self.last_heartbeat = time.time() self.max_heartbeats_missed = max_heartbeats_missed self.endpoint_id = endpoint_id if self.worker_ports: self.worker_task_port = self.worker_ports[0] self.worker_result_port = self.worker_ports[1] self.task_outgoing.bind("tcp://*:{}".format(self.worker_task_port)) self.results_incoming.bind("tcp://*:{}".format(self.worker_result_port)) else: self.worker_task_port = self.task_outgoing.bind_to_random_port('tcp://*', min_port=worker_port_range[0], max_port=worker_port_range[1], max_tries=100) self.worker_result_port = self.results_incoming.bind_to_random_port('tcp://*', min_port=worker_port_range[0], max_port=worker_port_range[1], max_tries=100) logger.info("Bound to ports {},{} for incoming worker connections".format( self.worker_task_port, self.worker_result_port)) self._ready_manager_queue = {} self.heartbeat_threshold = heartbeat_threshold self.blocks = {} # type: Dict[str, str] self.block_id_map = {} self.launch_cmd = launch_cmd self.last_core_hr_counter = 0 if not launch_cmd: self.launch_cmd = ("funcx-manager {debug} {max_workers} " "-c {cores_per_worker} " "--poll {poll_period} " "--task_url={task_url} " "--result_url={result_url} " "--logdir={logdir} " "--block_id={{block_id}} " "--hb_period={heartbeat_period} " "--hb_threshold={heartbeat_threshold} " "--worker_mode={worker_mode} " "--scheduler_mode={scheduler_mode} " "--worker_type={{worker_type}} ") self.current_platform = {'parsl_v': PARSL_VERSION, 'python_v': "{}.{}.{}".format(sys.version_info.major, sys.version_info.minor, sys.version_info.micro), 'os': platform.system(), 'hname': platform.node(), 'dir': os.getcwd()} logger.info("Platform info: {}".format(self.current_platform)) self._block_counter = 0 try: self.load_config() except Exception as e: logger.exception("Caught exception") raise def load_config(self): """ Load the config """ logger.info("Loading endpoint local config") working_dir = self.config.working_dir if self.config.working_dir is None: working_dir = "{}/{}".format(self.logdir, "worker_logs") logger.info("Setting working_dir: {}".format(working_dir)) self.config.provider.script_dir = working_dir if hasattr(self.config.provider, 'channel'): self.config.provider.channel.script_dir = os.path.join(working_dir, 'submit_scripts') self.config.provider.channel.makedirs(self.config.provider.channel.script_dir, exist_ok=True) os.makedirs(self.config.provider.script_dir, exist_ok=True) debug_opts = "--debug" if self.config.worker_debug else "" max_workers = "" if self.config.max_workers_per_node == float('inf') \ else "--max_workers={}".format(self.config.max_workers_per_node) worker_task_url = f"tcp://{self.interchange_address}:{self.worker_task_port}" worker_result_url = f"tcp://{self.interchange_address}:{self.worker_result_port}" l_cmd = self.launch_cmd.format(debug=debug_opts, max_workers=max_workers, cores_per_worker=self.config.cores_per_worker, #mem_per_worker=self.config.mem_per_worker, prefetch_capacity=self.config.prefetch_capacity, task_url=worker_task_url, result_url=worker_result_url, nodes_per_block=self.config.provider.nodes_per_block, heartbeat_period=self.config.heartbeat_period, heartbeat_threshold=self.config.heartbeat_threshold, poll_period=self.config.poll_period, worker_mode=self.config.worker_mode, scheduler_mode=self.config.scheduler_mode, logdir=working_dir) self.launch_cmd = l_cmd logger.info("Launch command: {}".format(self.launch_cmd)) if self.config.scaling_enabled: logger.info("Scaling ...") self.scale_out(self.config.provider.init_blocks) def get_tasks(self, count): """ Obtains a batch of tasks from the internal pending_task_queue Parameters ---------- count: int Count of tasks to get from the queue Returns ------- List of upto count tasks. May return fewer than count down to an empty list eg. [{'task_id':<x>, 'buffer':<buf>} ... ] """ tasks = [] for i in range(0, count): try: x = self.pending_task_queue.get(block=False) except queue.Empty: break else: tasks.append(x) return tasks def migrate_tasks_to_internal(self, kill_event, status_request): """Pull tasks from the incoming tasks 0mq pipe onto the internal pending task queue Parameters: ----------- kill_event : threading.Event Event to let the thread know when it is time to die. """ logger.info("[TASK_PULL_THREAD] Starting") task_counter = 0 poller = zmq.Poller() poller.register(self.task_incoming, zmq.POLLIN) while not kill_event.is_set(): # Check when the last heartbeat was. # logger.debug(f"[TASK_PULL_THREAD] Last heartbeat: {self.last_heartbeat}") if int(time.time() - self.last_heartbeat) > (self.heartbeat_threshold * self.max_heartbeats_missed): logger.critical("[TASK_PULL_THREAD] Missed too many heartbeats. Setting kill event.") kill_event.set() break try: msg = self.task_incoming.recv_pyobj() self.last_heartbeat = time.time() except zmq.Again: # We just timed out while attempting to receive logger.debug("[TASK_PULL_THREAD] {} tasks in internal queue".format(self.total_pending_task_count)) continue if msg == 'STOP': kill_event.set() break elif msg == 'STATUS_REQUEST': logger.info("Got STATUS_REQUEST") status_request.set() else: logger.info("[TASK_PULL_THREAD] Received task:{}".format(msg)) task_type = self.get_container(msg['task_id'].split(";")[1]) msg['container'] = task_type if task_type not in self.pending_task_queue: self.pending_task_queue[task_type] = queue.Queue(maxsize=10 ** 6) self.pending_task_queue[task_type].put(msg) self.total_pending_task_count += 1 logger.debug("[TASK_PULL_THREAD] pending task count: {}".format(self.total_pending_task_count)) task_counter += 1 logger.debug("[TASK_PULL_THREAD] Fetched task:{}".format(task_counter)) def get_container(self, container_uuid): """ Get the container image location if it is not known to the interchange""" if container_uuid not in self.containers: if container_uuid == 'RAW' or not container_uuid: self.containers[container_uuid] = 'RAW' else: try: container = self.fxs.get_container(container_uuid, self.config.container_type) except Exception: logger.exception("[FETCH_CONTAINER] Unable to resolve container location") self.containers[container_uuid] = 'RAW' else: logger.info("[FETCH_CONTAINER] Got container info: {}".format(container)) self.containers[container_uuid] = container.get('location', 'RAW') return self.containers[container_uuid] def get_total_tasks_outstanding(self): """ Get the outstanding tasks in total """ outstanding = {} for task_type in self.pending_task_queue: outstanding[task_type] = outstanding.get(task_type, 0) + self.pending_task_queue[task_type].qsize() for manager in self._ready_manager_queue: for task_type in self._ready_manager_queue[manager]['tasks']: outstanding[task_type] = outstanding.get(task_type, 0) + len(self._ready_manager_queue[manager]['tasks'][task_type]) return outstanding def get_total_live_workers(self): """ Get the total active workers """ active = 0 for manager in self._ready_manager_queue: if self._ready_manager_queue[manager]['active']: active += self._ready_manager_queue[manager]['max_worker_count'] return active def get_outstanding_breakdown(self): """ Get outstanding breakdown per manager and in the interchange queues Returns ------- List of status for online elements [ (element, tasks_pending, status) ... ] """ pending_on_interchange = self.total_pending_task_count # Reporting pending on interchange is a deviation from Parsl reply = [('interchange', pending_on_interchange, True)] for manager in self._ready_manager_queue: resp = (manager.decode('utf-8'), sum([len(tids) for tids in self._ready_manager_queue[manager]['tasks'].values()]), self._ready_manager_queue[manager]['active']) reply.append(resp) return reply def _hold_block(self, block_id): """ Sends hold command to all managers which are in a specific block Parameters ---------- block_id : str Block identifier of the block to be put on hold """ for manager in self._ready_manager_queue: if self._ready_manager_queue[manager]['active'] and \ self._ready_manager_queue[manager]['block_id'] == block_id: logger.debug("[HOLD_BLOCK]: Sending hold to manager: {}".format(manager)) self.hold_manager(manager) def hold_manager(self, manager): """ Put manager on hold Parameters ---------- manager : str Manager id to be put on hold while being killed """ if manager in self._ready_manager_queue: self._ready_manager_queue[manager]['active'] = False reply = True else: reply = False def _command_server(self, kill_event): """ Command server to run async command to the interchange """ logger.debug("[COMMAND] Command Server Starting") while not kill_event.is_set(): try: command_req = self.command_channel.recv_pyobj() logger.debug("[COMMAND] Received command request: {}".format(command_req)) if command_req == "OUTSTANDING_C": reply = self.get_total_outstanding() elif command_req == "MANAGERS": reply = self.get_outstanding_breakdown() elif command_req.startswith("HOLD_WORKER"): cmd, s_manager = command_req.split(';') manager = s_manager.encode('utf-8') logger.info("[CMD] Received HOLD_WORKER for {}".format(manager)) if manager in self._ready_manager_queue: self._ready_manager_queue[manager]['active'] = False reply = True else: reply = False elif command_req == "HEARTBEAT": logger.info("[CMD] Received heartbeat message from hub") reply = "HBT,{}".format(self.endpoint_id) elif command_req == "SHUTDOWN": logger.info("[CMD] Received SHUTDOWN command") kill_event.set() reply = True else: reply = None logger.debug("[COMMAND] Reply: {}".format(reply)) self.command_channel.send_pyobj(reply) except zmq.Again: logger.debug("[COMMAND] is alive") continue def stop(self): """Prepare the interchange for shutdown""" self._kill_event.set() self._task_puller_thread.join() self._command_thread.join() def start(self, poll_period=None): """ Start the Interchange Parameters: ---------- poll_period : int poll_period in milliseconds """ logger.info("Incoming ports bound") if poll_period is None: poll_period = self.poll_period start = time.time() count = 0 self._kill_event = threading.Event() self._status_request = threading.Event() self._task_puller_thread = threading.Thread(target=self.migrate_tasks_to_internal, args=(self._kill_event, self._status_request, )) self._task_puller_thread.start() self._command_thread = threading.Thread(target=self._command_server, args=(self._kill_event, )) self._command_thread.start() try: logger.debug("Starting strategy.") self.strategy.start(self) except RuntimeError as e: # This is raised when re-registering an endpoint as strategy already exists logger.debug("Failed to start strategy.") logger.info(e) poller = zmq.Poller() # poller.register(self.task_incoming, zmq.POLLIN) poller.register(self.task_outgoing, zmq.POLLIN) poller.register(self.results_incoming, zmq.POLLIN) # These are managers which we should examine in an iteration # for scheduling a job (or maybe any other attention?). # Anything altering the state of the manager should add it # onto this list. interesting_managers = set() while not self._kill_event.is_set(): self.socks = dict(poller.poll(timeout=poll_period)) # Listen for requests for work if self.task_outgoing in self.socks and self.socks[self.task_outgoing] == zmq.POLLIN: logger.debug("[MAIN] starting task_outgoing section") message = self.task_outgoing.recv_multipart() manager = message[0] if manager not in self._ready_manager_queue: reg_flag = False try: msg = json.loads(message[1].decode('utf-8')) reg_flag = True except Exception: logger.warning("[MAIN] Got a non-json registration message from manager:{}".format( manager)) logger.debug("[MAIN] Message :\n{}\n".format(message)) # By default we set up to ignore bad nodes/registration messages. self._ready_manager_queue[manager] = {'last': time.time(), 'reg_time': time.time(), 'free_capacity': {'total_workers': 0}, 'max_worker_count': 0, 'active': True, 'tasks': collections.defaultdict(set), 'total_tasks': 0} if reg_flag is True: interesting_managers.add(manager) logger.info("[MAIN] Adding manager: {} to ready queue".format(manager)) self._ready_manager_queue[manager].update(msg) logger.info("[MAIN] Registration info for manager {}: {}".format(manager, msg)) if (msg['python_v'].rsplit(".", 1)[0] != self.current_platform['python_v'].rsplit(".", 1)[0] or msg['parsl_v'] != self.current_platform['parsl_v']): logger.warn("[MAIN] Manager {} has incompatible version info with the interchange".format(manager)) if self.suppress_failure is False: logger.debug("Setting kill event") self._kill_event.set() e = ManagerLost(manager) result_package = {'task_id': -1, 'exception': self.serializer.serialize(e)} pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) logger.warning("[MAIN] Sent failure reports, unregistering manager") else: logger.debug("[MAIN] Suppressing shutdown due to version incompatibility") else: # Registration has failed. if self.suppress_failure is False: logger.debug("Setting kill event for bad manager") self._kill_event.set() e = BadRegistration(manager, critical=True) result_package = {'task_id': -1, 'exception': self.serializer.serialize(e)} pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) else: logger.debug("[MAIN] Suppressing bad registration from manager:{}".format( manager)) else: self._ready_manager_queue[manager]['last'] = time.time() if message[1] == b'HEARTBEAT': logger.debug("[MAIN] Manager {} sends heartbeat".format(manager)) self.task_outgoing.send_multipart([manager, b'', PKL_HEARTBEAT_CODE]) else: manager_adv = pickle.loads(message[1]) logger.debug("[MAIN] Manager {} requested {}".format(manager, manager_adv)) self._ready_manager_queue[manager]['free_capacity'].update(manager_adv) self._ready_manager_queue[manager]['free_capacity']['total_workers'] = sum(manager_adv.values()) interesting_managers.add(manager) # If we had received any requests, check if there are tasks that could be passed logger.debug("[MAIN] Managers count (total/interesting): {}/{}".format( len(self._ready_manager_queue), len(interesting_managers))) task_dispatch, dispatched_task = naive_interchange_task_dispatch(interesting_managers, self.pending_task_queue, self._ready_manager_queue, scheduler_mode=self.config.scheduler_mode) self.total_pending_task_count -= dispatched_task for manager in task_dispatch: tasks = task_dispatch[manager] if tasks: logger.info("[MAIN] Sending task message {} to manager {}".format(tasks, manager)) self.task_outgoing.send_multipart([manager, b'', pickle.dumps(tasks)]) # Receive any results and forward to client if self.results_incoming in self.socks and self.socks[self.results_incoming] == zmq.POLLIN: logger.debug("[MAIN] entering results_incoming section") manager, *b_messages = self.results_incoming.recv_multipart() if manager not in self._ready_manager_queue: logger.warning("[MAIN] Received a result from a un-registered manager: {}".format(manager)) else: logger.info("[MAIN] Got {} result items in batch".format(len(b_messages))) for b_message in b_messages: r = pickle.loads(b_message) # logger.debug("[MAIN] Received result for task {} from {}".format(r['task_id'], manager)) task_type = self.containers[r['task_id'].split(';')[1]] self._ready_manager_queue[manager]['tasks'][task_type].remove(r['task_id']) self._ready_manager_queue[manager]['total_tasks'] -= len(b_messages) self.results_outgoing.send_multipart(b_messages) logger.debug("[MAIN] Current tasks: {}".format(self._ready_manager_queue[manager]['tasks'])) logger.debug("[MAIN] leaving results_incoming section") # logger.debug("[MAIN] entering bad_managers section") bad_managers = [manager for manager in self._ready_manager_queue if time.time() - self._ready_manager_queue[manager]['last'] > self.heartbeat_threshold] for manager in bad_managers: logger.debug("[MAIN] Last: {} Current: {}".format(self._ready_manager_queue[manager]['last'], time.time())) logger.warning("[MAIN] Too many heartbeats missed for manager {}".format(manager)) e = ManagerLost(manager) for task_type in self._ready_manager_queue[manager]['tasks']: for tid in self._ready_manager_queue[manager]['tasks'][task_type]: result_package = {'task_id': tid, 'exception': self.serializer.serialize(e)} pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) logger.warning("[MAIN] Sent failure reports, unregistering manager") self._ready_manager_queue.pop(manager, 'None') if manager in interesting_managers: interesting_managers.remove(manager) logger.debug("[MAIN] ending one main loop iteration") if self._status_request.is_set(): logger.info("status request response") result_package = self.get_status_report() pkl_package = pickle.dumps(result_package) self.results_outgoing.send(pkl_package) logger.info("[MAIN] Sent info response") self._status_request.clear() delta = time.time() - start logger.info("Processed {} tasks in {} seconds".format(count, delta)) logger.warning("Exiting") def get_status_report(self): """ Get utilization numbers """ total_cores = 0 total_mem = 0 core_hrs = 0 active_managers = 0 free_capacity = 0 outstanding_tasks = self.get_total_tasks_outstanding() pending_tasks = self.total_pending_task_count num_managers = len(self._ready_manager_queue) live_workers = self.get_total_live_workers() for manager in self._ready_manager_queue: total_cores += self._ready_manager_queue[manager]['cores'] total_mem += self._ready_manager_queue[manager]['mem'] active_dur = abs(time.time() - self._ready_manager_queue[manager]['reg_time']) core_hrs += (active_dur * total_cores) / 3600 if self._ready_manager_queue[manager]['active']: active_managers += 1 free_capacity += self._ready_manager_queue[manager]['free_capacity']['total_workers'] result_package = {'task_id': -2, 'info': {'total_cores': total_cores, 'total_mem' : total_mem, 'new_core_hrs': core_hrs - self.last_core_hr_counter, 'total_core_hrs': round(core_hrs, 2), 'managers': num_managers, 'active_managers': active_managers, 'total_workers': live_workers, 'idle_workers': free_capacity, 'pending_tasks': pending_tasks, 'outstanding_tasks': outstanding_tasks, 'worker_mode': self.config.worker_mode, 'scheduler_mode': self.config.scheduler_mode, 'scaling_enabled': self.config.scaling_enabled, 'mem_per_worker': self.config.mem_per_worker, 'cores_per_worker': self.config.cores_per_worker, 'prefetch_capacity': self.config.prefetch_capacity, 'max_blocks': self.config.provider.max_blocks, 'min_blocks': self.config.provider.min_blocks, 'max_workers_per_node': self.config.max_workers_per_node, 'nodes_per_block': self.config.provider.nodes_per_block }} self.last_core_hr_counter = core_hrs return result_package def scale_out(self, blocks=1, task_type=None): """Scales out the number of blocks by "blocks" Raises: NotImplementedError """ r = [] for i in range(blocks): if self.config.provider: self._block_counter += 1 external_block_id = str(self._block_counter) if not task_type and self.config.scheduler_mode == 'hard': launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type='RAW') else: launch_cmd = self.launch_cmd.format(block_id=external_block_id, worker_type=task_type) if not task_type: internal_block = self.config.provider.submit(launch_cmd, 1) else: internal_block = self.config.provider.submit(launch_cmd, 1, task_type) logger.debug("Launched block {}->{}".format(external_block_id, internal_block)) if not internal_block: raise(ScalingFailed(self.provider.label, "Attempts to provision nodes via provider has failed")) self.blocks[external_block_id] = internal_block self.block_id_map[internal_block] = external_block_id else: logger.error("No execution provider available") r = None return r def scale_in(self, blocks=None, block_ids=[], task_type=None): """Scale in the number of active blocks by specified amount. Parameters ---------- blocks : int # of blocks to terminate block_ids : [str.. ] List of external block ids to terminate """ if task_type: logger.info("Scaling in blocks of specific task type {}. Let the provider decide which to kill".format(task_type)) if self.config.scaling_enabled and self.config.provider: to_kill, r = self.config.provider.cancel(blocks, task_type) logger.info("Get the killed blocks: {}, and status: {}".format(to_kill, r)) for job in to_kill: logger.info("[scale_in] Getting the block_id map {} for job {}".format(self.block_id_map, job)) block_id = self.block_id_map[job] logger.info("[scale_in] Holding block {}".format(block_id)) self._hold_block(block_id) self.blocks.pop(block_id) return r if block_ids: block_ids_to_kill = block_ids else: block_ids_to_kill = list(self.blocks.keys())[:blocks] # Try a polite terminate # TODO : Missing logic to hold blocks for block_id in block_ids_to_kill: self._hold_block(block_id) # Now kill via provider to_kill = [self.blocks.pop(bid) for bid in block_ids_to_kill] if self.config.scaling_enabled and self.config.provider: r = self.config.provider.cancel(to_kill) return r def provider_status(self): """ Get status of all blocks from the provider """ status = [] if self.config.provider: logger.debug("[MAIN] Getting the status of {} blocks.".format(list(self.blocks.values()))) status = self.config.provider.status(list(self.blocks.values())) logger.debug("[MAIN] The status is {}".format(status)) return status def start_file_logger(filename, name="interchange", level=logging.DEBUG, format_string=None): """Add a stream log handler. Parameters --------- filename: string Name of the file to write logs to. Required. name: string Logger name. Default="parsl.executors.interchange" level: logging.LEVEL Set the logging level. Default=logging.DEBUG - format_string (string): Set the format string format_string: string Format string to use. Returns ------- None. """ if format_string is None: format_string = "%(asctime)s.%(msecs)03d %(name)s:%(lineno)d [%(levelname)s] %(message)s" global logger logger = logging.getLogger(name) logger.setLevel(level) if not len(logger.handlers): handler = logging.FileHandler(filename) handler.setLevel(level) formatter = logging.Formatter(format_string, datefmt='%Y-%m-%d %H:%M:%S') handler.setFormatter(formatter) logger.addHandler(handler) def starter(comm_q, *args, **kwargs): """Start the interchange process The executor is expected to call this function. The args, kwargs match that of the Interchange.__init__ """ # logger = multiprocessing.get_logger() ic = Interchange(*args, **kwargs) comm_q.put((ic.worker_task_port, ic.worker_result_port)) ic.start() def cli_run(): parser = argparse.ArgumentParser() parser.add_argument("-c", "--client_address", required=True, help="Client address") parser.add_argument("--client_ports", required=True, help="client ports as a triple of outgoing,incoming,command") parser.add_argument("--worker_port_range", help="Worker port range as a tuple") parser.add_argument("-l", "--logdir", default="./parsl_worker_logs", help="Parsl worker log directory") parser.add_argument("-p", "--poll_period", help="REQUIRED: poll period used for main thread") parser.add_argument("--worker_ports", default=None, help="OPTIONAL, pair of workers ports to listen on, eg --worker_ports=50001,50005") parser.add_argument("--suppress_failure", action='store_true', help="Enables suppression of failures") parser.add_argument("--endpoint_id", default=None, help="Endpoint ID, used to identify the endpoint to the remote broker") parser.add_argument("--hb_threshold", help="Heartbeat threshold in seconds") parser.add_argument("--config", default=None, help="Configuration object that describes provisioning") parser.add_argument("-d", "--debug", action='store_true', help="Enables debug logging") print("Starting HTEX Intechange") args = parser.parse_args() optionals = {} optionals['suppress_failure'] = args.suppress_failure optionals['logdir'] = os.path.abspath(args.logdir) optionals['client_address'] = args.client_address optionals['client_ports'] = [int(i) for i in args.client_ports.split(',')] optionals['endpoint_id'] = args.endpoint_id optionals['config'] = args.config if args.debug: optionals['logging_level'] = logging.DEBUG if args.worker_ports: optionals['worker_ports'] = [int(i) for i in args.worker_ports.split(',')] if args.worker_port_range: optionals['worker_port_range'] = [int(i) for i in args.worker_port_range.split(',')] with daemon.DaemonContext(): ic = Interchange(**optionals) ic.start()