org.apache.spark.sql.execution.streaming.Sink Scala Examples

package net.snowflake.spark.snowflake

import net.snowflake.spark.snowflake.streaming.SnowflakeSink
import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_SHORT_NAME
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.slf4j.LoggerFactory

  override def createRelation(sqlContext: SQLContext,
                              saveMode: SaveMode,
                              parameters: Map[String, String],
                              data: DataFrame): BaseRelation = {

    val params = Parameters.mergeParameters(parameters)
    // check spark version for push down
    if (params.autoPushdown) {
    // pass parameters to pushdown functions
    val table = params.table.getOrElse {
      throw new IllegalArgumentException(
        "For save operations you must specify a Snowfake table name with the 'dbtable' parameter"

    def tableExists: Boolean = {
      val conn = jdbcWrapper.getConnector(params)
      try {
        jdbcWrapper.tableExists(conn, table.toString)
      } finally {

    val (doSave, dropExisting) = saveMode match {
      case SaveMode.Append => (true, false)
      case SaveMode.Overwrite => (true, true)
      case SaveMode.ErrorIfExists =>
        if (tableExists) {
            s"Table $table already exists! (SaveMode is set to ErrorIfExists)"
        } else {
          (true, false)
      case SaveMode.Ignore =>
        if (tableExists) {
"Table $table already exists -- ignoring save request.")
          (false, false)
        } else {
          (true, false)

    if (doSave) {
      val updatedParams = parameters.updated("overwrite", dropExisting.toString)
      new SnowflakeWriter(jdbcWrapper)


    createRelation(sqlContext, parameters)

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink =
    new SnowflakeSink(sqlContext, parameters, partitionColumns, outputMode)
package com.samelamin.spark.bigquery.streaming

import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.streaming.Sink
import com.samelamin.spark.bigquery._
import org.slf4j.LoggerFactory
import scala.util.Try
import org.apache.hadoop.fs.Path

class BigQuerySink(sparkSession: SparkSession, path: String, options: Map[String, String]) extends Sink {
  private val logger = LoggerFactory.getLogger(classOf[BigQuerySink])
  private val basePath = new Path(path)
  private val logPath = new Path(basePath, new Path(BigQuerySink.metadataDir,"transaction.json"))

  private val fileLog = new BigQuerySinkLog(sparkSession, logPath.toUri.toString)
  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    if (batchId <= fileLog.getLatest().getOrElse(-1L)) {"Skipping already committed batch $batchId")
    } else {
      val fullyQualifiedOutputTableId = options.get("tableReferenceSink").get
      val isPartitionByDay = Try(options.get("partitionByDay").get.toBoolean).getOrElse(true)

      val bqDF = new BigQueryDataFrame(data)
      bqDF.saveAsBigQueryTable(fullyQualifiedOutputTableId, isPartitionByDay)

object BigQuerySink {
  // The name of the subdirectory that is used to store metadata about which files are valid.
  val metadataDir = "_spark_metadata"
package com.cloudera.streaming.refapp.kudu

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

class KuduSinkProvider extends StreamSinkProvider with DataSourceRegister {

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {
    require(outputMode == OutputMode.Update, "only 'update' OutputMode is supported")
    KuduSink.withDefaultContext(sqlContext, parameters)

  override def shortName(): String = "kudu"
package com.cloudera.streaming.refapp.kudu

import org.apache.kudu.spark.kudu.KuduContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.slf4j.LoggerFactory

import scala.util.control.NonFatal

object KuduSink {
  def withDefaultContext(sqlContext: SQLContext, parameters: Map[String, String]) =
    new KuduSink(new KuduContext(parameters("kudu.master"), sqlContext.sparkContext), parameters)

class KuduSink(initKuduContext: => KuduContext, parameters: Map[String, String]) extends Sink {

  private val logger = LoggerFactory.getLogger(getClass)

  private var kuduContext = initKuduContext

  private val tablename = parameters("kudu.table")

  private val retries = parameters.getOrElse("retries", "1").toInt
  require(retries >= 0, "retries must be non-negative")"Created Kudu sink writing to table $tablename")

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    for (attempt <- 0 to retries) {
      try {
        kuduContext.upsertRows(data, tablename)
      } catch {
        case NonFatal(e) =>
          if (attempt < retries) {
            logger.warn("Kudu upsert error, retrying...", e)
            kuduContext = initKuduContext
          else {
            logger.error("Kudu upsert error, exhausted", e)
            throw e
package com.cloudera.streaming.refapp.kudu

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

class KuduSinkProvider extends StreamSinkProvider with DataSourceRegister {

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {
    require(outputMode == OutputMode.Update, "only 'update' OutputMode is supported")
    KuduSink.withDefaultContext(sqlContext, parameters)

  override def shortName(): String = "kudu"
package com.cloudera.streaming.refapp.kudu

import org.apache.kudu.spark.kudu.KuduContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.slf4j.LoggerFactory

import scala.util.control.NonFatal

object KuduSink {
  def withDefaultContext(sqlContext: SQLContext, parameters: Map[String, String]) =
    new KuduSink(new KuduContext(parameters("kudu.master"), sqlContext.sparkContext), parameters)

class KuduSink(initKuduContext: => KuduContext, parameters: Map[String, String]) extends Sink {

  private val logger = LoggerFactory.getLogger(getClass)

  private var kuduContext = initKuduContext

  private val tablename = parameters("kudu.table")

  private val retries = parameters.getOrElse("retries", "1").toInt
  require(retries >= 0, "retries must be non-negative")"Created Kudu sink writing to table $tablename")

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    for (attempt <- 0 to retries) {
      try {
        kuduContext.upsertRows(data, tablename)
      } catch {
        case NonFatal(e) =>
          if (attempt < retries) {
            logger.warn("Kudu upsert error, retrying...", e)
            kuduContext = initKuduContext
          else {
            logger.error("Kudu upsert error, exhausted", e)
            throw e
package com.couchbase.spark.sql.streaming

import com.couchbase.spark.Logging
import org.apache.spark.sql.{DataFrame, SaveMode}
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.types.StringType
import com.couchbase.spark.sql._
import com.couchbase.spark._
import com.couchbase.client.core.CouchbaseException
import scala.concurrent.duration._

class CouchbaseSink(options: Map[String, String]) extends Sink with Logging {

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    val bucketName = options.get("bucket").orNull
    val idFieldName = options.getOrElse("idField", DefaultSource.DEFAULT_DOCUMENT_ID_FIELD)
    val removeIdField = options.getOrElse("removeIdField", "true").toBoolean
    val timeout = options.get("timeout").map(v => Duration(v.toLong, MILLISECONDS))

    val createDocument = options.get("expiry").map(_.toInt)
      .map(expiry => (id: String, content: JsonObject) => JsonDocument.create(id, expiry, content))
      .getOrElse((id: String, content: JsonObject) => JsonDocument.create(id, content))

      .map(_.get(0, StringType).asInstanceOf[UTF8String].toString())
      .map { rawJson =>
          val encoded = JsonObject.fromJson(rawJson)
          val id = encoded.get(idFieldName)

          if (id == null) {
              throw new Exception(s"Could not find ID field $idFieldName in $encoded")

          if (removeIdField) {

          createDocument(id.toString, encoded)
      .saveToCouchbase(bucketName, StoreMode.UPSERT, timeout)

package com.qubole.spark.hiveacid.streaming

import com.qubole.spark.hiveacid.{HiveAcidErrors, HiveAcidTable}
import com.qubole.spark.hiveacid.hive.HiveAcidMetadata
import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.streaming.Sink

class HiveAcidSink(sparkSession: SparkSession,
                   parameters: Map[String, String]) extends Sink with Logging {

  import HiveAcidSink._

  private val acidSinkOptions = new HiveAcidSinkOptions(parameters)

  private val fullyQualifiedTableName = acidSinkOptions.tableName

  private val hiveAcidTable: HiveAcidTable = HiveAcidTable.fromSparkSession(


  private val logPath = getMetaDataPath()
  private val fileLog = new HiveAcidSinkLog(
    HiveAcidSinkLog.VERSION, sparkSession, logPath.toUri.toString, acidSinkOptions)

  private def assertNonBucketedTable(): Unit = {
    if(hiveAcidTable.isBucketed) {
      throw HiveAcidErrors.unsupportedOperationTypeBucketedTable("Streaming Write", fullyQualifiedTableName)

  private def getMetaDataPath(): Path = {
    acidSinkOptions.metadataDir match {
      case Some(dir) =>
        new Path(dir)
      case None =>
        logInfo(s"Metadata dir not specified. Using " +
          s"$metadataDirPrefix/_query_default as metadata dir")
        logWarning(s"Please make sure that multiple streaming writes to " +
          s"$fullyQualifiedTableName are not running")
        val tableLocation = HiveAcidMetadata.fromSparkSession(
          sparkSession, fullyQualifiedTableName).rootPath
        new Path(tableLocation, s"$metadataDirPrefix/_query_default")

  override def addBatch(batchId: Long, df: DataFrame): Unit = {

    if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) {
      logInfo(s"Skipping already committed batch $batchId")
    } else {

      val commitProtocol = new HiveAcidStreamingCommitProtocol(fileLog)
      val txnId = hiveAcidTable.addBatch(df)
      commitProtocol.commitJob(batchId, txnId)


  override def toString: String = s"HiveAcidSinkV1[$fullyQualifiedTableName]"


object HiveAcidSink {

  val metadataDirPrefix = "_acid_streaming"
package com.qubole.spark.hiveacid.datasource

import com.qubole.spark.hiveacid.{HiveAcidErrors, HiveAcidTable}
import com.qubole.spark.hiveacid.streaming.HiveAcidSink

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode

class HiveAcidDataSource
  extends RelationProvider          // USING HiveAcid
    with CreatableRelationProvider  // Insert into/overwrite
    with DataSourceRegister         // FORMAT("HiveAcid")
    with StreamSinkProvider
    with Logging {

  // returns relation for passed in table name
  override def createRelation(sqlContext: SQLContext,
                              parameters: Map[String, String]): BaseRelation = {
    HiveAcidRelation(sqlContext.sparkSession, getFullyQualifiedTableName(parameters), parameters)

  // returns relation after writing passed in data frame. Table name is part of parameter
  override def createRelation(sqlContext: SQLContext,
                              mode: SaveMode,
                              parameters: Map[String, String],
                              df: DataFrame): BaseRelation = {

    val hiveAcidTable: HiveAcidTable = HiveAcidTable.fromSparkSession(

    mode match {
      case SaveMode.Overwrite =>
      case SaveMode.Append =>
      // TODO: Add support for these
      case SaveMode.ErrorIfExists | SaveMode.Ignore =>
    createRelation(sqlContext, parameters)

  override def shortName(): String = {

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {

    tableSinkAssertions(partitionColumns, outputMode)

    new HiveAcidSink(sqlContext.sparkSession, parameters)

  private def tableSinkAssertions(partitionColumns: Seq[String], outputMode: OutputMode): Unit = {

    if (partitionColumns.nonEmpty) {
      throw HiveAcidErrors.unsupportedFunction("partitionBy", "HiveAcidSink")
    if (outputMode != OutputMode.Append) {
      throw HiveAcidErrors.unsupportedStreamingOutputMode(s"$outputMode")


  private def getFullyQualifiedTableName(parameters: Map[String, String]): String = {
    parameters.getOrElse("table", {
      throw HiveAcidErrors.tableNotSpecifiedException()

object HiveAcidDataSource {
  val NAME = "HiveAcid"
package cassandra.StreamSinkProvider

import cassandra.{CassandraDriver, CassandraKafkaMetadata}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.functions.max
import spark.SparkHelper
import cassandra.CassandraDriver
import com.datastax.spark.connector._
import kafka.KafkaMetadata
import log.LazyLogger
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.types.LongType
import radio.SimpleSongAggregation

  private def saveKafkaMetaData(df: DataFrame) = {
    val kafkaMetadata = df

    log.warn("Saving Kafka Metadata (partition and offset per topic (only one in our example)")

      SomeColumns("partition", "offset")

    //Otherway to save offset inside Cassandra
package org.apache.spark.sql.kinesis

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.OutputMode

private[kinesis] class KinesisSink(sqlContext: SQLContext,
                                   sinkOptions: Map[String, String],
                                   outputMode: OutputMode)
  extends Sink with Logging {

  @volatile private var latestBatchId = -1L

  override def toString: String = "KinesisSink"

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    if (batchId <= latestBatchId) {
      logInfo(s"Skipping already committed batch $batchId")
    } else {
      KinesisWriter.write(sqlContext.sparkSession, data.queryExecution, sinkOptions)
      latestBatchId = batchId
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.streaming.Sink

private[kafka010] class KafkaSink(
    sqlContext: SQLContext,
    executorKafkaParams: ju.Map[String, Object],
    topic: Option[String]) extends Sink with Logging {
  @volatile private var latestBatchId = -1L

  override def toString(): String = "KafkaSink"

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    if (batchId <= latestBatchId) {
      logInfo(s"Skipping already committed batch $batchId")
    } else {
        data.queryExecution, executorKafkaParams, topic)
      latestBatchId = batchId
package com.knockdata.spark.highcharts

import com.knockdata.spark.highcharts.model.Highcharts
import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.StreamSinkProvider
import org.apache.spark.sql.streaming.OutputMode

class CustomSinkProvider extends StreamSinkProvider {
  def createSink(
                  sqlContext: SQLContext,
                  parameters: Map[String, String],
                  partitionColumns: Seq[String],
                  outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {

        val chartId = parameters("chartId")
        val chartParagraphId = parameters("chartParagraphId")

        println(s"batchId: $batchId, chartId: $chartId, chartParagraphId: $chartParagraphId")

        val z = Registry.get(s"$chartId-z").asInstanceOf[ZeppelinContextHolder]
        val seriesHolder = Registry.get(s"$chartId-seriesHolder").asInstanceOf[SeriesHolder]
        val outputMode = Registry.get(s"$chartId-outputMode").asInstanceOf[CustomOutputMode]

        seriesHolder.dataFrame = data

        val result = seriesHolder.result
        val (normalSeriesList, drilldownSeriesList) = outputMode.result(result._1, result._2)

        val chart = new Highcharts(normalSeriesList, seriesHolder.chartId)

        val plotData = chart.plotData
//        val escaped = plotData.replace("%angular", "")
//        println(s" put $chartParagraphId $escaped")
        z.put(chartParagraphId, plotData)
        println(s"run $chartParagraphId")
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
import{ImplicitMetadataOperation, SchemaUtils}
import org.apache.hadoop.fs.Path

import org.apache.spark.SparkContext
import org.apache.spark.sql._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.metric.SQLMetrics.createMetric
import org.apache.spark.sql.execution.streaming.{Sink, StreamExecution}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.NullType

class DeltaSink(
    sqlContext: SQLContext,
    path: Path,
    partitionColumns: Seq[String],
    outputMode: OutputMode,
    options: DeltaOptions)
  extends Sink with ImplicitMetadataOperation with DeltaLogging {

  private val deltaLog = DeltaLog.forTable(sqlContext.sparkSession, path)

  private val sqlConf = sqlContext.sparkSession.sessionState.conf

  override protected val canOverwriteSchema: Boolean =
    outputMode == OutputMode.Complete() && options.canOverwriteSchema

  override protected val canMergeSchema: Boolean = options.canMergeSchema

  override def addBatch(batchId: Long, data: DataFrame): Unit = deltaLog.withNewTransaction { txn =>
    val sc = data.sparkSession.sparkContext
    val metrics = Map[String, SQLMetric](
      "numAddedFiles" -> createMetric(sc, "number of files added"),
      "numRemovedFiles" -> createMetric(sc, "number of files removed")
    val queryId = sqlContext.sparkContext.getLocalProperty(StreamExecution.QUERY_ID_KEY)
    assert(queryId != null)

    if (SchemaUtils.typeExistsRecursively(data.schema)(_.isInstanceOf[NullType])) {
      throw DeltaErrors.streamWriteNullTypeException

    // If the batch reads the same Delta table as this sink is going to write to, then this
    // write has dependencies. Then make sure that this commit set hasDependencies to true
    // by injecting a read on the whole table. This needs to be done explicitly because
    // MicroBatchExecution has already enforced all the data skipping (by forcing the generation
    // of the executed plan) even before the transaction was started.
    val selfScan = data.queryExecution.analyzed.collectFirst {
      case DeltaTable(index) if index.deltaLog.isSameLogAs(txn.deltaLog) => true
    if (selfScan) {

    // Streaming sinks can't blindly overwrite schema. See Schema Management design doc for details
      configuration = Map.empty,
      outputMode == OutputMode.Complete())

    val currentVersion = txn.txnVersion(queryId)
    if (currentVersion >= batchId) {
      logInfo(s"Skipping already complete epoch $batchId, in query $queryId")

    val deletedFiles = outputMode match {
      case o if o == OutputMode.Complete() =>
      case _ => Nil
    val newFiles = txn.writeFiles(data, Some(options))
    val setTxn = SetTransaction(queryId, batchId, Some(deltaLog.clock.getTimeMillis())) :: Nil
    val info = DeltaOperations.StreamingUpdate(outputMode, queryId, batchId, options.userMetadata)
    txn.registerSQLMetrics(sqlContext.sparkSession, metrics)
    txn.commit(setTxn ++ newFiles ++ deletedFiles, info)
    // This is needed to make the SQL metrics visible in the Spark UI
    val executionId = sqlContext.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
      sqlContext.sparkContext, executionId, metrics.values.toSeq)

  override def toString(): String = s"DeltaSink[$path]"
package com.samelamin.spark.bigquery

import com.samelamin.spark.bigquery.converters.SchemaConverters
import com.samelamin.spark.bigquery.streaming.{BigQuerySink, BigQuerySource}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.{Sink, Source}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.sources.RelationProvider

class DefaultSource
  extends StreamSinkProvider
    with StreamSourceProvider with RelationProvider{
  override def createSink(sqlContext: SQLContext, parameters: Map[String, String],
                          partitionColumns: Seq[String], outputMode: OutputMode): Sink = {

    val path = parameters.get("transaction_log").getOrElse("transaction_log")
    new BigQuerySink(sqlContext.sparkSession, path, parameters)


  def getConvertedSchema(sqlContext: SQLContext,options: Map[String, String]): StructType = {
    val bigqueryClient = BigQueryClient.getInstance(sqlContext)
    val tableReference = BigQueryStrings.parseTableReference(options.get("tableReferenceSource").get)

  override def sourceSchema(sqlContext: SQLContext,
                            schema: Option[StructType],
                            providerName: String,
                            options: Map[String, String]): (String, StructType) = {
    val convertedSchema = getConvertedSchema(sqlContext,options)
    ("bigquery", schema.getOrElse(convertedSchema))

  override def createSource(sqlContext: SQLContext, metadataPath: String,
                            schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = {
    new BigQuerySource(sqlContext, schema, parameters)
override def createRelation(
  sqlContext: SQLContext,
  parameters: Map[String, String]): BigQueryRelation = {
    val tableName = parameters.get("tableReferenceSource").get
    new BigQueryRelation(tableName)(sqlContext)
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
package com.highperformancespark.examples.structuredstreaming

import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql._
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.execution.streaming.Sink


  override def addBatch(batchId: Long, data: DataFrame) = {
    val batchDistinctCount = data.rdd.distinct.count()
    println(s"Batch ${batchId}'s distinct count is ${batchDistinctCount}")
object CustomSinkDemo {
  def write(ds: Dataset[_]) = {
      "com.highperformancespark.examples.structuredstreaming." +
package org.apache.spark.sql.streaming

import scala.collection.mutable

import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.Sink

case class EvilStreamingQueryManager(streamingQueryManager: StreamingQueryManager) {
  def startQuery(
    userSpecifiedName: Option[String],
    userSpecifiedCheckpointLocation: Option[String],
    df: DataFrame,
    sink: Sink,
    outputMode: OutputMode): StreamingQuery = {
import{KeyVaultUtils, KustoDataSourceUtils}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

class KustoSinkProvider extends StreamSinkProvider with DataSourceRegister {

  override def shortName(): String = "KustoSink"

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {
    val sinkParameters = KustoDataSourceUtils.parseSinkParameters(parameters)

    new KustoSink(
        val paramsFromKeyVault = KeyVaultUtils.getAadAppParametersFromKeyVault(sinkParameters.sourceParametersResults.keyVaultAuth.get)
        KustoDataSourceUtils.mergeKeyVaultAndOptionsAuthentication(paramsFromKeyVault, Some(sinkParameters.sourceParametersResults.authenticationParameters))
      } else sinkParameters.sourceParametersResults.authenticationParameters,
Example 22
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, SQLContext}

class KustoSink(sqlContext: SQLContext,
                tableCoordinates: KustoCoordinates,
                authentication: KustoAuthentication,
                writeOptions: WriteOptions) extends Sink with Serializable {

  private val myName = this.getClass.getSimpleName
  val MessageSource = "KustoSink"
  @volatile private var latestBatchId = -1L

  override def toString = "KustoSink"

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    if (batchId <= latestBatchId) {
      KDSU.logInfo(myName, s"Skipping already committed batch $batchId")
    } else {
      KustoWriter.write(Option(batchId), data, tableCoordinates, authentication, writeOptions)
      latestBatchId = batchId
Example 23
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{ DataFrame, SQLContext }
import org.apache.spark.sql.execution.streaming.Sink

private[eventhubs] class EventHubsSink(sqlContext: SQLContext, parameters: Map[String, String])
    extends Sink
    with Logging {

  @volatile private var latestBatchId = -1L

  override def toString: String = "EventHubsSink"

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    if (batchId <= latestBatchId) {
      logInfo(s"Skipping already committed batch $batchId")
    } else {
      EventHubsWriter.write(sqlContext.sparkSession, data.queryExecution, parameters)
      latestBatchId = batchId
Example 24
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.StreamSinkProvider
import org.apache.spark.sql.streaming.OutputMode

import Params.map2Params

class HttpStreamSinkProvider
		extends StreamSinkProvider with DataSourceRegister {
	def createSink(
		sqlContext: SQLContext,
		parameters: Map[String, String],
		partitionColumns: Seq[String],
		outputMode: OutputMode): Sink = {
		new HttpStreamSink(parameters.getRequiredString("httpServletUrl"),
			parameters.getInt("maxPacketSize", 10 * 1024 * 1024));

	def shortName(): String = "httpStream"

class HttpStreamSink(httpPostURL: String, topic: String, maxPacketSize: Int)
		extends Sink with Logging {
	val producer = HttpStreamClient.connect(httpPostURL);
	val RETRY_TIMES = 5;
	val SLEEP_TIME = 100;

	override def addBatch(batchId: Long, data: DataFrame) {
		//send data to the HTTP server
		var success = false;
		var retried = 0;
		while (!success && retried < RETRY_TIMES) {
			try {
				retried += 1;
				producer.sendDataFrame(topic, batchId, data, maxPacketSize);
				success = true;
			catch {
				case e: Throwable ⇒ {
					success = false;
					super.logWarning(s"failed to send", e);
					if (retried < RETRY_TIMES) {
						val sleepTime = SLEEP_TIME * retried;
						super.logWarning(s"will retry to send after ${sleepTime}ms");
					else {
						throw e;
Example 25
import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
      override def stop() {}

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}

object BlockingSource {
  var latch: CountDownLatch = null
Example 26
import org.apache.spark.api.python.PythonException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.DataStreamWriter

class ForeachBatchSink[T](batchWriter: (Dataset[T], Long) => Unit, encoder: ExpressionEncoder[T])
  extends Sink {

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    val resolvedEncoder = encoder.resolveAndBind(
    val rdd =[T](resolvedEncoder.fromRow)(encoder.clsTag)
    val ds = data.sparkSession.createDataset(rdd)(encoder)
    batchWriter(ds, batchId)

  override def toString(): String = "ForeachBatchSink"

  def call(batchDF: DataFrame, batchId: Long): Unit

object PythonForeachBatchHelper {
  def callForeachBatch(dsw: DataStreamWriter[Row], pythonFunc: PythonForeachBatchFunction): Unit = {
    dsw.foreachBatch( _)
Example 27
import com.lucidworks.spark.{SolrRelation, SolrStreamWriter}
import com.lucidworks.spark.util.Constants
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode

class DefaultSource extends RelationProvider with CreatableRelationProvider with StreamSinkProvider with DataSourceRegister {

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    try {
      new SolrRelation(parameters, sqlContext.sparkSession)
    } catch {
      case re: RuntimeException => throw re
      case e: Exception => throw new RuntimeException(e)

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    try {
      // TODO: What to do with the saveMode?
      val solrRelation: SolrRelation = new SolrRelation(parameters, Some(df), sqlContext.sparkSession)
      solrRelation.insert(df, overwrite = true)
    } catch {
      case re: RuntimeException => throw re
      case e: Exception => throw new RuntimeException(e)

  override def shortName(): String = Constants.SOLR_FORMAT

  override def createSink(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new SolrStreamWriter(sqlContext.sparkSession, parameters, partitionColumns, outputMode)
Example 28
import com.lucidworks.spark.util.{SolrQuerySupport, SolrSupport}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.streaming.OutputMode
import com.lucidworks.spark.util.ConfigurationConstants._
import org.apache.spark.sql.types.StructType

import scala.collection.mutable

class SolrStreamWriter(
    val sparkSession: SparkSession,
    parameters: Map[String, String],
    val partitionColumns: Seq[String],
    val outputMode: OutputMode)(
  implicit val solrConf : SolrConf = new SolrConf(parameters))
  extends Sink with LazyLogging {

  require(solrConf.getZkHost.isDefined, s"Parameter ${SOLR_ZK_HOST_PARAM} not defined")
  require(solrConf.getCollection.isDefined, s"Parameter ${SOLR_COLLECTION_PARAM} not defined")

  val collection : String = solrConf.getCollection.get
  val zkhost: String = solrConf.getZkHost.get

  lazy val solrVersion : String = SolrSupport.getSolrVersion(solrConf.getZkHost.get)
  lazy val uniqueKey: String = SolrQuerySupport.getUniqueKey(zkhost, collection.split(",")(0))

  lazy val dynamicSuffixes: Set[String] = SolrQuerySupport.getFieldTypes(
      skipDynamicExtensions = false)
    .filter(f => f.startsWith("*_") || f.endsWith("_*"))
    .map(f => if (f.startsWith("*_")) f.substring(1) else f.substring(0, f.length-1))

  @volatile private var latestBatchId: Long = -1L
  val acc: SparkSolrAccumulator = new SparkSolrAccumulator
  val accName = if (solrConf.getAccumulatorName.isDefined) solrConf.getAccumulatorName.get else "Records Written"
  sparkSession.sparkContext.register(acc, accName)

  override def addBatch(batchId: Long, df: DataFrame): Unit = {
    if (batchId <= latestBatchId) {"Skipping already processed batch $batchId")
    } else {
      val rows = df.collect()
      if (rows.nonEmpty) {
        val schema: StructType = df.schema
        val solrClient = SolrSupport.getCachedCloudClient(zkhost)

        // build up a list of updates to send to the Solr Schema API
        val fieldsToAddToSolr = SolrRelation.getFieldsToAdd(schema, solrConf, solrVersion, dynamicSuffixes)

        if (fieldsToAddToSolr.nonEmpty) {
          SolrRelation.addFieldsForInsert(fieldsToAddToSolr, collection, solrClient)

        val solrDocs = => SolrRelation.convertRowToSolrInputDocument(row, solrConf, uniqueKey))
        SolrSupport.sendBatchToSolrWithRetry(zkhost, solrClient, collection, solrDocs, solrConf.commitWithin)"Written ${solrDocs.length} documents to Solr collection $collection from batch $batchId")
        latestBatchId = batchId
Example 29
import com.typesafe.config.{Config, ConfigFactory, ConfigRenderOptions}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

import scala.collection.JavaConversions._

class S2SinkProvider extends StreamSinkProvider with DataSourceRegister with Logger {
  override def createSink(
                  sqlContext: SQLContext,
                  parameters: Map[String, String],
                  partitionColumns: Seq[String],
                  outputMode: OutputMode): Sink = {"S2SinkProvider options : ${parameters}")
    val jobConf:Config = ConfigFactory.parseMap(parameters).withFallback(ConfigFactory.load())"S2SinkProvider Configuration : ${jobConf.root().render(ConfigRenderOptions.concise())}")

    new S2SparkSqlStreamingSink(sqlContext.sparkSession, jobConf)

  override def shortName(): String = "s2graph"
Example 30
import java.util.UUID

import com.typesafe.config.{Config, ConfigRenderOptions}
import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.streaming.{MetadataLog, Sink}
import org.apache.spark.sql.{DataFrame, SparkSession}

class S2SparkSqlStreamingSink(
                               sparkSession: SparkSession,
                             ) extends Sink with Logger {
  import S2SinkConfigs._

  private val APP_NAME = "s2graph"

  private val writeLog: MetadataLog[Array[S2SinkStatus]] = {
    val logPath = getCommitLogPath(config)"MetaDataLogPath: $logPath")

    new S2SinkMetadataLog(sparkSession, config, logPath)

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    logger.debug(s"addBatch : $batchId, getLatest : ${writeLog.getLatest()}")

    if (batchId <= writeLog.getLatest().map(_._1).getOrElse(-1L)) {"Skipping already committed batch [$batchId]")
    } else {
      val queryName = getConfigStringOpt(config, "queryname").getOrElse(UUID.randomUUID().toString)
      val commitProtocol = new S2CommitProtocol(writeLog)
      val jobState = JobState(queryName, batchId)
      val serializedConfig = config.root().render(ConfigRenderOptions.concise())
      val queryExecution = data.queryExecution
      val schema = data.schema

      SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
        try {
          val taskCommits = sparkSession.sparkContext.runJob(queryExecution.toRdd,
            (taskContext: TaskContext, iter: Iterator[InternalRow]) => {
              new S2StreamQueryWriter(serializedConfig, schema, commitProtocol).run(taskContext, iter)
          commitProtocol.commitJob(jobState, taskCommits)
        } catch {
          case t: Throwable =>
            throw t;


  private def getCommitLogPath(config:Config): String = {
    val logPathOpt = getConfigStringOpt(config, S2_SINK_LOG_PATH)
    val userCheckpointLocationOpt = getConfigStringOpt(config, S2_SINK_CHECKPOINT_LOCATION)

    (logPathOpt, userCheckpointLocationOpt) match {
      case (Some(logPath), _) => logPath
      case (None, Some(userCheckpoint)) => s"$userCheckpoint/sinks/$APP_NAME"
      case _ => throw new IllegalArgumentException(s"failed to get commit log path")

  override def toString(): String = "S2GraphSink"