org.apache.spark.sql.test.SharedSQLContext Scala Examples

The following examples show how to use org.apache.spark.sql.test.SharedSQLContext. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: ExchangeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.test.SharedSQLContext

class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.localSeqToDataFrameHolder

  test("shuffling UnsafeRows in exchange") {
    val input = (1 to 1000).map(Tuple1.apply)
    checkAnswer(
      input.toDF(),
      plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))),
      input.map(Row.fromTuple)
    )
  }
} 
Example 2
Source File: MiscFunctionsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class MiscFunctionsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("reflect and java_method") {
    val df = Seq((1, "one")).toDF("a", "b")
    val className = ReflectClass.getClass.getName.stripSuffix("$")
    checkAnswer(
      df.selectExpr(
        s"reflect('$className', 'method1', a, b)",
        s"java_method('$className', 'method1', a, b)"),
      Row("m1one", "m1one"))
  }
}

object ReflectClass {
  def method1(v1: Int, v2: String): String = "m" + v1 + v2
} 
Example 3
Source File: DataFrameHintSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.test.SharedSQLContext

class DataFrameHintSuite extends AnalysisTest with SharedSQLContext {
  import testImplicits._
  lazy val df = spark.range(10)

  private def check(df: Dataset[_], expected: LogicalPlan) = {
    comparePlans(
      df.queryExecution.logical,
      expected
    )
  }

  test("various hint parameters") {
    check(
      df.hint("hint1"),
      UnresolvedHint("hint1", Seq(),
        df.logicalPlan
      )
    )

    check(
      df.hint("hint1", 1, "a"),
      UnresolvedHint("hint1", Seq(1, "a"), df.logicalPlan)
    )

    check(
      df.hint("hint1", 1, $"a"),
      UnresolvedHint("hint1", Seq(1, $"a"),
        df.logicalPlan
      )
    )

    check(
      df.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")),
      UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")),
        df.logicalPlan
      )
    )
  }
} 
Example 4
Source File: ApproxCountDistinctForIntervalsQuerySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.test.SharedSQLContext

class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height
  // histogram usually contains hundreds of buckets. So we need to test
  // ApproxCountDistinctForIntervals with large number of endpoints
  // (the number of endpoints == the number of buckets + 1).
  test("test ApproxCountDistinctForIntervals with large number of endpoints") {
    val table = "approx_count_distinct_for_intervals_tbl"
    withTable(table) {
      (1 to 100000).toDF("col").createOrReplaceTempView(table)
      // percentiles of 0, 0.001, 0.002 ... 0.999, 1
      val endpoints = (0 to 1000).map(_ * 100000 / 1000)

      // Since approx_count_distinct_for_intervals is not a public function, here we do
      // the computation by constructing logical plan.
      val relation = spark.table(table).logicalPlan
      val attr = relation.output.find(_.name == "col").get
      val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_))))
      val aggExpr = aggFunc.toAggregateExpression()
      val namedExpr = Alias(aggExpr, aggExpr.toString)()
      val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation))
        .executedPlan.executeTake(1).head
      val ndvArray = ndvsRow.getArray(0).toLongArray()
      assert(endpoints.length == ndvArray.length + 1)

      // Each bucket has 100 distinct values.
      val expectedNdv = 100
      for (i <- ndvArray.indices) {
        val ndv = ndvArray(i)
        val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
        assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.")
      }
    }
  }
} 
Example 5
Source File: KinesisSourceOffsetSuite.scala    From kinesis-sql   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kinesis

import java.io.File

import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.streaming.OffsetSuite
import org.apache.spark.sql.test.SharedSQLContext


class KinesisSourceOffsetSuite extends OffsetSuite with SharedSQLContext {


  compare(
    one = KinesisSourceOffset(new ShardOffsets(-1L, "dummy", Array.empty[ShardInfo])),
    two = KinesisSourceOffset(new ShardOffsets(1L, "dummy", Array.empty[ShardInfo])))

  compare(
    one = KinesisSourceOffset(new ShardOffsets(1L, "foo", Array.empty[ShardInfo])),
    two = KinesisSourceOffset(new ShardOffsets(1L, "bar", Array.empty[ShardInfo]))
  )

  compare(
    one = KinesisSourceOffset(new ShardOffsets(1L, "foo", Array(
      new ShardInfo("shard-001", new TrimHorizon())))),
    two = KinesisSourceOffset(new ShardOffsets(1L, "foo",
      Array(new ShardInfo("shard-001", new TrimHorizon()),
        new ShardInfo("shard-002", new TrimHorizon()) )))
  )
  var shardInfo1 = Array.empty[ShardInfo]
  shardInfo1 = shardInfo1 ++ Array(ShardInfo("shard-001", "AFTER_SEQUENCE_NUMBER", "1234"))

  val kso1 = KinesisSourceOffset(
    new ShardOffsets(1L, "foo", shardInfo1))

  val shardInfo2 = shardInfo1 ++ Array(ShardInfo("shard-002", "TRIM_HORIZON", ""))
  val kso2 = KinesisSourceOffset(
    new ShardOffsets(1L, "bar", shardInfo2))

  val shardInfo3 = shardInfo2 ++ Array(ShardInfo("shard-003", "AFTER_SEQUENCE_NUMBER", "2342"))
  val kso3 = KinesisSourceOffset(
    new ShardOffsets(1L, "bar", shardInfo3)
  )

  compare(KinesisSourceOffset(SerializedOffset(kso1.json)), kso2)

  test("basic serialization - deserialization") {
    assert(KinesisSourceOffset.getShardOffsets(kso1) ==
      KinesisSourceOffset.getShardOffsets(SerializedOffset(kso1.json)))
  }

  test("OffsetSeqLog serialization - deserialization") {
    withTempDir { temp =>
      // use non-existent directory to test whether log make the dir
      val dir = new File(temp, "dir")
      val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
      val batch0 = OffsetSeq.fill(kso1)
      val batch1 = OffsetSeq.fill(kso2, kso3)

      val batch0Serialized = OffsetSeq.fill(batch0.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      val batch1Serialized = OffsetSeq.fill(batch1.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      assert(metadataLog.add(0, batch0))
      assert(metadataLog.getLatest() === Some(0 -> batch0Serialized))
      assert(metadataLog.get(0) === Some(batch0Serialized))

      assert(metadataLog.add(1, batch1))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))

      // Adding the same batch does nothing
      metadataLog.add(1, OffsetSeq.fill(LongOffset(3)))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))
    }
  }


} 
Example 6
Source File: ShardSyncerSuite.scala    From kinesis-sql   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kinesis

import com.amazonaws.services.kinesis.model.{SequenceNumberRange, Shard}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext

class ShardSyncerSuite extends SparkFunSuite with SharedSQLContext {

  val latestShards = Seq(createShard("shard1", "1"))
  val prevShardInfo = Seq(new ShardInfo("shard0", new AfterSequenceNumber("0")))

  test("Should error out when failondataloss is true and a shard is deleted") {
    val ex = intercept[ IllegalStateException ] {
      ShardSyncer.getLatestShardInfo(latestShards, prevShardInfo,
        InitialKinesisPosition.fromPredefPosition(new TrimHorizon), true)
    }
  }

  test("Should error out when failondataloss is false and a shard is deleted") {
    val expectedShardInfo = Seq(new ShardInfo("Shard1", new TrimHorizon))
    val latest: Seq[ShardInfo] = ShardSyncer.getLatestShardInfo(
      latestShards, prevShardInfo, InitialKinesisPosition.fromPredefPosition(new TrimHorizon),
      false)
    assert(latest.nonEmpty)
    assert(latest(0).shardId === "Shard1")
    assert(latest(0).iteratorType === new TrimHorizon().iteratorType )
  }

  private def createShard(shardId: String, seqNum: String): Shard = {
    new Shard()
      .withShardId("Shard1")
      .withSequenceNumberRange(
        new SequenceNumberRange().withStartingSequenceNumber("1")
      )
  }

} 
Example 7
Source File: HDFSMetaDataCommiterSuite.scala    From kinesis-sql   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kinesis

import java.io.File

import org.apache.hadoop.conf.Configuration
import scala.language.implicitConversions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.SerializableConfiguration


class HDFSMetaDataCommiterSuite extends SparkFunSuite with SharedSQLContext {

  val testConf: Configuration = new Configuration()
  val serializedConf = new SerializableConfiguration(testConf)

  test("Add and Get operation") {
    withTempDir { temp =>
      val dir = new File(temp, "commit")
      val metadataCommitter = new HDFSMetadataCommitter[String](dir.getAbsolutePath, serializedConf)
      assert(metadataCommitter.add(0, "Shard-000001", "foo"))
      assert(metadataCommitter.get(0) === Seq("foo"))

      assert(metadataCommitter.add(1, "Shard-000001", "one"))
      assert(metadataCommitter.add(1, "Shard-000002", "two"))
      assert(metadataCommitter.get(1).toSet === Set("one", "two"))

      // Adding the same batch over-writes the previous entry
      // This is required since re-attempt of a failed task will
      // update in same location
      assert(metadataCommitter.add(1, "Shard-000001", "updated-one"))
      assert(metadataCommitter.get(1).toSet === Set("updated-one", "two"))
    }
  }

  test("Purge operation") {
    withTempDir { temp =>
      val metadataCommitter = new HDFSMetadataCommitter[String](
        temp.getAbsolutePath, serializedConf)

      assert(metadataCommitter.add(0, "Shard-000001", "one"))
      assert(metadataCommitter.add(1, "Shard-000001", "two"))
      assert(metadataCommitter.add(2, "Shard-000001", "three"))

      assert(metadataCommitter.get(0).nonEmpty)
      assert(metadataCommitter.get(1).nonEmpty)
      assert(metadataCommitter.get(2).nonEmpty)

      metadataCommitter.purge(2)
      assertThrows[IllegalStateException](metadataCommitter.get(0))
      assertThrows[IllegalStateException](metadataCommitter.get(1))
      assert(metadataCommitter.get(2).nonEmpty)

      // There should be exactly one file, called "2", in the metadata directory.
      val allFiles = new File(metadataCommitter.metadataPath.toString).listFiles().toSeq
      assert(allFiles.size == 1)
      assert(allFiles.head.getName == "2")
    }
  }
} 
Example 8
Source File: JsonFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("function get_json_object") {
    val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b")
    checkAnswer(
      df.selectExpr("get_json_object(a, '$.name')", "get_json_object(a, '$.age')"),
      Row("alice", "5"))
  }


  val tuples: Seq[(String, String)] =
    ("1", """{"f1": "value1", "f2": "value2", "f3": 3, "f5": 5.23}""") ::
    ("2", """{"f1": "value12", "f3": "value3", "f2": 2, "f4": 4.01}""") ::
    ("3", """{"f1": "value13", "f4": "value44", "f3": "value33", "f2": 2, "f5": 5.01}""") ::
    ("4", null) ::
    ("5", """{"f1": "", "f5": null}""") ::
    ("6", "[invalid JSON string]") ::
    Nil

  test("function get_json_object - null") {
    val df: DataFrame = tuples.toDF("key", "jstring")
    val expected =
      Row("1", "value1", "value2", "3", null, "5.23") ::
        Row("2", "value12", "2", "value3", "4.01", null) ::
        Row("3", "value13", "2", "value33", "value44", "5.01") ::
        Row("4", null, null, null, null, null) ::
        Row("5", "", null, null, null, null) ::
        Row("6", null, null, null, null, null) ::
        Nil

    checkAnswer(
      df.select($"key", functions.get_json_object($"jstring", "$.f1"),
        functions.get_json_object($"jstring", "$.f2"),
        functions.get_json_object($"jstring", "$.f3"),
        functions.get_json_object($"jstring", "$.f4"),
        functions.get_json_object($"jstring", "$.f5")),
      expected)
  }

  test("json_tuple select") {
    val df: DataFrame = tuples.toDF("key", "jstring")
    val expected =
      Row("1", "value1", "value2", "3", null, "5.23") ::
      Row("2", "value12", "2", "value3", "4.01", null) ::
      Row("3", "value13", "2", "value33", "value44", "5.01") ::
      Row("4", null, null, null, null, null) ::
      Row("5", "", null, null, null, null) ::
      Row("6", null, null, null, null, null) ::
      Nil

    checkAnswer(
      df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")),
      expected)
  }

  test("json_tuple filter and group") {
    val df: DataFrame = tuples.toDF("key", "jstring")
    val expr = df
      .select(functions.json_tuple($"jstring", "f1", "f2"))
      .where($"c0".isNotNull)
      .groupBy($"c1")
      .count()

    val expected = Row(null, 1) ::
      Row("2", 2) ::
      Row("value2", 1) ::
      Nil

    checkAnswer(expr, expected)
  }
} 
Example 9
Source File: SQLUtilsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.api.r

import org.apache.spark.sql.test.SharedSQLContext

class SQLUtilsSuite extends SharedSQLContext {

  import testImplicits._

  test("dfToCols should collect and transpose a data frame") {
    val df = Seq(
      (1, 2, 3),
      (4, 5, 6)
    ).toDF
    assert(SQLUtils.dfToCols(df) === Array(
      Array(1, 4),
      Array(2, 5),
      Array(3, 6)
    ))
  }

} 
Example 10
Source File: DataFramePivotSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext

class DataFramePivotSuite extends QueryTest with SharedSQLContext{
  import testImplicits._

  test("pivot courses with literals") {
    checkAnswer(
      courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
        .agg(sum($"earnings")),
      Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
    )
  }

  test("pivot year with literals") {
    checkAnswer(
      courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
    )
  }

  test("pivot courses with literals and multiple aggregations") {
    checkAnswer(
      courseSales.groupBy($"year")
        .pivot("course", Seq("dotNET", "Java"))
        .agg(sum($"earnings"), avg($"earnings")),
      Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
        Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
    )
  }

  test("pivot year with string values (cast)") {
    checkAnswer(
      courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"),
      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
    )
  }

  test("pivot year with int values") {
    checkAnswer(
      courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"),
      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
    )
  }

  test("pivot courses with no values") {
    // Note Java comes before dotNet in sorted order
    checkAnswer(
      courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
      Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
    )
  }

  test("pivot year with no values") {
    checkAnswer(
      courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
      Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
    )
  }

  test("pivot max values enforced") {
    sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
    intercept[AnalysisException](
      courseSales.groupBy("year").pivot("course")
    )
    sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
      SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
  }

  test("pivot with UnresolvedFunction") {
    checkAnswer(
      courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
        .agg("earnings" -> "sum"),
      Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
    )
  }
} 
Example 11
Source File: DataFrameImplicitsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("RDD of tuples") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("Seq of tuples") {
    checkAnswer(
      (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("RDD[Int]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).toDF("intCol"),
      (1 to 10).map(i => Row(i)))
  }

  test("RDD[Long]") {
    checkAnswer(
      sparkContext.parallelize(1L to 10L).toDF("longCol"),
      (1L to 10L).map(i => Row(i)))
  }

  test("RDD[String]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
      (1 to 10).map(i => Row(i.toString)))
  }
} 
Example 12
Source File: SemiJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }

  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 13
Source File: JsonParsingOptionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.json

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SharedSQLContext

 'Reynold Xin'}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.option("allowComments", "true").json(rdd)

    assert(df.schema.head.name == "name")
    assert(df.first().getString(0) == "Reynold Xin")
  }

  test("allowSingleQuotes off") {
    val str = """{'name': 'Reynold Xin'}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.option("allowSingleQuotes", "false").json(rdd)

    assert(df.schema.head.name == "_corrupt_record")
  }

  test("allowSingleQuotes on") {
    val str = """{'name': 'Reynold Xin'}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.json(rdd)

    assert(df.schema.head.name == "name")
    assert(df.first().getString(0) == "Reynold Xin")
  }

  test("allowUnquotedFieldNames off") {
    val str = """{name: 'Reynold Xin'}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.json(rdd)

    assert(df.schema.head.name == "_corrupt_record")
  }

  test("allowUnquotedFieldNames on") {
    val str = """{name: 'Reynold Xin'}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.option("allowUnquotedFieldNames", "true").json(rdd)

    assert(df.schema.head.name == "name")
    assert(df.first().getString(0) == "Reynold Xin")
  }

  test("allowNumericLeadingZeros off") {
    val str = """{"age": 0018}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.json(rdd)

    assert(df.schema.head.name == "_corrupt_record")
  }

  test("allowNumericLeadingZeros on") {
    val str = """{"age": 0018}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.option("allowNumericLeadingZeros", "true").json(rdd)

    assert(df.schema.head.name == "age")
    assert(df.first().getLong(0) == 18)
  }

  // The following two tests are not really working - need to look into Jackson's
  // JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS.
  ignore("allowNonNumericNumbers off") {
    val str = """{"age": NaN}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.json(rdd)

    assert(df.schema.head.name == "_corrupt_record")
  }

  ignore("allowNonNumericNumbers on") {
    val str = """{"age": NaN}"""
    val rdd = sqlContext.sparkContext.parallelize(Seq(str))
    val df = sqlContext.read.option("allowNonNumericNumbers", "true").json(rdd)

    assert(df.schema.head.name == "age")
    assert(df.first().getDouble(0).isNaN)
  }
} 
Example 14
Source File: TextSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.text

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.util.Utils


class TextSuite extends QueryTest with SharedSQLContext {

  test("reading text file") {
    verifyFrame(sqlContext.read.format("text").load(testFile))
  }

  test("SQLContext.read.text() API") {
    verifyFrame(sqlContext.read.text(testFile))
  }

  test("SPARK-12562 verify write.text() can handle column name beyond `value`") {
    val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf")

    val tempFile = Utils.createTempDir()
    tempFile.delete()
    df.write.text(tempFile.getCanonicalPath)
    verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath))

    Utils.deleteRecursively(tempFile)
  }

  test("error handling for invalid schema") {
    val tempFile = Utils.createTempDir()
    tempFile.delete()

    val df = sqlContext.range(2)
    intercept[AnalysisException] {
      df.write.text(tempFile.getCanonicalPath)
    }

    intercept[AnalysisException] {
      sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath)
    }
  }

  private def testFile: String = {
    Thread.currentThread().getContextClassLoader.getResource("text-suite.txt").toString
  }

  
  private def verifyFrame(df: DataFrame): Unit = {
    // schema
    assert(df.schema == new StructType().add("value", StringType))

    // verify content
    val data = df.collect()
    assert(data(0) == Row("This is a test file for the text data source"))
    assert(data(1) == Row("1+1"))
    // non ascii characters are not allowed in the code, so we disable the scalastyle here.
    // scalastyle:off
    assert(data(2) == Row("数据砖头"))
    // scalastyle:on
    assert(data(3) == Row("\"doh\""))
    assert(data.length == 4)
  }
} 
Example 15
Source File: ParquetInteroperabilitySuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import java.io.File

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("parquet files with different physical schemas but share the same logical schema") {
    import ParquetCompatibilityTest._

    // This test case writes two Parquet files, both representing the following Catalyst schema
    //
    //   StructType(
    //     StructField(
    //       "f",
    //       ArrayType(IntegerType, containsNull = false),
    //       nullable = false))
    //
    // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the
    // other one comes with parquet-protobuf style 1-level unannotated primitive field.
    withTempDir { dir =>
      val avroStylePath = new File(dir, "avro-style").getCanonicalPath
      val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath

      val avroStyleSchema =
        """message avro_style {
          |  required group f (LIST) {
          |    repeated int32 array;
          |  }
          |}
        """.stripMargin

      writeDirect(avroStylePath, avroStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.group {
              rc.field("array", 0) {
                rc.addInteger(0)
                rc.addInteger(1)
              }
            }
          }
        }
      })

      logParquetSchema(avroStylePath)

      val protobufStyleSchema =
        """message protobuf_style {
          |  repeated int32 f;
          |}
        """.stripMargin

      writeDirect(protobufStylePath, protobufStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.addInteger(2)
            rc.addInteger(3)
          }
        }
      })

      logParquetSchema(protobufStylePath)

      checkAnswer(
        sqlContext.read.parquet(dir.getCanonicalPath),
        Seq(
          Row(Seq(0, 1)),
          Row(Seq(2, 3))))
    }
  }
} 
Example 16
Source File: ParquetProtobufCompatibilitySuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("unannotated array of primitive type") {
    checkAnswer(readResourceParquetFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }

  test("unannotated array of struct") {
    checkAnswer(
      readResourceParquetFile("old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readResourceParquetFile("proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readResourceParquetFile("proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }

  test("struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }

  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }

  test("unannotated array of string") {
    checkAnswer(
      readResourceParquetFile("proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 17
Source File: GroupedDatasetSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.plans.logical.AnalysisBarrier
import org.apache.spark.sql.execution.python.PythonUDF
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{LongType, StructField, StructType}

class GroupedDatasetSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  private val scalaUDF = udf((x: Long) => { x + 1 })
  private lazy val datasetWithUDF = spark.range(1).toDF("s").select($"s", scalaUDF($"s"))

  private def assertContainsAnalysisBarrier(ds: Dataset[_], atLevel: Int = 1): Unit = {
    assert(atLevel >= 0)
    var children = Seq(ds.queryExecution.logical)
    (1 to atLevel).foreach { _ =>
      children = children.flatMap(_.children)
    }
    val barriers = children.collect {
      case ab: AnalysisBarrier => ab
    }
    assert(barriers.nonEmpty, s"Plan does not contain AnalysisBarrier at level $atLevel:\n" +
      ds.queryExecution.logical)
  }

  test("SPARK-24373: avoid running Analyzer rules twice on RelationalGroupedDataset") {
    val groupByDataset = datasetWithUDF.groupBy()
    val rollupDataset = datasetWithUDF.rollup("s")
    val cubeDataset = datasetWithUDF.cube("s")
    val pivotDataset = datasetWithUDF.groupBy().pivot("s", Seq(1, 2))
    datasetWithUDF.cache()
    Seq(groupByDataset, rollupDataset, cubeDataset, pivotDataset).foreach { rgDS =>
      val df = rgDS.count()
      assertContainsAnalysisBarrier(df)
      assertCached(df)
    }

    val flatMapGroupsInRDF = datasetWithUDF.groupBy().flatMapGroupsInR(
      Array.emptyByteArray,
      Array.emptyByteArray,
      Array.empty,
      StructType(Seq(StructField("s", LongType))))
    val flatMapGroupsInPandasDF = datasetWithUDF.groupBy().flatMapGroupsInPandas(PythonUDF(
      "pyUDF",
      null,
      StructType(Seq(StructField("s", LongType))),
      Seq.empty,
      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
      true))
    Seq(flatMapGroupsInRDF, flatMapGroupsInPandasDF).foreach { df =>
      assertContainsAnalysisBarrier(df, 2)
      assertCached(df)
    }
    datasetWithUDF.unpersist(true)
  }

  test("SPARK-24373: avoid running Analyzer rules twice on KeyValueGroupedDataset") {
    val kvDasaset = datasetWithUDF.groupByKey(_.getLong(0))
    datasetWithUDF.cache()
    val mapValuesKVDataset = kvDasaset.mapValues(_.getLong(0)).reduceGroups(_ + _)
    val keysKVDataset = kvDasaset.keys
    val flatMapGroupsKVDataset = kvDasaset.flatMapGroups((k, _) => Seq(k))
    val aggKVDataset = kvDasaset.count()
    val otherKVDataset = spark.range(1).groupByKey(_ + 1)
    val cogroupKVDataset = kvDasaset.cogroup(otherKVDataset)((k, _, _) => Seq(k))
    Seq((mapValuesKVDataset, 1),
        (keysKVDataset, 2),
        (flatMapGroupsKVDataset, 2),
        (aggKVDataset, 1),
        (cogroupKVDataset, 2)).foreach { case (df, analysisBarrierDepth) =>
      assertContainsAnalysisBarrier(df, analysisBarrierDepth)
      assertCached(df)
    }
    datasetWithUDF.unpersist(true)
  }
} 
Example 18
Source File: ExpandSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.IntegerType

class ExpandSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.localSeqToDataFrameHolder

  private def testExpand(f: SparkPlan => SparkPlan): Unit = {
    val input = (1 to 1000).map(Tuple1.apply)
    val projections = Seq.tabulate(2) { i =>
      Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil
    }
    val attributes = projections.head.map(_.toAttribute)
    checkAnswer(
      input.toDF(),
      plan => Expand(projections, attributes, f(plan)),
      input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j)))
    )
  }

  test("inheriting child row type") {
    val exprs = AttributeReference("a", IntegerType, false)() :: Nil
    val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty)))
    assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.")
  }

  test("expanding UnsafeRows") {
    testExpand(ConvertToUnsafe)
  }

  test("expanding SafeRows") {
    testExpand(identity)
  }
} 
Example 19
Source File: SortSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.sql.{RandomDataGenerator, Row}



class SortSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.localSeqToDataFrameHolder

  test("basic sorting using ExternalSort") {

    val input = Seq(
      ("Hello", 4, 2.0),
      ("Hello", 1, 1.0),
      ("World", 8, 3.0)
    )

    checkAnswer(
      input.toDF("a", "b", "c"),
      (child: SparkPlan) => Sort('a.asc :: 'b.asc :: Nil, global = true, child = child),
      input.sortBy(t => (t._1, t._2)).map(Row.fromTuple),
      sortAnswers = false)

    checkAnswer(
      input.toDF("a", "b", "c"),
      (child: SparkPlan) => Sort('b.asc :: 'a.asc :: Nil, global = true, child = child),
      input.sortBy(t => (t._2, t._1)).map(Row.fromTuple),
      sortAnswers = false)
  }

  test("sort followed by limit") {
    checkThatPlansAgree(
      (1 to 100).map(v => Tuple1(v)).toDF("a"),
      (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child = child)),
      (child: SparkPlan) => Limit(10, ReferenceSort('a.asc :: Nil, global = true, child)),
      sortAnswers = false
    )
  }

  test("sorting does not crash for large inputs") {
    val sortOrder = 'a.asc :: Nil
    val stringLength = 1024 * 1024 * 2
    checkThatPlansAgree(
      Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
      Sort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
      ReferenceSort(sortOrder, global = true, _: SparkPlan),
      sortAnswers = false
    )
  }

  test("sorting updates peak execution memory") {
    AccumulatorSuite.verifyPeakExecutionMemorySet(sparkContext, "unsafe external sort") {
      checkThatPlansAgree(
        (1 to 100).map(v => Tuple1(v)).toDF("a"),
        (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child = child),
        (child: SparkPlan) => ReferenceSort('a.asc :: Nil, global = true, child),
        sortAnswers = false)
    }
  }

  // Test sorting on different data types
  for (
    dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
    nullable <- Seq(true, false);
    sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
    randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
  ) {
    test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
      val inputData = Seq.fill(1000)(randomDataGenerator())
      val inputDf = sqlContext.createDataFrame(
        sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
        StructType(StructField("a", dataType, nullable = true) :: Nil)
      )
      checkThatPlansAgree(
        inputDf,
        p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)),
        ReferenceSort(sortOrder, global = true, _: SparkPlan),
        sortAnswers = false
      )
    }
  }
} 
Example 20
Source File: DataFrameTungstenSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("test simple types") {
    val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
    assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
  }

  test("test struct type") {
    val struct = Row(1, 2L, 3.0F, 3.0)
    val data = sparkContext.parallelize(Seq(Row(1, struct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType))

    val df = sqlContext.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(struct))
  }

  test("test nested struct type") {
    val innerStruct = Row(1, "abcd")
    val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
    val data = sparkContext.parallelize(Seq(Row(1, outerStruct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType)
          .add("b5", new StructType()
          .add("b5a", IntegerType)
          .add("b5b", StringType))
          .add("b6", StringType))

    val df = sqlContext.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(outerStruct))
  }
} 
Example 21
Source File: RowSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericMutableRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificMutableRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("serialize w/ kryo") {
    val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first()
    val serializer = new SparkSqlSerializer(sparkContext.getConf)
    val instance = serializer.newInstance()
    val ser = instance.serialize(row)
    val de = instance.deserialize(ser).asInstanceOf[Row]
    assert(de === row)
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 22
Source File: ExtraStrategiesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package test.org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.UTF8String

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    sparkContext.parallelize(Seq(row))
  }

  override def children: Seq[SparkPlan] = Nil
}

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
  }
}

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      sqlContext.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      sqlContext.experimental.extraStrategies = Nil
    }
  }
} 
Example 23
Source File: DatasetCacheSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import scala.language.postfixOps

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext


class DatasetCacheSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("persist and unpersist") {
    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
    val cached = ds.cache()
    // count triggers the caching action. It should not throw.
    cached.count()
    // Make sure, the Dataset is indeed cached.
    assertCached(cached)
    // Check result.
    checkAnswer(
      cached,
      2, 3, 4)
    // Drop the cache.
    cached.unpersist()
    assert(!sqlContext.isCached(cached), "The Dataset should not be cached.")
  }

  test("persist and then rebind right encoder when join 2 datasets") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    ds1.persist()
    assertCached(ds1)
    ds2.persist()
    assertCached(ds2)

    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    checkAnswer(joined, ("2", 2))
    assertCached(joined, 2)

    ds1.unpersist()
    assert(!sqlContext.isCached(ds1), "The Dataset ds1 should not be cached.")
    ds2.unpersist()
    assert(!sqlContext.isCached(ds2), "The Dataset ds2 should not be cached.")
  }

  test("persist and then groupBy columns asKey, map") {
    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
    val grouped = ds.groupBy($"_1").keyAs[String]
    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
    agged.persist()

    checkAnswer(
      agged.filter(_._1 == "b"),
      ("b", 3))
    assertCached(agged.filter(_._1 == "b"))

    ds.unpersist()
    assert(!sqlContext.isCached(ds), "The Dataset ds should not be cached.")
    agged.unpersist()
    assert(!sqlContext.isCached(agged), "The Dataset agged should not be cached.")
  }
} 
Example 24
Source File: ListTablesSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
import org.apache.spark.sql.catalyst.TableIdentifier

class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")

  before {
    df.registerTempTable("ListTablesSuiteTable")
  }

  after {
    sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
  }

  test("get all tables") {
    checkAnswer(
      sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
    assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("getting all Tables with a database name has no impact on returned table names") {
    checkAnswer(
      sqlContext.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    sqlContext.catalog.unregisterTable(TableIdentifier("ListTablesSuiteTable"))
    assert(sqlContext.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("query the returned DataFrame of tables") {
    val expectedSchema = StructType(
      StructField("tableName", StringType, false) ::
      StructField("isTemporary", BooleanType, false) :: Nil)

    Seq(sqlContext.tables(), sql("SHOW TABLes")).foreach {
      case tableDF =>
        assert(expectedSchema === tableDF.schema)

        tableDF.registerTempTable("tables")
        checkAnswer(
          sql(
            "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
          Row(true, "ListTablesSuiteTable")
        )
        checkAnswer(
          sqlContext.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
          Row("tables", true))
        sqlContext.dropTempTable("tables")
    }
  }
} 
Example 25
Source File: DataFrameComplexTypeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext


class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("UDF on struct") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.select(struct($"a").as("s")).select(f($"s.a")).collect()
  }

  test("UDF on named_struct") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect()
  }

  test("UDF on array") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
  }

  test("SPARK-12477 accessing null element in array field") {
    val df = sparkContext.parallelize(Seq((Seq("val1", null, "val2"),
      Seq(Some(1), None, Some(2))))).toDF("s", "i")
    val nullStringRow = df.selectExpr("s[1]").collect()(0)
    assert(nullStringRow == org.apache.spark.sql.Row(null))
    val nullIntRow = df.selectExpr("i[1]").collect()(0)
    assert(nullIntRow == org.apache.spark.sql.Row(null))
  }
} 
Example 26
Source File: DDLSourceLoadSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      caseInsensitiveContext.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    caseInsensitiveContext.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without HiveContext") {
    intercept[ClassNotFoundException] {
      caseInsensitiveContext.read.format("orc").load()
    }
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 27
Source File: DDLTestSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class DDLScanSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext)
  }
}

case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan {

  override def schema: StructType =
    StructType(Seq(
      StructField("intType", IntegerType, nullable = false,
        new MetadataBuilder().putString("comment", s"test comment $table").build()),
      StructField("stringType", StringType, nullable = false),
      StructField("dateType", DateType, nullable = false),
      StructField("timestampType", TimestampType, nullable = false),
      StructField("doubleType", DoubleType, nullable = false),
      StructField("bigintType", LongType, nullable = false),
      StructField("tinyintType", ByteType, nullable = false),
      StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
      StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
      StructField("binaryType", BinaryType, nullable = false),
      StructField("booleanType", BooleanType, nullable = false),
      StructField("smallIntType", ShortType, nullable = false),
      StructField("floatType", FloatType, nullable = false),
      StructField("mapType", MapType(StringType, StringType)),
      StructField("arrayType", ArrayType(StringType)),
      StructField("structType",
        StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
        )
      )
    ))

  override def needConversion: Boolean = false

  override def buildScan(): RDD[Row] = {
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    sqlContext.sparkContext.parallelize(from to to).map { e =>
      InternalRow(UTF8String.fromString(s"people$e"), e * 2)
    }.asInstanceOf[RDD[Row]]
  }
}

class DDLTestSuite extends DataSourceTest with SharedSQLContext {
  protected override lazy val sql = caseInsensitiveContext.sql _

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql(
      """
      |CREATE TEMPORARY TABLE ddlPeople
      |USING org.apache.spark.sql.sources.DDLScanSource
      |OPTIONS (
      |  From '1',
      |  To '10',
      |  Table 'test1'
      |)
      """.stripMargin)
  }

  sqlTest(
      "describe ddlPeople",
      Seq(
        Row("intType", "int", "test comment test1"),
        Row("stringType", "string", ""),
        Row("dateType", "date", ""),
        Row("timestampType", "timestamp", ""),
        Row("doubleType", "double", ""),
        Row("bigintType", "bigint", ""),
        Row("tinyintType", "tinyint", ""),
        Row("decimalType", "decimal(10,0)", ""),
        Row("fixedDecimalType", "decimal(5,1)", ""),
        Row("binaryType", "binary", ""),
        Row("booleanType", "boolean", ""),
        Row("smallIntType", "smallint", ""),
        Row("floatType", "float", ""),
        Row("mapType", "map<string,string>", ""),
        Row("arrayType", "array<string>", ""),
        Row("structType", "struct<f1:string,f2:int>", "")
      ))

  test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
    val attributes = sql("describe ddlPeople")
      .queryExecution.executedPlan.output
    assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
    assert(attributes.map(_.dataType).toSet === Set(StringType))
  }
} 
Example 28
Source File: PartitionedWriteSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("write many partitions") {
    val path = Utils.createTempDir()
    path.delete()

    val df = sqlContext.range(100).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      sqlContext.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("write many partitions with repeats") {
    val path = Utils.createTempDir()
    path.delete()

    val base = sqlContext.range(100)
    val df = base.unionAll(base).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      sqlContext.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("partitioned columns should appear at the end of schema") {
    withTempPath { f =>
      val path = f.getAbsolutePath
      Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path)
      assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
    }
  }
} 
Example 29
Source File: SQLConfSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext}


class SQLConfSuite extends QueryTest with SharedSQLContext {
  private val testKey = "test.key.0"
  private val testVal = "test.val.0"

  test("propagate from spark conf") {
    // We create a new context here to avoid order dependence with other tests that might call
    // clear().
    val newContext = new SQLContext(sparkContext)
    assert(newContext.getConf("spark.sql.testkey", "false") === "true")
  }

  test("programmatic ways of basic setting and getting") {
    // Set a conf first.
    sqlContext.setConf(testKey, testVal)
    // Clear the conf.
    sqlContext.conf.clear()
    // After clear, only overrideConfs used by unit test should be in the SQLConf.
    assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs)

    sqlContext.setConf(testKey, testVal)
    assert(sqlContext.getConf(testKey) === testVal)
    assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
    assert(sqlContext.getAllConfs.contains(testKey))

    // Tests SQLConf as accessed from a SQLContext is mutable after
    // the latter is initialized, unlike SparkConf inside a SparkContext.
    assert(sqlContext.getConf(testKey) === testVal)
    assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
    assert(sqlContext.getAllConfs.contains(testKey))

    sqlContext.conf.clear()
  }

  test("parse SQL set commands") {
    sqlContext.conf.clear()
    sql(s"set $testKey=$testVal")
    assert(sqlContext.getConf(testKey, testVal + "_") === testVal)
    assert(sqlContext.getConf(testKey, testVal + "_") === testVal)

    sql("set some.property=20")
    assert(sqlContext.getConf("some.property", "0") === "20")
    sql("set some.property = 40")
    assert(sqlContext.getConf("some.property", "0") === "40")

    val key = "spark.sql.key"
    val vs = "val0,val_1,val2.3,my_table"
    sql(s"set $key=$vs")
    assert(sqlContext.getConf(key, "0") === vs)

    sql(s"set $key=")
    assert(sqlContext.getConf(key, "0") === "")

    sqlContext.conf.clear()
  }

  test("deprecated property") {
    sqlContext.conf.clear()
    val original = sqlContext.conf.numShufflePartitions
    try{
      sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
      assert(sqlContext.conf.numShufflePartitions === 10)
    } finally {
      sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original")
    }
  }

  test("invalid conf value") {
    sqlContext.conf.clear()
    val e = intercept[IllegalArgumentException] {
      sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
    }
    assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10")
  }
} 
Example 30
Source File: PulsarSourceOffsetSuite.scala    From pulsar-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.pulsar

import java.io.File

import org.apache.pulsar.client.impl.MessageIdImpl

import org.apache.spark.sql.execution.streaming.{LongOffset, OffsetSeq, OffsetSeqLog, SerializedOffset}
import org.apache.spark.sql.streaming.OffsetSuite
import org.apache.spark.sql.test.SharedSQLContext

class PulsarSourceOffsetSuite extends OffsetSuite with SharedSQLContext {

  compare(
    one = SpecificPulsarOffset(("t", new MessageIdImpl(1, 1, -1))),
    two = SpecificPulsarOffset(("t", new MessageIdImpl(1, 2, -1))))

  compare(
    one = SpecificPulsarOffset(
      ("t", new MessageIdImpl(1, 1, -1)),
      ("t1", new MessageIdImpl(1, 1, -1))),
    two = SpecificPulsarOffset(
      ("t", new MessageIdImpl(1, 2, -1)),
      ("t1", new MessageIdImpl(1, 2, -1)))
  )

  compare(
    one = SpecificPulsarOffset(("t", new MessageIdImpl(1, 1, -1))),
    two = SpecificPulsarOffset(
      ("t", new MessageIdImpl(1, 2, -1)),
      ("t1", new MessageIdImpl(1, 1, -1))))

  val kso1 = SpecificPulsarOffset(("t", new MessageIdImpl(1, 1, -1)))
  val kso2 =
    SpecificPulsarOffset(("t", new MessageIdImpl(1, 2, -1)), ("t1", new MessageIdImpl(1, 3, -1)))
  val kso3 = SpecificPulsarOffset(
    ("t", new MessageIdImpl(1, 2, -1)),
    ("t1", new MessageIdImpl(1, 3, -1)),
    ("t2", new MessageIdImpl(1, 4, -1)))

  compare(
    SpecificPulsarOffset(SerializedOffset(kso1.json)),
    SpecificPulsarOffset(SerializedOffset(kso2.json)))

  test("basic serialization - deserialization") {
    assert(
      SpecificPulsarOffset.getTopicOffsets(kso1) ==
        SpecificPulsarOffset.getTopicOffsets(SerializedOffset(kso1.json)))
  }

  test("OffsetSeqLog serialization - deserialization") {
    withTempDir { temp =>
      // use non-existent directory to test whether log make the dir
      val dir = new File(temp, "dir")
      val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
      val batch0 = OffsetSeq.fill(kso1)
      val batch1 = OffsetSeq.fill(kso2, kso3)

      val batch0Serialized =
        OffsetSeq.fill(batch0.offsets.flatMap(_.map(o => SerializedOffset(o.json))): _*)

      val batch1Serialized =
        OffsetSeq.fill(batch1.offsets.flatMap(_.map(o => SerializedOffset(o.json))): _*)

      assert(metadataLog.add(0, batch0))
      assert(metadataLog.getLatest() === Some(0 -> batch0Serialized))
      assert(metadataLog.get(0) === Some(batch0Serialized))

      assert(metadataLog.add(1, batch1))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(
        metadataLog.get(None, Some(1)) ===
          Array(0 -> batch0Serialized, 1 -> batch1Serialized))

      // Adding the same batch does nothing
      metadataLog.add(1, OffsetSeq.fill(LongOffset(3)))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(
        metadataLog.get(None, Some(1)) ===
          Array(0 -> batch0Serialized, 1 -> batch1Serialized))
    }
  }

} 
Example 31
Source File: CachedPulsarClientSuite.scala    From pulsar-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.pulsar

import java.util.concurrent.ConcurrentMap
import java.{util => ju}

import org.scalatest.PrivateMethodTester

import org.apache.pulsar.client.api.PulsarClient
import org.apache.spark.sql.test.SharedSQLContext

class CachedPulsarClientSuite extends SharedSQLContext with PrivateMethodTester with PulsarTest {

  import PulsarOptions._

  type KP = PulsarClient

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    CachedPulsarClient.clear()
  }

  test("Should return the cached instance on calling getOrCreate with same params.") {
    val pulsarParams = new ju.HashMap[String, Object]()
    // Here only host should be resolvable, it does not need a running instance of pulsar server.
    pulsarParams.put(SERVICE_URL_OPTION_KEY, "pulsar://127.0.0.1:6650")
    pulsarParams.put("concurrentLookupRequest", "10000")
    val producer = CachedPulsarClient.getOrCreate(pulsarParams)
    val producer2 = CachedPulsarClient.getOrCreate(pulsarParams)
    assert(producer == producer2)

    val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap)
    val map = CachedPulsarClient.invokePrivate(cacheMap())
    assert(map.size == 1)
  }

  test("Should close the correct pulsar producer for the given pulsarPrams.") {
    val pulsarParams = new ju.HashMap[String, Object]()
    pulsarParams.put(SERVICE_URL_OPTION_KEY, "pulsar://127.0.0.1:6650")
    pulsarParams.put("concurrentLookupRequest", "10000")
    val producer: KP = CachedPulsarClient.getOrCreate(pulsarParams)
    pulsarParams.put("concurrentLookupRequest", "20000")
    val producer2: KP = CachedPulsarClient.getOrCreate(pulsarParams)
    // With updated conf, a new producer instance should be created.
    assert(producer != producer2)

    val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap)
    val map = CachedPulsarClient.invokePrivate(cacheMap())
    assert(map.size == 2)

    CachedPulsarClient.close(pulsarParams)
    val map2 = CachedPulsarClient.invokePrivate(cacheMap())
    assert(map2.size == 1)
    import scala.collection.JavaConverters._
    val (seq: Seq[(String, Object)], _producer: KP) = map2.asScala.toArray.apply(0)
    assert(_producer == producer)
  }
} 
Example 32
Source File: ParquetEncodingSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.JavaConverters._

import org.apache.parquet.hadoop.ParquetOutputFormat

import org.apache.spark.sql.test.SharedSQLContext

// TODO: this needs a lot more testing but it's currently not easy to test with the parquet
// writer abstractions. Revisit.
class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext {
  import testImplicits._

  val ROW = ((1).toByte, 2, 3L, "abc")
  val NULL_ROW = (
    null.asInstanceOf[java.lang.Byte],
    null.asInstanceOf[Integer],
    null.asInstanceOf[java.lang.Long],
    null.asInstanceOf[String])

  test("All Types Dictionary") {
    (1 :: 1000 :: Nil).foreach { n => {
      withTempPath { dir =>
        List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled)
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).getByte(i) == 1)
          assert(batch.column(1).getInt(i) == 2)
          assert(batch.column(2).getLong(i) == 3)
          assert(batch.column(3).getUTF8String(i).toString == "abc")
          i += 1
        }
        reader.close()
      }
    }}
  }

  test("All Types Null") {
    (1 :: 100 :: Nil).foreach { n => {
      withTempPath { dir =>
        val data = List.fill(n)(NULL_ROW).toDF
        data.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled)
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).isNullAt(i))
          assert(batch.column(1).isNullAt(i))
          assert(batch.column(2).isNullAt(i))
          assert(batch.column(3).isNullAt(i))
          i += 1
        }
        reader.close()
      }}
    }
  }

  test("Read row group containing both dictionary and plain encoded pages") {
    withSQLConf(ParquetOutputFormat.DICTIONARY_PAGE_SIZE -> "2048",
      ParquetOutputFormat.PAGE_SIZE -> "4096") {
      withTempPath { dir =>
        // In order to explicitly test for SPARK-14217, we set the parquet dictionary and page size
        // such that the following data spans across 3 pages (within a single row group) where the
        // first page is dictionary encoded and the remaining two are plain encoded.
        val data = (0 until 512).flatMap(i => Seq.fill(3)(i.toString))
        data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head

        val reader = new VectorizedParquetRecordReader(sqlContext.conf.offHeapColumnVectorEnabled)
        reader.initialize(file, null )
        val column = reader.resultBatch().column(0)
        assert(reader.nextBatch())

        (0 until 512).foreach { i =>
          assert(column.getUTF8String(3 * i).toString == i.toString)
          assert(column.getUTF8String(3 * i + 1).toString == i.toString)
          assert(column.getUTF8String(3 * i + 2).toString == i.toString)
        }
        reader.close()
      }
    }
  }
} 
Example 33
Source File: PartitionedWriteSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
//分区写测试
class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("write many partitions") {//写入更多分区
    val path = Utils.createTempDir()
    path.delete()
    //val df: DataFrame
    val df = ctx.range(100).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      ctx.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("write many partitions with repeats") {//用重复写多个分区
    val path = Utils.createTempDir()
    path.delete()

    val base = ctx.range(100)
    val df = base.unionAll(base).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      ctx.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }
} 
Example 34
Source File: SQLConfSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

//SQL配置测试套件
class SQLConfSuite extends QueryTest with SharedSQLContext {
  private val testKey = "test.key.0"
  private val testVal = "test.val.0"

  test("propagate from spark conf") {//传播Spark配置文件
    // We create a new context here to avoid order dependence with other tests that might call
    // clear().
    //我们在这里创建一个新的上下文,以避免与其他可能调用的测试顺序依赖关系
    val newContext = new SQLContext(ctx.sparkContext)
    assert(newContext.getConf("spark.sql.testkey", "false") === "true")
  }

  test("programmatic ways of basic setting and getting") {//编程方式的基本设置和获取
    ctx.conf.clear()
    assert(ctx.getAllConfs.size === 0)

    ctx.setConf(testKey, testVal)
    assert(ctx.getConf(testKey) === testVal)
    assert(ctx.getConf(testKey, testVal + "_") === testVal)
    assert(ctx.getAllConfs.contains(testKey))

    // Tests SQLConf as accessed from a SQLContext is mutable after
    // the latter is initialized, unlike SparkConf inside a SparkContext.
    //测试sqlconf作为访问一个sqlcontext改变后被初始化
    assert(ctx.getConf(testKey) == testVal)
    assert(ctx.getConf(testKey, testVal + "_") === testVal)
    assert(ctx.getAllConfs.contains(testKey))

    ctx.conf.clear()
  }

  test("parse SQL set commands") {//解析SQL命令集
    ctx.conf.clear()
    sql(s"set $testKey=$testVal")
    assert(ctx.getConf(testKey, testVal + "_") === testVal)
    assert(ctx.getConf(testKey, testVal + "_") === testVal)

    sql("set some.property=20")
    assert(ctx.getConf("some.property", "0") === "20")
    sql("set some.property = 40")
    assert(ctx.getConf("some.property", "0") === "40")

    val key = "spark.sql.key"
    val vs = "val0,val_1,val2.3,my_table"
    sql(s"set $key=$vs")
    assert(ctx.getConf(key, "0") === vs)

    sql(s"set $key=")
    assert(ctx.getConf(key, "0") === "")

    ctx.conf.clear()
  }

  test("deprecated property") {//不赞成的属性
    ctx.conf.clear()
    sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
    assert(ctx.conf.numShufflePartitions === 10)
  }

  test("invalid conf value") {//无效配置值
    ctx.conf.clear()
    val e = intercept[IllegalArgumentException] {
      sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
    }
    assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10")
  }
} 
Example 35
Source File: KafkaSourceOffsetSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.io.File

import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.streaming.OffsetSuite
import org.apache.spark.sql.test.SharedSQLContext

class KafkaSourceOffsetSuite extends OffsetSuite with SharedSQLContext {

  compare(
    one = KafkaSourceOffset(("t", 0, 1L)),
    two = KafkaSourceOffset(("t", 0, 2L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L), ("t", 1, 0L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L), ("T", 0, 0L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("T", 0, 1L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L)))


  val kso1 = KafkaSourceOffset(("t", 0, 1L))
  val kso2 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L))
  val kso3 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L), ("t", 1, 4L))

  compare(KafkaSourceOffset(SerializedOffset(kso1.json)),
    KafkaSourceOffset(SerializedOffset(kso2.json)))

  test("basic serialization - deserialization") {
    assert(KafkaSourceOffset.getPartitionOffsets(kso1) ==
      KafkaSourceOffset.getPartitionOffsets(SerializedOffset(kso1.json)))
  }


  test("OffsetSeqLog serialization - deserialization") {
    withTempDir { temp =>
      // use non-existent directory to test whether log make the dir
      val dir = new File(temp, "dir")
      val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
      val batch0 = OffsetSeq.fill(kso1)
      val batch1 = OffsetSeq.fill(kso2, kso3)

      val batch0Serialized = OffsetSeq.fill(batch0.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      val batch1Serialized = OffsetSeq.fill(batch1.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      assert(metadataLog.add(0, batch0))
      assert(metadataLog.getLatest() === Some(0 -> batch0Serialized))
      assert(metadataLog.get(0) === Some(batch0Serialized))

      assert(metadataLog.add(1, batch1))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))

      // Adding the same batch does nothing
      metadataLog.add(1, OffsetSeq.fill(LongOffset(3)))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))
    }
  }

  test("read Spark 2.1.0 offset format") {
    val offset = readFromResource("kafka-source-offset-version-2.1.0.txt")
    assert(KafkaSourceOffset(offset) ===
      KafkaSourceOffset(("topic1", 0, 456L), ("topic1", 1, 789L), ("topic2", 0, 0L)))
  }

  private def readFromResource(file: String): SerializedOffset = {
    import scala.io.Source
    val input = getClass.getResource(s"/$file").toURI
    val str = Source.fromFile(input).mkString
    SerializedOffset(str)
  }
} 
Example 36
Source File: CachedKafkaProducerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.{util => ju}
import java.util.concurrent.ConcurrentMap

import org.apache.kafka.clients.producer.KafkaProducer
import org.apache.kafka.common.serialization.ByteArraySerializer
import org.scalatest.PrivateMethodTester

import org.apache.spark.sql.test.SharedSQLContext

class CachedKafkaProducerSuite extends SharedSQLContext with PrivateMethodTester {

  type KP = KafkaProducer[Array[Byte], Array[Byte]]

  protected override def beforeEach(): Unit = {
    super.beforeEach()
    val clear = PrivateMethod[Unit]('clear)
    CachedKafkaProducer.invokePrivate(clear())
  }

  test("Should return the cached instance on calling getOrCreate with same params.") {
    val kafkaParams = new ju.HashMap[String, Object]()
    kafkaParams.put("acks", "0")
    // Here only host should be resolvable, it does not need a running instance of kafka server.
    kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
    kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
    kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
    val producer = CachedKafkaProducer.getOrCreate(kafkaParams)
    val producer2 = CachedKafkaProducer.getOrCreate(kafkaParams)
    assert(producer == producer2)

    val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap)
    val map = CachedKafkaProducer.invokePrivate(cacheMap())
    assert(map.size == 1)
  }

  test("Should close the correct kafka producer for the given kafkaPrams.") {
    val kafkaParams = new ju.HashMap[String, Object]()
    kafkaParams.put("acks", "0")
    kafkaParams.put("bootstrap.servers", "127.0.0.1:9022")
    kafkaParams.put("key.serializer", classOf[ByteArraySerializer].getName)
    kafkaParams.put("value.serializer", classOf[ByteArraySerializer].getName)
    val producer: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
    kafkaParams.put("acks", "1")
    val producer2: KP = CachedKafkaProducer.getOrCreate(kafkaParams)
    // With updated conf, a new producer instance should be created.
    assert(producer != producer2)

    val cacheMap = PrivateMethod[ConcurrentMap[Seq[(String, Object)], KP]]('getAsMap)
    val map = CachedKafkaProducer.invokePrivate(cacheMap())
    assert(map.size == 2)

    CachedKafkaProducer.close(kafkaParams)
    val map2 = CachedKafkaProducer.invokePrivate(cacheMap())
    assert(map2.size == 1)
    import scala.collection.JavaConverters._
    val (seq: Seq[(String, Object)], _producer: KP) = map2.asScala.toArray.apply(0)
    assert(_producer == producer)
  }
} 
Example 37
Source File: KafkaContinuousSourceSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import org.scalatest.time.SpanSugar._
import scala.collection.mutable
import scala.util.Random

import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}

// Run tests in KafkaSourceSuiteBase in continuous execution mode.
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest

class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
  import testImplicits._

  override val brokerProps = Map("auto.create.topics.enable" -> "false")

  test("subscribing topic by pattern with topic deletions") {
    val topicPrefix = newTopic()
    val topic = topicPrefix + "-seems"
    val topic2 = topicPrefix + "-bad"
    testUtils.createTopic(topic, partitions = 5)
    testUtils.sendMessages(topic, Array("-1"))
    require(testUtils.getLatestOffsets(Set(topic)).size === 5)

    val reader = spark
      .readStream
      .format("kafka")
      .option("kafka.bootstrap.servers", testUtils.brokerAddress)
      .option("kafka.metadata.max.age.ms", "1")
      .option("subscribePattern", s"$topicPrefix-.*")
      .option("failOnDataLoss", "false")

    val kafka = reader.load()
      .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
      .as[(String, String)]
    val mapped = kafka.map(kv => kv._2.toInt + 1)

    testStream(mapped)(
      makeSureGetOffsetCalled,
      AddKafkaData(Set(topic), 1, 2, 3),
      CheckAnswer(2, 3, 4),
      Execute { query =>
        testUtils.deleteTopic(topic)
        testUtils.createTopic(topic2, partitions = 5)
        eventually(timeout(streamingTimeout)) {
          assert(
            query.lastExecution.logical.collectFirst {
              case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
            }.exists { r =>
              // Ensure the new topic is present and the old topic is gone.
              r.knownPartitions.exists(_.topic == topic2)
            },
            s"query never reconfigured to new topic $topic2")
        }
      },
      AddKafkaData(Set(topic2), 4, 5, 6),
      CheckAnswer(2, 3, 4, 5, 6, 7)
    )
  }
}

class KafkaContinuousSourceStressForDontFailOnDataLossSuite
    extends KafkaSourceStressForDontFailOnDataLossSuite {
  override protected def startStream(ds: Dataset[Int]) = {
    ds.writeStream
      .format("memory")
      .queryName("memory")
      .trigger(Trigger.Continuous("1 second"))
      .start()
  }
} 
Example 38
Source File: BenchmarkQueryTest.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with BeforeAndAfterAll {

  // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting
  // the max iteration of analyzer/optimizer batches.
  assert(Utils.isTesting, "spark.testing is not set to true")

  
  protected override def afterAll(): Unit = {
    try {
      // For debugging dump some statistics about how much time was spent in various optimizer rules
      logWarning(RuleExecutor.dumpTimeSpent())
      spark.sessionState.catalog.reset()
    } finally {
      super.afterAll()
    }
  }

  override def beforeAll() {
    super.beforeAll()
    RuleExecutor.resetMetrics()
  }

  protected def checkGeneratedCode(plan: SparkPlan): Unit = {
    val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]()
    plan foreach {
      case s: WholeStageCodegenExec =>
        codegenSubtrees += s
      case s => s
    }
    codegenSubtrees.toSeq.foreach { subtree =>
      val code = subtree.doCodeGen()._2
      try {
        // Just check the generated code can be properly compiled
        CodeGenerator.compile(code)
      } catch {
        case e: Exception =>
          val msg =
            s"""
               |failed to compile:
               |Subtree:
               |$subtree
               |Generated code:
               |${CodeFormatter.format(code)}
             """.stripMargin
          throw new Exception(msg, e)
      }
    }
  }
} 
Example 39
Source File: SQLUtilsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.api.r

import org.apache.spark.sql.test.SharedSQLContext

class SQLUtilsSuite extends SharedSQLContext {

  import testImplicits._

  test("dfToCols should collect and transpose a data frame") {
    val df = Seq(
      (1, 2, 3),
      (4, 5, 6)
    ).toDF
    assert(SQLUtils.dfToCols(df) === Array(
      Array(1, 4),
      Array(2, 5),
      Array(3, 6)
    ))
  }

} 
Example 40
Source File: ConfigBehaviorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.commons.math3.stat.inference.ChiSquareTest

import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext


class ConfigBehaviorSuite extends QueryTest with SharedSQLContext {

  import testImplicits._

  test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition") {
    // In this test, we run a sort and compute the histogram for partition size post shuffle.
    // With a high sample count, the partition size should be more evenly distributed, and has a
    // low chi-sq test value.
    // Also the whole code path for range partitioning as implemented should be deterministic
    // (it uses the partition id as the seed), so this test shouldn't be flaky.

    val numPartitions = 4

    def computeChiSquareTest(): Double = {
      val n = 10000
      // Trigger a sort
      val data = spark.range(0, n, 1, 1).sort('id)
        .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect()

      // Compute histogram for the number of records per partition post sort
      val dist = data.groupBy(_._1).map(_._2.length.toLong).toArray
      assert(dist.length == 4)

      new ChiSquareTest().chiSquare(
        Array.fill(numPartitions) { n.toDouble / numPartitions },
        dist)
    }

    withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) {
      // The default chi-sq value should be low
      assert(computeChiSquareTest() < 100)

      withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") {
        // If we only sample one point, the range boundaries will be pretty bad and the
        // chi-sq value would be very high.
        assert(computeChiSquareTest() > 300)
      }
    }
  }

} 
Example 41
Source File: DataFrameImplicitsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("RDD of tuples") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("Seq of tuples") {
    checkAnswer(
      (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("RDD[Int]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).toDF("intCol"),
      (1 to 10).map(i => Row(i)))
  }

  test("RDD[Long]") {
    checkAnswer(
      sparkContext.parallelize(1L to 10L).toDF("longCol"),
      (1L to 10L).map(i => Row(i)))
  }

  test("RDD[String]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
      (1 to 10).map(i => Row(i.toString)))
  }

  test("SPARK-19959: df[java.lang.Long].collect includes null throws NullPointerException") {
    checkAnswer(sparkContext.parallelize(Seq[java.lang.Integer](0, null, 2), 1).toDF,
      Seq(Row(0), Row(null), Row(2)))
    checkAnswer(sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF,
      Seq(Row(0L), Row(null), Row(2L)))
    checkAnswer(sparkContext.parallelize(Seq[java.lang.Float](0.0F, null, 2.0F), 1).toDF,
      Seq(Row(0.0F), Row(null), Row(2.0F)))
    checkAnswer(sparkContext.parallelize(Seq[java.lang.Double](0.0D, null, 2.0D), 1).toDF,
      Seq(Row(0.0D), Row(null), Row(2.0D)))
  }
} 
Example 42
Source File: DebuggingSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.debug

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData.TestData

class DebuggingSuite extends SparkFunSuite with SharedSQLContext {

  test("DataFrame.debug()") {
    testData.debug()
  }

  test("Dataset.debug()") {
    import testImplicits._
    testData.as[TestData].debug()
  }

  test("debugCodegen") {
    val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan)
    assert(res.contains("Subtree 1 / 2"))
    assert(res.contains("Subtree 2 / 2"))
    assert(res.contains("Object[]"))
  }

  test("debugCodegenStringSeq") {
    val res = codegenStringSeq(spark.range(10).groupBy("id").count().queryExecution.executedPlan)
    assert(res.length == 2)
    assert(res.forall{ case (subtree, code) =>
      subtree.contains("Range") && code.contains("Object[]")})
  }
} 
Example 43
Source File: RowDataSourceStrategySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE OR REPLACE TEMPORARY VIEW inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
       """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 44
Source File: SaveIntoDataSourceCommandSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.SparkConf
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.test.SharedSQLContext

class SaveIntoDataSourceCommandSuite extends SharedSQLContext {

  test("simpleString is redacted") {
    val URL = "connection.url"
    val PASS = "123"
    val DRIVER = "mydriver"

    val dataSource = DataSource(
      sparkSession = spark,
      className = "jdbc",
      partitionColumns = Nil,
      options = Map("password" -> PASS, "url" -> URL, "driver" -> DRIVER))

    val logicalPlanString = dataSource
      .planForWriting(SaveMode.ErrorIfExists, spark.range(1).logicalPlan)
      .treeString(true)

    assert(!logicalPlanString.contains(URL))
    assert(!logicalPlanString.contains(PASS))
    assert(logicalPlanString.contains(DRIVER))
  }
} 
Example 45
Source File: FileFormatWriterSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext

class FileFormatWriterSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("empty file should be skipped while write to file") {
    withTempPath { path =>
      spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
      val partFiles = path.listFiles()
        .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
      assert(partFiles.length === 2)
    }
  }

  test("SPARK-22252: FileFormatWriter should respect the input query schema") {
    withTable("t1", "t2", "t3", "t4") {
      spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1")
      spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2")
      checkAnswer(spark.table("t2"), Row(0, 0))

      // Test picking part of the columns when writing.
      spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3")
      spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4")
      checkAnswer(spark.table("t4"), Row(0, 0))
    }
  }
} 
Example 46
Source File: HadoopFsRelationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.io.{File, FilenameFilter}

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.test.SharedSQLContext

class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {

  test("sizeInBytes should be the total size of all files") {
    withTempDir{ dir =>
      dir.delete()
      spark.range(1000).write.parquet(dir.toString)
      // ignore hidden files
      val allFiles = dir.listFiles(new FilenameFilter {
        override def accept(dir: File, name: String): Boolean = {
          !name.startsWith(".") && !name.startsWith("_")
        }
      })
      val totalSize = allFiles.map(_.length()).sum
      val df = spark.read.parquet(dir.toString)
      assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize))
    }
  }

  test("SPARK-22790: spark.sql.sources.compressionFactor takes effect") {
    import testImplicits._
    Seq(1.0, 0.5).foreach { compressionFactor =>
      withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString,
        "spark.sql.autoBroadcastJoinThreshold" -> "400") {
        withTempPath { workDir =>
          // the file size is 740 bytes
          val workDirPath = workDir.getAbsolutePath
          val data1 = Seq(100, 200, 300, 400).toDF("count")
          data1.write.parquet(workDirPath + "/data1")
          val df1FromFile = spark.read.parquet(workDirPath + "/data1")
          val data2 = Seq(100, 200, 300, 400).toDF("count")
          data2.write.parquet(workDirPath + "/data2")
          val df2FromFile = spark.read.parquet(workDirPath + "/data2")
          val joinedDF = df1FromFile.join(df2FromFile, Seq("count"))
          if (compressionFactor == 0.5) {
            val bJoinExec = joinedDF.queryExecution.executedPlan.collect {
              case bJoin: BroadcastHashJoinExec => bJoin
            }
            assert(bJoinExec.nonEmpty)
            val smJoinExec = joinedDF.queryExecution.executedPlan.collect {
              case smJoin: SortMergeJoinExec => smJoin
            }
            assert(smJoinExec.isEmpty)
          } else {
            // compressionFactor is 1.0
            val bJoinExec = joinedDF.queryExecution.executedPlan.collect {
              case bJoin: BroadcastHashJoinExec => bJoin
            }
            assert(bJoinExec.isEmpty)
            val smJoinExec = joinedDF.queryExecution.executedPlan.collect {
              case smJoin: SortMergeJoinExec => smJoin
            }
            assert(smJoinExec.nonEmpty)
          }
        }
      }
    }
  }
} 
Example 47
Source File: ParquetFileFormatSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLContext {

  test("read parquet footers in parallel") {
    def testReadFooters(ignoreCorruptFiles: Boolean): Unit = {
      withTempDir { dir =>
        val fs = FileSystem.get(sparkContext.hadoopConfiguration)
        val basePath = dir.getCanonicalPath

        val path1 = new Path(basePath, "first")
        val path2 = new Path(basePath, "second")
        val path3 = new Path(basePath, "third")

        spark.range(1).toDF("a").coalesce(1).write.parquet(path1.toString)
        spark.range(1, 2).toDF("a").coalesce(1).write.parquet(path2.toString)
        spark.range(2, 3).toDF("a").coalesce(1).write.json(path3.toString)

        val fileStatuses =
          Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten

        val footers = ParquetFileFormat.readParquetFootersInParallel(
          sparkContext.hadoopConfiguration, fileStatuses, ignoreCorruptFiles)

        assert(footers.size == 2)
      }
    }

    testReadFooters(true)
    val exception = intercept[java.io.IOException] {
      testReadFooters(false)
    }
    assert(exception.getMessage().contains("Could not read footer for file"))
  }
} 
Example 48
Source File: DDLTestSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class DDLScanSource extends RelationProvider {
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimpleDDLScan(parameters("from").toInt, parameters("TO").toInt, parameters("Table"))(sqlContext)
  }
}

case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlContext: SQLContext)
  extends BaseRelation with TableScan {

  override def schema: StructType =
    StructType(Seq(//StructType代表一张表,StructField代表一个字段
      StructField("intType", IntegerType, nullable = false,
        new MetadataBuilder().putString("comment", s"test comment $table").build()),
      StructField("stringType", StringType, nullable = false),
      StructField("dateType", DateType, nullable = false),
      StructField("timestampType", TimestampType, nullable = false),
      StructField("doubleType", DoubleType, nullable = false),
      StructField("bigintType", LongType, nullable = false),
      StructField("tinyintType", ByteType, nullable = false),
      StructField("decimalType", DecimalType.USER_DEFAULT, nullable = false),
      StructField("fixedDecimalType", DecimalType(5, 1), nullable = false),
      StructField("binaryType", BinaryType, nullable = false),
      StructField("booleanType", BooleanType, nullable = false),
      StructField("smallIntType", ShortType, nullable = false),
      StructField("floatType", FloatType, nullable = false),
      StructField("mapType", MapType(StringType, StringType)),
      StructField("arrayType", ArrayType(StringType)),
      StructField("structType",//StructType代表一张表,StructField代表一个字段
        StructType(StructField("f1", StringType) :: StructField("f2", IntegerType) :: Nil
        )
      )
    ))
   //需要转换
  override def needConversion: Boolean = false

  override def buildScan(): RDD[Row] = {
    //依靠一个类型删掉黑客通过RDD[internalrow]回到RDD[行]
    // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
    sqlContext.sparkContext.parallelize(from to to).map { e =>
      InternalRow(UTF8String.fromString(s"people$e"), e * 2)
    }.asInstanceOf[RDD[Row]]
  }
}

class DDLTestSuite extends DataSourceTest with SharedSQLContext {
  protected override lazy val sql = caseInsensitiveContext.sql _

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql(
      """
      |CREATE TEMPORARY TABLE ddlPeople
      |USING org.apache.spark.sql.sources.DDLScanSource
      |OPTIONS (
      |  From '1',
      |  To '10',
      |  Table 'test1'
      |)
      """.stripMargin)
  }

  sqlTest(
      "describe ddlPeople",
      Seq(
        Row("intType", "int", "test comment test1"),
        Row("stringType", "string", ""),
        Row("dateType", "date", ""),
        Row("timestampType", "timestamp", ""),
        Row("doubleType", "double", ""),
        Row("bigintType", "bigint", ""),
        Row("tinyintType", "tinyint", ""),
        Row("decimalType", "decimal(10,0)", ""),
        Row("fixedDecimalType", "decimal(5,1)", ""),
        Row("binaryType", "binary", ""),
        Row("booleanType", "boolean", ""),
        Row("smallIntType", "smallint", ""),
        Row("floatType", "float", ""),
        Row("mapType", "map<string,string>", ""),
        Row("arrayType", "array<string>", ""),
        Row("structType", "struct<f1:string,f2:int>", "")
      ))
  //描述命令应该有正确的物理计划输出属性
  test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
    val attributes = sql("describe ddlPeople")
      .queryExecution.executedPlan.output
    assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
    assert(attributes.map(_.dataType).toSet === Set(StringType))
  }
} 
Example 49
Source File: ParquetProtobufCompatibilitySuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("unannotated array of primitive type") {
    checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }

  test("unannotated array of struct") {
    checkAnswer(
      readResourceParquetFile("test-data/old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }

  test("struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }

  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }

  test("unannotated array of string") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 50
Source File: DataSourceScanExecRedactionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext


class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext {

  override protected def sparkConf: SparkConf = super.sparkConf
    .set("spark.redaction.string.regex", "file:/[\\w_]+")

  test("treeString is redacted") {
    withTempDir { dir =>
      val basePath = dir.getCanonicalPath
      spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
      val df = spark.read.parquet(basePath)

      val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
        .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head
      assert(rootPath.toString.contains(dir.toURI.getPath.stripSuffix("/")))

      assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName))
      assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName))
      assert(!df.queryExecution.toString.contains(rootPath.getName))
      assert(!df.queryExecution.simpleString.contains(rootPath.getName))

      val replacement = "*********"
      assert(df.queryExecution.sparkPlan.treeString(verbose = true).contains(replacement))
      assert(df.queryExecution.executedPlan.treeString(verbose = true).contains(replacement))
      assert(df.queryExecution.toString.contains(replacement))
      assert(df.queryExecution.simpleString.contains(replacement))
    }
  }

  private def isIncluded(queryExecution: QueryExecution, msg: String): Boolean = {
    queryExecution.toString.contains(msg) ||
    queryExecution.simpleString.contains(msg) ||
    queryExecution.stringWithStats.contains(msg)
  }

  test("explain is redacted using SQLConf") {
    withTempDir { dir =>
      val basePath = dir.getCanonicalPath
      spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
      val df = spark.read.parquet(basePath)
      val replacement = "*********"

      // Respect SparkConf and replace file:/
      assert(isIncluded(df.queryExecution, replacement))

      assert(isIncluded(df.queryExecution, "FileScan"))
      assert(!isIncluded(df.queryExecution, "file:/"))

      withSQLConf(SQLConf.SQL_STRING_REDACTION_PATTERN.key -> "(?i)FileScan") {
        // Respect SQLConf and replace FileScan
        assert(isIncluded(df.queryExecution, replacement))

        assert(!isIncluded(df.queryExecution, "FileScan"))
        assert(isIncluded(df.queryExecution, "file:/"))
      }
    }
  }

} 
Example 51
Source File: QueryExecutionSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.test.SharedSQLContext

class QueryExecutionSuite extends SharedSQLContext {
  test("toString() exception/error handling") {
    spark.experimental.extraStrategies = Seq(
        new SparkStrategy {
          override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil
        })

    def qe: QueryExecution = new QueryExecution(spark, OneRowRelation())

    // Nothing!
    assert(qe.toString.contains("OneRowRelation"))

    // Throw an AnalysisException - this should be captured.
    spark.experimental.extraStrategies = Seq(
      new SparkStrategy {
        override def apply(plan: LogicalPlan): Seq[SparkPlan] =
          throw new AnalysisException("exception")
      })
    assert(qe.toString.contains("org.apache.spark.sql.AnalysisException"))

    // Throw an Error - this should not be captured.
    spark.experimental.extraStrategies = Seq(
      new SparkStrategy {
        override def apply(plan: LogicalPlan): Seq[SparkPlan] =
          throw new Error("error")
      })
    val error = intercept[Error](qe.toString)
    assert(error.getMessage.contains("error"))
  }
} 
Example 52
Source File: BatchEvalPythonExecSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.api.python.{PythonEvalType, PythonFunction}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.BooleanType

class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.newProductEncoder
  import testImplicits.localSeqToDatasetHolder

  override def beforeAll(): Unit = {
    super.beforeAll()
    spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF)
  }

  override def afterAll(): Unit = {
    spark.sessionState.functionRegistry.dropFunction(FunctionIdentifier("dummyPythonUDF"))
    super.afterAll()
  }

  test("Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: AttributeReference),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Nested Python UDF: push down deterministic FilterExec predicates") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF: no push down on non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("b > 4 and dummyPythonUDF(a) and rand() > 0.3")
    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF: push down on deterministic predicates after the first non-deterministic") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
      .where("dummyPythonUDF(a) and rand() > 0.3 and b > 4")

    val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
      case f @ FilterExec(
          And(_: AttributeReference, _: GreaterThan),
          InputAdapter(_: BatchEvalPythonExec)) => f
      case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
    }
    assert(qualifiedPlanNodes.size == 2)
  }

  test("Python UDF refers to the attributes from more than one child") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = Seq(("Hello", 4)).toDF("c", "d")
    val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
    val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect {
      case b: BatchEvalPythonExec => b
    }
    assert(qualifiedPlanNodes.size == 1)
  }
}

// This Python UDF is dummy and just for testing. Unable to execute.
class DummyUDF extends PythonFunction(
  command = Array[Byte](),
  envVars = Map("" -> "").asJava,
  pythonIncludes = ArrayBuffer("").asJava,
  pythonExec = "",
  pythonVer = "",
  broadcastVars = null,
  accumulator = null)

class MyDummyPythonUDF extends UserDefinedPythonFunction(
  name = "dummyUDF",
  func = new DummyUDF,
  dataType = BooleanType,
  pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
  udfDeterministic = true) 
Example 53
Source File: TakeOrderedAndProjectSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {

  private var rand: Random = _
  private var seed: Long = 0

  protected override def beforeAll(): Unit = {
    super.beforeAll()
    seed = System.currentTimeMillis()
    rand = new Random(seed)
  }

  private def generateRandomInputData(): DataFrame = {
    val schema = new StructType()
      .add("a", IntegerType, nullable = false)
      .add("b", IntegerType, nullable = false)
    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
    spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
  }

  
  private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan)

  val limit = 250
  val sortOrder = 'a.desc :: 'b.desc :: Nil

  test("TakeOrderedAndProject.doExecute without project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              SortExec(sortOrder, true, input))),
        sortAnswers = false)
    }
  }

  test("TakeOrderedAndProject.doExecute with project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(
            TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              ProjectExec(Seq(input.output.last),
                SortExec(sortOrder, true, input)))),
        sortAnswers = false)
    }
  }
} 
Example 54
Source File: SparkPlanSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlanSuite extends QueryTest with SharedSQLContext {

  test("SPARK-21619 execution of a canonicalized plan should fail") {
    val plan = spark.range(10).queryExecution.executedPlan.canonicalized

    intercept[IllegalStateException] { plan.execute() }
    intercept[IllegalStateException] { plan.executeCollect() }
    intercept[IllegalStateException] { plan.executeCollectPublic() }
    intercept[IllegalStateException] { plan.executeToIterator() }
    intercept[IllegalStateException] { plan.executeBroadcast() }
    intercept[IllegalStateException] { plan.executeTake(1) }
  }

} 
Example 55
Source File: SameResultSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext


class SameResultSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("FileSourceScanExec: different orders of data filters and partition filters") {
    withTempPath { path =>
      val tmpDir = path.getCanonicalPath
      spark.range(10)
        .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d")
        .write
        .partitionBy("a", "b")
        .parquet(tmpDir)
      val df = spark.read.parquet(tmpDir)
      // partition filters: a > 1 AND b < 9
      // data filters: c > 1 AND d < 9
      val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9"))
      val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1"))
      assert(plan1.sameResult(plan2))
    }
  }

  private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = {
    df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
      .asInstanceOf[FileSourceScanExec]
  }

  test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
    val df1 = spark.range(10).agg(sum($"id"))
    val df2 = spark.range(10).agg(sum($"id"))
    assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))

    val df3 = spark.range(10).agg(sumDistinct($"id"))
    val df4 = spark.range(10).agg(sumDistinct($"id"))
    assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
  }
} 
Example 56
Source File: SparkPlannerSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data, _) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 57
Source File: DataFrameTungstenSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("test simple types") {
    val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
    assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
  }

  test("test struct type") {
    val struct = Row(1, 2L, 3.0F, 3.0)
    val data = sparkContext.parallelize(Seq(Row(1, struct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(struct))
  }

  test("test nested struct type") {
    val innerStruct = Row(1, "abcd")
    val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
    val data = sparkContext.parallelize(Seq(Row(1, outerStruct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType)
          .add("b5", new StructType()
          .add("b5a", IntegerType)
          .add("b5b", StringType))
          .add("b6", StringType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(outerStruct))
  }

  test("primitive data type accesses in persist data") {
    val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong,
      31.25.toFloat, 63.75, null)
    val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType,
      FloatType, DoubleType, IntegerType)
    val schemas = dataTypes.zipWithIndex.map { case (dataType, index) =>
      StructField(s"col$index", dataType, true)
    }
    val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
    val df = spark.createDataFrame(rdd, StructType(schemas))
    val row = df.persist.take(1).apply(0)
    checkAnswer(df, row)
  }

  test("access cache multiple times") {
    val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache
    df0.count
    val df1 = df0.filter("x > 1")
    checkAnswer(df1, Seq(Row(2), Row(3)))
    val df2 = df0.filter("x > 2")
    checkAnswer(df2, Row(3))

    val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache
    for (_ <- 0 to 2) {
      val df11 = df10.filter("x > 5")
      checkAnswer(df11, Row(6))
    }
  }

  test("access only some column of the all of columns") {
    val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d")
    df.cache
    df.count
    assert(df.filter("d < 3").count == 1)
  }
} 
Example 58
Source File: RowSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 59
Source File: ExtraStrategiesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()
    sparkContext.parallelize(Seq(unsafeRow))
  }

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil
}

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
  }
}

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 60
Source File: DatasetCacheSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.storage.StorageLevel


class DatasetCacheSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("get storage level") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    // default storage level
    ds1.persist()
    ds2.cache()
    assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK)
    assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK)
    // unpersist
    ds1.unpersist()
    assert(ds1.storageLevel == StorageLevel.NONE)
    // non-default storage level
    ds1.persist(StorageLevel.MEMORY_ONLY_2)
    assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2)
    // joined Dataset should not be persisted
    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    assert(joined.storageLevel == StorageLevel.NONE)
  }

  test("persist and unpersist") {
    val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
    val cached = ds.cache()
    // count triggers the caching action. It should not throw.
    cached.count()
    // Make sure, the Dataset is indeed cached.
    assertCached(cached)
    // Check result.
    checkDataset(
      cached,
      2, 3, 4)
    // Drop the cache.
    cached.unpersist()
    assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.")
  }

  test("persist and then rebind right encoder when join 2 datasets") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    ds1.persist()
    assertCached(ds1)
    ds2.persist()
    assertCached(ds2)

    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    checkDataset(joined, ("2", 2))
    assertCached(joined, 2)

    ds1.unpersist()
    assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.")
    ds2.unpersist()
    assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.")
  }

  test("persist and then groupBy columns asKey, map") {
    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
    val grouped = ds.groupByKey(_._1)
    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
    agged.persist()

    checkDataset(
      agged.filter(_._1 == "b"),
      ("b", 3))
    assertCached(agged.filter(_._1 == "b"))

    ds.unpersist()
    assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be cached.")
    agged.unpersist()
    assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.")
  }
} 
Example 61
Source File: ResolvedDataSourceSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.test.SharedSQLContext

class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext {
  private def getProvidingClass(name: String): Class[_] =
    DataSource(
      sparkSession = spark,
      className = name,
      options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID)
    ).providingClass

  test("jdbc") {
    assert(
      getProvidingClass("jdbc") ===
      classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider])
    assert(
      getProvidingClass("org.apache.spark.sql.execution.datasources.jdbc") ===
      classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider])
    assert(
      getProvidingClass("org.apache.spark.sql.jdbc") ===
        classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider])
  }

  test("json") {
    assert(
      getProvidingClass("json") ===
      classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat])
    assert(
      getProvidingClass("org.apache.spark.sql.execution.datasources.json") ===
        classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat])
    assert(
      getProvidingClass("org.apache.spark.sql.json") ===
        classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat])
  }

  test("parquet") {
    assert(
      getProvidingClass("parquet") ===
      classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat])
    assert(
      getProvidingClass("org.apache.spark.sql.execution.datasources.parquet") ===
        classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat])
    assert(
      getProvidingClass("org.apache.spark.sql.parquet") ===
        classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat])
  }

  test("csv") {
    assert(
      getProvidingClass("csv") ===
        classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat])
    assert(
      getProvidingClass("com.databricks.spark.csv") ===
        classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat])
  }

  test("error message for unknown data sources") {
    val error1 = intercept[AnalysisException] {
      getProvidingClass("avro")
    }
    assert(error1.getMessage.contains("Failed to find data source: avro."))

    val error2 = intercept[AnalysisException] {
      getProvidingClass("com.databricks.spark.avro")
    }
    assert(error2.getMessage.contains("Failed to find data source: com.databricks.spark.avro."))

    val error3 = intercept[ClassNotFoundException] {
      getProvidingClass("asfdwefasdfasdf")
    }
    assert(error3.getMessage.contains("Failed to find data source: asfdwefasdfasdf."))
  }
} 
Example 62
Source File: DDLSourceLoadSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name - internal data sources") {
    val e = intercept[AnalysisException] {
      spark.read.format("Fluet da Bomb").load()
    }
    assert(e.getMessage.contains("Multiple sources found for Fluet da Bomb"))
  }

  test("data sources with the same name - internal data source/external data source") {
    assert(spark.read.format("datasource").load().schema ==
      StructType(Seq(StructField("longType", LongType, nullable = false))))
  }

  test("data sources with the same name - external data sources") {
    val e = intercept[AnalysisException] {
      spark.read.format("Fake external source").load()
    }
    assert(e.getMessage.contains("Multiple sources found for Fake external source"))
  }

  test("load data source from format alias") {
    assert(spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false))))
  }

  test("specify full classname with duplicate formats") {
    assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("integerType", IntegerType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceFour extends RelationProvider with DataSourceRegister {

  def shortName(): String = "datasource"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("longType", LongType, nullable = false)))
    }
} 
Example 63
Source File: SQLUtilsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.api.r

import org.apache.spark.sql.test.SharedSQLContext

class SQLUtilsSuite extends SharedSQLContext {

  import testImplicits._

  test("dfToCols should collect and transpose a data frame") {
    val df = Seq(
      (1, 2, 3),
      (4, 5, 6)
    ).toDF
    assert(SQLUtils.dfToCols(df) === Array(
      Array(1, 4),
      Array(2, 5),
      Array(3, 6)
    ))
  }

} 
Example 64
Source File: DataFrameImplicitsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("RDD of tuples") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("Seq of tuples") {
    checkAnswer(
      (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("RDD[Int]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).toDF("intCol"),
      (1 to 10).map(i => Row(i)))
  }

  test("RDD[Long]") {
    checkAnswer(
      sparkContext.parallelize(1L to 10L).toDF("longCol"),
      (1L to 10L).map(i => Row(i)))
  }

  test("RDD[String]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
      (1 to 10).map(i => Row(i.toString)))
  }
} 
Example 65
Source File: SQLCompatibilityFunctionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import java.math.BigDecimal
import java.sql.Timestamp

import org.apache.spark.sql.test.SharedSQLContext


class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext {

  test("ifnull") {
    checkAnswer(
      sql("SELECT ifnull(null, 'x'), ifnull('y', 'x'), ifnull(null, null)"),
      Row("x", "y", null))

    // Type coercion
    checkAnswer(
      sql("SELECT ifnull(1, 2.1d), ifnull(null, 2.1d)"),
      Row(1.0, 2.1))
  }

  test("nullif") {
    checkAnswer(
      sql("SELECT nullif('x', 'x'), nullif('x', 'y')"),
      Row(null, "x"))

    // Type coercion
    checkAnswer(
      sql("SELECT nullif(1, 2.1d), nullif(1, 1.0d)"),
      Row(1.0, null))
  }

  test("nvl") {
    checkAnswer(
      sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
      Row("x", "y", null))

    // Type coercion
    checkAnswer(
      sql("SELECT nvl(1, 2.1d), nvl(null, 2.1d)"),
      Row(1.0, 2.1))
  }

  test("nvl2") {
    checkAnswer(
      sql("SELECT nvl2(null, 'x', 'y'), nvl2('n', 'x', 'y'), nvl2(null, null, null)"),
      Row("y", "x", null))

    // Type coercion
    checkAnswer(
      sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"),
      Row(2.1, 1.0))
  }

  test("SPARK-16730 cast alias functions for Hive compatibility") {
    checkAnswer(
      sql("SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)"),
      Row(true, 1.toByte, 1.toShort, 1, 1L))

    checkAnswer(
      sql("SELECT float(1), double(1), decimal(1)"),
      Row(1.toFloat, 1.0, new BigDecimal(1)))

    checkAnswer(
      sql("SELECT date(\"2014-04-04\"), timestamp(date(\"2014-04-04\"))"),
      Row(new java.util.Date(114, 3, 4), new Timestamp(114, 3, 4, 0, 0, 0, 0)))

    checkAnswer(
      sql("SELECT string(1)"),
      Row("1"))

    // Error handling: only one argument
    val errorMsg = intercept[AnalysisException](sql("SELECT string(1, 2)")).getMessage
    assert(errorMsg.contains("Function string accepts only one argument"))
  }
} 
Example 66
Source File: DebuggingSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.debug

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData.TestData

class DebuggingSuite extends SparkFunSuite with SharedSQLContext {

  test("DataFrame.debug()") {
    testData.debug()
  }

  test("Dataset.debug()") {
    import testImplicits._
    testData.as[TestData].debug()
  }

  test("debugCodegen") {
    val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan)
    assert(res.contains("Subtree 1 / 2"))
    assert(res.contains("Subtree 2 / 2"))
    assert(res.contains("Object[]"))
  }
} 
Example 67
Source File: RowDataSourceStrategySuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE TEMPORARY TABLE inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
      """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 68
Source File: HadoopFsRelationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.io.{File, FilenameFilter}

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SharedSQLContext

class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {

  test("sizeInBytes should be the total size of all files") {
    withTempDir{ dir =>
      dir.delete()
      spark.range(1000).write.parquet(dir.toString)
      // ignore hidden files
      val allFiles = dir.listFiles(new FilenameFilter {
        override def accept(dir: File, name: String): Boolean = {
          !name.startsWith(".")
        }
      })
      val totalSize = allFiles.map(_.length()).sum
      val df = spark.read.parquet(dir.toString)
      assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize))
    }
  }
} 
Example 69
Source File: ParquetInteroperabilitySuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import java.io.File

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("parquet files with different physical schemas but share the same logical schema") {
    import ParquetCompatibilityTest._

    // This test case writes two Parquet files, both representing the following Catalyst schema
    //
    //   StructType(
    //     StructField(
    //       "f",
    //       ArrayType(IntegerType, containsNull = false),
    //       nullable = false))
    //
    // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the
    // other one comes with parquet-protobuf style 1-level unannotated primitive field.
    withTempDir { dir =>
      val avroStylePath = new File(dir, "avro-style").getCanonicalPath
      val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath

      val avroStyleSchema =
        """message avro_style {
          |  required group f (LIST) {
          |    repeated int32 array;
          |  }
          |}
        """.stripMargin

      writeDirect(avroStylePath, avroStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.group {
              rc.field("array", 0) {
                rc.addInteger(0)
                rc.addInteger(1)
              }
            }
          }
        }
      })

      logParquetSchema(avroStylePath)

      val protobufStyleSchema =
        """message protobuf_style {
          |  repeated int32 f;
          |}
        """.stripMargin

      writeDirect(protobufStylePath, protobufStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.addInteger(2)
            rc.addInteger(3)
          }
        }
      })

      logParquetSchema(protobufStylePath)

      checkAnswer(
        spark.read.parquet(dir.getCanonicalPath),
        Seq(
          Row(Seq(0, 1)),
          Row(Seq(2, 3))))
    }
  }
} 
Example 70
Source File: ParquetEncodingSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.JavaConverters._

import org.apache.parquet.hadoop.ParquetOutputFormat

import org.apache.spark.sql.test.SharedSQLContext

// TODO: this needs a lot more testing but it's currently not easy to test with the parquet
// writer abstractions. Revisit.
class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext {
  import testImplicits._

  val ROW = ((1).toByte, 2, 3L, "abc")
  val NULL_ROW = (
    null.asInstanceOf[java.lang.Byte],
    null.asInstanceOf[Integer],
    null.asInstanceOf[java.lang.Long],
    null.asInstanceOf[String])

  test("All Types Dictionary") {
    (1 :: 1000 :: Nil).foreach { n => {
      withTempPath { dir =>
        List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).getByte(i) == 1)
          assert(batch.column(1).getInt(i) == 2)
          assert(batch.column(2).getLong(i) == 3)
          assert(batch.column(3).getUTF8String(i).toString == "abc")
          i += 1
        }
        reader.close()
      }
    }}
  }

  test("All Types Null") {
    (1 :: 100 :: Nil).foreach { n => {
      withTempPath { dir =>
        val data = List.fill(n)(NULL_ROW).toDF
        data.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).isNullAt(i))
          assert(batch.column(1).isNullAt(i))
          assert(batch.column(2).isNullAt(i))
          assert(batch.column(3).isNullAt(i))
          i += 1
        }
        reader.close()
      }}
    }
  }

  test("Read row group containing both dictionary and plain encoded pages") {
    withSQLConf(ParquetOutputFormat.DICTIONARY_PAGE_SIZE -> "2048",
      ParquetOutputFormat.PAGE_SIZE -> "4096") {
      withTempPath { dir =>
        // In order to explicitly test for SPARK-14217, we set the parquet dictionary and page size
        // such that the following data spans across 3 pages (within a single row group) where the
        // first page is dictionary encoded and the remaining two are plain encoded.
        val data = (0 until 512).flatMap(i => Seq.fill(3)(i.toString))
        data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file, null )
        val column = reader.resultBatch().column(0)
        assert(reader.nextBatch())

        (0 until 512).foreach { i =>
          assert(column.getUTF8String(3 * i).toString == i.toString)
          assert(column.getUTF8String(3 * i + 1).toString == i.toString)
          assert(column.getUTF8String(3 * i + 2).toString == i.toString)
        }
        reader.close()
      }
    }
  }
} 
Example 71
Source File: ParquetProtobufCompatibilitySuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("unannotated array of primitive type") {
    checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }

  test("unannotated array of struct") {
    checkAnswer(
      readResourceParquetFile("test-data/old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }

  test("struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }

  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }

  test("unannotated array of string") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 72
Source File: ExchangeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.test.SharedSQLContext

class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits._

  test("shuffling UnsafeRows in exchange") {
    val input = (1 to 1000).map(Tuple1.apply)
    checkAnswer(
      input.toDF(),
      plan => ShuffleExchange(SinglePartition, plan),
      input.map(Row.fromTuple)
    )
  }

  test("compatible BroadcastMode") {
    val mode1 = IdentityBroadcastMode
    val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
    val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)

    assert(mode1.compatibleWith(mode1))
    assert(!mode1.compatibleWith(mode2))
    assert(!mode2.compatibleWith(mode1))
    assert(mode2.compatibleWith(mode2))
    assert(!mode2.compatibleWith(mode3))
    assert(mode3.compatibleWith(mode3))
  }

  test("BroadcastExchange same result") {
    val df = spark.range(10)
    val plan = df.queryExecution.executedPlan
    val output = plan.output
    assert(plan sameResult plan)

    val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan)
    val hashMode = HashedRelationBroadcastMode(output)
    val exchange2 = BroadcastExchangeExec(hashMode, plan)
    val hashMode2 =
      HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil)
    val exchange3 = BroadcastExchangeExec(hashMode2, plan)
    val exchange4 = ReusedExchangeExec(output, exchange3)

    assert(exchange1 sameResult exchange1)
    assert(exchange2 sameResult exchange2)
    assert(exchange3 sameResult exchange3)
    assert(exchange4 sameResult exchange4)

    assert(!exchange1.sameResult(exchange2))
    assert(!exchange2.sameResult(exchange3))
    assert(!exchange3.sameResult(exchange4))
    assert(exchange4 sameResult exchange3)
  }

  test("ShuffleExchange same result") {
    val df = spark.range(10)
    val plan = df.queryExecution.executedPlan
    val output = plan.output
    assert(plan sameResult plan)

    val part1 = HashPartitioning(output, 1)
    val exchange1 = ShuffleExchange(part1, plan)
    val exchange2 = ShuffleExchange(part1, plan)
    val part2 = HashPartitioning(output, 2)
    val exchange3 = ShuffleExchange(part2, plan)
    val part3 = HashPartitioning(output ++ output, 2)
    val exchange4 = ShuffleExchange(part3, plan)
    val exchange5 = ReusedExchangeExec(output, exchange4)

    assert(exchange1 sameResult exchange1)
    assert(exchange2 sameResult exchange2)
    assert(exchange3 sameResult exchange3)
    assert(exchange4 sameResult exchange4)
    assert(exchange5 sameResult exchange5)

    assert(exchange1 sameResult exchange2)
    assert(!exchange2.sameResult(exchange3))
    assert(!exchange3.sameResult(exchange4))
    assert(!exchange4.sameResult(exchange5))
    assert(exchange5 sameResult exchange4)
  }
} 
Example 73
Source File: TakeOrderedAndProjectSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {

  private var rand: Random = _
  private var seed: Long = 0

  protected override def beforeAll(): Unit = {
    super.beforeAll()
    seed = System.currentTimeMillis()
    rand = new Random(seed)
  }

  private def generateRandomInputData(): DataFrame = {
    val schema = new StructType()
      .add("a", IntegerType, nullable = false)
      .add("b", IntegerType, nullable = false)
    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
    spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
  }

  
  private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan)

  val limit = 250
  val sortOrder = 'a.desc :: 'b.desc :: Nil

  test("TakeOrderedAndProject.doExecute without project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              SortExec(sortOrder, true, input))),
        sortAnswers = false)
    }
  }

  test("TakeOrderedAndProject.doExecute with project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(
            TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              ProjectExec(Seq(input.output.last),
                SortExec(sortOrder, true, input)))),
        sortAnswers = false)
    }
  }
} 
Example 74
Source File: SparkPlannerSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 75
Source File: WholeStageCodegenSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 76
Source File: FileStreamSourceSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.File
import java.net.URI

import scala.util.Random

import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.streaming.ExistsThrowsExceptionFileSystem._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType

class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext {

  import FileStreamSource._

  test("SeenFilesMap") {
    val map = new SeenFilesMap(maxAgeMs = 10)

    map.add("a", 5)
    assert(map.size == 1)
    map.purge()
    assert(map.size == 1)

    // Add a new entry and purge should be no-op, since the gap is exactly 10 ms.
    map.add("b", 15)
    assert(map.size == 2)
    map.purge()
    assert(map.size == 2)

    // Add a new entry that's more than 10 ms than the first entry. We should be able to purge now.
    map.add("c", 16)
    assert(map.size == 3)
    map.purge()
    assert(map.size == 2)

    // Override existing entry shouldn't change the size
    map.add("c", 25)
    assert(map.size == 2)

    // Not a new file because we have seen c before
    assert(!map.isNewFile("c", 20))

    // Not a new file because timestamp is too old
    assert(!map.isNewFile("d", 5))

    // Finally a new file: never seen and not too old
    assert(map.isNewFile("e", 20))
  }

  test("SeenFilesMap should only consider a file old if it is earlier than last purge time") {
    val map = new SeenFilesMap(maxAgeMs = 10)

    map.add("a", 20)
    assert(map.size == 1)

    // Timestamp 5 should still considered a new file because purge time should be 0
    assert(map.isNewFile("b", 9))
    assert(map.isNewFile("b", 10))

    // Once purge, purge time should be 10 and then b would be a old file if it is less than 10.
    map.purge()
    assert(!map.isNewFile("b", 9))
    assert(map.isNewFile("b", 10))
  }

  testWithUninterruptibleThread("do not recheck that files exist during getBatch") {
    withTempDir { temp =>
      spark.conf.set(
        s"fs.$scheme.impl",
        classOf[ExistsThrowsExceptionFileSystem].getName)
      // add the metadata entries as a pre-req
      val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
      val metadataLog =
        new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath)
      assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0))))

      val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil,
        dir.getAbsolutePath, Map.empty)
      // this method should throw an exception if `fs.exists` is called during resolveRelation
      newSource.getBatch(None, LongOffset(1))
    }
  }
}


  override def listStatus(file: Path): Array[FileStatus] = {
    val emptyFile = new FileStatus()
    emptyFile.setPath(file)
    Array(emptyFile)
  }
}

object ExistsThrowsExceptionFileSystem {
  val scheme = s"FileStreamSourceSuite${math.abs(Random.nextInt)}fs"
} 
Example 77
Source File: DataFrameTungstenSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("test simple types") {
    val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
    assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
  }

  test("test struct type") {
    val struct = Row(1, 2L, 3.0F, 3.0)
    val data = sparkContext.parallelize(Seq(Row(1, struct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(struct))
  }

  test("test nested struct type") {
    val innerStruct = Row(1, "abcd")
    val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
    val data = sparkContext.parallelize(Seq(Row(1, outerStruct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType)
          .add("b5", new StructType()
          .add("b5a", IntegerType)
          .add("b5b", StringType))
          .add("b6", StringType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(outerStruct))
  }
} 
Example 78
Source File: ExtraStrategiesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()
    sparkContext.parallelize(Seq(unsafeRow))
  }

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil
}

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
  }
}

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 79
Source File: DatasetCacheSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.storage.StorageLevel


class DatasetCacheSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("get storage level") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    // default storage level
    ds1.persist()
    ds2.cache()
    assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK)
    assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK)
    // unpersist
    ds1.unpersist()
    assert(ds1.storageLevel == StorageLevel.NONE)
    // non-default storage level
    ds1.persist(StorageLevel.MEMORY_ONLY_2)
    assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2)
    // joined Dataset should not be persisted
    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    assert(joined.storageLevel == StorageLevel.NONE)
  }

  test("persist and unpersist") {
    val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
    val cached = ds.cache()
    // count triggers the caching action. It should not throw.
    cached.count()
    // Make sure, the Dataset is indeed cached.
    assertCached(cached)
    // Check result.
    checkDataset(
      cached,
      2, 3, 4)
    // Drop the cache.
    cached.unpersist()
    assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.")
  }

  test("persist and then rebind right encoder when join 2 datasets") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    ds1.persist()
    assertCached(ds1)
    ds2.persist()
    assertCached(ds2)

    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    checkDataset(joined, ("2", 2))
    assertCached(joined, 2)

    ds1.unpersist()
    assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.")
    ds2.unpersist()
    assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.")
  }

  test("persist and then groupBy columns asKey, map") {
    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
    val grouped = ds.groupByKey(_._1)
    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
    agged.persist()

    checkDataset(
      agged.filter(_._1 == "b"),
      ("b", 3))
    assertCached(agged.filter(_._1 == "b"))

    ds.unpersist()
    assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be cached.")
    agged.unpersist()
    assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.")
  }
} 
Example 80
Source File: DataFrameComplexTypeSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext


class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("UDF on struct") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.select(struct($"a").as("s")).select(f($"s.a")).collect()
  }

  test("UDF on named_struct") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect()
  }

  test("UDF on array") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.select(array($"a").as("s")).select(f($"s".getItem(0))).collect()
  }

  test("UDF on map") {
    val f = udf((a: String) => a)
    val df = Seq("a" -> 1).toDF("a", "b")
    df.select(map($"a", $"b").as("s")).select(f($"s".getItem("a"))).collect()
  }

  test("SPARK-12477 accessing null element in array field") {
    val df = sparkContext.parallelize(Seq((Seq("val1", null, "val2"),
      Seq(Some(1), None, Some(2))))).toDF("s", "i")
    val nullStringRow = df.selectExpr("s[1]").collect()(0)
    assert(nullStringRow == org.apache.spark.sql.Row(null))
    val nullIntRow = df.selectExpr("i[1]").collect()(0)
    assert(nullIntRow == org.apache.spark.sql.Row(null))
  }

  test("SPARK-15285 Generated SpecificSafeProjection.apply method grows beyond 64KB") {
    val ds100_5 = Seq(S100_5()).toDS()
    ds100_5.rdd.count
  }
}

class S100(
  val s1: String = "1", val s2: String = "2", val s3: String = "3", val s4: String = "4",
  val s5: String = "5", val s6: String = "6", val s7: String = "7", val s8: String = "8",
  val s9: String = "9", val s10: String = "10", val s11: String = "11", val s12: String = "12",
  val s13: String = "13", val s14: String = "14", val s15: String = "15", val s16: String = "16",
  val s17: String = "17", val s18: String = "18", val s19: String = "19", val s20: String = "20",
  val s21: String = "21", val s22: String = "22", val s23: String = "23", val s24: String = "24",
  val s25: String = "25", val s26: String = "26", val s27: String = "27", val s28: String = "28",
  val s29: String = "29", val s30: String = "30", val s31: String = "31", val s32: String = "32",
  val s33: String = "33", val s34: String = "34", val s35: String = "35", val s36: String = "36",
  val s37: String = "37", val s38: String = "38", val s39: String = "39", val s40: String = "40",
  val s41: String = "41", val s42: String = "42", val s43: String = "43", val s44: String = "44",
  val s45: String = "45", val s46: String = "46", val s47: String = "47", val s48: String = "48",
  val s49: String = "49", val s50: String = "50", val s51: String = "51", val s52: String = "52",
  val s53: String = "53", val s54: String = "54", val s55: String = "55", val s56: String = "56",
  val s57: String = "57", val s58: String = "58", val s59: String = "59", val s60: String = "60",
  val s61: String = "61", val s62: String = "62", val s63: String = "63", val s64: String = "64",
  val s65: String = "65", val s66: String = "66", val s67: String = "67", val s68: String = "68",
  val s69: String = "69", val s70: String = "70", val s71: String = "71", val s72: String = "72",
  val s73: String = "73", val s74: String = "74", val s75: String = "75", val s76: String = "76",
  val s77: String = "77", val s78: String = "78", val s79: String = "79", val s80: String = "80",
  val s81: String = "81", val s82: String = "82", val s83: String = "83", val s84: String = "84",
  val s85: String = "85", val s86: String = "86", val s87: String = "87", val s88: String = "88",
  val s89: String = "89", val s90: String = "90", val s91: String = "91", val s92: String = "92",
  val s93: String = "93", val s94: String = "94", val s95: String = "95", val s96: String = "96",
  val s97: String = "97", val s98: String = "98", val s99: String = "99", val s100: String = "100")
extends DefinedByConstructorParams

case class S100_5(
  s1: S100 = new S100(), s2: S100 = new S100(), s3: S100 = new S100(),
  s4: S100 = new S100(), s5: S100 = new S100()) 
Example 81
Source File: DDLSourceLoadSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      spark.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without Hive Support") {
    val e = intercept[AnalysisException] {
      spark.read.format("orc").load()
    }
    assert(e.message.contains("The ORC data source must be used with Hive support enabled"))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 82
Source File: PartitionedWriteSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("write many partitions") {
    val path = Utils.createTempDir()
    path.delete()

    val df = spark.range(100).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      spark.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("write many partitions with repeats") {
    val path = Utils.createTempDir()
    path.delete()

    val base = spark.range(100)
    val df = base.union(base).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      spark.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("partitioned columns should appear at the end of schema") {
    withTempPath { f =>
      val path = f.getAbsolutePath
      Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path)
      assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
    }
  }
} 
Example 83
Source File: MiscFunctionsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class MiscFunctionsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("reflect and java_method") {
    val df = Seq((1, "one")).toDF("a", "b")
    val className = ReflectClass.getClass.getName.stripSuffix("$")
    checkAnswer(
      df.selectExpr(
        s"reflect('$className', 'method1', a, b)",
        s"java_method('$className', 'method1', a, b)"),
      Row("m1one", "m1one"))
  }
}

object ReflectClass {
  def method1(v1: Int, v2: String): String = "m" + v1 + v2
} 
Example 84
Source File: BenchmarkQueryTest.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

abstract class BenchmarkQueryTest extends QueryTest with SharedSQLContext with BeforeAndAfterAll {

  // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting
  // the max iteration of analyzer/optimizer batches.
  assert(Utils.isTesting, "spark.testing is not set to true")

  
  protected override def afterAll(): Unit = {
    try {
      // For debugging dump some statistics about how much time was spent in various optimizer rules
      logWarning(RuleExecutor.dumpTimeSpent())
      spark.sessionState.catalog.reset()
    } finally {
      super.afterAll()
    }
  }

  override def beforeAll() {
    super.beforeAll()
    RuleExecutor.resetMetrics()
  }

  protected def checkGeneratedCode(plan: SparkPlan): Unit = {
    val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]()
    plan foreach {
      case s: WholeStageCodegenExec =>
        codegenSubtrees += s
      case _ =>
    }
    codegenSubtrees.toSeq.foreach { subtree =>
      val code = subtree.doCodeGen()._2
      try {
        // Just check the generated code can be properly compiled
        CodeGenerator.compile(code)
      } catch {
        case e: Exception =>
          val msg =
            s"""
               |failed to compile:
               |Subtree:
               |$subtree
               |Generated code:
               |${CodeFormatter.format(code)}
             """.stripMargin
          throw new Exception(msg, e)
      }
    }
  }
} 
Example 85
Source File: SQLUtilsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.api.r

import org.apache.spark.sql.test.SharedSQLContext

class SQLUtilsSuite extends SharedSQLContext {

  import testImplicits._

  test("dfToCols should collect and transpose a data frame") {
    val df = Seq(
      (1, 2, 3),
      (4, 5, 6)
    ).toDF
    assert(SQLUtils.dfToCols(df) === Array(
      Array(1, 4),
      Array(2, 5),
      Array(3, 6)
    ))
  }

} 
Example 86
Source File: DataFrameImplicitsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("RDD of tuples") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("Seq of tuples") {
    checkAnswer(
      (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("RDD[Int]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).toDF("intCol"),
      (1 to 10).map(i => Row(i)))
  }

  test("RDD[Long]") {
    checkAnswer(
      sparkContext.parallelize(1L to 10L).toDF("longCol"),
      (1L to 10L).map(i => Row(i)))
  }

  test("RDD[String]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
      (1 to 10).map(i => Row(i.toString)))
  }

  test("SPARK-19959: df[java.lang.Long].collect includes null throws NullPointerException") {
    checkAnswer(sparkContext.parallelize(Seq[java.lang.Integer](0, null, 2), 1).toDF,
      Seq(Row(0), Row(null), Row(2)))
    checkAnswer(sparkContext.parallelize(Seq[java.lang.Long](0L, null, 2L), 1).toDF,
      Seq(Row(0L), Row(null), Row(2L)))
    checkAnswer(sparkContext.parallelize(Seq[java.lang.Float](0.0F, null, 2.0F), 1).toDF,
      Seq(Row(0.0F), Row(null), Row(2.0F)))
    checkAnswer(sparkContext.parallelize(Seq[java.lang.Double](0.0D, null, 2.0D), 1).toDF,
      Seq(Row(0.0D), Row(null), Row(2.0D)))
  }
} 
Example 87
Source File: ExplainSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType

class ExplainSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  
  private def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = {
    val output = new java.io.ByteArrayOutputStream()
    Console.withOut(output) {
      df.explain(extended = false)
    }
    for (key <- keywords) {
      assert(output.toString.contains(key))
    }
  }

  test("SPARK-23034 show rdd names in RDD scan nodes (Dataset)") {
    val rddWithName = spark.sparkContext.parallelize(Row(1, "abc") :: Nil).setName("testRdd")
    val df = spark.createDataFrame(rddWithName, StructType.fromDDL("c0 int, c1 string"))
    checkKeywordsExistsInExplain(df, keywords = "Scan ExistingRDD testRdd")
  }

  test("SPARK-23034 show rdd names in RDD scan nodes (DataFrame)") {
    val rddWithName = spark.sparkContext.parallelize(ExplainSingleData(1) :: Nil).setName("testRdd")
    val df = spark.createDataFrame(rddWithName)
    checkKeywordsExistsInExplain(df, keywords = "Scan testRdd")
  }

  test("SPARK-24850 InMemoryRelation string representation does not include cached plan") {
    val df = Seq(1).toDF("a").cache()
    checkKeywordsExistsInExplain(df,
      keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)")
  }
}

case class ExplainSingleData(id: Int) 
Example 88
Source File: DebuggingSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.debug

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData.TestData

class DebuggingSuite extends SparkFunSuite with SharedSQLContext {

  test("DataFrame.debug()") {
    testData.debug()
  }

  test("Dataset.debug()") {
    import testImplicits._
    testData.as[TestData].debug()
  }

  test("debugCodegen") {
    val res = codegenString(spark.range(10).groupBy(col("id") * 2).count()
      .queryExecution.executedPlan)
    assert(res.contains("Subtree 1 / 2"))
    assert(res.contains("Subtree 2 / 2"))
    assert(res.contains("Object[]"))
  }

  test("debugCodegenStringSeq") {
    val res = codegenStringSeq(spark.range(10).groupBy(col("id") * 2).count()
      .queryExecution.executedPlan)
    assert(res.length == 2)
    assert(res.forall{ case (subtree, code) =>
      subtree.contains("Range") && code.contains("Object[]")})
  }
} 
Example 89
Source File: RowDataSourceStrategySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE OR REPLACE TEMPORARY VIEW inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
       """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 90
Source File: SaveIntoDataSourceCommandSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.SparkConf
import org.apache.spark.sql.SaveMode
import org.apache.spark.sql.test.SharedSQLContext

class SaveIntoDataSourceCommandSuite extends SharedSQLContext {

  test("simpleString is redacted") {
    val URL = "connection.url"
    val PASS = "mypassword"
    val DRIVER = "mydriver"

    val dataSource = DataSource(
      sparkSession = spark,
      className = "jdbc",
      partitionColumns = Nil,
      options = Map("password" -> PASS, "url" -> URL, "driver" -> DRIVER))

    val logicalPlanString = dataSource
      .planForWriting(SaveMode.ErrorIfExists, spark.range(1).logicalPlan)
      .treeString(true)

    assert(!logicalPlanString.contains(URL))
    assert(!logicalPlanString.contains(PASS))
    assert(logicalPlanString.contains(DRIVER))
  }
} 
Example 91
Source File: FileFormatWriterSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.test.SharedSQLContext

class FileFormatWriterSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("empty file should be skipped while write to file") {
    withTempPath { path =>
      spark.range(100).repartition(10).where("id = 50").write.parquet(path.toString)
      val partFiles = path.listFiles()
        .filter(f => f.isFile && !f.getName.startsWith(".") && !f.getName.startsWith("_"))
      assert(partFiles.length === 2)
    }
  }

  test("SPARK-22252: FileFormatWriter should respect the input query schema") {
    withTable("t1", "t2", "t3", "t4") {
      spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1")
      spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2")
      checkAnswer(spark.table("t2"), Row(0, 0))

      // Test picking part of the columns when writing.
      spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3")
      spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4")
      checkAnswer(spark.table("t4"), Row(0, 0))
    }
  }
} 
Example 92
Source File: HadoopFsRelationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.io.{File, FilenameFilter}

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.test.SharedSQLContext

class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {

  test("sizeInBytes should be the total size of all files") {
    withTempDir{ dir =>
      dir.delete()
      spark.range(1000).write.parquet(dir.toString)
      // ignore hidden files
      val allFiles = dir.listFiles(new FilenameFilter {
        override def accept(dir: File, name: String): Boolean = {
          !name.startsWith(".") && !name.startsWith("_")
        }
      })
      val totalSize = allFiles.map(_.length()).sum
      val df = spark.read.parquet(dir.toString)
      assert(df.queryExecution.logical.stats.sizeInBytes === BigInt(totalSize))
    }
  }

  test("SPARK-22790: spark.sql.sources.compressionFactor takes effect") {
    import testImplicits._
    Seq(1.0, 0.5).foreach { compressionFactor =>
      withSQLConf("spark.sql.sources.fileCompressionFactor" -> compressionFactor.toString,
        "spark.sql.autoBroadcastJoinThreshold" -> "400") {
        withTempPath { workDir =>
          // the file size is 740 bytes
          val workDirPath = workDir.getAbsolutePath
          val data1 = Seq(100, 200, 300, 400).toDF("count")
          data1.write.parquet(workDirPath + "/data1")
          val df1FromFile = spark.read.parquet(workDirPath + "/data1")
          val data2 = Seq(100, 200, 300, 400).toDF("count")
          data2.write.parquet(workDirPath + "/data2")
          val df2FromFile = spark.read.parquet(workDirPath + "/data2")
          val joinedDF = df1FromFile.join(df2FromFile, Seq("count"))
          if (compressionFactor == 0.5) {
            val bJoinExec = joinedDF.queryExecution.executedPlan.collect {
              case bJoin: BroadcastHashJoinExec => bJoin
            }
            assert(bJoinExec.nonEmpty)
            val smJoinExec = joinedDF.queryExecution.executedPlan.collect {
              case smJoin: SortMergeJoinExec => smJoin
            }
            assert(smJoinExec.isEmpty)
          } else {
            // compressionFactor is 1.0
            val bJoinExec = joinedDF.queryExecution.executedPlan.collect {
              case bJoin: BroadcastHashJoinExec => bJoin
            }
            assert(bJoinExec.isEmpty)
            val smJoinExec = joinedDF.queryExecution.executedPlan.collect {
              case smJoin: SortMergeJoinExec => smJoin
            }
            assert(smJoinExec.nonEmpty)
          }
        }
      }
    }
  }
} 
Example 93
Source File: ParquetFileFormatSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark.SparkException
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext

class ParquetFileFormatSuite extends QueryTest with ParquetTest with SharedSQLContext {

  test("read parquet footers in parallel") {
    def testReadFooters(ignoreCorruptFiles: Boolean): Unit = {
      withTempDir { dir =>
        val fs = FileSystem.get(spark.sessionState.newHadoopConf())
        val basePath = dir.getCanonicalPath

        val path1 = new Path(basePath, "first")
        val path2 = new Path(basePath, "second")
        val path3 = new Path(basePath, "third")

        spark.range(1).toDF("a").coalesce(1).write.parquet(path1.toString)
        spark.range(1, 2).toDF("a").coalesce(1).write.parquet(path2.toString)
        spark.range(2, 3).toDF("a").coalesce(1).write.json(path3.toString)

        val fileStatuses =
          Seq(fs.listStatus(path1), fs.listStatus(path2), fs.listStatus(path3)).flatten

        val footers = ParquetFileFormat.readParquetFootersInParallel(
          spark.sessionState.newHadoopConf(), fileStatuses, ignoreCorruptFiles)

        assert(footers.size == 2)
      }
    }

    testReadFooters(true)
    val exception = intercept[SparkException] {
      testReadFooters(false)
    }.getCause
    assert(exception.getMessage().contains("Could not read footer for file"))
  }
} 
Example 94
Source File: ParquetEncodingSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.JavaConverters._

import org.apache.parquet.hadoop.ParquetOutputFormat

import org.apache.spark.sql.test.SharedSQLContext

// TODO: this needs a lot more testing but it's currently not easy to test with the parquet
// writer abstractions. Revisit.
class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext {
  import testImplicits._

  val ROW = ((1).toByte, 2, 3L, "abc")
  val NULL_ROW = (
    null.asInstanceOf[java.lang.Byte],
    null.asInstanceOf[Integer],
    null.asInstanceOf[java.lang.Long],
    null.asInstanceOf[String])

  test("All Types Dictionary") {
    (1 :: 1000 :: Nil).foreach { n => {
      withTempPath { dir =>
        List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val conf = sqlContext.conf
        val reader = new VectorizedParquetRecordReader(
          null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize)
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).getByte(i) == 1)
          assert(batch.column(1).getInt(i) == 2)
          assert(batch.column(2).getLong(i) == 3)
          assert(batch.column(3).getUTF8String(i).toString == "abc")
          i += 1
        }
        reader.close()
      }
    }}
  }

  test("All Types Null") {
    (1 :: 100 :: Nil).foreach { n => {
      withTempPath { dir =>
        val data = List.fill(n)(NULL_ROW).toDF
        data.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val conf = sqlContext.conf
        val reader = new VectorizedParquetRecordReader(
          null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize)
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).isNullAt(i))
          assert(batch.column(1).isNullAt(i))
          assert(batch.column(2).isNullAt(i))
          assert(batch.column(3).isNullAt(i))
          i += 1
        }
        reader.close()
      }}
    }
  }

  test("Read row group containing both dictionary and plain encoded pages") {
    withSQLConf(ParquetOutputFormat.DICTIONARY_PAGE_SIZE -> "2048",
      ParquetOutputFormat.PAGE_SIZE -> "4096") {
      withTempPath { dir =>
        // In order to explicitly test for SPARK-14217, we set the parquet dictionary and page size
        // such that the following data spans across 3 pages (within a single row group) where the
        // first page is dictionary encoded and the remaining two are plain encoded.
        val data = (0 until 512).flatMap(i => Seq.fill(3)(i.toString))
        data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head

        val conf = sqlContext.conf
        val reader = new VectorizedParquetRecordReader(
          null, conf.offHeapColumnVectorEnabled, conf.parquetVectorizedReaderBatchSize)
        reader.initialize(file, null )
        val column = reader.resultBatch().column(0)
        assert(reader.nextBatch())

        (0 until 512).foreach { i =>
          assert(column.getUTF8String(3 * i).toString == i.toString)
          assert(column.getUTF8String(3 * i + 1).toString == i.toString)
          assert(column.getUTF8String(3 * i + 2).toString == i.toString)
        }
        reader.close()
      }
    }
  }
} 
Example 95
Source File: ParquetProtobufCompatibilitySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("unannotated array of primitive type") {
    checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }

  test("unannotated array of struct") {
    checkAnswer(
      readResourceParquetFile("test-data/old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }

  test("struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }

  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }

  test("unannotated array of string") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 96
Source File: DataSourceScanExecRedactionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.hadoop.fs.Path

import org.apache.spark.SparkConf
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext


class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext {

  override protected def sparkConf: SparkConf = super.sparkConf
    .set("spark.redaction.string.regex", "file:/[\\w_]+")

  test("treeString is redacted") {
    withTempDir { dir =>
      val basePath = dir.getCanonicalPath
      spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
      val df = spark.read.parquet(basePath)

      val rootPath = df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
        .asInstanceOf[FileSourceScanExec].relation.location.rootPaths.head
      assert(rootPath.toString.contains(dir.toURI.getPath.stripSuffix("/")))

      assert(!df.queryExecution.sparkPlan.treeString(verbose = true).contains(rootPath.getName))
      assert(!df.queryExecution.executedPlan.treeString(verbose = true).contains(rootPath.getName))
      assert(!df.queryExecution.toString.contains(rootPath.getName))
      assert(!df.queryExecution.simpleString.contains(rootPath.getName))

      val replacement = "*********"
      assert(df.queryExecution.sparkPlan.treeString(verbose = true).contains(replacement))
      assert(df.queryExecution.executedPlan.treeString(verbose = true).contains(replacement))
      assert(df.queryExecution.toString.contains(replacement))
      assert(df.queryExecution.simpleString.contains(replacement))
    }
  }

  private def isIncluded(queryExecution: QueryExecution, msg: String): Boolean = {
    queryExecution.toString.contains(msg) ||
    queryExecution.simpleString.contains(msg) ||
    queryExecution.stringWithStats.contains(msg)
  }

  test("explain is redacted using SQLConf") {
    withTempDir { dir =>
      val basePath = dir.getCanonicalPath
      spark.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
      val df = spark.read.parquet(basePath)
      val replacement = "*********"

      // Respect SparkConf and replace file:/
      assert(isIncluded(df.queryExecution, replacement))

      assert(isIncluded(df.queryExecution, "FileScan"))
      assert(!isIncluded(df.queryExecution, "file:/"))

      withSQLConf(SQLConf.SQL_STRING_REDACTION_PATTERN.key -> "(?i)FileScan") {
        // Respect SQLConf and replace FileScan
        assert(isIncluded(df.queryExecution, replacement))

        assert(!isIncluded(df.queryExecution, "FileScan"))
        assert(isIncluded(df.queryExecution, "file:/"))
      }
    }
  }

} 
Example 97
Source File: QueryExecutionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.test.SharedSQLContext

class QueryExecutionSuite extends SharedSQLContext {
  test("toString() exception/error handling") {
    spark.experimental.extraStrategies = Seq(
        new SparkStrategy {
          override def apply(plan: LogicalPlan): Seq[SparkPlan] = Nil
        })

    def qe: QueryExecution = new QueryExecution(spark, OneRowRelation())

    // Nothing!
    assert(qe.toString.contains("OneRowRelation"))

    // Throw an AnalysisException - this should be captured.
    spark.experimental.extraStrategies = Seq(
      new SparkStrategy {
        override def apply(plan: LogicalPlan): Seq[SparkPlan] =
          throw new AnalysisException("exception")
      })
    assert(qe.toString.contains("org.apache.spark.sql.AnalysisException"))

    // Throw an Error - this should not be captured.
    spark.experimental.extraStrategies = Seq(
      new SparkStrategy {
        override def apply(plan: LogicalPlan): Seq[SparkPlan] =
          throw new Error("error")
      })
    val error = intercept[Error](qe.toString)
    assert(error.getMessage.contains("error"))
  }
} 
Example 98
Source File: ExtractPythonUDFsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.python

import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSQLContext

class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.newProductEncoder
  import testImplicits.localSeqToDatasetHolder

  val batchedPythonUDF = new MyDummyPythonUDF
  val scalarPandasUDF = new MyDummyScalarPandasUDF

  private def collectBatchExec(plan: SparkPlan): Seq[BatchEvalPythonExec] = plan.collect {
    case b: BatchEvalPythonExec => b
  }

  private def collectArrowExec(plan: SparkPlan): Seq[ArrowEvalPythonExec] = plan.collect {
    case b: ArrowEvalPythonExec => b
  }

  test("Chained Batched Python UDFs should be combined to a single physical node") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = df.withColumn("c", batchedPythonUDF(col("a")))
      .withColumn("d", batchedPythonUDF(col("c")))
    val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
    assert(pythonEvalNodes.size == 1)
  }

  test("Chained Scalar Pandas UDFs should be combined to a single physical node") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = df.withColumn("c", scalarPandasUDF(col("a")))
      .withColumn("d", scalarPandasUDF(col("c")))
    val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
    assert(arrowEvalNodes.size == 1)
  }

  test("Mixed Batched Python UDFs and Pandas UDF should be separate physical node") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = df.withColumn("c", batchedPythonUDF(col("a")))
      .withColumn("d", scalarPandasUDF(col("b")))

    val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
    val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
    assert(pythonEvalNodes.size == 1)
    assert(arrowEvalNodes.size == 1)
  }

  test("Independent Batched Python UDFs and Scalar Pandas UDFs should be combined separately") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = df.withColumn("c1", batchedPythonUDF(col("a")))
      .withColumn("c2", batchedPythonUDF(col("c1")))
      .withColumn("d1", scalarPandasUDF(col("a")))
      .withColumn("d2", scalarPandasUDF(col("d1")))

    val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
    val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
    assert(pythonEvalNodes.size == 1)
    assert(arrowEvalNodes.size == 1)
  }

  test("Dependent Batched Python UDFs and Scalar Pandas UDFs should not be combined") {
    val df = Seq(("Hello", 4)).toDF("a", "b")
    val df2 = df.withColumn("c1", batchedPythonUDF(col("a")))
      .withColumn("d1", scalarPandasUDF(col("c1")))
      .withColumn("c2", batchedPythonUDF(col("d1")))
      .withColumn("d2", scalarPandasUDF(col("c2")))

    val pythonEvalNodes = collectBatchExec(df2.queryExecution.executedPlan)
    val arrowEvalNodes = collectArrowExec(df2.queryExecution.executedPlan)
    assert(pythonEvalNodes.size == 2)
    assert(arrowEvalNodes.size == 2)
  }
} 
Example 99
Source File: TakeOrderedAndProjectSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {

  private var rand: Random = _
  private var seed: Long = 0

  protected override def beforeAll(): Unit = {
    super.beforeAll()
    seed = System.currentTimeMillis()
    rand = new Random(seed)
  }

  private def generateRandomInputData(): DataFrame = {
    val schema = new StructType()
      .add("a", IntegerType, nullable = false)
      .add("b", IntegerType, nullable = false)
    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
    spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
  }

  
  private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan)

  val limit = 250
  val sortOrder = 'a.desc :: 'b.desc :: Nil

  test("TakeOrderedAndProject.doExecute without project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              SortExec(sortOrder, true, input))),
        sortAnswers = false)
    }
  }

  test("TakeOrderedAndProject.doExecute with project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(
            TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              ProjectExec(Seq(input.output.last),
                SortExec(sortOrder, true, input)))),
        sortAnswers = false)
    }
  }
} 
Example 100
Source File: SparkPlanSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkEnv
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlanSuite extends QueryTest with SharedSQLContext {

  test("SPARK-21619 execution of a canonicalized plan should fail") {
    val plan = spark.range(10).queryExecution.executedPlan.canonicalized

    intercept[IllegalStateException] { plan.execute() }
    intercept[IllegalStateException] { plan.executeCollect() }
    intercept[IllegalStateException] { plan.executeCollectPublic() }
    intercept[IllegalStateException] { plan.executeToIterator() }
    intercept[IllegalStateException] { plan.executeBroadcast() }
    intercept[IllegalStateException] { plan.executeTake(1) }
  }

  test("SPARK-23731 plans should be canonicalizable after being (de)serialized") {
    withTempPath { path =>
      spark.range(1).write.parquet(path.getAbsolutePath)
      val df = spark.read.parquet(path.getAbsolutePath)
      val fileSourceScanExec =
        df.queryExecution.sparkPlan.collectFirst { case p: FileSourceScanExec => p }.get
      val serializer = SparkEnv.get.serializer.newInstance()
      val readback =
        serializer.deserialize[FileSourceScanExec](serializer.serialize(fileSourceScanExec))
      try {
        readback.canonicalized
      } catch {
        case e: Throwable => fail("FileSourceScanExec was not canonicalizable", e)
      }
    }
  }

  test("SPARK-25357 SparkPlanInfo of FileScan contains nonEmpty metadata") {
    withTempPath { path =>
      spark.range(5).write.parquet(path.getAbsolutePath)
      val f = spark.read.parquet(path.getAbsolutePath)
      assert(SparkPlanInfo.fromSparkPlan(f.queryExecution.sparkPlan).metadata.nonEmpty)
    }
  }
} 
Example 101
Source File: SameResultSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.IntegerType


class SameResultSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("FileSourceScanExec: different orders of data filters and partition filters") {
    withTempPath { path =>
      val tmpDir = path.getCanonicalPath
      spark.range(10)
        .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d")
        .write
        .partitionBy("a", "b")
        .parquet(tmpDir)
      val df = spark.read.parquet(tmpDir)
      // partition filters: a > 1 AND b < 9
      // data filters: c > 1 AND d < 9
      val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9"))
      val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1"))
      assert(plan1.sameResult(plan2))
    }
  }

  private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = {
    df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
      .asInstanceOf[FileSourceScanExec]
  }

  test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
    val df1 = spark.range(10).agg(sum($"id"))
    val df2 = spark.range(10).agg(sum($"id"))
    assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))

    val df3 = spark.range(10).agg(sumDistinct($"id"))
    val df4 = spark.range(10).agg(sumDistinct($"id"))
    assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
  }

  test("Canonicalized result is case-insensitive") {
    val a = AttributeReference("A", IntegerType)()
    val b = AttributeReference("B", IntegerType)()
    val planUppercase = Project(Seq(a), LocalRelation(a, b))

    val c = AttributeReference("a", IntegerType)()
    val d = AttributeReference("b", IntegerType)()
    val planLowercase = Project(Seq(c), LocalRelation(c, d))

    assert(planUppercase.sameResult(planLowercase))
  }
} 
Example 102
Source File: SparkPlannerSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data, _) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 103
Source File: DataFrameTungstenSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("test simple types") {
    val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
    assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
  }

  test("test struct type") {
    val struct = Row(1, 2L, 3.0F, 3.0)
    val data = sparkContext.parallelize(Seq(Row(1, struct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(struct))
  }

  test("test nested struct type") {
    val innerStruct = Row(1, "abcd")
    val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
    val data = sparkContext.parallelize(Seq(Row(1, outerStruct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType)
          .add("b5", new StructType()
          .add("b5a", IntegerType)
          .add("b5b", StringType))
          .add("b6", StringType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(outerStruct))
  }

  test("primitive data type accesses in persist data") {
    val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong,
      31.25.toFloat, 63.75, null)
    val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType,
      FloatType, DoubleType, IntegerType)
    val schemas = dataTypes.zipWithIndex.map { case (dataType, index) =>
      StructField(s"col$index", dataType, true)
    }
    val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data)))
    val df = spark.createDataFrame(rdd, StructType(schemas))
    val row = df.persist.take(1).apply(0)
    checkAnswer(df, row)
  }

  test("access cache multiple times") {
    val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache
    df0.count
    val df1 = df0.filter("x > 1")
    checkAnswer(df1, Seq(Row(2), Row(3)))
    val df2 = df0.filter("x > 2")
    checkAnswer(df2, Row(3))

    val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache
    for (_ <- 0 to 2) {
      val df11 = df10.filter("x > 5")
      checkAnswer(df11, Row(6))
    }
  }

  test("access only some column of the all of columns") {
    val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d")
    df.cache
    df.count
    assert(df.filter("d < 3").count == 1)
  }
} 
Example 104
Source File: ExtraStrategiesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()
    sparkContext.parallelize(Seq(unsafeRow))
  }

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil
}

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
  }
}

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 105
Source File: DataFrameComplexTypeSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.DefinedByConstructorParams
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext


class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("UDF on struct") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.select(struct($"a").as("s")).select(f($"s.a")).collect()
  }

  test("UDF on named_struct") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect()
  }

  test("UDF on array") {
    val f = udf((a: String) => a)
    val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.select(array($"a").as("s")).select(f($"s".getItem(0))).collect()
  }

  test("UDF on map") {
    val f = udf((a: String) => a)
    val df = Seq("a" -> 1).toDF("a", "b")
    df.select(map($"a", $"b").as("s")).select(f($"s".getItem("a"))).collect()
  }

  test("SPARK-12477 accessing null element in array field") {
    val df = sparkContext.parallelize(Seq((Seq("val1", null, "val2"),
      Seq(Some(1), None, Some(2))))).toDF("s", "i")
    val nullStringRow = df.selectExpr("s[1]").collect()(0)
    assert(nullStringRow == org.apache.spark.sql.Row(null))
    val nullIntRow = df.selectExpr("i[1]").collect()(0)
    assert(nullIntRow == org.apache.spark.sql.Row(null))
  }

  test("SPARK-15285 Generated SpecificSafeProjection.apply method grows beyond 64KB") {
    val ds100_5 = Seq(S100_5()).toDS()
    ds100_5.rdd.count
  }
}

class S100(
  val s1: String = "1", val s2: String = "2", val s3: String = "3", val s4: String = "4",
  val s5: String = "5", val s6: String = "6", val s7: String = "7", val s8: String = "8",
  val s9: String = "9", val s10: String = "10", val s11: String = "11", val s12: String = "12",
  val s13: String = "13", val s14: String = "14", val s15: String = "15", val s16: String = "16",
  val s17: String = "17", val s18: String = "18", val s19: String = "19", val s20: String = "20",
  val s21: String = "21", val s22: String = "22", val s23: String = "23", val s24: String = "24",
  val s25: String = "25", val s26: String = "26", val s27: String = "27", val s28: String = "28",
  val s29: String = "29", val s30: String = "30", val s31: String = "31", val s32: String = "32",
  val s33: String = "33", val s34: String = "34", val s35: String = "35", val s36: String = "36",
  val s37: String = "37", val s38: String = "38", val s39: String = "39", val s40: String = "40",
  val s41: String = "41", val s42: String = "42", val s43: String = "43", val s44: String = "44",
  val s45: String = "45", val s46: String = "46", val s47: String = "47", val s48: String = "48",
  val s49: String = "49", val s50: String = "50", val s51: String = "51", val s52: String = "52",
  val s53: String = "53", val s54: String = "54", val s55: String = "55", val s56: String = "56",
  val s57: String = "57", val s58: String = "58", val s59: String = "59", val s60: String = "60",
  val s61: String = "61", val s62: String = "62", val s63: String = "63", val s64: String = "64",
  val s65: String = "65", val s66: String = "66", val s67: String = "67", val s68: String = "68",
  val s69: String = "69", val s70: String = "70", val s71: String = "71", val s72: String = "72",
  val s73: String = "73", val s74: String = "74", val s75: String = "75", val s76: String = "76",
  val s77: String = "77", val s78: String = "78", val s79: String = "79", val s80: String = "80",
  val s81: String = "81", val s82: String = "82", val s83: String = "83", val s84: String = "84",
  val s85: String = "85", val s86: String = "86", val s87: String = "87", val s88: String = "88",
  val s89: String = "89", val s90: String = "90", val s91: String = "91", val s92: String = "92",
  val s93: String = "93", val s94: String = "94", val s95: String = "95", val s96: String = "96",
  val s97: String = "97", val s98: String = "98", val s99: String = "99", val s100: String = "100")
extends DefinedByConstructorParams

case class S100_5(
  s1: S100 = new S100(), s2: S100 = new S100(), s3: S100 = new S100(),
  s4: S100 = new S100(), s5: S100 = new S100()) 
Example 106
Source File: ResolvedDataSourceSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.test.SharedSQLContext

class ResolvedDataSourceSuite extends SparkFunSuite with SharedSQLContext {
  private def getProvidingClass(name: String): Class[_] =
    DataSource(
      sparkSession = spark,
      className = name,
      options = Map(DateTimeUtils.TIMEZONE_OPTION -> DateTimeUtils.defaultTimeZone().getID)
    ).providingClass

  test("jdbc") {
    assert(
      getProvidingClass("jdbc") ===
      classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider])
    assert(
      getProvidingClass("org.apache.spark.sql.execution.datasources.jdbc") ===
      classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider])
    assert(
      getProvidingClass("org.apache.spark.sql.jdbc") ===
        classOf[org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider])
  }

  test("json") {
    assert(
      getProvidingClass("json") ===
      classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat])
    assert(
      getProvidingClass("org.apache.spark.sql.execution.datasources.json") ===
        classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat])
    assert(
      getProvidingClass("org.apache.spark.sql.json") ===
        classOf[org.apache.spark.sql.execution.datasources.json.JsonFileFormat])
  }

  test("parquet") {
    assert(
      getProvidingClass("parquet") ===
      classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat])
    assert(
      getProvidingClass("org.apache.spark.sql.execution.datasources.parquet") ===
        classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat])
    assert(
      getProvidingClass("org.apache.spark.sql.parquet") ===
        classOf[org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat])
  }

  test("csv") {
    assert(
      getProvidingClass("csv") ===
        classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat])
    assert(
      getProvidingClass("com.databricks.spark.csv") ===
        classOf[org.apache.spark.sql.execution.datasources.csv.CSVFileFormat])
  }

  test("avro: show deploy guide for loading the external avro module") {
    Seq("avro", "org.apache.spark.sql.avro").foreach { provider =>
      val message = intercept[AnalysisException] {
        getProvidingClass(provider)
      }.getMessage
      assert(message.contains(s"Failed to find data source: $provider"))
      assert(message.contains("Please deploy the application as per the deployment section of"))
    }
  }

  test("kafka: show deploy guide for loading the external kafka module") {
    val message = intercept[AnalysisException] {
      getProvidingClass("kafka")
    }.getMessage
    assert(message.contains("Failed to find data source: kafka"))
    assert(message.contains("Please deploy the application as per the deployment section of"))
  }

  test("error message for unknown data sources") {
    val error = intercept[ClassNotFoundException] {
      getProvidingClass("asfdwefasdfasdf")
    }
    assert(error.getMessage.contains("Failed to find data source: asfdwefasdfasdf."))
  }
} 
Example 107
Source File: DDLSourceLoadSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name - internal data sources") {
    val e = intercept[AnalysisException] {
      spark.read.format("Fluet da Bomb").load()
    }
    assert(e.getMessage.contains("Multiple sources found for Fluet da Bomb"))
  }

  test("data sources with the same name - internal data source/external data source") {
    assert(spark.read.format("datasource").load().schema ==
      StructType(Seq(StructField("longType", LongType, nullable = false))))
  }

  test("data sources with the same name - external data sources") {
    val e = intercept[AnalysisException] {
      spark.read.format("Fake external source").load()
    }
    assert(e.getMessage.contains("Multiple sources found for Fake external source"))
  }

  test("load data source from format alias") {
    assert(spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false))))
  }

  test("specify full classname with duplicate formats") {
    assert(spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false))))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("integerType", IntegerType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceFour extends RelationProvider with DataSourceRegister {

  def shortName(): String = "datasource"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("longType", LongType, nullable = false)))
    }
} 
Example 108
Source File: SerializationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.test.SharedSQLContext

class SerializationSuite extends SparkFunSuite with SharedSQLContext {

  test("[SPARK-5235] SQLContext should be serializable") {
    val spark = SparkSession.builder.getOrCreate()
    new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext)
  }

  test("[SPARK-26409] SQLConf should be serializable") {
    val spark = SparkSession.builder.getOrCreate()
    new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sessionState.conf)
  }
} 
Example 109
Source File: MiscFunctionsSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class MiscFunctionsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("reflect and java_method") {
    val df = Seq((1, "one")).toDF("a", "b")
    val className = ReflectClass.getClass.getName.stripSuffix("$")
    checkAnswer(
      df.selectExpr(
        s"reflect('$className', 'method1', a, b)",
        s"java_method('$className', 'method1', a, b)"),
      Row("m1one", "m1one"))
  }
}

object ReflectClass {
  def method1(v1: Int, v2: String): String = "m" + v1 + v2
} 
Example 110
Source File: DataFrameHintSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.test.SharedSQLContext

class DataFrameHintSuite extends AnalysisTest with SharedSQLContext {
  import testImplicits._
  lazy val df = spark.range(10)

  private def check(df: Dataset[_], expected: LogicalPlan) = {
    comparePlans(
      df.queryExecution.logical,
      expected
    )
  }

  test("various hint parameters") {
    check(
      df.hint("hint1"),
      UnresolvedHint("hint1", Seq(),
        df.logicalPlan
      )
    )

    check(
      df.hint("hint1", 1, "a"),
      UnresolvedHint("hint1", Seq(1, "a"), df.logicalPlan)
    )

    check(
      df.hint("hint1", 1, $"a"),
      UnresolvedHint("hint1", Seq(1, $"a"),
        df.logicalPlan
      )
    )

    check(
      df.hint("hint1", Seq(1, 2, 3), Seq($"a", $"b", $"c")),
      UnresolvedHint("hint1", Seq(Seq(1, 2, 3), Seq($"a", $"b", $"c")),
        df.logicalPlan
      )
    )
  }

  test("coalesce and repartition hint") {
    check(
      df.hint("COALESCE", 10),
      UnresolvedHint("COALESCE", Seq(10), df.logicalPlan))

    check(
      df.hint("REPARTITION", 100),
      UnresolvedHint("REPARTITION", Seq(100), df.logicalPlan))
  }
} 
Example 111
Source File: ApproxCountDistinctForIntervalsQuerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.test.SharedSQLContext

class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height
  // histogram usually contains hundreds of buckets. So we need to test
  // ApproxCountDistinctForIntervals with large number of endpoints
  // (the number of endpoints == the number of buckets + 1).
  test("test ApproxCountDistinctForIntervals with large number of endpoints") {
    val table = "approx_count_distinct_for_intervals_tbl"
    withTable(table) {
      (1 to 100000).toDF("col").createOrReplaceTempView(table)
      // percentiles of 0, 0.001, 0.002 ... 0.999, 1
      val endpoints = (0 to 1000).map(_ * 100000 / 1000)

      // Since approx_count_distinct_for_intervals is not a public function, here we do
      // the computation by constructing logical plan.
      val relation = spark.table(table).logicalPlan
      val attr = relation.output.find(_.name == "col").get
      val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_))))
      val aggExpr = aggFunc.toAggregateExpression()
      val namedExpr = Alias(aggExpr, aggExpr.toString)()
      val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation))
        .executedPlan.executeTake(1).head
      val ndvArray = ndvsRow.getArray(0).toLongArray()
      assert(endpoints.length == ndvArray.length + 1)

      // Each bucket has 100 distinct values.
      val expectedNdv = 100
      for (i <- ndvArray.indices) {
        val ndv = ndvArray(i)
        val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d)
        assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.")
      }
    }
  }
} 
Example 112
Source File: ComplexTypesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.test.SharedSQLContext

class ComplexTypesSuite extends QueryTest with SharedSQLContext {

  override def beforeAll() {
    super.beforeAll()
    spark.range(10).selectExpr(
      "id + 1 as i1", "id + 2 as i2", "id + 3 as i3", "id + 4 as i4", "id + 5 as i5")
      .write.saveAsTable("tab")
  }

  override def afterAll() {
    try {
      spark.sql("DROP TABLE IF EXISTS tab")
    } finally {
      super.afterAll()
    }
  }

  def checkNamedStruct(plan: LogicalPlan, expectedCount: Int): Unit = {
    var count = 0
    plan.foreach { operator =>
      operator.transformExpressions {
        case c: CreateNamedStruct =>
          count += 1
          c
      }
    }

    if (expectedCount != count) {
      fail(s"expect $expectedCount CreateNamedStruct but got $count.")
    }
  }

  test("simple case") {
    val df = spark.table("tab").selectExpr(
      "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4) as col2")
      .filter("col2.c > 11").selectExpr("col1.a")
    checkAnswer(df, Row(9) :: Row(10) :: Nil)
    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
  }

  test("named_struct is used in the top Project") {
    val df = spark.table("tab").selectExpr(
      "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)")
      .selectExpr("col1.a", "col1")
      .filter("col1.a > 8")
    checkAnswer(df, Row(9, Row(9, 10)) :: Row(10, Row(10, 11)) :: Nil)
    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 1)

    val df1 = spark.table("tab").selectExpr(
      "i5", "named_struct('a', i1, 'b', i2) as col1", "named_struct('a', i3, 'c', i4)")
      .sort("col1")
      .selectExpr("col1.a")
      .filter("col1.a > 8")
    checkAnswer(df1, Row(9) :: Row(10) :: Nil)
    checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 1)
  }

  test("expression in named_struct") {
    val df = spark.table("tab")
      .selectExpr("i5", "struct(i1 as exp, i2, i3) as cola")
      .selectExpr("cola.exp", "cola.i3").filter("cola.i3 > 10")
    checkAnswer(df, Row(9, 11) :: Row(10, 12) :: Nil)
    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)

    val df1 = spark.table("tab")
      .selectExpr("i5", "struct(i1 + 1 as exp, i2, i3) as cola")
      .selectExpr("cola.i3").filter("cola.exp > 10")
    checkAnswer(df1, Row(12) :: Nil)
    checkNamedStruct(df1.queryExecution.optimizedPlan, expectedCount = 0)
  }

  test("nested case") {
    val df = spark.table("tab")
      .selectExpr("struct(struct(i2, i3) as exp, i4) as cola")
      .selectExpr("cola.exp.i2", "cola.i4").filter("cola.exp.i2 > 10")
    checkAnswer(df, Row(11, 13) :: Nil)
    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)

    val df1 = spark.table("tab")
      .selectExpr("struct(i2, i3) as exp", "i4")
      .selectExpr("struct(exp, i4) as cola")
      .selectExpr("cola.exp.i2", "cola.i4").filter("cola.i4 > 11")
    checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil)
    checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
  }
} 
Example 113
Source File: EventHubsSourceOffsetSuite.scala    From azure-event-hubs-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.eventhubs

import java.io.File

import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.streaming.OffsetSuite
import org.apache.spark.sql.test.SharedSQLContext

class EventHubsSourceOffsetSuite extends OffsetSuite with SharedSQLContext {

  compare(one = EventHubsSourceOffset(("t", 0, 1L)), two = EventHubsSourceOffset(("t", 0, 2L)))

  compare(one = EventHubsSourceOffset(("t", 0, 1L), ("t", 1, 0L)),
          two = EventHubsSourceOffset(("t", 0, 2L), ("t", 1, 1L)))

  compare(one = EventHubsSourceOffset(("t", 0, 1L), ("T", 0, 0L)),
          two = EventHubsSourceOffset(("t", 0, 2L), ("T", 0, 1L)))

  compare(one = EventHubsSourceOffset(("t", 0, 1L)),
          two = EventHubsSourceOffset(("t", 0, 2L), ("t", 1, 1L)))

  val ehso1 = EventHubsSourceOffset(("t", 0, 1L))
  val ehso2 = EventHubsSourceOffset(("t", 0, 2L), ("t", 1, 3L))
  val ehso3 = EventHubsSourceOffset(("t", 0, 2L), ("t", 1, 3L), ("t", 1, 4L))

  compare(EventHubsSourceOffset(SerializedOffset(ehso1.json)),
          EventHubsSourceOffset(SerializedOffset(ehso2.json)))

  test("basic serialization - deserialization") {
    assert(
      EventHubsSourceOffset.getPartitionSeqNos(ehso1) ==
        EventHubsSourceOffset.getPartitionSeqNos(SerializedOffset(ehso1.json)))
  }

  test("OffsetSeqLog serialization - deserialization") {
    withTempDir { temp =>
      // use non-existent directory to test whether log make the dir
      val dir = new File(temp, "dir")
      val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
      val batch0 = OffsetSeq.fill(ehso1)
      val batch1 = OffsetSeq.fill(ehso2, ehso3)

      val batch0Serialized =
        OffsetSeq.fill(batch0.offsets.flatMap(_.map(o => SerializedOffset(o.json))): _*)

      val batch1Serialized =
        OffsetSeq.fill(batch1.offsets.flatMap(_.map(o => SerializedOffset(o.json))): _*)

      assert(metadataLog.add(0, batch0))
      assert(metadataLog.getLatest() === Some(0 -> batch0Serialized))
      assert(metadataLog.get(0) === Some(batch0Serialized))

      assert(metadataLog.add(1, batch1))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(
        metadataLog.get(None, Some(1)) ===
          Array(0 -> batch0Serialized, 1 -> batch1Serialized))

      // Adding the same batch does nothing
      metadataLog.add(1, OffsetSeq.fill(LongOffset(3)))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(
        metadataLog.get(None, Some(1)) ===
          Array(0 -> batch0Serialized, 1 -> batch1Serialized))
    }
  }

  test("read Spark 2.1.0 offset format") {
    val offset = readFromResource("eventhubs-source-offset-version-2.1.0.txt")
    assert(
      EventHubsSourceOffset(offset) ===
        EventHubsSourceOffset(("ehName1", 0, 456L), ("ehName1", 1, 789L), ("ehName2", 0, 0L)))
  }

  private def readFromResource(file: String): SerializedOffset = {
    import scala.io.Source
    val input = getClass.getResource(s"/$file").toURI
    val str = Source.fromFile(input).mkString
    SerializedOffset(str)
  }
} 
Example 114
Source File: JsonUtilsSuite.scala    From azure-event-hubs-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.eventhubs

import org.apache.spark.SparkFunSuite
import org.apache.spark.eventhubs.NameAndPartition
import org.apache.spark.sql.test.SharedSQLContext

class JsonUtilsSuite extends SparkFunSuite with SharedSQLContext {

  test("parsing partitions") {
    val parsed = JsonUtils.partitions("""{"nameA":[0,1],"nameB":[4,6]}""")
    val expected = Array(
      new NameAndPartition("nameA", 0),
      new NameAndPartition("nameA", 1),
      new NameAndPartition("nameB", 4),
      new NameAndPartition("nameB", 6)
    )
    assert(parsed.toSeq === expected.toSeq)
  }

  test("parsing partitionSeqNos") {
    val parsed = JsonUtils.partitionSeqNos("""{"nameA":{"0":23,"1":-1},"nameB":{"0":-2}}""")

    assert(parsed(new NameAndPartition("nameA", 0)) === 23)
    assert(parsed(new NameAndPartition("nameA", 1)) === -1)
    assert(parsed(new NameAndPartition("nameB", 0)) === -2)
  }
} 
Example 115
Source File: KafkaSourceOffsetSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.io.File

import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.streaming.OffsetSuite
import org.apache.spark.sql.test.SharedSQLContext

class KafkaSourceOffsetSuite extends OffsetSuite with SharedSQLContext {

  compare(
    one = KafkaSourceOffset(("t", 0, 1L)),
    two = KafkaSourceOffset(("t", 0, 2L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L), ("t", 1, 0L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L), ("T", 0, 0L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("T", 0, 1L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L)))


  val kso1 = KafkaSourceOffset(("t", 0, 1L))
  val kso2 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L))
  val kso3 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L), ("t", 1, 4L))

  compare(KafkaSourceOffset(SerializedOffset(kso1.json)),
    KafkaSourceOffset(SerializedOffset(kso2.json)))

  test("basic serialization - deserialization") {
    assert(KafkaSourceOffset.getPartitionOffsets(kso1) ==
      KafkaSourceOffset.getPartitionOffsets(SerializedOffset(kso1.json)))
  }


  testWithUninterruptibleThread("OffsetSeqLog serialization - deserialization") {
    withTempDir { temp =>
      // use non-existent directory to test whether log make the dir
      val dir = new File(temp, "dir")
      val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
      val batch0 = OffsetSeq.fill(kso1)
      val batch1 = OffsetSeq.fill(kso2, kso3)

      val batch0Serialized = OffsetSeq.fill(batch0.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      val batch1Serialized = OffsetSeq.fill(batch1.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      assert(metadataLog.add(0, batch0))
      assert(metadataLog.getLatest() === Some(0 -> batch0Serialized))
      assert(metadataLog.get(0) === Some(batch0Serialized))

      assert(metadataLog.add(1, batch1))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))

      // Adding the same batch does nothing
      metadataLog.add(1, OffsetSeq.fill(LongOffset(3)))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))
    }
  }

  test("read Spark 2.1.0 offset format") {
    val offset = readFromResource("kafka-source-offset-version-2.1.0.txt")
    assert(KafkaSourceOffset(offset) ===
      KafkaSourceOffset(("topic1", 0, 456L), ("topic1", 1, 789L), ("topic2", 0, 0L)))
  }

  private def readFromResource(file: String): SerializedOffset = {
    import scala.io.Source
    val input = getClass.getResource(s"/$file").toURI
    val str = Source.fromFile(input).mkString
    SerializedOffset(str)
  }
} 
Example 116
Source File: SQLUtilsSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.api.r

import org.apache.spark.sql.test.SharedSQLContext

class SQLUtilsSuite extends SharedSQLContext {

  import testImplicits._

  test("dfToCols should collect and transpose a data frame") {
    val df = Seq(
      (1, 2, 3),
      (4, 5, 6)
    ).toDF
    assert(SQLUtils.dfToCols(df) === Array(
      Array(1, 4),
      Array(2, 5),
      Array(3, 6)
    ))
  }

} 
Example 117
Source File: DataFrameImplicitsSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("RDD of tuples") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("Seq of tuples") {
    checkAnswer(
      (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("RDD[Int]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).toDF("intCol"),
      (1 to 10).map(i => Row(i)))
  }

  test("RDD[Long]") {
    checkAnswer(
      sparkContext.parallelize(1L to 10L).toDF("longCol"),
      (1L to 10L).map(i => Row(i)))
  }

  test("RDD[String]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
      (1 to 10).map(i => Row(i.toString)))
  }
} 
Example 118
Source File: DebuggingSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.debug

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData.TestData

class DebuggingSuite extends SparkFunSuite with SharedSQLContext {

  test("DataFrame.debug()") {
    testData.debug()
  }

  test("Dataset.debug()") {
    import testImplicits._
    testData.as[TestData].debug()
  }

  test("debugCodegen") {
    val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan)
    assert(res.contains("Subtree 1 / 2"))
    assert(res.contains("Subtree 2 / 2"))
    assert(res.contains("Object[]"))
  }
} 
Example 119
Source File: RowDataSourceStrategySuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE TEMPORARY TABLE inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
      """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 120
Source File: HadoopFsRelationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.io.{File, FilenameFilter}

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SharedSQLContext

class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {

  test("sizeInBytes should be the total size of all files") {
    withTempDir{ dir =>
      dir.delete()
      spark.range(1000).write.parquet(dir.toString)
      // ignore hidden files
      val allFiles = dir.listFiles(new FilenameFilter {
        override def accept(dir: File, name: String): Boolean = {
          !name.startsWith(".")
        }
      })
      val totalSize = allFiles.map(_.length()).sum
      val df = spark.read.parquet(dir.toString)
      assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize))
    }
  }
} 
Example 121
Source File: ParquetInteroperabilitySuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import java.io.File

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("parquet files with different physical schemas but share the same logical schema") {
    import ParquetCompatibilityTest._

    // This test case writes two Parquet files, both representing the following Catalyst schema
    //
    //   StructType(
    //     StructField(
    //       "f",
    //       ArrayType(IntegerType, containsNull = false),
    //       nullable = false))
    //
    // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the
    // other one comes with parquet-protobuf style 1-level unannotated primitive field.
    withTempDir { dir =>
      val avroStylePath = new File(dir, "avro-style").getCanonicalPath
      val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath

      val avroStyleSchema =
        """message avro_style {
          |  required group f (LIST) {
          |    repeated int32 array;
          |  }
          |}
        """.stripMargin

      writeDirect(avroStylePath, avroStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.group {
              rc.field("array", 0) {
                rc.addInteger(0)
                rc.addInteger(1)
              }
            }
          }
        }
      })

      logParquetSchema(avroStylePath)

      val protobufStyleSchema =
        """message protobuf_style {
          |  repeated int32 f;
          |}
        """.stripMargin

      writeDirect(protobufStylePath, protobufStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.addInteger(2)
            rc.addInteger(3)
          }
        }
      })

      logParquetSchema(protobufStylePath)

      checkAnswer(
        spark.read.parquet(dir.getCanonicalPath),
        Seq(
          Row(Seq(0, 1)),
          Row(Seq(2, 3))))
    }
  }
} 
Example 122
Source File: ParquetEncodingSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.JavaConverters._

import org.apache.parquet.hadoop.ParquetOutputFormat

import org.apache.spark.sql.test.SharedSQLContext

// TODO: this needs a lot more testing but it's currently not easy to test with the parquet
// writer abstractions. Revisit.
class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext {
  import testImplicits._

  val ROW = ((1).toByte, 2, 3L, "abc")
  val NULL_ROW = (
    null.asInstanceOf[java.lang.Byte],
    null.asInstanceOf[Integer],
    null.asInstanceOf[java.lang.Long],
    null.asInstanceOf[String])

  test("All Types Dictionary") {
    (1 :: 1000 :: Nil).foreach { n => {
      withTempPath { dir =>
        List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).getByte(i) == 1)
          assert(batch.column(1).getInt(i) == 2)
          assert(batch.column(2).getLong(i) == 3)
          assert(batch.column(3).getUTF8String(i).toString == "abc")
          i += 1
        }
        reader.close()
      }
    }}
  }

  test("All Types Null") {
    (1 :: 100 :: Nil).foreach { n => {
      withTempPath { dir =>
        val data = List.fill(n)(NULL_ROW).toDF
        data.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).isNullAt(i))
          assert(batch.column(1).isNullAt(i))
          assert(batch.column(2).isNullAt(i))
          assert(batch.column(3).isNullAt(i))
          i += 1
        }
        reader.close()
      }}
    }
  }

  test("Read row group containing both dictionary and plain encoded pages") {
    withSQLConf(ParquetOutputFormat.DICTIONARY_PAGE_SIZE -> "2048",
      ParquetOutputFormat.PAGE_SIZE -> "4096") {
      withTempPath { dir =>
        // In order to explicitly test for SPARK-14217, we set the parquet dictionary and page size
        // such that the following data spans across 3 pages (within a single row group) where the
        // first page is dictionary encoded and the remaining two are plain encoded.
        val data = (0 until 512).flatMap(i => Seq.fill(3)(i.toString))
        data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file, null )
        val column = reader.resultBatch().column(0)
        assert(reader.nextBatch())

        (0 until 512).foreach { i =>
          assert(column.getUTF8String(3 * i).toString == i.toString)
          assert(column.getUTF8String(3 * i + 1).toString == i.toString)
          assert(column.getUTF8String(3 * i + 2).toString == i.toString)
        }
        reader.close()
      }
    }
  }
} 
Example 123
Source File: ParquetProtobufCompatibilitySuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("unannotated array of primitive type") {
    checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }

  test("unannotated array of struct") {
    checkAnswer(
      readResourceParquetFile("test-data/old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }

  test("struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }

  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }

  test("unannotated array of string") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 124
Source File: QueryExecutionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.test.SharedSQLContext

class QueryExecutionSuite extends SharedSQLContext {
  test("toString() exception/error handling") {
    val badRule = new SparkStrategy {
      var mode: String = ""
      override def apply(plan: LogicalPlan): Seq[SparkPlan] = mode.toLowerCase match {
        case "exception" => throw new AnalysisException(mode)
        case "error" => throw new Error(mode)
        case _ => Nil
      }
    }
    spark.experimental.extraStrategies = badRule :: Nil

    def qe: QueryExecution = new QueryExecution(spark, OneRowRelation)

    // Nothing!
    badRule.mode = ""
    assert(qe.toString.contains("OneRowRelation"))

    // Throw an AnalysisException - this should be captured.
    badRule.mode = "exception"
    assert(qe.toString.contains("org.apache.spark.sql.AnalysisException"))

    // Throw an Error - this should not be captured.
    badRule.mode = "error"
    val error = intercept[Error](qe.toString)
    assert(error.getMessage.contains("error"))
  }
} 
Example 125
Source File: ExchangeSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.test.SharedSQLContext

class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits._

  test("shuffling UnsafeRows in exchange") {
    val input = (1 to 1000).map(Tuple1.apply)
    checkAnswer(
      input.toDF(),
      plan => ShuffleExchange(SinglePartition, plan),
      input.map(Row.fromTuple)
    )
  }

  test("compatible BroadcastMode") {
    val mode1 = IdentityBroadcastMode
    val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
    val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)

    assert(mode1.compatibleWith(mode1))
    assert(!mode1.compatibleWith(mode2))
    assert(!mode2.compatibleWith(mode1))
    assert(mode2.compatibleWith(mode2))
    assert(!mode2.compatibleWith(mode3))
    assert(mode3.compatibleWith(mode3))
  }

  test("BroadcastExchange same result") {
    val df = spark.range(10)
    val plan = df.queryExecution.executedPlan
    val output = plan.output
    assert(plan sameResult plan)

    val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan)
    val hashMode = HashedRelationBroadcastMode(output)
    val exchange2 = BroadcastExchangeExec(hashMode, plan)
    val hashMode2 =
      HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil)
    val exchange3 = BroadcastExchangeExec(hashMode2, plan)
    val exchange4 = ReusedExchangeExec(output, exchange3)

    assert(exchange1 sameResult exchange1)
    assert(exchange2 sameResult exchange2)
    assert(exchange3 sameResult exchange3)
    assert(exchange4 sameResult exchange4)

    assert(!exchange1.sameResult(exchange2))
    assert(!exchange2.sameResult(exchange3))
    assert(!exchange3.sameResult(exchange4))
    assert(exchange4 sameResult exchange3)
  }

  test("ShuffleExchange same result") {
    val df = spark.range(10)
    val plan = df.queryExecution.executedPlan
    val output = plan.output
    assert(plan sameResult plan)

    val part1 = HashPartitioning(output, 1)
    val exchange1 = ShuffleExchange(part1, plan)
    val exchange2 = ShuffleExchange(part1, plan)
    val part2 = HashPartitioning(output, 2)
    val exchange3 = ShuffleExchange(part2, plan)
    val part3 = HashPartitioning(output ++ output, 2)
    val exchange4 = ShuffleExchange(part3, plan)
    val exchange5 = ReusedExchangeExec(output, exchange4)

    assert(exchange1 sameResult exchange1)
    assert(exchange2 sameResult exchange2)
    assert(exchange3 sameResult exchange3)
    assert(exchange4 sameResult exchange4)
    assert(exchange5 sameResult exchange5)

    assert(exchange1 sameResult exchange2)
    assert(!exchange2.sameResult(exchange3))
    assert(!exchange3.sameResult(exchange4))
    assert(!exchange4.sameResult(exchange5))
    assert(exchange5 sameResult exchange4)
  }
} 
Example 126
Source File: TakeOrderedAndProjectSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {

  private var rand: Random = _
  private var seed: Long = 0

  protected override def beforeAll(): Unit = {
    super.beforeAll()
    seed = System.currentTimeMillis()
    rand = new Random(seed)
  }

  private def generateRandomInputData(): DataFrame = {
    val schema = new StructType()
      .add("a", IntegerType, nullable = false)
      .add("b", IntegerType, nullable = false)
    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
    spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
  }

  
  private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan)

  val limit = 250
  val sortOrder = 'a.desc :: 'b.desc :: Nil

  test("TakeOrderedAndProject.doExecute without project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              SortExec(sortOrder, true, input))),
        sortAnswers = false)
    }
  }

  test("TakeOrderedAndProject.doExecute with project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(
            TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              ProjectExec(Seq(input.output.last),
                SortExec(sortOrder, true, input)))),
        sortAnswers = false)
    }
  }
} 
Example 127
Source File: SparkPlannerSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child) :: planLater(NeverPlanned) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(planLater)) :: planLater(NeverPlanned) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data) :: planLater(NeverPlanned) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 128
Source File: WholeStageCodegenSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 129
Source File: OffsetSeqLogSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.File

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext

class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {

  
  case class StringOffset(override val json: String) extends Offset

  test("OffsetSeqMetadata - deserialization") {
    assert(OffsetSeqMetadata(0, 0) === OffsetSeqMetadata("""{}"""))
    assert(OffsetSeqMetadata(1, 0) === OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
    assert(OffsetSeqMetadata(0, 2) === OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
    assert(
      OffsetSeqMetadata(1, 2) ===
        OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
  }

  testWithUninterruptibleThread("OffsetSeqLog - serialization - deserialization") {
    withTempDir { temp =>
      val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
      val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
      val batch0 = OffsetSeq.fill(LongOffset(0), LongOffset(1), LongOffset(2))
      val batch1 = OffsetSeq.fill(StringOffset("one"), StringOffset("two"), StringOffset("three"))

      val batch0Serialized = OffsetSeq.fill(batch0.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      val batch1Serialized = OffsetSeq.fill(batch1.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      assert(metadataLog.add(0, batch0))
      assert(metadataLog.getLatest() === Some(0 -> batch0Serialized))
      assert(metadataLog.get(0) === Some(batch0Serialized))

      assert(metadataLog.add(1, batch1))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))

      // Adding the same batch does nothing
      metadataLog.add(1, OffsetSeq.fill(LongOffset(3)))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))
    }
  }

  test("read Spark 2.1.0 log format") {
    val (batchId, offsetSeq) = readFromResource("offset-log-version-2.1.0")
    assert(batchId === 0)
    assert(offsetSeq.offsets === Seq(
      Some(SerializedOffset("""{"logOffset":345}""")),
      Some(SerializedOffset("""{"topic-0":{"0":1}}"""))
    ))
    assert(offsetSeq.metadata === Some(OffsetSeqMetadata(0L, 1480981499528L)))
  }

  private def readFromResource(dir: String): (Long, OffsetSeq) = {
    val input = getClass.getResource(s"/structured-streaming/$dir")
    val log = new OffsetSeqLog(spark, input.toString)
    log.getLatest().get
  }
} 
Example 130
Source File: FileStreamSourceSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.File
import java.net.URI

import scala.util.Random

import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.streaming.ExistsThrowsExceptionFileSystem._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType

class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext {

  import FileStreamSource._

  test("SeenFilesMap") {
    val map = new SeenFilesMap(maxAgeMs = 10)

    map.add("a", 5)
    assert(map.size == 1)
    map.purge()
    assert(map.size == 1)

    // Add a new entry and purge should be no-op, since the gap is exactly 10 ms.
    map.add("b", 15)
    assert(map.size == 2)
    map.purge()
    assert(map.size == 2)

    // Add a new entry that's more than 10 ms than the first entry. We should be able to purge now.
    map.add("c", 16)
    assert(map.size == 3)
    map.purge()
    assert(map.size == 2)

    // Override existing entry shouldn't change the size
    map.add("c", 25)
    assert(map.size == 2)

    // Not a new file because we have seen c before
    assert(!map.isNewFile("c", 20))

    // Not a new file because timestamp is too old
    assert(!map.isNewFile("d", 5))

    // Finally a new file: never seen and not too old
    assert(map.isNewFile("e", 20))
  }

  test("SeenFilesMap should only consider a file old if it is earlier than last purge time") {
    val map = new SeenFilesMap(maxAgeMs = 10)

    map.add("a", 20)
    assert(map.size == 1)

    // Timestamp 5 should still considered a new file because purge time should be 0
    assert(map.isNewFile("b", 9))
    assert(map.isNewFile("b", 10))

    // Once purge, purge time should be 10 and then b would be a old file if it is less than 10.
    map.purge()
    assert(!map.isNewFile("b", 9))
    assert(map.isNewFile("b", 10))
  }

  testWithUninterruptibleThread("do not recheck that files exist during getBatch") {
    withTempDir { temp =>
      spark.conf.set(
        s"fs.$scheme.impl",
        classOf[ExistsThrowsExceptionFileSystem].getName)
      // add the metadata entries as a pre-req
      val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
      val metadataLog =
        new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath)
      assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0))))

      val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil,
        dir.getAbsolutePath, Map.empty)
      // this method should throw an exception if `fs.exists` is called during resolveRelation
      newSource.getBatch(None, FileStreamSourceOffset(1))
    }
  }
}


  override def listStatus(file: Path): Array[FileStatus] = {
    val emptyFile = new FileStatus()
    emptyFile.setPath(file)
    Array(emptyFile)
  }
}

object ExistsThrowsExceptionFileSystem {
  val scheme = s"FileStreamSourceSuite${math.abs(Random.nextInt)}fs"
} 
Example 131
Source File: DataFrameTungstenSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("test simple types") {
    val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
    assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
  }

  test("test struct type") {
    val struct = Row(1, 2L, 3.0F, 3.0)
    val data = sparkContext.parallelize(Seq(Row(1, struct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(struct))
  }

  test("test nested struct type") {
    val innerStruct = Row(1, "abcd")
    val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
    val data = sparkContext.parallelize(Seq(Row(1, outerStruct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType)
          .add("b5", new StructType()
          .add("b5a", IntegerType)
          .add("b5b", StringType))
          .add("b6", StringType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(outerStruct))
  }
} 
Example 132
Source File: RowSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 133
Source File: ExtraStrategiesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()
    sparkContext.parallelize(Seq(unsafeRow))
  }

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil
}

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
  }
}

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 134
Source File: DatasetCacheSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.storage.StorageLevel


class DatasetCacheSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("get storage level") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    // default storage level
    ds1.persist()
    ds2.cache()
    assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK)
    assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK)
    // unpersist
    ds1.unpersist()
    assert(ds1.storageLevel == StorageLevel.NONE)
    // non-default storage level
    ds1.persist(StorageLevel.MEMORY_ONLY_2)
    assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2)
    // joined Dataset should not be persisted
    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    assert(joined.storageLevel == StorageLevel.NONE)
  }

  test("persist and unpersist") {
    val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
    val cached = ds.cache()
    // count triggers the caching action. It should not throw.
    cached.count()
    // Make sure, the Dataset is indeed cached.
    assertCached(cached)
    // Check result.
    checkDataset(
      cached,
      2, 3, 4)
    // Drop the cache.
    cached.unpersist()
    assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.")
  }

  test("persist and then rebind right encoder when join 2 datasets") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    ds1.persist()
    assertCached(ds1)
    ds2.persist()
    assertCached(ds2)

    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    checkDataset(joined, ("2", 2))
    assertCached(joined, 2)

    ds1.unpersist()
    assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.")
    ds2.unpersist()
    assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.")
  }

  test("persist and then groupBy columns asKey, map") {
    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
    val grouped = ds.groupByKey(_._1)
    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
    agged.persist()

    checkDataset(
      agged.filter(_._1 == "b"),
      ("b", 3))
    assertCached(agged.filter(_._1 == "b"))

    ds.unpersist()
    assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be cached.")
    agged.unpersist()
    assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.")
  }
} 
Example 135
Source File: DDLSourceLoadSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      spark.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without Hive Support") {
    val e = intercept[AnalysisException] {
      spark.read.format("orc").load()
    }
    assert(e.message.contains("The ORC data source must be used with Hive support enabled"))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 136
Source File: PartitionedWriteSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("write many partitions") {
    val path = Utils.createTempDir()
    path.delete()

    val df = spark.range(100).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      spark.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("write many partitions with repeats") {
    val path = Utils.createTempDir()
    path.delete()

    val base = spark.range(100)
    val df = base.union(base).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      spark.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("partitioned columns should appear at the end of schema") {
    withTempPath { f =>
      val path = f.getAbsolutePath
      Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path)
      assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
    }
  }
} 
Example 137
Source File: MiscFunctionsSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class MiscFunctionsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("reflect and java_method") {
    val df = Seq((1, "one")).toDF("a", "b")
    val className = ReflectClass.getClass.getName.stripSuffix("$")
    checkAnswer(
      df.selectExpr(
        s"reflect('$className', 'method1', a, b)",
        s"java_method('$className', 'method1', a, b)"),
      Row("m1one", "m1one"))
  }
}

object ReflectClass {
  def method1(v1: Int, v2: String): String = "m" + v1 + v2
} 
Example 138
Source File: KafkaSourceOffsetSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.kafka010

import java.io.File

import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.streaming.OffsetSuite
import org.apache.spark.sql.test.SharedSQLContext

class KafkaSourceOffsetSuite extends OffsetSuite with SharedSQLContext {

  compare(
    one = KafkaSourceOffset(("t", 0, 1L)),
    two = KafkaSourceOffset(("t", 0, 2L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L), ("t", 1, 0L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L), ("T", 0, 0L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("T", 0, 1L)))

  compare(
    one = KafkaSourceOffset(("t", 0, 1L)),
    two = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 1L)))


  val kso1 = KafkaSourceOffset(("t", 0, 1L))
  val kso2 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L))
  val kso3 = KafkaSourceOffset(("t", 0, 2L), ("t", 1, 3L), ("t", 1, 4L))

  compare(KafkaSourceOffset(SerializedOffset(kso1.json)),
    KafkaSourceOffset(SerializedOffset(kso2.json)))

  test("basic serialization - deserialization") {
    assert(KafkaSourceOffset.getPartitionOffsets(kso1) ==
      KafkaSourceOffset.getPartitionOffsets(SerializedOffset(kso1.json)))
  }


  testWithUninterruptibleThread("OffsetSeqLog serialization - deserialization") {
    withTempDir { temp =>
      // use non-existent directory to test whether log make the dir
      val dir = new File(temp, "dir")
      val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
      val batch0 = OffsetSeq.fill(kso1)
      val batch1 = OffsetSeq.fill(kso2, kso3)

      val batch0Serialized = OffsetSeq.fill(batch0.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      val batch1Serialized = OffsetSeq.fill(batch1.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      assert(metadataLog.add(0, batch0))
      assert(metadataLog.getLatest() === Some(0 -> batch0Serialized))
      assert(metadataLog.get(0) === Some(batch0Serialized))

      assert(metadataLog.add(1, batch1))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))

      // Adding the same batch does nothing
      metadataLog.add(1, OffsetSeq.fill(LongOffset(3)))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))
    }
  }

  test("read Spark 2.1.0 offset format") {
    val offset = readFromResource("kafka-source-offset-version-2.1.0.txt")
    assert(KafkaSourceOffset(offset) ===
      KafkaSourceOffset(("topic1", 0, 456L), ("topic1", 1, 789L), ("topic2", 0, 0L)))
  }

  private def readFromResource(file: String): SerializedOffset = {
    import scala.io.Source
    val input = getClass.getResource(s"/$file").toURI
    val str = Source.fromFile(input).mkString
    SerializedOffset(str)
  }
} 
Example 139
Source File: SQLUtilsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.api.r

import org.apache.spark.sql.test.SharedSQLContext

class SQLUtilsSuite extends SharedSQLContext {

  import testImplicits._

  test("dfToCols should collect and transpose a data frame") {
    val df = Seq(
      (1, 2, 3),
      (4, 5, 6)
    ).toDF
    assert(SQLUtils.dfToCols(df) === Array(
      Array(1, 4),
      Array(2, 5),
      Array(3, 6)
    ))
  }

} 
Example 140
Source File: DataFrameImplicitsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("RDD of tuples") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("Seq of tuples") {
    checkAnswer(
      (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("RDD[Int]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).toDF("intCol"),
      (1 to 10).map(i => Row(i)))
  }

  test("RDD[Long]") {
    checkAnswer(
      sparkContext.parallelize(1L to 10L).toDF("longCol"),
      (1L to 10L).map(i => Row(i)))
  }

  test("RDD[String]") {
    checkAnswer(
      sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
      (1 to 10).map(i => Row(i.toString)))
  }
} 
Example 141
Source File: DebuggingSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.debug

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData.TestData

class DebuggingSuite extends SparkFunSuite with SharedSQLContext {

  test("DataFrame.debug()") {
    testData.debug()
  }

  test("Dataset.debug()") {
    import testImplicits._
    testData.as[TestData].debug()
  }

  test("debugCodegen") {
    val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan)
    assert(res.contains("Subtree 1 / 2"))
    assert(res.contains("Subtree 2 / 2"))
    assert(res.contains("Object[]"))
  }
} 
Example 142
Source File: RowDataSourceStrategySuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE TEMPORARY TABLE inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
      """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 143
Source File: HadoopFsRelationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.io.{File, FilenameFilter}

import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.test.SharedSQLContext

class HadoopFsRelationSuite extends QueryTest with SharedSQLContext {

  test("sizeInBytes should be the total size of all files") {
    withTempDir{ dir =>
      dir.delete()
      spark.range(1000).write.parquet(dir.toString)
      // ignore hidden files
      val allFiles = dir.listFiles(new FilenameFilter {
        override def accept(dir: File, name: String): Boolean = {
          !name.startsWith(".")
        }
      })
      val totalSize = allFiles.map(_.length()).sum
      val df = spark.read.parquet(dir.toString)
      assert(df.queryExecution.logical.statistics.sizeInBytes === BigInt(totalSize))
    }
  }
} 
Example 144
Source File: ParquetInteroperabilitySuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import java.io.File

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetInteroperabilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("parquet files with different physical schemas but share the same logical schema") {
    import ParquetCompatibilityTest._

    // This test case writes two Parquet files, both representing the following Catalyst schema
    //
    //   StructType(
    //     StructField(
    //       "f",
    //       ArrayType(IntegerType, containsNull = false),
    //       nullable = false))
    //
    // The first Parquet file comes with parquet-avro style 2-level LIST-annotated group, while the
    // other one comes with parquet-protobuf style 1-level unannotated primitive field.
    withTempDir { dir =>
      val avroStylePath = new File(dir, "avro-style").getCanonicalPath
      val protobufStylePath = new File(dir, "protobuf-style").getCanonicalPath

      val avroStyleSchema =
        """message avro_style {
          |  required group f (LIST) {
          |    repeated int32 array;
          |  }
          |}
        """.stripMargin

      writeDirect(avroStylePath, avroStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.group {
              rc.field("array", 0) {
                rc.addInteger(0)
                rc.addInteger(1)
              }
            }
          }
        }
      })

      logParquetSchema(avroStylePath)

      val protobufStyleSchema =
        """message protobuf_style {
          |  repeated int32 f;
          |}
        """.stripMargin

      writeDirect(protobufStylePath, protobufStyleSchema, { rc =>
        rc.message {
          rc.field("f", 0) {
            rc.addInteger(2)
            rc.addInteger(3)
          }
        }
      })

      logParquetSchema(protobufStylePath)

      checkAnswer(
        spark.read.parquet(dir.getCanonicalPath),
        Seq(
          Row(Seq(0, 1)),
          Row(Seq(2, 3))))
    }
  }
} 
Example 145
Source File: ParquetEncodingSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.JavaConverters._

import org.apache.parquet.hadoop.ParquetOutputFormat

import org.apache.spark.sql.test.SharedSQLContext

// TODO: this needs a lot more testing but it's currently not easy to test with the parquet
// writer abstractions. Revisit.
class ParquetEncodingSuite extends ParquetCompatibilityTest with SharedSQLContext {
  import testImplicits._

  val ROW = ((1).toByte, 2, 3L, "abc")
  val NULL_ROW = (
    null.asInstanceOf[java.lang.Byte],
    null.asInstanceOf[Integer],
    null.asInstanceOf[java.lang.Long],
    null.asInstanceOf[String])

  test("All Types Dictionary") {
    (1 :: 1000 :: Nil).foreach { n => {
      withTempPath { dir =>
        List.fill(n)(ROW).toDF.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).getByte(i) == 1)
          assert(batch.column(1).getInt(i) == 2)
          assert(batch.column(2).getLong(i) == 3)
          assert(batch.column(3).getUTF8String(i).toString == "abc")
          i += 1
        }
        reader.close()
      }
    }}
  }

  test("All Types Null") {
    (1 :: 100 :: Nil).foreach { n => {
      withTempPath { dir =>
        val data = List.fill(n)(NULL_ROW).toDF
        data.repartition(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).toArray.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file.asInstanceOf[String], null)
        val batch = reader.resultBatch()
        assert(reader.nextBatch())
        assert(batch.numRows() == n)
        var i = 0
        while (i < n) {
          assert(batch.column(0).isNullAt(i))
          assert(batch.column(1).isNullAt(i))
          assert(batch.column(2).isNullAt(i))
          assert(batch.column(3).isNullAt(i))
          i += 1
        }
        reader.close()
      }}
    }
  }

  test("Read row group containing both dictionary and plain encoded pages") {
    withSQLConf(ParquetOutputFormat.DICTIONARY_PAGE_SIZE -> "2048",
      ParquetOutputFormat.PAGE_SIZE -> "4096") {
      withTempPath { dir =>
        // In order to explicitly test for SPARK-14217, we set the parquet dictionary and page size
        // such that the following data spans across 3 pages (within a single row group) where the
        // first page is dictionary encoded and the remaining two are plain encoded.
        val data = (0 until 512).flatMap(i => Seq.fill(3)(i.toString))
        data.toDF("f").coalesce(1).write.parquet(dir.getCanonicalPath)
        val file = SpecificParquetRecordReaderBase.listDirectory(dir).asScala.head

        val reader = new VectorizedParquetRecordReader
        reader.initialize(file, null )
        val column = reader.resultBatch().column(0)
        assert(reader.nextBatch())

        (0 until 512).foreach { i =>
          assert(column.getUTF8String(3 * i).toString == i.toString)
          assert(column.getUTF8String(3 * i + 1).toString == i.toString)
          assert(column.getUTF8String(3 * i + 2).toString == i.toString)
        }
        reader.close()
      }
    }
  }
} 
Example 146
Source File: ParquetProtobufCompatibilitySuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  test("unannotated array of primitive type") {
    checkAnswer(readResourceParquetFile("test-data/old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }

  test("unannotated array of struct") {
    checkAnswer(
      readResourceParquetFile("test-data/old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }

  test("struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }

  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readResourceParquetFile("test-data/nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }

  test("unannotated array of string") {
    checkAnswer(
      readResourceParquetFile("test-data/proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 147
Source File: QueryExecutionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation}
import org.apache.spark.sql.test.SharedSQLContext

class QueryExecutionSuite extends SharedSQLContext {
  test("toString() exception/error handling") {
    val badRule = new SparkStrategy {
      var mode: String = ""
      override def apply(plan: LogicalPlan): Seq[SparkPlan] = mode.toLowerCase match {
        case "exception" => throw new AnalysisException(mode)
        case "error" => throw new Error(mode)
        case _ => Nil
      }
    }
    spark.experimental.extraStrategies = badRule :: Nil

    def qe: QueryExecution = new QueryExecution(spark, OneRowRelation)

    // Nothing!
    badRule.mode = ""
    assert(qe.toString.contains("OneRowRelation"))

    // Throw an AnalysisException - this should be captured.
    badRule.mode = "exception"
    assert(qe.toString.contains("org.apache.spark.sql.AnalysisException"))

    // Throw an Error - this should not be captured.
    badRule.mode = "error"
    val error = intercept[Error](qe.toString)
    assert(error.getMessage.contains("error"))
  }
} 
Example 148
Source File: ExchangeSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange}
import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.test.SharedSQLContext

class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits._

  test("shuffling UnsafeRows in exchange") {
    val input = (1 to 1000).map(Tuple1.apply)
    checkAnswer(
      input.toDF(),
      plan => ShuffleExchange(SinglePartition, plan),
      input.map(Row.fromTuple)
    )
  }

  test("compatible BroadcastMode") {
    val mode1 = IdentityBroadcastMode
    val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
    val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)

    assert(mode1.compatibleWith(mode1))
    assert(!mode1.compatibleWith(mode2))
    assert(!mode2.compatibleWith(mode1))
    assert(mode2.compatibleWith(mode2))
    assert(!mode2.compatibleWith(mode3))
    assert(mode3.compatibleWith(mode3))
  }

  test("BroadcastExchange same result") {
    val df = spark.range(10)
    val plan = df.queryExecution.executedPlan
    val output = plan.output
    assert(plan sameResult plan)

    val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan)
    val hashMode = HashedRelationBroadcastMode(output)
    val exchange2 = BroadcastExchangeExec(hashMode, plan)
    val hashMode2 =
      HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil)
    val exchange3 = BroadcastExchangeExec(hashMode2, plan)
    val exchange4 = ReusedExchangeExec(output, exchange3, sparkContext.sparkUser)

    assert(exchange1 sameResult exchange1)
    assert(exchange2 sameResult exchange2)
    assert(exchange3 sameResult exchange3)
    assert(exchange4 sameResult exchange4)

    assert(!exchange1.sameResult(exchange2))
    assert(!exchange2.sameResult(exchange3))
    assert(!exchange3.sameResult(exchange4))
    assert(exchange4 sameResult exchange3)
  }

  test("ShuffleExchange same result") {
    val df = spark.range(10)
    val plan = df.queryExecution.executedPlan
    val output = plan.output
    assert(plan sameResult plan)

    val part1 = HashPartitioning(output, 1)
    val exchange1 = ShuffleExchange(part1, plan)
    val exchange2 = ShuffleExchange(part1, plan)
    val part2 = HashPartitioning(output, 2)
    val exchange3 = ShuffleExchange(part2, plan)
    val part3 = HashPartitioning(output ++ output, 2)
    val exchange4 = ShuffleExchange(part3, plan)
    val exchange5 = ReusedExchangeExec(output, exchange4, sparkContext.sparkUser)

    assert(exchange1 sameResult exchange1)
    assert(exchange2 sameResult exchange2)
    assert(exchange3 sameResult exchange3)
    assert(exchange4 sameResult exchange4)
    assert(exchange5 sameResult exchange5)

    assert(exchange1 sameResult exchange2)
    assert(!exchange2.sameResult(exchange3))
    assert(!exchange3.sameResult(exchange4))
    assert(!exchange4.sameResult(exchange5))
    assert(exchange5 sameResult exchange4)
  }
} 
Example 149
Source File: TakeOrderedAndProjectSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {

  private var rand: Random = _
  private var seed: Long = 0

  protected override def beforeAll(): Unit = {
    super.beforeAll()
    seed = System.currentTimeMillis()
    rand = new Random(seed)
  }

  private def generateRandomInputData(): DataFrame = {
    val schema = new StructType()
      .add("a", IntegerType, nullable = false)
      .add("b", IntegerType, nullable = false)
    val inputData = Seq.fill(10000)(Row(rand.nextInt(), rand.nextInt()))
    spark.createDataFrame(sparkContext.parallelize(Random.shuffle(inputData), 10), schema)
  }

  
  private def noOpFilter(plan: SparkPlan): SparkPlan = FilterExec(Literal(true), plan)

  val limit = 250
  val sortOrder = 'a.desc :: 'b.desc :: Nil

  test("TakeOrderedAndProject.doExecute without project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              SortExec(sortOrder, true, input))),
        sortAnswers = false)
    }
  }

  test("TakeOrderedAndProject.doExecute with project") {
    withClue(s"seed = $seed") {
      checkThatPlansAgree(
        generateRandomInputData(),
        input =>
          noOpFilter(
            TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
        input =>
          GlobalLimitExec(limit,
            LocalLimitExec(limit,
              ProjectExec(Seq(input.output.last),
                SortExec(sortOrder, true, input)))),
        sortAnswers = false)
    }
  }
} 
Example 150
Source File: SparkPlannerSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, ReturnAnswer, Union}
import org.apache.spark.sql.test.SharedSQLContext

class SparkPlannerSuite extends SharedSQLContext {
  import testImplicits._

  test("Ensure to go down only the first branch, not any other possible branches") {

    case object NeverPlanned extends LeafNode {
      override def output: Seq[Attribute] = Nil
    }

    var planned = 0
    object TestStrategy extends Strategy {
      def user: String = sparkContext.sparkUser
      def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
        case ReturnAnswer(child) =>
          planned += 1
          planLater(child, user) :: planLater(NeverPlanned, user) :: Nil
        case Union(children) =>
          planned += 1
          UnionExec(children.map(p => planLater(p, user))) :: planLater(NeverPlanned, user) :: Nil
        case LocalRelation(output, data) =>
          planned += 1
          LocalTableScanExec(output, data, user) :: planLater(NeverPlanned, user) :: Nil
        case NeverPlanned =>
          fail("QueryPlanner should not go down to this branch.")
        case _ => Nil
      }
    }

    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val ds = Seq("a", "b", "c").toDS().union(Seq("d", "e", "f").toDS())

      assert(ds.collect().toSeq === Seq("a", "b", "c", "d", "e", "f"))
      assert(planned === 4)
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 151
Source File: WholeStageCodegenSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 152
Source File: OffsetSeqLogSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.File

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext

class OffsetSeqLogSuite extends SparkFunSuite with SharedSQLContext {

  
  case class StringOffset(override val json: String) extends Offset

  test("OffsetSeqMetadata - deserialization") {
    assert(OffsetSeqMetadata(0, 0) === OffsetSeqMetadata("""{}"""))
    assert(OffsetSeqMetadata(1, 0) === OffsetSeqMetadata("""{"batchWatermarkMs":1}"""))
    assert(OffsetSeqMetadata(0, 2) === OffsetSeqMetadata("""{"batchTimestampMs":2}"""))
    assert(
      OffsetSeqMetadata(1, 2) ===
        OffsetSeqMetadata("""{"batchWatermarkMs":1,"batchTimestampMs":2}"""))
  }

  testWithUninterruptibleThread("OffsetSeqLog - serialization - deserialization") {
    withTempDir { temp =>
      val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
      val metadataLog = new OffsetSeqLog(spark, dir.getAbsolutePath)
      val batch0 = OffsetSeq.fill(LongOffset(0), LongOffset(1), LongOffset(2))
      val batch1 = OffsetSeq.fill(StringOffset("one"), StringOffset("two"), StringOffset("three"))

      val batch0Serialized = OffsetSeq.fill(batch0.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      val batch1Serialized = OffsetSeq.fill(batch1.offsets.flatMap(_.map(o =>
        SerializedOffset(o.json))): _*)

      assert(metadataLog.add(0, batch0))
      assert(metadataLog.getLatest() === Some(0 -> batch0Serialized))
      assert(metadataLog.get(0) === Some(batch0Serialized))

      assert(metadataLog.add(1, batch1))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))

      // Adding the same batch does nothing
      metadataLog.add(1, OffsetSeq.fill(LongOffset(3)))
      assert(metadataLog.get(0) === Some(batch0Serialized))
      assert(metadataLog.get(1) === Some(batch1Serialized))
      assert(metadataLog.getLatest() === Some(1 -> batch1Serialized))
      assert(metadataLog.get(None, Some(1)) ===
        Array(0 -> batch0Serialized, 1 -> batch1Serialized))
    }
  }

  test("read Spark 2.1.0 log format") {
    val (batchId, offsetSeq) = readFromResource("offset-log-version-2.1.0")
    assert(batchId === 0)
    assert(offsetSeq.offsets === Seq(
      Some(SerializedOffset("""{"logOffset":345}""")),
      Some(SerializedOffset("""{"topic-0":{"0":1}}"""))
    ))
    assert(offsetSeq.metadata === Some(OffsetSeqMetadata(0L, 1480981499528L)))
  }

  private def readFromResource(dir: String): (Long, OffsetSeq) = {
    val input = getClass.getResource(s"/structured-streaming/$dir")
    val log = new OffsetSeqLog(spark, input.toString)
    log.getLatest().get
  }
} 
Example 153
Source File: FileStreamSourceSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import java.io.File
import java.net.URI

import scala.util.Random

import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.streaming.ExistsThrowsExceptionFileSystem._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType

class FileStreamSourceSuite extends SparkFunSuite with SharedSQLContext {

  import FileStreamSource._

  test("SeenFilesMap") {
    val map = new SeenFilesMap(maxAgeMs = 10)

    map.add("a", 5)
    assert(map.size == 1)
    map.purge()
    assert(map.size == 1)

    // Add a new entry and purge should be no-op, since the gap is exactly 10 ms.
    map.add("b", 15)
    assert(map.size == 2)
    map.purge()
    assert(map.size == 2)

    // Add a new entry that's more than 10 ms than the first entry. We should be able to purge now.
    map.add("c", 16)
    assert(map.size == 3)
    map.purge()
    assert(map.size == 2)

    // Override existing entry shouldn't change the size
    map.add("c", 25)
    assert(map.size == 2)

    // Not a new file because we have seen c before
    assert(!map.isNewFile("c", 20))

    // Not a new file because timestamp is too old
    assert(!map.isNewFile("d", 5))

    // Finally a new file: never seen and not too old
    assert(map.isNewFile("e", 20))
  }

  test("SeenFilesMap should only consider a file old if it is earlier than last purge time") {
    val map = new SeenFilesMap(maxAgeMs = 10)

    map.add("a", 20)
    assert(map.size == 1)

    // Timestamp 5 should still considered a new file because purge time should be 0
    assert(map.isNewFile("b", 9))
    assert(map.isNewFile("b", 10))

    // Once purge, purge time should be 10 and then b would be a old file if it is less than 10.
    map.purge()
    assert(!map.isNewFile("b", 9))
    assert(map.isNewFile("b", 10))
  }

  testWithUninterruptibleThread("do not recheck that files exist during getBatch") {
    withTempDir { temp =>
      spark.conf.set(
        s"fs.$scheme.impl",
        classOf[ExistsThrowsExceptionFileSystem].getName)
      // add the metadata entries as a pre-req
      val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir
      val metadataLog =
        new FileStreamSourceLog(FileStreamSourceLog.VERSION, spark, dir.getAbsolutePath)
      assert(metadataLog.add(0, Array(FileEntry(s"$scheme:///file1", 100L, 0))))

      val newSource = new FileStreamSource(spark, s"$scheme:///", "parquet", StructType(Nil), Nil,
        dir.getAbsolutePath, Map.empty)
      // this method should throw an exception if `fs.exists` is called during resolveRelation
      newSource.getBatch(None, FileStreamSourceOffset(1))
    }
  }
}


  override def listStatus(file: Path): Array[FileStatus] = {
    val emptyFile = new FileStatus()
    emptyFile.setPath(file)
    Array(emptyFile)
  }
}

object ExistsThrowsExceptionFileSystem {
  val scheme = s"FileStreamSourceSuite${math.abs(Random.nextInt)}fs"
} 
Example 154
Source File: DataFrameTungstenSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("test simple types") {
    val df = sparkContext.parallelize(Seq((1, 2))).toDF("a", "b")
    assert(df.select(struct("a", "b")).first().getStruct(0) === Row(1, 2))
  }

  test("test struct type") {
    val struct = Row(1, 2L, 3.0F, 3.0)
    val data = sparkContext.parallelize(Seq(Row(1, struct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(struct))
  }

  test("test nested struct type") {
    val innerStruct = Row(1, "abcd")
    val outerStruct = Row(1, 2L, 3.0F, 3.0, innerStruct, "efg")
    val data = sparkContext.parallelize(Seq(Row(1, outerStruct)))

    val schema = new StructType()
      .add("a", IntegerType)
      .add("b",
        new StructType()
          .add("b1", IntegerType)
          .add("b2", LongType)
          .add("b3", FloatType)
          .add("b4", DoubleType)
          .add("b5", new StructType()
          .add("b5a", IntegerType)
          .add("b5b", StringType))
          .add("b6", StringType))

    val df = spark.createDataFrame(data, schema)
    assert(df.select("b").first() === Row(outerStruct))
  }
} 
Example 155
Source File: RowSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 156
Source File: ExtraStrategiesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.test.SharedSQLContext

case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    val unsafeProj = UnsafeProjection.create(schema)
    val unsafeRow = unsafeProj(row).copy()
    sparkContext.parallelize(Seq(unsafeRow))
  }

  override def producedAttributes: AttributeSet = outputSet
  override def children: Seq[SparkPlan] = Nil
}

object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      FastOperator(attr.toAttribute :: Nil) :: Nil
    case _ => Nil
  }
}

class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {
    try {
      spark.experimental.extraStrategies = TestStrategy :: Nil

      val df = sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      spark.experimental.extraStrategies = Nil
    }
  }
} 
Example 157
Source File: DatasetCacheSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.storage.StorageLevel


class DatasetCacheSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("get storage level") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    // default storage level
    ds1.persist()
    ds2.cache()
    assert(ds1.storageLevel == StorageLevel.MEMORY_AND_DISK)
    assert(ds2.storageLevel == StorageLevel.MEMORY_AND_DISK)
    // unpersist
    ds1.unpersist()
    assert(ds1.storageLevel == StorageLevel.NONE)
    // non-default storage level
    ds1.persist(StorageLevel.MEMORY_ONLY_2)
    assert(ds1.storageLevel == StorageLevel.MEMORY_ONLY_2)
    // joined Dataset should not be persisted
    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    assert(joined.storageLevel == StorageLevel.NONE)
  }

  test("persist and unpersist") {
    val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS().select(expr("_2 + 1").as[Int])
    val cached = ds.cache()
    // count triggers the caching action. It should not throw.
    cached.count()
    // Make sure, the Dataset is indeed cached.
    assertCached(cached)
    // Check result.
    checkDataset(
      cached,
      2, 3, 4)
    // Drop the cache.
    cached.unpersist()
    assert(cached.storageLevel == StorageLevel.NONE, "The Dataset should not be cached.")
  }

  test("persist and then rebind right encoder when join 2 datasets") {
    val ds1 = Seq("1", "2").toDS().as("a")
    val ds2 = Seq(2, 3).toDS().as("b")

    ds1.persist()
    assertCached(ds1)
    ds2.persist()
    assertCached(ds2)

    val joined = ds1.joinWith(ds2, $"a.value" === $"b.value")
    checkDataset(joined, ("2", 2))
    assertCached(joined, 2)

    ds1.unpersist()
    assert(ds1.storageLevel == StorageLevel.NONE, "The Dataset ds1 should not be cached.")
    ds2.unpersist()
    assert(ds2.storageLevel == StorageLevel.NONE, "The Dataset ds2 should not be cached.")
  }

  test("persist and then groupBy columns asKey, map") {
    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
    val grouped = ds.groupByKey(_._1)
    val agged = grouped.mapGroups { case (g, iter) => (g, iter.map(_._2).sum) }
    agged.persist()

    checkDataset(
      agged.filter(_._1 == "b"),
      ("b", 3))
    assertCached(agged.filter(_._1 == "b"))

    ds.unpersist()
    assert(ds.storageLevel == StorageLevel.NONE, "The Dataset ds should not be cached.")
    agged.unpersist()
    assert(agged.storageLevel == StorageLevel.NONE, "The Dataset agged should not be cached.")
  }
} 
Example 158
Source File: DDLSourceLoadSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{AnalysisException, SQLContext}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {
    intercept[RuntimeException] {
      spark.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {
    spark.read.format("gathering quorum").load().schema ==
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {
    spark.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without Hive Support") {
    val e = intercept[AnalysisException] {
      spark.read.format("orc").load()
    }
    assert(e.message.contains("The ORC data source must be used with Hive support enabled"))
  }
}


class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}

class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 159
Source File: PartitionedWriteSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils

class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("write many partitions") {
    val path = Utils.createTempDir()
    path.delete()

    val df = spark.range(100).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      spark.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("write many partitions with repeats") {
    val path = Utils.createTempDir()
    path.delete()

    val base = spark.range(100)
    val df = base.union(base).select($"id", lit(1).as("data"))
    df.write.partitionBy("id").save(path.getCanonicalPath)

    checkAnswer(
      spark.read.load(path.getCanonicalPath),
      (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)

    Utils.deleteRecursively(path)
  }

  test("partitioned columns should appear at the end of schema") {
    withTempPath { f =>
      val path = f.getAbsolutePath
      Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path)
      assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i"))
    }
  }
} 
Example 160
Source File: MiscFunctionsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext

class MiscFunctionsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("reflect and java_method") {
    val df = Seq((1, "one")).toDF("a", "b")
    val className = ReflectClass.getClass.getName.stripSuffix("$")
    checkAnswer(
      df.selectExpr(
        s"reflect('$className', 'method1', a, b)",
        s"java_method('$className', 'method1', a, b)"),
      Row("m1one", "m1one"))
  }
}

object ReflectClass {
  def method1(v1: Int, v2: String): String = "m" + v1 + v2
} 
Example 161
Source File: SQLContextSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SharedSQLContext
//SQL上下文测试套件
class SQLContextSuite extends SparkFunSuite with SharedSQLContext {

  override def afterAll(): Unit = {
    try {
      SQLContext.setLastInstantiatedContext(ctx)
    } finally {
      super.afterAll()
    }
  }

  test("getOrCreate instantiates SQLContext") {//获取或创建实例化SQL上下文
    SQLContext.clearLastInstantiatedContext()
    val sqlContext = SQLContext.getOrCreate(ctx.sparkContext)
    assert(sqlContext != null, "SQLContext.getOrCreate returned null")
    assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext),
      "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate")
  }

  test("getOrCreate gets last explicitly instantiated SQLContext") {//获得或创造获取最后的显式实例化SQL上下文
    SQLContext.clearLastInstantiatedContext()
    val sqlContext = new SQLContext(ctx.sparkContext)
    assert(SQLContext.getOrCreate(ctx.sparkContext) != null,
      "SQLContext.getOrCreate after explicitly created SQLContext returned null")
    assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext),
      "SQLContext.getOrCreate after explicitly created SQLContext did not return the context")
  }
} 
Example 162
Source File: JsonFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext
//JSON函数测试套件
class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("function get_json_object") {//使用get_json_object函数得JSON对象
    val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b")
    df.registerTempTable("df")   
    
    df.show()  
    checkAnswer(
      df.selectExpr("get_json_object(a, '$.name')", "get_json_object(a, '$.age')"),
      Row("alice", "5"))
  }

} 
Example 163
Source File: DataFrameImplicitsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.test.SharedSQLContext
//DataFrame数据隐式转换测试
class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("RDD of tuples") {//RDD元组转换DF
    checkAnswer(
      ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("Seq of tuples") {//序列元组转换DF
    checkAnswer(
      (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
      (1 to 10).map(i => Row(i, i.toString)))
  }

  test("RDD[Int]") {//RDD[Int]转换DF
    checkAnswer(
      ctx.sparkContext.parallelize(1 to 10).toDF("intCol"),
      (1 to 10).map(i => Row(i)))
  }

  test("RDD[Long]") {//RDD[Long]转换DF
    checkAnswer(
      ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"),
      (1L to 10L).map(i => Row(i)))
  }

  test("RDD[String]") {//RDD[String]转换DF
    checkAnswer(
      ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
      (1 to 10).map(i => Row(i.toString)))
  }
} 
Example 164
Source File: TungstenAggregationIteratorSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.aggregate

import org.apache.spark._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.memory.TaskMemoryManager
//钨丝聚合迭代器测试套件
class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {

  test("memory acquired on construction") {//在内存上构建
    val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
    val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
    TaskContext.setTaskContext(taskContext)

    // Assert that a page is allocated before processing starts
    //断言在处理开始前分配一个页面
    var iter: TungstenAggregationIterator = null
    try {
      val newMutableProjection = (expr: Seq[Expression], schema: Seq[Attribute]) => {
        () => new InterpretedMutableProjection(expr, schema)
      }
      val dummyAccum = SQLMetrics.createLongMetric(ctx.sparkContext, "dummy")
      iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, 0,
        Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum)
      val numPages = iter.getHashMap.getNumDataPages
      assert(numPages === 1)
    } finally {
      // Clean up 清理
      if (iter != null) {
        iter.free()
      }
      TaskContext.unset()
    }
  }
} 
Example 165
Source File: SemiJoinSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
//半连接测试套件
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = ctx.createDataFrame(
    ctx.sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = ctx.createDataFrame(
    ctx.sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }
  //测试左半连接
  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 166
Source File: ParquetInteroperabilitySuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import java.io.File

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext

      sqlContext.read.parquet(dir.getCanonicalPath).show()
      checkAnswer(
        sqlContext.read.parquet(dir.getCanonicalPath),
        Seq(
          Row(Seq(0, 1)),
          Row(Seq(2, 3)),
          Row(Seq(4, 5)),
          Row(Seq(6, 7))))
    }
  }
} 
Example 167
Source File: ParquetThriftCompatibilitySuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.Row
import org.apache.spark.sql.test.SharedSQLContext
//Parquet兼容性套件
class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  import ParquetCompatibilityTest._

  private val parquetFilePath =
  //Thread.currentThread().getContextClassLoader,可以获取当前线程的引用,getContextClassLoader用来获取线程的上下文类加载器
    //getResource得到的是当前类class文件的URI目录,不包括自己
    Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet")
   //读Parquet文件产生parquet-thrift
  test("Read Parquet file generated by parquet-thrift") {
    logInfo(
      s"""Schema of the Parquet file written by parquet-thrift:
         |${readParquetSchema(parquetFilePath.toString)}
       """.stripMargin)

    checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i =>
      def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i)

      val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS")

      Row(
        i % 2 == 0,
        i.toByte,
        (i + 1).toShort,
        i + 2,
        i.toLong * 10,
        i.toDouble + 0.2d,
        // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always
        // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume
        // Thrift `STRING`s are encoded using UTF-8.
        s"val_$i",
        s"val_$i",
        // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings
        //节省ENUM值将转换为包含UTF-8字符串的Parquet二进制文件
        suits(i % 4),

        nullable(i % 2 == 0: java.lang.Boolean),
        nullable(i.toByte: java.lang.Byte),
        nullable((i + 1).toShort: java.lang.Short),
        nullable(i + 2: Integer),
        nullable((i * 10).toLong: java.lang.Long),
        nullable(i.toDouble + 0.2d: java.lang.Double),
        nullable(s"val_$i"),
        nullable(s"val_$i"),
        nullable(suits(i % 4)),

        Seq.tabulate(3)(n => s"arr_${i + n}"),
        // Thrift `SET`s are converted to Parquet `LIST`s
        Seq(i),
        Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap,
        Seq.tabulate(3) { n =>
          (i + n) -> Seq.tabulate(3) { m =>
            Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}")
          }
        }.toMap)
    })
  }
} 
Example 168
Source File: ParquetProtobufCompatibilitySuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.parquet

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.test.SharedSQLContext

class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
  //读Parquet protobuf文件
  private def readParquetProtobufFile(name: String): DataFrame = {
    //println("========="+name)
    //Thread.currentThread().getContextClassLoader,可以获取当前线程的引用,getContextClassLoader用来获取线程的上下文类加载器
    val url = Thread.currentThread().getContextClassLoader.getResource(name)    
    //println("===="+url)
    sqlContext.read.parquet(url.toString)
  }
  //未注释的原始类型数组
  test("unannotated array of primitive type") {
    checkAnswer(readParquetProtobufFile("old-repeated-int.parquet"), Row(Seq(1, 2, 3)))
  }
  //未注释的结构数组
  test("unannotated array of struct") {
    //readParquetProtobufFile("old-repeated-message.parquet").show()
    checkAnswer(
      readParquetProtobufFile("old-repeated-message.parquet"),
      Row(
        Seq(
          Row("First inner", null, null),
          Row(null, "Second inner", null),
          Row(null, null, "Third inner"))))

    checkAnswer(
      readParquetProtobufFile("proto-repeated-struct.parquet"),
      Row(
        Seq(
          Row("0 - 1", "0 - 2", "0 - 3"),
          Row("1 - 1", "1 - 2", "1 - 3"))))

    checkAnswer(
      readParquetProtobufFile("proto-struct-with-array-many.parquet"),
      Seq(
        Row(
          Seq(
            Row("0 - 0 - 1", "0 - 0 - 2", "0 - 0 - 3"),
            Row("0 - 1 - 1", "0 - 1 - 2", "0 - 1 - 3"))),
        Row(
          Seq(
            Row("1 - 0 - 1", "1 - 0 - 2", "1 - 0 - 3"),
            Row("1 - 1 - 1", "1 - 1 - 2", "1 - 1 - 3"))),
        Row(
          Seq(
            Row("2 - 0 - 1", "2 - 0 - 2", "2 - 0 - 3"),
            Row("2 - 1 - 1", "2 - 1 - 2", "2 - 1 - 3")))))
  }
  //带有未注释数组的struct
  test("struct with unannotated array") {
    checkAnswer(
      readParquetProtobufFile("proto-struct-with-array.parquet"),
      Row(10, 9, Seq.empty, null, Row(9), Seq(Row(9), Row(10))))
  }
  //未注释的数组,带有未注释的数组
  test("unannotated array of struct with unannotated array") {
    checkAnswer(
      readParquetProtobufFile("nested-array-struct.parquet"),
      Seq(
        Row(2, Seq(Row(1, Seq(Row(3))))),
        Row(5, Seq(Row(4, Seq(Row(6))))),
        Row(8, Seq(Row(7, Seq(Row(9)))))))
  }
  //未注释的字符串数组
  test("unannotated array of string") {
    checkAnswer(
      readParquetProtobufFile("proto-repeated-string.parquet"),
      Seq(
        Row(Seq("hello", "world")),
        Row(Seq("good", "bye")),
        Row(Seq("one", "two", "three"))))
  }
} 
Example 169
Source File: ExchangeSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.test.SharedSQLContext

class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
  test("shuffling UnsafeRows in exchange") {//在Shuffle交换中的不安全行
    val input = (1 to 1000).map(Tuple1.apply)
    checkAnswer(
      input.toDF(),
      plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))),
      input.map(Row.fromTuple)
    )
  }
} 
Example 170
Source File: TungstenSortSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import scala.util.Random

import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._


class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {

  override def beforeAll(): Unit = {
    super.beforeAll()
    ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
  }

  override def afterAll(): Unit = {
    try {
      ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
    } finally {
      super.afterAll()
    }
  }

  test("sort followed by limit") {//排序下列限制
    checkThatPlansAgree(
      (1 to 100).map(v => Tuple1(v)).toDF("a"),
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      (child: SparkPlan) => Limit(10, TungstenSort('a.asc :: Nil, true, child)),
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      (child: SparkPlan) => Limit(10, Sort('a.asc :: Nil, global = true, child)),
      sortAnswers = false
    )
  }

  test("sorting does not crash for large inputs") {//大输入的排序不崩溃
  //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
    val sortOrder = 'a.asc :: Nil
    val stringLength = 1024 * 1024 * 2
    checkThatPlansAgree(
      Seq(Tuple1("a" * stringLength), Tuple1("b" * stringLength)).toDF("a").repartition(1),
      TungstenSort(sortOrder, global = true, _: SparkPlan, testSpillFrequency = 1),
      Sort(sortOrder, global = true, _: SparkPlan),
      sortAnswers = false
    )
  }

  test("sorting updates peak execution memory") {//排序更新执行内存值
    val sc = ctx.sparkContext
    AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
      checkThatPlansAgree(
        (1 to 100).map(v => Tuple1(v)).toDF("a"),
        //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
        (child: SparkPlan) => TungstenSort('a.asc :: Nil, true, child),
        //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
        (child: SparkPlan) => Sort('a.asc :: Nil, global = true, child),
        sortAnswers = false)
    }
  }

  // Test sorting on different data types
  //不同数据类型的测试排序
  for (
    dataType <- DataTypeTestUtils.atomicTypes ++ Set(NullType);
    nullable <- Seq(true, false);
    //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
    sortOrder <- Seq('a.asc :: Nil, 'a.desc :: Nil);
    randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
  ) {
    test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
      val inputData = Seq.fill(1000)(randomDataGenerator())
      val inputDf = ctx.createDataFrame(
        ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
        //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
        StructType(StructField("a", dataType, nullable = true) :: Nil)
      )
      assert(TungstenSort.supportsSchema(inputDf.schema))
      checkThatPlansAgree(
        inputDf,
        plan => ConvertToSafe(
          TungstenSort(sortOrder, global = true, plan: SparkPlan, testSpillFrequency = 23)),
        Sort(sortOrder, global = true, _: SparkPlan),
        sortAnswers = false
      )
    }
  }
} 
Example 171
Source File: SortSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.test.SharedSQLContext

class SortSuite extends SparkPlanTest with SharedSQLContext {

  // This test was originally added as an example of how to use [[SparkPlanTest]];
  //这个测试最初是作为一个如何使用的一个例子添加的
  // it's not designed to be a comprehensive test of ExternalSort.
  test("basic sorting using ExternalSort") {//使用外部排序的基本排序

    val input = Seq(
      ("Hello", 4, 2.0),
      ("Hello", 1, 1.0),
      ("World", 8, 3.0)
    )

    checkAnswer(
      input.toDF("a", "b", "c"),
      ExternalSort('a.asc :: 'b.asc :: Nil, global = true, _: SparkPlan),
      input.sortBy(t => (t._1, t._2)).map(Row.fromTuple),
      sortAnswers = false)

    checkAnswer(
      input.toDF("a", "b", "c"),
      ExternalSort('b.asc :: 'a.asc :: Nil, global = true, _: SparkPlan),
      input.sortBy(t => (t._2, t._1)).map(Row.fromTuple),
      sortAnswers = false)
  }
} 
Example 172
Source File: ExtraStrategiesSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package test.org.apache.spark.sql

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.UTF8String
//快速操作
case class FastOperator(output: Seq[Attribute]) extends SparkPlan {

  override protected def doExecute(): RDD[InternalRow] = {
    val str = Literal("so fast").value
    val row = new GenericInternalRow(Array[Any](str))
    sparkContext.parallelize(Seq(row))
  }
  //Nil是一个空的List
  override def children: Seq[SparkPlan] = Nil
}
//测试策略
object TestStrategy extends Strategy {
  def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
    case Project(Seq(attr), _) if attr.name == "a" =>
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      FastOperator(attr.toAttribute :: Nil) :: Nil
    //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
    case _ => Nil
  }
}
//额外的策略集
class ExtraStrategiesSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("insert an extraStrategy") {//插入一个额外的策略
    try {
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      sqlContext.experimental.extraStrategies = TestStrategy :: Nil

      val df = sqlContext.sparkContext.parallelize(Seq(("so slow", 1))).toDF("a", "b")
      checkAnswer(
        df.select("a"),
        Row("so fast"))

      checkAnswer(
        df.select("a", "b"),
        Row("so slow", 1))
    } finally {
      //Nil是一个空的List,::向队列的头部追加数据,创造新的列表
      sqlContext.experimental.extraStrategies = Nil
    }
  }
} 
Example 173
Source File: ListTablesSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
//列表测试套件
class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")

  before {
    df.registerTempTable("ListTablesSuiteTable")
  }

  after {
    ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
  }

  test("get all tables") {//获得所有的表
      
      
    ctx.tables("DB").show()
    
    checkAnswer(
      //使用数据库,查找表名
      ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      //使用命令查询表名
      sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
    assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("query the returned DataFrame of tables") {//查询返回的数据集的表名
  //StructType代表一张表,StructField代表一个字段
    val expectedSchema = StructType(
      StructField("tableName", StringType, false) ::
      StructField("isTemporary", BooleanType, false) :: Nil)

    Seq(ctx.tables(), sql("SHOW TABLes")).foreach {
      case tableDF =>
        assert(expectedSchema === tableDF.schema)

        tableDF.registerTempTable("tables")
        checkAnswer(
          sql(
            "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
          Row(true, "ListTablesSuiteTable")
        )
        checkAnswer(
          ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
          Row("tables", true))
        ctx.dropTempTable("tables")
    }
  }
} 
Example 174
Source File: DataFrameComplexTypeSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext


    df.select(struct($"a").as("s")).select(f($"s.a")).collect()
  }

  test("UDF on named_struct") {
    val f = udf((a: String) => a)
    val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    //df.show()    
    //df.selectExpr("named_struct('a', a) s").select(f($"s.a")).show()    
    df.selectExpr("named_struct('a', a) s").select(f($"s.a")).collect()
  }

  test("UDF on array") {//数组
    val f = udf((a: String) => a)
    val df = sqlContext.sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
    df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
  }

  test("SPARK-12477 accessing null element in array field") {//数组字段中的空元素
    val df = sqlContext.sparkContext.parallelize(Seq((Seq("val1", null, "val2"),
      Seq(Some(1), None, Some(2))))).toDF("s", "i")
    val nullStringRow = df.selectExpr("s[1]").collect()(0)
    assert(nullStringRow == org.apache.spark.sql.Row(null))
    val nullIntRow = df.selectExpr("i[1]").collect()(0)
    assert(nullIntRow == org.apache.spark.sql.Row(null))
  }
} 
Example 175
Source File: DDLSourceLoadSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}


// please note that the META-INF/services had to be modified for the test directory for this to work
//请注意,这个服务必须改进这项工作的测试目录
//数据库定义语言数据源加载
class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {

  test("data sources with the same name") {//相同的名字的数据源
    intercept[RuntimeException] {
      caseInsensitiveContext.read.format("Fluet da Bomb").load()
    }
  }

  test("load data source from format alias") {//从格式化别名加载数据源
    caseInsensitiveContext.read.format("gathering quorum").load().schema ==
    //StructType代表一张表,StructField代表一个字段
      StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("specify full classname with duplicate formats") {//重复的格式指定完整的类名
    caseInsensitiveContext.read.format("org.apache.spark.sql.sources.FakeSourceOne")
      .load().schema == StructType(Seq(StructField("stringType", StringType, nullable = false)))
  }

  test("should fail to load ORC without HiveContext") {//
    intercept[ClassNotFoundException] {
      caseInsensitiveContext.read.format("orc").load()
    }
  }
}

//假数据源之一
class FakeSourceOne extends RelationProvider with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}
//假数据源之二
class FakeSourceTwo extends RelationProvider  with DataSourceRegister {

  def shortName(): String = "Fluet da Bomb"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
}
//假数据源之三
class FakeSourceThree extends RelationProvider with DataSourceRegister {

  def shortName(): String = "gathering quorum"

  override def createRelation(cont: SQLContext, param: Map[String, String]): BaseRelation =
    new BaseRelation {
      override def sqlContext: SQLContext = cont

      override def schema: StructType =
        StructType(Seq(StructField("stringType", StringType, nullable = false)))
    }
} 
Example 176
Source File: PrunedScanSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import scala.language.existentials

import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
//PrunedScan 可以指定列,其他的列数据源可以不用返回
class PrunedScanSource extends RelationProvider {//提供关系
  override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    SimplePrunedScan(parameters("from").toInt, parameters("to").toInt)(sqlContext)
  }
}

case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLContext)
  extends BaseRelation
  with PrunedScan {

  override def schema: StructType =
    StructType(//StructType代表一张表,StructField代表一个字段
      StructField("a", IntegerType, nullable = false) ::
      StructField("b", IntegerType, nullable = false) :: Nil)

  override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
    val rowBuilders = requiredColumns.map {
      case "a" => (i: Int) => Seq(i)
      case "b" => (i: Int) => {
        //println(">>>>>>>"+i * 2)
        Seq(i * 2)
      }
    }
    //parallelize 分区数
    sqlContext.sparkContext.parallelize(from to to).map(i =>
      Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
  }
}

class PrunedScanSuite extends DataSourceTest with SharedSQLContext {
  protected override lazy val sql = caseInsensitiveContext.sql _

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql(
      """
        |CREATE TEMPORARY TABLE oneToTenPruned
        |USING org.apache.spark.sql.sources.PrunedScanSource
        |OPTIONS (
        |  from '1',
        |  to '10'
        |)
      """.stripMargin)
     
  }

  def testPruning(sqlString: String, expectedColumns: String*): Unit = {
    test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
      val queryExecution = sql(sqlString).queryExecution
      val rawPlan = queryExecution.executedPlan.collect {
        case p: execution.PhysicalRDD => p
      } match {
        case Seq(p) => p
        case _ => fail(s"More than one PhysicalRDD found\n$queryExecution")
      }
      val rawColumns = rawPlan.output.map(_.name)
      val rawOutput = rawPlan.execute().first()

      if (rawColumns != expectedColumns) {
        fail(
          s"Wrong column names. Got $rawColumns, Expected $expectedColumns\n" +
          s"Filters pushed: ${FiltersPushed.list.mkString(",")}\n" +
            queryExecution)
      }

      if (rawOutput.numFields != expectedColumns.size) {
        fail(s"Wrong output row. Got $rawOutput\n$queryExecution")
      }
    }
  }

} 
Example 177
Source File: RowSuite.scala    From drizzle-spark   with Apache License 2.0 4 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
} 
Example 178
Source File: RowSuite.scala    From XSQL   with Apache License 2.0 4 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, SpecificInternalRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class RowSuite extends SparkFunSuite with SharedSQLContext {
  import testImplicits._

  test("create row") {
    val expected = new GenericInternalRow(4)
    expected.setInt(0, 2147483647)
    expected.update(1, UTF8String.fromString("this is a string"))
    expected.setBoolean(2, false)
    expected.setNullAt(3)

    val actual1 = Row(2147483647, "this is a string", false, null)
    assert(expected.numFields === actual1.size)
    assert(expected.getInt(0) === actual1.getInt(0))
    assert(expected.getString(1) === actual1.getString(1))
    assert(expected.getBoolean(2) === actual1.getBoolean(2))
    assert(expected.isNullAt(3) === actual1.isNullAt(3))

    val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null))
    assert(expected.numFields === actual2.size)
    assert(expected.getInt(0) === actual2.getInt(0))
    assert(expected.getString(1) === actual2.getString(1))
    assert(expected.getBoolean(2) === actual2.getBoolean(2))
    assert(expected.isNullAt(3) === actual2.isNullAt(3))
  }

  test("SpecificMutableRow.update with null") {
    val row = new SpecificInternalRow(Seq(IntegerType))
    row(0) = null
    assert(row.isNullAt(0))
  }

  test("get values by field name on Row created via .toDF") {
    val row = Seq((1, Seq(1))).toDF("a", "b").first()
    assert(row.getAs[Int]("a") === 1)
    assert(row.getAs[Seq[Int]]("b") === Seq(1))

    intercept[IllegalArgumentException]{
      row.getAs[Int]("c")
    }
  }

  test("float NaN == NaN") {
    val r1 = Row(Float.NaN)
    val r2 = Row(Float.NaN)
    assert(r1 === r2)
  }

  test("double NaN == NaN") {
    val r1 = Row(Double.NaN)
    val r2 = Row(Double.NaN)
    assert(r1 === r2)
  }

  test("equals and hashCode") {
    val r1 = Row("Hello")
    val r2 = Row("Hello")
    assert(r1 === r2)
    assert(r1.hashCode() === r2.hashCode())
    val r3 = Row("World")
    assert(r3.hashCode() != r1.hashCode())
  }
}