import pandas as pd import pyspark as ps import regex as re import toolz from pkg_resources import parse_version import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.lineage as lin import ibis.expr.operations as ops import ibis.expr.schema as sch import ibis.expr.types as ir from ibis.client import Database, Query, SQLClient from ibis.spark import compiler as comp from ibis.spark import ddl from ibis.spark.compiler import SparkDialect, build_ast from ibis.spark.datatypes import spark_dtype from ibis.util import log @sch.infer.register(ps.sql.dataframe.DataFrame) def spark_dataframe_schema(df): """Infer the schema of a Spark SQL `DataFrame` object.""" # df.schema is a pt.StructType schema_struct = dt.dtype(df.schema) return sch.schema(schema_struct.names, schema_struct.types) class SparkCursor: """Spark cursor. This allows the Spark client to reuse machinery in :file:`ibis/client.py`. """ def __init__(self, query): """ Construct a SparkCursor with query `query`. Parameters ---------- query : pyspark.sql.DataFrame Contains result of query. """ self.query = query def fetchall(self): """Fetch all rows.""" result = self.query.collect() # blocks until finished return result @property def columns(self): """Return the columns of the result set.""" return self.query.columns @property def description(self): """Get the fields of the result set's schema.""" return self.query.schema def __enter__(self): # For compatibility when constructed from Query.execute() """No-op for compatibility. See Also -------- ibis.client.Query.execute """ return self def __exit__(self, exc_type, exc_value, traceback): """No-op for compatibility. See Also -------- ibis.client.Query.execute """ def find_spark_udf(expr): result = expr.op() if not isinstance(result, (comp.SparkUDFNode, comp.SparkUDAFNode)): result = None return lin.proceed, result class SparkQuery(Query): def execute(self): udf_nodes = lin.traverse(find_spark_udf, self.expr) # UDFs are uniquely identified by the name of the Node subclass we # generate. udf_nodes_unique = list( toolz.unique(udf_nodes, key=lambda node: type(node).__name__) ) # register UDFs in pyspark for node in udf_nodes_unique: self.client._session.udf.register( type(node).__name__, node.udf_func ) result = super().execute() for node in udf_nodes_unique: stmt = ddl.DropFunction(type(node).__name__, must_exist=True) self.client._execute(stmt.compile()) return result def _fetch(self, cursor): df = cursor.query.toPandas() # blocks until finished schema = self.schema() return schema.apply_to(df) class SparkDatabase(Database): pass class SparkDatabaseTable(ops.DatabaseTable): pass class SparkTable(ir.TableExpr): @property def _qualified_name(self): return self.op().args[0] def _match_name(self): m = ddl.fully_qualified_re.match(self._qualified_name) if not m: return None, self._qualified_name db, quoted, unquoted = m.groups() return db, quoted or unquoted @property def _database(self): return self._match_name()[0] @property def _unqualified_name(self): return self._match_name()[1] @property def name(self): return self.op().name @property def _client(self): return self.op().source def _execute(self, stmt): return self._client._execute(stmt) def compute_stats(self, noscan=False): """ Invoke Spark ANALYZE TABLE <tbl> COMPUTE STATISTICS command to compute column, table, and partition statistics. See also SparkClient.compute_stats """ return self._client.compute_stats(self._qualified_name, noscan=noscan) def drop(self): """ Drop the table from the database """ self._client.drop_table_or_view(self._qualified_name) def truncate(self): self._client.truncate_table(self._qualified_name) def insert(self, obj=None, overwrite=False, values=None, validate=True): """ Insert into Spark table. Parameters ---------- obj : TableExpr or pandas DataFrame overwrite : boolean, default False If True, will replace existing contents of table validate : boolean, default True If True, do more rigorous validation that schema of table being inserted is compatible with the existing table Examples -------- >>> t.insert(table_expr) # doctest: +SKIP # Completely overwrite contents >>> t.insert(table_expr, overwrite=True) # doctest: +SKIP """ if isinstance(obj, pd.DataFrame): spark_df = self._session.createDataFrame(obj) spark_df.insertInto(self.name, overwrite=overwrite) return expr = obj if values is not None: raise NotImplementedError if validate: existing_schema = self.schema() insert_schema = expr.schema() if not insert_schema.equals(existing_schema): _validate_compatible(insert_schema, existing_schema) ast = build_ast(expr, SparkDialect.make_context()) select = ast.queries[0] statement = ddl.InsertSelect( self._qualified_name, select, overwrite=overwrite ) return self._execute(statement.compile()) def rename(self, new_name): """ Rename table inside Spark. References to the old table are no longer valid. Spark does not support moving tables across databases using rename. Parameters ---------- new_name : string Returns ------- renamed : SparkTable """ new_qualified_name = _fully_qualified_name(new_name, self._database) statement = ddl.RenameTable(self._qualified_name, new_name) self._client._execute(statement.compile()) op = self.op().change_name(new_qualified_name) return type(self)(op) def alter(self, tbl_properties=None): """ Change setting and parameters of the table. Parameters ---------- tbl_properties : dict, optional Returns ------- None (for now) """ stmt = ddl.AlterTable( self._qualified_name, tbl_properties=tbl_properties ) return self._execute(stmt.compile()) class SparkClient(SQLClient): """ An Ibis client interface that uses Spark SQL. """ dialect = comp.SparkDialect database_class = SparkDatabase query_class = SparkQuery table_class = SparkDatabaseTable table_expr_class = SparkTable def __init__(self, session): self._context = session.sparkContext self._session = session self._catalog = session.catalog def close(self): """ Close Spark connection and drop any temporary objects """ self._context.stop() def _build_ast(self, expr, context): result = comp.build_ast(expr, context) return result def _execute(self, stmt, results=False): query = self._session.sql(stmt) if results: return SparkCursor(query) def database(self, name=None): return self.database_class(name or self.current_database, self) @property def current_database(self): """ String name of the current database. """ return self._catalog.currentDatabase() def _get_table_schema(self, table_name): return self.get_schema(table_name) def _get_schema_using_query(self, query): cur = self._execute(query, results=True) return spark_dataframe_schema(cur.query) def log(self, msg): log(msg) def _get_jtable(self, name, database=None): try: jtable = self._catalog._jcatalog.getTable( _fully_qualified_name(name, database) ) except ps.sql.utils.AnalysisException as e: raise com.IbisInputError(str(e)) from e return jtable def table(self, name, database=None): """ Create a table expression that references a particular table or view in the database. Parameters ---------- name : string database : string, optional Returns ------- table : TableExpr """ jtable = self._get_jtable(name, database) name, database = jtable.name(), jtable.database() qualified_name = _fully_qualified_name(name, database) schema = self._get_table_schema(qualified_name) node = self.table_class(qualified_name, schema, self) return self.table_expr_class(node) def list_tables(self, like=None, database=None): """ List tables in the current (or indicated) database. Like the SHOW TABLES command. Parameters ---------- like : string, default None e.g. 'foo*' to match all tables starting with 'foo' database : string, default None If not passed, uses the current/default database Returns ------- results : list of strings """ results = [t.name for t in self._catalog.listTables(dbName=database)] if like: results = [ table_name for table_name in results if re.match(like, table_name) is not None ] return results def exists_table(self, name, database=None): """ Determine if the indicated table or view exists Parameters ---------- name : string database : string, default None Returns ------- if_exists : boolean """ try: self._get_jtable(name, database) return True except com.IbisInputError: return False def set_database(self, name): """ Set the default database scope for client """ self._catalog.setCurrentDatabase(name) def list_databases(self, like=None): """ List databases in the Spark SQL cluster. Parameters ---------- like : string, default None e.g. 'foo*' to match all tables starting with 'foo' Returns ------- results : list of strings """ results = [db.name for db in self._catalog.listDatabases()] if like: results = [ database_name for database_name in results if re.match(like, database_name) is not None ] return results def exists_database(self, name): """ Checks if a given database exists Parameters ---------- name : string Database name Returns ------- if_exists : boolean """ return bool(self.list_databases(like=name)) def create_database(self, name, path=None, force=False): """ Create a new Spark database Parameters ---------- name : string Database name path : string, default None Path where to store the database data; otherwise uses Spark default """ statement = ddl.CreateDatabase(name, path=path, can_exist=force) return self._execute(statement.compile()) def drop_database(self, name, force=False): """Drop a Spark database. Parameters ---------- name : string Database name force : bool, default False If False, Spark throws exception if database is not empty or database does not exist """ statement = ddl.DropDatabase(name, must_exist=not force, cascade=force) return self._execute(statement.compile()) def get_schema(self, table_name, database=None): """ Return a Schema object for the indicated table and database Parameters ---------- table_name : string May be fully qualified database : string Spark does not have a database argument for its table() method, so this must be None Returns ------- schema : ibis Schema """ if database is not None: raise com.UnsupportedArgumentError( 'Spark does not support database param for table' ) df = self._session.table(table_name) return sch.infer(df) def _schema_from_csv(self, path, **kwargs): """ Return a Schema object for the indicated csv file. Spark goes through the file once to determine the schema. See documentation for `pyspark.sql.DataFrameReader` for kwargs. Parameters ---------- path : string Returns ------- schema : ibis Schema """ options = _read_csv_defaults.copy() options.update(kwargs) options['inferSchema'] = True df = self._session.read.csv(path, **options) return spark_dataframe_schema(df) def _create_table_or_temp_view_from_csv( self, name, path, schema=None, database=None, force=False, temp_view=False, format='parquet', **kwargs, ): options = _read_csv_defaults.copy() options.update(kwargs) if schema: assert ('inferSchema', True) not in options.items() schema = spark_dtype(schema) options['schema'] = schema else: options['inferSchema'] = True df = self._session.read.csv(path, **options) if temp_view: if force: df.createOrReplaceTempView(name) else: df.createTempView(name) else: qualified_name = _fully_qualified_name( name, database or self.current_database ) mode = 'error' if force: mode = 'overwrite' df.write.saveAsTable(qualified_name, format=format, mode=mode) @property def version(self): return parse_version(ps.__version__) def create_table( self, table_name, obj=None, schema=None, database=None, force=False, # HDFS options format='parquet', ): """ Create a new table in Spark using an Ibis table expression. Parameters ---------- table_name : string obj : TableExpr or pandas.DataFrame, optional If passed, creates table from select statement results schema : ibis.Schema, optional Mutually exclusive with obj, creates an empty table with a particular schema database : string, default None (optional) force : boolean, default False If true, create table if table with indicated name already exists format : {'parquet'} Examples -------- >>> con.create_table('new_table_name', table_expr) # doctest: +SKIP """ if obj is not None: if isinstance(obj, pd.DataFrame): spark_df = self._session.createDataFrame(obj) mode = 'error' if force: mode = 'overwrite' spark_df.write.saveAsTable( table_name, format=format, mode=mode ) return ast = self._build_ast(obj, SparkDialect.make_context()) select = ast.queries[0] statement = ddl.CTAS( table_name, select, database=database, can_exist=force, format=format, ) elif schema is not None: statement = ddl.CreateTableWithSchema( table_name, schema, database=database, format=format, can_exist=force, ) else: raise com.IbisError('Must pass expr or schema') return self._execute(statement.compile()) def create_view( self, name, expr, database=None, can_exist=False, temporary=False ): """ Create a Spark view from a table expression Parameters ---------- name : string expr : ibis TableExpr database : string, default None can_exist : boolean, default False Replace an existing view of the same name if it exists temporary : boolean, default False """ ast = self._build_ast(expr, SparkDialect.make_context()) select = ast.queries[0] statement = ddl.CreateView( name, select, database=database, can_exist=can_exist, temporary=temporary, ) return self._execute(statement.compile()) def drop_table(self, name, database=None, force=False): self.drop_table_or_view(name, database, force) def drop_view(self, name, database=None, force=False): self.drop_table_or_view(name, database, force) def drop_table_or_view(self, name, database=None, force=False): """ Drop a Spark table or view Parameters ---------- name : string database : string, default None (optional) force : boolean, default False Database may throw exception if table does not exist Examples -------- >>> table = 'my_table' >>> db = 'operations' >>> con.drop_table_or_view(table, db, force=True) # doctest: +SKIP """ statement = ddl.DropTable( name, database=database, must_exist=not force ) self._execute(statement.compile()) def truncate_table(self, table_name, database=None): """ Delete all rows from, but do not drop, an existing table Parameters ---------- table_name : string database : string, default None (optional) """ statement = ddl.TruncateTable(table_name, database=database) self._execute(statement.compile()) def insert( self, table_name, obj=None, database=None, overwrite=False, values=None, validate=True, ): """ Insert into existing table. See SparkTable.insert for other parameters. Parameters ---------- table_name : string database : string, default None Examples -------- >>> table = 'my_table' >>> con.insert(table, table_expr) # doctest: +SKIP # Completely overwrite contents >>> con.insert(table, table_expr, overwrite=True) # doctest: +SKIP """ table = self.table(table_name, database=database) return table.insert( obj=obj, overwrite=overwrite, values=values, validate=validate ) def compute_stats(self, name, database=None, noscan=False): """ Issue COMPUTE STATISTICS command for a given table Parameters ---------- name : string Can be fully qualified (with database name) database : string, optional noscan : boolean, default False If True, collect only basic statistics for the table (number of rows, size in bytes). """ maybe_noscan = ' NOSCAN' if noscan else '' stmt = 'ANALYZE TABLE {0} COMPUTE STATISTICS{1}'.format( _fully_qualified_name(name, database), maybe_noscan ) return self._execute(stmt) def _fully_qualified_name(name, database): if ddl._is_fully_qualified(name): return name if database: return '{0}.`{1}`'.format(database, name) return name def _validate_compatible(from_schema, to_schema): if set(from_schema.names) != set(to_schema.names): raise com.IbisInputError('Schemas have different names') for name in from_schema: lt = from_schema[name] rt = to_schema[name] if not lt.castable(rt): raise com.IbisInputError( 'Cannot safely cast {0!r} to {1!r}'.format(lt, rt) ) _read_csv_defaults = { 'header': True, 'multiLine': True, 'mode': 'FAILFAST', 'escape': '"', }