org.apache.spark.sql.SparkSessionExtensions Scala Examples

The following examples show how to use org.apache.spark.sql.SparkSessionExtensions. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: TestSparkSession.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.internal.{SQLConf, SessionState, SessionStateBuilder, WithTestConf}
import org.apache.spark.{SparkConf, SparkContext}


  val overrideConfs: Map[String, String] =
    Map(
      // Fewer shuffle partitions to speed up testing.
      SQLConf.SHUFFLE_PARTITIONS.key -> "3"
    )
}

private[sql] class TestSQLSessionStateBuilder(
  session: SparkSession,
  state: Option[SessionState])
  extends SessionStateBuilder(session, state) with WithTestConf {
  override def overrideConfs: Map[String, String] = TestSQLContext.overrideConfs
  override def newBuilder: NewBuilder = new TestSQLSessionStateBuilder(_, _)
} 
Example 2
Source File: ColumnarPlugin.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package com.intel.sparkColumnarPlugin

import com.intel.sparkColumnarPlugin.execution._

import org.apache.spark.internal.Logging
import org.apache.spark.SparkConf
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.{RowToColumnarExec, ColumnarToRowExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}

case class ColumnarPreOverrides(conf: SparkConf) extends Rule[SparkPlan] {
  val columnarConf = ColumnarPluginConfig.getConf(conf)

  def replaceWithColumnarPlan(plan: SparkPlan): SparkPlan = plan match {
    case plan: BatchScanExec =>
      logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
      new ColumnarBatchScanExec(plan.output, plan.scan)
    case plan: ProjectExec =>
      //new ColumnarProjectExec(plan.projectList, replaceWithColumnarPlan(plan.child))
      val columnarPlan = replaceWithColumnarPlan(plan.child)
      val res = if (!columnarPlan.isInstanceOf[ColumnarConditionProjectExec]) {
        new ColumnarConditionProjectExec(null, plan.projectList, columnarPlan)
      } else {
        val cur_plan = columnarPlan.asInstanceOf[ColumnarConditionProjectExec]
        new ColumnarConditionProjectExec(cur_plan.condition, plan.projectList, cur_plan.child)
      }
      logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
      res
    case plan: FilterExec =>
      val child = replaceWithColumnarPlan(plan.child)
      logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
      new ColumnarConditionProjectExec(plan.condition, null, child)
    case plan: HashAggregateExec =>
      val child = replaceWithColumnarPlan(plan.child)
      logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
      new ColumnarHashAggregateExec(
        plan.requiredChildDistributionExpressions,
        plan.groupingExpressions,
        plan.aggregateExpressions,
        plan.aggregateAttributes,
        plan.initialInputBufferOffset,
        plan.resultExpressions,
        child)
    case plan: SortExec =>
      if (columnarConf.enableColumnarSort) {
        val child = replaceWithColumnarPlan(plan.child)
        logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
        new ColumnarSortExec(plan.sortOrder, plan.global, child, plan.testSpillFrequency)
      } else {
        val children = plan.children.map(replaceWithColumnarPlan)
        logDebug(s"Columnar Processing for ${plan.getClass} is not currently supported.")
        plan.withNewChildren(children)
      }
    
class ColumnarPlugin extends Function1[SparkSessionExtensions, Unit] with Logging {
  override def apply(extensions: SparkSessionExtensions): Unit = {
    logWarning(
      "Installing extensions to enable columnar CPU support." +
        " To disable this set `org.apache.spark.example.columnar.enabled` to false")
    extensions.injectColumnar((session) => ColumnarOverrideRules(session))
  }
} 
Example 3
Source File: SparkExtension.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.sql

import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}


class SparkExtension extends (SparkSessionExtensions => Unit) {
  def apply(e: SparkSessionExtensions): Unit = {
    e.injectParser(SparkAtlasConnectorParser)
  }
}

case class SparkAtlasConnectorParser(spark: SparkSession, delegate: ParserInterface)
  extends ParserInterface {
  override def parsePlan(sqlText: String): LogicalPlan = {
    SQLQuery.set(sqlText)
    delegate.parsePlan(sqlText)
  }

  override def parseExpression(sqlText: String): Expression =
    delegate.parseExpression(sqlText)

  override def parseTableIdentifier(sqlText: String): TableIdentifier =
    delegate.parseTableIdentifier(sqlText)

  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =
    delegate.parseFunctionIdentifier(sqlText)

  override def parseTableSchema(sqlText: String): StructType =
    delegate.parseTableSchema(sqlText)

  override def parseDataType(sqlText: String): DataType =
    delegate.parseDataType(sqlText)
}

object SQLQuery {
  private[this] val sqlQuery = new ThreadLocal[String]
  def get(): String = sqlQuery.get
  def set(s: String): Unit = sqlQuery.set(s)
} 
Example 4
Source File: DeltaSparkSessionExtension.scala    From delta   with Apache License 2.0 5 votes vote down vote up
package io.delta.sql

import org.apache.spark.sql.delta._
import io.delta.sql.parser.DeltaSqlParser

import org.apache.spark.sql.SparkSessionExtensions
import org.apache.spark.sql.internal.SQLConf


class DeltaSparkSessionExtension extends (SparkSessionExtensions => Unit) {
  override def apply(extensions: SparkSessionExtensions): Unit = {
    extensions.injectParser { (session, parser) =>
      new DeltaSqlParser(parser)
    }
    extensions.injectResolutionRule { session =>
      new DeltaAnalysis(session, session.sessionState.conf)
    }
    extensions.injectCheckRule { session =>
      new DeltaUnsupportedOperationsCheck(session)
    }
    extensions.injectPostHocResolutionRule { session =>
      new PreprocessTableUpdate(session.sessionState.conf)
    }
    extensions.injectPostHocResolutionRule { session =>
      new PreprocessTableMerge(session.sessionState.conf)
    }
    extensions.injectPostHocResolutionRule { session =>
      new PreprocessTableDelete(session.sessionState.conf)
    }
  }
} 
Example 5
Source File: Spark.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import com.github.passionke.starry.StarrySparkContext
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.execution.LocalBasedStrategies


object Spark {

  val sparkConf = new SparkConf()
  sparkConf.setMaster("local[*]")
  sparkConf.setAppName("aloha")
  sparkConf
    .set("spark.default.parallelism", "1")
    .set("spark.sql.shuffle.partitions", "1")
    .set("spark.broadcast.manager", "rotary")
    .set("rotary.shuffer", "true")
    .set("spark.sql.codegen.wholeStage", "false")
    .set("spark.sql.extensions", "org.apache.spark.sql.StarrySparkSessionExtension")
    .set("spark.driver.allowMultipleContexts", "true") // for test only
  val sparkContext = new StarrySparkContext(sparkConf)
  val sparkSession: SparkSession =
    SparkSession.builder
      .sparkContext(sparkContext)
      .getOrCreate

  LocalBasedStrategies.register(sparkSession)
} 
Example 6
Source File: HiveAcidAutoConvert.scala    From spark-acid   with Apache License 2.0 5 votes vote down vote up
package com.qubole.spark.hiveacid

import java.util.Locale

import com.qubole.spark.datasources.hiveacid.sql.execution.SparkAcidSqlParser
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, InsertIntoTable, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.LogicalRelation
import com.qubole.spark.hiveacid.datasource.HiveAcidDataSource



case class HiveAcidAutoConvert(spark: SparkSession) extends Rule[LogicalPlan] {

  private def isConvertible(relation: HiveTableRelation): Boolean = {
    val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
    relation.tableMeta.properties.getOrElse("transactional", "false").toBoolean
  }

  private def convert(relation: HiveTableRelation): LogicalRelation = {
    val options = relation.tableMeta.properties ++
      relation.tableMeta.storage.properties ++ Map("table" -> relation.tableMeta.qualifiedName)

    val newRelation = new HiveAcidDataSource().createRelation(spark.sqlContext, options)
    LogicalRelation(newRelation, isStreaming = false)
  }

  override def apply(plan: LogicalPlan): LogicalPlan = {
    plan resolveOperators {
      // Write path
      case InsertIntoTable(r: HiveTableRelation, partition, query, overwrite, ifPartitionNotExists)
        if query.resolved && DDLUtils.isHiveTable(r.tableMeta) && isConvertible(r) =>
        InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists)

      // Read path
      case relation: HiveTableRelation
        if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) =>
        convert(relation)
    }
  }
}

class HiveAcidAutoConvertExtension extends (SparkSessionExtensions => Unit) {
  def apply(extension: SparkSessionExtensions): Unit = {
    extension.injectResolutionRule(HiveAcidAutoConvert.apply)
    extension.injectParser { (session, parser) =>
      SparkAcidSqlParser(parser)
    }
  }
} 
Example 7
Source File: SQLServerEnv.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server

import scala.util.control.NonFatal

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, SQLContext}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.server.ui.SQLServerTab
import org.apache.spark.util.Utils

object SQLServerEnv extends Logging {

  // For test use
  private var _sqlContext: Option[SQLContext] = None

  @DeveloperApi
  def withSQLContext(sqlContext: SQLContext): Unit = {
    require(sqlContext != null)
    _sqlContext = Option(sqlContext)
    sqlServListener
    uiTab
  }

  private def mergeSparkConf(sqlConf: SQLConf, sparkConf: SparkConf): Unit = {
    sparkConf.getAll.foreach { case (k, v) =>
      sqlConf.setConfString(k, v)
    }
  }

  lazy val sparkConf: SparkConf = _sqlContext.map(_.sparkContext.conf).getOrElse {
    val sparkConf = new SparkConf(loadDefaults = true)

    // If user doesn't specify the appName, we want to get [SparkSQL::localHostName]
    // instead of the default appName [SQLServer].
    val maybeAppName = sparkConf
      .getOption("spark.app.name")
      .filterNot(_ == classOf[SQLServer].getName)
    sparkConf
      .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}"))
      .set("spark.sql.crossJoin.enabled", "true")
  }

  lazy val sqlConf: SQLConf = _sqlContext.map(_.conf).getOrElse {
    val newSqlConf = new SQLConf()
    mergeSparkConf(newSqlConf, sparkConf)
    newSqlConf
  }

  lazy val sqlContext: SQLContext = _sqlContext.getOrElse(newSQLContext(sparkConf))
  lazy val sparkContext: SparkContext = sqlContext.sparkContext
  lazy val sqlServListener: Option[SQLServerListener] = Some(newSQLServerListener(sqlContext))
  lazy val uiTab: Option[SQLServerTab] = newUiTab(sqlContext, sqlServListener.get)

  private[sql] def newSQLContext(conf: SparkConf): SQLContext = {
    def buildSQLContext(f: SparkSessionExtensions => Unit = _ => {}): SQLContext = {
      SparkSession.builder.config(conf).withExtensions(f).enableHiveSupport()
        .getOrCreate().sqlContext
    }
    val builderClassName = conf.get("spark.sql.server.extensions.builder", "")
    if (builderClassName.nonEmpty) {
      // Tries to install user-defined extensions
      try {
        val objName = builderClassName + (if (!builderClassName.endsWith("$")) "$" else "")
        val clazz = Utils.classForName(objName)
        val builder = clazz.getDeclaredField("MODULE$").get(null)
          .asInstanceOf[SparkSessionExtensions => Unit]
        val sqlContext = buildSQLContext(builder)
        logInfo(s"Successfully installed extensions from $builderClassName")
        sqlContext
      } catch {
        case NonFatal(e) =>
          logWarning(s"Failed to install extensions from $builderClassName: " + e.getMessage)
          buildSQLContext()
      }
    } else {
      buildSQLContext()
    }
  }
  def newSQLServerListener(sqlContext: SQLContext): SQLServerListener = {
    val listener = new SQLServerListener(sqlContext.conf)
    sqlContext.sparkContext.addSparkListener(listener)
    listener
  }
  def newUiTab(sqlContext: SQLContext, listener: SQLServerListener): Option[SQLServerTab] = {
    sqlContext.sparkContext.conf.getBoolean("spark.ui.enabled", true) match {
      case true => Some(SQLServerTab(SQLServerEnv.sqlContext.sparkContext, listener))
      case _ => None
    }
  }
}