org.apache.spark.sql.catalyst.expressions.And Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.And. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: DeltaPushFilter.scala    From connectors   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.delta

import scala.collection.immutable.HashSet
import scala.collection.JavaConverters._

import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, SerializationUtilities}
import org.apache.hadoop.hive.ql.lib._
import org.apache.hadoop.hive.ql.parse.SemanticException
import org.apache.hadoop.hive.ql.plan.{ExprNodeColumnDesc, ExprNodeConstantDesc, ExprNodeGenericFuncDesc}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, InSet, LessThan, LessThanOrEqual, Like, Literal, Not}

object DeltaPushFilter extends Logging {
  lazy val supportedPushDownUDFs = Array(
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqual",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrGreaterThan",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualOrLessThan",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPLessThan",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPGreaterThan",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPEqualNS",
    "org.apache.hadoop.hive.ql.udf.UDFLike",
    "org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn"
  )

  def partitionFilterConverter(hiveFilterExprSeriablized: String): Seq[Expression] = {
    if (hiveFilterExprSeriablized != null) {
      val filterExpr = SerializationUtilities.deserializeExpression(hiveFilterExprSeriablized)
      val opRules = new java.util.LinkedHashMap[Rule, NodeProcessor]()
      val nodeProcessor = new NodeProcessor() {
        @throws[SemanticException]
        def process(nd: Node, stack: java.util.Stack[Node],
            procCtx: NodeProcessorCtx, nodeOutputs: Object*): Object = {
          nd match {
            case e: ExprNodeGenericFuncDesc if FunctionRegistry.isOpAnd(e) =>
              nodeOutputs.map(_.asInstanceOf[Expression]).reduce(And)
            case e: ExprNodeGenericFuncDesc =>
              val (columnDesc, constantDesc) =
                if (nd.getChildren.get(0).isInstanceOf[ExprNodeColumnDesc]) {
                  (nd.getChildren.get(0), nd.getChildren.get(1))
                } else { (nd.getChildren.get(1), nd.getChildren.get(0)) }

              val columnAttr = UnresolvedAttribute(
                columnDesc.asInstanceOf[ExprNodeColumnDesc].getColumn)
              val constantVal = Literal(constantDesc.asInstanceOf[ExprNodeConstantDesc].getValue)
              nd.asInstanceOf[ExprNodeGenericFuncDesc].getGenericUDF match {
                case f: GenericUDFOPNotEqualNS =>
                  Not(EqualNullSafe(columnAttr, constantVal))
                case f: GenericUDFOPNotEqual =>
                  Not(EqualTo(columnAttr, constantVal))
                case f: GenericUDFOPEqualNS =>
                  EqualNullSafe(columnAttr, constantVal)
                case f: GenericUDFOPEqual =>
                  EqualTo(columnAttr, constantVal)
                case f: GenericUDFOPGreaterThan =>
                  GreaterThan(columnAttr, constantVal)
                case f: GenericUDFOPEqualOrGreaterThan =>
                  GreaterThanOrEqual(columnAttr, constantVal)
                case f: GenericUDFOPLessThan =>
                  LessThan(columnAttr, constantVal)
                case f: GenericUDFOPEqualOrLessThan =>
                  LessThanOrEqual(columnAttr, constantVal)
                case f: GenericUDFBridge if f.getUdfName.equals("like") =>
                  Like(columnAttr, constantVal)
                case f: GenericUDFIn =>
                  val inConstantVals = nd.getChildren.asScala
                    .filter(_.isInstanceOf[ExprNodeConstantDesc])
                    .map(_.asInstanceOf[ExprNodeConstantDesc].getValue)
                    .map(Literal(_)).toSet
                  InSet(columnAttr, HashSet() ++ inConstantVals)
                case _ =>
                  throw new RuntimeException(s"Unsupported func(${nd.getName}) " +
                    s"which can not be pushed down to delta")
              }
            case _ => null
          }
        }
      }

      val disp = new DefaultRuleDispatcher(nodeProcessor, opRules, null)
      val ogw = new DefaultGraphWalker(disp)
      val topNodes = new java.util.ArrayList[Node]()
      topNodes.add(filterExpr)
      val nodeOutput = new java.util.HashMap[Node, Object]()
      try {
        ogw.startWalking(topNodes, nodeOutput)
      } catch {
        case ex: Exception =>
          throw new RuntimeException(ex)
      }
      logInfo(s"converted partition filter expr:" +
        s"${nodeOutput.get(filterExpr).asInstanceOf[Expression].toJSON}")
      Seq(nodeOutput.get(filterExpr).asInstanceOf[Expression])
    } else Seq.empty[org.apache.spark.sql.catalyst.expressions.Expression]
  }
} 
Example 2
Source File: DataSourceV2Strategy.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.{sources, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader

object DataSourceV2Strategy extends Strategy {

  
  // TODO: nested column pruning.
  private def pruneColumns(
      reader: DataSourceReader,
      relation: DataSourceV2Relation,
      exprs: Seq[Expression]): Seq[AttributeReference] = {
    reader match {
      case r: SupportsPushDownRequiredColumns =>
        val requiredColumns = AttributeSet(exprs.flatMap(_.references))
        val neededOutput = relation.output.filter(requiredColumns.contains)
        if (neededOutput != relation.output) {
          r.pruneColumns(neededOutput.toStructType)
          val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap
          r.readSchema().toAttributes.map {
            // We have to keep the attribute id during transformation.
            a => a.withExprId(nameToAttr(a.name).exprId)
          }
        } else {
          relation.output
        }

      case _ => relation.output
    }
  }


  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
      val reader = relation.newReader()
      // `pushedFilters` will be pushed down and evaluated in the underlying data sources.
      // `postScanFilters` need to be evaluated after the scan.
      // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
      val (pushedFilters, postScanFilters) = pushFilters(reader, filters)
      val output = pruneColumns(reader, relation, project ++ postScanFilters)
      logInfo(
        s"""
           |Pushing operators to ${relation.source.getClass}
           |Pushed Filters: ${pushedFilters.mkString(", ")}
           |Post-Scan Filters: ${postScanFilters.mkString(",")}
           |Output: ${output.mkString(", ")}
         """.stripMargin)

      val scan = DataSourceV2ScanExec(
        output, relation.source, relation.options, pushedFilters, reader)

      val filterCondition = postScanFilters.reduceLeftOption(And)
      val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)

      // always add the projection, which will produce unsafe rows required by some operators
      ProjectExec(project, withFilter) :: Nil

    case r: StreamingDataSourceV2Relation =>
      // ensure there is a projection, which will produce unsafe rows required by some operators
      ProjectExec(r.output,
        DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil

    case WriteToDataSourceV2(writer, query) =>
      WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil

    case AppendData(r: DataSourceV2Relation, query, _) =>
      WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil

    case WriteToContinuousDataSource(writer, query) =>
      WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil

    case Repartition(1, false, child) =>
      val isContinuous = child.collectFirst {
        case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r
      }.isDefined

      if (isContinuous) {
        ContinuousCoalesceExec(1, planLater(child)) :: Nil
      } else {
        Nil
      }

    case _ => Nil
  }
} 
Example 3
Source File: SemiJoinSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
//半连接测试套件
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = ctx.createDataFrame(
    ctx.sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = ctx.createDataFrame(
    ctx.sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }
  //测试左半连接
  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 4
Source File: BatchEvalPythonExecSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.BooleanType

class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.newProductEncoder
  import testImplicits.localSeqToDatasetHolder

  override def beforeAll(): Unit = {
    super.beforeAll()
    spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF)
  }

  override def afterAll(): Unit = {
    spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF"))
    super.afterAll()
  }

  test("Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: AttributeReference),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Nested Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF: no push down on non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF: push down on deterministic predicates after the first non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4")

    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF refers to the attributes from more than one child") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = Seq(("Hello", 4)).toDF("c", "d")
    val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
    val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect {
      case b: BatchEvalPythonExec => b
    }
    assert(qualifiedPlanNodes.size == 1)
  }
}

// This Python UDF is dummy and just for testing. Unable to execute.
class DummyUDF extends PythonFunction(
  command = Array[Byte](),
  envVars = Map("" -> "").asJava,
  pythonIncludes = ArrayBuffer("").asJava,
  pythonExec = "",
  pythonVer = "",
  broadcastVars = null,
  accumulator = null)

class MyDummyPythonUDF extends UserDefinedPythonFunction(
  name = "dummyUDF",
  func = new DummyUDF,
  dataType = BooleanType,
  pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
  udfDeterministic = true) 
Example 5
Source File: GroupAnd.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.udf

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{ DataType, BooleanType }
import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, And }

case class GroupAnd(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate {
  override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil
  override def nullable: Boolean = false
  // Return data type.
  override def dataType: DataType = BooleanType
  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function group_and")
  private lazy val group_and = AttributeReference("group_and", BooleanType)()
  override lazy val aggBufferAttributes: Seq[AttributeReference] = group_and :: Nil
  override lazy val initialValues: Seq[Literal] = Seq(
    Literal.create(true, BooleanType)
  )
  override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq(
    And(group_and, child)
  )
  override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = {
    Seq(
      And(group_and.left, group_and.right)
    )
  }
  override lazy val evaluateExpression: AttributeReference = group_and
} 
Example 6
Source File: SemiJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }

  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 7
Source File: SimbaOptimizer.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkOptimizer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.simba.plans.SpatialJoin


class SimbaOptimizer(catalog: SessionCatalog,
                     conf: SQLConf,
                     experimentalMethods: ExperimentalMethods)
 extends SparkOptimizer(catalog, conf, experimentalMethods) {
  override def batches: Seq[Batch] = super.batches :+
    Batch("SpatialJoinPushDown", FixedPoint(100), PushPredicateThroughSpatialJoin)
}

object PushPredicateThroughSpatialJoin extends Rule[LogicalPlan] with PredicateHelper {
  private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
    val (leftEvaluateCondition, rest) =
      condition.partition(_.references subsetOf left.outputSet)
    val (rightEvaluateCondition, commonCondition) =
      rest.partition(_.references subsetOf right.outputSet)

    (leftEvaluateCondition, rightEvaluateCondition, commonCondition)
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    // push the where condition down into join filter
    case f @ Filter(filterCondition, SpatialJoin(left, right, joinType, joinCondition)) =>
      val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
        split(splitConjunctivePredicates(filterCondition), left, right)

      val newLeft = leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
      val newRight = rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
      val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
      SpatialJoin(newLeft, newRight, joinType, newJoinCond)

    // push down the join filter into sub query scanning if applicable
    case f @ SpatialJoin(left, right, joinType, joinCondition) =>
      val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
        split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)

      val newLeft = leftJoinConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
      val newRight = rightJoinConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
      val newJoinCond = commonJoinCondition.reduceLeftOption(And)

      SpatialJoin(newLeft, newRight, joinType, newJoinCond)
  }
} 
Example 8
Source File: PredicateUtil.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.util

import org.apache.spark.sql.catalyst.expressions.{Expression, And, Or}


object PredicateUtil {
  def toDNF(condition: Expression): Expression = {
    condition match {
      case Or(left, right) =>
        Or(toDNF(left), toDNF(right))
      case And(left, right) =>
        var ans: Expression = null
        val tmp_left = toDNF(left)
        val tmp_right = toDNF(right)
        tmp_left match {
          case Or(l, r) =>
            ans = Or(And(l, tmp_right), And(r, tmp_right))
          case _ =>
        }
        tmp_right match {
          case Or(l, r) =>
            if (ans == null) ans = Or(And(tmp_left, l), And(tmp_left, r))
          case _ =>
        }
        if (ans == null) And(tmp_left, tmp_right)
        else toDNF(ans)
      case exp => exp
    }
  }

  def toCNF(condition: Expression): Expression = {
    condition match {
      case And(left, right) =>
        And(toCNF(left), toCNF(right))
      case Or(left, right) =>
        var ans: Expression = null
        val tmp_left = toCNF(left)
        val tmp_right = toCNF(right)
        tmp_left match {
          case And(l, r) =>
            ans = And(Or(l, tmp_right), Or(r, tmp_right))
          case _ =>
        }
        tmp_right match {
          case And(l, r) =>
            if (ans == null) ans = And(Or(tmp_left, l), Or(tmp_left, r))
          case _ =>
        }
        if (ans == null) Or(tmp_left, tmp_right)
        else toCNF(ans)
      case exp => exp
    }
  }
  def dnfExtract(expression: Expression): Seq[Expression] = {
    expression match {
      case Or(left, right) =>
        dnfExtract(left) ++ dnfExtract(right)
      case And(left @ And(l2, r2), right) =>
        dnfExtract(And(l2, And(r2, right)))
      case other =>
        other :: Nil
    }
  }

  def cnfExtract(expression: Expression): Seq[Expression] = {
    expression match {
      case And(left, right) =>
        cnfExtract(left) ++ cnfExtract(right)
      case Or(left @ Or(l2, r2), right) =>
        cnfExtract(Or(l2, Or(r2, right)))
      case other =>
        other :: Nil
    }
  }

  def splitDNFPredicates(condition: Expression) = dnfExtract(toDNF(condition))

  def splitCNFPredicates(condition: Expression) = cnfExtract(toCNF(condition))


  def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
    condition match {
      case And(cond1, cond2) =>
        splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
      case other => other :: Nil
    }
  }

  def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = {
    condition match {
      case Or(cond1, cond2) =>
        splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2)
      case other => other :: Nil
    }
  }

} 
Example 9
Source File: ParameterBinderSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service.postgresql.protocol.v3

import java.sql.SQLException

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.server.catalyst.expressions.ParameterPlaceHolder
import org.apache.spark.sql.server.service.ParamBinder
import org.apache.spark.sql.types._

class ParameterBinderSuite extends PlanTest {

  test("bind parameters") {
    val c0 = 'a.int
    val c1 = 'b.int
    val r1 = LocalRelation(c0, c1)

    val param1 = Literal(18, IntegerType)
    val lp1 = Filter(EqualTo(c0, ParameterPlaceHolder(1)), r1)
    val expected1 = Filter(EqualTo(c0, param1), r1)
    comparePlans(expected1, ParamBinder.bind(lp1, Map(1 -> param1)))

    val param2 = Literal(42, IntegerType)
    val lp2 = Filter(EqualTo(c0, ParameterPlaceHolder(300)), r1)
    val expected2 = Filter(EqualTo(c0, param2), r1)
    comparePlans(expected2, ParamBinder.bind(lp2, Map(300 -> param2)))

    val param3 = Literal(-1, IntegerType)
    val param4 = Literal(48, IntegerType)
    val lp3 = Filter(
      And(
        EqualTo(c0, ParameterPlaceHolder(1)),
        EqualTo(c1, ParameterPlaceHolder(2))
      ), r1)
    val expected3 = Filter(
      And(
        EqualTo(c0, param3),
        EqualTo(c1, param4)
      ), r1)
    comparePlans(expected3, ParamBinder.bind(lp3, Map(1 -> param3, 2 -> param4)))

    val errMsg1 = intercept[SQLException] {
      ParamBinder.bind(lp1, Map.empty)
    }.getMessage
    assert(errMsg1 == "Unresolved parameters found: $1")
    val errMsg2 = intercept[SQLException] {
      ParamBinder.bind(lp2, Map.empty)
    }.getMessage
    assert(errMsg2 == "Unresolved parameters found: $300")
    val errMsg3 = intercept[SQLException] {
      ParamBinder.bind(lp3, Map.empty)
    }.getMessage
    assert(errMsg3 == "Unresolved parameters found: $1, $2")
  }
}