package org.apache.spark.sql.hbase

import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLConf, _}

class HBaseAdvancedSQLQuerySuite extends TestBaseWithSplitData {

  import org.apache.spark.sql.hbase.TestHbase._
  import org.apache.spark.sql.hbase.TestHbase.implicits._

  test("aggregation with codegen") {
    val originalValue = TestHbase.conf.codegenEnabled
    setConf(SQLConf.CODEGEN_ENABLED, "true")
    val result = sql("SELECT col1 FROM ta GROUP BY col1").collect()
    assert(result.length == 14, s"aggregation with codegen test failed on size")
    setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)

  test("dsl simple select 0") {
    val tableA = sql("SELECT * FROM ta")
      tableA.where('col7 === 1).orderBy('col2.asc).select('col4),
      Row(1) :: Nil)
      tableA.where('col2 === 6).orderBy('col2.asc).select('col7),
      Row(-31) :: Nil)

  test("metadata is propagated correctly") {
    val tableA = sql("SELECT col7, col1, col3 FROM ta")
    val schema = tableA.schema
    val docKey = "doc"
    val docValue = "first name"
    val metadata = new MetadataBuilder()
      .putString(docKey, docValue)
    val schemaWithMeta = new StructType(Array(
      schema("col7"), schema("col1").copy(metadata = metadata), schema("col3")))
    val personWithMeta = createDataFrame(tableA.rdd, schemaWithMeta)
    def validateMetadata(rdd: DataFrame): Unit = {
      assert(rdd.schema("col1").metadata.getString(docKey) == docValue)
    validateMetadata($"col7", $"col1"))
    validateMetadata(sql("SELECT * FROM personWithMeta"))
    validateMetadata(sql("SELECT col7, col1 FROM personWithMeta"))
    validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON col7 = personId"))
    validateMetadata(sql("SELECT col1, salary FROM personWithMeta JOIN salary ON col7 = personId"))
package org.apache.spark.sql.extensions

import com.pingcap.tispark.statistics.StatisticsManager
import com.pingcap.tispark.utils.ReflectionUtil._
import com.pingcap.tispark.{MetaManager, TiDBRelation, TiTableReference}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.TiSessionCatalog
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.{AnalysisException, _}

case class TiResolutionRule(getOrCreateTiContext: SparkSession => TiContext)(
    sparkSession: SparkSession)
    extends Rule[LogicalPlan] {
  protected lazy val meta: MetaManager = tiContext.meta
  private lazy val autoLoad = tiContext.autoLoad
  private lazy val tiCatalog = tiContext.tiCatalog
  private lazy val tiSession = tiContext.tiSession
  private lazy val sqlContext = tiContext.sqlContext
  protected val tiContext: TiContext = getOrCreateTiContext(sparkSession)
  protected val resolveTiDBRelation: TableIdentifier => LogicalPlan =
    tableIdentifier => {
      val dbName = getDatabaseFromIdentifier(tableIdentifier)
      val tableName = tableIdentifier.table
      val table = meta.getTable(dbName, tableName)
      if (table.isEmpty) {
        throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'")
      if (autoLoad) {
      val sizeInBytes = StatisticsManager.estimateTableSize(table.get)
      val tiDBRelation =
        TiDBRelation(tiSession, TiTableReference(dbName, tableName, sizeInBytes), meta)(
      // Use SubqueryAlias so that projects and joins can correctly resolve
      // UnresolvedAttributes in JoinConditions, Projects, Filters, etc.
      newSubqueryAlias(tableName, LogicalRelation(tiDBRelation))

  override def apply(plan: LogicalPlan): LogicalPlan =
    plan transformUp resolveTiDBRelations

  protected def resolveTiDBRelations: PartialFunction[LogicalPlan, LogicalPlan] = {
    case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier), _, _, _, _)
        if tiCatalog
          .exists(_.isInstanceOf[TiSessionCatalog]) =>
      i.copy(table = EliminateSubqueryAliases(resolveTiDBRelation(tableIdentifier)))
    case UnresolvedRelation(tableIdentifier)
        if tiCatalog
          .exists(_.isInstanceOf[TiSessionCatalog]) =>

  private def getDatabaseFromIdentifier(tableIdentifier: TableIdentifier): String =

case class TiDDLRule(getOrCreateTiContext: SparkSession => TiContext)(sparkSession: SparkSession)
    extends Rule[LogicalPlan] {
  protected lazy val tiContext: TiContext = getOrCreateTiContext(sparkSession)

  override def apply(plan: LogicalPlan): LogicalPlan =
    plan transformUp {
      // TODO: support other commands that may concern TiSpark catalog.
      case sd: ShowDatabasesCommand =>
        TiShowDatabasesCommand(tiContext, sd)
      case sd: SetDatabaseCommand =>
        TiSetDatabaseCommand(tiContext, sd)
      case st: ShowTablesCommand =>
        TiShowTablesCommand(tiContext, st)
      case st: ShowColumnsCommand =>
        TiShowColumnsCommand(tiContext, st)
      case dt: DescribeTableCommand =>
        TiDescribeTablesCommand(tiContext, dt)
      case dc: DescribeColumnCommand =>
        TiDescribeColumnCommand(tiContext, dc)
      case ct: CreateTableLikeCommand =>
        TiCreateTableLikeCommand(tiContext, ct)
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
import{KustoSinkOptions, SparkIngestionProperties}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.{DataFrameWriter, _}

object SparkExtension {

  implicit class DataFrameReaderExtension(df: DataFrameReader) {

    def kusto(kustoCluster: String, database: String, query: String, conf: Map[String, String] = Map.empty[String, String], cpr: Option[ClientRequestProperties] = None): DataFrame = {
      if (cpr.isDefined) {
        df.option(KustoSourceOptions.KUSTO_CLIENT_REQUEST_PROPERTIES_JSON, cpr.get.toString)

        .option(KustoSourceOptions.KUSTO_CLUSTER, kustoCluster)
        .option(KustoSourceOptions.KUSTO_DATABASE, database)
        .option(KustoSourceOptions.KUSTO_QUERY, query)

  implicit class DataFrameWriterExtension(df: DataFrameWriter[Row]) {
    def kusto(kustoCluster: String, database: String, table: String, conf: Map[String, String] = Map.empty[String, String], sparkIngestionProperties: Option[SparkIngestionProperties] = None): Unit = {
      if (sparkIngestionProperties.isDefined) {
        df.option(KustoSinkOptions.KUSTO_SPARK_INGESTION_PROPERTIES_JSON, sparkIngestionProperties.get.toString)

      .option(KustoSinkOptions.KUSTO_CLUSTER, kustoCluster)
      .option(KustoSinkOptions.KUSTO_DATABASE, database)
      .option(KustoSinkOptions.KUSTO_TABLE, table)

  implicit class DataStreamWriterExtension(df: DataStreamWriter[Row]) {
    def kusto(kustoCluster: String, database: String, table: String, conf: Map[String, String] = Map.empty[String, String], sparkIngestionProperties: Option[SparkIngestionProperties] = None): Unit = {
      if (sparkIngestionProperties.isDefined) {
        df.option(KustoSinkOptions.KUSTO_SPARK_INGESTION_PROPERTIES_JSON, sparkIngestionProperties.get.toString)

        .option(KustoSinkOptions.KUSTO_CLUSTER, kustoCluster)
        .option(KustoSinkOptions.KUSTO_DATABASE, database)
        .option(KustoSinkOptions.KUSTO_TABLE, table)

package org.sparksamples.regression.bikesharing

import org.apache.log4j.Logger
import{VectorAssembler, VectorIndexer}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{SparkSession, _}

object GeneralizedLinearRegressionPipeline {

  @transient lazy val logger = Logger.getLogger(getClass.getName)

  def genLinearRegressionWithVectorFormat(vectorAssembler: VectorAssembler, vectorIndexer: VectorIndexer, dataFrame: DataFrame) = {
    val lr = new GeneralizedLinearRegression()

    val pipeline = new Pipeline().setStages(Array(vectorAssembler, vectorIndexer, lr))

    val Array(training, test) = dataFrame.randomSplit(Array(0.8, 0.2), seed = 12345)

    val model =

    val fullPredictions = model.transform(test).cache()
    val predictions ="prediction")
    val labels ="label")
    val RMSE = new RegressionMetrics(
    println(s"  Root mean squared error (RMSE): $RMSE")

  def genLinearRegressionWithSVMFormat(spark: SparkSession) = {
    // Load training data
    val training ="libsvm")

    val lr = new GeneralizedLinearRegression()

    // Fit the model
    val model =

    // Print the coefficients and intercept for generalized linear regression model
    println(s"Coefficients: ${model.coefficients}")
    println(s"Intercept: ${model.intercept}")

    // Summarize the model over the training set and print out some metrics
    val summary = model.summary
    println(s"Coefficient Standard Errors: ${summary.coefficientStandardErrors.mkString(",")}")
    println(s"T Values: ${summary.tValues.mkString(",")}")
    println(s"P Values: ${summary.pValues.mkString(",")}")
    println(s"Dispersion: ${summary.dispersion}")
    println(s"Null Deviance: ${summary.nullDeviance}")
    println(s"Residual Degree Of Freedom Null: ${summary.residualDegreeOfFreedomNull}")
    println(s"Deviance: ${summary.deviance}")
    println(s"Residual Degree Of Freedom: ${summary.residualDegreeOfFreedom}")
    println(s"AIC: ${summary.aic}")
    println("Deviance Residuals: ")
    summary.residuals().show()  }

package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
package org.apache.spark.sql.hive


import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
import org.apache.spark.util.Utils

class QueryPartitionSuite extends QueryTest {
  import org.apache.spark.sql.hive.test.TestHive.implicits._

  test("SPARK-5068: query data when path doesn't exist"){
    val testData = TestHive.sparkContext.parallelize(
      (1 to 10).map(i => TestData(i, i.toString))).toDF()

    val tmpDir = Files.createTempDir()
    // create the table for test
    sql(s"CREATE TABLE table_with_partition(key int,value string) " +
      s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ")
    sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='1') " +
      "SELECT key,value FROM testData")
    sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='2') " +
      "SELECT key,value FROM testData")
    sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='3') " +
      "SELECT key,value FROM testData")
    sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='4') " +
      "SELECT key,value FROM testData")

    // test for the exist path
    checkAnswer(sql("select key,value from table_with_partition"),
      testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect
        ++ testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect)

    // delete the path of one partition
      .find { f => f.isDirectory && f.getName().startsWith("ds=") }
      .foreach { f => Utils.deleteRecursively(f) }

    // test for after delete the path
    checkAnswer(sql("select key,value from table_with_partition"),
      testData.toSchemaRDD.collect ++ testData.toSchemaRDD.collect
        ++ testData.toSchemaRDD.collect)

    sql("DROP TABLE table_with_partition")
    sql("DROP TABLE createAndInsertTest")
package org.apache.spark.sql.hive

import org.apache.spark.sql.test.SQLTestUtils

import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.util.Utils

class QueryPartitionSuite extends QueryTest with SQLTestUtils {

  private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
  import ctx.implicits._

  protected def _sqlContext = ctx
  test("SPARK-5068: query data when path doesn't exist"){
    withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) {
      val testData = ctx.sparkContext.parallelize(
        (1 to 10).map(i => TestData(i, i.toString))).toDF()

      val tmpDir = Files.createTempDir()
      // create the table for test 创建表进行测试
      sql(s"CREATE TABLE table_with_partition(key int,value string) " +
        s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='1') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='2') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='3') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='4') " +
        "SELECT key,value FROM testData")

      // test for the exist path 测试存在的路径
      checkAnswer(sql("select key,value from table_with_partition"),
        testData.toDF.collect ++ testData.toDF.collect
          ++ testData.toDF.collect ++ testData.toDF.collect)

      // delete the path of one partition 删除一个分区的路径
        .find { f => f.isDirectory && f.getName().startsWith("ds=") }
        .foreach { f => Utils.deleteRecursively(f) }

      // test for after delete the path 测试后删除路径
      checkAnswer(sql("select key,value from table_with_partition"),
        testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect)

      sql("DROP TABLE table_with_partition")
      sql("DROP TABLE createAndInsertTest")
package org.apache.spark.sql.hbase

import org.apache.spark.sql.types._
import org.apache.spark.sql.{SQLConf, _}

class HBaseAdvancedSQLQuerySuite extends TestBaseWithSplitData {

  import org.apache.spark.sql.hbase.TestHbase._
  import org.apache.spark.sql.hbase.TestHbase.implicits._

  test("aggregation with codegen") {
    val originalValue = TestHbase.conf.codegenEnabled
    setConf(SQLConf.CODEGEN_ENABLED, "true")
    val result = sql("SELECT col1 FROM ta GROUP BY col1").collect()
    assert(result.length == 14, s"aggregation with codegen test failed on size")
    setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)

  test("dsl simple select 0") {
    val tableA = sql("SELECT * FROM ta")
      tableA.where('col7 === 1).orderBy('col2.asc).select('col4),
      Row(1) :: Nil)
      tableA.where('col2 === 6).orderBy('col2.asc).select('col7),
      Row(-31) :: Nil)

  test("metadata is propagated correctly") {
    val tableA = sql("SELECT col7, col1, col3 FROM ta")
    val schema = tableA.schema
    val docKey = "doc"
    val docValue = "first name"
    val metadata = new MetadataBuilder()
      .putString(docKey, docValue)
    val schemaWithMeta = new StructType(Array(
      schema("col7"), schema("col1").copy(metadata = metadata), schema("col3")))
    val personWithMeta = createDataFrame(tableA.rdd, schemaWithMeta)
    def validateMetadata(rdd: DataFrame): Unit = {
      assert(rdd.schema("col1").metadata.getString(docKey) == docValue)
    validateMetadata($"col7", $"col1"))
    validateMetadata(sql("SELECT * FROM personWithMeta"))
    validateMetadata(sql("SELECT col7, col1 FROM personWithMeta"))
    validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON col7 = personId"))
    validateMetadata(sql("SELECT col1, salary FROM personWithMeta JOIN salary ON col7 = personId"))
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
package org.apache.spark.sql.hive


import org.apache.spark.util.Utils
import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils

class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  import hiveContext.implicits._

  test("SPARK-5068: query data when path doesn't exist") {
    withSQLConf((SQLConf.HIVE_VERIFY_PARTITION_PATH.key, "true")) {
      val testData = sparkContext.parallelize(
        (1 to 10).map(i => TestData(i, i.toString))).toDF()

      val tmpDir = Files.createTempDir()
      // create the table for test
      sql(s"CREATE TABLE table_with_partition(key int,value string) " +
        s"PARTITIONED by (ds string) location '${tmpDir.toURI.toString}' ")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='1') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='2') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='3') " +
        "SELECT key,value FROM testData")
      sql("INSERT OVERWRITE TABLE table_with_partition  partition (ds='4') " +
        "SELECT key,value FROM testData")

      // test for the exist path
      checkAnswer(sql("select key,value from table_with_partition"),
        testData.toDF.collect ++ testData.toDF.collect
          ++ testData.toDF.collect ++ testData.toDF.collect)

      // delete the path of one partition
        .find { f => f.isDirectory && f.getName().startsWith("ds=") }
        .foreach { f => Utils.deleteRecursively(f) }

      // test for after delete the path
      checkAnswer(sql("select key,value from table_with_partition"),
        testData.toDF.collect ++ testData.toDF.collect ++ testData.toDF.collect)

      sql("DROP TABLE table_with_partition")
      sql("DROP TABLE createAndInsertTest")
package akka.persistence.jdbc.spark.sql.execution.streaming

import org.apache.spark.sql.execution.streaming.{ LongOffset, Offset, Source }
import org.apache.spark.sql.sources.{ DataSourceRegister, StreamSourceProvider }
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{ SQLContext, _ }

object CurrentEventsByPersistenceIdQuerySourceProvider {
  val name = "current-events-by-persistence-id"

class CurrentEventsByPersistenceIdQuerySourceProvider extends StreamSourceProvider with DataSourceRegister with Serializable {
  override def sourceSchema(
    sqlContext: SQLContext,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): (String, StructType) = {
    println(s"[CurrentEventsByPersistenceIdQuerySourceProvider.sourceSchema]: schema: $schema, providerName: $providerName, parameters: $parameters") -> schema.get

  override def createSource(
    sqlContext: SQLContext,
    metadataPath: String,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): Source = {

    val eventMapperFQCN: String = parameters.get("event-mapper") match {
      case Some(_eventMapper) => _eventMapper
      case _                  => throw new RuntimeException("No event mapper FQCN")

    val pid = (parameters.get("pid"), parameters.get("persistence-id")) match {
      case (Some(pid), _) => pid
      case (_, Some(pid)) => pid
      case _              => throw new RuntimeException("No persistence_id")

    new CurrentEventsByPersistenceIdQuerySourceImpl(sqlContext, parameters("path"), eventMapperFQCN, pid, schema.get)
  override def shortName(): String =

class CurrentEventsByPersistenceIdQuerySourceImpl(val sqlContext: SQLContext, val readJournalPluginId: String, eventMapperFQCN: String, persistenceId: String, override val schema: StructType) extends Source with ReadJournalSource {
  override def getOffset: Option[Offset] = {
    val offset = maxEventsByPersistenceId(persistenceId)
    println("[CurrentEventsByPersistenceIdQuery]: Returning maximum offset: " + offset)

  override def getBatch(_start: Option[Offset], _end: Offset): DataFrame = {
    val (start, end) = getStartEnd(_start, _end)
    val df: DataFrame = eventsByPersistenceId(persistenceId, start, end, eventMapperFQCN)
    println(s"[CurrentEventsByPersistenceIdQuery]: Getting currentPersistenceIds from start: $start, end: $end, DataFrame.count: ${df.count}")
Example 14
package akka.persistence.jdbc.spark.sql.execution.streaming

import org.apache.spark.sql.execution.streaming.{ LongOffset, Offset, Source }
import org.apache.spark.sql.sources.{ DataSourceRegister, StreamSourceProvider }
import org.apache.spark.sql.types.{ StringType, StructField, StructType }
import org.apache.spark.sql.{ SQLContext, _ }

object CurrentPersistenceIdsQuerySourceProvider {
  val name = "current-persistence-id"
  val schema: StructType = StructType(Array(
    StructField("persistence_id", StringType, nullable = false)

class CurrentPersistenceIdsQuerySourceProvider extends StreamSourceProvider with DataSourceRegister with Serializable {
  override def sourceSchema(
    sqlContext: SQLContext,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): (String, StructType) = { -> CurrentPersistenceIdsQuerySourceProvider.schema

  override def createSource(
    sqlContext: SQLContext,
    metadataPath: String,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): Source = {
    new CurrentPersistenceIdsQuerySourceImpl(sqlContext, parameters("path"))
  override def shortName(): String =

class CurrentPersistenceIdsQuerySourceImpl(val sqlContext: SQLContext, val readJournalPluginId: String) extends Source with ReadJournalSource {
  override def schema: StructType = CurrentPersistenceIdsQuerySourceProvider.schema

  override def getOffset: Option[Offset] = {
    val offset = maxPersistenceIds
    println("[CurrentPersistenceIdsQuery]: Returning maximum offset: " + offset)

  override def getBatch(_start: Option[Offset], _end: Offset): DataFrame = {
    val (start, end) = getStartEnd(_start, _end)
    println(s"[CurrentPersistenceIdsQuery]: Getting currentPersistenceIds from start: $start, end: $end")
    import sqlContext.implicits._
    persistenceIds(start, end).toDF()