import asyncio import json import os import signal import sys import time from pathlib import Path import aiohttp import aioredis import config payload = { "Authorization": f"Bot {config.token}", "User-Agent": f"DiscordBot (custom, {config.__version__})", } async def get_shard_count(): async with aiohttp.ClientSession() as session, session.get( "https://discordapp.com/api/gateway/bot", headers=payload, ) as req: response = await req.json() return response["shards"] def get_cluster_list(shards): return [ list(range(0, shards)[i : i + config.shards_per_cluster]) for i in range(0, shards, config.shards_per_cluster) ] class Instance: def __init__(self, instance_id, shard_list, shard_count, loop, main, cluster_count): self.id = instance_id self.shard_list = shard_list self.shard_count = shard_count self.loop = loop self.main = main self.cluster_count = cluster_count self.started_at = None self.command = ( f"{sys.executable} \"{Path.cwd() / 'main.py'}\" \"{shard_list}\" {shard_count} {self.id} {cluster_count}" ) self._process = None self.status = "initialized" self.started_at = 0.0 self.task = self.loop.create_task(self.start()) self.task.add_done_callback(self.main.dead_process_handler) @property def is_active(self): return self._process is not None and not self._process.returncode async def read_stream(self, stream): while self.is_active: try: line = await stream.readline() except (asyncio.LimitOverrunError, ValueError): continue if line: line = line.decode("utf-8")[:-1] print(f"[Cluster {self.id}] {line}") else: break async def start(self): if self.is_active: print(f"[Cluster {self.id}] Already active.") return self.started_at = time.time() self._process = await asyncio.create_subprocess_shell( self.command, stdin=asyncio.subprocess.DEVNULL, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, preexec_fn=os.setsid, limit=1024 * 256, ) self.status = "running" self.started_at = time.time() print(f"[Cluster {self.id}] The cluster is starting.") await asyncio.wait([self.read_stream(self._process.stdout), self.read_stream(self._process.stderr)]) return self async def stop(self): self.status = "stopped" os.killpg(os.getpgid(self._process.pid), signal.SIGTERM) print(f"[Cluster {self.id}] The cluster is killed.") await asyncio.sleep(5) def kill(self): self.status = "stopped" os.killpg(os.getpgid(self._process.pid), signal.SIGTERM) async def restart(self): if self.is_active: await self.stop() await self.start() class Main: def __init__(self, loop): self.loop = loop self.instances = [] self.redis = None def dead_process_handler(self, result): instance = result.result() print(f"[Cluster {instance.id}] The cluster exited with code {instance._process.returncode}.") if instance._process.returncode == 0 or instance._process.returncode == -15: print(f"[Cluster {instance.id}] The cluster stopped gracefully.") else: print(f"[Cluster {instance.id}] The cluster is restarting.") instance.loop.create_task(instance.start()) def get_instance(self, iterable, instance_id): for element in iterable: if getattr(element, "id") == instance_id: return element return None async def event_handler(self): self.redis = await aioredis.create_pool("redis://localhost", minsize=1, maxsize=2) await self.redis.execute_pubsub("SUBSCRIBE", config.ipc_channel) channel = self.redis.pubsub_channels[bytes(config.ipc_channel, "utf-8")] while await channel.wait_message(): payload = await channel.get_json(encoding="utf-8") if payload.get("scope") != "launcher" or not payload.get("action"): pass elif payload.get("action") == "restart": print(f"[Cluster Manager] Received signal to restart cluster {payload.get('id')}.") self.loop.create_task(self.get_instance(self.instances, payload.get("id")).restart()) elif payload.get("action") == "stop": print(f"[Cluster Manager] Received signal to stop cluster {payload.get('id')}.") self.loop.create_task(self.get_instance(self.instances, payload.get("id")).stop()) elif payload.get("action") == "start": print(f"[Cluster Manager] Received signal to start cluster {payload.get('id')}.") self.loop.create_task(self.get_instance(self.instances, payload.get("id")).start()) elif payload.get("action") == "statuses" and payload.get("command_id"): statuses = {} for instance in self.instances: statuses[str(instance.id)] = { "active": instance.is_active, "status": instance.status, "started_at": instance.started_at, } await self.redis.execute( "PUBLISH", config.ipc_channel, json.dumps({"command_id": payload["command_id"], "output": statuses}), ) elif payload.get("action") == "roll_restart": print("[Cluster Manager] Received signal to perform a rolling restart.") for instance in self.instances: self.loop.create_task(instance.restart()) await asyncio.sleep(config.shards_per_cluster * 10) async def close(self): await self.redis.execute_pubsub("UNSUBSCRIBE", config.ipc_channel) self.redis.close() def write_targets(self, clusters): data = [] for i, shard_list in enumerate(clusters, 1): if not shard_list: continue data.append({"labels": {"cluster": f"{i}"}, "targets": [f"localhost:{6000 + i}"]}) with open("targets.json", "w") as f: json.dump(data, f, indent=4) async def launch(self): self.loop.create_task(self.event_handler()) shard_count = await get_shard_count() + config.additional_shards clusters = get_cluster_list(shard_count) if config.testing is False: self.write_targets(clusters) print(f"[Cluster Manager] Starting a total of {len(clusters)} clusters.") for i, shard_list in enumerate(clusters, 1): if not shard_list: continue self.instances.append( Instance(i, shard_list, shard_count, self.loop, main=self, cluster_count=len(clusters)) ) await asyncio.sleep(config.shards_per_cluster * 10) loop = asyncio.get_event_loop() main = Main(loop=loop) loop.create_task(main.launch()) try: loop.run_forever() except KeyboardInterrupt: def shutdown_handler(_loop, context): if "exception" not in context or not isinstance(context["exception"], asyncio.CancelledError): _loop.default_exception_handler(context) loop.set_exception_handler(shutdown_handler) for instance in main.instances: instance.task.remove_done_callback(main.dead_process_handler) instance.kill() loop.run_until_complete(main.close()) tasks = asyncio.gather(*asyncio.all_tasks(loop=loop), return_exceptions=True) tasks.add_done_callback(lambda t: loop.stop()) tasks.cancel() finally: loop.run_until_complete(loop.shutdown_asyncgens()) loop.close()