# Copyright 2014-2020 Chris Cummins <chrisc.101@gmail.com>. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utility code for working with sqlalchemy.""" import contextlib import os import pathlib import queue import sqlite3 import sys import threading import time import typing from typing import Callable from typing import List from typing import Optional import sqlalchemy as sql from absl import flags as absl_flags from sqlalchemy import func from sqlalchemy import orm from sqlalchemy.dialects import mysql from sqlalchemy.ext import declarative from labm8.py import humanize from labm8.py import labdate from labm8.py import pbutil from labm8.py import progress from labm8.py import text from labm8.py.internal import labm8_logging as logging FLAGS = absl_flags.FLAGS absl_flags.DEFINE_boolean( "sqlutil_echo", False, "If True, the Engine will log all statements as well as a repr() of their " "parameter lists to the engines logger, which defaults to sys.stdout.", ) absl_flags.DEFINE_boolean( "sqlutil_pool_pre_ping", True, "Enable pessimistic pre-ping to check that database connections are " "alive. This adds some overhead, but reduces the risk of " '"server has gone away" errors. See:' "<https://docs.sqlalchemy.org/en/13/core/pooling.html#disconnect-handling-pessimistic>", ) absl_flags.DEFINE_integer( "mysql_engine_pool_size", 5, "The number of connections to keep open inside the connection pool. A " "--mysql_engine_pool_size of 0 indicates no limit", ) absl_flags.DEFINE_integer( "mysql_engine_max_overflow", 10, "The number of connections to allow in connection pool “overflow”, that " "is connections that can be opened above and beyond the " "--mysql_engine_pool_size setting", ) absl_flags.DEFINE_boolean( "mysql_assume_utf8_charset", True, "Default to adding the '?charset=utf8' suffix to MySQL database URLs.", ) absl_flags.DEFINE_boolean( "sqlite_enable_foreign_keys", True, "Enable foreign key support for SQLite. This enforces foreign key " "constraints, and enables cascaded update/delete statements. See: " "https://docs.sqlalchemy.org/en/13/dialects/sqlite.html#foreign-key-support", ) # The Query type is returned by Session.query(). This is a convenience for type # annotations. Query = orm.query.Query class DatabaseNotFound(FileNotFoundError): """An error that is raised if the requested database cannot be found.""" def __init__(self, url: str): self._url = url @property def url(self): return self._url def __repr__(self) -> str: return f"Database not found: '{self.url}'" def __str__(self) -> str: return repr(self) def Base(*args, **kwargs) -> sql.ext.declarative.DeclarativeMeta: """Construct a base class for declarative class definitions.""" return sql.ext.declarative.declarative_base(*args, **kwargs) def GetOrAdd( session: sql.orm.session.Session, model, defaults: typing.Dict[str, object] = None, **kwargs, ): """Instantiate a mapped database object. If the object is not in the database, add it. Note that no change is written to disk until commit() is called on the session. Args: session: The database session. model: The database table class. defaults: Default values for mapped objects. kwargs: The values for the table row. Returns: An instance of the model class, with the values specified. """ instance = session.query(model).filter_by(**kwargs).first() if not instance: params = { k: v for k, v in kwargs.items() if not isinstance(v, sql.sql.expression.ClauseElement) } params.update(defaults or {}) instance = model(**params) session.add(instance) logging.Log( logging.GetCallingModuleName(), 5, "New record: %s(%s)", model.__name__, params, ) return instance def Get( session: sql.orm.session.Session, model, defaults: typing.Dict[str, object] = None, **kwargs, ): """Determine if a database object exists. Args: session: The database session. model: The database table class. defaults: Default values for mapped objects. kwargs: The values for the table row. Returns: An instance of the model class with the values specified, or None if the object is not in the database. """ del defaults return session.query(model).filter_by(**kwargs).first() def CreateEngine(url: str, must_exist: bool = False) -> sql.engine.Engine: """Create an sqlalchemy database engine. This is a convenience wrapper for creating an sqlalchemy engine, that also creates the database if required, and checks that the database exists. This means that it is less flexible than SqlAlchemy's create_engine() - only three combination of dialects and drivers are supported: sqlite, mysql, and postgresql. See https://docs.sqlalchemy.org/en/latest/core/engines.html for details. Additionally, this implements a custom 'file://' handler, which reads a URL from a local file, and returns a connection to the database addressed by the URL. Use this if you would like to keep sensitive information such as a MySQL database password out of your .bash_history. Examples: Create in-memory SQLite database: >>> engine = CreateEngine('sqlite://') Connect to an SQLite database at relative.db: >>> engine = CreateEngine('sqlite:///relative.db') Connect to an SQLite database at /absolute/path/to/db: >>> engine = CreateEngine('sqlite:////absolute/path/to/db') Connect to MySQL database: >>> engine = CreateEngine( 'mysql://bob:password@localhost:1234/database?charset=utf8') Connect to PostgreSQL database: >>> engine.CreateEngine( 'postgresql://bob:password@localhost:1234/database') Connect to a URL specified in the file /tmp/url.txt: >>> engine.CreateEngine('file:///tmp/url.txt') Connect to a URL specified in the file /tmp/url.txt, with the suffix '/database?charset=utf8': >>> engine.CreateEngine('file:///tmp/url.txt?/database?charset=utf8') Args: url: The URL of the database to connect to. must_exist: If True, raise DatabaseNotFound if it doesn't exist. Else, database is created if it doesn't exist. Returns: An SQLalchemy Engine instance. Raises: DatabaseNotFound: If the database does not exist and must_exist is set. ValueError: If the datastore backend is not supported. """ engine_args = {} # Read and expand a `file://` prefixed URL. url = ResolveUrl(url) if url.startswith("mysql://"): # Support for MySQL dialect. # We create a throwaway engine that we use to check if the requested # database exists. engine = sql.create_engine("/".join(url.split("/")[:-1])) database = url.split("/")[-1].split("?")[0] query = engine.execute( sql.text( "SELECT SCHEMA_NAME FROM " "INFORMATION_SCHEMA.SCHEMATA WHERE " "SCHEMA_NAME = :database", ), database=database, ) # Engine-specific options. engine_args["pool_size"] = FLAGS.mysql_engine_pool_size engine_args["max_overflow"] = FLAGS.mysql_engine_max_overflow if not query.first(): if must_exist: raise DatabaseNotFound(url) else: # We can't use sql.text() escaping here because it uses single quotes # for escaping. MySQL only accepts backticks for quoting database # names. engine.execute(f"CREATE DATABASE `{database}`") engine.dispose() elif url.startswith("sqlite://"): # Support for SQLite dialect. # This project (phd) deliberately disallows relative paths due to Bazel # sandboxing. if url != "sqlite://" and not url.startswith("sqlite:////"): raise ValueError("Relative path to SQLite database is not allowed") if url == "sqlite://": if must_exist: raise ValueError( "must_exist=True not valid for in-memory SQLite database", ) else: path = pathlib.Path(url[len("sqlite:///") :]) if must_exist: if not path.is_file(): raise DatabaseNotFound(url) else: # Make the parent directory for SQLite database if creating a new # database. path.parent.mkdir(parents=True, exist_ok=True) elif url.startswith("postgresql://"): # Support for PostgreSQL dialect. engine = sql.create_engine("/".join(url.split("/")[:-1] + ["postgres"])) conn = engine.connect() database = url.split("/")[-1] query = conn.execute( sql.text("SELECT 1 FROM pg_database WHERE datname = :database"), database=database, ) if not query.first(): if must_exist: raise DatabaseNotFound(url) else: # PostgreSQL does not let you create databases within a transaction, so # manually complete the transaction before creating the database. conn.execute(sql.text("COMMIT")) # PostgreSQL does not allow single quoting of database names. conn.execute(f"CREATE DATABASE {database}") conn.close() engine.dispose() else: raise ValueError(f"Unsupported database URL='{url}'") # Create the engine. engine = sql.create_engine( url, encoding="utf-8", echo=FLAGS.sqlutil_echo, pool_pre_ping=FLAGS.sqlutil_pool_pre_ping, **engine_args, ) # Create and immediately close a connection. This is because SQLAlchemy engine # is lazily instantiated, so for connections such as SQLite, this line # actually creates the file. engine.connect().close() return engine @sql.event.listens_for(sql.engine.Engine, "connect") def EnableSqliteForeignKeysCallback(dbapi_connection, connection_record): """Enable foreign key constraints for SQLite databases. See --sqlite_enable_foreign_keys for details. """ del connection_record # This callback listens for *all* database connections, not just SQLite. Check # the type before trying to run an SQLite-specific pragma. if FLAGS.sqlite_enable_foreign_keys and isinstance( dbapi_connection, sqlite3.Connection ): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() def ResolveUrl(url: str, use_flags: bool = True): """Resolve the URL of a database. The following modifications are supported: * If the url begins with 'file://', the URL is substituted with the contents of the file. * If --mysql_assume_utf8_charset is set, then '?charset=utf8' suffix is appended to URLs which begin with mysql://. * Shell variables are expanded. Args: url: The URL to expand, e.g. `file://path/to/file.txt?arg' use_flags: Determine whether behaviour is dictated by the FLAGS variables. Set this to False only when resolving database URLs before flags parsing, e.g. in enumerating test fixtures. Returns: The URL as interpreted by reading any URL file. Raises: ValueError: If the file path is invalid. FileNotFoundError: IF the file path does not exist. """ # Substitute shell variables. url = os.path.expandvars(url) if url.startswith("file://"): # Split the URL into the file path, and the optional suffix. components = url.split("?") path, suffix = components[0], "?".join(components[1:]) # Strip the file:// prefix from the path. path = pathlib.Path(path[len("file://") :]) if not path.is_absolute(): raise ValueError("Relative path to file:// is not allowed") if not path.is_file(): raise FileNotFoundError(f"File '{path}' not found") # Read the contents of the file, ignoring lines starting with '#'. with open(path) as f: url = "\n".join( x for x in f.read().split("\n") if not x.lstrip().startswith("#") ).strip() # Append the suffix. url += suffix if ( use_flags and url.startswith("mysql://") and FLAGS.mysql_assume_utf8_charset ): url += "?charset=utf8" return url def ColumnNames(model) -> typing.List[str]: """Return the names of all columns in a mapped object. Args: model: A mapped class. Returns: A list of string column names in the order that they are declared. """ try: inst = sql.inspect(model) return [c_attr.key for c_attr in inst.mapper.column_attrs] except sql.exc.NoInspectionAvailable as e: raise TypeError(str(e)) class Session(orm.session.Session): """A subclass of the default SQLAlchemy Session with added functionality. An instance of this class is returned by Database.Session(). """ def GetOrAdd( self, model, defaults: typing.Dict[str, object] = None, **kwargs ): """Instantiate a mapped database object. If the object is not in the database, add it. Note that no change is written to disk until commit() is called on the session. Args: model: The database table class. defaults: Default values for mapped objects. kwargs: The values for the table row. Returns: An instance of the model class, with the values specified. """ return GetOrAdd(self, model, defaults, **kwargs) class Database(object): """A base class for implementing databases.""" SessionType = Session def __init__(self, url: str, declarative_base, must_exist: bool = False): """Instantiate a database object. Example: >>> db = Database('sqlite:////tmp/foo.db', sqlalchemy.ext.declarative.declarative_base()) Args: url: The URL of the database to connect to. declarative_base: The SQLAlchemy declarative base instance. must_exist: If True, raise DatabaseNotFound if it doesn't exist. Else, database is created if it doesn't exist. Raises: DatabaseNotFound: If the database does not exist and must_exist is set. ValueError: If the datastore backend is not supported. """ self._url = url self.engine = CreateEngine(url, must_exist=must_exist) declarative_base.metadata.create_all(self.engine) declarative_base.metadata.bind = self.engine # Bind the Engine to a session maker, which instantiates our own Session # class, which is a subclass of the default SQLAlchemy Session with added # functionality. self.MakeSession = orm.sessionmaker(bind=self.engine, class_=Session) def Close(self) -> None: """Close the connection to the database. Use this to free up the connection to a database, while keeping the database instance around. After calling this method, attempting to run operations on this database will raise an error (like a sqlalchemy.exc.OperationalError). Usage of this method is generally discouraged - connections are automatically closed up when a database instance is garbage collected, so there are rarely cases for leaving a database instance around with the connection closed. Use at your peril! """ self.engine.dispose() def Drop(self, are_you_sure_about_this_flag: bool = False): """Drop the database, irreverisbly destroying it. Be careful with this! After calling this method an a Database instance, no further operations can be made on it, and any Sessions should be discarded. Args: are_you_sure_about_this_flag: You should be sure. Raises: ValueError: In case you're not 100% sure. """ if not are_you_sure_about_this_flag: raise ValueError("Let's take a minute to think things over") if self.url.startswith("mysql://"): engine = sql.create_engine("/".join(self.url.split("/")[:-1])) database = self.url.split("/")[-1].split("?")[0] logging.Log(logging.GetCallingModuleName(), 1, "database %s", database) engine.execute(f"DROP DATABASE IF EXISTS `{database}`") elif self.url == "sqlite://": # In-memory databases do not dropping. pass elif self.url.startswith("sqlite:///"): path = pathlib.Path(self.url[len("sqlite:///") :]) assert path.is_file() path.unlink() else: raise NotImplementedError( f"Unsupported operation DROP for database: '{self.url}'", ) @property def url(self) -> str: """Return the URL of the database.""" return self._url @contextlib.contextmanager def Session( self, commit: bool = False, session: Optional[Session] = None ) -> Session: """Provide a transactional scope around a session. The optional session argument may be used for cases where you want to optionally re-use an existing session, rather than always creating a new session, e.g.: class MyDatabase(sqlutil.Database): def DoAThing(self, session=None): with self.Session(session=session, commit=True): # go nuts ... Args: commit: If true, commit session at the end of scope. session: An existing session object to re-use. Returns: A database session. """ session = session or self.MakeSession() try: yield session if commit: session.commit() except: session.rollback() raise finally: session.close() @property def Random(self): """Get the backend-specific random function. This can be used to select a random row from a table, e.g. session.query(Table).order_by(db.Random()).first() """ if self.url.startswith("mysql"): return func.rand else: return func.random # for PostgreSQL, SQLite def __repr__(self) -> str: return self.url class TablenameFromClassNameMixin(object): """A class mixin which derives __tablename__ from the class name. Add this mixin to a mapped table class to automatically set the set the __tablename__ property of a class to the lowercase name of the Python class. """ @declarative.declared_attr def __tablename__(self): return self.__name__.lower() class TablenameFromCamelCapsClassNameMixin(object): """A class mixin which derives __tablename__ from the class name. Add this mixin to a mapped table class to automatically set the set the __tablename__ property of a class to the name of the Python class with camel caps converted to underscores, e.g. class FooBar -> table "foo_bar". """ @declarative.declared_attr def __tablename__(self): return text.CamelCapsToUnderscoreSeparated(self.__name__) class PluralTablenameFromCamelCapsClassNameMixin(object): """A class mixin which derives __tablename__ from the class name. Add this mixin to a mapped table class to automatically set the set the __tablename__ property of a class to the pluralized name of the Python class with camel caps converted to underscores, e.g. class FooBar -> table "foo_bars". """ @declarative.declared_attr def __tablename__(self): pluralised = humanize.Plural(2, self.__name__) pluralised = " ".join(pluralised.split()[1:]) return text.CamelCapsToUnderscoreSeparated(pluralised) class ProtoBackedMixin(object): """A database table backed by protocol buffers. This class provides the abstract interface for sqlalchemy table classes which support serialization to and from protocol buffers. This is only an interface - inheriting classes must still inherit from sqlalchemy.ext.declarative.declarative_base(). """ proto_t = None def SetProto(self, proto: pbutil.ProtocolBuffer) -> None: """Set the fields of a protocol buffer with the values from the instance. Args: proto: A protocol buffer. """ raise NotImplementedError( f"{type(self).__name__}.SetProto() not implemented", ) def ToProto(self) -> pbutil.ProtocolBuffer: """Serialize the instance to protocol buffer. Returns: A protocol buffer. """ proto = self.proto_t() self.SetProto(proto) return proto @classmethod def FromProto( cls, proto: pbutil.ProtocolBuffer, ) -> typing.Dict[str, typing.Any]: """Return a dictionary of instance constructor args from proto. Examples: Construct a table instance from proto: >>> table = Table(**Table.FromProto(proto)) Construct a table instance and add to session: >>> session.GetOrAdd(Table, **Table.FromProto(proto)) Args: proto: A protocol buffer. Returns: A dictionary of constructor arguments. """ raise NotImplementedError( f"{type(self).__name__}.FromProto() not implemented", ) @classmethod def FromFile(cls, path: pathlib.Path) -> typing.Dict[str, typing.Any]: """Return a dictionary of instance constructor args from proto file. Examples: Construct a table instance from proto file: >>> table = Table(**Table.FromFile(path)) Construct a table instance and add to session: >>> session.GetOrAdd(Table, **Table.FromFile(path)) Args: path: Path to a proto file. Returns: An instance. """ proto = pbutil.FromFile(path, cls.proto_t()) return cls.FromProto(proto) class OffsetLimitQueryResultsBatch(typing.NamedTuple): """The results of an offset-limit batched query.""" # The current batch number. batch_num: int # Offset into the results set. offset: int # Limit is the last row in the results set. limit: int # The total number of rows in the query if compute_max_rows=True, else None. max_rows: int # The results of the query. rows: typing.List[typing.Any] def OffsetLimitBatchedQuery( query: Query, batch_size: int = 1000, start_at: int = 0, compute_max_rows: bool = False, ) -> typing.Iterator[OffsetLimitQueryResultsBatch]: """Split and return the rows resulting from a query in to batches. This iteratively runs the query `SELECT * FROM * OFFSET i LIMIT batch_size;` with `i` initialized to `start_at` and increasing by `batch_size` per iteration. Iteration terminates when the query returns no rows. This function is useful for returning row sets from enormous tables, where loading the full query results in to memory would take prohibitive time or resources. Args: query: The query to run. batch_size: The number of rows to return per batch. start_at: The initial offset into the table. compute_max_rows: If true Returns: A generator of OffsetLimitQueryResultsBatch tuples, where each tuple contains between 1 <= x <= `batch_size` rows. """ max_rows = None if compute_max_rows: max_rows = query.count() batch_num = 0 i = start_at while True: batch_num += 1 batch = query.offset(i).limit(batch_size).all() if batch: yield OffsetLimitQueryResultsBatch( batch_num=batch_num, offset=i, limit=i + batch_size, max_rows=max_rows, rows=batch, ) i += len(batch) else: break class ColumnTypes(object): """Abstract class containing methods for generating column types.""" def __init__(self): raise TypeError("abstract class") @staticmethod def BinaryArray(length: int): """Return a fixed size binary array column type. Args: length: The length of the column. Returns: A column type. """ return sql.Binary(length).with_variant(mysql.BINARY(length), "mysql") @staticmethod def LargeBinary(): """Return a fixed size binary array column type. Returns: A column type. """ return sql.LargeBinary().with_variant(sql.LargeBinary(2 ** 31), "mysql") @staticmethod def UnboundedUnicodeText(): """Return an unbounded unicode text column type. This isn't truly unbounded, but 2^32 chars should be enough! Returns: A column type. """ return sql.UnicodeText().with_variant(sql.UnicodeText(2 ** 31), "mysql") @staticmethod def IndexableString(length: int = None): """Return a string that is short enough that it can be used as an index. Returns: A column type. """ # MySQL InnoDB tables use a default index key prefix length limit of 767. # https://dev.mysql.com/doc/refman/5.6/en/innodb-restrictions.html MAX_LENGTH = 767 if length and length > MAX_LENGTH: raise ValueError( f"IndexableString requested length {length} is greater " f"than maximum allowed {MAX_LENGTH}", ) return sql.String(MAX_LENGTH) @staticmethod def MillisecondDatetime(): """Return a datetime type with millisecond precision. Returns: A column type. """ return sql.DateTime().with_variant(mysql.DATETIME(fsp=3), "mysql") class ColumnFactory(object): """Abstract class containing methods for generating columns.""" @staticmethod def MillisecondDatetime( nullable: bool = False, default=labdate.GetUtcMillisecondsNow, ): """Return a datetime column with millisecond precision. Returns: A column which defaults to UTC now. """ return sql.Column( sql.DateTime().with_variant(mysql.DATETIME(fsp=3), "mysql",), nullable=nullable, default=default, ) def ResilientAddManyAndCommit(db: Database, mapped: typing.Iterable[Base]): """Attempt to commit all mapped objects and return those that fail. This method creates a session and commits the given mapped objects. In case of error, this method will recurse up to O(log(n)) times, committing as many objects that can be as possible. Args: db: The database to add the objects to. mapped: A sequence of objects to commit. Returns: Any items in `mapped` which could not be committed, if any. Relative order of items is preserved. """ failures = [] if not mapped: return failures mapped = list(mapped) try: with db.Session(commit=True) as session: session.add_all(mapped) except sql.exc.SQLAlchemyError as e: logging.Log( logging.GetCallingModuleName(), 1, "Caught error while committing %d mapped objects: %s", len(mapped), e, ) # Divide and conquer. If we're committing only a single object, then a # failure to commit it means that we can do nothing other than return it. # Else, divide the mapped objects in half and attempt to commit as many of # them as possible. if len(mapped) == 1: return mapped else: mid = int(len(mapped) / 2) left = mapped[:mid] right = mapped[mid:] failures += ResilientAddManyAndCommit(db, left) failures += ResilientAddManyAndCommit(db, right) return failures def QueryToString(query) -> str: """Compile the query to inline literals in place of '?' placeholders. See: https://stackoverflow.com/a/23835766 """ return str(query.statement.compile(compile_kwargs={"literal_binds": True})) class BufferedDatabaseWriter(threading.Thread): """A buffered writer for adding objects to a database. Use this class for cases when you are producing lots of mapped objects that you would like to commit to a database, but don't require them to be committed immediately. By buffering objects and committing them in batches, this class minimises the number of SQL statements that are executed, and is faster than creating and committing a session for every object. This object spawns a separate thread for asynchronously performing database writes. Use AddOne() and AddMany() methods to add objects to the write buffer. Note that because this is a multithreaded implementation, in-memory SQLite databases are not supported. The user is responsible for calling Close() to flush the contents of the buffer and terminate the thread. Alternatively, use this class as a context manager to automatically flush the buffer and terminate the thread: with BufferedDatabaseWriter(db, max_buffer_length=128) as writer: for chunk in chunks_to_process: objs = ProcessChunk(chunk) writer.AddMany(objs) """ def __init__( self, db: Database, max_buffer_size: Optional[int] = None, max_buffer_length: Optional[int] = None, max_seconds_since_flush: Optional[float] = None, log_level: int = 2, ctx: progress.ProgressContext = progress.NullContext, ): """Constructor. Args: db: The database to write to. max_buffer_size: The maximum size of the buffer before flushing, in bytes. The buffer size is the sum of the elements in the write buffer. The size of elements is determined using sys.getsizeof(), and has all the caveats of this method. max_buffer_length: The maximum number of items in the write buffer before flushing. max_seconds_since_flush: The maximum number of elapsed seconds between flushes. ctx: progress.ProgressContext = progress.NullContext, log_level: The logging level for logging output. """ super(BufferedDatabaseWriter, self).__init__() self.db = db self.ctx = ctx self.log_level = log_level self.max_seconds_since_flush = max_seconds_since_flush self.max_buffer_size = max_buffer_size self.max_buffer_length = max_buffer_length # Counters. self.flush_count = 0 self.error_count = 0 self._buffer = [] self.buffer_size = 0 self._last_flush = time.time() # Limit the size of the queue so that calls to AddOne() or AddMany() will # block if the calling code is too far ahead of the writer. queue_size = self.max_buffer_length * 2 if self.max_buffer_length else 1000 self._queue = queue.Queue(maxsize=queue_size) self.start() def __enter__(self) -> "Buff": """Enter a scoped writer context closes at the end.""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Exit a scoped writer context closes at the end.""" del exc_type del exc_val del exc_tb self.Close() def AddOne(self, mapped, size: Optional[int] = None) -> None: """Add a mapped object. Args: mapped: The mapped object to write to the database. size: The object sizes to use to update the total buffer size. If not provided, sys.getsizeof() is used to determine the size. """ size = size or sys.getsizeof(mapped) self._queue.put((mapped, size)) def AddMany(self, mappeds, sizes: Optional[List[int]] = None) -> None: """Add many mapped objects. Args: mappeds: The mapped objects to write to the database. sizes: A list of mapped object sizes to use to calculate the buffer size. If not provided, sys.getsizeof() is used to determine the size. """ sizes = sizes or [sys.getsizeof(item) for item in mappeds] for mapped, size in zip(mappeds, sizes): self._queue.put((mapped, size)) def AddLambdaOp(self, callback: Callable[[Database.SessionType], None]): self._queue.put(BufferedDatabaseWriter.LambdaOp(callback)) def Flush(self) -> None: """Flush the buffer. This method blocks until the flush has completed. In normal use, you can rely on the automated flushing mechanisms to flush the write buffer, rather than calling this by hand. """ self._queue.put(BufferedDatabaseWriter.FlushMarker()) self._queue.join() def Close(self): """Close the writer thread. This method blocks until the buffer has been flushed and the thread terminates. """ if not self.is_alive(): raise TypeError("Close() called on dead BufferedDatabaseWriter") self._queue.put(BufferedDatabaseWriter.CloseMarker()) self._queue.join() self.join() @property def buffer_length(self) -> int: """Get the current length of the buffer, in range [0, max_buffer_length].""" return len(self._buffer) @property def seconds_since_last_flush(self) -> float: """Get the number of seconds since the buffer was last flushed.""" return time.time() - self._last_flush ############################################################################## # Private methods. ############################################################################## class CloseMarker(object): """An object to append to _queue to close the thread.""" pass class FlushMarker(object): """An object to append to _queue to flush the buffer.""" pass class LambdaOp(object): def __init__(self, callback): self.callback = callback def __call__(self, session: Database.SessionType): self.callback(session) def run(self): """The thread loop.""" while True: # Block until there is something on the queue. Use max_seconds_since_flush # as a timeout to ensure that flushes still occur when the writer is not # being used. try: item = self._queue.get(timeout=self.max_seconds_since_flush) except queue.Empty: self._Flush() continue if isinstance(item, BufferedDatabaseWriter.CloseMarker): # End of queue. Break out of the loop. break elif isinstance(item, BufferedDatabaseWriter.FlushMarker): # Force a flush. self._Flush() elif isinstance(item, BufferedDatabaseWriter.LambdaOp): # Handle delete op. self._buffer.append(item) self._MaybeFlush() else: # Add the object to the buffer. mapped, size = item self._buffer.append(mapped) self.buffer_size += size self._MaybeFlush() # Register that the item has been processed. This is used by join() to # signal to stop blocking. self._queue.task_done() # Register that the end-of-queue marker has been processed. self._Flush() self._queue.task_done() def _MaybeFlush(self) -> None: if ( (self.max_buffer_size and self.buffer_size >= self.max_buffer_size) or ( self.max_buffer_length and self.buffer_length >= self.max_buffer_length ) or ( self.max_seconds_since_flush and self.seconds_since_last_flush >= self.max_seconds_since_flush ) ): self._Flush() def _AddMapped(self, mapped) -> None: """Add and commit a list of mapped objects.""" if not mapped: return failures = ResilientAddManyAndCommit(self.db, mapped) if failures: self.ctx.Error("Logger failed to commit %d objects", len(failures)) self.error_count += len(failures) def _Flush(self): """Flush the buffer.""" if not self._buffer: return with self.ctx.Profile( self.log_level, f"Committed {self.buffer_length} rows " f"({humanize.BinaryPrefix(self.buffer_size, 'B')}) to {self.db.url}", ), self.db.Session() as session: # Iterate through the buffer and handle any lambda ops. start_i, end_i = 0, 0 for end_i, item in enumerate(self._buffer): if isinstance(item, BufferedDatabaseWriter.LambdaOp): # If we have a lambda op, we flush the contents of the current buffer, # then execute the op and continue. self._AddMapped(self._buffer[start_i:end_i]) self._buffer[end_i](session) session.commit() start_i = end_i + 1 # Add any remaining mapped objects from the buffer. self._AddMapped(self._buffer[start_i:]) self._buffer = [] self._last_flush = time.time() self.buffer_size = 0 self.flush_count += 1