org.apache.spark.sql.sources.DataSourceRegister Scala Examples

The following examples show how to use org.apache.spark.sql.sources.DataSourceRegister. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: JdbcRelationProvider.scala    From drizzle-spark   with Apache License 2.0 7 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends CreatableRelationProvider
  with RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val partitionColumn = jdbcOptions.partitionColumn
    val lowerBound = jdbcOptions.lowerBound
    val upperBound = jdbcOptions.upperBound
    val numPartitions = jdbcOptions.numPartitions

    val partitionInfo = if (partitionColumn == null) {
      null
    } else {
      JDBCPartitioningInfo(
        partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val url = jdbcOptions.url
    val table = jdbcOptions.table
    val createTableOptions = jdbcOptions.createTableOptions
    val isTruncate = jdbcOptions.isTruncate

    val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, url, table)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, table)
              saveTable(df, url, table, jdbcOptions)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, table)
              createTable(df.schema, url, table, createTableOptions, conn)
              saveTable(df, url, table, jdbcOptions)
            }

          case SaveMode.Append =>
            saveTable(df, url, table, jdbcOptions)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '$table' already exists. SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(df.schema, url, table, createTableOptions, conn)
        saveTable(df, url, table, jdbcOptions)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }
} 
Example 2
Source File: HadoopFsRelation.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.types.{StructField, StructType}



case class HadoopFsRelation(
    location: FileIndex,
    partitionSchema: StructType,
    dataSchema: StructType,
    bucketSpec: Option[BucketSpec],
    fileFormat: FileFormat,
    options: Map[String, String])(val sparkSession: SparkSession)
  extends BaseRelation with FileRelation {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  private def getColName(f: StructField): String = {
    if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
      f.name
    } else {
      f.name.toLowerCase(Locale.ROOT)
    }
  }

  val overlappedPartCols = mutable.Map.empty[String, StructField]
  partitionSchema.foreach { partitionField =>
    if (dataSchema.exists(getColName(_) == getColName(partitionField))) {
      overlappedPartCols += getColName(partitionField) -> partitionField
    }
  }

  // When data and partition schemas have overlapping columns, the output
  // schema respects the order of the data schema for the overlapping columns, and it
  // respects the data types of the partition schema.
  val schema: StructType = {
    StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++
      partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f))))
  }

  def partitionSchemaOption: Option[StructType] =
    if (partitionSchema.isEmpty) None else Some(partitionSchema)

  override def toString: String = {
    fileFormat match {
      case source: DataSourceRegister => source.shortName()
      case _ => "HadoopFiles"
    }
  }

  override def sizeInBytes: Long = {
    val compressionFactor = sqlContext.conf.fileCompressionFactor
    (location.sizeInBytes * compressionFactor).toLong
  }


  override def inputFiles: Array[String] = location.inputFiles
} 
Example 3
Source File: MQTTStreamSink.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.bahir.sql.streaming.mqtt

import scala.collection.JavaConverters._
import scala.collection.mutable

import org.eclipse.paho.client.mqttv3.MqttException

import org.apache.spark.SparkEnv
import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

import org.apache.bahir.utils.Logging
import org.apache.bahir.utils.Retry


class MQTTStreamWriter (schema: StructType, parameters: DataSourceOptions)
    extends StreamWriter with Logging {
  override def createWriterFactory(): DataWriterFactory[InternalRow] = {
    // Skipping client identifier as single batch can be distributed to multiple
    // Spark worker process. MQTT server does not support two connections
    // declaring same client ID at given point in time.
    val params = parameters.asMap().asScala.filterNot(
      _._1.equalsIgnoreCase("clientId")
    )
    MQTTDataWriterFactory(params)
  }

  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

  override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
}

case class MQTTDataWriterFactory(config: mutable.Map[String, String])
    extends DataWriterFactory[InternalRow] {
  override def createDataWriter(
    partitionId: Int, taskId: Long, epochId: Long
  ): DataWriter[InternalRow] = new MQTTDataWriter(config)
}

case object MQTTWriterCommitMessage extends WriterCommitMessage

class MQTTDataWriter(config: mutable.Map[String, String]) extends DataWriter[InternalRow] {
  private lazy val publishAttempts: Int =
    SparkEnv.get.conf.getInt("spark.mqtt.client.publish.attempts", -1)
  private lazy val publishBackoff: Long =
    SparkEnv.get.conf.getTimeAsMs("spark.mqtt.client.publish.backoff", "5s")

  private lazy val (_, _, topic, _, _, qos, _, _, _) = MQTTUtils.parseConfigParams(config.toMap)

  override def write(record: InternalRow): Unit = {
    val client = CachedMQTTClient.getOrCreate(config.toMap)
    val message = record.getBinary(0)
    Retry(publishAttempts, publishBackoff, classOf[MqttException]) {
      // In case of errors, retry sending the message.
      client.publish(topic, message, qos, false)
    }
  }

  override def commit(): WriterCommitMessage = MQTTWriterCommitMessage

  override def abort(): Unit = {}
}

case class MQTTRelation(override val sqlContext: SQLContext, data: DataFrame)
    extends BaseRelation {
  override def schema: StructType = data.schema
}

class MQTTStreamSinkProvider extends DataSourceV2 with StreamWriteSupport
    with DataSourceRegister with CreatableRelationProvider {
  override def createStreamWriter(queryId: String, schema: StructType,
      mode: OutputMode, options: DataSourceOptions): StreamWriter = {
    new MQTTStreamWriter(schema, options)
  }

  override def createRelation(sqlContext: SQLContext, mode: SaveMode,
      parameters: Map[String, String], data: DataFrame): BaseRelation = {
    MQTTRelation(sqlContext, data)
  }

  override def shortName(): String = "mqtt"
} 
Example 4
Source File: HDFSMQTTSourceProvider.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.mqtt

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.StructType

import org.apache.bahir.sql.streaming.mqtt.{MQTTStreamConstants, MQTTUtils}


class HDFSMQTTSourceProvider extends StreamSourceProvider with DataSourceRegister with Logging {

  override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType],
    providerName: String, parameters: Map[String, String]): (String, StructType) = {
    ("hdfs-mqtt", MQTTStreamConstants.SCHEMA_DEFAULT)
  }

  override def createSource(sqlContext: SQLContext, metadataPath: String,
    schema: Option[StructType], providerName: String, parameters: Map[String, String]): Source = {

    val parsedResult = MQTTUtils.parseConfigParams(parameters)

    new HdfsBasedMQTTStreamSource(
      sqlContext,
      metadataPath,
      parsedResult._1, // brokerUrl
      parsedResult._2, // clientId
      parsedResult._3, // topic
      parsedResult._5, // mqttConnectionOptions
      parsedResult._6, // qos
      parsedResult._7, // maxBatchMessageNum
      parsedResult._8, // maxBatchMessageSize
      parsedResult._9  // maxRetryNum
    )
  }

  override def shortName(): String = "hdfs-mqtt"
}

object HDFSMQTTSourceProvider {
  val SEP = "##"
} 
Example 5
Source File: JdbcSourceProvider.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.bahir.sql.streaming.jdbc

import scala.collection.JavaConverters._

import org.apache.spark.sql.execution.datasources.jdbc.JDBCOptions
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

class JdbcSourceProvider extends StreamWriteSupport with DataSourceRegister{
  override def createStreamWriter(queryId: String, schema: StructType,
    mode: OutputMode, options: DataSourceOptions): StreamWriter = {
    val optionMap = options.asMap().asScala.toMap
    // add this for parameter check.
    new JDBCOptions(optionMap)
    new JdbcStreamWriter(schema, optionMap)
  }

  // short name 'jdbc' is used for batch, chose a different name for streaming.
  override def shortName(): String = "streaming-jdbc"
} 
Example 6
Source File: DefaultSource.scala    From spark-netezza   with Apache License 2.0 5 votes vote down vote up
package com.ibm.spark.netezza

import java.util.Properties
import org.apache.spark.sql.{SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, BaseRelation, RelationProvider}


  override def createRelation(
                               sqlContext: SQLContext,
                               parameters: Map[String, String]): BaseRelation = {
    val url = parameters.getOrElse("url", sys.error("Option 'Netezza database url' not specified"))
    val (table, isQuery) = parameters.get("dbtable").map(table => (table, false)).orElse {
      parameters.get("query")
        .map(q => (s"($q) as src", true))
        .orElse(sys.error("Option 'dbtable/query' should be specified."))
    }.get

    // TODO: Have to set it to the system default.
    // For query default is 1, when fetching from a table defauilt is 4. Data slice ca
    // can be used for partitioning when table is specified.
    val numPartitions = parameters.getOrElse("numPartitions", if (isQuery) "1" else "4").toInt

    val partitionCol = parameters.get("partitioncol")
    val lowerBound = parameters.get("lowerbound")
    val upperBound = parameters.get("upperbound")

    val properties = new Properties() // Additional properties that we will pass to getConnection
    parameters.foreach { case (k, v) => properties.setProperty(k, v) }

    val conn = NetezzaJdbcUtils.getConnector(url, properties)()
    val parts = try {
      if (partitionCol.isDefined || isQuery) {
        if (isQuery && numPartitions > 1 && !partitionCol.isDefined) {
          throw new IllegalArgumentException("Partition column should be specified or" +
            " number of partitions should be set to 1 with the query option.")
        }
        val partnInfo = PartitioningInfo(partitionCol, lowerBound, upperBound, numPartitions)
        NetezzaInputFormat.getColumnPartitions(conn, table, partnInfo)
      } else {
        // Partitions based on the data slices.
        NetezzaInputFormat.getDataSlicePartition(conn, numPartitions)
      }
    } finally { conn.close() }

    NetezzaRelation(url, table, parts, properties, numPartitions)(sqlContext)
  }
} 
Example 7
Source File: JdbcRelationProvider.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends CreatableRelationProvider
  with RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val partitionColumn = jdbcOptions.partitionColumn
    val lowerBound = jdbcOptions.lowerBound
    val upperBound = jdbcOptions.upperBound
    val numPartitions = jdbcOptions.numPartitions

    val partitionInfo = if (partitionColumn == null) {
      null
    } else {
      JDBCPartitioningInfo(
        partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val url = jdbcOptions.url
    val table = jdbcOptions.table
    val createTableOptions = jdbcOptions.createTableOptions
    val isTruncate = jdbcOptions.isTruncate

    val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, url, table)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, table)
              saveTable(df, url, table, jdbcOptions)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, table)
              createTable(df.schema, url, table, createTableOptions, conn)
              saveTable(df, url, table, jdbcOptions)
            }

          case SaveMode.Append =>
            saveTable(df, url, table, jdbcOptions)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '$table' already exists. SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(df.schema, url, table, createTableOptions, conn)
        saveTable(df, url, table, jdbcOptions)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }
} 
Example 8
Source File: HadoopFsRelation.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.types.StructType



case class HadoopFsRelation(
    location: FileIndex,
    partitionSchema: StructType,
    dataSchema: StructType,
    bucketSpec: Option[BucketSpec],
    fileFormat: FileFormat,
    options: Map[String, String])(val sparkSession: SparkSession)
  extends BaseRelation with FileRelation {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  val schema: StructType = {
    val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
    StructType(dataSchema ++ partitionSchema.filterNot { column =>
      dataSchemaColumnNames.contains(column.name.toLowerCase)
    })
  }

  def partitionSchemaOption: Option[StructType] =
    if (partitionSchema.isEmpty) None else Some(partitionSchema)

  override def toString: String = {
    fileFormat match {
      case source: DataSourceRegister => source.shortName()
      case _ => "HadoopFiles"
    }
  }

  override def sizeInBytes: Long = location.sizeInBytes

  override def inputFiles: Array[String] = location.inputFiles
} 
Example 9
Source File: console.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
  // Number of rows to display, by default 20 rows
  private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20)

  // Truncate the displayed data if it is too long, by default it is true
  private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true)

  // Track the batch id
  private var lastBatchId = -1L

  override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
    val batchIdStr = if (batchId <= lastBatchId) {
      s"Rerun batch: $batchId"
    } else {
      lastBatchId = batchId
      s"Batch: $batchId"
    }

    // scalastyle:off println
    println("-------------------------------------------")
    println(batchIdStr)
    println("-------------------------------------------")
    // scalastyle:off println
    data.sparkSession.createDataFrame(
      data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
      .show(numRowsToShow, isTruncated)
  }
}

class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister {
  def createSink(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new ConsoleSink(parameters)
  }

  def shortName(): String = "console"
} 
Example 10
Source File: DefaultSource.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import java.util.Properties

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister}

class DefaultSource extends RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
    val driver = parameters.getOrElse("driver", null)
    val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
    val partitionColumn = parameters.getOrElse("partitionColumn", null)
    val lowerBound = parameters.getOrElse("lowerBound", null)
    val upperBound = parameters.getOrElse("upperBound", null)
    val numPartitions = parameters.getOrElse("numPartitions", null)

    if (driver != null) DriverRegistry.register(driver)

    if (partitionColumn != null
      && (lowerBound == null || upperBound == null || numPartitions == null)) {
      sys.error("Partitioning incompletely specified")
    }

    val partitionInfo = if (partitionColumn == null) {
      null
    } else {
      JDBCPartitioningInfo(
        partitionColumn,
        lowerBound.toLong,
        upperBound.toLong,
        numPartitions.toInt)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    //我们将传递给getConnection的其他属性
    val properties = new Properties() // Additional properties that we will pass to getConnection
    parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
    JDBCRelation(url, table, parts, properties)(sqlContext)
  }
} 
Example 11
Source File: DefaultSource.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.util.Properties

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCRelation, JDBCPartitioningInfo, DriverRegistry}
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider}


class DefaultSource extends RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
    val driver = parameters.getOrElse("driver", null)
    val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
    val partitionColumn = parameters.getOrElse("partitionColumn", null)
    val lowerBound = parameters.getOrElse("lowerBound", null)
    val upperBound = parameters.getOrElse("upperBound", null)
    val numPartitions = parameters.getOrElse("numPartitions", null)

    if (driver != null) DriverRegistry.register(driver)

    if (partitionColumn != null
      && (lowerBound == null || upperBound == null || numPartitions == null)) {
      sys.error("Partitioning incompletely specified")
    }

    val partitionInfo = if (partitionColumn == null) {
      null
    } else {
      JDBCPartitioningInfo(
        partitionColumn,
        lowerBound.toLong,
        upperBound.toLong,
        numPartitions.toInt)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    //我们将传递给getConnection的其他属性
    val properties = new Properties() // Additional properties that we will pass to getConnection
    parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
    JDBCRelation(url, table, parts, properties)(sqlContext)
  }
} 
Example 12
Source File: JdbcRelationProvider.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends CreatableRelationProvider
  with RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    import JDBCOptions._

    val jdbcOptions = new JDBCOptions(parameters)
    val partitionColumn = jdbcOptions.partitionColumn
    val lowerBound = jdbcOptions.lowerBound
    val upperBound = jdbcOptions.upperBound
    val numPartitions = jdbcOptions.numPartitions

    val partitionInfo = if (partitionColumn.isEmpty) {
      assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " +
        s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty")
      null
    } else {
      assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
        s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " +
          s"'$JDBC_NUM_PARTITIONS' are also required")
      JDBCPartitioningInfo(
        partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val options = new JDBCOptions(parameters)
    val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis

    val conn = JdbcUtils.createConnectionFactory(options)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, options)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, options)
              val tableSchema = JdbcUtils.getSchemaOption(conn, options)
              saveTable(df, tableSchema, isCaseSensitive, options)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, options.table)
              createTable(conn, df, options)
              saveTable(df, Some(df.schema), isCaseSensitive, options)
            }

          case SaveMode.Append =>
            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
            saveTable(df, tableSchema, isCaseSensitive, options)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(conn, df, options)
        saveTable(df, Some(df.schema), isCaseSensitive, options)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }
} 
Example 13
Source File: SqsSourceProvider.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.sqs

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.StructType

class SqsSourceProvider extends DataSourceRegister
  with StreamSourceProvider
  with Logging {

  override def shortName(): String = "s3-sqs"

  override def sourceSchema(sqlContext: SQLContext,
                            schema: Option[StructType],
                            providerName: String,
                            parameters: Map[String, String]): (String, StructType) = {

    require(schema.isDefined, "Sqs source doesn't support empty schema")
    (shortName(), schema.get)
  }

  override def createSource(sqlContext: SQLContext,
                            metadataPath: String,
                            schema: Option[StructType],
                            providerName: String,
                            parameters: Map[String, String]): Source = {

    new SqsSource(
      sqlContext.sparkSession,
      metadataPath,
      parameters,
      schema.get)
  }
} 
Example 14
Source File: console.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
  extends BaseRelation {
  override def schema: StructType = data.schema
}

class ConsoleSinkProvider extends DataSourceV2
  with StreamWriteSupport
  with DataSourceRegister
  with CreatableRelationProvider {

  override def createStreamWriter(
      queryId: String,
      schema: StructType,
      mode: OutputMode,
      options: DataSourceOptions): StreamWriter = {
    new ConsoleWriter(schema, options)
  }

  def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      data: DataFrame): BaseRelation = {
    // Number of rows to display, by default 20 rows
    val numRowsToShow = parameters.get("numRows").map(_.toInt).getOrElse(20)

    // Truncate the displayed data if it is too long, by default it is true
    val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
    data.show(numRowsToShow, isTruncated)

    ConsoleRelation(sqlContext, data)
  }

  def shortName(): String = "console"
} 
Example 15
Source File: SageMakerProtobufFileFormat.scala    From sagemaker-spark   with Apache License 2.0 5 votes vote down vote up
package com.amazonaws.services.sagemaker.sparksdk.protobuf

import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType

class SageMakerProtobufFileFormat extends FileFormat with DataSourceRegister {

  override def inferSchema(sparkSession: SparkSession,
                           options: Map[String, String],
                           files: Seq[FileStatus]):
  Option[StructType] = {
    Option.empty
  }

  override def shortName(): String = "sagemaker"

  override def toString: String = "SageMaker"

  override def prepareWrite(
                             sparkSession: SparkSession,
                             job: Job,
                             options: Map[String, String],
                             dataSchema: StructType): OutputWriterFactory = {
    new OutputWriterFactory {
      override def newInstance(
                                path: String,
                                dataSchema: StructType,
                                context: TaskAttemptContext): OutputWriter = {
        new SageMakerProtobufWriter(path, context, dataSchema, options)
      }

      override def getFileExtension(context: TaskAttemptContext): String = {
        ".pbr"
      }
    }
  }
} 
Example 16
Source File: DefaultSource.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import java.util.Properties

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister}

class DefaultSource extends RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val url = parameters.getOrElse("url", sys.error("Option 'url' not specified"))
    val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified"))
    val partitionColumn = parameters.getOrElse("partitionColumn", null)
    val lowerBound = parameters.getOrElse("lowerBound", null)
    val upperBound = parameters.getOrElse("upperBound", null)
    val numPartitions = parameters.getOrElse("numPartitions", null)

    if (partitionColumn != null
      && (lowerBound == null || upperBound == null || numPartitions == null)) {
      sys.error("Partitioning incompletely specified")
    }

    val partitionInfo = if (partitionColumn == null) {
      null
    } else {
      JDBCPartitioningInfo(
        partitionColumn,
        lowerBound.toLong,
        upperBound.toLong,
        numPartitions.toInt)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    val properties = new Properties() // Additional properties that we will pass to getConnection
    parameters.foreach(kv => properties.setProperty(kv._1, kv._2))
    JDBCRelation(url, table, parts, properties)(sqlContext)
  }
} 
Example 17
Source File: KuduSinkProvider.scala    From kafka-examples   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.streaming.refapp.kudu

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode


class KuduSinkProvider extends StreamSinkProvider with DataSourceRegister {

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {
    require(outputMode == OutputMode.Update, "only 'update' OutputMode is supported")
    KuduSink.withDefaultContext(sqlContext, parameters)
  }

  override def shortName(): String = "kudu"
} 
Example 18
Source File: KuduSinkProvider.scala    From kafka-examples   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.streaming.refapp.kudu

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode


class KuduSinkProvider extends StreamSinkProvider with DataSourceRegister {

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {
    require(outputMode == OutputMode.Update, "only 'update' OutputMode is supported")
    KuduSink.withDefaultContext(sqlContext, parameters)
  }

  override def shortName(): String = "kudu"
} 
Example 19
Source File: BigBgenDatasource.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.bgen

import java.io.ByteArrayOutputStream

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLUtils}
import org.apache.spark.sql.sources.DataSourceRegister

import io.projectglow.common.logging.{HlsEventRecorder, HlsTagValues}
import io.projectglow.sql.BigFileDatasource
import io.projectglow.sql.util.ComDatabricksDataSource

class BigBgenDatasource extends BigFileDatasource with DataSourceRegister {

  override def shortName(): String = "bigbgen"

  override def serializeDataFrame(
      options: Map[String, String],
      data: DataFrame): RDD[Array[Byte]] = {
    BigBgenDatasource.serializeDataFrame(options, data)
  }

}

class ComDatabricksBigBgenDatasource extends BigBgenDatasource with ComDatabricksDataSource

object BigBgenDatasource extends HlsEventRecorder {

  import io.projectglow.common.BgenOptions._

  private def parseOptions(options: Map[String, String]): BigBgenOptions = {
    val bitsPerProb = options.getOrElse(BITS_PER_PROB_KEY, BITS_PER_PROB_DEFAULT_VALUE).toInt
    val maxPloidy = options.getOrElse(MAX_PLOIDY_KEY, MAX_PLOIDY_VALUE).toInt
    val defaultPloidy = options.getOrElse(DEFAULT_PLOIDY_KEY, DEFAULT_PLOIDY_VALUE).toInt
    val defaultPhasing = options.getOrElse(DEFAULT_PHASING_KEY, DEFAULT_PHASING_VALUE).toBoolean
    BigBgenOptions(bitsPerProb, maxPloidy, defaultPloidy, defaultPhasing)
  }

  private def logBgenWrite(parsedOptions: BigBgenOptions): Unit = {
    val logOptions = Map(
      BITS_PER_PROB_KEY -> parsedOptions.bitsPerProb,
      MAX_PLOIDY_KEY -> parsedOptions.maxPloidy,
      DEFAULT_PLOIDY_KEY -> parsedOptions.defaultPloidy,
      DEFAULT_PHASING_KEY -> parsedOptions.defaultPhasing
    )
    recordHlsEvent(HlsTagValues.EVENT_BGEN_WRITE, logOptions)
  }

  def serializeDataFrame(options: Map[String, String], data: DataFrame): RDD[Array[Byte]] = {

    val parsedOptions = parseOptions(options)
    logBgenWrite(parsedOptions)

    val dSchema = data.schema
    val numVariants = data.count
    val rawRdd = data.queryExecution.toRdd

    val inputRdd = if (rawRdd.getNumPartitions == 0) {
      logger.warn("Writing BGEN header only as the input DataFrame has zero partitions.")
      SQLUtils.createEmptyRDD(data.sparkSession)
    } else {
      rawRdd
    }

    inputRdd.mapPartitionsWithIndex {
      case (idx, it) =>
        val baos = new ByteArrayOutputStream()

        val writeHeader = idx == 0
        val writer = new BgenRecordWriter(
          baos,
          dSchema,
          writeHeader,
          numVariants,
          parsedOptions.bitsPerProb,
          parsedOptions.maxPloidy,
          parsedOptions.defaultPloidy,
          parsedOptions.defaultPhasing
        )

        it.foreach { row =>
          writer.write(row)
        }

        writer.close()
        Iterator(baos.toByteArray)
    }
  }
}

case class BigBgenOptions(
    bitsPerProb: Int,
    maxPloidy: Int,
    defaultPloidy: Int,
    defaultPhasing: Boolean) 
Example 20
Source File: HelloWorldDataSource.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package com.github.dnvriend.spark.datasources.helloworld

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{ BaseRelation, DataSourceRegister, RelationProvider, TableScan }
import org.apache.spark.sql.types.{ StringType, StructField, StructType }
import org.apache.spark.sql.{ Row, SQLContext }

class HelloWorldDataSource extends RelationProvider with DataSourceRegister with Serializable {
  override def shortName(): String = "helloworld"

  override def hashCode(): Int = getClass.hashCode()

  override def equals(other: scala.Any): Boolean = other.isInstanceOf[HelloWorldDataSource]

  override def toString: String = "HelloWorldDataSource"

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    val path = parameters.get("path")
    path match {
      case Some(p) => new HelloWorldRelationProvider(sqlContext, p, parameters)
      case _       => throw new IllegalArgumentException("Path is required for Tickets datasets")
    }
  }
}

class HelloWorldRelationProvider(val sqlContext: SQLContext, path: String, parameters: Map[String, String]) extends BaseRelation with TableScan {
  import sqlContext.implicits._

  override def schema: StructType = StructType(Array(
    StructField("key", StringType, nullable = false),
    StructField("value", StringType, nullable = true)
  ))

  override def buildScan(): RDD[Row] =
    Seq(
      "path" -> path,
      "message" -> parameters.getOrElse("message", ""),
      "name" -> s"Hello ${parameters.getOrElse("name", "")}",
      "hello_world" -> "Hello World!"
    ).toDF.rdd
} 
Example 21
Source File: CurrentEventsByPersistenceIdQuerySourceProvider.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package akka.persistence.jdbc.spark.sql.execution.streaming

import org.apache.spark.sql.execution.streaming.{ LongOffset, Offset, Source }
import org.apache.spark.sql.sources.{ DataSourceRegister, StreamSourceProvider }
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{ SQLContext, _ }

object CurrentEventsByPersistenceIdQuerySourceProvider {
  val name = "current-events-by-persistence-id"
}

class CurrentEventsByPersistenceIdQuerySourceProvider extends StreamSourceProvider with DataSourceRegister with Serializable {
  override def sourceSchema(
    sqlContext: SQLContext,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): (String, StructType) = {
    println(s"[CurrentEventsByPersistenceIdQuerySourceProvider.sourceSchema]: schema: $schema, providerName: $providerName, parameters: $parameters")
    CurrentEventsByPersistenceIdQuerySourceProvider.name -> schema.get
  }

  override def createSource(
    sqlContext: SQLContext,
    metadataPath: String,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): Source = {

    val eventMapperFQCN: String = parameters.get("event-mapper") match {
      case Some(_eventMapper) => _eventMapper
      case _                  => throw new RuntimeException("No event mapper FQCN")
    }

    val pid = (parameters.get("pid"), parameters.get("persistence-id")) match {
      case (Some(pid), _) => pid
      case (_, Some(pid)) => pid
      case _              => throw new RuntimeException("No persistence_id")
    }

    new CurrentEventsByPersistenceIdQuerySourceImpl(sqlContext, parameters("path"), eventMapperFQCN, pid, schema.get)
  }
  override def shortName(): String = CurrentEventsByPersistenceIdQuerySourceProvider.name
}

class CurrentEventsByPersistenceIdQuerySourceImpl(val sqlContext: SQLContext, val readJournalPluginId: String, eventMapperFQCN: String, persistenceId: String, override val schema: StructType) extends Source with ReadJournalSource {
  override def getOffset: Option[Offset] = {
    val offset = maxEventsByPersistenceId(persistenceId)
    println("[CurrentEventsByPersistenceIdQuery]: Returning maximum offset: " + offset)
    Some(LongOffset(offset))
  }

  override def getBatch(_start: Option[Offset], _end: Offset): DataFrame = {
    val (start, end) = getStartEnd(_start, _end)
    val df: DataFrame = eventsByPersistenceId(persistenceId, start, end, eventMapperFQCN)
    println(s"[CurrentEventsByPersistenceIdQuery]: Getting currentPersistenceIds from start: $start, end: $end, DataFrame.count: ${df.count}")
    df
  }
} 
Example 22
Source File: CurrentPersistenceIdsQuerySourceProvider.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package akka.persistence.jdbc.spark.sql.execution.streaming

import org.apache.spark.sql.execution.streaming.{ LongOffset, Offset, Source }
import org.apache.spark.sql.sources.{ DataSourceRegister, StreamSourceProvider }
import org.apache.spark.sql.types.{ StringType, StructField, StructType }
import org.apache.spark.sql.{ SQLContext, _ }

object CurrentPersistenceIdsQuerySourceProvider {
  val name = "current-persistence-id"
  val schema: StructType = StructType(Array(
    StructField("persistence_id", StringType, nullable = false)
  ))
}

class CurrentPersistenceIdsQuerySourceProvider extends StreamSourceProvider with DataSourceRegister with Serializable {
  override def sourceSchema(
    sqlContext: SQLContext,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): (String, StructType) = {
    CurrentPersistenceIdsQuerySourceProvider.name -> CurrentPersistenceIdsQuerySourceProvider.schema
  }

  override def createSource(
    sqlContext: SQLContext,
    metadataPath: String,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): Source = {
    new CurrentPersistenceIdsQuerySourceImpl(sqlContext, parameters("path"))
  }
  override def shortName(): String = CurrentPersistenceIdsQuerySourceProvider.name
}

class CurrentPersistenceIdsQuerySourceImpl(val sqlContext: SQLContext, val readJournalPluginId: String) extends Source with ReadJournalSource {
  override def schema: StructType = CurrentPersistenceIdsQuerySourceProvider.schema

  override def getOffset: Option[Offset] = {
    val offset = maxPersistenceIds
    println("[CurrentPersistenceIdsQuery]: Returning maximum offset: " + offset)
    Some(LongOffset(offset))
  }

  override def getBatch(_start: Option[Offset], _end: Offset): DataFrame = {
    val (start, end) = getStartEnd(_start, _end)
    println(s"[CurrentPersistenceIdsQuery]: Getting currentPersistenceIds from start: $start, end: $end")
    import sqlContext.implicits._
    persistenceIds(start, end).toDF()
  }
} 
Example 23
Source File: DefaultSource.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.sql

import org.apache.spark.sql.{ DataFrame, SQLContext, SaveMode }
import org.apache.spark.sql.sources.{ BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider, SchemaRelationProvider }
import org.apache.spark.sql.types.StructType

import com.actian.spark_vector.util.Logging
import com.actian.spark_vector.vector.VectorJDBC

class DefaultSource extends DataSourceRegister with RelationProvider with SchemaRelationProvider with CreatableRelationProvider with Logging {
  override def shortName(): String = "vector"

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation =
    VectorRelation(TableRef(parameters), sqlContext, parameters)

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation =
    VectorRelation(TableRef(parameters), Some(schema), sqlContext, parameters)

  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
    val tableRef = TableRef(parameters)
    val table = VectorRelation(tableRef, sqlContext, parameters)

    mode match {
      case SaveMode.Overwrite =>
        table.insert(data, true)
      case SaveMode.ErrorIfExists =>
        val isEmpty = VectorJDBC.withJDBC(tableRef.toConnectionProps) { _.isTableEmpty(tableRef.table) }
        if (isEmpty) {
          table.insert(data, false)
        } else {
          throw new UnsupportedOperationException("Writing to a non-empty Vector table is not allowed with mode ErrorIfExists.")
        }
      case SaveMode.Append =>
        table.insert(data, false)
      case SaveMode.Ignore =>
        val isEmpty = VectorJDBC.withJDBC(tableRef.toConnectionProps) { _.isTableEmpty(tableRef.table) }
        if (isEmpty) {
          table.insert(data, false)
        }
    }

    table
  }
} 
Example 24
Source File: HadoopFsRelation.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.types.{StructField, StructType}



case class HadoopFsRelation(
    location: FileIndex,
    partitionSchema: StructType,
    dataSchema: StructType,
    bucketSpec: Option[BucketSpec],
    fileFormat: FileFormat,
    options: Map[String, String])(val sparkSession: SparkSession)
  extends BaseRelation with FileRelation {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  private def getColName(f: StructField): String = {
    if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
      f.name
    } else {
      f.name.toLowerCase(Locale.ROOT)
    }
  }

  val overlappedPartCols = mutable.Map.empty[String, StructField]
  partitionSchema.foreach { partitionField =>
    if (dataSchema.exists(getColName(_) == getColName(partitionField))) {
      overlappedPartCols += getColName(partitionField) -> partitionField
    }
  }

  // When data and partition schemas have overlapping columns, the output
  // schema respects the order of the data schema for the overlapping columns, and it
  // respects the data types of the partition schema.
  val schema: StructType = {
    StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++
      partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f))))
  }

  def partitionSchemaOption: Option[StructType] =
    if (partitionSchema.isEmpty) None else Some(partitionSchema)

  override def toString: String = {
    fileFormat match {
      case source: DataSourceRegister => source.shortName()
      case _ => "HadoopFiles"
    }
  }

  override def sizeInBytes: Long = {
    val compressionFactor = sqlContext.conf.fileCompressionFactor
    (location.sizeInBytes * compressionFactor).toLong
  }


  override def inputFiles: Array[String] = location.inputFiles
} 
Example 25
Source File: HadoopFsRelation.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.types.StructType



case class HadoopFsRelation(
    location: FileCatalog,
    partitionSchema: StructType,
    dataSchema: StructType,
    bucketSpec: Option[BucketSpec],
    fileFormat: FileFormat,
    options: Map[String, String])(val sparkSession: SparkSession)
  extends BaseRelation with FileRelation {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  val schema: StructType = {
    val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
    StructType(dataSchema ++ partitionSchema.filterNot { column =>
      dataSchemaColumnNames.contains(column.name.toLowerCase)
    })
  }

  def partitionSchemaOption: Option[StructType] =
    if (partitionSchema.isEmpty) None else Some(partitionSchema)

  override def toString: String = {
    fileFormat match {
      case source: DataSourceRegister => source.shortName()
      case _ => "HadoopFiles"
    }
  }

  override def sizeInBytes: Long = location.sizeInBytes

  override def inputFiles: Array[String] = location.inputFiles
} 
Example 26
Source File: console.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
  // Number of rows to display, by default 20 rows
  private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20)

  // Truncate the displayed data if it is too long, by default it is true
  private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true)

  // Track the batch id
  private var lastBatchId = -1L

  override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
    val batchIdStr = if (batchId <= lastBatchId) {
      s"Rerun batch: $batchId"
    } else {
      lastBatchId = batchId
      s"Batch: $batchId"
    }

    // scalastyle:off println
    println("-------------------------------------------")
    println(batchIdStr)
    println("-------------------------------------------")
    // scalastyle:off println
    data.sparkSession.createDataFrame(
      data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
      .show(numRowsToShow, isTruncated)
  }
}

class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister {
  def createSink(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new ConsoleSink(parameters)
  }

  def shortName(): String = "console"
} 
Example 27
Source File: S2SinkProvider.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.spark.sql.streaming

import com.typesafe.config.{Config, ConfigFactory, ConfigRenderOptions}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

import scala.collection.JavaConversions._

class S2SinkProvider extends StreamSinkProvider with DataSourceRegister with Logger {
  override def createSink(
                  sqlContext: SQLContext,
                  parameters: Map[String, String],
                  partitionColumns: Seq[String],
                  outputMode: OutputMode): Sink = {

    logger.info(s"S2SinkProvider options : ${parameters}")
    val jobConf:Config = ConfigFactory.parseMap(parameters).withFallback(ConfigFactory.load())
    logger.info(s"S2SinkProvider Configuration : ${jobConf.root().render(ConfigRenderOptions.concise())}")

    new S2SparkSqlStreamingSink(sqlContext.sparkSession, jobConf)
  }

  override def shortName(): String = "s2graph"
} 
Example 28
Source File: SelectJSONSource.scala    From spark-select   with Apache License 2.0 5 votes vote down vote up
package io.minio.spark.select

// Java standard libraries
import java.io.File

// Spark internal libraries
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

import org.apache.spark.sql.sources.DataSourceRegister

class SelectJSONSource
  extends SchemaRelationProvider
  with DataSourceRegister {

  private def checkPath(parameters: Map[String, String]): String = {
    parameters.getOrElse("path", sys.error("'path' must be specified for JSON data."))
  }

  
  override def shortName(): String = "minioSelectJSON"

  override def createRelation(sqlContext: SQLContext, params: Map[String, String], schema: StructType): SelectJSONRelation = {
    val path = checkPath(params)
    SelectJSONRelation(Some(path), params, schema)(sqlContext)
  }
} 
Example 29
Source File: SelectCSVSource.scala    From spark-select   with Apache License 2.0 5 votes vote down vote up
package io.minio.spark.select

// Java standard libraries
import java.io.File

// Spark internal libraries
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

import org.apache.spark.sql.sources.DataSourceRegister

class SelectCSVSource
  extends SchemaRelationProvider
  with DataSourceRegister {

  private def checkPath(parameters: Map[String, String]): String = {
    parameters.getOrElse("path", sys.error("'path' must be specified for CSV data."))
  }

  
  override def shortName(): String = "minioSelectCSV"

  override def createRelation(sqlContext: SQLContext, params: Map[String, String], schema: StructType): SelectCSVRelation = {
    val path = checkPath(params)
    SelectCSVRelation(Some(path), params, schema)(sqlContext)
  }
} 
Example 30
Source File: SelectParquetSource.scala    From spark-select   with Apache License 2.0 5 votes vote down vote up
package io.minio.spark.select

// Java standard libraries
import java.io.File

// Spark internal libraries
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

import org.apache.spark.sql.sources.DataSourceRegister

class SelectParquetSource
  extends SchemaRelationProvider
  with DataSourceRegister {

  private def checkPath(parameters: Map[String, String]): String = {
    parameters.getOrElse("path", sys.error("'path' must be specified for Parquet data."))
  }

  
  override def shortName(): String = "minioSelectParquet"

  override def createRelation(sqlContext: SQLContext, params: Map[String, String], schema: StructType): SelectParquetRelation = {
    val path = checkPath(params)
    SelectParquetRelation(Some(path), params, schema)(sqlContext)
  }
} 
Example 31
Source File: RedisStreamProvider.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.spark.sql.redis.stream

import com.redislabs.provider.redis.util.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType}


class RedisStreamProvider extends DataSourceRegister with StreamSourceProvider with Logging {

  override def shortName(): String = "redis"

  override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType],
                            providerName: String, parameters: Map[String, String]): (String, StructType) = {
    providerName -> schema.getOrElse {
      StructType(Seq(StructField("_id", StringType)))
    }
  }

  override def createSource(sqlContext: SQLContext, metadataPath: String,
                            schema: Option[StructType], providerName: String,
                            parameters: Map[String, String]): Source = {
    val (_, ss) = sourceSchema(sqlContext, schema, providerName, parameters)
    val source = new RedisSource(sqlContext, metadataPath, Some(ss), parameters)
    source.start()
    source
  }
} 
Example 32
Source File: DefaultSource.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.hbase

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

class CustomedDefaultSource
  extends DefaultSource
  with DataSourceRegister
  with SchemaRelationProvider {

  override def shortName(): String = "hbase"

  
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      schema: StructType): BaseRelation = {
    new CustomedHBaseRelation(parameters, Option(schema))(sqlContext)
  }
} 
Example 33
Source File: DataSourceV2StringFormat.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import org.apache.commons.lang3.StringUtils

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.DataSourceV2
import org.apache.spark.util.Utils


  def pushedFilters: Seq[Expression]

  private def sourceName: String = source match {
    case registered: DataSourceRegister => registered.shortName()
    // source.getClass.getSimpleName can cause Malformed class name error,
    // call safer `Utils.getSimpleName` instead
    case _ => Utils.getSimpleName(source.getClass)
  }

  def metadataString: String = {
    val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)]

    if (pushedFilters.nonEmpty) {
      entries += "Filters" -> pushedFilters.mkString("[", ", ", "]")
    }

    // TODO: we should only display some standard options like path, table, etc.
    if (options.nonEmpty) {
      entries += "Options" -> Utils.redact(options).map {
        case (k, v) => s"$k=$v"
      }.mkString("[", ",", "]")
    }

    val outputStr = Utils.truncatedString(output, "[", ", ", "]")

    val entriesStr = if (entries.nonEmpty) {
      Utils.truncatedString(entries.map {
        case (key, value) => key + ": " + StringUtils.abbreviate(value, 100)
      }, " (", ", ", ")")
    } else {
      ""
    }

    s"$sourceName$outputStr$entriesStr"
  }
} 
Example 34
Source File: JdbcRelationProvider.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends CreatableRelationProvider
  with RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val resolver = sqlContext.conf.resolver
    val timeZoneId = sqlContext.conf.sessionLocalTimeZone
    val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
    val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
    JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val options = new JdbcOptionsInWrite(parameters)
    val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis

    val conn = JdbcUtils.createConnectionFactory(options)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, options)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, options)
              val tableSchema = JdbcUtils.getSchemaOption(conn, options)
              saveTable(df, tableSchema, isCaseSensitive, options)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, options.table, options)
              createTable(conn, df, options)
              saveTable(df, Some(df.schema), isCaseSensitive, options)
            }

          case SaveMode.Append =>
            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
            saveTable(df, tableSchema, isCaseSensitive, options)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '${options.table}' already exists. " +
                s"SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(conn, df, options)
        saveTable(df, Some(df.schema), isCaseSensitive, options)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }
} 
Example 35
Source File: DefaultSource.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb.datasource

import java.util.Optional

import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2.reader.DataSourceReader
import org.apache.spark.sql.sources.v2.writer.DataSourceWriter
import org.apache.spark.sql.sources.v2.{DataSourceOptions, ReadSupport, WriteSupport}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{SaveMode, SparkSession}
import org.slf4j.LoggerFactory

import scala.collection.JavaConverters._

class DefaultSource extends ReadSupport with WriteSupport with DataSourceRegister {

    private val logger = LoggerFactory.getLogger(this.getClass)

    override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = {
        val optionsMap = options.asMap().asScala
        val defaultParallelism = optionsMap.get("defaultparallelism").map(_.toInt).getOrElse(getDefaultParallelism)
        new DynamoDataSourceReader(defaultParallelism, Map(optionsMap.toSeq: _*), Some(schema))
    }

    override def createReader(options: DataSourceOptions): DataSourceReader = {
        val optionsMap = options.asMap().asScala
        val defaultParallelism = optionsMap.get("defaultparallelism").map(_.toInt).getOrElse(getDefaultParallelism)
        new DynamoDataSourceReader(defaultParallelism, Map(optionsMap.toSeq: _*))
    }

    override def createWriter(writeUUID: String, schema: StructType, mode: SaveMode, options: DataSourceOptions): Optional[DataSourceWriter] = {
        if (mode == SaveMode.Append || mode == SaveMode.Overwrite)
            throw new IllegalArgumentException(s"DynamoDB data source does not support save modes ($mode)." +
                " Please use option 'update' (true | false) to differentiate between append/overwrite and append/update behavior.")
        val optionsMap = options.asMap().asScala
        val defaultParallelism = optionsMap.get("defaultparallelism").map(_.toInt).getOrElse(getDefaultParallelism)
        val writer = new DynamoDataSourceWriter(defaultParallelism, Map(optionsMap.toSeq: _*), schema)
        Optional.of(writer)
    }

    override def shortName(): String = "dynamodb"

    private def getDefaultParallelism: Int =
        SparkSession.getActiveSession match {
            case Some(spark) => spark.sparkContext.defaultParallelism
            case None =>
                logger.warn("Unable to read defaultParallelism from SparkSession." +
                    " Parallelism will be 1 unless overwritten with option `defaultParallelism`")
                1
        }

} 
Example 36
Source File: console.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql._
import org.apache.spark.sql.execution.streaming.sources.ConsoleWriter
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType

case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame)
  extends BaseRelation {
  override def schema: StructType = data.schema
}

class ConsoleSinkProvider extends DataSourceV2
  with StreamWriteSupport
  with DataSourceRegister
  with CreatableRelationProvider {

  override def createStreamWriter(
      queryId: String,
      schema: StructType,
      mode: OutputMode,
      options: DataSourceOptions): StreamWriter = {
    new ConsoleWriter(schema, options)
  }

  def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      data: DataFrame): BaseRelation = {
    // Number of rows to display, by default 20 rows
    val numRowsToShow = parameters.get("numRows").map(_.toInt).getOrElse(20)

    // Truncate the displayed data if it is too long, by default it is true
    val isTruncated = parameters.get("truncate").map(_.toBoolean).getOrElse(true)
    data.show(numRowsToShow, isTruncated)

    ConsoleRelation(sqlContext, data)
  }

  def shortName(): String = "console"
} 
Example 37
Source File: RateStreamProvider.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.sources

import java.util.Optional

import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader}
import org.apache.spark.sql.types._


  def valueAtSecond(seconds: Long, rowsPerSecond: Long, rampUpTimeSeconds: Long): Long = {
    // E.g., rampUpTimeSeconds = 4, rowsPerSecond = 10
    // Then speedDeltaPerSecond = 2
    //
    // seconds   = 0 1 2  3  4  5  6
    // speed     = 0 2 4  6  8 10 10 (speedDeltaPerSecond * seconds)
    // end value = 0 2 6 12 20 30 40 (0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2
    val speedDeltaPerSecond = rowsPerSecond / (rampUpTimeSeconds + 1)
    if (seconds <= rampUpTimeSeconds) {
      // Calculate "(0 + speedDeltaPerSecond * seconds) * (seconds + 1) / 2" in a special way to
      // avoid overflow
      if (seconds % 2 == 1) {
        (seconds + 1) / 2 * speedDeltaPerSecond * seconds
      } else {
        seconds / 2 * speedDeltaPerSecond * (seconds + 1)
      }
    } else {
      // rampUpPart is just a special case of the above formula: rampUpTimeSeconds == seconds
      val rampUpPart = valueAtSecond(rampUpTimeSeconds, rowsPerSecond, rampUpTimeSeconds)
      rampUpPart + (seconds - rampUpTimeSeconds) * rowsPerSecond
    }
  }
} 
Example 38
Source File: HttpStreamSink.scala    From spark-http-stream   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.http

import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.StreamSinkProvider
import org.apache.spark.sql.streaming.OutputMode

import Params.map2Params

class HttpStreamSinkProvider
		extends StreamSinkProvider with DataSourceRegister {
	def createSink(
		sqlContext: SQLContext,
		parameters: Map[String, String],
		partitionColumns: Seq[String],
		outputMode: OutputMode): Sink = {
		new HttpStreamSink(parameters.getRequiredString("httpServletUrl"),
			parameters.getRequiredString("topic"),
			parameters.getInt("maxPacketSize", 10 * 1024 * 1024));
	}

	def shortName(): String = "httpStream"
}

class HttpStreamSink(httpPostURL: String, topic: String, maxPacketSize: Int)
		extends Sink with Logging {
	val producer = HttpStreamClient.connect(httpPostURL);
	val RETRY_TIMES = 5;
	val SLEEP_TIME = 100;

	override def addBatch(batchId: Long, data: DataFrame) {
		//send data to the HTTP server
		var success = false;
		var retried = 0;
		while (!success && retried < RETRY_TIMES) {
			try {
				retried += 1;
				producer.sendDataFrame(topic, batchId, data, maxPacketSize);
				success = true;
			}
			catch {
				case e: Throwable ⇒ {
					success = false;
					super.logWarning(s"failed to send", e);
					if (retried < RETRY_TIMES) {
						val sleepTime = SLEEP_TIME * retried;
						super.logWarning(s"will retry to send after ${sleepTime}ms");
						Thread.sleep(sleepTime);
					}
					else {
						throw e;
					}
				}
			}
		}
	}
} 
Example 39
Source File: KustoSinkProvider.scala    From azure-kusto-spark   with Apache License 2.0 5 votes vote down vote up
package com.microsoft.kusto.spark.datasink

import com.microsoft.kusto.spark.utils.{KeyVaultUtils, KustoDataSourceUtils}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

class KustoSinkProvider extends StreamSinkProvider with DataSourceRegister {

  override def shortName(): String = "KustoSink"

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {
    val sinkParameters = KustoDataSourceUtils.parseSinkParameters(parameters)

    new KustoSink(
      sqlContext,
      sinkParameters.sourceParametersResults.kustoCoordinates,
      if(sinkParameters.sourceParametersResults.keyVaultAuth.isDefined){
        val paramsFromKeyVault = KeyVaultUtils.getAadAppParametersFromKeyVault(sinkParameters.sourceParametersResults.keyVaultAuth.get)
        KustoDataSourceUtils.mergeKeyVaultAndOptionsAuthentication(paramsFromKeyVault, Some(sinkParameters.sourceParametersResults.authenticationParameters))
      } else sinkParameters.sourceParametersResults.authenticationParameters,
      sinkParameters.writeOptions
    )
  }
} 
Example 40
Source File: ArrowFileFormat.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package com.intel.oap.spark.sql.execution.datasources.arrow

import scala.collection.JavaConverters._

import com.intel.oap.spark.sql.execution.datasources.arrow.ArrowFileFormat.UnsafeItr
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.{ArrowFilters, ArrowOptions}
import com.intel.oap.spark.sql.execution.datasources.v2.arrow.ArrowSQLConf._
import org.apache.arrow.dataset.scanner.ScanOptions
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.FileStatus
import org.apache.hadoop.mapreduce.Job

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.v2.arrow.ArrowUtils
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

class ArrowFileFormat extends FileFormat with DataSourceRegister with Serializable {

  val batchSize = 4096

  def convert(files: Seq[FileStatus], options: Map[String, String]): Option[StructType] = {
    ArrowUtils.readSchema(files, new CaseInsensitiveStringMap(options.asJava))
  }

  override def inferSchema(
                            sparkSession: SparkSession,
                            options: Map[String, String],
                            files: Seq[FileStatus]): Option[StructType] = {
    convert(files, options)
  }

  override def prepareWrite(
                             sparkSession: SparkSession,
                             job: Job,
                             options: Map[String, String],
                             dataSchema: StructType): OutputWriterFactory = {
    throw new UnsupportedOperationException("Write is not supported for Arrow source")
  }

  override def supportBatch(sparkSession: SparkSession, dataSchema: StructType): Boolean = true

  override def buildReaderWithPartitionValues(sparkSession: SparkSession,
      dataSchema: StructType,
      partitionSchema: StructType,
      requiredSchema: StructType,
      filters: Seq[Filter],
      options: Map[String, String],
      hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
    (file: PartitionedFile) => {

      val sqlConf = sparkSession.sessionState.conf;
      val enableFilterPushDown = sqlConf.arrowFilterPushDown
      val factory = ArrowUtils.makeArrowDiscovery(
        file.filePath, new ArrowOptions(
          new CaseInsensitiveStringMap(
            options.asJava).asScala.toMap))

      // todo predicate validation / pushdown
      val dataset = factory.finish();

      val filter = if (enableFilterPushDown) {
        ArrowFilters.translateFilters(filters)
      } else {
        org.apache.arrow.dataset.filter.Filter.EMPTY
      }

      val scanOptions = new ScanOptions(requiredSchema.map(f => f.name).toArray,
        filter, batchSize)
      val scanner = dataset.newScan(scanOptions)
      val itrList = scanner
        .scan()
        .iterator()
        .asScala
        .map(task => task.scan())
        .toList

      val itr = itrList
        .toIterator
        .flatMap(itr => itr.asScala)
        .map(vsr => ArrowUtils.loadVsr(vsr, file.partitionValues, partitionSchema, dataSchema))
      new UnsafeItr(itr).asInstanceOf[Iterator[InternalRow]]
    }
  }

  override def shortName(): String = "arrow"
}

object ArrowFileFormat {
  class UnsafeItr[T](delegate: Iterator[T]) extends Iterator[T] {
    override def hasNext: Boolean = delegate.hasNext

    override def next(): T = delegate.next()
  }
} 
Example 41
Source File: CustomSink.scala    From spark-structured-streaming-ml   with Apache License 2.0 5 votes vote down vote up
package com.highperformancespark.examples.structuredstreaming

import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql._
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.execution.streaming.Sink


//tag::foreachDatasetSink[]

  override def addBatch(batchId: Long, data: DataFrame) = {
    val batchDistinctCount = data.rdd.distinct.count()
    println(s"Batch ${batchId}'s distinct count is ${batchDistinctCount}")
  }
}
//end::basicSink[]
object CustomSinkDemo {
  def write(ds: Dataset[_]) = {
    //tag::customSinkDemo[]
    ds.writeStream.format(
      "com.highperformancespark.examples.structuredstreaming." +
        "BasicSinkProvider")
      .queryName("customSinkDemo")
      .start()
    //end::customSinkDemo[]
  }
} 
Example 42
Source File: JdbcRelationProvider.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends CreatableRelationProvider
  with RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val partitionColumn = jdbcOptions.partitionColumn
    val lowerBound = jdbcOptions.lowerBound
    val upperBound = jdbcOptions.upperBound
    val numPartitions = jdbcOptions.numPartitions

    val partitionInfo = if (partitionColumn == null) {
      null
    } else {
      JDBCPartitioningInfo(
        partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val url = jdbcOptions.url
    val table = jdbcOptions.table
    val createTableOptions = jdbcOptions.createTableOptions
    val isTruncate = jdbcOptions.isTruncate

    val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, url, table)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, table)
              saveTable(df, url, table, jdbcOptions)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, table)
              createTable(df.schema, url, table, createTableOptions, conn)
              saveTable(df, url, table, jdbcOptions)
            }

          case SaveMode.Append =>
            saveTable(df, url, table, jdbcOptions)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '$table' already exists. SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(df.schema, url, table, createTableOptions, conn)
        saveTable(df, url, table, jdbcOptions)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }
} 
Example 43
Source File: HadoopFsRelation.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.types.StructType



case class HadoopFsRelation(
    location: FileIndex,
    partitionSchema: StructType,
    dataSchema: StructType,
    bucketSpec: Option[BucketSpec],
    fileFormat: FileFormat,
    options: Map[String, String])(val sparkSession: SparkSession)
  extends BaseRelation with FileRelation {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  val schema: StructType = {
    val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
    StructType(dataSchema ++ partitionSchema.filterNot { column =>
      dataSchemaColumnNames.contains(column.name.toLowerCase)
    })
  }

  def partitionSchemaOption: Option[StructType] =
    if (partitionSchema.isEmpty) None else Some(partitionSchema)

  override def toString: String = {
    fileFormat match {
      case source: DataSourceRegister => source.shortName()
      case _ => "HadoopFiles"
    }
  }

  override def sizeInBytes: Long = location.sizeInBytes

  override def inputFiles: Array[String] = location.inputFiles
} 
Example 44
Source File: console.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
  // Number of rows to display, by default 20 rows
  private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20)

  // Truncate the displayed data if it is too long, by default it is true
  private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true)

  // Track the batch id
  private var lastBatchId = -1L

  override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
    val batchIdStr = if (batchId <= lastBatchId) {
      s"Rerun batch: $batchId"
    } else {
      lastBatchId = batchId
      s"Batch: $batchId"
    }

    // scalastyle:off println
    println("-------------------------------------------")
    println(batchIdStr)
    println("-------------------------------------------")
    // scalastyle:off println
    data.sparkSession.createDataFrame(
      data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
      .show(numRowsToShow, isTruncated)
  }
}

class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister {
  def createSink(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new ConsoleSink(parameters)
  }

  def shortName(): String = "console"
}