org.apache.spark.sql.SQLContext Scala Examples

The following examples show how to use org.apache.spark.sql.SQLContext. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: DefaultSource.scala    From spark-snowflake   with Apache License 2.0 7 votes vote down vote up
package net.snowflake.spark.snowflake

import net.snowflake.spark.snowflake.streaming.SnowflakeSink
import net.snowflake.spark.snowflake.Utils.SNOWFLAKE_SOURCE_SHORT_NAME
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.slf4j.LoggerFactory


  override def createRelation(sqlContext: SQLContext,
                              saveMode: SaveMode,
                              parameters: Map[String, String],
                              data: DataFrame): BaseRelation = {

    val params = Parameters.mergeParameters(parameters)
    // check spark version for push down
    if (params.autoPushdown) {
      SnowflakeConnectorUtils.checkVersionAndEnablePushdown(
        sqlContext.sparkSession
      )
    }
    // pass parameters to pushdown functions
    pushdowns.setGlobalParameter(params)
    val table = params.table.getOrElse {
      throw new IllegalArgumentException(
        "For save operations you must specify a Snowfake table name with the 'dbtable' parameter"
      )
    }

    def tableExists: Boolean = {
      val conn = jdbcWrapper.getConnector(params)
      try {
        jdbcWrapper.tableExists(conn, table.toString)
      } finally {
        conn.close()
      }
    }

    val (doSave, dropExisting) = saveMode match {
      case SaveMode.Append => (true, false)
      case SaveMode.Overwrite => (true, true)
      case SaveMode.ErrorIfExists =>
        if (tableExists) {
          sys.error(
            s"Table $table already exists! (SaveMode is set to ErrorIfExists)"
          )
        } else {
          (true, false)
        }
      case SaveMode.Ignore =>
        if (tableExists) {
          log.info(s"Table $table already exists -- ignoring save request.")
          (false, false)
        } else {
          (true, false)
        }
    }

    if (doSave) {
      val updatedParams = parameters.updated("overwrite", dropExisting.toString)
      new SnowflakeWriter(jdbcWrapper)
        .save(
          sqlContext,
          data,
          saveMode,
          Parameters.mergeParameters(updatedParams)
        )

    }

    createRelation(sqlContext, parameters)
  }

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink =
    new SnowflakeSink(sqlContext, parameters, partitionColumns, outputMode)
} 
Example 2
Source File: JdbcRelationProvider.scala    From drizzle-spark   with Apache License 2.0 7 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends CreatableRelationProvider
  with RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val partitionColumn = jdbcOptions.partitionColumn
    val lowerBound = jdbcOptions.lowerBound
    val upperBound = jdbcOptions.upperBound
    val numPartitions = jdbcOptions.numPartitions

    val partitionInfo = if (partitionColumn == null) {
      null
    } else {
      JDBCPartitioningInfo(
        partitionColumn, lowerBound.toLong, upperBound.toLong, numPartitions.toInt)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val url = jdbcOptions.url
    val table = jdbcOptions.table
    val createTableOptions = jdbcOptions.createTableOptions
    val isTruncate = jdbcOptions.isTruncate

    val conn = JdbcUtils.createConnectionFactory(jdbcOptions)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, url, table)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (isTruncate && isCascadingTruncateTable(url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, table)
              saveTable(df, url, table, jdbcOptions)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, table)
              createTable(df.schema, url, table, createTableOptions, conn)
              saveTable(df, url, table, jdbcOptions)
            }

          case SaveMode.Append =>
            saveTable(df, url, table, jdbcOptions)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '$table' already exists. SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(df.schema, url, table, createTableOptions, conn)
        saveTable(df, url, table, jdbcOptions)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }
} 
Example 3
Source File: MLlibTestSparkContext.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import java.io.File

import org.scalatest.Suite

import org.apache.spark.SparkContext
import org.apache.spark.ml.util.TempDirectory
import org.apache.spark.sql.{SparkSession, SQLContext, SQLImplicits}
import org.apache.spark.util.Utils

trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
  @transient var spark: SparkSession = _
  @transient var sc: SparkContext = _
  @transient var checkpointDir: String = _

  override def beforeAll() {
    super.beforeAll()
    spark = SparkSession.builder
      .master("local[2]")
      .appName("MLlibUnitTest")
      .getOrCreate()
    sc = spark.sparkContext

    checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
    sc.setCheckpointDir(checkpointDir)
  }

  override def afterAll() {
    try {
      Utils.deleteRecursively(new File(checkpointDir))
      SparkSession.clearActiveSession()
      if (spark != null) {
        spark.stop()
      }
      spark = null
    } finally {
      super.afterAll()
    }
  }

  
  protected object testImplicits extends SQLImplicits {
    protected override def _sqlContext: SQLContext = self.spark.sqlContext
  }
} 
Example 4
Source File: SparkSQLCLIService.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.IOException
import java.util.{List => JList}
import javax.security.auth.login.LoginException

import scala.collection.JavaConverters._

import org.apache.commons.logging.Log
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.shims.Utils
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hive.service.{AbstractService, Service, ServiceException}
import org.apache.hive.service.Service.STATE
import org.apache.hive.service.auth.HiveAuthFactory
import org.apache.hive.service.cli._
import org.apache.hive.service.server.HiveServer2

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._

private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLContext)
  extends CLIService(hiveServer)
  with ReflectedCompositeService {

  override def init(hiveConf: HiveConf) {
    setSuperField(this, "hiveConf", hiveConf)

    val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, sqlContext)
    setSuperField(this, "sessionManager", sparkSqlSessionManager)
    addService(sparkSqlSessionManager)
    var sparkServiceUGI: UserGroupInformation = null

    if (UserGroupInformation.isSecurityEnabled) {
      try {
        HiveAuthFactory.loginFromKeytab(hiveConf)
        sparkServiceUGI = Utils.getUGI()
        setSuperField(this, "serviceUGI", sparkServiceUGI)
      } catch {
        case e @ (_: IOException | _: LoginException) =>
          throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
      }
    }

    initCompositeService(hiveConf)
  }

  override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = {
    getInfoType match {
      case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL")
      case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL")
      case GetInfoType.CLI_DBMS_VER => new GetInfoValue(sqlContext.sparkContext.version)
      case _ => super.getInfo(sessionHandle, getInfoType)
    }
  }
}

private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
  def initCompositeService(hiveConf: HiveConf) {
    // Emulating `CompositeService.init(hiveConf)`
    val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList")
    serviceList.asScala.foreach(_.init(hiveConf))

    // Emulating `AbstractService.init(hiveConf)`
    invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED)
    setAncestorField(this, 3, "hiveConf", hiveConf)
    invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED)
    getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.")
  }
} 
Example 5
Source File: SparkSQLDriver.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.{ArrayList => JArrayList, Arrays, List => JList}

import scala.collection.JavaConverters._

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.execution.QueryExecution


private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext)
  extends Driver
  with Logging {

  private[hive] var tableSchema: Schema = _
  private[hive] var hiveResponse: Seq[String] = _

  override def init(): Unit = {
  }

  private def getResultSetSchema(query: QueryExecution): Schema = {
    val analyzed = query.analyzed
    logDebug(s"Result Schema: ${analyzed.output}")
    if (analyzed.output.isEmpty) {
      new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null)
    } else {
      val fieldSchemas = analyzed.output.map { attr =>
        new FieldSchema(attr.name, attr.dataType.catalogString, "")
      }

      new Schema(fieldSchemas.asJava, null)
    }
  }

  override def run(command: String): CommandProcessorResponse = {
    // TODO unify the error code
    try {
      context.sparkContext.setJobDescription(command)
      val execution = context.sessionState.executePlan(context.sql(command).logicalPlan)
      hiveResponse = execution.hiveResultString()
      tableSchema = getResultSetSchema(execution)
      new CommandProcessorResponse(0)
    } catch {
        case ae: AnalysisException =>
          logDebug(s"Failed in [$command]", ae)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae)
        case cause: Throwable =>
          logError(s"Failed in [$command]", cause)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause)
    }
  }

  override def close(): Int = {
    hiveResponse = null
    tableSchema = null
    0
  }

  override def getResults(res: JList[_]): Boolean = {
    if (hiveResponse == null) {
      false
    } else {
      res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava)
      hiveResponse = null
      true
    }
  }

  override def getSchema: Schema = tableSchema

  override def destroy() {
    super.destroy()
    hiveResponse = null
    tableSchema = null
  }
} 
Example 6
Source File: SparkSQLSessionManager.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.concurrent.Executors

import org.apache.commons.logging.Log
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.SessionHandle
import org.apache.hive.service.cli.session.SessionManager
import org.apache.hive.service.cli.thrift.TProtocolVersion
import org.apache.hive.service.server.HiveServer2

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils}
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager


private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: SQLContext)
  extends SessionManager(hiveServer)
  with ReflectedCompositeService {

  private lazy val sparkSqlOperationManager = new SparkSQLOperationManager()

  override def init(hiveConf: HiveConf) {
    setSuperField(this, "hiveConf", hiveConf)

    // Create operation log root directory, if operation logging is enabled
    if (hiveConf.getBoolVar(ConfVars.HIVE_SERVER2_LOGGING_OPERATION_ENABLED)) {
      invoke(classOf[SessionManager], this, "initOperationLogRootDir")
    }

    val backgroundPoolSize = hiveConf.getIntVar(ConfVars.HIVE_SERVER2_ASYNC_EXEC_THREADS)
    setSuperField(this, "backgroundOperationPool", Executors.newFixedThreadPool(backgroundPoolSize))
    getAncestorField[Log](this, 3, "LOG").info(
      s"HiveServer2: Async execution pool size $backgroundPoolSize")

    setSuperField(this, "operationManager", sparkSqlOperationManager)
    addService(sparkSqlOperationManager)

    initCompositeService(hiveConf)
  }

  override def openSession(
      protocol: TProtocolVersion,
      username: String,
      passwd: String,
      ipAddress: String,
      sessionConf: java.util.Map[String, String],
      withImpersonation: Boolean,
      delegationToken: String): SessionHandle = {
    val sessionHandle =
      super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation,
          delegationToken)
    val session = super.getSession(sessionHandle)
    HiveThriftServer2.listener.onSessionCreated(
      session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername)
    val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState]
    val ctx = if (sessionState.hiveThriftServerSingleSession) {
      sqlContext
    } else {
      sqlContext.newSession()
    }
    ctx.setConf("spark.sql.hive.version", HiveUtils.hiveExecutionVersion)
    if (sessionConf != null && sessionConf.containsKey("use:database")) {
      ctx.sql(s"use ${sessionConf.get("use:database")}")
    }
    sparkSqlOperationManager.sessionToContexts.put(sessionHandle, ctx)
    sessionHandle
  }

  override def closeSession(sessionHandle: SessionHandle) {
    HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString)
    super.closeSession(sessionHandle)
    sparkSqlOperationManager.sessionToActivePool.remove(sessionHandle)
    sparkSqlOperationManager.sessionToContexts.remove(sessionHandle)
  }
} 
Example 7
Source File: SparkSQLOperationManager.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver.server

import java.util.{Map => JMap}
import java.util.concurrent.ConcurrentHashMap

import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveSessionState
import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation}


private[thriftserver] class SparkSQLOperationManager()
  extends OperationManager with Logging {

  val handleToOperation = ReflectionUtils
    .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")

  val sessionToActivePool = new ConcurrentHashMap[SessionHandle, String]()
  val sessionToContexts = new ConcurrentHashMap[SessionHandle, SQLContext]()

  override def newExecuteStatementOperation(
      parentSession: HiveSession,
      statement: String,
      confOverlay: JMap[String, String],
      async: Boolean): ExecuteStatementOperation = synchronized {
    val sqlContext = sessionToContexts.get(parentSession.getSessionHandle)
    require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" +
      s" initialized or had already closed.")
    val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState]
    val runInBackground = async && sessionState.hiveThriftServerAsync
    val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
      runInBackground)(sqlContext, sessionToActivePool)
    handleToOperation.put(operation.getHandle, operation)
    logDebug(s"Created Operation for $statement with session=$parentSession, " +
      s"runInBackground=$runInBackground")
    operation
  }
} 
Example 8
Source File: SparkSQLEnv.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.PrintStream

import scala.collection.JavaConverters._

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.hive.{HiveSessionState, HiveUtils}
import org.apache.spark.util.Utils


  def stop() {
    logDebug("Shutting down Spark SQL Environment")
    // Stop the SparkContext
    if (SparkSQLEnv.sparkContext != null) {
      sparkContext.stop()
      sparkContext = null
      sqlContext = null
    }
  }
} 
Example 9
Source File: HadoopFsRelation.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.types.StructType



case class HadoopFsRelation(
    location: FileCatalog,
    partitionSchema: StructType,
    dataSchema: StructType,
    bucketSpec: Option[BucketSpec],
    fileFormat: FileFormat,
    options: Map[String, String])(val sparkSession: SparkSession)
  extends BaseRelation with FileRelation {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  val schema: StructType = {
    val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet
    StructType(dataSchema ++ partitionSchema.filterNot { column =>
      dataSchemaColumnNames.contains(column.name.toLowerCase)
    })
  }

  def partitionSchemaOption: Option[StructType] =
    if (partitionSchema.isEmpty) None else Some(partitionSchema)

  override def toString: String = {
    fileFormat match {
      case source: DataSourceRegister => source.shortName()
      case _ => "HadoopFiles"
    }
  }

  override def sizeInBytes: Long = location.sizeInBytes

  override def inputFiles: Array[String] = location.inputFiles
} 
Example 10
Source File: console.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

class ConsoleSink(options: Map[String, String]) extends Sink with Logging {
  // Number of rows to display, by default 20 rows
  private val numRowsToShow = options.get("numRows").map(_.toInt).getOrElse(20)

  // Truncate the displayed data if it is too long, by default it is true
  private val isTruncated = options.get("truncate").map(_.toBoolean).getOrElse(true)

  // Track the batch id
  private var lastBatchId = -1L

  override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
    val batchIdStr = if (batchId <= lastBatchId) {
      s"Rerun batch: $batchId"
    } else {
      lastBatchId = batchId
      s"Batch: $batchId"
    }

    // scalastyle:off println
    println("-------------------------------------------")
    println(batchIdStr)
    println("-------------------------------------------")
    // scalastyle:off println
    data.sparkSession.createDataFrame(
      data.sparkSession.sparkContext.parallelize(data.collect()), data.schema)
      .show(numRowsToShow, isTruncated)
  }
}

class ConsoleSinkProvider extends StreamSinkProvider with DataSourceRegister {
  def createSink(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new ConsoleSink(parameters)
  }

  def shortName(): String = "console"
} 
Example 11
Source File: package.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import scala.reflect.ClassTag

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType

package object state {

  implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {

    
    private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
        checkpointLocation: String,
        operatorId: Long,
        storeVersion: Long,
        keySchema: StructType,
        valueSchema: StructType,
        sessionState: SessionState,
        storeCoordinator: Option[StateStoreCoordinatorRef])(
        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {
      val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
      new StateStoreRDD(
        dataRDD,
        cleanedF,
        checkpointLocation,
        operatorId,
        storeVersion,
        keySchema,
        valueSchema,
        sessionState,
        storeCoordinator)
    }
  }
} 
Example 12
Source File: DDLSourceLoadSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      spark.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without Hive Support") {
    val e = intercept[AnalysisException] {
      spark.read.format("orc").load()
    }
    assert(e.message.contains("The ORC data source must be used with Hive support enabled"))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 13
Source File: SharedSQLContext.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}



  protected override def afterAll(): Unit = {
    try {
      if (_spark != null) {
        _spark.stop()
        _spark = null
      }
    } finally {
      super.afterAll()
    }
  }

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    DebugFilesystem.clearOpenStreams()
  }

  protected override def afterEach(): Unit = {
    super.afterEach()
    DebugFilesystem.assertNoOpenStreams()
  }
} 
Example 14
Source File: DLClassifierLeNet.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.example.MLPipeline

import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.dataset.image.{BytesToGreyImg, GreyImgNormalizer, GreyImgToBatch}
import com.intel.analytics.bigdl.dataset.{DataSet, DistributedDataSet, MiniBatch, _}
import com.intel.analytics.bigdl.dlframes.DLClassifier
import com.intel.analytics.bigdl.models.lenet.LeNet5
import com.intel.analytics.bigdl.models.lenet.Utils._
import com.intel.analytics.bigdl.nn.ClassNLLCriterion
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat
import com.intel.analytics.bigdl.utils.{Engine, LoggerFilter}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext


object DLClassifierLeNet {

  LoggerFilter.redirectSparkInfoLogs()

  def main(args: Array[String]): Unit = {
    val inputs = Array[String]("Feature data", "Label data")
    trainParser.parse(args, new TrainParams()).foreach(param => {
      val conf = Engine.createSparkConf()
        .setAppName("MLPipeline Example")
        .set("spark.task.maxFailures", "1")
      val sc = new SparkContext(conf)
      val sqLContext = SQLContext.getOrCreate(sc)
      Engine.init

      val trainData = param.folder + "/train-images-idx3-ubyte"
      val trainLabel = param.folder + "/train-labels-idx1-ubyte"
      val validationData = param.folder + "/t10k-images-idx3-ubyte"
      val validationLabel = param.folder + "/t10k-labels-idx1-ubyte"

      val trainSet = DataSet.array(load(trainData, trainLabel), sc) ->
        BytesToGreyImg(28, 28) -> GreyImgNormalizer(trainMean, trainStd) -> GreyImgToBatch(1)

      val trainingRDD : RDD[Data[Float]] = trainSet.
        asInstanceOf[DistributedDataSet[MiniBatch[Float]]].data(false).map(batch => {
          val feature = batch.getInput().asInstanceOf[Tensor[Float]]
          val label = batch.getTarget().asInstanceOf[Tensor[Float]]
          Data[Float](feature.storage().array(), label.storage().array())
        })
      val trainingDF = sqLContext.createDataFrame(trainingRDD).toDF(inputs: _*)

      val model = LeNet5(classNum = 10)
      val criterion = ClassNLLCriterion[Float]()
      val featureSize = Array(28, 28)
      val estimator = new DLClassifier[Float](model, criterion, featureSize)
        .setFeaturesCol(inputs(0))
        .setLabelCol(inputs(1))
        .setBatchSize(param.batchSize)
        .setMaxEpoch(param.maxEpoch)
      val transformer = estimator.fit(trainingDF)

      val validationSet = DataSet.array(load(validationData, validationLabel), sc) ->
        BytesToGreyImg(28, 28) -> GreyImgNormalizer(testMean, testStd) -> GreyImgToBatch(1)

      val validationRDD: RDD[Data[Float]] = validationSet.
        asInstanceOf[DistributedDataSet[MiniBatch[Float]]].data(false).map{batch =>
          val feature = batch.getInput().asInstanceOf[Tensor[Float]]
          val label = batch.getTarget().asInstanceOf[Tensor[Float]]
          Data[Float](feature.storage().array(), label.storage().array())
        }
      val validationDF = sqLContext.createDataFrame(validationRDD).toDF(inputs: _*)
      val transformed = transformer.transform(validationDF)
      transformed.show()
      sc.stop()
    })
  }
}

private case class Data[T](featureData : Array[T], labelData : Array[T]) 
Example 15
Source File: DLEstimatorMultiLabelLR.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.example.MLPipeline

import com.intel.analytics.bigdl.dlframes.DLEstimator
import com.intel.analytics.bigdl.nn._
import com.intel.analytics.bigdl.optim.LBFGS
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericDouble
import com.intel.analytics.bigdl.utils.Engine
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext


object DLEstimatorMultiLabelLR {

  def main(args: Array[String]): Unit = {
    val conf = Engine.createSparkConf()
      .setAppName("DLEstimatorMultiLabelLR")
      .setMaster("local[1]")
    val sc = new SparkContext(conf)
    val sqlContext = SQLContext.getOrCreate(sc)
    Engine.init

    val model = Sequential().add(Linear(2, 2))
    val criterion = MSECriterion()
    val estimator = new DLEstimator(model, criterion, Array(2), Array(2))
      .setOptimMethod(new LBFGS[Double]())
      .setLearningRate(1.0)
      .setBatchSize(4)
      .setMaxEpoch(10)
    val data = sc.parallelize(Seq(
      (Array(2.0, 1.0), Array(1.0, 2.0)),
      (Array(1.0, 2.0), Array(2.0, 1.0)),
      (Array(2.0, 1.0), Array(1.0, 2.0)),
      (Array(1.0, 2.0), Array(2.0, 1.0))))
    val df = sqlContext.createDataFrame(data).toDF("features", "label")
    val dlModel = estimator.fit(df)
    dlModel.transform(df).show(false)
  }
} 
Example 16
Source File: DLClassifierLogisticRegression.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.example.MLPipeline

import com.intel.analytics.bigdl.dlframes.DLClassifier
import com.intel.analytics.bigdl.nn.{ClassNLLCriterion, Linear, LogSoftMax, Sequential}
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat
import com.intel.analytics.bigdl.utils.Engine
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext


object DLClassifierLogisticRegression {

  def main(args: Array[String]): Unit = {
    val conf = Engine.createSparkConf()
      .setAppName("DLClassifierLogisticRegression")
      .setMaster("local[1]")
    val sc = new SparkContext(conf)
    val sqlContext = SQLContext.getOrCreate(sc)
    Engine.init

    val model = Sequential().add(Linear(2, 2)).add(LogSoftMax())
    val criterion = ClassNLLCriterion()
    val estimator = new DLClassifier(model, criterion, Array(2))
      .setBatchSize(4)
      .setMaxEpoch(10)
    val data = sc.parallelize(Seq(
      (Array(0.0, 1.0), 1.0),
      (Array(1.0, 0.0), 2.0),
      (Array(0.0, 1.0), 1.0),
      (Array(1.0, 0.0), 2.0)))
    val df = sqlContext.createDataFrame(data).toDF("features", "label")
    val dlModel = estimator.fit(df)
    dlModel.transform(df).show(false)
  }
} 
Example 17
Source File: ImagePredictor.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.example.imageclassification

import java.nio.file.Paths

import com.intel.analytics.bigdl.dataset.image._
import com.intel.analytics.bigdl.dlframes.DLClassifierModel
import com.intel.analytics.bigdl.example.imageclassification.MlUtils._
import com.intel.analytics.bigdl.numeric.NumericFloat
import com.intel.analytics.bigdl.utils.{Engine, LoggerFilter}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext


object ImagePredictor {
  LoggerFilter.redirectSparkInfoLogs()
  Logger.getLogger("com.intel.analytics.bigdl.example").setLevel(Level.INFO)

  def main(args: Array[String]): Unit = {
    predictParser.parse(args, new PredictParams()).map(param => {
      val conf = Engine.createSparkConf()
      conf.setAppName("Predict with trained model")
      val sc = new SparkContext(conf)
      Engine.init
      val sqlContext = new SQLContext(sc)

      val partitionNum = Engine.nodeNumber() * Engine.coreNumber()
      val model = loadModel(param)
      val valTrans = new DLClassifierModel(model, Array(3, imageSize, imageSize))
        .setBatchSize(param.batchSize)
        .setFeaturesCol("features")
        .setPredictionCol("predict")

      val valRDD = if (param.isHdfs) {
        // load image set from hdfs
        imagesLoadSeq(param.folder, sc, param.classNum).coalesce(partitionNum, true)
      } else {
        // load image set from local
        val paths = LocalImageFiles.readPaths(Paths.get(param.folder), hasLabel = false)
        sc.parallelize(imagesLoad(paths, 256), partitionNum)
      }

      val transf = RowToByteRecords() ->
          BytesToBGRImg() ->
          BGRImgCropper(imageSize, imageSize) ->
          BGRImgNormalizer(testMean, testStd) ->
          BGRImgToImageVector()

      val valDF = transformDF(sqlContext.createDataFrame(valRDD), transf)

      valTrans.transform(valDF)
          .select("imageName", "predict")
          .collect()
          .take(param.showNum)
          .foreach(println)
      sc.stop()
    })
  }
} 
Example 18
Source File: ImageInference.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.example.dlframes.imageInference

import com.intel.analytics.bigdl.dlframes.{DLClassifierModel, DLModel}
import org.apache.spark.sql.DataFrame
import scopt.OptionParser
import com.intel.analytics.bigdl.dataset.Sample
import com.intel.analytics.bigdl.nn.Module
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric.NumericFloat
import com.intel.analytics.bigdl.transform.vision.image.augmentation._
import com.intel.analytics.bigdl.transform.vision.image._
import com.intel.analytics.bigdl.utils.Engine
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext

object ImageInference {

  def main(args: Array[String]): Unit = {

    val defaultParams = Utils.LocalParams()
    Utils.parser.parse(args, defaultParams).map { params =>

      val conf = Engine.createSparkConf().setAppName("ModelInference")
      val sc = SparkContext.getOrCreate(conf)
      val sqlContext = new SQLContext(sc)
      Engine.init

      val imagesDF = Utils.loadImages(params.folder, params.batchSize, sqlContext).cache()

      imagesDF.show(10)
      imagesDF.printSchema()

      val model = Module.loadCaffeModel[Float](params.caffeDefPath, params.modelPath)
      val dlmodel: DLModel[Float] = new DLClassifierModel[Float](
        model, Array(3, 224, 224))
        .setBatchSize(params.batchSize)
        .setFeaturesCol("features")
        .setPredictionCol("prediction")

      val count = imagesDF.count().toInt
      val tranDF = dlmodel.transform(imagesDF.limit(count))

      tranDF.select("imageName", "prediction").show(100, false)
    }
  }
}

object Utils {

  case class LocalParams(caffeDefPath: String = " ",
                         modelPath: String = " ",
                         folder: String = " ",
                         batchSize: Int = 16,
                         nEpochs: Int = 10
                        )

  val defaultParams = LocalParams()

  val parser = new OptionParser[LocalParams]("BigDL Example") {
    opt[String]("caffeDefPath")
      .text(s"caffeDefPath")
      .action((x, c) => c.copy(caffeDefPath = x))
    opt[String]("modelPath")
      .text(s"modelPath")
      .action((x, c) => c.copy(modelPath = x))
    opt[String]("folder")
      .text(s"folder")
      .action((x, c) => c.copy(folder = x))
    opt[Int]('b', "batchSize")
      .text(s"batchSize")
      .action((x, c) => c.copy(batchSize = x.toInt))
    opt[Int]('e', "nEpochs")
      .text("epoch numbers")
      .action((x, c) => c.copy(nEpochs = x))
  }

  def loadImages(path: String, partitionNum: Int, sqlContext: SQLContext): DataFrame = {

    val imageFrame: ImageFrame = ImageFrame.read(path, sqlContext.sparkContext)
    val transformer = Resize(256, 256) -> CenterCrop(224, 224) ->
      ChannelNormalize(123, 117, 104, 1, 1, 1) -> MatToTensor() -> ImageFrameToSample()
    val transformed: ImageFrame = transformer(imageFrame)
    val imageRDD = transformed.toDistributed().rdd.map { im =>
      (im.uri, im[Sample[Float]](ImageFeature.sample).getData())
    }
    val imageDF = sqlContext.createDataFrame(imageRDD)
      .withColumnRenamed("_1", "imageName")
      .withColumnRenamed("_2", "features")
    imageDF
  }

} 
Example 19
Source File: MLlibTestSparkContext.scala    From spark-lp   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import org.scalatest.{BeforeAndAfterAll, Suite}

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLContext


trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
  @transient var sc: SparkContext = _
  @transient var sqlContext: SQLContext = _

  override def beforeAll() {
    super.beforeAll()
    val conf = new SparkConf()
      .setMaster("local[2]")
      .setAppName("MLlibUnitTest")
    sc = new SparkContext(conf)
    sc.setLogLevel("WARN")
    sqlContext = new SQLContext(sc)
  }

  override def afterAll() {
    sqlContext = null
    if (sc != null) {
      sc.stop()
    }
    sc = null
    super.afterAll()
  }
} 
Example 20
Source File: TestSFObjectWriter.scala    From spark-salesforce   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.salesforce

import org.mockito.Mockito._
import org.mockito.Matchers._
import org.scalatest.mock.MockitoSugar
import org.scalatest.{ FunSuite, BeforeAndAfterEach}
import com.springml.salesforce.wave.api.BulkAPI
import org.apache.spark.{ SparkConf, SparkContext}
import com.springml.salesforce.wave.model.{ JobInfo, BatchInfo}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{ Row, DataFrame, SQLContext}
import org.apache.spark.sql.types.{ StructType, StringType, StructField}

class TestSFObjectWriter extends FunSuite with MockitoSugar with BeforeAndAfterEach {
  val contact = "Contact";
  val jobId = "750B0000000WlhtIAC";
  val batchId = "751B0000000scSHIAY";
  val data = "Id,Description\n003B00000067Rnx,123456\n003B00000067Rnw,7890";

  val bulkAPI = mock[BulkAPI](withSettings().serializable())
  val writer = mock[SFObjectWriter]

  var sparkConf: SparkConf = _
  var sc: SparkContext = _

  override def beforeEach() {
    val jobInfo = new JobInfo
    jobInfo.setId(jobId)
    when(bulkAPI.createJob(contact)).thenReturn(jobInfo)

    val batchInfo = new BatchInfo
    batchInfo.setId(batchId)
    batchInfo.setJobId(jobId)
    when(bulkAPI.addBatch(jobId, data)).thenReturn(batchInfo)

    when(bulkAPI.closeJob(jobId)).thenReturn(jobInfo)
    when(bulkAPI.isCompleted(jobId)).thenReturn(true)

    sparkConf = new SparkConf().setMaster("local").setAppName("Test SF Object Update")
    sc = new SparkContext(sparkConf)
  }

  private def sampleDF() : DataFrame = {
    val rowArray = new Array[Row](2)
    val fieldArray = new Array[String](2)

    fieldArray(0) = "003B00000067Rnx"
    fieldArray(1) = "Desc1"
    rowArray(0) = Row.fromSeq(fieldArray)

    val fieldArray1 = new Array[String](2)
    fieldArray1(0) = "001B00000067Rnx"
    fieldArray1(1) = "Desc2"
    rowArray(1) = Row.fromSeq(fieldArray1)

    val rdd = sc.parallelize(rowArray)
    val schema = StructType(
      StructField("id", StringType, true) ::
      StructField("desc", StringType, true) :: Nil)

    val sqlContext = new SQLContext(sc)
    sqlContext.createDataFrame(rdd, schema)
  }

  test ("Write Object to Salesforce") {
    val df = sampleDF();
    val csvHeader = Utils.csvHeadder(df.schema)
    writer.writeData(df.rdd)
    sc.stop()
  }
} 
Example 21
Source File: SharedSparkSessionBase.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.apache.spark.{DebugFilesystem, SparkConf}
import org.scalatest.Suite
import org.scalatest.concurrent.Eventually

import scala.concurrent.duration._



  protected override def afterAll(): Unit = {
    try {
      super.afterAll()
    } finally {
      try {
        if (_spark != null) {
          try {
            _spark.sessionState.catalog.reset()
          } finally {
            try {
              waitForTasksToFinish()
            } finally {
              _spark.stop()
              _spark = null
            }
          }
        }
      } finally {
        SparkSession.clearActiveSession()
        SparkSession.clearDefaultSession()
      }
    }
  }

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    DebugFilesystem.clearOpenStreams()
  }

  protected override def afterEach(): Unit = {
    super.afterEach()
    // Clear all persistent datasets after each test
    spark.sharedState.cacheManager.clearCache()
    // files can be closed from other threads, so wait a bit
    // normally this doesn't take more than 1s
    eventually(timeout(30.seconds), interval(2.seconds)) {
      DebugFilesystem.assertNoOpenStreams()
    }
  }
} 
Example 22
Source File: SparkSessionSpec.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package com.swoop.test_utils

import org.apache.logging.log4j.Level
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.test.{SharedSparkSessionBase, TestSparkSession}
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.TestSuite

import scala.util.Try


trait SparkSessionSpec extends SharedSparkSessionBase {
  this: TestSuite =>

  override protected def createSparkSession: TestSparkSession = {
    val spark = super.createSparkSession
    configureLoggers(sparkLogLevel)
    spark
  }

  def sparkSession = spark

  def sqlc: SQLContext = sparkSession.sqlContext

  def sc: SparkContext = sparkSession.sparkContext

  protected def sparkLogLevel =
    Try(sys.env("SPARK_LOG_LEVEL")).getOrElse("WARN").toUpperCase match {
      case "DEBUG" => Level.DEBUG
      case "INFO" => Level.INFO
      case "WARN" => Level.WARN
      case _ => Level.ERROR
    }

  protected def configureLoggers(): Unit =
    configureLoggers(sparkLogLevel)

  protected def configureLoggers(logLevel: Level): Unit = {
    // Set logging through log4j v1 APIs also as v2 APIs are too tricky to manage
    org.apache.log4j.Logger.getRootLogger.setLevel(logLevel match {
      case Level.DEBUG => org.apache.log4j.Level.DEBUG
      case Level.INFO => org.apache.log4j.Level.INFO
      case Level.WARN => org.apache.log4j.Level.WARN
      case Level.ERROR => org.apache.log4j.Level.ERROR
    })
  }

  override protected def sparkConf: SparkConf =
    super.sparkConf
      .set("spark.driver.bindAddress", "127.0.0.1")

} 
Example 23
Source File: S2SinkProvider.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.spark.sql.streaming

import com.typesafe.config.{Config, ConfigFactory, ConfigRenderOptions}
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode

import scala.collection.JavaConversions._

class S2SinkProvider extends StreamSinkProvider with DataSourceRegister with Logger {
  override def createSink(
                  sqlContext: SQLContext,
                  parameters: Map[String, String],
                  partitionColumns: Seq[String],
                  outputMode: OutputMode): Sink = {

    logger.info(s"S2SinkProvider options : ${parameters}")
    val jobConf:Config = ConfigFactory.parseMap(parameters).withFallback(ConfigFactory.load())
    logger.info(s"S2SinkProvider Configuration : ${jobConf.root().render(ConfigRenderOptions.concise())}")

    new S2SparkSqlStreamingSink(sqlContext.sparkSession, jobConf)
  }

  override def shortName(): String = "s2graph"
} 
Example 24
Source File: JsonUtil.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.util

import org.apache.spark.sql.functions.explode
import org.apache.spark.sql.{Column, DataFrame, SQLContext, SparkSession}

import scala.collection.mutable.ArrayBuffer

object JsonUtil extends Serializable{


//  The tag you want to parse,If you want to open an array field,you have to write it like this:links_name(MasterField_ChildField)
  def ParserJsonDF(df:DataFrame,tag:String): DataFrame = {

    var openArrField:String=""
    var ArrSchame:String=""

    var tagARR: Array[String] = tag.split(",")
    var tagNew:String=""


    for(tt<-tagARR){

      if(tt.indexOf("_")> -1){
        //包含“.”
        val openField: Array[String] = tt.split("_")
        openArrField=openField(0)

        ArrSchame+=(openField(1)+",")
      }else{
        tagNew+=(tt+",")
      }
    }
    tagNew+=openArrField
    ArrSchame=ArrSchame.substring(0,ArrSchame.length-1)

    tagARR = tagNew.split(",")
    var FinalDF:DataFrame=df

    //如果用户选择返回字段
    var strings: Seq[Column] =tagNew.split(",").toSeq.map(p => new Column(p))

    if(tag.length>0){
      val df00 = FinalDF.select(strings : _*)
      FinalDF=df00
    }

    //如果用户选择打开的数组字段,并给出schame
    if(openArrField.length>0&&ArrSchame.length>0){

      val schames: Array[String] = ArrSchame.split(",")

      var selARR:ArrayBuffer[String]=ArrayBuffer()//分别取出已经打开的字段
      //遍历数组,封装到column对象中
      var coARR:ArrayBuffer[Column]=ArrayBuffer()//打开字段的select方法用
      val sss = tagNew.split(",")//打开字段后todf方法用
      var co: Column =null
      for(each<-tagARR){
        if(each==openArrField){
          co = explode(FinalDF(openArrField))
          for(x<-schames){

            selARR+=(openArrField+"."+x)
          }
        }else{
          selARR+=each
          co=FinalDF(each)
        }
        coARR+=co
      }
      println("###################")
      selARR.foreach(println(_))
      var selSEQ: Seq[Column] = selARR.toSeq.map(q => new Column(q))

      var df01: DataFrame = FinalDF.select(coARR : _*).toDF(sss:_*)
      FinalDF = df01.select(selSEQ : _*)

    }

FinalDF

  }
} 
Example 25
Source File: ScalaRiakParquetExample.scala    From spark-riak-connector   with Apache License 2.0 5 votes vote down vote up
package com.basho.riak.spark.examples.parquet
import org.apache.spark.sql.{SaveMode, SQLContext}
import org.apache.spark.{SparkContext, SparkConf}


object ScalaRiakParquetExample {
  case class TSData(site: String, species: String, measurementDate: Long, latitude: Double, longitude: Double, value: Double )
  val startDate = System.currentTimeMillis()
  val endDate = startDate + 100
  val tableName = "parquet_demo"
  val parquetFileName = "riak-ts-data.parquet"

  val testData = Seq(
    TSData("MY7", "PM10", startDate, 51.52254, -0.15459, 41.4),
    TSData("MY7", "PM10", startDate + 10, 51.52254, -0.15459, 41.2),
    TSData("MY7", "PM10", startDate + 20, 51.52254, -0.15459, 39.1),
    TSData("MY7", "PM10", startDate + 30, 51.52254, -0.15459, 39.5),
    TSData("MY7", "PM10", startDate + 40, 51.52254, -0.15459, 29.9),
    TSData("MY7", "PM10", startDate + 50, 51.52254, -0.15459, 34.2),
    TSData("MY7", "PM10", startDate + 60, 51.52254, -0.15459, 28.5),
    TSData("MY7", "PM10", startDate + 70, 51.52254, -0.15459, 39.6),
    TSData("MY7", "PM10", startDate + 80, 51.52254, -0.15459, 29.2),
    TSData("MY7", "PM10", startDate + 90, 51.52254, -0.15459, 31.3)
  )

  def main(args: Array[String]): Unit = {
    val sparkConf = new SparkConf()
      .setAppName("Simple Scala Riak TS Demo")

    setSparkOpt(sparkConf, "spark.master", "local")
    setSparkOpt(sparkConf, "spark.riak.connection.host", "127.0.0.1:8087")
    println(s"Test data start time: $startDate")

    val sc = new SparkContext(sparkConf)
    val sqlCtx = SQLContext.getOrCreate(sc)

    import sqlCtx.implicits._

    val rdd = sc.parallelize(testData)
    rdd.toDF().write.format("org.apache.spark.sql.riak")
      .mode(SaveMode.Append).save(tableName)

    val df = sqlCtx.read.format("org.apache.spark.sql.riak")
      .load(tableName).registerTempTable(tableName)

    val from = (startDate / 1000).toInt
    val query = s"select * from $tableName where measurementDate >= CAST($from AS TIMESTAMP) " +
      s"AND measurementDate <= CAST(${from + 1} AS TIMESTAMP) AND site = 'MY7' AND species = 'PM10'"

    println(s"Query: $query")
    val rows = sqlCtx.sql(query)
    rows.show()
    val schema = rows.schema

    rows.write.mode("overwrite").parquet(parquetFileName)
    println(s"Data was successfully saved to Parquet file: $parquetFileName")

    val parquetFile = sqlCtx.read.parquet(parquetFileName)
    parquetFile.registerTempTable("parquetFile")
    val data = sqlCtx.sql("SELECT MAX(value) max_value FROM parquetFile ")

    println("Maximum value retrieved from Parquet file:")
    data.show()

  }

  private def setSparkOpt(sparkConf: SparkConf, option: String, defaultOptVal: String): SparkConf = {
    val optval = sparkConf.getOption(option).getOrElse(defaultOptVal)
    sparkConf.set(option, optval)
  }
} 
Example 26
Source File: SparkDataframesTest.scala    From spark-riak-connector   with Apache License 2.0 5 votes vote down vote up
package com.basho.riak.spark.rdd

import scala.reflect.runtime.universe
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SQLContext
import org.junit.Assert._
import org.junit.{ Before, Test }
import com.basho.riak.spark.toSparkContextFunctions
import org.junit.experimental.categories.Category

case class TestData(id: String, name: String, age: Int, category: String)

@Category(Array(classOf[RiakTSTests]))
class SparkDataframesTest extends AbstractRiakSparkTest {

  private val indexName = "creationNo"

  protected override val jsonData = Some(
    """[
      |   {key: 'key1', value: {id: 'u1', name: 'Ben', age: 20, category: 'CategoryA'}},
      |   {key: 'key2', value: {id: 'u2', name: 'Clair', age: 30, category: 'CategoryB'}},
      |   {key: 'key3', value: {id: 'u3', name: 'John', age: 70}},
      |   {key: 'key4', value: {id: 'u4', name: 'Chris', age: 10, category: 'CategoryC'}},
      |   {key: 'key5', value: {id: 'u5', name: 'Mary', age: 40, category: 'CategoryB'}},
      |   {key: 'key6', value: {id: 'u6', name: 'George', age: 50, category: 'CategoryC'}}
      |]""".stripMargin)

  protected override def initSparkConf() = super.initSparkConf().setAppName("Dataframes Test")

  var sqlContextHolder: SQLContext = _
  var df: DataFrame = _

  @Before
  def initializeDF(): Unit = {
    val sqlContext = new org.apache.spark.sql.SQLContext(sc)
    import sqlContext.implicits._
    sqlContextHolder = sqlContext
    df = sc.riakBucket[TestData](DEFAULT_NAMESPACE.getBucketNameAsString)
      .queryAll().toDF
    df.registerTempTable("test")
  }

  @Test
  def schemaTest(): Unit = {
    df.printSchema()
    val schema = df.schema.map(_.name).toList
    val fields = universe.typeOf[TestData].members.withFilter(!_.isMethod).map(_.name.toString.trim).toList
    assertEquals(schema.sorted, fields.sorted)
  }

  @Test
  def sqlQueryTest(): Unit = {
    val sqlResult = sqlContextHolder.sql("select * from test where category >= 'CategoryC'").toJSON.collect
    val expected =
      """ [
        |   {id:'u4',name:'Chris',age:10,category:'CategoryC'},
        |   {id:'u6',name:'George',age:50,category:'CategoryC'}
        | ]""".stripMargin
    assertEqualsUsingJSONIgnoreOrder(expected, stringify(sqlResult))
  }

  @Test
  def udfTest(): Unit = {
    sqlContextHolder.udf.register("stringLength", (s: String) => s.length)
    val udf = sqlContextHolder.sql("select name, stringLength(name) strLgth from test order by strLgth, name").toJSON.collect
    val expected =
      """ [
        |   {name:'Ben',strLgth:3},
        |   {name:'John',strLgth:4},
        |   {name:'Mary',strLgth:4},
        |   {name:'Chris',strLgth:5},
        |   {name:'Clair',strLgth:5},
        |   {name:'George',strLgth:6}
        | ]""".stripMargin
    assertEqualsUsingJSON(expected, stringify(udf))
  }

  @Test
  def grouppingTest(): Unit = {
    val groupped = df.groupBy("category").count.toJSON.collect
    val expected =
      """ [
        |   {category:'CategoryA',count:1},
        |   {category:'CategoryB',count:2},
        |   {category:'CategoryC',count:2},
        |   {count:1}
        | ]""".stripMargin
    assertEqualsUsingJSONIgnoreOrder(expected, stringify(groupped))
  }

  @Test
  def sqlVsFilterTest(): Unit = {
    val sql = sqlContextHolder.sql("select id, name from test where age >= 50").toJSON.collect
    val filtered = df.where(df("age") >= 50).select("id", "name").toJSON.collect
    assertEqualsUsingJSONIgnoreOrder(stringify(sql), stringify(filtered))
  }

} 
Example 27
Source File: DefaultSource.scala    From spark-solr   with Apache License 2.0 5 votes vote down vote up
package solr

import com.lucidworks.spark.{SolrRelation, SolrStreamWriter}
import com.lucidworks.spark.util.Constants
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.streaming.OutputMode

class DefaultSource extends RelationProvider with CreatableRelationProvider with StreamSinkProvider with DataSourceRegister {

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    try {
      new SolrRelation(parameters, sqlContext.sparkSession)
    } catch {
      case re: RuntimeException => throw re
      case e: Exception => throw new RuntimeException(e)
    }
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    try {
      // TODO: What to do with the saveMode?
      val solrRelation: SolrRelation = new SolrRelation(parameters, Some(df), sqlContext.sparkSession)
      solrRelation.insert(df, overwrite = true)
      solrRelation
    } catch {
      case re: RuntimeException => throw re
      case e: Exception => throw new RuntimeException(e)
    }
  }

  override def shortName(): String = Constants.SOLR_FORMAT

  override def createSink(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new SolrStreamWriter(sqlContext.sparkSession, parameters, partitionColumns, outputMode)
  }
} 
Example 28
Source File: MLlibTestSparkContext.scala    From bisecting-kmeans   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, Suite}

trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
  @transient var sc: SparkContext = _
  @transient var sqlContext: SQLContext = _

  override def beforeAll() {
    super.beforeAll()
    val conf = new SparkConf()
      .setMaster("local[2]")
      .setAppName("MLlibUnitTest")
    sc = new SparkContext(conf)
    sqlContext = new SQLContext(sc)
  }

  override def afterAll() {
    sqlContext = null
    if (sc != null) {
      sc.stop()
    }
    sc = null
    super.afterAll()
  }
} 
Example 29
Source File: MLlibTestSparkContext.scala    From spark-tfocs   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import org.scalatest.{ BeforeAndAfterAll, Suite }

import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.sql.SQLContext


trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
  @transient var sc: SparkContext = _
  @transient var sqlContext: SQLContext = _

  override def beforeAll() {
    super.beforeAll()
    val conf = new SparkConf()
      .setMaster("local[2]")
      .setAppName("MLlibUnitTest")
    sc = new SparkContext(conf)
    sc.setLogLevel("WARN")
    sqlContext = new SQLContext(sc)
  }

  override def afterAll() {
    sqlContext = null
    if (sc != null) {
      sc.stop()
    }
    sc = null
    super.afterAll()
  }
} 
Example 30
Source File: SparkSqlExtension.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.extension

import java.util.concurrent._

import com.webank.wedatasphere.linkis.common.conf.CommonVars
import com.webank.wedatasphere.linkis.common.utils.{Logging, Utils}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.{DataFrame, SQLContext}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._

abstract class SparkSqlExtension extends Logging{

  private val maxPoolSize = CommonVars("wds.linkis.dws.ujes.spark.extension.max.pool",5).getValue

  private  val executor = new ThreadPoolExecutor(2, maxPoolSize, 2, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable](), new ThreadFactory {
    override def newThread(r: Runnable): Thread = {
      val thread = new Thread(r)
      thread.setDaemon(true)
      thread
    }
  })

  final def afterExecutingSQL(sqlContext: SQLContext,command: String,dataFrame: DataFrame,timeout:Long,sqlStartTime:Long):Unit = {
    try {
      val thread = new Runnable {
        override def run(): Unit = extensionRule(sqlContext,command,dataFrame.queryExecution,sqlStartTime)
      }
      val future = executor.submit(thread)
      Utils.waitUntil(future.isDone,timeout milliseconds)
    } catch {
      case e: Throwable => info("Failed to execute SparkSqlExtension: ", e)
    }
  }

  protected def extensionRule(sqlContext: SQLContext,command: String,queryExecution: QueryExecution,sqlStartTime:Long):Unit


}

object SparkSqlExtension extends Logging {

  private val extensions = ArrayBuffer[SparkSqlExtension]()

  def register(sqlExtension: SparkSqlExtension):Unit = {
    info("Get a sqlExtension register")
    extensions.append(sqlExtension)
  }

  def getSparkSqlExtensions():Array[SparkSqlExtension] = {
    extensions.toArray
  }
} 
Example 31
Source File: SparkSqlExecutor.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.executors

import java.lang.reflect.InvocationTargetException
import java.util.concurrent.atomic.AtomicLong

import com.webank.wedatasphere.linkis.common.conf.CommonVars
import com.webank.wedatasphere.linkis.common.utils.{Logging, Utils}
import com.webank.wedatasphere.linkis.engine.configuration.SparkConfiguration
import com.webank.wedatasphere.linkis.engine.execute.EngineExecutorContext
import com.webank.wedatasphere.linkis.engine.extension.SparkSqlExtension
import com.webank.wedatasphere.linkis.engine.spark.common.{Kind, SparkSQL}
import com.webank.wedatasphere.linkis.engine.spark.utils.EngineUtils
import com.webank.wedatasphere.linkis.scheduler.executer.{ErrorExecuteResponse, ExecuteResponse, SuccessExecuteResponse}
import org.apache.commons.lang.exception.ExceptionUtils
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}

      SQLSession.showDF(sc, jobGroup, rdd, null, SparkConfiguration.SHOW_DF_MAX_RES.getValue,engineExecutorContext)
      SuccessExecuteResponse()
    } catch {
      case e: InvocationTargetException =>
        var cause = ExceptionUtils.getCause(e)
        if(cause == null) cause = e
        error("execute sparkSQL failed!", cause)
        ErrorExecuteResponse(ExceptionUtils.getRootCauseMessage(e), cause)
      case ite: Exception =>
        error("execute sparkSQL failed!", ite)
        ErrorExecuteResponse(ExceptionUtils.getRootCauseMessage(ite), ite)
    } finally sc.clearJobGroup()
  }



  override def kind: Kind = SparkSQL()

  override def open: Unit = {}

  override def close: Unit = {}

} 
Example 32
Source File: DefaultSource.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.spark.excel

import org.apache.hadoop.fs.Path
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {

  
  override def createRelation(
    sqlContext: SQLContext,
    parameters: Map[String, String],
    schema: StructType
  ): ExcelRelation = {
    ExcelRelation(
      location = checkParameter(parameters, "path"),
      sheetName = parameters.get("sheetName"),
      useHeader = checkParameter(parameters, "useHeader").toBoolean,
      treatEmptyValuesAsNulls = parameters.get("treatEmptyValuesAsNulls").fold(true)(_.toBoolean),
      userSchema = Option(schema),
      inferSheetSchema = parameters.get("inferSchema").fold(false)(_.toBoolean),
      addColorColumns = parameters.get("addColorColumns").fold(false)(_.toBoolean),
      startColumn = parameters.get("startColumn").fold(0)(_.toInt),
      endColumn = parameters.get("endColumn").fold(Int.MaxValue)(_.toInt),
      timestampFormat = parameters.get("timestampFormat"),
      maxRowsInMemory = parameters.get("maxRowsInMemory").map(_.toInt),
      excerptSize = parameters.get("excerptSize").fold(10)(_.toInt),
      parameters = parameters,
      dateFormat = parameters.get("dateFormats").getOrElse("yyyy-MM-dd").split(";").toList
    )(sqlContext)
  }

  override def createRelation(
    sqlContext: SQLContext,
    mode: SaveMode,
    parameters: Map[String, String],
    data: DataFrame
  ): BaseRelation = {
    val path = checkParameter(parameters, "path")
    val sheetName = parameters.getOrElse("sheetName", "Sheet1")
    val useHeader = checkParameter(parameters, "useHeader").toBoolean
    val dateFormat = parameters.getOrElse("dateFormat", ExcelFileSaver.DEFAULT_DATE_FORMAT)
    val timestampFormat = parameters.getOrElse("timestampFormat", ExcelFileSaver.DEFAULT_TIMESTAMP_FORMAT)
    val filesystemPath = new Path(path)
    val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
    fs.setWriteChecksum(false)
    val doSave = if (fs.exists(filesystemPath)) {
      mode match {
        case SaveMode.Append =>
          sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}")
        case SaveMode.Overwrite =>
          fs.delete(filesystemPath, true)
          true
        case SaveMode.ErrorIfExists =>
          sys.error(s"path $path already exists.")
        case SaveMode.Ignore => false
      }
    } else {
      true
    }
    if (doSave) {
      // Only save data when the save mode is not ignore.
      (new ExcelFileSaver(fs)).save(
        filesystemPath,
        data,
        sheetName = sheetName,
        useHeader = useHeader,
        dateFormat = dateFormat,
        timestampFormat = timestampFormat
      )
    }

    createRelation(sqlContext, parameters, data.schema)
  }

  // Forces a Parameter to exist, otherwise an exception is thrown.
  private def checkParameter(map: Map[String, String], param: String): String = {
    if (!map.contains(param)) {
      throw new IllegalArgumentException(s"Parameter ${'"'}$param${'"'} is missing in options.")
    } else {
      map.apply(param)
    }
  }

  // Gets the Parameter if it exists, otherwise returns the default argument
  private def parameterOrDefault(map: Map[String, String], param: String, default: String) =
    map.getOrElse(param, default)
} 
Example 33
Source File: SelectJSONSource.scala    From spark-select   with Apache License 2.0 5 votes vote down vote up
package io.minio.spark.select

// Java standard libraries
import java.io.File

// Spark internal libraries
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

import org.apache.spark.sql.sources.DataSourceRegister

class SelectJSONSource
  extends SchemaRelationProvider
  with DataSourceRegister {

  private def checkPath(parameters: Map[String, String]): String = {
    parameters.getOrElse("path", sys.error("'path' must be specified for JSON data."))
  }

  
  override def shortName(): String = "minioSelectJSON"

  override def createRelation(sqlContext: SQLContext, params: Map[String, String], schema: StructType): SelectJSONRelation = {
    val path = checkPath(params)
    SelectJSONRelation(Some(path), params, schema)(sqlContext)
  }
} 
Example 34
Source File: SelectCSVSource.scala    From spark-select   with Apache License 2.0 5 votes vote down vote up
package io.minio.spark.select

// Java standard libraries
import java.io.File

// Spark internal libraries
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

import org.apache.spark.sql.sources.DataSourceRegister

class SelectCSVSource
  extends SchemaRelationProvider
  with DataSourceRegister {

  private def checkPath(parameters: Map[String, String]): String = {
    parameters.getOrElse("path", sys.error("'path' must be specified for CSV data."))
  }

  
  override def shortName(): String = "minioSelectCSV"

  override def createRelation(sqlContext: SQLContext, params: Map[String, String], schema: StructType): SelectCSVRelation = {
    val path = checkPath(params)
    SelectCSVRelation(Some(path), params, schema)(sqlContext)
  }
} 
Example 35
Source File: SelectParquetSource.scala    From spark-select   with Apache License 2.0 5 votes vote down vote up
package io.minio.spark.select

// Java standard libraries
import java.io.File

// Spark internal libraries
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

import org.apache.spark.sql.sources.DataSourceRegister

class SelectParquetSource
  extends SchemaRelationProvider
  with DataSourceRegister {

  private def checkPath(parameters: Map[String, String]): String = {
    parameters.getOrElse("path", sys.error("'path' must be specified for Parquet data."))
  }

  
  override def shortName(): String = "minioSelectParquet"

  override def createRelation(sqlContext: SQLContext, params: Map[String, String], schema: StructType): SelectParquetRelation = {
    val path = checkPath(params)
    SelectParquetRelation(Some(path), params, schema)(sqlContext)
  }
} 
Example 36
Source File: SparkCassBulkWriterSpec.scala    From Spark2Cassandra   with Apache License 2.0 5 votes vote down vote up
package com.github.jparkie.spark.cassandra

import com.datastax.driver.core.querybuilder.QueryBuilder
import com.datastax.spark.connector.AllColumns
import com.datastax.spark.connector.writer.{ RowWriterFactory, SqlRowWriter }
import com.github.jparkie.spark.cassandra.client.SparkCassSSTableLoaderClientManager
import com.github.jparkie.spark.cassandra.conf.{ SparkCassServerConf, SparkCassWriteConf }
import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.sql.{ Row, SQLContext }
import org.scalatest.{ MustMatchers, WordSpec }

import scala.collection.JavaConverters._

class SparkCassBulkWriterSpec extends WordSpec with MustMatchers with CassandraServerSpecLike with SharedSparkContext {
  val testKeyspace = "test_keyspace"
  val testTable = "test_table"

  override def beforeAll(): Unit = {
    super.beforeAll()

    getCassandraConnector.withSessionDo { currentSession =>
      createKeyspace(currentSession, testKeyspace)

      currentSession.execute(
        s"""CREATE TABLE $testKeyspace.$testTable (
            |  test_key BIGINT PRIMARY KEY,
            |  test_value VARCHAR
            |);
         """.stripMargin
      )
    }
  }

  "SparkCassBulkWriter" must {
    "write() successfully" in {
      val sqlContext = new SQLContext(sc)

      import sqlContext.implicits._

      implicit val testRowWriterFactory: RowWriterFactory[Row] = SqlRowWriter.Factory

      val testCassandraConnector = getCassandraConnector
      val testSparkCassWriteConf = SparkCassWriteConf()
      val testSparkCassServerConf = SparkCassServerConf(
        // See https://github.com/jsevellec/cassandra-unit/blob/master/cassandra-unit/src/main/resources/cu-cassandra.yaml
        storagePort = 7010
      )

      val testSparkCassBulkWriter = SparkCassBulkWriter(
        testCassandraConnector,
        testKeyspace,
        testTable,
        AllColumns,
        testSparkCassWriteConf,
        testSparkCassServerConf
      )

      val testRDD = sc.parallelize(1 to 25)
        .map(currentNumber => (currentNumber.toLong, s"Hello World: $currentNumber!"))
      val testDataFrame = testRDD.toDF("test_key", "test_value")

      sc.runJob(testDataFrame.rdd, testSparkCassBulkWriter.write _)

      getCassandraConnector.withSessionDo { currentSession =>
        val queryStatement = QueryBuilder.select("test_key", "test_value")
          .from(testKeyspace, testTable)
          .limit(25)

        val resultSet = currentSession.execute(queryStatement)

        val outputSet = resultSet.all.asScala
          .map(currentRow => (currentRow.getLong("test_key"), currentRow.getString("test_value")))
          .toMap

        for (currentNumber <- 1 to 25) {
          val currentKey = currentNumber.toLong

          outputSet(currentKey) mustEqual s"Hello World: $currentNumber!"
        }
      }

      SparkCassSSTableLoaderClientManager.evictAll()
    }
  }
} 
Example 37
Source File: SparkCassDataFrameFunctionsSpec.scala    From Spark2Cassandra   with Apache License 2.0 5 votes vote down vote up
package com.github.jparkie.spark.cassandra.sql

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.sql.SQLContext
import org.scalatest.{ MustMatchers, WordSpec }

class SparkCassDataFrameFunctionsSpec extends WordSpec with MustMatchers with SharedSparkContext {
  "Package com.github.jparkie.spark.cassandra.sql" must {
    "lift DataFrame into SparkCassDataFrameFunctions" in {
      val sqlContext = new SQLContext(sc)

      import sqlContext.implicits._

      val testRDD = sc.parallelize(1 to 25)
        .map(currentNumber => (currentNumber.toLong, s"Hello World: $currentNumber!"))
      val testDataFrame = testRDD.toDF("test_key", "test_value")

      // If internalSparkContext is available, RDD was lifted.
      testDataFrame.internalSparkContext
    }
  }
} 
Example 38
Source File: PointCloudRelation.scala    From geotrellis-pointcloud   with Apache License 2.0 5 votes vote down vote up
package geotrellis.pointcloud.spark.datasource

import geotrellis.pointcloud.spark.store.hadoop._
import geotrellis.pointcloud.spark.store.hadoop.HadoopPointCloudRDD.{Options => HadoopOptions}
import geotrellis.pointcloud.util.Filesystem
import geotrellis.proj4.CRS
import geotrellis.store.hadoop.util.HdfsUtils
import geotrellis.vector.Extent

import cats.implicits._
import io.pdal._
import io.circe.syntax._
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}

import java.io.File

import scala.collection.JavaConverters._

// This class has to be serializable since it is shipped over the network.
class PointCloudRelation(
  val sqlContext: SQLContext,
  path: String,
  options: HadoopOptions
) extends BaseRelation with TableScan with Serializable {

  @transient implicit lazy val sc: SparkContext = sqlContext.sparkContext

  // TODO: switch between HadoopPointCloudRDD and S3PointcCloudRDD
  lazy val isS3: Boolean = path.startsWith("s3")

  override def schema: StructType = {
    lazy val (local, fixedPath) =
      if(path.startsWith("s3") || path.startsWith("hdfs")) {
        val tmpDir = Filesystem.createDirectory()
        val remotePath = new Path(path)
        // copy remote file into local tmp dir
        val localPath = new File(tmpDir, remotePath.getName)
        HdfsUtils.copyPath(remotePath, new Path(s"file:///${localPath.getAbsolutePath}"), sc.hadoopConfiguration)
        (true, localPath.toString)
      } else (false, path)

    val localPipeline =
      options.pipeline
        .hcursor
        .downField("pipeline").downArray
        .downField("filename").withFocus(_ => fixedPath.asJson)
        .top.fold(options.pipeline)(identity)

    val pl = Pipeline(localPipeline.noSpaces)
    if (pl.validate()) pl.execute()
    val pointCloud = try {
      pl.getPointViews().next().getPointCloud(0)
    } finally {
      pl.close()
      if(local) println(new File(fixedPath).delete)
    }

    val rdd = HadoopPointCloudRDD(new Path(path), options)

    val md: (Option[Extent], Option[CRS]) =
      rdd
        .map { case (header, _) => (header.projectedExtent3D.map(_.extent3d.toExtent), header.crs) }
        .reduce { case ((e1, c), (e2, _)) => ((e1, e2).mapN(_ combine _), c) }

    val metadata = new MetadataBuilder().putString("metadata", md.asJson.noSpaces).build

    pointCloud.deriveSchema(metadata)
  }

  override def buildScan(): RDD[Row] = {
    val rdd = HadoopPointCloudRDD(new Path(path), options)
    rdd.flatMap { _._2.flatMap { pc => pc.readAll.toList.map { k => Row(k: _*) } } }
  }
} 
Example 39
Source File: MlLibOnKudu.scala    From Taxi360   with Apache License 2.0 5 votes vote down vote up
package com.hadooparchitecturebook.taxi360.etl.machinelearning.kudu

import com.hadooparchitecturebook.taxi360.model.{NyTaxiYellowTrip, NyTaxiYellowTripBuilder}
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

object MlLibOnKudu {
  def main(args: Array[String]): Unit = {

    if (args.length == 0) {
      println("Args: <runLocal> " +
        "<kuduMaster> " +
        "<taxiTable> " +
        "<numOfCenters> " +
        "<numOfIterations> ")
      return
    }

    val runLocal = args(0).equalsIgnoreCase("l")
    val kuduMaster = args(1)
    val taxiTable = args(2)
    val numOfCenters = args(3).toInt
    val numOfIterations = args(4).toInt

    val sc: SparkContext = if (runLocal) {
      val sparkConfig = new SparkConf()
      sparkConfig.set("spark.broadcast.compress", "false")
      sparkConfig.set("spark.shuffle.compress", "false")
      sparkConfig.set("spark.shuffle.spill.compress", "false")
      new SparkContext("local", "TableStatsSinglePathMain", sparkConfig)
    } else {
      val sparkConfig = new SparkConf().setAppName("TableStatsSinglePathMain")
      new SparkContext(sparkConfig)
    }

    val sqlContext = new SQLContext(sc)

    val kuduOptions = Map(
      "kudu.table" -> taxiTable,
      "kudu.master" -> kuduMaster)

    sqlContext.read.options(kuduOptions).format("org.apache.kudu.spark.kudu").load.
      registerTempTable("ny_taxi_trip_tmp")

    //Vector
    val vectorRDD:RDD[Vector] = sqlContext.sql("select * from ny_taxi_trip_tmp").map(r => {
      val taxiTrip = NyTaxiYellowTripBuilder.build(r)
      generateVectorOnly(taxiTrip)
    })

    println("--Running KMeans")
    val clusters = KMeans.train(vectorRDD, numOfCenters, numOfIterations)
    println(" > vector centers:")
    clusters.clusterCenters.foreach(v => println(" >> " + v))

    println("--Running corr")
    val correlMatrix: Matrix = Statistics.corr(vectorRDD, "pearson")
    println(" > corr: " + correlMatrix.toString)

    println("--Running colStats")
    val colStats = Statistics.colStats(vectorRDD)
    println(" > max: " + colStats.max)
    println(" > count: " + colStats.count)
    println(" > mean: " + colStats.mean)
    println(" > min: " + colStats.min)
    println(" > normL1: " + colStats.normL1)
    println(" > normL2: " + colStats.normL2)
    println(" > numNonZeros: " + colStats.numNonzeros)
    println(" > variance: " + colStats.variance)

    //Labeled Points
    
} 
Example 40
Source File: TimeType.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.time.types

import com.twosigma.flint.FlintConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.{ SQLContext, SparkSession, types }

trait TimeType {
  
  def roundDownPrecision(nanosSinceEpoch: Long): Long
}

object TimeType {
  case object LongType extends TimeType {
    override def internalToNanos(value: Long): Long = value
    override def nanosToInternal(nanos: Long): Long = nanos
    override def roundDownPrecision(nanos: Long): Long = nanos
  }

  // Spark sql represents timestamp as microseconds internally
  case object TimestampType extends TimeType {
    override def internalToNanos(value: Long): Long = value * 1000
    override def nanosToInternal(nanos: Long): Long = nanos / 1000
    override def roundDownPrecision(nanos: Long): Long = nanos - nanos % 1000
  }

  def apply(timeType: String): TimeType = {
    timeType match {
      case "long" => LongType
      case "timestamp" => TimestampType
      case _ => throw new IllegalAccessException(s"Unsupported time type: ${timeType}. " +
        s"Only `long` and `timestamp` are supported.")
    }
  }

  def apply(sqlType: types.DataType): TimeType = {
    sqlType match {
      case types.LongType => LongType
      case types.TimestampType => TimestampType
      case _ => throw new IllegalArgumentException(s"Unsupported time type: ${sqlType}")
    }
  }

  def get(sparkSession: SparkSession): TimeType = {
    TimeType(sparkSession.conf.get(
      FlintConf.TIME_TYPE_CONF, FlintConf.TIME_TYPE_DEFAULT
    ))
  }
} 
Example 41
Source File: TimeSeriesRDDConversionSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import java.util.concurrent.TimeUnit

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{ SQLContext, DataFrame, Row }
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.expressions.{ GenericRowWithSchema => ExternalRow }
import org.scalatest.tagobjects.Slow

class TimeSeriesRDDConversionSpec extends TimeSeriesSuite {

  // The largest prime < 100
  override val defaultPartitionParallelism = 97

  // The 10000-th prime.
  private val defaultNumRows = 104729

  private def createDataFrame(isSorted: Boolean = true)(implicit sqlContext: SQLContext): DataFrame = {
    val n = defaultNumRows
    val schema = Schema("value" -> DoubleType)
    val rdd: RDD[Row] = sqlContext.sparkContext.parallelize(1 to n, defaultPartitionParallelism).map { i =>
      val data: Array[Any] = if (isSorted) {
        Array((i / 100).toLong, i.toDouble)
      } else {
        Array(((i + 1 - n) / 100).toLong, i.toDouble)
      }
      new ExternalRow(data, schema)
    }
    sqlContext.createDataFrame(rdd, schema)
  }

  "TimeSeriesRDD" should "convert from a sorted DataFrame correctly" taggedAs (Slow) in {
    implicit val _sqlContext = sqlContext
    (1 to 10).foreach {
      i =>
        val tsRdd = TimeSeriesRDD.fromDF(createDataFrame(isSorted = true))(isSorted = true, TimeUnit.NANOSECONDS)
        assert(tsRdd.count() == defaultNumRows)
    }
    (1 to 10).foreach {
      i =>
        val tsRdd = TimeSeriesRDD.fromDF(createDataFrame(isSorted = true))(isSorted = false, TimeUnit.NANOSECONDS)
        assert(tsRdd.count() == defaultNumRows)
    }
    (1 to 10).foreach {
      i =>
        val tsRdd = TimeSeriesRDD.fromDF(createDataFrame(isSorted = false))(isSorted = false, TimeUnit.NANOSECONDS)
        assert(tsRdd.count() == defaultNumRows)
    }
    (1 to 10).foreach {
      i =>
        val tsRdd = TimeSeriesRDD.fromDF(
          createDataFrame(isSorted = false).sort("time")
        )(
            isSorted = true, TimeUnit.NANOSECONDS
          )
        assert(tsRdd.count() == defaultNumRows)
    }
  }
} 
Example 42
Source File: Utils.scala    From Mastering-Machine-Learning-with-Spark-2.x   with MIT License 5 votes vote down vote up
package com.packtpub.mmlwspark.utils

import org.apache.spark.h2o.H2OContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.expressions.UserDefinedFunction
import water.fvec.H2OFrame


object Utils {
  def colTransform(hf: H2OFrame, udf: UserDefinedFunction, colName: String)(implicit h2oContext: H2OContext, sqlContext: SQLContext): H2OFrame = {
    import sqlContext.implicits._
    val name = hf.key.toString
    val colHf = hf(Array(colName))
    val df = h2oContext.asDataFrame(colHf)
    val result = h2oContext.asH2OFrame(df.withColumn(colName, udf($"${colName}")), s"${name}_${colName}")
    colHf.delete()
    result
  }

  def let[A](in: A)(body: A => Unit) = {
    body(in)
    in
  }
} 
Example 43
Source File: RedisStreamProvider.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.spark.sql.redis.stream

import com.redislabs.provider.redis.util.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSourceProvider}
import org.apache.spark.sql.types.{StringType, StructField, StructType}


class RedisStreamProvider extends DataSourceRegister with StreamSourceProvider with Logging {

  override def shortName(): String = "redis"

  override def sourceSchema(sqlContext: SQLContext, schema: Option[StructType],
                            providerName: String, parameters: Map[String, String]): (String, StructType) = {
    providerName -> schema.getOrElse {
      StructType(Seq(StructField("_id", StringType)))
    }
  }

  override def createSource(sqlContext: SQLContext, metadataPath: String,
                            schema: Option[StructType], providerName: String,
                            parameters: Map[String, String]): Source = {
    val (_, ss) = sourceSchema(sqlContext, schema, providerName, parameters)
    val source = new RedisSource(sqlContext, metadataPath, Some(ss), parameters)
    source.start()
    source
  }
} 
Example 44
Source File: DefaultSource.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.spark.sql.redis

import org.apache.spark.sql.SaveMode.{Append, ErrorIfExists, Ignore, Overwrite}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

class DefaultSource extends RelationProvider with SchemaRelationProvider
  with CreatableRelationProvider {

  override def createRelation(sqlContext: SQLContext,
                              parameters: Map[String, String]): BaseRelation = {
    new RedisSourceRelation(sqlContext, parameters, userSpecifiedSchema = None)
  }

  
  override def createRelation(sqlContext: SQLContext, mode: SaveMode,
                              parameters: Map[String, String], data: DataFrame): BaseRelation = {
    val relation = new RedisSourceRelation(sqlContext, parameters, userSpecifiedSchema = None)
    mode match {
      case Append => relation.insert(data, overwrite = false)
      case Overwrite => relation.insert(data, overwrite = true)
      case ErrorIfExists =>
        if (relation.nonEmpty) {
          throw new IllegalStateException("SaveMode is set to ErrorIfExists and dataframe " +
            "already exists in Redis and contains data.")
        }
        relation.insert(data, overwrite = false)
      case Ignore =>
        if (relation.isEmpty) {
          relation.insert(data, overwrite = false)
        }
    }

    relation
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String],
                              schema: StructType): BaseRelation =
    new RedisSourceRelation(sqlContext, parameters, userSpecifiedSchema = Some(schema))
} 
Example 45
Source File: JdbcExample.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark.ch05

// scalastyle:off println
import java.util.Properties

import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}

object JdbcExample {

  
  def main(args: Seq[String]): Unit = {
    if (args.length != 3) {
      new IllegalArgumentException("Invalid arguments")
      System.exit(1)
    }
    val url = args(0)
    val user = args(1)
    val pass = args(2)

    val conf = new SparkConf().setAppName("JdbcExample")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)

    run(sc, sqlContext, url, user, pass)

    sc.stop()
  }

  def run(sc: SparkContext, sqlContext: SQLContext,
      url: String, user: String, pass: String): Unit = {
    val prop = new Properties()
    prop.setProperty("user", user)
    prop.setProperty("password", pass)

    val df: DataFrame = sqlContext.read.jdbc(url, "gihyo_spark.person", prop)
    df.printSchema()
    println("# Rows: " + df.count())
  }
}
// scalastyle:on println 
Example 46
Source File: DataFrameNaFunctionExample.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark.ch05

// scalastyle:off println
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}


object DataFrameNaFunctionExample {

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setAppName("BasicDataFrameExample")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    run(sc, sqlContext)
    sc.stop()
  }

  def run(
      sc: SparkContext,
      sqlContext: SQLContext): Unit = {
    import sqlContext.implicits._

    val nullDF = Seq[(String, java.lang.Integer, java.lang.Double)](
      ("Bob", 16, 176.5),
      ("Alice", null, 164.3),
      ("", 60, null),
      ("UNKNOWN", 25, Double.NaN),
      ("Amy", null, null),
      (null, null, Double.NaN)
    ).toDF("name", "age", "height")

    // drop
    nullDF.na.drop("any").show()
    nullDF.na.drop("all").show()
    nullDF.na.drop(Array("age")).show()
    nullDF.na.drop(Seq("age", "height")).show()
    nullDF.na.drop("any", Array("name", "age")).show()
    nullDF.na.drop("all", Array("age", "height")).show()

    // fill
    nullDF.na.fill(0.0, Array("name", "height")).show()
    nullDF.na.fill(Map(
      "name" -> "UNKNOWN",
      "height" -> 0.0
    )).show()

    // replace
    nullDF.na.replace("name", Map("" -> "UNKNOWN")).show()
  }
}

// scalastyle:on println 
Example 47
Source File: DatasetExample.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark.ch05

import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.sql.{Dataset, SQLContext}
import org.apache.spark.sql.functions._

private case class Person(id: Int, name: String, age: Int)

object DatasetExample {

  
  def main(args: Seq[String]): Unit = {
    val conf = new SparkConf().setAppName("DatasetExample")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    run(sc, sqlContext)
    sc.stop()
  }

  def run(sc: SparkContext, sqlContext: SQLContext): Unit = {
    import sqlContext.implicits._

    // Creates a Dataset from a `Seq`
    val seq = Seq((1, "Bob", 23), (2, "Tom", 23), (3, "John", 22))
    val ds1: Dataset[(Int, String, Int)] = sqlContext.createDataset(seq)
    val ds2: Dataset[(Int, String, Int)] = seq.toDS()

    // Creates a Dataset from a `RDD`
    val rdd = sc.parallelize(seq)
    val ds3: Dataset[(Int, String, Int)] = sqlContext.createDataset(rdd)
    val ds4: Dataset[(Int, String, Int)] = rdd.toDS()

    // Creates a Dataset from a `DataFrame`
    val df = rdd.toDF("id", "name", "age")
    val ds5: Dataset[Person] = df.as[Person]

    // Selects a column
    ds5.select(expr("name").as[String]).show()

    // Filtering
    ds5.filter(_.name == "Bob").show()
    ds5.filter(person => person.age == 23).show()

    // Groups and counts the number of rows
    ds5.groupBy(_.age).count().show()
  }
} 
Example 48
Source File: TestSparkContext.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark

import org.scalatest.{BeforeAndAfterAll, Suite}

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLContext

private[spark]
trait TestSparkContext extends BeforeAndAfterAll { self: Suite =>
  @transient var sc: SparkContext = _
  @transient var sqlContext: SQLContext = _

  override def beforeAll() {
    super.beforeAll()
    val conf = new SparkConf()
      .setMaster("local[2]")
      .setAppName("SparkUnitTest")
      .set("spark.sql.shuffle.partitions", "2")
    sc = new SparkContext(conf)
    SQLContext.clearActive()
    sqlContext = new SQLContext(sc)
    SQLContext.setActive(sqlContext)
  }

  override def afterAll() {
    try {
      sqlContext = null
      SQLContext.clearActive()
      if (sc != null) {
        sc.stop()
      }
      sc = null
    } finally {
      super.afterAll()
    }
  }
} 
Example 49
Source File: MovieRecommendation.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.MovieRecommendation

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.SQLImplicits
import org.apache.spark.sql._
import org.apache.spark.sql.Dataset
import org.apache.spark.mllib.recommendation.ALS
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import org.apache.spark.mllib.recommendation.Rating
import scala.Tuple2
import org.apache.spark.rdd.RDD

object MovieRecommendation {  
  //Compute the RMSE to evaluate the model. Less the RMSE better the model and it's prediction capability. 
  def computeRmse(model: MatrixFactorizationModel, data: RDD[Rating], implicitPrefs: Boolean): Double = {
    val predictions: RDD[Rating] = model.predict(data.map(x => (x.user, x.product)))
    val predictionsAndRatings = predictions.map { x => ((x.user, x.product), x.rating)
    }.join(data.map(x => ((x.user, x.product), x.rating))).values
    if (implicitPrefs) {
      println("(Prediction, Rating)")
      println(predictionsAndRatings.take(5).mkString("\n"))
    }
    math.sqrt(predictionsAndRatings.map(x => (x._1 - x._2) * (x._1 - x._2)).mean())
  }

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .appName("JavaLDAExample")
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/").
      getOrCreate()

    val ratigsFile = "data/ratings.csv"
    val df1 = spark.read.format("com.databricks.spark.csv").option("header", true).load(ratigsFile)

    val ratingsDF = df1.select(df1.col("userId"), df1.col("movieId"), df1.col("rating"), df1.col("timestamp"))
    ratingsDF.show(false)

    val moviesFile = "data/movies.csv"
    val df2 = spark.read.format("com.databricks.spark.csv").option("header", "true").load(moviesFile)

    val moviesDF = df2.select(df2.col("movieId"), df2.col("title"), df2.col("genres"))
    moviesDF.show(false)

    ratingsDF.createOrReplaceTempView("ratings")
    moviesDF.createOrReplaceTempView("movies")

    

    var rmseTest = computeRmse(model, testRDD, true)
    println("Test RMSE: = " + rmseTest) //Less is better

    //Movie recommendation for a specific user. Get the top 6 movie predictions for user 668
    println("Recommendations: (MovieId => Rating)")
    println("----------------------------------")
    val recommendationsUser = model.recommendProducts(668, 6)
    recommendationsUser.map(rating => (rating.product, rating.rating)).foreach(println)
    println("----------------------------------")

    spark.stop()
  }
} 
Example 50
Source File: XmlReader.scala    From spark-xml   with Apache License 2.0 5 votes vote down vote up
package com.databricks.spark.xml

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, SQLContext, SparkSession}
import org.apache.spark.sql.types.StructType
import com.databricks.spark.xml.util.XmlFile
import com.databricks.spark.xml.util.FailFastMode


  @deprecated("Use xmlFile(SparkSession, ...)", "0.5.0")
  def xmlFile(sqlContext: SQLContext, path: String): DataFrame = {
    // We need the `charset` and `rowTag` before creating the relation.
    val (charset, rowTag) = {
      val options = XmlOptions(parameters.toMap)
      (options.charset, options.rowTag)
    }
    val relation = XmlRelation(
      () => XmlFile.withCharset(sqlContext.sparkContext, path, charset, rowTag),
      Some(path),
      parameters.toMap,
      schema)(sqlContext)
    sqlContext.baseRelationToDataFrame(relation)
  }

  @deprecated("Use xmlRdd(SparkSession, ...)", "0.5.0")
  def xmlRdd(sqlContext: SQLContext, xmlRDD: RDD[String]): DataFrame = {
    val relation = XmlRelation(
      () => xmlRDD,
      None,
      parameters.toMap,
      schema)(sqlContext)
    sqlContext.baseRelationToDataFrame(relation)
  }

} 
Example 51
Source File: DefaultSource.scala    From spark-xml   with Apache License 2.0 5 votes vote down vote up
package com.databricks.spark.xml

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import com.databricks.spark.xml.util.XmlFile


  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      schema: StructType): XmlRelation = {
    val path = checkPath(parameters)
    // We need the `charset` and `rowTag` before creating the relation.
    val (charset, rowTag) = {
      val options = XmlOptions(parameters)
      (options.charset, options.rowTag)
    }

    XmlRelation(
      () => XmlFile.withCharset(sqlContext.sparkContext, path, charset, rowTag),
      Some(path),
      parameters,
      schema)(sqlContext)
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      data: DataFrame): BaseRelation = {
    val path = checkPath(parameters)
    val filesystemPath = new Path(path)
    val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
    val doSave = if (fs.exists(filesystemPath)) {
      mode match {
        case SaveMode.Append =>
          throw new IllegalArgumentException(
            s"Append mode is not supported by ${this.getClass.getCanonicalName}")
        case SaveMode.Overwrite =>
          fs.delete(filesystemPath, true)
          true
        case SaveMode.ErrorIfExists =>
          throw new IllegalArgumentException(s"path $path already exists.")
        case SaveMode.Ignore => false
      }
    } else {
      true
    }
    if (doSave) {
      // Only save data when the save mode is not ignore.
      XmlFile.saveAsXmlFile(data, filesystemPath.toString, parameters)
    }
    createRelation(sqlContext, parameters, data.schema)
  }
} 
Example 52
Source File: HyperLogLog.scala    From spark-hyperloglog   with Apache License 2.0 5 votes vote down vote up
package com.mozilla.spark.sql.hyperloglog.test

import com.mozilla.spark.sql.hyperloglog.aggregates._
import com.mozilla.spark.sql.hyperloglog.functions._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions._
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{FlatSpec, Matchers}

class HyperLogLogTest extends FlatSpec with Matchers{
 "Algebird's HyperLogLog" can "be used from Spark" in {
  val sparkConf = new SparkConf().setAppName("HyperLogLog")
  sparkConf.setMaster(sparkConf.get("spark.master", "local[1]"))

  val sc = new SparkContext(sparkConf)
  val sqlContext = new SQLContext(sc)
  import sqlContext.implicits._

  val hllMerge = new HyperLogLogMerge
  sqlContext.udf.register("hll_merge", hllMerge)
  sqlContext.udf.register("hll_create", hllCreate _)
  sqlContext.udf.register("hll_cardinality", hllCardinality _)

  val frame = sc.parallelize(List("a", "b", "c", "c"), 4).toDF("id")
  val count = frame
    .select(expr("hll_create(id, 12) as hll"))
    .groupBy()
    .agg(expr("hll_cardinality(hll_merge(hll)) as count"))
    .collect()
  count(0)(0) should be (3)
 }
} 
Example 53
Source File: HiSpeedRead.scala    From spark-db2   with Apache License 2.0 5 votes vote down vote up
import com.ibm.spark.ibmdataserver.Constants
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkContext, SparkConf}

object HiSpeedRead {

  def main(args: Array[String]) {
    val DB2_CONNECTION_URL = "jdbc:db2://localhost:50700/sample:traceFile=C:\\1.txt;"

    val conf = new SparkConf().setMaster("local[2]").setAppName("read test")

    val sparkContext = new SparkContext(conf)

    val sqlContext = new SQLContext(sparkContext)

    Class.forName("com.ibm.db2.jcc.DB2Driver")

    val jdbcRdr = sqlContext.read.format("com.ibm.spark.ibmdataserver")
      .option("url", DB2_CONNECTION_URL)
      // .option(Constants.TABLE, tableName)
      .option("user", "pallavipr")
      .option("password", "9manjari")
      .option("dbtable", "employee")
      .load()

    jdbcRdr.show()
  }
} 
Example 54
Source File: DefaultSource.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.hbase

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType

class CustomedDefaultSource
  extends DefaultSource
  with DataSourceRegister
  with SchemaRelationProvider {

  override def shortName(): String = "hbase"

  
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      schema: StructType): BaseRelation = {
    new CustomedHBaseRelation(parameters, Option(schema))(sqlContext)
  }
} 
Example 55
Source File: SparkSQLCLIService.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.IOException
import java.util.{List => JList}
import javax.security.auth.login.LoginException

import scala.collection.JavaConverters._

import org.apache.commons.logging.Log
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hadoop.hive.shims.Utils
import org.apache.hadoop.security.{SecurityUtil, UserGroupInformation}
import org.apache.hive.service.{AbstractService, Service, ServiceException}
import org.apache.hive.service.Service.STATE
import org.apache.hive.service.auth.HiveAuthFactory
import org.apache.hive.service.cli._
import org.apache.hive.service.server.HiveServer2

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._

private[hive] class SparkSQLCLIService(hiveServer: HiveServer2, sqlContext: SQLContext)
  extends CLIService(hiveServer)
  with ReflectedCompositeService {

  override def init(hiveConf: HiveConf) {
    setSuperField(this, "hiveConf", hiveConf)

    val sparkSqlSessionManager = new SparkSQLSessionManager(hiveServer, sqlContext)
    setSuperField(this, "sessionManager", sparkSqlSessionManager)
    addService(sparkSqlSessionManager)
    var sparkServiceUGI: UserGroupInformation = null
    var httpUGI: UserGroupInformation = null

    if (UserGroupInformation.isSecurityEnabled) {
      try {
        val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_PRINCIPAL)
        val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_KERBEROS_KEYTAB)
        if (principal.isEmpty || keyTabFile.isEmpty) {
          throw new IOException(
            "HiveServer2 Kerberos principal or keytab is not correctly configured")
        }

        val originalUgi = UserGroupInformation.getCurrentUser
        sparkServiceUGI = if (HiveAuthFactory.needUgiLogin(originalUgi,
          SecurityUtil.getServerPrincipal(principal, "0.0.0.0"), keyTabFile)) {
          HiveAuthFactory.loginFromKeytab(hiveConf)
          Utils.getUGI()
        } else {
          originalUgi
        }

        setSuperField(this, "serviceUGI", sparkServiceUGI)
      } catch {
        case e @ (_: IOException | _: LoginException) =>
          throw new ServiceException("Unable to login to kerberos with given principal/keytab", e)
      }

      // Try creating spnego UGI if it is configured.
      val principal = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_PRINCIPAL).trim
      val keyTabFile = hiveConf.getVar(ConfVars.HIVE_SERVER2_SPNEGO_KEYTAB).trim
      if (principal.nonEmpty && keyTabFile.nonEmpty) {
        try {
          httpUGI = HiveAuthFactory.loginFromSpnegoKeytabAndReturnUGI(hiveConf)
          setSuperField(this, "httpUGI", httpUGI)
        } catch {
          case e: IOException =>
            throw new ServiceException("Unable to login to spnego with given principal " +
              s"$principal and keytab $keyTabFile: $e", e)
        }
      }
    }

    initCompositeService(hiveConf)
  }

  override def getInfo(sessionHandle: SessionHandle, getInfoType: GetInfoType): GetInfoValue = {
    getInfoType match {
      case GetInfoType.CLI_SERVER_NAME => new GetInfoValue("Spark SQL")
      case GetInfoType.CLI_DBMS_NAME => new GetInfoValue("Spark SQL")
      case GetInfoType.CLI_DBMS_VER => new GetInfoValue(sqlContext.sparkContext.version)
      case _ => super.getInfo(sessionHandle, getInfoType)
    }
  }
}

private[thriftserver] trait ReflectedCompositeService { this: AbstractService =>
  def initCompositeService(hiveConf: HiveConf) {
    // Emulating `CompositeService.init(hiveConf)`
    val serviceList = getAncestorField[JList[Service]](this, 2, "serviceList")
    serviceList.asScala.foreach(_.init(hiveConf))

    // Emulating `AbstractService.init(hiveConf)`
    invoke(classOf[AbstractService], this, "ensureCurrentState", classOf[STATE] -> STATE.NOTINITED)
    setAncestorField(this, 3, "hiveConf", hiveConf)
    invoke(classOf[AbstractService], this, "changeState", classOf[STATE] -> STATE.INITED)
    getAncestorField[Log](this, 3, "LOG").info(s"Service: $getName is inited.")
  }
} 
Example 56
Source File: SparkSQLDriver.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.{ArrayList => JArrayList, Arrays, List => JList}

import scala.collection.JavaConverters._

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}


private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext)
  extends Driver
  with Logging {

  private[hive] var tableSchema: Schema = _
  private[hive] var hiveResponse: Seq[String] = _

  override def init(): Unit = {
  }

  private def getResultSetSchema(query: QueryExecution): Schema = {
    val analyzed = query.analyzed
    logDebug(s"Result Schema: ${analyzed.output}")
    if (analyzed.output.isEmpty) {
      new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null)
    } else {
      val fieldSchemas = analyzed.output.map { attr =>
        new FieldSchema(attr.name, attr.dataType.catalogString, "")
      }

      new Schema(fieldSchemas.asJava, null)
    }
  }

  override def run(command: String): CommandProcessorResponse = {
    // TODO unify the error code
    try {
      context.sparkContext.setJobDescription(command)
      val execution = context.sessionState.executePlan(context.sql(command).logicalPlan)
      hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) {
        execution.hiveResultString()
      }
      tableSchema = getResultSetSchema(execution)
      new CommandProcessorResponse(0)
    } catch {
        case ae: AnalysisException =>
          logDebug(s"Failed in [$command]", ae)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae)
        case cause: Throwable =>
          logError(s"Failed in [$command]", cause)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause)
    }
  }

  override def close(): Int = {
    hiveResponse = null
    tableSchema = null
    0
  }

  override def getResults(res: JList[_]): Boolean = {
    if (hiveResponse == null) {
      false
    } else {
      res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava)
      hiveResponse = null
      true
    }
  }

  override def getSchema: Schema = tableSchema

  override def destroy() {
    super.destroy()
    hiveResponse = null
    tableSchema = null
  }
} 
Example 57
Source File: SparkSQLSessionManager.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.concurrent.Executors

import org.apache.commons.logging.Log
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.SessionHandle
import org.apache.hive.service.cli.session.SessionManager
import org.apache.hive.service.cli.thrift.TProtocolVersion
import org.apache.hive.service.server.HiveServer2

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._
import org.apache.spark.sql.hive.thriftserver.server.SparkSQLOperationManager


private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, sqlContext: SQLContext)
  extends SessionManager(hiveServer)
  with ReflectedCompositeService {

  private lazy val sparkSqlOperationManager = new SparkSQLOperationManager()

  override def init(hiveConf: HiveConf) {
    setSuperField(this, "operationManager", sparkSqlOperationManager)
    super.init(hiveConf)
  }

  override def openSession(
      protocol: TProtocolVersion,
      username: String,
      passwd: String,
      ipAddress: String,
      sessionConf: java.util.Map[String, String],
      withImpersonation: Boolean,
      delegationToken: String): SessionHandle = {
    val sessionHandle =
      super.openSession(protocol, username, passwd, ipAddress, sessionConf, withImpersonation,
          delegationToken)
    val session = super.getSession(sessionHandle)
    HiveThriftServer2.listener.onSessionCreated(
      session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername)
    val ctx = if (sqlContext.conf.hiveThriftServerSingleSession) {
      sqlContext
    } else {
      sqlContext.newSession()
    }
    ctx.setConf(HiveUtils.FAKE_HIVE_VERSION.key, HiveUtils.builtinHiveVersion)
    if (sessionConf != null && sessionConf.containsKey("use:database")) {
      ctx.sql(s"use ${sessionConf.get("use:database")}")
    }
    sparkSqlOperationManager.sessionToContexts.put(sessionHandle, ctx)
    sessionHandle
  }

  override def closeSession(sessionHandle: SessionHandle) {
    HiveThriftServer2.listener.onSessionClosed(sessionHandle.getSessionId.toString)
    super.closeSession(sessionHandle)
    sparkSqlOperationManager.sessionToActivePool.remove(sessionHandle)
    sparkSqlOperationManager.sessionToContexts.remove(sessionHandle)
  }
} 
Example 58
Source File: SparkSQLOperationManager.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver.server

import java.util.{Map => JMap}
import java.util.concurrent.ConcurrentHashMap

import org.apache.hive.service.cli._
import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager}
import org.apache.hive.service.cli.session.HiveSession

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation}
import org.apache.spark.sql.internal.SQLConf


private[thriftserver] class SparkSQLOperationManager()
  extends OperationManager with Logging {

  val handleToOperation = ReflectionUtils
    .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation")

  val sessionToActivePool = new ConcurrentHashMap[SessionHandle, String]()
  val sessionToContexts = new ConcurrentHashMap[SessionHandle, SQLContext]()

  override def newExecuteStatementOperation(
      parentSession: HiveSession,
      statement: String,
      confOverlay: JMap[String, String],
      async: Boolean): ExecuteStatementOperation = synchronized {
    val sqlContext = sessionToContexts.get(parentSession.getSessionHandle)
    require(sqlContext != null, s"Session handle: ${parentSession.getSessionHandle} has not been" +
      s" initialized or had already closed.")
    val conf = sqlContext.sessionState.conf
    val hiveSessionState = parentSession.getSessionState
    setConfMap(conf, hiveSessionState.getOverriddenConfigurations)
    setConfMap(conf, hiveSessionState.getHiveVariables)
    val runInBackground = async && conf.getConf(HiveUtils.HIVE_THRIFT_SERVER_ASYNC)
    val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay,
      runInBackground)(sqlContext, sessionToActivePool)
    handleToOperation.put(operation.getHandle, operation)
    logDebug(s"Created Operation for $statement with session=$parentSession, " +
      s"runInBackground=$runInBackground")
    operation
  }

  def setConfMap(conf: SQLConf, confMap: java.util.Map[String, String]): Unit = {
    val iterator = confMap.entrySet().iterator()
    while (iterator.hasNext) {
      val kv = iterator.next()
      conf.setConfString(kv.getKey, kv.getValue)
    }
  }
} 
Example 59
Source File: SparkSQLEnv.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.io.PrintStream

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION
import org.apache.spark.util.Utils


  def stop() {
    logDebug("Shutting down Spark SQL Environment")
    // Stop the SparkContext
    if (SparkSQLEnv.sparkContext != null) {
      sparkContext.stop()
      sparkContext = null
      sqlContext = null
    }
  }
} 
Example 60
Source File: RedisRelation.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.redis

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.xsql.DataSourceManager._
import org.apache.spark.sql.xsql.execution.datasources.redis.RedisSpecialStrategy

trait RedisRelationTrait {
  val parameters: Map[String, String]
  val schema: StructType
  lazy val redisConfig: RedisConfig = new RedisConfig(new RedisEndpoint(parameters.get(URL).get))
}
case class RedisRelationImpl(val parameters: Map[String, String], val schema: StructType)
  extends RedisRelationTrait

case class RedisRelation(
    parameters: Map[String, String],
    schema: StructType,
    filter: Seq[Expression] = Nil)(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with PrunedScan
  with RedisRelationTrait {

  override def toString: String = s"RedisRelation(${filter.mkString(",")})"

  val partitionNum: Int = parameters.getOrElse("partitionNum", "1").toInt

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    val filters = filter
      .map(RedisSpecialStrategy.getAttr)
      .groupBy(_._1)
      .map(tup => (tup._1, tup._2.map(_._2)))
    new RedisRDD(sqlContext.sparkContext, this, filters, requiredColumns, partitionNum)
  }
} 
Example 61
Source File: PythonSQLUtils.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.api.python

import java.io.InputStream
import java.nio.channels.Channels

import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.PythonRDDServer
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.types.DataType

private[sql] object PythonSQLUtils {
  def parseDataType(typeText: String): DataType = CatalystSqlParser.parseDataType(typeText)

  // This is needed when generating SQL documentation for built-in functions.
  def listBuiltinFunctionInfos(): Array[ExpressionInfo] = {
    FunctionRegistry.functionSet.flatMap(f => FunctionRegistry.builtin.lookupFunction(f)).toArray
  }

  
private[sql] class ArrowRDDServer(sqlContext: SQLContext) extends PythonRDDServer {

  override protected def streamToRDD(input: InputStream): RDD[Array[Byte]] = {
    // Create array to consume iterator so that we can safely close the inputStream
    val batches = ArrowConverters.getBatchesFromStream(Channels.newChannel(input)).toArray
    // Parallelize the record batches to create an RDD
    JavaRDD.fromRDD(sqlContext.sparkContext.parallelize(batches, batches.length))
  }

} 
Example 62
Source File: JdbcRelationProvider.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.jdbc

import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}

class JdbcRelationProvider extends CreatableRelationProvider
  with RelationProvider with DataSourceRegister {

  override def shortName(): String = "jdbc"

  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val resolver = sqlContext.conf.resolver
    val timeZoneId = sqlContext.conf.sessionLocalTimeZone
    val schema = JDBCRelation.getSchema(resolver, jdbcOptions)
    val parts = JDBCRelation.columnPartition(schema, resolver, timeZoneId, jdbcOptions)
    JDBCRelation(schema, parts, jdbcOptions)(sqlContext.sparkSession)
  }

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val options = new JdbcOptionsInWrite(parameters)
    val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis

    val conn = JdbcUtils.createConnectionFactory(options)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, options)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, options)
              val tableSchema = JdbcUtils.getSchemaOption(conn, options)
              saveTable(df, tableSchema, isCaseSensitive, options)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, options.table, options)
              createTable(conn, df, options)
              saveTable(df, Some(df.schema), isCaseSensitive, options)
            }

          case SaveMode.Append =>
            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
            saveTable(df, tableSchema, isCaseSensitive, options)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '${options.table}' already exists. " +
                s"SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(conn, df, options)
        saveTable(df, Some(df.schema), isCaseSensitive, options)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }
} 
Example 63
Source File: HadoopFsRelation.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
import org.apache.spark.sql.types.{StructField, StructType}



case class HadoopFsRelation(
    location: FileIndex,
    partitionSchema: StructType,
    dataSchema: StructType,
    bucketSpec: Option[BucketSpec],
    fileFormat: FileFormat,
    options: Map[String, String])(val sparkSession: SparkSession)
  extends BaseRelation with FileRelation {

  override def sqlContext: SQLContext = sparkSession.sqlContext

  private def getColName(f: StructField): String = {
    if (sparkSession.sessionState.conf.caseSensitiveAnalysis) {
      f.name
    } else {
      f.name.toLowerCase(Locale.ROOT)
    }
  }

  val overlappedPartCols = mutable.Map.empty[String, StructField]
  partitionSchema.foreach { partitionField =>
    if (dataSchema.exists(getColName(_) == getColName(partitionField))) {
      overlappedPartCols += getColName(partitionField) -> partitionField
    }
  }

  // When data and partition schemas have overlapping columns, the output
  // schema respects the order of the data schema for the overlapping columns, and it
  // respects the data types of the partition schema.
  val schema: StructType = {
    StructType(dataSchema.map(f => overlappedPartCols.getOrElse(getColName(f), f)) ++
      partitionSchema.filterNot(f => overlappedPartCols.contains(getColName(f))))
  }

  def partitionSchemaOption: Option[StructType] =
    if (partitionSchema.isEmpty) None else Some(partitionSchema)

  override def toString: String = {
    fileFormat match {
      case source: DataSourceRegister => source.shortName()
      case _ => "HadoopFiles"
    }
  }

  override def sizeInBytes: Long = {
    val compressionFactor = sqlContext.conf.fileCompressionFactor
    (location.sizeInBytes * compressionFactor).toLong
  }


  override def inputFiles: Array[String] = location.inputFiles
} 
Example 64
Source File: package.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import scala.reflect.ClassTag

import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType

package object state {

  implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) {

    
    private[streaming] def mapPartitionsWithStateStore[U: ClassTag](
        stateInfo: StatefulOperatorStateInfo,
        keySchema: StructType,
        valueSchema: StructType,
        indexOrdinal: Option[Int],
        sessionState: SessionState,
        storeCoordinator: Option[StateStoreCoordinatorRef])(
        storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = {

      val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction)
      val wrappedF = (store: StateStore, iter: Iterator[T]) => {
        // Abort the state store in case of error
        TaskContext.get().addTaskCompletionListener[Unit](_ => {
          if (!store.hasCommitted) store.abort()
        })
        cleanedF(store, iter)
      }

      new StateStoreRDD(
        dataRDD,
        wrappedF,
        stateInfo.checkpointLocation,
        stateInfo.queryRunId,
        stateInfo.operatorId,
        stateInfo.storeVersion,
        keySchema,
        valueSchema,
        indexOrdinal,
        sessionState,
        storeCoordinator)
    }
  }
} 
Example 65
Source File: DDLSourceLoadSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name - internal data sources") {
    val e = intercept[AnalysisException] {
      spark.read.format("Fluet da Bomb").load()
    }
    assert(e.getMessage.contains("Multiple sources found for Fluet da Bomb"))
  }

  test("data sources with the same name - internal data source/external data source") {
    assert(spark.read.format("datasource").load().schema ==
      StructType(Seq(StructField("longType", LongType, nullable = false))))
  }

  test("data sources with the same name - external data sources") {
    val e = intercept[AnalysisException] {
      spark.read.format("Fake external source").load()
    }
    assert(e.getMessage.contains("Multiple sources found for Fake external source"))
  }

  test("load data source from format alias") {
    assert(spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false))))
  }

  test("specify full classname with duplicate formats") {
    assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("integerType", IntegerType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceFour extends RelationProvider with DataSourceRegister {

  def shortName(): String = "datasource"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("longType", LongType, nullable = false)))
    }
} 
Example 66
Source File: SharedSparkSession.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import scala.concurrent.duration._

import org.scalatest.{BeforeAndAfterEach, Suite}
import org.scalatest.concurrent.Eventually

import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.internal.SQLConf


  protected override def afterAll(): Unit = {
    try {
      super.afterAll()
    } finally {
      try {
        if (_spark != null) {
          try {
            _spark.sessionState.catalog.reset()
          } finally {
            _spark.stop()
            _spark = null
          }
        }
      } finally {
        SparkSession.clearActiveSession()
        SparkSession.clearDefaultSession()
      }
    }
  }

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    DebugFilesystem.clearOpenStreams()
  }

  protected override def afterEach(): Unit = {
    super.afterEach()
    // Clear all persistent datasets after each test
    spark.sharedState.cacheManager.clearCache()
    // files can be closed from other threads, so wait a bit
    // normally this doesn't take more than 1s
    eventually(timeout(10.seconds), interval(2.seconds)) {
      DebugFilesystem.assertNoOpenStreams()
    }
  }
} 
Example 67
Source File: BlockingSource.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    BlockingSource.latch.await()
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
        Seq[Int]().toDS().toDF()
      }
      override def stop() {}
    }
  }

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}
    }
  }
}

object BlockingSource {
  var latch: CountDownLatch = null
} 
Example 68
Source File: MockSourceProvider.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class MockSourceProvider extends StreamSourceProvider {
  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", MockSourceProvider.fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    MockSourceProvider.sourceProviderFunction()
  }
}

object MockSourceProvider {
  // Function to generate sources. May provide multiple sources if the user implements such a
  // function.
  private var sourceProviderFunction: () => Source = _

  final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = {
    var i = 0
    val sources = source +: otherSources
    sourceProviderFunction = () => {
      val source = sources(i % sources.length)
      i += 1
      source
    }
    try {
      f
    } finally {
      sourceProviderFunction = null
    }
  }
} 
Example 69
Source File: SapHiveContext.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.SparkContext
import org.apache.spark.sql.{CommonSapSQLContext, SQLContext}
import org.apache.spark.sql.execution.CacheManager
import org.apache.spark.sql.execution.ui.SQLListener
import org.apache.spark.sql.hive.client.{ClientInterface, ClientWrapper}


class SapHiveContext(
    @transient sparkContext: SparkContext,
    cacheManager: CacheManager,
    listener: SQLListener,
    @transient execHive: ClientWrapper,
    @transient metaHive: ClientInterface,
    isRootContext: Boolean)
  extends ExtendableHiveContext(
    sparkContext,
    cacheManager,
    listener,
    execHive,
    metaHive,
    isRootContext)
  with CommonSapSQLContext {

  def this(sc: SparkContext) =
    this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), null, null, true)

  override def newSession(): HiveContext =
    new SapHiveContext(
      sparkContext = this.sparkContext,
      cacheManager = this.cacheManager,
      listener = this.listener,
      executionHive.newSession(),
      metadataHive.newSession(),
      isRootContext = false)
} 
Example 70
Source File: BasicCurrencyConversionFunction.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.currency.basic

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.currency._
import org.apache.spark.sql.util.ValidatingPropertyMap._

import scala.util.Try

protected object BasicCurrencyConversionConfig {

  
  private def updateRatesMapByTable(ratesTable: String, sqlContext: SQLContext): Unit = {
    val ratesTableData = sqlContext.sql(s"SELECT * FROM $ratesTable").collect()
    ratesTableData.foreach { row =>
      val from = row.getString(0)
      val to = row.getString(1)
      val date = row.getString(2).replaceAll("-", "").toInt
      val rate =
        Try(row.getDecimal(3)).recover {
          case ex: ClassCastException => new java.math.BigDecimal(row.getDouble(3))
        }.get
      ratesMap.put((from, to), date, rate)
    }
  }
} 
Example 71
Source File: CreateTableUsingTemporaryAwareCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.TableIdentifierUtils._
import org.apache.spark.sql.catalyst.CaseSensitivityUtils._


  // scalastyle:off method.length
  private def resolveDataSource(sqlContext: SQLContext,
                                dataSource: Any,
                                tableId: TableIdentifier): ResolvedDataSource = {
    dataSource match {
      case drp: PartitionedRelationProvider =>
        if (userSpecifiedSchema.isEmpty) {
          new ResolvedDataSource(drp.getClass,
            drp.createRelation(
              sqlContext,
              tableId.toSeq,
              new CaseInsensitiveMap(options),
              partitioningFunction,
              partitioningColumns,
              isTemporary,
              allowExisting))
        } else {
          new ResolvedDataSource(drp.getClass,
            drp.createRelation(
              sqlContext,
              tableId.toSeq,
              new CaseInsensitiveMap(options),
              userSpecifiedSchema.get,
              partitioningFunction,
              partitioningColumns,
              isTemporary,
              allowExisting))
        }
      case drp: TemporaryAndPersistentSchemaRelationProvider if userSpecifiedSchema.nonEmpty =>
            new ResolvedDataSource(drp.getClass,
              drp.createRelation(
                sqlContext,
                tableId.toSeq,
                new CaseInsensitiveMap(options),
                userSpecifiedSchema.get,
                isTemporary,
                allowExisting))
      case drp: TemporaryAndPersistentRelationProvider =>
        new ResolvedDataSource(drp.getClass,
          drp.createRelation(
            sqlContext,
            tableId.toSeq,
            new CaseInsensitiveMap(options),
            isTemporary,
            allowExisting))
      case _ => ResolvedDataSource(sqlContext, userSpecifiedSchema,
        partitionColumns, provider, options)
    }
  }
  // scalastyle:on method.length
} 
Example 72
Source File: SqlContextAccessor.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import scala.language.implicitConversions


object SqlContextAccessor {
  implicit def sqlContextToCatalogAccessable(sqlContext: SQLContext): SqlContextCatalogAccessor =
    new SqlContextCatalogAccessor(sqlContext)

  class SqlContextCatalogAccessor(sqlContext: SQLContext)
    extends SQLContext(sqlContext.sparkContext) {

    def registerRawPlan(lp: LogicalPlan, tableName: String): Unit = {
      sqlContext.catalog.registerTable(TableIdentifier(tableName), lp)
    }
  }
} 
Example 73
Source File: DropRunnableCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.Logging
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.DropRelation
import org.apache.spark.sql.{Row, SQLContext}

import scala.util.Try


private[sql] case class DropRunnableCommand(toDrop: Map[String, Option[DropRelation]])
  extends RunnableCommand
  with Logging {

  override def run(sqlContext: SQLContext): Seq[Row] = {
    toDrop.foreach {
      case (name, dropOption) =>
        sqlContext.dropTempTable(name)
        dropOption.foreach { dropRelation =>
          Try {
            dropRelation.dropTable()
          }.recover {
            // When the provider indicates an exception while dropping, we still have to continue
            // dropping all the referencing tables, otherwise there could be integrity issues
            case ex =>
              logWarning(
                s"""Error occurred when dropping table '$name':${ex.getMessage}, however
                   |table '$name' will still be dropped from Spark catalog.
                 """.stripMargin)
          }.get
        }
    }
    Seq.empty
  }
} 
Example 74
Source File: ShowTablesUsingCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.sources.DatasourceCatalog
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types.{StringType, StructField, StructType}


private[sql]
case class ShowTablesUsingCommand(provider: String, options: Map[String, String])
  extends LogicalPlan
  with RunnableCommand {

  override def output: Seq[Attribute] = StructType(
    StructField("TABLE_NAME", StringType, nullable = false) ::
    StructField("IS_TEMPORARY", StringType, nullable = false) ::
    StructField("KIND", StringType, nullable = false) ::
    Nil
  ).toAttributes

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val dataSource: Any = DatasourceResolver.resolverFor(sqlContext).newInstanceOf(provider)

    dataSource match {
      case describableRelation: DatasourceCatalog =>
        describableRelation
          .getRelations(sqlContext, new CaseInsensitiveMap(options))
          .map(relationInfo => Row(
            relationInfo.name,
            relationInfo.isTemporary.toString.toUpperCase,
            relationInfo.kind.toUpperCase))
      case _ =>
        throw new RuntimeException(s"The provided data source $provider does not support " +
        "showing its relations.")
    }
  }
} 
Example 75
Source File: DeepDescribeCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.describable.Describable
import org.apache.spark.sql.sources.describable.FieldLike.StructFieldLike
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}


private[sql] case class DeepDescribeCommand(
    relation: Describable)
  extends RunnableCommand {

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val description = relation.describe()
    Seq(description match {
      case r: Row => r
      case default => Row(default)
    })
  }

  override def output: Seq[Attribute] = {
    relation.describeOutput match {
      case StructType(fields) =>
        fields.map(StructFieldLike.toAttribute)
      case other =>
        AttributeReference("value", other)() :: Nil
    }
  }
} 
Example 76
Source File: DescribeTableUsingCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.TableIdentifierUtils._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.{DatasourceCatalog, RelationInfo}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}


private[sql]
case class DescribeTableUsingCommand(
    name: TableIdentifier,
    provider: String,
    options: Map[String, String])
  extends LogicalPlan
  with RunnableCommand {

  override def output: Seq[Attribute] = StructType(
    StructField("TABLE_NAME", StringType, nullable = false) ::
    StructField("DDL_STMT", StringType, nullable = false) ::
    Nil
  ).toAttributes

  override def run(sqlContext: SQLContext): Seq[Row] = {
    // Convert the table name according to the case-sensitivity settings
    val tableId = name.toSeq
    val resolver = DatasourceResolver.resolverFor(sqlContext)
    val catalog = resolver.newInstanceOfTyped[DatasourceCatalog](provider)

    Seq(catalog
      .getRelation(sqlContext, tableId, new CaseInsensitiveMap(options)) match {
        case None => Row("", "")
        case Some(RelationInfo(relName, _, _, ddl, _)) => Row(
          relName, ddl.getOrElse(""))
    })
  }
} 
Example 77
Source File: RawDDLCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.sources.RawDDLObjectType.RawDDLObjectType
import org.apache.spark.sql.sources.RawDDLStatementType.RawDDLStatementType
import org.apache.spark.sql.sources.{RawSqlSourceProvider}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.execution.RunnableCommand


private[sql] case class RawDDLCommand(
    identifier: String,
    objectType: RawDDLObjectType,
    statementType: RawDDLStatementType,
    sparkSchema: Option[StructType],
    ddlStatement: String,
    provider: String,
    options: Map[String, String])
  extends RunnableCommand {

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val dataSource: Any = ResolvedDataSource.lookupDataSource(provider).newInstance()

    dataSource match {
      case rsp: RawSqlSourceProvider =>
        rsp.executeDDL(identifier, objectType, statementType, sparkSchema, ddlStatement, options)
      case _ => throw new RuntimeException("The provided datasource does not support " +
        "executing raw DDL statements.")
    }
    Seq.empty[Row]
  }
} 
Example 78
Source File: RegisterAllTablesCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.CaseSensitivityUtils._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.execution.datasources.SqlContextAccessor._
import org.apache.spark.sql.sources.{LogicalPlanSource, RegisterAllTableRelations}
import org.apache.spark.sql.util.CollectionUtils._
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}


    relations.map {
      case (name, source) =>
        val lp = source.logicalPlan(sqlContext)
        if (lp.resolved) {
          sqlContext.validatedSchema(lp.schema).recover {
            case d: DuplicateFieldsException =>
              throw new RuntimeException(
                s"Provider '$provider' returned a relation that has duplicate fields.",
                d)
          }.get
        } else {
          // TODO(AC): With the new view interface, this can be checked
          logWarning(s"Adding relation $name with potentially unreachable fields.")
        }
        name -> lp
    }.foreach {
      case (name, plan) =>
        sqlContext.registerRawPlan(plan, name)
    }
  }
} 
Example 79
Source File: AbstractViewCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources.{AbstractViewProvider, ViewKind}
import org.apache.spark.sql.{DatasourceResolver, SQLContext}


  def withValidProvider[B](sqlContext: SQLContext)(b: AbstractViewProvider[_] => B): B = {
    val resolver = DatasourceResolver.resolverFor(sqlContext)
    AbstractViewProvider.matcherFor(kind)(resolver.newInstanceOf(provider)) match {
      case Some(viewProvider) =>
        b(viewProvider)
      case _ =>
        throw new ProviderException(provider, "Does not support the " +
          s"execution of ${this.getClass.getSimpleName}")
    }
  }
}

class ProviderException(val provider: String, val reason: String)
  extends Exception(s"Exception using provider $provider: $reason") 
Example 80
Source File: RegisterTableCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.catalyst.CaseSensitivityUtils._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.execution.datasources.SqlContextAccessor._
import org.apache.spark.sql.sources.RegisterAllTableRelations
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}


      val relation = resolvedProvider.getTableRelation(tableName, sqlContext, options)

      relation match {
        case None =>
          sys.error(s"Relation $tableName was not found in the catalog.")
        case Some(r) =>
          val lp = r.logicalPlan(sqlContext)
          if (lp.resolved) {
            sqlContext.validatedSchema(lp.schema).recover {
              case d: DuplicateFieldsException =>
                throw new RuntimeException(
                  s"Provider '$provider' returned a relation that has duplicate fields.",
                  d)
            }.get
          } else {
            // TODO(AC): With the new view interface, this can be checked
            logWarning(s"Adding relation $tableName with potentially unreachable fields.")
          }
          sqlContext.registerRawPlan(lp, tableName)
      }
    }
    Seq.empty
  }
} 
Example 81
Source File: CreateTableStrategy.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{DatasourceResolver, SQLContext, Strategy}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{ExecutedCommand, SparkPlan}
import org.apache.spark.sql.sources.TemporaryAndPersistentNature


private[sql] case class CreateTableStrategy(sqlContext: SQLContext) extends Strategy {

  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    // Currently we only handle cases where the user wants to instantiate a
    // persistent relation any other cases has to be handled by the datasource itself
    case CreateTableUsing(tableName,
        userSpecifiedSchema, provider, temporary, options, allowExisting, _) =>
      DatasourceResolver.resolverFor(sqlContext).newInstanceOf(provider) match {
        case _: TemporaryAndPersistentNature =>
          ExecutedCommand(CreateTableUsingTemporaryAwareCommand(tableName,
            userSpecifiedSchema,
            Array.empty[String],
            None,
            None,
            provider,
            options,
            temporary,
            allowExisting)) :: Nil
        case _ => Nil
      }

    case CreateTablePartitionedByUsing(tableId, userSpecifiedSchema, provider,
    partitioningFunction, partitioningColumns, temporary, options, allowExisting, _) =>
      ResolvedDataSource.lookupDataSource(provider).newInstance() match {
        case _: TemporaryAndPersistentNature =>
          ExecutedCommand(CreateTableUsingTemporaryAwareCommand(
            tableId,
            userSpecifiedSchema,
            Array.empty[String],
            Some(partitioningFunction),
            Some(partitioningColumns),
            provider,
            options,
            isTemporary = false,
            allowExisting)) :: Nil
        case _ => Nil
      }
    case _ => Nil
  }
} 
Example 82
Source File: DescCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.commands.hive

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{Row, SQLContext}


case class DescCommand(ident: TableIdentifier) extends HiveRunnableCommand {

  override protected val commandName: String = s"DESC $ident"

  override def execute(sqlContext: SQLContext): Seq[Row] = {
    val plan = sqlContext.catalog.lookupRelation(ident)
    if (plan.resolved) {
      plan.schema.map { field =>
        Row(field.name, field.dataType.simpleString, None)
      }
    } else {
      Seq.empty
    }
  }

  override lazy val output: Seq[Attribute] =
    AttributeReference("col_name", StringType)() ::
    AttributeReference("data_type", StringType)() ::
    AttributeReference("comment", StringType)() :: Nil
} 
Example 83
Source File: inferSchemaCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.commands

import org.apache.spark.sql.catalyst.analysis.systables.SchemaEnumeration
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.execution.tablefunctions.DataTypeExtractor
import org.apache.spark.sql.hive.orc.OrcRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}


case class InferSchemaCommand(path: String, fileType: FileType) extends RunnableCommand {
  override lazy val output: Seq[Attribute] = InferSchemaCommand.schema.toAttributes

  override def run(sqlContext: SQLContext): Seq[Row] = {
    val fileSchema = fileType.readSchema(sqlContext, path)
    fileSchema.zipWithIndex.map {
      case (StructField(name, dataType, nullable, _), idx) =>
        val dataTypeExtractor = DataTypeExtractor(dataType)
        Row(
          name,
          idx + 1, // idx + 1 since the ordinal position has to start at 1
          nullable,
          dataTypeExtractor.inferredSqlType,
          dataTypeExtractor.numericPrecision.orNull,
          dataTypeExtractor.numericPrecisionRadix.orNull,
          dataTypeExtractor.numericScale.orNull)
    }
  }
}

object InferSchemaCommand extends SchemaEnumeration {
  val name = Field("COLUMN_NAME", StringType, nullable = false)
  val ordinalPosition = Field("ORDINAL_POSITION", IntegerType, nullable = false)
  val isNullable = Field("IS_NULLABLE", BooleanType, nullable = false)
  val dataType = Field("DATA_TYPE", StringType, nullable = false)
  val numericPrecision = Field("NUMERIC_PRECISION", IntegerType, nullable = true)
  val numericPrecisionRadix = Field("NUMERIC_PRECISION_RADIX", IntegerType, nullable = true)
  val numericScale = Field("NUMERIC_SCALE", IntegerType, nullable = true)
} 
Example 84
Source File: ShowPartitionFunctionsUsingCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources.commands

import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DatasourceResolver, DefaultDatasourceResolver, Row, SQLContext}

case class ShowPartitionFunctionsUsingCommand(
    provider: String,
    options: Map[String, String])
  extends RunnableCommand {

  def run(sqlContext: SQLContext): Seq[Row] = {
    val resolver = DatasourceResolver.resolverFor(sqlContext)
    val pFunProvider = resolver.newInstanceOfTyped[PartitioningFunctionProvider](provider)
    val pFuns = pFunProvider.getAllPartitioningFunctions(sqlContext, options)

    pFuns.map { fun =>
      val (splittersOpt, rightClosedOpt) = fun match {
        case RangeSplitPartitioningFunction(_, _, splitters, rightClosed) =>
          (Some(splitters), Some(rightClosed))
        case _ =>
          (None, None)
      }
      val (startOpt, endOpt, intervalTypeOpt, intervalValueOpt) = fun match {
        case RangeIntervalPartitioningFunction(_, _, start, end, strideParts) =>
          (Some(start), Some(end), Some(strideParts.productPrefix), Some(strideParts.n))
        case _ =>
          (None, None, None, None)
      }
      val partitionsNoOpt = fun match {
        case HashPartitioningFunction(_, _, partitionsNo) =>
          partitionsNo
        case s: SimpleDataType =>
          None
      }
      Row(fun.name, fun.productPrefix, fun.dataTypes.map(_.toString).mkString(","),
        splittersOpt.map(_.mkString(",")).orNull, rightClosedOpt.orNull, startOpt.orNull,
        endOpt.orNull, intervalTypeOpt.orNull, intervalValueOpt.orNull, partitionsNoOpt.orNull)
    }
  }

  override lazy val output: Seq[Attribute] = StructType(
    StructField("name", StringType, nullable = false) ::
      StructField("kind", StringType, nullable = false) ::
      StructField("dataTypes", StringType, nullable = false) ::
      StructField("splitters", StringType, nullable = true) ::
      StructField("rightClosed", BooleanType, nullable = true) ::
      StructField("start", IntegerType, nullable = true) ::
      StructField("end", IntegerType, nullable = true) ::
      StructField("intervalType", StringType, nullable = true) ::
      StructField("intervalValue", IntegerType, nullable = true) ::
      StructField("partitionsNo", IntegerType, nullable = true) :: Nil
  ).toAttributes
} 
Example 85
Source File: PartitionedRelationProvider.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.StructType


trait PartitionedRelationProvider
  extends SchemaRelationProvider
  with TemporaryAndPersistentNature {

  def createRelation(sqlContext: SQLContext,
                     tableName: Seq[String],
                     parameters: Map[String, String],
                     partitioningFunction: Option[String],
                     partitioningColumns: Option[Seq[String]],
                     isTemporary: Boolean,
                     allowExisting: Boolean): BaseRelation

  def createRelation(sqlContext: SQLContext,
                     tableName: Seq[String],
                     parameters: Map[String, String],
                     schema: StructType,
                     partitioningFunction: Option[String],
                     partitioningColumns: Option[Seq[String]],
                     isTemporary: Boolean,
                     allowExisting: Boolean): BaseRelation

} 
Example 86
Source File: AbstractViewProvider.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.logical.view.{AbstractView, Persisted}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import scala.reflect._


  def name: String
}

abstract class BaseAbstractViewProvider[A <: AbstractView with Persisted: ClassTag]
  extends AbstractViewProvider[A] {
  val tag = implicitly[ClassTag[A]]
}

object AbstractViewProvider {
  def matcherFor(kind: ViewKind)(any: Any): Option[AbstractViewProvider[_]] = {
    val multiProvider = MultiAbstractViewProvider.matcherFor(kind)
    any match {
      case provider: AbstractViewProvider[_] if tagMatches(provider.tag) =>
        Some(provider)
      case multiProvider(provider) =>
        Some(provider)
      case _ => None
    }
  }

  private def tagMatches[A: ClassTag](tag: ClassTag[_]): Boolean = {
    classTag[A].runtimeClass.isAssignableFrom(tag.runtimeClass)
  }
}

case class CreateViewInput(
    sqlContext: SQLContext,
    plan: LogicalPlan,
    viewSql: String,
    options: Map[String, String],
    identifier: TableIdentifier,
    allowExisting: Boolean)

case class DropViewInput(
    sqlContext: SQLContext,
    options: Map[String, String],
    identifier: TableIdentifier,
    allowNotExisting: Boolean) 
Example 87
Source File: RawSqlSourceProvider.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import java.util.concurrent.atomic.AtomicReference

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.execution.{PhysicalRDD, RDDConversions, SparkPlan}
import org.apache.spark.sql.sources.RawDDLObjectType.RawDDLObjectType
import org.apache.spark.sql.sources.RawDDLStatementType.RawDDLStatementType
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext}

case object RawDDLObjectType {

  sealed trait RawDDLObjectType {
    val name: String
    override def toString: String = name
  }

  sealed abstract class BaseRawDDLObjectType(val name: String) extends RawDDLObjectType
  sealed trait RawData

  case object PartitionFunction extends BaseRawDDLObjectType("partition function")
  case object PartitionScheme   extends BaseRawDDLObjectType("partition scheme")
  case object Collection        extends BaseRawDDLObjectType("collection") with RawData
  case object Series            extends BaseRawDDLObjectType("table") with RawData
  case object Graph             extends BaseRawDDLObjectType("graph") with RawData
}

case object RawDDLStatementType {

  sealed trait RawDDLStatementType

  case object Create extends RawDDLStatementType
  case object Drop   extends RawDDLStatementType
  case object Append extends RawDDLStatementType
  case object Load   extends RawDDLStatementType
}


  protected def calculateSchema(): StructType
} 
Example 88
Source File: LogicalPlanSource.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.{DataFrame, SQLContext}


case class CreatePersistentViewSource(createViewStatement: String, handle: ViewHandle)
  extends LogicalPlanSource {

  def logicalPlan(sqlContext: SQLContext): LogicalPlan = {
    sqlContext.parseSql(createViewStatement) match {
      // This might seem repetitive but in the future the commands might drastically differ
      case CreatePersistentViewCommand(kind, _, plan, _, provider, _, _) =>
        kind.createPersisted(plan, handle, provider)

      case unknown =>
        throw new RuntimeException(s"Could not extract view query from $unknown")
    }
  }
} 
Example 89
Source File: TemporaryAndPersistentSchemaRelationProvider.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._


trait TemporaryAndPersistentSchemaRelationProvider
  extends SchemaRelationProvider
  with TemporaryAndPersistentNature {

  def createRelation(sqlContext: SQLContext,
                     tableName: Seq[String],
                     parameters: Map[String, String],
                     schema: StructType,
                     isTemporary: Boolean,
                     allowExisting: Boolean): BaseRelation
} 
Example 90
Source File: ResolveSelectUsing.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SelectUsing, UnresolvedSelectUsing}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.sources.RawSqlSourceProvider
import org.apache.spark.sql.{DatasourceResolver, SQLContext}


private[sql] case class ResolveSelectUsing(sqlContext: SQLContext) extends Rule[LogicalPlan] {

  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case UnresolvedSelectUsing(sqlCommand, provider, expectedSchema, options) => {
      val resolver = DatasourceResolver.resolverFor(sqlContext)
      val rawSqlProvider = resolver.newInstanceOfTyped[RawSqlSourceProvider](provider)
      val execution = rawSqlProvider.executionOf(sqlContext, options, sqlCommand, expectedSchema)
      SelectUsing(execution)
    }
  }

} 
Example 91
Source File: ResolveInferSchemaCommand.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.sources.commands.{InferSchemaCommand, Orc, Parquet, UnresolvedInferSchemaCommand}

import scala.util.Try


case class ResolveInferSchemaCommand(sqlContext: SQLContext) extends Rule[LogicalPlan] {
  override def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
    case UnresolvedInferSchemaCommand(path, explicitFileType) =>
      val fileType = explicitFileType.getOrElse(path.toLowerCase match {
        case p if p.endsWith(".orc") => Orc
        case p if p.endsWith(".parquet") => Parquet
        case invalid =>
          throw new AnalysisException(s"Could not determine file format of '$path'")
      })
      InferSchemaCommand(path, fileType)
  }
} 
Example 92
Source File: dependenciesSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableDependencyCalculator
import org.apache.spark.sql.sources.{RelationKind, Table}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}

object DependenciesSystemTableProvider extends SystemTableProvider with LocalSpark {
  
  override def execute(): Seq[Row] = {
    val tables = getTables(sqlContext.catalog)
    val dependentsMap = buildDependentsMap(tables)

    def kindOf(tableIdentifier: TableIdentifier): String =
      tables
        .get(tableIdentifier)
        .map(plan => RelationKind.kindOf(plan).getOrElse(Table).name)
        .getOrElse(DependenciesSystemTable.UnknownType)
        .toUpperCase

    dependentsMap.flatMap {
      case (tableIdent, dependents) =>
        val curKind = kindOf(tableIdent)
        dependents.map { dependent =>
          val dependentKind = kindOf(dependent)
          Row(
            tableIdent.database.orNull,
            tableIdent.table,
            curKind,
            dependent.database.orNull,
            dependent.table,
            dependentKind,
            ReferenceDependency.id)
        }
    }.toSeq
  }

  override val schema: StructType = DependenciesSystemTable.schema
}

object DependenciesSystemTable extends SchemaEnumeration {
  val baseSchemaName = Field("BASE_SCHEMA_NAME", StringType, nullable = true)
  val baseObjectName = Field("BASE_OBJECT_NAME", StringType, nullable = false)
  val baseObjectType = Field("BASE_OBJECT_TYPE", StringType, nullable = false)
  val dependentSchemaName = Field("DEPENDENT_SCHEMA_NAME", StringType, nullable = true)
  val dependentObjectName = Field("DEPENDENT_OBJECT_NAME", StringType, nullable = false)
  val dependentObjectType = Field("DEPENDENT_OBJECT_TYPE", StringType, nullable = false)
  val dependencyType = Field("DEPENDENCY_TYPE", IntegerType, nullable = false)

  private[DependenciesSystemTable] val UnknownType = "UNKNOWN"
} 
Example 93
Source File: partitionFunctionSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.sql.execution.tablefunctions.OutputFormatter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.GenericUtil._


  private def typeNameOf(f: PartitionFunction): String = f match {
    case _: RangePartitionFunction => "RANGE"
    case _: BlockPartitionFunction => "BLOCK"
    case _: HashPartitionFunction => "HASH"
  }
}

object PartitionFunctionSystemTable extends SchemaEnumeration {
  val id = Field("ID", StringType, nullable = false)
  val functionType = Field("TYPE", StringType, nullable = false)
  val columnName = Field("COLUMN_NAME", StringType, nullable = false)
  val columnType = Field("COLUMN_TYPE", StringType, nullable = false)
  val boundaries = Field("BOUNDARIES", StringType, nullable = true)
  val block = Field("BLOCK_SIZE", IntegerType, nullable = true)
  val partitions = Field("PARTITIONS", IntegerType, nullable = true)
  val minP = Field("MIN_PARTITIONS", IntegerType, nullable = true)
  val maxP = Field("MAX_PARTITIONS", IntegerType, nullable = true)
} 
Example 94
Source File: sessionSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{Row, SQLConf, SQLContext}


  private def allSettingsOf(conf: SQLConf): Map[String, String] = {
    val setConfs = conf.getAllConfs
    val defaultConfs = conf.getAllDefinedConfs.collect {
      case (key, default, _) if !setConfs.contains(key) => key -> default
    }
    setConfs ++ defaultConfs
  }

  override def schema: StructType = SessionSystemTable.schema
}

object SessionSystemTable extends SchemaEnumeration {
  val section = Field("SECTION", StringType, nullable = false)
  val key = Field("KEY", StringType, nullable = false)
  val value = Field("VALUE", StringType, nullable = true)
} 
Example 95
Source File: systemTableRegistry.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.analysis.ResolveSystemTables
import org.apache.spark.sql.catalyst.plans.logical.{UnresolvedProviderBoundSystemTable, UnresolvedSparkLocalSystemTable, UnresolvedSystemTable}


object SystemTableRegistry extends SimpleSystemTableRegistry {
  register("tables", TablesSystemTableProvider)
  register("object_dependencies", DependenciesSystemTableProvider)
  register("table_metadata", MetadataSystemTableProvider)
  register("schemas", SchemaSystemTableProvider)
  register("session_context", SessionSystemTableProvider)
  register("partition_functions", PartitionFunctionSystemTableProvider)
  register("partition_schemes", PartitionSchemeSystemTableProvider)
  register("relation_sql_name", RelationMappingSystemTableProvider)
} 
Example 96
Source File: tablesSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.commands.WithOrigin
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.util.CollectionUtils.CaseInsensitiveMap
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.CaseSensitivityUtils._

object TablesSystemTableProvider extends SystemTableProvider with LocalSpark with ProviderBound {
  
  override def buildScan(requiredColumns: Array[String],
                         filters: Array[Filter]): RDD[Row] =
    DatasourceResolver
      .resolverFor(sqlContext)
      .newInstanceOfTyped[DatasourceCatalog](provider) match {
      case catalog: DatasourceCatalog with DatasourceCatalogPushDown =>
        catalog.getRelations(sqlContext, options, requiredColumns, filters.toSeq.merge)
      case catalog: DatasourceCatalog =>
        val values =
          catalog
            .getRelations(sqlContext, new CaseInsensitiveMap(options))
            .map(relationInfo => Row(
              relationInfo.name,
              relationInfo.isTemporary.toString.toUpperCase,
              relationInfo.kind.toUpperCase,
              relationInfo.provider))
        val rows = schema.buildPrunedFilteredScan(requiredColumns, filters)(values)
        sparkContext.parallelize(rows)
    }
}

sealed trait TablesSystemTable extends SystemTable {
  override def schema: StructType = TablesSystemTable.schema
}

object TablesSystemTable extends SchemaEnumeration {
  val tableName = Field("TABLE_NAME", StringType, nullable = false)
  val isTemporary = Field("IS_TEMPORARY", StringType, nullable = false)
  val kind = Field("KIND", StringType, nullable = false)
  val provider = Field("PROVIDER", StringType, nullable = true)
} 
Example 97
Source File: metadataSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.tablefunctions.OutputFormatter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.catalyst.CaseSensitivityUtils._


  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] =
    DatasourceResolver
      .resolverFor(sqlContext)
      .newInstanceOfTyped[MetadataCatalog](provider) match {
      case catalog: MetadataCatalog with MetadataCatalogPushDown =>
        catalog.getTableMetadata(sqlContext, options, requiredColumns, filters.toSeq.merge)
      case catalog =>
        val rows = catalog.getTableMetadata(sqlContext, options).flatMap { tableMetadata =>
          val formatter = new OutputFormatter(tableMetadata.tableName, tableMetadata.metadata)
          formatter.format().map(Row.fromSeq)
        }
        sparkContext.parallelize(schema.buildPrunedFilteredScan(requiredColumns, filters)(rows))
    }

  override def schema: StructType = MetadataSystemTable.schema
}

object MetadataSystemTable extends SchemaEnumeration {
  val tableName = Field("TABLE_NAME", StringType, nullable = false)
  val metadataKey = Field("METADATA_KEY", StringType, nullable = true)
  val metadataValue = Field("METADATA_VALUE", StringType, nullable = true)
} 
Example 98
Source File: relationMappingSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{Row, SQLContext}

object RelationMappingSystemTableProvider extends SystemTableProvider with LocalSpark {

  
  override def execute(): Seq[Row] = {
    sqlContext.tableNames().map { tableName =>
      val plan = sqlContext.catalog.lookupRelation(TableIdentifier(tableName))
      val sqlName = plan.collectFirst {
        case s: SqlLikeRelation =>
          s.relationName
        case LogicalRelation(s: SqlLikeRelation, _) =>
          s.relationName
      }
      Row(tableName, sqlName)
    }
  }
}

object RelationMappingSystemTable extends SchemaEnumeration {
  val sparkName = Field("RELATION_NAME", StringType, nullable = false)
  val providerName = Field("SQL_NAME", StringType, nullable = true)
} 
Example 99
Source File: CaseSensitivityUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.analysis.{Analyzer, Catalog}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.util.CollectionUtils._

import scala.util.{Failure, Success, Try}


  case class DuplicateFieldsException(
      originalSchema: StructType,
      schema: StructType,
      duplicateFields: Set[String])
    extends RuntimeException(
      s"""Given schema contains duplicate fields after applying case sensitivity rules:
         |${duplicateFields.mkString(", ")}
         |Given schema:
         |$originalSchema
         |After applying case sensitivity rules:
         |$schema
       """.stripMargin)
} 
Example 100
Source File: SQLRunner.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package com.sap.spark.cli

import java.io._

import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.{Logging, SparkContext}

import scala.annotation.tailrec

protected[cli] case class CLIOptions(
    sqlFiles: List[String] = Nil, output: Option[String] = None)


  def main(args: Array[String]): Unit = {
    def fail(msg: String = USAGE): Unit = {
      logError(msg)
      System.exit(1)
    }

    val opts = parseOpts(args.toList)

    val outputStream: OutputStream = opts.output match {
      case Some(filename) => new FileOutputStream(new File(filename))
      case None => System.out
    }

    opts.sqlFiles
      .map((string: String) => new FileInputStream(new File(string)))
      .foreach(sql(_, outputStream))
  }
} 
Example 101
Source File: MockedDefaultSourceSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import java.util.concurrent.{Callable, Executors}

import com.sap.spark.dsmock.DefaultSource
import org.apache.spark.sql.sources.HashPartitioningFunction
import org.apache.spark.sql.{GlobalSapSQLContext, Row, SQLContext}
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.FunSuite

import scala.concurrent.duration._


class MockedDefaultSourceSuite
  extends FunSuite
  with GlobalSapSQLContext {

  val testTimeout = 10 // seconds

  private def numberOfThreads: Int = {
    val noOfCores = Runtime.getRuntime.availableProcessors()
    assert(noOfCores > 0)

    if (noOfCores == 1) 2 // It should always be multithreaded although only
                          // one processor is available (pseudo-multithreading)
    else noOfCores
  }

  def runMultiThreaded[A](op: Int => A): Seq[A] = {
    info(s"Running with $numberOfThreads threads")
    val pool = Executors.newFixedThreadPool(numberOfThreads)

    val futures = 1 to numberOfThreads map { i =>
      val task = new Callable[A] {
        override def call(): A = op(i)
      }
      pool.submit(task)
    }

    futures.map(_.get(testTimeout, SECONDS))
  }

  test("Underlying mocks of multiple threads are distinct") {
    val dataSources = runMultiThreaded { _ =>
      DefaultSource.withMock(identity)
    }

    dataSources foreach { current =>
      val sourcesWithoutCurrent = dataSources.filter(_.ne(current))
      assert(sourcesWithoutCurrent.forall(_.underlying ne current))
    }
  }

  test("Mocking works as expected") {
    runMultiThreaded { i =>
      DefaultSource.withMock { defaultSource =>
        when(defaultSource.getAllPartitioningFunctions(
          anyObject[SQLContext],
          anyObject[Map[String, String]]))
          .thenReturn(Seq(HashPartitioningFunction(s"foo$i", Seq.empty, None)))

        val Array(Row(name)) = sqlc
          .sql("SHOW PARTITION FUNCTIONS USING com.sap.spark.dsmock")
          .select("name")
          .collect()

        assertResult(s"foo$i")(name)
      }
    }
  }
} 
Example 102
Source File: ERPCurrencyConversionTestUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.currency.erp

import org.apache.spark.sql.SQLContext

object ERPCurrencyConversionTestUtils {
  private val DecimalOne = java.math.BigDecimal.ONE
  private val DecimalRateA = new java.math.BigDecimal("0.7")
  private val DecimalRateB = new java.math.BigDecimal("1.2")

  
    val tcurx = List(("USD", 2), ("EUR", 2))
    val tcurv = List(("000", "M", "1", "", "0", "", "", "0", "0"))
    val tcurf = List(("000", "M", "USD", "EUR", "79839898", DecimalOne, DecimalOne, "", ""),
                     ("000", "M", "EUR", "USD", "79839898", DecimalOne, DecimalOne, "", ""))
    val tcurr = List(("000", "M", "USD", "EUR", "79839898", DecimalRateA, DecimalOne, DecimalOne),
                     ("000", "M", "EUR", "USD", "79839898", DecimalRateB, DecimalOne, DecimalOne))
    val tcurn = List(("000", "M", "USD", "EUR", "79839898", ""),
                     ("000", "M", "EUR", "USD", "79839898", ""))

    val tcurxRDD = sqlContext.sparkContext.parallelize(tcurx, parallelism)
    val tcurvRDD = sqlContext.sparkContext.parallelize(tcurv, parallelism)
    val tcurfRDD = sqlContext.sparkContext.parallelize(tcurf, parallelism)
    val tcurrRDD = sqlContext.sparkContext.parallelize(tcurr, parallelism)
    val tcurnRDD = sqlContext.sparkContext.parallelize(tcurn, parallelism)

    sqlContext.createDataFrame(tcurxRDD).registerTempTable(tables("tcurx"))
    sqlContext.createDataFrame(tcurvRDD).registerTempTable(tables("tcurv"))
    sqlContext.createDataFrame(tcurfRDD).registerTempTable(tables("tcurf"))
    sqlContext.createDataFrame(tcurrRDD).registerTempTable(tables("tcurr"))
    sqlContext.createDataFrame(tcurnRDD).registerTempTable(tables("tcurn"))
  }

  def dropERPTables(sqlContext: SQLContext, tables: Map[String, String]): Unit = {
    tables.foreach { case (tableID, tableName) =>
      sqlContext.sql(s"DROP TABLE IF EXISTS $tableName")
    }
  }

} 
Example 103
Source File: testCatalystRelations.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._

class DummyCatalystRelation(
                             override val schema: StructType,
                             @transient override val sqlContext: SQLContext)
  extends BaseRelation
  with CatalystSource
  with Serializable {

  @transient
  var isMultiplePartitionExecutionFunc: Seq[CatalystSource] => Boolean = (r) => false
  override def isMultiplePartitionExecution(relations: Seq[CatalystSource]): Boolean =
    isMultiplePartitionExecutionFunc(relations)

  @transient
  var supportsLogicalPlanFunc: LogicalPlan => Boolean = (plan) => true
  override def supportsLogicalPlan(plan: LogicalPlan): Boolean =
    supportsLogicalPlanFunc(plan)

  @transient
  var supportsExpressionFunc: Expression => Boolean = (expr) => true
  override def supportsExpression(expr: Expression): Boolean =
    supportsExpressionFunc(expr)

  @transient
  var logicalPlanToRDDFunc: LogicalPlan => RDD[Row] =
    (plan) => new LogicalPlanRDD(plan, sqlContext.sparkContext)
  override def logicalPlanToRDD(plan: LogicalPlan): RDD[Row] =
    logicalPlanToRDDFunc(plan)

} 
Example 104
Source File: UseAliasesForAggregationsInGroupingsSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class UseAliasesForAggregationsInGroupingsSuite extends FunSuite with MockitoSugar {

  val br1 = new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType)
    ))
  }

  val lr1 = LogicalRelation(br1)
  val nameAtt = lr1.output.find(_.name == "name").get
  val ageAtt = lr1.output.find(_.name == "age").get

  test("replace functions in group by") {
    val avgExpr = avg(ageAtt)
    val avgAlias = avgExpr as 'avgAlias
    assertResult(
      lr1.groupBy(avgAlias.toAttribute)(avgAlias)
    )(UseAliasesForFunctionsInGroupings(
      lr1.groupBy(avgExpr)(avgAlias))
    )
    assertResult(
      lr1.select(ageAtt)
    )(UseAliasesForFunctionsInGroupings(
      lr1.select(ageAtt))
      )
    intercept[RuntimeException](
      UseAliasesForFunctionsInGroupings(Aggregate(Seq(avgExpr), Seq(ageAtt), lr1))
    )
  }

} 
Example 105
Source File: RemoveNestedAliasesSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import com.sap.spark.PlanTest
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class RemoveNestedAliasesSuite extends FunSuite with MockitoSugar with PlanTest {

  val br1 = new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]

    override def schema: StructType = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType)
    ))
  }

  val lr1 = LogicalRelation(br1)
  val nameAtt = lr1.output.find(_.name == "name").get
  val ageAtt = lr1.output.find(_.name == "age").get

  test("Replace alias into aliases") {
    val avgExpr = avg(ageAtt)
    val avgAlias = avgExpr as 'avgAlias
    val aliasAlias = avgAlias as 'aliasAlias
    val aliasAliasAlias = aliasAlias as 'aliasAliasAlias
    val copiedAlias = Alias(avgExpr, aliasAlias.name)(
      exprId = aliasAlias.exprId
    )
    val copiedAlias2 = Alias(avgExpr, aliasAliasAlias.name)(
      exprId = aliasAliasAlias.exprId
    )

    assertResult(
      lr1.groupBy(avgAlias.toAttribute)(avgAlias)
    )(RemoveNestedAliases(lr1.groupBy(avgAlias.toAttribute)(avgAlias)))

    assertResult(
      lr1.groupBy(copiedAlias.toAttribute)(copiedAlias)
    )(RemoveNestedAliases(lr1.groupBy(aliasAlias.toAttribute)(aliasAlias)))

    assertResult(
      lr1.groupBy(copiedAlias2.toAttribute)(copiedAlias2)
    )(RemoveNestedAliases(lr1.groupBy(aliasAliasAlias.toAttribute)(aliasAliasAlias)))
  }

  test("Replace alias into expressions") {
    val ageAlias = ageAtt as 'ageAlias
    val avgExpr = avg(ageAlias) as 'avgAlias
    val correctedAvgExpr = avg(ageAtt) as 'avgAlias
    comparePlans(
      lr1.groupBy(correctedAvgExpr.toAttribute)(correctedAvgExpr),
      RemoveNestedAliases(lr1.groupBy(avgExpr.toAttribute)(avgExpr))
    )
  }

} 
Example 106
Source File: ResolveHierarchySuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, EqualTo}
import org.apache.spark.sql.catalyst.plans.logical.{AdjacencyListHierarchySpec, Hierarchy}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar

class ResolveHierarchySuite extends FunSuite with MockitoSugar {

  val br1 = new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("id", IntegerType),
      StructField("parent", IntegerType)
    ))
  }

  val lr1 = LogicalRelation(br1)
  val idAtt = lr1.output.find(_.name == "id").get
  val parentAtt = lr1.output.find(_.name == "parent").get

  test("Check parenthood expression has no conflicting expression IDs and qualifiers") {
    val source = SimpleAnalyzer.execute(lr1.select('id, 'parent).subquery('u))
    assert(source.resolved)

    val hierarchy = Hierarchy(
      AdjacencyListHierarchySpec(source, "v",
        
        UnresolvedAttribute("u" :: "id" :: Nil) === UnresolvedAttribute("v" :: "id" :: Nil),
        Some('id.isNull), Nil),
      'node
    )

    val resolveHierarchy = ResolveHierarchy(SimpleAnalyzer)
    val resolveReferences = ResolveReferencesWithHierarchies(SimpleAnalyzer)

    val resolvedHierarchy = (0 to 10).foldLeft(hierarchy: Hierarchy) { (h, _) =>
      SimpleAnalyzer.ResolveReferences(
        resolveReferences(resolveHierarchy(h))
      ).asInstanceOf[Hierarchy]
    }

    assert(resolvedHierarchy.node.resolved)
    val resolvedSpec = resolvedHierarchy.spec.asInstanceOf[AdjacencyListHierarchySpec]
    assert(resolvedSpec.parenthoodExp.resolved)
    assert(resolvedSpec.startWhere.forall(_.resolved))
    assert(resolvedHierarchy.childrenResolved)
    assert(resolvedHierarchy.resolved)

    val parenthoodExpression = resolvedSpec.parenthoodExp.asInstanceOf[EqualTo]

    assertResult("u" :: Nil)(parenthoodExpression.left.asInstanceOf[Attribute].qualifiers)
    assertResult("v" :: Nil)(parenthoodExpression.right.asInstanceOf[Attribute].qualifiers)
    assert(parenthoodExpression.right.asInstanceOf[Attribute].exprId !=
      source.output.find(_.name == "id").get.exprId)
  }

} 
Example 107
Source File: ResolveAnnotationsSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar
import org.apache.spark.sql.catalyst.dsl.plans._


class ResolveAnnotationsSuite extends FunSuite with MockitoSugar {

  // scalastyle:off magic.number
  val annotatedRel1 = new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("id1.1", IntegerType, metadata =
        new MetadataBuilder().putLong("key1.1", 11L).build()),
      StructField("id1.2", IntegerType, metadata =
        new MetadataBuilder()
          .putLong("key1.2", 12L)
            .putLong("key1.3", 13).build()))
    )
  }
  val lr1 = LogicalRelation(annotatedRel1)
  val id11Att = lr1.output.find(_.name == "id1.1").get
  val id12Att = lr1.output.find(_.name == "id1.2").get

  val id11AnnotatedAtt = AnnotatedAttribute(id11Att)(
    Map("key1.1" -> Literal.create(100L, LongType), // override the old key
    "newkey" -> Literal.create(200L, LongType))) // define a new key

  val simpleAnnotatedSelect = lr1.select(id11AnnotatedAtt)
} 
Example 108
Source File: ResolveCountDistinctStarSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FunSuite
import org.scalatest.Inside._
import org.scalatest.mock.MockitoSugar
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}

import scala.collection.mutable.ArrayBuffer

class ResolveCountDistinctStarSuite extends FunSuite with MockitoSugar {
  val persons = new LogicalRelation(new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("age", IntegerType),
      StructField("name", StringType)
    ))
  })

  test("Count distinct star is resolved correctly") {
    val projection = persons.select(UnresolvedAlias(
      AggregateExpression(Count(UnresolvedStar(None) :: Nil), Complete, true)))
    val stillNotCompletelyResolvedAggregate = SimpleAnalyzer.execute(projection)
    val resolvedAggregate = ResolveCountDistinctStar(SimpleAnalyzer)
                              .apply(stillNotCompletelyResolvedAggregate)
    inside(resolvedAggregate) {
      case Aggregate(Nil,
      ArrayBuffer(Alias(AggregateExpression(Count(expressions), Complete, true), _)), _) =>
        assert(expressions.collect {
          case a:AttributeReference => a.name
        }.toSet == Set("name", "age"))
    }
    assert(resolvedAggregate.resolved)
  }
} 
Example 109
Source File: SqlContextConfigurationUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.sql.SQLContext


  def withConf[A](settings: Map[String, String])(a: => A): A = {
    val temps: Map[String, String] =
      settings.keys.map(key => key -> sqlContext.getConf(key))(scala.collection.breakOut)

    try {
      settings.foreach {
        case (key, value) => sqlContext.setConf(key, value)
      }
      a
    } finally {
      temps.foreach {
        case (key, value) => sqlContext.setConf(key, value)
      }
    }
  }
} 
Example 110
Source File: DummyRelationUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.{ColumnName, Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.types.{StructField, StructType}


  case class DummyCatalystSourceRelation(
      schema: StructType,
      isMultiplePartitionExecutionFunc: Option[Seq[CatalystSource] => Boolean] = None,
      supportsLogicalPlanFunc: Option[LogicalPlan => Boolean] = None,
      supportsExpressionFunc: Option[Expression => Boolean] = None,
      logicalPlanToRDDFunc: Option[LogicalPlan => RDD[Row]] = None)
     (@transient implicit val sqlContext: SQLContext)
    extends BaseRelation
    with CatalystSource {

    override def isMultiplePartitionExecution(relations: Seq[CatalystSource]): Boolean =
      isMultiplePartitionExecutionFunc.forall(_.apply(relations))

    override def supportsLogicalPlan(plan: LogicalPlan): Boolean =
      supportsLogicalPlanFunc.forall(_.apply(plan))

    override def supportsExpression(expr: Expression): Boolean =
      supportsExpressionFunc.forall(_.apply(expr))

    override def logicalPlanToRDD(plan: LogicalPlan): RDD[Row] =
      logicalPlanToRDDFunc.getOrElse(
        (plan: LogicalPlan) => new LogicalPlanRDD(plan, sqlContext.sparkContext)).apply(plan)
  }
} 
Example 111
Source File: PartitioningFunctionUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.Dissector


  def dropPartitioningFunction(name: String,
                               allowExisting: Boolean = false,
                               dataSource: String): Unit = {
    sqlc.sql(
      s"""
        |DROP PARTITION FUNCTION
        |${if (allowExisting) "IF EXISTS" else ""}
        |$name
        |USING $dataSource
      """.stripMargin
    )
  }

  private def intervalToSql(interval: Dissector) =
    s"${interval.productPrefix.toUpperCase()} ${interval.n}"
} 
Example 112
Source File: testRelations.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package com.sap.spark.dstest

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.commands.WithOrigin
import org.apache.spark.sql.sources.{BaseRelation, DropRelation, Table}
import org.apache.spark.sql.types._


case class DummyRelationWithTempFlag(
    sqlContext: SQLContext,
    tableName: Seq[String],
    schema: StructType,
    temporary: Boolean)
  extends BaseRelation
  with Table
  with DropRelation
  with WithOrigin {

  override val provider: String = "com.sap.spark.dstest"

  override def isTemporary: Boolean = temporary

  override def dropTable(): Unit = {}
}

case class DummyRelationWithoutTempFlag(
    sqlContext: SQLContext,
    schema: StructType)
  extends BaseRelation
  with DropRelation
  with Table
  with WithOrigin {

  override def isTemporary: Boolean = false

  override val provider: String = "com.sap.spark.dstest"

  override def dropTable(): Unit = {}
} 
Example 113
Source File: WithSQLContext.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package com.sap.spark

import java.util.Locale

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.HiveContext
import org.scalatest.{BeforeAndAfterEach, Suite}

trait WithSQLContext extends BeforeAndAfterEach {
  self: Suite with WithSparkContext =>

  override def beforeEach(): Unit = {
    try {
      super.beforeEach()
      setUpSQLContext()
    } catch {
      case ex: Throwable =>
        tearDownSQLContext()
        throw ex
    }
  }

  override def afterEach(): Unit = {
    try {
      super.afterEach()
    } finally {
      tearDownSQLContext()
    }
  }

  implicit def sqlContext: SQLContext = _sqlContext
  def sqlc: SQLContext = sqlContext

  var _sqlContext: SQLContext = _

  protected def setUpSQLContext(): Unit =
    _sqlContext = SQLContext.getOrCreate(sc).newSession()


  protected def tearDownSQLContext(): Unit =
    _sqlContext = null

  protected def tableName(name: String): String =
    sqlc match {
      
      case _: HiveContext => name.toLowerCase(Locale.ENGLISH)
      case _ => name
    }

} 
Example 114
Source File: TestUtils.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package com.sap.spark.util

import java.util.Locale

import scala.io.Source
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.{Row, SQLContext, SapSQLContext}
import org.apache.spark.sql.hive.SapHiveContext
import org.apache.spark.sql.sources.sql.SqlLikeRelation
import org.apache.spark.sql.sources.{BaseRelation, CatalystSource, Table}
import org.apache.spark.sql.types.StructType
import org.mockito.Matchers._
import org.mockito.Mockito._

import scala.tools.nsc.io.Directory
import scala.util.{Failure, Success}


  def parsePTestFile(fileName: String): List[(String, String, String)] = {
    val filePath = getFileFromClassPath(fileName)
    val fileContents = Source.fromFile(filePath).getLines
      .map(p => p.stripMargin.trim)
      .filter(p => !p.isEmpty && !p.startsWith("//")) // filter empty rows and comments
      .mkString("\n")
    val p = new PTestFileParser

    // strip semicolons from query and parsed
    p(fileContents) match {
      case Success(lines) =>
        lines.map {
          case (query, parsed, expect) =>
            (stripSemicolon(query).trim, stripSemicolon(parsed).trim, expect.trim)
        }
      case Failure(ex) => throw ex
    }
  }

  private def stripSemicolon(sql: String): String =
    if (sql.endsWith(";")) {
      sql.substring(0, sql.length-1)
    } else {
      sql
    }

  def withTempDirectory[A](f: Directory => A): A = {
    val dir = Directory.makeTemp()
    try {
      f(dir)
    } finally {
      dir.deleteIfExists()
    }
  }
} 
Example 115
Source File: SQLRunnerSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package com.sap.spark.cli

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}

import org.apache.spark.SparkContext
import org.apache.spark.sql.{GlobalSapSQLContext, SQLContext}
import org.scalatest.{BeforeAndAfterEach, FunSuite, ShouldMatchers}



    // good call
    val goodOpts =
      SQLRunner.parseOpts(List("a.sql", "b.sql", "-o", "output.csv"))

    goodOpts.sqlFiles should be(List("a.sql", "b.sql"))
    goodOpts.output should be(Some("output.csv"))

    // bad call
    val badOpts = SQLRunner.parseOpts(List())

    badOpts.sqlFiles should be(List())
    badOpts.output should be(None)

    // ugly call
    val uglyOpts =
      SQLRunner.parseOpts(List("a.sql", "-o", "output.csv", "b.sql"))

    uglyOpts.sqlFiles should be(List("a.sql", "b.sql"))
    uglyOpts.output should be(Some("output.csv"))
  }

  def runSQLTest(input: String, expectedOutput: String): Unit = {
    val inputStream: InputStream = new ByteArrayInputStream(input.getBytes())
    val outputStream = new ByteArrayOutputStream()

    SQLRunner.sql(inputStream, outputStream)

    val output = outputStream.toString
    output should be(expectedOutput)
  }

  test("can run dummy query") {
    val input = "SELECT 1;"
    val output = "1\n"

    runSQLTest(input, output)
  }

  test("can run multiple dummy queries") {
    val input = """
        |SELECT 1;SELECT 2;
        |SELECT 3;
      """.stripMargin

    val output = "1\n2\n3\n"

    runSQLTest(input, output)
  }

  test("can run a basic example with tables") {
    val input = """
                  |SELECT * FROM DEMO_TABLE;
                  |SELECT * FROM DEMO_TABLE LIMIT 1;
                  |DROP TABLE DEMO_TABLE;
                """.stripMargin

    val output = "1,a\n2,b\n3,c\n1,a\n"

    runSQLTest(input, output)
  }

  test("can run an example with comments") {
    val input = """
                  |SELECT * FROM DEMO_TABLE; -- this is the first query
                  |SELECT * FROM DEMO_TABLE LIMIT 1;
                  |-- now let's drop a table
                  |DROP TABLE DEMO_TABLE;
                """.stripMargin

    val output = "1,a\n2,b\n3,c\n1,a\n"

    runSQLTest(input, output)
  }
} 
Example 116
Source File: Preparator.scala    From pio-template-sr   with Apache License 2.0 5 votes vote down vote up
package org.template.sr



import org.apache.predictionio.controller.PPreparator
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.ml.feature.StandardScaler
import org.apache.spark.sql.DataFrame
import org.apache.spark.ml.feature.StandardScalerModel
import org.apache.spark.sql.SQLContext
import org.apache.spark.mllib.linalg.Vectors

class PreparedData(
  val rows: DataFrame,
  val dsp: DataSourceParams,
  val ssModel: org.apache.spark.mllib.feature.StandardScalerModel
) extends Serializable

class Preparator
  extends PPreparator[TrainingData, PreparedData] {

  def prepare(sc: SparkContext, trainingData: TrainingData): PreparedData = {
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._

    if (trainingData.dsp.useStandardScaler) {
      val training = trainingData.rows.map(x=>(x._1,x._2,Vectors.dense(x._3))).toDF("label", "censor", "features")
      val scaler = new StandardScaler().setInputCol("features").setOutputCol("scaledFeatures").setWithStd(trainingData.dsp.standardScalerWithStd).setWithMean(trainingData.dsp.standardScalerWithMean)
      val scalerModel = scaler.fit(training)
      val scaledData = scalerModel.transform(training)
      val s1 = scaledData.select("label","censor","scaledFeatures").withColumnRenamed("scaledFeatures","features")

      //Prepare old StandardScaler
      val oldScaler = new org.apache.spark.mllib.feature.StandardScaler(withMean = trainingData.dsp.standardScalerWithMean, withStd = trainingData.dsp.standardScalerWithStd)
      val oldSSModel = oldScaler.fit(trainingData.rows.map(x=>(Vectors.dense(x._3))))
            
      new PreparedData(rows = s1, dsp = trainingData.dsp, ssModel = oldSSModel)
    }
    else {
      new PreparedData(rows = trainingData.rows.map(x=>(x._1,x._2,Vectors.dense(x._3))).toDF("label", "censor", "features"), dsp = trainingData.dsp, ssModel = null)
    }
  }
} 
Example 117
Source File: AvroTransformer.scala    From streamliner-examples   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark.examples.avro

import com.memsql.spark.etl.api.{UserTransformConfig, Transformer, PhaseConfig}
import com.memsql.spark.etl.utils.PhaseLogger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.types.StructType

import org.apache.avro.Schema
import org.apache.avro.generic.GenericData
import org.apache.avro.io.DecoderFactory
import org.apache.avro.specific.SpecificDatumReader

// Takes DataFrames of byte arrays, where each row is a serialized Avro record.
// Returns DataFrames of deserialized data, where each field has its own column.
class AvroTransformer extends Transformer {
  var avroSchemaStr: String = null
  var sparkSqlSchema: StructType = null

  def AvroRDDToDataFrame(sqlContext: SQLContext, rdd: RDD[Row]): DataFrame = {

    val rowRDD: RDD[Row] = rdd.mapPartitions({ partition => {
      // Create per-partition copies of non-serializable objects
      val parser: Schema.Parser = new Schema.Parser()
      val avroSchema = parser.parse(avroSchemaStr)
      val reader = new SpecificDatumReader[GenericData.Record](avroSchema)

      partition.map({ rowOfBytes =>
        val bytes = rowOfBytes(0).asInstanceOf[Array[Byte]]
        val decoder = DecoderFactory.get().binaryDecoder(bytes, null)
        val record = reader.read(null, decoder)
        val avroToRow = new AvroToRow()

        avroToRow.getRow(record)
      })
    }})
    sqlContext.createDataFrame(rowRDD, sparkSqlSchema)
  }

  override def initialize(sqlContext: SQLContext, config: PhaseConfig, logger: PhaseLogger): Unit = {
    val userConfig = config.asInstanceOf[UserTransformConfig]

    val avroSchemaJson = userConfig.getConfigJsValue("avroSchema") match {
      case Some(s) => s
      case None => throw new IllegalArgumentException("avroSchema must be set in the config")
    }
    avroSchemaStr = avroSchemaJson.toString

    val parser = new Schema.Parser()
    val avroSchema = parser.parse(avroSchemaJson.toString)
    sparkSqlSchema = AvroToSchema.getSchema(avroSchema)
  }

  override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = {
    AvroRDDToDataFrame(sqlContext, df.rdd)
  }
} 
Example 118
Source File: AvroRandomExtractor.scala    From streamliner-examples   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark.examples.avro

import com.memsql.spark.etl.api._
import com.memsql.spark.etl.utils.PhaseLogger
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.types._
import org.apache.avro.Schema
import org.apache.avro.generic.GenericData
import org.apache.avro.io.{DatumWriter, EncoderFactory}
import org.apache.avro.specific.SpecificDatumWriter

import java.io.ByteArrayOutputStream

// Generates an RDD of byte arrays, where each is a serialized Avro record.
class AvroRandomExtractor extends Extractor {
  var count: Int = 1
  var generator: AvroRandomGenerator = null
  var writer: DatumWriter[GenericData.Record] = null
  var avroSchema: Schema = null
  
  def schema: StructType = StructType(StructField("bytes", BinaryType, false) :: Nil)

  val parser: Schema.Parser = new Schema.Parser()

  override def initialize(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Unit = {
    val userConfig = config.asInstanceOf[UserExtractConfig]
    val avroSchemaJson = userConfig.getConfigJsValue("avroSchema") match {
      case Some(s) => s
      case None => throw new IllegalArgumentException("avroSchema must be set in the config")
    }
    count = userConfig.getConfigInt("count").getOrElse(1)
    avroSchema = parser.parse(avroSchemaJson.toString)

    writer = new SpecificDatumWriter(avroSchema)
    generator = new AvroRandomGenerator(avroSchema)
  }

  override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Option[DataFrame] = {
    val rdd = sqlContext.sparkContext.parallelize((1 to count).map(_ => Row({
      val out = new ByteArrayOutputStream
      val encoder = EncoderFactory.get().binaryEncoder(out, null)
      val avroRecord: GenericData.Record = generator.next().asInstanceOf[GenericData.Record]

      writer.write(avroRecord, encoder)
      encoder.flush
      out.close
      out.toByteArray
    })))

    Some(sqlContext.createDataFrame(rdd, schema))
  }
} 
Example 119
Source File: AvroExtractorSpec.scala    From streamliner-examples   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark.examples.avro

import com.memsql.spark.etl.api.UserExtractConfig
import org.apache.spark.streaming._
import org.apache.spark.sql.SQLContext
import test.util.{Fixtures, UnitSpec, TestLogger, LocalSparkContext}
import spray.json._

class ExtractorsSpec extends UnitSpec with LocalSparkContext {
  var ssc: StreamingContext = _
  var sqlContext: SQLContext = _

  override def beforeEach(): Unit = {
    super.beforeEach()
    ssc = new StreamingContext(sc, Seconds(1))
    sqlContext = new SQLContext(sc)
  }

  val avroConfig = Fixtures.avroConfig.parseJson
  val extractConfig = UserExtractConfig(class_name = "Test", value = avroConfig)
  val logger = new TestLogger("test")

  "AvroRandomExtractor" should "emit a random DF" in {
    val extract = new AvroRandomExtractor
    extract.initialize(ssc, sqlContext, extractConfig, 1, logger)

    val maybeDf = extract.next(ssc, 1, sqlContext, extractConfig, 1, logger)
    assert(maybeDf.isDefined)
    assert(maybeDf.get.count == 5)
  }
} 
Example 120
Source File: ThriftTransformer.scala    From streamliner-examples   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark.examples.thrift

import com.memsql.spark.etl.api._
import com.memsql.spark.etl.utils.PhaseLogger
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.types._
import org.apache.thrift.{TBase, TDeserializer, TFieldIdEnum}

class ThriftTransformer extends Transformer {
  private var classObj: Class[_] = null
  private var thriftToRow: ThriftToRow = null
  private var deserializer: TDeserializer = null
  private var schema: StructType = null

  def thriftRDDToDataFrame(sqlContext: SQLContext, rdd: RDD[Row]): DataFrame = {
    val rowRDD: RDD[Row] = rdd.map({ record =>
      val recordAsBytes = record(0).asInstanceOf[Array[Byte]]
      val i = classObj.newInstance().asInstanceOf[TBase[_ <: TBase[_, _], _ <: TFieldIdEnum]]
      deserializer.deserialize(i, recordAsBytes)
      thriftToRow.getRow(i)
    })
    sqlContext.createDataFrame(rowRDD, schema)
  }

  override def initialize(sqlContext: SQLContext, config: PhaseConfig, logger: PhaseLogger): Unit = {
    val userConfig = config.asInstanceOf[UserTransformConfig]
    val className = userConfig.getConfigString("className") match {
      case Some(s) => s
      case None => throw new IllegalArgumentException("className must be set in the config")
    }

    classObj = Class.forName(className)
    thriftToRow = new ThriftToRow(classObj)
    deserializer = new TDeserializer()

    schema = ThriftToSchema.getSchema(classObj)
  }

  override def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = {
    thriftRDDToDataFrame(sqlContext, df.rdd)
  }
} 
Example 121
Source File: ThriftRandomExtractor.scala    From streamliner-examples   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark.examples.thrift

import com.memsql.spark.etl.api._
import com.memsql.spark.etl.utils.PhaseLogger
import org.apache.spark.SparkContext
import org.apache.spark.sql.{SQLContext, DataFrame, Row}
import org.apache.spark.sql.types._
import org.apache.spark.streaming.StreamingContext
import org.apache.thrift.protocol.TBinaryProtocol
import org.apache.thrift.{TBase, TFieldIdEnum, TSerializer}

class ThriftRandomExtractor extends Extractor {
  var count: Int = 1
  var thriftType: Class[_] = null
  var serializer: TSerializer = null

  def schema: StructType = StructType(StructField("bytes", BinaryType, false) :: Nil)

  override def initialize(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Unit = {
    val userConfig = config.asInstanceOf[UserExtractConfig]
    val className = userConfig.getConfigString("className") match {
      case Some(s) => s
      case None => throw new IllegalArgumentException("className must be set in the config")
    }
    thriftType = Class.forName(className)
    serializer = new TSerializer(new TBinaryProtocol.Factory())
    count = userConfig.getConfigInt("count").getOrElse(1)
  }

  override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Option[DataFrame] = {
    val rdd = sqlContext.sparkContext.parallelize((1 to count).map(_ => Row({
      val thriftObject = ThriftRandomGenerator.next(thriftType).asInstanceOf[TBase[_ <: TBase[_, _], _ <: TFieldIdEnum]]
      serializer.serialize(thriftObject)
    })))
    Some(sqlContext.createDataFrame(rdd, schema))
  }
} 
Example 122
Source File: CheckpointingKafkaExtractor.scala    From streamliner-examples   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark.examples.kafka

import com.memsql.spark.etl.api.{UserExtractConfig, PhaseConfig, ByteArrayExtractor}
import com.memsql.spark.etl.utils.PhaseLogger
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.StreamingContext

import kafka.serializer.{DefaultDecoder, StringDecoder}
import org.apache.spark.streaming.kafka.{CheckpointedDirectKafkaInputDStream, CheckpointedKafkaUtils}
import org.apache.spark.streaming.dstream.InputDStream


class CheckpointingKafkaExtractor extends ByteArrayExtractor {
  var CHECKPOINT_DATA_VERSION = 1

  var dstream: CheckpointedDirectKafkaInputDStream[String, Array[Byte], StringDecoder, DefaultDecoder, Array[Byte]] = null

  var zkQuorum: String = null
  var topic: String = null

  override def initialize(ssc: StreamingContext, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long, logger: PhaseLogger): Unit = {
    val kafkaConfig  = config.asInstanceOf[UserExtractConfig]
    zkQuorum = kafkaConfig.getConfigString("zk_quorum").getOrElse {
      throw new IllegalArgumentException("\"zk_quorum\" must be set in the config")
    }
    topic = kafkaConfig.getConfigString("topic").getOrElse {
      throw new IllegalArgumentException("\"topic\" must be set in the config")
    }
  }

  def extract(ssc: StreamingContext, extractConfig: PhaseConfig, batchDuration: Long, logger: PhaseLogger): InputDStream[Array[Byte]] = {
    val kafkaParams = Map[String, String](
      "memsql.zookeeper.connect" -> zkQuorum
    )
    val topics = Set(topic)

    dstream = CheckpointedKafkaUtils.createDirectStreamFromZookeeper[String, Array[Byte], StringDecoder, DefaultDecoder](
      ssc, kafkaParams, topics, batchDuration, lastCheckpoint)
    dstream
  }

  override def batchCheckpoint: Option[Map[String, Any]] = {
    dstream match {
      case null => None
      case default => {
        val currentOffsets = dstream.getCurrentOffsets.map { case (tp, offset) =>
          Map("topic" -> tp.topic, "partition" -> tp.partition, "offset" -> offset)
        }
        Some(Map("offsets" -> currentOffsets, "zookeeper" -> zkQuorum, "version" -> CHECKPOINT_DATA_VERSION))
      }
    }
  }

  override def batchRetry: Unit = {
    if (dstream.prevOffsets != null) {
      dstream.setCurrentOffsets(dstream.prevOffsets)
    }
  }
} 
Example 123
Source File: DefaultSource.scala    From spark-power-bi   with Apache License 2.0 5 votes vote down vote up
package com.granturing.spark.powerbi

import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider}
import scala.concurrent._
import scala.concurrent.ExecutionContext.Implicits._
import scala.concurrent.duration.Duration

class DefaultSource extends CreatableRelationProvider with PowerBISink {

  override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      data: DataFrame): BaseRelation = {

    val conf = ClientConf.fromSparkConf(sqlContext.sparkContext.getConf)
    implicit val client = new Client(conf)

    val dataset = parameters.getOrElse("dataset", sys.error("'dataset' must be specified"))
    val table = parameters.getOrElse("table", sys.error("'table' must be specified"))
    val batchSize = parameters.getOrElse("batchSize", conf.batchSize.toString).toInt
    val group = parameters.get("group")

    val step = for {
      groupId <- getGroupId(group)
      ds <- getOrCreateDataset(mode, groupId, dataset, table, data.schema)
    } yield (groupId, ds)

    val result = step map { case (groupId, ds) =>
      val fields = data.schema.fieldNames.zipWithIndex
      val _conf = conf
      val _token = Some(client.currentToken)
      val _table = table
      val _batchSize = batchSize

      val coalesced = data.rdd.partitions.size > _conf.maxPartitions match {
        case true => data.coalesce(_conf.maxPartitions)
        case false => data
      }

      coalesced foreachPartition { p =>
        val rows = p map { r =>
          fields map { case(name, index) => (name -> r(index)) } toMap
        } toSeq

        val _client = new Client(_conf, _token)

        val submit = rows.
          sliding(_batchSize, _batchSize).
          foldLeft(future()) { (fAccum, batch) =>
          fAccum flatMap { _ => _client.addRows(ds.id, _table, batch, groupId) } }

        submit.onComplete { _ => _client.shutdown() }

        Await.result(submit, _conf.timeout)
      }
    }

    result.onComplete { _ => client.shutdown() }

    Await.result(result, Duration.Inf)

    new BaseRelation {
      val sqlContext = data.sqlContext

      val schema = data.schema
    }
  }

} 
Example 124
Source File: ExcelRelation.scala    From spark-hadoopoffice-ds   with Apache License 2.0 5 votes vote down vote up
package org.zuinnote.spark.office.excel

import scala.collection.JavaConversions._

import org.apache.spark.sql.sources.{ BaseRelation, TableScan }
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.SQLContext

import org.apache.spark.sql._
import org.apache.spark.rdd.RDD

import org.apache.hadoop.conf._
import org.apache.hadoop.mapreduce._

import org.apache.commons.logging.LogFactory
import org.apache.commons.logging.Log

import org.zuinnote.hadoop.office.format.common.dao._
import org.zuinnote.hadoop.office.format.mapreduce._

import org.zuinnote.spark.office.excel.util.ExcelFile


  override def buildScan: RDD[Row] = {
    // read ExcelRows
    val excelRowsRDD = ExcelFile.load(sqlContext, location, hadoopParams)
    // map to schema
    val schemaFields = schema.fields
    excelRowsRDD.flatMap(excelKeyValueTuple => {
      // map the Excel row data structure to a Spark SQL schema
      val rowArray = new Array[Any](excelKeyValueTuple._2.get.length)
      var i = 0;
      for (x <- excelKeyValueTuple._2.get) { // parse through the SpreadSheetCellDAO
        val spreadSheetCellDAOStructArray = new Array[String](schemaFields.length)
        val currentSpreadSheetCellDAO: Array[SpreadSheetCellDAO] = excelKeyValueTuple._2.get.asInstanceOf[Array[SpreadSheetCellDAO]]
        spreadSheetCellDAOStructArray(0) = currentSpreadSheetCellDAO(i).getFormattedValue
        spreadSheetCellDAOStructArray(1) = currentSpreadSheetCellDAO(i).getComment
        spreadSheetCellDAOStructArray(2) = currentSpreadSheetCellDAO(i).getFormula
        spreadSheetCellDAOStructArray(3) = currentSpreadSheetCellDAO(i).getAddress
        spreadSheetCellDAOStructArray(4) = currentSpreadSheetCellDAO(i).getSheetName
        // add row representing one Excel row
        rowArray(i) = spreadSheetCellDAOStructArray
        i += 1
      }
      Some(Row.fromSeq(rowArray))
    })

  }

} 
Example 125
Source File: HttpStreamSink.scala    From spark-http-stream   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.http

import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.sources.StreamSinkProvider
import org.apache.spark.sql.streaming.OutputMode

import Params.map2Params

class HttpStreamSinkProvider
		extends StreamSinkProvider with DataSourceRegister {
	def createSink(
		sqlContext: SQLContext,
		parameters: Map[String, String],
		partitionColumns: Seq[String],
		outputMode: OutputMode): Sink = {
		new HttpStreamSink(parameters.getRequiredString("httpServletUrl"),
			parameters.getRequiredString("topic"),
			parameters.getInt("maxPacketSize", 10 * 1024 * 1024));
	}

	def shortName(): String = "httpStream"
}

class HttpStreamSink(httpPostURL: String, topic: String, maxPacketSize: Int)
		extends Sink with Logging {
	val producer = HttpStreamClient.connect(httpPostURL);
	val RETRY_TIMES = 5;
	val SLEEP_TIME = 100;

	override def addBatch(batchId: Long, data: DataFrame) {
		//send data to the HTTP server
		var success = false;
		var retried = 0;
		while (!success && retried < RETRY_TIMES) {
			try {
				retried += 1;
				producer.sendDataFrame(topic, batchId, data, maxPacketSize);
				success = true;
			}
			catch {
				case e: Throwable ⇒ {
					success = false;
					super.logWarning(s"failed to send", e);
					if (retried < RETRY_TIMES) {
						val sleepTime = SLEEP_TIME * retried;
						super.logWarning(s"will retry to send after ${sleepTime}ms");
						Thread.sleep(sleepTime);
					}
					else {
						throw e;
					}
				}
			}
		}
	}
} 
Example 126
Source File: SparkLeapFrame.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package ml.combust.mleap.spark

import ml.combust.mleap.core.types.{StructField, StructType}
import ml.combust.mleap.runtime.frame.{FrameBuilder, Row, RowUtil}
import ml.combust.mleap.runtime.function.{Selector, UserDefinedFunction}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql
import org.apache.spark.sql.mleap.TypeConverters
import org.apache.spark.sql.{DataFrame, SQLContext, types}

import scala.util.Try


case class SparkLeapFrame(schema: StructType,
                          dataset: RDD[Row],
                          sqlContext: SQLContext) extends FrameBuilder[SparkLeapFrame] {
  override def withColumn(output: String, inputs: Selector *)
                         (udf: UserDefinedFunction): Try[SparkLeapFrame] = {
    RowUtil.createRowSelectors(schema, inputs: _*)(udf).flatMap {
      rowSelectors =>
        val field = StructField(output, udf.outputTypes.head)

        schema.withField(field).map {
          schema2 =>
            val dataset2 = dataset.map {
              row => row.withValue(rowSelectors: _*)(udf)
            }
            copy(schema = schema2, dataset = dataset2)
        }
    }
  }

  override def withColumns(outputs: Seq[String], inputs: Selector*)
                          (udf: UserDefinedFunction): Try[SparkLeapFrame] = {
    RowUtil.createRowSelectors(schema, inputs: _*)(udf).flatMap {
      rowSelectors =>
        val fields = outputs.zip(udf.outputTypes).map {
          case (name, dt) => StructField(name, dt)
        }

        schema.withFields(fields).map {
          schema2 =>
            val dataset2 = dataset.map {
              row => row.withValues(rowSelectors: _*)(udf)
            }
            copy(schema = schema2, dataset = dataset2)
        }
    }
  }

  override def select(fieldNames: String *): Try[SparkLeapFrame] = {
    for(indices <- schema.indicesOf(fieldNames: _*);
      schema2 <- schema.selectIndices(indices: _*)) yield {
      val dataset2 = dataset.map(row => row.selectIndices(indices: _*))

      copy(schema = schema2, dataset = dataset2)
    }
  }

  override def drop(names: String*): Try[SparkLeapFrame] = {
    for(indices <- schema.indicesOf(names: _*);
        schema2 <- schema.dropIndices(indices: _*)) yield {
      val dataset2 = dataset.map(row => row.dropIndices(indices: _*))

      copy(schema = schema2, dataset = dataset2)
    }
  }

  override def filter(selectors: Selector*)
                     (udf: UserDefinedFunction): Try[SparkLeapFrame] = {
    RowUtil.createRowSelectors(schema, selectors: _*)(udf).map {
      rowSelectors =>
        val dataset2 = dataset.filter(row => row.shouldFilter(rowSelectors: _*)(udf))
        copy(schema = schema, dataset = dataset2)
    }
  }

  def toSpark: DataFrame = {
    val spec = schema.fields.map(TypeConverters.mleapToSparkConverter)
    val fields = spec.map(_._1)
    val converters = spec.map(_._2)
    val sparkSchema = new types.StructType(fields.toArray)
    val data = dataset.map {
      r =>
        val values = r.zip(converters).map {
          case (v, c) => c(v)
        }
        sql.Row(values.toSeq: _*)
    }

    sqlContext.createDataFrame(data, sparkSchema)
  }
} 
Example 127
Source File: package.scala    From sparksql-scalapb   with Apache License 2.0 5 votes vote down vote up
package scalapb

import org.apache.spark.sql.{DataFrame, Encoder, SQLContext, SparkSession}

import scala.reflect.ClassTag

package object spark {
  implicit class ProtoSQLContext(val sqlContext: SQLContext) extends AnyVal {
    def protoToDataFrame[T <: GeneratedMessage: Encoder](
        protoRdd: org.apache.spark.rdd.RDD[T]
    ) = {
      ProtoSQL.protoToDataFrame(sqlContext, protoRdd)
    }
  }

  implicit class ProtoRDD[T <: GeneratedMessage](
      val protoRdd: org.apache.spark.rdd.RDD[T]
  ) extends AnyVal {
    def toDataFrame(
        sqlContext: SQLContext
    )(implicit encoder: Encoder[T]): DataFrame = {
      ProtoSQL.protoToDataFrame(sqlContext, protoRdd)
    }

    def toDataFrame(
        sparkSession: SparkSession
    )(implicit encoder: Encoder[T]): DataFrame = {
      ProtoSQL.protoToDataFrame(sparkSession, protoRdd)
    }
  }
} 
Example 128
Source File: RegressionDatagen.scala    From hivemall-spark   with Apache License 2.0 5 votes vote down vote up
package hivemall.tools

import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.hive.HivemallOps._
import org.apache.spark.sql.hive.HivemallUtils._
import org.apache.spark.sql.types._

object RegressionDatagen {

  
  def exec(sc: SQLContext,
           n_partitions: Int = 2,
           min_examples: Int = 1000,
           n_features: Int = 10,
           n_dims: Int = 200,
           seed: Int = 43,
           dense: Boolean = false,
           prob_one: Float = 0.6f,
           sort: Boolean = false,
           cl: Boolean = false): DataFrame = {

    require(n_partitions > 0, "Non-negative #n_partitions required.")
    require(min_examples > 0, "Non-negative #min_examples required.")
    require(n_features > 0, "Non-negative #n_features required.")
    require(n_dims > 0, "Non-negative #n_dims required.")

    // Calculate #examples to generate in each partition
    val n_examples = (min_examples + n_partitions - 1) / n_partitions

    val df = sc.createDataFrame(
        sc.sparkContext.parallelize((0 until n_partitions).map(Row(_)), n_partitions),
        StructType(
          StructField("data", IntegerType, true) ::
          Nil)
      )
    import sc.implicits._
    df.lr_datagen(
      s"-n_examples $n_examples -n_features $n_features -n_dims $n_dims -prob_one $prob_one"
        + (if (dense) " -dense" else "")
        + (if (sort) " -sort" else "")
        + (if (cl) " -cl" else ""))
      .select($"label".cast(DoubleType).as("label"), $"features")
  }
} 
Example 129
Source File: HivemallStreamingOps.scala    From hivemall-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import scala.reflect.ClassTag

import org.apache.spark.ml.feature.HmLabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, DataFrame, SQLContext}
import org.apache.spark.streaming.dstream.DStream

final class HivemallStreamingOps(ds: DStream[HmLabeledPoint]) {

  def predict[U: ClassTag](f: DataFrame => DataFrame)(implicit sqlContext: SQLContext)
      : DStream[Row] = {
    ds.transform[Row] { rdd: RDD[HmLabeledPoint] =>
      f(sqlContext.createDataFrame(rdd)).rdd
    }
  }
}

object HivemallStreamingOps {

  
  implicit def dataFrameToHivemallStreamingOps(ds: DStream[HmLabeledPoint])
      : HivemallStreamingOps = {
    new HivemallStreamingOps(ds)
  }
} 
Example 130
Source File: L8-1DataFrameAPI.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.reflect.runtime.universe

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions.desc
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext

object CdrDataframeApp {

  case class Cdr(squareId: Int, timeInterval: Long, countryCode: Int,
    smsInActivity: Float, smsOutActivity: Float, callInActivity: Float,
    callOutActivity: Float, internetTrafficActivity: Float)

  def main(args: Array[String]) {
    if (args.length != 4) {
      System.err.println(
        "Usage: CdrDataframeApp <appname> <batchInterval> <hostname> <port>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        val cdrs = seqToCdr(rdd).toDF()

        cdrs.groupBy("countryCode").count().orderBy(desc("count")).show(5)
      })

    ssc.start()
    ssc.awaitTermination()
  }

  def seqToCdr(rdd: RDD[Array[String]]): RDD[Cdr] = {
    rdd.map(c => c.map(f => f match {
      case x if x.isEmpty() => "0"
      case x => x
    })).map(c => Cdr(c(0).toInt, c(1).toLong, c(2).toInt, c(3).toFloat,
      c(4).toFloat, c(5).toFloat, c(6).toFloat, c(7).toFloat))
  }
} 
Example 131
Source File: L8-3-6-7DataFrameCreation.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.reflect.runtime.universe

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions.desc
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext
import org.json4s.native.Serialization.write
import org.json4s.DefaultFormats

object DataframeCreationApp {

  case class Cdr(squareId: Int, timeInterval: Long, countryCode: Int,
    smsInActivity: Float, smsOutActivity: Float, callInActivity: Float,
    callOutActivity: Float, internetTrafficActivity: Float)

  def main(args: Array[String]) {
    if (args.length != 4) {
      System.err.println(
        "Usage: CdrDataframeApp <appname> <batchInterval> <hostname> <port>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        //val cdrs = sqlC.createDataFrame(seqToCdr(rdd))
        //val cdrs = sqlC.createDataFrame(seqToCdr(rdd).collect())
        //val cdrs = seqToCdr(rdd).toDF()
        val cdrsJson = seqToCdr(rdd).map(r => {
          implicit val formats = DefaultFormats
          write(r)
        })
        val cdrs = sqlC.read.json(cdrsJson)

        cdrs.groupBy("countryCode").count().orderBy(desc("count")).show(5)
      })

    ssc.start()
    ssc.awaitTermination()

  }

  def seqToCdr(rdd: RDD[Array[String]]): RDD[Cdr] = {
    rdd.map(c => c.map(f => f match {
      case x if x.isEmpty() => "0"
      case x => x
    })).map(c => Cdr(c(0).toInt, c(1).toLong, c(2).toInt, c(3).toFloat,
      c(4).toFloat, c(5).toFloat, c(6).toFloat, c(7).toFloat))
  }
} 
Example 132
Source File: L8-29DataFrameExamplesJoin.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.reflect.runtime.universe

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext
import org.json4s.DefaultFormats
import org.json4s.JDouble
import org.json4s.JObject
import org.json4s.jvalue2extractable
import org.json4s.jvalue2monadic
import org.json4s.native.JsonMethods.compact
import org.json4s.native.JsonMethods.parse
import org.json4s.native.JsonMethods.render
import org.json4s.string2JsonInput

object CdrDataframeExamples3App {

  case class Cdr(squareId: Int, timeInterval: Long, countryCode: Int,
    smsInActivity: Float, smsOutActivity: Float, callInActivity: Float,
    callOutActivity: Float, internetTrafficActivity: Float)

  def main(args: Array[String]) {
    if (args.length != 5) {
      System.err.println(
        "Usage: CdrDataframeExamples3App <appname> <batchInterval> <hostname> <port> <gridJsonPath>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port, gridJsonPath) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._
    implicit val formats = DefaultFormats

    val gridFile = scala.io.Source.fromFile(gridJsonPath).mkString
    val gridGeo = (parse(gridFile) \ "features")
    val gridStr = gridGeo.children.map(r => {
      val c = (r \ "geometry" \ "coordinates").extract[List[List[List[Float]]]].flatten.flatten.map(r => JDouble(r))
      val l = List(("id", r \ "id"), ("x1", c(0)), ("y1", c(1)), ("x2", c(2)), ("y2", c(3)),
        ("x3", c(4)), ("y3", c(5)), ("x4", c(6)), ("y4", c(7)))
      compact(render(JObject(l)))
    })

    val gridDF = sqlC.read.json(ssc.sparkContext.makeRDD(gridStr))

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        val cdrs = seqToCdr(rdd).toDF()
        cdrs.join(gridDF, $"squareId" === $"id").show()
      })

    ssc.start()
    ssc.awaitTermination()
  }

  def seqToCdr(rdd: RDD[Array[String]]): RDD[Cdr] = {
    rdd.map(c => c.map(f => f match {
      case x if x.isEmpty() => "0"
      case x => x
    })).map(c => Cdr(c(0).toInt, c(1).toLong, c(2).toInt, c(3).toFloat,
      c(4).toFloat, c(5).toFloat, c(6).toFloat, c(7).toFloat))
  }
} 
Example 133
Source File: L8-10-11UDF.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.io.Source
import scala.reflect.runtime.universe

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext
import org.json4s.jackson.JsonMethods.parse
import org.json4s.jvalue2extractable
import org.json4s.string2JsonInput

object CdrUDFApp {

  case class Cdr(squareId: Int, timeInterval: Long, countryCode: Int,
    smsInActivity: Float, smsOutActivity: Float, callInActivity: Float,
    callOutActivity: Float, internetTrafficActivity: Float)

  def main(args: Array[String]) {
    if (args.length != 4) {
      System.err.println(
        "Usage: CdrUDFApp <appname> <batchInterval> <hostname> <port>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._

    def getCountryCodeMapping() = {
      implicit val formats = org.json4s.DefaultFormats
      parse(Source.fromURL("http://country.io/phone.json").mkString).extract[Map[String, String]].map(_.swap)
    }

    def getCountryNameMapping() = {
      implicit val formats = org.json4s.DefaultFormats
      parse(Source.fromURL("http://country.io/names.json").mkString).extract[Map[String, String]]
    }

    def getCountryName(mappingPhone: Map[String, String], mappingName: Map[String, String], code: Int) = {
      mappingName.getOrElse(mappingPhone.getOrElse(code.toString, "NotFound"), "NotFound")
    }

    val getCountryNamePartial = getCountryName(getCountryCodeMapping(), getCountryNameMapping(), _: Int)

    sqlC.udf.register("getCountryNamePartial", getCountryNamePartial)

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        val cdrs = seqToCdr(rdd).toDF()
        cdrs.registerTempTable("cdrs")

        sqlC.sql("SELECT getCountryNamePartial(countryCode) AS countryName, COUNT(countryCode) AS cCount FROM cdrs GROUP BY countryCode ORDER BY cCount DESC LIMIT 5").show()

      })

    ssc.start()
    ssc.awaitTermination()
  }

  def seqToCdr(rdd: RDD[Array[String]]): RDD[Cdr] = {
    rdd.map(c => c.map(f => f match {
      case x if x.isEmpty() => "0"
      case x => x
    })).map(c => Cdr(c(0).toInt, c(1).toLong, c(2).toInt, c(3).toFloat,
      c(4).toFloat, c(5).toFloat, c(6).toFloat, c(7).toFloat))
  }

} 
Example 134
Source File: L8-4DataFrameCreationSchema.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions.desc
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.StructType
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext

object DataframeCreationApp2 {

  def main(args: Array[String]) {
    if (args.length != 5) {
      System.err.println(
        "Usage: CdrDataframeApp2 <appname> <batchInterval> <hostname> <port> <schemaPath>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port, schemaFile) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)

    val schemaJson = scala.io.Source.fromFile(schemaFile).mkString
    val schema = DataType.fromJson(schemaJson).asInstanceOf[StructType]

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        val cdrs = sqlC.createDataFrame(rdd.map(c => Row(c: _*)), schema)
        
        cdrs.groupBy("countryCode").count().orderBy(desc("count")).show(5)
      })

    ssc.start()
    ssc.awaitTermination()

  }
} 
Example 135
Source File: L8-14-27DataFrameExamples.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.reflect.runtime.universe

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions._
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext

object CdrDataframeExamplesApp {

  case class Cdr(squareId: Int, timeInterval: Long, countryCode: Int,
    smsInActivity: Float, smsOutActivity: Float, callInActivity: Float,
    callOutActivity: Float, internetTrafficActivity: Float)

  def main(args: Array[String]) {
    if (args.length != 4) {
      System.err.println(
        "Usage: CdrDataframeExamplesApp <appname> <batchInterval> <hostname> <port>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        val cdrs = seqToCdr(rdd).toDF()

        cdrs.select("squareId", "timeInterval", "countryCode").show()
        cdrs.select($"squareId", $"timeInterval", $"countryCode").show()
        cdrs.filter("squareId = 5").show()
        cdrs.drop("countryCode").show()
        cdrs.select($"squareId", $"timeInterval", $"countryCode").where($"squareId" === 5).show()
        cdrs.limit(5).show()
        cdrs.groupBy("squareId").count().show()
        cdrs.groupBy("countryCode").avg("internetTrafficActivity").show()
        cdrs.groupBy("countryCode").max("callOutActivity").show()
        cdrs.groupBy("countryCode").min("callOutActivity").show()
        cdrs.groupBy("squareId").sum("internetTrafficActivity").show()
        cdrs.groupBy("squareId").agg(sum("callOutActivity"), sum("callInActivity"), sum("smsOutActivity"), sum("smsInActivity"), sum("internetTrafficActivity")).show()
        cdrs.groupBy("countryCode").sum("internetTrafficActivity").orderBy(desc("SUM(internetTrafficActivity)")).show()
        cdrs.agg(sum("callOutActivity"), sum("callInActivity"), sum("smsOutActivity"), sum("smsInActivity"), sum("internetTrafficActivity")).show()
        cdrs.rollup("squareId", "countryCode").count().orderBy(desc("squareId"), desc("countryCode")).rdd.saveAsTextFile("/tmp/rollup" + rdd.hashCode())
        cdrs.cube("squareId", "countryCode").count().orderBy(desc("squareId"), desc("countryCode")).rdd.saveAsTextFile("/tmp/cube" + rdd.hashCode())
        cdrs.dropDuplicates(Array("callOutActivity", "callInActivity")).show()
        cdrs.select("squareId", "countryCode", "internetTrafficActivity").distinct.show()
        cdrs.withColumn("endTime", cdrs("timeInterval") + 600000).show()
        cdrs.sample(true, 0.01).show()
      })

    ssc.start()
    ssc.awaitTermination()
  }

  def seqToCdr(rdd: RDD[Array[String]]): RDD[Cdr] = {
    rdd.map(c => c.map(f => f match {
      case x if x.isEmpty() => "0"
      case x => x
    })).map(c => Cdr(c(0).toInt, c(1).toLong, c(2).toInt, c(3).toFloat,
      c(4).toFloat, c(5).toFloat, c(6).toFloat, c(7).toFloat))
  }
} 
Example 136
Source File: L8-28DataFrameExamplesOps.scala    From prosparkstreaming   with Apache License 2.0 5 votes vote down vote up
package org.apress.prospark

import scala.reflect.runtime.universe

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext

object CdrDataframeExamples2App {

  case class Cdr(squareId: Int, timeInterval: Long, countryCode: Int,
    smsInActivity: Float, smsOutActivity: Float, callInActivity: Float,
    callOutActivity: Float, internetTrafficActivity: Float)

  def main(args: Array[String]) {
    if (args.length != 4) {
      System.err.println(
        "Usage: CdrDataframeExamples2App <appname> <batchInterval> <hostname> <port>")
      System.exit(1)
    }
    val Seq(appName, batchInterval, hostname, port) = args.toSeq

    val conf = new SparkConf()
      .setAppName(appName)
      .setJars(SparkContext.jarOfClass(this.getClass).toSeq)

    val ssc = new StreamingContext(conf, Seconds(batchInterval.toInt))

    val sqlC = new SQLContext(ssc.sparkContext)
    import sqlC.implicits._

    var previousCdrs: Option[DataFrame] = None

    val cdrStream = ssc.socketTextStream(hostname, port.toInt)
      .map(_.split("\\t", -1))
      .foreachRDD(rdd => {
        val cdrs = seqToCdr(rdd).toDF().select("squareId", "countryCode").dropDuplicates()
        previousCdrs match {
          case Some(prevCdrs) => cdrs.unionAll(prevCdrs).show()
          //case Some(prevCdrs) => cdrs.intersect(prevCdrs).show()
          //case Some(prevCdrs) => cdrs.except(prevCdrs).show()
          case None => Unit
        }
        previousCdrs = Some(cdrs)
      })

    ssc.start()
    ssc.awaitTermination()
  }

  def seqToCdr(rdd: RDD[Array[String]]): RDD[Cdr] = {
    rdd.map(c => c.map(f => f match {
      case x if x.isEmpty() => "0"
      case x => x
    })).map(c => Cdr(c(0).toInt, c(1).toLong, c(2).toInt, c(3).toFloat,
      c(4).toFloat, c(5).toFloat, c(6).toFloat, c(7).toFloat))
  }
} 
Example 137
Source File: HelloWorldDataSource.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package com.github.dnvriend.spark.datasources.helloworld

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{ BaseRelation, DataSourceRegister, RelationProvider, TableScan }
import org.apache.spark.sql.types.{ StringType, StructField, StructType }
import org.apache.spark.sql.{ Row, SQLContext }

class HelloWorldDataSource extends RelationProvider with DataSourceRegister with Serializable {
  override def shortName(): String = "helloworld"

  override def hashCode(): Int = getClass.hashCode()

  override def equals(other: scala.Any): Boolean = other.isInstanceOf[HelloWorldDataSource]

  override def toString: String = "HelloWorldDataSource"

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    val path = parameters.get("path")
    path match {
      case Some(p) => new HelloWorldRelationProvider(sqlContext, p, parameters)
      case _       => throw new IllegalArgumentException("Path is required for Tickets datasets")
    }
  }
}

class HelloWorldRelationProvider(val sqlContext: SQLContext, path: String, parameters: Map[String, String]) extends BaseRelation with TableScan {
  import sqlContext.implicits._

  override def schema: StructType = StructType(Array(
    StructField("key", StringType, nullable = false),
    StructField("value", StringType, nullable = true)
  ))

  override def buildScan(): RDD[Row] =
    Seq(
      "path" -> path,
      "message" -> parameters.getOrElse("message", ""),
      "name" -> s"Hello ${parameters.getOrElse("name", "")}",
      "hello_world" -> "Hello World!"
    ).toDF.rdd
} 
Example 138
Source File: BigQueryReader.scala    From sope   with Apache License 2.0 5 votes vote down vote up
package com.sope.spark.utils.google

import com.google.cloud.hadoop.io.bigquery.{BigQueryConfiguration, GsonBigQueryInputFormat}
import com.google.gson.JsonObject
import com.sope.utils.Logging
import org.apache.hadoop.io.LongWritable
import org.apache.spark.sql.{DataFrame, SQLContext}


  def load(): DataFrame = {
    import sqlContext.implicits._
    // Load data from BigQuery.
    val tableData = sc.newAPIHadoopRDD(
      conf,
      classOf[GsonBigQueryInputFormat],
      classOf[LongWritable],
      classOf[JsonObject])
      .map(_._2.toString)
    sqlContext.read.json(tableData.toDS)
  }
} 
Example 139
Source File: TestContext.scala    From sope   with Apache License 2.0 5 votes vote down vote up
package com.sope.etl

import org.apache.spark.sql.{SQLContext, SparkSession}


object TestContext {

  def getSQlContext: SQLContext = {
    System.setProperty(UDFRegistrationClassProperty, "com.sope.etl.custom.CustomUDF")
    System.setProperty(TransformationRegistrationClassProperty, "com.sope.etl.custom.CustomTransformation")
    SparkSession.builder()
      .master("local[*]")
      .appName("SopeUnitTest")
      .getOrCreate().sqlContext
  }
} 
Example 140
Source File: UDFRegistration.scala    From sope   with Apache License 2.0 5 votes vote down vote up
package com.sope.etl.register

import com.sope.etl.{SopeETLConfig, getClassInstance}
import com.sope.utils.Logging
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.expressions.UserDefinedFunction


  def registerCustomUDFs(sqlContext: SQLContext): Unit = {
    SopeETLConfig.UDFRegistrationConfig match {
      case Some(classStr) =>
        logInfo(s"Registering custom UDFs from $classStr")
        getClassInstance[UDFRegistration](classStr) match {
          case Some(udfClass) =>
            udfClass.performRegistration(sqlContext)
            logInfo("Successfully registered custom UDFs")
          case _ => logError(s"UDF Registration failed")
        }
      case None => logInfo("No class defined for registering Custom udfs")
    }
  }
} 
Example 141
Source File: TensorflowRelation.scala    From ecosystem   with Apache License 2.0 5 votes vote down vote up
package org.tensorflow.spark.datasources.tfrecords

import org.apache.hadoop.io.{BytesWritable, NullWritable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{Row, SQLContext, SparkSession}
import org.tensorflow.example.{SequenceExample, Example}
import org.tensorflow.hadoop.io.TFRecordFileInputFormat
import org.tensorflow.spark.datasources.tfrecords.serde.DefaultTfRecordRowDecoder


case class TensorflowRelation(options: Map[String, String], customSchema: Option[StructType]=None)
                             (@transient val session: SparkSession) extends BaseRelation with TableScan {

  //Import TFRecords as DataFrame happens here
  lazy val (tfRdd, tfSchema) = {
    val rdd = session.sparkContext.newAPIHadoopFile(options("path"), classOf[TFRecordFileInputFormat], classOf[BytesWritable], classOf[NullWritable])

    val recordType = options.getOrElse("recordType", "Example")

    recordType match {
      case "Example" =>
        val exampleRdd = rdd.map{case (bytesWritable, nullWritable) =>
          Example.parseFrom(bytesWritable.getBytes)
        }
        val finalSchema = customSchema.getOrElse(TensorFlowInferSchema(exampleRdd))
        val rowRdd = exampleRdd.map(example => DefaultTfRecordRowDecoder.decodeExample(example, finalSchema))
        (rowRdd, finalSchema)
      case "SequenceExample" =>
        val sequenceExampleRdd = rdd.map{case (bytesWritable, nullWritable) =>
          SequenceExample.parseFrom(bytesWritable.getBytes)
        }
        val finalSchema = customSchema.getOrElse(TensorFlowInferSchema(sequenceExampleRdd))
        val rowRdd = sequenceExampleRdd.map(example => DefaultTfRecordRowDecoder.decodeSequenceExample(example, finalSchema))
        (rowRdd, finalSchema)
      case _ =>
        throw new IllegalArgumentException(s"Unsupported recordType ${recordType}: recordType can be Example or SequenceExample")
    }
  }

  override def sqlContext: SQLContext = session.sqlContext

  override def schema: StructType = tfSchema

  override def buildScan(): RDD[Row] = tfRdd
} 
Example 142
Source File: TestUtils.scala    From shc   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import java.nio.ByteBuffer
import java.io.{IOException, File}
import java.nio.ByteBuffer
import java.util

import org.apache.avro.generic.GenericData

import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import com.google.common.io.Files
import org.apache.spark.sql.SQLContext

import scala.util.Random

object TestUtils {

  def generateRandomByteBuffer(rand: Random, size: Int): ByteBuffer = {
    val bb = ByteBuffer.allocate(size)
    val arrayOfBytes = new Array[Byte](size)
    rand.nextBytes(arrayOfBytes)
    bb.put(arrayOfBytes)
  }

  def generateRandomMap(rand: Random, size: Int): java.util.Map[String, Int] = {
    val jMap = new util.HashMap[String, Int]()
    for (i <- 0 until size) {
      jMap.put(rand.nextString(5), i)
    }
    jMap
  }

  def generateRandomArray(rand: Random, size: Int): util.ArrayList[Boolean] = {
    val vec = new util.ArrayList[Boolean]()
    for (i <- 0 until size) {
      vec.add(rand.nextBoolean())
    }
    vec
  }
} 
Example 143
Source File: CurrentPersistenceIdsQuerySourceProvider.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package akka.persistence.jdbc.spark.sql.execution.streaming

import org.apache.spark.sql.execution.streaming.{ LongOffset, Offset, Source }
import org.apache.spark.sql.sources.{ DataSourceRegister, StreamSourceProvider }
import org.apache.spark.sql.types.{ StringType, StructField, StructType }
import org.apache.spark.sql.{ SQLContext, _ }

object CurrentPersistenceIdsQuerySourceProvider {
  val name = "current-persistence-id"
  val schema: StructType = StructType(Array(
    StructField("persistence_id", StringType, nullable = false)
  ))
}

class CurrentPersistenceIdsQuerySourceProvider extends StreamSourceProvider with DataSourceRegister with Serializable {
  override def sourceSchema(
    sqlContext: SQLContext,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): (String, StructType) = {
    CurrentPersistenceIdsQuerySourceProvider.name -> CurrentPersistenceIdsQuerySourceProvider.schema
  }

  override def createSource(
    sqlContext: SQLContext,
    metadataPath: String,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): Source = {
    new CurrentPersistenceIdsQuerySourceImpl(sqlContext, parameters("path"))
  }
  override def shortName(): String = CurrentPersistenceIdsQuerySourceProvider.name
}

class CurrentPersistenceIdsQuerySourceImpl(val sqlContext: SQLContext, val readJournalPluginId: String) extends Source with ReadJournalSource {
  override def schema: StructType = CurrentPersistenceIdsQuerySourceProvider.schema

  override def getOffset: Option[Offset] = {
    val offset = maxPersistenceIds
    println("[CurrentPersistenceIdsQuery]: Returning maximum offset: " + offset)
    Some(LongOffset(offset))
  }

  override def getBatch(_start: Option[Offset], _end: Offset): DataFrame = {
    val (start, end) = getStartEnd(_start, _end)
    println(s"[CurrentPersistenceIdsQuery]: Getting currentPersistenceIds from start: $start, end: $end")
    import sqlContext.implicits._
    persistenceIds(start, end).toDF()
  }
} 
Example 144
Source File: CurrentEventsByPersistenceIdQuerySourceProvider.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package akka.persistence.jdbc.spark.sql.execution.streaming

import org.apache.spark.sql.execution.streaming.{ LongOffset, Offset, Source }
import org.apache.spark.sql.sources.{ DataSourceRegister, StreamSourceProvider }
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{ SQLContext, _ }

object CurrentEventsByPersistenceIdQuerySourceProvider {
  val name = "current-events-by-persistence-id"
}

class CurrentEventsByPersistenceIdQuerySourceProvider extends StreamSourceProvider with DataSourceRegister with Serializable {
  override def sourceSchema(
    sqlContext: SQLContext,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): (String, StructType) = {
    println(s"[CurrentEventsByPersistenceIdQuerySourceProvider.sourceSchema]: schema: $schema, providerName: $providerName, parameters: $parameters")
    CurrentEventsByPersistenceIdQuerySourceProvider.name -> schema.get
  }

  override def createSource(
    sqlContext: SQLContext,
    metadataPath: String,
    schema: Option[StructType],
    providerName: String,
    parameters: Map[String, String]
  ): Source = {

    val eventMapperFQCN: String = parameters.get("event-mapper") match {
      case Some(_eventMapper) => _eventMapper
      case _                  => throw new RuntimeException("No event mapper FQCN")
    }

    val pid = (parameters.get("pid"), parameters.get("persistence-id")) match {
      case (Some(pid), _) => pid
      case (_, Some(pid)) => pid
      case _              => throw new RuntimeException("No persistence_id")
    }

    new CurrentEventsByPersistenceIdQuerySourceImpl(sqlContext, parameters("path"), eventMapperFQCN, pid, schema.get)
  }
  override def shortName(): String = CurrentEventsByPersistenceIdQuerySourceProvider.name
}

class CurrentEventsByPersistenceIdQuerySourceImpl(val sqlContext: SQLContext, val readJournalPluginId: String, eventMapperFQCN: String, persistenceId: String, override val schema: StructType) extends Source with ReadJournalSource {
  override def getOffset: Option[Offset] = {
    val offset = maxEventsByPersistenceId(persistenceId)
    println("[CurrentEventsByPersistenceIdQuery]: Returning maximum offset: " + offset)
    Some(LongOffset(offset))
  }

  override def getBatch(_start: Option[Offset], _end: Offset): DataFrame = {
    val (start, end) = getStartEnd(_start, _end)
    val df: DataFrame = eventsByPersistenceId(persistenceId, start, end, eventMapperFQCN)
    println(s"[CurrentEventsByPersistenceIdQuery]: Getting currentPersistenceIds from start: $start, end: $end, DataFrame.count: ${df.count}")
    df
  }
} 
Example 145
Source File: PersonEventMapper.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package com.github.dnvriend.spark.mapper

import akka.persistence.jdbc.spark.sql.execution.streaming.EventMapper
import akka.persistence.query.EventEnvelope
import com.github.dnvriend.spark.datasources.person.Person
import org.apache.spark.sql.{ Row, SQLContext }
import org.apache.spark.sql.types._

class PersonEventMapper extends EventMapper {
  override def row(envelope: EventEnvelope, sqlContext: SQLContext): Row = envelope match {
    case EventEnvelope(offset, persistenceId, sequenceNr, Person(id, name, age)) =>
      Row(offset, persistenceId, sequenceNr, id, name, age)
  }

  override def schema: StructType =
    PersonEventMapper.schema
}

object PersonEventMapper {
  val schema = StructType(Array(
    StructField("offset", LongType, nullable = false),
    StructField("persistence_id", StringType, nullable = false),
    StructField("sequence_number", LongType, nullable = false),
    StructField("id", LongType, nullable = false),
    StructField("name", StringType, nullable = true),
    StructField("age", IntegerType, nullable = true)
  ))
} 
Example 146
Source File: PersonDataSource.scala    From apache-spark-test   with Apache License 2.0 5 votes vote down vote up
package com.github.dnvriend.spark.datasources.person

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{ Row, SQLContext }

final case class Person(id: Long, name: String, age: Int)

class PersonDataSource extends RelationProvider with DataSourceRegister with Serializable {
  override def shortName(): String = "person"

  override def hashCode(): Int = getClass.hashCode()

  override def equals(other: scala.Any): Boolean = other.isInstanceOf[PersonDataSource]

  override def toString: String = "PersonDataSource"

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
    val path = parameters.get("path")
    path match {
      case Some(p) => new PersonRelationProvider(sqlContext, p)
      case _       => throw new IllegalArgumentException("Path is required for Tickets datasets")
    }
  }
}

object PersonRelationProvider {
  val regex = """(id="[\d]+)|(name="[\s\w]+)|(age="[\d]+)""".r
  val schema = StructType(Array(
    StructField("id", LongType, nullable = false),
    StructField("name", StringType, nullable = true),
    StructField("age", IntegerType, nullable = true)
  ))
}

class PersonRelationProvider(val sqlContext: SQLContext, path: String) extends BaseRelation with TableScan with Serializable {
  override def schema: StructType = PersonRelationProvider.schema

  override def buildScan(): RDD[Row] =
    sqlContext.sparkContext.textFile(path)
      .filter(_.contains("person"))
      .map(line => PersonRelationProvider.regex.findAllIn(line).toList)
      .map { xs =>
        val id = xs.head.replace("id=\"", "").toLong
        val name = xs.drop(1).map(str => str.replace("name=\"", "")).headOption
        val age = xs.drop(2).map(str => str.replace("age=\"", "")).headOption.map(_.toInt)
        Row(id, name, age)
      }
} 
Example 147
Source File: VectorTempTable.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.loader.command

import org.apache.spark.sql.SQLContext

import com.actian.spark_vector.loader.options.UserOptions
import com.actian.spark_vector.sql.{ SparkSqlTable, TableRef, TempTable, VectorRelation }

object VectorTempTable {
  private def generateSQLOption(key: String, queries: Seq[String]): Seq[(String, String)] = for { i <- 0 until queries.size } yield {
    (s"${key}${i}", s"${queries(i)}")
  }

  private def parseOptions(config: UserOptions): Map[String, String] = {
    val base = Seq("host" -> config.vector.host,
      "instance" -> config.vector.instance,
      "database" -> config.vector.database,
      "port" -> config.vector.port,
      "table" -> config.vector.targetTable)
    val optional = Seq(config.vector.user.map("user" -> _),
      config.vector.password.map("password" -> _)).flatten ++
      config.vector.preSQL.map(generateSQLOption("loadpresql", _)).getOrElse(Nil) ++
      config.vector.postSQL.map(generateSQLOption("loadpostsql", _)).getOrElse(Nil)
    (base ++ optional).toMap
  }

  
  def register(config: UserOptions, sqlContext: SQLContext): SparkSqlTable = {
    val params = parseOptions(config)
    val tableName = params("table")
    val df = sqlContext.baseRelationToDataFrame(VectorRelation(TableRef(params), sqlContext, params))
    TempTable(tableName, df)
  }
} 
Example 148
Source File: SparkOperationTestPimpers.scala    From sparkplug   with MIT License 5 votes vote down vote up
package springnz.sparkplug.testkit

import com.typesafe.scalalogging.{ LazyLogging, Logger }
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{ DataFrame, SQLContext }
import springnz.sparkplug.core.SparkOperation
import springnz.sparkplug.util.Logging

import scala.reflect.ClassTag

object SparkOperationTestPimpers extends LazyLogging {

  private def persistTestResource[A: ClassTag](rdd: RDD[A], rddName: String, overwrite: Boolean = false)(
    implicit projectName: ProjectName): RDD[A] = {
    val path = RDDPersister.getPath(projectName.name, rddName)
    if (overwrite || (!overwrite && !path.exists)) {
      if (path.exists) {
        logger.info(s"deleting existing RDD at ${path.pathAsString}")
        path.delete()
      }
      RDDPersister.persistRDD(path.pathAsString, rdd)
    } else { // (!overwrite && path.exists)
      logger.info(s"Not persisting RDD that already exists at path [${path.pathAsString}]")
      rdd
    }
  }

  class RDDExtensions[A: ClassTag](operation: SparkOperation[RDD[A]]) {
    import RDDSamplers._

    def saveTo(rddName: String, sampler: RDD[A] ⇒ RDD[A] = identitySampler)(
      implicit projectName: ProjectName): SparkOperation[RDD[A]] =
      operation.map {
        rdd ⇒
          val sampled = sampler(rdd)
          persistTestResource(sampled, rddName, overwrite = false)
          sampled
      }

    def sourceFrom(rddName: String, sampler: RDD[A] ⇒ RDD[A] = identitySampler)(
      implicit projectName: ProjectName): SparkOperation[RDD[A]] =
      SparkOperation { ctx ⇒
        val path = RDDPersister.getPath(projectName.name, rddName)
        if (path.exists)
          ctx.objectFile[A](path.pathAsString)
        else {
          val rdd = operation.run(ctx)
          val sampled = sampler(rdd)
          persistTestResource(sampled, rddName, overwrite = false)
          sampled
        }
      }
  }

  class DataFrameExtensions(operation: SparkOperation[DataFrame]) {
    import RDDSamplers._

    def saveTo(rddName: String,
      overwrite: Boolean = false,
      sampler: RDD[String] ⇒ RDD[String] = identitySampler)(
        implicit projectName: ProjectName): SparkOperation[DataFrame] =
      operation.map {
        df ⇒
          val rdd: RDD[String] = df.toJSON
          val sampled = sampler(rdd)
          persistTestResource(sampled, rddName, overwrite)
          val sqlContext = new SQLContext(sampled.sparkContext)
          sqlContext.read.json(sampled)
      }

    def sourceFrom(dataFrameName: String,
      overwrite: Boolean = false,
      sampler: RDD[String] ⇒ RDD[String] = rdd ⇒ rdd)(
        implicit projectName: ProjectName, log: Logger): SparkOperation[DataFrame] =
      SparkOperation { ctx ⇒
        val path = RDDPersister.getPath(projectName.name, dataFrameName)
        val sampledRDD = if (path.exists)
          ctx.objectFile[String](path.pathAsString)
        else {
          val df = operation.run(ctx)
          val rdd: RDD[String] = df.toJSON
          val sampled = sampler(rdd)
          persistTestResource(sampled, dataFrameName, overwrite)
          sampled
        }
        val sqlContext = new SQLContext(ctx)
        sqlContext.read.json(sampledRDD)
      }

  }
} 
Example 149
Source File: BigFileDatasource.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql

import java.net.URI
import java.util.ServiceLoader

import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

import io.projectglow.common.{GlowLogging, WithUtils}


  def write(rdd: RDD[Array[Byte]], path: String) {
    val uri = new URI(path)
    uploaders.find(_.canUpload(rdd.sparkContext.hadoopConfiguration, path)) match {
      case Some(uploader) => uploader.upload(rdd, path)
      case None =>
        logger.info(s"Could not find a parallel uploader for $path, uploading from the driver")
        writeFileFromDriver(new Path(uri), rdd)
    }
  }

  private def writeFileFromDriver(path: Path, byteRdd: RDD[Array[Byte]]): Unit = {
    val sc = byteRdd.sparkContext
    val fs = path.getFileSystem(sc.hadoopConfiguration)
    WithUtils.withCloseable(fs.create(path)) { stream =>
      WithUtils.withCachedRDD(byteRdd) { cachedRdd =>
        cachedRdd.count()
        cachedRdd.toLocalIterator.foreach { chunk =>
          stream.write(chunk)
        }
      }
    }
  }
} 
Example 150
Source File: MLlibTestSparkContext.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.ml.util

import java.io.File

import org.apache.spark.SparkContext
import org.apache.spark.sql.types.UDTRegistration
import org.apache.spark.sql.{SQLContext, SQLImplicits, SparkSession}
import org.apache.spark.util.{SparkUtil, Utils}
import org.scalatest.Suite

trait MLlibTestSparkContext extends TempDirectory { self: Suite =>
  @transient var spark: SparkSession = _
  @transient var sc: SparkContext = _
  @transient var checkpointDir: String = _

  override def beforeAll() {
    super.beforeAll()

    SparkUtil.UDTRegister("org.apache.spark.linalg.Vector", "org.apache.spark.linalg.VectorUDT")
    SparkUtil.UDTRegister("org.apache.spark.linalg.DenseVector", "org.apache.spark.linalg.VectorUDT")
    SparkUtil.UDTRegister("org.apache.spark.linalg.SparseVector", "org.apache.spark.linalg.VectorUDT")
    SparkUtil.UDTRegister("org.apache.spark.linalg.Matrix", "org.apache.spark.linalg.MatrixUDT")
    SparkUtil.UDTRegister("org.apache.spark.linalg.DenseMatrix", "org.apache.spark.linalg.MatrixUDT")
    SparkUtil.UDTRegister("org.apache.spark.linalg.SparseMatrix", "org.apache.spark.linalg.MatrixUDT")

    spark = SparkSession.builder
      .master("local[2]")
      .appName("MLlibUnitTest")
      .getOrCreate()
    sc = spark.sparkContext

    checkpointDir = SparkUtil.createDirectory(tempDir.getCanonicalPath, "checkpoints").toString
    sc.setCheckpointDir(checkpointDir)
  }

  override def afterAll() {
    try {
      SparkUtil.deleteRecursively(new File(checkpointDir))
      SparkSession.clearActiveSession()
      if (spark != null) {
        spark.stop()
      }
      spark = null
    } finally {
      super.afterAll()
    }
  }

  /**
   * A helper object for importing SQL implicits.
   *
   * Note that the alternative of importing `spark.implicits._` is not possible here.
   * This is because we create the `SQLContext` immediately before the first test is run,
   * but the implicits import is needed in the constructor.
   */
  protected object testImplicits extends SQLImplicits {
    protected override def _sqlContext: SQLContext = self.spark.sqlContext
  }
} 
Example 151
Source File: SharedSQLContext.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import scala.concurrent.duration._

import org.scalatest.BeforeAndAfterEach
import org.scalatest.concurrent.Eventually

import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.sql.{SparkSession, SQLContext}


  protected override def afterAll(): Unit = {
    super.afterAll()
    if (_spark != null) {
      _spark.sessionState.catalog.reset()
      _spark.stop()
      _spark = null
    }
  }

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    DebugFilesystem.clearOpenStreams()
  }

  protected override def afterEach(): Unit = {
    super.afterEach()
    // files can be closed from other threads, so wait a bit
    // normally this doesn't take more than 1s
    eventually(timeout(10.seconds)) {
      DebugFilesystem.assertNoOpenStreams()
    }
  }
} 
Example 152
Source File: DefaultSource.scala    From spark-google-spreadsheets   with Apache License 2.0 5 votes vote down vote up
package com.github.potix2.spark.google.spreadsheets

import java.io.File

import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}

class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
  final val DEFAULT_CREDENTIAL_PATH = "/etc/gdata/credential.p12"

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]) = {
    createRelation(sqlContext, parameters, null)
  }

  private[spreadsheets] def pathToSheetNames(parameters: Map[String, String]): (String, String) = {
    val path = parameters.getOrElse("path", sys.error("'path' must be specified for spreadsheets."))
    val elems = path.split('/')
    if (elems.length < 2)
      throw new Exception("'path' must be formed like '<spreadsheet>/<worksheet>'")

    (elems(0), elems(1))
  }

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType) = {
    val (spreadsheetName, worksheetName) = pathToSheetNames(parameters)
    val context = createSpreadsheetContext(parameters)
    createRelation(sqlContext, context, spreadsheetName, worksheetName, schema)
  }


  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
    val (spreadsheetName, worksheetName) = pathToSheetNames(parameters)
    implicit val context = createSpreadsheetContext(parameters)
    val spreadsheet = SparkSpreadsheetService.findSpreadsheet(spreadsheetName)
    if(!spreadsheet.isDefined)
      throw new RuntimeException(s"no such a spreadsheet: $spreadsheetName")

    spreadsheet.get.addWorksheet(worksheetName, data.schema, data.collect().toList, Util.toRowData)
    createRelation(sqlContext, context, spreadsheetName, worksheetName, data.schema)
  }

  private[spreadsheets] def createSpreadsheetContext(parameters: Map[String, String]) = {
    val serviceAccountIdOption = parameters.get("serviceAccountId")
    val credentialPath = parameters.getOrElse("credentialPath", DEFAULT_CREDENTIAL_PATH)
    SparkSpreadsheetService(serviceAccountIdOption, new File(credentialPath))
  }

  private[spreadsheets] def createRelation(sqlContext: SQLContext,
                                           context: SparkSpreadsheetService.SparkSpreadsheetContext,
                                           spreadsheetName: String,
                                           worksheetName: String,
                                           schema: StructType): SpreadsheetRelation =
    if (schema == null) {
      createRelation(sqlContext, context, spreadsheetName, worksheetName, None)
    }
    else {
      createRelation(sqlContext, context, spreadsheetName, worksheetName, Some(schema))
    }

  private[spreadsheets] def createRelation(sqlContext: SQLContext,
                                           context: SparkSpreadsheetService.SparkSpreadsheetContext,
                                           spreadsheetName: String,
                                           worksheetName: String,
                                           schema: Option[StructType]): SpreadsheetRelation =
    SpreadsheetRelation(context, spreadsheetName, worksheetName, schema)(sqlContext)
} 
Example 153
Source File: SpreadsheetRelation.scala    From spark-google-spreadsheets   with Apache License 2.0 5 votes vote down vote up
package com.github.potix2.spark.google.spreadsheets

import com.github.potix2.spark.google.spreadsheets.SparkSpreadsheetService.SparkSpreadsheetContext
import com.github.potix2.spark.google.spreadsheets.util._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation, TableScan}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SQLContext}

case class SpreadsheetRelation protected[spark] (
                                                  context:SparkSpreadsheetContext,
                                                  spreadsheetName: String,
                                                  worksheetName: String,
                                                  userSchema: Option[StructType] = None)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan with InsertableRelation {

  import com.github.potix2.spark.google.spreadsheets.SparkSpreadsheetService._

  override def schema: StructType = userSchema.getOrElse(inferSchema())

  private lazy val aWorksheet: SparkWorksheet =
    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(aWorksheet) => aWorksheet
      case Left(e) => throw e
    }

  private lazy val rows: Seq[Map[String, String]] = aWorksheet.rows

  private[spreadsheets] def findWorksheet(spreadsheetName: String, worksheetName: String)(implicit ctx: SparkSpreadsheetContext): Either[Throwable, SparkWorksheet] =
    for {
      sheet <- findSpreadsheet(spreadsheetName).toRight(new RuntimeException(s"no such spreadsheet: $spreadsheetName")).right
      worksheet <- sheet.findWorksheet(worksheetName).toRight(new RuntimeException(s"no such worksheet: $worksheetName")).right
    } yield worksheet

  override def buildScan(): RDD[Row] = {
    val aSchema = schema
    sqlContext.sparkContext.makeRDD(rows).mapPartitions { iter =>
      iter.map { m =>
        var index = 0
        val rowArray = new Array[Any](aSchema.fields.length)
        while(index < aSchema.fields.length) {
          val field = aSchema.fields(index)
          rowArray(index) = if (m.contains(field.name)) {
            TypeCast.castTo(m(field.name), field.dataType, field.nullable)
          } else {
            null
          }
          index += 1
        }
        Row.fromSeq(rowArray)
      }
    }
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    if(!overwrite) {
      sys.error("Spreadsheet tables only support INSERT OVERWRITE for now.")
    }

    findWorksheet(spreadsheetName, worksheetName)(context) match {
      case Right(w) =>
        w.updateCells(data.schema, data.collect().toList, Util.toRowData)
      case Left(e) =>
        throw e
    }
  }

  private def inferSchema(): StructType =
    StructType(aWorksheet.headers.toList.map { fieldName =>
      StructField(fieldName, StringType, nullable = true)
    })

} 
Example 154
Source File: TestSparkContext.scala    From spark-images   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.image

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, DataFrame, SQLContext, SparkSession}

import scala.reflect.runtime.universe._
import org.scalatest.{FunSuite, BeforeAndAfterAll}

// This context is used for all tests in this project
trait TestSparkContext extends BeforeAndAfterAll { self: FunSuite =>
  @transient var sc: SparkContext = _
  @transient var sqlContext: SQLContext = _
  @transient lazy val spark: SparkSession = {
    val conf = new SparkConf()
      .setMaster("local[*]")
      .setAppName("Spark-Image-Test")
      .set("spark.ui.port", "4079")
      .set("spark.sql.shuffle.partitions", "4")  // makes small tests much faster

    val sess = SparkSession.builder().config(conf).getOrCreate()
    sess.sparkContext.setLogLevel("WARN")
    sess
  }

  override def beforeAll() {
    super.beforeAll()
    sc = spark.sparkContext
    sqlContext = spark.sqlContext
    import spark.implicits._
  }

  override def afterAll() {
    sqlContext = null
    if (sc != null) {
      sc.stop()
    }
    sc = null
    super.afterAll()
  }

  def makeDF[T: TypeTag](xs: Seq[T], col: String): DataFrame = {
    sqlContext.createDataFrame(xs.map(Tuple1.apply)).toDF(col)
  }

  def compareRows(r1: Array[Row], r2: Seq[Row]): Unit = {
    val a = r1.sortBy(_.toString())
    val b = r2.sortBy(_.toString())
    assert(a === b)
  }
} 
Example 155
Source File: KuduSinkProvider.scala    From kafka-examples   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.streaming.refapp.kudu

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode


class KuduSinkProvider extends StreamSinkProvider with DataSourceRegister {

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {
    require(outputMode == OutputMode.Update, "only 'update' OutputMode is supported")
    KuduSink.withDefaultContext(sqlContext, parameters)
  }

  override def shortName(): String = "kudu"
} 
Example 156
Source File: KuduSink.scala    From kafka-examples   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.streaming.refapp.kudu

import org.apache.kudu.spark.kudu.KuduContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.slf4j.LoggerFactory

import scala.util.control.NonFatal

object KuduSink {
  def withDefaultContext(sqlContext: SQLContext, parameters: Map[String, String]) =
    new KuduSink(new KuduContext(parameters("kudu.master"), sqlContext.sparkContext), parameters)
}


class KuduSink(initKuduContext: => KuduContext, parameters: Map[String, String]) extends Sink {

  private val logger = LoggerFactory.getLogger(getClass)

  private var kuduContext = initKuduContext

  private val tablename = parameters("kudu.table")

  private val retries = parameters.getOrElse("retries", "1").toInt
  require(retries >= 0, "retries must be non-negative")

  logger.info(s"Created Kudu sink writing to table $tablename")

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    for (attempt <- 0 to retries) {
      try {
        kuduContext.upsertRows(data, tablename)
        return
      } catch {
        case NonFatal(e) =>
          if (attempt < retries) {
            logger.warn("Kudu upsert error, retrying...", e)
            kuduContext = initKuduContext
          }
          else {
            logger.error("Kudu upsert error, exhausted", e)
            throw e
          }
      }
    }
  }
} 
Example 157
Source File: ExtractorsSpec.scala    From streamliner-starter   with Apache License 2.0 5 votes vote down vote up
package test

import com.memsql.spark.etl.api.UserExtractConfig
import com.memsql.spark.etl.utils.ByteUtils
import spray.json.JsString
import com.memsql.streamliner.starter.BasicExtractor
import org.apache.spark.streaming._
import org.apache.spark.sql.SQLContext

class ExtractorsSpec extends UnitSpec with LocalSparkContext {
  val emptyConfig = UserExtractConfig(class_name = "Test", value = new JsString("empty"))
  val logger = new TestLogger("test")

  var ssc: StreamingContext = _
  var sqlContext: SQLContext = _

  override def beforeEach(): Unit = {
    super.beforeEach()
    ssc = new StreamingContext(sc, Seconds(1))
    sqlContext = new SQLContext(sc)
  }

  "BasicExtractor" should "emit a constant DataFrame" in {
    val extract = new BasicExtractor
    
    val maybeDf = extract.next(ssc, 1, sqlContext, emptyConfig, 1, logger)
    assert(maybeDf.isDefined)

    val total = maybeDf.get.select("number").rdd.map(r => r(0).asInstanceOf[Int]).sum()
    assert(total == 15)
  }

} 
Example 158
Source File: QueryLshTest.scala    From cosine-lsh-join-spark   with MIT License 5 votes vote down vote up
package com.soundcloud.lsh

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed._
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalatest.{FunSuite, Matchers}

class QueryLshTest extends FunSuite with SparkLocalContext with Matchers {

  lazy val lsh = new QueryLsh(
    minCosineSimilarity = -1.0,
    dimensions = 100,
    numNeighbours = 10,
    maxMatches = 1000,
    rounds = 10)(new SQLContext(sc).sparkSession)

  test("join bitmap") {
    val bitseq = Seq(
      new BitSet(2),
      new BitSet(2)
    )
    bitseq(0).set(0)
    bitseq(1).set(1)
    val rdd = sc.parallelize(bitseq.map(bitSetToString(_)).zipWithIndex)
    val joined = rdd.join(rdd).values.collect
    joined shouldBe Seq((0,0), (1,1))
  }

  test("join") {
    val rows = Seq(
      IndexedRow(0, Vectors.dense(1, 1, 0, 0)),
      IndexedRow(1, Vectors.dense(1, 2, 0, 0)),
      IndexedRow(2, Vectors.dense(0, 1, 4, 2))
    )
    val inputMatrix = new IndexedRowMatrix(sc.parallelize(rows))
    val got = lsh.join(inputMatrix, inputMatrix)
    val expected = Seq(
      (0, 0),
      (1, 1),
      (2, 2)
    )
    val gotIndex = got.entries.collect.map {
      entry: MatrixEntry =>
        (entry.i, entry.j)
    }
    gotIndex.sorted should be(expected.sorted)
  }

  test("distinct") {
    val matrix = Seq(
      new MatrixEntry(1, 2, 3.4),
      new MatrixEntry(1, 2, 3.5),
      new MatrixEntry(1, 3, 3.4)
    )
    val got = distinct(sc.parallelize(matrix)).collect
    val expected = Seq(
      matrix(0), matrix(2)
    )
    got should be(expected)
  }

} 
Example 159
Source File: DefaultSource.scala    From spark-google-adwords   with Apache License 2.0 5 votes vote down vote up
package com.crealytics.google.adwords

import com.google.api.ads.common.lib.auth.OfflineCredentials
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources._

class DefaultSource extends RelationProvider {

  // The default User Agent
  private final val DEFAULT_USER_AGENT = "Spark"
  // Default During Clause: 30 days
  private final val DEFAULT_DURING = "LAST_30_DAYS"

  
  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): AdWordsRelation = {

    // gather parameters
    val clientId = checkParameter(parameters, "clientId")
    val clientSecret = checkParameter(parameters, "clientSecret")
    val developerToken = checkParameter(parameters, "developerToken")
    val refreshToken = checkParameter(parameters, "refreshToken")
    val reportType = checkParameter(parameters, "reportType")
    val clientCustomerId = checkParameter(parameters, "clientCustomerId")

    val userAgent =
      parameterOrDefault(parameters, "userAgent", DEFAULT_USER_AGENT)
    val duringStmt = parameterOrDefault(parameters, "during", DEFAULT_DURING)
    // Our OAuth2 Credential
    val credential = new OfflineCredentials.Builder()
      .forApi(OfflineCredentials.Api.ADWORDS)
      .withClientSecrets(clientId, clientSecret)
      .withRefreshToken(refreshToken)
      .build
      .generateCredential
    // create relation
    AdWordsRelation(credential, developerToken, clientCustomerId, userAgent, reportType, duringStmt)(sqlContext)
  }

  // Forces a Parameter to exist, otherwise an exception is thrown.
  private def checkParameter(map: Map[String, String], param: String) = {
    if (!map.contains(param)) {
      throw new IllegalArgumentException(s"Parameter ${'"'}$param${'"'} is missing in options.")
    } else {
      map.apply(param)
    }
  }

  // Gets the Parameter if it exists, otherwise returns the default argument
  private def parameterOrDefault(map: Map[String, String], param: String, default: String) =
    map.getOrElse(param, default)
} 
Example 160
Source File: RddToDataFrame.scala    From spark-sframe   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package org.apache.spark.turi

import org.graphlab.create.GraphLabUtil
import org.apache.spark.sql.{SQLContext, Row, DataFrame}
import org.apache.spark.rdd.RDD
import scala.collection.JavaConversions._
import org.apache.spark.sql.types._
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.immutable.Map
import java.util.HashMap
import java.util.ArrayList
import java.util.{Date,GregorianCalendar}
import java.sql.Date

object EvaluateRDD {
  
  def inferSchema(obj: Any): DataType = {
    if(obj.isInstanceOf[Int]) { 
      IntegerType
    } else if(obj.isInstanceOf[String]) { 
      StringType
    } else if(obj.isInstanceOf[Double]) { 
      DoubleType
    } else if(obj.isInstanceOf[Long]) { 
      LongType
    } else if(obj.isInstanceOf[Float]) { 
      FloatType
    } else if(obj.isInstanceOf[Map[_,_]]) {
      MapType(inferSchema(obj.asInstanceOf[Map[_,_]].head._1),inferSchema(obj.asInstanceOf[Map[_,_]].head._2))
    } else if(obj.isInstanceOf[java.util.HashMap[_,_]]) {
      MapType(inferSchema(obj.asInstanceOf[java.util.HashMap[_,_]].head._1),inferSchema(obj.asInstanceOf[java.util.HashMap[_,_]].head._2))
    } else if(obj.isInstanceOf[Array[_]]) {
      ArrayType(inferSchema(obj.asInstanceOf[Array[_]](0)))
    } else if(obj.isInstanceOf[java.util.ArrayList[_]]) {
      ArrayType(inferSchema(obj.asInstanceOf[java.util.ArrayList[_]](0)))
    } else if(obj.isInstanceOf[java.util.GregorianCalendar]) {
      TimestampType
    } else if(obj.isInstanceOf[java.util.Date] || obj.isInstanceOf[java.sql.Date]) {
      DateType
    } else { 
      StringType
    }
  }

  def toScala(obj: Any): Any = {
    if (obj.isInstanceOf[java.util.HashMap[_,_]]) {
      val jmap = obj.asInstanceOf[java.util.HashMap[_,_]]
      jmap.map { case (k,v) => toScala(k) -> toScala(v) }.toMap
    }
    else if(obj.isInstanceOf[java.util.ArrayList[_]]) {
      val buf = ArrayBuffer[Any]()
      val jArray = obj.asInstanceOf[java.util.ArrayList[_]]
      for(item <- jArray) {
        buf += toScala(item)
      }
      buf.toArray
    } else if(obj.isInstanceOf[java.util.GregorianCalendar]) {
      new java.sql.Timestamp(obj.asInstanceOf[java.util.GregorianCalendar].getTime().getTime())
    } else {
      obj
    }
  }
  def toSparkDataFrame(sqlContext: SQLContext, rdd: RDD[java.util.HashMap[String,_]]): DataFrame = { 
    val scalaRDD = rdd.map(l => toScala(l))
    val rowRDD = scalaRDD.map(l => Row.fromSeq(l.asInstanceOf[Map[_,_]].values.toList))
    
    var sample_data: java.util.HashMap[String,_] = rdd.take(1)(0)
    
    var schema_list: ListBuffer[StructField] = new ListBuffer[StructField]()
    for ((name,v) <- sample_data) { 
      schema_list.append(StructField(name,inferSchema(v)))
    }
    sqlContext.createDataFrame(rowRDD,StructType(schema_list))
  }
} 
Example 161
Source File: SparkFunSuite.scala    From spark-corenlp   with GNU General Public License v3.0 5 votes vote down vote up
package com.databricks.spark.corenlp

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FunSuite}

trait SparkFunSuite extends FunSuite with BeforeAndAfterAll {
  @transient var sc: SparkContext = _
  @transient var sqlContext: SQLContext = _

  override def beforeAll(): Unit = {
    sc = SparkContext.getOrCreate(
      new SparkConf()
        .setMaster("local[2]")
        .setAppName(this.getClass.getSimpleName)
    )
    sqlContext = SQLContext.getOrCreate(sc)
  }

  override def afterAll(): Unit = {
    sc.stop()
    sc = null
    sqlContext = null
  }
} 
Example 162
Source File: DefaultSource.scala    From spark-athena   with Apache License 2.0 5 votes vote down vote up
package io.github.tmheo.spark.athena

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources._

class DefaultSource extends RelationProvider {

  override def createRelation(
                               sqlContext: SQLContext,
                               parameters: Map[String, String]): BaseRelation = {
    val jdbcOptions = new JDBCOptions(parameters)
    val partitionColumn = jdbcOptions.partitionColumn
    val lowerBound = jdbcOptions.lowerBound
    val upperBound = jdbcOptions.upperBound
    val numPartitions = jdbcOptions.numPartitions

    val partitionInfo = if (partitionColumn.isEmpty) {
      assert(lowerBound.isEmpty && upperBound.isEmpty)
      null
    } else {
      assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty)
      JDBCPartitioningInfo(
        partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
  }

} 
Example 163
Source File: MlLibOnKudu.scala    From Taxi360   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.sa.taxi360.etl.machinelearning.kudu

import com.cloudera.sa.taxi360.model.{NyTaxiYellowTrip, NyTaxiYellowTripBuilder}
import org.apache.spark.mllib.clustering.KMeans
import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

object MlLibOnKudu {
  def main(args: Array[String]): Unit = {

    if (args.length == 0) {
      println("Args: <runLocal> " +
        "<kuduMaster> " +
        "<taxiTable> " +
        "<numOfCenters> " +
        "<numOfIterations> ")
      return
    }

    val runLocal = args(0).equalsIgnoreCase("l")
    val kuduMaster = args(1)
    val taxiTable = args(2)
    val numOfCenters = args(3).toInt
    val numOfIterations = args(4).toInt

    val sc: SparkContext = if (runLocal) {
      val sparkConfig = new SparkConf()
      sparkConfig.set("spark.broadcast.compress", "false")
      sparkConfig.set("spark.shuffle.compress", "false")
      sparkConfig.set("spark.shuffle.spill.compress", "false")
      new SparkContext("local", "TableStatsSinglePathMain", sparkConfig)
    } else {
      val sparkConfig = new SparkConf().setAppName("TableStatsSinglePathMain")
      new SparkContext(sparkConfig)
    }

    val sqlContext = new SQLContext(sc)

    val kuduOptions = Map(
      "kudu.table" -> taxiTable,
      "kudu.master" -> kuduMaster)

    sqlContext.read.options(kuduOptions).format("org.apache.kudu.spark.kudu").load.
      registerTempTable("ny_taxi_trip_tmp")

    //Vector
    val vectorRDD:RDD[Vector] = sqlContext.sql("select * from ny_taxi_trip_tmp").map(r => {
      val taxiTrip = NyTaxiYellowTripBuilder.build(r)
      generateVectorOnly(taxiTrip)
    })

    println("--Running KMeans")
    val clusters = KMeans.train(vectorRDD, numOfCenters, numOfIterations)
    println(" > vector centers:")
    clusters.clusterCenters.foreach(v => println(" >> " + v))

    println("--Running corr")
    val correlMatrix: Matrix = Statistics.corr(vectorRDD, "pearson")
    println(" > corr: " + correlMatrix.toString)

    println("--Running colStats")
    val colStats = Statistics.colStats(vectorRDD)
    println(" > max: " + colStats.max)
    println(" > count: " + colStats.count)
    println(" > mean: " + colStats.mean)
    println(" > min: " + colStats.min)
    println(" > normL1: " + colStats.normL1)
    println(" > normL2: " + colStats.normL2)
    println(" > numNonZeros: " + colStats.numNonzeros)
    println(" > variance: " + colStats.variance)

    //Labeled Points
    
} 
Example 164
Source File: SparkSqlRunner.scala    From amaterasu   with Apache License 2.0 5 votes vote down vote up
package org.apache.amaterasu.executor.execution.actions.runners.spark.SparkSql

import java.io.File

import org.apache.amaterasu.common.execution.actions.Notifier
import org.apache.amaterasu.common.logging.Logging
import org.apache.amaterasu.common.runtime.Environment
import org.apache.commons.io.FilenameUtils
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode, SparkSession}



  def findFileType(folderName: File): Array[String] = {
    // get all the files from a directory
    val files: Array[File] = folderName.listFiles()
    val extensions: Array[String] = files.map(file => FilenameUtils.getExtension(file.toString))
    extensions
  }

}

object SparkSqlRunner {

  def apply(env: Environment,
            jobId: String,
            actionName: String,
            notifier: Notifier,
            sc: SparkContext): SparkSqlRunner = {

    val sparkSqlRunnerObj = new SparkSqlRunner

    sparkSqlRunnerObj.env = env
    sparkSqlRunnerObj.jobId = jobId
    sparkSqlRunnerObj.actionName = actionName
    sparkSqlRunnerObj.notifier = notifier
    sparkSqlRunnerObj.sc = sc
    sparkSqlRunnerObj.spark = SparkSession.builder().config(sc.getConf).enableHiveSupport().getOrCreate()
    sparkSqlRunnerObj
  }
} 
Example 165
Source File: ModelSparkContext.scala    From flamy   with Apache License 2.0 5 votes vote down vote up
package com.flaminem.flamy.conf.spark

import java.io.File

import com.flaminem.flamy.conf.{Flamy, FlamyContext, FlamyGlobalContext}
import com.flaminem.flamy.model.files.PresetsFile
import com.flaminem.flamy.parsing.hive.QueryUtils
import org.apache.spark.SparkConf
import org.apache.spark.sql.{SQLContext, SparkSession}


object ModelSparkContext {

  val LOCAL_METASTORE: String = "local_spark_metastore"
  val LOCAL_WAREHOUSE: String = "local_spark_warehouse"
  val LOCAL_DERBY_LOG: String = "derby.log"

  private var sparkConf: SparkConf = _

  def localMetastore(context: FlamyContext): String = {
    FlamyGlobalContext.getUniqueRunDir + "/" + LOCAL_METASTORE
  }

  def localWarehouse(context: FlamyContext): String = {
    FlamyGlobalContext.getUniqueRunDir + "/" + LOCAL_WAREHOUSE
  }


  private def init(context: FlamyContext) = {
    sparkConf = new SparkConf()
      .setAppName(Flamy.name)
      .setMaster("local")
      .set("spark.sql.crossJoin.enabled", "true")
      .set("derby.stream.error.file", FlamyGlobalContext.getUniqueRunDir + "/" + LOCAL_DERBY_LOG)
      .set("javax.jdo.option.ConnectionURL", "jdbc:derby:;databaseName=" + localMetastore(context) + ";create=true")
      .set("hive.metastore.warehouse.dir", localWarehouse(context))
  }

  private lazy val _spark = {
    SparkSession.builder()
      .enableHiveSupport()
      .config(sparkConf)
      .getOrCreate()
  }

  def spark(context: FlamyContext): SparkSession = {
    init(context)
    _spark
  }

  def getSparkSQLContext(context: FlamyContext): SQLContext = {
    val sqlContext: SQLContext = ModelSparkContext.spark(context).sqlContext.newSession()
    runPresets(sqlContext, context)
    sqlContext
  }

  private def addUDFJars(sqlContext: SQLContext, context: FlamyContext): Unit = {
    context.getUdfJarPaths.foreach{
      path => sqlContext.sql(s"ADD JAR $path")
    }
  }

  private def runPresets(sqlContext: SQLContext, context: FlamyContext): Unit = {
    addUDFJars(sqlContext, context)
    context.HIVE_PRESETS_PATH.getProperty match {
      case Some(path) =>
        val file = new PresetsFile(new File(path))
        QueryUtils.cleanAndSplitQuery(file.text).foreach{
          query =>
            val replacedQuery = context.getVariables.replaceInText(query)
            sqlContext.sql(replacedQuery)
        }
      case _ =>
        System.out.println("No presets to run")
    }
  }

} 
Example 166
Source File: Check.scala    From flamy   with Apache License 2.0 5 votes vote down vote up
package com.flaminem.flamy.commands

import com.flaminem.flamy.commands.utils.FlamySubcommand
import com.flaminem.flamy.conf.spark.ModelSparkContext
import com.flaminem.flamy.conf.{Environment, FlamyContext, FlamyGlobalOptions}
import com.flaminem.flamy.exec.FlamyRunner
import com.flaminem.flamy.exec.files.{FileRunner, ItemFileAction}
import com.flaminem.flamy.exec.hive.{HivePartitionFetcher, ModelHivePartitionFetcher}
import com.flaminem.flamy.exec.utils._
import com.flaminem.flamy.exec.utils.io.FlamyOutput
import com.flaminem.flamy.graph.TableGraph
import com.flaminem.flamy.model._
import com.flaminem.flamy.model.core.Model
import com.flaminem.flamy.model.files.FilePath
import com.flaminem.flamy.model.names.ItemName
import org.apache.spark.sql.SQLContext
import org.rogach.scallop.{ScallopConf, ScallopOption, Subcommand}

import scala.language.reflectiveCalls


      val runGraph: TableGraph = baseGraph.subGraph(items())

      val dryRunner: FlamyRunner = FlamyRunner(context)
      println("Creating schemas and tables ...")
      try {
        dryRunner.checkAll(baseGraph)
      }
      finally{
        //TODO: For some strange reason, closing the connection here will result in ClassNotFoundErrors for udfs in the RunActions...
        //      dryRunner.close()
      }
      FlamyOutput.out.info("Running Populates ...")
      dryRunner.populateAll(runGraph.model, context)
      dryRunner.close()
      ReturnStatus(success = dryRunner.getStats.getFailCount==0)
    }

  }

  override def doCommand(globalOptions: FlamyGlobalOptions, subCommands: List[ScallopConf]): ReturnStatus = {
    subCommands match {
      case  (command: FlamySubcommand)::Nil => command.doCommand(globalOptions, Nil)
      case Nil => throw new IllegalArgumentException("A subcommand is expected")
      case _ =>
        printHelp()
        ReturnFailure
    }
  }


} 
Example 167
Source File: KuduSinkProvider.scala    From kafka-examples   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.streaming.refapp.kudu

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
import org.apache.spark.sql.streaming.OutputMode


class KuduSinkProvider extends StreamSinkProvider with DataSourceRegister {

  override def createSink(sqlContext: SQLContext,
                          parameters: Map[String, String],
                          partitionColumns: Seq[String],
                          outputMode: OutputMode): Sink = {
    require(outputMode == OutputMode.Update, "only 'update' OutputMode is supported")
    KuduSink.withDefaultContext(sqlContext, parameters)
  }

  override def shortName(): String = "kudu"
} 
Example 168
Source File: TransformersSpec.scala    From streamliner-starter   with Apache License 2.0 5 votes vote down vote up
package test

import com.memsql.spark.etl.api.UserTransformConfig
import com.memsql.spark.etl.utils.ByteUtils
import com.memsql.streamliner.starter.BasicTransformer
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.types._
import spray.json.JsString

class TransformersSpec extends UnitSpec with LocalSparkContext {
  val emptyConfig = UserTransformConfig(class_name = "Test", value = new JsString("empty"))
  val logger = new TestLogger("test")

  var sqlContext: SQLContext = _

  override def beforeEach(): Unit = {
    super.beforeEach()
    sqlContext = new SQLContext(sc)
  }

  "BasicTransformer" should "only emit even numbers" in {
    val transform = new BasicTransformer

    val schema = StructType(StructField("number", IntegerType, false) :: Nil)
    val sampleData = List(1,2,3)
    val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_))
    val dfIn = sqlContext.createDataFrame(rowRDD, schema)

    val df = transform.transform(sqlContext, dfIn, emptyConfig, logger)
    assert(df.schema == schema)
    assert(df.first == Row(2))
    assert(df.count == 1)
  }

  "BasicTransformer" should "only accept IntegerType fields" in {
    val transform = new BasicTransformer

    val schema = StructType(StructField("column", StringType, false) :: Nil)
    val sampleData = List(1,2,3)
    val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_))
    val dfIn = sqlContext.createDataFrame(rowRDD, schema)

    val e = intercept[IllegalArgumentException] {
      transform.transform(sqlContext, dfIn, emptyConfig, logger)
    }
    assert(e.getMessage() == "The first column of the input DataFrame should be IntegerType")
  }

} 
Example 169
Source File: Transformers.scala    From streamliner-starter   with Apache License 2.0 5 votes vote down vote up
package com.memsql.streamliner.starter

import org.apache.spark.sql.{Row, DataFrame, SQLContext}
import org.apache.spark.sql.types._
import com.memsql.spark.etl.api.{Transformer, PhaseConfig}
import com.memsql.spark.etl.utils.PhaseLogger

// A helper object to extract the first column of a schema
object ExtractFirstStructField {
  def unapply(schema: StructType): Option[(String, DataType, Boolean, Metadata)] = schema.fields match {
    case Array(first: StructField, _*) => Some((first.name, first.dataType, first.nullable, first.metadata))
  }
}

// This transformer expects an input DataFrame and returns it
class BasicTransformer extends Transformer {
  def transform(sqlContext: SQLContext, df: DataFrame, config: PhaseConfig, logger: PhaseLogger): DataFrame = {
    logger.info("transforming the DataFrame")

    // check that the first column is of type IntegerType and return its name
    val column = df.schema match {
      case ExtractFirstStructField(name: String, dataType: IntegerType, _, _) => name
      case _ => throw new IllegalArgumentException("The first column of the input DataFrame should be IntegerType")
    }

    // filter the dataframe, returning only even numbers
    df.filter(s"$column % 2 = 0")
  }
} 
Example 170
Source File: Extractors.scala    From streamliner-starter   with Apache License 2.0 5 votes vote down vote up
package com.memsql.streamliner.starter

import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.types._
import org.apache.spark.streaming.StreamingContext
import com.memsql.spark.etl.api.{Extractor, PhaseConfig}
import com.memsql.spark.etl.utils.PhaseLogger

// This extract just returns a static range of 5 integers each batch interval
class BasicExtractor extends Extractor {
  override def next(ssc: StreamingContext, time: Long, sqlContext: SQLContext, config: PhaseConfig, batchInterval: Long,
   logger: PhaseLogger): Option[DataFrame] = {
    logger.info("extracting a constant sequence DataFrame")

    val schema = StructType(StructField("number", IntegerType, false) :: Nil)

    val sampleData = List(1,2,3,4,5)
    val rowRDD = sqlContext.sparkContext.parallelize(sampleData).map(Row(_))

    val df = sqlContext.createDataFrame(rowRDD, schema)
    Some(df)
  }
} 
Example 171
Source File: ExtensionBuilderSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server

import java.net.URL

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.util.Utils

class ExtensionBuilderSuite extends SparkFunSuite with BeforeAndAfterAll {

  var _sqlContext: SQLContext = _

  // TODO: This method works only in Java8
  private def addJarInClassPath(jarURLString: String): Unit = {
    // val cl = ClassLoader.getSystemClassLoader
    val cl = Utils.getSparkClassLoader
    val clazz = cl.getClass
    val method = clazz.getSuperclass.getDeclaredMethod("addURL", Seq(classOf[URL]): _*)
    method.setAccessible(true)
    method.invoke(cl, Seq[Object](new URL(jarURLString)): _*)
  }

  protected override def beforeAll(): Unit = {
    super.beforeAll()

    // Adds a jar for an extension builder
    val jarPath = "src/test/resources/extensions_2.12_3.0.0-preview2_0.1.7-spark3.0-SNAPSHOT.jar"
    val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath"
    // sqlContext.sparkContext.addJar(jarURL)
    addJarInClassPath(jarURL)

    val conf = new SparkConf(loadDefaults = true)
      .setMaster("local[1]")
      .setAppName("spark-sql-server-test")
      .set("spark.sql.server.extensions.builder", "org.apache.spark.ExtensionBuilderExample")
    _sqlContext = SQLServerEnv.newSQLContext(conf)
  }

  protected override def afterAll(): Unit = {
    try {
      super.afterAll()
    } finally {
      try {
        if (_sqlContext != null) {
          _sqlContext.sparkContext.stop()
          _sqlContext = null
        }
      } finally {
        SparkSession.clearActiveSession()
        SparkSession.clearDefaultSession()
      }
    }
  }

  test("user-defined optimizer rules") {
    val rules = Seq("org.apache.spark.catalyst.EmptyRule1", "org.apache.spark.catalyst.EmptyRule2")
    val optimizerRuleNames = _sqlContext.sessionState.optimizer
      .extendedOperatorOptimizationRules.map(_.ruleName)
    rules.foreach { expectedRuleName =>
      assert(optimizerRuleNames.contains(expectedRuleName))
    }
  }
} 
Example 172
Source File: SQLServerEnv.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server

import scala.util.control.NonFatal

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, SQLContext}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.server.ui.SQLServerTab
import org.apache.spark.util.Utils

object SQLServerEnv extends Logging {

  // For test use
  private var _sqlContext: Option[SQLContext] = None

  @DeveloperApi
  def withSQLContext(sqlContext: SQLContext): Unit = {
    require(sqlContext != null)
    _sqlContext = Option(sqlContext)
    sqlServListener
    uiTab
  }

  private def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
    sparkConf.getAll.foreach { case (k, v) =>
      sqlConf.setConfString(k, v)
    }
  }

  lazy val sparkConf: SparkConf = _sqlContext.map(_.sparkContext.conf).getOrElse {
    val sparkConf = new SparkConf(loadDefaults = true)

    // If user doesn't specify the appName, we want to get [SparkSQL::localHostName]
    // instead of the default appName [SQLServer].
    val maybeAppName = sparkConf
      .getOption("spark.app.name")
      .filterNot(_ == classOf[SQLServer].getName)
    sparkConf
      .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}"))
      .set("spark.sql.crossJoin.enabled", "true")
  }

  lazy val sqlConf: SQLConf = _sqlContext.map(_.conf).getOrElse {
    val newSqlConf = new SQLConf()
    mergeSparkConf(newSqlConf, sparkConf)
    newSqlConf
  }

  lazy val sqlContext: SQLContext = _sqlContext.getOrElse(newSQLContext(sparkConf))
  lazy val sparkContext: SparkContext = sqlContext.sparkContext
  lazy val sqlServListener: Option[SQLServerListener] = Some(newSQLServerListener(sqlContext))
  lazy val uiTab: Option[SQLServerTab] = newUiTab(sqlContext, sqlServListener.get)

  private[sql] def newSQLContext(conf: SparkConf): SQLContext = {
    def buildSQLContext(f: SparkSessionExtensions => Unit = _ => {}): SQLContext = {
      SparkSession.builder.config(conf).withExtensions(f).enableHiveSupport()
        .getOrCreate().sqlContext
    }
    val builderClassName = conf.get("spark.sql.server.extensions.builder", "")
    if (builderClassName.nonEmpty) {
      // Tries to install user-defined extensions
      try {
        val objName = builderClassName + (if (!builderClassName.endsWith("$")) "$" else "")
        val clazz = Utils.classForName(objName)
        val builder = clazz.getDeclaredField("MODULE$").get(null)
          .asInstanceOf[SparkSessionExtensions => Unit]
        val sqlContext = buildSQLContext(builder)
        logInfo(s"Successfully installed extensions from $builderClassName")
        sqlContext
      } catch {
        case NonFatal(e) =>
          logWarning(s"Failed to install extensions from $builderClassName: " + e.getMessage)
          buildSQLContext()
      }
    } else {
      buildSQLContext()
    }
  }
  def newSQLServerListener(sqlContext: SQLContext): SQLServerListener = {
    val listener = new SQLServerListener(sqlContext.conf)
    sqlContext.sparkContext.addSparkListener(listener)
    listener
  }
  def newUiTab(sqlContext: SQLContext, listener: SQLServerListener): Option[SQLServerTab] = {
    sqlContext.sparkContext.conf.getBoolean("spark.ui.enabled", true) match {
      case true => Some(SQLServerTab(SQLServerEnv.sqlContext.sparkContext, listener))
      case _ => None
    }
  }
} 
Example 173
Source File: StreamingJob.scala    From confluent-platform-spark-streaming   with Apache License 2.0 5 votes vote down vote up
package example

import com.typesafe.config.ConfigFactory
import io.confluent.kafka.serializers.KafkaAvroDecoder
import kafka.serializer.StringDecoder
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.kafka.KafkaUtils
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.{SparkContext, SparkConf}


object StreamingJob extends App {

  // Get job configuration
  val config = ConfigFactory.load()

  Logger.getLogger("example").setLevel(Level.toLevel(config.getString("loglevel")))
  private val logger = Logger.getLogger(getClass)

  // Spark config and contexts
  val sparkMaster = config.getString("spark.master")
  val sparkConf = new SparkConf()
    .setMaster(sparkMaster)
    .setAppName("StreamingExample")
    .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")

  val sc = new SparkContext(sparkConf)
  val batchInterval = config.getInt("spark.batch.interval")
  val ssc = new StreamingContext(sc, Seconds(batchInterval))

  // Create Kafka stream
  val groupId = config.getString("kafka.group.id")
  val topic = config.getString("topic")
  val kafkaParams = Map(
    "bootstrap.servers" -> config.getString("kafka.bootstrap.servers"),
    "schema.registry.url" -> config.getString("kafka.schema.registry.url"),
    "group.id" -> groupId
  )

  @transient val kafkaStream: DStream[(String, Object)] =
      KafkaUtils.createDirectStream[String, Object, StringDecoder, KafkaAvroDecoder](
        ssc, kafkaParams, Set(topic)
      )

  // Load JSON strings into DataFrame
  kafkaStream.foreachRDD { rdd =>
    // Get the singleton instance of SQLContext
    val sqlContext = SQLContext.getOrCreate(rdd.sparkContext)
    import sqlContext.implicits._

    val topicValueStrings = rdd.map(_._2.toString)
    val df = sqlContext.read.json(topicValueStrings)

    df.printSchema()
    println("DataFrame count: " + df.count())
    df.take(1).foreach(println)
  }

  ssc.start()
  ssc.awaitTermination()

} 
Example 174
Source File: DefaultSource.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.sql

import org.apache.spark.sql.{ DataFrame, SQLContext, SaveMode }
import org.apache.spark.sql.sources.{ BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider, SchemaRelationProvider }
import org.apache.spark.sql.types.StructType

import com.actian.spark_vector.util.Logging
import com.actian.spark_vector.vector.VectorJDBC

class DefaultSource extends DataSourceRegister with RelationProvider with SchemaRelationProvider with CreatableRelationProvider with Logging {
  override def shortName(): String = "vector"

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation =
    VectorRelation(TableRef(parameters), sqlContext, parameters)

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation =
    VectorRelation(TableRef(parameters), Some(schema), sqlContext, parameters)

  override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
    val tableRef = TableRef(parameters)
    val table = VectorRelation(tableRef, sqlContext, parameters)

    mode match {
      case SaveMode.Overwrite =>
        table.insert(data, true)
      case SaveMode.ErrorIfExists =>
        val isEmpty = VectorJDBC.withJDBC(tableRef.toConnectionProps) { _.isTableEmpty(tableRef.table) }
        if (isEmpty) {
          table.insert(data, false)
        } else {
          throw new UnsupportedOperationException("Writing to a non-empty Vector table is not allowed with mode ErrorIfExists.")
        }
      case SaveMode.Append =>
        table.insert(data, false)
      case SaveMode.Ignore =>
        val isEmpty = VectorJDBC.withJDBC(tableRef.toConnectionProps) { _.isTableEmpty(tableRef.table) }
        if (isEmpty) {
          table.insert(data, false)
        }
    }

    table
  }
} 
Example 175
Source File: CSVRead.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.loader.command

import org.apache.spark.sql.SQLContext

import com.actian.spark_vector.loader.options.UserOptions
import com.actian.spark_vector.sql.sparkQuote
import com.actian.spark_vector.util.Logging

object CSVRead extends Logging {
  private def parseOptions(options: UserOptions): String = {
    Seq(
      Some(s"""sep "${options.csv.separatorChar.getOrElse(",")}""""),
      options.csv.headerRow.filter(identity).map(_ => """header "true""""),
      options.csv.inferSchema.filter(identity).map(_ => s"""inferSchema "true""""),
      options.csv.encoding.map(v => s"""encoding "${v}""""),
      options.csv.quoteChar.map(c => if (c != '\'') s"quote '$c'" else s"""quote "$c""""),
      options.csv.escapeChar.map(c => if (c != '\'') s"""escape '$c'""" else s"""escape "$c""""),
      options.csv.commentChar.map(c => if (c != '\'') s"comment '$c'" else s"""comment "$c""""),
      options.csv.ignoreLeading.filter(identity).map(_ => """ignoreLeadingWhiteSpace "true""""),
      options.csv.ignoreTrailing.filter(identity).map(_ => """ignoreTrailingWhiteSpace "true""""),
      options.csv.nullValue.map(v => s"""nullValue "${v}""""),
      options.csv.nanValue.map(v => s"""nanValue "${v}""""),
      options.csv.positiveInf.map(v => s"""positiveInf "${v}""""),
      options.csv.negativeInf.map(v => s"""negativeInf "${v}""""),
      options.csv.dateFormat.map(v => s"""dateFormat "${v}""""),
      options.csv.timestampFormat.map(v => s"""timestampFormat "${v}""""),
      options.csv.parseMode.map(_.toUpperCase()).collect {
        case "PERMISSIVE" => """mode "PERMISSIVE" """
        case "DROPMALFORMED" => """mode "DROPMALFORMED""""
        case "FAILFAST" => """mode "FAILFAST""""
      }
      ).flatten.mkString(",", ",", "")
  }

  
  def registerTempTable(options: UserOptions, sqlContext: SQLContext): String = {
    val table = s"csv_${options.vector.targetTable}_${System.currentTimeMillis}"
    val quotedTable = sparkQuote(table)
    val baseQuery = s"""CREATE TEMPORARY VIEW $quotedTable${options.csv.header.map(_.mkString("(", ",", ")")).getOrElse("")}
      USING csv
      OPTIONS (path "${options.general.sourceFile}"${parseOptions(options)})"""
    logDebug(s"CSV query to be executed for registering temporary table:\n$baseQuery")
    val df = sqlContext.sql(baseQuery)
    val cols = options.general.colsToLoad.getOrElse(sqlContext.sql(s"select * from $quotedTable where 1=0").columns.toSeq)
    s"select ${cols.map(c => s"""`${c.trim}`""").mkString(",")} from $quotedTable"
  }
} 
Example 176
Source File: JSONRead.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.loader.command

import org.apache.spark.sql.SQLContext

import com.actian.spark_vector.loader.options.UserOptions
import com.actian.spark_vector.sql.sparkQuote
import com.actian.spark_vector.util.Logging
import com.actian.spark_vector.sql.colsSelectStatement

object JSONRead extends Logging { 
 
  private def parseOptions(options: UserOptions): String = {
    Seq(
      options.json.primitivesAsString.filter(identity).map(_ => """primitivesAsString "true""""),
      options.json.allowComments.filter(identity).map(_ => """allowComments "true""""),
      options.json.allowUnquoted.filter(identity).map(_ => """allowUnquotedFieldNames "true""""),
      options.json.allowSingleQuotes.filter(identity).map(_ => """allowSingleQuotes "true""""),
      options.json.allowLeadingZeros.filter(identity).map(_ => """allowNumericLeadingZeros "true""""),
      options.json.allowEscapingAny.filter(identity).map(_ => """allowBackslashEscapingAnyCharacter "true""""),
      options.json.allowUnquotedControlChars.filter(identity).map(_ => """allowUnquotedControlChars "true""""),
      options.json.multiline.filter(identity).map(_ => """multiline "true""""),
      options.json.parseMode.map(_.toUpperCase()).collect {
        case "PERMISSIVE" => """mode "PERMISSIVE" """
        case "DROPMALFORMED" => """mode "DROPMALFORMED""""
        case "FAILFAST" => """mode "FAILFAST""""
      }
    ).flatten.mkString(",", ",", "")
  }  
  
  
  def registerTempTable(options: UserOptions, sqlContext: SQLContext): String = {
      val table = s"json_${options.vector.targetTable}_${System.currentTimeMillis}"
      val quotedTable = sparkQuote(table)
      val baseQuery = s"""CREATE TEMPORARY VIEW $quotedTable${options.json.header.map(_.mkString("(", ",", ")")).getOrElse("")}
      USING json
      OPTIONS (path "${options.general.sourceFile}"${parseOptions(options)})"""
      logDebug(s"JSON query to be executed for registering temporary table:\n$baseQuery")
      val df = sqlContext.sql(baseQuery)
      val cols = options.general.colsToLoad.getOrElse(sqlContext.sql(s"select * from $quotedTable where 1=0").columns.toSeq)
      s"select ${cols.map(c => s"""`${c.trim}`""").mkString(",")} from $quotedTable"    
   }
} 
Example 177
Source File: StringIndexerDemo.scala    From Scala-and-Spark-for-Big-Data-Analytics   with MIT License 5 votes vote down vote up
package com.chapter11.SparkMachineLearning

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.feature.{ OneHotEncoder, StringIndexer }
import org.apache.spark.sql.types._
import org.apache.spark.sql._
import org.apache.spark.sql.functions.year
import org.apache.spark.ml.{ Pipeline, PipelineStage }
import org.apache.spark.ml.classification.{ LogisticRegression, LogisticRegressionModel }
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.sql.{ DataFrame, SparkSession }
import scala.collection.mutable
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.sql._
import org.apache.spark.sql.SQLContext

object StringIndexerDemo {
  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/")
      .appName(s"OneVsRestExample")
      .getOrCreate()

    val df = spark.createDataFrame(
      Seq((0, "Jason", "Germany"),
        (1, "David", "France"),
        (2, "Martin", "Spain"),
        (3, "Jason", "USA"),
        (4, "Daiel", "UK"),
        (5, "Moahmed", "Bangladesh"),
        (6, "David", "Ireland"),
        (7, "Jason", "Netherlands"))).toDF("id", "name", "address")

    df.show(false)

    val indexer = new StringIndexer()
      .setInputCol("name")
      .setOutputCol("label")
      .fit(df)

    val indexed = indexer.transform(df)
    indexed.show(false)

    spark.stop()
  }
} 
Example 178
Source File: PlyOutputWriter.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.ply

import org.apache.spark.sql.types._
import org.apache.hadoop.mapreduce.{ TaskAttemptID, RecordWriter, TaskAttemptContext, JobContext }
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
import java.io.DataOutputStream
import org.apache.spark.sql.sources.OutputWriter
import org.apache.hadoop.io.{ NullWritable, BytesWritable }
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.fs.Path
import java.text.NumberFormat
import org.apache.spark.sql.{ Row, SQLContext, sources }
import fr.ign.spark.iqmulus.RowOutputStream

class PlyOutputWriter(
  name: String,
  context: TaskAttemptContext,
  dataSchema: StructType,
  element: String,
  littleEndian: Boolean
)
    extends OutputWriter {

  private val file = {
    val path = getDefaultWorkFile(s".ply.$element")
    val fs = path.getFileSystem(context.getConfiguration)
    fs.create(path)
  }

  private var count = 0L

  // strip out ids
  private val schema = StructType(dataSchema.filterNot { Seq("fid", "pid") contains _.name })

  private val recordWriter = new RowOutputStream(new DataOutputStream(file), littleEndian, schema, dataSchema)

  def getDefaultWorkFile(extension: String): Path = {
    val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID")
    val taskAttemptId: TaskAttemptID = context.getTaskAttemptID
    val split = taskAttemptId.getTaskID.getId
    new Path(name, f"$split%05d-$uniqueWriteJobId$extension")
  }

  override def write(row: Row): Unit = {
    recordWriter.write(row)
    count += 1
  }

  override def close(): Unit = {
    recordWriter.close

    // write header
    val path = getDefaultWorkFile(".ply.header")
    val fs = path.getFileSystem(context.getConfiguration)
    val dos = new java.io.DataOutputStream(fs.create(path))
    val header = new PlyHeader(path.toString, littleEndian, Map(element -> ((count, schema))))
    header.write(dos)
    dos.close
  }
} 
Example 179
Source File: PerTestSparkSession.scala    From Spark-RSVD   with Apache License 2.0 5 votes vote down vote up
package com.criteo.rsvd

import java.io.File
import java.nio.file.{Files, Path}
import java.util.concurrent.locks.ReentrantLock

import org.apache.commons.io.FileUtils
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.scalatest.{BeforeAndAfterEach, Suite}

import scala.reflect.ClassTag
import scala.util.control.NonFatal

object LocalSparkSession {
  private[this] val lock = new ReentrantLock()

  def acquire(): Unit = lock.lock()

  def release(): Unit = lock.unlock()

  def builder: SparkSession.Builder = {
    SparkSession
      .builder()
      .master("local[*]")
      .appName("test")
      .config("spark.ui.enabled", false)
  }
}


  def sparkConf: Map[String, Any] = Map()

  def toRDD[T: ClassTag](input: Seq[T]): RDD[T] = sc.parallelize(input)

  def toArray[T](input: RDD[T]): Array[T] = input.collect()

  protected def closeSession() = {
    currentSession.foreach(_.stop())
    currentSession = None
    try {
      checkpointDir.foreach(path =>
        FileUtils.deleteDirectory(new File(path.toString)))
    } catch {
      case NonFatal(_) =>
    }
    checkpointDir = None
    LocalSparkSession.release()
  }

  private def getOrCreateSession = synchronized {
    if (currentSession.isEmpty) {
      val builder = LocalSparkSession.builder
      for ((key, value) <- sparkConf) {
        builder.config(key, value.toString)
      }
      currentSession = Some(builder.getOrCreate())
      checkpointDir =
        Some(Files.createTempDirectory("spark-unit-test-checkpoint-"))
      currentSession.get.sparkContext
        .setCheckpointDir(checkpointDir.get.toString)
        currentSession.get.sparkContext.setLogLevel("WARN")
    }
    currentSession.get
  }

  override def beforeEach(): Unit = {
    LocalSparkSession.acquire()
    super.beforeEach()
  }

  override def afterEach(): Unit = {
    try {
      super.afterEach()
    } finally {
      closeSession()
    }
  }
} 
Example 180
Source File: Source.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.cli

import org.apache.spark.sql.{DataFrame, SQLContext}

import scala.util.{Failure, Success, Try}

sealed trait Source {
  def asDataFrame(implicit sqlContext: SQLContext): DataFrame
}

object Source {
  private val hive = "hive://(.*)".r
  private val parquet = "parquet://(.*)".r

  def validate(source: String): Either[String, Unit] = {
    Try(apply(source)) match {
      case Success(s) => Right(())
      case Failure(err) => Left(s"Unsupported source type: $source")
    }
  }

  def apply(source: String): Source = source match {
    case hive(table) => HiveSource(table)
    case parquet(path) => ParquetSource(path)
  }
}

object NoSource extends Source {
  def asDataFrame(implicit sqlContext: SQLContext): DataFrame = {
    sys.error(s"Source is not defined")
  }

  override def toString: String = "Source is not defined"
}

case class HiveSource(
  tableName: String
) extends Source {

  def asDataFrame(implicit sqlContext: SQLContext): DataFrame = {
    sqlContext.sql(s"SELECT * FROM $tableName")
  }

  override def toString: String = {
    s"Hive table: $tableName"
  }

}

case class ParquetSource(
  path: String
) extends Source {

  def asDataFrame(implicit sqlContext: SQLContext): DataFrame = {
    sqlContext.parquetFile(path)
  }

  override def toString: String = {
    s"Parquet: $path"
  }

} 
Example 181
Source File: Sink.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.cli

import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}

import scala.util.{Failure, Success, Try}

sealed trait Sink {
  def saveDataFrame(df: DataFrame)(implicit sqlContext: SQLContext): Unit
}

object Sink {
  private val hive = "hive://(.*)".r
  private val parquet = "parquet://(.*)".r

  def validate(sink: String): Either[String, Unit] = {
    Try(apply(sink)) match {
      case Success(s) => Right(())
      case Failure(err) => Left(s"Unsupported sink type: $sink")
    }
  }

  def apply(sink: String): Sink = sink match {
    case hive(table) => HiveSink(table)
    case parquet(path) => ParquetSink(path)
  }
}

object NoSink extends Sink {
  def saveDataFrame(df: DataFrame)(implicit sqlContext: SQLContext): Unit = {
    sys.error(s"Sink is not defined")
  }

  override def toString: String = "Sink is not defined"
}

case class HiveSink(
  tableName: String
) extends Sink {

  def saveDataFrame(df: DataFrame)(implicit sqlContext: SQLContext): Unit = {
    df.saveAsTable(tableName, SaveMode.Overwrite)
  }

  override def toString: String =
    s"Hive table: $tableName"
}

case class ParquetSink(
  path: String
) extends Sink {

  def saveDataFrame(df: DataFrame)(implicit sqlContext: SQLContext): Unit = {
    df.saveAsParquetFile(path)
  }

  override def toString: String =
    s"Parquet: $path"
} 
Example 182
Source File: Transformers.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.transform

import com.collective.modelmatrix.ModelFeature
import org.apache.spark.sql.{SQLContext, DataFrame}

import scalaz._

trait Transformers {

  protected class Transformers(input: DataFrame @@ Transformer.Features)(implicit sqlContext: SQLContext) {

    val identity = new IdentityTransformer(input)
    val top = new TopTransformer(input)
    val index = new IndexTransformer(input)
    val bins = new BinsTransformer(input)

    private val unknownFeature: PartialFunction[ModelFeature, FeatureTransformationError \/ TypedModelFeature] = {
      case feature => sys.error(s"Feature can't be validated by any of transformers: $feature")
    }

    def validate(feature: ModelFeature): FeatureTransformationError \/ TypedModelFeature =
      (identity.validate orElse
        top.validate orElse
        index.validate orElse
        bins.validate orElse
        unknownFeature
        )(feature)

  }
} 
Example 183
Source File: Main.scala    From spark-gdb   with Apache License 2.0 5 votes vote down vote up
package com.esri.app

import com.esri.core.geometry.Polyline
import com.esri.udt.{PointType, PolylineType}
import org.apache.spark.sql.{SQLContext, SaveMode}
import org.apache.spark.{Logging, SparkConf, SparkContext}



    val sqlContext = new SQLContext(sc)
    val df = sqlContext.read.format("com.esri.gdb")
      .option("path", path)
      .option("name", name)
      .option("numPartitions", "1")
      .load()
    df.printSchema()
    df.registerTempTable(name)
    sqlContext.udf.register("getX", (point: PointType) => point.x)
    sqlContext.udf.register("getY", (point: PointType) => point.y)
    sqlContext.udf.register("line", (point: PointType) => PolylineType({
      val polyline = new Polyline()
      polyline.startPath(point.x - 2, point.y - 2)
      polyline.lineTo(point.x + 2, point.y + 2)
      polyline
    }
    ))
    sqlContext.sql(s"select line(Shape),getX(Shape)-2 as x from $name")
      .write
      .mode(SaveMode.Overwrite)
      .format("json")
      .save(s"/tmp/$name.json")
  } finally {
    sc.stop()
  }

} 
Example 184
Source File: package.scala    From spark-gdb   with Apache License 2.0 5 votes vote down vote up
package com.esri

import java.nio.ByteBuffer

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext


package object gdb {

  implicit class ByteBufferImplicits(byteBuffer: ByteBuffer) {

    implicit def getVarUInt() = {
      var shift = 7
      var b: Long = byteBuffer.get
      var ret = b & 0x7FL
      var old = ret
      while ((b & 0x80L) != 0) {
        b = byteBuffer.get
        ret = ((b & 0x7FL) << shift) | old
        old = ret
        shift += 7
      }
      ret
    }

    implicit def getVarInt() = {
      var shift = 7
      var b: Long = byteBuffer.get
      val isNeg = (b & 0x40L) != 0
      var ret = b & 0x3FL
      var old = ret
      while ((b & 0x80L) != 0) {
        b = byteBuffer.get
        ret = ((b & 0x7FL) << (shift - 1)) | old
        old = ret
        shift += 7
      }
      if (isNeg) -ret else ret
    }
  }

  implicit class SparkContextImplicits(sc: SparkContext) {
    implicit def gdbFile(path: String, name: String, numPartitions: Int = 8) = {
      GDBRDD(sc, path, name, numPartitions)
    }
  }

  implicit class SQLContextImplicits(sqlContext: SQLContext) extends Serializable {
    implicit def gdbFile(path: String, name: String, numPartitions: Int = 8) = {
      sqlContext.baseRelationToDataFrame(GDBRelation(path, name, numPartitions)(sqlContext))
    }
  }

} 
Example 185
Source File: GDBRelation.scala    From spark-gdb   with Apache License 2.0 5 votes vote down vote up
package com.esri.gdb

import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SQLContext}


case class GDBRelation(gdbPath: String, gdbName: String, numPartition: Int)
                      (@transient val sqlContext: SQLContext)
  extends BaseRelation with Logging with TableScan {

  override val schema = inferSchema()

  private def inferSchema() = {
    val sc = sqlContext.sparkContext
    GDBTable.findTable(gdbPath, gdbName, sc.hadoopConfiguration) match {
      case Some(catTab) => {
        val table = GDBTable(gdbPath, catTab.hexName, sc.hadoopConfiguration)
        try {
          table.schema()
        } finally {
          table.close()
        }
      }
      case _ => {
        log.error(s"Cannot find '$gdbName' in $gdbPath, creating an empty schema !")
        StructType(Seq.empty[StructField])
      }
    }
  }

  override def buildScan(): RDD[Row] = {
    GDBRDD(sqlContext.sparkContext, gdbPath, gdbName, numPartition)
  }
} 
Example 186
Source File: DefaultSource.scala    From spark-gdb   with Apache License 2.0 5 votes vote down vote up
package com.esri.gdb

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
import org.apache.spark.sql.types.StructType


  override def createRelation(sqlContext: SQLContext,
                              parameters: Map[String, String],
                              schema: StructType
                             ): BaseRelation = {
    val path = parameters.getOrElse("path", sys.error("Parameter 'path' must be defined."))
    val name = parameters.getOrElse("name", sys.error("Parameter 'name' must be defined."))
    val numPartitions = parameters.getOrElse("numPartitions", "8").toInt
    GDBRelation(path, name, numPartitions)(sqlContext)
  }
} 
Example 187
Source File: DefaultSource.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.ply

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.sources.{ HadoopFsRelation, HadoopFsRelationProvider }
import org.apache.spark.sql.types.StructType

class DefaultSource extends HadoopFsRelationProvider {

  // override def shortName(): String = "ply"

  override def createRelation(
    sqlContext: SQLContext,
    paths: Array[String],
    dataSchema: Option[StructType],
    partitionColumns: Option[StructType],
    parameters: Map[String, String]
  ): HadoopFsRelation = {
    new PlyRelation(paths, dataSchema, partitionColumns, parameters)(sqlContext)
  }
} 
Example 188
package com.example.ZeepleinAndSpark

import org.apache.spark.sql.SparkSession
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext

object BankDatawithZepplein {

  case class Bank(age: Integer, job: String, marital: String, education: String, balance: Integer)

  def main(args: Array[String]): Unit = {
    val spark: SparkSession = SparkSession
      .builder()
      .appName("TrainSplitOCR")
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "C:/Exp/").
      getOrCreate();
    val bankText = spark.sparkContext.textFile("data/bank-full.csv")

    // split each line, filter out header (starts with "age"), and map it into Bank case class
    val bank = bankText.map(s => s.split(";")).filter(s => (s.size) > 5).filter(s => s(0) != "\"age\"").map(
      s => Bank(s(0).toInt,
        s(1).replaceAll("\"", ""),
        s(2).replaceAll("\"", ""),
        s(3).replaceAll("\"", ""),
        s(5).replaceAll("\"", "").toInt))

    val sqlContext = new SQLContext(spark.sparkContext)

    import sqlContext.implicits._
    import sqlContext._
    // convert to DataFrame and create temporal table
    val newDF = bank.toDF()
    newDF.show()
    newDF.createOrReplaceTempView("bank")

    spark.sql("select age, count(1) from bank where age <= 50 group by age order by age").show()

    spark.sql("select age, count(1) from bank where age <= 30 group by age order by age").show()
    spark.sql("select max(balance) as MaxBalance from bank where age <= 30 group by age order by MaxBalance DESC").show()

  }
} 
Example 189
Source File: TwitterBatchTimely.scala    From Mastering-Spark-for-Data-Science   with MIT License 5 votes vote down vote up
package io.gzet.timeseries

import java.sql.Timestamp

import com.cloudera.sparkts.{DateTimeIndex, TimeSeriesRDD}
import io.gzet.timeseries.timely.MetricImplicits._
import io.gzet.timeseries.timely.TimelyImplicits._
import io.gzet.timeseries.twitter.Twitter._
import io.gzet.utils.spark.accumulo.AccumuloConfig
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import org.joda.time.{DateTime, Minutes, Period}

object TwitterBatchTimely extends SimpleConfig {

  case class Observation(
                          hashtag: String,
                          time: Timestamp,
                          count: Double
                        )

  def main(args: Array[String]) = {

    val sparkConf = new SparkConf().setAppName("Twitter Extractor")
    val sc = new SparkContext(sparkConf)
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._

    val twitterJsonRDD = sc.textFile("file:///Users/antoine/CHAPTER/twitter-trump", 500)
    val tweetRDD = twitterJsonRDD mapPartitions analyzeJson cache()

    // Publish metrics to Timely
    tweetRDD.count()
    tweetRDD.countByState.publish()
    tweetRDD.sentimentByState.publish()

    // Read metrics from Timely
    val conf = AccumuloConfig("GZET", "alice", "alice", "localhost:2181")
    val metricsRDD = sc.timely(conf, Some("io.gzet.count"))

    val minDate = metricsRDD.map(_.time).min()
    val maxDate = metricsRDD.map(_.time).max()

    class TwitterFrequency(val minutes: Int) extends com.cloudera.sparkts.PeriodFrequency(Period.minutes(minutes)) {
      def difference(dt1: DateTime, dt2: DateTime): Int = Minutes.minutesBetween(dt1, dt2).getMinutes / minutes
      override def toString: String = s"minutes $minutes"
    }

    val dtIndex = DateTimeIndex.uniform(minDate, maxDate, new TwitterFrequency(1))

    val metricsDF = metricsRDD.filter({
      metric =>
        metric.tags.keys.toSet.contains("tag")
    }).flatMap({
      metric =>
        metric.tags map {
          case (k, v) =>
            ((v, roundFloorMinute(metric.time, 1)), metric.value)
        }
    }).reduceByKey(_+_).map({
      case ((metric, time), sentiment) =>
        Observation(metric, new Timestamp(time), sentiment)
    }).toDF()

    val tsRDD = TimeSeriesRDD.timeSeriesRDDFromObservations(dtIndex, metricsDF, "time", "hashtag", "count").filter(_._2.toArray.exists(!_.isNaN))

  }

  def roundFloorMinute(time: Long, windowMinutes: Int) = {
    val dt = new DateTime(time)
    dt.withMinuteOfHour((dt.getMinuteOfHour / windowMinutes) * windowMinutes).minuteOfDay().roundFloorCopy().toDate.getTime
  }

} 
Example 190
Source File: StackBootstraping.scala    From Mastering-Spark-for-Data-Science   with MIT License 5 votes vote down vote up
package io.gzet.tagging.stackoverflow


import io.gzet.tagging.classifier.Classifier
import io.gzet.tagging.html.HtmlHandler
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{SparkSession, DataFrame, SQLContext}

import scala.collection.mutable
import scala.xml.{Elem, XML}

object StackBootstraping {

  def parse(spark: SparkSession, posts: RDD[String]): DataFrame = {

    import spark.sqlContext.implicits._
    posts filter { line =>
      line.contains("row Id")
    } map { line =>
      val xml = XML.loadString(line)
      (getBody(xml), getTags(xml))
    } filter { case (body, tags) =>
      body.isDefined && tags.isDefined
    } flatMap  { case (body, tags) =>
      tags.get.map(tag => (body.get, tag))
    } toDF("body", "tag")
  }

  private def getBody(xml: Elem): Option[String] = {
    val bodyAttr = xml.attribute("Body")
    if (bodyAttr.isDefined) {
      val html = bodyAttr.get.head.text
      val htmlHandler = new HtmlHandler()
      val content = htmlHandler.parseHtml(html)
      if (content.isDefined) {
        return content.get.body
      }
    }
    None: Option[String]
  }

  private def getTags(xml: Elem): Option[Array[String]] = {
    val tagsAttr = xml.attribute("Tags")
    if (tagsAttr.isDefined) {
      val tagsText = tagsAttr.get.head.text
      val tags = tagsText
        .replaceAll("<", "")
        .replaceAll(">", ",")
        .split(",")
      return Some(tags)
    }
    None: Option[Array[String]]
  }

  def bootstrapNaiveBayes(df: DataFrame, vectorSize: Option[Int]) = {
    val labeledText = df.rdd map { row =>
      val body = row.getString(0)
      val labels = row.getAs[mutable.WrappedArray[String]](1)
      (body, labels.toArray)
    }
    Classifier.train(labeledText)
  }

} 
Example 191
Source File: Neo4jJavaIntegration.scala    From neo4j-spark-connector   with Apache License 2.0 5 votes vote down vote up
package org.neo4j.spark

import java.util

import org.apache.spark.SparkContext
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.neo4j.spark.dataframe.Neo4jDataFrame
import org.neo4j.spark.rdd.{Neo4jRowRDD, Neo4jTupleRDD}

import scala.collection.JavaConverters._


object Neo4jJavaIntegration {
  def rowRDD(sc: SparkContext, query: String, parameters: java.util.Map[String, AnyRef]) =
    new Neo4jRowRDD(sc, query, if (parameters == null) Seq.empty else parameters.asScala.toSeq).toJavaRDD()

  def tupleRDD(sc: SparkContext, query: String, parameters: java.util.Map[String, AnyRef]): JavaRDD[util.Map[String, AnyRef]] = {
    val params = if (parameters == null) Seq.empty else parameters.asScala.toSeq
    Neo4jTupleRDD(sc, query, params)
      .map((t) => new util.LinkedHashMap[String, AnyRef](t.toMap.asJava).asInstanceOf[util.Map[String, AnyRef]])
      .toJavaRDD()
  }

  def dataFrame(sqlContext: SQLContext, query: String, parameters: java.util.Map[String, AnyRef], schemaInfo: util.Map[String, String]) = {
    Neo4jDataFrame(sqlContext, query, parameters.asScala.toSeq, schemaInfo.asScala.toSeq: _*)
  }
} 
Example 192
Source File: Neo4jGraphFrame.scala    From neo4j-spark-connector   with Apache License 2.0 5 votes vote down vote up
package org.neo4j.spark.dataframe

import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.sql.SQLContext
import org.neo4j.spark.Neo4jGraph

import org.neo4j.spark.cypher.CypherHelpers._


object Neo4jGraphFrame {

  def apply(sqlContext: SQLContext, src: (String, String), edge: (String, String), dst: (String, String)) = {
    def nodeStmt(s: (String, String)) = s"MATCH (n:${s._1.quote}) RETURN id(n) as id, n.${s._2.quote} as prop"

    val edgeProp = if (edge._2 == null) "" else s", r.${edge._2.quote} as prop"
    val edgeStmt = s"MATCH (n:${src._1.quote})-[r:${edge._1.quote}]->(m:${dst._1.quote}) RETURN id(n) as src, id(m) as dst" + edgeProp

    val vertices1 = Neo4jDataFrame(sqlContext, nodeStmt(src), Seq.empty, ("id", "integer"), ("prop", "string"))
    val vertices2 = Neo4jDataFrame(sqlContext, nodeStmt(dst), Seq.empty, ("id", "integer"), ("prop", "string"))
    val schema = Seq(("src", "integer"), ("dst", "integer")) ++ (if (edge._2 != null) Some("prop", "string") else None)
    val edges = Neo4jDataFrame(sqlContext, edgeStmt, Seq.empty, schema: _*)

    org.graphframes.GraphFrame(vertices1.union(vertices2).distinct(), edges)
  }

  def fromGraphX(sc: SparkContext, label1: String, rels: Seq[String], label2: String) = {
    val g: Graph[Any, Int] = Neo4jGraph.loadGraph(sc, label1, rels, label2)
    org.graphframes.GraphFrame.fromGraphX(g)
  }

  def fromEdges(sqlContext: SQLContext, label1: String, rels: Seq[String], label2: String) = {
    val relTypes = rels.map(_.quote).mkString("|")
    val edgeStmt = s"MATCH (n:${label1.quote})-[r:$relTypes]->(m:${label2.quote}) RETURN id(n) as src, id(m) as dst"
    val edges = Neo4jDataFrame(sqlContext, edgeStmt, Seq.empty, ("src", "integer"), ("dst", "integer"))
    org.graphframes.GraphFrame.fromEdges(edges)
  }
} 
Example 193
Source File: HiveAcidRelation.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.hiveacid.datasource

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Column, DataFrame, Row, SQLContext, SparkSession}
import org.apache.spark.sql.sources.{BaseRelation, Filter, InsertableRelation, PrunedFilteredScan}
import org.apache.spark.sql.types._
import com.qubole.spark.hiveacid.{HiveAcidErrors, HiveAcidTable, SparkAcidConf}
import com.qubole.spark.hiveacid.hive.HiveAcidMetadata
import com.qubole.spark.hiveacid.merge.{MergeWhenClause, MergeWhenNotInsert}
import org.apache.spark.sql.catalyst.AliasIdentifier
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import collection.JavaConversions._


case class HiveAcidRelation(sparkSession: SparkSession,
                            fullyQualifiedTableName: String,
                            parameters: Map[String, String])
    extends BaseRelation
    with InsertableRelation
    with PrunedFilteredScan
    with Logging {

  private val hiveAcidMetadata: HiveAcidMetadata = HiveAcidMetadata.fromSparkSession(
    sparkSession,
    fullyQualifiedTableName
  )
  private val hiveAcidTable: HiveAcidTable = new HiveAcidTable(sparkSession,
    hiveAcidMetadata, parameters)

  private val readOptions = SparkAcidConf(sparkSession, parameters)

  override def sqlContext: SQLContext = sparkSession.sqlContext

  override val schema: StructType = if (readOptions.includeRowIds) {
    hiveAcidMetadata.tableSchemaWithRowId
  } else {
    hiveAcidMetadata.tableSchema
  }

  override def insert(data: DataFrame, overwrite: Boolean): Unit = {
   // sql insert into and overwrite
    if (overwrite) {
      hiveAcidTable.insertOverwrite(data)
    } else {
      hiveAcidTable.insertInto(data)
    }
  }

  def update(condition: Option[Column], newValues: Map[String, Column]): Unit = {
    hiveAcidTable.update(condition, newValues)
  }

  def delete(condition: Column): Unit = {
    hiveAcidTable.delete(condition)
  }
  override def sizeInBytes: Long = {
    val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor
    (sparkSession.sessionState.conf.defaultSizeInBytes * compressionFactor).toLong
  }

  def merge(sourceDf: DataFrame,
            mergeExpression: Expression,
            matchedClause: Seq[MergeWhenClause],
            notMatched: Option[MergeWhenNotInsert],
            sourceAlias: Option[AliasIdentifier],
            targetAlias: Option[AliasIdentifier]): Unit = {
    hiveAcidTable.merge(sourceDf, mergeExpression, matchedClause,
      notMatched, sourceAlias, targetAlias)
  }

  def getHiveAcidTable(): HiveAcidTable = {
    hiveAcidTable
  }

  // FIXME: should it be true / false. Recommendation seems to
  //  be to leave it as true
  override val needConversion: Boolean = false

  override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
    val readOptions = SparkAcidConf(sparkSession, parameters)
    // sql "select *"
    hiveAcidTable.getRdd(requiredColumns, filters, readOptions)
  }
} 
Example 194
Source File: MLlibTestSparkContext.scala    From spark-ranking-algorithms   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import org.scalatest.{BeforeAndAfterAll, Suite}

import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}

trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite =>
  @transient var sc: SparkContext = _
  @transient var sqlContext: SQLContext = _

  override def beforeAll() {
    super.beforeAll()
    val conf = new SparkConf()
      .setMaster("local[2]")
      .setAppName("MLlibUnitTest")
    sc = new SparkContext(conf)
    SQLContext.clearActive()
    sqlContext = new SQLContext(sc)
    SQLContext.setActive(sqlContext)
  }

  override def afterAll() {
    try {
      sqlContext = null
      SQLContext.clearActive()
      if (sc != null) {
        sc.stop()
      }
      sc = null
    } finally {
      super.afterAll()
    }
  }
} 
Example 195
Source File: PulsarRelation.scala    From pulsar-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.pulsar

import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.json.JSONOptionsInRead
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType


private[pulsar] class PulsarRelation(
    override val sqlContext: SQLContext,
    override val schema: StructType,
    schemaInfo: SchemaInfoSerializable,
    adminUrl: String,
    clientConf: ju.Map[String, Object],
    readerConf: ju.Map[String, Object],
    startingOffset: SpecificPulsarOffset,
    endingOffset: SpecificPulsarOffset,
    pollTimeoutMs: Int,
    failOnDataLoss: Boolean,
    subscriptionNamePrefix: String,
    jsonOptions: JSONOptionsInRead)
    extends BaseRelation
    with TableScan
    with Logging {

  import PulsarSourceUtils._

  val reportDataLoss = reportDataLossFunc(failOnDataLoss)

  override def buildScan(): RDD[Row] = {
    val fromTopicOffsets = startingOffset.topicOffsets
    val endTopicOffsets = endingOffset.topicOffsets

    if (fromTopicOffsets.keySet != endTopicOffsets.keySet) {
      val fromTopics = fromTopicOffsets.keySet.toList.sorted.mkString(",")
      val endTopics = endTopicOffsets.keySet.toList.sorted.mkString(",")
      throw new IllegalStateException(
        "different topics " +
          s"for starting offsets topics[${fromTopics}] and " +
          s"ending offsets topics[${endTopics}]")
    }

    val offsetRanges = endTopicOffsets.keySet
      .map { tp =>
        val fromOffset = fromTopicOffsets.getOrElse(tp, {
          // this shouldn't happen since we had checked it
          throw new IllegalStateException(s"$tp doesn't have a from offset")
        })
        val untilOffset = endTopicOffsets(tp)
        PulsarOffsetRange(tp, fromOffset, untilOffset, None)
      }
      .filter { range =>
        if (range.untilOffset.compareTo(range.fromOffset) < 0) {
          reportDataLoss(
            s"${range.topic}'s offset was changed " +
              s"from ${range.fromOffset} to ${range.untilOffset}, " +
              "some data might has been missed")
          false
        } else {
          true
        }
      }
      .toSeq

    val rdd = new PulsarSourceRDD4Batch(
      sqlContext.sparkContext,
      schemaInfo,
      adminUrl,
      clientConf,
      readerConf,
      offsetRanges,
      pollTimeoutMs,
      failOnDataLoss,
      subscriptionNamePrefix,
      jsonOptions
    )
    sqlContext.internalCreateDataFrame(rdd.setName("pulsar"), schema).rdd
  }
} 
Example 196
Source File: InsertMysqlDemo.scala    From spark_mysql   with Apache License 2.0 5 votes vote down vote up
import java.sql.{Date, Timestamp}

import InsertMysqlDemo.CardMember
import org.apache.spark.sql.SQLContext
import org.apache.spark.{SparkConf, SparkContext}
import utils.MySQLUtils

/**
  * Created with IntelliJ IDEA.
  * Author: [email protected]
  * Description:DataFrame 中数据存入到MySQL
  * Date: Created in 2018-11-17 12:39
  */
object InsertMysqlDemo {

  case class CardMember(m_id: String, card_type: String, expire: Timestamp, duration: Int, is_sale: Boolean, date: Date, user: Long, salary: Float)

  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[*]").setAppName(getClass.getSimpleName).set("spark.testing.memory", "3147480000")
    val sparkContext = new SparkContext(conf)
    val hiveContext = new SQLContext(sparkContext)
    import hiveContext.implicits._
    val memberSeq = Seq(
      CardMember("member_2", "月卡", new Timestamp(System.currentTimeMillis()), 31, false, new Date(System.currentTimeMillis()), 123223, 0.32f),
      CardMember("member_1", "季卡", new Timestamp(System.currentTimeMillis()), 93, false, new Date(System.currentTimeMillis()), 124224, 0.362f)
    )
    val memberDF = memberSeq.toDF()
    MySQLUtils.saveDFtoDBCreateTableIfNotExist("member_test", memberDF)
    MySQLUtils.insertOrUpdateDFtoDBUsePool("member_test", memberDF, Array("user", "salary"))
    MySQLUtils.getDFFromMysql(hiveContext, "", null)


    sparkContext.stop()
  }
} 
Example 197
Source File: HBaseTestSource.scala    From hbase-connectors   with Apache License 2.0 5 votes vote down vote up
package org.apache.hadoop.hbase.spark

import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
import org.apache.spark.SparkEnv
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._

class HBaseTestSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    DummyScan(
      parameters("cacheSize").toInt,
      parameters("batchNum").toInt,
      parameters("blockCacheingEnable").toBoolean,
      parameters("rowNum").toInt)(sqlContext)
  }
}

case class DummyScan(
     cacheSize: Int,
     batchNum: Int,
     blockCachingEnable: Boolean,
     rowNum: Int)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan {
  private def sparkConf = SparkEnv.get.conf
  override def schema: StructType =
    StructType(StructField("i", IntegerType, nullable = false) :: Nil)

  override def buildScan(): RDD[Row] = sqlContext.sparkContext.parallelize(0 until rowNum)
    .map(Row(_))
    .map{ x =>
      if (sparkConf.getInt(HBaseSparkConf.QUERY_BATCHSIZE,
          -1) != batchNum ||
        sparkConf.getInt(HBaseSparkConf.QUERY_CACHEDROWS,
          -1) != cacheSize ||
        sparkConf.getBoolean(HBaseSparkConf.QUERY_CACHEBLOCKS,
          false) != blockCachingEnable) {
        throw new Exception("HBase Spark configuration cannot be set properly")
      }
      x
  }
} 
Example 198
Source File: KuduSink.scala    From kafka-examples   with Apache License 2.0 5 votes vote down vote up
package com.cloudera.streaming.refapp.kudu

import org.apache.kudu.spark.kudu.KuduContext
import org.apache.spark.sql.execution.streaming.Sink
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.slf4j.LoggerFactory

import scala.util.control.NonFatal

object KuduSink {
  def withDefaultContext(sqlContext: SQLContext, parameters: Map[String, String]) =
    new KuduSink(new KuduContext(parameters("kudu.master"), sqlContext.sparkContext), parameters)
}


class KuduSink(initKuduContext: => KuduContext, parameters: Map[String, String]) extends Sink {

  private val logger = LoggerFactory.getLogger(getClass)

  private var kuduContext = initKuduContext

  private val tablename = parameters("kudu.table")

  private val retries = parameters.getOrElse("retries", "1").toInt
  require(retries >= 0, "retries must be non-negative")

  logger.info(s"Created Kudu sink writing to table $tablename")

  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    for (attempt <- 0 to retries) {
      try {
        kuduContext.upsertRows(data, tablename)
        return
      } catch {
        case NonFatal(e) =>
          if (attempt < retries) {
            logger.warn("Kudu upsert error, retrying...", e)
            kuduContext = initKuduContext
          }
          else {
            logger.error("Kudu upsert error, exhausted", e)
            throw e
          }
      }
    }
  }
} 
Example 199
Source File: SparkNRedshiftUtil.scala    From SqlShift   with MIT License 5 votes vote down vote up
package com.goibibo.sqlshift

import java.sql.{Connection, DriverManager}
import java.util.Properties

import com.databricks.spark.redshift.RedshiftReaderM
import com.typesafe.config.Config
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, Suite}
import org.slf4j.{Logger, LoggerFactory}


trait SparkNRedshiftUtil extends BeforeAndAfterAll {
    self: Suite =>
    private val logger: Logger = LoggerFactory.getLogger(this.getClass)
    @transient private var _sc: SparkContext = _
    @transient private var _sqlContext: SQLContext = _

    def sc: SparkContext = _sc
    def sqlContext: SQLContext = _sqlContext

    private def getRedshiftConnection(config: Config): Connection = {
        val mysql = config.getConfig("redshift")
        val connectionProps = new Properties()
        connectionProps.put("user", mysql.getString("username"))
        connectionProps.put("password", mysql.getString("password"))
        val jdbcUrl = s"jdbc:redshift://${mysql.getString("hostname")}:${mysql.getInt("portno")}/${mysql.getString("database")}?useSSL=false"
        Class.forName("com.amazon.redshift.jdbc4.Driver")
        DriverManager.getConnection(jdbcUrl, connectionProps)
    }

    val getSparkContext: (SparkContext, SQLContext) = {
        val sparkConf: SparkConf = new SparkConf().setAppName("Full Dump Testing").setMaster("local")
        val sc: SparkContext = new SparkContext(sparkConf)
        val sqlContext: SQLContext = new SQLContext(sc)

        System.setProperty("com.amazonaws.services.s3.enableV4", "true")
        sc.hadoopConfiguration.set("fs.s3a.endpoint", "s3.ap-south-1.amazonaws.com")
        sc.hadoopConfiguration.set("fs.s3a.fast.upload", "true")
        (sc, sqlContext)
    }

    def readTableFromRedshift(config: Config, tableName: String): DataFrame = {
        val redshift: Config = config.getConfig("redshift")
        val options = Map("dbtable" -> tableName,
            "user" -> redshift.getString("username"),
            "password" -> redshift.getString("password"),
            "url" -> s"jdbc:redshift://${redshift.getString("hostname")}:${redshift.getInt("portno")}/${redshift.getString("database")}",
            "tempdir" -> config.getString("s3.location"),
            "aws_iam_role" -> config.getString("redshift.iamRole")
        )
        RedshiftReaderM.getDataFrameForConfig(options, sc, sqlContext)
    }

    def dropTableRedshift(config: Config, tables: String*): Unit = {
        logger.info("Droping table: {}", tables)
        val conn = getRedshiftConnection(config)
        val statement = conn.createStatement()
        try {
            val dropTableQuery = s"""DROP TABLE ${tables.mkString(",")}"""
            logger.info("Running query: {}", dropTableQuery)
            statement.executeUpdate(dropTableQuery)
        } finally {
            statement.close()
            conn.close()
        }
    }

    override protected def beforeAll(): Unit = {
        super.beforeAll()
        val (sc, sqlContext) = getSparkContext
        _sc = sc
        _sqlContext = sqlContext
    }

    override protected def afterAll(): Unit = {
        super.afterAll()
        _sc.stop()
    }
} 
Example 200
Source File: RedshiftReaderM.scala    From SqlShift   with MIT License 5 votes vote down vote up
package com.databricks.spark.redshift

import com.amazonaws.auth.AWSCredentials
import com.amazonaws.services.s3.AmazonS3Client
import org.apache.spark.SparkContext
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.{DataFrame, SQLContext}

object RedshiftReaderM {

    val endpoint = "s3.ap-south-1.amazonaws.com"

    def getS3Client(provider: AWSCredentials): AmazonS3Client = {
        val client = new AmazonS3Client(provider)
        client.setEndpoint(endpoint)
        client
    }

    def getDataFrameForConfig(configs: Map[String, String], sparkContext: SparkContext, sqlContext: SQLContext): DataFrame = {
        val source: DefaultSource = new DefaultSource(new JDBCWrapper(), getS3Client)
        val br: BaseRelation = source.createRelation(sqlContext, configs)
        sqlContext.baseRelationToDataFrame(br)
    }
}