org.apache.spark.sql.catalyst.plans.Inner Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.plans.Inner. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: ReorderAssociativeOperatorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 2
Source File: ReorderAssociativeOperatorSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 3
Source File: ReorderAssociativeOperatorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 4
Source File: ExtractJoinConditionsSuite.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.view.plans

import org.apache.carbondata.mv.dsl.Plans._
import org.apache.carbondata.view.testutil.ModularPlanTest
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.plans.{Inner, _}

class ExtractJoinConditionsSuite extends ModularPlanTest {
  val testRelation0 = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation1 = LocalRelation('d.int)
  val testRelation2 = LocalRelation('b.int,'c.int,'e.int)
  
  test("join only") {
    val left = testRelation0.where('a === 1)
    val right = testRelation1
    val originalQuery =
      left.join(right, condition = Some("d".attr === "b".attr || "d".attr === "c".attr)).analyze
    val modularPlan = originalQuery.modularize 
    val extracted = modularPlan.extractJoinConditions(modularPlan.children(0),modularPlan.children(1))
    
    val correctAnswer = originalQuery match {
      case logical.Join(logical.Filter(cond1,MatchLocalRelation(tbl1,_)),MatchLocalRelation(tbl2,_),Inner,Some(cond2)) =>
        Seq(cond2)
    }
    
    compareExpressions(correctAnswer, extracted)
  }
  
  test("join and filter") {
    val left = testRelation0.where('b === 2).subquery('l)
    val right = testRelation2.where('b === 2).subquery('r)
    val originalQuery =
      left.join(right,condition = Some("r.b".attr === 2 && "l.c".attr === "r.c".attr)).analyze
    val modularPlan = originalQuery.modularize
    val extracted = modularPlan.extractJoinConditions(modularPlan.children(0),modularPlan.children(1))
    
    val originalQuery1 =
      left.join(right,condition = Some("l.c".attr === "r.c".attr)).analyze
      
    val correctAnswer = originalQuery1 match {
      case logical.Join(logical.Filter(cond1,MatchLocalRelation(tbl1,_)),logical.Filter(cond2,MatchLocalRelation(tbl2,_)),Inner,Some(cond3)) =>
        Seq(cond3)
    }    
    
    compareExpressions(correctAnswer, extracted)
  }
} 
Example 5
Source File: ReorderAssociativeOperatorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 6
Source File: SemiJoinSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
//半连接测试套件
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = ctx.createDataFrame(
    ctx.sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = ctx.createDataFrame(
    ctx.sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }
  //测试左半连接
  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 7
Source File: ResolveHintsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._

class ResolveHintsSuite extends AnalysisTest {
  import org.apache.spark.sql.catalyst.analysis.TestRelations._

  test("invalid hints should be ignored") {
    checkAnalysis(
      UnresolvedHint("some_random_hint_that_does_not_exist", Seq("TaBlE"), table("TaBlE")),
      testRelation,
      caseSensitive = false)
  }

  test("case-sensitive or insensitive parameters") {
    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
      ResolvedHint(testRelation, HintInfo(broadcast = true)),
      caseSensitive = false)

    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
      ResolvedHint(testRelation, HintInfo(broadcast = true)),
      caseSensitive = false)

    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("TaBlE"), table("TaBlE")),
      ResolvedHint(testRelation, HintInfo(broadcast = true)),
      caseSensitive = true)

    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("table"), table("TaBlE")),
      testRelation,
      caseSensitive = true)
  }

  test("multiple broadcast hint aliases") {
    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("table", "table2"), table("table").join(table("table2"))),
      Join(ResolvedHint(testRelation, HintInfo(broadcast = true)),
        ResolvedHint(testRelation2, HintInfo(broadcast = true)), Inner, None),
      caseSensitive = false)
  }

  test("do not traverse past existing broadcast hints") {
    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("table"),
        ResolvedHint(table("table").where('a > 1), HintInfo(broadcast = true))),
      ResolvedHint(testRelation.where('a > 1), HintInfo(broadcast = true)).analyze,
      caseSensitive = false)
  }

  test("should work for subqueries") {
    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").as("tableAlias")),
      ResolvedHint(testRelation, HintInfo(broadcast = true)),
      caseSensitive = false)

    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("tableAlias"), table("table").subquery('tableAlias)),
      ResolvedHint(testRelation, HintInfo(broadcast = true)),
      caseSensitive = false)

    // Negative case: if the alias doesn't match, don't match the original table name.
    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("table"), table("table").as("tableAlias")),
      testRelation,
      caseSensitive = false)
  }

  test("do not traverse past subquery alias") {
    checkAnalysis(
      UnresolvedHint("MAPJOIN", Seq("table"), table("table").where('a > 1).subquery('tableAlias)),
      testRelation.where('a > 1).analyze,
      caseSensitive = false)
  }

  test("should work for CTE") {
    checkAnalysis(
      CatalystSqlParser.parsePlan(
        """
          |WITH ctetable AS (SELECT * FROM table WHERE a > 1)
          |SELECT  * FROM ctetable
        """.stripMargin
      ),
      testRelation.where('a > 1).select('a).select('a).analyze,
      caseSensitive = false)
  }
} 
Example 8
Source File: ReorderAssociativeOperatorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 9
Source File: SemiJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }

  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 10
Source File: ExtractRangeJoinKeysWithEquality.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.rangejoins.common

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner


      //println(condition.head)
      if (condition.size!=0 && joinType == Inner) {
        condition.head match {
          case And(And(EqualTo(l3, r3), LessThanOrEqual(l1, g1)), LessThanOrEqual(l2, g2)) =>
            Some((joinType,
              getKeys(l1, l2, g1, g2, l3, r3, left, right),
              left, right))
          case And(And(EqualTo(l3, r3), GreaterThanOrEqual(g1, l1)), LessThanOrEqual(l2, g2)) =>
            Some((joinType,
              getKeys(l1, l2, g1, g2, l3, r3, left, right),
              left, right))
          case And(And(EqualTo(l3, r3), LessThanOrEqual(l1, g1)), GreaterThanOrEqual(g2, l2)) =>
            Some((joinType,
              getKeys(l1, l2, g1, g2, l3, r3, left, right),
              left, right))
          case And(And(EqualTo(l3, r3), GreaterThanOrEqual(g1, l1)), GreaterThanOrEqual(g2, l2)) =>
            Some((joinType,
              getKeys(l1, l2, g1, g2, l3, r3, left, right),
              left, right))
          case _ => None
        }
      } else {
        None
      }
    case _ =>
      None
  }

  def getKeys(l1:Expression,l2:Expression,g1:Expression,g2:Expression,l3:Expression,r3:Expression,left:LogicalPlan,right:LogicalPlan): Seq[Expression] ={
    var leftStart:Expression = null
    var leftEnd:Expression = null
    var rightStart:Expression = null
    var rightEnd:Expression = null
    var leftEquality:Expression = null
    var rightEquality:Expression = null

    if (canEvaluate(g1, right)) {
      if (canEvaluate(l1, left)) {
        leftStart=l1
        leftEnd=g2
        rightStart=l2
        rightEnd=g1
      } else {
        leftStart=l2
        leftEnd=g2
        rightStart=l1
        rightEnd=g1
      }
    } else {
      if (canEvaluate(l1, left)) {
        leftStart=l1
        leftEnd=g1
        rightStart=l2
        rightEnd=g2
      } else {
        leftStart=l2
        leftEnd=g1
        rightStart=l1
        rightEnd=g2
      }
    }

    if (canEvaluate(l3, left)) {
      leftEquality = l3
      rightEquality = r3
    } else {
      leftEquality = r3
      rightEquality = l3
    }

    List(leftStart, leftEnd, rightStart, rightEnd,leftEquality,rightEquality)
  }
}