Source File: MyUDF.scala    From spark-tools   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.TimestampType

object MyUDF {

  private def myTimestampCast(xs: Seq[Expression]): Expression = {
    val expSource = xs.head
    expSource.dataType match {
      case LongType =>
        new Column(expSource).divide(Literal(1000)).cast(TimestampType).expr
      case TimestampType =>

  def register(sparkSession: SparkSession): Unit =
      .registerFunction(FunctionIdentifier("toTs",None), myTimestampCast)

Source File: NativeFunctionRegistration.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package com.swoop.alchemy.spark.expressions

import org.apache.spark.sql.EncapsulationViolator.createAnalysisException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpressionInfo, RuntimeReplaceable}

import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

// based on Spark's FunctionRegistry @ossSpark
trait NativeFunctionRegistration extends FunctionRegistration {

  type FunctionBuilder = Seq[Expression] => Expression

  def expressions: Map[String, (ExpressionInfo, FunctionBuilder)]

  def registerFunctions(fr: FunctionRegistry): Unit = {
    expressions.foreach { case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder) }

  def registerFunctions(spark: SparkSession): Unit = {

  protected def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = {
    val clazz = scala.reflect.classTag[T].runtimeClass
    val df = clazz.getAnnotation(classOf[ExpressionDescription])
    if (df != null) {
      new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended())
    } else {
      new ExpressionInfo(clazz.getCanonicalName, name)

Source File: parser.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.extensions

import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression}
import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.command.{
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{SparkSession, TiContext}

case class TiParser(getOrCreateTiContext: SparkSession => TiContext)(
    sparkSession: SparkSession,
    delegate: ParserInterface)
    extends ParserInterface {
  private lazy val tiContext = getOrCreateTiContext(sparkSession)
  private lazy val internal = new SparkSqlParser(sparkSession.sqlContext.conf)

  private def needQualify(tableIdentifier: TableIdentifier) =
    tableIdentifier.database.isEmpty && tiContext.sessionCatalog
Source File: LookupFunctionsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis


import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

class LookupFunctionsSuite extends PlanTest {

  test("SPARK-23486: the functionExists for the Persistent function check") {
    val externalCatalog = new CustomInMemoryCatalog
    val conf = new SQLConf()
    val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf)
    val analyzer = {
        CatalogDatabase("default", "", new URI("loc"), Map.empty),
        ignoreIfExists = false)
      new Analyzer(catalog, conf)

    def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
    val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false)
    val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
    val plan = Project(
      Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(),
        Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(),
        Alias(unresolvedRegisteredFunc, "call5")()),

    assert(externalCatalog.getFunctionExistsCalledTimes == 1)
      ( == Some("default"))

  test("SPARK-23486: the functionExists for the Registered function check") {
    val externalCatalog = new InMemoryCatalog
    val conf = new SQLConf()
    val customerFunctionReg = new CustomerFunctionRegistry
    val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf)
    val analyzer = {
        CatalogDatabase("default", "", new URI("loc"), Map.empty),
        ignoreIfExists = false)
      new Analyzer(catalog, conf)

    def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
    val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
    val plan = Project(
      Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()),

    assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)
      ( == Some("default"))

class CustomerFunctionRegistry extends SimpleFunctionRegistry {

  private var isRegisteredFunctionCalledTimes: Int = 0;

  override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized {
    isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1

  def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes

class CustomInMemoryCatalog extends InMemoryCatalog {

  private var functionExistsCalledTimes: Int = 0

  override def functionExists(db: String, funcName: String): Boolean = synchronized {
    functionExistsCalledTimes = functionExistsCalledTimes + 1

  def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes
Source File: SparkExtension.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.sql

import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}

class SparkExtension extends (SparkSessionExtensions => Unit) {
  def apply(e: SparkSessionExtensions): Unit = {

case class SparkAtlasConnectorParser(spark: SparkSession, delegate: ParserInterface)
  extends ParserInterface {
  override def parsePlan(sqlText: String): LogicalPlan = {

  override def parseExpression(sqlText: String): Expression =

  override def parseTableIdentifier(sqlText: String): TableIdentifier =

  override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier =

  override def parseTableSchema(sqlText: String): StructType =

  override def parseDataType(sqlText: String): DataType =

object SQLQuery {
  private[this] val sqlQuery = new ThreadLocal[String]
  def get(): String = sqlQuery.get
  def set(s: String): Unit = sqlQuery.set(s)
Source File: BatchEvalPythonExecSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.BooleanType

class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.newProductEncoder
  import testImplicits.localSeqToDatasetHolder

  override def beforeAll(): Unit = {
    spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF)

  override def afterAll(): Unit = {

  test("Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: AttributeReference),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    assert(qualifiedPlanNodes.size == 2)

  test("Nested Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    assert(qualifiedPlanNodes.size == 2)

  test("Python UDF: no push down on non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    assert(qualifiedPlanNodes.size == 2)

  test("Python UDF: push down on deterministic predicates after the first non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4")

    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    assert(qualifiedPlanNodes.size == 2)

  test("Python UDF refers to the attributes from more than one child") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = Seq(("Hello", 4)).toDF("c", "d")
    val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
    val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect {
      case b: BatchEvalPythonExec => b
    assert(qualifiedPlanNodes.size == 1)

// This Python UDF is dummy and just for testing. Unable to execute.
class DummyUDF extends PythonFunction(
  command = Array[Byte](),
  envVars = Map("" -> "").asJava,
  pythonIncludes = ArrayBuffer("").asJava,
  pythonExec = "",
  pythonVer = "",
  broadcastVars = null,
  accumulator = null)

class MyDummyPythonUDF extends UserDefinedPythonFunction(
  name = "dummyUDF",
  func = new DummyUDF,
  dataType = BooleanType,
  pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
  udfDeterministic = true)