import asyncio import logging import logging.config import signal from typing import Any, AsyncGenerator, Callable, List, Optional, Tuple try: import uvloop except ImportError: uvloop = None from src.protocols import introducer_protocol from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage from src.server.server import ChiaServer, start_server from src.types.peer_info import PeerInfo from src.util.logging import initialize_logging from src.util.config import load_config_cli, load_config from src.util.setproctitle import setproctitle from .reconnect_task import start_reconnect_task OutboundMessageGenerator = AsyncGenerator[OutboundMessage, None] def create_periodic_introducer_poll_task( server, peer_info, global_connections, introducer_connect_interval, target_peer_count, ): """ Start a background task connecting periodically to the introducer and requesting the peer list. """ def _num_needed_peers() -> int: diff = target_peer_count - len(global_connections.get_full_node_connections()) return diff if diff >= 0 else 0 async def introducer_client(): async def on_connect() -> OutboundMessageGenerator: msg = Message("request_peers", introducer_protocol.RequestPeers()) yield OutboundMessage(NodeType.INTRODUCER, msg, Delivery.RESPOND) while True: # If we are still connected to introducer, disconnect for connection in global_connections.get_connections(): if connection.connection_type == NodeType.INTRODUCER: global_connections.close(connection) # The first time connecting to introducer, keep trying to connect if _num_needed_peers(): if not await server.start_client(peer_info, on_connect): await asyncio.sleep(5) continue await asyncio.sleep(introducer_connect_interval) return asyncio.create_task(introducer_client()) class Service: def __init__( self, root_path, api: Any, node_type: NodeType, advertised_port: int, service_name: str, server_listen_ports: List[int] = [], connect_peers: List[PeerInfo] = [], on_connect_callback: Optional[OutboundMessage] = None, rpc_start_callback_port: Optional[Tuple[Callable, int]] = None, start_callback: Optional[Callable] = None, stop_callback: Optional[Callable] = None, await_closed_callback: Optional[Callable] = None, periodic_introducer_poll: Optional[Tuple[PeerInfo, int, int]] = None, ): net_config = load_config(root_path, "config.yaml") ping_interval = net_config.get("ping_interval") network_id = net_config.get("network_id") assert ping_interval is not None assert network_id is not None self._node_type = node_type proctitle_name = f"chia_{service_name}" setproctitle(proctitle_name) self._log = logging.getLogger(service_name) config = load_config_cli(root_path, "config.yaml", service_name) initialize_logging(f"{service_name:<30s}", config["logging"], root_path) self._rpc_start_callback_port = rpc_start_callback_port self._server = ChiaServer( config["port"], api, node_type, ping_interval, network_id, root_path, config, ) for _ in ["set_server", "_set_server"]: f = getattr(api, _, None) if f: f(self._server) self._connect_peers = connect_peers self._server_listen_ports = server_listen_ports self._api = api self._task = None self._is_stopping = False self._periodic_introducer_poll = periodic_introducer_poll self._on_connect_callback = on_connect_callback self._start_callback = start_callback self._stop_callback = stop_callback self._await_closed_callback = await_closed_callback def start(self): if self._task is not None: return async def _run(): if self._start_callback: await self._start_callback() self._introducer_poll_task = None if self._periodic_introducer_poll: ( peer_info, introducer_connect_interval, target_peer_count, ) = self._periodic_introducer_poll self._introducer_poll_task = create_periodic_introducer_poll_task( self._server, peer_info, self._server.global_connections, introducer_connect_interval, target_peer_count, ) self._rpc_task = None if self._rpc_start_callback_port: rpc_f, rpc_port = self._rpc_start_callback_port self._rpc_task = asyncio.ensure_future( rpc_f(self._api, self.stop, rpc_port) ) self._reconnect_tasks = [ start_reconnect_task(self._server, _, self._log) for _ in self._connect_peers ] self._server_sockets = [ await start_server(self._server, self._on_connect_callback) for _ in self._server_listen_ports ] try: asyncio.get_running_loop().add_signal_handler(signal.SIGINT, self.stop) asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, self.stop) except NotImplementedError: self._log.info("signal handlers unsupported") for _ in self._server_sockets: await _.wait_closed() await self._server.await_closed() if self._await_closed_callback: await self._await_closed_callback() self._task = asyncio.ensure_future(_run()) async def run(self): self.start() await self.wait_closed() self._log.info("Closed all node servers.") return 0 def stop(self): if not self._is_stopping: self._is_stopping = True for _ in self._server_sockets: _.close() for _ in self._reconnect_tasks: _.cancel() self._server.close_all() self._api._shut_down = True if self._introducer_poll_task: self._introducer_poll_task.cancel() if self._stop_callback: self._stop_callback() async def wait_closed(self): await self._task if self._rpc_task: await self._rpc_task self._log.info("Closed RPC server.") self._log.info("%s fully closed", self._node_type) async def async_run_service(*args, **kwargs): service = Service(*args, **kwargs) return await service.run() def run_service(*args, **kwargs): if uvloop is not None: uvloop.install() # TODO: use asyncio.run instead # for now, we use `run_until_complete` as `asyncio.run` blocks on RPC server not exiting if 1: return asyncio.get_event_loop().run_until_complete( async_run_service(*args, **kwargs) ) else: return asyncio.run(async_run_service(*args, **kwargs))