org.apache.spark.sql.types.IntegerType Scala Examples

The following examples show how to use org.apache.spark.sql.types.IntegerType. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: ScalaUDFSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase(Locale.ROOT),
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

  test("SPARK-22695: ScalaUDF should not use global variables") {
    val ctx = new CodegenContext
    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
    assert(ctx.inlinedMutableStates.isEmpty)
  }
} 
Example 2
Source File: HashSetManager.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark.storage

import edu.ucla.cs.wis.bigdatalog.spark.SchemaInfo
import edu.ucla.cs.wis.bigdatalog.spark.storage.set.hashset._
import org.apache.spark.TaskContext
import org.apache.spark.sql.types.{IntegerType, LongType}

object HashSetManager {
  def determineKeyType(schemaInfo: SchemaInfo): Int = {
    schemaInfo.arity match {
      case 1 => {
        schemaInfo.schema(0).dataType match {
          case IntegerType => 1
          case LongType => 2
          case other => 3
        }
      }
      case 2 => {
        val bytesPerKey = schemaInfo.schema.map(_.dataType.defaultSize).sum
        if (bytesPerKey == 8) 2 else 3
      }
      case other => 3
    }
  }

  def create(schemaInfo: SchemaInfo): HashSet = {
    determineKeyType(schemaInfo) match {
      case 1 => new IntKeysHashSet()
      case 2 => new LongKeysHashSet(schemaInfo)
      
      case _ => new ObjectHashSet()
    }
  }
} 
Example 3
Source File: HashSetRowIterator.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package edu.ucla.cs.wis.bigdatalog.spark.storage.set.hashset

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class ObjectHashSetRowIterator(set: ObjectHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()

  override final def hasNext(): Boolean = {
    rawIter.hasNext
  }

  override final def next(): InternalRow = {
    rawIter.next()
  }
}

class IntKeysHashSetRowIterator(set: IntKeysHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()
  val uRow = new UnsafeRow()
  val bufferHolder = new BufferHolder()
  val rowWriter = new UnsafeRowWriter()

  override final def hasNext(): Boolean = {
    rawIter.hasNext
  }

  override final def next(): InternalRow = {
    bufferHolder.reset()
    rowWriter.initialize(bufferHolder, 1)
    rowWriter.write(0, rawIter.next())

    uRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize())
    uRow
  }
}

class LongKeysHashSetRowIterator(set: LongKeysHashSet) extends Iterator[InternalRow] {
  val rawIter = set.iterator()
  val numFields = set.schemaInfo.arity
  val uRow = new UnsafeRow()
  val bufferHolder = new BufferHolder()
  val rowWriter = new UnsafeRowWriter()

  override final def hasNext(): Boolean = {
    rawIter.hasNext
  }

  override final def next(): InternalRow = {
    bufferHolder.reset()
    rowWriter.initialize(bufferHolder, numFields)

    val value = rawIter.nextLong()
    if (numFields == 2) {
      rowWriter.write(0, (value >> 32).toInt)
      rowWriter.write(1, value.toInt)
    } else {
      rowWriter.write(0, value)
    }
    uRow.pointTo(bufferHolder.buffer, numFields, bufferHolder.totalSize())
    uRow
  }
}

object HashSetRowIterator {
  def create(set: HashSet): Iterator[InternalRow] = {
    set match {
      //case set: UnsafeFixedWidthSet => set.iterator().asScala
      case set: IntKeysHashSet => new IntKeysHashSetRowIterator(set)
      case set: LongKeysHashSet => new LongKeysHashSetRowIterator(set)
      case set: ObjectHashSet => new ObjectHashSetRowIterator(set)
    }
  }
} 
Example 4
Source File: MockSourceProvider.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class MockSourceProvider extends StreamSourceProvider {
  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", MockSourceProvider.fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    MockSourceProvider.sourceProviderFunction()
  }
}

object MockSourceProvider {
  // Function to generate sources. May provide multiple sources if the user implements such a
  // function.
  private var sourceProviderFunction: () => Source = _

  final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = {
    var i = 0
    val sources = source +: otherSources
    sourceProviderFunction = () => {
      val source = sources(i % sources.length)
      i += 1
      source
    }
    try {
      f
    } finally {
      sourceProviderFunction = null
    }
  }
} 
Example 5
Source File: BlockingSource.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    BlockingSource.latch.await()
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
        Seq[Int]().toDS().toDF()
      }
      override def stop() {}
    }
  }

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}
    }
  }
}

object BlockingSource {
  var latch: CountDownLatch = null
} 
Example 6
Source File: GroupedIteratorSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 7
Source File: resources.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 8
Source File: SimplifyConditionalSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation()).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation()).analyze)
    comparePlans(actual, correctAnswer)
  }

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
    assertEquivalent(
      If(TrueLiteral, Literal(10), Literal(20)),
      Literal(10))

    assertEquivalent(
      If(FalseLiteral, Literal(10), Literal(20)),
      Literal(20))

    assertEquivalent(
      If(Literal.create(null, NullType), Literal(10), Literal(20)),
      Literal(20))
  }

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
    assertEquivalent(
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))
  }

  test("remove entire CaseWhen if only the else branch is reachable") {
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
      Literal(30))

    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))
  }

  test("remove entire CaseWhen if the first branch is always true") {
    assertEquivalent(
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
      Literal(5))

    // Test branch elimination and simplification in combination
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),
      Literal(5))

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
    assertEquivalent(
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: Nil, None))
  }

  test("simplify CaseWhen, prune branches following a definite true") {
    assertEquivalent(
      CaseWhen(normalBranch :: unreachableBranch ::
        unreachableBranch :: nullBranch ::
        trueBranch :: normalBranch ::
        Nil,
        None),
      CaseWhen(normalBranch :: trueBranch :: Nil, None))
  }
} 
Example 9
Source File: RewriteDistinctAggregatesSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 10
Source File: ComplexDataSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.util

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class ComplexDataSuite extends SparkFunSuite {
  def utf8(str: String): UTF8String = UTF8String.fromString(str)

  test("inequality tests for MapData") {
    // test data
    val testMap1 = Map(utf8("key1") -> 1)
    val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
    val testMap3 = Map(utf8("key1") -> 1)
    val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }

  test("GenericInternalRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericRow = genericRow.copy()
    assert(copiedGenericRow.getString(0) == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied internal row should not be changed externally.
    assert(copiedGenericRow.getString(0) == "a")
  }

  test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val mutableRow = new SpecificInternalRow(Seq(StringType))
    mutableRow(0) = unsafeRow.getUTF8String(0)
    val copiedMutableRow = mutableRow.copy()
    assert(copiedMutableRow.getString(0) == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied internal row should not be changed externally.
    assert(copiedMutableRow.getString(0) == "a")
  }

  test("GenericArrayData.copy return a new instance that is independent from the old one") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
    val copiedGenericArray = genericArray.copy()
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied array data should not be changed externally.
    assert(copiedGenericArray.getUTF8String(0).toString == "a")
  }

  test("copy on nested complex type") {
    val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
    val unsafeRow = project.apply(InternalRow(utf8("a")))

    val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0))))
    val copied = arrayOfRow.copy()
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
    project.apply(InternalRow(UTF8String.fromString("b")))
    // The copied data should not be changed externally.
    assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
  }
} 
Example 11
Source File: RandomSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

    checkDoubleEvaluation(
      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
    checkDoubleEvaluation(
      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)
  }

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
  }
} 
Example 12
Source File: ObjectExpressionsSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}


class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SPARK-16622: The returned value of the called method in Invoke can be null") {
    val inputRow = InternalRow.fromSeq(Seq((false, null)))
    val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val invoke = Invoke(inputObject, "_2", IntegerType)
    checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)
  }

  test("MapObjects should make copies of unsafe-backed data") {
    // test UnsafeRow-backed data
    val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
    val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
    val structExpected = new GenericArrayData(
      Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
    checkEvalutionWithUnsafeProjection(
      structEncoder.serializer.head, structExpected, structInputRow)

    // test UnsafeArray-backed data
    val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
    val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
    val arrayExpected = new GenericArrayData(
      Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
    checkEvalutionWithUnsafeProjection(
      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

    // test UnsafeMap-backed data
    val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
    val mapInputRow = InternalRow.fromSeq(Seq(Array(
      Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
    val mapExpected = new GenericArrayData(Seq(
      new ArrayBasedMapData(
        new GenericArrayData(Array(1, 2)),
        new GenericArrayData(Array(100, 200))),
      new ArrayBasedMapData(
        new GenericArrayData(Array(3, 4)),
        new GenericArrayData(Array(300, 400)))))
    checkEvalutionWithUnsafeProjection(
      mapEncoder.serializer.head, mapExpected, mapInputRow)
  }
} 
Example 13
Source File: ExpressionEvalHelperSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      s"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 14
Source File: MiscFunctionsSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("md5") {
    checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
    checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "6ac1e56bc78f031059be7be854522c4c")
    checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)
  }

  test("sha1") {
    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "5d211bad8f4ee70e16c7d343a838fc344a1ed961")
    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
    checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)
  }

  test("sha2") {
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
    // unsupported bit length
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
  }

  test("crc32") {
    checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L)
    checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      2180413220L)
    checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
  }
} 
Example 15
Source File: QueryPlanSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.types.IntegerType

class QueryPlanSuite extends SparkFunSuite {

  test("origin remains the same after mapExpressions (SPARK-23823)") {
    CurrentOrigin.setPosition(0, 0)
    val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId)
    val query = plans.DslLogicalPlan(plans.table("table")).select(column)
    CurrentOrigin.reset()

    val mappedQuery = query mapExpressions {
      case _: Expression => Literal(1)
    }

    val mappedOrigin = mappedQuery.expressions.apply(0).origin
    assert(mappedOrigin == Origin.apply(Some(0), Some(0)))
  }

} 
Example 16
Source File: LogicalPlanSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType


class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1
      p
  }

  private val testRelation = LocalRelation()

  test("transformUp runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan transformUp function

    assert(invocationCount === 1)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 1)
  }

  test("transformUp runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan transformUp function

    assert(invocationCount === 2)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 2)
  }

  test("transformUp skips all ready resolved plans wrapped in analysis barrier") {
    invocationCount = 0
    val plan = AnalysisBarrier(Project(Nil, Project(Nil, testRelation)))
    plan transformUp function

    assert(invocationCount === 0)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 0)
  }

  test("transformUp skips partially resolved plans wrapped in analysis barrier") {
    invocationCount = 0
    val plan1 = AnalysisBarrier(Project(Nil, testRelation))
    val plan2 = Project(Nil, plan1)
    plan2 transformUp function

    assert(invocationCount === 1)

    invocationCount = 0
    plan2 transformDown function
    assert(invocationCount === 1)
  }

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true)

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output
    }

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
  }
} 
Example 17
Source File: StatsEstimationTestBase.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}


trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    super.beforeAll()
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
  }

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
    super.afterAll()
  }

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen + 8 + 4
    case _ => colStat.avgLen
  }

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

  
case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
} 
Example 18
Source File: SubstituteUnresolvedOrdinals.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType


class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = s.order.map {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      }
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = a.groupingExpressions.map {
        case ordinal @ Literal(index: Int, IntegerType) =>
          withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
        case other => other
      }
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
  }
} 
Example 19
Source File: SparkExecuteStatementOperationSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType}

class SparkExecuteStatementOperationSuite extends SparkFunSuite {
  test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") {
    val field1 = StructField("NULL", NullType)
    val field2 = StructField("(IF(true, NULL, NULL))", NullType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
  }

  test("SPARK-20146 Comment should be preserved") {
    val field1 = StructField("column1", StringType).withComment("comment 1")
    val field2 = StructField("column2", IntegerType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE)
    assert(columns.get(0).getComment() == "comment 1")
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE)
    assert(columns.get(1).getComment() == "")
  }
} 
Example 20
Source File: CustomSchemaTest.scala    From spark-sftp   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.sftp

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, _}
import org.scalatest.{BeforeAndAfterEach, FunSuite}


class CustomSchemaTest extends FunSuite with BeforeAndAfterEach {
  var ss: SparkSession = _

  val csvTypesMap = Map("ProposalId" -> IntegerType,
    "OpportunityId" -> StringType,
    "Clicks" -> LongType,
    "Impressions" -> LongType
  )

  val jsonTypesMap = Map("name" -> StringType,
    "age" -> IntegerType
  )

  override def beforeEach() {
    ss = SparkSession.builder().master("local").appName("Custom Schema Test").getOrCreate()
  }

  private def validateTypes(field : StructField, typeMap : Map[String, DataType]) = {
    val expectedType = typeMap(field.name)
    assert(expectedType == field.dataType)
  }

  private def columnArray(typeMap : Map[String, DataType]) : Array[StructField] = {
    val columns = typeMap.map(x => new StructField(x._1, x._2, true))

    val columnStruct = Array[StructField] ()
    columns.copyToArray(columnStruct)

    columnStruct
  }

  test ("Read CSV with custom schema") {
    val columnStruct = columnArray(csvTypesMap)
    val expectedSchema = StructType(columnStruct)

    val fileLocation = getClass.getResource("/sample.csv").getPath
    val dsr = DatasetRelation(fileLocation, "csv", "false", "true", ",", "\"", "\\", "false", null, expectedSchema, ss.sqlContext)
    val rdd = dsr.buildScan()

    assert(dsr.schema.fields.length == columnStruct.length)
    dsr.schema.fields.foreach(s => validateTypes(s, csvTypesMap))
  }

  test ("Read Json with custom schema") {
    val columnStruct = columnArray(jsonTypesMap)
    val expectedSchema = StructType(columnStruct)

    val fileLocation = getClass.getResource("/people.json").getPath
    val dsr = DatasetRelation(fileLocation, "json", "false", "true", ",", "\"", "\\", "false", null, expectedSchema, ss.sqlContext)
    val rdd = dsr.buildScan()

    assert(dsr.schema.fields.length == columnStruct.length)
    dsr.schema.fields.foreach(s => validateTypes(s, jsonTypesMap))
  }

} 
Example 21
Source File: SQLOperationsSpec.scala    From pravda-ml   with Apache License 2.0 5 votes vote down vote up
package odkl.analysis.spark.util

import odkl.analysis.spark.TestEnv
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.IntegerType
import org.scalatest.{FlatSpec, Matchers}


class SQLOperationsSpec extends FlatSpec with TestEnv with Matchers with SQLOperations {

  import sqlc.implicits._

  "A CollectAsList" should "aggregate items" in {
    val df = sc.parallelize(Seq(
      "A" -> 10,
      "A" -> 11,
      "A" -> 12,
      "B" -> 20,
      "B" -> 21,
      "C" -> 30
    )).toDF("C1", "C2")
    val res = df.groupBy("C1").agg(collectAsList(IntegerType)(col("C2"))).collect()
    assertResult(3)(res.length)
    val c1 = res.find(_.getString(0) == "A").getOrElse(fail("No row for 'A'")).getAs[Seq[Int]](1)
    c1 should contain theSameElementsAs Seq(10, 11, 12)
    val c2 = res.find(_.getString(0) == "B").getOrElse(fail("No row for 'B'")).getAs[Seq[Int]](1)
    c2 should contain theSameElementsAs Seq(20, 21)
    val c3 = res.find(_.getString(0) == "C").getOrElse(fail("No row for 'C'")).getAs[Seq[Int]](1)
    c3 should contain theSameElementsAs Seq(30)
  }
} 
Example 22
Source File: SemiJoinSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
//半连接测试套件
class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = ctx.createDataFrame(
    ctx.sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = ctx.createDataFrame(
    ctx.sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }
  //测试左半连接
  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 23
Source File: AttributeSetSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  }
  //按ID检查而不是名称
  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)
  }

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)
  }

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
    assert(addSet.contains(aUpper))
    assert(addSet.contains(aLower))
    assert(addSet.contains(bUpper))
    assert(addSet.contains(bLower))
  }

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
  }

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)
  }

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
  }
} 
Example 24
Source File: MiscFunctionsSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.codec.digest.DigestUtils

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}

class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("md5") {
    checkEvaluation(Md5(Literal("ABC".getBytes)), "902fbdd2b1df0c4f70b4a5d23525e932")
    checkEvaluation(Md5(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "6ac1e56bc78f031059be7be854522c4c")
    checkEvaluation(Md5(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Md5, BinaryType)
  }

  test("sha1") {
    checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")
    checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      "5d211bad8f4ee70e16c7d343a838fc344a1ed961")
    checkEvaluation(Sha1(Literal.create(null, BinaryType)), null)
    checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709")
    checkConsistencyBetweenInterpretedAndCodegen(Sha1, BinaryType)
  }

  test("sha2") {
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC"))
    checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)),
      DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6)))
    // unsupported bit length
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null)
    checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null)
    checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null)
  }

  test("crc32") {
    checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L)
    checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)),
      2180413220L)
    checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
    checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
  }
} 
Example 25
Source File: ArrangePostprocessor.scala    From DataQuality   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package it.agilelab.bigdata.DataQuality.postprocessors

import com.typesafe.config.Config
import it.agilelab.bigdata.DataQuality.checks.CheckResult
import it.agilelab.bigdata.DataQuality.metrics.MetricResult
import it.agilelab.bigdata.DataQuality.sources.HdfsFile
import it.agilelab.bigdata.DataQuality.targets.HdfsTargetConfig
import it.agilelab.bigdata.DataQuality.utils
import it.agilelab.bigdata.DataQuality.utils.DQSettings
import it.agilelab.bigdata.DataQuality.utils.io.{HdfsReader, HdfsWriter}
import org.apache.hadoop.fs.FileSystem
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, NumericType}
import org.apache.spark.sql.{Column, DataFrame, SQLContext}

import scala.collection.JavaConversions._

final class ArrangePostprocessor(config: Config, settings: DQSettings)
    extends BasicPostprocessor(config, settings) {

  private case class ColumnSelector(name: String, tipo: Option[String] = None, format: Option[String] = None, precision: Option[Integer] = None) {
    def toColumn()(implicit df: DataFrame): Column = {

      val dataType: Option[NumericType with Product with Serializable] =
        tipo.getOrElse("").toUpperCase match {
          case "DOUBLE" => Some(DoubleType)
          case "INT"    => Some(IntegerType)
          case "LONG"   => Some(LongType)
          case _        => None
        }

      import org.apache.spark.sql.functions.format_number
      import org.apache.spark.sql.functions.format_string

      (dataType, precision, format) match {
        case (Some(dt), None, None) => df(name).cast(dt)
        case(Some(dt), None, Some(f)) => format_string(f, df(name).cast(dt)).alias(name)
        case (Some(dt), Some(p),None) => format_number(df(name).cast(dt), p).alias(name)
        case (None, Some(p), None) => format_number(df(name), p).alias(name)
        case (None, None, Some(f)) => format_string(f, df(name)).alias(name)
        case _ => df(name)
      }
    }
  }

  private val vs = config.getString("source")
  private val target: HdfsTargetConfig = {
    val conf = config.getConfig("saveTo")
    utils.parseTargetConfig(conf)(settings).get
  }

  private val columns: Seq[ColumnSelector] =
    config.getAnyRefList("columnOrder").map {
      case x: String => ColumnSelector(x)
      case x: java.util.HashMap[_, String] => {
        val (name, v) = x.head.asInstanceOf[String Tuple2 _]

        v match {
          case v: String =>
            ColumnSelector(name, Option(v))
          case v: java.util.HashMap[String, _] => {
            val k = v.head._1
            val f = v.head._2

            f match {
              case f: Integer =>
                ColumnSelector(name, Option(k), None, Option(f))
              case f: String =>
                ColumnSelector(name, Option(k), Option(f))
            }
          }
        }
      }
    }

  override def process(vsRef: Set[HdfsFile],
                       metRes: Seq[MetricResult],
                       chkRes: Seq[CheckResult])(
      implicit fs: FileSystem,
      sqlContext: SQLContext,
      settings: DQSettings): HdfsFile = {

    val reqVS: HdfsFile = vsRef.filter(vr => vr.id == vs).head
    implicit val df: DataFrame = HdfsReader.load(reqVS, settings.ref_date).head

    val arrangeDF = df.select(columns.map(_.toColumn): _*)

    HdfsWriter.saveVirtualSource(arrangeDF, target, settings.refDateString)(
      fs,
      sqlContext.sparkContext)

    new HdfsFile(target)
  }
} 
Example 26
Source File: DataFrameNaFunctionsSpec.scala    From spark-spec   with MIT License 5 votes vote down vote up
package com.github.mrpowers.spark.spec.sql

import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.Row
import org.scalatest._
import com.github.mrpowers.spark.daria.sql.SparkSessionExt._

import com.github.mrpowers.spark.spec.SparkSessionTestWrapper

import com.github.mrpowers.spark.fast.tests.DatasetComparer

class DataFrameNaFunctionsSpec extends FunSpec with SparkSessionTestWrapper with DatasetComparer {

  import spark.implicits._

  describe("#drop") {

    it("drops rows that contains null values") {

      val sourceData = List(
        Row(1, null),
        Row(null, null),
        Row(3, 30),
        Row(10, 20)
      )

      val sourceSchema = List(
        StructField("num1", IntegerType, true),
        StructField("num2", IntegerType, true)
      )

      val sourceDF = spark.createDataFrame(
        spark.sparkContext.parallelize(sourceData),
        StructType(sourceSchema)
      )

      val actualDF = sourceDF.na.drop()

      val expectedData = List(
        Row(3, 30),
        Row(10, 20)
      )

      val expectedSchema = List(
        StructField("num1", IntegerType, true),
        StructField("num2", IntegerType, true)
      )

      val expectedDF = spark.createDataFrame(
        spark.sparkContext.parallelize(expectedData),
        StructType(expectedSchema)
      )

      assertSmallDatasetEquality(actualDF, expectedDF)

    }

  }

  describe("#fill") {

    it("Returns a new DataFrame that replaces null or NaN values in numeric columns with value") {

      val sourceDF = spark.createDF(
        List(
          (1, null),
          (null, null),
          (3, 30),
          (10, 20)
        ), List(
          ("num1", IntegerType, true),
          ("num2", IntegerType, true)
        )
      )

      val actualDF = sourceDF.na.fill(77)

      val expectedDF = spark.createDF(
        List(
          (1, 77),
          (77, 77),
          (3, 30),
          (10, 20)
        ), List(
          ("num1", IntegerType, false),
          ("num2", IntegerType, false)
        )
      )

      assertSmallDatasetEquality(actualDF, expectedDF)

    }

  }

  describe("#replace") {
    pending
  }

} 
Example 27
Source File: ReportContentTestFactory.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.reportlib.model.factory

import io.deepsense.reportlib.model.{ReportType, ReportContent}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}

trait ReportContentTestFactory {

  import ReportContentTestFactory._

  def testReport: ReportContent = ReportContent(
    reportName,
    reportType,
    Seq(TableTestFactory.testEmptyTable),
    Map(ReportContentTestFactory.categoricalDistName ->
      DistributionTestFactory.testCategoricalDistribution(
        ReportContentTestFactory.categoricalDistName),
      ReportContentTestFactory.continuousDistName ->
      DistributionTestFactory.testContinuousDistribution(
        ReportContentTestFactory.continuousDistName)
    )
  )

}

object ReportContentTestFactory extends ReportContentTestFactory {
  val continuousDistName = "continuousDistributionName"
  val categoricalDistName = "categoricalDistributionName"
  val reportName = "TestReportContentName"
  val reportType = ReportType.Empty

  val someReport: ReportContent = ReportContent("empty", ReportType.Empty)
} 
Example 28
Source File: EstimatorModelWrapperIntegSpec.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.deeplang.doperables.spark.wrappers.estimators

import io.deepsense.deeplang.DeeplangIntegTestSupport
import io.deepsense.deeplang.doperables.dataframe.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructType, StructField}

class EstimatorModelWrapperIntegSpec extends DeeplangIntegTestSupport {

  import io.deepsense.deeplang.doperables.spark.wrappers.estimators.EstimatorModelWrapperFixtures._

  val inputDF = {
    val rowSeq = Seq(Row(1), Row(2), Row(3))
    val schema = StructType(Seq(StructField("x", IntegerType, nullable = false)))
    createDataFrame(rowSeq, schema)
  }

  val estimatorPredictionParamValue = "estimatorPrediction"

  val expectedSchema = StructType(Seq(
    StructField("x", IntegerType, nullable = false),
    StructField(estimatorPredictionParamValue, IntegerType, nullable = false)
  ))

  val transformerPredictionParamValue = "modelPrediction"

  val expectedSchemaForTransformerParams = StructType(Seq(
    StructField("x", IntegerType, nullable = false),
    StructField(transformerPredictionParamValue, IntegerType, nullable = false)
  ))

  "EstimatorWrapper" should {
    "_fit() and transform() + transformSchema() with parameters inherited" in {

      val transformer = createEstimatorAndFit()

      val transformOutputSchema =
        transformer._transform(executionContext, inputDF).sparkDataFrame.schema
      transformOutputSchema shouldBe expectedSchema

      val inferenceOutputSchema = transformer._transformSchema(inputDF.sparkDataFrame.schema)
      inferenceOutputSchema shouldBe Some(expectedSchema)
    }

    "_fit() and transform() + transformSchema() with parameters overwritten" in {

      val transformer = createEstimatorAndFit().setPredictionColumn(transformerPredictionParamValue)

      val transformOutputSchema =
        transformer._transform(executionContext, inputDF).sparkDataFrame.schema
      transformOutputSchema shouldBe expectedSchemaForTransformerParams

      val inferenceOutputSchema = transformer._transformSchema(inputDF.sparkDataFrame.schema)
      inferenceOutputSchema shouldBe Some(expectedSchemaForTransformerParams)
    }

    "_fit_infer().transformSchema() with parameters inherited" in {

      val estimatorWrapper = new SimpleSparkEstimatorWrapper()
        .setPredictionColumn(estimatorPredictionParamValue)

      estimatorWrapper._fit_infer(inputDF.schema)
        ._transformSchema(inputDF.sparkDataFrame.schema) shouldBe Some(expectedSchema)
    }

    "_fit_infer().transformSchema() with parameters overwritten" in {

      val estimatorWrapper = new SimpleSparkEstimatorWrapper()
        .setPredictionColumn(estimatorPredictionParamValue)
      val transformer =
        estimatorWrapper._fit_infer(inputDF.schema).asInstanceOf[SimpleSparkModelWrapper]
      val transformerWithParams = transformer.setPredictionColumn(transformerPredictionParamValue)

      val outputSchema = transformerWithParams._transformSchema(inputDF.sparkDataFrame.schema)
      outputSchema shouldBe Some(expectedSchemaForTransformerParams)
    }
  }

  private def createEstimatorAndFit(): SimpleSparkModelWrapper = {

    val estimatorWrapper = new SimpleSparkEstimatorWrapper()
      .setPredictionColumn(estimatorPredictionParamValue)

    val transformer =
      estimatorWrapper._fit(executionContext, inputDF).asInstanceOf[SimpleSparkModelWrapper]
    transformer.getPredictionColumn() shouldBe estimatorPredictionParamValue

    transformer
  }
} 
Example 29
Source File: EstimatorModelWrapperFixtures.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.deeplang.doperables.spark.wrappers.estimators

import scala.language.reflectiveCalls

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml
import org.apache.spark.ml.param.{ParamMap, Param => SparkParam}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

import io.deepsense.deeplang.ExecutionContext
import io.deepsense.deeplang.doperables.report.Report
import io.deepsense.deeplang.doperables.serialization.SerializableSparkModel
import io.deepsense.deeplang.doperables.{SparkEstimatorWrapper, SparkModelWrapper}
import io.deepsense.deeplang.params.wrappers.spark.SingleColumnCreatorParamWrapper
import io.deepsense.deeplang.params.{Param, Params}
import io.deepsense.sparkutils.ML

object EstimatorModelWrapperFixtures {

  class SimpleSparkModel private[EstimatorModelWrapperFixtures]()
    extends ML.Model[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "modelId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    def setPredictionCol(value: String): this.type = set(predictionCol, value)

    override def copy(extra: ParamMap): this.type = defaultCopy(extra)

    override def transformDF(dataset: DataFrame): DataFrame = {
      dataset.selectExpr("*", "1 as " + $(predictionCol))
    }

    @DeveloperApi
    override def transformSchema(schema: StructType): StructType = ???
  }

  class SimpleSparkEstimator extends ML.Estimator[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "estimatorId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    override def fitDF(dataset: DataFrame): SimpleSparkModel =
      new SimpleSparkModel().setPredictionCol($(predictionCol))

    override def copy(extra: ParamMap): ML.Estimator[SimpleSparkModel] = defaultCopy(extra)

    @DeveloperApi
    override def transformSchema(schema: StructType): StructType = {
      schema.add(StructField($(predictionCol), IntegerType, nullable = false))
    }
  }

  trait HasPredictionColumn extends Params {
    val predictionColumn = new SingleColumnCreatorParamWrapper[
        ml.param.Params { val predictionCol: SparkParam[String] }](
      "prediction column",
      None,
      _.predictionCol)
    setDefault(predictionColumn, "abcdefg")

    def getPredictionColumn(): String = $(predictionColumn)
    def setPredictionColumn(value: String): this.type = set(predictionColumn, value)
  }

  class SimpleSparkModelWrapper
    extends SparkModelWrapper[SimpleSparkModel, SimpleSparkEstimator]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report: Report = ???

    override protected def loadModel(
      ctx: ExecutionContext,
      path: String): SerializableSparkModel[SimpleSparkModel] = ???
  }

  class SimpleSparkEstimatorWrapper
    extends SparkEstimatorWrapper[SimpleSparkModel, SimpleSparkEstimator, SimpleSparkModelWrapper]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report: Report = ???
  }
} 
Example 30
Source File: DataFrameSplitterIntegSpec.scala    From seahorse-workflow-executor   with Apache License 2.0 5 votes vote down vote up
package io.deepsense.deeplang.doperations

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.scalatest.Matchers
import org.scalatest.prop.GeneratorDrivenPropertyChecks

import io.deepsense.deeplang._
import io.deepsense.deeplang.doperables.dataframe.DataFrame

class DataFrameSplitterIntegSpec
  extends DeeplangIntegTestSupport
  with GeneratorDrivenPropertyChecks
  with Matchers {

  "SplitDataFrame" should {
    "split randomly one df into two df in given range" in {

      val input = Range(1, 100)

      val parameterPairs = List(
        (0.0, 0),
        (0.3, 1),
        (0.5, 2),
        (0.8, 3),
        (1.0, 4))

      for((splitRatio, seed) <- parameterPairs) {
        val rdd = createData(input)
        val df = executionContext.dataFrameBuilder.buildDataFrame(createSchema, rdd)
        val (df1, df2) = executeOperation(
          executionContext,
          new Split()
            .setSplitMode(
              SplitModeChoice.Random()
                .setSplitRatio(splitRatio)
                .setSeed(seed / 2)))(df)
        validateSplitProperties(df, df1, df2)
      }
    }

    "split conditionally one df into two df in given range" in {

      val input = Range(1, 100)

      val condition = "value > 20"
      val predicate: Int => Boolean = _ > 20

      val (expectedDF1, expectedDF2) =
        (input.filter(predicate), input.filter(!predicate(_)))

      val rdd = createData(input)
      val df = executionContext.dataFrameBuilder.buildDataFrame(createSchema, rdd)
      val (df1, df2) = executeOperation(
        executionContext,
        new Split()
          .setSplitMode(
            SplitModeChoice.Conditional()
              .setCondition(condition)))(df)
      df1.sparkDataFrame.collect().map(_.get(0)) should contain theSameElementsAs expectedDF1
      df2.sparkDataFrame.collect().map(_.get(0)) should contain theSameElementsAs expectedDF2
      validateSplitProperties(df, df1, df2)
    }
  }

  private def createSchema: StructType = {
    StructType(List(
      StructField("value", IntegerType, nullable = false)
    ))
  }

  private def createData(data: Seq[Int]): RDD[Row] = {
    sparkContext.parallelize(data.map(Row(_)))
  }

  private def executeOperation(context: ExecutionContext, operation: DOperation)
                              (dataFrame: DataFrame): (DataFrame, DataFrame) = {
    val operationResult = operation.executeUntyped(Vector[DOperable](dataFrame))(context)
    val df1 = operationResult.head.asInstanceOf[DataFrame]
    val df2 = operationResult.last.asInstanceOf[DataFrame]
    (df1, df2)
  }

  def validateSplitProperties(inputDF: DataFrame, outputDF1: DataFrame, outputDF2: DataFrame)
      : Unit = {
    val dfCount = inputDF.sparkDataFrame.count()
    val df1Count = outputDF1.sparkDataFrame.count()
    val df2Count = outputDF2.sparkDataFrame.count()
    val rowsDf = inputDF.sparkDataFrame.collectAsList().asScala
    val rowsDf1 = outputDF1.sparkDataFrame.collectAsList().asScala
    val rowsDf2 = outputDF2.sparkDataFrame.collectAsList().asScala
    val intersect = rowsDf1.intersect(rowsDf2)
    intersect.size shouldBe 0
    (df1Count + df2Count) shouldBe dfCount
    rowsDf.toSet shouldBe rowsDf1.toSet.union(rowsDf2.toSet)
  }
} 
Example 31
Source File: QuadTreeIndexedRelation.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.index

import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.apache.spark.storage.StorageLevel

import org.apache.spark.sql.simba.partitioner.QuadTreePartitioner
import org.apache.spark.sql.simba.spatial.Point


private[simba] case class QuadTreeIndexedRelation(output: Seq[Attribute], child: SparkPlan, table_name: Option[String],
                                                  column_keys: List[Attribute], index_name: String)(var _indexedRDD: IndexedRDD = null, var global_index: QuadTree = null)
  extends IndexedRelation with MultiInstanceRelation {
  private def checkKeys: Boolean = {
    for (i <- column_keys.indices)
      if (!(column_keys(i).dataType.isInstanceOf[DoubleType] ||
        column_keys(i).dataType.isInstanceOf[IntegerType])) {
        return false
      }
    true
  }
  require(checkKeys)

  if (_indexedRDD == null) {
    buildIndex()
  }

  private[simba] def buildIndex(): Unit = {
    val numShufflePartitions = simbaSession.sessionState.simbaConf.indexPartitions
    val sampleRate = simbaSession.sessionState.simbaConf.sampleRate
    val tranferThreshold = simbaSession.sessionState.simbaConf.transferThreshold

    val dataRDD = child.execute().map(row => {
      val now = column_keys.map(x =>
        BindReferences.bindReference(x, child.output).eval(row).asInstanceOf[Number].doubleValue()
      ).toArray
      (new Point(now), row)
    })

    val dimension = column_keys.length
    val (partitionedRDD, _, global_qtree) = QuadTreePartitioner(dataRDD, dimension,
      numShufflePartitions, sampleRate, tranferThreshold)

    val indexed = partitionedRDD.mapPartitions { iter =>
      val data = iter.toArray
      val index: QuadTree =
        if (data.length > 0) QuadTree(data.map(_._1).zipWithIndex)
        else null
      Array(IPartition(data.map(_._2), index)).iterator
    }.persist(StorageLevel.MEMORY_AND_DISK_SER)

    indexed.setName(table_name.map(name => s"$name $index_name").getOrElse(child.toString))
    _indexedRDD = indexed
    global_index = global_qtree
  }

  override def newInstance(): IndexedRelation = {
    new QuadTreeIndexedRelation(output.map(_.newInstance()), child, table_name,
      column_keys, index_name)(_indexedRDD)
      .asInstanceOf[this.type]
  }

  override def withOutput(new_output: Seq[Attribute]): IndexedRelation = {
    new QuadTreeIndexedRelation(new_output, child, table_name,
      column_keys, index_name)(_indexedRDD, global_index)
  }
} 
Example 32
Source File: RDDFixtures.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector

import java.util.Date

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DateType, IntegerType, StringType, StructField, StructType }

import com.actian.spark_vector.test.util.StructTypeUtil.createSchema


trait RDDFixtures {
  // poor man's fixture, for other approaches see:
  // http://www.scalatest.org/user_guide/sharing_fixtures
  def createRecordRdd(sc: SparkContext): (RDD[Seq[Any]], StructType) = {
    val input = Seq(
      Seq(42, "a"),
      Seq(43, "b"))
    val inputRdd = sc.parallelize(input, 2)
    val inputSchema = createSchema("id" -> IntegerType, "name" -> StringType)
    (inputRdd, inputSchema)
  }

  def createRowRDD(sc: SparkContext): (RDD[Seq[Any]], StructType) = {
    val input = Seq(
      Seq[Any](42, "a", new Date(), new Date()),
      Seq[Any](43, "b", new Date(), new Date()))
    val inputRdd = sc.parallelize(input, 2)
    val inputSchema = createSchema("id" -> IntegerType, "name" -> StringType, "date" -> DateType)
    (inputRdd, inputSchema)
  }

  def wideRDD(sc: SparkContext, columnCount: Int, rowCount: Int = 2): (RDD[Row], StructType) = {
    val data: Row = Row.fromSeq(1 to columnCount)

    val fields = for (i <- 1 to rowCount) yield {
      StructField("field_" + i, IntegerType, true)
    }

    val inputSchema = StructType(fields.toSeq)

    val input = for (i <- 1 to rowCount) yield {
      data
    }

    val inputRDD = sc.parallelize(input, 2)
    (inputRDD, inputSchema)
  }
} 
Example 33
Source File: FunctionTest.scala    From sope   with Apache License 2.0 5 votes vote down vote up
package com.sope

import com.sope.model.{Class, Person, Student}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import com.sope.spark.sql._
import com.sope.TestContext.getSQlContext
import org.apache.spark.sql.types.{StringType, IntegerType}
import org.scalatest.{FlatSpec, Matchers}


class FunctionTest extends FlatSpec with Matchers {

  private val sqlContext = getSQlContext

  import sqlContext.implicits._

  private val testSData = Seq(
    Person("Sherlock", "Holmes", "baker street", "[email protected]", "999999"),
    Person("John", "Watson", "east street", "[email protected]", "55555")
  ).toDF

  private val studentDF = Seq(
    Student("A", "B", 1, 10),
    Student("B", "C", 2, 10),
    Student("C", "E", 4, 9),
    Student("E", "F", 5, 9),
    Student("F", "G", 6, 10),
    Student("G", "H", 7, 10),
    Student("H", "I", 9, 8),
    Student("H", "I", 9, 7)
  ).toDF

  private val classDF = Seq(
    Class(1, 10, "Tenth"),
    Class(2, 9, "Ninth"),
    Class(3, 8, "Eighth")
  ).toDF

  "Dataframe Function transformations" should "generate the transformations correctly" in {
    val nameUpperFunc = (df: DataFrame) => df.withColumn("first_name", upper(col("first_name")))
    val nameConcatFunc = (df: DataFrame) => df.withColumn("name", concat(col("first_name"), col("last_name")))
    val addressUpperFunc = (df: DataFrame) => df.withColumn("address", upper(col("address")))
    val transformed = testSData.applyDFTransformations(Seq(nameUpperFunc, nameConcatFunc, addressUpperFunc))
    transformed.show(false)
    transformed.schema.fields.map(_.name) should contain("name")
  }

  "Group by as list Function Transformation" should  "generate the transformations correctly" in {
    val grouped = studentDF.groupByAsList(Seq("cls"))
        .withColumn("grouped_data", explode($"grouped_data"))
        .unstruct("grouped_data", keepStructColumn = false)
    grouped.show(false)
    grouped.filter("cls = 10").head.getAs[Long]("grouped_count") should be(4)
  }

  "Cast Transformation" should  "generate the transformations correctly" in {
    val casted = studentDF.castColumns(IntegerType, StringType)
    casted.dtypes.count(_._2 == "StringType") should be(4)
  }


  "Update Keys Transformation" should  "generate the transformations correctly" in {
    val updatedWithKey = studentDF
      .updateKeys(Seq("cls"), classDF.renameColumns(Map("cls" -> "class")), "class", "key")
      .dropColumns(Seq("last_name", "roll_no"))
    updatedWithKey.show(false)
    updatedWithKey.filter("first_name = 'A'").head.getAs[Long]("cls_key") should be(1)
  }

} 
Example 34
Source File: SqlExtensionProviderSuite.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{BinaryExpression, Expression, Literal, UnaryExpression}
import org.apache.spark.sql.types.{DataType, IntegerType}

import io.projectglow.GlowSuite

class SqlExtensionProviderSuite extends GlowSuite {
  override def beforeAll(): Unit = {
    super.beforeAll()
    SqlExtensionProvider.registerFunctions(
      spark.sessionState.conf,
      spark.sessionState.functionRegistry,
      "test-functions.yml")
  }

  private lazy val sess = spark
  test("one arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("one_arg_test(id)").as[Int].head() == 1)

    intercept[AnalysisException] {
      spark.range(1).selectExpr("one_arg_test()").collect()
    }

    intercept[AnalysisException] {
      spark.range(1).selectExpr("one_arg_test(id, id)").collect()
    }
  }

  test("two arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("two_arg_test(id, id)").as[Int].head() == 1)

    intercept[AnalysisException] {
      spark.range(1).selectExpr("two_arg_test(id)").collect()
    }

    intercept[AnalysisException] {
      spark.range(1).selectExpr("two_arg_test(id, id, id)").collect()
    }
  }

  test("var args function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("var_args_test(id, id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("var_args_test(id, id, id, id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("var_args_test(id)").as[Int].head() == 1)

    intercept[AnalysisException] {
      spark.range(1).selectExpr("var_args_test()").collect()
    }
  }

  test("can call optional arg function") {
    import sess.implicits._
    assert(spark.range(1).selectExpr("optional_arg_test(id)").as[Int].head() == 1)
    assert(spark.range(1).selectExpr("optional_arg_test(id, id)").as[Int].head() == 1)

    intercept[AnalysisException] {
      spark.range(1).selectExpr("optional_arg_test()").collect()
    }

    intercept[AnalysisException] {
      spark.range(1).selectExpr("optional_arg_test(id, id, id)").collect()
    }
  }
}

trait TestExpr extends Expression with CodegenFallback {
  override def dataType: DataType = IntegerType

  override def nullable: Boolean = true

  override def eval(input: InternalRow): Any = 1
}

case class OneArgExpr(child: Expression) extends UnaryExpression with TestExpr
case class TwoArgExpr(left: Expression, right: Expression) extends BinaryExpression with TestExpr
case class VarArgsExpr(arg: Expression, varArgs: Seq[Expression]) extends TestExpr {
  override def children: Seq[Expression] = arg +: varArgs
}
case class OptionalArgExpr(required: Expression, optional: Expression) extends TestExpr {
  def this(required: Expression) = this(required, Literal(1))
  override def children: Seq[Expression] = Seq(required, optional)
} 
Example 35
Source File: AnnotationUtils.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.vcf

import org.apache.spark.sql.types.{ArrayType, DataType, IntegerType, StringType, StructField, StructType}

// Unified VCF annotation representation, used by SnpEff and VEP
object AnnotationUtils {

  // Delimiter between annotation fields
  val annotationDelimiter = "|"
  val annotationDelimiterRegex = "\\|"

  // Fractional delimiter for struct subfields
  val structDelimiter = "/"
  val structDelimiterRegex = "\\/"

  // Delimiter for array subfields
  val arrayDelimiter = "&"

  // Struct subfield schemas
  private val rankTotalStruct = StructType(
    Seq(StructField("rank", IntegerType), StructField("total", IntegerType)))
  private val posLengthStruct = StructType(
    Seq(StructField("pos", IntegerType), StructField("length", IntegerType)))
  private val referenceVariantStruct = StructType(
    Seq(StructField("reference", StringType), StructField("variant", StringType)))

  // Special schemas for SnpEff subfields
  private val snpEffFieldsToSchema: Map[String, DataType] = Map(
    "Annotation" -> ArrayType(StringType),
    "Rank" -> rankTotalStruct,
    "cDNA_pos/cDNA_length" -> posLengthStruct,
    "CDS_pos/CDS_length" -> posLengthStruct,
    "AA_pos/AA_length" -> posLengthStruct,
    "Distance" -> IntegerType
  )

  // Special schemas for VEP subfields
  private val vepFieldsToSchema: Map[String, DataType] = Map(
    "Consequence" -> ArrayType(StringType),
    "EXON" -> rankTotalStruct,
    "INTRON" -> rankTotalStruct,
    "cDNA_position" -> IntegerType,
    "CDS_position" -> IntegerType,
    "Protein_position" -> IntegerType,
    "Amino_acids" -> referenceVariantStruct,
    "Codons" -> referenceVariantStruct,
    "Existing_variation" -> ArrayType(StringType),
    "DISTANCE" -> IntegerType,
    "STRAND" -> IntegerType,
    "FLAGS" -> ArrayType(StringType)
  )

  // Special schemas for LOFTEE (as VEP plugin) subfields
  private val lofteeFieldsToSchema: Map[String, DataType] = Map(
    "LoF_filter" -> ArrayType(StringType),
    "LoF_flags" -> ArrayType(StringType),
    "LoF_info" -> ArrayType(StringType)
  )

  // Default string schema for annotation subfield
  val allFieldsToSchema: Map[String, DataType] =
    (snpEffFieldsToSchema ++ vepFieldsToSchema ++ lofteeFieldsToSchema).withDefaultValue(StringType)
} 
Example 36
Source File: KCore.scala    From sona   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.sona.graph.kcore
import com.tencent.angel.sona.context.PSContext
import org.apache.spark.SparkContext
import com.tencent.angel.sona.graph.params._
import com.tencent.angel.sona.ml.Transformer
import com.tencent.angel.sona.ml.param.ParamMap
import com.tencent.angel.sona.ml.util.Identifiable
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.storage.StorageLevel

class KCore(override val uid: String) extends Transformer
  with HasSrcNodeIdCol with HasDstNodeIdCol with HasOutputNodeIdCol with HasOutputCoreIdCol
  with HasStorageLevel with HasPartitionNum with HasPSPartitionNum with HasUseBalancePartition {

  def this() = this(Identifiable.randomUID("KCore"))

  override def transform(dataset: Dataset[_]): DataFrame = {
    val edges = dataset.select($(srcNodeIdCol), $(dstNodeIdCol)).rdd
      .map(row => (row.getLong(0), row.getLong(1)))
      .filter(e => e._1 != e._2)

    edges.persist(StorageLevel.DISK_ONLY)

    val maxId = edges.map(e => math.max(e._1, e._2)).max() + 1
    val minId = edges.map(e => math.min(e._1, e._2)).min()
    val nodes = edges.flatMap(e => Iterator(e._1, e._2))
    val numEdges = edges.count()

    println(s"minId=$minId maxId=$maxId numEdges=$numEdges level=${$(storageLevel)}")

    // Start PS and init the model
    println("start to run ps")
    PSContext.getOrCreate(SparkContext.getOrCreate())

    val model = KCorePSModel.fromMinMax(minId, maxId, nodes, $(psPartitionNum), $(useBalancePartition))
    var graph = edges.flatMap(e => Iterator((e._1, e._2), (e._2, e._1)))
      .groupByKey($(partitionNum))
      .mapPartitionsWithIndex((index, edgeIter) =>
        Iterator(KCoreGraphPartition.apply(index, edgeIter)))

    graph.persist($(storageLevel))
    graph.foreachPartition(_ => Unit)
    graph.foreach(_.initMsgs(model))

    var curIteration = 0
    var numMsgs = model.numMsgs()
    var prev = graph
    println(s"numMsgs=$numMsgs")

    do {
      curIteration += 1
      graph = prev.map(_.process(model, numMsgs, curIteration == 1))
      graph.persist($(storageLevel))
      graph.count()
      prev.unpersist(true)
      prev = graph
      model.resetMsgs()
      numMsgs = model.numMsgs()
      println(s"curIteration=$curIteration numMsgs=$numMsgs")
    } while (numMsgs > 0)

    val retRDD = graph.map(_.save()).flatMap{case (nodes,cores) => nodes.zip(cores)}
      .map(r => Row.fromSeq(Seq[Any](r._1, r._2)))

    dataset.sparkSession.createDataFrame(retRDD, transformSchema(dataset.schema))
  }

  override def transformSchema(schema: StructType): StructType = {
    StructType(Seq(
      StructField(s"${$(outputNodeIdCol)}", LongType, nullable = false),
      StructField(s"${$(outputCoreIdCol)}", IntegerType, nullable = false)
    ))
  }

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

} 
Example 37
Source File: TestData.scala    From drunken-data-quality   with Apache License 2.0 5 votes vote down vote up
package de.frosner.ddq.testutils

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}

object TestData {

  def makeIntegerDf(spark: SparkSession, numbers: Seq[Int]): DataFrame =
    spark.createDataFrame(
      spark.sparkContext.makeRDD(numbers.map(Row(_))),
      StructType(List(StructField("column", IntegerType, nullable = false)))
    )

  def makeNullableStringDf(spark: SparkSession, strings: Seq[String]): DataFrame =
    spark.createDataFrame(spark.sparkContext.makeRDD(strings.map(Row(_))), StructType(List(StructField("column", StringType, nullable = true))))

  def makeIntegersDf(spark: SparkSession, row1: Seq[Int], rowN: Seq[Int]*): DataFrame = {
    val rows = row1 :: rowN.toList
    val numCols = row1.size
    val rdd = spark.sparkContext.makeRDD(rows.map(Row(_:_*)))
    val schema = StructType((1 to numCols).map(idx => StructField("column" + idx, IntegerType, nullable = false)))
    spark.createDataFrame(rdd, schema)
  }

} 
Example 38
Source File: ReportContentTestFactory.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.reportlib.model.factory

import ai.deepsense.reportlib.model.{ReportType, ReportContent}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}

trait ReportContentTestFactory {

  import ReportContentTestFactory._

  def testReport: ReportContent = ReportContent(
    reportName,
    reportType,
    Seq(TableTestFactory.testEmptyTable),
    Map(ReportContentTestFactory.categoricalDistName ->
      DistributionTestFactory.testCategoricalDistribution(
        ReportContentTestFactory.categoricalDistName),
      ReportContentTestFactory.continuousDistName ->
      DistributionTestFactory.testContinuousDistribution(
        ReportContentTestFactory.continuousDistName)
    )
  )

}

object ReportContentTestFactory extends ReportContentTestFactory {
  val continuousDistName = "continuousDistributionName"
  val categoricalDistName = "categoricalDistributionName"
  val reportName = "TestReportContentName"
  val reportType = ReportType.Empty

  val someReport: ReportContent = ReportContent("empty", ReportType.Empty)
} 
Example 39
Source File: IdentityTransformerSpec.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.transform

import com.collective.modelmatrix.{ModelFeature, ModelMatrix, TestSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FlatSpec

import scalaz.syntax.either._
import scalaz.{-\/, \/-}

class IdentityTransformerSpec extends FlatSpec with TestSparkContext {

  val sqlContext = ModelMatrix.sqlContext(sc)

  val schema = StructType(Seq(
    StructField("adv_site", StringType),
    StructField("adv_id", IntegerType)
  ))

  val input = Seq(
      Row("cnn.com", 1),
      Row("bbc.com", 2),
      Row("hbo.com", 1),
      Row("mashable.com", 3)
  )

  val isActive = true
  val withAllOther = true

  val adSite = ModelFeature(isActive, "Ad", "ad_site", "adv_site", Identity)
  val adId = ModelFeature(isActive, "Ad", "ad_id", "adv_id", Identity)

  val df = sqlContext.createDataFrame(sc.parallelize(input), schema)
  val transformer = new IdentityTransformer(Transformer.extractFeatures(df, Seq(adSite, adId)) match {
    case -\/(err) => sys.error(s"Can't extract features: $err")
    case \/-(suc) => suc
  })

  "Identity Transformer" should "support integer typed model feature" in {
    val valid = transformer.validate(adId)
    assert(valid == TypedModelFeature(adId, IntegerType).right)
  }

  it should "fail if feature column doesn't exists" in {
    val failed = transformer.validate(adSite.copy(feature = "adv_site"))
    assert(failed == FeatureTransformationError.FeatureColumnNotFound("adv_site").left)
  }

  it should "fail if column type is not supported" in {
    val failed = transformer.validate(adSite)
    assert(failed == FeatureTransformationError.UnsupportedTransformDataType("ad_site", StringType, Identity).left)
  }

} 
Example 40
Source File: TransformerSpec.scala    From modelmatrix   with Apache License 2.0 5 votes vote down vote up
package com.collective.modelmatrix.transform

import com.collective.modelmatrix.{ModelFeature, ModelMatrix, TestSparkContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FlatSpec

class TransformerSpec extends FlatSpec with TestSparkContext {

  val sqlContext = ModelMatrix.sqlContext(sc)

  val schema = StructType(Seq(
    StructField("adv_site", StringType),
    StructField("adv_id", IntegerType)
  ))

  val input = Seq(
    Row("cnn.com", 1),
    Row("bbc.com", 2),
    Row("hbo.com", 1),
    Row("mashable.com", 3)
  )

  val isActive = true
  val withAllOther = true

  // Can't call 'day_of_week' with String
  val badFunctionType = ModelFeature(isActive, "advertisement", "f1", "day_of_week(adv_site, 'UTC')", Top(95.0, allOther = false))

  // Not enough parameters for 'concat'
  val wrongParametersCount = ModelFeature(isActive, "advertisement", "f2", "concat(adv_site)", Top(95.0, allOther = false))

  val df = sqlContext.createDataFrame(sc.parallelize(input), schema)

  "Transformer" should "report failed feature extraction" in {
    val features = Transformer.extractFeatures(df, Seq(badFunctionType, wrongParametersCount))
    assert(features.isLeft)
    val errors = features.fold(identity, _ => sys.error("Should not be here"))

    assert(errors.length == 2)
    assert(errors(0).feature == badFunctionType)
    assert(errors(1).feature == wrongParametersCount)
  }

} 
Example 41
Source File: PileupTestBase.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.pileup

import com.holdenkarau.spark.testing.{DataFrameSuiteBase, SharedSparkContext}
import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{IntegerType, ShortType, StringType, StructField, StructType}
import org.scalatest.{BeforeAndAfter, FunSuite}

class PileupTestBase extends FunSuite
  with DataFrameSuiteBase
  with BeforeAndAfter
  with SharedSparkContext{

  val sampleId = "NA12878.multichrom.md"
  val samResPath: String = getClass.getResource("/multichrom/mdbam/samtools.pileup").getPath
  val referencePath: String = getClass.getResource("/reference/Homo_sapiens_assembly18_chr1_chrM.small.fasta").getPath
  val bamPath: String = getClass.getResource(s"/multichrom/mdbam/${sampleId}.bam").getPath
  val cramPath : String = getClass.getResource(s"/multichrom/mdcram/${sampleId}.cram").getPath
  val tableName = "reads_bam"
  val tableNameCRAM = "reads_cram"

  val schema: StructType = StructType(
    List(
      StructField("contig", StringType, nullable = true),
      StructField("position", IntegerType, nullable = true),
      StructField("reference", StringType, nullable = true),
      StructField("coverage", ShortType, nullable = true),
      StructField("pileup", StringType, nullable = true),
      StructField("quality", StringType, nullable = true)
    )
  )
  before {
    System.setProperty("spark.kryo.registrator", "org.biodatageeks.sequila.pileup.serializers.CustomKryoRegistrator")
    spark
      .conf.set("spark.sql.shuffle.partitions",1) //FIXME: In order to get orderBy in Samtools tests working - related to exchange partitions stage
    spark.sql(s"DROP TABLE IF EXISTS $tableName")
    spark.sql(
      s"""
         |CREATE TABLE $tableName
         |USING org.biodatageeks.sequila.datasources.BAM.BAMDataSource
         |OPTIONS(path "$bamPath")
         |
      """.stripMargin)

    spark.sql(s"DROP TABLE IF EXISTS $tableNameCRAM")
    spark.sql(
      s"""
         |CREATE TABLE $tableNameCRAM
         |USING org.biodatageeks.sequila.datasources.BAM.CRAMDataSource
         |OPTIONS(path "$cramPath", refPath "$referencePath" )
         |
      """.stripMargin)

    val mapToString = (map: Map[Byte, Short]) => {
      if (map == null)
        "null"
      else
        map.map({
          case (k, v) => k.toChar -> v}).mkString.replace(" -> ", ":")
    }

    val byteToString = ((byte: Byte) => byte.toString)

    spark.udf.register("mapToString", mapToString)
    spark.udf.register("byteToString", byteToString)
  }

} 
Example 42
Source File: CloudPartitionTest.scala    From cloud-integration   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.sources

import org.apache.hadoop.fs.Path

import org.apache.spark.sql._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


abstract class CloudPartitionTest extends AbstractCloudRelationTest {

import testImplicits._

  ctest(
    "save-findClass-partitioned-part-columns-in-data",
    "Save sets of files in explicitly set up partition tree; read") {
    withTempPathDir("part-columns", None) { path =>
      for (p1 <- 1 to 2; p2 <- Seq("foo", "bar")) {
        val partitionDir = new Path(path, s"p1=$p1/p2=$p2")
        val df = sparkContext
          .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1))
          .toDF("a", "b", "p1")

         df.write
          .format(dataSourceName)
          .mode(SaveMode.ErrorIfExists)
          .save(partitionDir.toString)
        // each of these directories as its own success file; there is
        // none at the root
        resolveSuccessFile(partitionDir, true)
      }

      val dataSchemaWithPartition =
        StructType(
          dataSchema.fields :+ StructField("p1", IntegerType, nullable = true))

      checkQueries(
        spark.read.options(Map(
          "path" -> path.toString,
          "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName)
          .load())
    }
  }
} 
Example 43
Source File: similarityFunctions.scala    From spark-stringmetric   with MIT License 5 votes vote down vote up
package com.github.mrpowers.spark.stringmetric.expressions

import com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions
import org.apache.commons.text.similarity.CosineDistance
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{
  CodegenContext,
  ExprCode
}
import org.apache.spark.sql.types.{ DataType, IntegerType, StringType }


trait UTF8StringFunctionsHelper {
  val stringFuncs: String = "com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions"
}

trait StringString2IntegerExpression
extends ImplicitCastInputTypes
with NullIntolerant
with UTF8StringFunctionsHelper { self: BinaryExpression =>
  override def dataType: DataType = IntegerType
  override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

  protected override def nullSafeEval(left: Any, right: Any): Any = -1
}

case class HammingDistance(left: Expression, right: Expression)
extends BinaryExpression with StringString2IntegerExpression {
  override def prettyName: String = "hamming"

  override def nullSafeEval(leftVal: Any, righValt: Any): Any = {
    val leftStr = left.asInstanceOf[UTF8String]
    val rightStr = right.asInstanceOf[UTF8String]
    UTF8StringFunctions.hammingDistance(leftStr, rightStr)
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    defineCodeGen(ctx, ev, (s1, s2) => s"$stringFuncs.hammingDistance($s1, $s2)")
  }
} 
Example 44
Source File: GroupedIteratorSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema)

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema)
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 45
Source File: ExpandSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.IntegerType

class ExpandSuite extends SparkPlanTest with SharedSQLContext {
  import testImplicits.localSeqToDataFrameHolder

  private def testExpand(f: SparkPlan => SparkPlan): Unit = {
    val input = (1 to 1000).map(Tuple1.apply)
    val projections = Seq.tabulate(2) { i =>
      Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil
    }
    val attributes = projections.head.map(_.toAttribute)
    checkAnswer(
      input.toDF(),
      plan => Expand(projections, attributes, f(plan)),
      input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j)))
    )
  }

  test("inheriting child row type") {
    val exprs = AttributeReference("a", IntegerType, false)() :: Nil
    val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty)))
    assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.")
  }

  test("expanding UnsafeRows") {
    testExpand(ConvertToUnsafe)
  }

  test("expanding SafeRows") {
    testExpand(identity)
  }
} 
Example 46
Source File: SemiJoinSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.joins

import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {

  private lazy val left = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(1, 2.0),
      Row(1, 2.0),
      Row(2, 1.0),
      Row(2, 1.0),
      Row(3, 3.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("a", IntegerType).add("b", DoubleType))

  private lazy val right = sqlContext.createDataFrame(
    sparkContext.parallelize(Seq(
      Row(2, 3.0),
      Row(2, 3.0),
      Row(3, 2.0),
      Row(4, 1.0),
      Row(null, null),
      Row(null, 5.0),
      Row(6, null)
    )), new StructType().add("c", IntegerType).add("d", DoubleType))

  private lazy val condition = {
    And((left.col("a") === right.col("c")).expr,
      LessThan(left.col("b").expr, right.col("d").expr))
  }

  // Note: the input dataframes and expression must be evaluated lazily because
  // the SQLContext should be used only within a test to keep SQL tests stable
  private def testLeftSemiJoin(
      testName: String,
      leftRows: => DataFrame,
      rightRows: => DataFrame,
      condition: => Expression,
      expectedAnswer: Seq[Product]): Unit = {

    def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
      val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
      ExtractEquiJoinKeys.unapply(join)
    }

    test(s"$testName using LeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            EnsureRequirements(left.sqlContext).apply(
              LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using BroadcastLeftSemiJoinHash") {
      extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
        withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
          checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
            BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
            expectedAnswer.map(Row.fromTuple),
            sortAnswers = true)
        }
      }
    }

    test(s"$testName using LeftSemiJoinBNL") {
      withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
        checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
          LeftSemiJoinBNL(left, right, Some(condition)),
          expectedAnswer.map(Row.fromTuple),
          sortAnswers = true)
      }
    }
  }

  testLeftSemiJoin(
    "basic test",
    left,
    right,
    condition,
    Seq(
      (2, 1.0),
      (2, 1.0)
    )
  )
} 
Example 47
Source File: LocalNodeTest.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.local

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType}


class LocalNodeTest extends SparkFunSuite {

  protected val conf: SQLConf = new SQLConf
  protected val kvIntAttributes = Seq(
    AttributeReference("k", IntegerType)(),
    AttributeReference("v", IntegerType)())
  protected val joinNameAttributes = Seq(
    AttributeReference("id1", IntegerType)(),
    AttributeReference("name", StringType)())
  protected val joinNicknameAttributes = Seq(
    AttributeReference("id2", IntegerType)(),
    AttributeReference("nickname", StringType)())

  
  protected def resolveExpressions(
      expressions: Seq[Expression],
      localNode: LocalNode): Seq[Expression] = {
    require(localNode.expressions.forall(_.resolved))
    val inputMap = localNode.output.map { a => (a.name, a) }.toMap
    expressions.map { expression =>
      expression.transformUp {
        case UnresolvedAttribute(Seq(u)) =>
          inputMap.getOrElse(u,
            sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
      }
    }
  }

} 
Example 48
Source File: ProjectNodeSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.local

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression}
import org.apache.spark.sql.types.{IntegerType, StringType}


class ProjectNodeSuite extends LocalNodeTest {
  private val pieAttributes = Seq(
    AttributeReference("id", IntegerType)(),
    AttributeReference("age", IntegerType)(),
    AttributeReference("name", StringType)())

  private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = {
    val inputNode = new DummyNode(pieAttributes, inputData)
    val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2))
    val projectNode = new ProjectNode(conf, columns, inputNode)
    val expectedOutput = inputData.map { case (id, age, name) => (id, name) }
    val actualOutput = projectNode.collect().map { case row =>
      (row.getInt(0), row.getString(1))
    }
    assert(actualOutput === expectedOutput)
  }

  test("empty") {
    testProject()
  }

  test("basic") {
    testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray)
  }

} 
Example 49
Source File: AttributeSetSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  }

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)
  }

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)
  }

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
    assert(addSet.contains(aUpper))
    assert(addSet.contains(aLower))
    assert(addSet.contains(bUpper))
    assert(addSet.contains(bLower))
  }

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
  }

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)
  }

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
  }
} 
Example 50
Source File: GroupedIteratorSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 51
Source File: NullValuesTest.scala    From spark-dynamodb   with Apache License 2.0 5 votes vote down vote up
package com.audienceproject.spark.dynamodb

import com.amazonaws.services.dynamodbv2.model.{AttributeDefinition, CreateTableRequest, KeySchemaElement, ProvisionedThroughput}
import com.audienceproject.spark.dynamodb.implicits._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class NullValuesTest extends AbstractInMemoryTest {

    test("Insert nested StructType with null values") {
        dynamoDB.createTable(new CreateTableRequest()
            .withTableName("NullTest")
            .withAttributeDefinitions(new AttributeDefinition("name", "S"))
            .withKeySchema(new KeySchemaElement("name", "HASH"))
            .withProvisionedThroughput(new ProvisionedThroughput(5L, 5L)))

        val schema = StructType(
            Seq(
                StructField("name", StringType, nullable = false),
                StructField("info", StructType(
                    Seq(
                        StructField("age", IntegerType, nullable = true),
                        StructField("address", StringType, nullable = true)
                    )
                ), nullable = true)
            )
        )

        val rows = spark.sparkContext.parallelize(Seq(
            Row("one", Row(30, "Somewhere")),
            Row("two", null),
            Row("three", Row(null, null))
        ))

        val newItemsDs = spark.createDataFrame(rows, schema)

        newItemsDs.write.dynamodb("NullTest")

        val validationDs = spark.read.dynamodb("NullTest")

        validationDs.show(false)
    }

} 
Example 52
Source File: UDFTest.scala    From SparkGIS   with Apache License 2.0 5 votes vote down vote up
package org.betterers.spark.gis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, Row}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.betterers.spark.gis.udf.Functions


class UDFTest extends FunSuite with BeforeAndAfter {
  import Geometry.WGS84

  val point = Geometry.point((2.0, 2.0))
  val multiPoint = Geometry.multiPoint((1.0, 1.0), (2.0, 2.0), (3.0, 3.0))
  var line = Geometry.line((11.0, 11.0), (12.0, 12.0))
  var multiLine = Geometry.multiLine(
    Seq((11.0, 1.0), (23.0, 23.0)),
    Seq((31.0, 3.0), (42.0, 42.0)))
  var polygon = Geometry.polygon((1.0, 1.0), (2.0, 2.0), (3.0, 1.0))
  var multiPolygon = Geometry.multiPolygon(
    Seq((1.0, 1.0), (2.0, 2.0), (3.0, 1.0)),
    Seq((1.1, 1.1), (2.0, 1.9), (2.5, 1.1))
  )
  val collection = Geometry.collection(point, multiPoint, line)
  val all: Seq[Geometry] = Seq(point, multiPoint, line, multiLine, polygon, multiPolygon, collection)

  var sc: SparkContext = _
  var sql: SQLContext = _

  before {
    sc = new SparkContext(new SparkConf().setMaster("local[4]").setAppName("SparkGIS"))
    sql = new SQLContext(sc)
  }

  after {
    sc.stop()
  }

  test("ST_Boundary") {
    // all.foreach(g => println(Functions.ST_Boundary(g).toString))

    assertResult(true) {
      Functions.ST_Boundary(point).isEmpty
    }
    assertResult(true) {
      Functions.ST_Boundary(multiPoint).isEmpty
    }
    assertResult("Some(MULTIPOINT ((11 11), (12 12)))") {
      Functions.ST_Boundary(line).toString
    }
    assertResult(None) {
      Functions.ST_Boundary(multiLine)
    }
    assertResult("Some(LINEARRING (1 1, 2 2, 3 1, 1 1))") {
      Functions.ST_Boundary(polygon).toString
    }
    assertResult(None) {
      Functions.ST_Boundary(multiPolygon)
    }
    assertResult(None) {
      Functions.ST_Boundary(collection)
    }
  }

  test("ST_CoordDim") {
    all.foreach(g => {
      assertResult(3) {
        Functions.ST_CoordDim(g)
      }
    })
  }

  test("UDF in SQL") {
    val schema = StructType(Seq(
      StructField("id", IntegerType),
      StructField("geo", GeometryType.Instance)
    ))
    val jsons = Map(
      (1, "{\"type\":\"Point\",\"coordinates\":[1,1]}}"),
      (2, "{\"type\":\"LineString\",\"coordinates\":[[12,13],[15,20]]}}")
    )
    val rdd = sc.parallelize(Seq(
      "{\"id\":1,\"geo\":" + jsons(1) + "}",
      "{\"id\":2,\"geo\":" + jsons(2) + "}"
    ))
    rdd.name = "TEST"
    val df = sql.read.schema(schema).json(rdd)
    df.registerTempTable("TEST")
    Functions.register(sql)
    assertResult(Array(3,3)) {
      sql.sql("SELECT ST_CoordDim(geo) FROM TEST").collect().map(_.get(0))
    }
  }
} 
Example 53
Source File: TemporalDataSuite.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.datasource.config.ConfigParameters._
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.BeforeAndAfter

private[datasource] trait TemporalDataSuite extends DatasourceSuite
  with BeforeAndAfter {

  val conf = new SparkConf()
    .setAppName("datasource-receiver-example")
    .setIfMissing("spark.master", "local[*]")
  var sc: SparkContext = null
  var ssc: StreamingContext = null
  val tableName = "tableName"
  val datasourceParams = Map(
    StopGracefully -> "true",
    StopSparkContext -> "false",
    StorageLevelKey -> "MEMORY_ONLY",
    RememberDuration -> "15s"
  )
  val schema = new StructType(Array(
    StructField("id", StringType, nullable = true),
    StructField("idInt", IntegerType, nullable = true)
  ))
  val totalRegisters = 10000
  val registers = for (a <- 1 to totalRegisters) yield Row(a.toString, a)

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (sc != null) {
      sc.stop()
      sc = null
    }
  }
} 
Example 54
Source File: RecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.unit

import com.adidas.analytics.util.RecoverPartitionsCustom
import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class RecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")
  }

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")
  }

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")
  }

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)
    }
  }

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)
    }
  }

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = RecoverPartitionsCustom(
      tableName="test",
      targetPartitions = Seq("country","district")
    )

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")
    )

    val inputSchema = StructType(
      List(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)
      )
    )

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"
    )

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))
      .collectAsList()
      .asScala

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)
  }

  override def afterAll(): Unit = {
    spark.stop()
  }

} 
Example 55
Source File: SparkRecoverPartitionsCustomTest.scala    From m3d-engine   with Apache License 2.0 5 votes vote down vote up
package com.adidas.analytics.unit

import com.adidas.analytics.util.SparkRecoverPartitionsCustom
import com.adidas.utils.SparkSessionWrapper
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Dataset, Row}
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, PrivateMethodTester}

import scala.collection.JavaConverters._

class SparkRecoverPartitionsCustomTest extends FunSuite
  with SparkSessionWrapper
  with PrivateMethodTester
  with Matchers
  with BeforeAndAfterAll{

  test("test conversion of String Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue("theValue")

    result should be("'theValue'")
  }

  test("test conversion of Short Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Short.valueOf("2"))

    result should be("2")
  }

  test("test conversion of Integer Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    val result = customSparkRecoverPartitions invokePrivate createParameterValue(java.lang.Integer.valueOf("4"))

    result should be("4")
  }

  test("test conversion of null Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(null)
    }
  }

  test("test conversion of not supported Value to HiveQL Partition Parameter") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(tableName="", targetPartitions = Seq())
    val createParameterValue = PrivateMethod[String]('createParameterValue)
    an [Exception] should be thrownBy {
      customSparkRecoverPartitions invokePrivate createParameterValue(false)
    }
  }

  test("test HiveQL statements Generation") {
    val customSparkRecoverPartitions = SparkRecoverPartitionsCustom(
      tableName="test",
      targetPartitions = Seq("country","district")
    )

    val rowsInput = Seq(
      Row(1, "portugal", "porto"),
      Row(2, "germany", "herzogenaurach"),
      Row(3, "portugal", "coimbra")
    )

    val inputSchema = StructType(
      List(
        StructField("number", IntegerType, nullable = true),
        StructField("country", StringType, nullable = true),
        StructField("district", StringType, nullable = true)
      )
    )

    val expectedStatements: Seq[String] = Seq(
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='porto')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='germany',district='herzogenaurach')",
      "ALTER TABLE test ADD IF NOT EXISTS PARTITION(country='portugal',district='coimbra')"
    )

    val testDataset: Dataset[Row] = spark.createDataset(rowsInput)(RowEncoder(inputSchema))

    val createParameterValue = PrivateMethod[Dataset[String]]('generateAddPartitionStatements)

    val producedStatements: Seq[String] = (customSparkRecoverPartitions invokePrivate createParameterValue(testDataset))
      .collectAsList()
      .asScala

    expectedStatements.sorted.toSet should equal(producedStatements.sorted.toSet)
  }

  override def afterAll(): Unit = {
    spark.stop()
  }

} 
Example 56
Source File: ReferenceTableTest.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.{DataFrame, SaveMode}

import scala.util.Try

class ReferenceTableTest extends IntegrationSuiteBase {

  val childAggregatorHost = "localhost"
  val childAggregatorPort = "5508"

  val dbName                  = "testdb"
  val commonCollectionName    = "test_table"
  val referenceCollectionName = "reference_table"

  override def beforeEach(): Unit = {
    super.beforeEach()

    // Set child aggregator as a dmlEndpoint
    spark.conf
      .set("spark.datasource.memsql.dmlEndpoints", s"${childAggregatorHost}:${childAggregatorPort}")
  }

  def writeToTable(tableName: String): Unit = {
    val df = spark.createDF(
      List(4, 5, 6),
      List(("id", IntegerType, true))
    )
    df.write
      .format(DefaultSource.MEMSQL_SOURCE_NAME_SHORT)
      .mode(SaveMode.Append)
      .save(s"${dbName}.${tableName}")
  }

  def readFromTable(tableName: String): DataFrame = {
    spark.read
      .format(DefaultSource.MEMSQL_SOURCE_NAME_SHORT)
      .load(s"${dbName}.${tableName}")
  }

  def writeAndReadFromTable(tableName: String): Unit = {
    writeToTable(tableName)
    val dataFrame = readFromTable(tableName)
    val sqlRows   = dataFrame.collect();
    assert(sqlRows.length == 3)
  }

  def dropTable(tableName: String): Unit = executeQuery(s"drop table if exists $dbName.$tableName")

  describe("Success during write operations") {

    it("to common table") {
      dropTable(commonCollectionName)
      executeQuery(
        s"create table if not exists $dbName.$commonCollectionName (id INT NOT NULL, PRIMARY KEY (id))")
      writeAndReadFromTable(commonCollectionName)
    }

    it("to reference table") {
      dropTable(referenceCollectionName)
      executeQuery(
        s"create reference table if not exists $dbName.$referenceCollectionName (id INT NOT NULL, PRIMARY KEY (id))")
      writeAndReadFromTable(referenceCollectionName)
    }
  }

  describe("Success during creating") {

    it("common table") {
      dropTable(commonCollectionName)
      writeAndReadFromTable(commonCollectionName)
    }
  }

  describe("Failure because of") {

    it("database name not specified") {
      spark.conf.set("spark.datasource.memsql.database", "")
      val df = spark.createDF(
        List(4, 5, 6),
        List(("id", IntegerType, true))
      )
      val result = Try {
        df.write
          .format(DefaultSource.MEMSQL_SOURCE_NAME_SHORT)
          .mode(SaveMode.Append)
          .save(s"${commonCollectionName}")
      }
      
      assert(SQLHelper.isSQLExceptionWithCode(result.failed.get, List(1046)))
    }
  }
} 
Example 57
Source File: OutputMetricsTest.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
import org.apache.spark.sql.types.{IntegerType, StringType}

class OutputMetricsTest extends IntegrationSuiteBase {
  it("records written") {
    var outputWritten = 0L
    spark.sparkContext.addSparkListener(new SparkListener() {
      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
        val metrics = taskEnd.taskMetrics
        outputWritten += metrics.outputMetrics.recordsWritten
      }
    })

    val numRows = 100000
    val df1 = spark.createDF(
      List.range(0, numRows),
      List(("id", IntegerType, true))
    )

    df1.repartition(30)

    df1.write
      .format("memsql")
      .save("metricsInts")

    assert(outputWritten == numRows)
    outputWritten = 0

    val df2 = spark.createDF(
      List("st1", "", null),
      List(("st", StringType, true))
    )

    df2.write
      .format("memsql")
      .save("metricsStrings")

    assert(outputWritten == 3)
  }
} 
Example 58
Source File: BinaryTypeBenchmark.scala    From memsql-spark-connector   with Apache License 2.0 5 votes vote down vote up
package com.memsql.spark

import java.sql.{Connection, DriverManager}
import java.util.Properties

import com.github.mrpowers.spark.daria.sql.SparkSessionExt._
import com.memsql.spark.BatchInsertBenchmark.{df, executeQuery}
import org.apache.spark.sql.types.{BinaryType, IntegerType}
import org.apache.spark.sql.{SaveMode, SparkSession}

import scala.util.Random

// BinaryTypeBenchmark is written to writing of the BinaryType with CPU profiler
// this feature is accessible in Ultimate version of IntelliJ IDEA
// see https://www.jetbrains.com/help/idea/async-profiler.html#profile for more details
object BinaryTypeBenchmark extends App {
  final val masterHost: String = sys.props.getOrElse("memsql.host", "localhost")
  final val masterPort: String = sys.props.getOrElse("memsql.port", "5506")

  val spark: SparkSession = SparkSession
    .builder()
    .master("local")
    .config("spark.sql.shuffle.partitions", "1")
    .config("spark.driver.bindAddress", "localhost")
    .config("spark.datasource.memsql.ddlEndpoint", s"${masterHost}:${masterPort}")
    .config("spark.datasource.memsql.database", "testdb")
    .getOrCreate()

  def jdbcConnection: Loan[Connection] = {
    val connProperties = new Properties()
    connProperties.put("user", "root")

    Loan(
      DriverManager.getConnection(
        s"jdbc:mysql://$masterHost:$masterPort",
        connProperties
      ))
  }

  def executeQuery(sql: String): Unit = {
    jdbcConnection.to(conn => Loan(conn.createStatement).to(_.execute(sql)))
  }

  executeQuery("set global default_partitions_per_leaf = 2")
  executeQuery("drop database if exists testdb")
  executeQuery("create database testdb")

  def genRandomByte(): Byte = (Random.nextInt(256) - 128).toByte
  def genRandomRow(): Array[Byte] =
    Array.fill(1000)(genRandomByte())

  val df = spark.createDF(
    List.fill(100000)(genRandomRow()).zipWithIndex,
    List(("data", BinaryType, true), ("id", IntegerType, true))
  )

  val start1 = System.nanoTime()
  df.write
    .format("memsql")
    .mode(SaveMode.Overwrite)
    .save("testdb.LoadData")

  println("Elapsed time: " + (System.nanoTime() - start1) + "ns [LoadData CSV]")

  val start2 = System.nanoTime()
  df.write
    .format("memsql")
    .option("tableKey.primary", "id")
    .option("onDuplicateKeySQL", "id = id")
    .mode(SaveMode.Overwrite)
    .save("testdb.BatchInsert")

  println("Elapsed time: " + (System.nanoTime() - start2) + "ns [BatchInsert]")

  val avroStart = System.nanoTime()
  df.write
    .format(DefaultSource.MEMSQL_SOURCE_NAME_SHORT)
    .mode(SaveMode.Overwrite)
    .option(MemsqlOptions.LOAD_DATA_FORMAT, "Avro")
    .save("testdb.AvroSerialization")
  println("Elapsed time: " + (System.nanoTime() - avroStart) + "ns [LoadData Avro] ")
} 
Example 59
Source File: PopulateHiveTable.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.nested

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}


object PopulateHiveTable {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder
      .master("local")
      .appName("my-spark-app")
      .config("spark.some.config.option", "config-value")
      .config("spark.driver.host","127.0.0.1")
      .config("spark.sql.parquet.compression.codec", "gzip")
      .enableHiveSupport()
      .getOrCreate()


    spark.sql("create table IF NOT EXISTS nested_empty " +
      "( A int, " +
      "  B string, " +
      "  nested ARRAY<STRUCT< " +
      "     nested_C: int," +
      "     nested_D: string" +
      "  >>" +
      ") ")

    val rowRDD = spark.sparkContext.
      parallelize(Array(
        Row(1, "foo", Seq(Row(1, "barA"),Row(2, "bar"))),
        Row(2, "foo", Seq(Row(1, "barB"),Row(2, "bar"))),
        Row(3, "foo", Seq(Row(1, "barC"),Row(2, "bar")))))

    val emptyDf = spark.sql("select * from nested_empty limit 0")

    val tableSchema = emptyDf.schema

    val populated1Df = spark.sqlContext.createDataFrame(rowRDD, tableSchema)

    populated1Df.repartition(2).write.saveAsTable("nested_populated")

    println("----")
    populated1Df.collect().foreach(r => println(" emptySchemaExample:" + r))

    val nestedSchema = new StructType()
      .add("nested_C", IntegerType)
      .add("nested_D", StringType)

    val definedSchema = new StructType()
      .add("A", IntegerType)
      .add("B", StringType)
      .add("nested", ArrayType(nestedSchema))

    val populated2Df = spark.sqlContext.createDataFrame(rowRDD, definedSchema)

    println("----")
    populated1Df.collect().foreach(r => println(" BuiltExample:" + r))

    spark.stop()
  }
} 
Example 60
Source File: NestedTableExample.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.nested

import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructType}
import org.apache.spark.sql.{Row, SparkSession}

object NestedTableExample {
  Logger.getLogger("org").setLevel(Level.OFF)
  Logger.getLogger("akka").setLevel(Level.OFF)

  def main(args: Array[String]): Unit = {

    val spark = SparkSession.builder
      .master("local")
      .appName("my-spark-app")
      .config("spark.some.config.option", "config-value")
      .config("spark.driver.host","127.0.0.1")
      .enableHiveSupport()
      .getOrCreate()


    spark.sql("create table IF NOT EXISTS nested_empty " +
      "( A int, " +
      "  B string, " +
      "  nested ARRAY<STRUCT< " +
      "     nested_C: int," +
      "     nested_D: string" +
      "  >>" +
      ") ")

    val rowRDD = spark.sparkContext.
      parallelize(Array(
        Row(1, "foo", Seq(Row(1, "barA"),Row(2, "bar"))),
        Row(2, "foo", Seq(Row(1, "barB"),Row(2, "bar"))),
        Row(3, "foo", Seq(Row(1, "barC"),Row(2, "bar")))))

    val emptyDf = spark.sql("select * from nested_empty limit 0")

    val tableSchema = emptyDf.schema

    val populated1Df = spark.sqlContext.createDataFrame(rowRDD, tableSchema)

    println("----")
    populated1Df.collect().foreach(r => println(" emptySchemaExample:" + r))

    val nestedSchema = new StructType()
      .add("nested_C", IntegerType)
      .add("nested_D", StringType)

    val definedSchema = new StructType()
      .add("A", IntegerType)
      .add("B", StringType)
      .add("nested", ArrayType(nestedSchema))

    val populated2Df = spark.sqlContext.createDataFrame(rowRDD, definedSchema)
    println("----")
    populated1Df.collect().foreach(r => println(" BuiltExample:" + r))

    spark.stop()
  }
} 
Example 61
Source File: CarbonDataFrameExample.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.examples

import org.apache.spark.sql.{SaveMode, SparkSession}

import org.apache.carbondata.examples.util.ExampleUtils

object CarbonDataFrameExample {

  def main(args: Array[String]) {
    val spark = ExampleUtils.createSparkSession("CarbonDataFrameExample")
    exampleBody(spark)
    spark.close()
  }

  def exampleBody(spark : SparkSession): Unit = {
    // Writes Dataframe to CarbonData file:
    import spark.implicits._
    val df = spark.sparkContext.parallelize(1 to 100)
      .map(x => ("a" + x % 10, "b", x))
      .toDF("c1", "c2", "number")

    // Saves dataframe to carbondata file
    df.write
      .format("carbondata")
      .option("tableName", "carbon_df_table")
      .option("partitionColumns", "c1")  // a list of column names
      .mode(SaveMode.Overwrite)
      .save()

    spark.sql(""" SELECT * FROM carbon_df_table """).show()

    spark.sql("SHOW PARTITIONS carbon_df_table").show()

    // Specify schema
    import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerType}
    val customSchema = StructType(Array(
      StructField("c1", StringType),
      StructField("c2", StringType),
      StructField("number", IntegerType)))

    // Reads carbondata to dataframe
    val carbondf = spark.read
      .format("carbondata")
      .schema(customSchema)
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "carbon_df_table")
      .load()

    // Dataframe operations
    carbondf.printSchema()
    carbondf.select($"c1", $"number" + 10).show()
    carbondf.filter($"number" > 31).show()

    spark.sql("DROP TABLE IF EXISTS carbon_df_table")
  }
} 
Example 62
Source File: BlockingSource.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    BlockingSource.latch.await()
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
        Seq[Int]().toDS().toDF()
      }
      override def stop() {}
    }
  }

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}
    }
  }
}

object BlockingSource {
  var latch: CountDownLatch = null
} 
Example 63
Source File: StratifiedRepartitionSuite.scala    From mmlspark   with MIT License 5 votes vote down vote up
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.stages

import com.microsoft.ml.spark.core.test.base.TestBase
import com.microsoft.ml.spark.core.test.fuzzing.{TestObject, TransformerFuzzing}
import org.apache.spark.TaskContext
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class StratifiedRepartitionSuite extends TestBase with TransformerFuzzing[StratifiedRepartition] {

  import session.implicits._

  val values = "values"
  val colors = "colors"
  val const = "const"

  lazy val input = Seq(
    (0, "Blue", 2),
    (0, "Red", 2),
    (0, "Green", 2),
    (1, "Purple", 2),
    (1, "Orange", 2),
    (1, "Indigo", 2),
    (2, "Violet", 2),
    (2, "Black", 2),
    (2, "White", 2),
    (3, "Gray", 2),
    (3, "Yellow", 2),
    (3, "Cerulean", 2)
  ).toDF(values, colors, const)

  test("Assert doing a stratified repartition will ensure all keys exist across all partitions") {
    val inputSchema = new StructType()
      .add(values, IntegerType).add(colors, StringType).add(const, IntegerType)
    val inputEnc = RowEncoder(inputSchema)
    val valuesFieldIndex = inputSchema.fieldIndex(values)
    val numPartitions = 3
    val trainData = input.repartition(numPartitions).select(values, colors, const)
      .mapPartitions(iter => {
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        // Remove all instances of 0 class on partition 1
        if (partId == 1) {
          iter.flatMap(row => {
            if (row.getInt(valuesFieldIndex) <= 0)
              None
            else Some(row)
          })
        } else {
          // Add back at least 3 instances on other partitions
          val oneOfEachExample = List(Row(0, "Blue", 2), Row(1, "Purple", 2), Row(2, "Black", 2), Row(3, "Gray", 2))
          (iter.toList.union(oneOfEachExample).union(oneOfEachExample).union(oneOfEachExample)).toIterator
        }
      })(inputEnc).cache()
    // Some debug to understand what data is on which partition
    trainData.foreachPartition { rows =>
      rows.foreach { row =>
        val ctx = TaskContext.get
        val partId = ctx.partitionId
        println(s"Row: $row partition id: $partId")
      }
    }
    val stratifiedInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Equal).transform(trainData)
    // Assert stratified data contains all keys across all partitions, with extra count
    // for it to be evaluated
    stratifiedInputData
      .mapPartitions(iter => {
        val actualLabels = iter.map(row => row.getInt(valuesFieldIndex))
          .toArray.distinct.sorted.toList
        val expectedLabels = (0 to 3).toList
        if (actualLabels != expectedLabels)
          throw new Exception(s"Missing labels, actual: $actualLabels, expected: $expectedLabels")
        iter
      })(inputEnc).count()
    val stratifiedMixedInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Mixed).transform(trainData)
    assert(stratifiedMixedInputData.count() >= trainData.count())
    val stratifiedOriginalInputData = new StratifiedRepartition().setLabelCol(values)
      .setMode(SPConstants.Original).transform(trainData)
    assert(stratifiedOriginalInputData.count() == trainData.count())
  }

  def testObjects(): Seq[TestObject[StratifiedRepartition]] = List(new TestObject(
    new StratifiedRepartition().setLabelCol(values).setMode(SPConstants.Equal), input))

  def reader: MLReadable[_] = StratifiedRepartition
} 
Example 64
Source File: WholeStageCodegenSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 65
Source File: resources.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 66
Source File: SimplifyConditionalSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
    assertEquivalent(
      If(TrueLiteral, Literal(10), Literal(20)),
      Literal(10))

    assertEquivalent(
      If(FalseLiteral, Literal(10), Literal(20)),
      Literal(20))

    assertEquivalent(
      If(Literal.create(null, NullType), Literal(10), Literal(20)),
      Literal(20))
  }

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
    assertEquivalent(
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))
  }

  test("remove entire CaseWhen if only the else branch is reachable") {
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
      Literal(30))

    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))
  }

  test("remove entire CaseWhen if the first branch is always true") {
    assertEquivalent(
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
      Literal(5))

    // Test branch elimination and simplification in combination
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),
      Literal(5))

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
    assertEquivalent(
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
  }
} 
Example 67
Source File: RewriteDistinctAggregatesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 68
Source File: AttributeSetSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  }

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)
  }

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)
  }

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
    assert(addSet.contains(aUpper))
    assert(addSet.contains(aLower))
    assert(addSet.contains(bUpper))
    assert(addSet.contains(bLower))
  }

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
  }

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)
  }

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
  }
} 
Example 69
Source File: RandomSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

    checkDoubleEvaluation(
      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
    checkDoubleEvaluation(
      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)
  }

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
  }
} 
Example 70
Source File: ObjectExpressionsSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}


class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SPARK-16622: The returned value of the called method in Invoke can be null") {
    val inputRow = InternalRow.fromSeq(Seq((false, null)))
    val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val invoke = Invoke(inputObject, "_2", IntegerType)
    checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)
  }

  test("MapObjects should make copies of unsafe-backed data") {
    // test UnsafeRow-backed data
    val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
    val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
    val structExpected = new GenericArrayData(
      Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
    checkEvalutionWithUnsafeProjection(
      structEncoder.serializer.head, structExpected, structInputRow)

    // test UnsafeArray-backed data
    val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
    val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
    val arrayExpected = new GenericArrayData(
      Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
    checkEvalutionWithUnsafeProjection(
      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

    // test UnsafeMap-backed data
    val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
    val mapInputRow = InternalRow.fromSeq(Seq(Array(
      Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
    val mapExpected = new GenericArrayData(Seq(
      new ArrayBasedMapData(
        new GenericArrayData(Array(1, 2)),
        new GenericArrayData(Array(100, 200))),
      new ArrayBasedMapData(
        new GenericArrayData(Array(3, 4)),
        new GenericArrayData(Array(300, 400)))))
    checkEvalutionWithUnsafeProjection(
      mapEncoder.serializer.head, mapExpected, mapInputRow)
  }
} 
Example 71
Source File: ExpressionEvalHelperSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      s"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 72
Source File: ScalaUDFSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

} 
Example 73
Source File: MapDataSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 74
Source File: LogicalPlanSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType


class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1
      p
  }

  private val testRelation = LocalRelation()

  test("resolveOperator runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan resolveOperators function

    assert(invocationCount === 1)
  }

  test("resolveOperator runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 2)
  }

  test("resolveOperator skips all ready resolved plans") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan.foreach(_.setAnalyzed())
    plan resolveOperators function

    assert(invocationCount === 0)
  }

  test("resolveOperator skips partially resolved plans") {
    invocationCount = 0
    val plan1 = Project(Nil, testRelation)
    val plan2 = Project(Nil, plan1)
    plan1.foreach(_.setAnalyzed())
    plan2 resolveOperators function

    assert(invocationCount === 1)
  }

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = new LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)())) {
      override def isStreaming(): Boolean = true
    }

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output
    }

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
  }
} 
Example 75
Source File: AttributeSetSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  }

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)
  }

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)
  }

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
    assert(addSet.contains(aUpper))
    assert(addSet.contains(aLower))
    assert(addSet.contains(bUpper))
    assert(addSet.contains(bLower))
  }

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
  }

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)
  }

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
  }
} 
Example 76
Source File: EstimatorModelWrapperIntegSpec.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.spark.wrappers.estimators

import ai.deepsense.deeplang.DeeplangIntegTestSupport
import ai.deepsense.deeplang.doperables.dataframe.DataFrame
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructType, StructField}

class EstimatorModelWrapperIntegSpec extends DeeplangIntegTestSupport {

  import ai.deepsense.deeplang.doperables.spark.wrappers.estimators.EstimatorModelWrapperFixtures._

  val inputDF = {
    val rowSeq = Seq(Row(1), Row(2), Row(3))
    val schema = StructType(Seq(StructField("x", IntegerType, nullable = false)))
    createDataFrame(rowSeq, schema)
  }

  val estimatorPredictionParamValue = "estimatorPrediction"

  val expectedSchema = StructType(Seq(
    StructField("x", IntegerType, nullable = false),
    StructField(estimatorPredictionParamValue, IntegerType, nullable = false)
  ))

  val transformerPredictionParamValue = "modelPrediction"

  val expectedSchemaForTransformerParams = StructType(Seq(
    StructField("x", IntegerType, nullable = false),
    StructField(transformerPredictionParamValue, IntegerType, nullable = false)
  ))

  "EstimatorWrapper" should {
    "_fit() and transform() + transformSchema() with parameters inherited" in {

      val transformer = createEstimatorAndFit()

      val transformOutputSchema =
        transformer._transform(executionContext, inputDF).sparkDataFrame.schema
      transformOutputSchema shouldBe expectedSchema

      val inferenceOutputSchema = transformer._transformSchema(inputDF.sparkDataFrame.schema)
      inferenceOutputSchema shouldBe Some(expectedSchema)
    }

    "_fit() and transform() + transformSchema() with parameters overwritten" in {

      val transformer = createEstimatorAndFit().setPredictionColumn(transformerPredictionParamValue)

      val transformOutputSchema =
        transformer._transform(executionContext, inputDF).sparkDataFrame.schema
      transformOutputSchema shouldBe expectedSchemaForTransformerParams

      val inferenceOutputSchema = transformer._transformSchema(inputDF.sparkDataFrame.schema)
      inferenceOutputSchema shouldBe Some(expectedSchemaForTransformerParams)
    }

    "_fit_infer().transformSchema() with parameters inherited" in {

      val estimatorWrapper = new SimpleSparkEstimatorWrapper()
        .setPredictionColumn(estimatorPredictionParamValue)

      estimatorWrapper._fit_infer(inputDF.schema)
        ._transformSchema(inputDF.sparkDataFrame.schema) shouldBe Some(expectedSchema)
    }

    "_fit_infer().transformSchema() with parameters overwritten" in {

      val estimatorWrapper = new SimpleSparkEstimatorWrapper()
        .setPredictionColumn(estimatorPredictionParamValue)
      val transformer =
        estimatorWrapper._fit_infer(inputDF.schema).asInstanceOf[SimpleSparkModelWrapper]
      val transformerWithParams = transformer.setPredictionColumn(transformerPredictionParamValue)

      val outputSchema = transformerWithParams._transformSchema(inputDF.sparkDataFrame.schema)
      outputSchema shouldBe Some(expectedSchemaForTransformerParams)
    }
  }

  private def createEstimatorAndFit(): SimpleSparkModelWrapper = {

    val estimatorWrapper = new SimpleSparkEstimatorWrapper()
      .setPredictionColumn(estimatorPredictionParamValue)

    val transformer =
      estimatorWrapper._fit(executionContext, inputDF).asInstanceOf[SimpleSparkModelWrapper]
    transformer.getPredictionColumn() shouldBe estimatorPredictionParamValue

    transformer
  }
} 
Example 77
Source File: EstimatorModelWrapperFixtures.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperables.spark.wrappers.estimators

import scala.language.reflectiveCalls

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml
import org.apache.spark.ml.param.{ParamMap, Param => SparkParam}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

import ai.deepsense.deeplang.ExecutionContext
import ai.deepsense.deeplang.doperables.report.Report
import ai.deepsense.deeplang.doperables.serialization.SerializableSparkModel
import ai.deepsense.deeplang.doperables.{SparkEstimatorWrapper, SparkModelWrapper}
import ai.deepsense.deeplang.params.wrappers.spark.SingleColumnCreatorParamWrapper
import ai.deepsense.deeplang.params.{Param, Params}
import ai.deepsense.sparkutils.ML

object EstimatorModelWrapperFixtures {

  class SimpleSparkModel private[EstimatorModelWrapperFixtures]()
    extends ML.Model[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "modelId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    def setPredictionCol(value: String): this.type = set(predictionCol, value)

    override def copy(extra: ParamMap): this.type = defaultCopy(extra)

    override def transformDF(dataset: DataFrame): DataFrame = {
      dataset.selectExpr("*", "1 as " + $(predictionCol))
    }

    @DeveloperApi
    override def transformSchema(schema: StructType): StructType = ???
  }

  class SimpleSparkEstimator extends ML.Estimator[SimpleSparkModel] {

    def this(x: String) = this()

    override val uid: String = "estimatorId"

    val predictionCol = new SparkParam[String](uid, "name", "description")

    override def fitDF(dataset: DataFrame): SimpleSparkModel =
      new SimpleSparkModel().setPredictionCol($(predictionCol))

    override def copy(extra: ParamMap): ML.Estimator[SimpleSparkModel] = defaultCopy(extra)

    @DeveloperApi
    override def transformSchema(schema: StructType): StructType = {
      schema.add(StructField($(predictionCol), IntegerType, nullable = false))
    }
  }

  trait HasPredictionColumn extends Params {
    val predictionColumn = new SingleColumnCreatorParamWrapper[
        ml.param.Params { val predictionCol: SparkParam[String] }](
      "prediction column",
      None,
      _.predictionCol)
    setDefault(predictionColumn, "abcdefg")

    def getPredictionColumn(): String = $(predictionColumn)
    def setPredictionColumn(value: String): this.type = set(predictionColumn, value)
  }

  class SimpleSparkModelWrapper
    extends SparkModelWrapper[SimpleSparkModel, SimpleSparkEstimator]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report(extended: Boolean = true): Report = ???

    override protected def loadModel(
      ctx: ExecutionContext,
      path: String): SerializableSparkModel[SimpleSparkModel] = ???
  }

  class SimpleSparkEstimatorWrapper
    extends SparkEstimatorWrapper[SimpleSparkModel, SimpleSparkEstimator, SimpleSparkModelWrapper]
    with HasPredictionColumn {

    override val params: Array[Param[_]] = Array(predictionColumn)
    override def report(extended: Boolean = true): Report = ???
  }
} 
Example 78
Source File: DataFrameSplitterIntegSpec.scala    From seahorse   with Apache License 2.0 5 votes vote down vote up
package ai.deepsense.deeplang.doperations

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.scalatest.Matchers
import org.scalatest.prop.GeneratorDrivenPropertyChecks

import ai.deepsense.deeplang._
import ai.deepsense.deeplang.doperables.dataframe.DataFrame

class DataFrameSplitterIntegSpec
  extends DeeplangIntegTestSupport
  with GeneratorDrivenPropertyChecks
  with Matchers {

  "SplitDataFrame" should {
    "split randomly one df into two df in given range" in {

      val input = Range(1, 100)

      val parameterPairs = List(
        (0.0, 0),
        (0.3, 1),
        (0.5, 2),
        (0.8, 3),
        (1.0, 4))

      for((splitRatio, seed) <- parameterPairs) {
        val rdd = createData(input)
        val df = executionContext.dataFrameBuilder.buildDataFrame(createSchema, rdd)
        val (df1, df2) = executeOperation(
          executionContext,
          new Split()
            .setSplitMode(
              SplitModeChoice.Random()
                .setSplitRatio(splitRatio)
                .setSeed(seed / 2)))(df)
        validateSplitProperties(df, df1, df2)
      }
    }

    "split conditionally one df into two df in given range" in {

      val input = Range(1, 100)

      val condition = "value > 20"
      val predicate: Int => Boolean = _ > 20

      val (expectedDF1, expectedDF2) =
        (input.filter(predicate), input.filter(!predicate(_)))

      val rdd = createData(input)
      val df = executionContext.dataFrameBuilder.buildDataFrame(createSchema, rdd)
      val (df1, df2) = executeOperation(
        executionContext,
        new Split()
          .setSplitMode(
            SplitModeChoice.Conditional()
              .setCondition(condition)))(df)
      df1.sparkDataFrame.collect().map(_.get(0)) should contain theSameElementsAs expectedDF1
      df2.sparkDataFrame.collect().map(_.get(0)) should contain theSameElementsAs expectedDF2
      validateSplitProperties(df, df1, df2)
    }
  }

  private def createSchema: StructType = {
    StructType(List(
      StructField("value", IntegerType, nullable = false)
    ))
  }

  private def createData(data: Seq[Int]): RDD[Row] = {
    sparkContext.parallelize(data.map(Row(_)))
  }

  private def executeOperation(context: ExecutionContext, operation: DOperation)
                              (dataFrame: DataFrame): (DataFrame, DataFrame) = {
    val operationResult = operation.executeUntyped(Vector[DOperable](dataFrame))(context)
    val df1 = operationResult.head.asInstanceOf[DataFrame]
    val df2 = operationResult.last.asInstanceOf[DataFrame]
    (df1, df2)
  }

  def validateSplitProperties(inputDF: DataFrame, outputDF1: DataFrame, outputDF2: DataFrame)
      : Unit = {
    val dfCount = inputDF.sparkDataFrame.count()
    val df1Count = outputDF1.sparkDataFrame.count()
    val df2Count = outputDF2.sparkDataFrame.count()
    val rowsDf = inputDF.sparkDataFrame.collectAsList().asScala
    val rowsDf1 = outputDF1.sparkDataFrame.collectAsList().asScala
    val rowsDf2 = outputDF2.sparkDataFrame.collectAsList().asScala
    val intersect = rowsDf1.intersect(rowsDf2)
    intersect.size shouldBe 0
    (df1Count + df2Count) shouldBe dfCount
    rowsDf.toSet shouldBe rowsDf1.toSet.union(rowsDf2.toSet)
  }
} 
Example 79
Source File: AttributeSetSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  }

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)
  }

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)
  }

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
    assert(addSet.contains(aUpper))
    assert(addSet.contains(aLower))
    assert(addSet.contains(bUpper))
    assert(addSet.contains(bLower))
  }

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
  }

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)
  }

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
  }
} 
Example 80
Source File: SchemaColumnRandom.scala    From data-faker   with MIT License 5 votes vote down vote up
package com.dunnhumby.datafaker.schema.table.columns

import java.sql.{Date, Timestamp}
import com.dunnhumby.datafaker.YamlParser.YamlParserProtocol
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.{to_utc_timestamp, round, rand, from_unixtime, to_date}
import org.apache.spark.sql.types.{IntegerType, LongType}

trait SchemaColumnRandom[T] extends SchemaColumn

object SchemaColumnRandom {
  val FloatDP = 3
  val DoubleDP = 3

  def apply(name: String, min: Int, max: Int): SchemaColumn = SchemaColumnRandomNumeric(name, min, max)
  def apply(name: String, min: Long, max: Long): SchemaColumn = SchemaColumnRandomNumeric(name, min, max)
  def apply(name: String, min: Float, max: Float): SchemaColumn = SchemaColumnRandomNumeric(name, min, max)
  def apply(name: String, min: Double, max: Double): SchemaColumn = SchemaColumnRandomNumeric(name, min, max)
  def apply(name: String, min: Date, max: Date): SchemaColumn = SchemaColumnRandomDate(name, min, max)
  def apply(name: String, min: Timestamp, max: Timestamp): SchemaColumn = SchemaColumnRandomTimestamp(name, min, max)
  def apply(name: String): SchemaColumn = SchemaColumnRandomBoolean(name)
}

private case class SchemaColumnRandomNumeric[T: Numeric](override val name: String, min: T, max: T) extends SchemaColumnRandom[T] {
  override def column(rowID: Option[Column] = None): Column = {
    import Numeric.Implicits._

    (min, max) match {
      case (_: Int, _: Int) => round(rand() * (max - min) + min, 0).cast(IntegerType)
      case (_: Long, _: Long) => round(rand() * (max - min) + min, 0).cast(LongType)
      case (_: Float, _: Float) => round(rand() * (max - min) + min, SchemaColumnRandom.FloatDP)
      case (_: Double, _: Double) => round(rand() * (max - min) + min, SchemaColumnRandom.DoubleDP)
    }
  }
}

private case class SchemaColumnRandomTimestamp(override val name: String, min: Timestamp, max: Timestamp) extends SchemaColumnRandom[Timestamp] {
  override def column(rowID: Option[Column] = None): Column = {
    val minTime = min.getTime / 1000
    val maxTime = max.getTime / 1000
    to_utc_timestamp(from_unixtime(rand() * (maxTime - minTime) + minTime), "UTC")
  }
}

private case class SchemaColumnRandomDate(override val name: String, min: Date, max: Date) extends SchemaColumnRandom[Date] {
  val timestamp = SchemaColumnRandomTimestamp(name, new Timestamp(min.getTime), new Timestamp(max.getTime + 86400000))

  override def column(rowID: Option[Column] = None): Column = to_date(timestamp.column())
}

private case class SchemaColumnRandomBoolean(override val name: String) extends SchemaColumnRandom[Boolean] {
  override def column(rowID: Option[Column] = None): Column = rand() < 0.5f
}

object SchemaColumnRandomProtocol extends SchemaColumnRandomProtocol
trait SchemaColumnRandomProtocol extends YamlParserProtocol {

  import net.jcazevedo.moultingyaml._

  implicit object SchemaColumnRandomFormat extends YamlFormat[SchemaColumnRandom[_]] {

    override def read(yaml: YamlValue): SchemaColumnRandom[_] = {
      val fields = yaml.asYamlObject.fields
      val YamlString(name) = fields.getOrElse(YamlString("name"), deserializationError("name not set"))
      val YamlString(dataType) = fields.getOrElse(YamlString("data_type"), deserializationError(s"data_type not set for $name"))

      if (dataType == SchemaColumnDataType.Boolean) {
        SchemaColumnRandomBoolean(name)
      }
      else {
        val min = fields.getOrElse(YamlString("min"), deserializationError(s"min not set for $name"))
        val max = fields.getOrElse(YamlString("max"), deserializationError(s"max not set for $name"))

        dataType match {
          case SchemaColumnDataType.Int => SchemaColumnRandomNumeric(name, min.convertTo[Int], max.convertTo[Int])
          case SchemaColumnDataType.Long => SchemaColumnRandomNumeric(name, min.convertTo[Long], max.convertTo[Long])
          case SchemaColumnDataType.Float => SchemaColumnRandomNumeric(name, min.convertTo[Float], max.convertTo[Float])
          case SchemaColumnDataType.Double => SchemaColumnRandomNumeric(name, min.convertTo[Double], max.convertTo[Double])
          case SchemaColumnDataType.Date => SchemaColumnRandomDate(name, min.convertTo[Date], max.convertTo[Date])
          case SchemaColumnDataType.Timestamp => SchemaColumnRandomTimestamp(name, min.convertTo[Timestamp], max.convertTo[Timestamp])
          case _ => deserializationError(s"unsupported data_type: $dataType for ${SchemaColumnType.Random}")
        }
      }

    }

    override def write(obj: SchemaColumnRandom[_]): YamlValue = ???

  }

} 
Example 81
Source File: BlockingSource.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    BlockingSource.latch.await()
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
        Seq[Int]().toDS().toDF()
      }
      override def stop() {}
    }
  }

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}
    }
  }
}

object BlockingSource {
  var latch: CountDownLatch = null
} 
Example 82
Source File: GroupedIteratorSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 83
Source File: WholeStageCodegenSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 84
Source File: resources.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 85
Source File: SimplifyConditionalSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
    assertEquivalent(
      If(TrueLiteral, Literal(10), Literal(20)),
      Literal(10))

    assertEquivalent(
      If(FalseLiteral, Literal(10), Literal(20)),
      Literal(20))

    assertEquivalent(
      If(Literal.create(null, NullType), Literal(10), Literal(20)),
      Literal(20))
  }

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
    assertEquivalent(
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))
  }

  test("remove entire CaseWhen if only the else branch is reachable") {
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
      Literal(30))

    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))
  }

  test("remove entire CaseWhen if the first branch is always true") {
    assertEquivalent(
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
      Literal(5))

    // Test branch elimination and simplification in combination
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),
      Literal(5))

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
    assertEquivalent(
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
  }
} 
Example 86
Source File: RewriteDistinctAggregatesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 87
Source File: SubstituteUnresolvedOrdinals.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.types.IntegerType


class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = s.order.map {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      }
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = a.groupingExpressions.map {
        case ordinal @ Literal(index: Int, IntegerType) =>
          withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
        case other => other
      }
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
  }
} 
Example 88
Source File: RandomSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

    checkDoubleEvaluation(
      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
    checkDoubleEvaluation(
      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)
  }

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
  }
} 
Example 89
Source File: ObjectExpressionsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}


class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SPARK-16622: The returned value of the called method in Invoke can be null") {
    val inputRow = InternalRow.fromSeq(Seq((false, null)))
    val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val invoke = Invoke(inputObject, "_2", IntegerType)
    checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)
  }

  test("MapObjects should make copies of unsafe-backed data") {
    // test UnsafeRow-backed data
    val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
    val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
    val structExpected = new GenericArrayData(
      Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
    checkEvalutionWithUnsafeProjection(
      structEncoder.serializer.head, structExpected, structInputRow)

    // test UnsafeArray-backed data
    val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
    val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
    val arrayExpected = new GenericArrayData(
      Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
    checkEvalutionWithUnsafeProjection(
      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

    // test UnsafeMap-backed data
    val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
    val mapInputRow = InternalRow.fromSeq(Seq(Array(
      Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
    val mapExpected = new GenericArrayData(Seq(
      new ArrayBasedMapData(
        new GenericArrayData(Array(1, 2)),
        new GenericArrayData(Array(100, 200))),
      new ArrayBasedMapData(
        new GenericArrayData(Array(3, 4)),
        new GenericArrayData(Array(300, 400)))))
    checkEvalutionWithUnsafeProjection(
      mapEncoder.serializer.head, mapExpected, mapInputRow)
  }
} 
Example 90
Source File: ExpressionEvalHelperSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      s"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 91
Source File: ScalaUDFSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

} 
Example 92
Source File: MapDataSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 93
Source File: LogicalPlanSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType


class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1
      p
  }

  private val testRelation = LocalRelation()

  test("resolveOperator runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan resolveOperators function

    assert(invocationCount === 1)
  }

  test("resolveOperator runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 2)
  }

  test("resolveOperator skips all ready resolved plans") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan.foreach(_.setAnalyzed())
    plan resolveOperators function

    assert(invocationCount === 0)
  }

  test("resolveOperator skips partially resolved plans") {
    invocationCount = 0
    val plan1 = Project(Nil, testRelation)
    val plan2 = Project(Nil, plan1)
    plan1.foreach(_.setAnalyzed())
    plan2 resolveOperators function

    assert(invocationCount === 1)
  }

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = new LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)())) {
      override def isStreaming(): Boolean = true
    }

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output
    }

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
  }
} 
Example 94
Source File: SubstituteUnresolvedOrdinals.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.types.IntegerType


class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = s.order.map {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      }
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = a.groupingExpressions.map {
        case ordinal @ Literal(index: Int, IntegerType) =>
          withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
        case other => other
      }
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
  }
} 
Example 95
Source File: MetastoreRelationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  test("makeCopy and toJSON should work") {
    val table = CatalogTable(
      identifier = TableIdentifier("test", Some("db")),
      tableType = CatalogTableType.VIEW,
      storage = CatalogStorageFormat.empty,
      schema = StructType(StructField("a", IntegerType, true) :: Nil))
    val relation = MetastoreRelation("db", "test")(table, null)

    // No exception should be thrown
    relation.makeCopy(Array("db", "test"))
    // No exception should be thrown
    relation.toJSON
  }

  test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") {
    withTable("bar") {
      withTempView("foo") {
        sql("select 0 as id").createOrReplaceTempView("foo")
        // If we optimize the query in CTAS more than once, the following saveAsTable will fail
        // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
        sql("CREATE TABLE bar AS SELECT * FROM foo group by id")
        checkAnswer(spark.table("bar"), Row(0) :: Nil)
        val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
        assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table")
      }
    }
  }
} 
Example 96
Source File: SparkSQLExprMapperTest.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.impl

import java.util.Collections

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.opencypher.morpheus.api.value.MorpheusElement._
import org.opencypher.morpheus.impl.ExprEval._
import org.opencypher.morpheus.impl.SparkSQLExprMapper._
import org.opencypher.morpheus.testing.fixture.SparkSessionFixture
import org.opencypher.okapi.api.types.CTInteger
import org.opencypher.okapi.api.value.CypherValue.CypherMap
import org.opencypher.okapi.ir.api.expr._
import org.opencypher.okapi.relational.impl.table.RecordHeader
import org.opencypher.okapi.testing.BaseTestSuite

import scala.language.implicitConversions

class SparkSQLExprMapperTest extends BaseTestSuite with SparkSessionFixture {

  val vA: Var = Var("a")(CTInteger)
  val vB: Var = Var("b")(CTInteger)
  val header: RecordHeader = RecordHeader.from(vA, vB)

  it("converts prefix id expressions") {
    val id = 257L
    val prefix = 2.toByte
    val expr = PrefixId(ToId(IntegerLit(id)), prefix)
    expr.eval.asInstanceOf[Array[_]].toList should equal(prefix :: id.encodeAsMorpheusId.toList)
  }

  it("converts a CypherInteger to an ID") {
    val id = 257L
    val expr = ToId(IntegerLit(id))
    expr.eval.asInstanceOf[Array[_]].toList should equal(id.encodeAsMorpheusId.toList)
  }

  it("converts a CypherInteger to an ID and prefixes it") {
    val id = 257L
    val prefix = 2.toByte
    val expr = PrefixId(ToId(IntegerLit(id)), prefix)
    expr.eval.asInstanceOf[Array[_]].toList should equal(prefix :: id.encodeAsMorpheusId.toList)
  }

  it("converts a CypherInteger literal") {
    val id = 257L
    val expr = IntegerLit(id)
    expr.eval.asInstanceOf[Long] should equal(id)
  }

  private def convert(expr: Expr, header: RecordHeader = header): Column = {
    expr.asSparkSQLExpr(header, df, CypherMap.empty)
  }
  val df: DataFrame = sparkSession.createDataFrame(
    Collections.emptyList[Row](),
    StructType(Seq(StructField(header.column(vA), IntegerType), StructField(header.column(vB), IntegerType))))

  implicit def extractRecordHeaderFromResult[T](tuple: (RecordHeader, T)): RecordHeader = tuple._1
}

object ExprEval {

  implicit class ExprOps(val expr: Expr) extends AnyVal {

    def eval(implicit spark: SparkSession): Any = {
      val df = spark.createDataFrame(
        Collections.emptyList[Row](),
        StructType(Seq.empty))

      expr.asSparkSQLExpr(RecordHeader.empty, df, CypherMap.empty).expr.eval(InternalRow.empty)
    }
  }
} 
Example 97
Source File: YelpHelpers.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.morpheus.integration.yelp

import org.apache.spark.sql.types.{ArrayType, DateType, IntegerType, LongType}
import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions}
import org.opencypher.morpheus.api.io.GraphElement.sourceIdKey
import org.opencypher.morpheus.api.io.Relationship.{sourceEndNodeKey, sourceStartNodeKey}
import org.opencypher.morpheus.impl.table.SparkTable._
import org.opencypher.morpheus.integration.yelp.YelpConstants._

object YelpHelpers {

  case class YelpTables(
    userDf: DataFrame,
    businessDf: DataFrame,
    reviewDf: DataFrame
  )

  def loadYelpTables(inputPath: String)(implicit spark: SparkSession): YelpTables = {
    import spark.implicits._

    log("read business.json", 2)
    val rawBusinessDf = spark.read.json(s"$inputPath/business.json")
    log("read review.json", 2)
    val rawReviewDf = spark.read.json(s"$inputPath/review.json")
    log("read user.json", 2)
    val rawUserDf = spark.read.json(s"$inputPath/user.json")

    val businessDf = rawBusinessDf.select($"business_id".as(sourceIdKey), $"business_id", $"name", $"address", $"city", $"state")
    val reviewDf = rawReviewDf.select($"review_id".as(sourceIdKey), $"user_id".as(sourceStartNodeKey), $"business_id".as(sourceEndNodeKey), $"stars", $"date".cast(DateType))
    val userDf = rawUserDf.select(
      $"user_id".as(sourceIdKey),
      $"name",
      $"yelping_since".cast(DateType),
      functions.split($"elite", ",").cast(ArrayType(LongType)).as("elite"))

    YelpTables(userDf, businessDf, reviewDf)
  }

  def printYelpStats(inputPath: String)(implicit spark: SparkSession): Unit = {
    val rawBusinessDf = spark.read.json(s"$inputPath/business.json")
    val rawReviewDf = spark.read.json(s"$inputPath/review.json")

    import spark.implicits._

    rawBusinessDf.select($"city", $"state").distinct().show()
    rawBusinessDf.withColumnRenamed("business_id", "id")
      .join(rawReviewDf, $"id" === $"business_id")
      .groupBy($"city", $"state")
      .count().as("count")
      .orderBy($"count".desc, $"state".asc)
      .show(100)
  }

  def extractYelpCitySubset(inputPath: String, outputPath: String, city: String)(implicit spark: SparkSession): Unit = {
    import spark.implicits._

    def emailColumn(userId: String): Column = functions.concat($"$userId", functions.lit("@yelp.com"))

    val rawUserDf = spark.read.json(s"$inputPath/user.json")
    val rawReviewDf = spark.read.json(s"$inputPath/review.json")
    val rawBusinessDf = spark.read.json(s"$inputPath/business.json")

    val businessDf = rawBusinessDf.filter($"city" === city)
    val reviewDf = rawReviewDf
      .join(businessDf, Seq("business_id"), "left_semi")
      .withColumn("user_email", emailColumn("user_id"))
      .withColumnRenamed("stars", "stars_tmp")
      .withColumn("stars", $"stars_tmp".cast(IntegerType))
      .drop("stars_tmp")
    val userDf = rawUserDf
      .join(reviewDf, Seq("user_id"), "left_semi")
      .withColumn("email", emailColumn("user_id"))
    val friendDf = userDf
      .select($"email".as("user1_email"), functions.explode(functions.split($"friends", ", ")).as("user2_id"))
      .withColumn("user2_email", emailColumn("user2_id"))
      .select(s"user1_email", s"user2_email")

    businessDf.write.json(s"$outputPath/$cityGraphName/$yelpDB/business.json")
    reviewDf.write.json(s"$outputPath/$cityGraphName/$yelpDB/review.json")
    userDf.write.json(s"$outputPath/$cityGraphName/$yelpDB/user.json")
    friendDf.write.json(s"$outputPath/$cityGraphName/$yelpBookDB/friend.json")
  }

  implicit class DataFrameOps(df: DataFrame) {
    def prependIdColumn(idColumn: String, prefix: String): DataFrame =
      df.transformColumns(idColumn)(column => functions.concat(functions.lit(prefix), column).as(idColumn))
  }
} 
Example 98
Source File: Locus.scala    From hail   with MIT License 5 votes vote down vote up
package is.hail.variant

import is.hail.annotations.Annotation
import is.hail.check.Gen
import is.hail.expr.Parser
import is.hail.utils._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.json4s._

import scala.collection.JavaConverters._
import scala.language.implicitConversions

object Locus {
  val simpleContigs: Seq[String] = (1 to 22).map(_.toString) ++ Seq("X", "Y", "MT")

  def apply(contig: String, position: Int, rg: ReferenceGenome): Locus = {
    rg.checkLocus(contig, position)
    Locus(contig, position)
  }

  def annotation(contig: String, position: Int, rg: Option[ReferenceGenome]): Annotation = {
    rg match {
      case Some(ref) => Locus(contig, position, ref)
      case None => Annotation(contig, position)
    }
  }

  def sparkSchema: StructType =
    StructType(Array(
      StructField("contig", StringType, nullable = false),
      StructField("position", IntegerType, nullable = false)))

  def fromRow(r: Row): Locus = {
    Locus(r.getAs[String](0), r.getInt(1))
  }

  def gen(rg: ReferenceGenome): Gen[Locus] = for {
    (contig, length) <- Contig.gen(rg)
    pos <- Gen.choose(1, length)
  } yield Locus(contig, pos)

  def parse(str: String, rg: ReferenceGenome): Locus = {
    val elts = str.split(":")
    val size = elts.length
    if (size < 2)
      fatal(s"Invalid string for Locus. Expecting contig:pos -- found '$str'.")

    val contig = elts.take(size - 1).mkString(":")
    Locus(contig, elts(size - 1).toInt, rg)
  }

  def parseInterval(str: String, rg: ReferenceGenome, invalidMissing: Boolean = false): Interval =
    Parser.parseLocusInterval(str, rg, invalidMissing)

  def parseIntervals(arr: Array[String], rg: ReferenceGenome, invalidMissing: Boolean): Array[Interval] = arr.map(parseInterval(_, rg, invalidMissing))

  def parseIntervals(arr: java.util.List[String], rg: ReferenceGenome, invalidMissing: Boolean = false): Array[Interval] = parseIntervals(arr.asScala.toArray, rg, invalidMissing)

  def makeInterval(contig: String, start: Int, end: Int, includesStart: Boolean, includesEnd: Boolean,
    rgBase: ReferenceGenome, invalidMissing: Boolean = false): Interval = {
    val rg = rgBase.asInstanceOf[ReferenceGenome]
    rg.toLocusInterval(Interval(Locus(contig, start), Locus(contig, end), includesStart, includesEnd), invalidMissing)
  }
}

case class Locus(contig: String, position: Int) {
  def toRow: Row = Row(contig, position)

  def toJSON: JValue = JObject(
    ("contig", JString(contig)),
    ("position", JInt(position)))

  def copyChecked(rg: ReferenceGenome, contig: String = contig, position: Int = position): Locus = {
    rg.checkLocus(contig, position)
    Locus(contig, position)
  }

  def isAutosomalOrPseudoAutosomal(rg: ReferenceGenome): Boolean = isAutosomal(rg) || inXPar(rg) || inYPar(rg)

  def isAutosomal(rg: ReferenceGenome): Boolean = !(inX(rg) || inY(rg) || isMitochondrial(rg))

  def isMitochondrial(rg: ReferenceGenome): Boolean = rg.isMitochondrial(contig)

  def inXPar(rg: ReferenceGenome): Boolean = rg.inXPar(this)

  def inYPar(rg: ReferenceGenome): Boolean = rg.inYPar(this)

  def inXNonPar(rg: ReferenceGenome): Boolean = inX(rg) && !inXPar(rg)

  def inYNonPar(rg: ReferenceGenome): Boolean = inY(rg) && !inYPar(rg)

  private def inX(rg: ReferenceGenome): Boolean = rg.inX(contig)

  private def inY(rg: ReferenceGenome): Boolean = rg.inY(contig)

  override def toString: String = s"$contig:$position"
} 
Example 99
Source File: DataFrameComparisonTest.scala    From spark-tools   with Apache License 2.0 5 votes vote down vote up
package io.univalence.sparktest

import io.univalence.schema.SchemaComparator.SchemaError
import org.apache.spark.SparkContext
import org.apache.spark.sql.{ Row, SparkSession }
import org.apache.spark.sql.types.{ IntegerType, StructField, StructType }
import org.scalatest.FunSuite

class DataFrameComparisonTest extends FunSuite with SparkTest {

  val sharedSparkSession: SparkSession = ss
  val sc: SparkContext                 = ss.sparkContext

  // TODO : unordered
  ignore("should assertEquals unordered between equal DF") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(3, 2, 1).toDF("id")

    dfUT.assertEquals(dfExpected)
  }

  // TODO : unordered
  ignore("should not assertEquals unordered between DF with different contents") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(2, 1, 4).toDF("id")

    assertThrows[SparkTestError] {
      dfUT.assertEquals(dfExpected)
    }
  }

  test("should assertEquals ordered between equal DF") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(1, 2, 3).toDF("id")

    dfUT.assertEquals(dfExpected)
  }

  test("should not assertEquals ordered between DF with different contents") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(1, 3, 4).toDF("id")

    assertThrows[SparkTestError] {
      dfUT.assertEquals(dfExpected)
    }
  }

  test("should not assertEquals between DF with different schema") {
    val dfUT       = Seq(1, 2, 3).toDF("id")
    val dfExpected = Seq(1, 2, 3).toDF("di")

    assertThrows[SchemaError] {
      dfUT.assertEquals(dfExpected)
    }
  }

  test("assertEquals (DF & Seq) : a DF and a Seq with the same content are equal") {
    val seq = Seq(1, 2, 3)
    val df = ss.createDataFrame(
      sc.parallelize(seq.map(Row(_))),
      StructType(List(StructField("number", IntegerType, nullable = true)))
    )

    df.assertEquals(seq)
  }

  test("assertEquals (DF & Seq) : a DF and a Seq with different content are not equal") {
    val df    = Seq(1, 3, 3).toDF("number")
    val seqEx = Seq(1, 2, 3)

    assertThrows[SparkTestError] {
      df.assertEquals(seqEx)
    }
  }

  test("should assertEquals ordered between equal DF with columns containing special character") {
    val dfUT       = Seq(1, 2, 3).toDF("id.a")
    val dfExpected = Seq(2, 1, 4).toDF("id.a")

    assertThrows[SparkTestError] {
      dfUT.assertEquals(dfExpected)
    }
  }

  

} 
Example 100
Source File: MetastoreRelationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  test("makeCopy and toJSON should work") {
    val table = CatalogTable(
      identifier = TableIdentifier("test", Some("db")),
      tableType = CatalogTableType.VIEW,
      storage = CatalogStorageFormat.empty,
      schema = StructType(StructField("a", IntegerType, true) :: Nil))
    val relation = MetastoreRelation("db", "test")(table, null)

    // No exception should be thrown
    relation.makeCopy(Array("db", "test"))
    // No exception should be thrown
    relation.toJSON
  }

  test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") {
    withTable("bar") {
      withTempView("foo") {
        sql("select 0 as id").createOrReplaceTempView("foo")
        // If we optimize the query in CTAS more than once, the following saveAsTable will fail
        // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
        sql("CREATE TABLE bar AS SELECT * FROM foo group by id")
        checkAnswer(spark.table("bar"), Row(0) :: Nil)
        val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
        assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table")
      }
    }
  }
} 
Example 101
Source File: SubstituteUnresolvedOrdinals.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.types.IntegerType


class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = s.order.map {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      }
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = a.groupingExpressions.map {
        case ordinal @ Literal(index: Int, IntegerType) =>
          withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
        case other => other
      }
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
  }
} 
Example 102
Source File: ResolveTableValuedFunctions.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}


      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
          "numPartitions" -> IntegerType) {
        case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
          Range(start, end, step, Some(numPartitions))
      })
  )

  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
      builtinFunctions.get(u.functionName) match {
        case Some(tvf) =>
          val resolved = tvf.flatMap { case (argList, resolver) =>
            argList.implicitCast(u.functionArgs) match {
              case Some(casted) =>
                Some(resolver(casted.map(_.eval())))
              case _ =>
                None
            }
          }
          resolved.headOption.getOrElse {
            val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ")
            u.failAnalysis(
              s"""error: table-valued function ${u.functionName} with alternatives:
                |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
                |cannot be applied to: (${argTypes})""".stripMargin)
          }
        case _ =>
          u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
      }
  }
} 
Example 103
Source File: LogicalPlanSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType


class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1
      p
  }

  private val testRelation = LocalRelation()

  test("resolveOperator runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan resolveOperators function

    assert(invocationCount === 1)
  }

  test("resolveOperator runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan resolveOperators function

    assert(invocationCount === 2)
  }

  test("resolveOperator skips all ready resolved plans") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan.foreach(_.setAnalyzed())
    plan resolveOperators function

    assert(invocationCount === 0)
  }

  test("resolveOperator skips partially resolved plans") {
    invocationCount = 0
    val plan1 = Project(Nil, testRelation)
    val plan2 = Project(Nil, plan1)
    plan1.foreach(_.setAnalyzed())
    plan2 resolveOperators function

    assert(invocationCount === 1)
  }

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = new LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)())) {
      override def isStreaming(): Boolean = true
    }

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output
    }

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
  }
} 
Example 104
Source File: MapDataSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

  test("inequality tests") {
    def u(str: String): UTF8String = UTF8String.fromString(str)

    // test data
    val testMap1 = Map(u("key1") -> 1)
    val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
    val testMap3 = Map(u("key1") -> 1)
    val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

    // ArrayBasedMapData
    val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
    val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
    val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
    val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
    assert(testArrayMap1 !== testArrayMap3)
    assert(testArrayMap2 !== testArrayMap4)

    // UnsafeMapData
    val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
    val row = new GenericInternalRow(1)
    def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
      row.update(0, map)
      val unsafeRow = unsafeConverter.apply(row)
      unsafeRow.getMap(0).copy
    }
    assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
    assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
  }
} 
Example 105
Source File: ScalaUDFSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase,
      StringType,
      Literal.create(null, StringType) :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvalutionWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

} 
Example 106
Source File: ExpressionEvalHelperSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      s"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 107
Source File: ObjectExpressionsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types.{IntegerType, ObjectType}


class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("SPARK-16622: The returned value of the called method in Invoke can be null") {
    val inputRow = InternalRow.fromSeq(Seq((false, null)))
    val cls = classOf[Tuple2[Boolean, java.lang.Integer]]
    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
    val invoke = Invoke(inputObject, "_2", IntegerType)
    checkEvaluationWithGeneratedMutableProjection(invoke, null, inputRow)
  }

  test("MapObjects should make copies of unsafe-backed data") {
    // test UnsafeRow-backed data
    val structEncoder = ExpressionEncoder[Array[Tuple2[java.lang.Integer, java.lang.Integer]]]
    val structInputRow = InternalRow.fromSeq(Seq(Array((1, 2), (3, 4))))
    val structExpected = new GenericArrayData(
      Array(InternalRow.fromSeq(Seq(1, 2)), InternalRow.fromSeq(Seq(3, 4))))
    checkEvalutionWithUnsafeProjection(
      structEncoder.serializer.head, structExpected, structInputRow)

    // test UnsafeArray-backed data
    val arrayEncoder = ExpressionEncoder[Array[Array[Int]]]
    val arrayInputRow = InternalRow.fromSeq(Seq(Array(Array(1, 2), Array(3, 4))))
    val arrayExpected = new GenericArrayData(
      Array(new GenericArrayData(Array(1, 2)), new GenericArrayData(Array(3, 4))))
    checkEvalutionWithUnsafeProjection(
      arrayEncoder.serializer.head, arrayExpected, arrayInputRow)

    // test UnsafeMap-backed data
    val mapEncoder = ExpressionEncoder[Array[Map[Int, Int]]]
    val mapInputRow = InternalRow.fromSeq(Seq(Array(
      Map(1 -> 100, 2 -> 200), Map(3 -> 300, 4 -> 400))))
    val mapExpected = new GenericArrayData(Seq(
      new ArrayBasedMapData(
        new GenericArrayData(Array(1, 2)),
        new GenericArrayData(Array(100, 200))),
      new ArrayBasedMapData(
        new GenericArrayData(Array(3, 4)),
        new GenericArrayData(Array(300, 400)))))
    checkEvalutionWithUnsafeProjection(
      mapEncoder.serializer.head, mapExpected, mapInputRow)
  }
} 
Example 108
Source File: CallMethodViaReflectionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types.{IntegerType, StringType}


class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper {

  import CallMethodViaReflection._

  // Get rid of the $ so we are getting the companion object's name.
  private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$")
  private val dynamicClassName = classOf[ReflectDynamicClass].getName

  test("findMethod via reflection for static methods") {
    assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1"))
    assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined)
  }

  test("findMethod for a JDK library") {
    assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined)
  }

  test("class not found") {
    val ret = createExpr("some-random-class", "method").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("not found") && errorMsg.contains("class"))
  }

  test("method not found because name does not match") {
    val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))
  }

  test("method not found because there is no static method") {
    val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))
  }

  test("input type checking") {
    assert(CallMethodViaReflection(Seq.empty).checkInputDataTypes().isFailure)
    assert(CallMethodViaReflection(Seq(Literal(staticClassName))).checkInputDataTypes().isFailure)
    assert(CallMethodViaReflection(
      Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure)
    assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess)
  }

  test("invoking methods using acceptable types") {
    checkEvaluation(createExpr(staticClassName, "method1"), "m1")
    checkEvaluation(createExpr(staticClassName, "method2", 2), "m2")
    checkEvaluation(createExpr(staticClassName, "method3", 3), "m3")
    checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four")
  }

  private def createExpr(className: String, methodName: String, args: Any*) = {
    CallMethodViaReflection(
      Literal.create(className, StringType) +:
      Literal.create(methodName, StringType) +:
      args.map(Literal.apply)
    )
  }
} 
Example 109
Source File: AttributeSetSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.IntegerType

class AttributeSetSuite extends SparkFunSuite {

  val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
  val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
  val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
  val aSet = AttributeSet(aLower :: Nil)

  val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
  val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
  val bSet = AttributeSet(bUpper :: Nil)

  val aAndBSet = AttributeSet(aUpper :: bUpper :: Nil)

  test("sanity check") {
    assert(aUpper != aLower)
    assert(bUpper != bLower)
  }

  test("checks by id not name") {
    assert(aSet.contains(aUpper) === true)
    assert(aSet.contains(aLower) === true)
    assert(aSet.contains(fakeA) === false)

    assert(aSet.contains(bUpper) === false)
    assert(aSet.contains(bLower) === false)
  }

  test("++ preserves AttributeSet")  {
    assert((aSet ++ bSet).contains(aUpper) === true)
    assert((aSet ++ bSet).contains(aLower) === true)
  }

  test("extracts all references references") {
    val addSet = AttributeSet(Add(aUpper, Alias(bUpper, "test")()):: Nil)
    assert(addSet.contains(aUpper))
    assert(addSet.contains(aLower))
    assert(addSet.contains(bUpper))
    assert(addSet.contains(bLower))
  }

  test("dedups attributes") {
    assert(AttributeSet(aUpper :: aLower :: Nil).size === 1)
  }

  test("subset") {
    assert(aSet.subsetOf(aAndBSet) === true)
    assert(aAndBSet.subsetOf(aSet) === false)
  }

  test("equality") {
    assert(aSet != aAndBSet)
    assert(aAndBSet != aSet)
    assert(aSet != bSet)
    assert(bSet != aSet)

    assert(aSet == aSet)
    assert(aSet == AttributeSet(aUpper :: Nil))
  }
} 
Example 110
Source File: RewriteDistinctAggregatesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{If, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectSet, Count}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  val conf = SimpleCatalystConf(caseSensitiveAnalysis = false, groupByOrdinal = false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with non-partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        CollectSet('b).toAggregateExpression().as('agg2))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 111
Source File: SimplifyConditionalSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{IntegerType, NullType}


class SimplifyConditionalSuite extends PlanTest with PredicateHelper {

  object Optimize extends RuleExecutor[LogicalPlan] {
    val batches = Batch("SimplifyConditionals", FixedPoint(50), SimplifyConditionals) :: Nil
  }

  protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
    val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
    val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
    comparePlans(actual, correctAnswer)
  }

  private val trueBranch = (TrueLiteral, Literal(5))
  private val normalBranch = (NonFoldableLiteral(true), Literal(10))
  private val unreachableBranch = (FalseLiteral, Literal(20))
  private val nullBranch = (Literal.create(null, NullType), Literal(30))

  test("simplify if") {
    assertEquivalent(
      If(TrueLiteral, Literal(10), Literal(20)),
      Literal(10))

    assertEquivalent(
      If(FalseLiteral, Literal(10), Literal(20)),
      Literal(20))

    assertEquivalent(
      If(Literal.create(null, NullType), Literal(10), Literal(20)),
      Literal(20))
  }

  test("remove unreachable branches") {
    // i.e. removing branches whose conditions are always false
    assertEquivalent(
      CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None),
      CaseWhen(normalBranch :: Nil, None))
  }

  test("remove entire CaseWhen if only the else branch is reachable") {
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))),
      Literal(30))

    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: Nil, None),
      Literal.create(null, IntegerType))
  }

  test("remove entire CaseWhen if the first branch is always true") {
    assertEquivalent(
      CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None),
      Literal(5))

    // Test branch elimination and simplification in combination
    assertEquivalent(
      CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch
        :: Nil, None),
      Literal(5))

    // Make sure this doesn't trigger if there is a non-foldable branch before the true branch
    assertEquivalent(
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None),
      CaseWhen(normalBranch :: trueBranch :: normalBranch :: Nil, None))
  }
} 
Example 112
Source File: resources.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 113
Source File: WholeStageCodegenSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {

  test("range/filter should be combined") {
    val df = spark.range(10).filter("id = 1").selectExpr("id + 1")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)
    assert(df.collect() === Array(Row(2)))
  }

  test("Aggregate should be included in WholeStageCodegen") {
    val df = spark.range(10).groupBy().agg(max(col("id")), avg(col("id")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(9, 4.5)))
  }

  test("Aggregate with grouping keys should be included in WholeStageCodegen") {
    val df = spark.range(3).groupBy("id").count().orderBy("id")
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
  }

  test("BroadcastHashJoin should be included in WholeStageCodegen") {
    val rdd = spark.sparkContext.makeRDD(Seq(Row(1, "1"), Row(1, "1"), Row(2, "2")))
    val schema = new StructType().add("k", IntegerType).add("v", StringType)
    val smallDF = spark.createDataFrame(rdd, schema)
    val df = spark.range(10).join(broadcast(smallDF), col("k") === col("id"))
    assert(df.queryExecution.executedPlan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[BroadcastHashJoinExec]).isDefined)
    assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
  }

  test("Sort should be included in WholeStageCodegen") {
    val df = spark.range(3, 0, -1).toDF().sort(col("id"))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]).isDefined)
    assert(df.collect() === Array(Row(1), Row(2), Row(3)))
  }

  test("MapElements should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = spark.range(10).map(_.toString)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
    assert(ds.collect() === 0.until(10).map(_.toString).toArray)
  }

  test("typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 2, 4, 6, 8))
  }

  test("back-to-back typed filter should be included in WholeStageCodegen") {
    val ds = spark.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
      p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
    assert(ds.collect() === Array(0, 6))
  }

  test("simple typed UDAF should be included in WholeStageCodegen") {
    import testImplicits._

    val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS()
      .groupByKey(_._1).agg(typed.sum(_._2))

    val plan = ds.queryExecution.executedPlan
    assert(plan.find(p =>
      p.isInstanceOf[WholeStageCodegenExec] &&
        p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[HashAggregateExec]).isDefined)
    assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0)))
  }
} 
Example 114
Source File: GroupedIteratorSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 115
Source File: CarbonDataFrameExample.scala    From CarbonDataLearning   with GNU General Public License v3.0 5 votes vote down vote up
package org.github.xubo245.carbonDataLearning.example

import org.apache.carbondata.examples.util.ExampleUtils
import org.apache.spark.sql.{SaveMode, SparkSession}

object CarbonDataFrameExample {

  def main(args: Array[String]) {
    val spark = ExampleUtils.createCarbonSession("CarbonDataFrameExample")
    exampleBody(spark)
    spark.close()
  }

  def exampleBody(spark : SparkSession): Unit = {
    // Writes Dataframe to CarbonData file:
    import spark.implicits._
    val df = spark.sparkContext.parallelize(1 to 100)
      .map(x => ("a" + x % 10, "b", x))
      .toDF("c1", "c2", "number")

    // Saves dataframe to carbondata file
    df.write
      .format("carbondata")
      .option("tableName", "carbon_df_table")
      .option("partitionColumns", "c1")  // a list of column names
      .mode(SaveMode.Overwrite)
      .save()

    spark.sql(""" SELECT * FROM carbon_df_table """).show()

    spark.sql("SHOW PARTITIONS carbon_df_table").show()

    // Specify schema
    import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
    val customSchema = StructType(Array(
      StructField("c1", StringType),
      StructField("c2", StringType),
      StructField("number", IntegerType)))

    // Reads carbondata to dataframe
    val carbondf = spark.read
      .format("carbondata")
      .schema(customSchema)
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "carbon_df_table")
      .load()

    df.write
      .format("csv")
      .option("tableName", "csv_df_table")
      .option("partitionColumns", "c1") // a list of column names
      //      .option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")
      .mode(SaveMode.Overwrite)
      .csv("/Users/xubo/Desktop/xubo/git/carbondata3/examples/spark2/target/csv/1.csv")

    // Reads carbondata to dataframe
    val carbondf2 = spark.read
      .format("csv")
      .schema(customSchema)
      // .option("dbname", "db_name") the system will use "default" as dbname if not set this option
      .option("tableName", "csv_df_table")

      //      .option("timestampFormat", "yyyy/MM/dd HH:mm:ss ZZ")
      .load("/Users/xubo/Desktop/xubo/git/carbondata3/examples/spark2/target/csv")

    carbondf2.show()


    // Dataframe operations
    carbondf.printSchema()
    carbondf.select($"c1", $"number" + 10).show()
    carbondf.filter($"number" > 31).show()

    spark.sql("DROP TABLE IF EXISTS carbon_df_table")
  }
} 
Example 116
Source File: TestMetadataConstructor.scala    From spark-salesforce   with Apache License 2.0 5 votes vote down vote up
package com.springml.spark.salesforce.metadata

import org.apache.spark.sql.types.{StructType, StringType, IntegerType, LongType,
  FloatType, DateType, TimestampType, BooleanType, StructField}
import org.scalatest.FunSuite
import com.springml.spark.salesforce.Utils


class TestMetadataConstructor extends FunSuite {

  test("Test Metadata generation") {
    val columnNames = List("c1", "c2", "c3", "c4")
    val columnStruct = columnNames.map(colName => StructField(colName, StringType, true))
    val schema = StructType(columnStruct)

    val schemaString = MetadataConstructor.generateMetaString(schema,"sampleDataSet", Utils.metadataConfig(null))
    assert(schemaString.length > 0)
    assert(schemaString.contains("sampleDataSet"))
  }

  test("Test Metadata generation With Custom MetadataConfig") {
    val columnNames = List("c1", "c2", "c3", "c4")
    val intField = StructField("intCol", IntegerType, true)
    val longField = StructField("longCol", LongType, true)
    val floatField = StructField("floatCol", FloatType, true)
    val dateField = StructField("dateCol", DateType, true)
    val timestampField = StructField("timestampCol", TimestampType, true)
    val stringField = StructField("stringCol", StringType, true)
    val someTypeField = StructField("someTypeCol", BooleanType, true)

    val columnStruct = Array[StructField] (intField, longField, floatField, dateField, timestampField, stringField, someTypeField)

    val schema = StructType(columnStruct)

    var metadataConfig = Map("string" -> Map("wave_type" -> "Text"))
    metadataConfig += ("integer" -> Map("wave_type" -> "Numeric", "precision" -> "10", "scale" -> "0", "defaultValue" -> "100"))
    metadataConfig += ("float" -> Map("wave_type" -> "Numeric", "precision" -> "10", "scale" -> "2"))
    metadataConfig += ("long" -> Map("wave_type" -> "Numeric", "precision" -> "18", "scale" -> "0"))
    metadataConfig += ("date" -> Map("wave_type" -> "Date", "format" -> "yyyy/MM/dd"))
    metadataConfig += ("timestamp" -> Map("wave_type" -> "Date", "format" -> "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"))


    val schemaString = MetadataConstructor.generateMetaString(schema, "sampleDataSet", metadataConfig)
    assert(schemaString.length > 0)
    assert(schemaString.contains("sampleDataSet"))
    assert(schemaString.contains("Numeric"))
    assert(schemaString.contains("precision"))
    assert(schemaString.contains("scale"))
    assert(schemaString.contains("18"))
    assert(schemaString.contains("Text"))
    assert(schemaString.contains("Date"))
    assert(schemaString.contains("format"))
    assert(schemaString.contains("defaultValue"))
    assert(schemaString.contains("100"))
    assert(schemaString.contains("yyyy/MM/dd"))
    assert(schemaString.contains("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"))
  }
} 
Example 117
Source File: SparkScoreDoc.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd.models

import org.apache.lucene.document.Document
import org.apache.lucene.index.IndexableField
import org.apache.lucene.search.{IndexSearcher, ScoreDoc}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.zouzias.spark.lucenerdd.models.SparkScoreDoc.inferNumericType
import org.zouzias.spark.lucenerdd.models.SparkScoreDoc.{DocIdField, ScoreField, ShardField}

import scala.collection.JavaConverters._

sealed trait FieldType extends Serializable
object TextType extends FieldType
object IntType extends FieldType
object DoubleType extends FieldType
object LongType extends FieldType
object FloatType extends FieldType



  private def inferNumericType(num: Number): FieldType = {
    num match {
      case _: java.lang.Double => DoubleType
      case _: java.lang.Long => LongType
      case _: java.lang.Integer => IntType
      case _: java.lang.Float => FloatType
      case _ => TextType
    }
  }
} 
Example 118
Source File: SummarizeSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{ LongType, IntegerType, DoubleType }

class SummarizeSpec extends MultiPartitionSuite {

  override val defaultResourceDir: String = "/timeseries/summarize"

  it should "`summarize` correctly" in {
    val expectedSchema = Schema("volume_sum" -> DoubleType)
    val expectedResults = Array[Row](new GenericRowWithSchema(Array(0L, 7800.0), expectedSchema))

    def test(rdd: TimeSeriesRDD): Unit = {
      val results = rdd.summarize(Summarizers.sum("volume"))
      assert(results.schema == expectedSchema)
      assert(results.collect().deep == expectedResults.deep)
    }

    {
      val volumeRdd = fromCSV("Volume.csv", Schema("id" -> IntegerType, "volume" -> LongType))
      withPartitionStrategy(volumeRdd)(DEFAULT)(test)
    }

  }

  it should "`summarize` per key correctly" in {
    val expectedSchema = Schema("id" -> IntegerType, "volume_sum" -> DoubleType)
    val expectedResults = Array[Row](
      new GenericRowWithSchema(Array(0L, 7, 4100.0), expectedSchema),
      new GenericRowWithSchema(Array(0L, 3, 3700.0), expectedSchema)
    )

    def test(rdd: TimeSeriesRDD): Unit = {
      val results = rdd.summarize(Summarizers.sum("volume"), Seq("id"))
      assert(results.schema == expectedSchema)
      assert(results.collect().sortBy(_.getAs[Int]("id")).deep == expectedResults.sortBy(_.getAs[Int]("id")).deep)
    }

    {
      val volumeTSRdd = fromCSV("Volume.csv", Schema("id" -> IntegerType, "volume" -> LongType))
      withPartitionStrategy(volumeTSRdd)(DEFAULT)(test)
    }
  }
} 
Example 119
Source File: SummarizeCyclesSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.types.{ DoubleType, IntegerType, LongType }

class SummarizeCyclesSpec extends MultiPartitionSuite with TimeSeriesTestData with TimeTypeSuite {

  override val defaultResourceDir: String = "/timeseries/summarizecycles"
  private val volumeSchema = Schema("id" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType)
  private val volume2Schema = Schema("id" -> IntegerType, "volume" -> LongType)
  private val volumeWithGroupSchema = Schema(
    "id" -> IntegerType, "group" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType
  )

  "SummarizeCycles" should "pass `SummarizeSingleColumn` test." in {
    withAllTimeType {
      val resultTSRdd = fromCSV("SummarizeSingleColumn.results", Schema("volume_sum" -> DoubleType))

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeCycles(Summarizers.sum("volume"))
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)
      }

      val volumeTSRdd = fromCSV("Volume.csv", volumeSchema)
      withPartitionStrategy(volumeTSRdd)(DEFAULT)(test)
    }
  }

  it should "pass `SummarizeSingleColumnPerKey` test, i.e. with additional a single key." in {
    withAllTimeType {
      val resultTSRdd = fromCSV(
        "SummarizeSingleColumnPerKey.results",
        Schema("id" -> IntegerType, "volume_sum" -> DoubleType)
      )

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeCycles(Summarizers.sum("volume"), Seq("id"))
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)
      }

      val volumeTSRdd = fromCSV("Volume2.csv", volume2Schema)
      withPartitionStrategy(volumeTSRdd)(DEFAULT)(test)
    }
  }

  it should "pass `SummarizeSingleColumnPerSeqOfKeys` test, i.e. with additional a sequence of keys." in {
    withAllTimeType {
      val resultTSRdd = fromCSV(
        "SummarizeSingleColumnPerSeqOfKeys.results",
        Schema("id" -> IntegerType, "group" -> IntegerType, "volume_sum" -> DoubleType)
      )

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeCycles(Summarizers.sum("volume"), Seq("id", "group"))
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)
      }

      val volumeTSRdd = fromCSV("VolumeWithIndustryGroup.csv", volumeWithGroupSchema)
      withPartitionStrategy(volumeTSRdd)(DEFAULT)(test)
    }
  }

  it should "pass generated cycle data test" in {
    // TODO: The way cycleData works now doesn't support changing time type.
    val testData = cycleData1

    def sum(rdd: TimeSeriesRDD): TimeSeriesRDD = {
      rdd.summarizeCycles(Summarizers.compose(Summarizers.count(), Summarizers.sum("v1")))
    }

    withPartitionStrategyCompare(testData)(DEFAULT)(sum)
  }
} 
Example 120
Source File: TimeSeriesRDDCacheSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }
import org.scalatest.concurrent.Timeouts
import org.scalatest.tagobjects.Slow
import org.scalatest.time.{ Second, Span }

class TimeSeriesRDDCacheSpec extends TimeSeriesSuite with Timeouts {

  "TimeSeriesRDD" should "correctly cache data" taggedAs Slow in {
    withResource("/timeseries/csv/Price.csv") { source =>
      val priceSchema = Schema("id" -> IntegerType, "price" -> DoubleType)
      val timeSeriesRdd = CSV.from(sqlContext, "file://" + source, sorted = true, schema = priceSchema)

      val slowTimeSeriesRdd = timeSeriesRdd.addColumns("new_column" -> DoubleType -> {
        row: Row =>
          Thread.sleep(500L)
          row.getAs[Double]("price") + 1.0
      })

      // run a dummy addColumns() to initialize TSRDD's internal state
      slowTimeSeriesRdd.addColumns("foo_column" -> DoubleType -> { _ => 1.0 })

      slowTimeSeriesRdd.cache()
      assert(slowTimeSeriesRdd.count() == 12)

      // this test succeeds only if all representations are correctly cached
      failAfter(Span(1, Second)) {
        assert(slowTimeSeriesRdd.toDF.collect().length == 12)
        assert(slowTimeSeriesRdd.orderedRdd.count() == 12)
        assert(slowTimeSeriesRdd.asInstanceOf[TimeSeriesRDDImpl].unsafeOrderedRdd.count == 12)
      }
    }
  }
} 
Example 121
Source File: CompositeSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.{ CSV, Summarizers, TimeSeriesRDD, TimeSeriesSuite }
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.types.{ DoubleType, IntegerType, StructType }

class CompositeSummarizerSpec extends SummarizerSuite {
  // Reuse mean summarizer data
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  var priceTSRdd: TimeSeriesRDD = _

  lazy val init: Unit = {
    priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
  }

  "CompositeSummarizer" should "compute `mean` and `stddev` correctly" in {
    init
    val result = priceTSRdd.summarize(
      Summarizers.compose(Summarizers.mean("price"), Summarizers.stddev("price"))
    )
    val row = result.first()

    assert(row.getAs[Double]("price_mean") === 3.25)
    assert(row.getAs[Double]("price_stddev") === 1.8027756377319946)
  }

  it should "throw exception for conflicting output columns" in {
    init
    intercept[Exception] {
      priceTSRdd.summarize(Summarizers.compose(Summarizers.mean("price"), Summarizers.mean("price")))
    }
  }

  it should "handle conflicting output columns using prefix" in {
    init
    val result = priceTSRdd.summarize(
      Summarizers.compose(Summarizers.mean("price"), Summarizers.mean("price").prefix("prefix"))
    )

    val row = result.first()

    assert(row.getAs[Double]("price_mean") === 3.25)
    assert(row.getAs[Double]("prefix_price_mean") === 3.25)
  }

  it should "handle null values" in {
    init
    val inputWithNull = insertNullRows(priceTSRdd, "price")
    val row = inputWithNull.summarize(
      Summarizers.compose(
        Summarizers.count(),
        Summarizers.count("id"),
        Summarizers.count("price")
      )
    ).first()

    val count = priceTSRdd.count()
    assert(row.getAs[Long]("count") == 2 * count)
    assert(row.getAs[Long]("id_count") == 2 * count)
    assert(row.getAs[Long]("price_count") == count)
  }
} 
Example 122
Source File: MeanSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.{ Summarizers, TimeSeriesSuite }
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class MeanSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  "MeanSummarizer" should "compute `mean` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val result = priceTSRdd.summarize(Summarizers.mean("price")).first()
    assert(result.getAs[Double]("price_mean") === 3.25)
  }

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    assertEquals(
      priceTSRdd.summarize(Summarizers.mean("price")),
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.mean("price"))
    )
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllProperties)(Summarizers.mean("x1"))
    summarizerPropertyTest(AllProperties)(Summarizers.mean("x2"))
  }
} 
Example 123
Source File: ExtremeSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.rdd.function.summarize.summarizer.Summarizer
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.{ SummarizerFactory, SummarizerSuite }
import com.twosigma.flint.timeseries.{ CSV, Summarizers, TimeSeriesRDD, TimeSeriesSuite }
import org.apache.spark.sql.types.{ DataType, DoubleType, FloatType, IntegerType, LongType, StructType }
import java.util.Random

import org.apache.spark.sql.Row

class ExtremeSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  private def test[T](
    dataType: DataType,
    randValue: Row => Any,
    summarizer: String => SummarizerFactory,
    reduceFn: (T, T) => T,
    inputColumn: String,
    outputColumn: String
  ): Unit = {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType)).addColumns(
      inputColumn -> dataType -> randValue
    )

    val data = priceTSRdd.collect().map{ row => row.getAs[T](inputColumn) }

    val trueExtreme = data.reduceLeft[T]{ case (x, y) => reduceFn(x, y) }

    val result = priceTSRdd.summarize(summarizer(inputColumn))

    val extreme = result.first().getAs[T](outputColumn)
    val outputType = result.schema(outputColumn).dataType

    assert(outputType == dataType, s"$outputType")
    assert(trueExtreme === extreme, s"extreme: $extreme, trueExtreme: $trueExtreme, data: ${data.toSeq}")
  }

  "MaxSummarizer" should "compute double max correctly" in {
    val rand = new Random()
    test[Double](DoubleType, { _: Row => rand.nextDouble() }, Summarizers.max, math.max, "x", "x_max")
  }

  it should "compute long max correctly" in {
    val rand = new Random()
    test[Long](LongType, { _: Row => rand.nextLong() }, Summarizers.max, math.max, "x", "x_max")
  }

  it should "compute float max correctly" in {
    val rand = new Random()
    test[Float](FloatType, { _: Row => rand.nextFloat() }, Summarizers.max, math.max, "x", "x_max")
  }

  it should "compute int max correctly" in {
    val rand = new Random()
    test[Int](IntegerType, { _: Row => rand.nextInt() }, Summarizers.max, math.max, "x", "x_max")
  }

  "MinSummarizer" should "compute double min correctly" in {
    val rand = new Random()
    test[Double](DoubleType, { _: Row => rand.nextDouble() }, Summarizers.min, math.min, "x", "x_min")
  }

  it should "compute long min correctly" in {
    val rand = new Random()
    test[Long](LongType, { _: Row => rand.nextLong() }, Summarizers.min, math.min, "x", "x_min")
  }

  it should "compute float min correctly" in {
    val rand = new Random()
    test[Float](FloatType, { _: Row => rand.nextFloat() }, Summarizers.min, math.min, "x", "x_min")
  }

  it should "compute int min correctly" in {
    val rand = new Random()
    test[Int](IntegerType, { _: Row => rand.nextInt() }, Summarizers.min, math.min, "x", "x_min")
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllProperties)(Summarizers.max("x1"))
    summarizerPropertyTest(AllProperties)(Summarizers.min("x2"))
  }

  it should "ignore null values" in {
    val input = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val inputWithNull = insertNullRows(input, "price")

    assertEquals(
      input.summarize(Summarizers.min("price")),
      inputWithNull.summarize(Summarizers.min("price"))
    )
  }
} 
Example 124
Source File: GeometricMeanSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.{ Summarizers, Windows }
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class GeometricMeanSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/geometricmeansummarizer"

  "GeometricMeanSummarizer" should "compute geometric mean correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    ))
    val results = priceTSRdd.summarize(Summarizers.geometricMean("price"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_geometricMean") === 2.621877636494)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_geometricMean") === 2.667168275340)
  }

  it should "compute geometric mean with a zero correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    ))
    var results = priceTSRdd.summarize(Summarizers.geometricMean("priceWithZero")).collect()
    assert(results.head.getAs[Double]("priceWithZero_geometricMean") === 0.0)

    // Test that having a zero exit the window still computes correctly.
    results = priceTSRdd.coalesce(1).summarizeWindows(
      Windows.pastAbsoluteTime("50 ns"),
      Summarizers.geometricMean("priceWithZero")
    ).collect()
    assert(results.head.getAs[Double]("priceWithZero_geometricMean") === 0.0)
    assert(results.last.getAs[Double]("priceWithZero_geometricMean") === 5.220043408524)
  }

  it should "compute geometric mean with negative values correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    ))
    val results = priceTSRdd.summarize(Summarizers.geometricMean("priceWithNegatives"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("priceWithNegatives_geometricMean")
      === -2.621877636494)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("priceWithNegatives_geometricMean")
      === 2.667168275340)
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.geometricMean("x1"))
  }
} 
Example 125
Source File: DotProductSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.Summarizers
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class DotProductSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/dotproductsummarizer"

  "DotProductSummarizer" should "compute dot product correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val results = priceTSRdd.summarize(Summarizers.dotProduct("price", "price"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price_dotProduct") === 72.25)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price_dotProduct") === 90.25)
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.dotProduct("x1", "x2"))
  }
} 
Example 126
Source File: ProductSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.{ Summarizers, Windows }
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class ProductSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/productsummarizer"

  "ProductSummarizer" should "compute product correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    ))
    val results = priceTSRdd.summarize(Summarizers.product("price"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_product") === 324.84375)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_product") === 360.0)
  }

  it should "compute product with a zero correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    ))
    var results = priceTSRdd.summarize(Summarizers.product("priceWithZero")).collect()
    assert(results.head.getAs[Double]("priceWithZero_product") === 0.0)

    // Test that having a zero exit the window still computes correctly.
    results = priceTSRdd.coalesce(1).summarizeWindows(
      Windows.pastAbsoluteTime("50 ns"),
      Summarizers.product("priceWithZero")
    ).collect()
    assert(results.head.getAs[Double]("priceWithZero_product") === 0.0)
    assert(results.last.getAs[Double]("priceWithZero_product") === 742.5)
  }

  it should "compute product with negative values correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema(
      "id" -> IntegerType,
      "price" -> DoubleType,
      "priceWithZero" -> DoubleType,
      "priceWithNegatives" -> DoubleType
    ))
    val results = priceTSRdd.summarize(Summarizers.product("priceWithNegatives"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("priceWithNegatives_product") === -324.84375)
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("priceWithNegatives_product") === 360.0)
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.product("x1"))
  }
} 
Example 127
Source File: StandardizedMomentSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.Summarizers
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class StandardizedMomentSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/standardizedmomentsummarizer"

  "SkewnessSummarizer" should "compute skewness correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val results = priceTSRdd.summarize(Summarizers.skewness("price"))
    assert(results.collect().head.getAs[Double]("price_skewness") === 0.0)
  }

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    assertEquals(
      priceTSRdd.summarize(Summarizers.skewness("price")),
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.skewness("price"))
    )
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.skewness("x1"))
  }

  "KurtosisSummarizer" should "compute kurtosis correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val results = priceTSRdd.summarize(Summarizers.kurtosis("price"))
    assert(results.collect().head.getAs[Double]("price_kurtosis") === -1.2167832167832167)
  }

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    assertEquals(
      priceTSRdd.summarize(Summarizers.kurtosis("price")),
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.kurtosis("price"))
    )
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.kurtosis("x1"))
  }
} 
Example 128
Source File: ZScoreSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer.subtractable

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.Summarizers
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class ZScoreSummarizerSpec extends SummarizerSuite {

  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/zscoresummarizer"

  "ZScoreSummarizer" should "compute in-sample `zScore` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val expectedSchema = Schema("price_zScore" -> DoubleType)
    val expectedResults = Array[Row](new GenericRowWithSchema(Array(0L, 1.5254255396193801), expectedSchema))
    val results = priceTSRdd.summarize(Summarizers.zScore("price", true))
    assert(results.schema == expectedSchema)
    assert(results.collect().deep == expectedResults.deep)
  }

  it should "compute out-of-sample `zScore` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    val expectedSchema = Schema("price_zScore" -> DoubleType)
    val expectedResults = Array[Row](new GenericRowWithSchema(Array(0L, 1.8090680674665818), expectedSchema))
    val results = priceTSRdd.summarize(Summarizers.zScore("price", false))
    assert(results.schema == expectedSchema)
    assert(results.collect().deep == expectedResults.deep)
  }

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    assertEquals(
      priceTSRdd.summarize(Summarizers.zScore("price", true)),
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.zScore("price", true))
    )
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.zScore("x1", true))
    summarizerPropertyTest(AllPropertiesAndSubtractable)(Summarizers.zScore("x2", false))
  }
} 
Example 129
Source File: StandardDeviationSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.Summarizers
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class StandardDeviationSummarizerSpec extends SummarizerSuite {
  // It is by intention to reuse the files
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  "StandardDeviationSummarizer" should "compute `stddev` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType)).addColumns(
      "price2" -> DoubleType -> { r: Row => r.getAs[Double]("price") },
      "price3" -> DoubleType -> { r: Row => -r.getAs[Double]("price") },
      "price4" -> DoubleType -> { r: Row => r.getAs[Double]("price") * 2 },
      "price5" -> DoubleType -> { r: Row => 0d }
    )

    val result = priceTSRdd.summarize(Summarizers.stddev("price")).first()
    assert(result.getAs[Double]("price_stddev") === 1.802775638)
  }

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    assertEquals(
      priceTSRdd.summarize(Summarizers.stddev("price")),
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.stddev("price"))
    )
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllProperties)(Summarizers.stddev("x1"))
  }
} 
Example 130
Source File: PredicateSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.{ Summarizers, TimeSeriesRDD }
import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class PredicateSummarizerSpec extends SummarizerSuite {
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"
  var priceTSRdd: TimeSeriesRDD = _

  private lazy val init = {
    priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
  }

  "PredicateSummarizer" should "return the same results as filtering TSRDD first" in {
    init
    val summarizer = Summarizers.compose(Summarizers.mean("price"), Summarizers.stddev("price"))

    val predicate: Int => Boolean = id => id == 3
    val resultWithPredicate = priceTSRdd.summarize(summarizer.where(predicate)("id")).first()

    val filteredTSRDD = priceTSRdd.keepRows {
      row: Row => row.getAs[Int]("id") == 3
    }
    val filteredResults = filteredTSRDD.summarize(summarizer).first()

    assert(resultWithPredicate.getAs[Double]("price_mean") === filteredResults.getAs[Double]("price_mean"))
    assert(resultWithPredicate.getAs[Double]("price_stddev") === filteredResults.getAs[Double]("price_stddev"))

    assertEquals(
      priceTSRdd.summarize(summarizer.where(predicate)("id")),
      insertNullRows(priceTSRdd, "price").summarize(summarizer.where(predicate)("id"))
    )
  }

  it should "pass summarizer property test" in {
    val predicate: Double => Boolean = num => num > 0
    summarizerPropertyTest(AllProperties)(Summarizers.sum("x1").where(predicate)("x2"))
  }
} 
Example 131
Source File: VarianceSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.Summarizers
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class VarianceSummarizerSpec extends SummarizerSuite {
  // It is by intention to reuse the files
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/meansummarizer"

  "StandardDeviationSummarizer" should "compute `stddev` correctly" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType)).addColumns(
      "price2" -> DoubleType -> { r: Row => r.getAs[Double]("price") },
      "price3" -> DoubleType -> { r: Row => -r.getAs[Double]("price") },
      "price4" -> DoubleType -> { r: Row => r.getAs[Double]("price") * 2 },
      "price5" -> DoubleType -> { r: Row => 0d }
    )

    val result = priceTSRdd.summarize(Summarizers.variance("price")).first()
    assert(result.getAs[Double]("price_variance") === 3.250000000)
  }

  it should "ignore null values" in {
    val priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    assertEquals(
      priceTSRdd.summarize(Summarizers.variance("price")),
      insertNullRows(priceTSRdd, "price").summarize(Summarizers.variance("price"))
    )
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllProperties)(Summarizers.variance("x1"))
  }
} 
Example 132
Source File: CovarianceSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries.summarize.summarizer

import com.twosigma.flint.timeseries.row.Schema
import com.twosigma.flint.timeseries.summarize.SummarizerSuite
import com.twosigma.flint.timeseries.{ Summarizers, TimeSeriesRDD, TimeSeriesSuite }
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class CovarianceSummarizerSpec extends SummarizerSuite {

  // It is by intention to reuse the files from correlation summarizer
  override val defaultResourceDir: String = "/timeseries/summarize/summarizer/correlationsummarizer"

  private var priceTSRdd: TimeSeriesRDD = null
  private var forecastTSRdd: TimeSeriesRDD = null
  private var input: TimeSeriesRDD = null

  private lazy val init: Unit = {
    priceTSRdd = fromCSV("Price.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
    forecastTSRdd = fromCSV("Forecast.csv", Schema("id" -> IntegerType, "forecast" -> DoubleType))
    input = priceTSRdd.leftJoin(forecastTSRdd, key = Seq("id")).addColumns(
      "price2" -> DoubleType -> { r: Row => r.getAs[Double]("price") },
      "price3" -> DoubleType -> { r: Row => -r.getAs[Double]("price") },
      "price4" -> DoubleType -> { r: Row => r.getAs[Double]("price") * 2 },
      "price5" -> DoubleType -> { r: Row => 0d }
    )
  }

  "CovarianceSummarizer" should "`computeCovariance` correctly" in {
    init
    var results = input.summarize(Summarizers.covariance("price", "price2"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price2_covariance") === 3.368055556)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price2_covariance") === 2.534722222)

    results = input.summarize(Summarizers.covariance("price", "price3"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price3_covariance") === -3.368055556)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price3_covariance") === -2.534722222)

    results = input.summarize(Summarizers.covariance("price", "price4"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price4_covariance") === 6.736111111)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price4_covariance") === 5.069444444)

    results = input.summarize(Summarizers.covariance("price", "price5"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_price5_covariance") === 0d)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_price5_covariance") === 0d)

    results = input.summarize(Summarizers.covariance("price", "forecast"), Seq("id")).collect()
    assert(results.find(_.getAs[Int]("id") == 7).head.getAs[Double]("price_forecast_covariance") === -0.190277778)
    assert(results.find(_.getAs[Int]("id") == 3).head.getAs[Double]("price_forecast_covariance") === -3.783333333)
  }

  it should "ignore null values" in {
    init
    val inputWithNull = insertNullRows(input, "price", "forecast")

    assertEquals(
      inputWithNull.summarize(Summarizers.covariance("price", "forecast")),
      input.summarize(Summarizers.covariance("price", "forecast"))
    )

    assertEquals(
      inputWithNull.summarize(Summarizers.covariance("price", "forecast"), Seq("id")),
      input.summarize(Summarizers.covariance("price", "forecast"), Seq("id"))
    )
  }

  it should "pass summarizer property test" in {
    summarizerPropertyTest(AllProperties)(Summarizers.covariance("x1", "x2"))
    summarizerPropertyTest(AllProperties)(Summarizers.covariance("x0", "x3"))
  }
} 
Example 133
Source File: SummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class SummarizerSpec extends TimeSeriesSuite {

  "SummarizerFactory" should "support alias." in {
    withResource("/timeseries/csv/Price.csv") { source =>
      val expectedSchema = Schema("C1" -> IntegerType, "C2" -> DoubleType)
      val timeseriesRdd = CSV.from(sqlContext, "file://" + source, sorted = true, schema = expectedSchema)
      assert(timeseriesRdd.schema == expectedSchema)
      val result: Row = timeseriesRdd.summarize(Summarizers.count().prefix("alias")).first()
      assert(result.getAs[Long]("alias_count") == timeseriesRdd.count())
    }
  }
} 
Example 134
Source File: SummarizeIntervalsSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.types.{ DoubleType, LongType, IntegerType }

class SummarizeIntervalsSpec extends MultiPartitionSuite with TimeSeriesTestData with TimeTypeSuite {

  override val defaultResourceDir: String = "/timeseries/summarizeintervals"

  "SummarizeInterval" should "pass `SummarizeSingleColumn` test." in {
    withAllTimeType {
      val volumeTSRdd = fromCSV(
        "Volume.csv", Schema("id" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType)
      )

      volumeTSRdd.toDF.show()

      val clockTSRdd = fromCSV("Clock.csv", Schema())
      val resultTSRdd = fromCSV("SummarizeSingleColumn.results", Schema("volume_sum" -> DoubleType))

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeIntervals(clockTSRdd, Summarizers.sum("volume"))
        summarizedVolumeTSRdd.toDF.show()
        assert(summarizedVolumeTSRdd.collect().deep == resultTSRdd.collect().deep)
      }

      withPartitionStrategy(volumeTSRdd)(DEFAULT)(test)
    }
  }

  it should "pass `SummarizeSingleColumnPerKey` test, i.e. with additional a single key." in {
    withAllTimeType {
      val volumeTSRdd = fromCSV(
        "Volume.csv", Schema("id" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType)
      )

      val clockTSRdd = fromCSV("Clock.csv", Schema())
      val resultTSRdd = fromCSV(
        "SummarizeSingleColumnPerKey.results",
        Schema("id" -> IntegerType, "volume_sum" -> DoubleType)
      )

      val result2TSRdd = fromCSV(
        "SummarizeV2PerKey.results",
        Schema("id" -> IntegerType, "v2_sum" -> DoubleType)
      )

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeIntervals(clockTSRdd, Summarizers.sum("volume"), Seq("id"))
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)
        val summarizedV2TSRdd = rdd.summarizeIntervals(clockTSRdd, Summarizers.sum("v2"), Seq("id"))
        assertEquals(summarizedV2TSRdd, result2TSRdd)
      }

      withPartitionStrategy(volumeTSRdd)(DEFAULT)(test)
    }
  }

  it should "pass `SummarizeSingleColumnPerSeqOfKeys` test, i.e. with additional a sequence of keys." in {
    withAllTimeType {
      val volumeTSRdd = fromCSV(
        "VolumeWithIndustryGroup.csv",
        Schema("id" -> IntegerType, "group" -> IntegerType, "volume" -> LongType, "v2" -> DoubleType)
      )

      val clockTSRdd = fromCSV("Clock.csv", Schema())
      val resultTSRdd = fromCSV(
        "SummarizeSingleColumnPerSeqOfKeys.results",
        Schema("id" -> IntegerType, "group" -> IntegerType, "volume_sum" -> DoubleType)
      )

      def test(rdd: TimeSeriesRDD): Unit = {
        val summarizedVolumeTSRdd = rdd.summarizeIntervals(
          clockTSRdd,
          Summarizers.sum("volume"),
          Seq("id", "group")
        )
        assertEquals(summarizedVolumeTSRdd, resultTSRdd)
      }

      withPartitionStrategy(volumeTSRdd)(DEFAULT)(test)
    }
  }
} 
Example 135
Source File: MergeSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.apache.spark.sql.types.{ DoubleType, IntegerType }

class MergeSpec extends MultiPartitionSuite with TimeSeriesTestData {
  override val defaultResourceDir: String = "/timeseries/merge"

  "Merge" should "pass `Merge` test." in {
    val resultsTSRdd = fromCSV("Merge.results", Schema("id" -> IntegerType, "price" -> DoubleType))

    def test(rdd1: TimeSeriesRDD, rdd2: TimeSeriesRDD): Unit = {
      val mergedTSRdd = rdd1.merge(rdd2)
      assert(resultsTSRdd.schema == mergedTSRdd.schema)
      assert(resultsTSRdd.collect().deep == mergedTSRdd.collect().deep)
    }

    {
      val priceTSRdd1 = fromCSV("Price1.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      val priceTSRdd2 = fromCSV("Price2.csv", Schema("id" -> IntegerType, "price" -> DoubleType))
      withPartitionStrategy(priceTSRdd1, priceTSRdd2)(DEFAULT)(test)
    }
  }

  it should "pass generated cycle data test" in {
    val testData1 = cycleData1
    val testData2 = cycleData2

    def merge(rdd1: TimeSeriesRDD, rdd2: TimeSeriesRDD): TimeSeriesRDD = {
      rdd1.merge(rdd2)
    }
    withPartitionStrategyCompare(testData1, testData2)(ALL)(merge)
  }
} 
Example 136
Source File: BasicDataSourceSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import com.pingcap.tikv.exception.TiBatchWriteException
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class BasicDataSourceSuite extends BaseDataSourceTest("test_datasource_basic") {
  private val row1 = Row(null, "Hello")
  private val row2 = Row(2, "TiDB")
  private val row3 = Row(3, "Spark")
  private val row4 = Row(4, null)

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

  override def beforeAll(): Unit = {
    super.beforeAll()

    dropTable()
    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")
    jdbcUpdate(s"insert into $dbtable values(null, 'Hello'), (2, 'TiDB')")
  }

  test("Test Select") {
    if (!supportBatchWrite) {
      cancel
    }

    testTiDBSelect(Seq(row1, row2))
  }

  test("Test Write Append") {
    if (!supportBatchWrite) {
      cancel
    }

    val data: RDD[Row] = sc.makeRDD(List(row3, row4))
    val df = sqlContext.createDataFrame(data, schema)

    df.write
      .format("tidb")
      .options(tidbOptions)
      .option("database", database)
      .option("table", table)
      .mode("append")
      .save()

    testTiDBSelect(Seq(row1, row2, row3, row4))
  }

  test("Test Write Overwrite") {
    if (!supportBatchWrite) {
      cancel
    }

    val data: RDD[Row] = sc.makeRDD(List(row3, row4))
    val df = sqlContext.createDataFrame(data, schema)

    val caught = intercept[TiBatchWriteException] {
      df.write
        .format("tidb")
        .options(tidbOptions)
        .option("database", database)
        .option("table", table)
        .mode("overwrite")
        .save()
    }

    assert(
      caught.getMessage
        .equals("SaveMode: Overwrite is not supported. TiSpark only support SaveMode.Append."))
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 137
Source File: UpperCaseColumnNameSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class UpperCaseColumnNameSuite
    extends BaseDataSourceTest("test_datasource_uppser_case_column_name") {

  private val row1 = Row(1, 2)

  private val schema = StructType(
    List(StructField("O_ORDERKEY", IntegerType), StructField("O_CUSTKEY", IntegerType)))

  override def beforeAll(): Unit = {
    super.beforeAll()

    dropTable()
    jdbcUpdate(s"""
                  |CREATE TABLE $dbtable (O_ORDERKEY INTEGER NOT NULL,
                  |                       O_CUSTKEY INTEGER NOT NULL);
       """.stripMargin)
  }

  test("Test insert upper case column name") {
    if (!supportBatchWrite) {
      cancel
    }

    val data: RDD[Row] = sc.makeRDD(List(row1))
    val df = sqlContext.createDataFrame(data, schema)
    df.write
      .format("tidb")
      .options(tidbOptions)
      .option("database", database)
      .option("table", table)
      .mode("append")
      .save()
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 138
Source File: CheckUnsupportedSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import com.pingcap.tikv.exception.TiBatchWriteException
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class CheckUnsupportedSuite extends BaseDataSourceTest("test_datasource_check_unsupported") {

  override def beforeAll(): Unit =
    super.beforeAll()

  test("Test write to partition table") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()

    tidbStmt.execute("set @@tidb_enable_table_partition = 1")

    jdbcUpdate(
      s"create table $dbtable(i int, s varchar(128)) partition by range(i) (partition p0 values less than maxvalue)")
    jdbcUpdate(s"insert into $dbtable values(null, 'Hello')")

    val row1 = Row(null, "Hello")
    val row2 = Row(2, "TiDB")
    val row3 = Row(3, "Spark")

    val schema = StructType(List(StructField("i", IntegerType), StructField("s", StringType)))

    {
      val caught = intercept[TiBatchWriteException] {
        tidbWrite(List(row2, row3), schema)
      }
      assert(
        caught.getMessage
          .equals("tispark currently does not support write data to partition table!"))
    }

    testTiDBSelect(Seq(row1))
  }

  test("Check Virtual Generated Column") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()
    jdbcUpdate(s"create table $dbtable(i INT, c1 INT, c2 INT,  c3 INT AS (c1 + c2))")

    val row1 = Row(1, 2, 3)
    val schema = StructType(
      List(
        StructField("i", IntegerType),
        StructField("c1", IntegerType),
        StructField("c2", IntegerType)))

    val caught = intercept[TiBatchWriteException] {
      tidbWrite(List(row1), schema)
    }
    assert(
      caught.getMessage
        .equals("tispark currently does not support write data to table with generated column!"))

  }

  test("Check Stored Generated Column") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()
    jdbcUpdate(s"create table $dbtable(i INT, c1 INT, c2 INT,  c3 INT AS (c1 + c2) STORED)")

    val row1 = Row(1, 2, 3)
    val schema = StructType(
      List(
        StructField("i", IntegerType),
        StructField("c1", IntegerType),
        StructField("c2", IntegerType)))
    val caught = intercept[TiBatchWriteException] {
      tidbWrite(List(row1), schema)
    }
    assert(
      caught.getMessage
        .equals("tispark currently does not support write data to table with generated column!"))

  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 139
Source File: RegionSplitSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import com.pingcap.tikv.TiBatchWriteUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class RegionSplitSuite extends BaseDataSourceTest("region_split_test") {
  private val row1 = Row(1)
  private val row2 = Row(2)
  private val row3 = Row(3)
  private val schema = StructType(List(StructField("a", IntegerType)))

  test("index region split test") {
    if (!supportBatchWrite) {
      cancel
    }

    // do not test this case on tidb which does not support split region
    if (!isEnableSplitRegion) {
      cancel
    }

    dropTable()
    jdbcUpdate(
      s"CREATE TABLE  $dbtable ( `a` int(11), unique index(a)) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin")

    val options = Some(Map("enableRegionSplit" -> "true", "regionSplitNum" -> "3"))

    tidbWrite(List(row1, row2, row3), schema, options)

    val tiTableInfo =
      ti.tiSession.getCatalog.getTable(dbPrefix + database, table)
    val regionsNum = TiBatchWriteUtils
      .getRegionByIndex(ti.tiSession, tiTableInfo, tiTableInfo.getIndices.get(0))
      .size()
    assert(regionsNum == 3)
  }

  test("table region split test") {
    if (!supportBatchWrite) {
      cancel
    }

    // do not test this case on tidb which does not support split region
    if (!isEnableSplitRegion) {
      cancel
    }

    dropTable()
    jdbcUpdate(
      s"CREATE TABLE  $dbtable ( `a` int(11) DEFAULT NULL) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_bin")

    val options = Some(Map("enableRegionSplit" -> "true", "regionSplitNum" -> "3"))

    tidbWrite(List(row1, row2, row3), schema, options)

    val tiTableInfo =
      ti.tiSession.getCatalog.getTable(dbPrefix + database, table)
    val regionsNum = TiBatchWriteUtils.getRecordRegions(ti.tiSession, tiTableInfo).size()
    assert(regionsNum == 3)
  }
} 
Example 140
Source File: MissingParameterSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class MissingParameterSuite extends BaseDataSourceTest("test_datasource_missing_parameter") {
  private val row1 = Row(null, "Hello")

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

  test("Missing parameter: database") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()
    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")

    val caught = intercept[IllegalArgumentException] {
      val rows = row1 :: Nil
      val data: RDD[Row] = sc.makeRDD(rows)
      val df = sqlContext.createDataFrame(data, schema)
      df.write
        .format("tidb")
        .options(tidbOptions)
        .option("table", table)
        .mode("append")
        .save()
    }
    assert(
      caught.getMessage
        .equals("requirement failed: Option 'database' is required."))
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 141
Source File: ShardRowIDBitsSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class ShardRowIDBitsSuite extends BaseDataSourceTest("test_shard_row_id_bits") {
  private val row1 = Row(1)
  private val row2 = Row(2)
  private val row3 = Row(3)
  private val schema = StructType(List(StructField("a", IntegerType)))

  test("reading and writing a table with shard_row_id_bits") {
    if (!supportBatchWrite) {
      cancel
    }

    dropTable()
    jdbcUpdate(s"CREATE TABLE  $dbtable ( `a` int(11))  SHARD_ROW_ID_BITS = 4")

    jdbcUpdate(s"insert into $dbtable values(null)")

    tidbWrite(List(row1, row2, row3), schema)
    testTiDBSelect(List(Row(null), row1, row2, row3), sortCol = "a")
  }
} 
Example 142
Source File: OnlyOnePkSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.datasource

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class OnlyOnePkSuite extends BaseDataSourceTest("test_datasource_only_one_pk") {
  private val row3 = Row(3)
  private val row4 = Row(4)

  private val schema = StructType(List(StructField("i", IntegerType)))

  override def beforeAll(): Unit = {
    super.beforeAll()

    dropTable()
    jdbcUpdate(s"create table $dbtable(i int primary key)")
  }

  test("Test Write Append") {
    if (!supportBatchWrite) {
      cancel
    }

    val data: RDD[Row] = sc.makeRDD(List(row3, row4))
    val df = sqlContext.createDataFrame(data, schema)

    df.write
      .format("tidb")
      .options(tidbOptions)
      .option("database", database)
      .option("table", table)
      .mode("append")
      .save()

    testTiDBSelect(Seq(row3, row4))
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 143
Source File: BatchWriteIssueSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark

import com.pingcap.tispark.datasource.BaseDataSourceTest
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class BatchWriteIssueSuite extends BaseDataSourceTest("test_batchwrite_issue") {
  override def beforeAll(): Unit = {
    super.beforeAll()
  }

  test("Combine unique index with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), CONSTRAINT ab UNIQUE (a, b))")
  }

  test("Combine primary key with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), PRIMARY KEY (a, b))")
  }

  test("PK is handler with null value test") {
    doTestNullValues(s"create table $dbtable(a int, b varchar(64), PRIMARY KEY (a))")
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }

  private def doTestNullValues(createTableSQL: String): Unit = {
    if (!supportBatchWrite) {
      cancel
    }
    val schema = StructType(
      List(
        StructField("a", IntegerType),
        StructField("b", StringType),
        StructField("c", StringType)))

    val options = Some(Map("replace" -> "true"))

    dropTable()
    jdbcUpdate(createTableSQL)
    jdbcUpdate(s"alter table $dbtable add column to_delete int")
    jdbcUpdate(s"alter table $dbtable add column c varchar(64) default 'c33'")
    jdbcUpdate(s"alter table $dbtable drop column to_delete")
    jdbcUpdate(s"""
                  |insert into $dbtable values(11, 'c12', null);
                  |insert into $dbtable values(21, 'c22', null);
                  |insert into $dbtable (a, b) values(31, 'c32');
                  |insert into $dbtable values(41, 'c42', 'c43');
                  |
      """.stripMargin)

    assert(queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head == null)
    assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
    assert(
      queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head.toString.equals("c33"))
    assert(
      queryTiDBViaJDBC(s"select c from $dbtable where a=41").head.head.toString.equals("c43"))

    {
      val row1 = Row(11, "c12", "c13")
      val row3 = Row(31, "c32", null)

      tidbWrite(List(row1, row3), schema, options)

      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head.toString.equals("c13"))
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head == null)
      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=41").head.head.toString.equals("c43"))
    }

    {
      val row1 = Row(11, "c12", "c213")
      val row3 = Row(31, "c32", "tt")
      tidbWrite(List(row1, row3), schema, options)
      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=11").head.head.toString.equals("c213"))
      assert(queryTiDBViaJDBC(s"select c from $dbtable where a=21").head.head == null)
      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=31").head.head.toString.equals("tt"))
      assert(
        queryTiDBViaJDBC(s"select c from $dbtable where a=41").head.head.toString.equals("c43"))
    }
  }
} 
Example 144
Source File: LockTimeoutSuite.scala    From tispark   with Apache License 2.0 5 votes vote down vote up
package com.pingcap.tispark.ttl

import com.pingcap.tikv.TTLManager
import com.pingcap.tikv.exception.GrpcException
import com.pingcap.tispark.datasource.BaseDataSourceTest
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}

class LockTimeoutSuite extends BaseDataSourceTest("test_lock_timeout") {
  private val row1 = Row(1, "Hello")

  private val schema = StructType(
    List(StructField("i", IntegerType), StructField("s", StringType)))

  override def beforeAll(): Unit = {
    super.beforeAll()
    dropTable()
    jdbcUpdate(s"create table $dbtable(i int, s varchar(128))")
  }

  test("Test Lock TTL Timeout") {
    if (!supportTTLUpdate) {
      cancel
    }

    val seconds = 1000
    val sleep1 = TTLManager.MANAGED_LOCK_TTL + 10 * seconds
    val sleep2 = TTLManager.MANAGED_LOCK_TTL + 15 * seconds

    val data: RDD[Row] = sc.makeRDD(List(row1))
    val df = sqlContext.createDataFrame(data, schema)

    new Thread(new Runnable {
      override def run(): Unit = {
        Thread.sleep(sleep1)
        queryTiDBViaJDBC(s"select * from $dbtable")
      }
    }).start()

    val grpcException = intercept[GrpcException] {
      df.write
        .format("tidb")
        .options(tidbOptions)
        .option("database", database)
        .option("table", table)
        .option("sleepAfterPrewritePrimaryKey", sleep2)
        .mode("append")
        .save()
    }

    assert(grpcException.getMessage.equals("retry is exhausted."))
    assert(grpcException.getCause.getMessage.startsWith("Txn commit primary key failed"))
    assert(
      grpcException.getCause.getCause.getMessage.startsWith(
        "Key exception occurred and the reason is retryable: \"Txn(Mvcc(TxnLockNotFound"))
  }

  override def afterAll(): Unit =
    try {
      dropTable()
    } finally {
      super.afterAll()
    }
} 
Example 145
Source File: Preprocess.scala    From Scala-Machine-Learning-Projects   with MIT License 5 votes vote down vote up
package com.packt.ScalaML.BitCoin

import java.io.{ BufferedWriter, File, FileWriter }
import org.apache.spark.sql.types.{ DoubleType, IntegerType, StructField, StructType }
import org.apache.spark.sql.{ DataFrame, Row, SparkSession }
import scala.collection.mutable.ListBuffer

object Preprocess {
  //how many of first rows are omitted
    val dropFirstCount: Int = 612000

    def rollingWindow(data: DataFrame, window: Int, xFilename: String, yFilename: String): Unit = {
      var i = 0
      val xWriter = new BufferedWriter(new FileWriter(new File(xFilename)))
      val yWriter = new BufferedWriter(new FileWriter(new File(yFilename)))

      val zippedData = data.rdd.zipWithIndex().collect()
      System.gc()
      val dataStratified = zippedData.drop(dropFirstCount) //todo slice fisrt 614K
      while (i < (dataStratified.length - window)) {
        val x = dataStratified
          .slice(i, i + window)
          .map(r => r._1.getAs[Double]("Delta")).toList
        val y = dataStratified.apply(i + window)._1.getAs[Integer]("label")
        val stringToWrite = x.mkString(",")
        xWriter.write(stringToWrite + "\n")
        yWriter.write(y + "\n")

        i += 1
        if (i % 10 == 0) {
          xWriter.flush()
          yWriter.flush()
        }
      }

      xWriter.close()
      yWriter.close()
    }
    
  def main(args: Array[String]): Unit = {
    //todo modify these variables to match desirable files
    val priceDataFileName: String = "C:/Users/admin-karim/Desktop/bitstampUSD_1-min_data_2012-01-01_to_2017-10-20.csv/bitstampUSD_1-min_data_2012-01-01_to_2017-10-20.csv"
    val outputDataFilePath: String = "output/scala_test_x.csv"
    val outputLabelFilePath: String = "output/scala_test_y.csv"

    val spark = SparkSession
      .builder()
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/")
      .appName("Bitcoin Preprocessing")
      .getOrCreate()

    val data = spark.read.format("com.databricks.spark.csv").option("header", "true").load(priceDataFileName)
    data.show(10)
    println((data.count(), data.columns.size))

    val dataWithDelta = data.withColumn("Delta", data("Close") - data("Open"))

    import org.apache.spark.sql.functions._
    import spark.sqlContext.implicits._

    val dataWithLabels = dataWithDelta.withColumn("label", when($"Close" - $"Open" > 0, 1).otherwise(0))
    rollingWindow(dataWithLabels, 22, outputDataFilePath, outputLabelFilePath)    
    spark.stop()
  }
} 
Example 146
Source File: OptimizeHiveMetadataOnlyQuerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfter

import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton
    with BeforeAndAfter with SQLTestUtils {

  import spark.implicits._

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)")
    (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)"))
  }

  override protected def afterAll(): Unit = {
    try {
      sql("DROP TABLE IF EXISTS metadata_only")
    } finally {
      super.afterAll()
    }
  }

  test("SPARK-23877: validate metadata-only query pushes filters to metastore") {
    withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") {
      val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

      // verify the number of matching partitions
      assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5)

      // verify that the partition predicate was pushed down to the metastore
      assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5)
    }
  }

  test("SPARK-23877: filter on projected expression") {
    withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") {
      val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

      // verify the matching partitions
      val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr,
        Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]),
          spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child)))
          .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType))))

      checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x"))

      // verify that the partition predicate was not pushed down to the metastore
      assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11)
    }
  }
} 
Example 147
Source File: SparkExecuteStatementOperationSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.thriftserver

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, NullType, StringType, StructField, StructType}

class SparkExecuteStatementOperationSuite extends SparkFunSuite {
  test("SPARK-17112 `select null` via JDBC triggers IllegalArgumentException in ThriftServer") {
    val field1 = StructField("NULL", NullType)
    val field2 = StructField("(IF(true, NULL, NULL))", NullType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.NULL_TYPE)
  }

  test("SPARK-20146 Comment should be preserved") {
    val field1 = StructField("column1", StringType).withComment("comment 1")
    val field2 = StructField("column2", IntegerType)
    val tableSchema = StructType(Seq(field1, field2))
    val columns = SparkExecuteStatementOperation.getTableSchema(tableSchema).getColumnDescriptors()
    assert(columns.size() == 2)
    assert(columns.get(0).getType() == org.apache.hive.service.cli.Type.STRING_TYPE)
    assert(columns.get(0).getComment() == "comment 1")
    assert(columns.get(1).getType() == org.apache.hive.service.cli.Type.INT_TYPE)
    assert(columns.get(1).getComment() == "")
  }
} 
Example 148
Source File: SubstituteUnresolvedOrdinals.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType


class SubstituteUnresolvedOrdinals(conf: SQLConf) extends Rule[LogicalPlan] {
  private def isIntLiteral(e: Expression) = e match {
    case Literal(_, IntegerType) => true
    case _ => false
  }

  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) =>
      val newOrders = s.order.map {
        case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) =>
          val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
          withOrigin(order.origin)(order.copy(child = newOrdinal))
        case other => other
      }
      withOrigin(s.origin)(s.copy(order = newOrders))

    case a: Aggregate if conf.groupByOrdinal && a.groupingExpressions.exists(isIntLiteral) =>
      val newGroups = a.groupingExpressions.map {
        case ordinal @ Literal(index: Int, IntegerType) =>
          withOrigin(ordinal.origin)(UnresolvedOrdinal(index))
        case other => other
      }
      withOrigin(a.origin)(a.copy(groupingExpressions = newGroups))
  }
} 
Example 149
Source File: ResolveTableValuedFunctions.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Range}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.{DataType, IntegerType, LongType}


      tvf("start" -> LongType, "end" -> LongType, "step" -> LongType,
          "numPartitions" -> IntegerType) {
        case Seq(start: Long, end: Long, step: Long, numPartitions: Int) =>
          Range(start, end, step, Some(numPartitions))
      })
  )

  override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
    case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) =>
      // The whole resolution is somewhat difficult to understand here due to too much abstractions.
      // We should probably rewrite the following at some point. Reynold was just here to improve
      // error messages and didn't have time to do a proper rewrite.
      val resolvedFunc = builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
        case Some(tvf) =>

          def failAnalysis(): Nothing = {
            val argTypes = u.functionArgs.map(_.dataType.typeName).mkString(", ")
            u.failAnalysis(
              s"""error: table-valued function ${u.functionName} with alternatives:
                 |${tvf.keys.map(_.toString).toSeq.sorted.map(x => s" ($x)").mkString("\n")}
                 |cannot be applied to: ($argTypes)""".stripMargin)
          }

          val resolved = tvf.flatMap { case (argList, resolver) =>
            argList.implicitCast(u.functionArgs) match {
              case Some(casted) =>
                try {
                  Some(resolver(casted.map(_.eval())))
                } catch {
                  case e: AnalysisException =>
                    failAnalysis()
                }
              case _ =>
                None
            }
          }
          resolved.headOption.getOrElse {
            failAnalysis()
          }
        case _ =>
          u.failAnalysis(s"could not resolve `${u.functionName}` to a table-valued function")
      }

      // If alias names assigned, add `Project` with the aliases
      if (u.outputNames.nonEmpty) {
        val outputAttrs = resolvedFunc.output
        // Checks if the number of the aliases is equal to expected one
        if (u.outputNames.size != outputAttrs.size) {
          u.failAnalysis(s"Number of given aliases does not match number of output columns. " +
            s"Function name: ${u.functionName}; number of aliases: " +
            s"${u.outputNames.size}; number of output columns: ${outputAttrs.size}.")
        }
        val aliases = outputAttrs.zip(u.outputNames).map {
          case (attr, name) => Alias(attr, name)()
        }
        Project(aliases, resolvedFunc)
      } else {
        resolvedFunc
      }
  }
} 
Example 150
Source File: StatsEstimationTestBase.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.statsEstimation

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType}


trait StatsEstimationTestBase extends SparkFunSuite {

  var originalValue: Boolean = false

  override def beforeAll(): Unit = {
    super.beforeAll()
    // Enable stats estimation based on CBO.
    originalValue = SQLConf.get.getConf(SQLConf.CBO_ENABLED)
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, true)
  }

  override def afterAll(): Unit = {
    SQLConf.get.setConf(SQLConf.CBO_ENABLED, originalValue)
    super.afterAll()
  }

  def getColSize(attribute: Attribute, colStat: ColumnStat): Long = attribute.dataType match {
    // For UTF8String: base + offset + numBytes
    case StringType => colStat.avgLen.getOrElse(attribute.dataType.defaultSize.toLong) + 8 + 4
    case _ => colStat.avgLen.getOrElse(attribute.dataType.defaultSize)
  }

  def attr(colName: String): AttributeReference = AttributeReference(colName, IntegerType)()

  
case class StatsTestPlan(
    outputList: Seq[Attribute],
    rowCount: BigInt,
    attributeStats: AttributeMap[ColumnStat],
    size: Option[BigInt] = None) extends LeafNode {
  override def output: Seq[Attribute] = outputList
  override def computeStats(): Statistics = Statistics(
    // If sizeInBytes is useless in testing, we just use a fake value
    sizeInBytes = size.getOrElse(Int.MaxValue),
    rowCount = Some(rowCount),
    attributeStats = attributeStats)
} 
Example 151
Source File: LogicalPlanSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType


class LogicalPlanSuite extends SparkFunSuite {
  private var invocationCount = 0
  private val function: PartialFunction[LogicalPlan, LogicalPlan] = {
    case p: Project =>
      invocationCount += 1
      p
  }

  private val testRelation = LocalRelation()

  test("transformUp runs on operators") {
    invocationCount = 0
    val plan = Project(Nil, testRelation)
    plan transformUp function

    assert(invocationCount === 1)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 1)
  }

  test("transformUp runs on operators recursively") {
    invocationCount = 0
    val plan = Project(Nil, Project(Nil, testRelation))
    plan transformUp function

    assert(invocationCount === 2)

    invocationCount = 0
    plan transformDown function
    assert(invocationCount === 2)
  }

  test("isStreaming") {
    val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
    val incrementalRelation = LocalRelation(
      Seq(AttributeReference("a", IntegerType, nullable = true)()), isStreaming = true)

    case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
      override def output: Seq[Attribute] = left.output ++ right.output
    }

    require(relation.isStreaming === false)
    require(incrementalRelation.isStreaming === true)
    assert(TestBinaryRelation(relation, relation).isStreaming === false)
    assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true)
    assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true)
    assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming)
  }

  test("transformExpressions works with a Stream") {
    val id1 = NamedExpression.newExprId
    val id2 = NamedExpression.newExprId
    val plan = Project(Stream(
      Alias(Literal(1), "a")(exprId = id1),
      Alias(Literal(2), "b")(exprId = id2)),
      OneRowRelation())
    val result = plan.transformExpressions {
      case Literal(v: Int, IntegerType) if v != 1 =>
        Literal(v + 1, IntegerType)
    }
    val expected = Project(Stream(
      Alias(Literal(1), "a")(exprId = id1),
      Alias(Literal(3), "b")(exprId = id2)),
      OneRowRelation())
    assert(result.sameResult(expected))
  }
} 
Example 152
Source File: QueryPlanSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.plans

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Literal, NamedExpression}
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.types.IntegerType

class QueryPlanSuite extends SparkFunSuite {

  test("origin remains the same after mapExpressions (SPARK-23823)") {
    CurrentOrigin.setPosition(0, 0)
    val column = AttributeReference("column", IntegerType)(NamedExpression.newExprId)
    val query = plans.DslLogicalPlan(plans.table("table")).select(column)
    CurrentOrigin.reset()

    val mappedQuery = query mapExpressions {
      case _: Expression => Literal(1)
    }

    val mappedOrigin = mappedQuery.expressions.apply(0).origin
    assert(mappedOrigin == Origin.apply(Some(0), Some(0)))
  }

} 
Example 153
Source File: ScalaUDFSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("basic") {
    val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil)
    checkEvaluation(intUdf, 2)

    val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil)
    checkEvaluation(stringUdf, "ax")
  }

  test("better error message for NPE") {
    val udf = ScalaUDF(
      (s: String) => s.toLowerCase(Locale.ROOT),
      StringType,
      Literal.create(null, StringType) :: Nil,
      true :: Nil)

    val e1 = intercept[SparkException](udf.eval())
    assert(e1.getMessage.contains("Failed to execute user defined function"))

    val e2 = intercept[SparkException] {
      checkEvaluationWithUnsafeProjection(udf, null)
    }
    assert(e2.getMessage.contains("Failed to execute user defined function"))
  }

  test("SPARK-22695: ScalaUDF should not use global variables") {
    val ctx = new CodegenContext
    ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx)
    assert(ctx.inlinedMutableStates.isEmpty)
  }
} 
Example 154
Source File: ExpressionEvalHelperSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.types.{DataType, IntegerType}


case class BadCodegenExpression() extends LeafExpression {
  override def nullable: Boolean = false
  override def eval(input: InternalRow): Any = 10
  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    ev.copy(code =
      code"""
        |int some_variable = 11;
        |int ${ev.value} = 10;
      """.stripMargin)
  }
  override def dataType: DataType = IntegerType
} 
Example 155
Source File: CanonicalizeSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class CanonicalizeSuite extends SparkFunSuite {

  test("SPARK-24276: IN expression with different order are semantically equal") {
    val range = Range(1, 1, 1, 1)
    val idAttr = range.output.head

    val in1 = In(idAttr, Seq(Literal(1), Literal(2)))
    val in2 = In(idAttr, Seq(Literal(2), Literal(1)))
    val in3 = In(idAttr, Seq(Literal(1), Literal(2), Literal(3)))

    assert(in1.canonicalized.semanticHash() == in2.canonicalized.semanticHash())
    assert(in1.canonicalized.semanticHash() != in3.canonicalized.semanticHash())

    assert(range.where(in1).sameResult(range.where(in2)))
    assert(!range.where(in1).sameResult(range.where(in3)))

    val arrays1 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))),
      CreateArray(Seq(Literal(2), Literal(1)))))
    val arrays2 = In(idAttr, Seq(CreateArray(Seq(Literal(2), Literal(1))),
      CreateArray(Seq(Literal(1), Literal(2)))))
    val arrays3 = In(idAttr, Seq(CreateArray(Seq(Literal(1), Literal(2))),
      CreateArray(Seq(Literal(3), Literal(1)))))

    assert(arrays1.canonicalized.semanticHash() == arrays2.canonicalized.semanticHash())
    assert(arrays1.canonicalized.semanticHash() != arrays3.canonicalized.semanticHash())

    assert(range.where(arrays1).sameResult(range.where(arrays2)))
    assert(!range.where(arrays1).sameResult(range.where(arrays3)))
  }

  test("SPARK-26402: accessing nested fields with different cases in case insensitive mode") {
    val expId = NamedExpression.newExprId
    val qualifier = Seq.empty[String]
    val structType = StructType(
      StructField("a", StructType(StructField("b", IntegerType, false) :: Nil), false) :: Nil)

    // GetStructField with different names are semantically equal
    val fieldA1 = GetStructField(
      AttributeReference("data1", structType, false)(expId, qualifier),
      0, Some("a1"))
    val fieldA2 = GetStructField(
      AttributeReference("data2", structType, false)(expId, qualifier),
      0, Some("a2"))
    assert(fieldA1.semanticEquals(fieldA2))

    val fieldB1 = GetStructField(
      GetStructField(
        AttributeReference("data1", structType, false)(expId, qualifier),
        0, Some("a1")),
      0, Some("b1"))
    val fieldB2 = GetStructField(
      GetStructField(
        AttributeReference("data2", structType, false)(expId, qualifier),
        0, Some("a2")),
      0, Some("b2"))
    assert(fieldB1.semanticEquals(fieldB2))
  }
} 
Example 156
Source File: CallMethodViaReflectionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
import org.apache.spark.sql.types.{IntegerType, StringType}


class CallMethodViaReflectionSuite extends SparkFunSuite with ExpressionEvalHelper {

  import CallMethodViaReflection._

  // Get rid of the $ so we are getting the companion object's name.
  private val staticClassName = ReflectStaticClass.getClass.getName.stripSuffix("$")
  private val dynamicClassName = classOf[ReflectDynamicClass].getName

  test("findMethod via reflection for static methods") {
    assert(findMethod(staticClassName, "method1", Seq.empty).exists(_.getName == "method1"))
    assert(findMethod(staticClassName, "method2", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method3", Seq(IntegerType)).isDefined)
    assert(findMethod(staticClassName, "method4", Seq(IntegerType, StringType)).isDefined)
  }

  test("findMethod for a JDK library") {
    assert(findMethod(classOf[java.util.UUID].getName, "randomUUID", Seq.empty).isDefined)
  }

  test("class not found") {
    val ret = createExpr("some-random-class", "method").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("not found") && errorMsg.contains("class"))
  }

  test("method not found because name does not match") {
    val ret = createExpr(staticClassName, "notfoundmethod").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))
  }

  test("method not found because there is no static method") {
    val ret = createExpr(dynamicClassName, "method1").checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("cannot find a static method"))
  }

  test("input type checking") {
    assert(CallMethodViaReflection(Seq.empty).checkInputDataTypes().isFailure)
    assert(CallMethodViaReflection(Seq(Literal(staticClassName))).checkInputDataTypes().isFailure)
    assert(CallMethodViaReflection(
      Seq(Literal(staticClassName), Literal(1))).checkInputDataTypes().isFailure)
    assert(createExpr(staticClassName, "method1").checkInputDataTypes().isSuccess)
  }

  test("unsupported type checking") {
    val ret = createExpr(staticClassName, "method1", new Timestamp(1)).checkInputDataTypes()
    assert(ret.isFailure)
    val errorMsg = ret.asInstanceOf[TypeCheckFailure].message
    assert(errorMsg.contains("arguments from the third require boolean, byte, short"))
  }

  test("invoking methods using acceptable types") {
    checkEvaluation(createExpr(staticClassName, "method1"), "m1")
    checkEvaluation(createExpr(staticClassName, "method2", 2), "m2")
    checkEvaluation(createExpr(staticClassName, "method3", 3), "m3")
    checkEvaluation(createExpr(staticClassName, "method4", 4, "four"), "m4four")
  }

  private def createExpr(className: String, methodName: String, args: Any*) = {
    CallMethodViaReflection(
      Literal.create(className, StringType) +:
      Literal.create(methodName, StringType) +:
      args.map(Literal.apply)
    )
  }
} 
Example 157
Source File: RandomSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import org.scalatest.Matchers._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, LongType}

class RandomSuite extends SparkFunSuite with ExpressionEvalHelper {

  test("random") {
    checkDoubleEvaluation(Rand(30), 0.31429268272540556 +- 0.001)
    checkDoubleEvaluation(Randn(30), -0.4798519469521663 +- 0.001)

    checkDoubleEvaluation(
      new Rand(Literal.create(null, LongType)), 0.8446490682263027 +- 0.001)
    checkDoubleEvaluation(
      new Randn(Literal.create(null, IntegerType)), 1.1164209726833079 +- 0.001)
  }

  test("SPARK-9127 codegen with long seed") {
    checkDoubleEvaluation(Rand(5419823303878592871L), 0.2304755080444375 +- 0.001)
    checkDoubleEvaluation(Randn(5419823303878592871L), -1.2824262718225607 +- 0.001)
  }
} 
Example 158
Source File: CodeGeneratorWithInterpretedFallbackSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.util.concurrent.ExecutionException

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator}
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.IntegerType

class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase {

  object FailedCodegenProjection
      extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] {

    override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = {
      val invalidCode = new CodeAndComment("invalid code", Map.empty)
      // We assume this compilation throws an exception
      CodeGenerator.compile(invalidCode)
      null
    }

    override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = {
      InterpretedUnsafeProjection.createProjection(in)
    }
  }

  test("UnsafeProjection with codegen factory mode") {
    val input = Seq(BoundReference(0, IntegerType, nullable = true))
    val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
      val obj = UnsafeProjection.createObject(input)
      assert(obj.getClass.getName.contains("GeneratedClass$SpecificUnsafeProjection"))
    }

    val noCodegen = CodegenObjectFactoryMode.NO_CODEGEN.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> noCodegen) {
      val obj = UnsafeProjection.createObject(input)
      assert(obj.isInstanceOf[InterpretedUnsafeProjection])
    }
  }

  test("fallback to the interpreter mode") {
    val input = Seq(BoundReference(0, IntegerType, nullable = true))
    val fallback = CodegenObjectFactoryMode.FALLBACK.toString
    withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallback) {
      val obj = FailedCodegenProjection.createObject(input)
      assert(obj.isInstanceOf[InterpretedUnsafeProjection])
    }
  }

  test("codegen failures in the CODEGEN_ONLY mode") {
    val errMsg = intercept[ExecutionException] {
      val input = Seq(BoundReference(0, IntegerType, nullable = true))
      val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString
      withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) {
        FailedCodegenProjection.createObject(input)
      }
    }.getMessage
    assert(errMsg.contains("failed to compile: org.codehaus.commons.compiler.CompileException:"))
  }
} 
Example 159
Source File: ResolveLambdaVariablesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{ArrayType, IntegerType}


class ResolveLambdaVariablesSuite extends PlanTest {
  import org.apache.spark.sql.catalyst.dsl.expressions._
  import org.apache.spark.sql.catalyst.dsl.plans._

  object Analyzer extends RuleExecutor[LogicalPlan] {
    val batches = Batch("Resolution", FixedPoint(4), ResolveLambdaVariables(conf)) :: Nil
  }

  private val key = 'key.int
  private val values1 = 'values1.array(IntegerType)
  private val values2 = 'values2.array(ArrayType(ArrayType(IntegerType)))
  private val data = LocalRelation(Seq(key, values1, values2))
  private val lvInt = NamedLambdaVariable("x", IntegerType, nullable = true)
  private val lvHiddenInt = NamedLambdaVariable("col0", IntegerType, nullable = true)
  private val lvArray = NamedLambdaVariable("x", ArrayType(IntegerType), nullable = true)

  private def plan(e: Expression): LogicalPlan = data.select(e.as("res"))

  private def checkExpression(e1: Expression, e2: Expression): Unit = {
    comparePlans(Analyzer.execute(plan(e1)), plan(e2))
  }

  private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))

  test("resolution - no op") {
    checkExpression(key, key)
  }

  test("resolution - simple") {
    val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
    val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
    checkExpression(in, out)
  }

  test("resolution - nested") {
    val in = ArrayTransform(values2, LambdaFunction(
      ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
    val out = ArrayTransform(values2, LambdaFunction(
      ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
    checkExpression(in, out)
  }

  test("resolution - hidden") {
    val in = ArrayTransform(values1, key)
    val out = ArrayTransform(values1, LambdaFunction(key, lvHiddenInt :: Nil, hidden = true))
    checkExpression(in, out)
  }

  test("fail - name collisions") {
    val p = plan(ArrayTransform(values1,
      LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
    val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
    assert(msg.contains("arguments should not have names that are semantically the same"))
  }

  test("fail - lambda arguments") {
    val p = plan(ArrayTransform(values1,
      LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
    val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
    assert(msg.contains("does not match the number of arguments expected"))
  }
} 
Example 160
Source File: RewriteDistinctAggregatesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Expand, LocalRelation, LogicalPlan}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, GROUP_BY_ORDINAL}
import org.apache.spark.sql.types.{IntegerType, StringType}

class RewriteDistinctAggregatesSuite extends PlanTest {
  override val conf = new SQLConf().copy(CASE_SENSITIVE -> false, GROUP_BY_ORDINAL -> false)
  val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf)
  val analyzer = new Analyzer(catalog, conf)

  val nullInt = Literal(null, IntegerType)
  val nullString = Literal(null, StringType)
  val testRelation = LocalRelation('a.string, 'b.string, 'c.string, 'd.string, 'e.int)

  private def checkRewrite(rewrite: LogicalPlan): Unit = rewrite match {
    case Aggregate(_, _, Aggregate(_, _, _: Expand)) =>
    case _ => fail(s"Plan is not rewritten:\n$rewrite")
  }

  test("single distinct group") {
    val input = testRelation
      .groupBy('a)(countDistinct('e))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("single distinct group with partial aggregates") {
    val input = testRelation
      .groupBy('a, 'd)(
        countDistinct('e, 'c).as('agg1),
        max('b).as('agg2))
      .analyze
    val rewrite = RewriteDistinctAggregates(input)
    comparePlans(input, rewrite)
  }

  test("multiple distinct groups") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with partial aggregates") {
    val input = testRelation
      .groupBy('a)(countDistinct('b, 'c), countDistinct('d), sum('e))
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }

  test("multiple distinct groups with non-partial aggregates") {
    val input = testRelation
      .groupBy('a)(
        countDistinct('b, 'c),
        countDistinct('d),
        CollectSet('b).toAggregateExpression())
      .analyze
    checkRewrite(RewriteDistinctAggregates(input))
  }
} 
Example 161
Source File: resources.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.command

import java.io.File
import java.net.URI

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


case class ListJarsCommand(jars: Seq[String] = Seq.empty[String]) extends RunnableCommand {
  override val output: Seq[Attribute] = {
    AttributeReference("Results", StringType, nullable = false)() :: Nil
  }
  override def run(sparkSession: SparkSession): Seq[Row] = {
    val jarList = sparkSession.sparkContext.listJars()
    if (jars.nonEmpty) {
      for {
        jarName <- jars.map(f => new Path(f).getName)
        jarPath <- jarList if jarPath.contains(jarName)
      } yield Row(jarPath)
    } else {
      jarList.map(Row(_))
    }
  }
} 
Example 162
Source File: SameResultSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.sql.{DataFrame, QueryTest}
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.IntegerType


class SameResultSuite extends QueryTest with SharedSQLContext {
  import testImplicits._

  test("FileSourceScanExec: different orders of data filters and partition filters") {
    withTempPath { path =>
      val tmpDir = path.getCanonicalPath
      spark.range(10)
        .selectExpr("id as a", "id + 1 as b", "id + 2 as c", "id + 3 as d")
        .write
        .partitionBy("a", "b")
        .parquet(tmpDir)
      val df = spark.read.parquet(tmpDir)
      // partition filters: a > 1 AND b < 9
      // data filters: c > 1 AND d < 9
      val plan1 = getFileSourceScanExec(df.where("a > 1 AND b < 9 AND c > 1 AND d < 9"))
      val plan2 = getFileSourceScanExec(df.where("b < 9 AND a > 1 AND d < 9 AND c > 1"))
      assert(plan1.sameResult(plan2))
    }
  }

  private def getFileSourceScanExec(df: DataFrame): FileSourceScanExec = {
    df.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get
      .asInstanceOf[FileSourceScanExec]
  }

  test("SPARK-20725: partial aggregate should behave correctly for sameResult") {
    val df1 = spark.range(10).agg(sum($"id"))
    val df2 = spark.range(10).agg(sum($"id"))
    assert(df1.queryExecution.executedPlan.sameResult(df2.queryExecution.executedPlan))

    val df3 = spark.range(10).agg(sumDistinct($"id"))
    val df4 = spark.range(10).agg(sumDistinct($"id"))
    assert(df3.queryExecution.executedPlan.sameResult(df4.queryExecution.executedPlan))
  }

  test("Canonicalized result is case-insensitive") {
    val a = AttributeReference("A", IntegerType)()
    val b = AttributeReference("B", IntegerType)()
    val planUppercase = Project(Seq(a), LocalRelation(a, b))

    val c = AttributeReference("a", IntegerType)()
    val d = AttributeReference("b", IntegerType)()
    val planLowercase = Project(Seq(c), LocalRelation(c, d))

    assert(planUppercase.sameResult(planLowercase))
  }
} 
Example 163
Source File: GroupedIteratorSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType}

class GroupedIteratorSuite extends SparkFunSuite {

  test("basic") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 1)
        key.getInt(0) -> data.map(encoder.fromRow).toSeq
    }.toSeq

    assert(result ==
      1 -> Seq(input(0), input(1)) ::
      2 -> Seq(input(2)) :: Nil)
  }

  test("group by 2 columns") {
    val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()

    val input = Seq(
      Row(1, 2L, "a"),
      Row(1, 2L, "b"),
      Row(1, 3L, "c"),
      Row(2, 1L, "d"),
      Row(3, 2L, "e"))

    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0), 'l.long.at(1)), schema.toAttributes)

    val result = grouped.map {
      case (key, data) =>
        assert(key.numFields == 2)
        (key.getInt(0), key.getLong(1), data.map(encoder.fromRow).toSeq)
    }.toSeq

    assert(result ==
      (1, 2L, Seq(input(0), input(1))) ::
      (1, 3L, Seq(input(2))) ::
      (2, 1L, Seq(input(3))) ::
      (3, 2L, Seq(input(4))) :: Nil)
  }

  test("do nothing to the value iterator") {
    val schema = new StructType().add("i", IntegerType).add("s", StringType)
    val encoder = RowEncoder(schema).resolveAndBind()
    val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
    val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
      Seq('i.int.at(0)), schema.toAttributes)

    assert(grouped.length == 2)
  }
} 
Example 164
Source File: WholeStageCodegenSparkSubmitSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers}
import org.scalatest.concurrent.TimeLimits

import org.apache.spark.{SparkFunSuite, TestUtils}
import org.apache.spark.deploy.SparkSubmitSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{LocalSparkSession, QueryTest, Row, SparkSession}
import org.apache.spark.sql.functions.{array, col, count, lit}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.ResetSystemProperties

// Due to the need to set driver's extraJavaOptions, this test needs to use actual SparkSubmit.
class WholeStageCodegenSparkSubmitSuite extends SparkFunSuite
  with Matchers
  with BeforeAndAfterEach
  with ResetSystemProperties {

  test("Generated code on driver should not embed platform-specific constant") {
    val unusedJar = TestUtils.createJarWithClasses(Seq.empty)

    // HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched
    // settings of UseCompressedOops JVM option.
    val argsForSparkSubmit = Seq(
      "--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"),
      "--master", "local-cluster[1,1,1024]",
      "--driver-memory", "1g",
      "--conf", "spark.ui.enabled=false",
      "--conf", "spark.master.rest.enabled=false",
      "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops",
      "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops",
      unusedJar.toString)
    SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..")
  }
}

object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging {

  var spark: SparkSession = _

  def main(args: Array[String]): Unit = {
    TestUtils.configTestLog4j("INFO")

    spark = SparkSession.builder().getOrCreate()

    // Make sure the test is run where the driver and the executors uses different object layouts
    val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET
    val executorArrayHeaderSize =
      spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect.head.toInt
    assert(driverArrayHeaderSize > executorArrayHeaderSize)

    val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v")
      .groupBy(array(col("v"))).agg(count(col("*")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)

    val expectedAnswer =
      Row(Array(0), 7178) ::
        Row(Array(1), 7178) ::
        Row(Array(2), 7178) ::
        Row(Array(3), 7177) ::
        Row(Array(4), 7177) ::
        Row(Array(5), 7177) ::
        Row(Array(6), 7177) ::
        Row(Array(7), 7177) ::
        Row(Array(8), 7177) ::
        Row(Array(9), 7177) :: Nil
    val result = df.collect
    QueryTest.sameRows(result.toSeq, expectedAnswer) match {
      case Some(errMsg) => fail(errMsg)
      case _ =>
    }
  }
} 
Example 165
Source File: BlockingSource.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import java.util.concurrent.CountDownLatch

import org.apache.spark.sql.{SQLContext, _}
import org.apache.spark.sql.execution.streaming.{LongOffset, Offset, Sink, Source}
import org.apache.spark.sql.sources.{StreamSinkProvider, StreamSourceProvider}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class BlockingSource extends StreamSourceProvider with StreamSinkProvider {

  private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    BlockingSource.latch.await()
    new Source {
      override def schema: StructType = fakeSchema
      override def getOffset: Option[Offset] = Some(new LongOffset(0))
      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
        import spark.implicits._
        Seq[Int]().toDS().toDF()
      }
      override def stop() {}
    }
  }

  override def createSink(
      spark: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink = {
    new Sink {
      override def addBatch(batchId: Long, data: DataFrame): Unit = {}
    }
  }
}

object BlockingSource {
  var latch: CountDownLatch = null
} 
Example 166
Source File: MockSourceProvider.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.streaming.util

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.execution.streaming.Source
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}


class MockSourceProvider extends StreamSourceProvider {
  override def sourceSchema(
      spark: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType) = {
    ("dummySource", MockSourceProvider.fakeSchema)
  }

  override def createSource(
      spark: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source = {
    MockSourceProvider.sourceProviderFunction()
  }
}

object MockSourceProvider {
  // Function to generate sources. May provide multiple sources if the user implements such a
  // function.
  private var sourceProviderFunction: () => Source = _

  final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)

  def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = {
    var i = 0
    val sources = source +: otherSources
    sourceProviderFunction = () => {
      val source = sources(i % sources.length)
      i += 1
      source
    }
    try {
      f
    } finally {
      sourceProviderFunction = null
    }
  }
} 
Example 167
Source File: dependenciesSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.TableDependencyCalculator
import org.apache.spark.sql.sources.{RelationKind, Table}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext}

object DependenciesSystemTableProvider extends SystemTableProvider with LocalSpark {
  
  override def execute(): Seq[Row] = {
    val tables = getTables(sqlContext.catalog)
    val dependentsMap = buildDependentsMap(tables)

    def kindOf(tableIdentifier: TableIdentifier): String =
      tables
        .get(tableIdentifier)
        .map(plan => RelationKind.kindOf(plan).getOrElse(Table).name)
        .getOrElse(DependenciesSystemTable.UnknownType)
        .toUpperCase

    dependentsMap.flatMap {
      case (tableIdent, dependents) =>
        val curKind = kindOf(tableIdent)
        dependents.map { dependent =>
          val dependentKind = kindOf(dependent)
          Row(
            tableIdent.database.orNull,
            tableIdent.table,
            curKind,
            dependent.database.orNull,
            dependent.table,
            dependentKind,
            ReferenceDependency.id)
        }
    }.toSeq
  }

  override val schema: StructType = DependenciesSystemTable.schema
}

object DependenciesSystemTable extends SchemaEnumeration {
  val baseSchemaName = Field("BASE_SCHEMA_NAME", StringType, nullable = true)
  val baseObjectName = Field("BASE_OBJECT_NAME", StringType, nullable = false)
  val baseObjectType = Field("BASE_OBJECT_TYPE", StringType, nullable = false)
  val dependentSchemaName = Field("DEPENDENT_SCHEMA_NAME", StringType, nullable = true)
  val dependentObjectName = Field("DEPENDENT_OBJECT_NAME", StringType, nullable = false)
  val dependentObjectType = Field("DEPENDENT_OBJECT_TYPE", StringType, nullable = false)
  val dependencyType = Field("DEPENDENCY_TYPE", IntegerType, nullable = false)

  private[DependenciesSystemTable] val UnknownType = "UNKNOWN"
} 
Example 168
Source File: partitionFunctionSystemTable.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis.systables

import org.apache.spark.sql.execution.tablefunctions.OutputFormatter
import org.apache.spark.sql.sources._
import org.apache.spark.sql.{DatasourceResolver, Row, SQLContext}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.apache.spark.sql.util.GenericUtil._


  private def typeNameOf(f: PartitionFunction): String = f match {
    case _: RangePartitionFunction => "RANGE"
    case _: BlockPartitionFunction => "BLOCK"
    case _: HashPartitionFunction => "HASH"
  }
}

object PartitionFunctionSystemTable extends SchemaEnumeration {
  val id = Field("ID", StringType, nullable = false)
  val functionType = Field("TYPE", StringType, nullable = false)
  val columnName = Field("COLUMN_NAME", StringType, nullable = false)
  val columnType = Field("COLUMN_TYPE", StringType, nullable = false)
  val boundaries = Field("BOUNDARIES", StringType, nullable = true)
  val block = Field("BLOCK_SIZE", IntegerType, nullable = true)
  val partitions = Field("PARTITIONS", IntegerType, nullable = true)
  val minP = Field("MIN_PARTITIONS", IntegerType, nullable = true)
  val maxP = Field("MAX_PARTITIONS", IntegerType, nullable = true)
} 
Example 169
Source File: HierarchyJoinBuilderUnitTests.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hierarchy

import org.apache.spark.sql.types.{IntegerType, Node}
import org.apache.spark.Logging
import org.apache.spark.sql.Row

// scalastyle:off magic.number
class HierarchyJoinBuilderUnitTests extends NodeUnitTestSpec with Logging {
  var jb = new HierarchyJoinBuilder[Row, Row, Long](null, null, null, null, null, null)

  log.info("Testing function 'extractNodeFromRow'\n")

  val x = Node(List(1,2,3), IntegerType, List(1L,1L,2L))
  Some(x) should equal {
    jb.extractNodeFromRow(Row.fromSeq(Seq(1,2,3, x)))
  }

  None should equal {
    jb.extractNodeFromRow(Row.fromSeq(Seq(1,2,3)))
  }

  None should equal {
    jb.extractNodeFromRow(Row.fromSeq(Seq()))
  }

  log.info("Testing function 'getOrd'\n")
   None should equal {
     jb.getOrd(Row.fromSeq(Seq(1,2,3)))
   }
  val testValues = List((42L, Some(42L)), (13, Some(13L)), ("hello", None), (1234.56, None))
  testValues.foreach(
    testVal => {
      val jbWithOrd = new HierarchyJoinBuilder[Row, Row, Long](null, null, null, null,
        x => testVal._1
        , null)
    testVal._2 should equal {
      jbWithOrd.getOrd(Row.fromSeq(Seq(x)))
    }
    }
  )
}
// scalastyle:on magic.number 
Example 170
Source File: TimestampExpressionSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types.{DateType, IntegerType}
import org.scalatest.FunSuite

class TimestampExpressionSuite extends FunSuite with ExpressionEvalHelper {

  test("add_seconds") {
    // scalastyle:off magic.number
    checkEvaluation(AddSeconds(Literal(Timestamp.valueOf("2015-01-01 00:11:33")), Literal(28)),
      DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-01-01 00:12:01")))
    checkEvaluation(AddSeconds(Literal(Timestamp.valueOf("2015-01-02 00:00:00")), Literal(-1)),
      DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-01-01 23:59:59")))
    checkEvaluation(AddSeconds(Literal(Timestamp.valueOf("2015-01-01 00:00:00")), Literal(-1)),
      DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2014-12-31 23:59:59")))
    checkEvaluation(AddSeconds(Literal(Timestamp.valueOf("2015-01-02 00:00:00")),
      Literal.create(null, IntegerType)), null)
    checkEvaluation(AddSeconds(Literal.create(null, DateType), Literal(1)), null)
    checkEvaluation(AddSeconds(Literal.create(null, DateType), Literal.create(null, IntegerType)),
      null)
  }
} 
Example 171
Source File: ResolveCountDistinctStarSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalatest.FunSuite
import org.scalatest.Inside._
import org.scalatest.mock.MockitoSugar
import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}

import scala.collection.mutable.ArrayBuffer

class ResolveCountDistinctStarSuite extends FunSuite with MockitoSugar {
  val persons = new LogicalRelation(new BaseRelation {
    override def sqlContext: SQLContext = mock[SQLContext]
    override def schema: StructType = StructType(Seq(
      StructField("age", IntegerType),
      StructField("name", StringType)
    ))
  })

  test("Count distinct star is resolved correctly") {
    val projection = persons.select(UnresolvedAlias(
      AggregateExpression(Count(UnresolvedStar(None) :: Nil), Complete, true)))
    val stillNotCompletelyResolvedAggregate = SimpleAnalyzer.execute(projection)
    val resolvedAggregate = ResolveCountDistinctStar(SimpleAnalyzer)
                              .apply(stillNotCompletelyResolvedAggregate)
    inside(resolvedAggregate) {
      case Aggregate(Nil,
      ArrayBuffer(Alias(AggregateExpression(Count(expressions), Complete, true), _)), _) =>
        assert(expressions.collect {
          case a:AttributeReference => a.name
        }.toSet == Set("name", "age"))
    }
    assert(resolvedAggregate.resolved)
  }
} 
Example 172
Source File: HttpStreamServerClientTest.scala    From spark-http-stream   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
import org.apache.spark.SparkConf
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.streaming.http.HttpStreamClient
import org.junit.Assert
import org.junit.Test
import org.apache.spark.sql.types.LongType
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.FloatType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.ByteType
import org.apache.spark.sql.execution.streaming.http.HttpStreamServer
import org.apache.spark.sql.execution.streaming.http.StreamPrinter
import org.apache.spark.sql.execution.streaming.http.HttpStreamServerSideException


class HttpStreamServerClientTest {
	val ROWS1 = Array(Row("hello1", 1, true, 0.1f, 0.1d, 1L, '1'.toByte),
		Row("hello2", 2, false, 0.2f, 0.2d, 2L, '2'.toByte),
		Row("hello3", 3, true, 0.3f, 0.3d, 3L, '3'.toByte));

	val ROWS2 = Array(Row("hello"),
		Row("world"),
		Row("bye"),
		Row("world"));

	@Test
	def testHttpStreamIO() {
		//starts a http server
		val kryoSerializer = new KryoSerializer(new SparkConf());
		val server = HttpStreamServer.start("/xxxx", 8080);

		val spark = SparkSession.builder.appName("testHttpTextSink").master("local[4]")
			.getOrCreate();
		spark.conf.set("spark.sql.streaming.checkpointLocation", "/tmp/");

		val sqlContext = spark.sqlContext;
		import spark.implicits._
		//add a local message buffer to server, with 2 topics registered
		server.withBuffer()
			.addListener(new StreamPrinter())
			.createTopic[(String, Int, Boolean, Float, Double, Long, Byte)]("topic-1")
			.createTopic[String]("topic-2");

		val client = HttpStreamClient.connect("http://localhost:8080/xxxx");
		//tests schema of topics
		val schema1 = client.fetchSchema("topic-1");
		Assert.assertArrayEquals(Array[Object](StringType, IntegerType, BooleanType, FloatType, DoubleType, LongType, ByteType),
			schema1.fields.map(_.dataType).asInstanceOf[Array[Object]]);

		val schema2 = client.fetchSchema("topic-2");
		Assert.assertArrayEquals(Array[Object](StringType),
			schema2.fields.map(_.dataType).asInstanceOf[Array[Object]]);

		//prepare to consume messages
		val sid1 = client.subscribe("topic-1")._1;
		val sid2 = client.subscribe("topic-2")._1;

		//produces some data
		client.sendRows("topic-1", 1, ROWS1);

		val sid4 = client.subscribe("topic-1")._1;
		val sid5 = client.subscribe("topic-2")._1;

		client.sendRows("topic-2", 1, ROWS2);

		//consumes data
		val fetched = client.fetchStream(sid1).map(_.originalRow);
		Assert.assertArrayEquals(ROWS1.asInstanceOf[Array[Object]], fetched.asInstanceOf[Array[Object]]);
		//it is empty now
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid1).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(ROWS2.asInstanceOf[Array[Object]], client.fetchStream(sid2).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid4).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(ROWS2.asInstanceOf[Array[Object]], client.fetchStream(sid5).map(_.originalRow).asInstanceOf[Array[Object]]);
		Assert.assertArrayEquals(Array[Object](), client.fetchStream(sid5).map(_.originalRow).asInstanceOf[Array[Object]]);

		client.unsubscribe(sid4);
		try {
			client.fetchStream(sid4);
			//exception should be thrown, because subscriber id is invalidated
			Assert.assertTrue(false);
		}
		catch {
			case e: Throwable ⇒
				e.printStackTrace();
				Assert.assertEquals(classOf[HttpStreamServerSideException], e.getClass);
		}

		server.stop();
	}
} 
Example 173
Source File: MultinomialLogisticRegressionParitySpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.parity.classification

import org.apache.spark.ml.classification.LogisticRegressionModel
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.{Matrices, Vectors}
import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}

class MultinomialLogisticRegressionParitySpec extends SparkParityBase {

  val labels = Seq(0.0, 1.0, 2.0, 0.0, 1.0, 2.0)
  val ages = Seq(15, 30, 40, 50, 15, 80)
  val heights = Seq(175, 190, 155, 160, 170, 180)
  val weights = Seq(67, 100, 57, 56, 56, 88)

  val rows = spark.sparkContext.parallelize(Seq.tabulate(6) { i => Row(labels(i), ages(i), heights(i), weights(i)) })
  val schema = new StructType().add("label", DoubleType, nullable = false)
    .add("age", IntegerType, nullable = false)
    .add("height", IntegerType, nullable = false)
    .add("weight", IntegerType, nullable = false)

  override val dataset: DataFrame = spark.sqlContext.createDataFrame(rows, schema)

  override val sparkTransformer: Transformer = new Pipeline().setStages(Array(
    new VectorAssembler().
      setInputCols(Array("age", "height", "weight")).
      setOutputCol("features"),
    new LogisticRegressionModel(uid = "logr", 
      coefficientMatrix = Matrices.dense(3, 3, Array(-1.3920551604166562, -0.13119545493644366, 1.5232506153530998, 0.3129112131192873, -0.21959056436528473, -0.09332064875400257, -0.24696506013528507, 0.6122879917796569, -0.36532293164437174)),
      interceptVector = Vectors.dense(0.4965574044951358, -2.1486146169780063, 1.6520572124828703),
      numClasses = 3, isMultinomial = true))).fit(dataset)
} 
Example 174
Source File: WrappersSpec.scala    From sparksql-scalapb   with Apache License 2.0 5 votes vote down vote up
package scalapb.spark

import com.example.protos.wrappers._
import org.apache.spark.sql.SparkSession
import org.apache.hadoop.io.ArrayPrimitiveWritable
import scalapb.GeneratedMessageCompanion
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.Row

import org.scalatest.BeforeAndAfterAll
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.must.Matchers

class WrappersSpec extends AnyFlatSpec with Matchers with BeforeAndAfterAll {
  val spark: SparkSession = SparkSession
    .builder()
    .appName("ScalaPB Demo")
    .master("local[2]")
    .getOrCreate()

  import spark.implicits.StringToColumn

  val data = Seq(
    PrimitiveWrappers(
      intValue = Option(45),
      stringValue = Option("boo"),
      ints = Seq(17, 19, 25),
      strings = Seq("foo", "bar")
    ),
    PrimitiveWrappers(
      intValue = None,
      stringValue = None,
      ints = Seq(17, 19, 25),
      strings = Seq("foo", "bar")
    )
  )

  "converting df with primitive wrappers" should "work with primitive implicits" in {
    import ProtoSQL.withPrimitiveWrappers.implicits._
    val df = ProtoSQL.withPrimitiveWrappers.createDataFrame(spark, data)
    df.schema.fields.map(_.dataType).toSeq must be(
      Seq(
        IntegerType,
        StringType,
        ArrayType(IntegerType, false),
        ArrayType(StringType, false)
      )
    )
    df.collect must contain theSameElementsAs (
      Seq(
        Row(45, "boo", Seq(17, 19, 25), Seq("foo", "bar")),
        Row(null, null, Seq(17, 19, 25), Seq("foo", "bar"))
      )
    )
  }

  "converting df with primitive wrappers" should "work with default implicits" in {
    import ProtoSQL.implicits._
    val df = ProtoSQL.createDataFrame(spark, data)
    df.schema.fields.map(_.dataType).toSeq must be(
      Seq(
        StructType(Seq(StructField("value", IntegerType, true))),
        StructType(Seq(StructField("value", StringType, true))),
        ArrayType(
          StructType(Seq(StructField("value", IntegerType, true))),
          false
        ),
        ArrayType(
          StructType(Seq(StructField("value", StringType, true))),
          false
        )
      )
    )
    df.collect must contain theSameElementsAs (
      Seq(
        Row(
          Row(45),
          Row("boo"),
          Seq(Row(17), Row(19), Row(25)),
          Seq(Row("foo"), Row("bar"))
        ),
        Row(
          null,
          null,
          Seq(Row(17), Row(19), Row(25)),
          Seq(Row("foo"), Row("bar"))
        )
      )
    )
  }
} 
Example 175
Source File: KustoSourceTests.scala    From azure-kusto-spark   with Apache License 2.0 5 votes vote down vote up
package com.microsoft.kusto.spark

import com.microsoft.kusto.spark.datasource.KustoSourceOptions
import com.microsoft.kusto.spark.utils.{KustoDataSourceUtils => KDSU}
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, SparkSession}
import org.junit.runner.RunWith
import org.scalamock.scalatest.MockFactory
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class KustoSourceTests extends FlatSpec with MockFactory with Matchers with BeforeAndAfterAll {
  private val loggingLevel: Option[String] = Option(System.getProperty("logLevel"))
  if (loggingLevel.isDefined) KDSU.setLoggingLevel(loggingLevel.get)

  private val nofExecutors = 4
  private val spark: SparkSession = SparkSession.builder()
    .appName("KustoSource")
    .master(f"local[$nofExecutors]")
    .getOrCreate()

  private var sc: SparkContext = _
  private var sqlContext: SQLContext = _
  private val cluster: String = "KustoCluster"
  private val database: String = "KustoDatabase"
  private val query: String = "KustoTable"
  private val appId: String = "KustoSinkTestApplication"
  private val appKey: String = "KustoSinkTestKey"
  private val appAuthorityId: String = "KustoSinkAuthorityId"

  override def beforeAll(): Unit = {
    super.beforeAll()

    sc = spark.sparkContext
    sqlContext = spark.sqlContext
  }

  override def afterAll(): Unit = {
    super.afterAll()

    sc.stop()
  }

  "KustoDataSource" should "recognize Kusto and get the correct schema" in {
    val spark: SparkSession = SparkSession.builder()
      .appName("KustoSource")
      .master(f"local[$nofExecutors]")
      .getOrCreate()

    val customSchema = "colA STRING, colB INT"

    val df = spark.sqlContext
      .read
      .format("com.microsoft.kusto.spark.datasource")
      .option(KustoSourceOptions.KUSTO_CLUSTER, cluster)
      .option(KustoSourceOptions.KUSTO_DATABASE, database)
      .option(KustoSourceOptions.KUSTO_QUERY, query)
      .option(KustoSourceOptions.KUSTO_AAD_APP_ID, appId)
      .option(KustoSourceOptions.KUSTO_AAD_APP_SECRET, appKey)
      .option(KustoSourceOptions.KUSTO_AAD_AUTHORITY_ID, appAuthorityId)
      .option(KustoSourceOptions.KUSTO_CUSTOM_DATAFRAME_COLUMN_TYPES, customSchema)
      .load("src/test/resources/")

    val expected = StructType(Array(StructField("colA", StringType, nullable = true),StructField("colB", IntegerType, nullable = true)))
    assert(df.schema.equals(expected))
  }
} 
Example 176
Source File: ShortestPaths.scala    From graphframes   with Apache License 2.0 5 votes vote down vote up
package org.graphframes.lib

import java.util

import scala.collection.JavaConverters._

import org.apache.spark.graphx.{lib => graphxlib}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, MapType}

import org.graphframes.GraphFrame


  def landmarks(value: util.ArrayList[Any]): this.type = {
    landmarks(value.asScala)
  }

  def run(): DataFrame = {
    ShortestPaths.run(graph, check(lmarks, "landmarks"))
  }
}

private object ShortestPaths {

  private def run(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = {
    val idType = graph.vertices.schema(GraphFrame.ID).dataType
    val longIdToLandmark = landmarks.map(l => GraphXConversions.integralId(graph, l) -> l).toMap
    val gx = graphxlib.ShortestPaths.run(
      graph.cachedTopologyGraphX,
      longIdToLandmark.keys.toSeq.sorted).mapVertices { case (_, m) => m.toSeq }
    val g = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(DISTANCE_ID))
    val distanceCol: Column = if (graph.hasIntegralIdType) {
      // It seems there are no easy way to convert a sequence of pairs into a map
      val mapToLandmark = udf { distances: Seq[Row] =>
        distances.map { case Row(k: Long, v: Int) =>
          k -> v
        }.toMap
      }
      mapToLandmark(g.vertices(DISTANCE_ID))
    } else {
      val func = new UDF1[Seq[Row], Map[Any, Int]] {
        override def call(t1: Seq[Row]): Map[Any, Int] = {
          t1.map { case Row(k: Long, v: Int) =>
              longIdToLandmark(k) -> v
          }.toMap
        }
      }
      val mapToLandmark = udf(func, MapType(idType, IntegerType, false))
      mapToLandmark(col(DISTANCE_ID))
    }
    val cols = graph.vertices.columns.map(col) :+ distanceCol.as(DISTANCE_ID)
    g.vertices.select(cols: _*)
  }

  private val DISTANCE_ID = "distances"

} 
Example 177
Source File: LastValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.lastValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class LastValueOperatorTest extends WordSpec with Matchers {

  "LastValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new LastValueOperator("lastValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(2))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("b"))
    }

    "associative process must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))

      val inputFields4 = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new LastValueOperator("lastValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
    }
  }
} 
Example 178
Source File: StddevOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.stddev

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class StddevOperatorTest extends WordSpec with Matchers {

  "Std dev operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new StddevOperator("stdev", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new StddevOperator("stdev", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be
      (Some(2.8284271247461903))

      val inputFields3 = new StddevOperator("stdev", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be
      (Some(2.850438562747845))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(
        Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
    }

    "processReduce distinct must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be
      (Some(2.8284271247461903))

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be
      (Some(2.850438562747845))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(
        Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
    }
  }
} 
Example 179
Source File: MedianOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.median

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MedianOperatorTest extends WordSpec with Matchers {

  "Median operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MedianOperator("median", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MedianOperator("median", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(3d))

      val inputFields3 = new MedianOperator("median", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("3.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(2.5))

      val inputFields3 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("2.5"))
    }
  }
} 
Example 180
Source File: ModeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mode

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ModeOperatorTest extends WordSpec with Matchers {

  "Mode operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new ModeOperator("mode", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new ModeOperator("mode", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new ModeOperator("mode", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new ModeOperator("mode", initSchema, Map())
      inputFields2.processReduce(Seq(Some("hey"), Some("hey"), Some("hi"))) should be(Some(List("hey")))

      val inputFields3 = new ModeOperator("mode", initSchema, Map())
      inputFields3.processReduce(Seq(Some("1"), Some("1"), Some("4"))) should be(Some(List("1")))

      val inputFields4 = new ModeOperator("mode", initSchema, Map())
      inputFields4.processReduce(Seq(
        Some("1"), Some("1"), Some("4"), Some("4"), Some("4"), Some("4"))) should be(Some(List("4")))

      val inputFields5 = new ModeOperator("mode", initSchema, Map())
      inputFields5.processReduce(Seq(
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"))) should be(Some(List("1", "2", "4")))

      val inputFields6 = new ModeOperator("mode", initSchema, Map())
      inputFields6.processReduce(Seq(
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"), Some("5"))
      ) should be(Some(List("1", "2", "4")))
    }
  }
} 
Example 181
Source File: RangeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.range

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class RangeOperatorTest extends WordSpec with Matchers {

  "Range operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new RangeOperator("range", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new RangeOperator("range", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
    }
  }
} 
Example 182
Source File: AccumulatorOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.accumulator

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class AccumulatorOperatorTest extends WordSpec with Matchers {

  "Accumulator operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new AccumulatorOperator("accumulator", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new AccumulatorOperator("accumulator", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":2}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":2}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(Seq("1", "1")))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(Seq("a", "b")))
    }

    "associative process must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Seq(1L))),
        (Operator.NewValuesKey, Some(Seq(2L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Seq("1", "2")))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> "arraydouble"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(3))))
      inputFields2.associativity(resultInput2) should be(Some(Seq(1d, 3d)))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(1))))
      inputFields3.associativity(resultInput3) should be(Some(Seq("1", "1")))
    }
  }
} 
Example 183
Source File: FirstValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.firstValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class FirstValueOperatorTest extends WordSpec with Matchers {

  "FirstValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new FirstValueOperator("firstValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("a"))
    }

    "associative process must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(1)),
        (Operator.NewValuesKey, None))
      inputFields3.associativity(resultInput3) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new FirstValueOperator("firstValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
    }
  }
} 
Example 184
Source File: MeanAssociativeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MeanAssociativeOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MeanAssociativeOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanAssociativeOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be (Some(List(1.0, 1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(List(1.0, 2.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(List()))
    }

    "processReduce distinct must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(List(1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(List(1.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(List()))
    }

    "associative process must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput =
        Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("count" -> 1.0, "sum" -> 2.0, "mean" -> 2.0)))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))),
        (Operator.NewValuesKey, Some(Seq(1d))))
      inputFields2.associativity(resultInput2) should be(Some(Map("sum" -> 3.0, "count" -> 2.0, "mean" -> 1.5)))
    }
  }
} 
Example 185
Source File: MeanOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MeanOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MeanOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
    }
  }
} 
Example 186
Source File: EntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class EntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new EntityCountOperator("entityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new EntityCountOperator("entityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be(Some(Seq("hola", "holo")))
    }

    "processReduce must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(Seq("hola", "holo")))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(Seq("hola", "holo", "hola")))
    }

    "associative process must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, Some(Seq("hola"))))
      inputFields2.associativity(resultInput2) should be(Some(Map()))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))))
      inputFields3.associativity(resultInput3) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))
    }
  }
} 
Example 187
Source File: SumOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.sum

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class SumOperatorTest extends WordSpec with Matchers {

  "Sum operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new SumOperator("sum", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new SumOperator("sum", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new SumOperator("sum", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new SumOperator("sum", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(20d))

      val inputFields3 = new SumOperator("sum", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(20d))

      val inputFields4 = new SumOperator("sum", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))
    }

    "processReduce distinct must be " in {
      val inputFields = new SumOperator("sum", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new SumOperator("sum", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(1))) should be(Some(3d))
    }

    "associative process must be " in {
      val inputFields = new SumOperator("count", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(2d))

      val inputFields2 = new SumOperator("count", initSchema, Map("typeOp" -> "string"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields2.associativity(resultInput2) should be(Some("2.0"))
    }
  }
} 
Example 188
Source File: FullTextOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.fullText

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class FullTextOperatorTest extends WordSpec with Matchers {

  "FullText operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new FullTextOperator("fullText", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FullTextOperator("fullText", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new FullTextOperator("fullText", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(""))

      val inputFields2 = new FullTextOperator("fullText", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(s"1${Operator.SpaceSeparator}1"))

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(s"a${Operator.SpaceSeparator}b"))
    }

    "associative process must be " in {
      val inputFields = new FullTextOperator("fullText", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some("2"))

      val inputFields2 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> "arraystring"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(2)),
        (Operator.NewValuesKey, Some(1)))
      inputFields2.associativity(resultInput2) should be(Some(Seq(s"2${Operator.SpaceSeparator}1")))

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(2)), (Operator.OldValuesKey, Some(3)))
      inputFields3.associativity(resultInput3) should be(Some(s"2${Operator.SpaceSeparator}3"))
    }
  }
} 
Example 189
Source File: TotalEntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.totalEntityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class TotalEntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new TotalEntityCountOperator("totalEntityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new TotalEntityCountOperator("totalEntityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be
      (Some(Seq("hola", "holo")))
    }

    "processReduce must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0L))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(2L))

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(3L))

      val inputFields4 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0L))
    }

    "processReduce distinct must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0L))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(2L))
    }

    "associative process must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(2))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(2)),
        (Operator.NewValuesKey, Some(1)))
      inputFields2.associativity(resultInput2) should be(Some(3))

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))
    }
  }
} 
Example 190
Source File: StatisticsTest.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources.oap.statistics

import java.io.ByteArrayOutputStream

import scala.collection.mutable.ArrayBuffer

import org.scalatest.BeforeAndAfterEach

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.BaseOrdering
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
import org.apache.spark.sql.execution.datasources.oap.filecache.FiberCache
import org.apache.spark.sql.execution.datasources.oap.index.RangeInterval
import org.apache.spark.sql.execution.datasources.oap.utils.{NonNullKeyReader, NonNullKeyWriter}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.memory.MemoryBlock
import org.apache.spark.unsafe.types.UTF8String

abstract class StatisticsTest extends SparkFunSuite with BeforeAndAfterEach {

  protected def rowGen(i: Int): InternalRow = InternalRow(i, UTF8String.fromString(s"test#$i"))

  protected lazy val schema: StructType = StructType(StructField("a", IntegerType)
    :: StructField("b", StringType) :: Nil)
  @transient
  protected lazy val nnkw: NonNullKeyWriter = new NonNullKeyWriter(schema)
  @transient
  protected lazy val nnkr: NonNullKeyReader = new NonNullKeyReader(schema)
  @transient
  protected lazy val ordering: BaseOrdering = GenerateOrdering.create(schema)
  @transient
  protected lazy val partialOrdering: BaseOrdering =
    GenerateOrdering.create(StructType(schema.dropRight(1)))
  protected var out: ByteArrayOutputStream = _

  protected var intervalArray: ArrayBuffer[RangeInterval] = new ArrayBuffer[RangeInterval]()

  override def beforeEach(): Unit = {
    out = new ByteArrayOutputStream(8000)
  }

  override def afterEach(): Unit = {
    out.close()
    intervalArray.clear()
  }

  protected def generateInterval(
      start: InternalRow, end: InternalRow,
      startInclude: Boolean, endInclude: Boolean): Unit = {
    intervalArray.clear()
    intervalArray.append(new RangeInterval(start, end, startInclude, endInclude))
  }

  protected def checkInternalRow(row1: InternalRow, row2: InternalRow): Unit = {
    val res = row1 == row2 // it works..
    assert(res, s"row1: $row1 does not match $row2")
  }

  protected def wrapToFiberCache(out: ByteArrayOutputStream): FiberCache = {
    val bytes = out.toByteArray
    FiberCache(bytes)
  }
} 
Example 191
Source File: MLPipelineTrackerIT.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.ml

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.MinMaxScaler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.scalatest.Matchers
import com.hortonworks.spark.atlas._
import com.hortonworks.spark.atlas.types._
import com.hortonworks.spark.atlas.TestUtils._

class MLPipelineTrackerIT extends BaseResourceIT with Matchers with WithHiveSupport {
  private val atlasClient = new RestAtlasClient(atlasClientConf)

  def clusterName: String = atlasClientConf.get(AtlasClientConf.CLUSTER_NAME)

  def getTableEntity(tableName: String): SACAtlasEntityWithDependencies = {
    val dbDefinition = createDB("db1", "hdfs:///test/db/db1")
    val sd = createStorageFormat()
    val schema = new StructType()
      .add("user", StringType, false)
      .add("age", IntegerType, true)
    val tableDefinition = createTable("db1", s"$tableName", schema, sd)
    internal.sparkTableToEntity(tableDefinition, clusterName, Some(dbDefinition))
  }

  // Enable it to run integrated test.
  it("pipeline and pipeline model") {
    val uri = "hdfs://"
    val pipelineDir = "tmp/pipeline"
    val modelDir = "tmp/model"

    val pipelineDirEntity = internal.mlDirectoryToEntity(uri, pipelineDir)
    val modelDirEntity = internal.mlDirectoryToEntity(uri, modelDir)

    atlasClient.createEntitiesWithDependencies(Seq(pipelineDirEntity, modelDirEntity))

    val df = sparkSession.createDataFrame(Seq(
      (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
      (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
      (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
      (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)
    )).toDF("id", "features", "label")

    val scaler = new MinMaxScaler()
      .setInputCol("features")
      .setOutputCol("features_scaled")
      .setMin(0.0)
      .setMax(3.0)
    val pipeline = new Pipeline().setStages(Array(scaler))

    val model = pipeline.fit(df)

    pipeline.write.overwrite().save(pipelineDir)

    val pipelineEntity = internal.mlPipelineToEntity(pipeline.uid, pipelineDirEntity)

    atlasClient.createEntitiesWithDependencies(Seq(pipelineDirEntity, pipelineEntity))

    val modelEntity = internal.mlModelToEntity(model.uid, modelDirEntity)

    atlasClient.createEntitiesWithDependencies(Seq(modelDirEntity, modelEntity))

    val tableEntities1 = getTableEntity("chris1")
    val tableEntities2 = getTableEntity("chris2")

    atlasClient.createEntitiesWithDependencies(tableEntities1)
    atlasClient.createEntitiesWithDependencies(tableEntities2)

  }
} 
Example 192
Source File: MLAtlasEntityUtilsSuite.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.types

import java.io.File

import org.apache.atlas.{AtlasClient, AtlasConstants}
import org.apache.atlas.model.instance.AtlasEntity
import org.apache.commons.io.FileUtils
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.MinMaxScaler
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
import org.scalatest.{FunSuite, Matchers}
import com.hortonworks.spark.atlas.TestUtils._
import com.hortonworks.spark.atlas.{AtlasUtils, WithHiveSupport}

class MLAtlasEntityUtilsSuite extends FunSuite with Matchers with WithHiveSupport {

  def getTableEntity(tableName: String): AtlasEntity = {
    val dbDefinition = createDB("db1", "hdfs:///test/db/db1")
    val sd = createStorageFormat()
    val schema = new StructType()
      .add("user", StringType, false)
      .add("age", IntegerType, true)
    val tableDefinition = createTable("db1", s"$tableName", schema, sd)

    val tableEntities = internal.sparkTableToEntity(
      tableDefinition, AtlasConstants.DEFAULT_CLUSTER_NAME, Some(dbDefinition))
    val tableEntity = tableEntities.entity

    tableEntity
  }

  test("pipeline, pipeline model, fit and transform") {
    val uri = "/"
    val pipelineDir = "tmp/pipeline"
    val modelDir = "tmp/model"

    val pipelineDirEntity = internal.mlDirectoryToEntity(uri, pipelineDir)
    pipelineDirEntity.entity.getAttribute("uri") should be (uri)
    pipelineDirEntity.entity.getAttribute("directory") should be (pipelineDir)
    pipelineDirEntity.dependencies.length should be (0)

    val modelDirEntity = internal.mlDirectoryToEntity(uri, modelDir)
    modelDirEntity.entity.getAttribute("uri") should be (uri)
    modelDirEntity.entity.getAttribute("directory") should be (modelDir)
    modelDirEntity.dependencies.length should be (0)

    val df = sparkSession.createDataFrame(Seq(
      (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
      (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
      (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
      (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)
    )).toDF("id", "features", "label")

    val scaler = new MinMaxScaler()
      .setInputCol("features")
      .setOutputCol("features_scaled")
      .setMin(0.0)
      .setMax(3.0)
    val pipeline = new Pipeline().setStages(Array(scaler))

    val model = pipeline.fit(df)

    pipeline.write.overwrite().save(pipelineDir)

    val pipelineEntity = internal.mlPipelineToEntity(pipeline.uid, pipelineDirEntity)
    pipelineEntity.entity.getTypeName should be (metadata.ML_PIPELINE_TYPE_STRING)
    pipelineEntity.entity.getAttribute(AtlasClient.REFERENCEABLE_ATTRIBUTE_NAME) should be (
      pipeline.uid)
    pipelineEntity.entity.getAttribute("name") should be (pipeline.uid)
    pipelineEntity.entity.getRelationshipAttribute("directory") should be (
      AtlasUtils.entityToReference(pipelineDirEntity.entity, useGuid = false))
    pipelineEntity.dependencies should be (Seq(pipelineDirEntity))

    val modelEntity = internal.mlModelToEntity(model.uid, modelDirEntity)
    val modelUid = model.uid.replaceAll("pipeline", "model")
    modelEntity.entity.getTypeName should be (metadata.ML_MODEL_TYPE_STRING)
    modelEntity.entity.getAttribute(AtlasClient.REFERENCEABLE_ATTRIBUTE_NAME) should be (modelUid)
    modelEntity.entity.getAttribute("name") should be (modelUid)
    modelEntity.entity.getRelationshipAttribute("directory") should be (
      AtlasUtils.entityToReference(modelDirEntity.entity, useGuid = false))

    modelEntity.dependencies should be (Seq(modelDirEntity))

    FileUtils.deleteDirectory(new File("tmp"))
  }
} 
Example 193
Source File: ExtraOperationsSuite.scala    From tensorframes   with Apache License 2.0 5 votes vote down vote up
package org.tensorframes

import org.apache.spark.sql.types.{DoubleType, IntegerType}
import org.scalatest.FunSuite
import org.tensorframes.impl.{ScalarDoubleType, ScalarIntType}


class ExtraOperationsSuite
    extends FunSuite with TensorFramesTestSparkContext with Logging {
  lazy val sql = sqlContext
  import ExtraOperations._
  import sql.implicits._
  import Shape.Unknown

  test("simple test for doubles") {
    val df = Seq(Tuple1(0.0)).toDF("a")
    val di = ExtraOperations.explainDetailed(df)
    val Seq(c1) = di.cols
    val Some(s) = c1.stf
    assert(s.dataType === ScalarDoubleType)
    assert(s.shape === Shape(Unknown))
    logDebug(df.toString() + "->" + di.toString)
  }

  test("simple test for integers") {
    val df = Seq(Tuple1(0)).toDF("a")
    val di = explainDetailed(df)
    val Seq(c1) = di.cols
    val Some(s) = c1.stf
    assert(s.dataType === ScalarIntType)
    assert(s.shape === Shape(Unknown))
    logDebug(df.toString() + "->" + di.toString)
  }

  test("test for arrays") {
    val df = Seq((0.0, Seq(1.0), Seq(Seq(1.0)))).toDF("a", "b", "c")
    val di = explainDetailed(df)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1, c2, c3) = di.cols
    val Some(s1) = c1.stf
    assert(s1.dataType === ScalarDoubleType)
    assert(s1.shape === Shape(Unknown))
    val Some(s2) = c2.stf
    assert(s2.dataType === ScalarDoubleType)
    assert(s2.shape === Shape(Unknown, Unknown))
    val Some(s3) = c3.stf
    assert(s3.dataType === ScalarDoubleType)
    assert(s3.shape === Shape(Unknown, Unknown, Unknown))
  }

  test("simple analysis") {
    val df = Seq(Tuple1(0.0)).toDF("a")
    val df2 = analyze(df)
    val di = explainDetailed(df2)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1) = di.cols
    val Some(s) = c1.stf
    assert(s.dataType === ScalarDoubleType)
    assert(s.shape === Shape(1)) // There is only one partition
  }

  test("simple analysis with multiple partitions of different sizes") {
    val df = Seq.fill(10)(0.0).map(Tuple1.apply).toDF("a").repartition(3)
    val df2 = analyze(df)
    val di = explainDetailed(df2)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1) = di.cols
    val Some(s) = c1.stf
    assert(s.dataType === ScalarDoubleType)
    assert(s.shape === Shape(Unknown)) // There is only one partition
  }

  test("simple analysis with variable sizes") {
    val df = Seq(
      (0.0, Seq(0.0)),
      (1.0, Seq(1.0, 1.0))).toDF("a", "b")
    val df2 = analyze(df)
    val di = explainDetailed(df2)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1, c2) = di.cols
    val Some(s2) = c2.stf
    assert(s2.dataType === ScalarDoubleType)
    assert(s2.shape === Shape(2, Unknown)) // There is only one partition
  }

  test("2nd order analysis") {
    val df = Seq(
      (0.0, Seq(0.0, 0.0)),
      (1.0, Seq(1.0, 1.0)),
      (2.0, Seq(2.0, 2.0))).toDF("a", "b")
    val df2 = analyze(df)
    val di = explainDetailed(df2)
    logDebug(df.toString() + "->" + di.toString)
    val Seq(c1, c2) = di.cols
    val Some(s2) = c2.stf
    assert(s2.dataType === ScalarDoubleType)
    assert(s2.shape === Shape(3, 2)) // There is only one partition
  }
} 
Example 194
Source File: SlicingSuite.scala    From tensorframes   with Apache License 2.0 5 votes vote down vote up
package org.tensorframes

import org.scalatest.FunSuite
import org.tensorframes.dsl.GraphScoping
import org.tensorframes.impl.DebugRowOps
import org.tensorframes.{dsl => tf}
import org.tensorframes.dsl.Implicits._

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{DoubleType, IntegerType}


class SlicingSuite
  extends FunSuite with TensorFramesTestSparkContext with Logging with GraphScoping {
  lazy val sql = sqlContext

  import Shape.Unknown

  val ops = new DebugRowOps

  test("2D - 1") {
    val df = make1(Seq(Seq(1.0, 2.0), Seq(3.0, 4.0)), "x")
    val x = df.block("x")
//    val y =
  }
} 
Example 195
package org.sparksamples

import java.awt.Font

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.jfree.chart.axis.CategoryLabelPositions

import scalax.chart.module.ChartFactories



    val customSchema = StructType(Array(
      StructField("user_id", IntegerType, true),
      StructField("movie_id", IntegerType, true),
      StructField("rating", IntegerType, true),
      StructField("timestamp", IntegerType, true)))

    val spConfig = (new SparkConf).setMaster("local").setAppName("SparkApp")
    val spark = SparkSession
      .builder()
      .appName("SparkRatingData").config(spConfig)
      .getOrCreate()

    val rating_df = spark.read.format("com.databricks.spark.csv")
      .option("delimiter", "\t").schema(customSchema)
      .load("../../data/ml-100k/u.data")

    val rating_df_count = rating_df.groupBy("rating").count().sort("rating")
    //val rating_df_count_sorted = rating_df_count.sort("count")
    rating_df_count.show()
    val rating_df_count_collection = rating_df_count.collect()

    val ds = new org.jfree.data.category.DefaultCategoryDataset
    val mx = scala.collection.immutable.ListMap()

    for( x <- 0 until rating_df_count_collection.length) {
      val occ = rating_df_count_collection(x)(0)
      val count = Integer.parseInt(rating_df_count_collection(x)(1).toString)
      ds.addValue(count,"UserAges", occ.toString)
    }


    //val sorted =  ListMap(ratings_count.toSeq.sortBy(_._1):_*)
    //val ds = new org.jfree.data.category.DefaultCategoryDataset
    //sorted.foreach{ case (k,v) => ds.addValue(v,"Rating Values", k)}

    val chart = ChartFactories.BarChart(ds)
    val font = new Font("Dialog", Font.PLAIN,5);

    chart.peer.getCategoryPlot.getDomainAxis().
      setCategoryLabelPositions(CategoryLabelPositions.UP_90);
    chart.peer.getCategoryPlot.getDomainAxis.setLabelFont(font)
    chart.show()
    Util.sc.stop()
  }
} 
Example 196
package org.sparksamples

import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

import scalax.chart.module.ChartFactories


object UserRatingsChart {

  def main(args: Array[String]) {

    val customSchema = StructType(Array(
      StructField("user_id", IntegerType, true),
      StructField("movie_id", IntegerType, true),
      StructField("rating", IntegerType, true),
      StructField("timestamp", IntegerType, true)))

    val spConfig = (new SparkConf).setMaster("local").setAppName("SparkApp")
    val spark = SparkSession
      .builder()
      .appName("SparkRatingData").config(spConfig)
      .getOrCreate()

    val rating_df = spark.read.format("com.databricks.spark.csv")
      .option("delimiter", "\t").schema(customSchema)
      .load("../../data/ml-100k/u.data")

    val rating_nos_by_user = rating_df.groupBy("user_id").count().sort("count")
    val ds = new org.jfree.data.category.DefaultCategoryDataset
    rating_nos_by_user.show(rating_nos_by_user.collect().length)
    val rating_nos_by_user_collect =rating_nos_by_user.collect()

    var mx = Map(0 -> 0)
    val min = 1
    val max = 1000
    val bins = 100

    val step = (max/bins).toInt
    for (i <- step until (max + step) by step) {
      mx += (i -> 0);
    }
    for( x <- 0 until rating_nos_by_user_collect.length) {
      val user_id = Integer.parseInt(rating_nos_by_user_collect(x)(0).toString)
      val count = Integer.parseInt(rating_nos_by_user_collect(x)(1).toString)
      ds.addValue(count,"Ratings", user_id)
    }
   // ------------------------------------------------------------------

    val chart = ChartFactories.BarChart(ds)
    chart.peer.getCategoryPlot.getDomainAxis().setVisible(false)

    chart.show()
    Util.sc.stop()
  }
} 
Example 197
Source File: UserData.scala    From Machine-Learning-with-Spark-Second-Edition   with MIT License 5 votes vote down vote up
package org.sparksamples.df
//import org.apache.spark.sql.SQLContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession

import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType};
package object UserData {
  def main(args: Array[String]): Unit = {
    val customSchema = StructType(Array(
      StructField("no", IntegerType, true),
      StructField("age", StringType, true),
      StructField("gender", StringType, true),
      StructField("occupation", StringType, true),
      StructField("zipCode", StringType, true)));
    val spConfig = (new SparkConf).setMaster("local").setAppName("SparkApp")
    val spark = SparkSession
      .builder()
      .appName("SparkUserData").config(spConfig)
      .getOrCreate()

    val user_df = spark.read.format("com.databricks.spark.csv")
      .option("delimiter", "|").schema(customSchema)
      .load("/home/ubuntu/work/ml-resources/spark-ml/data/ml-100k/u.user")
    val first = user_df.first()
    println("First Record : " + first)

    val num_genders = user_df.groupBy("gender").count().count()
    val num_occupations = user_df.groupBy("occupation").count().count()
    val num_zipcodes = user_df.groupBy("zipCode").count().count()

    println("num_users : " + user_df.count())
    println("num_genders : "+ num_genders)
    println("num_occupations : "+ num_occupations)
    println("num_zipcodes: " + num_zipcodes)
    println("Distribution by Occupation")
    println(user_df.groupBy("occupation").count().show())

  }
} 
Example 198
Source File: PrettifyTest.scala    From spark-testing-base   with Apache License 2.0 5 votes vote down vote up
package com.holdenkarau.spark.testing

import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.scalacheck.Gen
import org.scalacheck.Prop._
import org.scalacheck.util.Pretty
import org.scalatest.FunSuite
import org.scalatest.exceptions.GeneratorDrivenPropertyCheckFailedException
import org.scalatest.prop.Checkers

class PrettifyTest extends FunSuite with SharedSparkContext with Checkers with Prettify {
  implicit val propertyCheckConfig = PropertyCheckConfig(minSize = 2, maxSize = 2)

  test("pretty output of DataFrame's check") {
    val schema = StructType(List(StructField("name", StringType), StructField("age", IntegerType)))
    val sqlContext = new SQLContext(sc)
    val nameGenerator = new Column("name", Gen.const("Holden Hanafy"))
    val ageGenerator = new Column("age", Gen.const(20))

    val dataframeGen = DataframeGenerator.arbitraryDataFrameWithCustomFields(sqlContext, schema)(nameGenerator, ageGenerator)

    val actual = runFailingCheck(dataframeGen.arbitrary)
    val expected =
      Some("arg0 = <DataFrame: schema = [name: string, age: int], size = 2, values = ([Holden Hanafy,20], [Holden Hanafy,20])>")
    assert(actual == expected)
  }

  test("pretty output of RDD's check") {
    val rddGen = RDDGenerator.genRDD[(String, Int)](sc) {
      for {
        name <- Gen.const("Holden Hanafy")
        age <- Gen.const(20)
      } yield name -> age
    }

    val actual = runFailingCheck(rddGen)
    val expected =
      Some("""arg0 = <RDD: size = 2, values = ((Holden Hanafy,20), (Holden Hanafy,20))>""")
    assert(actual == expected)
  }

  test("pretty output of Dataset's check") {
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._

    val datasetGen = DatasetGenerator.genDataset[(String, Int)](sqlContext) {
      for {
        name <- Gen.const("Holden Hanafy")
        age <- Gen.const(20)
      } yield name -> age
    }

    val actual = runFailingCheck(datasetGen)
    val expected =
      Some("""arg0 = <Dataset: schema = [_1: string, _2: int], size = 2, values = ((Holden Hanafy,20), (Holden Hanafy,20))>""")
    assert(actual == expected)
  }

  private def runFailingCheck[T](genUnderTest: Gen[T])(implicit p: T => Pretty) = {
    val property = forAll(genUnderTest)(_ => false)
    val e = intercept[GeneratorDrivenPropertyCheckFailedException] {
      check(property)
    }
    takeSecondToLastLine(e.message)
  }

  private def takeSecondToLastLine(msg: Option[String]) =
    msg.flatMap(_.split("\n").toList.reverse.tail.headOption.map(_.trim))

} 
Example 199
Source File: MetastoreRelationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class MetastoreRelationSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
  test("makeCopy and toJSON should work") {
    val table = CatalogTable(
      identifier = TableIdentifier("test", Some("db")),
      tableType = CatalogTableType.VIEW,
      storage = CatalogStorageFormat.empty,
      schema = StructType(StructField("a", IntegerType, true) :: Nil))
    val relation = MetastoreRelation("db", "test")(table, null)

    // No exception should be thrown
    relation.makeCopy(Array("db", "test"))
    // No exception should be thrown
    relation.toJSON
  }

  test("SPARK-17409: Do Not Optimize Query in CTAS (Hive Serde Table) More Than Once") {
    withTable("bar") {
      withTempView("foo") {
        sql("select 0 as id").createOrReplaceTempView("foo")
        // If we optimize the query in CTAS more than once, the following saveAsTable will fail
        // with the error: `GROUP BY position 0 is not in select list (valid range is [1, 1])`
        sql("CREATE TABLE bar AS SELECT * FROM foo group by id")
        checkAnswer(spark.table("bar"), Row(0) :: Nil)
        val tableMetadata = spark.sessionState.catalog.getTableMetadata(TableIdentifier("bar"))
        assert(tableMetadata.provider == Some("hive"), "the expected table is a Hive serde table")
      }
    }
  }
} 
Example 200
Source File: HiveClientSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.client

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.conf.HiveConf

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal}
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.types.IntegerType

class HiveClientSuite extends SparkFunSuite {
  private val clientBuilder = new HiveClientBuilder

  private val tryDirectSqlKey = HiveConf.ConfVars.METASTORE_TRY_DIRECT_SQL.varname

  test(s"getPartitionsByFilter returns all partitions when $tryDirectSqlKey=false") {
    val testPartitionCount = 5

    val storageFormat = CatalogStorageFormat(
      locationUri = None,
      inputFormat = None,
      outputFormat = None,
      serde = None,
      compressed = false,
      properties = Map.empty)

    val hadoopConf = new Configuration()
    hadoopConf.setBoolean(tryDirectSqlKey, false)
    val client = clientBuilder.buildClient(HiveUtils.hiveExecutionVersion, hadoopConf)
    client.runSqlHive("CREATE TABLE test (value INT) PARTITIONED BY (part INT)")

    val partitions = (1 to testPartitionCount).map { part =>
      CatalogTablePartition(Map("part" -> part.toString), storageFormat)
    }
    client.createPartitions(
      "default", "test", partitions, ignoreIfExists = false)

    val filteredPartitions = client.getPartitionsByFilter(client.getTable("default", "test"),
      Seq(EqualTo(AttributeReference("part", IntegerType)(), Literal(3))))

    assert(filteredPartitions.size == testPartitionCount)
  }
}