"""Loop to execute database queries and collect metrics.""" import asyncio from collections import defaultdict from datetime import datetime from decimal import Decimal from logging import Logger import time from typing import ( Any, Dict, Iterable, List, Mapping, Optional, Set, ) from croniter import croniter from dateutil.tz import gettz from prometheus_aioexporter import MetricsRegistry from toolrack.aio import ( PeriodicCall, TimedCall, ) from .config import ( Config, DB_ERRORS_METRIC_NAME, QUERIES_METRIC_NAME, QUERY_LATENCY_METRIC_NAME, ) from .db import ( DataBase, DATABASE_LABEL, DataBaseError, Query, ) class QueryLoop: """Run database queries and collect metrics.""" _METRIC_METHODS = { "counter": "inc", "gauge": "set", "histogram": "observe", "summary": "observe", "enum": "state", } def __init__( self, config: Config, registry: MetricsRegistry, logger: Logger, ): self._config = config self._registry = registry self._logger = logger self._timed_queries: List[Query] = [] self._aperiodic_queries: List[Query] = [] # map query names to their TimedCalls self._timed_calls: Dict[str, TimedCall] = {} # map query names to list of database names self._doomed_queries: Dict[str, Set[str]] = defaultdict(set) self._loop = asyncio.get_event_loop() self._setup() async def start(self): """Start timed queries execution.""" for db in self._databases: try: await db.connect() except DataBaseError: self._increment_db_error_count(db) for query in self._timed_queries: if query.interval: call = PeriodicCall(self._run_query, query) call.start(query.interval) else: call = TimedCall(self._run_query, query) now = datetime.now().replace(tzinfo=gettz()) cron_iter = croniter(query.schedule, now) def times_iter(): while True: delta = next(cron_iter) - time.time() yield self._loop.time() + delta call.start(times_iter()) self._timed_calls[query.name] = call async def stop(self): """Stop timed query execution.""" coros = (call.stop() for call in self._timed_calls.values()) await asyncio.gather(*coros, return_exceptions=True) self._timed_calls.clear() coros = (db.close() for db in self._databases) await asyncio.gather(*coros, return_exceptions=True) async def run_aperiodic_queries(self): """Run queries on request.""" coros = ( self._execute_query(query, dbname) for query in self._aperiodic_queries for dbname in query.databases ) await asyncio.gather(*coros, return_exceptions=True) @property def _databases(self) -> Iterable[DataBase]: """Return an iterable with defined Databases.""" return self._config.databases.values() def _setup(self): """Initialize instance attributes.""" for database in self._databases: database.set_logger(self._logger) for query in self._config.queries.values(): if query.timed: self._timed_queries.append(query) else: self._aperiodic_queries.append(query) def _run_query(self, query: Query): """Periodic task to run a query.""" for dbname in query.databases: self._loop.create_task(self._execute_query(query, dbname)) async def _execute_query(self, query: Query, dbname: str): """'Execute a Query on a DataBase.""" if await self._remove_if_dooomed(query, dbname): return db = self._config.databases[dbname] try: metric_results = await db.execute(query) except DataBaseError as error: self._increment_queries_count(db, "error") if error.fatal: self._logger.debug( f'removing doomed query "{query.name}" ' f'for database "{dbname}"' ) self._doomed_queries[query.name].add(dbname) return for result in metric_results.results: self._update_metric(db, result.metric, result.value, labels=result.labels) if metric_results.latency: self._update_query_latency_metric( db, query.config_name, metric_results.latency ) self._increment_queries_count(db, "success") async def _remove_if_dooomed(self, query: Query, dbname: str) -> bool: """Remove a query if it will never work. Return whether the query has been removed for the database. """ if dbname not in self._doomed_queries[query.name]: return False if set(query.databases) == self._doomed_queries[query.name]: # the query has failed on all databases if query.timed: self._timed_queries.remove(query) call = self._timed_calls.pop(query.name, None) if call is not None: await call.stop() else: self._aperiodic_queries.remove(query) return True def _update_metric( self, database: DataBase, name: str, value: Any, labels: Optional[Mapping[str, str]] = None, ): """Update value for a metric.""" if value is None: # don't fail is queries that count return NULL value = 0.0 elif isinstance(value, Decimal): value = float(value) method = self._METRIC_METHODS[self._config.metrics[name].type] all_labels = {DATABASE_LABEL: database.name} all_labels.update(database.labels) if labels: all_labels.update(labels) labels_string = ",".join( f'{label}="{value}"' for label, value in sorted(all_labels.items()) ) self._logger.debug( f'updating metric "{name}" {method} {value} {{{labels_string}}}' ) metric = self._registry.get_metric(name, labels=all_labels) getattr(metric, method)(value) def _increment_queries_count(self, database: DataBase, status: str): """Increment count of queries in a status for a database.""" self._update_metric(database, QUERIES_METRIC_NAME, 1, labels={"status": status}) def _increment_db_error_count(self, database: DataBase): """Increment number of errors for a database.""" self._update_metric(database, DB_ERRORS_METRIC_NAME, 1) def _update_query_latency_metric( self, database: DataBase, query_name: str, latency: float ): """Update latency metric for a query on a database.""" self._update_metric( database, QUERY_LATENCY_METRIC_NAME, latency, labels={"query": query_name} )