#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2020 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import atexit import multiprocessing import os import signal import sys import time from ...actors import create_actor_pool from ...cluster_info import StaticSchedulerDiscoverer from ...config import options, option_context from ...resource import cpu_count from ...scheduler.service import SchedulerService from ...session import new_session from ...utils import get_next_port, kill_process_tree from ...worker.service import WorkerService from .distributor import gen_distributor _mp_spawn_context = multiprocessing.get_context('spawn') _local_cluster_clients = dict() atexit.register(lambda: [v.stop() for v in list(_local_cluster_clients.values())]) class LocalDistributedCluster(object): # at least 2 process are required by scheduler and worker MIN_SCHEDULER_N_PROCESS = 2 MIN_WORKER_N_PROCESS = 2 def __init__(self, endpoint, n_process=None, scheduler_n_process=None, worker_n_process=None, cuda_device=None, ignore_avail_mem=True, shared_memory=None): self._endpoint = endpoint self._started = False self._stopped = False self._pool = None self._scheduler_service = SchedulerService() cuda_devices = [cuda_device] if cuda_device is not None else None self._worker_service = WorkerService(ignore_avail_mem=ignore_avail_mem, cache_mem_limit=shared_memory, cuda_devices=cuda_devices) self._scheduler_n_process, self._worker_n_process = \ self._calc_scheduler_worker_n_process(n_process, scheduler_n_process, worker_n_process) @property def pool(self): return self._pool @classmethod def _calc_scheduler_worker_n_process(cls, n_process, scheduler_n_process, worker_n_process, calc_cpu_count=cpu_count): n_scheduler, n_worker = scheduler_n_process, worker_n_process if n_scheduler is None and n_worker is None: n_scheduler = cls.MIN_SCHEDULER_N_PROCESS n_process = n_process if n_process is not None else calc_cpu_count() + n_scheduler n_worker = max(n_process - n_scheduler, cls.MIN_WORKER_N_PROCESS) elif n_scheduler is None or n_worker is None: # one of scheduler and worker n_process provided if n_scheduler is None: n_process = n_process if n_process is not None else calc_cpu_count() n_scheduler = max(n_process - n_worker, cls.MIN_SCHEDULER_N_PROCESS) else: assert n_worker is None n_process = n_process if n_process is not None else calc_cpu_count() + n_scheduler n_worker = max(n_process - n_scheduler, cls.MIN_WORKER_N_PROCESS) return n_scheduler, n_worker def _make_sure_scheduler_ready(self, timeout=120): check_start_time = time.time() while True: workers_meta = self._scheduler_service._resource_ref.get_workers_meta() if not workers_meta: # wait for worker to report status self._pool.sleep(.5) if time.time() - check_start_time > timeout: # pragma: no cover raise TimeoutError('Check worker ready timed out.') else: break def start_service(self): if self._started: return self._started = True # start plasma self._worker_service.start_plasma() # start actor pool n_process = self._scheduler_n_process + self._worker_n_process distributor = gen_distributor(self._scheduler_n_process, self._worker_n_process) self._pool = create_actor_pool(self._endpoint, n_process, distributor=distributor) discoverer = StaticSchedulerDiscoverer([self._endpoint]) # start scheduler first self._scheduler_service.start(self._endpoint, discoverer, self._pool, distributed=False) # start worker next self._worker_service.start(self._endpoint, self._pool, distributed=False, discoverer=discoverer, process_start_index=self._scheduler_n_process) # make sure scheduler is ready self._make_sure_scheduler_ready() def stop_service(self): if self._stopped: return self._stopped = True try: self._scheduler_service.stop(self._pool) self._worker_service.stop() finally: self._pool.stop() def serve_forever(self): try: self._pool.join() except KeyboardInterrupt: pass finally: self.stop_service() def __enter__(self): self.start_service() return self def __exit__(self, *_): self.stop_service() def gen_endpoint(address): port = None tries = 5 # retry for 5 times for i in range(tries): try: port = get_next_port() break except SystemError: if i < tries - 1: continue raise return '{0}:{1}'.format(address, port) def _start_cluster(endpoint, event, n_process=None, shared_memory=None, **kw): modules = kw.pop('modules', None) or [] for m in modules: __import__(m, globals(), locals(), []) options_dict = kw.pop('options', None) or {} with option_context(options_dict): cluster = LocalDistributedCluster(endpoint, n_process=n_process, shared_memory=shared_memory, **kw) cluster.start_service() event.set() try: cluster.serve_forever() finally: cluster.stop_service() def _start_cluster_process(endpoint, n_process, shared_memory, **kw): event = _mp_spawn_context.Event() kw = kw.copy() kw['n_process'] = n_process kw['shared_memory'] = shared_memory or '20%' process = _mp_spawn_context.Process( target=_start_cluster, args=(endpoint, event), kwargs=kw) process.start() while True: event.wait(5) if not event.is_set(): # service not started yet continue if not process.is_alive(): raise SystemError('New local cluster failed') else: break return process def _start_web(scheduler_address, ui_port, event): import gevent.monkey gevent.monkey.patch_all(thread=False) from ...web import MarsWeb web = MarsWeb(None, ui_port, scheduler_address) try: web.start(event=event, block=True) finally: web.stop() def _start_web_process(scheduler_endpoint, web_endpoint): ui_port = int(web_endpoint.rsplit(':', 1)[1]) web_event = _mp_spawn_context.Event() web_process = _mp_spawn_context.Process( target=_start_web, args=(scheduler_endpoint, ui_port, web_event), daemon=True) web_process.start() while True: web_event.wait(5) if not web_event.is_set(): # web not started yet continue if not web_process.is_alive(): raise SystemError('New web interface failed') else: break return web_process class LocalDistributedClusterClient(object): def __init__(self, endpoint, web_endpoint, cluster_process, web_process): self._cluster_process = cluster_process self._web_process = web_process self._endpoint = endpoint self._web_endpoint = web_endpoint self._session = new_session(endpoint).as_default() @property def endpoint(self): return self._endpoint @property def web_endpoint(self): return self._web_endpoint @property def session(self): return self._session def __enter__(self): return self def __exit__(self, *_): self.stop() @staticmethod def _ensure_process_finish(proc): if proc is None or not proc.is_alive(): return proc.join(3) kill_process_tree(proc.pid) def stop(self): try: del _local_cluster_clients[id(self)] except KeyError: # pragma: no cover pass if self._cluster_process.is_alive(): os.kill(self._cluster_process.pid, signal.SIGINT) if self._web_process is not None and self._web_process.is_alive(): os.kill(self._web_process.pid, signal.SIGINT) self._ensure_process_finish(self._cluster_process) self._ensure_process_finish(self._web_process) def new_cluster(address='0.0.0.0', web=False, n_process=None, shared_memory=None, open_browser=None, **kw): open_browser = open_browser if open_browser is not None else options.deploy.open_browser endpoint = gen_endpoint(address) web_endpoint = None if web is True: web_endpoint = gen_endpoint('0.0.0.0') elif isinstance(web, str): if ':' in web: web_endpoint = web else: web_endpoint = gen_endpoint(web) process = _start_cluster_process(endpoint, n_process, shared_memory, **kw) web_process = None if web_endpoint: web_process = _start_web_process(endpoint, web_endpoint) print('Web endpoint started at http://%s' % web_endpoint, file=sys.stderr) if open_browser: import webbrowser webbrowser.open_new_tab('http://%s' % web_endpoint) client = LocalDistributedClusterClient(endpoint, web_endpoint, process, web_process) _local_cluster_clients[id(client)] = client return client