org.apache.spark.sql.catalyst.plans.logical.LocalRelation Scala Examples

The following examples show how to use org.apache.spark.sql.catalyst.plans.logical.LocalRelation. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: ResolveInlineTables.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          if (e.dataType.sameType(targetType)) {
            e.eval()
          } else {
            Cast(e, targetType).eval()
          }
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 2
Source File: SameResultSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }

  test("union") {
    assertSameResult(Union(Seq(testRelation, testRelation2)),
      Union(Seq(testRelation2, testRelation)))
  }
} 
Example 3
Source File: TestRelations.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object TestRelations {
  val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())

  val testRelation2 = LocalRelation(
    AttributeReference("a", StringType)(),
    AttributeReference("b", StringType)(),
    AttributeReference("c", DoubleType)(),
    AttributeReference("d", DecimalType(10, 2))(),
    AttributeReference("e", ShortType)())

  val testRelation3 = LocalRelation(
    AttributeReference("e", ShortType)(),
    AttributeReference("f", StringType)(),
    AttributeReference("g", DoubleType)(),
    AttributeReference("h", DecimalType(10, 2))())

  val nestedRelation = LocalRelation(
    AttributeReference("top", StructType(
      StructField("duplicateField", StringType) ::
        StructField("duplicateField", StringType) ::
        StructField("differentCase", StringType) ::
        StructField("differentcase", StringType) :: Nil
    ))())

  val nestedRelation2 = LocalRelation(
    AttributeReference("top", StructType(
      StructField("aField", StringType) ::
        StructField("bField", StringType) ::
        StructField("cField", StringType) :: Nil
    ))())

  val listRelation = LocalRelation(
    AttributeReference("list", ArrayType(IntegerType))())

  val mapRelation = LocalRelation(
    AttributeReference("map", MapType(IntegerType, IntegerType))())
} 
Example 4
Source File: ConvertToLocalRelationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 5
Source File: CollapseWindowSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.window(Seq(
        avg(b).as('avg_b),
        sum(b).as('sum_b),
        max(a).as('max_a),
        min(a).as('min_a)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }
} 
Example 6
Source File: RewriteDistinctAggregatesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 7
Source File: CollapseRepartitionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseRepartitionSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseRepartition", FixedPoint(10),
        CollapseRepartition) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two adjacent repartitions into one") {
    val query = testRelation
      .repartition(10)
      .repartition(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.repartition(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartition and repartitionBy into one") {
    val query = testRelation
      .repartition(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartitionBy and repartition into one") {
    val query = testRelation
      .distribute('a)(20)
      .repartition(10)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(10).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two adjacent repartitionBys into one") {
    val query = testRelation
      .distribute('b)(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 8
Source File: ComputeCurrentTimeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 9
Source File: ReorderAssociativeOperatorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 10
Source File: AggregateOptimizeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val input = LocalRelation('a.int, 'b.int, 'c.int)
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 11
Source File: FrequentItems.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object FrequentItems extends Logging {

  
  def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo: Array[(String, DataType)] = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }.toArray

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
  }
} 
Example 12
Source File: SparkPlannerSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 13
Source File: ResolveInlineTables.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          val castedExpr = if (e.dataType.sameType(targetType)) {
            e
          } else {
            cast(e, targetType)
          }
          castedExpr.eval()
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex)
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 14
Source File: SameResultSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ResolvedHint, Union}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }

  test("union") {
    assertSameResult(Union(Seq(testRelation, testRelation2)),
      Union(Seq(testRelation2, testRelation)))
  }

  test("hint") {
    val df1 = testRelation.join(ResolvedHint(testRelation))
    val df2 = testRelation.join(testRelation)
    assertSameResult(df1, df2)
  }
} 
Example 15
Source File: TestRelations.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object TestRelations {
  val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())

  val testRelation2 = LocalRelation(
    AttributeReference("a", StringType)(),
    AttributeReference("b", StringType)(),
    AttributeReference("c", DoubleType)(),
    AttributeReference("d", DecimalType(10, 2))(),
    AttributeReference("e", ShortType)())

  val testRelation3 = LocalRelation(
    AttributeReference("e", ShortType)(),
    AttributeReference("f", StringType)(),
    AttributeReference("g", DoubleType)(),
    AttributeReference("h", DecimalType(10, 2))())

  // This is the same with `testRelation3` but only `h` is incompatible type.
  val testRelation4 = LocalRelation(
    AttributeReference("e", StringType)(),
    AttributeReference("f", StringType)(),
    AttributeReference("g", StringType)(),
    AttributeReference("h", MapType(IntegerType, IntegerType))())

  val nestedRelation = LocalRelation(
    AttributeReference("top", StructType(
      StructField("duplicateField", StringType) ::
        StructField("duplicateField", StringType) ::
        StructField("differentCase", StringType) ::
        StructField("differentcase", StringType) :: Nil
    ))())

  val nestedRelation2 = LocalRelation(
    AttributeReference("top", StructType(
      StructField("aField", StringType) ::
        StructField("bField", StringType) ::
        StructField("cField", StringType) :: Nil
    ))())

  val listRelation = LocalRelation(
    AttributeReference("list", ArrayType(IntegerType))())

  val mapRelation = LocalRelation(
    AttributeReference("map", MapType(IntegerType, IntegerType))())
} 
Example 16
Source File: ResolveLambdaVariablesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{ArrayType, IntegerType}


class ResolveLambdaVariablesSuite extends PlanTest {
  import org.apache.spark.sql.catalyst.dsl.expressions._
  import org.apache.spark.sql.catalyst.dsl.plans._

  object Analyzer extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables(conf)) :: Nil
  }

  private val key = 'key.int
  private val values1 = 'values1.array(IntegerType)
  private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType)))
  private val data = LocalRelation(Seq(key, values1, values2))
  private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true)
  private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true)
  private val lvArray = NamedLambdaVariable("x", ArrayType(IntegerType), nullable = true)

  private def plan(e: Expression): LogicalPlan = data.select(e.as("res"))

  private def checkExpression(e1: Expression, e2: Expression): Unit = {
    comparePlans(Analyzer.execute(plan(e1)), plan(e2))
  }

  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

  test("resolution - no op") {
    checkExpression(key, key)
  }

  test("resolution - simple") {
    val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
    val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
    checkExpression(in, out)
  }

  test("resolution - nested") {
    val in = ArrayTransform(values2, LambdaFunction(
      ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
    val out = ArrayTransform(values2, LambdaFunction(
      ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
    checkExpression(in, out)
  }

  test("resolution - hidden") {
    val in = ArrayTransform(values1, key)
    val out = ArrayTransform(values1, LambdaFunction(key, lvHiddenInt :: Nil, hidden = true))
    checkExpression(in, out)
  }

  test("fail - name collisions") {
    val p = plan(ArrayTransform(values1,
      LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
    val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
    assert(msg.contains("arguments should not have names that are semantically the same"))
  }

  test("fail - lambda arguments") {
    val p = plan(ArrayTransform(values1,
      LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
    val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
    assert(msg.contains("does not match the number of arguments expected"))
  }
} 
Example 17
Source File: PullOutNondeterministicSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation


class PullOutNondeterministicSuite extends AnalysisTest {

  private lazy val a = 'a.int
  private lazy val b = 'b.int
  private lazy val r = LocalRelation(a, b)
  private lazy val rnd = Rand(10).as('_nondeterministic)
  private lazy val rndref = rnd.toAttribute

  test("no-op on filter") {
    checkAnalysis(
      r.where(Rand(10) > Literal(1.0)),
      r.where(Rand(10) > Literal(1.0))
    )
  }

  test("sort") {
    checkAnalysis(
      r.sortBy(SortOrder(Rand(10), Ascending)),
      r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b)
    )
  }

  test("aggregate") {
    checkAnalysis(
      r.groupBy(Rand(10))(Rand(10).as("rnd")),
      r.select(a, b, rnd).groupBy(rndref)(rndref.as("rnd"))
    )
  }
} 
Example 18
Source File: ResolvedUuidExpressionsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}


class ResolvedUuidExpressionsSuite extends AnalysisTest {

  private lazy val a = 'a.int
  private lazy val r = LocalRelation(a)
  private lazy val uuid1 = Uuid().as('_uuid1)
  private lazy val uuid2 = Uuid().as('_uuid2)
  private lazy val uuid3 = Uuid().as('_uuid3)
  private lazy val uuid1Ref = uuid1.toAttribute

  private val analyzer = getAnalyzer(caseSensitive = true)

  private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = {
    plan.flatMap {
      case p =>
        p.expressions.flatMap(_.collect {
          case u: Uuid => u
        })
    }
  }

  test("analyzed plan sets random seed for Uuid expression") {
    val plan = r.select(a, uuid1)
    val resolvedPlan = analyzer.executeAndCheck(plan)
    getUuidExpressions(resolvedPlan).foreach { u =>
      assert(u.resolved)
      assert(u.randomSeed.isDefined)
    }
  }

  test("Uuid expressions should have different random seeds") {
    val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
    val resolvedPlan = analyzer.executeAndCheck(plan)
    assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3)
  }

  test("Different analyzed plans should have different random seeds in Uuids") {
    val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
    val resolvedPlan1 = analyzer.executeAndCheck(plan)
    val resolvedPlan2 = analyzer.executeAndCheck(plan)
    val uuids1 = getUuidExpressions(resolvedPlan1)
    val uuids2 = getUuidExpressions(resolvedPlan2)
    assert(uuids1.distinct.length == 3)
    assert(uuids2.distinct.length == 3)
    assert(uuids1.intersect(uuids2).length == 0)
  }
} 
Example 19
Source File: ResolveInlineTablesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}


class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables(conf).validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables(conf).validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(conf)(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables(conf).convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("convert TimeZoneAwareExpression") {
    val table = UnresolvedInlineTable(Seq("c1"),
      Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
    val withTimeZone = ResolveTimeZone(conf).apply(table)
    val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone)
    val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
      .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
    assert(output.map(_.dataType) == Seq(TimestampType))
    assert(data.size == 1)
    assert(data.head.getLong(0) == correct)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables(conf).convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables(conf).convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 20
Source File: ResolveSubquerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project}


class ResolveSubquerySuite extends AnalysisTest {

  val a = 'a.int
  val b = 'b.int
  val t1 = LocalRelation(a)
  val t2 = LocalRelation(b)

  test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
    val expr = Filter(
      InSubquery(Seq(a), ListQuery(Project(Seq(UnresolvedAttribute("a")), t2))), t1)
    val m = intercept[AnalysisException] {
      SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr))
    }.getMessage
    assert(m.contains(
      "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses"))
  }
} 
Example 21
Source File: ConvertToLocalRelationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{LessThan, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

  test("Filter on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: Nil)

    val filterAndProjectOnLocal = testRelation
      .select(UnresolvedAttribute("a").as("a1"), (UnresolvedAttribute("b") + 1).as("b1"))
      .where(LessThan(UnresolvedAttribute("b1"), Literal.create(6)))

    val optimized = Optimize.execute(filterAndProjectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }
} 
Example 22
Source File: PullupCorrelatedPredicatesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{InSubquery, ListQuery}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PullupCorrelatedPredicatesSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("PullupCorrelatedPredicates", Once,
        PullupCorrelatedPredicates) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.double)
  val testRelation2 = LocalRelation('c.int, 'd.double)

  test("PullupCorrelatedPredicates should not produce unresolved plan") {
    val correlatedSubquery =
      testRelation2
        .where('b < 'd)
        .select('c)
    val outerQuery =
      testRelation
        .where(InSubquery(Seq('a), ListQuery(correlatedSubquery)))
        .select('a).analyze
    assert(outerQuery.resolved)

    val optimized = Optimize.execute(outerQuery)
    assert(optimized.resolved)
  }
} 
Example 23
Source File: CheckCartesianProductsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.scalatest.Matchers._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf.CROSS_JOINS_ENABLED

class CheckCartesianProductsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: Nil
  }

  val testRelation1 = LocalRelation('a.int, 'b.int)
  val testRelation2 = LocalRelation('c.int, 'd.int)

  val joinTypesWithRequiredCondition = Seq(Inner, LeftOuter, RightOuter, FullOuter)
  val joinTypesWithoutRequiredCondition = Seq(LeftSemi, LeftAnti, ExistenceJoin('exists))

  test("CheckCartesianProducts doesn't throw an exception if cross joins are enabled)") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
      noException should be thrownBy {
        for (joinType <- joinTypesWithRequiredCondition ++ joinTypesWithoutRequiredCondition) {
          performCartesianProductCheck(joinType)
        }
      }
    }
  }

  test("CheckCartesianProducts throws an exception for join types that require a join condition") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithRequiredCondition) {
        val thrownException = the [AnalysisException] thrownBy {
          performCartesianProductCheck(joinType)
        }
        assert(thrownException.message.contains("Detected implicit cartesian product"))
      }
    }
  }

  test("CheckCartesianProducts doesn't throw an exception if a join condition is present") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithRequiredCondition) {
        noException should be thrownBy {
          performCartesianProductCheck(joinType, Some('a === 'd))
        }
      }
    }
  }

  test("CheckCartesianProducts doesn't throw an exception if join types don't require conditions") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithoutRequiredCondition) {
        noException should be thrownBy {
          performCartesianProductCheck(joinType)
        }
      }
    }
  }

  private def performCartesianProductCheck(
      joinType: JoinType,
      condition: Option[Expression] = None): Unit = {
    val analyzedPlan = testRelation1.join(testRelation2, joinType, condition).analyze
    val optimizedPlan = Optimize.execute(analyzedPlan)
    comparePlans(analyzedPlan, optimizedPlan)
  }
} 
Example 24
Source File: EliminateDistinctSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class EliminateDistinctSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Operator Optimizations", Once,
        EliminateDistinct) :: Nil
  }

  val testRelation = LocalRelation('a.int)

  test("Eliminate Distinct in Max") {
    val query = testRelation
      .select(maxDistinct('a).as('result))
      .analyze
    val answer = testRelation
      .select(max('a).as('result))
      .analyze
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
  }

  test("Eliminate Distinct in Min") {
    val query = testRelation
      .select(minDistinct('a).as('result))
      .analyze
    val answer = testRelation
      .select(min('a).as('result))
      .analyze
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
  }
} 
Example 25
Source File: CollapseWindowSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      min(a).as('min_a),
      max(a).as('max_a),
      sum(b).as('sum_b),
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }

  test("Don't collapse adjacent windows with dependent columns") {
    val query = testRelation
      .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1)
      .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1)
      .analyze

    val expected = query.analyze
    val optimized = Optimize.execute(query.analyze)
    comparePlans(optimized, expected)
  }
} 
Example 26
Source File: RewriteDistinctAggregatesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 27
Source File: EliminateMapObjectsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class EliminateMapObjectsSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = {
      Batch("EliminateMapObjects", FixedPoint(50),
        NullPropagation,
        SimplifyCasts,
        EliminateMapObjects) :: Nil
    }
  }

  implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]()
  implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]()

  test("SPARK-20254: Remove unnecessary data conversion for primitive array") {
    val intObjType = ObjectType(classOf[Array[Int]])
    val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val intQuery = intInput.deserialize[Array[Int]].analyze
    val intOptimized = Optimize.execute(intQuery)
    val intExpected = DeserializeToObject(
      Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
      AttributeReference("obj", intObjType, true)(), intInput)
    comparePlans(intOptimized, intExpected)

    val doubleObjType = ObjectType(classOf[Array[Double]])
    val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
    val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
    val doubleOptimized = Optimize.execute(doubleQuery)
    val doubleExpected = DeserializeToObject(
      Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
      AttributeReference("obj", doubleObjType, true)(), doubleInput)
    comparePlans(doubleOptimized, doubleExpected)
  }
} 
Example 28
Source File: RewriteSubquerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.ListQuery
import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class RewriteSubquerySuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Column Pruning", FixedPoint(100), ColumnPruning) ::
      Batch("Rewrite Subquery", FixedPoint(1),
        RewritePredicateSubquery,
        ColumnPruning,
        CollapseProject,
        RemoveRedundantProject) :: Nil
  }

  test("Column pruning after rewriting predicate subquery") {
    val relation = LocalRelation('a.int, 'b.int)
    val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int)

    val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = relation
      .select('a)
      .join(relInSubquery.select('x), LeftSemi, Some('a === 'x))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

} 
Example 29
Source File: PushProjectThroughUnionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PushProjectThroughUnionSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Optimizer Batch", FixedPoint(100),
      PushProjectionThroughUnion,
      FoldablePropagation) :: Nil
  }

  test("SPARK-25450 PushProjectThroughUnion rule uses the same exprId for project expressions " +
    "in each Union child, causing mistakes in constant propagation") {
    val testRelation1 = LocalRelation('a.string, 'b.int, 'c.string)
    val testRelation2 = LocalRelation('d.string, 'e.int, 'f.string)
    val query = testRelation1
      .union(testRelation2.select("bar".as("d"), 'e, 'f))
      .select('a.as("n"))
      .select('n, "dummy").analyze
    val optimized = Optimize.execute(query)

    val expected = testRelation1
      .select('a.as("n"))
      .select('n, "dummy")
      .union(testRelation2
        .select("bar".as("d"), 'e, 'f)
        .select("bar".as("n"))
        .select("bar".as("n"), "dummy")).analyze

    comparePlans(optimized, expected)
  }
} 
Example 30
Source File: ComputeCurrentTimeSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 31
Source File: UpdateNullabilityInAttributeReferencesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{CreateArray, GetArrayItem}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class UpdateNullabilityInAttributeReferencesSuite extends PlanTest {

  object Optimizer extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Constant Folding", FixedPoint(10),
          NullPropagation,
          ConstantFolding,
          BooleanSimplification,
          SimplifyConditionals,
          SimplifyBinaryComparison,
          SimplifyExtractValueOps) ::
      Batch("UpdateAttributeReferences", Once,
        UpdateNullabilityInAttributeReferences) :: Nil
  }

  test("update nullability in AttributeReference")  {
    val rel = LocalRelation('a.long.notNull)
    // In the 'original' plans below, the Aggregate node produced by groupBy() has a
    // nullable AttributeReference to `b`, because both array indexing and map lookup are
    // nullable expressions. After optimization, the same attribute is now non-nullable,
    // but the AttributeReference is not updated to reflect this. So, we need to update nullability
    // by the `UpdateNullabilityInAttributeReferences` rule.
    val original = rel
      .select(GetArrayItem(CreateArray(Seq('a, 'a + 1L)), 0) as "b")
      .groupBy($"b")("1")
    val expected = rel.select('a as "b").groupBy($"b")("1").analyze
    val optimized = Optimizer.execute(original.analyze)
    comparePlans(optimized, expected)
  }
} 
Example 32
Source File: ReorderAssociativeOperatorSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 33
Source File: AggregateOptimizeSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}

class AggregateOptimizeSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 34
Source File: FrequentItems.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object FrequentItems extends Logging {

  
  def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo: Array[(String, DataType)] = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }.toArray

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
  }
} 
Example 35
Source File: ConsoleWriter.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming.sources

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.sources.v2.DataSourceOptions
import org.apache.spark.sql.sources.v2.writer.{DataWriterFactory, WriterCommitMessage}
import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
import org.apache.spark.sql.types.StructType


class ConsoleWriter(schema: StructType, options: DataSourceOptions)
    extends StreamWriter with Logging {

  // Number of rows to display, by default 20 rows
  protected val numRowsToShow = options.getInt("numRows", 20)

  // Truncate the displayed data if it is too long, by default it is true
  protected val isTruncated = options.getBoolean("truncate", true)

  assert(SparkSession.getActiveSession.isDefined)
  protected val spark = SparkSession.getActiveSession.get

  def createWriterFactory(): DataWriterFactory[InternalRow] = PackedRowWriterFactory

  override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
    // We have to print a "Batch" label for the epoch for compatibility with the pre-data source V2
    // behavior.
    printRows(messages, schema, s"Batch: $epochId")
  }

  def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}

  protected def printRows(
      commitMessages: Array[WriterCommitMessage],
      schema: StructType,
      printMessage: String): Unit = {
    val rows = commitMessages.collect {
      case PackedRowCommitMessage(rs) => rs
    }.flatten

    // scalastyle:off println
    println("-------------------------------------------")
    println(printMessage)
    println("-------------------------------------------")
    // scalastyle:off println
    Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows))
      .show(numRowsToShow, isTruncated)
  }

  override def toString(): String = {
    s"ConsoleWriter[numRows=$numRowsToShow, truncate=$isTruncated]"
  }
} 
Example 36
Source File: SameResultSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.IntegerType


class SameResultSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("FileSourceScanExec: different orders of data filters and partition filters") {
    withTempPath { path =>
      val tmpDir = path.getCanonicalPath
      spark.range(10)
        .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d")
        .write
        .partitionBy("a", "b")
        .parquet(tmpDir)
      val df = spark.read.parquet(tmpDir)
      // partition filters: a > 1 AND b < 9
      // data filters: c > 1 AND d < 9
      val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9"))
      val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1"))
      assert(plan1.sameResult(plan2))
    }
  }

  private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = {
    df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
      .asInstanceOf[FileSourceScanExec]
  }

  test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
    val df1 = spark.range(10).agg(sum($"id"))
    val df2 = spark.range(10).agg(sum($"id"))
    assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))

    val df3 = spark.range(10).agg(sumDistinct($"id"))
    val df4 = spark.range(10).agg(sumDistinct($"id"))
    assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
  }

  test("Canonicalized result is case-insensitive") {
    val a = AttributeReference("A", IntegerType)()
    val b = AttributeReference("B", IntegerType)()
    val planUppercase = Project(Seq(a), LocalRelation(a, b))

    val c = AttributeReference("a", IntegerType)()
    val d = AttributeReference("b", IntegerType)()
    val planLowercase = Project(Seq(c), LocalRelation(c, d))

    assert(planUppercase.sameResult(planLowercase))
  }
} 
Example 37
Source File: SparkPlannerSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data, _) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 38
Source File: ResolveInlineTables.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          if (e.dataType.sameType(targetType)) {
            e.eval()
          } else {
            Cast(e, targetType).eval()
          }
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 39
Source File: SameResultSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }

  test("union") {
    assertSameResult(Union(Seq(testRelation, testRelation2)),
      Union(Seq(testRelation2, testRelation)))
  }
} 
Example 40
Source File: TestRelations.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object TestRelations {
  val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())

  val testRelation2 = LocalRelation(
    AttributeReference("a", StringType)(),
    AttributeReference("b", StringType)(),
    AttributeReference("c", DoubleType)(),
    AttributeReference("d", DecimalType(10, 2))(),
    AttributeReference("e", ShortType)())

  val testRelation3 = LocalRelation(
    AttributeReference("e", ShortType)(),
    AttributeReference("f", StringType)(),
    AttributeReference("g", DoubleType)(),
    AttributeReference("h", DecimalType(10, 2))())

  val nestedRelation = LocalRelation(
    AttributeReference("top", StructType(
      StructField("duplicateField", StringType) ::
        StructField("duplicateField", StringType) ::
        StructField("differentCase", StringType) ::
        StructField("differentcase", StringType) :: Nil
    ))())

  val nestedRelation2 = LocalRelation(
    AttributeReference("top", StructType(
      StructField("aField", StringType) ::
        StructField("bField", StringType) ::
        StructField("cField", StringType) :: Nil
    ))())

  val listRelation = LocalRelation(
    AttributeReference("list", ArrayType(IntegerType))())

  val mapRelation = LocalRelation(
    AttributeReference("map", MapType(IntegerType, IntegerType))())
} 
Example 41
Source File: ResolveSubquerySuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project}


class ResolveSubquerySuite extends AnalysisTest {

  val a = 'a.int
  val b = 'b.int
  val t1 = LocalRelation(a)
  val t2 = LocalRelation(b)

  test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
    val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1)
    val m = intercept[AnalysisException] {
      SimpleAnalyzer.ResolveSubquery(expr)
    }.getMessage
    assert(m.contains(
      "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses"))
  }
} 
Example 42
Source File: ConvertToLocalRelationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 43
Source File: CollapseWindowSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      min(a).as('min_a),
      max(a).as('max_a),
      sum(b).as('sum_b),
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }
} 
Example 44
Source File: RewriteDistinctAggregatesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 45
Source File: CollapseRepartitionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseRepartitionSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseRepartition", FixedPoint(10),
        CollapseRepartition) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two adjacent repartitions into one") {
    val query = testRelation
      .repartition(10)
      .repartition(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.repartition(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartition and repartitionBy into one") {
    val query = testRelation
      .repartition(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartitionBy and repartition into one") {
    val query = testRelation
      .distribute('a)(20)
      .repartition(10)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(10).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two adjacent repartitionBys into one") {
    val query = testRelation
      .distribute('b)(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 46
Source File: ComputeCurrentTimeSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 47
Source File: ReorderAssociativeOperatorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 48
Source File: AggregateOptimizeSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val input = LocalRelation('a.int, 'b.int, 'c.int)
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 49
Source File: FrequentItems.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object FrequentItems extends Logging {

  
  def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo: Array[(String, DataType)] = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }.toArray

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
  }
} 
Example 50
Source File: SparkPlannerSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 51
Source File: ExtractJoinConditionsSuite.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.view.plans

import org.apache.carbondata.mv.dsl.Plans._
import org.apache.carbondata.view.testutil.ModularPlanTest
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.plans.{Inner, _}

class ExtractJoinConditionsSuite extends ModularPlanTest {
  val testRelation0 = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation1 = LocalRelation('d.int)
  val testRelation2 = LocalRelation('b.int,'c.int,'e.int)
  
  test("join only") {
    val left = testRelation0.where('a === 1)
    val right = testRelation1
    val originalQuery =
      left.join(right, condition = Some("d".attr === "b".attr || "d".attr === "c".attr)).analyze
    val modularPlan = originalQuery.modularize 
    val extracted = modularPlan.extractJoinConditions(modularPlan.children(0),modularPlan.children(1))
    
    val correctAnswer = originalQuery match {
      case logical.Join(logical.Filter(cond1,MatchLocalRelation(tbl1,_)),MatchLocalRelation(tbl2,_),Inner,Some(cond2)) =>
        Seq(cond2)
    }
    
    compareExpressions(correctAnswer, extracted)
  }
  
  test("join and filter") {
    val left = testRelation0.where('b === 2).subquery('l)
    val right = testRelation2.where('b === 2).subquery('r)
    val originalQuery =
      left.join(right,condition = Some("r.b".attr === 2 && "l.c".attr === "r.c".attr)).analyze
    val modularPlan = originalQuery.modularize
    val extracted = modularPlan.extractJoinConditions(modularPlan.children(0),modularPlan.children(1))
    
    val originalQuery1 =
      left.join(right,condition = Some("l.c".attr === "r.c".attr)).analyze
      
    val correctAnswer = originalQuery1 match {
      case logical.Join(logical.Filter(cond1,MatchLocalRelation(tbl1,_)),logical.Filter(cond2,MatchLocalRelation(tbl2,_)),Inner,Some(cond3)) =>
        Seq(cond3)
    }    
    
    compareExpressions(correctAnswer, extracted)
  }
} 
Example 52
Source File: IsSPJGHSuite.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.view.plans

import org.apache.carbondata.mv.dsl.Plans._
import org.apache.carbondata.mv.plans.modular.ModularPlan
import org.apache.carbondata.view.testutil.ModularPlanTest
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation


class IsSPJGHSuite extends ModularPlanTest {
  val testRelation0 = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation1 = LocalRelation('d.int, 'e.int)

  def assertIsSPJGH(plan: ModularPlan, result: Boolean = true): Unit = {
    if (plan.isSPJGH != result) {
      val ps = plan.toString
      println(s"Plans should return sameResult = $result\n$ps")
    }
  }

  test("project only") {
    assertIsSPJGH(testRelation0.select('a).analyze.modularize)
    assertIsSPJGH(testRelation0.select('a,'b).analyze.modularize)
  }

  test("groupby-project") {
    assertIsSPJGH(testRelation0.select('a).groupBy('a)('a).select('a).analyze.modularize)
    assertIsSPJGH(testRelation0.select('a,'b).groupBy('a,'b)('a,'b).select('a).analyze.modularize)
  }

  test("groupby-project-filter") {
    assertIsSPJGH(testRelation0.where('a === 1).select('a,'b).groupBy('a,'b)('a,'b).select('a).analyze.modularize)   
  }

  test("groupby-project-filter-join") {
    assertIsSPJGH(testRelation0.where('b === 1).join(testRelation1.where('d === 1),condition = Some("d".attr === "b".attr || "d".attr === "c".attr)).groupBy('b,'c)('b,'c).select('b).analyze.modularize)
  }
} 
Example 53
Source File: ResolveInlineTables.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          if (e.dataType.sameType(targetType)) {
            e.eval()
          } else {
            Cast(e, targetType).eval()
          }
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}")
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 54
Source File: SameResultSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }

  test("union") {
    assertSameResult(Union(Seq(testRelation, testRelation2)),
      Union(Seq(testRelation2, testRelation)))
  }
} 
Example 55
Source File: TestRelations.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object TestRelations {
  val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())

  val testRelation2 = LocalRelation(
    AttributeReference("a", StringType)(),
    AttributeReference("b", StringType)(),
    AttributeReference("c", DoubleType)(),
    AttributeReference("d", DecimalType(10, 2))(),
    AttributeReference("e", ShortType)())

  val testRelation3 = LocalRelation(
    AttributeReference("e", ShortType)(),
    AttributeReference("f", StringType)(),
    AttributeReference("g", DoubleType)(),
    AttributeReference("h", DecimalType(10, 2))())

  val nestedRelation = LocalRelation(
    AttributeReference("top", StructType(
      StructField("duplicateField", StringType) ::
        StructField("duplicateField", StringType) ::
        StructField("differentCase", StringType) ::
        StructField("differentcase", StringType) :: Nil
    ))())

  val nestedRelation2 = LocalRelation(
    AttributeReference("top", StructType(
      StructField("aField", StringType) ::
        StructField("bField", StringType) ::
        StructField("cField", StringType) :: Nil
    ))())

  val listRelation = LocalRelation(
    AttributeReference("list", ArrayType(IntegerType))())

  val mapRelation = LocalRelation(
    AttributeReference("map", MapType(IntegerType, IntegerType))())
} 
Example 56
Source File: ResolveSubquerySuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project}


class ResolveSubquerySuite extends AnalysisTest {

  val a = 'a.int
  val b = 'b.int
  val t1 = LocalRelation(a)
  val t2 = LocalRelation(b)

  test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
    val expr = Filter(In(a, Seq(ListQuery(Project(Seq(OuterReference(a)), t2)))), t1)
    val m = intercept[AnalysisException] {
      SimpleAnalyzer.ResolveSubquery(expr)
    }.getMessage
    assert(m.contains(
      "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses"))
  }
} 
Example 57
Source File: ConvertToLocalRelationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 58
Source File: CollapseWindowSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      min(a).as('min_a),
      max(a).as('max_a),
      sum(b).as('sum_b),
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }
} 
Example 59
Source File: RewriteDistinctAggregatesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 60
Source File: CollapseRepartitionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseRepartitionSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseRepartition", FixedPoint(10),
        CollapseRepartition) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two adjacent repartitions into one") {
    val query = testRelation
      .repartition(10)
      .repartition(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.repartition(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartition and repartitionBy into one") {
    val query = testRelation
      .repartition(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse repartitionBy and repartition into one") {
    val query = testRelation
      .distribute('a)(20)
      .repartition(10)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(10).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two adjacent repartitionBys into one") {
    val query = testRelation
      .distribute('b)(10)
      .distribute('a)(20)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.distribute('a)(20).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 61
Source File: ComputeCurrentTimeSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 62
Source File: ReorderAssociativeOperatorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 63
Source File: AggregateOptimizeSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val input = LocalRelation('a.int, 'b.int, 'c.int)
    val query = input.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = input.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 64
Source File: FrequentItems.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object FrequentItems extends Logging {

  
  def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo: Array[(String, DataType)] = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }.toArray

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
  }
} 
Example 65
Source File: SparkPlannerSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def user: String = sparkContext.sparkUser
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child, user) :: planLater(NeverPlanned, user) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(p => planLater(p, user))) :: planLater(NeverPlanned, user) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data, user) :: planLater(NeverPlanned, user) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 66
Source File: SameResultSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }
} 
Example 67
Source File: ConvertToLocalRelationSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      Row(1, 2) ::
      Row(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      Row(1, 3) ::
      Row(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 68
Source File: OptimizeInSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class OptimizeInSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubQueries) ::
      Batch("ConstantFolding", Once,
        ConstantFolding,
        BooleanSimplification,
        OptimizeIn) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("OptimizedIn test: In clause optimized to InSet") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("OptimizedIn test: In clause not optimized in case filter has attributes") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
        .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 69
Source File: ProjectCollapsingSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation.select(
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),
      'b).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 70
Source File: FrequentItems.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.Logging
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{ArrayType, StructField, StructType}

private[sql] object FrequentItems extends Logging {

  
  private[sql] def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toSeq)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    new DataFrame(df.sqlContext, LocalRelation(schema, Seq(resultRow)))
  }
} 
Example 71
Source File: SameResultSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }
} 
Example 72
Source File: TestRelations.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object TestRelations {
  val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())

  val testRelation2 = LocalRelation(
    AttributeReference("a", StringType)(),
    AttributeReference("b", StringType)(),
    AttributeReference("c", DoubleType)(),
    AttributeReference("d", DecimalType(10, 2))(),
    AttributeReference("e", ShortType)())

  val nestedRelation = LocalRelation(
    AttributeReference("top", StructType(
      StructField("duplicateField", StringType) ::
        StructField("duplicateField", StringType) ::
        StructField("differentCase", StringType) ::
        StructField("differentcase", StringType) :: Nil
    ))())

  val nestedRelation2 = LocalRelation(
    AttributeReference("top", StructType(
      StructField("aField", StringType) ::
        StructField("bField", StringType) ::
        StructField("cField", StringType) :: Nil
    ))())

  val listRelation = LocalRelation(
    AttributeReference("list", ArrayType(IntegerType))())
} 
Example 73
Source File: ConvertToLocalRelationSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 74
Source File: ColumnPruningSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types.StringType

class ColumnPruningSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Column pruning", FixedPoint(100),
      ColumnPruning) :: Nil
  }

  test("Column pruning for Generate when Generate.join = false") {
    val input = LocalRelation('a.int, 'b.array(StringType))

    val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Generate(Explode('b), false, false, None, 's.string :: Nil,
        Project('b.attr :: Nil, input)).analyze

    comparePlans(optimized, correctAnswer)
  }
  //生成Generate.join = true时的列修剪
  test("Column pruning for Generate when Generate.join = true") {
    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))

    val query =
      Project(Seq('a, 's),
        Generate(Explode('c), true, false, None, 's.string :: Nil,
          input)).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Project(Seq('a, 's),
        Generate(Explode('c), true, false, None, 's.string :: Nil,
          Project(Seq('a, 'c),
            input))).analyze

    comparePlans(optimized, correctAnswer)
  }
  //如果可能,将Generate.join转换为false
  test("Turn Generate.join to false if possible") {
    val input = LocalRelation('b.array(StringType))

    val query =
      Project(('s + 1).as("s+1") :: Nil,
        Generate(Explode('b), true, false, None, 's.string :: Nil,
          input)).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer =
      Project(('s + 1).as("s+1") :: Nil,
        Generate(Explode('b), false, false, None, 's.string :: Nil,
          input)).analyze

    comparePlans(optimized, correctAnswer)
  }

  // todo: add more tests for column pruning
} 
Example 75
Source File: OptimizeInSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import scala.collection.immutable.HashSet
import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._

class OptimizeInSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("AnalysisNodes", Once,
        EliminateSubQueries) ::
      Batch("ConstantFolding", Once,
        ConstantFolding,
        BooleanSimplification,
        OptimizeIn) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  //OptimizedIn测试:当子项未优化为InSet时少于10项
  test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    comparePlans(optimized, originalQuery)
  }
  //优化测试:在优化到InSert的子句中,超过10项
  test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
        .analyze

    comparePlans(optimized, correctAnswer)
  }
  //OptimizedIn测试:在子句未优化的情况下,过滤器具有属性
  test("OptimizedIn test: In clause not optimized in case filter has attributes") {
    val originalQuery =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
        .analyze

    val optimized = Optimize.execute(originalQuery.analyze)
    val correctAnswer =
      testRelation
        .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2), UnresolvedAttribute("b"))))
        .analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 76
Source File: ProjectCollapsingSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)
  //
  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation.select(
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),
      'b).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two nondeterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(Rand(20).as('rand2))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(Rand(20).as('rand2)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse one nondeterministic, one deterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand), 'a)
      .select(('a + 1).as('a_plus_1))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(('a + 1).as('a_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 77
Source File: AggregateOptimizeSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      ReplaceDistinctWithAggregate,
      RemoveLiteralFromGroupExpressions) :: Nil
  }
  //用聚合代替distinct
  test("replace distinct with aggregate") {
    val input = LocalRelation('a.int, 'b.int)

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)
  }
  //在表达式分组中移除文字
  test("remove literals in grouping expression") {
    val input = LocalRelation('a.int, 'b.int)

    val query =
      input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(query)

    val correctAnswer = input.groupBy('a)(sum('b))

    comparePlans(optimized, correctAnswer)
  }
} 
Example 78
Source File: ResolveInlineTables.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import scala.util.control.NonFatal

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructField, StructType}


  private[analysis] def convert(table: UnresolvedInlineTable): LocalRelation = {
    // For each column, traverse all the values and find a common data type and nullability.
    val fields = table.rows.transpose.zip(table.names).map { case (column, name) =>
      val inputTypes = column.map(_.dataType)
      val tpe = TypeCoercion.findWiderTypeWithoutStringPromotion(inputTypes).getOrElse {
        table.failAnalysis(s"incompatible types found in column $name for inline table")
      }
      StructField(name, tpe, nullable = column.exists(_.nullable))
    }
    val attributes = StructType(fields).toAttributes
    assert(fields.size == table.names.size)

    val newRows: Seq[InternalRow] = table.rows.map { row =>
      InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
        val targetType = fields(ci).dataType
        try {
          val castedExpr = if (e.dataType.sameType(targetType)) {
            e
          } else {
            cast(e, targetType)
          }
          castedExpr.eval()
        } catch {
          case NonFatal(ex) =>
            table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}", ex)
        }
      })
    }

    LocalRelation(attributes, newRows)
  }
} 
Example 79
Source File: SameResultSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, ResolvedHint, Union}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }

  test("union") {
    assertSameResult(Union(Seq(testRelation, testRelation2)),
      Union(Seq(testRelation2, testRelation)))
  }

  test("hint") {
    val df1 = testRelation.join(ResolvedHint(testRelation))
    val df2 = testRelation.join(testRelation)
    assertSameResult(df1, df2)
  }
} 
Example 80
Source File: TestRelations.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object TestRelations {
  val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())

  val testRelation2 = LocalRelation(
    AttributeReference("a", StringType)(),
    AttributeReference("b", StringType)(),
    AttributeReference("c", DoubleType)(),
    AttributeReference("d", DecimalType(10, 2))(),
    AttributeReference("e", ShortType)())

  val testRelation3 = LocalRelation(
    AttributeReference("e", ShortType)(),
    AttributeReference("f", StringType)(),
    AttributeReference("g", DoubleType)(),
    AttributeReference("h", DecimalType(10, 2))())

  // This is the same with `testRelation3` but only `h` is incompatible type.
  val testRelation4 = LocalRelation(
    AttributeReference("e", StringType)(),
    AttributeReference("f", StringType)(),
    AttributeReference("g", StringType)(),
    AttributeReference("h", MapType(IntegerType, IntegerType))())

  val nestedRelation = LocalRelation(
    AttributeReference("top", StructType(
      StructField("duplicateField", StringType) ::
        StructField("duplicateField", StringType) ::
        StructField("differentCase", StringType) ::
        StructField("differentcase", StringType) :: Nil
    ))())

  val nestedRelation2 = LocalRelation(
    AttributeReference("top", StructType(
      StructField("aField", StringType) ::
        StructField("bField", StringType) ::
        StructField("cField", StringType) :: Nil
    ))())

  val listRelation = LocalRelation(
    AttributeReference("list", ArrayType(IntegerType))())

  val mapRelation = LocalRelation(
    AttributeReference("map", MapType(IntegerType, IntegerType))())
} 
Example 81
Source File: PullOutNondeterministicSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation


class PullOutNondeterministicSuite extends AnalysisTest {

  private lazy val a = 'a.int
  private lazy val b = 'b.int
  private lazy val r = LocalRelation(a, b)
  private lazy val rnd = Rand(10).as('_nondeterministic)
  private lazy val rndref = rnd.toAttribute

  test("no-op on filter") {
    checkAnalysis(
      r.where(Rand(10) > Literal(1.0)),
      r.where(Rand(10) > Literal(1.0))
    )
  }

  test("sort") {
    checkAnalysis(
      r.sortBy(SortOrder(Rand(10), Ascending)),
      r.select(a, b, rnd).sortBy(SortOrder(rndref, Ascending)).select(a, b)
    )
  }

  test("aggregate") {
    checkAnalysis(
      r.groupBy(Rand(10))(Rand(10).as("rnd")),
      r.select(a, b, rnd).groupBy(rndref)(rndref.as("rnd"))
    )
  }
} 
Example 82
Source File: ResolvedUuidExpressionsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}


class ResolvedUuidExpressionsSuite extends AnalysisTest {

  private lazy val a = 'a.int
  private lazy val r = LocalRelation(a)
  private lazy val uuid1 = Uuid().as('_uuid1)
  private lazy val uuid2 = Uuid().as('_uuid2)
  private lazy val uuid3 = Uuid().as('_uuid3)
  private lazy val uuid1Ref = uuid1.toAttribute

  private val analyzer = getAnalyzer(caseSensitive = true)

  private def getUuidExpressions(plan: LogicalPlan): Seq[Uuid] = {
    plan.flatMap {
      case p =>
        p.expressions.flatMap(_.collect {
          case u: Uuid => u
        })
    }
  }

  test("analyzed plan sets random seed for Uuid expression") {
    val plan = r.select(a, uuid1)
    val resolvedPlan = analyzer.executeAndCheck(plan)
    getUuidExpressions(resolvedPlan).foreach { u =>
      assert(u.resolved)
      assert(u.randomSeed.isDefined)
    }
  }

  test("Uuid expressions should have different random seeds") {
    val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
    val resolvedPlan = analyzer.executeAndCheck(plan)
    assert(getUuidExpressions(resolvedPlan).map(_.randomSeed.get).distinct.length == 3)
  }

  test("Different analyzed plans should have different random seeds in Uuids") {
    val plan = r.select(a, uuid1).groupBy(uuid1Ref)(uuid2, uuid3)
    val resolvedPlan1 = analyzer.executeAndCheck(plan)
    val resolvedPlan2 = analyzer.executeAndCheck(plan)
    val uuids1 = getUuidExpressions(resolvedPlan1)
    val uuids2 = getUuidExpressions(resolvedPlan2)
    assert(uuids1.distinct.length == 3)
    assert(uuids2.distinct.length == 3)
    assert(uuids1.intersect(uuids2).length == 0)
  }
} 
Example 83
Source File: ResolveInlineTablesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}


class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables(conf).validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables(conf).validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(conf)(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables(conf).convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("convert TimeZoneAwareExpression") {
    val table = UnresolvedInlineTable(Seq("c1"),
      Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
    val withTimeZone = ResolveTimeZone(conf).apply(table)
    val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone)
    val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
      .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
    assert(output.map(_.dataType) == Seq(TimestampType))
    assert(data.size == 1)
    assert(data.head.getLong(0) == correct)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables(conf).convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables(conf).convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 84
Source File: ResolveSubquerySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery, OuterReference}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LocalRelation, Project}


class ResolveSubquerySuite extends AnalysisTest {

  val a = 'a.int
  val b = 'b.int
  val t1 = LocalRelation(a)
  val t2 = LocalRelation(b)

  test("SPARK-17251 Improve `OuterReference` to be `NamedExpression`") {
    val expr = Filter(In(a, Seq(ListQuery(Project(Seq(UnresolvedAttribute("a")), t2)))), t1)
    val m = intercept[AnalysisException] {
      SimpleAnalyzer.checkAnalysis(SimpleAnalyzer.ResolveSubquery(expr))
    }.getMessage
    assert(m.contains(
      "Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses"))
  }
} 
Example 85
Source File: ConvertToLocalRelationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 86
Source File: PullupCorrelatedPredicatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{In, ListQuery}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class PullupCorrelatedPredicatesSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("PullupCorrelatedPredicates", Once,
        PullupCorrelatedPredicates) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.double)
  val testRelation2 = LocalRelation('c.int, 'd.double)

  test("PullupCorrelatedPredicates should not produce unresolved plan") {
    val correlatedSubquery =
      testRelation2
        .where('b < 'd)
        .select('c)
    val outerQuery =
      testRelation
        .where(In('a, Seq(ListQuery(correlatedSubquery))))
        .select('a).analyze
    assert(outerQuery.resolved)

    val optimized = Optimize.execute(outerQuery)
    assert(optimized.resolved)
  }
} 
Example 87
Source File: CheckCartesianProductsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.scalatest.Matchers._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf.CROSS_JOINS_ENABLED

class CheckCartesianProductsSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Check Cartesian Products", Once, CheckCartesianProducts) :: Nil
  }

  val testRelation1 = LocalRelation('a.int, 'b.int)
  val testRelation2 = LocalRelation('c.int, 'd.int)

  val joinTypesWithRequiredCondition = Seq(Inner, LeftOuter, RightOuter, FullOuter)
  val joinTypesWithoutRequiredCondition = Seq(LeftSemi, LeftAnti, ExistenceJoin('exists))

  test("CheckCartesianProducts doesn't throw an exception if cross joins are enabled)") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "true") {
      noException should be thrownBy {
        for (joinType <- joinTypesWithRequiredCondition ++ joinTypesWithoutRequiredCondition) {
          performCartesianProductCheck(joinType)
        }
      }
    }
  }

  test("CheckCartesianProducts throws an exception for join types that require a join condition") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithRequiredCondition) {
        val thrownException = the [AnalysisException] thrownBy {
          performCartesianProductCheck(joinType)
        }
        assert(thrownException.message.contains("Detected implicit cartesian product"))
      }
    }
  }

  test("CheckCartesianProducts doesn't throw an exception if a join condition is present") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithRequiredCondition) {
        noException should be thrownBy {
          performCartesianProductCheck(joinType, Some('a === 'd))
        }
      }
    }
  }

  test("CheckCartesianProducts doesn't throw an exception if join types don't require conditions") {
    withSQLConf(CROSS_JOINS_ENABLED.key -> "false") {
      for (joinType <- joinTypesWithoutRequiredCondition) {
        noException should be thrownBy {
          performCartesianProductCheck(joinType)
        }
      }
    }
  }

  private def performCartesianProductCheck(
      joinType: JoinType,
      condition: Option[Expression] = None): Unit = {
    val analyzedPlan = testRelation1.join(testRelation2, joinType, condition).analyze
    val optimizedPlan = Optimize.execute(analyzedPlan)
    comparePlans(analyzedPlan, optimizedPlan)
  }
} 
Example 88
Source File: EliminateDistinctSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class EliminateDistinctSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Operator Optimizations", Once,
        EliminateDistinct) :: Nil
  }

  val testRelation = LocalRelation('a.int)

  test("Eliminate Distinct in Max") {
    val query = testRelation
      .select(maxDistinct('a).as('result))
      .analyze
    val answer = testRelation
      .select(max('a).as('result))
      .analyze
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
  }

  test("Eliminate Distinct in Min") {
    val query = testRelation
      .select(minDistinct('a).as('result))
      .analyze
    val answer = testRelation
      .select(min('a).as('result))
      .analyze
    assert(query != answer)
    comparePlans(Optimize.execute(query), answer)
  }
} 
Example 89
Source File: CollapseWindowSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class CollapseWindowSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("CollapseWindow", FixedPoint(10),
        CollapseWindow) :: Nil
  }

  val testRelation = LocalRelation('a.double, 'b.double, 'c.string)
  val a = testRelation.output(0)
  val b = testRelation.output(1)
  val c = testRelation.output(2)
  val partitionSpec1 = Seq(c)
  val partitionSpec2 = Seq(c + 1)
  val orderSpec1 = Seq(c.asc)
  val orderSpec2 = Seq(c.desc)

  test("collapse two adjacent windows with the same partition/order") {
    val query = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec1)
      .window(Seq(sum(b).as('sum_b)), partitionSpec1, orderSpec1)
      .window(Seq(avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    val analyzed = query.analyze
    val optimized = Optimize.execute(analyzed)
    assert(analyzed.output === optimized.output)

    val correctAnswer = testRelation.window(Seq(
      min(a).as('min_a),
      max(a).as('max_a),
      sum(b).as('sum_b),
      avg(b).as('avg_b)), partitionSpec1, orderSpec1)

    comparePlans(optimized, correctAnswer)
  }

  test("Don't collapse adjacent windows with different partitions or orders") {
    val query1 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec1, orderSpec2)

    val optimized1 = Optimize.execute(query1.analyze)
    val correctAnswer1 = query1.analyze

    comparePlans(optimized1, correctAnswer1)

    val query2 = testRelation
      .window(Seq(min(a).as('min_a)), partitionSpec1, orderSpec1)
      .window(Seq(max(a).as('max_a)), partitionSpec2, orderSpec1)

    val optimized2 = Optimize.execute(query2.analyze)
    val correctAnswer2 = query2.analyze

    comparePlans(optimized2, correctAnswer2)
  }

  test("Don't collapse adjacent windows with dependent columns") {
    val query = testRelation
      .window(Seq(sum(a).as('sum_a)), partitionSpec1, orderSpec1)
      .window(Seq(max('sum_a).as('max_sum_a)), partitionSpec1, orderSpec1)
      .analyze

    val expected = query.analyze
    val optimized = Optimize.execute(query.analyze)
    comparePlans(optimized, expected)
  }
} 
Example 90
Source File: RewriteDistinctAggregatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 91
Source File: EliminateMapObjectsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types._

class EliminateMapObjectsSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = {
      Batch("EliminateMapObjects", FixedPoint(50),
        NullPropagation,
        SimplifyCasts,
        EliminateMapObjects) :: Nil
    }
  }

  implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]()
  implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]()

  test("SPARK-20254: Remove unnecessary data conversion for primitive array") {
    val intObjType = ObjectType(classOf[Array[Int]])
    val intInput = LocalRelation('a.array(ArrayType(IntegerType, false)))
    val intQuery = intInput.deserialize[Array[Int]].analyze
    val intOptimized = Optimize.execute(intQuery)
    val intExpected = DeserializeToObject(
      Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false),
      AttributeReference("obj", intObjType, true)(), intInput)
    comparePlans(intOptimized, intExpected)

    val doubleObjType = ObjectType(classOf[Array[Double]])
    val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false)))
    val doubleQuery = doubleInput.deserialize[Array[Double]].analyze
    val doubleOptimized = Optimize.execute(doubleQuery)
    val doubleExpected = DeserializeToObject(
      Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false),
      AttributeReference("obj", doubleObjType, true)(), doubleInput)
    comparePlans(doubleOptimized, doubleExpected)
  }
} 
Example 92
Source File: RewriteSubquerySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.ListQuery
import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class RewriteSubquerySuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Column Pruning", FixedPoint(100), ColumnPruning) ::
      Batch("Rewrite Subquery", FixedPoint(1),
        RewritePredicateSubquery,
        ColumnPruning,
        CollapseProject,
        RemoveRedundantProject) :: Nil
  }

  test("Column pruning after rewriting predicate subquery") {
    val relation = LocalRelation('a.int, 'b.int)
    val relInSubquery = LocalRelation('x.int, 'y.int, 'z.int)

    val query = relation.where('a.in(ListQuery(relInSubquery.select('x)))).select('a)

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = relation
      .select('a)
      .join(relInSubquery.select('x), LeftSemi, Some('a === 'x))
      .analyze

    comparePlans(optimized, correctAnswer)
  }

} 
Example 93
Source File: ComputeCurrentTimeSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

class ComputeCurrentTimeSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime))
  }

  test("analyzer should replace current_timestamp with literals") {
    val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()),
      LocalRelation())

    val min = System.currentTimeMillis() * 1000
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = (System.currentTimeMillis() + 1) * 1000

    val lits = new scala.collection.mutable.ArrayBuffer[Long]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Long]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }

  test("analyzer should replace current_date with literals") {
    val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation())

    val min = DateTimeUtils.millisToDays(System.currentTimeMillis())
    val plan = Optimize.execute(in.analyze).asInstanceOf[Project]
    val max = DateTimeUtils.millisToDays(System.currentTimeMillis())

    val lits = new scala.collection.mutable.ArrayBuffer[Int]
    plan.transformAllExpressions { case e: Literal =>
      lits += e.value.asInstanceOf[Int]
      e
    }
    assert(lits.size == 2)
    assert(lits(0) >= min && lits(0) <= max)
    assert(lits(1) >= min && lits(1) <= max)
    assert(lits(0) == lits(1))
  }
} 
Example 94
Source File: ReorderAssociativeOperatorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class ReorderAssociativeOperatorSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("ReorderAssociativeOperator", Once,
        ReorderAssociativeOperator) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("Reorder associative operators") {
    val originalQuery =
      testRelation
        .select(
          (Literal(3) + ((Literal(1) + 'a) + 2)) + 4,
          'b * 1 * 2 * 3 * 4,
          ('b + 1) * 2 * 3 * 4,
          'a + 1 + 'b + 2 + 'c + 3,
          'a + 1 + 'b * 2 + 'c + 3,
          Rand(0) * 1 * 2 * 3 * 4)

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer =
      testRelation
        .select(
          ('a + 10).as("((3 + ((1 + a) + 2)) + 4)"),
          ('b * 24).as("((((b * 1) * 2) * 3) * 4)"),
          (('b + 1) * 24).as("((((b + 1) * 2) * 3) * 4)"),
          ('a + 'b + 'c + 6).as("(((((a + 1) + b) + 2) + c) + 3)"),
          ('a + 'b * 2 + 'c + 4).as("((((a + 1) + (b * 2)) + c) + 3)"),
          Rand(0) * 1 * 2 * 3 * 4)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("nested expression with aggregate operator") {
    val originalQuery =
      testRelation.as("t1")
        .join(testRelation.as("t2"), Inner, Some("t1.a".attr === "t2.a".attr))
        .groupBy("t1.a".attr + 1, "t2.a".attr + 1)(
          (("t1.a".attr + 1) + ("t2.a".attr + 1)).as("col"))

    val optimized = Optimize.execute(originalQuery.analyze)

    val correctAnswer = originalQuery.analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 95
Source File: AggregateOptimizeSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}

class AggregateOptimizeSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      FoldablePropagation,
      RemoveLiteralFromGroupExpressions,
      RemoveRepetitionFromGroupExpressions) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)

  test("remove literals in grouping expression") {
    val query = testRelation.groupBy('a, Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not remove all grouping expressions if they are all literals") {
    val query = testRelation.groupBy(Literal("1"), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = analyzer.execute(testRelation.groupBy(Literal(0))(sum('b)))

    comparePlans(optimized, correctAnswer)
  }

  test("Remove aliased literals") {
    val query = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a, 'y)(sum('b))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.select('a, 'b, Literal(1).as('y)).groupBy('a)(sum('b)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("remove repetition in grouping expression") {
    val query = testRelation.groupBy('a + 1, 'b + 2, Literal(1) + 'A, Literal(2) + 'B)(sum('c))
    val optimized = Optimize.execute(analyzer.execute(query))
    val correctAnswer = testRelation.groupBy('a + 1, 'b + 2)(sum('c)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 96
Source File: FrequentItems.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object FrequentItems extends Logging {

  
  def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4 && support <= 1.0, s"Support must be in [1e-4, 1], but got $support.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo: Array[(String, DataType)] = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }.toArray

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.treeAggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    Dataset.ofRows(df.sparkSession, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
  }
} 
Example 97
Source File: SparkPlannerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data, _) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 98
Source File: StarryLocalRelation.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.logical

import org.apache.spark.sql.catalyst.{InternalRow, analysis}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}


  override final def newInstance(): this.type = {
    LocalRelation(output.map(_.newInstance()), data, isStreaming).asInstanceOf[this.type]
  }

  override protected def stringArgs: Iterator[Any] = {
    if (data.isEmpty) {
      Iterator("<empty>", output)
    } else {
      Iterator(output)
    }
  }

  override def computeStats(): Statistics =
    Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)

  def toSQL(inlineTableName: String): String = {
    require(data.nonEmpty)
    val types = output.map(_.dataType)
    val rows = data.map { row =>
      val cells = row.toSeq(types).zip(types).map { case (v, tpe) => Literal(v, tpe).sql }
      cells.mkString("(", ", ", ")")
    }
    "VALUES " + rows.mkString(", ") +
      " AS " + inlineTableName +
      output.map(_.name).mkString("(", ", ", ")")
  }

} 
Example 99
Source File: ReplaceGroup.scala    From starry   with Apache License 2.0 5 votes vote down vote up
package com.github.passionke.replace

import com.github.passionke.starry.SparkPlanExecutor
import com.github.passionke.baseline.Dumy
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, SubqueryAlias}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.Spark
import org.scalatest.FunSuite


class ReplaceGroup extends FunSuite {

  test("group by") {
    val sparkSession = Spark.sparkSession
    sparkSession.sparkContext.setLogLevel("WARN")
    import sparkSession.implicits._
    val dumys = Seq(Dumy("a", 10, "abc"), Dumy("a", 20, "ass"))
    dumys.toDF().createOrReplaceTempView("a")

    val df = sparkSession.sql(
      """
        |select name, count(1) as cnt
        |from a
        |group by name
      """.stripMargin)

    df.show()
    val sparkPlan = df.queryExecution.sparkPlan
    val logicalPlan = df.queryExecution.analyzed


    val dumy1 = Seq(Dumy("a", 1, "abc"), Dumy("a", 1, "ass"), Dumy("a", 2, "sf"))
    val data = dumy1.toDF().queryExecution.executedPlan.execute().collect()

    val newL = logicalPlan.transform({
      case SubqueryAlias(a, localRelation) if a.equals("a") =>
        SubqueryAlias(a, LocalRelation(localRelation.output, data))
    })

    val ns = sparkSession.newSession()
    val qe = new QueryExecution(ns, newL)
    val start = System.currentTimeMillis()
    val list = SparkPlanExecutor.exec(qe.sparkPlan, ns)
    assert(list.head.getLong(1).equals(3L))
    val end = System.currentTimeMillis()
    end - start
  }

} 
Example 100
Source File: SameResultSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.util._


class SameResultSuite extends SparkFunSuite {
  val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
  val testRelation2 = LocalRelation('a.int, 'b.int, 'c.int)

  def assertSameResult(a: LogicalPlan, b: LogicalPlan, result: Boolean = true): Unit = {
    val aAnalyzed = a.analyze
    val bAnalyzed = b.analyze

    if (aAnalyzed.sameResult(bAnalyzed) != result) {
      val comparison = sideBySide(aAnalyzed.toString, bAnalyzed.toString).mkString("\n")
      fail(s"Plans should return sameResult = $result\n$comparison")
    }
  }

  test("relations") {
    assertSameResult(testRelation, testRelation2)
  }

  test("projections") {
    assertSameResult(testRelation.select('a), testRelation2.select('a))
    assertSameResult(testRelation.select('b), testRelation2.select('b))
    assertSameResult(testRelation.select('a, 'b), testRelation2.select('a, 'b))
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('b, 'a))

    assertSameResult(testRelation, testRelation2.select('a), result = false)
    assertSameResult(testRelation.select('b, 'a), testRelation2.select('a, 'b), result = false)
  }

  test("filters") {
    assertSameResult(testRelation.where('a === 'b), testRelation2.where('a === 'b))
  }

  test("sorts") {
    assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc))
  }
} 
Example 101
Source File: TestRelations.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

object TestRelations {
  val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())

  val testRelation2 = LocalRelation(
    AttributeReference("a", StringType)(),
    AttributeReference("b", StringType)(),
    AttributeReference("c", DoubleType)(),
    AttributeReference("d", DecimalType(10, 2))(),
    AttributeReference("e", ShortType)())

  val nestedRelation = LocalRelation(
    AttributeReference("top", StructType(
      StructField("duplicateField", StringType) ::
        StructField("duplicateField", StringType) ::
        StructField("differentCase", StringType) ::
        StructField("differentcase", StringType) :: Nil
    ))())

  val nestedRelation2 = LocalRelation(
    AttributeReference("top", StructType(
      StructField("aField", StringType) ::
        StructField("bField", StringType) ::
        StructField("cField", StringType) :: Nil
    ))())

  val listRelation = LocalRelation(
    AttributeReference("list", ArrayType(IntegerType))())

  val mapRelation = LocalRelation(
    AttributeReference("map", MapType(IntegerType, IntegerType))())
} 
Example 102
Source File: ConvertToLocalRelationSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ConvertToLocalRelationSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("LocalRelation", FixedPoint(100),
        ConvertToLocalRelation) :: Nil
  }

  test("Project on LocalRelation should be turned into a single LocalRelation") {
    val testRelation = LocalRelation(
      LocalRelation('a.int, 'b.int).output,
      InternalRow(1, 2) :: InternalRow(4, 5) :: Nil)

    val correctAnswer = LocalRelation(
      LocalRelation('a1.int, 'b1.int).output,
      InternalRow(1, 3) :: InternalRow(4, 6) :: Nil)

    val projectOnLocal = testRelation.select(
      UnresolvedAttribute("a").as("a1"),
      (UnresolvedAttribute("b") + 1).as("b1"))

    val optimized = Optimize.execute(projectOnLocal.analyze)

    comparePlans(optimized, correctAnswer)
  }

} 
Example 103
Source File: ColumnPruningSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types.StringType

class ColumnPruningSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Column pruning", FixedPoint(100),
      ColumnPruning) :: Nil
  }

  test("Column pruning for Generate when Generate.join = false") {
    val input = LocalRelation('a.int, 'b.array(StringType))

    val query = input.generate(Explode('b), join = false).analyze

    val optimized = Optimize.execute(query)

    val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("Column pruning for Generate when Generate.join = true") {
    val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))

    val query =
      input
        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
        .select('a, 'explode)
        .analyze

    val optimized = Optimize.execute(query)

    val correctAnswer =
      input
        .select('a, 'c)
        .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
        .select('a, 'explode)
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("Turn Generate.join to false if possible") {
    val input = LocalRelation('b.array(StringType))

    val query =
      input
        .generate(Explode('b), join = true, outputNames = "explode" :: Nil)
        .select(('explode + 1).as("result"))
        .analyze

    val optimized = Optimize.execute(query)

    val correctAnswer =
      input
        .generate(Explode('b), join = false, outputNames = "explode" :: Nil)
        .select(('explode + 1).as("result"))
        .analyze

    comparePlans(optimized, correctAnswer)
  }

  test("Column pruning for Project on Sort") {
    val input = LocalRelation('a.int, 'b.string, 'c.double)

    val query = input.orderBy('b.asc).select('a).analyze
    val optimized = Optimize.execute(query)

    val correctAnswer = input.select('a, 'b).orderBy('b.asc).select('a).analyze

    comparePlans(optimized, correctAnswer)
  }

  // todo: add more tests for column pruning
} 
Example 104
Source File: ProjectCollapsingSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Rand
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor


class ProjectCollapsingSuite extends PlanTest {
  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches =
      Batch("Subqueries", FixedPoint(10), EliminateSubQueries) ::
        Batch("ProjectCollapsing", Once, ProjectCollapsing) :: Nil
  }

  val testRelation = LocalRelation('a.int, 'b.int)

  test("collapse two deterministic, independent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select('a_plus_1, ('b + 1).as('b_plus_1))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = testRelation.select(('a + 1).as('a_plus_1), ('b + 1).as('b_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two deterministic, dependent projects into one") {
    val query = testRelation
      .select(('a + 1).as('a_plus_1), 'b)
      .select(('a_plus_1 + 1).as('a_plus_2), 'b)

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation.select(
      (('a + 1).as('a_plus_1) + 1).as('a_plus_2),
      'b).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("do not collapse nondeterministic projects") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(('rand + 1).as('rand1), ('rand + 2).as('rand2))

    val optimized = Optimize.execute(query.analyze)
    val correctAnswer = query.analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse two nondeterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand))
      .select(Rand(20).as('rand2))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(Rand(20).as('rand2)).analyze

    comparePlans(optimized, correctAnswer)
  }

  test("collapse one nondeterministic, one deterministic, independent projects into one") {
    val query = testRelation
      .select(Rand(10).as('rand), 'a)
      .select(('a + 1).as('a_plus_1))

    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = testRelation
      .select(('a + 1).as('a_plus_1)).analyze

    comparePlans(optimized, correctAnswer)
  }
} 
Example 105
Source File: AggregateOptimizeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor

class AggregateOptimizeSuite extends PlanTest {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Aggregate", FixedPoint(100),
      ReplaceDistinctWithAggregate,
      RemoveLiteralFromGroupExpressions) :: Nil
  }

  test("replace distinct with aggregate") {
    val input = LocalRelation('a.int, 'b.int)

    val query = Distinct(input)
    val optimized = Optimize.execute(query.analyze)

    val correctAnswer = Aggregate(input.output, input.output, input)

    comparePlans(optimized, correctAnswer)
  }

  test("remove literals in grouping expression") {
    val input = LocalRelation('a.int, 'b.int)

    val query =
      input.groupBy('a, Literal(1), Literal(1) + Literal(2))(sum('b))
    val optimized = Optimize.execute(query)

    val correctAnswer = input.groupBy('a)(sum('b))

    comparePlans(optimized, correctAnswer)
  }
} 
Example 106
Source File: FrequentItems.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.stat

import scala.collection.mutable.{Map => MutableMap}

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, Column, DataFrame}

private[sql] object FrequentItems extends Logging {

  
  private[sql] def singlePassFreqItems(
      df: DataFrame,
      cols: Seq[String],
      support: Double): DataFrame = {
    require(support >= 1e-4, s"support ($support) must be greater than 1e-4.")
    val numCols = cols.length
    // number of max items to keep counts for
    val sizeOfMap = (1 / support).toInt
    val countMaps = Seq.tabulate(numCols)(i => new FreqItemCounter(sizeOfMap))
    val originalSchema = df.schema
    val colInfo: Array[(String, DataType)] = cols.map { name =>
      val index = originalSchema.fieldIndex(name)
      (name, originalSchema.fields(index).dataType)
    }.toArray

    val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)(
      seqOp = (counts, row) => {
        var i = 0
        while (i < numCols) {
          val thisMap = counts(i)
          val key = row.get(i)
          thisMap.add(key, 1L)
          i += 1
        }
        counts
      },
      combOp = (baseCounts, counts) => {
        var i = 0
        while (i < numCols) {
          baseCounts(i).merge(counts(i))
          i += 1
        }
        baseCounts
      }
    )
    val justItems = freqItems.map(m => m.baseMap.keys.toArray)
    val resultRow = Row(justItems : _*)
    // append frequent Items to the column name for easy debugging
    val outputCols = colInfo.map { v =>
      StructField(v._1 + "_freqItems", ArrayType(v._2, false))
    }
    val schema = StructType(outputCols).toAttributes
    new DataFrame(df.sqlContext, LocalRelation.fromExternalRows(schema, Seq(resultRow)))
  }
} 
Example 107
Source File: DummyNode.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.local

import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation


private[local] case class DummyNode(
    output: Seq[Attribute],
    relation: LocalRelation,
    conf: SQLConf)
  extends LocalNode(conf) {

  import DummyNode._

  private var index: Int = CLOSED
  private val input: Seq[InternalRow] = relation.data

  def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
    this(output, LocalRelation.fromProduct(output, data), conf)
  }

  def isOpen: Boolean = index != CLOSED

  override def children: Seq[LocalNode] = Seq.empty

  override def open(): Unit = {
    index = -1
  }

  override def next(): Boolean = {
    index += 1
    index < input.size
  }

  override def fetch(): InternalRow = {
    assert(index >= 0 && index < input.size)
    input(index)
  }

  override def close(): Unit = {
    index = CLOSED
  }
}

private object DummyNode {
  val CLOSED: Int = Int.MinValue
}