from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import logging import es from es import exceptions from es.const import DEFAULT_SCHEMA from sqlalchemy import types from sqlalchemy.engine import default from sqlalchemy.sql import compiler logger = logging.getLogger(__name__) def parse_bool_argument(value: str) -> bool: if value in ("True", "true"): return True elif value in ("False", "false"): return False else: raise ValueError(f"Expected boolean found {value}") class BaseESCompiler(compiler.SQLCompiler): def visit_fromclause(self, fromclause, **kwargs): return fromclause.replace("default.", "") def visit_label(self, *args, **kwargs): if len(kwargs) == 0 or len(kwargs) == 1: kwargs["render_label_as_label"] = args[0] result = super().visit_label(*args, **kwargs) return result class BaseESTypeCompiler(compiler.GenericTypeCompiler): def visit_REAL(self, type_, **kwargs): return "DOUBLE" def visit_NUMERIC(self, type_, **kwargs): return "LONG" visit_DECIMAL = visit_NUMERIC visit_INTEGER = visit_NUMERIC visit_SMALLINT = visit_NUMERIC visit_BIGINT = visit_NUMERIC visit_BOOLEAN = visit_NUMERIC visit_TIMESTAMP = visit_NUMERIC visit_DATE = visit_NUMERIC def visit_CHAR(self, type_, **kwargs): return "STRING" visit_NCHAR = visit_CHAR visit_VARCHAR = visit_CHAR visit_NVARCHAR = visit_CHAR visit_TEXT = visit_CHAR def visit_DATETIME(self, type_, **kwargs): return "DATETIME" def visit_TIME(self, type_, **kwargs): raise exceptions.NotSupportedError("Type TIME is not supported") def visit_BINARY(self, type_, **kwargs): raise exceptions.NotSupportedError("Type BINARY is not supported") def visit_VARBINARY(self, type_, **kwargs): raise exceptions.NotSupportedError("Type VARBINARY is not supported") def visit_BLOB(self, type_, **kwargs): raise exceptions.NotSupportedError("Type BLOB is not supported") def visit_CLOB(self, type_, **kwargs): raise exceptions.NotSupportedError("Type CBLOB is not supported") def visit_NCLOB(self, type_, **kwargs): raise exceptions.NotSupportedError("Type NCBLOB is not supported") class BaseESDialect(default.DefaultDialect): name = "SET" scheme = "SET" driver = "SET" statement_compiler = None type_compiler = None preparer = compiler.IdentifierPreparer supports_alter = False supports_pk_autoincrement = False supports_default_values = False supports_empty_insert = False supports_unicode_statements = True supports_unicode_binds = True returns_unicode_strings = True description_encoding = None supports_native_boolean = True supports_simple_order_by_label = True _not_supported_column_types = ["object", "nested"] _map_parse_connection_parameters = { "verify_certs": parse_bool_argument, "use_ssl": parse_bool_argument, "http_compress": parse_bool_argument, "sniff_on_start": parse_bool_argument, "sniff_on_connection_fail": parse_bool_argument, "retry_on_timeout": parse_bool_argument, "sniffer_timeout": int, "sniff_timeout": int, "max_retries": int, "maxsize": int, "timeout": int, } @classmethod def dbapi(cls): return es def create_connect_args(self, url): kwargs = { "host": url.host, "port": url.port or 9200, "path": url.database, "scheme": self.scheme, "user": url.username or None, "password": url.password or None, } if url.query: kwargs.update(url.query) for name, parse_func in self._map_parse_connection_parameters.items(): if name in kwargs: kwargs[name] = parse_func(url.query[name]) return ([], kwargs) def get_schema_names(self, connection, **kwargs): # ES does not have the concept of a schema return [DEFAULT_SCHEMA] def has_table(self, connection, table_name, schema=None): return table_name in self.get_table_names(connection, schema) def get_table_names(self, connection, schema=None, **kwargs): query = "SHOW TABLES" result = connection.execute(query) return [row.name for row in result if row.name[0] != "."] def get_view_names(self, connection, schema=None, **kwargs): return [] def get_table_options(self, connection, table_name, schema=None, **kwargs): return {} def get_columns(self, connection, table_name, schema=None, **kwargs): query = f'SHOW COLUMNS FROM "{table_name}"' # A bit of an hack this cmd does not exist on ES array_columns_ = connection.execute( f"SHOW ARRAY_COLUMNS FROM {table_name}" ).fetchall() if len(array_columns_[0]) == 0: array_columns = [] else: array_columns = [col_name[0] for col_name in array_columns_] result = connection.execute(query) return [ { "name": row.column, "type": get_type(row.mapping), "nullable": True, "default": None, } for row in result if row.mapping not in self._not_supported_column_types and row.column not in array_columns ] def get_pk_constraint(self, connection, table_name, schema=None, **kwargs): return {"constrained_columns": [], "name": None} def get_foreign_keys(self, connection, table_name, schema=None, **kwargs): return [] def get_check_constraints(self, connection, table_name, schema=None, **kwargs): return [] def get_table_comment(self, connection, table_name, schema=None, **kwargs): return {"text": ""} def get_indexes(self, connection, table_name, schema=None, **kwargs): return [] def get_unique_constraints(self, connection, table_name, schema=None, **kwargs): return [] def get_view_definition(self, connection, view_name, schema=None, **kwargs): pass # pragma: no cover def do_rollback(self, dbapi_connection): pass def _check_unicode_returns(self, connection, additional_tests=None): return True def _check_unicode_description(self, connection): return True def get_type(data_type): type_map = { "bytes": types.LargeBinary, "boolean": types.Boolean, "date": types.Date, "datetime": types.DateTime, "double": types.Numeric, "text": types.String, "keyword": types.String, "integer": types.Integer, "half_float": types.Float, "geo_point": types.String, # TODO get a solution for nested type "nested": types.String, # TODO get a solution for object "object": types.BLOB, "long": types.BigInteger, "float": types.Float, "ip": types.String, } type_ = type_map.get(data_type) if not type_: logger.warning(f"Unknown type found {data_type} reverting to string") type_ = types.String return type_