org.apache.spark.sql.catalyst.expressions.And Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.expressions.And. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: DeltaPushFilter.scala    From connectors   with Apache License 2.0 5 votes vote down vote up

import scala.collection.immutable.HashSet
import scala.collection.JavaConverters._

import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, SerializationUtilities}
import org.apache.hadoop.hive.ql.lib._
import org.apache.hadoop.hive.ql.parse.SemanticException
import org.apache.hadoop.hive.ql.plan.{ExprNodeColumnDesc, ExprNodeConstantDesc, ExprNodeGenericFuncDesc}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{And, EqualNullSafe, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, InSet, LessThan, LessThanOrEqual, Like, Literal, Not}

object DeltaPushFilter extends Logging {
  lazy val supportedPushDownUDFs = Array(

  def partitionFilterConverter(hiveFilterExprSeriablized: String): Seq[Expression] = {
    if (hiveFilterExprSeriablized != null) {
      val filterExpr = SerializationUtilities.deserializeExpression(hiveFilterExprSeriablized)
      val opRules = new java.util.LinkedHashMap[Rule, NodeProcessor]()
      val nodeProcessor = new NodeProcessor() {
        def process(nd: Node, stack: java.util.Stack[Node],
            procCtx: NodeProcessorCtx, nodeOutputs: Object*): Object = {
          nd match {
            case e: ExprNodeGenericFuncDesc if FunctionRegistry.isOpAnd(e) =>
            case e: ExprNodeGenericFuncDesc =>
              val (columnDesc, constantDesc) =
                if (nd.getChildren.get(0).isInstanceOf[ExprNodeColumnDesc]) {
                  (nd.getChildren.get(0), nd.getChildren.get(1))
                } else { (nd.getChildren.get(1), nd.getChildren.get(0)) }

              val columnAttr = UnresolvedAttribute(
              val constantVal = Literal(constantDesc.asInstanceOf[ExprNodeConstantDesc].getValue)
              nd.asInstanceOf[ExprNodeGenericFuncDesc].getGenericUDF match {
                case f: GenericUDFOPNotEqualNS =>
                  Not(EqualNullSafe(columnAttr, constantVal))
                case f: GenericUDFOPNotEqual =>
                  Not(EqualTo(columnAttr, constantVal))
                case f: GenericUDFOPEqualNS =>
                  EqualNullSafe(columnAttr, constantVal)
                case f: GenericUDFOPEqual =>
                  EqualTo(columnAttr, constantVal)
                case f: GenericUDFOPGreaterThan =>
                  GreaterThan(columnAttr, constantVal)
                case f: GenericUDFOPEqualOrGreaterThan =>
                  GreaterThanOrEqual(columnAttr, constantVal)
                case f: GenericUDFOPLessThan =>
                  LessThan(columnAttr, constantVal)
                case f: GenericUDFOPEqualOrLessThan =>
                  LessThanOrEqual(columnAttr, constantVal)
                case f: GenericUDFBridge if f.getUdfName.equals("like") =>
                  Like(columnAttr, constantVal)
                case f: GenericUDFIn =>
                  val inConstantVals = nd.getChildren.asScala
                  InSet(columnAttr, HashSet() ++ inConstantVals)
                case _ =>
                  throw new RuntimeException(s"Unsupported func(${nd.getName}) " +
                    s"which can not be pushed down to delta")
            case _ => null

      val disp = new DefaultRuleDispatcher(nodeProcessor, opRules, null)
      val ogw = new DefaultGraphWalker(disp)
      val topNodes = new java.util.ArrayList[Node]()
      val nodeOutput = new java.util.HashMap[Node, Object]()
      try {
        ogw.startWalking(topNodes, nodeOutput)
      } catch {
        case ex: Exception =>
          throw new RuntimeException(ex)
      logInfo(s"converted partition filter expr:" +
    } else Seq.empty[org.apache.spark.sql.catalyst.expressions.Expression]
Example 2
Source File: DataSourceV2Strategy.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.{sources, Strategy}
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition}
import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader

object DataSourceV2Strategy extends Strategy {

  // TODO: nested column pruning.
  private def pruneColumns(
      reader: DataSourceReader,
      relation: DataSourceV2Relation,
      exprs: Seq[Expression]): Seq[AttributeReference] = {
    reader match {
      case r: SupportsPushDownRequiredColumns =>
        val requiredColumns = AttributeSet(exprs.flatMap(_.references))
        val neededOutput = relation.output.filter(requiredColumns.contains)
        if (neededOutput != relation.output) {
          val nameToAttr =
          r.readSchema() {
            // We have to keep the attribute id during transformation.
            a => a.withExprId(nameToAttr(
        } else {

      case _ => relation.output

  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case PhysicalOperation(project, filters, relation: DataSourceV2Relation) =>
      val reader = relation.newReader()
      // `pushedFilters` will be pushed down and evaluated in the underlying data sources.
      // `postScanFilters` need to be evaluated after the scan.
      // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
      val (pushedFilters, postScanFilters) = pushFilters(reader, filters)
      val output = pruneColumns(reader, relation, project ++ postScanFilters)
           |Pushing operators to ${relation.source.getClass}
           |Pushed Filters: ${pushedFilters.mkString(", ")}
           |Post-Scan Filters: ${postScanFilters.mkString(",")}
           |Output: ${output.mkString(", ")}

      val scan = DataSourceV2ScanExec(
        output, relation.source, relation.options, pushedFilters, reader)

      val filterCondition = postScanFilters.reduceLeftOption(And)
      val withFilter =, scan)).getOrElse(scan)

      // always add the projection, which will produce unsafe rows required by some operators
      ProjectExec(project, withFilter) :: Nil

    case r: StreamingDataSourceV2Relation =>
      // ensure there is a projection, which will produce unsafe rows required by some operators
        DataSourceV2ScanExec(r.output, r.source, r.options, r.pushedFilters, r.reader)) :: Nil

    case WriteToDataSourceV2(writer, query) =>
      WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil

    case AppendData(r: DataSourceV2Relation, query, _) =>
      WriteToDataSourceV2Exec(r.newWriter(), planLater(query)) :: Nil

    case WriteToContinuousDataSource(writer, query) =>
      WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil

    case Repartition(1, false, child) =>
      val isContinuous = child.collectFirst {
        case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r

      if (isContinuous) {
        ContinuousCoalesceExec(1, planLater(child)) :: Nil
      } else {

    case _ => Nil
Example 3
Source File: SemiJoinSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = ctx.createDataFrame(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = ctx.createDataFrame(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            sortAnswers = true)

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            sortAnswers = true)

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          sortAnswers = true)
    "basic test",
      (2, 1.0),
      (2, 1.0)
Example 4
Source File: BatchEvalPythonExecSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.BooleanType

class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.newProductEncoder
  import testImplicits.localSeqToDatasetHolder

  override def beforeAll(): Unit = {
    spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF)

  override def afterAll(): Unit = {

  test("Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: AttributeReference),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    assert(qualifiedPlanNodes.size == 2)

  test("Nested Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    assert(qualifiedPlanNodes.size == 2)

  test("Python UDF: no push down on non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    assert(qualifiedPlanNodes.size == 2)

  test("Python UDF: push down on deterministic predicates after the first non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4")

    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    assert(qualifiedPlanNodes.size == 2)

  test("Python UDF refers to the attributes from more than one child") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = Seq(("Hello", 4)).toDF("c", "d")
    val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
    val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect {
      case b: BatchEvalPythonExec => b
    assert(qualifiedPlanNodes.size == 1)

// This Python UDF is dummy and just for testing. Unable to execute.
class DummyUDF extends PythonFunction(
  command = Array[Byte](),
  envVars = Map("" -> "").asJava,
  pythonIncludes = ArrayBuffer("").asJava,
  pythonExec = "",
  pythonVer = "",
  broadcastVars = null,
  accumulator = null)

class MyDummyPythonUDF extends UserDefinedPythonFunction(
  name = "dummyUDF",
  func = new DummyUDF,
  dataType = BooleanType,
  pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
  udfDeterministic = true) 
Example 5
Source File: GroupAnd.scala    From mimir   with Apache License 2.0 5 votes vote down vote up
package mimir.exec.spark.udf

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.{ DataType, BooleanType }
import org.apache.spark.sql.catalyst.expressions.{ AttributeReference, Literal, And }

case class GroupAnd(child: org.apache.spark.sql.catalyst.expressions.Expression) extends DeclarativeAggregate {
  override def children: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = child :: Nil
  override def nullable: Boolean = false
  // Return data type.
  override def dataType: DataType = BooleanType
  override def checkInputDataTypes(): TypeCheckResult =
    TypeUtils.checkForOrderingExpr(child.dataType, "function group_and")
  private lazy val group_and = AttributeReference("group_and", BooleanType)()
  override lazy val aggBufferAttributes: Seq[AttributeReference] = group_and :: Nil
  override lazy val initialValues: Seq[Literal] = Seq(
    Literal.create(true, BooleanType)
  override lazy val updateExpressions: Seq[ org.apache.spark.sql.catalyst.expressions.Expression] = Seq(
    And(group_and, child)
  override lazy val mergeExpressions: Seq[org.apache.spark.sql.catalyst.expressions.Expression] = {
      And(group_and.left, group_and.right)
  override lazy val evaluateExpression: AttributeReference = group_and
Example 6
Source File: SemiJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = sqlContext.createDataFrame(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = sqlContext.createDataFrame(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            sortAnswers = true)

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            sortAnswers = true)

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          sortAnswers = true)

    "basic test",
      (2, 1.0),
      (2, 1.0)
Example 7
Source File: SimbaOptimizer.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.expressions.{And, Expression, PredicateHelper}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkOptimizer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.simba.plans.SpatialJoin

class SimbaOptimizer(catalog: SessionCatalog,
                     conf: SQLConf,
                     experimentalMethods: ExperimentalMethods)
 extends SparkOptimizer(catalog, conf, experimentalMethods) {
  override def batches: Seq[Batch] = super.batches :+
    Batch("SpatialJoinPushDown", FixedPoint(100), PushPredicateThroughSpatialJoin)

object PushPredicateThroughSpatialJoin extends Rule[LogicalPlan] with PredicateHelper {
  private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
    val (leftEvaluateCondition, rest) =
      condition.partition(_.references subsetOf left.outputSet)
    val (rightEvaluateCondition, commonCondition) =
      rest.partition(_.references subsetOf right.outputSet)

    (leftEvaluateCondition, rightEvaluateCondition, commonCondition)

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    // push the where condition down into join filter
    case f @ Filter(filterCondition, SpatialJoin(left, right, joinType, joinCondition)) =>
      val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
        split(splitConjunctivePredicates(filterCondition), left, right)

      val newLeft = leftFilterConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
      val newRight = rightFilterConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
      val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
      SpatialJoin(newLeft, newRight, joinType, newJoinCond)

    // push down the join filter into sub query scanning if applicable
    case f @ SpatialJoin(left, right, joinType, joinCondition) =>
      val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
        split(, left, right)

      val newLeft = leftJoinConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
      val newRight = rightJoinConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
      val newJoinCond = commonJoinCondition.reduceLeftOption(And)

      SpatialJoin(newLeft, newRight, joinType, newJoinCond)
Example 8
Source File: PredicateUtil.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.util

import org.apache.spark.sql.catalyst.expressions.{Expression, And, Or}

object PredicateUtil {
  def toDNF(condition: Expression): Expression = {
    condition match {
      case Or(left, right) =>
        Or(toDNF(left), toDNF(right))
      case And(left, right) =>
        var ans: Expression = null
        val tmp_left = toDNF(left)
        val tmp_right = toDNF(right)
        tmp_left match {
          case Or(l, r) =>
            ans = Or(And(l, tmp_right), And(r, tmp_right))
          case _ =>
        tmp_right match {
          case Or(l, r) =>
            if (ans == null) ans = Or(And(tmp_left, l), And(tmp_left, r))
          case _ =>
        if (ans == null) And(tmp_left, tmp_right)
        else toDNF(ans)
      case exp => exp

  def toCNF(condition: Expression): Expression = {
    condition match {
      case And(left, right) =>
        And(toCNF(left), toCNF(right))
      case Or(left, right) =>
        var ans: Expression = null
        val tmp_left = toCNF(left)
        val tmp_right = toCNF(right)
        tmp_left match {
          case And(l, r) =>
            ans = And(Or(l, tmp_right), Or(r, tmp_right))
          case _ =>
        tmp_right match {
          case And(l, r) =>
            if (ans == null) ans = And(Or(tmp_left, l), Or(tmp_left, r))
          case _ =>
        if (ans == null) Or(tmp_left, tmp_right)
        else toCNF(ans)
      case exp => exp
  def dnfExtract(expression: Expression): Seq[Expression] = {
    expression match {
      case Or(left, right) =>
        dnfExtract(left) ++ dnfExtract(right)
      case And(left @ And(l2, r2), right) =>
        dnfExtract(And(l2, And(r2, right)))
      case other =>
        other :: Nil

  def cnfExtract(expression: Expression): Seq[Expression] = {
    expression match {
      case And(left, right) =>
        cnfExtract(left) ++ cnfExtract(right)
      case Or(left @ Or(l2, r2), right) =>
        cnfExtract(Or(l2, Or(r2, right)))
      case other =>
        other :: Nil

  def splitDNFPredicates(condition: Expression) = dnfExtract(toDNF(condition))

  def splitCNFPredicates(condition: Expression) = cnfExtract(toCNF(condition))

  def splitConjunctivePredicates(condition: Expression): Seq[Expression] = {
    condition match {
      case And(cond1, cond2) =>
        splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
      case other => other :: Nil

  def splitDisjunctivePredicates(condition: Expression): Seq[Expression] = {
    condition match {
      case Or(cond1, cond2) =>
        splitDisjunctivePredicates(cond1) ++ splitDisjunctivePredicates(cond2)
      case other => other :: Nil

Example 9
Source File: ParameterBinderSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service.postgresql.protocol.v3

import java.sql.SQLException

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.server.catalyst.expressions.ParameterPlaceHolder
import org.apache.spark.sql.server.service.ParamBinder
import org.apache.spark.sql.types._

class ParameterBinderSuite extends PlanTest {

  test("bind parameters") {
    val c0 = '
    val c1 = '
    val r1 = LocalRelation(c0, c1)

    val param1 = Literal(18, IntegerType)
    val lp1 = Filter(EqualTo(c0, ParameterPlaceHolder(1)), r1)
    val expected1 = Filter(EqualTo(c0, param1), r1)
    comparePlans(expected1, ParamBinder.bind(lp1, Map(1 -> param1)))

    val param2 = Literal(42, IntegerType)
    val lp2 = Filter(EqualTo(c0, ParameterPlaceHolder(300)), r1)
    val expected2 = Filter(EqualTo(c0, param2), r1)
    comparePlans(expected2, ParamBinder.bind(lp2, Map(300 -> param2)))

    val param3 = Literal(-1, IntegerType)
    val param4 = Literal(48, IntegerType)
    val lp3 = Filter(
        EqualTo(c0, ParameterPlaceHolder(1)),
        EqualTo(c1, ParameterPlaceHolder(2))
      ), r1)
    val expected3 = Filter(
        EqualTo(c0, param3),
        EqualTo(c1, param4)
      ), r1)
    comparePlans(expected3, ParamBinder.bind(lp3, Map(1 -> param3, 2 -> param4)))

    val errMsg1 = intercept[SQLException] {
      ParamBinder.bind(lp1, Map.empty)
    assert(errMsg1 == "Unresolved parameters found: $1")
    val errMsg2 = intercept[SQLException] {
      ParamBinder.bind(lp2, Map.empty)
    assert(errMsg2 == "Unresolved parameters found: $300")
    val errMsg3 = intercept[SQLException] {
      ParamBinder.bind(lp3, Map.empty)
    assert(errMsg3 == "Unresolved parameters found: $1, $2")