import asyncio import logging import platform import signal from typing import Union import discord import discordhealthcheck from . import apis from . import commands from . import config from . import notifications from . import embeds from . import storage class SpaceXLaunchBotClient(discord.Client): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) logging.info("Client initialised") self.ds = storage.DataStore(config.PICKLE_DUMP_LOCATION) logging.info("Data storage initialised") if platform.system() == "Linux": self.loop.add_signal_handler( signal.SIGTERM, lambda: self.loop.create_task(self.shutdown()) ) logging.info("Signal handler for SIGTERM registered") self.loop.create_task(notifications.notification_task(self)) discordhealthcheck.start(self) async def on_ready(self) -> None: logging.info("Connected to Discord API") await self.set_playing(config.BOT_GAME) await self.update_website_metrics() async def shutdown(self) -> None: """Saves data to disk, cancels asyncio tasks, and disconnects from Discord""" logging.info("Shutting down") self.ds.save() for task in asyncio.Task.all_tasks(): task.cancel() await self.close() async def update_website_metrics(self) -> None: """Update Discord bot websites with guild count""" guild_count = len(self.guilds) logging.info(f"Updating bot lists with a guild_count of {guild_count}") await apis.bot_lists.post_all_bot_lists(guild_count) async def on_guild_join(self, guild: discord.guild) -> None: logging.info(f"Joined guild, ID: {guild.id}") await self.update_website_metrics() async def on_guild_remove(self, guild: discord.guild) -> None: logging.info(f"Removed from guild, ID: {guild.id}") await self.update_website_metrics() if self.ds.remove_guild_options(guild.id) is True: logging.info(f"Removed guild settings for {guild.id}") async def set_playing(self, title: str) -> None: await self.change_presence(activity=discord.Game(name=title)) async def on_message(self, message: discord.message) -> None: if message.author.bot or not message.guild: return message_parts = message.content.lower().split(" ") # ToDo: Temporary, remove after n months if message_parts[0].startswith(config.BOT_COMMAND_PREFIX_LEGACY): if message_parts[0][1:] in commands.CMD_LOOKUP: await self._send_s(message.channel, embeds.LEGACY_PREFIX_WARNING_EMBED) return if message_parts[0] != config.BOT_COMMAND_PREFIX: return to_send = None try: command_used = message_parts[1] run_command = commands.CMD_LOOKUP[command_used] to_send = await run_command(client=self, message=message) except (KeyError, IndexError): # Message contained wrong or no command pass except TypeError: logging.exception(f"run_command TypeError: {message.content=}") if to_send is None: return await self._send_s(message.channel, to_send) @staticmethod async def _send_s( channel: discord.TextChannel, to_send: Union[str, discord.Embed] ) -> None: """Safely send a text / embed message to a channel. Logs any errors that occur. Args: channel: A discord.Channel object. to_send: A String or discord.Embed object. """ try: if isinstance(to_send, discord.Embed): if embeds.embed_is_valid(to_send): await channel.send(embed=to_send) else: logging.warning("Embed is too large to send") else: await channel.send(to_send) except discord.errors.Forbidden as ex: logging.warning(f"Forbidden: {ex}") except discord.errors.HTTPException as ex: # Length/size is most likely cause, # see https://discord.com/developers/docs/resources/channel#embed-limits logging.warning(f"HTTPException: {ex}") async def send_all_subscribed( self, to_send: Union[str, discord.Embed], send_mentions: bool = False ) -> None: """Send a message to all subscribed channels. Args: to_send: A String or discord.Embed object. send_mentions: If True, get mentions from db and send as well. """ channel_ids = self.ds.get_subbed_channels() guild_opts = self.ds.get_all_guilds_options() invalid_ids = set() for channel_id in channel_ids: channel = self.get_channel(channel_id) if channel is None: invalid_ids.add(channel_id) continue await self._send_s(channel, to_send) if send_mentions: if (opts := guild_opts.get(channel.guild.id)) is not None: if (mentions := opts.get("mentions")) is not None: await self._send_s(channel, mentions) # Remove any channels from db that are picked up as invalid self.ds.remove_subbed_channels(invalid_ids)