org.scalatest.Assertions Scala Examples

The following examples show how to use org.scalatest.Assertions. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: SparkContextInfoSuite.scala    From BigDatalog   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.scalatest.Assertions
import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
  test("getPersistentRDDs only returns RDDs that are marked as cached") {
    sc = new SparkContext("local", "test")
    assert(sc.getPersistentRDDs.isEmpty === true)

    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    assert(sc.getPersistentRDDs.isEmpty === true)

    rdd.cache()
    assert(sc.getPersistentRDDs.size === 1)
    assert(sc.getPersistentRDDs.values.head === rdd)
  }

  test("getPersistentRDDs returns an immutable map") {
    sc = new SparkContext("local", "test")
    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    val myRdds = sc.getPersistentRDDs
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

    // myRdds2 should have 2 RDDs, but myRdds should not change
    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
    val myRdds2 = sc.getPersistentRDDs
    assert(myRdds2.size === 2)
    assert(myRdds2(0) === rdd1)
    assert(myRdds2(1) === rdd2)
    assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
    sc = new SparkContext("local", "test")
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    assert(sc.getRDDStorageInfo.size === 0)
    rdd.collect()
    assert(sc.getRDDStorageInfo.size === 1)
    assert(sc.getRDDStorageInfo.head.isCached)
    assert(sc.getRDDStorageInfo.head.memSize > 0)
    assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("call sites report correct locations") {
    sc = new SparkContext("local", "test")
    testPackage.runCallSiteTest(sc)
  }
}


package object testPackage extends Assertions {
  private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

  def runCallSiteTest(sc: SparkContext) {
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val rddCreationSite = rdd.getCreationSite
    val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"

    val rddCreationLine = rddCreationSite match {
      case CALL_SITE_REGEX(func, file, line) => {
        assert(func === "makeRDD")
        assert(file === "SparkContextInfoSuite.scala")
        line.toInt
      }
      case _ => fail("Did not match expected call site format")
    }

    curCallSite match {
      case CALL_SITE_REGEX(func, file, line) => {
        assert(func === "getCallSite") // this is correct because we called it from outside of Spark
        assert(file === "SparkContextInfoSuite.scala")
        assert(line.toInt === rddCreationLine.toInt + 2)
      }
      case _ => fail("Did not match expected call site format")
    }
  }
} 
Example 2
Source File: CompareParamGrid.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl

import org.apache.spark.ml.param.ParamMap
import org.scalatest.{Assertions, Matchers}


trait CompareParamGrid extends Matchers with Assertions {

  
  def gridCompare(g1: Array[ParamMap], g2: Array[ParamMap]): Unit = {
    val g1values = g1.toSet[ParamMap].map(_.toSeq.toSet)
    val g2values = g2.toSet[ParamMap].map(_.toSeq.toSet)
    matchTwoSets(g1values, g2values)
  }

  private def matchTwoSets[T](actual: Set[T], expected: Set[T]): Unit = {
    def stringify(set: Set[T]): String = {
      val list = set.toList
      val chunk = list take 10
      val strings = chunk.map(_.toString).sorted
      if (list.size > chunk.size) strings.mkString else strings.mkString + ",..."
    }
    val missing = stringify(expected -- actual)
    val extra = stringify(actual -- expected)
    withClue(s"Missing:\n $missing\nExtra:\n$extra") {
      actual shouldBe expected
    }
  }
} 
Example 3
Source File: ScalaStyleValidationTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class ScalaStyleValidationTest extends FlatSpec with Matchers with Assertions {
  import scala.Throwable
  private def +(x: Int, y: Int) = x + y
  private def -(x: Int, y: Int) = x - y
  private def *(x: Int, y: Int) = x * y
  private def /(x: Int, y: Int) = x / y
  private def +-(x: Int, y: Int) = x + (-y)
  private def xx_=(y: Int) = println(s"setting xx to $y")

  "bad names" should "never happen" in {
    "def _=abc = ???" shouldNot compile
    true shouldBe true
  }

  "non-ascii" should "not be allowed" in {
//    "def ⇒ = ???" shouldNot compile // it does not even compile as a string
  }

} 
Example 4
Source File: RandomListTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import java.text.SimpleDateFormat

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import com.salesforce.op.testkit.RandomList.{NormalGeolocation, UniformGeolocation}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RandomListTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142136L

  private def check[D, T <: OPList[D]](
    g: RandomList[D, T],
    minLen: Int, maxLen: Int,
    predicate: (D => Boolean) = (_: D) => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    segment count (_.value.length < minLen) shouldBe 0
    segment count (_.value.length > maxLen) shouldBe 0
    segment foreach (list => list.value foreach { x =>
      predicate(x) shouldBe true
    })
  }

  private val df = new SimpleDateFormat("dd/MM/yy")

  Spec[Text, RandomList[String, TextList]] should "generate lists of strings" in {
    val sut = RandomList.ofTexts(RandomText.countries, 0, 4)
    check[String, TextList](sut, 0, 4, _.length > 0)

    (sut limit 7 map (_.value.toList)) shouldBe
      List(
        List("Madagascar", "Gondal", "Zephyria"),
        List("Holy Alliance"),
        List("North American Union"),
        List("Guatemala", "Estonia", "Kolechia"),
        List(),
        List("Myanmar", "Bhutan"),
        List("Equatorial Guinea")
      )
  }

  Spec[Date, RandomList[Long, DateList]] should "generate lists of dates" in {
    val dates = RandomIntegral.dates(df.parse("01/01/2017"), 1000, 1000000)
    val sut = RandomList.ofDates(dates, 11, 22)
    var d0 = 0L
    check[Long, DateList](sut, 11, 22, d => {
      val d1 = d0
      d0 = d
      d > d1
    })
  }

  Spec[DateTimeList, RandomList[Long, DateTimeList]] should "generate lists of datetimes" in {
    val datetimes = RandomIntegral.datetimes(df.parse("01/01/2017"), 1000, 1000000)
    val sut = RandomList.ofDateTimes(datetimes, 11, 22)
    var d0 = 0L
    check[Long, DateTimeList](sut, 11, 22, d => {
      val d1 = d0
      d0 = d
      d > d1
    })
  }

  Spec[UniformGeolocation] should "generate uniformly distributed geolocations" in {
    val sut = RandomList.ofGeolocations
    val segment = sut limit numTries
    segment foreach (_.value.length shouldBe 3)
  }

  Spec[NormalGeolocation] should "generate geolocations around given point" in {
    for {accuracy <- GeolocationAccuracy.values} {
      val geolocation = RandomList.ofGeolocationsNear(37.444136, 122.163160, accuracy)
      val segment = geolocation limit numTries
      segment foreach (_.value.length shouldBe 3)
    }
  }
} 
Example 5
Source File: RandomIntegralTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import java.text.SimpleDateFormat

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps

@RunWith(classOf[JUnitRunner])
class RandomIntegralTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142135L

  private def check[T <: Integral](
    g: RandomIntegral[T],
    predicate: Long => Boolean = _ => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    val numberOfEmpties = segment count (_.isEmpty)

    val expectedNumberOfEmpties = g.probabilityOfEmpty * numTries

    withClue(s"numEmpties = $numberOfEmpties, expected $expectedNumberOfEmpties") {
      math.abs(numberOfEmpties - expectedNumberOfEmpties) < 2 * math.sqrt(numTries) shouldBe true
    }

    val maybeValues = segment filterNot (_.isEmpty) map (_.value)
    val values = maybeValues collect { case Some(s) => s }

    values foreach (x => predicate(x) shouldBe true)

    withClue(s"number of distinct values = ${values.size}, expected:") {
      math.abs(maybeValues.size - values.toSet.size) < maybeValues.size / 20
    }

  }

  private val df = new SimpleDateFormat("dd/MM/yy")

  Spec[RandomIntegral[Integral]] should "generate empties and distinct numbers" in {
    val sut0 = RandomIntegral.integrals
    val sut = sut0.withProbabilityOfEmpty(0.3)
    check(sut)
    sut.probabilityOfEmpty shouldBe 0.3
  }

  Spec[RandomIntegral[Integral]] should "generate empties and distinct numbers in some range" in {
    val sut0 = RandomIntegral.integrals(100, 200)
    val sut = sut0.withProbabilityOfEmpty(0.3)
    check(sut, i => i >= 100 && i < 200)
    sut.probabilityOfEmpty shouldBe 0.3
  }

  Spec[RandomIntegral[Date]] should "generate dates" in {
    val sut = RandomIntegral.dates(df.parse("01/01/2017"), 1000, 1000000)
    var d0 = 0L
    check(sut withProbabilityOfEmpty 0.01, d => {
      val d1 = d0
      d0 = d
      d0 > d1
    })
  }

  Spec[RandomIntegral[DateTime]] should "generate dates with times" in {
    val sut = RandomIntegral.datetimes(df.parse("08/24/2017"), 1000, 1000000)
    var d0 = 0L
    check(sut withProbabilityOfEmpty 0.001, d => {
      val d1 = d0
      d0 = d
      d0 > d1
    })
  }
} 
Example 6
Source File: RandomSetTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RandomSetTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142136L

  private def check[D, T <: OPSet[D]](
    g: RandomSet[D, T],
    minLen: Int, maxLen: Int,
    predicate: (D => Boolean) = (_: D) => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    segment count (_.value.size < minLen) shouldBe 0
    segment count (_.value.size > maxLen) shouldBe 0
    segment foreach (Set => Set.value foreach { x =>
      predicate(x) shouldBe true
    })
  }

  Spec[MultiPickList] should "generate multipicklists" in {
    val sut = RandomMultiPickList.of(RandomText.countries, maxLen = 5)

    check[String, MultiPickList](sut, 0, 5, _.nonEmpty)

    val expected = List(
      Set(),
      Set("Aldorria", "Palau", "Glubbdubdrib"),
      Set(),
      Set(),
      Set("Sweden", "Wuhu Islands", "Tuvalu")
    )

    {sut reset 42; sut limit 5 map (_.value)} shouldBe expected

    {sut reset 42; sut limit 5 map (_.value)} shouldBe expected
  }

} 
Example 7
Source File: AssertingSyntax.scala    From cats-effect-testing   with Apache License 2.0 5 votes vote down vote up
package cats.effect.testing.scalatest

import cats.Functor
import cats.effect.Sync
import org.scalatest.{Assertion, Assertions, Succeeded}
import cats.implicits._


    def assertThrows[E <: Throwable](implicit F: Sync[F], ct: reflect.ClassTag[E]): F[Assertion] =
      self.attempt.flatMap {
        case Left(t: E) => F.pure(Succeeded: Assertion)
        case Left(t) =>
          F.delay(
            fail(
              s"Expected an exception of type ${ct.runtimeClass.getName} but got an exception: $t"
            )
          )
        case Right(a) =>
          F.delay(
            fail(s"Expected an exception of type ${ct.runtimeClass.getName} but got a result: $a")
          )
      }
  }
} 
Example 8
Source File: BagelSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.bagel

import org.scalatest.{BeforeAndAfter, Assertions}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.storage.StorageLevel

class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable

class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts {

  var sc: SparkContext = _

  after {
    if (sc != null) {
      sc.stop()
      sc = null
    }
  }

  test("halting by voting") {
    sc = new SparkContext("local", "test")
    val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
    val msgs = sc.parallelize(Array[(String, TestMessage)]())
    val numSupersteps = 5
    val result =
      Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
        (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
          (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
      }
    for ((id, vert) <- result.collect) {
      assert(vert.age === numSupersteps)
    }
  }

  test("halting by message silence") {
    sc = new SparkContext("local", "test")
    val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0))))
    val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
    val numSupersteps = 5
    val result =
      Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
        (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
          val msgsOut =
            msgs match {
              case Some(ms) if (superstep < numSupersteps - 1) =>
                ms
              case _ =>
                Array[TestMessage]()
            }
        (new TestVertex(self.active, self.age + 1), msgsOut)
      }
    for ((id, vert) <- result.collect) {
      assert(vert.age === numSupersteps)
    }
  }

  test("large number of iterations") {
    // This tests whether jobs with a large number of iterations finish in a reasonable time,
    // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
    failAfter(30 seconds) {
      sc = new SparkContext("local", "test")
      val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
      val msgs = sc.parallelize(Array[(String, TestMessage)]())
      val numSupersteps = 50
      val result =
        Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
          (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
            (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
        }
      for ((id, vert) <- result.collect) {
        assert(vert.age === numSupersteps)
      }
    }
  }

  test("using non-default persistence level") {
    failAfter(10 seconds) {
      sc = new SparkContext("local", "test")
      val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
      val msgs = sc.parallelize(Array[(String, TestMessage)]())
      val numSupersteps = 20
      val result =
        Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
          (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
            (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
        }
      for ((id, vert) <- result.collect) {
        assert(vert.age === numSupersteps)
      }
    }
  }
} 
Example 9
Source File: SparkContextInfoSuite.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.scalatest.Assertions
import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
  //只返回RDDS被标记为缓存
  test("getPersistentRDDs only returns RDDs that are marked as cached") {
    sc = new SparkContext("local", "test")
    assert(sc.getPersistentRDDs.isEmpty === true)

    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    //获得持久化RDD空值
    assert(sc.getPersistentRDDs.isEmpty === true)

    rdd.cache()//RDD持久化缓存
    assert(sc.getPersistentRDDs.size === 1)
    //返回列表第一个RDD的值
    assert(sc.getPersistentRDDs.values.head === rdd)
  }
  //返回一个不可变的Map
  test("getPersistentRDDs returns an immutable map") {
    sc = new SparkContext("local", "test")
    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    val myRdds = sc.getPersistentRDDs //返回已标记的持久化
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    //获得持久化存储级别,默认内存
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

    // myRdds2 should have 2 RDDs, but myRdds should not change
    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
    val myRdds2 = sc.getPersistentRDDs
    assert(myRdds2.size === 2)
    assert(myRdds2(0) === rdd1)
    assert(myRdds2(1) === rdd2)
    assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
  }
  //报告RDDS实际持久化RDDInfo数据
  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
    sc = new SparkContext("local", "test")
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    assert(sc.getRDDStorageInfo.size === 0)
    rdd.collect()
    assert(sc.getRDDStorageInfo.size === 1)//RDDInfo
    assert(sc.getRDDStorageInfo.head.isCached)//判断是否缓存
    assert(sc.getRDDStorageInfo.head.memSize > 0)//内存大小
    assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("call sites report correct locations") {//报告正确的位置
    sc = new SparkContext("local", "test")
    testPackage.runCallSiteTest(sc)
  }
}


package object testPackage extends Assertions {
  private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

  def runCallSiteTest(sc: SparkContext) {
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val rddCreationSite = rdd.getCreationSite
    println("===="+rddCreationSite)
    //注意:2行后定义“rdd”
    val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"

    val rddCreationLine = rddCreationSite match {
      case CALL_SITE_REGEX(func, file, line) => {
        assert(func === "makeRDD")
        assert(file === "SparkContextInfoSuite.scala")
        line.toInt
      }
      case _ => fail("Did not match expected call site format")
    }

    curCallSite match {
      case CALL_SITE_REGEX(func, file, line) => {
        //这是正确的,因为我们从Spark的外部称它为正确的
        assert(func === "getCallSite") // this is correct because we called it from outside of Spark
        assert(file === "SparkContextInfoSuite.scala")
        println("==line==="+line.toInt )
        //assert(line.toInt === rddCreationLine.toInt + 2)
      }
      case _ => fail("Did not match expected call site format")
    }
  }
} 
Example 10
Source File: SparkContextInfoSuite.scala    From Spark-2.3.1   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.scalatest.Assertions

import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
  test("getPersistentRDDs only returns RDDs that are marked as cached") {
    sc = new SparkContext("local", "test")
    assert(sc.getPersistentRDDs.isEmpty === true)

    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    assert(sc.getPersistentRDDs.isEmpty === true)

    rdd.cache()
    assert(sc.getPersistentRDDs.size === 1)
    assert(sc.getPersistentRDDs.values.head === rdd)
  }

  test("getPersistentRDDs returns an immutable map") {
    sc = new SparkContext("local", "test")
    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    val myRdds = sc.getPersistentRDDs
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

    // myRdds2 should have 2 RDDs, but myRdds should not change
    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
    val myRdds2 = sc.getPersistentRDDs
    assert(myRdds2.size === 2)
    assert(myRdds2(0) === rdd1)
    assert(myRdds2(1) === rdd2)
    assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
    sc = new SparkContext("local", "test")
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    assert(sc.getRDDStorageInfo.size === 0)
    rdd.collect()
    assert(sc.getRDDStorageInfo.size === 1)
    assert(sc.getRDDStorageInfo.head.isCached)
    assert(sc.getRDDStorageInfo.head.memSize > 0)
    assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("call sites report correct locations") {
    sc = new SparkContext("local", "test")
    testPackage.runCallSiteTest(sc)
  }
}


package object testPackage extends Assertions {
  private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

  def runCallSiteTest(sc: SparkContext) {
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val rddCreationSite = rdd.getCreationSite
    val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"

    val rddCreationLine = rddCreationSite match {
      case CALL_SITE_REGEX(func, file, line) =>
        assert(func === "makeRDD")
        assert(file === "SparkContextInfoSuite.scala")
        line.toInt
      case _ => fail("Did not match expected call site format")
    }

    curCallSite match {
      case CALL_SITE_REGEX(func, file, line) =>
        assert(func === "getCallSite") // this is correct because we called it from outside of Spark
        assert(file === "SparkContextInfoSuite.scala")
        assert(line.toInt === rddCreationLine.toInt + 2)
      case _ => fail("Did not match expected call site format")
    }
  }
} 
Example 11
Source File: OpCountVectorizerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.TestOpVectorColumnType.{IndCol, IndVal}
import com.salesforce.op.test.{TestFeatureBuilder, TestOpVectorMetadataBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.OpVectorMetadata
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class OpCountVectorizerTest extends FlatSpec with TestSparkContext {

  val data = Seq[(Real, TextList)](
    (Real(0), Seq("a", "b", "c").toTextList),
    (Real(1), Seq("a", "b", "b", "b", "a", "c").toTextList)
  )

  lazy val (ds, f1, f2) = TestFeatureBuilder(data)

  lazy val expected = Array[(Real, OPVector)](
    (Real(0), Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 1.0, 1.0)).toOPVector),
    (Real(1), Vectors.sparse(3, Array(0, 1, 2), Array(3.0, 2.0, 1.0)).toOPVector)
  )

  val f2vec = new OpCountVectorizer().setInput(f2).setVocabSize(3).setMinDF(2)

  Spec[OpCountVectorizerTest] should "convert array of strings into count vector" in {
    val transformedData = f2vec.fit(ds).transform(ds)
    val output = f2vec.getOutput()
    transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected
  }

  it should "return the a fitted vectorizer with the correct parameters" in {
    val fitted = f2vec.fit(ds)
    val vectorMetadata = fitted.getMetadata()
    val expectedMeta = TestOpVectorMetadataBuilder(
      f2vec,
      f2 -> List(IndVal(Some("b")), IndVal(Some("a")), IndVal(Some("c")))
    )
    // cannot just do equals because fitting is nondeterministic
    OpVectorMetadata(f2vec.getOutputFeatureName, vectorMetadata).columns should contain theSameElementsAs
      expectedMeta.columns
  }

  it should "convert array of strings into count vector (shortcut version)" in {
    val output = f2.countVec(minDF = 2, vocabSize = 3)
    val f2vec = output.originStage.asInstanceOf[OpCountVectorizer]
    val transformedData = f2vec.fit(ds).transform(ds)
    transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected
  }
} 
Example 12
Source File: TimerImplTest.scala    From datadog4s   with MIT License 5 votes vote down vote up
package com.avast.datadog4s.statsd

import java.util.concurrent.TimeUnit

import cats.effect.{ Clock, IO }
import com.avast.datadog4s.api.Tag
import com.avast.datadog4s.statsd.metric.TimerImpl
import com.timgroup.statsd.{ StatsDClient => JStatsDClient }
import org.mockito.scalatest.MockitoSugar
import org.scalatest.{ Assertions, BeforeAndAfter }
import org.scalatest.flatspec.AnyFlatSpec

class TimerImplTest extends AnyFlatSpec with MockitoSugar with BeforeAndAfter with Assertions {

  trait Fixtures {
    val aspect: String = "metric"
    val sampleRate     = 1.0

    val statsD: JStatsDClient = mock[JStatsDClient]
    val clock: Clock[IO]      = mock[Clock[IO]]

    val timer = new TimerImpl[IO](clock, statsD, aspect, sampleRate, Vector.empty)

    when(clock.monotonic(TimeUnit.NANOSECONDS)).thenReturn(IO.pure(10 * 1000 * 1000), IO.pure(30 * 1000 * 1000))
  }

  "time F[A]" should "report success with label success:true" in new Fixtures {
    private val res = timer.time(IO.delay("hello world")).unsafeRunSync()

    verify(statsD, times(1)).recordExecutionTime(aspect, 20, sampleRate, Tag.of("success", "true"))
    assertResult(res)("hello world")
  }

  it should "report failure with label failure:true and exception name" in new Fixtures {
    private val res = timer.time(IO.raiseError(new NoSuchElementException("fail")))

    assertThrows[NoSuchElementException](res.unsafeRunSync())
    verify(statsD, times(1))
      .recordExecutionTime(
        aspect,
        20,
        sampleRate,
        Tag.of("exception", "java.util.NoSuchElementException"),
        Tag.of("success", "false")
      )
  }

} 
Example 13
Source File: VideoDisplay.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.av.callbacks.video

import java.io.Closeable

import im.tox.tox4j.av.data.{ Height, Width }
import im.tox.tox4j.testing.autotest.AutoTestSuite.timed
import org.scalatest.Assertions

import scala.util.Try

abstract class VideoDisplay[Parsed, Canvas] extends Assertions with Closeable {

  def width: Width
  def height: Height

  protected def canvas: Try[Canvas]
  protected def parse(
    y: Array[Byte], u: Array[Byte], v: Array[Byte],
    yStride: Int, uStride: Int, vStride: Int
  ): Parsed
  protected def displaySent(canvas: Canvas, frameNumber: Int, parsed: Parsed): Unit
  protected def displayReceived(canvas: Canvas, frameNumber: Int, parsed: Parsed): Unit

  final def displaySent(frameNumber: Int, y: Array[Byte], u: Array[Byte], v: Array[Byte]): Unit = {
    val width = this.width.value
    canvas.foreach(displaySent(_, frameNumber, parse(y, u, v, width, width / 2, width / 2)))
  }

  
  final def displayReceived(
    frameNumber: Int,
    y: Array[Byte], u: Array[Byte], v: Array[Byte],
    yStride: Int, uStride: Int, vStride: Int
  ): Option[(Int, Int)] = {
    canvas.toOption.map { canvas =>
      val (parseTime, parsed) = timed(parse(y, u, v, yStride, uStride, vStride))
      val displayTime = timed(displayReceived(canvas, frameNumber, parsed))

      (parseTime, displayTime)
    }
  }

} 
Example 14
Source File: ToxCoreTestBase.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j

import java.io.IOException
import java.net.{ InetAddress, Socket }
import java.util.Random

import org.jetbrains.annotations.NotNull
import org.scalatest.Assertions

object ToxCoreTestBase extends Assertions {

  private[tox4j] val nodeCandidates = Seq(
    new DhtNode("tox.initramfs.io", "tox.initramfs.io", 33445, "3F0A45A268367C1BEA652F258C85F4A66DA76BCAA667A49E770BCC4917AB6A25"),
    new DhtNode("tox.verdict.gg", null, 33445, "1C5293AEF2114717547B39DA8EA6F1E331E5E358B35F9B6B5F19317911C5F976")
  )

  @NotNull def randomBytes(length: Int): Array[Byte] = {
    val array = new Array[Byte](length)
    new Random().nextBytes(array)
    array
  }

  @NotNull
  def readablePublicKey(@NotNull id: Array[Byte]): String = {
    val str = new StringBuilder
    id foreach { c => str.append(f"$c%02X") }
    str.toString()
  }

  @NotNull
  def parsePublicKey(@NotNull id: String): Array[Byte] = {
    val publicKey = new Array[Byte](id.length / 2)
    publicKey.indices foreach { i =>
      publicKey(i) =
        ((fromHexDigit(id.charAt(i * 2)) << 4) +
          fromHexDigit(id.charAt(i * 2 + 1))).toByte
    }
    publicKey
  }

  private def fromHexDigit(c: Char): Byte = {
    val digit =
      if (false) { 0 }
      else if ('0' to '9' contains c) { c - '0' }
      else if ('A' to 'F' contains c) { c - 'A' + 10 }
      else if ('a' to 'f' contains c) { c - 'a' + 10 }
      else { throw new IllegalArgumentException(s"Non-hex digit character: $c") }
    digit.toByte
  }

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  private def hasConnection(ip: String, port: Int): Option[String] = {
    var socket: Socket = null
    try {
      socket = new Socket(InetAddress.getByName(ip), port)
      if (socket.getInputStream == null) {
        Some("Socket input stream is null")
      } else {
        None
      }
    } catch {
      case e: IOException =>
        Some(s"A network connection can't be established to $ip:$port: ${e.getMessage}")
    } finally {
      if (socket != null) {
        socket.close()
      }
    }
  }

  def checkIPv4: Option[String] = {
    hasConnection("8.8.8.8", 53)
  }

  def checkIPv6: Option[String] = {
    hasConnection("2001:4860:4860::8888", 53)
  }

  protected[tox4j] def assumeIPv4(): Unit = {
    assume(checkIPv4.isEmpty)
  }

  protected[tox4j] def assumeIPv6(): Unit = {
    assume(checkIPv6.isEmpty)
  }

} 
Example 15
Source File: VideoGenerator.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.av.callbacks.video

import im.tox.tox4j.av.data.{ Height, Width }
import org.scalatest.Assertions

abstract class VideoGenerator {

  def width: Width
  def height: Height
  def length: Int

  def yuv(t: Int): (Array[Byte], Array[Byte], Array[Byte])
  def resize(width: Width, height: Height): VideoGenerator

  final def size: Int = width.value * height.value

  protected final def w: Int = width.value
  protected final def h: Int = height.value

}

object VideoGenerator extends Assertions {

  private def resizeNearestNeighbour(
    pixels: Array[Byte],
    oldWidth: Int,
    oldHeight: Int,
    newWidth: Int,
    newHeight: Int
  ): Array[Byte] = {
    val result = Array.ofDim[Byte](newWidth * newHeight)

    val xRatio = oldWidth / newWidth.toDouble
    val yRatio = oldHeight / newHeight.toDouble

    for {
      yPos <- 0 until newHeight
      xPos <- 0 until newWidth
    } {
      val px = Math.floor(xPos * xRatio)
      val py = Math.floor(yPos * yRatio)
      result((yPos * newWidth) + xPos) = pixels(((py * oldWidth) + px).toInt)
    }

    result
  }

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  def resizeNearestNeighbour(newWidth: Width, newHeight: Height, gen: VideoGenerator): VideoGenerator = {
    if (newWidth == gen.width && newHeight == gen.height) {
      gen
    } else {
      new VideoGenerator {

        override def toString: String = s"resizeNearestNeighbour($width, $height, $gen)"

        override def resize(width: Width, height: Height): VideoGenerator = gen.resize(width, height)

        override def yuv(t: Int): (Array[Byte], Array[Byte], Array[Byte]) = {
          val yuv = gen.yuv(t)
          (
            resizeNearestNeighbour(yuv._1, gen.width.value, gen.height.value, width.value, height.value),
            resizeNearestNeighbour(yuv._2, gen.width.value / 2, gen.height.value / 2, width.value / 2, height.value / 2),
            resizeNearestNeighbour(yuv._3, gen.width.value / 2, gen.height.value / 2, width.value / 2, height.value / 2)
          )
        }

        override def width: Width = newWidth
        override def height: Height = newHeight
        override def length: Int = gen.length

      }
    }
  }

} 
Example 16
Source File: DhtNodeSelector.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j

import java.io.IOException
import java.net.{ InetAddress, Socket }

import com.typesafe.scalalogging.Logger
import im.tox.tox4j.core.ToxCore
import im.tox.tox4j.impl.jni.ToxCoreImplFactory
import org.scalatest.Assertions
import org.slf4j.LoggerFactory

object DhtNodeSelector extends Assertions {

  private val logger = Logger(LoggerFactory.getLogger(this.getClass))
  private var selectedNode: Option[DhtNode] = Some(ToxCoreTestBase.nodeCandidates(0))

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  private def tryConnect(node: DhtNode): Option[DhtNode] = {
    var socket: Socket = null
    try {
      socket = new Socket(InetAddress.getByName(node.ipv4), node.udpPort.value)
      assume(socket.getInputStream != null)
      Some(node)
    } catch {
      case e: IOException =>
        logger.info(s"TCP connection failed (${e.getMessage})")
        None
    } finally {
      if (socket != null) {
        socket.close()
      }
    }
  }

  private def tryBootstrap(
    withTox: (Boolean, Boolean) => (ToxCore => Option[DhtNode]) => Option[DhtNode],
    node: DhtNode,
    udpEnabled: Boolean
  ): Option[DhtNode] = {
    val protocol = if (udpEnabled) "UDP" else "TCP"
    val port = if (udpEnabled) node.udpPort else node.tcpPort
    logger.info(s"Trying to bootstrap with ${node.ipv4}:$port using $protocol")

    withTox(true, udpEnabled) { tox =>
      val status = new ConnectedListener
      if (!udpEnabled) {
        tox.addTcpRelay(node.ipv4, port, node.dhtId)
      }
      tox.bootstrap(node.ipv4, port, node.dhtId)

      // Try bootstrapping for 10 seconds.
      (0 to 10000 / tox.iterationInterval) find { _ =>
        tox.iterate(status)(())
        Thread.sleep(tox.iterationInterval)
        status.isConnected
      } match {
        case Some(time) =>
          logger.info(s"Bootstrapped successfully after ${time * tox.iterationInterval}ms using $protocol")
          Some(node)
        case None =>
          logger.info(s"Unable to bootstrap with $protocol")
          None
      }
    }
  }

  private def findNode(withTox: (Boolean, Boolean) => (ToxCore => Option[DhtNode]) => Option[DhtNode]): DhtNode = {
    DhtNodeSelector.selectedNode match {
      case Some(node) => node
      case None =>
        logger.info("Looking for a working bootstrap node")

        DhtNodeSelector.selectedNode = ToxCoreTestBase.nodeCandidates find { node =>
          logger.info(s"Trying to establish a TCP connection to ${node.ipv4}")

          (for {
            node <- tryConnect(node)
            node <- tryBootstrap(withTox, node, udpEnabled = true)
            node <- tryBootstrap(withTox, node, udpEnabled = false)
          } yield node).isDefined
        }

        assume(DhtNodeSelector.selectedNode.nonEmpty, "No viable nodes for bootstrap found; cannot test")
        DhtNodeSelector.selectedNode.get
    }
  }

  def node: DhtNode = findNode(ToxCoreImplFactory.withToxUnit[Option[DhtNode]])

} 
Example 17
Source File: CheckedOrdering.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.testing

import org.scalatest.Assertions


object CheckedOrdering extends Assertions {

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  def apply[A](ord: Ordering[A]): Ordering[A] = {
    new Ordering[A] {
      override def compare(x: A, y: A): Int = {
        val result = ord.compare(x, y)
        if (result == 0) {
          assert(x == y)
        }
        result
      }
    }
  }

} 
Example 18
Source File: CheckedOrderingEq.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.testing

import org.scalatest.Assertions


object CheckedOrderingEq extends Assertions {

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  def apply[A <: AnyRef](ord: Ordering[A]): Ordering[A] = {
    new Ordering[A] {
      override def compare(x: A, y: A): Int = {
        val result = ord.compare(x, y)
        if (result == 0) {
          assert(x eq y)
        }
        result
      }
    }
  }

} 
Example 19
Source File: GetDisjunction.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.testing

import im.tox.core.error.CoreError
import im.tox.core.typesafe.{ -\/, \/, \/- }
import org.scalatest.Assertions

import scala.language.implicitConversions

final case class GetDisjunction[T] private (disjunction: CoreError \/ T) extends Assertions {
  def get: T = {
    disjunction match {
      case -\/(error)   => fail(error.toString)
      case \/-(success) => success
    }
  }
}

object GetDisjunction {

  @SuppressWarnings(Array("org.wartremover.warts.ImplicitConversion"))
  implicit def toGetDisjunction[T](disjunction: CoreError \/ T): GetDisjunction[T] = GetDisjunction(disjunction)

} 
Example 20
Source File: ToxExceptionChecks.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.testing

import im.tox.tox4j.exceptions.ToxException
import org.scalatest.Assertions

trait ToxExceptionChecks extends Assertions {

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  protected def intercept[E <: Enum[E]](code: E)(f: => Unit) = {
    try {
      f
      fail(s"Expected exception with code ${code.name}")
    } catch {
      case e: ToxException[_] =>
        assert(e.code eq code)
    }
  }

} 
Example 21
Source File: BagelSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.bagel

import org.scalatest.{BeforeAndAfter, Assertions}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.storage.StorageLevel

class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable

class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts {

  var sc: SparkContext = _

  after {
    if (sc != null) {
      sc.stop()
      sc = null
    }
  }

  test("halting by voting") {
    sc = new SparkContext("local", "test")
    val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
    val msgs = sc.parallelize(Array[(String, TestMessage)]())
    val numSupersteps = 5
    val result =
      Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
        (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
          (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
      }
    for ((id, vert) <- result.collect) {
      assert(vert.age === numSupersteps)
    }
  }

  test("halting by message silence") {
    sc = new SparkContext("local", "test")
    val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0))))
    val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
    val numSupersteps = 5
    val result =
      Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
        (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
          val msgsOut =
            msgs match {
              case Some(ms) if (superstep < numSupersteps - 1) =>
                ms
              case _ =>
                Array[TestMessage]()
            }
        (new TestVertex(self.active, self.age + 1), msgsOut)
      }
    for ((id, vert) <- result.collect) {
      assert(vert.age === numSupersteps)
    }
  }

  test("large number of iterations") {
    // This tests whether jobs with a large number of iterations finish in a reasonable time,
    // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
    failAfter(30 seconds) {
      sc = new SparkContext("local", "test")
      val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
      val msgs = sc.parallelize(Array[(String, TestMessage)]())
      val numSupersteps = 50
      val result =
        Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
          (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
            (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
        }
      for ((id, vert) <- result.collect) {
        assert(vert.age === numSupersteps)
      }
    }
  }

  test("using non-default persistence level") {
    failAfter(10 seconds) {
      sc = new SparkContext("local", "test")
      val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
      val msgs = sc.parallelize(Array[(String, TestMessage)]())
      val numSupersteps = 20
      val result =
        Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
          (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
            (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
        }
      for ((id, vert) <- result.collect) {
        assert(vert.age === numSupersteps)
      }
    }
  }
} 
Example 22
Source File: SerializerPropertiesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.serializer

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import scala.util.Random

import org.scalatest.Assertions

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.KryoTest.RegistratorWithoutAutoReset


class SerializerPropertiesSuite extends SparkFunSuite {

  import SerializerPropertiesSuite._

  test("JavaSerializer does not support relocation") {
    // Per a comment on the SPARK-4550 JIRA ticket, Java serialization appears to write out the
    // full class name the first time an object is written to an output stream, but subsequent
    // references to the class write a more compact identifier; this prevents relocation.
    val ser = new JavaSerializer(new SparkConf())
    testSupportsRelocationOfSerializedObjects(ser, generateRandomItem)
  }

  test("KryoSerializer supports relocation when auto-reset is enabled") {
    val ser = new KryoSerializer(new SparkConf)
    assert(ser.newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset())
    testSupportsRelocationOfSerializedObjects(ser, generateRandomItem)
  }

  test("KryoSerializer does not support relocation when auto-reset is disabled") {
    val conf = new SparkConf().set("spark.kryo.registrator",
      classOf[RegistratorWithoutAutoReset].getName)
    val ser = new KryoSerializer(conf)
    assert(!ser.newInstance().asInstanceOf[KryoSerializerInstance].getAutoReset())
    testSupportsRelocationOfSerializedObjects(ser, generateRandomItem)
  }

}

object SerializerPropertiesSuite extends Assertions {

  def generateRandomItem(rand: Random): Any = {
    val randomFunctions: Seq[() => Any] = Seq(
      () => rand.nextInt(),
      () => rand.nextString(rand.nextInt(10)),
      () => rand.nextDouble(),
      () => rand.nextBoolean(),
      () => (rand.nextInt(), rand.nextString(rand.nextInt(10))),
      () => MyCaseClass(rand.nextInt(), rand.nextString(rand.nextInt(10))),
      () => {
        val x = MyCaseClass(rand.nextInt(), rand.nextString(rand.nextInt(10)))
        (x, x)
      }
    )
    randomFunctions(rand.nextInt(randomFunctions.size)).apply()
  }

  def testSupportsRelocationOfSerializedObjects(
      serializer: Serializer,
      generateRandomItem: Random => Any): Unit = {
    if (!serializer.supportsRelocationOfSerializedObjects) {
      return
    }
    val NUM_TRIALS = 5
    val rand = new Random(42)
    for (_ <- 1 to NUM_TRIALS) {
      val items = {
        // Make sure that we have duplicate occurrences of the same object in the stream:
        val randomItems = Seq.fill(10)(generateRandomItem(rand))
        randomItems ++ randomItems.take(5)
      }
      val baos = new ByteArrayOutputStream()
      val serStream = serializer.newInstance().serializeStream(baos)
      def serializeItem(item: Any): Array[Byte] = {
        val itemStartOffset = baos.toByteArray.length
        serStream.writeObject(item)
        serStream.flush()
        val itemEndOffset = baos.toByteArray.length
        baos.toByteArray.slice(itemStartOffset, itemEndOffset).clone()
      }
      val itemsAndSerializedItems: Seq[(Any, Array[Byte])] = {
        val serItems = items.map {
          item => (item, serializeItem(item))
        }
        serStream.close()
        rand.shuffle(serItems)
      }
      val reorderedSerializedData: Array[Byte] = itemsAndSerializedItems.flatMap(_._2).toArray
      val deserializedItemsStream = serializer.newInstance().deserializeStream(
        new ByteArrayInputStream(reorderedSerializedData))
      assert(deserializedItemsStream.asIterator.toSeq === itemsAndSerializedItems.map(_._1))
      deserializedItemsStream.close()
    }
  }
}

private case class MyCaseClass(foo: Int, bar: String) 
Example 23
Source File: BufferHolderSparkSubmitSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.expressions.codegen

import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers}

import org.apache.spark.{SparkFunSuite, TestUtils}
import org.apache.spark.deploy.SparkSubmitSuite
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.util.ResetSystemProperties

// A test for growing the buffer holder to nearly 2GB. Due to the heap size limitation of the Spark
// unit tests JVM, the actually test code is running as a submit job.
class BufferHolderSparkSubmitSuite
  extends SparkFunSuite
    with Matchers
    with BeforeAndAfterEach
    with ResetSystemProperties {

  test("SPARK-22222: Buffer holder should be able to allocate memory larger than 1GB") {
    val unusedJar = TestUtils.createJarWithClasses(Seq.empty)

    val argsForSparkSubmit = Seq(
      "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"),
      "--name", "SPARK-22222",
      "--master", "local-cluster[1,1,4096]",
      "--driver-memory", "4g",
      "--conf", "spark.ui.enabled=false",
      "--conf", "spark.master.rest.enabled=false",
      "--conf", "spark.driver.extraJavaOptions=-ea",
      unusedJar.toString)
    SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..")
  }
}

object BufferHolderSparkSubmitSuite extends Assertions {

  def main(args: Array[String]): Unit = {

    val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH

    val unsafeRow = new UnsafeRow(1000)
    val holder = new BufferHolder(unsafeRow)

    holder.reset()

    assert(intercept[IllegalArgumentException] {
      holder.grow(-1)
    }.getMessage.contains("because the size is negative"))

    // while to reuse a buffer may happen, this test checks whether the buffer can be grown
    holder.grow(ARRAY_MAX / 2)
    assert(unsafeRow.getSizeInBytes % 8 == 0)

    holder.grow(ARRAY_MAX / 2 + 7)
    assert(unsafeRow.getSizeInBytes % 8 == 0)

    holder.grow(Integer.MAX_VALUE / 2)
    assert(unsafeRow.getSizeInBytes % 8 == 0)

    holder.grow(ARRAY_MAX - holder.totalSize())
    assert(unsafeRow.getSizeInBytes % 8 == 0)

    assert(intercept[IllegalArgumentException] {
      holder.grow(ARRAY_MAX + 1 - holder.totalSize())
    }.getMessage.contains("because the size after growing"))
  }
} 
Example 24
Source File: WholeStageCodegenSparkSubmitSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution

import org.scalatest.{Assertions, BeforeAndAfterEach, Matchers}
import org.scalatest.concurrent.TimeLimits

import org.apache.spark.{SparkFunSuite, TestUtils}
import org.apache.spark.deploy.SparkSubmitSuite
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{LocalSparkSession, QueryTest, Row, SparkSession}
import org.apache.spark.sql.functions.{array, col, count, lit}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.unsafe.Platform
import org.apache.spark.util.ResetSystemProperties

// Due to the need to set driver's extraJavaOptions, this test needs to use actual SparkSubmit.
class WholeStageCodegenSparkSubmitSuite extends SparkFunSuite
  with Matchers
  with BeforeAndAfterEach
  with ResetSystemProperties {

  test("Generated code on driver should not embed platform-specific constant") {
    val unusedJar = TestUtils.createJarWithClasses(Seq.empty)

    // HotSpot JVM specific: Set up a local cluster with the driver/executor using mismatched
    // settings of UseCompressedOops JVM option.
    val argsForSparkSubmit = Seq(
      "--class", WholeStageCodegenSparkSubmitSuite.getClass.getName.stripSuffix("$"),
      "--master", "local-cluster[1,1,1024]",
      "--driver-memory", "1g",
      "--conf", "spark.ui.enabled=false",
      "--conf", "spark.master.rest.enabled=false",
      "--conf", "spark.driver.extraJavaOptions=-XX:-UseCompressedOops",
      "--conf", "spark.executor.extraJavaOptions=-XX:+UseCompressedOops",
      unusedJar.toString)
    SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..")
  }
}

object WholeStageCodegenSparkSubmitSuite extends Assertions with Logging {

  var spark: SparkSession = _

  def main(args: Array[String]): Unit = {
    TestUtils.configTestLog4j("INFO")

    spark = SparkSession.builder().getOrCreate()

    // Make sure the test is run where the driver and the executors uses different object layouts
    val driverArrayHeaderSize = Platform.BYTE_ARRAY_OFFSET
    val executorArrayHeaderSize =
      spark.sparkContext.range(0, 1).map(_ => Platform.BYTE_ARRAY_OFFSET).collect.head.toInt
    assert(driverArrayHeaderSize > executorArrayHeaderSize)

    val df = spark.range(71773).select((col("id") % lit(10)).cast(IntegerType) as "v")
      .groupBy(array(col("v"))).agg(count(col("*")))
    val plan = df.queryExecution.executedPlan
    assert(plan.find(_.isInstanceOf[WholeStageCodegenExec]).isDefined)

    val expectedAnswer =
      Row(Array(0), 7178) ::
        Row(Array(1), 7178) ::
        Row(Array(2), 7178) ::
        Row(Array(3), 7177) ::
        Row(Array(4), 7177) ::
        Row(Array(5), 7177) ::
        Row(Array(6), 7177) ::
        Row(Array(7), 7177) ::
        Row(Array(8), 7177) ::
        Row(Array(9), 7177) :: Nil
    val result = df.collect
    QueryTest.sameRows(result.toSeq, expectedAnswer) match {
      case Some(errMsg) => fail(errMsg)
      case _ =>
    }
  }
} 
Example 25
Source File: SparkContextInfoSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.scalatest.Assertions

import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
  test("getPersistentRDDs only returns RDDs that are marked as cached") {
    sc = new SparkContext("local", "test")
    assert(sc.getPersistentRDDs.isEmpty === true)

    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    assert(sc.getPersistentRDDs.isEmpty === true)

    rdd.cache()
    assert(sc.getPersistentRDDs.size === 1)
    assert(sc.getPersistentRDDs.values.head === rdd)
  }

  test("getPersistentRDDs returns an immutable map") {
    sc = new SparkContext("local", "test")
    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    val myRdds = sc.getPersistentRDDs
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

    // myRdds2 should have 2 RDDs, but myRdds should not change
    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
    val myRdds2 = sc.getPersistentRDDs
    assert(myRdds2.size === 2)
    assert(myRdds2(0) === rdd1)
    assert(myRdds2(1) === rdd2)
    assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
    sc = new SparkContext("local", "test")
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    assert(sc.getRDDStorageInfo.size === 0)
    rdd.collect()
    assert(sc.getRDDStorageInfo.size === 1)
    assert(sc.getRDDStorageInfo.head.isCached)
    assert(sc.getRDDStorageInfo.head.memSize > 0)
    assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("call sites report correct locations") {
    sc = new SparkContext("local", "test")
    testPackage.runCallSiteTest(sc)
  }
}


package object testPackage extends Assertions {
  private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

  def runCallSiteTest(sc: SparkContext) {
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val rddCreationSite = rdd.getCreationSite
    val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"

    val rddCreationLine = rddCreationSite match {
      case CALL_SITE_REGEX(func, file, line) =>
        assert(func === "makeRDD")
        assert(file === "SparkContextInfoSuite.scala")
        line.toInt
      case _ => fail("Did not match expected call site format")
    }

    curCallSite match {
      case CALL_SITE_REGEX(func, file, line) =>
        assert(func === "getCallSite") // this is correct because we called it from outside of Spark
        assert(file === "SparkContextInfoSuite.scala")
        assert(line.toInt === rddCreationLine.toInt + 2)
      case _ => fail("Did not match expected call site format")
    }
  }
} 
Example 26
Source File: SparkContextInfoSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.scalatest.{Assertions, FunSuite}
import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends FunSuite with LocalSparkContext {
  test("getPersistentRDDs only returns RDDs that are marked as cached") {
    sc = new SparkContext("local", "test")
    assert(sc.getPersistentRDDs.isEmpty === true)

    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    assert(sc.getPersistentRDDs.isEmpty === true)

    rdd.cache()
    assert(sc.getPersistentRDDs.size === 1)
    assert(sc.getPersistentRDDs.values.head === rdd)
  }

  test("getPersistentRDDs returns an immutable map") {
    sc = new SparkContext("local", "test")
    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    val myRdds = sc.getPersistentRDDs
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

    // myRdds2 should have 2 RDDs, but myRdds should not change
    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
    val myRdds2 = sc.getPersistentRDDs
    assert(myRdds2.size === 2)
    assert(myRdds2(0) === rdd1)
    assert(myRdds2(1) === rdd2)
    assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
    sc = new SparkContext("local", "test")
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    assert(sc.getRDDStorageInfo.size === 0)
    rdd.collect()
    assert(sc.getRDDStorageInfo.size === 1)
    assert(sc.getRDDStorageInfo.head.isCached)
    assert(sc.getRDDStorageInfo.head.memSize > 0)
    assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("call sites report correct locations") {
    sc = new SparkContext("local", "test")
    testPackage.runCallSiteTest(sc)
  }
}


package object testPackage extends Assertions {
  private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

  def runCallSiteTest(sc: SparkContext) {
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val rddCreationSite = rdd.getCreationSite
    val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"

    val rddCreationLine = rddCreationSite match {
      case CALL_SITE_REGEX(func, file, line) => {
        assert(func === "makeRDD")
        assert(file === "SparkContextInfoSuite.scala")
        line.toInt
      }
      case _ => fail("Did not match expected call site format")
    }

    curCallSite match {
      case CALL_SITE_REGEX(func, file, line) => {
        assert(func === "getCallSite") // this is correct because we called it from outside of Spark
        assert(file === "SparkContextInfoSuite.scala")
        assert(line.toInt === rddCreationLine.toInt + 2)
      }
      case _ => fail("Did not match expected call site format")
    }
  }
} 
Example 27
Source File: AnalyzerTest.scala    From scala-commons   with MIT License 5 votes vote down vote up
package com.avsystem.commons
package analyzer

import org.scalactic.source.Position
import org.scalatest.Assertions

import scala.reflect.internal.util.BatchSourceFile
import scala.tools.nsc.plugins.Plugin
import scala.tools.nsc.{Global, Settings}

trait AnalyzerTest { this: Assertions =>
  val settings = new Settings
  settings.usejavacp.value = true
  settings.pluginOptions.value ++= List("AVSystemAnalyzer:+_")

  val compiler: Global = new Global(settings) { global =>
    override protected def loadRoughPluginsList(): List[Plugin] =
      new AnalyzerPlugin(global) :: super.loadRoughPluginsList()
  }

  def compile(source: String): Unit = {
    compiler.reporter.reset()
    val run = new compiler.Run
    run.compileSources(List(new BatchSourceFile("test.scala", source)))
  }

  def assertErrors(source: String)(implicit pos: Position): Unit = {
    compile(source)
    assert(compiler.reporter.hasErrors)
  }

  def assertErrors(errors: Int, source: String)(implicit pos: Position): Unit = {
    compile(source)
    assert(compiler.reporter.errorCount == errors)
  }

  def assertNoErrors(source: String)(implicit pos: Position): Unit = {
    compile(source)
    assert(!compiler.reporter.hasErrors)
  }
} 
Example 28
Source File: ResultAssertions.scala    From wartremover-contrib   with Apache License 2.0 5 votes vote down vote up
package org.wartremover
package contrib.test

import org.scalatest.Assertions

import org.wartremover.test.WartTestTraverser

trait ResultAssertions extends Assertions {

  def assertEmpty(result: WartTestTraverser.Result) = {
    assertResult(List.empty, "result.errors")(result.errors)
    assertResult(List.empty, "result.warnings")(result.warnings)
  }

  def assertError(result: WartTestTraverser.Result)(message: String) = assertErrors(result)(message, 1)

  def assertErrors(result: WartTestTraverser.Result)(message: String, times: Int) = {
    assertResult(List.fill(times)(message), "result.errors")(result.errors.map(skipTraverserPrefix))
    assertResult(List.empty, "result.warnings")(result.warnings.map(skipTraverserPrefix))
  }

  def assertWarnings(result: WartTestTraverser.Result)(message: String, times: Int) = {
    assertResult(List.empty, "result.errors")(result.errors.map(skipTraverserPrefix))
    assertResult(List.fill(times)(message), "result.warnings")(result.warnings.map(skipTraverserPrefix))
  }

  private val messageFormat = """\[wartremover:\S+\] ([\s\S]+)""".r

  private def skipTraverserPrefix(msg: String) = msg match {
    case messageFormat(rest) => rest
    case s => s
  }
} 
Example 29
Source File: ResponseAccessors.scala    From ScalaWebTest   with Apache License 2.0 5 votes vote down vote up
package org.scalawebtest.core

import org.scalatest.Assertions


  def responseHeaderValue(name: String): String = responseHeaders.get(name) match {
    case Some(v) => v
    case None =>
      val headerNames = responseHeaders.keys.mkString("'", "', '", "'")
      fail(
      s"""The current web response for
        did not contain the expected response header with field-name: '$name'
        It contained the following header field-names: $headerNames""")
  }
} 
Example 30
Source File: SparkContextInfoSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.scalatest.Assertions

import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
  test("getPersistentRDDs only returns RDDs that are marked as cached") {
    sc = new SparkContext("local", "test")
    assert(sc.getPersistentRDDs.isEmpty === true)

    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    assert(sc.getPersistentRDDs.isEmpty === true)

    rdd.cache()
    assert(sc.getPersistentRDDs.size === 1)
    assert(sc.getPersistentRDDs.values.head === rdd)
  }

  test("getPersistentRDDs returns an immutable map") {
    sc = new SparkContext("local", "test")
    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    val myRdds = sc.getPersistentRDDs
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

    // myRdds2 should have 2 RDDs, but myRdds should not change
    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
    val myRdds2 = sc.getPersistentRDDs
    assert(myRdds2.size === 2)
    assert(myRdds2(0) === rdd1)
    assert(myRdds2(1) === rdd2)
    assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
    sc = new SparkContext("local", "test")
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    assert(sc.getRDDStorageInfo.size === 0)
    rdd.collect()
    assert(sc.getRDDStorageInfo.size === 1)
    assert(sc.getRDDStorageInfo.head.isCached)
    assert(sc.getRDDStorageInfo.head.memSize > 0)
    assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("call sites report correct locations") {
    sc = new SparkContext("local", "test")
    testPackage.runCallSiteTest(sc)
  }
}


package object testPackage extends Assertions {
  private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

  def runCallSiteTest(sc: SparkContext) {
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val rddCreationSite = rdd.getCreationSite
    val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"

    val rddCreationLine = rddCreationSite match {
      case CALL_SITE_REGEX(func, file, line) =>
        assert(func === "makeRDD")
        assert(file === "SparkContextInfoSuite.scala")
        line.toInt
      case _ => fail("Did not match expected call site format")
    }

    curCallSite match {
      case CALL_SITE_REGEX(func, file, line) =>
        assert(func === "getCallSite") // this is correct because we called it from outside of Spark
        assert(file === "SparkContextInfoSuite.scala")
        assert(line.toInt === rddCreationLine.toInt + 2)
      case _ => fail("Did not match expected call site format")
    }
  }
} 
Example 31
Source File: SparkContextInfoSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.scalatest.Assertions

import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
  test("getPersistentRDDs only returns RDDs that are marked as cached") {
    sc = new SparkContext("local", "test")
    assert(sc.getPersistentRDDs.isEmpty === true)

    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    assert(sc.getPersistentRDDs.isEmpty === true)

    rdd.cache()
    assert(sc.getPersistentRDDs.size === 1)
    assert(sc.getPersistentRDDs.values.head === rdd)
  }

  test("getPersistentRDDs returns an immutable map") {
    sc = new SparkContext("local", "test")
    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    val myRdds = sc.getPersistentRDDs
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

    // myRdds2 should have 2 RDDs, but myRdds should not change
    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
    val myRdds2 = sc.getPersistentRDDs
    assert(myRdds2.size === 2)
    assert(myRdds2(0) === rdd1)
    assert(myRdds2(1) === rdd2)
    assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
    sc = new SparkContext("local", "test")
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    assert(sc.getRDDStorageInfo.size === 0)
    rdd.collect()
    assert(sc.getRDDStorageInfo.size === 1)
    assert(sc.getRDDStorageInfo.head.isCached)
    assert(sc.getRDDStorageInfo.head.memSize > 0)
    assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("call sites report correct locations") {
    sc = new SparkContext("local", "test")
    testPackage.runCallSiteTest(sc)
  }
}


package object testPackage extends Assertions {
  private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

  def runCallSiteTest(sc: SparkContext) {
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val rddCreationSite = rdd.getCreationSite
    val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"

    val rddCreationLine = rddCreationSite match {
      case CALL_SITE_REGEX(func, file, line) =>
        assert(func === "makeRDD")
        assert(file === "SparkContextInfoSuite.scala")
        line.toInt
      case _ => fail("Did not match expected call site format")
    }

    curCallSite match {
      case CALL_SITE_REGEX(func, file, line) =>
        assert(func === "getCallSite") // this is correct because we called it from outside of Spark
        assert(file === "SparkContextInfoSuite.scala")
        assert(line.toInt === rddCreationLine.toInt + 2)
      case _ => fail("Did not match expected call site format")
    }
  }
} 
Example 32
Source File: SparkContextInfoSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.scalatest.Assertions
import org.apache.spark.storage.StorageLevel

class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext {
  test("getPersistentRDDs only returns RDDs that are marked as cached") {
    sc = new SparkContext("local", "test")
    assert(sc.getPersistentRDDs.isEmpty === true)

    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    assert(sc.getPersistentRDDs.isEmpty === true)

    rdd.cache()
    assert(sc.getPersistentRDDs.size === 1)
    assert(sc.getPersistentRDDs.values.head === rdd)
  }

  test("getPersistentRDDs returns an immutable map") {
    sc = new SparkContext("local", "test")
    val rdd1 = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    val myRdds = sc.getPersistentRDDs
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)

    // myRdds2 should have 2 RDDs, but myRdds should not change
    val rdd2 = sc.makeRDD(Array(5, 6, 7, 8), 1).cache()
    val myRdds2 = sc.getPersistentRDDs
    assert(myRdds2.size === 2)
    assert(myRdds2(0) === rdd1)
    assert(myRdds2(1) === rdd2)
    assert(myRdds2(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds2(1).getStorageLevel === StorageLevel.MEMORY_ONLY)
    assert(myRdds.size === 1)
    assert(myRdds(0) === rdd1)
    assert(myRdds(0).getStorageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("getRDDStorageInfo only reports on RDDs that actually persist data") {
    sc = new SparkContext("local", "test")
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
    assert(sc.getRDDStorageInfo.size === 0)
    rdd.collect()
    assert(sc.getRDDStorageInfo.size === 1)
    assert(sc.getRDDStorageInfo.head.isCached)
    assert(sc.getRDDStorageInfo.head.memSize > 0)
    assert(sc.getRDDStorageInfo.head.storageLevel === StorageLevel.MEMORY_ONLY)
  }

  test("call sites report correct locations") {
    sc = new SparkContext("local", "test")
    testPackage.runCallSiteTest(sc)
  }
}


package object testPackage extends Assertions {
  private val CALL_SITE_REGEX = "(.+) at (.+):([0-9]+)".r

  def runCallSiteTest(sc: SparkContext) {
    val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
    val rddCreationSite = rdd.getCreationSite
    val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd"

    val rddCreationLine = rddCreationSite match {
      case CALL_SITE_REGEX(func, file, line) => {
        assert(func === "makeRDD")
        assert(file === "SparkContextInfoSuite.scala")
        line.toInt
      }
      case _ => fail("Did not match expected call site format")
    }

    curCallSite match {
      case CALL_SITE_REGEX(func, file, line) => {
        assert(func === "getCallSite") // this is correct because we called it from outside of Spark
        assert(file === "SparkContextInfoSuite.scala")
        assert(line.toInt === rddCreationLine.toInt + 2)
      }
      case _ => fail("Did not match expected call site format")
    }
  }
} 
Example 33
Source File: MockHelpers.scala    From guardrail   with MIT License 5 votes vote down vote up
package helpers

import com.fasterxml.jackson.databind.ObjectMapper
import io.netty.handler.codec.http.EmptyHttpHeaders
import java.io.ByteArrayInputStream
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import java.util.Collections
import java.util.concurrent.CompletableFuture
import javax.ws.rs.container.AsyncResponse
import org.asynchttpclient.Response
import org.asynchttpclient.uri.Uri
import org.mockito.{ ArgumentMatchersSugar, MockitoSugar }
import org.scalatest.Assertions
import scala.reflect.ClassTag

object MockHelpers extends Assertions with MockitoSugar with ArgumentMatchersSugar {
  def mockAsyncResponse[T](future: CompletableFuture[T])(implicit cls: ClassTag[T]): AsyncResponse = {
    val asyncResponse = mock[AsyncResponse]

    when(asyncResponse.resume(any[T])) thenAnswer [AnyRef] { response =>
      response match {
        case t: Throwable => future.completeExceptionally(t)
        case other: T     => future.complete(other)
        case other        => fail(s"AsyncResponse.resume expected an object of type ${cls.runtimeClass.getName}, but got ${other.getClass.getName} instead")
      }
    }

    asyncResponse
  }

  def mockAHCResponse[T](uri: String, status: Int, maybeBody: Option[T] = None)(implicit mapper: ObjectMapper): Response = {
    val response = mock[Response]
    when(response.getUri) thenReturn Uri.create(uri)
    when(response.hasResponseStatus) thenReturn true
    when(response.getStatusCode) thenReturn status
    when(response.getStatusText) thenReturn "Some Status"
    when(response.hasResponseHeaders) thenReturn true
    when(response.getHeaders) thenReturn EmptyHttpHeaders.INSTANCE
    when(response.getHeader(any)) thenReturn null
    when(response.getHeaders(any)) thenReturn Collections.emptyList()
    maybeBody match {
      case None =>
        when(response.hasResponseBody) thenReturn true
      case Some(body) =>
        val responseBytes = mapper.writeValueAsBytes(body)
        val responseStr   = new String(responseBytes, StandardCharsets.UTF_8)
        when(response.hasResponseBody) thenReturn true
        when(response.getResponseBody(any)) thenReturn responseStr
        when(response.getResponseBody) thenReturn responseStr
        when(response.getResponseBodyAsStream) thenReturn new ByteArrayInputStream(responseBytes)
        when(response.getResponseBodyAsByteBuffer) thenReturn ByteBuffer.wrap(responseBytes)
        when(response.getResponseBodyAsBytes) thenReturn responseBytes
    }
    response
  }

} 
Example 34
Source File: ListTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.apache.lucene.spatial3d.geom.{GeoPoint, PlanetModel}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class ListTest extends FlatSpec with TestCommon {

  
  Spec[DateTimeList] should "extend OPList[Long]" in {
    val myDateTimeList = new DateTimeList(List.empty[Long])
    myDateTimeList shouldBe a[FeatureType]
    myDateTimeList shouldBe a[OPCollection]
    myDateTimeList shouldBe a[OPList[_]]
    myDateTimeList shouldBe a[DateList]
  }
  it should "compare values correctly" in {
    new DateTimeList(List(456L, 13L)) shouldBe new DateTimeList(List(456L, 13L))
    new DateTimeList(List(13L, 456L)) should not be new DateTimeList(List(456L, 13L))
    FeatureTypeDefaults.DateTimeList should not be new DateTimeList(List(456L, 13L))
    FeatureTypeDefaults.DateTimeList shouldBe DateTimeList(List.empty[Long])

    List(12237834L, 4890489839L).toDateTimeList shouldBe a[DateTimeList]
  }


} 
Example 35
Source File: AvroFieldTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.cli.gen

import com.salesforce.op.cli.gen.AvroField._
import com.salesforce.op.test.TestCommon
import org.apache.avro.Schema
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.collection.JavaConverters._
import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class AvroFieldTest extends FlatSpec with TestCommon with Assertions {

  Spec[AvroField] should "do from" in {
    val types = List(
      Schema.Type.STRING,
      //  Schema.Type.BYTES, // somehow this avro type is not covered (yet)
      Schema.Type.INT,
      Schema.Type.LONG,
      Schema.Type.FLOAT,
      Schema.Type.DOUBLE,
      Schema.Type.BOOLEAN
    )
    val simpleSchemas = types map Schema.create

    val unions = List(
      Schema.createUnion((Schema.Type.NULL::Schema.Type.INT::Nil) map Schema.create asJava),
      Schema.createUnion((Schema.Type.INT::Schema.Type.NULL::Nil) map Schema.create asJava)
    )

    val enum = Schema.createEnum("Aliens", "undocumented", "outer",
      List("Edgar_the_Bug", "Boris_the_Animal", "Laura_Vasquez") asJava)

    val allSchemas = (enum::unions)++simpleSchemas // NULL does not work

    val fields = allSchemas.zipWithIndex map {
      case (s, i) => new Schema.Field("x" + i, s, "Who", null: Object)
    }

    val expected = List(
      AEnum(fields(0), isNullable = false),
      AInt(fields(1), isNullable = true),
      AInt(fields(2), isNullable = true),
      AString(fields(3), isNullable = false),
      AInt(fields(4), isNullable = false),
      ALong(fields(5), isNullable = false),
      AFloat(fields(6), isNullable = false),
      ADouble(fields(7), isNullable = false),
      ABoolean(fields(8), isNullable = false)
    )

    an[IllegalArgumentException] should be thrownBy {
      val nullSchema = Schema.create(Schema.Type.NULL)
      val nullField = new Schema.Field("xxx", null, "Nobody", null: Object)
      AvroField from nullField
    }

    fields.size shouldBe expected.size

    for {
      (field, expected) <- fields zip expected
    } {
      val actual = AvroField from field
      actual shouldBe expected
    }
  }

} 
Example 36
Source File: OpsTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.cli.gen

import java.io.File
import java.nio.file.Paths

import com.salesforce.op.cli.{AvroSchemaFromFile, CliParameters, GeneratorConfig}
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.io.Source


@RunWith(classOf[JUnitRunner])
class OpsTest extends FlatSpec with TestCommon with Assertions {

  val tempFolder = new File(System.getProperty("java.io.tmpdir"))
  val projectFolder = new File(tempFolder, "cli_test")
  projectFolder.deleteOnExit()

  val testParams = CliParameters(
    location = tempFolder,
    projName = "cli_test",
    inputFile = Some(Paths.get("templates", "simple", "src", "main", "resources", "PassengerData.csv").toFile),
    response = Some("survived"),
    idField = Some("passengerId"),
    schemaSource = Some(
      AvroSchemaFromFile(Paths.get("..", "utils", "src", "main", "avro", "PassengerCSV.avsc").toFile)
    ),
    answersFile = Some(new File("passengers.answers")),
    overwrite = true).values

  Spec[Ops] should "generate project files" in {

    testParams match {
      case None =>
        fail("Could not create config, I wonder why")
      case Some(conf: GeneratorConfig) =>
        val ops = Ops(conf)
        ops.run()
        val buildFile = new File(projectFolder, "build.gradle")
        buildFile should exist

        val buildFileContent = Source.fromFile(buildFile).mkString

        buildFileContent should include("mainClassName = 'com.salesforce.app.cli_test'")

        val scalaSourcesFolder = Paths.get(projectFolder.toString, "src", "main", "scala", "com", "salesforce", "app")
        val featuresFile = Source.fromFile(new File(scalaSourcesFolder.toFile, "Features.scala")).getLines
        val heightLine = featuresFile.find(_ contains "description") map (_.trim)
        heightLine shouldBe Some(
          "val description = FB.Text[PassengerCSV].extract(_.getDescription.toText).asPredictor"
        )
    }

  }

} 
Example 37
Source File: UserIOTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.cli.gen

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}


@RunWith(classOf[JUnitRunner])
class UserIOTest extends FlatSpec with TestCommon with Assertions {


  private case class Oracle(answers: String*) extends UserIO {
    private var i = -1
    var question = "---"

    override def readLine(q: String): Option[String] = {
      question = q
      i += 1
      if (i < answers.length) Some(answers(i)) else throw new IllegalStateException(s"Out of answers, q=$q")
    }
  }

  Spec[UserIO] should "do qna" in {
    // @see https://www.urbandictionary.com/define.php?term=aks
    def aksme(q: String, answers: String*): Option[String] = {
      Oracle(answers: _*).qna(q, _.length == 1, Map("2*3" -> "6", "3+2" -> "5"))
    }

    aksme("2+2", "11", "22", "?") shouldBe Some("?")
    aksme("2+2", "4", "5", "?") shouldBe Some("4")
    aksme("2+3", "44", "", "?") shouldBe Some("?")
    aksme("2*3", "4", "?") shouldBe Some("6")
    aksme("3+2", "4", "?") shouldBe Some("5")
  }

  it should "ask" in {

    // @see https://www.urbandictionary.com/define.php?term=aks
    def aksme[Int](q: String, opts: Map[Int, List[String]], answers: String*): (String, Int) = {
      val console = Oracle(answers: _*)
      val answer = console.ask(q, opts) getOrElse fail(s"A problem answering question $q")
      (console.question, answer)
    }

    an[IllegalStateException] should be thrownBy
      aksme("what is your name?", Map(1 -> List("one", "uno")), "11", "1", "?")

    aksme("what is your name?",
      Map(
        1 -> List("Nessuno", "Nobody"),
        2 -> List("Ishmael", "Gantenbein")),
      "5", "1", "?") shouldBe("what is your name? [0] Nessuno [1] Ishmael: ", 2)
  }

} 
Example 38
Source File: OpLDATest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.clustering.LDA
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class OpLDATest extends FlatSpec with TestSparkContext {

  val inputData = Seq(
    (0.0, Vectors.sparse(11, Array(0, 1, 2, 4, 5, 6, 7, 10), Array(1.0, 2.0, 6.0, 2.0, 3.0, 1.0, 1.0, 3.0))),
    (1.0, Vectors.sparse(11, Array(0, 1, 3, 4, 7, 10), Array(1.0, 3.0, 1.0, 3.0, 2.0, 1.0))),
    (2.0, Vectors.sparse(11, Array(0, 1, 2, 5, 6, 8, 9), Array(1.0, 4.0, 1.0, 4.0, 9.0, 1.0, 2.0))),
    (3.0, Vectors.sparse(11, Array(0, 1, 3, 6, 8, 9, 10), Array(2.0, 1.0, 3.0, 5.0, 2.0, 3.0, 9.0))),
    (4.0, Vectors.sparse(11, Array(0, 1, 2, 3, 4, 6, 9, 10), Array(3.0, 1.0, 1.0, 9.0, 3.0, 2.0, 1.0, 3.0))),
    (5.0, Vectors.sparse(11, Array(0, 1, 3, 4, 5, 6, 7, 8, 9), Array(4.0, 2.0, 3.0, 4.0, 5.0, 1.0, 1.0, 1.0, 4.0))),
    (6.0, Vectors.sparse(11, Array(0, 1, 3, 6, 8, 9, 10), Array(2.0, 1.0, 3.0, 5.0, 2.0, 2.0, 9.0))),
    (7.0, Vectors.sparse(11, Array(0, 1, 2, 3, 4, 5, 6, 9, 10), Array(1.0, 1.0, 1.0, 9.0, 2.0, 1.0, 2.0, 1.0, 3.0))),
    (8.0, Vectors.sparse(11, Array(0, 1, 3, 4, 5, 6, 7), Array(4.0, 4.0, 3.0, 4.0, 2.0, 1.0, 3.0))),
    (9.0, Vectors.sparse(11, Array(0, 1, 2, 4, 6, 8, 9, 10), Array(2.0, 8.0, 2.0, 3.0, 2.0, 2.0, 7.0, 2.0))),
    (10.0, Vectors.sparse(11, Array(0, 1, 2, 3, 5, 6, 9, 10), Array(1.0, 1.0, 1.0, 9.0, 2.0, 2.0, 3.0, 3.0))),
    (11.0, Vectors.sparse(11, Array(0, 1, 4, 5, 6, 7, 9), Array(4.0, 1.0, 4.0, 5.0, 1.0, 3.0, 1.0)))
  ).map(v => v._1.toReal -> v._2.toOPVector)

  lazy val (ds, f1, f2) = TestFeatureBuilder(inputData)

  lazy val inputDS = ds.persist()

  val seed = 1234567890L
  val k = 3
  val maxIter = 100

  lazy val expected = new LDA()
    .setFeaturesCol(f2.name)
    .setK(k)
    .setSeed(seed)
    .fit(inputDS)
    .transform(inputDS)
    .select("topicDistribution")
    .collect()
    .toSeq
    .map(_.getAs[Vector](0))

  Spec[OpLDA] should "convert document term vectors into topic vectors" in {
    val f2Vec = new OpLDA().setInput(f2).setK(k).setSeed(seed).setMaxIter(maxIter)
    val testTransformedData = f2Vec.fit(inputDS).transform(inputDS)
    val output = f2Vec.getOutput()
    val estimate = testTransformedData.collect(output)
    val mse = computeMeanSqError(estimate, expected)
    val expectedMse = 0.5
    withClue(s"Computed mse $mse (expected $expectedMse)") {
      mse should be < expectedMse
    }
  }

  it should "convert document term vectors into topic vectors (shortcut version)" in {
    val output = f2.lda(k = k, seed = seed, maxIter = maxIter)
    val f2Vec = output.originStage.asInstanceOf[OpLDA]
    val testTransformedData = f2Vec.fit(inputDS).transform(inputDS)
    val estimate = testTransformedData.collect(output)
    val mse = computeMeanSqError(estimate, expected)
    val expectedMse = 0.5
    withClue(s"Computed mse $mse (expected $expectedMse)") {
      mse should be < expectedMse
    }
  }

  private def computeMeanSqError(estimate: Seq[OPVector], expected: Seq[Vector]): Double = {
    val n = estimate.length.toDouble
    estimate.zip(expected).map { case (est, exp) => Vectors.sqdist(est.value, exp) }.sum / n
  }
} 
Example 39
Source File: OpWord2VecTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class OpWord2VecTest extends FlatSpec with TestSparkContext {

  val data = Seq(
    "I I I like like Spark".split(" "),
    "Hi I heard about Spark".split(" "),
    "I wish Java could use case classes".split(" "),
    "Logistic regression models are neat".split(" ")
  ).map(_.toSeq.toTextList)

  lazy val (inputData, f1) = TestFeatureBuilder(Seq(data.head))
  lazy val (testData, _) = TestFeatureBuilder(data.tail)

  lazy val expected = data.tail.zip(Seq(
    Vectors.dense(-0.029884086549282075, -0.055613189935684204, 0.04186216294765473).toOPVector,
    Vectors.dense(-0.0026281912411962234, -0.016138136386871338, 0.010740748473576136).toOPVector,
    Vectors.dense(0.0, 0.0, 0.0).toOPVector
  )).toArray

  Spec[OpWord2VecTest] should "convert array of strings into a vector" in {
    val f1Vec = new OpWord2Vec().setInput(f1).setMinCount(0).setVectorSize(3).setSeed(1234567890L)
    val output = f1Vec.getOutput()
    val testTransformedData = f1Vec.fit(inputData).transform(testData)
    testTransformedData.orderBy(f1.name).collect(f1, output) shouldBe expected
  }

  it should "convert array of strings into a vector (shortcut version)" in {
    val output = f1.word2vec(minCount = 0, vectorSize = 3)
    val f1Vec = output.originStage.asInstanceOf[OpWord2Vec].setSeed(1234567890L)
    val testTransformedData = f1Vec.fit(inputData).transform(testData)
    testTransformedData.orderBy(f1.name).collect(f1, output) shouldBe expected
  }

} 
Example 40
Source File: IDFTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.feature.IDF
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.{Estimator, Transformer}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class IDFTest extends FlatSpec with TestSparkContext {

  val data = Seq(
    Vectors.sparse(4, Array(1, 3), Array(1.0, 2.0)),
    Vectors.dense(0.0, 1.0, 2.0, 3.0),
    Vectors.sparse(4, Array(1), Array(1.0))
  )

  lazy val (ds, f1) = TestFeatureBuilder(data.map(_.toOPVector))

  Spec[IDF] should "compute inverted document frequency" in {
    val idf = f1.idf()
    val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds)
    val transformedData = model.asInstanceOf[Transformer].transform(ds)
    val results = transformedData.select(idf.name).collect(idf)

    idf.name shouldBe idf.originStage.getOutputFeatureName

    val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
      math.log((data.length + 1.0) / (x + 1.0))
    })
    val expected = scaleDataWithIDF(data, expectedIdf)

    for {
      (res, exp) <- results.zip(expected)
      (x, y) <- res.value.toArray.zip(exp.toArray)
    } assert(math.abs(x - y) <= 1e-5)
  }

  it should "compute inverted document frequency when minDocFreq is 1" in {
    val idf = f1.idf(minDocFreq = 1)
    val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds)
    val transformedData = model.asInstanceOf[Transformer].transform(ds)
    val results = transformedData.select(idf.name).collect(idf)
    idf.name shouldBe idf.originStage.getOutputFeatureName

    val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
      if (x > 0) math.log((data.length + 1.0) / (x + 1.0)) else 0
    })
    val expected = scaleDataWithIDF(data, expectedIdf)

    for {
      (res, exp) <- results.zip(expected)
      (x, y) <- res.value.toArray.zip(exp.toArray)
    } assert(math.abs(x - y) <= 1e-5)
  }

  private def scaleDataWithIDF(dataSet: Seq[Vector], model: Vector): Seq[Vector] = {
    dataSet.map {
      case data: DenseVector =>
        val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y }
        Vectors.dense(res)
      case data: SparseVector =>
        val res = data.indices.zip(data.values).map { case (id, value) =>
          (id, value * model(id))
        }
        Vectors.sparse(data.size, res)
    }
  }

}