import json import os from datetime import datetime from typing import Dict, Iterator, List, Optional, Tuple from uuid import UUID import structlog from eth_utils import to_canonical_address, to_checksum_address from pathfinding_service.model import IOU from pathfinding_service.model.channel import Channel from pathfinding_service.model.feedback import FeedbackToken from pathfinding_service.model.token_network import TokenNetwork from pathfinding_service.typing import DeferableMessage from raiden.messages.path_finding_service import PFSCapacityUpdate from raiden.storage.serialization.serializer import JSONSerializer from raiden.utils.typing import ( Address, BlockNumber, ChainID, ChannelID, FeeAmount, TokenAmount, TokenNetworkAddress, ) from raiden_libs.database import BaseDatabase, hex256 log = structlog.get_logger(__name__) class PFSDatabase(BaseDatabase): """ Store data that needs to persist between PFS restarts """ schema_filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), "schema.sql") def __init__( self, filename: str, chain_id: ChainID, pfs_address: Address, sync_start_block: BlockNumber = BlockNumber(0), allow_create: bool = False, **contract_addresses: Address, ): super().__init__(filename, allow_create=allow_create) self.pfs_address = pfs_address # Keep the journal around and skip inode updates. # References: # https://sqlite.org/atomiccommit.html#_persistent_rollback_journals # https://sqlite.org/pragma.html#pragma_journal_mode self.conn.execute("PRAGMA journal_mode=PERSIST") self._setup( chain_id=chain_id, receiver=pfs_address, sync_start_block=sync_start_block, **contract_addresses, ) def upsert_capacity_update(self, message: PFSCapacityUpdate) -> None: capacity_update_dict = dict( updating_participant=to_checksum_address(message.updating_participant), token_network_address=to_checksum_address( message.canonical_identifier.token_network_address ), channel_id=hex256(message.canonical_identifier.channel_identifier), updating_capacity=hex256(message.updating_capacity), other_capacity=hex256(message.other_capacity), ) self.upsert("capacity_update", capacity_update_dict) def get_capacity_updates( self, updating_participant: Address, token_network_address: TokenNetworkAddress, channel_id: int, ) -> Tuple[TokenAmount, TokenAmount]: capacity_list = self.conn.execute( """ SELECT updating_capacity, other_capacity FROM capacity_update WHERE updating_participant=? AND token_network_address=? AND channel_id=? """, [ to_checksum_address(updating_participant), to_checksum_address(token_network_address), hex256(channel_id), ], ) try: return next(capacity_list) except StopIteration: return TokenAmount(0), TokenAmount(0) def get_latest_committed_block(self) -> BlockNumber: return self.conn.execute("SELECT latest_committed_block FROM blockchain").fetchone()[0] def update_lastest_committed_block(self, latest_committed_block: BlockNumber) -> None: log.info("Updating latest_committed_block", latest_committed_block=latest_committed_block) self.conn.execute( "UPDATE blockchain SET latest_committed_block = ?", [latest_committed_block] ) def upsert_iou(self, iou: IOU) -> None: iou_dict = IOU.Schema(exclude=["receiver", "chain_id"]).dump(iou) iou_dict["one_to_n_address"] = to_checksum_address(iou_dict["one_to_n_address"]) for key in ("amount", "expiration_block"): iou_dict[key] = hex256(int(iou_dict[key])) self.upsert("iou", iou_dict) def get_ious( self, sender: Optional[Address] = None, expiration_block: Optional[BlockNumber] = None, claimed: Optional[bool] = None, expires_after: Optional[BlockNumber] = None, expires_before: Optional[BlockNumber] = None, amount_at_least: Optional[TokenAmount] = None, ) -> Iterator[IOU]: query = """ SELECT *, (SELECT chain_id FROM blockchain) AS chain_id FROM iou WHERE 1=1 """ args: list = [] if sender is not None: query += " AND sender = ?" args.append(to_checksum_address(sender)) if expiration_block is not None: query += " AND expiration_block = ?" args.append(hex256(expiration_block)) if claimed is not None: query += " AND claimed = ?" args.append(claimed) if expires_before is not None: query += " AND expiration_block < ?" args.append(hex256(expires_before)) if expires_after is not None: query += " AND expiration_block > ?" args.append(hex256(expires_after)) if amount_at_least is not None: query += " AND amount >= ?" args.append(hex256(amount_at_least)) for row in self.conn.execute(query, args): iou_dict = dict(zip(row.keys(), row)) iou_dict["receiver"] = to_checksum_address(self.pfs_address) yield IOU.Schema().load(iou_dict) def get_iou( self, sender: Address, expiration_block: Optional[BlockNumber] = None, claimed: Optional[bool] = None, ) -> Optional[IOU]: try: return next(self.get_ious(sender, expiration_block, claimed)) except StopIteration: return None def upsert_channel(self, channel: Channel) -> None: channel_dict = Channel.Schema().dump(channel) for key in ( "channel_id", "settle_timeout", "capacity1", "reveal_timeout1", "update_nonce1", "capacity2", "reveal_timeout2", "update_nonce2", ): channel_dict[key] = hex256(int(channel_dict[key])) channel_dict["fee_schedule1"] = json.dumps(channel_dict["fee_schedule1"]) channel_dict["fee_schedule2"] = json.dumps(channel_dict["fee_schedule2"]) self.upsert("channel", channel_dict) def get_channels(self) -> Iterator[Channel]: for row in self.conn.execute("SELECT * FROM channel"): channel_dict = dict(zip(row.keys(), row)) channel_dict["fee_schedule1"] = json.loads(channel_dict["fee_schedule1"]) channel_dict["fee_schedule2"] = json.loads(channel_dict["fee_schedule2"]) yield Channel.Schema().load(channel_dict) def delete_channel( self, token_network_address: TokenNetworkAddress, channel_id: ChannelID ) -> bool: """ Tries to delete a channel from the database Args: token_network_address: The address of the token network of the channel channel_id: The id of the channel Returns: `True` if the channel was deleted, `False` if it did not exist """ cursor = self.conn.execute( "DELETE FROM channel WHERE token_network_address = ? AND channel_id = ?", [to_checksum_address(token_network_address), hex256(channel_id)], ) assert cursor.rowcount <= 1, "Did delete more than one channel" return cursor.rowcount == 1 def get_token_networks(self) -> Iterator[TokenNetwork]: for row in self.conn.execute("SELECT address FROM token_network"): yield TokenNetwork( token_network_address=TokenNetworkAddress(to_canonical_address(row[0])) ) def prepare_feedback( self, token: FeedbackToken, route: List[Address], estimated_fee: FeeAmount ) -> None: hexed_route = [to_checksum_address(e) for e in route] token_dict = dict( token_id=token.uuid.hex, creation_time=token.creation_time, token_network_address=to_checksum_address(token.token_network_address), route=json.dumps(hexed_route), estimated_fee=hex256(estimated_fee), source_address=hexed_route[0], target_address=hexed_route[-1], ) self.insert("feedback", token_dict) def update_feedback(self, token: FeedbackToken, route: List[Address], successful: bool) -> int: hexed_route = [to_checksum_address(e) for e in route] token_dict = dict( token_id=token.uuid.hex, token_network_address=to_checksum_address(token.token_network_address), route=json.dumps(hexed_route), successful=successful, feedback_time=datetime.utcnow(), ) updated_rows = self.conn.execute( """ UPDATE feedback SET successful = :successful, feedback_time = :feedback_time WHERE token_id = :token_id AND token_network_address = :token_network_address AND route = :route AND successful IS NULL; """, token_dict, ).rowcount return updated_rows def get_feedback_routes( self, token_network_address: TokenNetworkAddress, source_address: Address, target_address: Optional[Address] = None, ) -> Iterator[Dict]: filters = { "token_network_address": to_checksum_address(token_network_address), "source_address": to_checksum_address(source_address), } where_clause = "" if target_address: where_clause = " AND target_address = :target_address" filters["target_address"] = to_checksum_address(target_address) sql = f""" SELECT source_address, target_address, route, estimated_fee, token_id FROM feedback WHERE token_network_address = :token_network_address AND source_address = :source_address {where_clause} """ for row in self.conn.execute(sql, filters): route = dict(zip(row.keys(), row)) route["route"] = json.loads(route["route"]) yield route def get_feedback_token( self, token_id: UUID, token_network_address: TokenNetworkAddress, route: List[Address] ) -> Optional[FeedbackToken]: hexed_route = [to_checksum_address(e) for e in route] token = self.conn.execute( """SELECT * FROM feedback WHERE token_id = ? AND token_network_address = ? AND route = ?; """, [token_id.hex, to_checksum_address(token_network_address), json.dumps(hexed_route)], ).fetchone() if token: return FeedbackToken( token_network_address=TokenNetworkAddress( to_canonical_address(token["token_network_address"]) ), uuid=UUID(token["token_id"]), creation_time=token["creation_time"], ) return None def get_num_routes_feedback( self, only_with_feedback: bool = False, only_successful: bool = False ) -> int: where_clause = "" if only_with_feedback: where_clause = "WHERE successful IS NOT NULL" elif only_successful: where_clause = "WHERE successful" return self.conn.execute(f"SELECT COUNT(*) FROM feedback {where_clause};").fetchone()[0] def insert_waiting_message(self, message: DeferableMessage) -> None: self.insert( "waiting_message", dict( token_network_address=to_checksum_address( message.canonical_identifier.token_network_address ), channel_id=hex256(message.canonical_identifier.channel_identifier), message=JSONSerializer.serialize(message), ), ) def pop_waiting_messages( self, token_network_address: TokenNetworkAddress, channel_id: ChannelID ) -> Iterator[DeferableMessage]: """Return all waiting messages for the given channel and delete them from the db""" # Return messages for row in self.conn.execute( """ SELECT message FROM waiting_message WHERE token_network_address = ? AND channel_id = ? """, [to_checksum_address(token_network_address), hex256(channel_id)], ): yield JSONSerializer.deserialize(row["message"]) # Delete returned messages self.conn.execute( "DELETE FROM waiting_message WHERE token_network_address = ? AND channel_id = ?", [to_checksum_address(token_network_address), hex256(channel_id)], )