org.apache.spark.sql.catalyst.plans.PlanTest Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.plans.PlanTest. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: ComputeCurrentTimeSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 2
Source File: ParameterBinderSuite.scala    From spark-sql-server   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.server.service.postgresql.protocol.v3

import java.sql.SQLException

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.server.catalyst.expressions.ParameterPlaceHolder
import org.apache.spark.sql.server.service.ParamBinder
import org.apache.spark.sql.types._

class ParameterBinderSuite extends PlanTest {

  test("bind parameters") {
    val c0 = 'a.int
    val c1 = 'b.int
    val r1 = LocalRelation(c0, c1)

    val param1 = Literal(18, IntegerType)
    val lp1 = Filter(EqualTo(c0, ParameterPlaceHolder(1)), r1)
    val expected1 = Filter(EqualTo(c0, param1), r1)
    comparePlans(expected1, ParamBinder.bind(lp1, Map(1 -> param1)))

    val param2 = Literal(42, IntegerType)
    val lp2 = Filter(EqualTo(c0, ParameterPlaceHolder(300)), r1)
    val expected2 = Filter(EqualTo(c0, param2), r1)
    comparePlans(expected2, ParamBinder.bind(lp2, Map(300 -> param2)))

    val param3 = Literal(-1, IntegerType)
    val param4 = Literal(48, IntegerType)
    val lp3 = Filter(
      And(
        EqualTo(c0, ParameterPlaceHolder(1)),
        EqualTo(c1, ParameterPlaceHolder(2))
      ), r1)
    val expected3 = Filter(
      And(
        EqualTo(c0, param3),
        EqualTo(c1, param4)
      ), r1)
    comparePlans(expected3, ParamBinder.bind(lp3, Map(1 -> param3, 2 -> param4)))

    val errMsg1 = intercept[SQLException] {
      ParamBinder.bind(lp1, Map.empty)
    }.getMessage
    assert(errMsg1 == "Unresolved parameters found: $1")
    val errMsg2 = intercept[SQLException] {
      ParamBinder.bind(lp2, Map.empty)
    }.getMessage
    assert(errMsg2 == "Unresolved parameters found: $300")
    val errMsg3 = intercept[SQLException] {
      ParamBinder.bind(lp3, Map.empty)
    }.getMessage
    assert(errMsg3 == "Unresolved parameters found: $1, $2")
  }
} 
Example 3
Source File: AggregateOptimizeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      ReplaceDistinctWithAggregate,
      RemoveLiteralFromGroupExpressions) :: Nil
  }

  test("replace distinct with aggregate") {
    val input = LocalRelation('a.int, 'b.int)

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)
  }

  test("remove literals in grouping expression") {
    val input = LocalRelation('a.int, 'b.int)

    val query =
      input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(query)

    val correctAnswer = input.groupBy('a)(sum('b))

    comparePlans(optimized, correctAnswer)
  }
} 
Example 4
Source File: ProjectCollapsingSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation.select(
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),
      'b).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two nondeterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(Rand(20).as('rand2))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(Rand(20).as('rand2)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse one nondeterministic, one deterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand), 'a)
      .select(('a + 1).as('a_plus_1))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(('a + 1).as('a_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 5
Source File: SetOperationPushDownSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class SetOperationPushDownSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", Once,
        EliminateSubQueries) ::
      Batch("Union Pushdown", Once,
        SetOperationPushDown,
        SimplifyFilters) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
  val testUnion = Union(testRelation, testRelation2)
  val testIntersect = Intersect(testRelation, testRelation2)
  val testExcept = Except(testRelation, testRelation2)

  test("union/intersect/except: filter to each side") {
    val unionQuery = testUnion.where('a === 1)
    val intersectQuery = testIntersect.where('b < 10)
    val exceptQuery = testExcept.where('c >= 5)

    val unionOptimized = Optimize.execute(unionQuery.analyze)
    val intersectOptimized = Optimize.execute(intersectQuery.analyze)
    val exceptOptimized = Optimize.execute(exceptQuery.analyze)

    val unionCorrectAnswer =
      Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze
    val intersectCorrectAnswer =
      Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze
    val exceptCorrectAnswer =
      Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze

    comparePlans(unionOptimized, unionCorrectAnswer)
    comparePlans(intersectOptimized, intersectCorrectAnswer)
    comparePlans(exceptOptimized, exceptCorrectAnswer)
  }

  test("union: project to each side") {
    val unionQuery = testUnion.select('a)
    val unionOptimized = Optimize.execute(unionQuery.analyze)
    val unionCorrectAnswer =
      Union(testRelation.select('a), testRelation2.select('d)).analyze
    comparePlans(unionOptimized, unionCorrectAnswer)
  }

  test("SPARK-10539: Project should not be pushed down through Intersect or Except") {
    val intersectQuery = testIntersect.select('b, 'c)
    val exceptQuery = testExcept.select('a, 'b, 'c)

    val intersectOptimized = Optimize.execute(intersectQuery.analyze)
    val exceptOptimized = Optimize.execute(exceptQuery.analyze)

    comparePlans(intersectOptimized, intersectQuery.analyze)
    comparePlans(exceptOptimized, exceptQuery.analyze)
  }
} 
Example 6
Source File: ColumnPruningSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types.StringType

class ColumnPruningSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Column pruning", FixedPoint(100),
      ColumnPruning) :: Nil
  }

  test("Column pruning for Generate when Generate.join = false") {
    val input = LocalRelation('a.int, 'b.array(StringType))

    val query = input.generate(Explode('b), join = false).analyze

    val optimized = Optimize.execute(query)

    val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("Column pruning for Generate when Generate.join = true") {
    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))

    val query =
      input
        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
        .select('a, 'explode)
        .analyze

    val optimized = Optimize.execute(query)

    val correctAnswer =
      input
        .select('a, 'c)
        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
        .select('a, 'explode)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("Turn Generate.join to false if possible") {
    val input = LocalRelation('b.array(StringType))

    val query =
      input
        .generate(Explode('b), join = true, outputNames = "explode" :: Nil)
        .select(('explode + 1).as("result"))
        .analyze

    val optimized = Optimize.execute(query)

    val correctAnswer =
      input
        .generate(Explode('b), join = false, outputNames = "explode" :: Nil)
        .select(('explode + 1).as("result"))
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("Column pruning for Project on Sort") {
    val input = LocalRelation('a.int, 'b.string, 'c.double)

    val query = input.orderBy('b.asc).select('a).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze

    comparePlans(optimized, correctAnswer)
  }

  // todo: add more tests for column pruning
} 
Example 7
Source File: CombiningLimitsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class CombiningLimitsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Filter Pushdown", FixedPoint(100),
        ColumnPruning) ::
      Batch("Combine Limit", FixedPoint(10),
        CombineLimits) ::
      Batch("Constant Folding", FixedPoint(10),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("limits: combines two limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(10)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(5).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines three limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .limit(7)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines two limits after ColumnPruning") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .select('a)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 8
Source File: LikeSimplificationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 9
Source File: ConvertToLocalRelationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 10
Source File: SimplifyCaseConversionExpressionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

class SimplifyCaseConversionExpressionsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Simplify CaseConversionExpressions", Once,
        SimplifyCaseConversionExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify UPPER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Upper('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify UPPER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Lower('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Upper('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Lower('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 11
Source File: AnalysisTest.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf}

trait AnalysisTest extends PlanTest {

  val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = {
    val caseSensitiveConf = new SimpleCatalystConf(true)
    val caseInsensitiveConf = new SimpleCatalystConf(false)

    val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
    val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)

    caseSensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation)
    caseInsensitiveCatalog.registerTable(TableIdentifier("TaBlE"), TestRelations.testRelation)

    new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) {
      override val extendedResolutionRules = EliminateSubQueries :: Nil
    } ->
    new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) {
      override val extendedResolutionRules = EliminateSubQueries :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.execute(inputPlan)
    analyzer.checkAnalysis(actualPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    analyzer.checkAnalysis(analyzer.execute(inputPlan))
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val e = intercept[AnalysisException] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }
    assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains),
      s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " +
        s"actually we get ${e.getMessage}")
  }
} 
Example 12
Source File: AggregateOptimizeSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}

class AggregateOptimizeSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 13
Source File: ReorderAssociativeOperatorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 14
Source File: SimplifyCastsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class SimplifyCastsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil
  }

  test("non-nullable element array to nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('a.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable element to non-nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, true)))
    val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    // Though cast from `ArrayType(IntegerType, true)` to `ArrayType(IntegerType, false)` is not
    // allowed, here we just ensure that `SimplifyCasts` rule respect the plan.
    comparePlans(optimized, plan, checkAnalysis = false)
  }

  test("non-nullable value map to nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, false)))
    val plan = input.select('m.cast(MapType(StringType, StringType, true))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('m.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable value map to non-nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, true)))
    val plan = input.select('m.cast(MapType(StringType, StringType, false))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    // Though cast from `MapType(StringType, StringType, true)` to
    // `MapType(StringType, StringType, false)` is not allowed, here we just ensure that
    // `SimplifyCasts` rule respect the plan.
    comparePlans(optimized, plan, checkAnalysis = false)
  }
} 
Example 15
Source File: CombineConcatsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class CombineConcatsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("CombineConcatsSuite", FixedPoint(50), CombineConcats) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
    comparePlans(actual, correctAnswer)
  }

  def str(s: String): Literal = Literal(s)
  def binary(s: String): Literal = Literal(s.getBytes)

  test("combine nested Concat exprs") {
    assertEquivalent(
      Concat(
        Concat(str("a") :: str("b") :: Nil) ::
        str("c") ::
        str("d") ::
        Nil),
      Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
    assertEquivalent(
      Concat(
        str("a") ::
        Concat(str("b") :: str("c") :: Nil) ::
        str("d") ::
        Nil),
      Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
    assertEquivalent(
      Concat(
        str("a") ::
        str("b") ::
        Concat(str("c") :: str("d") :: Nil) ::
        Nil),
      Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
    assertEquivalent(
      Concat(
        Concat(
          str("a") ::
          Concat(
            str("b") ::
            Concat(str("c") :: str("d") :: Nil) ::
            Nil) ::
          Nil) ::
        Nil),
      Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
  }

  test("combine string and binary exprs") {
    assertEquivalent(
      Concat(
        Concat(str("a") :: str("b") :: Nil) ::
        Concat(binary("c") :: binary("d") :: Nil) ::
        Nil),
      Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil))
  }
} 
Example 16
Source File: SimplifyConditionalSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
    comparePlans(actual, correctAnswer)
  }

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
    assertEquivalent(
      If(TrueLiteral, Literal(10), Literal(20)),
      Literal(10))

    assertEquivalent(
      If(FalseLiteral, Literal(10), Literal(20)),
      Literal(20))

    assertEquivalent(
      If(Literal.create(null, NullType), Literal(10), Literal(20)),
      Literal(20))
  }

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
    assertEquivalent(
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))
  }

  test("remove entire CaseWhen if only the else branch is reachable") {
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
      Literal(30))

    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))
  }

  test("remove entire CaseWhen if the first branch is always true") {
    assertEquivalent(
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
      Literal(5))

    // Test branch elimination and simplification in combination
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),
      Literal(5))

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
    assertEquivalent(
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: Nil, None))
  }

  test("simplify CaseWhen, prune branches following a definite true") {
    assertEquivalent(
      CaseWhen(normalBranch :: unreachableBranch ::
        unreachableBranch :: nullBranch ::
        trueBranch :: normalBranch ::
        Nil,
        None),
      CaseWhen(normalBranch :: trueBranch :: Nil, None))
  }
} 
Example 17
Source File: FiltersSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import java.util.Collections

import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
  private val shim = new Shim_v0_13

  private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
  private val varCharCol = new FieldSchema()
  varCharCol.setName("varchar")
  varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
  testTable.setPartCols(Collections.singletonList(varCharCol))

  filterTest("string filter",
    (a("stringcol", StringType) > Literal("test")) :: Nil,
    "stringcol > \"test\"")

  filterTest("string filter backwards",
    (Literal("test") > a("stringcol", StringType)) :: Nil,
    "\"test\" > stringcol")

  filterTest("int filter",
    (a("intcol", IntegerType) === Literal(1)) :: Nil,
    "intcol = 1")

  filterTest("int filter backwards",
    (Literal(1) === a("intcol", IntegerType)) :: Nil,
    "1 = intcol")

  filterTest("int and string filter",
    (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
    "1 = intcol and \"a\" = strcol")

  filterTest("skip varchar",
    (Literal("") === a("varchar", StringType)) :: Nil,
    "")

  filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning",
    (a("stringcol", StringType) === Literal("p1\" and q=\"q1")) ::
      (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil,
    """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""")

  private def filterTest(name: String, filters: Seq[Expression], result: String) = {
    test(name) {
      withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") {
        val converted = shim.convertFilters(testTable, filters)
        if (converted != result) {
          fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'")
        }
      }
    }
  }

  test("turn on/off ADVANCED_PARTITION_PREDICATE_PUSHDOWN") {
    import org.apache.spark.sql.catalyst.dsl.expressions._
    Seq(true, false).foreach { enabled =>
      withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> enabled.toString) {
        val filters =
          (Literal(1) === a("intcol", IntegerType) ||
            Literal(2) === a("intcol", IntegerType)) :: Nil
        val converted = shim.convertFilters(testTable, filters)
        if (enabled) {
          assert(converted == "(1 = intcol or 2 = intcol)")
        } else {
          assert(converted.isEmpty)
        }
      }
    }
  }

  private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
} 
Example 18
Source File: EliminateSerializationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

case class OtherTuple(_1: Int, _2: Int)

class EliminateSerializationSuite extends PlanTest {
  private object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Serialization", FixedPoint(100),
        EliminateSerialization) :: Nil
  }

  implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
  implicit private def intEncoder = ExpressionEncoder[Int]()

  test("back to back serialization") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('obj.as("obj")).analyze
    comparePlans(optimized, expected)
  }

  test("back to back serialization with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }

  test("back to back serialization in AppendColumns") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze

    val optimized = Optimize.execute(plan)

    val expected = AppendColumnsWithObject(
      func.asInstanceOf[Any => Any],
      productEncoder[(Int, Int)].namedExpressions,
      intEncoder.namedExpressions,
      input).analyze

    comparePlans(optimized, expected)
  }

  test("back to back serialization in AppendColumns with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze

    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }
} 
Example 19
Source File: EliminateSubqueryAliasesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) :: Nil
  }

  private def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
    comparePlans(actual, correctAnswer)
  }

  private def afterOptimization(plan: LogicalPlan): LogicalPlan = {
    Optimize.execute(analysis.SimpleAnalyzer.execute(plan))
  }

  test("eliminate top level subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = SubqueryAlias("a", input)
    comparePlans(afterOptimization(query), input)
  }

  test("eliminate mid-tree subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral, SubqueryAlias("a", input))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }

  test("eliminate multiple subqueries") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral,
      SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input))))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }
} 
Example 20
Source File: RewriteSubquerySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.ListQuery
import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class RewriteSubquerySuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Column Pruning", FixedPoint(100), ColumnPruning) ::
      Batch("Rewrite Subquery", FixedPoint(1),
        RewritePredicateSubquery,
        ColumnPruning,
        CollapseProject,
        RemoveRedundantProject) :: Nil
  }

  test("Column pruning after rewriting predicate subquery") {
    val relation = LocalRelation('a.int, 'b.int)
    val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int)

    val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = relation
      .select('a)
      .join(relInSubquery.select('x), LeftSemi, Some('a === 'x))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

} 
Example 21
Source File: SimplifyStringCaseConversionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class SimplifyStringCaseConversionSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Simplify CaseConversionExpressions", Once,
        SimplifyCaseConversionExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify UPPER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Upper('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify UPPER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Lower('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Upper('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Lower('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 22
Source File: EliminateMapObjectsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class EliminateMapObjectsSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = {
      Batch("EliminateMapObjects", FixedPoint(50),
        NullPropagation,
        SimplifyCasts,
        EliminateMapObjects) :: Nil
    }
  }

  implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]()
  implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]()

  test("SPARK-20254: Remove unnecessary data conversion for primitive array") {
    val intObjType = ObjectType(classOf[Array[Int]])
    val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val intQuery = intInput.deserialize[Array[Int]].analyze
    val intOptimized = Optimize.execute(intQuery)
    val intExpected = DeserializeToObject(
      Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
      AttributeReference("obj", intObjType, true)(), intInput)
    comparePlans(intOptimized, intExpected)

    val doubleObjType = ObjectType(classOf[Array[Double]])
    val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
    val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
    val doubleOptimized = Optimize.execute(doubleQuery)
    val doubleExpected = DeserializeToObject(
      Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
      AttributeReference("obj", doubleObjType, true)(), doubleInput)
    comparePlans(doubleOptimized, doubleExpected)
  }
} 
Example 23
Source File: CombiningLimitsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class CombiningLimitsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Filter Pushdown", FixedPoint(100),
        ColumnPruning) ::
      Batch("Combine Limit", FixedPoint(10),
        CombineLimits) ::
      Batch("Constant Folding", FixedPoint(10),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyConditionals) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("limits: combines two limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(10)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(5).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines three limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .limit(7)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines two limits after ColumnPruning") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .select('a)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 24
Source File: RewriteDistinctAggregatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 25
Source File: CollapseWindowSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      min(a).as('min_a),
      max(a).as('max_a),
      sum(b).as('sum_b),
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }

  test("Don't collapse adjacent windows with dependent columns") {
    val query = testRelation
      .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1)
      .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1)
      .analyze

    val expected = query.analyze
    val optimized = Optimize.execute(query.analyze)
    comparePlans(optimized, expected)
  }
} 
Example 26
Source File: BinaryComparisonSimplificationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubqueryAliases) ::
      Batch("Constant Folding", FixedPoint(50),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyBinaryComparison,
        PruneFilters) :: Nil
  }

  val nullableRelation = LocalRelation('a.int.withNullability(true))
  val nonNullableRelation = LocalRelation('a.int.withNullability(false))

  test("Preserve nullable exprs in general") {
    for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) {
      val plan = nullableRelation.where(e).analyze
      val actual = Optimize.execute(plan)
      val correctAnswer = plan
      comparePlans(actual, correctAnswer)
    }
  }

  test("Preserve non-deterministic exprs") {
    val plan = nonNullableRelation
      .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = plan
    comparePlans(actual, correctAnswer)
  }

  test("Nullable Simplification Primitive: <=>") {
    val plan = nullableRelation.select('a <=> 'a).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze
    comparePlans(actual, correctAnswer)
  }

  test("Non-Nullable Simplification Primitive") {
    val plan = nonNullableRelation
      .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nonNullableRelation
      .select(
        Alias(TrueLiteral, "(a = a)")(),
        Alias(TrueLiteral, "(a <=> a)")(),
        Alias(TrueLiteral, "(a <= a)")(),
        Alias(TrueLiteral, "(a >= a)")(),
        Alias(FalseLiteral, "(a < a)")(),
        Alias(FalseLiteral, "(a > a)")())
      .analyze
    comparePlans(actual, correctAnswer)
  }

  test("Expression Normalization") {
    val plan = nonNullableRelation.where(
      'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a &&
      DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a))
      .analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nonNullableRelation.analyze
    comparePlans(actual, correctAnswer)
  }
} 
Example 27
Source File: OptimizerStructuralIntegrityCheckerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf


class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {

  object OptimizeRuleBreakSI extends Rule[LogicalPlan] {
    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
      case Project(projectList, child) =>
        val newAttr = UnresolvedAttribute("unresolvedAttr")
        Project(projectList ++ Seq(newAttr), child)
    }
  }

  object Optimize extends Optimizer(
    new SessionCatalog(
      new InMemoryCatalog,
      EmptyFunctionRegistry,
      new SQLConf())) {
    val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI)
    override def batches: Seq[Batch] = Seq(newBatch) ++ super.batches
  }

  test("check for invalid plan after execution of rule") {
    val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze
    assert(analyzed.resolved)
    val message = intercept[TreeNodeException[LogicalPlan]] {
      Optimize.execute(analyzed)
    }.getMessage
    val ruleName = OptimizeRuleBreakSI.ruleName
    assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI"))
    assert(message.contains("the structural integrity of the plan is broken"))
  }
} 
Example 28
Source File: EliminateDistinctSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class EliminateDistinctSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Operator Optimizations", Once,
        EliminateDistinct) :: Nil
  }

  val testRelation = LocalRelation('a.int)

  test("Eliminate Distinct in Max") {
    val query = testRelation
      .select(maxDistinct('a).as('result))
      .analyze
    val answer = testRelation
      .select(max('a).as('result))
      .analyze
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
  }

  test("Eliminate Distinct in Min") {
    val query = testRelation
      .select(minDistinct('a).as('result))
      .analyze
    val answer = testRelation
      .select(min('a).as('result))
      .analyze
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
  }
} 
Example 29
Source File: LikeSimplificationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{BooleanType, StringType}

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into startsWith and EndsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc\\%def") || ('a like "abc%def"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a like "abc\\%def") ||
        (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("null pattern") {
    val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze
    val optimized = Optimize.execute(originalQuery)
    comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze)
  }
} 
Example 30
Source File: PullupCorrelatedPredicatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PullupCorrelatedPredicatesSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("PullupCorrelatedPredicates", Once,
        PullupCorrelatedPredicates) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.double)
  val testRelation2 = LocalRelation('c.int, 'd.double)

  test("PullupCorrelatedPredicates should not produce unresolved plan") {
    val correlatedSubquery =
      testRelation2
        .where('b < 'd)
        .select('c)
    val outerQuery =
      testRelation
        .where(In('a, Seq(ListQuery(correlatedSubquery))))
        .select('a).analyze
    assert(outerQuery.resolved)

    val optimized = Optimize.execute(outerQuery)
    assert(optimized.resolved)
  }
} 
Example 31
Source File: ConvertToLocalRelationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 32
Source File: AnalysisTest.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import java.net.URI
import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

trait AnalysisTest extends PlanTest {

  protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true)
  protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false)

  private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
    val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
    val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf)
    catalog.createDatabase(
      CatalogDatabase("default", "", new URI("loc"), Map.empty),
      ignoreIfExists = false)
    catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
    catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
    catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.executeAndCheck(inputPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected override def comparePlans(
      plan1: LogicalPlan,
      plan2: LogicalPlan,
      checkAnalysis: Boolean = false): Unit = {
    // Analysis tests may have not been fully resolved, so skip checkAnalysis.
    super.comparePlans(plan1, plan2, checkAnalysis)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val analysisAttempt = analyzer.execute(inputPlan)
    try analyzer.checkAnalysis(analysisAttempt) catch {
      case a: AnalysisException =>
        fail(
          s"""
            |Failed to Analyze Plan
            |$inputPlan
            |
            |Partial Analysis
            |$analysisAttempt
          """.stripMargin, a)
    }
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val e = intercept[AnalysisException] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }

    if (!expectedErrors.map(_.toLowerCase(Locale.ROOT)).forall(
        e.getMessage.toLowerCase(Locale.ROOT).contains)) {
      fail(
        s"""Exception message should contain the following substrings:
           |
           |  ${expectedErrors.mkString("\n  ")}
           |
           |Actual exception message:
           |
           |  ${e.getMessage}
         """.stripMargin)
    }
  }
} 
Example 33
Source File: AnalysisTest.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._

trait AnalysisTest extends PlanTest {

  protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true)
  protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false)

  private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
    val conf = new SimpleCatalystConf(caseSensitive)
    val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
    catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.execute(inputPlan)
    analyzer.checkAnalysis(actualPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val analysisAttempt = analyzer.execute(inputPlan)
    try analyzer.checkAnalysis(analysisAttempt) catch {
      case a: AnalysisException =>
        fail(
          s"""
            |Failed to Analyze Plan
            |$inputPlan
            |
            |Partial Analysis
            |$analysisAttempt
          """.stripMargin, a)
    }
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val e = intercept[AnalysisException] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }

    if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) {
      fail(
        s"""Exception message should contain the following substrings:
           |
           |  ${expectedErrors.mkString("\n  ")}
           |
           |Actual exception message:
           |
           |  ${e.getMessage}
         """.stripMargin)
    }
  }
} 
Example 34
Source File: ResolveInlineTablesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.{LongType, NullType}


class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables.validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables.validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables.convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables.convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables.convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 35
Source File: ConvertToLocalRelationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 36
Source File: LikeSimplificationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into startsWith and EndsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc\\%def") || ('a like "abc%def"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a like "abc\\%def") ||
        (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 37
Source File: BinaryComparisonSimplificationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubqueryAliases) ::
      Batch("Constant Folding", FixedPoint(50),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyBinaryComparison,
        PruneFilters) :: Nil
  }

  val nullableRelation = LocalRelation('a.int.withNullability(true))
  val nonNullableRelation = LocalRelation('a.int.withNullability(false))

  test("Preserve nullable exprs in general") {
    for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) {
      val plan = nullableRelation.where(e).analyze
      val actual = Optimize.execute(plan)
      val correctAnswer = plan
      comparePlans(actual, correctAnswer)
    }
  }

  test("Preserve non-deterministic exprs") {
    val plan = nonNullableRelation
      .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = plan
    comparePlans(actual, correctAnswer)
  }

  test("Nullable Simplification Primitive: <=>") {
    val plan = nullableRelation.select('a <=> 'a).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze
    comparePlans(actual, correctAnswer)
  }

  test("Non-Nullable Simplification Primitive") {
    val plan = nonNullableRelation
      .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nonNullableRelation
      .select(
        Alias(TrueLiteral, "(a = a)")(),
        Alias(TrueLiteral, "(a <=> a)")(),
        Alias(TrueLiteral, "(a <= a)")(),
        Alias(TrueLiteral, "(a >= a)")(),
        Alias(FalseLiteral, "(a < a)")(),
        Alias(FalseLiteral, "(a > a)")())
      .analyze
    comparePlans(actual, correctAnswer)
  }

  test("Expression Normalization") {
    val plan = nonNullableRelation.where(
      'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a &&
      DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a))
      .analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nonNullableRelation.analyze
    comparePlans(actual, correctAnswer)
  }
} 
Example 38
Source File: CollapseWindowSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.window(Seq(
        avg(b).as('avg_b),
        sum(b).as('sum_b),
        max(a).as('max_a),
        min(a).as('min_a)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }
} 
Example 39
Source File: RewriteDistinctAggregatesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 40
Source File: CombiningLimitsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class CombiningLimitsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Filter Pushdown", FixedPoint(100),
        ColumnPruning) ::
      Batch("Combine Limit", FixedPoint(10),
        CombineLimits) ::
      Batch("Constant Folding", FixedPoint(10),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyConditionals) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("limits: combines two limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(10)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(5).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines three limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .limit(7)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines two limits after ColumnPruning") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .select('a)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 41
Source File: RemoveAliasOnlyProjectSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.MetadataBuilder

class RemoveAliasOnlyProjectSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("RemoveAliasOnlyProject", FixedPoint(50), RemoveAliasOnlyProject) :: Nil
  }

  test("all expressions in project list are aliased child output") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b as 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, relation)
  }

  test("all expressions in project list are aliased child output but with different order") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('b as 'b, 'a as 'a).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are aliased child output") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, relation)
  }

  test("some expressions in project list are aliased child output but with different order") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('b as 'b, 'a).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are not Alias or Attribute") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b + 1).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are aliased child output but with metadata") {
    val relation = LocalRelation('a.int, 'b.int)
    val metadata = new MetadataBuilder().putString("x", "y").build()
    val aliasWithMeta = Alias('a, "a")(explicitMetadata = Some(metadata))
    val query = relation.select(aliasWithMeta, 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }
} 
Example 42
Source File: SimplifyStringCaseConversionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class SimplifyStringCaseConversionSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Simplify CaseConversionExpressions", Once,
        SimplifyCaseConversionExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify UPPER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Upper('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify UPPER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Lower('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Upper('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Lower('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 43
Source File: CollapseRepartitionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseRepartitionSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseRepartition", FixedPoint(10),
        CollapseRepartition) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two adjacent repartitions into one") {
    val query = testRelation
      .repartition(10)
      .repartition(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.repartition(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartition and repartitionBy into one") {
    val query = testRelation
      .repartition(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartitionBy and repartition into one") {
    val query = testRelation
      .distribute('a)(20)
      .repartition(10)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(10).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two adjacent repartitionBys into one") {
    val query = testRelation
      .distribute('b)(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 44
Source File: EliminateSubqueryAliasesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) :: Nil
  }

  private def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private def afterOptimization(plan: LogicalPlan): LogicalPlan = {
    Optimize.execute(analysis.SimpleAnalyzer.execute(plan))
  }

  test("eliminate top level subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = SubqueryAlias("a", input, None)
    comparePlans(afterOptimization(query), input)
  }

  test("eliminate mid-tree subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral, SubqueryAlias("a", input, None))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }

  test("eliminate multiple subqueries") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral,
      SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input, None), None), None))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }
} 
Example 45
Source File: ReplaceOperatorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReplaceOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Replace Operators", FixedPoint(100),
        ReplaceDistinctWithAggregate,
        ReplaceExceptWithAntiJoin,
        ReplaceIntersectWithSemiJoin) :: Nil
  }

  test("replace Intersect with Left-semi Join") {
    val table1 = LocalRelation('a.int, 'b.int)
    val table2 = LocalRelation('c.int, 'd.int)

    val query = Intersect(table1, table2)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      Aggregate(table1.output, table1.output,
        Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("replace Except with Left-anti Join") {
    val table1 = LocalRelation('a.int, 'b.int)
    val table2 = LocalRelation('c.int, 'd.int)

    val query = Except(table1, table2)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      Aggregate(table1.output, table1.output,
        Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("replace Distinct with Aggregate") {
    val input = LocalRelation('a.int, 'b.int)

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)
  }
} 
Example 46
Source File: EliminateSerializationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor

case class OtherTuple(_1: Int, _2: Int)

class EliminateSerializationSuite extends PlanTest {
  private object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Serialization", FixedPoint(100),
        EliminateSerialization) :: Nil
  }

  implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
  implicit private def intEncoder = ExpressionEncoder[Int]()

  test("back to back serialization") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('obj.as("obj")).analyze
    comparePlans(optimized, expected)
  }

  test("back to back serialization with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }

  test("back to back serialization in AppendColumns") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze

    val optimized = Optimize.execute(plan)

    val expected = AppendColumnsWithObject(
      func.asInstanceOf[Any => Any],
      productEncoder[(Int, Int)].namedExpressions,
      intEncoder.namedExpressions,
      input).analyze

    comparePlans(optimized, expected)
  }

  test("back to back serialization in AppendColumns with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze

    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }
} 
Example 47
Source File: OptimizeCodegenSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class OptimizeCodegenSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  test("Codegen only when the number of branches is small.") {
    assertEquivalent(
      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())

    assertEquivalent(
      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)),
      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)))
  }

  test("Nested CaseWhen Codegen.") {
    assertEquivalent(
      CaseWhen(
        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))),
        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
      CaseWhen(
        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))),
        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
  }

  test("Multiple CaseWhen in one operator.") {
    val plan = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
    val correctAnswer = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, correctAnswer)
  }

  test("Multiple CaseWhen in different operators") {
    val plan = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      .where(
        LessThan(
          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      ).analyze
    val correctAnswer = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      .where(
        LessThan(
          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      ).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, correctAnswer)
  }
} 
Example 48
Source File: ComputeCurrentTimeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 49
Source File: SimplifyConditionalSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
    assertEquivalent(
      If(TrueLiteral, Literal(10), Literal(20)),
      Literal(10))

    assertEquivalent(
      If(FalseLiteral, Literal(10), Literal(20)),
      Literal(20))

    assertEquivalent(
      If(Literal.create(null, NullType), Literal(10), Literal(20)),
      Literal(20))
  }

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
    assertEquivalent(
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))
  }

  test("remove entire CaseWhen if only the else branch is reachable") {
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
      Literal(30))

    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))
  }

  test("remove entire CaseWhen if the first branch is always true") {
    assertEquivalent(
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
      Literal(5))

    // Test branch elimination and simplification in combination
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),
      Literal(5))

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
    assertEquivalent(
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
  }
} 
Example 50
Source File: SimplifyCastsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class SimplifyCastsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil
  }

  test("non-nullable element array to nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('a.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable element to non-nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, true)))
    val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }

  test("non-nullable value map to nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, false)))
    val plan = input.select('m.cast(MapType(StringType, StringType, true))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('m.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable value map to non-nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, true)))
    val plan = input.select('m.cast(MapType(StringType, StringType, false))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }
} 
Example 51
Source File: ReorderAssociativeOperatorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 52
Source File: AggregateOptimizeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val input = LocalRelation('a.int, 'b.int, 'c.int)
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 53
Source File: FiltersSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import java.util.Collections

import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.SparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._


class FiltersSuite extends SparkFunSuite with Logging with PlanTest {
  private val shim = new Shim_v0_13

  private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test")
  private val varCharCol = new FieldSchema()
  varCharCol.setName("varchar")
  varCharCol.setType(serdeConstants.VARCHAR_TYPE_NAME)
  testTable.setPartCols(Collections.singletonList(varCharCol))

  filterTest("string filter",
    (a("stringcol", StringType) > Literal("test")) :: Nil,
    "stringcol > \"test\"")

  filterTest("string filter backwards",
    (Literal("test") > a("stringcol", StringType)) :: Nil,
    "\"test\" > stringcol")

  filterTest("int filter",
    (a("intcol", IntegerType) === Literal(1)) :: Nil,
    "intcol = 1")

  filterTest("int filter backwards",
    (Literal(1) === a("intcol", IntegerType)) :: Nil,
    "1 = intcol")

  filterTest("int and string filter",
    (Literal(1) === a("intcol", IntegerType)) :: (Literal("a") === a("strcol", IntegerType)) :: Nil,
    "1 = intcol and \"a\" = strcol")

  filterTest("skip varchar",
    (Literal("") === a("varchar", StringType)) :: Nil,
    "")

  filterTest("SPARK-19912 String literals should be escaped for Hive metastore partition pruning",
    (a("stringcol", StringType) === Literal("p1\" and q=\"q1")) ::
      (Literal("p2\" and q=\"q2") === a("stringcol", StringType)) :: Nil,
    """stringcol = 'p1" and q="q1' and 'p2" and q="q2' = stringcol""")

  filterTest("SPARK-24879 null literals should be ignored for IN constructs",
    (a("intcol", IntegerType) in (Literal(1), Literal(null))) :: Nil,
    "(intcol = 1)")

  // Applying the predicate `x IN (NULL)` should return an empty set, but since this optimization
  // will be applied by Catalyst, this filter converter does not need to account for this.
  filterTest("SPARK-24879 IN predicates with only NULLs will not cause a NPE",
    (a("intcol", IntegerType) in Literal(null)) :: Nil,
    "")

  filterTest("typecast null literals should not be pushed down in simple predicates",
    (a("intcol", IntegerType) === Literal(null, IntegerType)) :: Nil,
    "")

  private def filterTest(name: String, filters: Seq[Expression], result: String) = {
    test(name) {
      withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") {
        val converted = shim.convertFilters(testTable, filters)
        if (converted != result) {
          fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'")
        }
      }
    }
  }

  test("turn on/off ADVANCED_PARTITION_PREDICATE_PUSHDOWN") {
    import org.apache.spark.sql.catalyst.dsl.expressions._
    Seq(true, false).foreach { enabled =>
      withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> enabled.toString) {
        val filters =
          (Literal(1) === a("intcol", IntegerType) ||
            Literal(2) === a("intcol", IntegerType)) :: Nil
        val converted = shim.convertFilters(testTable, filters)
        if (enabled) {
          assert(converted == "(1 = intcol or 2 = intcol)")
        } else {
          assert(converted.isEmpty)
        }
      }
    }
  }

  private def a(name: String, dataType: DataType) = AttributeReference(name, dataType)()
} 
Example 54
Source File: SchemaPruningTest.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.internal.SQLConf.NESTED_SCHEMA_PRUNING_ENABLED


private[sql] trait SchemaPruningTest extends PlanTest with BeforeAndAfterAll {
  private var originalConfSchemaPruningEnabled = false

  override protected def beforeAll(): Unit = {
    originalConfSchemaPruningEnabled = conf.nestedSchemaPruningEnabled
    conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, true)
    super.beforeAll()
  }

  override protected def afterAll(): Unit = {
    try {
      super.afterAll()
    } finally {
      conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, originalConfSchemaPruningEnabled)
    }
  }
} 
Example 55
Source File: ResolveLambdaVariablesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{ArrayType, IntegerType}


class ResolveLambdaVariablesSuite extends PlanTest {
  import org.apache.spark.sql.catalyst.dsl.expressions._
  import org.apache.spark.sql.catalyst.dsl.plans._

  object Analyzer extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables(conf)) :: Nil
  }

  private val key = 'key.int
  private val values1 = 'values1.array(IntegerType)
  private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType)))
  private val data = LocalRelation(Seq(key, values1, values2))
  private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true)
  private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true)
  private val lvArray = NamedLambdaVariable("x", ArrayType(IntegerType), nullable = true)

  private def plan(e: Expression): LogicalPlan = data.select(e.as("res"))

  private def checkExpression(e1: Expression, e2: Expression): Unit = {
    comparePlans(Analyzer.execute(plan(e1)), plan(e2))
  }

  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

  test("resolution - no op") {
    checkExpression(key, key)
  }

  test("resolution - simple") {
    val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
    val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
    checkExpression(in, out)
  }

  test("resolution - nested") {
    val in = ArrayTransform(values2, LambdaFunction(
      ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
    val out = ArrayTransform(values2, LambdaFunction(
      ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
    checkExpression(in, out)
  }

  test("resolution - hidden") {
    val in = ArrayTransform(values1, key)
    val out = ArrayTransform(values1, LambdaFunction(key, lvHiddenInt :: Nil, hidden = true))
    checkExpression(in, out)
  }

  test("fail - name collisions") {
    val p = plan(ArrayTransform(values1,
      LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
    val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
    assert(msg.contains("arguments should not have names that are semantically the same"))
  }

  test("fail - lambda arguments") {
    val p = plan(ArrayTransform(values1,
      LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
    val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
    assert(msg.contains("does not match the number of arguments expected"))
  }
} 
Example 56
Source File: AnalysisTest.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import java.net.URI
import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

trait AnalysisTest extends PlanTest {

  protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true)
  protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false)

  private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
    val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)
    val catalog = new SessionCatalog(new InMemoryCatalog, FunctionRegistry.builtin, conf)
    catalog.createDatabase(
      CatalogDatabase("default", "", new URI("loc"), Map.empty),
      ignoreIfExists = false)
    catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
    catalog.createTempView("TaBlE2", TestRelations.testRelation2, overrideIfExists = true)
    catalog.createTempView("TaBlE3", TestRelations.testRelation3, overrideIfExists = true)
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.executeAndCheck(inputPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected override def comparePlans(
      plan1: LogicalPlan,
      plan2: LogicalPlan,
      checkAnalysis: Boolean = false): Unit = {
    // Analysis tests may have not been fully resolved, so skip checkAnalysis.
    super.comparePlans(plan1, plan2, checkAnalysis)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val analysisAttempt = analyzer.execute(inputPlan)
    try analyzer.checkAnalysis(analysisAttempt) catch {
      case a: AnalysisException =>
        fail(
          s"""
            |Failed to Analyze Plan
            |$inputPlan
            |
            |Partial Analysis
            |$analysisAttempt
          """.stripMargin, a)
    }
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val e = intercept[AnalysisException] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }

    if (!expectedErrors.map(_.toLowerCase(Locale.ROOT)).forall(
        e.getMessage.toLowerCase(Locale.ROOT).contains)) {
      fail(
        s"""Exception message should contain the following substrings:
           |
           |  ${expectedErrors.mkString("\n  ")}
           |
           |Actual exception message:
           |
           |  ${e.getMessage}
         """.stripMargin)
    }
  }
} 
Example 57
Source File: LookupFunctionsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import java.net.URI

import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.internal.SQLConf

class LookupFunctionsSuite extends PlanTest {

  test("SPARK-23486: the functionExists for the Persistent function check") {
    val externalCatalog = new CustomInMemoryCatalog
    val conf = new SQLConf()
    val catalog = new SessionCatalog(externalCatalog, FunctionRegistry.builtin, conf)
    val analyzer = {
      catalog.createDatabase(
        CatalogDatabase("default", "", new URI("loc"), Map.empty),
        ignoreIfExists = false)
      new Analyzer(catalog, conf)
    }

    def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
    val unresolvedPersistentFunc = UnresolvedFunction("func", Seq.empty, false)
    val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
    val plan = Project(
      Seq(Alias(unresolvedPersistentFunc, "call1")(), Alias(unresolvedPersistentFunc, "call2")(),
        Alias(unresolvedPersistentFunc, "call3")(), Alias(unresolvedRegisteredFunc, "call4")(),
        Alias(unresolvedRegisteredFunc, "call5")()),
      table("TaBlE"))
    analyzer.LookupFunctions.apply(plan)

    assert(externalCatalog.getFunctionExistsCalledTimes == 1)
    assert(analyzer.LookupFunctions.normalizeFuncName
      (unresolvedPersistentFunc.name).database == Some("default"))
  }

  test("SPARK-23486: the functionExists for the Registered function check") {
    val externalCatalog = new InMemoryCatalog
    val conf = new SQLConf()
    val customerFunctionReg = new CustomerFunctionRegistry
    val catalog = new SessionCatalog(externalCatalog, customerFunctionReg, conf)
    val analyzer = {
      catalog.createDatabase(
        CatalogDatabase("default", "", new URI("loc"), Map.empty),
        ignoreIfExists = false)
      new Analyzer(catalog, conf)
    }

    def table(ref: String): LogicalPlan = UnresolvedRelation(TableIdentifier(ref))
    val unresolvedRegisteredFunc = UnresolvedFunction("max", Seq.empty, false)
    val plan = Project(
      Seq(Alias(unresolvedRegisteredFunc, "call1")(), Alias(unresolvedRegisteredFunc, "call2")()),
      table("TaBlE"))
    analyzer.LookupFunctions.apply(plan)

    assert(customerFunctionReg.getIsRegisteredFunctionCalledTimes == 2)
    assert(analyzer.LookupFunctions.normalizeFuncName
      (unresolvedRegisteredFunc.name).database == Some("default"))
  }
}

class CustomerFunctionRegistry extends SimpleFunctionRegistry {

  private var isRegisteredFunctionCalledTimes: Int = 0;

  override def functionExists(funcN: FunctionIdentifier): Boolean = synchronized {
    isRegisteredFunctionCalledTimes = isRegisteredFunctionCalledTimes + 1
    true
  }

  def getIsRegisteredFunctionCalledTimes: Int = isRegisteredFunctionCalledTimes
}

class CustomInMemoryCatalog extends InMemoryCatalog {

  private var functionExistsCalledTimes: Int = 0

  override def functionExists(db: String, funcName: String): Boolean = synchronized {
    functionExistsCalledTimes = functionExistsCalledTimes + 1
    true
  }

  def getFunctionExistsCalledTimes: Int = functionExistsCalledTimes
} 
Example 58
Source File: ConvertToLocalRelationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{LessThan, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

  test("Filter on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: Nil)

    val filterAndProjectOnLocal = testRelation
      .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1"))
      .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6)))

    val optimized = Optimize.execute(filterAndProjectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }
} 
Example 59
Source File: PullupCorrelatedPredicatesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PullupCorrelatedPredicatesSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("PullupCorrelatedPredicates", Once,
        PullupCorrelatedPredicates) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.double)
  val testRelation2 = LocalRelation('c.int, 'd.double)

  test("PullupCorrelatedPredicates should not produce unresolved plan") {
    val correlatedSubquery =
      testRelation2
        .where('b < 'd)
        .select('c)
    val outerQuery =
      testRelation
        .where(InSubquery(Seq('a), ListQuery(correlatedSubquery)))
        .select('a).analyze
    assert(outerQuery.resolved)

    val optimized = Optimize.execute(outerQuery)
    assert(optimized.resolved)
  }
} 
Example 60
Source File: LikeSimplificationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{BooleanType, StringType}

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into startsWith and EndsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc\\%def") || ('a like "abc%def"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a like "abc\\%def") ||
        (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("null pattern") {
    val originalQuery = testRelation.where('a like Literal(null, StringType)).analyze
    val optimized = Optimize.execute(originalQuery)
    comparePlans(optimized, testRelation.where(Literal(null, BooleanType)).analyze)
  }
} 
Example 61
Source File: EliminateDistinctSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class EliminateDistinctSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Operator Optimizations", Once,
        EliminateDistinct) :: Nil
  }

  val testRelation = LocalRelation('a.int)

  test("Eliminate Distinct in Max") {
    val query = testRelation
      .select(maxDistinct('a).as('result))
      .analyze
    val answer = testRelation
      .select(max('a).as('result))
      .analyze
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
  }

  test("Eliminate Distinct in Min") {
    val query = testRelation
      .select(minDistinct('a).as('result))
      .analyze
    val answer = testRelation
      .select(min('a).as('result))
      .analyze
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
  }
} 
Example 62
Source File: OptimizerStructuralIntegrityCheckerSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{EmptyFunctionRegistry, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf


class OptimizerStructuralIntegrityCheckerSuite extends PlanTest {

  object OptimizeRuleBreakSI extends Rule[LogicalPlan] {
    def apply(plan: LogicalPlan): LogicalPlan = plan transform {
      case Project(projectList, child) =>
        val newAttr = UnresolvedAttribute("unresolvedAttr")
        Project(projectList ++ Seq(newAttr), child)
    }
  }

  object Optimize extends Optimizer(
    new SessionCatalog(
      new InMemoryCatalog,
      EmptyFunctionRegistry,
      new SQLConf())) {
    val newBatch = Batch("OptimizeRuleBreakSI", Once, OptimizeRuleBreakSI)
    override def defaultBatches: Seq[Batch] = Seq(newBatch) ++ super.defaultBatches
  }

  test("check for invalid plan after execution of rule") {
    val analyzed = Project(Alias(Literal(10), "attr")() :: Nil, OneRowRelation()).analyze
    assert(analyzed.resolved)
    val message = intercept[TreeNodeException[LogicalPlan]] {
      Optimize.execute(analyzed)
    }.getMessage
    val ruleName = OptimizeRuleBreakSI.ruleName
    assert(message.contains(s"After applying rule $ruleName in batch OptimizeRuleBreakSI"))
    assert(message.contains("the structural integrity of the plan is broken"))
  }
} 
Example 63
Source File: CollapseWindowSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      min(a).as('min_a),
      max(a).as('max_a),
      sum(b).as('sum_b),
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }

  test("Don't collapse adjacent windows with dependent columns") {
    val query = testRelation
      .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1)
      .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1)
      .analyze

    val expected = query.analyze
    val optimized = Optimize.execute(query.analyze)
    comparePlans(optimized, expected)
  }
} 
Example 64
Source File: RewriteDistinctAggregatesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 65
Source File: CombiningLimitsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class CombiningLimitsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Filter Pushdown", FixedPoint(100),
        ColumnPruning) ::
      Batch("Combine Limit", FixedPoint(10),
        CombineLimits) ::
      Batch("Constant Folding", FixedPoint(10),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyConditionals) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("limits: combines two limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(10)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(5).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines three limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .limit(7)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines two limits after ColumnPruning") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .select('a)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 66
Source File: EliminateMapObjectsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class EliminateMapObjectsSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = {
      Batch("EliminateMapObjects", FixedPoint(50),
        NullPropagation,
        SimplifyCasts,
        EliminateMapObjects) :: Nil
    }
  }

  implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]()
  implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]()

  test("SPARK-20254: Remove unnecessary data conversion for primitive array") {
    val intObjType = ObjectType(classOf[Array[Int]])
    val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val intQuery = intInput.deserialize[Array[Int]].analyze
    val intOptimized = Optimize.execute(intQuery)
    val intExpected = DeserializeToObject(
      Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
      AttributeReference("obj", intObjType, true)(), intInput)
    comparePlans(intOptimized, intExpected)

    val doubleObjType = ObjectType(classOf[Array[Double]])
    val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
    val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
    val doubleOptimized = Optimize.execute(doubleQuery)
    val doubleExpected = DeserializeToObject(
      Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
      AttributeReference("obj", doubleObjType, true)(), doubleInput)
    comparePlans(doubleOptimized, doubleExpected)
  }
} 
Example 67
Source File: SimplifyStringCaseConversionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class SimplifyStringCaseConversionSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Simplify CaseConversionExpressions", Once,
        SimplifyCaseConversionExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify UPPER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Upper('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify UPPER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Lower('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Upper('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Lower('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 68
Source File: RewriteSubquerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.ListQuery
import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class RewriteSubquerySuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Column Pruning", FixedPoint(100), ColumnPruning) ::
      Batch("Rewrite Subquery", FixedPoint(1),
        RewritePredicateSubquery,
        ColumnPruning,
        CollapseProject,
        RemoveRedundantProject) :: Nil
  }

  test("Column pruning after rewriting predicate subquery") {
    val relation = LocalRelation('a.int, 'b.int)
    val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int)

    val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = relation
      .select('a)
      .join(relInSubquery.select('x), LeftSemi, Some('a === 'x))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

} 
Example 69
Source File: EliminateSubqueryAliasesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) :: Nil
  }

  private def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
    comparePlans(actual, correctAnswer)
  }

  private def afterOptimization(plan: LogicalPlan): LogicalPlan = {
    Optimize.execute(analysis.SimpleAnalyzer.execute(plan))
  }

  test("eliminate top level subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = SubqueryAlias("a", input)
    comparePlans(afterOptimization(query), input)
  }

  test("eliminate mid-tree subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral, SubqueryAlias("a", input))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }

  test("eliminate multiple subqueries") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral,
      SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input))))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }
} 
Example 70
Source File: PushProjectThroughUnionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PushProjectThroughUnionSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Optimizer Batch", FixedPoint(100),
      PushProjectionThroughUnion,
      FoldablePropagation) :: Nil
  }

  test("SPARK-25450 PushProjectThroughUnion rule uses the same exprId for project expressions " +
    "in each Union child, causing mistakes in constant propagation") {
    val testRelation1 = LocalRelation('a.string, 'b.int, 'c.string)
    val testRelation2 = LocalRelation('d.string, 'e.int, 'f.string)
    val query = testRelation1
      .union(testRelation2.select("bar".as("d"), 'e, 'f))
      .select('a.as("n"))
      .select('n, "dummy").analyze
    val optimized = Optimize.execute(query)

    val expected = testRelation1
      .select('a.as("n"))
      .select('n, "dummy")
      .union(testRelation2
        .select("bar".as("d"), 'e, 'f)
        .select("bar".as("n"))
        .select("bar".as("n"), "dummy")).analyze

    comparePlans(optimized, expected)
  }
} 
Example 71
Source File: EliminateSerializationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

case class OtherTuple(_1: Int, _2: Int)

class EliminateSerializationSuite extends PlanTest {
  private object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Serialization", FixedPoint(100),
        EliminateSerialization) :: Nil
  }

  implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
  implicit private def intEncoder = ExpressionEncoder[Int]()

  test("back to back serialization") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('obj.as("obj")).analyze
    comparePlans(optimized, expected)
  }

  test("back to back serialization with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }

  test("back to back serialization in AppendColumns") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze

    val optimized = Optimize.execute(plan)

    val expected = AppendColumnsWithObject(
      func.asInstanceOf[Any => Any],
      productEncoder[(Int, Int)].namedExpressions,
      intEncoder.namedExpressions,
      input).analyze

    comparePlans(optimized, expected)
  }

  test("back to back serialization in AppendColumns with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze

    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }
} 
Example 72
Source File: ComputeCurrentTimeSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 73
Source File: CombineConcatsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class CombineConcatsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("CombineConcatsSuite", FixedPoint(50), CombineConcats) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
    comparePlans(actual, correctAnswer)
  }

  def str(s: String): Literal = Literal(s)
  def binary(s: String): Literal = Literal(s.getBytes)

  test("combine nested Concat exprs") {
    assertEquivalent(
      Concat(
        Concat(str("a") :: str("b") :: Nil) ::
        str("c") ::
        str("d") ::
        Nil),
      Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
    assertEquivalent(
      Concat(
        str("a") ::
        Concat(str("b") :: str("c") :: Nil) ::
        str("d") ::
        Nil),
      Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
    assertEquivalent(
      Concat(
        str("a") ::
        str("b") ::
        Concat(str("c") :: str("d") :: Nil) ::
        Nil),
      Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
    assertEquivalent(
      Concat(
        Concat(
          str("a") ::
          Concat(
            str("b") ::
            Concat(str("c") :: str("d") :: Nil) ::
            Nil) ::
          Nil) ::
        Nil),
      Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
  }

  test("combine string and binary exprs") {
    assertEquivalent(
      Concat(
        Concat(str("a") :: str("b") :: Nil) ::
        Concat(binary("c") :: binary("d") :: Nil) ::
        Nil),
      Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil))
  }
} 
Example 74
Source File: UpdateNullabilityInAttributeReferencesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class UpdateNullabilityInAttributeReferencesSuite extends PlanTest {

  object Optimizer extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Constant Folding", FixedPoint(10),
          NullPropagation,
          ConstantFolding,
          BooleanSimplification,
          SimplifyConditionals,
          SimplifyBinaryComparison,
          SimplifyExtractValueOps) ::
      Batch("UpdateAttributeReferences", Once,
        UpdateNullabilityInAttributeReferences) :: Nil
  }

  test("update nullability in AttributeReference")  {
    val rel = LocalRelation('a.long.notNull)
    // In the 'original' plans below, the Aggregate node produced by groupBy() has a
    // nullable AttributeReference to `b`, because both array indexing and map lookup are
    // nullable expressions. After optimization, the same attribute is now non-nullable,
    // but the AttributeReference is not updated to reflect this. So, we need to update nullability
    // by the `UpdateNullabilityInAttributeReferences` rule.
    val original = rel
      .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b")
      .groupBy($"b")("1")
    val expected = rel.select('a as "b").groupBy($"b")("1").analyze
    val optimized = Optimizer.execute(original.analyze)
    comparePlans(optimized, expected)
  }
} 
Example 75
Source File: SimplifyCastsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class SimplifyCastsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil
  }

  test("non-nullable element array to nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('a.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable element to non-nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, true)))
    val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    // Though cast from `ArrayType(IntegerType, true)` to `ArrayType(IntegerType, false)` is not
    // allowed, here we just ensure that `SimplifyCasts` rule respect the plan.
    comparePlans(optimized, plan, checkAnalysis = false)
  }

  test("non-nullable value map to nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, false)))
    val plan = input.select('m.cast(MapType(StringType, StringType, true))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('m.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable value map to non-nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, true)))
    val plan = input.select('m.cast(MapType(StringType, StringType, false))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    // Though cast from `MapType(StringType, StringType, true)` to
    // `MapType(StringType, StringType, false)` is not allowed, here we just ensure that
    // `SimplifyCasts` rule respect the plan.
    comparePlans(optimized, plan, checkAnalysis = false)
  }
} 
Example 76
Source File: ReorderAssociativeOperatorSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 77
Source File: AggregateOptimizeSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}

class AggregateOptimizeSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 78
Source File: AnalysisTest.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._

trait AnalysisTest extends PlanTest {

  protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true)
  protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false)

  private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
    val conf = new SimpleCatalystConf(caseSensitive)
    val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
    catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.execute(inputPlan)
    analyzer.checkAnalysis(actualPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val analysisAttempt = analyzer.execute(inputPlan)
    try analyzer.checkAnalysis(analysisAttempt) catch {
      case a: AnalysisException =>
        fail(
          s"""
            |Failed to Analyze Plan
            |$inputPlan
            |
            |Partial Analysis
            |$analysisAttempt
          """.stripMargin, a)
    }
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val e = intercept[AnalysisException] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }

    if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) {
      fail(
        s"""Exception message should contain the following substrings:
           |
           |  ${expectedErrors.mkString("\n  ")}
           |
           |Actual exception message:
           |
           |  ${e.getMessage}
         """.stripMargin)
    }
  }
} 
Example 79
Source File: ResolveInlineTablesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.{LongType, NullType}


class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables.validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables.validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables.convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables.convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables.convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 80
Source File: ConvertToLocalRelationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 81
Source File: LikeSimplificationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into startsWith and EndsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc\\%def") || ('a like "abc%def"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a like "abc\\%def") ||
        (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 82
Source File: BinaryComparisonSimplificationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubqueryAliases) ::
      Batch("Constant Folding", FixedPoint(50),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyBinaryComparison,
        PruneFilters) :: Nil
  }

  val nullableRelation = LocalRelation('a.int.withNullability(true))
  val nonNullableRelation = LocalRelation('a.int.withNullability(false))

  test("Preserve nullable exprs in general") {
    for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) {
      val plan = nullableRelation.where(e).analyze
      val actual = Optimize.execute(plan)
      val correctAnswer = plan
      comparePlans(actual, correctAnswer)
    }
  }

  test("Preserve non-deterministic exprs") {
    val plan = nonNullableRelation
      .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = plan
    comparePlans(actual, correctAnswer)
  }

  test("Nullable Simplification Primitive: <=>") {
    val plan = nullableRelation.select('a <=> 'a).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze
    comparePlans(actual, correctAnswer)
  }

  test("Non-Nullable Simplification Primitive") {
    val plan = nonNullableRelation
      .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nonNullableRelation
      .select(
        Alias(TrueLiteral, "(a = a)")(),
        Alias(TrueLiteral, "(a <=> a)")(),
        Alias(TrueLiteral, "(a <= a)")(),
        Alias(TrueLiteral, "(a >= a)")(),
        Alias(FalseLiteral, "(a < a)")(),
        Alias(FalseLiteral, "(a > a)")())
      .analyze
    comparePlans(actual, correctAnswer)
  }

  test("Expression Normalization") {
    val plan = nonNullableRelation.where(
      'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a &&
      DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a))
      .analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nonNullableRelation.analyze
    comparePlans(actual, correctAnswer)
  }
} 
Example 83
Source File: CollapseWindowSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      min(a).as('min_a),
      max(a).as('max_a),
      sum(b).as('sum_b),
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }
} 
Example 84
Source File: RewriteDistinctAggregatesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 85
Source File: CombiningLimitsSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class CombiningLimitsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Filter Pushdown", FixedPoint(100),
        ColumnPruning) ::
      Batch("Combine Limit", FixedPoint(10),
        CombineLimits) ::
      Batch("Constant Folding", FixedPoint(10),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyConditionals) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("limits: combines two limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(10)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(5).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines three limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .limit(7)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines two limits after ColumnPruning") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .select('a)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 86
Source File: RemoveAliasOnlyProjectSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.MetadataBuilder

class RemoveAliasOnlyProjectSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("RemoveAliasOnlyProject", FixedPoint(50), RemoveAliasOnlyProject) :: Nil
  }

  test("all expressions in project list are aliased child output") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b as 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, relation)
  }

  test("all expressions in project list are aliased child output but with different order") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('b as 'b, 'a as 'a).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are aliased child output") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, relation)
  }

  test("some expressions in project list are aliased child output but with different order") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('b as 'b, 'a).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are not Alias or Attribute") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b + 1).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are aliased child output but with metadata") {
    val relation = LocalRelation('a.int, 'b.int)
    val metadata = new MetadataBuilder().putString("x", "y").build()
    val aliasWithMeta = Alias('a, "a")(explicitMetadata = Some(metadata))
    val query = relation.select(aliasWithMeta, 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }
} 
Example 87
Source File: SimplifyStringCaseConversionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class SimplifyStringCaseConversionSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Simplify CaseConversionExpressions", Once,
        SimplifyCaseConversionExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify UPPER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Upper('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify UPPER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Lower('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Upper('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Lower('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 88
Source File: CollapseRepartitionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseRepartitionSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseRepartition", FixedPoint(10),
        CollapseRepartition) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two adjacent repartitions into one") {
    val query = testRelation
      .repartition(10)
      .repartition(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.repartition(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartition and repartitionBy into one") {
    val query = testRelation
      .repartition(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartitionBy and repartition into one") {
    val query = testRelation
      .distribute('a)(20)
      .repartition(10)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(10).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two adjacent repartitionBys into one") {
    val query = testRelation
      .distribute('b)(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 89
Source File: EliminateSubqueryAliasesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) :: Nil
  }

  private def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private def afterOptimization(plan: LogicalPlan): LogicalPlan = {
    Optimize.execute(analysis.SimpleAnalyzer.execute(plan))
  }

  test("eliminate top level subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = SubqueryAlias("a", input, None)
    comparePlans(afterOptimization(query), input)
  }

  test("eliminate mid-tree subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral, SubqueryAlias("a", input, None))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }

  test("eliminate multiple subqueries") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral,
      SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input, None), None), None))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }
} 
Example 90
Source File: ReplaceOperatorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReplaceOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Replace Operators", FixedPoint(100),
        ReplaceDistinctWithAggregate,
        ReplaceExceptWithAntiJoin,
        ReplaceIntersectWithSemiJoin) :: Nil
  }

  test("replace Intersect with Left-semi Join") {
    val table1 = LocalRelation('a.int, 'b.int)
    val table2 = LocalRelation('c.int, 'd.int)

    val query = Intersect(table1, table2)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      Aggregate(table1.output, table1.output,
        Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("replace Except with Left-anti Join") {
    val table1 = LocalRelation('a.int, 'b.int)
    val table2 = LocalRelation('c.int, 'd.int)

    val query = Except(table1, table2)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      Aggregate(table1.output, table1.output,
        Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("replace Distinct with Aggregate") {
    val input = LocalRelation('a.int, 'b.int)

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)
  }
} 
Example 91
Source File: EliminateSerializationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor

case class OtherTuple(_1: Int, _2: Int)

class EliminateSerializationSuite extends PlanTest {
  private object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Serialization", FixedPoint(100),
        EliminateSerialization) :: Nil
  }

  implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
  implicit private def intEncoder = ExpressionEncoder[Int]()

  test("back to back serialization") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('obj.as("obj")).analyze
    comparePlans(optimized, expected)
  }

  test("back to back serialization with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }

  test("back to back serialization in AppendColumns") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze

    val optimized = Optimize.execute(plan)

    val expected = AppendColumnsWithObject(
      func.asInstanceOf[Any => Any],
      productEncoder[(Int, Int)].namedExpressions,
      intEncoder.namedExpressions,
      input).analyze

    comparePlans(optimized, expected)
  }

  test("back to back serialization in AppendColumns with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze

    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }
} 
Example 92
Source File: OptimizeCodegenSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class OptimizeCodegenSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  test("Codegen only when the number of branches is small.") {
    assertEquivalent(
      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())

    assertEquivalent(
      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)),
      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)))
  }

  test("Nested CaseWhen Codegen.") {
    assertEquivalent(
      CaseWhen(
        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))),
        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
      CaseWhen(
        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))),
        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
  }

  test("Multiple CaseWhen in one operator.") {
    val plan = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
    val correctAnswer = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, correctAnswer)
  }

  test("Multiple CaseWhen in different operators") {
    val plan = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      .where(
        LessThan(
          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      ).analyze
    val correctAnswer = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      .where(
        LessThan(
          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      ).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, correctAnswer)
  }
} 
Example 93
Source File: ComputeCurrentTimeSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 94
Source File: SimplifyConditionalSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
    assertEquivalent(
      If(TrueLiteral, Literal(10), Literal(20)),
      Literal(10))

    assertEquivalent(
      If(FalseLiteral, Literal(10), Literal(20)),
      Literal(20))

    assertEquivalent(
      If(Literal.create(null, NullType), Literal(10), Literal(20)),
      Literal(20))
  }

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
    assertEquivalent(
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))
  }

  test("remove entire CaseWhen if only the else branch is reachable") {
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
      Literal(30))

    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))
  }

  test("remove entire CaseWhen if the first branch is always true") {
    assertEquivalent(
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
      Literal(5))

    // Test branch elimination and simplification in combination
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),
      Literal(5))

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
    assertEquivalent(
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
  }
} 
Example 95
Source File: SimplifyCastsSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class SimplifyCastsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil
  }

  test("non-nullable element array to nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('a.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable element to non-nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, true)))
    val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }

  test("non-nullable value map to nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, false)))
    val plan = input.select('m.cast(MapType(StringType, StringType, true))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('m.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable value map to non-nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, true)))
    val plan = input.select('m.cast(MapType(StringType, StringType, false))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }
} 
Example 96
Source File: ReorderAssociativeOperatorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 97
Source File: AggregateOptimizeSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val input = LocalRelation('a.int, 'b.int, 'c.int)
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 98
Source File: AnalysisTest.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._

trait AnalysisTest extends PlanTest {

  protected val caseSensitiveAnalyzer = makeAnalyzer(caseSensitive = true)
  protected val caseInsensitiveAnalyzer = makeAnalyzer(caseSensitive = false)

  private def makeAnalyzer(caseSensitive: Boolean): Analyzer = {
    val conf = new SimpleCatalystConf(caseSensitive)
    val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
    catalog.createTempView("TaBlE", TestRelations.testRelation, overrideIfExists = true)
    new Analyzer(catalog, conf) {
      override val extendedResolutionRules = EliminateSubqueryAliases :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.execute(inputPlan)
    analyzer.checkAnalysis(actualPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val analysisAttempt = analyzer.execute(inputPlan)
    try analyzer.checkAnalysis(analysisAttempt) catch {
      case a: AnalysisException =>
        fail(
          s"""
            |Failed to Analyze Plan
            |$inputPlan
            |
            |Partial Analysis
            |$analysisAttempt
          """.stripMargin, a)
    }
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val e = intercept[AnalysisException] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }

    if (!expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) {
      fail(
        s"""Exception message should contain the following substrings:
           |
           |  ${expectedErrors.mkString("\n  ")}
           |
           |Actual exception message:
           |
           |  ${e.getMessage}
         """.stripMargin)
    }
  }
} 
Example 99
Source File: ResolveInlineTablesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.{LongType, NullType}


class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables.validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables.validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables.convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables.convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables.convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 100
Source File: ConvertToLocalRelationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 101
Source File: LikeSimplificationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into startsWith and EndsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc\\%def") || ('a like "abc%def"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a like "abc\\%def") ||
        (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def"))))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 102
Source File: BinaryComparisonSimplificationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._

class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubqueryAliases) ::
      Batch("Constant Folding", FixedPoint(50),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyBinaryComparison,
        PruneFilters) :: Nil
  }

  val nullableRelation = LocalRelation('a.int.withNullability(true))
  val nonNullableRelation = LocalRelation('a.int.withNullability(false))

  test("Preserve nullable exprs in general") {
    for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) {
      val plan = nullableRelation.where(e).analyze
      val actual = Optimize.execute(plan)
      val correctAnswer = plan
      comparePlans(actual, correctAnswer)
    }
  }

  test("Preserve non-deterministic exprs") {
    val plan = nonNullableRelation
      .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = plan
    comparePlans(actual, correctAnswer)
  }

  test("Nullable Simplification Primitive: <=>") {
    val plan = nullableRelation.select('a <=> 'a).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze
    comparePlans(actual, correctAnswer)
  }

  test("Non-Nullable Simplification Primitive") {
    val plan = nonNullableRelation
      .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nonNullableRelation
      .select(
        Alias(TrueLiteral, "(a = a)")(),
        Alias(TrueLiteral, "(a <=> a)")(),
        Alias(TrueLiteral, "(a <= a)")(),
        Alias(TrueLiteral, "(a >= a)")(),
        Alias(FalseLiteral, "(a < a)")(),
        Alias(FalseLiteral, "(a > a)")())
      .analyze
    comparePlans(actual, correctAnswer)
  }

  test("Expression Normalization") {
    val plan = nonNullableRelation.where(
      'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a &&
      DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a))
      .analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = nonNullableRelation.analyze
    comparePlans(actual, correctAnswer)
  }
} 
Example 103
Source File: CollapseWindowSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      min(a).as('min_a),
      max(a).as('max_a),
      sum(b).as('sum_b),
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }
} 
Example 104
Source File: RewriteDistinctAggregatesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 105
Source File: CombiningLimitsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class CombiningLimitsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Filter Pushdown", FixedPoint(100),
        ColumnPruning) ::
      Batch("Combine Limit", FixedPoint(10),
        CombineLimits) ::
      Batch("Constant Folding", FixedPoint(10),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyConditionals) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("limits: combines two limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(10)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(5).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines three limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .limit(7)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines two limits after ColumnPruning") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .select('a)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 106
Source File: RemoveAliasOnlyProjectSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.MetadataBuilder

class RemoveAliasOnlyProjectSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("RemoveAliasOnlyProject", FixedPoint(50), RemoveAliasOnlyProject) :: Nil
  }

  test("all expressions in project list are aliased child output") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b as 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, relation)
  }

  test("all expressions in project list are aliased child output but with different order") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('b as 'b, 'a as 'a).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are aliased child output") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, relation)
  }

  test("some expressions in project list are aliased child output but with different order") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('b as 'b, 'a).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are not Alias or Attribute") {
    val relation = LocalRelation('a.int, 'b.int)
    val query = relation.select('a as 'a, 'b + 1).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }

  test("some expressions in project list are aliased child output but with metadata") {
    val relation = LocalRelation('a.int, 'b.int)
    val metadata = new MetadataBuilder().putString("x", "y").build()
    val aliasWithMeta = Alias('a, "a")(explicitMetadata = Some(metadata))
    val query = relation.select(aliasWithMeta, 'b).analyze
    val optimized = Optimize.execute(query)
    comparePlans(optimized, query)
  }
} 
Example 107
Source File: SimplifyStringCaseConversionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._

class SimplifyStringCaseConversionSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Simplify CaseConversionExpressions", Once,
        SimplifyCaseConversionExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify UPPER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Upper('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify UPPER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Lower('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Upper('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Lower('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 108
Source File: CollapseRepartitionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseRepartitionSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseRepartition", FixedPoint(10),
        CollapseRepartition) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two adjacent repartitions into one") {
    val query = testRelation
      .repartition(10)
      .repartition(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.repartition(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartition and repartitionBy into one") {
    val query = testRelation
      .repartition(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartitionBy and repartition into one") {
    val query = testRelation
      .distribute('a)(20)
      .repartition(10)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(10).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two adjacent repartitionBys into one") {
    val query = testRelation
      .distribute('b)(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 109
Source File: EliminateSubqueryAliasesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) :: Nil
  }

  private def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private def afterOptimization(plan: LogicalPlan): LogicalPlan = {
    Optimize.execute(analysis.SimpleAnalyzer.execute(plan))
  }

  test("eliminate top level subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = SubqueryAlias("a", input, None)
    comparePlans(afterOptimization(query), input)
  }

  test("eliminate mid-tree subquery") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral, SubqueryAlias("a", input, None))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }

  test("eliminate multiple subqueries") {
    val input = LocalRelation('a.int, 'b.int)
    val query = Filter(TrueLiteral,
      SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input, None), None), None))
    comparePlans(
      afterOptimization(query),
      Filter(TrueLiteral, LocalRelation('a.int, 'b.int)))
  }
} 
Example 110
Source File: ReplaceOperatorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReplaceOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Replace Operators", FixedPoint(100),
        ReplaceDistinctWithAggregate,
        ReplaceExceptWithAntiJoin,
        ReplaceIntersectWithSemiJoin) :: Nil
  }

  test("replace Intersect with Left-semi Join") {
    val table1 = LocalRelation('a.int, 'b.int)
    val table2 = LocalRelation('c.int, 'd.int)

    val query = Intersect(table1, table2)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      Aggregate(table1.output, table1.output,
        Join(table1, table2, LeftSemi, Option('a <=> 'c && 'b <=> 'd))).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("replace Except with Left-anti Join") {
    val table1 = LocalRelation('a.int, 'b.int)
    val table2 = LocalRelation('c.int, 'd.int)

    val query = Except(table1, table2)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      Aggregate(table1.output, table1.output,
        Join(table1, table2, LeftAnti, Option('a <=> 'c && 'b <=> 'd))).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("replace Distinct with Aggregate") {
    val input = LocalRelation('a.int, 'b.int)

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)
  }
} 
Example 111
Source File: EliminateSerializationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor

case class OtherTuple(_1: Int, _2: Int)

class EliminateSerializationSuite extends PlanTest {
  private object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Serialization", FixedPoint(100),
        EliminateSerialization) :: Nil
  }

  implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
  implicit private def intEncoder = ExpressionEncoder[Int]()

  test("back to back serialization") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('obj.as("obj")).analyze
    comparePlans(optimized, expected)
  }

  test("back to back serialization with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }

  test("back to back serialization in AppendColumns") {
    val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze

    val optimized = Optimize.execute(plan)

    val expected = AppendColumnsWithObject(
      func.asInstanceOf[Any => Any],
      productEncoder[(Int, Int)].namedExpressions,
      intEncoder.namedExpressions,
      input).analyze

    comparePlans(optimized, expected)
  }

  test("back to back serialization in AppendColumns with object change") {
    val input = LocalRelation('obj.obj(classOf[OtherTuple]))
    val func = (item: (Int, Int)) => item._1
    val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze

    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }
} 
Example 112
Source File: OptimizeCodegenSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._


class OptimizeCodegenSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  test("Codegen only when the number of branches is small.") {
    assertEquivalent(
      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
      CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())

    assertEquivalent(
      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)),
      CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)))
  }

  test("Nested CaseWhen Codegen.") {
    assertEquivalent(
      CaseWhen(
        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))),
        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
      CaseWhen(
        Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))),
        CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
  }

  test("Multiple CaseWhen in one operator.") {
    val plan = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
    val correctAnswer = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
        CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, correctAnswer)
  }

  test("Multiple CaseWhen in different operators") {
    val plan = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      .where(
        LessThan(
          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      ).analyze
    val correctAnswer = OneRowRelation
      .select(
        CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
        CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
        CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      .where(
        LessThan(
          CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
          CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
      ).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, correctAnswer)
  }
} 
Example 113
Source File: ComputeCurrentTimeSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 114
Source File: SimplifyConditionalSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
    assertEquivalent(
      If(TrueLiteral, Literal(10), Literal(20)),
      Literal(10))

    assertEquivalent(
      If(FalseLiteral, Literal(10), Literal(20)),
      Literal(20))

    assertEquivalent(
      If(Literal.create(null, NullType), Literal(10), Literal(20)),
      Literal(20))
  }

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
    assertEquivalent(
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))
  }

  test("remove entire CaseWhen if only the else branch is reachable") {
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
      Literal(30))

    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))
  }

  test("remove entire CaseWhen if the first branch is always true") {
    assertEquivalent(
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
      Literal(5))

    // Test branch elimination and simplification in combination
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),
      Literal(5))

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
    assertEquivalent(
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
  }
} 
Example 115
Source File: SimplifyCastsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class SimplifyCastsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyCasts", FixedPoint(50), SimplifyCasts) :: Nil
  }

  test("non-nullable element array to nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val plan = input.select('a.cast(ArrayType(IntegerType, true)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('a.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable element to non-nullable element array cast") {
    val input = LocalRelation('a.array(ArrayType(IntegerType, true)))
    val plan = input.select('a.cast(ArrayType(IntegerType, false)).as("casted")).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }

  test("non-nullable value map to nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, false)))
    val plan = input.select('m.cast(MapType(StringType, StringType, true))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    val expected = input.select('m.as("casted")).analyze
    comparePlans(optimized, expected)
  }

  test("nullable value map to non-nullable value map cast") {
    val input = LocalRelation('m.map(MapType(StringType, StringType, true)))
    val plan = input.select('m.cast(MapType(StringType, StringType, false))
      .as("casted")).analyze
    val optimized = Optimize.execute(plan)
    comparePlans(optimized, plan)
  }
} 
Example 116
Source File: ReorderAssociativeOperatorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 117
Source File: AggregateOptimizeSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val input = LocalRelation('a.int, 'b.int, 'c.int)
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 118
Source File: SimplifyCaseConversionExpressionsSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

class SimplifyCaseConversionExpressionsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Simplify CaseConversionExpressions", Once,
        SimplifyCaseConversionExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify UPPER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Upper('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify UPPER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Lower('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Upper('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Lower('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 119
Source File: ConvertToLocalRelationSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      Row(1, 2) ::
      Row(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      Row(1, 3) ::
      Row(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 120
Source File: LikeSimplificationSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 121
Source File: CombiningLimitsSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class CombiningLimitsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Filter Pushdown", FixedPoint(100),
        ColumnPruning) ::
      Batch("Combine Limit", FixedPoint(10),
        CombineLimits) ::
      Batch("Constant Folding", FixedPoint(10),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("limits: combines two limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(10)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(5).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines three limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .limit(7)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("limits: combines two limits after ColumnPruning") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .select('a)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 122
Source File: OptimizeInSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class OptimizeInSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubQueries) ::
      Batch("ConstantFolding", Once,
        ConstantFolding,
        BooleanSimplification,
        OptimizeIn) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("OptimizedIn test: In clause optimized to InSet") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("OptimizedIn test: In clause not optimized in case filter has attributes") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
        .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 123
Source File: UnionPushdownSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class UnionPushdownSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", Once,
        EliminateSubQueries) ::
      Batch("Union Pushdown", Once,
        UnionPushdown) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
  val testUnion = Union(testRelation, testRelation2)

  test("union: filter to each side") {
    val query = testUnion.where('a === 1)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("union: project to each side") {
    val query = testUnion.select('b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer =
      Union(testRelation.select('b), testRelation2.select('e)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 124
Source File: ProjectCollapsingSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation.select(
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),
      'b).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 125
Source File: BooleanSimplificationSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class BooleanSimplificationSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubQueries) ::
      Batch("Constant Folding", FixedPoint(50),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyFilters) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)

  // The `foldLeft` is required to handle cases like comparing `a && (b && c)` and `(a && b) && c`
  def compareConditions(e1: Expression, e2: Expression): Boolean = (e1, e2) match {
    case (lhs: And, rhs: And) =>
      val lhsSet = splitConjunctivePredicates(lhs).toSet
      val rhsSet = splitConjunctivePredicates(rhs).toSet
      lhsSet.foldLeft(rhsSet) { (set, e) =>
        set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
      }.isEmpty

    case (lhs: Or, rhs: Or) =>
      val lhsSet = splitDisjunctivePredicates(lhs).toSet
      val rhsSet = splitDisjunctivePredicates(rhs).toSet
      lhsSet.foldLeft(rhsSet) { (set, e) =>
        set.find(compareConditions(_, e)).map(set - _).getOrElse(set)
      }.isEmpty

    case (l, r) => l == r
  }

  def checkCondition(input: Expression, expected: Expression): Unit = {
    val plan = testRelation.where(input).analyze
    val actual = Optimize.execute(plan).expressions.head
    compareConditions(actual, expected)
  }

  test("a && a => a") {
    checkCondition(Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a)
    checkCondition(Literal(1) < 'a && Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a)
  }

  test("a || a => a") {
    checkCondition(Literal(1) < 'a || Literal(1) < 'a, Literal(1) < 'a)
    checkCondition(Literal(1) < 'a || Literal(1) < 'a || Literal(1) < 'a, Literal(1) < 'a)
  }

  test("(a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ...") {
    checkCondition('b > 3 || 'c > 5, 'b > 3 || 'c > 5)

    checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2)

    checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2)

    val input = ('a === 'b && 'b > 3 && 'c > 2) ||
      ('a === 'b && 'c < 1 && 'a === 5) ||
      ('a === 'b && 'b < 5 && 'a > 1)

    val expected =
      (((('b > 3) && ('c > 2)) ||
        (('c < 1) && ('a === 5))) ||
        (('b < 5) && ('a > 1))) && ('a === 'b)

    checkCondition(input, expected)
  }

  test("(a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ...") {
    checkCondition('b > 3 && 'c > 5, 'b > 3 && 'c > 5)

    checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2)

    checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2)

    checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), ('b > 3 && 'c > 5) || 'a < 2)

    checkCondition(
      ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5),
      ('b > 3 && 'a > 3 && 'a < 5) || 'a === 'b)
  }
} 
Example 126
Source File: AnalysisTest.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.SimpleCatalystConf

trait AnalysisTest extends PlanTest {

  val (caseSensitiveAnalyzer, caseInsensitiveAnalyzer) = {
    val caseSensitiveConf = new SimpleCatalystConf(true)
    val caseInsensitiveConf = new SimpleCatalystConf(false)

    val caseSensitiveCatalog = new SimpleCatalog(caseSensitiveConf)
    val caseInsensitiveCatalog = new SimpleCatalog(caseInsensitiveConf)

    caseSensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation)
    caseInsensitiveCatalog.registerTable(Seq("TaBlE"), TestRelations.testRelation)

    new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitiveConf) {
      override val extendedResolutionRules = EliminateSubQueries :: Nil
    } ->
    new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseInsensitiveConf) {
      override val extendedResolutionRules = EliminateSubQueries :: Nil
    }
  }

  protected def getAnalyzer(caseSensitive: Boolean) = {
    if (caseSensitive) caseSensitiveAnalyzer else caseInsensitiveAnalyzer
  }

  protected def checkAnalysis(
      inputPlan: LogicalPlan,
      expectedPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    val actualPlan = analyzer.execute(inputPlan)
    analyzer.checkAnalysis(actualPlan)
    comparePlans(actualPlan, expectedPlan)
  }

  protected def assertAnalysisSuccess(
      inputPlan: LogicalPlan,
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    analyzer.checkAnalysis(analyzer.execute(inputPlan))
  }

  protected def assertAnalysisError(
      inputPlan: LogicalPlan,
      expectedErrors: Seq[String],
      caseSensitive: Boolean = true): Unit = {
    val analyzer = getAnalyzer(caseSensitive)
    // todo: make sure we throw AnalysisException during analysis
    val e = intercept[Exception] {
      analyzer.checkAnalysis(analyzer.execute(inputPlan))
    }
    assert(expectedErrors.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains),
      s"Expected to throw Exception contains: ${expectedErrors.mkString(", ")}, " +
        s"actually we get ${e.getMessage}")
  }
} 
Example 127
Source File: SimplifyCaseConversionExpressionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

class SimplifyCaseConversionExpressionsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Simplify CaseConversionExpressions", Once,
        SimplifyCaseConversionExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.string)

  test("simplify UPPER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Upper('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify UPPER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Upper(Lower('a)) as 'u)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select(Upper('a) as 'u)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(UPPER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Upper('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify LOWER(LOWER(str))") {
    val originalQuery =
      testRelation
        .select(Lower(Lower('a)) as 'l)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .select(Lower('a) as 'l)
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 128
Source File: ConvertToLocalRelationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 129
Source File: LikeSimplificationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._


import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._

class LikeSimplificationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Like Simplification", Once,
        LikeSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.string)
  //简化为StartsWith
  test("simplify Like into StartsWith") {
    val originalQuery =
      testRelation
        .where(('a like "abc%") || ('a like "abc\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(StartsWith('a, "abc") || ('a like "abc\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }
  //简化为EndsWith
  test("simplify Like into EndsWith") {
    val originalQuery =
      testRelation
        .where('a like "%xyz")

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(EndsWith('a, "xyz"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into Contains") {
    val originalQuery =
      testRelation
        .where(('a like "%mn%") || ('a like "%mn\\%"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(Contains('a, "mn") || ('a like "%mn\\%"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("simplify Like into EqualTo") {
    val originalQuery =
      testRelation
        .where(('a like "") || ('a like "abc"))

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer = testRelation
      .where(('a === "") || ('a === "abc"))
      .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 130
Source File: CombiningLimitsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class CombiningLimitsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Filter Pushdown", FixedPoint(100),
        ColumnPruning) ::
      Batch("Combine Limit", FixedPoint(10),
        CombineLimits) ::
      Batch("Constant Folding", FixedPoint(10),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  //限制:组合两个限制
  test("limits: combines two limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(10)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(5).analyze

    comparePlans(optimized, correctAnswer)
  }
  //限制:组合三个限制
  test("limits: combines three limits") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .limit(7)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
  //限制:在ColumnPruning之后结合两个限制
  test("limits: combines two limits after ColumnPruning") {
    val originalQuery =
      testRelation
        .select('a)
        .limit(2)
        .select('a)
        .limit(5)

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .select('a)
        .limit(2).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 131
Source File: ColumnPruningSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types.StringType

class ColumnPruningSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Column pruning", FixedPoint(100),
      ColumnPruning) :: Nil
  }

  test("Column pruning for Generate when Generate.join = false") {
    val input = LocalRelation('a.int, 'b.array(StringType))

    val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Generate(Explode('b), false, false, None, 's.string :: Nil,
        Project('b.attr :: Nil, input)).analyze

    comparePlans(optimized, correctAnswer)
  }
  //生成Generate.join = true时的列修剪
  test("Column pruning for Generate when Generate.join = true") {
    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))

    val query =
      Project(Seq('a, 's),
        Generate(Explode('c), true, false, None, 's.string :: Nil,
          input)).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Project(Seq('a, 's),
        Generate(Explode('c), true, false, None, 's.string :: Nil,
          Project(Seq('a, 'c),
            input))).analyze

    comparePlans(optimized, correctAnswer)
  }
  //如果可能,将Generate.join转换为false
  test("Turn Generate.join to false if possible") {
    val input = LocalRelation('b.array(StringType))

    val query =
      Project(('s + 1).as("s+1") :: Nil,
        Generate(Explode('b), true, false, None, 's.string :: Nil,
          input)).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Project(('s + 1).as("s+1") :: Nil,
        Generate(Explode('b), false, false, None, 's.string :: Nil,
          input)).analyze

    comparePlans(optimized, correctAnswer)
  }

  // todo: add more tests for column pruning
} 
Example 132
Source File: OptimizeInSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class OptimizeInSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubQueries) ::
      Batch("ConstantFolding", Once,
        ConstantFolding,
        BooleanSimplification,
        OptimizeIn) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  //OptimizedIn测试:当子项未优化为InSet时少于10项
  test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    comparePlans(optimized, originalQuery)
  }
  //优化测试:在优化到InSert的子句中,超过10项
  test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
        .analyze

    comparePlans(optimized, correctAnswer)
  }
  //OptimizedIn测试:在子句未优化的情况下,过滤器具有属性
  test("OptimizedIn test: In clause not optimized in case filter has attributes") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
        .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 133
Source File: SetOperationPushDownSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class SetOperationPushDownSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", Once,
        EliminateSubQueries) ::
      Batch("Union Pushdown", Once,
        SetOperationPushDown,
        //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
        SimplifyFilters) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
  val testUnion = Union(testRelation, testRelation2)
  val testIntersect = Intersect(testRelation, testRelation2)
  val testExcept = Except(testRelation, testRelation2)

  test("union/intersect/except: filter to each side") {
    val unionQuery = testUnion.where('a === 1)
    val intersectQuery = testIntersect.where('b < 10)
    val exceptQuery = testExcept.where('c >= 5)

    val unionOptimized = Optimize.execute(unionQuery.analyze)
    val intersectOptimized = Optimize.execute(intersectQuery.analyze)
    val exceptOptimized = Optimize.execute(exceptQuery.analyze)

    val unionCorrectAnswer =
      Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze
    val intersectCorrectAnswer =
      Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze
    val exceptCorrectAnswer =
      Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze

    comparePlans(unionOptimized, unionCorrectAnswer)
    comparePlans(intersectOptimized, intersectCorrectAnswer)
    comparePlans(exceptOptimized, exceptCorrectAnswer)
  }

  test("union: project to each side") {
    val unionQuery = testUnion.select('a)
    val unionOptimized = Optimize.execute(unionQuery.analyze)
    val unionCorrectAnswer =
      Union(testRelation.select('a), testRelation2.select('d)).analyze
    comparePlans(unionOptimized, unionCorrectAnswer)
  }

  test("SPARK-10539: Project should not be pushed down through Intersect or Except") {
    val intersectQuery = testIntersect.select('b, 'c)
    val exceptQuery = testExcept.select('a, 'b, 'c)

    val intersectOptimized = Optimize.execute(intersectQuery.analyze)
    val exceptOptimized = Optimize.execute(exceptQuery.analyze)

    comparePlans(intersectOptimized, intersectQuery.analyze)
    comparePlans(exceptOptimized, exceptQuery.analyze)
  }
} 
Example 134
Source File: ProjectCollapsingSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)
  //
  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation.select(
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),
      'b).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two nondeterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(Rand(20).as('rand2))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(Rand(20).as('rand2)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse one nondeterministic, one deterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand), 'a)
      .select(('a + 1).as('a_plus_1))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(('a + 1).as('a_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 135
Source File: BooleanSimplificationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class BooleanSimplificationSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubQueries) ::
      Batch("Constant Folding", FixedPoint(50),
        NullPropagation,
        ConstantFolding,
        BooleanSimplification,
        SimplifyFilters) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string)

  private def checkCondition(input: Expression, expected: Expression): Unit = {
    val plan = testRelation.where(input).analyze
    val actual = Optimize.execute(plan)
    val correctAnswer = testRelation.where(expected).analyze
    comparePlans(actual, correctAnswer)
  }

  test("a && a => a") {
    checkCondition(Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a)
    checkCondition(Literal(1) < 'a && Literal(1) < 'a && Literal(1) < 'a, Literal(1) < 'a)
  }

  test("a || a => a") {
    checkCondition(Literal(1) < 'a || Literal(1) < 'a, Literal(1) < 'a)
    checkCondition(Literal(1) < 'a || Literal(1) < 'a || Literal(1) < 'a, Literal(1) < 'a)
  }

  test("(a && b && c && ...) || (a && b && d && ...) || (a && b && e && ...) ...") {
    checkCondition('b > 3 || 'c > 5, 'b > 3 || 'c > 5)

    checkCondition(('a < 2 && 'a > 3 && 'b > 5) || 'a < 2, 'a < 2)

    checkCondition('a < 2 || ('a < 2 && 'a > 3 && 'b > 5), 'a < 2)

    val input = ('a === 'b && 'b > 3 && 'c > 2) ||
      ('a === 'b && 'c < 1 && 'a === 5) ||
      ('a === 'b && 'b < 5 && 'a > 1)

    val expected = 'a === 'b && (
      ('b > 3 && 'c > 2) || ('c < 1 && 'a === 5) || ('b < 5 && 'a > 1))

    checkCondition(input, expected)
  }

  test("(a || b || c || ...) && (a || b || d || ...) && (a || b || e || ...) ...") {
    checkCondition('b > 3 && 'c > 5, 'b > 3 && 'c > 5)

    checkCondition(('a < 2 || 'a > 3 || 'b > 5) && 'a < 2, 'a < 2)

    checkCondition('a < 2 && ('a < 2 || 'a > 3 || 'b > 5) , 'a < 2)

    checkCondition(('a < 2 || 'b > 3) && ('a < 2 || 'c > 5), 'a < 2 || ('b > 3 && 'c > 5))

    checkCondition(
      ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5),
      ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5))
  }

  private val caseInsensitiveAnalyzer =
    new Analyzer(EmptyCatalog, EmptyFunctionRegistry, new SimpleCatalystConf(false))

  test("(a && b) || (a && c) => a && (b || c) when case insensitive") {
    val plan = caseInsensitiveAnalyzer.execute(
      testRelation.where(('a > 2 && 'b > 3) || ('A > 2 && 'b < 5)))
    val actual = Optimize.execute(plan)
    val expected = caseInsensitiveAnalyzer.execute(
      testRelation.where('a > 2 && ('b > 3 || 'b < 5)))
    comparePlans(actual, expected)
  }

  test("(a || b) && (a || c) => a || (b && c) when case insensitive") {
    val plan = caseInsensitiveAnalyzer.execute(
      testRelation.where(('a > 2 || 'b > 3) && ('A > 2 || 'b < 5)))
    val actual = Optimize.execute(plan)
    val expected = caseInsensitiveAnalyzer.execute(
      testRelation.where('a > 2 || ('b > 3 && 'b < 5)))
    comparePlans(actual, expected)
  }
} 
Example 136
Source File: AggregateOptimizeSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      ReplaceDistinctWithAggregate,
      RemoveLiteralFromGroupExpressions) :: Nil
  }
  //用聚合代替distinct
  test("replace distinct with aggregate") {
    val input = LocalRelation('a.int, 'b.int)

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)
  }
  //在表达式分组中移除文字
  test("remove literals in grouping expression") {
    val input = LocalRelation('a.int, 'b.int)

    val query =
      input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(query)

    val correctAnswer = input.groupBy('a)(sum('b))

    comparePlans(optimized, correctAnswer)
  }
}