org.apache.spark.sql.execution.QueryExecution Scala Examples

The following examples show how to use org.apache.spark.sql.execution.QueryExecution. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: SparkSQLDriver.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.{Arrays, ArrayList => JArrayList, List => JList}
import scala.collection.JavaConverters._

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SQLContext, SparkSession}
import org.apache.spark.sql.execution.QueryExecution


private[hive] class SparkSQLDriver(val sparkSession: SparkSession = SparkSQLEnv.sparkSession)
  extends Driver
  with Logging {

  private[hive] var tableSchema: Schema = _
  private[hive] var hiveResponse: Seq[String] = _

  override def init(): Unit = {
  }

  private def getResultSetSchema(query: QueryExecution): Schema = {
    val analyzed = query.analyzed
    logDebug(s"Result Schema: ${analyzed.output}")
    if (analyzed.output.isEmpty) {
      new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null)
    } else {
      val fieldSchemas = analyzed.output.map { attr =>
        new FieldSchema(attr.name, attr.dataType.catalogString, "")
      }

      new Schema(fieldSchemas.asJava, null)
    }
  }

  override def run(command: String): CommandProcessorResponse = {
    // TODO unify the error code
    try {
      sparkSession.sparkContext.setJobDescription(command)
      val execution = sparkSession.sessionState.executePlan(sparkSession.sql(command).logicalPlan)
      hiveResponse = execution.hiveResultString()
      tableSchema = getResultSetSchema(execution)
      new CommandProcessorResponse(0)
    } catch {
        case ae: AnalysisException =>
          logDebug(s"Failed in [$command]", ae)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae)
        case cause: Throwable =>
          logError(s"Failed in [$command]", cause)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause)
    }
  }

  override def close(): Int = {
    hiveResponse = null
    tableSchema = null
    0
  }

  override def getResults(res: JList[_]): Boolean = {
    if (hiveResponse == null) {
      false
    } else {
      res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava)
      hiveResponse = null
      true
    }
  }

  override def getSchema: Schema = tableSchema

  override def destroy() {
    super.destroy()
    hiveResponse = null
    tableSchema = null
  }
} 
Example 2
Source File: HBaseLocalClient.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.hbase.utilities

import java.io.File

import scala.collection.mutable.ArrayBuffer

import com.google.common.io.Files
import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName}
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.hbase.SparkHBaseConf
import org.apache.spark.sql.util._
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}

import com.paypal.gimel.common.catalog.Field
import com.paypal.gimel.hbase.DataSet

class HBaseLocalClient extends FunSuite with Matchers with BeforeAndAfterAll {

  var sparkSession : SparkSession = _
  var dataSet: DataSet = _
  val hbaseTestingUtility = new HBaseTestingUtility()
  val tableName = "test_table"
  val cfs = Array("personal", "professional")
  val columns = Array("id", "name", "age", "address", "company", "designation", "salary")
  val fields = columns.map(col => new Field(col))

  val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]

  protected override def beforeAll(): Unit = {
    val tempDir: File = Files.createTempDir
    tempDir.deleteOnExit
    hbaseTestingUtility.startMiniCluster()
    SparkHBaseConf.conf = hbaseTestingUtility.getConfiguration
    createTable(tableName, cfs)
    val conf = new SparkConf
    conf.set(SparkHBaseConf.testConf, "true")
    sparkSession = SparkSession.builder()
      .master("local")
      .appName("HBase Test")
      .config(conf)
      .getOrCreate()

    val listener = new QueryExecutionListener {
      // Only test successful case here, so no need to implement `onFailure`
      override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
      override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
        metrics += ((funcName, qe, duration))
      }
    }
    sparkSession.listenerManager.register(listener)
    sparkSession.sparkContext.setLogLevel("ERROR")
    dataSet = new DataSet(sparkSession)
  }

  protected override def afterAll(): Unit = {
    hbaseTestingUtility.shutdownMiniCluster()
    sparkSession.close()
  }

  def createTable(name: String, cfs: Array[String]) {
    val tName = Bytes.toBytes(name)
    val bcfs = cfs.map(Bytes.toBytes(_))
    try {
      hbaseTestingUtility.deleteTable(TableName.valueOf(tName))
    } catch {
      case _ : Throwable =>
        println("No table = " + name + " found")
    }
    hbaseTestingUtility.createMultiRegionTable(TableName.valueOf(tName), bcfs)
  }

  // Mocks data for testing
  def mockDataInDataFrame(numberOfRows: Int): DataFrame = {
    def stringed(n: Int) = s"""{"id": "$n","name": "MAC-$n", "address": "MAC-${n + 1}", "age": "${n + 1}", "company": "MAC-$n", "designation": "MAC-$n", "salary": "${n * 10000}" }"""
    val texts: Seq[String] = (1 to numberOfRows).map { x => stringed(x) }
    val rdd: RDD[String] = sparkSession.sparkContext.parallelize(texts)
    val dataFrame: DataFrame = sparkSession.read.json(rdd)
    dataFrame
  }
} 
Example 3
Source File: BigDatalogProgram.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark

import edu.ucla.cs.wis.bigdatalog.interpreter.OperatorProgram
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.{DataFrame, Row}

class BigDatalogProgram(var bigDatalogContext: BigDatalogContext,
                        plan: LogicalPlan,
                        operatorProgram: OperatorProgram) {

  def toDF(): DataFrame = {
    new DataFrame(bigDatalogContext, plan)
  }
  
  def count(): Long = {
    toDF().count()
  }

  // use this method to produce an rdd containing the results for the program (i.e., it evaluates the program)
  def execute(): RDD[Row] = {
    toDF().rdd
  }

  override def toString(): String = {
    new QueryExecution(bigDatalogContext, plan).toString
  }
} 
Example 4
Source File: KinesisWriter.scala    From kinesis-sql   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kinesis

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.util.Utils

private[kinesis] object KinesisWriter extends Logging {

  val DATA_ATTRIBUTE_NAME: String = "data"
  val PARTITION_KEY_ATTRIBUTE_NAME: String = "partitionKey"

  override def toString: String = "KinesisWriter"

  def write(sparkSession: SparkSession,
            queryExecution: QueryExecution,
            kinesisParameters: Map[String, String]): Unit = {
    val schema = queryExecution.analyzed.output

    SQLExecution.withNewExecutionId(sparkSession, queryExecution) {
      queryExecution.toRdd.foreachPartition { iter =>
        val writeTask = new KinesisWriteTask(kinesisParameters, schema)
        Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
          finallyBlock = writeTask.close())
      }
    }
  }
} 
Example 5
Source File: ReplaceGroup.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package com.github.passionke.replace

import com.github.passionke.starry.SparkPlanExecutor
import com.github.passionke.baseline.Dumy
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, SubqueryAlias}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.Spark
import org.scalatest.FunSuite


class ReplaceGroup extends FunSuite {

  test("group by") {
    val sparkSession = Spark.sparkSession
    sparkSession.sparkContext.setLogLevel("WARN")
    import sparkSession.implicits._
    val dumys = Seq(Dumy("a", 10, "abc"), Dumy("a", 20, "ass"))
    dumys.toDF().createOrReplaceTempView("a")

    val df = sparkSession.sql(
      """
        |select name, count(1) as cnt
        |from a
        |group by name
      """.stripMargin)

    df.show()
    val sparkPlan = df.queryExecution.sparkPlan
    val logicalPlan = df.queryExecution.analyzed


    val dumy1 = Seq(Dumy("a", 1, "abc"), Dumy("a", 1, "ass"), Dumy("a", 2, "sf"))
    val data = dumy1.toDF().queryExecution.executedPlan.execute().collect()

    val newL = logicalPlan.transform({
      case SubqueryAlias(a, localRelation) if a.equals("a") =>
        SubqueryAlias(a, LocalRelation(localRelation.output, data))
    })

    val ns = sparkSession.newSession()
    val qe = new QueryExecution(ns, newL)
    val start = System.currentTimeMillis()
    val list = SparkPlanExecutor.exec(qe.sparkPlan, ns)
    assert(list.head.getLong(1).equals(3L))
    val end = System.currentTimeMillis()
    end - start
  }

} 
Example 6
Source File: ApproxCountDistinctForIntervalsQuerySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.test.SharedSQLContext

class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height
  // histogram usually contains hundreds of buckets. So we need to test
  // ApproxCountDistinctForIntervals with large number of endpoints
  // (the number of endpoints == the number of buckets + 1).
  test("test ApproxCountDistinctForIntervals with large number of endpoints") {
    val table = "approx_count_distinct_for_intervals_tbl"
    withTable(table) {
      (1 to 100000).toDF("col").createOrReplaceTempView(table)
      // percentiles of 0, 0.001, 0.002 ... 0.999, 1
      val endpoints = (0 to 1000).map(_ * 100000 / 1000)

      // Since approx_count_distinct_for_intervals is not a public function, here we do
      // the computation by constructing logical plan.
      val relation = spark.table(table).logicalPlan
      val attr = relation.output.find(_.name == "col").get
      val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_))))
      val aggExpr = aggFunc.toAggregateExpression()
      val namedExpr = Alias(aggExpr, aggExpr.toString)()
      val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation))
        .executedPlan.executeTake(1).head
      val ndvArray = ndvsRow.getArray(0).toLongArray()
      assert(endpoints.length == ndvArray.length + 1)

      // Each bucket has 100 distinct values.
      val expectedNdv = 100
      for (i <- ndvArray.indices) {
        val ndv = ndvArray(i)
        val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
        assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.")
      }
    }
  }
} 
Example 7
Source File: SparkSQLDriver.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.{ArrayList => JArrayList, Arrays, List => JList}

import scala.collection.JavaConverters._

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}


private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext)
  extends Driver
  with Logging {

  private[hive] var tableSchema: Schema = _
  private[hive] var hiveResponse: Seq[String] = _

  override def init(): Unit = {
  }

  private def getResultSetSchema(query: QueryExecution): Schema = {
    val analyzed = query.analyzed
    logDebug(s"Result Schema: ${analyzed.output}")
    if (analyzed.output.isEmpty) {
      new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null)
    } else {
      val fieldSchemas = analyzed.output.map { attr =>
        new FieldSchema(attr.name, attr.dataType.catalogString, "")
      }

      new Schema(fieldSchemas.asJava, null)
    }
  }

  override def run(command: String): CommandProcessorResponse = {
    // TODO unify the error code
    try {
      context.sparkContext.setJobDescription(command)
      val execution = context.sessionState.executePlan(context.sql(command).logicalPlan)
      hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) {
        execution.hiveResultString()
      }
      tableSchema = getResultSetSchema(execution)
      new CommandProcessorResponse(0)
    } catch {
        case ae: AnalysisException =>
          logDebug(s"Failed in [$command]", ae)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae)
        case cause: Throwable =>
          logError(s"Failed in [$command]", cause)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause)
    }
  }

  override def close(): Int = {
    hiveResponse = null
    tableSchema = null
    0
  }

  override def getResults(res: JList[_]): Boolean = {
    if (hiveResponse == null) {
      false
    } else {
      res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava)
      hiveResponse = null
      true
    }
  }

  override def getSchema: Schema = tableSchema

  override def destroy() {
    super.destroy()
    hiveResponse = null
    tableSchema = null
  }
} 
Example 8
Source File: KafkaWriter.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.types.{BinaryType, StringType}
import org.apache.spark.util.Utils


private[kafka010] object KafkaWriter extends Logging {
  val TOPIC_ATTRIBUTE_NAME: String = "topic"
  val KEY_ATTRIBUTE_NAME: String = "key"
  val VALUE_ATTRIBUTE_NAME: String = "value"

  override def toString: String = "KafkaWriter"

  def validateQuery(
      schema: Seq[Attribute],
      kafkaParameters: ju.Map[String, Object],
      topic: Option[String] = None): Unit = {
    schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse(
      if (topic.isEmpty) {
        throw new AnalysisException(s"topic option required when no " +
          s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " +
          s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
      } else {
        Literal(topic.get, StringType)
      }
    ).dataType match {
      case StringType => // good
      case _ =>
        throw new AnalysisException(s"Topic type must be a String")
    }
    schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse(
      Literal(null, StringType)
    ).dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " +
          s"must be a String or BinaryType")
    }
    schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse(
      throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found")
    ).dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
          s"must be a String or BinaryType")
    }
  }

  def write(
      sparkSession: SparkSession,
      queryExecution: QueryExecution,
      kafkaParameters: ju.Map[String, Object],
      topic: Option[String] = None): Unit = {
    val schema = queryExecution.analyzed.output
    validateQuery(schema, kafkaParameters, topic)
    queryExecution.toRdd.foreachPartition { iter =>
      val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
      Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
        finallyBlock = writeTask.close())
    }
  }
} 
Example 9
Source File: FramelessInternals.scala    From frameless   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct}
import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.ObjectType
import scala.reflect.ClassTag

object FramelessInternals {
  def objectTypeFor[A](implicit classTag: ClassTag[A]): ObjectType = ObjectType(classTag.runtimeClass)

  def resolveExpr(ds: Dataset[_], colNames: Seq[String]): NamedExpression = {
    ds.toDF.queryExecution.analyzed.resolve(colNames, ds.sparkSession.sessionState.analyzer.resolver).getOrElse {
      throw new AnalysisException(
        s"""Cannot resolve column name "$colNames" among (${ds.schema.fieldNames.mkString(", ")})""")
    }
  }

  def expr(column: Column): Expression = column.expr

  def column(column: Column): Expression = column.expr

  def logicalPlan(ds: Dataset[_]): LogicalPlan = ds.logicalPlan

  def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution =
    ds.sparkSession.sessionState.executePlan(plan)

  def joinPlan(ds: Dataset[_], plan: LogicalPlan, leftPlan: LogicalPlan, rightPlan: LogicalPlan): LogicalPlan = {
    val joined = executePlan(ds, plan)
    val leftOutput = joined.analyzed.output.take(leftPlan.output.length)
    val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length)

    Project(List(
      Alias(CreateStruct(leftOutput), "_1")(),
      Alias(CreateStruct(rightOutput), "_2")()
    ), joined.analyzed)
  }

  def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] =
    new Dataset(sqlContext, plan, encoder)

  def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
    Dataset.ofRows(sparkSession, logicalPlan)

  // because org.apache.spark.sql.types.UserDefinedType is private[spark]
  type UserDefinedType[A >: Null] =  org.apache.spark.sql.types.UserDefinedType[A]

  
  case class DisambiguateRight[T](tagged: Expression) extends Expression with NonSQLExpression {
    def eval(input: InternalRow): Any = tagged.eval(input)
    def nullable: Boolean = false
    def children: Seq[Expression] = tagged :: Nil
    def dataType: DataType = tagged.dataType
    protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ???
    override def genCode(ctx: CodegenContext): ExprCode = tagged.genCode(ctx)
  }
} 
Example 10
Source File: SparkSQLDriver.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.{ArrayList => JArrayList, Arrays, List => JList}

import scala.collection.JavaConverters._

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.execution.QueryExecution


private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext)
  extends Driver
  with Logging {

  private[hive] var tableSchema: Schema = _
  private[hive] var hiveResponse: Seq[String] = _

  override def init(): Unit = {
  }

  private def getResultSetSchema(query: QueryExecution): Schema = {
    val analyzed = query.analyzed
    logDebug(s"Result Schema: ${analyzed.output}")
    if (analyzed.output.isEmpty) {
      new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null)
    } else {
      val fieldSchemas = analyzed.output.map { attr =>
        new FieldSchema(attr.name, attr.dataType.catalogString, "")
      }

      new Schema(fieldSchemas.asJava, null)
    }
  }

  override def run(command: String): CommandProcessorResponse = {
    // TODO unify the error code
    try {
      context.sparkContext.setJobDescription(command)
      val execution = context.sessionState.executePlan(context.sql(command).logicalPlan)
      hiveResponse = execution.hiveResultString()
      tableSchema = getResultSetSchema(execution)
      new CommandProcessorResponse(0)
    } catch {
        case ae: AnalysisException =>
          logDebug(s"Failed in [$command]", ae)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae)
        case cause: Throwable =>
          logError(s"Failed in [$command]", cause)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause)
    }
  }

  override def close(): Int = {
    hiveResponse = null
    tableSchema = null
    0
  }

  override def getResults(res: JList[_]): Boolean = {
    if (hiveResponse == null) {
      false
    } else {
      res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava)
      hiveResponse = null
      true
    }
  }

  override def getSchema: Schema = tableSchema

  override def destroy() {
    super.destroy()
    hiveResponse = null
    tableSchema = null
  }
} 
Example 11
Source File: SparkSQLDriver.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.{ArrayList => JArrayList, Arrays, List => JList}

import scala.collection.JavaConverters._

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.execution.QueryExecution


private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext)
  extends Driver
  with Logging {

  private[hive] var tableSchema: Schema = _
  private[hive] var hiveResponse: Seq[String] = _

  override def init(): Unit = {
  }

  private def getResultSetSchema(query: QueryExecution): Schema = {
    val analyzed = query.analyzed
    logDebug(s"Result Schema: ${analyzed.output}")
    if (analyzed.output.isEmpty) {
      new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null)
    } else {
      val fieldSchemas = analyzed.output.map { attr =>
        new FieldSchema(attr.name, attr.dataType.catalogString, "")
      }

      new Schema(fieldSchemas.asJava, null)
    }
  }

  override def run(command: String): CommandProcessorResponse = {
    // TODO unify the error code
    try {
      context.sparkContext.setJobDescription(command)
      val execution = context.sessionState.executePlan(context.sql(command).logicalPlan)
      hiveResponse = execution.hiveResultString()
      tableSchema = getResultSetSchema(execution)
      new CommandProcessorResponse(0)
    } catch {
        case ae: AnalysisException =>
          logDebug(s"Failed in [$command]", ae)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae)
        case cause: Throwable =>
          logError(s"Failed in [$command]", cause)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause)
    }
  }

  override def close(): Int = {
    hiveResponse = null
    tableSchema = null
    0
  }

  override def getResults(res: JList[_]): Boolean = {
    if (hiveResponse == null) {
      false
    } else {
      res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava)
      hiveResponse = null
      true
    }
  }

  override def getSchema: Schema = tableSchema

  override def destroy() {
    super.destroy()
    hiveResponse = null
    tableSchema = null
  }
} 
Example 12
Source File: AtlasQueryExecutionListener.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.sql.testhelper

import com.hortonworks.spark.atlas.sql.QueryDetail
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener

import scala.collection.mutable

class AtlasQueryExecutionListener extends QueryExecutionListener {
  val queryDetails = new mutable.MutableList[QueryDetail]()

  override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
    if (qe.logical.isStreaming) {
      // streaming query will be tracked via SparkAtlasStreamingQueryEventTracker
      return
    }
    queryDetails += QueryDetail.fromQueryExecutionListener(qe, durationNs)
  }

  override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
    throw exception
  }

  def clear(): Unit = {
    queryDetails.clear()
  }
} 
Example 13
Source File: SparkAtlasEventTracker.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas

import com.google.common.annotations.VisibleForTesting
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogEvent
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.util.QueryExecutionListener
import com.hortonworks.spark.atlas.sql._
import com.hortonworks.spark.atlas.ml.MLPipelineEventProcessor
import com.hortonworks.spark.atlas.utils.Logging

class SparkAtlasEventTracker(atlasClient: AtlasClient, atlasClientConf: AtlasClientConf)
    extends SparkListener with QueryExecutionListener with Logging {

  def this(atlasClientConf: AtlasClientConf) = {
    this(AtlasClient.atlasClient(atlasClientConf), atlasClientConf)
  }

  def this() {
    this(new AtlasClientConf)
  }

  private val enabled: Boolean = AtlasUtils.isSacEnabled(atlasClientConf)

  // Processor to handle DDL related events
  @VisibleForTesting
  private[atlas] val catalogEventTracker =
    new SparkCatalogEventProcessor(atlasClient, atlasClientConf)
  catalogEventTracker.startThread()

  // Processor to handle DML related events
  private val executionPlanTracker = new SparkExecutionPlanProcessor(atlasClient, atlasClientConf)
  executionPlanTracker.startThread()

  private val mlEventTracker = new MLPipelineEventProcessor(atlasClient, atlasClientConf)
  mlEventTracker.startThread()

  override def onOtherEvent(event: SparkListenerEvent): Unit = {
    if (!enabled) {
      // No op if SAC is disabled
      return
    }

    // We only care about SQL related events.
    event match {
      case e: ExternalCatalogEvent => catalogEventTracker.pushEvent(e)
      case e: SparkListenerEvent if e.getClass.getName.contains("org.apache.spark.ml") =>
        mlEventTracker.pushEvent(e)
      case _ => // Ignore other events
    }
  }

  override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = {
    if (!enabled) {
      // No op if SAC is disabled
      return
    }

    if (qe.logical.isStreaming) {
      // streaming query will be tracked via SparkAtlasStreamingQueryEventTracker
      return
    }

    val qd = QueryDetail.fromQueryExecutionListener(qe, durationNs)
    executionPlanTracker.pushEvent(qd)
  }

  override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
    // No-op: SAC is one of the listener.
  }

} 
Example 14
Source File: EventHubsWriter.scala    From azure-event-hubs-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.eventhubs

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{ AnalysisException, SparkSession }
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.{ BinaryType, StringType }
import org.apache.spark.util.Utils


private[eventhubs] object EventHubsWriter extends Logging {

  val BodyAttributeName = "body"
  val PartitionKeyAttributeName = "partitionKey"
  val PartitionIdAttributeName = "partition"
  val PropertiesAttributeName = "properties"

  override def toString: String = "EventHubsWriter"

  private def validateQuery(schema: Seq[Attribute], parameters: Map[String, String]): Unit = {
    schema
      .find(_.name == BodyAttributeName)
      .getOrElse(
        throw new AnalysisException(s"Required attribute '$BodyAttributeName' not found.")
      )
      .dataType match {
      case StringType | BinaryType => // good
      case _ =>
        throw new AnalysisException(
          s"$BodyAttributeName attribute type " +
            s"must be a String or BinaryType.")
    }
  }

  def write(
      sparkSession: SparkSession,
      queryExecution: QueryExecution,
      parameters: Map[String, String]
  ): Unit = {
    val schema = queryExecution.analyzed.output
    validateQuery(schema, parameters)
    queryExecution.toRdd.foreachPartition { iter =>
      val writeTask = new EventHubsWriteTask(parameters, schema)
      Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
        finallyBlock = writeTask.close()
      )
    }
  }
} 
Example 15
Source File: ApproxCountDistinctForIntervalsQuerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.test.SharedSQLContext

class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height
  // histogram usually contains hundreds of buckets. So we need to test
  // ApproxCountDistinctForIntervals with large number of endpoints
  // (the number of endpoints == the number of buckets + 1).
  test("test ApproxCountDistinctForIntervals with large number of endpoints") {
    val table = "approx_count_distinct_for_intervals_tbl"
    withTable(table) {
      (1 to 100000).toDF("col").createOrReplaceTempView(table)
      // percentiles of 0, 0.001, 0.002 ... 0.999, 1
      val endpoints = (0 to 1000).map(_ * 100000 / 1000)

      // Since approx_count_distinct_for_intervals is not a public function, here we do
      // the computation by constructing logical plan.
      val relation = spark.table(table).logicalPlan
      val attr = relation.output.find(_.name == "col").get
      val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_))))
      val aggExpr = aggFunc.toAggregateExpression()
      val namedExpr = Alias(aggExpr, aggExpr.toString)()
      val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation))
        .executedPlan.executeTake(1).head
      val ndvArray = ndvsRow.getArray(0).toLongArray()
      assert(endpoints.length == ndvArray.length + 1)

      // Each bucket has 100 distinct values.
      val expectedNdv = 100
      for (i <- ndvArray.indices) {
        val ndv = ndvArray(i)
        val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
        assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.")
      }
    }
  }
} 
Example 16
Source File: StreamingIncrementCommand.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.xsql.execution.command

import java.util.Locale

import org.apache.spark.SparkException
import org.apache.spark.sql.{Dataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.execution.streaming.StreamingRelationV2
import org.apache.spark.sql.sources.v2.StreamWriteSupport
import org.apache.spark.sql.streaming.{OutputMode, Trigger}
import org.apache.spark.sql.xsql.DataSourceManager._
import org.apache.spark.sql.xsql.StreamingSinkType


case class StreamingIncrementCommand(plan: LogicalPlan) extends RunnableCommand {

  private var outputMode: OutputMode = OutputMode.Append
  // dummy
  override def output: Seq[AttributeReference] = Seq.empty
  // dummy
  override def producedAttributes: AttributeSet = plan.producedAttributes

  override def run(sparkSession: SparkSession): Seq[Row] = {
    import StreamingSinkType._
    val qe = new QueryExecution(sparkSession, new ConstructedStreaming(plan))
    val df = new Dataset(sparkSession, qe, RowEncoder(qe.analyzed.schema))
    plan.collectLeaves.head match {
      case StreamingRelationV2(_, _, extraOptions, _, _) =>
        val source = extraOptions.getOrElse(STREAMING_SINK_TYPE, DEFAULT_STREAMING_SINK)
        val sinkOptions = extraOptions.filter(_._1.startsWith(STREAMING_SINK_PREFIX)).map { kv =>
          val key = kv._1.substring(STREAMING_SINK_PREFIX.length)
          (key, kv._2)
        }
        StreamingSinkType.withName(source.toUpperCase(Locale.ROOT)) match {
          case CONSOLE =>
          case TEXT | PARQUET | ORC | JSON | CSV =>
            if (sinkOptions.get(STREAMING_SINK_PATH) == None) {
              throw new SparkException("Sink type is file, must config path")
            }
          case KAFKA =>
            if (sinkOptions.get(STREAMING_SINK_BOOTSTRAP_SERVERS) == None) {
              throw new SparkException("Sink type is kafka, must config bootstrap servers")
            }
            if (sinkOptions.get(STREAMING_SINK_TOPIC) == None) {
              throw new SparkException("Sink type is kafka, must config kafka topic")
            }
          case _ =>
            throw new SparkException(
              "Sink type is invalid, " +
                s"select from ${StreamingSinkType.values}")
        }
        val ds = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
        val disabledSources = sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",")
        val sink = ds.newInstance() match {
          case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) =>
            w
          case _ =>
            val ds = DataSource(
              sparkSession,
              className = source,
              options = sinkOptions.toMap,
              partitionColumns = Nil)
            ds.createSink(InternalOutputModes.Append)
        }
        val outputMode = InternalOutputModes(
          extraOptions.getOrElse(STREAMING_OUTPUT_MODE, DEFAULT_STREAMING_OUTPUT_MODE))
        val duration =
          extraOptions.getOrElse(STREAMING_TRIGGER_DURATION, DEFAULT_STREAMING_TRIGGER_DURATION)
        val trigger =
          extraOptions.getOrElse(STREAMING_TRIGGER_TYPE, DEFAULT_STREAMING_TRIGGER_TYPE) match {
            case STREAMING_MICRO_BATCH_TRIGGER => Trigger.ProcessingTime(duration)
            case STREAMING_ONCE_TRIGGER => Trigger.Once()
            case STREAMING_CONTINUOUS_TRIGGER => Trigger.Continuous(duration)
          }
        val query = sparkSession.sessionState.streamingQueryManager.startQuery(
          extraOptions.get("queryName"),
          extraOptions.get(STREAMING_CHECKPOINT_LOCATION),
          df,
          sinkOptions.toMap,
          sink,
          outputMode,
          useTempCheckpointLocation = source == DEFAULT_STREAMING_SINK,
          recoverFromCheckpointLocation = true,
          trigger = trigger)
        query.awaitTermination()
    }
    // dummy
    Seq.empty
  }
}

case class ConstructedStreaming(child: LogicalPlan) extends UnaryNode {
  override def output: Seq[Attribute] = child.output
} 
Example 17
Source File: SparkSQLDriver.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import java.util.{ArrayList => JArrayList, Arrays, List => JList}

import scala.collection.JavaConverters._

import org.apache.commons.lang3.exception.ExceptionUtils
import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema}
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}


private[hive] class SparkSQLDriver(val context: SQLContext = SparkSQLEnv.sqlContext)
  extends Driver
  with Logging {

  private[hive] var tableSchema: Schema = _
  private[hive] var hiveResponse: Seq[String] = _

  override def init(): Unit = {
  }

  private def getResultSetSchema(query: QueryExecution): Schema = {
    val analyzed = query.analyzed
    logDebug(s"Result Schema: ${analyzed.output}")
    if (analyzed.output.isEmpty) {
      new Schema(Arrays.asList(new FieldSchema("Response code", "string", "")), null)
    } else {
      val fieldSchemas = analyzed.output.map { attr =>
        new FieldSchema(attr.name, attr.dataType.catalogString, "")
      }

      new Schema(fieldSchemas.asJava, null)
    }
  }

  override def run(command: String): CommandProcessorResponse = {
    // TODO unify the error code
    try {
      context.sparkContext.setJobDescription(command)
      val execution = context.sessionState.executePlan(context.sql(command).logicalPlan)
      hiveResponse = SQLExecution.withNewExecutionId(context.sparkSession, execution) {
        execution.hiveResultString()
      }
      tableSchema = getResultSetSchema(execution)
      new CommandProcessorResponse(0)
    } catch {
        case ae: AnalysisException =>
          logDebug(s"Failed in [$command]", ae)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(ae), null, ae)
        case cause: Throwable =>
          logError(s"Failed in [$command]", cause)
          new CommandProcessorResponse(1, ExceptionUtils.getStackTrace(cause), null, cause)
    }
  }

  override def close(): Int = {
    hiveResponse = null
    tableSchema = null
    0
  }

  override def getResults(res: JList[_]): Boolean = {
    if (hiveResponse == null) {
      false
    } else {
      res.asInstanceOf[JArrayList[String]].addAll(hiveResponse.asJava)
      hiveResponse = null
      true
    }
  }

  override def getSchema: Schema = tableSchema

  override def destroy() {
    super.destroy()
    hiveResponse = null
    tableSchema = null
  }
} 
Example 18
Source File: SparkSqlExtension.scala    From Linkis   with Apache License 2.0 5 votes vote down vote up
package com.webank.wedatasphere.linkis.engine.extension

import java.util.concurrent._

import com.webank.wedatasphere.linkis.common.conf.CommonVars
import com.webank.wedatasphere.linkis.common.utils.{Logging, Utils}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.{DataFrame, SQLContext}

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.duration._

abstract class SparkSqlExtension extends Logging{

  private val maxPoolSize = CommonVars("wds.linkis.dws.ujes.spark.extension.max.pool",5).getValue

  private  val executor = new ThreadPoolExecutor(2, maxPoolSize, 2, TimeUnit.SECONDS, new LinkedBlockingQueue[Runnable](), new ThreadFactory {
    override def newThread(r: Runnable): Thread = {
      val thread = new Thread(r)
      thread.setDaemon(true)
      thread
    }
  })

  final def afterExecutingSQL(sqlContext: SQLContext,command: String,dataFrame: DataFrame,timeout:Long,sqlStartTime:Long):Unit = {
    try {
      val thread = new Runnable {
        override def run(): Unit = extensionRule(sqlContext,command,dataFrame.queryExecution,sqlStartTime)
      }
      val future = executor.submit(thread)
      Utils.waitUntil(future.isDone,timeout milliseconds)
    } catch {
      case e: Throwable => info("Failed to execute SparkSqlExtension: ", e)
    }
  }

  protected def extensionRule(sqlContext: SQLContext,command: String,queryExecution: QueryExecution,sqlStartTime:Long):Unit


}

object SparkSqlExtension extends Logging {

  private val extensions = ArrayBuffer[SparkSqlExtension]()

  def register(sqlExtension: SparkSqlExtension):Unit = {
    info("Get a sqlExtension register")
    extensions.append(sqlExtension)
  }

  def getSparkSqlExtensions():Array[SparkSqlExtension] = {
    extensions.toArray
  }
}