org.scalatest.FunSuite Scala Examples

The following examples show how to use org.scalatest.FunSuite. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: SparkNarrowTest.scala    From spark-tools   with Apache License 2.0 7 votes vote down vote up
package io.univalence

import java.net.URLClassLoader
import java.sql.Date

import io.univalence.centrifuge.Sparknarrow
import org.apache.spark.SparkConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
import org.scalatest.FunSuite

case class Person(name: String, age: Int, date: Date)

class SparknarrowTest extends FunSuite {

  val conf: SparkConf = new SparkConf()
  conf.setAppName("yo")
  conf.set("spark.sql.caseSensitive", "true")
  conf.setMaster("local[2]")

  implicit val ss: SparkSession = SparkSession.builder.config(conf).getOrCreate
  import ss.implicits._

  test("testBasicCC") {

    val classDef = Sparknarrow.basicCC(Encoders.product[Person].schema).classDef
    checkDefinition(classDef)

  }

  def checkDefinition(scalaCode: String): Unit = {
    //TODO do a version for 2.11 and 2.12
    
  }

  test("play with scala eval") {

    val code =
      """
        case class Tata(str: String)
        case class Toto(age: Int, tata: Tata)
      """

    checkDefinition(code)
    checkDefinition(code)

  }

  ignore("printSchema StructType") {
    val yo = StructType(
      Seq(
        StructField("name", StringType),
        StructField("tel", ArrayType(StringType))
      )
    )

    yo.printTreeString()
  }

} 
Example 2
Source File: TransformerTest.scala    From incubator-s2graph   with Apache License 2.0 6 votes vote down vote up
package org.apache.s2graph.s2jobs.wal

import org.apache.s2graph.s2jobs.task.TaskConf
import org.apache.s2graph.s2jobs.wal.transformer._
import org.scalatest.{FunSuite, Matchers}
import play.api.libs.json.Json

class TransformerTest extends FunSuite with Matchers {
  val walLog = WalLog(1L, "insert", "edge", "a", "b", "s2graph", "friends", """{"name": 1, "url": "www.google.com"}""")

  test("test default transformer") {
    val taskConf = TaskConf.Empty
    val transformer = new DefaultTransformer(taskConf)
    val dimVals = transformer.toDimValLs(walLog, "name", "1")

    dimVals shouldBe Seq(DimVal("friends:name", "1"))
  }

  test("test ExtractDomain from URL") {
    val taskConf = TaskConf.Empty.copy(options =
      Map("urlDimensions" -> Json.toJson(Seq("url")).toString())
    )
    val transformer = new ExtractDomain(taskConf)
    val dimVals = transformer.toDimValLs(walLog, "url", "http://www.google.com/abc")

    dimVals shouldBe Seq(
      DimVal("host", "www.google.com"),
      DimVal("domain", "www.google.com"),
      DimVal("domain", "www.google.com/abc")
    )
  }
} 
Example 3
Source File: SqlUnitTest.scala    From SparkUnitTestingExamples   with Apache License 2.0 6 votes vote down vote up
package com.cloudera.sa.spark.unittest.sql

import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}

import scala.collection.mutable

class SqlUnitTest extends FunSuite with
BeforeAndAfterEach with BeforeAndAfterAll{

  @transient var sc: SparkContext = null
  @transient var hiveContext: HiveContext = null

  override def beforeAll(): Unit = {

    val envMap = Map[String,String](("Xmx", "512m"))

    val sparkConfig = new SparkConf()
    sparkConfig.set("spark.broadcast.compress", "false")
    sparkConfig.set("spark.shuffle.compress", "false")
    sparkConfig.set("spark.shuffle.spill.compress", "false")
    sparkConfig.set("spark.io.compression.codec", "lzf")
    sc = new SparkContext("local[2]", "unit test", sparkConfig)
    hiveContext = new HiveContext(sc)
  }

  override def afterAll(): Unit = {
    sc.stop()
  }

  test("Test table creation and summing of counts") {
    val personRDD = sc.parallelize(Seq(Row("ted", 42, "blue"),
      Row("tj", 11, "green"),
      Row("andrew", 9, "green")))

    hiveContext.sql("create table person (name string, age int, color string)")

    val emptyDataFrame = hiveContext.sql("select * from person limit 0")

    val personDataFrame = hiveContext.createDataFrame(personRDD, emptyDataFrame.schema)
    personDataFrame.registerTempTable("tempPerson")

    val ageSumDataFrame = hiveContext.sql("select sum(age) from tempPerson")

    val localAgeSum = ageSumDataFrame.take(10)

    assert(localAgeSum(0).get(0) == 62, "The sum of age should equal 62 but it equaled " + localAgeSum(0).get(0))
  }
} 
Example 4
Source File: OffsetLoaderTest.scala    From toketi-iothubreact   with MIT License 6 votes vote down vote up
package com.microsoft.azure.iot.iothubreact.checkpointing

import com.microsoft.azure.iot.iothubreact.config.{IConfiguration, IConnectConfiguration}
import org.scalatest.FunSuite
import org.scalatest.mockito.MockitoSugar
import org.mockito.Mockito.when
import org.scalatest.Matchers._

class OffsetLoaderTest extends FunSuite with MockitoSugar {

  test("test GetSavedOffsets handles None appropriately") {

    val config = mock[IConfiguration]
    val cnConfig = mock[IConnectConfiguration]
    when(config.connect) thenReturn(cnConfig)
    when(cnConfig.iotHubPartitions) thenReturn(10)
    val loader = StubbedLoader(config)
    loader.GetSavedOffsets should be(Map(0 → "Offset 0", 1 → "Offset 1", 3 → "Offset 3"))
  }

  case class StubbedLoader(config: IConfiguration) extends OffsetLoader(config) {

    override private[iothubreact] def GetSavedOffset(partition: Int) = {
      partition match {
        case 0 ⇒ Some("Offset 0")
        case 1 ⇒ Some("Offset 1")
        case 3 ⇒ Some("Offset 3")
        case _ ⇒ None
      }
    }
  }

} 
Example 5
Source File: SparkPFASuiteBase.scala    From aardpfark   with Apache License 2.0 6 votes vote down vote up
package com.ibm.aardpfark.pfa

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.SparkConf
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.scalactic.Equality
import org.scalatest.FunSuite

abstract class SparkPFASuiteBase extends FunSuite with DataFrameSuiteBase with PFATestUtils {

  val sparkTransformer: Transformer
  val input: Array[String]
  val expectedOutput: Array[String]

  val sparkConf =  new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID).
    set("spark.driver.host", "localhost")
  override lazy val spark = SparkSession.builder().config(sparkConf).getOrCreate()
  override val reuseContextIfPossible = true

  // Converts column containing a vector to an array
  def withColumnAsArray(df: DataFrame, colName: String) = {
    val vecToArray = udf { v: Vector => v.toArray }
    df.withColumn(colName, vecToArray(df(colName)))
  }

  def withColumnAsArray(df: DataFrame, first: String, others: String*) = {
    val vecToArray = udf { v: Vector => v.toArray }
    var result = df.withColumn(first, vecToArray(df(first)))
    others.foreach(c => result = result.withColumn(c, vecToArray(df(c))))
    result
  }

  // Converts column containing a vector to a sparse vector represented as a map
  def getColumnAsSparseVectorMap(df: DataFrame, colName: String) = {
    val vecToMap = udf { v: Vector => v.toSparse.indices.map(i => (i.toString, v(i))).toMap }
    df.withColumn(colName, vecToMap(df(colName)))
  }

}

abstract class Result

object ApproxEquality extends ApproxEquality

trait ApproxEquality {

  import org.scalactic.Tolerance._
  import org.scalactic.TripleEquals._

  implicit val seqApproxEq: Equality[Seq[Double]] = new Equality[Seq[Double]] {
    override def areEqual(a: Seq[Double], b: Any): Boolean = {
      b match {
        case d: Seq[Double] =>
          a.zip(d).forall { case (l, r) => l === r +- 0.001 }
        case _ =>
          false
      }
    }
  }

  implicit val vectorApproxEq: Equality[Vector] = new Equality[Vector] {
    override def areEqual(a: Vector, b: Any): Boolean = {
      b match {
        case v: Vector =>
          a.toArray.zip(v.toArray).forall { case (l, r) => l === r +- 0.001 }
        case _ =>
          false
      }
    }
  }
} 
Example 6
Source File: ContigNormalizationTest.scala    From bdg-sequila   with Apache License 2.0 5 votes vote down vote up
package org.biodatageeks.sequila.tests.dataquality

import org.biodatageeks.sequila.utils.DataQualityFuncs
import org.scalatest.FunSuite

class ContigNormalizationTest extends FunSuite{

  test("Test contig") {
    val chrInTest1 = "chr1"
    val chrInTest2 = "chrM"
    val chrInTest3 = "M"
    assert(
      DataQualityFuncs.cleanContig(chrInTest1) == "1"
    )
    assert(
      DataQualityFuncs.cleanContig(chrInTest2) === "MT"
    )
    assert(
      DataQualityFuncs.cleanContig(chrInTest3) === "MT"
    )
  }

} 
Example 7
Source File: TestUtils.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow.utils

import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FunSuite}

class TestUtils extends FunSuite with BeforeAndAfter {
	
	var conf: SparkConf = _
	var context: SparkContext = _
	
	before {
		conf = new SparkConf()
			.setMaster("local")
			.setAppName("wow-spark")
		context = new SparkContext(conf)
		context.setLogLevel("WARN")
	}
	
	after(println("have a good time"))
} 
Example 8
Source File: ScalaTest.scala    From wow-spark   with MIT License 5 votes vote down vote up
package com.sev7e0.wow

import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach, FunSuite, Matchers}

class ScalaTest extends FunSuite
	with BeforeAndAfter
	with BeforeAndAfterEach
	with Matchers{

	var str:String = _

	before{
		str = "result"
		println("当前str:"+str)
	}

	after{
		println("测试结束")
	}

	test("scala test demo") {
		println(11111)
		//条件判断
		assert(1 == 1)
		//用于判断是否发生异常
		assertThrows[Exception](1/0)
		//对给定的字符串进行类型检查
		assertCompiles("""val a: String = "a" """)
		//判断结果是否与期待的结果值相同
		assertResult(str)(result)
		
		str shouldEqual "result"
	}
	def result:String ={
		"result"
	}
} 
Example 9
Source File: BasicOpsTests.scala    From sigmastate-interpreter   with MIT License 5 votes vote down vote up
package sigmastate.eval

import java.math.BigInteger

import org.bouncycastle.crypto.ec.CustomNamedCurves
import org.scalatest.{FunSuite, Matchers}
import special.sigma.{Box, Context, ContractsTestkit, MockSigma, SigmaContract, SigmaDslBuilder, SigmaProp, TestSigmaDslBuilder}

import scala.language.implicitConversions

class BasicOpsTests extends FunSuite with ContractsTestkit with Matchers {
  override val SigmaDsl: SigmaDslBuilder = CostingSigmaDslBuilder

  implicit def boolToSigma(b: Boolean): SigmaProp = MockSigma(b)

  test("atLeast") {
    val props = Colls.fromArray(Array[SigmaProp](false, true, true, false))

    // border cases
    SigmaDsl.atLeast(0, props).isValid shouldBe true
    SigmaDsl.atLeast(5, props).isValid shouldBe false

    // normal cases
    SigmaDsl.atLeast(1, props).isValid shouldBe true
    SigmaDsl.atLeast(2, props).isValid shouldBe true
    SigmaDsl.atLeast(3, props).isValid shouldBe false
  }

  // TODO this is valid for BigIntModQ type (https://github.com/ScorexFoundation/sigmastate-interpreter/issues/554)
  ignore("ByteArrayToBigInt should always produce a positive big int") {
    SigmaDsl.byteArrayToBigInt(collection[Byte](-1)).signum shouldBe 1
  }

  // TODO this is valid for BigIntModQ type (https://github.com/ScorexFoundation/sigmastate-interpreter/issues/554)
  ignore("ByteArrayToBigInt should always produce big int less than dlog group order") {
    val groupOrder = CustomNamedCurves.getByName("secp256k1").getN

    SigmaDsl.byteArrayToBigInt(
      Colls.fromArray(groupOrder.subtract(BigInteger.ONE).toByteArray)
    ).compareTo(SigmaDsl.BigInt(BigInteger.ONE)) shouldBe 1

    SigmaDsl.byteArrayToBigInt(
      Colls.fromArray(groupOrder.toByteArray)
    ).compareTo(SigmaDsl.BigInt(BigInteger.ONE)) shouldBe 1

    an [RuntimeException] should be thrownBy
      SigmaDsl.byteArrayToBigInt(Colls.fromArray(groupOrder.add(BigInteger.ONE).toByteArray))

    an [RuntimeException] should be thrownBy
      SigmaDsl.byteArrayToBigInt(Colls.fromArray(Array.fill[Byte](500)(1)))
  }

  test("Coll.append")  {
    val c1 = collection[Byte](1, 2)
    val c2 = collection[Byte](3, 4)
    c1.append(c2).toArray shouldBe Array[Byte](1, 2, 3, 4)
  }

  test("box.creationInfo._1 is Int") {
    val box = newAliceBox(1, 100)
    box.creationInfo._1 shouldBe a [Integer]
  }

} 
Example 10
Source File: SparkFunSuite.scala    From yggdrasil   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.mllib.util

import org.apache.spark.Logging
import org.scalatest.{FunSuite, Outcome}


  final protected override def withFixture(test: NoArgTest): Outcome = {
    val testName = test.text
    val suiteName = this.getClass.getName
    val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s")
    try {
      logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n")
      test()
    } finally {
      logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n")
    }
  }
} 
Example 11
Source File: HBaseLocalClient.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.hbase.utilities

import java.io.File

import scala.collection.mutable.ArrayBuffer

import com.google.common.io.Files
import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName}
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.datasources.hbase.SparkHBaseConf
import org.apache.spark.sql.util._
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}

import com.paypal.gimel.common.catalog.Field
import com.paypal.gimel.hbase.DataSet

class HBaseLocalClient extends FunSuite with Matchers with BeforeAndAfterAll {

  var sparkSession : SparkSession = _
  var dataSet: DataSet = _
  val hbaseTestingUtility = new HBaseTestingUtility()
  val tableName = "test_table"
  val cfs = Array("personal", "professional")
  val columns = Array("id", "name", "age", "address", "company", "designation", "salary")
  val fields = columns.map(col => new Field(col))

  val metrics = ArrayBuffer.empty[(String, QueryExecution, Long)]

  protected override def beforeAll(): Unit = {
    val tempDir: File = Files.createTempDir
    tempDir.deleteOnExit
    hbaseTestingUtility.startMiniCluster()
    SparkHBaseConf.conf = hbaseTestingUtility.getConfiguration
    createTable(tableName, cfs)
    val conf = new SparkConf
    conf.set(SparkHBaseConf.testConf, "true")
    sparkSession = SparkSession.builder()
      .master("local")
      .appName("HBase Test")
      .config(conf)
      .getOrCreate()

    val listener = new QueryExecutionListener {
      // Only test successful case here, so no need to implement `onFailure`
      override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {}
      override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {
        metrics += ((funcName, qe, duration))
      }
    }
    sparkSession.listenerManager.register(listener)
    sparkSession.sparkContext.setLogLevel("ERROR")
    dataSet = new DataSet(sparkSession)
  }

  protected override def afterAll(): Unit = {
    hbaseTestingUtility.shutdownMiniCluster()
    sparkSession.close()
  }

  def createTable(name: String, cfs: Array[String]) {
    val tName = Bytes.toBytes(name)
    val bcfs = cfs.map(Bytes.toBytes(_))
    try {
      hbaseTestingUtility.deleteTable(TableName.valueOf(tName))
    } catch {
      case _ : Throwable =>
        println("No table = " + name + " found")
    }
    hbaseTestingUtility.createMultiRegionTable(TableName.valueOf(tName), bcfs)
  }

  // Mocks data for testing
  def mockDataInDataFrame(numberOfRows: Int): DataFrame = {
    def stringed(n: Int) = s"""{"id": "$n","name": "MAC-$n", "address": "MAC-${n + 1}", "age": "${n + 1}", "company": "MAC-$n", "designation": "MAC-$n", "salary": "${n * 10000}" }"""
    val texts: Seq[String] = (1 to numberOfRows).map { x => stringed(x) }
    val rdd: RDD[String] = sparkSession.sparkContext.parallelize(texts)
    val dataFrame: DataFrame = sparkSession.read.json(rdd)
    dataFrame
  }
} 
Example 12
Source File: QueryGuardTest.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.common.query.guard

import scala.collection.immutable
import scala.concurrent.Future
import scala.util.{Failure, Try}

import org.scalatest.FunSuite

import com.paypal.gimel.common.conf.QueryGuardConfigs
import com.paypal.gimel.common.utilities.spark.SharedSparkSession
import com.paypal.gimel.logger.Logger

class QueryGuardTest extends FunSuite with SharedSparkSession {
  override protected val additionalConfig: Map[String, String] = Map(
    QueryGuardConfigs.JOB_TTL -> "60000"
  )

  import ConcurrentContext._
  private val logger = Logger()

  def startAppAsync(jobSleepTimeoutConfig: Map[Int, Long] = Map.empty,
                    eachRunLength: Int = 10): Unit = {
    val scheduledJobs: immutable.Seq[Future[Unit]] =
      for (jobId <- 0 until jobSleepTimeoutConfig.size) yield {
        executeAsync(
          performAction(spark, jobSleepTimeoutConfig(jobId), eachRunLength)
        )
      }
    awaitAll(scheduledJobs.toIterator)
  }

  def startAppSync(jobSleepTimeoutConfig: Map[Int, Long] = Map.empty,
                   eachRunLength: Int = 10): Unit = {
    for (jobId <- 0 until jobSleepTimeoutConfig.size) {
      startSparkjob(spark, jobSleepTimeoutConfig(jobId), eachRunLength)
    }
  }

  test(
    "Query guard eviction with all the tasks completing within the scheduled time interval"
  ) {
    spark.conf.set(QueryGuardConfigs.DELAY_TTL, "1000")
    val jobSleepTimeoutConfig: Map[Int, Long] =
      Map(0 -> 5000, 1 -> 4000, 2 -> 500, 3 -> 2500)
    logger.setLogLevel("CONSOLE")
    val queryGuard: QueryGuard = new QueryGuard(spark)
    queryGuard.start()
    startAppAsync(jobSleepTimeoutConfig)
    queryGuard.stop()
  }

  test("Query guard eviction with synchronous timed task execution") {
    spark.conf.set(QueryGuardConfigs.JOB_TTL, "3000")
    spark.conf.set(QueryGuardConfigs.DELAY_TTL, "1000")
    val jobSleepTimeoutConfig: Map[Int, Long] =
      Map(0 -> 500, 1 -> 2500, 2 -> 1500, 3 -> 2800)
    logger.setLogLevel("CONSOLE")
    val queryGuard: QueryGuard = new QueryGuard(spark)
    queryGuard.start()
    startAppSync(jobSleepTimeoutConfig, 1)
    queryGuard.stop()
  }

  ignore("Ignoring this test") {
    test("Query guard eviction with app fail criteria") {
      spark.conf.set(QueryGuardConfigs.JOB_TTL, "3000")
      spark.conf.set(QueryGuardConfigs.DELAY_TTL, "1000")
      val jobSleepTimeoutConfig: Map[Int, Long] =
        Map(0 -> 500, 1 -> 2500, 2 -> 1500, 3 -> 4000)
      logger.setLogLevel("CONSOLE")
      val queryGuard: QueryGuard = new QueryGuard(spark)
      queryGuard.start()
      Try {
        startAppSync(jobSleepTimeoutConfig, 1)
      } match {
        case Failure(exception) =>
          logger.error(exception.getMessage)
          assert(
            exception.getMessage
              .contains(
                "cancelled as it reached the max TTL: 3 seconds, with Job start time "
              )
          )
        case _ =>
          throw new AssertionError("Expected an exception wiht TTL breach")
      }
      queryGuard.stop()
    }
  }

  test("looping") {
    for {
      cntr <- 1 until 23
      hr = "%02d".format(cntr)
    } {
      println(s" query where hr =$hr")
    }
  }
} 
Example 13
Source File: GimelServiceUtilitiesTest.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.common.utilities

import org.scalatest.FunSuite

import com.paypal.gimel.common.gimelservices.GimelServiceUtilities

class GimelServiceUtilitiesTest extends FunSuite {

  val gimelServices = new GimelServiceUtilities()

  test("getObjectPropertiesForSystem") {
    println("1. storageTypeName=HIVE, dataset=Hive.Cluster.default.flights")
    assert(gimelServices.getObjectPropertiesForSystem("HIVE", "Hive.Cluster.default.flights")
      .sameElements(scala.collection.mutable.Map[String, String]("gimel.hive.db.name" -> "default",
        "gimel.hive.table.name" -> "flights")))

    println("2. storageTypeName=TERADATA, dataset=Teradata.Cluster.flights_db.flights")
    assert(gimelServices.getObjectPropertiesForSystem("TERADATA", "Teradata.Cluster.flights_db.flights")
      .sameElements(scala.collection.mutable.Map[String, String]("gimel.jdbc.input.table.name" -> "flights_db.flights")))

    println("3. storageTypeName=MYSQL, dataset=MySql.gimelmysql.gimeldb.gimeltable")
    assert(gimelServices.getObjectPropertiesForSystem("MYSQL", "MySql.gimelmysql.gimeldb.gimeltable")
      .sameElements(scala.collection.mutable.Map[String, String]("gimel.jdbc.input.table.name" -> "gimeldb.gimeltable")))

    println("4. storageTypeName=ELASTIC, dataset=Elastic.Gimel_Dev.default.gimel_tau_flights")
    assert(gimelServices.getObjectPropertiesForSystem("ELASTIC", "Elastic.Gimel_Dev.default.gimel_tau_flights")
      .sameElements(scala.collection.mutable.Map[String, String]("es.resource" -> "default/gimel_tau_flights",
        "es.index.auto.create" -> "true")))

    println("5. storageTypeName=HBASE, dataset=Hbase.test_cluster.default.test_table")
    assert(gimelServices.getObjectPropertiesForSystem("HBASE", "Hbase.test_cluster.default.test_table")
      .sameElements(scala.collection.mutable.Map[String, String]("gimel.hbase.namespace.name" -> "default",
        "gimel.hbase.table.name" -> "test_table")))

    println("6. storageTypeName=S3, dataset=S3.Dev.default.test_object")
    val exception = intercept[Exception] {
      gimelServices.getObjectPropertiesForSystem("S3", "S3.Dev.default.test_object")
    }
    assert(exception.getMessage.contains(s"""does not exist. Please check if the dataset name is correct."""))
  }
} 
Example 14
Source File: QueryParserUtilsTest.scala    From gimel   with Apache License 2.0 5 votes vote down vote up
package com.paypal.gimel.parser.utilities

import org.scalatest.FunSuite

class QueryParserUtilsTest extends FunSuite {

  private val query1: String =
    """
      |insert into ${targetDb}.enriched_data
      |select
      |   review.review_id,
      |   review.review_text,
      |   review.user_id,
      |   review.review_date,
      |   review.business_id,
      |   business_details.name as business_name,
      |   postal_geo_map.latitude as business_latitude,
      |   postal_geo_map.longitude as business_longitude,
      |   yelp_user.name as user_name,
      |   yelp_user.review_count as user_review_count,
      |   yelp_user.yelping_since as user_yelping_since
      |from
      |   pcatalog.teradata.tau.yelp.review review
      |inner join
      |   pcatalog.teradata.tau.yelp.business_details business_details
      |on
      |   review.business_id = business_details.business_id
      |join
      |   pcatalog.teradata.tau.yelp.business_address business_address
      |on
      |   review.business_id = business_address.business_id
      |join
      |   pcatalog.teradata.tau.yelp.user yelp_user
      |on
      |   yelp_user.user_id = review.user_id
      |join
      |   pcatalog.teradata.tau.yelp.postal_geo_map
      |on
      |   business_address.postal_code = postal_geo_map.postal_code
      |where
      |   review.review_date > current_date -150
      |and
      |   review.business_id = 'ogpiys3gnfZNZBTEJw5-1Q'
      |""".stripMargin

  test("Teradata ANSI getSourceTables ") {
    val result = Seq(
      "pcatalog.teradata.tau.yelp.review",
      "pcatalog.teradata.tau.yelp.business_details",
      "pcatalog.teradata.tau.yelp.business_address",
      "pcatalog.teradata.tau.yelp.user",
      "pcatalog.teradata.tau.yelp.postal_geo_map",
      "${targetdb}.enriched_data"
    )

    assert(
      QueryParserUtils.getAllSourceTables(query1).forall(result.contains(_))
    )
  }

  test("getTargetTables") {
    assert(
      QueryParserUtils
        .getTargetTables(query1)
        .head === "${targetDb}.enriched_data"
    )
  }

  test("isHavingLimit") {
    val sql = "select * from udc.hbase.cluster_name.default_test limit 10"
    assert(
      QueryParserUtils.isHavingLimit(sql) == true
    )
  }

  test("getLimit") {
    val sql1 = "select * from udc.hbase.cluster_name.default_test limit 10"
    println("Checking for SQL -> " + sql1)
    assert(
      QueryParserUtils.getLimit(sql1) == 10
    )

    val sql2 = "select * from udc.hbase.cluster_name.default_test limit"
    println("Checking for SQL -> " + sql2)
    val exception = intercept[Exception] {
      QueryParserUtils.getLimit(sql2)
    }
    assert(exception.getMessage.contains("Invalid SQL"))
  }
} 
Example 15
Source File: BatchModeTest.scala    From neuroflow   with Apache License 2.0 5 votes vote down vote up
import org.scalatest.FunSuite
import breeze.linalg.DenseVector
import neuroflow.core.Activators.Double._
import neuroflow.core._
import neuroflow.dsl._


class BatchModeTest extends FunSuite {

  test("Batch Mode for Dense Net CPU") {

    import neuroflow.nets.cpu.DenseNetwork._
    implicit val weights = WeightBreeder[Double].random(-1, 1)
    val f = Sigmoid
    val net = Network(layout = Vector(2) :: Dense(3, f) :: Dense(10, f) :: SquaredError())
    val batch = (1 to 100).map { _ => DenseVector.rand[Double](size = 2) }
    val res = net.batchApply(batch)

    assert(res.size == batch.size)

  }

  test("Batch Mode for Conv Net CPU") {

    import neuroflow.nets.cpu.ConvNetwork._
    implicit val weights = WeightBreeder[Double].random(-1, 1)
    val f = Sigmoid
    val net = Network(layout =
      Convolution((1, 2, 1), (0, 0), (1, 2), (1, 1), 3, f) :: Dense(10, f) :: SquaredError()
    )
    val batch = (1 to 100).map { _ => Tensor3D.fromVector(DenseVector.rand[Double](size = 2)) }
    val res = net.batchApply(batch)

    assert(res.size == batch.size)

  }

  test("Batch Mode for Dense Net GPU") {

    import neuroflow.nets.gpu.DenseNetwork._
    implicit val weights = WeightBreeder[Double].random(-1, 1)
    val f = Sigmoid
    val net = Network(layout = Vector(2) :: Dense(3, f) :: Dense(10, f) :: SquaredError())
    val batch = (1 to 100).map { _ => DenseVector.rand[Double](size = 2) }
    val res = net.batchApply(batch)

    assert(res.size == batch.size)

  }

  test("Batch Mode for Conv Net GPU") {

    import neuroflow.nets.gpu.ConvNetwork._
    implicit val weights = WeightBreeder[Double].random(-1, 1)
    val f = Sigmoid
    val net = Network(layout =
      Convolution((1, 2, 1), (0, 0), (1, 2), (1, 1), 3, f) :: Dense(10, f) :: SquaredError()
    )
    val batch = (1 to 100).map { _ => Tensor3D.fromVector(DenseVector.rand[Double](size = 2)) }
    val res = net.batchApply(batch)

    assert(res.size == batch.size)

  }

} 
Example 16
Source File: SparkFunSuite.scala    From spark-ranking-algorithms   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

// scalastyle:off
import org.scalatest.{Outcome, FunSuite}
import org.apache.log4j.{Level, Logger}


  final protected override def withFixture(test: NoArgTest): Outcome = {
    val testName = test.text
    val suiteName = this.getClass.getName
    val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s")
    try {
      Logger.getLogger("org").setLevel(Level.OFF)
      Logger.getLogger("akka").setLevel(Level.OFF)

      logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n")
      test()
    } finally {
      logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n")
    }
  }

} 
Example 17
Source File: LightningStreamingSuite.scala    From lightning-scala   with MIT License 5 votes vote down vote up
package org.viz.lightning

import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.viz.lightning.types.Make

class LightningStreamingSuite extends FunSuite with BeforeAndAfterAll {

  var lgn: Lightning = _

  override def beforeAll() {
    lgn = Lightning("http://public.lightning-viz.org")
    lgn.createSession("test-streaming")
  }

  test("line streaming") {
    val viz = lgn.lineStreaming(series = Make.series(n = 5, t = 20))
    lgn.lineStreaming(series = Make.series(n = 5, t = 20), viz=viz)
  }
  test("scatter streaming") {
    val viz = lgn.scatterStreaming(x = Make.gaussian(n = 50, scale = 5),
      y = Make.gaussian(n = 50, scale = 5),
      label = Make.labels(n = 50),
      size = Make.sizes(n = 50),
      alpha = Make.alphas(n = 50))

    lgn.scatterStreaming(x = Make.gaussian(n = 50, scale = 5),
      y = Make.gaussian(n = 50, scale = 5), viz=viz)
  }

} 
Example 18
Source File: LightningThreeSuite.scala    From lightning-scala   with MIT License 5 votes vote down vote up
package org.viz.lightning

import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.viz.lightning.types.Make

class LightningThreeSuite extends FunSuite with BeforeAndAfterAll {

  var lgn: Lightning = _

  override def beforeAll() {
    lgn = Lightning("http://public.lightning-viz.org")
    lgn.createSession("test-three")
  }

  test("scatter3") {
    lgn.scatter3(x = Make.values(n = 20),
                 y = Make.values(n = 20),
                 z = Make.values(n = 20),
                 label = Make.labels(n = 20),
                 size = Make.sizes(n = 20))
  }

} 
Example 19
Source File: LightningPlotsSuite.scala    From lightning-scala   with MIT License 5 votes vote down vote up
package org.viz.lightning

import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.viz.lightning.types.Make

class LightningPlotsSuite extends FunSuite with BeforeAndAfterAll {

  var lgn: Lightning = _

  override def beforeAll() {
    lgn = Lightning("http://public.lightning-viz.org")
    lgn.createSession("test-plots")
  }

  test("line") {
    lgn.line(series = Make.series(n = 5, t = 20),
      label = Make.labels(n = 5),
      size = Make.sizes(n = 5, scale = 7, min = 3))
  }

  test("line (single)") {
    lgn.line(series = Make.line(t = 20))
  }

  test("force") {
    lgn.force(conn = Make.sparseMatrix(n = 30, threshold = 0.95),
              label = Make.labels(n = 30))
  }

  test("force (links)") {
    lgn.force(conn = Make.sparseLinks(n = 30, threshold = 0.95),
      label = Make.labels(n = 30))
  }

  test("force (links and value)") {
    lgn.force(conn = Make.sparseLinks(n = 30, threshold = 0.95),
      value = Make.values(n = 30), colormap="Purples")
  }

  test("matrix") {
    lgn.matrix(matrix = Make.matrix(n = 10))
  }

  test("adjacency") {
    lgn.adjacency(conn = Make.sparseMatrix(n = 10),
      label = Make.labels(n = 10))
  }

  test("map (states)") {
    lgn.map(regions = Array("NY", "CA", "VA"),
      values = Make.values(n = 3))
  }

  test("map (countries)") {
    lgn.map(regions = Array("USA", "ENG", "IND"),
      values = Make.values(n = 3))
  }

  test("scatter") {
    lgn.scatter(x = Make.gaussian(n = 50, scale = 5),
      y = Make.gaussian(n = 50, scale = 5),
      label = Make.labels(n = 50),
      size = Make.sizes(n = 50),
      alpha = Make.alphas(n = 50))
  }

  test("graph") {
    lgn.graph(x = Make.gaussian(n = 50),
      y = Make.gaussian(n = 50),
      conn = Make.sparseMatrix(n = 50),
      label = Make.labels(n = 50))
  }

  test("graph bundled") {
    lgn.graphBundled(x = Make.gaussian(n = 50),
      y = Make.gaussian(n = 50),
      conn = Make.sparseMatrix(n = 50),
      label = Make.labels(n = 50))
  }

} 
Example 20
Source File: SparkFunSuite.scala    From Mastering-Spark-for-Data-Science   with MIT License 5 votes vote down vote up
package io.gzet.test

import org.apache.spark.sql.SparkSession
import org.scalatest.{Matchers, FunSuite}

class SparkFunSuite extends FunSuite with Matchers {

  def localTest(name : String)(f : SparkSession => Unit) : Unit = {

    this.test(name) {

      val spark = SparkSession
        .builder()
        .appName(name)
        .master("local")
        .config("spark.default.parallelism", "1")
        .getOrCreate()

      try {
        f(spark)
      } finally {
        spark.stop()
      }
    }
  }
} 
Example 21
Source File: CryptoTest.scala    From Mastering-Spark-for-Data-Science   with MIT License 5 votes vote down vote up
package io.gzet

import java.io.File

import org.apache.commons.io.FileUtils
import org.apache.hadoop.io.compress.CryptoCodec
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{Matchers, FunSuite}

class CryptoTest extends FunSuite with Matchers {

  val cryptoDir = System.getProperty("java.io.tmpdir") + "cryptTestDir"

  test("Crypto encrypt then decrypt file") {
    val conf = new SparkConf()
      .setAppName("Test Crypto")
      .setMaster("local")
      .set("spark.default.parallelism", "1")
      .set("spark.hadoop.io.compression.codecs", "org.apache.hadoop.io.compress.CryptoCodec")
    val sc = new SparkContext(conf)

    val testFile = getClass.getResource("/gdeltTestFile.csv")
    val rdd = sc.textFile(testFile.getPath)

    rdd.saveAsTextFile(cryptoDir, classOf[CryptoCodec])
    val read = sc.textFile(cryptoDir)

    val allLines = read.collect
    allLines.size should be(20)
    allLines(0).startsWith("331150686") should be (true)
    allLines(allLines.length - 1).endsWith("polytrack/") should be (true)

    FileUtils.deleteDirectory(new File(cryptoDir))
    sc.stop
  }
} 
Example 22
package com.chapter16.SparkTesting

import org.scalatest.Assertions._
import org.apache.spark.rdd.RDD
import com.holdenkarau.spark.testing.SharedSparkContext
import org.scalatest.FunSuite

class TransformationTestWithSparkTestingBase extends FunSuite with SharedSparkContext {
  def tokenize(line: RDD[String]) = {
    line.map(x => x.split(' ')).collect()
  }

  test("works, obviously!") {
    assert(1 == 1)
  }

  test("Words counting") {
    assert(sc.parallelize("Hello world My name is Reza".split("\\W")).map(_ + 1).count == 6)
  }

  test("Testing RDD transformations using a shared Spark Context") {
    val input = List("Testing", "RDD transformations", "using a shared", "Spark Context")
    val expected = Array(Array("Testing"), Array("RDD", "transformations"), Array("using", "a", "shared"), Array("Spark", "Context"))
    val transformed = tokenize(sc.parallelize(input))
    assert(transformed === expected)
  }
} 
Example 23
Source File: UtilsTest.scala    From Spark-RSVD   with Apache License 2.0 5 votes vote down vote up
package com.criteo.rsvd

import breeze.linalg.DenseMatrix
import org.scalatest.FunSuite

class UtilsTest extends FunSuite with PerTestSparkSession {

  test("isAPartition should work") {

    val incorrect1 = Array(Set(2, 3), Set(3, 4, 5))
    val incorrect2 = Array(Set(2), Set(4, 5))
    val incorrect3 = Array(Set(2, 3, 4, 5), Set.empty[Int])
    val incorrect4 = Array.empty[Set[Int]]

    assert(!Utils.isAPartition(incorrect1, 2, 5))
    assert(!Utils.isAPartition(incorrect2, 2, 5))
    assert(!Utils.isAPartition(incorrect3, 2, 5))
    assert(!Utils.isAPartition(incorrect4, 2, 5))

    val correct1 = Array(Set(2), Set(3), Set(4), Set(5))
    val correct2 = Array(Set(2), Set(3, 4), Set(5))
    val correct3 = Array(Set(2, 3, 4, 5))

    assert(Utils.isAPartition(correct1, 2, 5))
    assert(Utils.isAPartition(correct2, 2, 5))
    assert(Utils.isAPartition(correct3, 2, 5))

    //incorrect bounds:
    assert(!Utils.isAPartition(correct1, 2, 6))
    assert(!Utils.isAPartition(correct1, 1, 5))
    assert(!Utils.isAPartition(correct1, 3, 5))
    assert(!Utils.isAPartition(correct1, 2, 4))
  }

  test("IterRows should work") {
    val matrix = DenseMatrix.zeros[Double](2, 3)

    matrix(0, 0) = 1.0
    matrix(0, 1) = 2.0
    matrix(0, 2) = 3.0

    matrix(1, 0) = 10.0
    matrix(1, 1) = 20.0
    matrix(1, 2) = 30.0

    val it = Utils.rowIter(matrix)

    val firstLine = it.next()
    val secondLine = it.next()
    assert(!it.hasNext)
    assert(firstLine.data === Array(1.0, 2.0, 3.0))
    assert(secondLine.data === Array(10.0, 20.0, 30.0))
  }

} 
Example 24
Source File: CSRMatrixTest.scala    From Spark-RSVD   with Apache License 2.0 5 votes vote down vote up
package com.criteo.rsvd

import breeze.linalg.{DenseMatrix => BDM}
import org.scalatest.FunSuite

class CSRMatrixTest extends FunSuite {

  test("CSR matrix activeIterator should list the elements row-wise") {
    val builder = new CSRMatrix.Builder(3, 6)

    builder.add(0, 0, 1.0)
    builder.add(2, 0, 2.0)
    builder.add(0, 5, 3.0)
    builder.add(1, 4, 4.0)

    val CSRMat = builder.result

    val activeValues = List[((Int, Int), Double)](
      ((0, 0), 1.0),
      ((0, 5), 3.0),
      ((1, 4), 4.0),
      ((2, 0), 2.0)
    )
    assert(CSRMat.activeIterator.toList === activeValues)
  }

  test("CSR matrix transpose should work as expected") {
    val builder = new CSRMatrix.Builder(3, 6)

    builder.add(0, 0, 1.0)
    builder.add(2, 0, 2.0)
    builder.add(0, 5, 3.0)
    builder.add(1, 4, 4.0)

    val CSRMat = builder.result

    val localMatrix = BDM.zeros[Double](3, 6)
    CSRMat.activeIterator.foreach({ case ((i, j), v) => localMatrix(i, j) = v })

    val CSRMatTransposed = CSRMat.t

    val localMatrixTransposed = BDM.zeros[Double](6, 3)

    CSRMatTransposed.activeIterator.foreach({
      case ((i, j), v) => localMatrixTransposed(i, j) = v
    })

    assert(localMatrixTransposed === localMatrix.t)
  }

} 
Example 25
Source File: SingleDimensionPartitionerSpec.scala    From Spark-RSVD   with Apache License 2.0 5 votes vote down vote up
package com.criteo.rsvd

import org.scalatest.FunSuite
import org.scalatest.prop.TableDrivenPropertyChecks

class SingleDimensionPartitionerSpec
    extends FunSuite
    with TableDrivenPropertyChecks
    with PerTestSparkSession {
  test(
    "Partitioner should partition square and skinny matrices with the same " +
      "number of columns / rows per partition") {

    val numRowBlocks = 5
    val numBlocksPerPartition = 2

    val indices = Table(
      ("SkinnyBlockMatrixIndex", "BlockMatrixIndex", "ExpectedPartitionId"),
      (0, (0, 0), 0),
      (0, (1, 0), 0),
      (0, (0, 1), 0),
      (1, (4, 0), 0),
      (2, (3, 2), 1)
    )

    val partitioner =
      SingleDimensionPartitioner(numRowBlocks, numBlocksPerPartition)
    forAll(indices) {
      (skinnyIndex: Int,
       squareIndex: (Int, Int),
       expectedPartitionIndex: Int) =>
        assert(
          partitioner.getPartition(skinnyIndex) === partitioner.getPartition(
            squareIndex))
        assert(partitioner.getPartition(skinnyIndex) === expectedPartitionIndex)
    }
  }

  test("createCompatibleIndicesRDD works") {
    val numRowBlocks = 5
    val numBlocksPerPartition = 2
    val partitioner =
      SingleDimensionPartitioner(numRowBlocks, numBlocksPerPartition)

    val rdd = partitioner.createCompatibleIndicesRDD(sc)

    assert(rdd.partitions.length == 3)

    val data = rdd
      .mapPartitionsWithIndex {
        case (idx, it) => Iterator((idx, it.map(_._1).toList))
      }
      .collect()
      .sortBy(_._1)

    assert(
      data ===
        Array(
          (0, List(0, 1)),
          (1, List(2, 3)),
          (2, List(4))
        )
    )

  }
} 
Example 26
Source File: MultilayerPerceptronClassifierSuite.scala    From scalable-deeplearning   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.scaladl

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.Row
import org.scalatest.FunSuite

import scaladl.util.SparkTestContext

class MultilayerPerceptronClassifierSuite extends FunSuite with SparkTestContext {

  test("XOR function learning as binary classification problem with two outputs.") {
    val dataFrame = spark.createDataFrame(Seq(
      (Vectors.dense(0.0, 0.0), 0.0),
      (Vectors.dense(0.0, 1.0), 1.0),
      (Vectors.dense(1.0, 0.0), 1.0),
      (Vectors.dense(1.0, 1.0), 0.0))
    ).toDF("features", "label")
    val layers = Array[Int](2, 5, 2)
    val trainer = new MultilayerPerceptronClassifier()
      .setLayers(layers)
      .setBlockSize(1)
      .setSeed(123L)
      .setMaxIter(100)
    val model = trainer.fit(dataFrame)
    val result = model.transform(dataFrame)
    val predictionAndLabels = result.select("prediction", "label").collect()
    predictionAndLabels.foreach { case Row(p: Double, l: Double) =>
      assert(p == l)
    }
  }

  test("Test setWeights by training restart") {
    val dataFrame = spark.createDataFrame(Seq(
      (Vectors.dense(0.0, 0.0), 0.0),
      (Vectors.dense(0.0, 1.0), 1.0),
      (Vectors.dense(1.0, 0.0), 1.0),
      (Vectors.dense(1.0, 1.0), 0.0))
    ).toDF("features", "label")
    val layers = Array[Int](2, 5, 2)
    val trainer = new MultilayerPerceptronClassifier()
      .setLayers(layers)
      .setBlockSize(1)
      .setSeed(123456L)
      .setMaxIter(1)
      .setTol(1e-6)
    val initialWeights = trainer.fit(dataFrame).weights
    trainer.setInitialWeights(initialWeights.copy)
    val weights1 = trainer.fit(dataFrame).weights
    trainer.setInitialWeights(initialWeights.copy)
    val weights2 = trainer.fit(dataFrame).weights
    weights1.toArray.zip(weights2.toArray).foreach { x =>
      assert(math.abs(x._1 - x._2) <= 10e-5,
        "Training should produce the same weights given equal initial weights and number of steps")
    }
  }
} 
Example 27
Source File: LayerSuite.scala    From scalable-deeplearning   with Apache License 2.0 5 votes vote down vote up
package scaladl.layers

import org.apache.spark.ml.linalg.Vectors
import org.scalatest.FunSuite

import scaladl.util.SparkTestContext

class LayerSuite extends FunSuite with SparkTestContext {

  // TODO: test for weights comparison with Weka MLP
  test("ANN with Sigmoid learns XOR function with LBFGS optimizer") {
    val inputs = Array(
      Array(0.0, 0.0),
      Array(0.0, 1.0),
      Array(1.0, 0.0),
      Array(1.0, 1.0)
    )
    val outputs = Array(0.0, 1.0, 1.0, 0.0)
    val data = inputs.zip(outputs).map { case (features, label) =>
      (Vectors.dense(features), Vectors.dense(label))
    }
    val rddData = sc.parallelize(data, 1)
    val hiddenLayersTopology = Array(5)
    val dataSample = rddData.first()
    val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
    val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
    val initialWeights = FeedForwardModel(topology, 23124).weights
    val trainer = new FeedForwardTrainer(topology, 2, 1)
    trainer.setWeights(initialWeights)
    trainer.LBFGSOptimizer.setNumIterations(20)
    val model = trainer.train(rddData)
    val predictionAndLabels = rddData.map { case (input, label) =>
      (model.predict(input)(0), label(0))
    }.collect()
    predictionAndLabels.foreach { case (p, l) =>
      assert(math.round(p) === l)
    }
  }

  test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer") {
    val inputs = Array(
      Array(0.0, 0.0),
      Array(0.0, 1.0),
      Array(1.0, 0.0),
      Array(1.0, 1.0)
    )
    val outputs = Array(
      Array(1.0, 0.0),
      Array(0.0, 1.0),
      Array(0.0, 1.0),
      Array(1.0, 0.0)
    )
    val data = inputs.zip(outputs).map { case (features, label) =>
      (Vectors.dense(features), Vectors.dense(label))
    }
    val rddData = sc.parallelize(data, 1)
    val hiddenLayersTopology = Array(5)
    val dataSample = rddData.first()
    val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
    val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
    val initialWeights = FeedForwardModel(topology, 23124).weights
    val trainer = new FeedForwardTrainer(topology, 2, 2)
    // TODO: add a test for SGD
    trainer.LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(20)
    trainer.setWeights(initialWeights).setStackSize(1)
    val model = trainer.train(rddData)
    val predictionAndLabels = rddData.map { case (input, label) =>
      (model.predict(input), label)
    }.collect()
    predictionAndLabels.foreach { case (p, l) =>
      p.toArray.zip(l.toArray).foreach(pair => assert(math.abs(pair._1 - pair._2) < 0.5))
    }
  }
} 
Example 28
Source File: GradientSuite.scala    From scalable-deeplearning   with Apache License 2.0 5 votes vote down vote up
package scaladl.layers

import org.apache.spark.ml.linalg.Vectors
import org.scalatest.FunSuite

import scaladl.layers.AnnTypes._
import scaladl.tensor.DenseTensor

class GradientSuite extends FunSuite {

  test("Gradient computation against numerical differentiation") {
    val x = DenseTensor[Double](Array(1.0, 1.0, 1.0), Array(3, 1))
    val input = new Tensor(Array(1.0, 1.0, 1.0), Array(3, 1))
    // output must contain zeros and one 1 for SoftMax
    val target = new Tensor(Array(0.0, 1.0), Array(2, 1))
    val topology = FeedForwardTopology.multiLayerPerceptron(Array(3, 4, 2), softmaxOnTop = false)
    val layersWithErrors = Seq(
      new SigmoidLayerWithSquaredError(),
      new SoftmaxLayerWithCrossEntropyLoss(),
      new SigmoidLayerWithCrossEntropyLoss(),
      new EmptyLayerWithSquaredError()
    )
    // check all layers that provide loss computation
    // 1) compute loss and gradient given the model and initial weights
    // 2) modify weights with small number epsilon (per dimension i)
    // 3) compute new loss
    // 4) ((newLoss - loss) / epsilon) should be close to the i-th component of the gradient
    for (layerWithError <- layersWithErrors) {
      topology.layers(topology.layers.length - 1) = layerWithError
      val model = topology.model(seed = 12L)
      val weights = model.weights.toArray
      val numWeights = weights.size
      val gradient = new Tensor(Array(numWeights))
      val loss = model.computeGradient(input, target, gradient, 1)
      val eps = 1e-4
      var i = 0
      val tol = 1e-4
      while (i < numWeights) {
        val originalValue = weights(i)
        weights(i) += eps
        val newModel = topology.model(Vectors.dense(weights))
        val newLoss = computeLoss(input, target, newModel)
        val derivativeEstimate = (newLoss - loss) / eps
        assert(math.abs(gradient.value(i) - derivativeEstimate) < tol,
          "Layer failed gradient check: " + layerWithError.getClass)
        weights(i) = originalValue
        i += 1
      }
    }
  }

  private def computeLoss(input: Tensor, target: Tensor, model: TopologyModel): Double = {
    val outputs = model.forward(input)
    model.layerModels.last match {
      case layerWithLoss: LossFunction =>
        layerWithLoss.loss(outputs.last, target, new Tensor(target.shape))
      case _ =>
        throw new UnsupportedOperationException("Top layer is required to have loss." +
          " Failed layer:" + model.layerModels.last.getClass)
    }
  }
} 
Example 29
Source File: ORSetSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.op

import org.scalatest.{FunSuite, Matchers}
import org.scalatest.prop.PropertyChecks
import org.scalacheck.Gen
import com.github.nscala_time.time.Imports._
import com.machinomy.crdt.state.TombStone

class ORSetSuite extends FunSuite with PropertyChecks with Matchers {
  test("fresh is empty") {
    val set = ORSet[Int, DateTime]()
    assert(set.value.isEmpty)
  }

  test("add") {
    forAll(Gen.posNum[Int]) { (i: Int) =>
      val set = ORSet[Int, DateTime]()
      val (nextSet, operation) = set.add(i)
      nextSet.value should be(Set(i))
      operation.isDefined should be(true)
      operation.get.isInstanceOf[ORSet.Add[_, _]] should be(true)
    }
  }

  test("remove if present") {
    forAll(Gen.posNum[Int]) { (i: Int) =>
      val initial = ORSet[Int, DateTime]()
      val (set, _) = initial.add(i)
      val (finalSet, operation) = set.remove(i)
      finalSet.value should be(empty)
      operation.isDefined should be(true)
      operation.get.isInstanceOf[ORSet.Remove[_, _]] should be(true)
    }
  }

  test("remove if absent") {
    forAll(Gen.posNum[Int]) { (i: Int) =>
      val initial = ORSet[Int, DateTime]()
      val (set, operation) = initial.remove(i)
      set.value should be(empty)
      operation shouldNot be(empty)
      operation.get.isInstanceOf[ORSet.Remove[_, _]] should be(true)
    }
  }

  test("add operation") {
    forAll(Gen.posNum[Int]) { (i: Int) =>
      val set = ORSet[Int, DateTime]()
      val addOp = ORSet.Add(i, implicitly[TombStone[DateTime]].next)
      val (nextSet, operation) = set.run(addOp)
      nextSet.value should be(Set(i))
      operation shouldNot be(empty)
      operation.get.isInstanceOf[ORSet.Add[_, _]] should be(true)
    }
  }

  test("remove operation if present") {
    forAll(Gen.posNum[Int]) { (i: Int) =>
      val initial = ORSet[Int, DateTime]()
      val (set, _) = initial.add(i)
      val removeOp = ORSet.Remove(i, set.state(i))
      val (finalSet, operation) = set.run(removeOp)
      finalSet.value should be(empty)
      operation.isDefined should be(true)
      operation.get.isInstanceOf[ORSet.Remove[_, _]] should be(true)
    }
  }

  test("remove operation if absent") {
    forAll(Gen.posNum[Int]) { (i: Int) =>
      val initial = ORSet[Int, DateTime]()
      val removeOp = ORSet.Remove(i, Set(implicitly[TombStone[DateTime]].next))
      val (set, operation) = initial.run(removeOp)
      set.value should be(empty)
      operation should be(empty)
    }
  }
} 
Example 30
Source File: CounterSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.op

import org.scalatest.{FunSuite, Matchers}
import org.scalatest.prop.PropertyChecks
import org.scalacheck.Gen

class CounterSuite extends FunSuite with PropertyChecks with Matchers {
  test("fresh value is zero") {
    val counter = Counter[Int]()
    assert(counter.value == 0)
  }

  test("increment") {
    forAll(Gen.posNum[Int]) { (i: Int) =>
      val increment = Counter.Increment(i)
      val counter = Counter.update(Counter[Int](), increment)
      counter.value should be(i)
    }
  }

  test("decrement") {
    forAll(Gen.posNum[Int]) { (i: Int) =>
      val decrement = Counter.Decrement(i)
      val counter = Counter.update(Counter[Int](), decrement)
      counter.value should be(-1 * i)
    }
  }
} 
Example 31
Source File: PartialOrderDagSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.op

import com.machinomy.crdt.state.TPSet
import org.scalatest.{FunSuite, Matchers}

import scalax.collection.Graph
import scalax.collection.GraphPredef._
import scalax.collection.GraphEdge._

class PartialOrderDagSuite extends FunSuite with Matchers {
  test("add vertex") {
    val g: Graph[Int, DiEdge] = Graph[Int, DiEdge]() + 1 ~> 100
    val edges: Set[DiEdge[Int]] = g.edges.toOuter
    val vertices = TPSet(g.nodes.toOuter)
    val dag = PartialOrderDag[Int, DiEdge](vertices, edges)
    val (dag2, op) = dag.add(2, 1, 100)
    dag2.value.edges shouldNot be(empty)
  }

  // @todo Actually, remove the vertex as a payload, but leave it as a chain link
  test("remove vertex - does nothing") {
    val g: Graph[Int, DiEdge] = Graph[Int, DiEdge]() + 1 ~> 100 + 1 ~> 2 + 2 ~> 100
    val edges = g.edges.toOuter
    val vertices = TPSet(g.nodes.toOuter)
    val dag = PartialOrderDag[Int, DiEdge](vertices, edges)
    val (dag2, op) = dag.remove(2)
    dag2.value.edges shouldNot be(empty)
  }
} 
Example 32
Source File: MonotonicDagSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.op

import org.scalatest.{FunSuite, Matchers}

import scalax.collection.Graph
import scalax.collection.GraphPredef._
import scalax.collection.GraphEdge._

class MonotonicDagSuite extends FunSuite with Matchers {
  test("add, edge") {
    val g: Graph[Int, DiEdge] = Graph[Int, DiEdge]() + 1 + 2
    val dag = MonotonicDag[Int, DiEdge, Graph[Int, DiEdge]](g)
    val edge = 1 ~> 2
    val (dag2, op) = dag.add(edge)
    dag2.value.edges shouldNot be(empty)
  }

  test("add, vertex") {
    val g: Graph[Int, DiEdge] = Graph[Int, DiEdge]() + 1 + 100
    val (dag, _)= MonotonicDag[Int, DiEdge, Graph[Int, DiEdge]](g).add(1 ~> 100)
    val (dag2, op) = dag.add(2, 1, 100)
    dag2.value.edges shouldNot be(empty)
  }
} 
Example 33
Source File: ORSetSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.state

import com.github.nscala_time.time.Imports._
import org.scalatest.FunSuite

class ORSetSuite extends FunSuite {
  test("Fresh ORSet is empty") {
    val fresh = ORSet[Int, DateTime]()
    assert(fresh.value.isEmpty)
  }

  test("ORSet could be updated") {
    val a = ORSet[Int, DateTime]() + 3
    assert(a.value === Set(3))

    val b = ORSet[Int, DateTime]() + 3 - 3
    assert(b.value === Set.empty)

    val now = DateTime.now()
    val c = ORSet[Int, DateTime]() + 3 - 3
    assert(c.value === Set.empty)

    val d = c + (3, now + 10.minutes)
    assert(d.value === Set(3))
  }

} 
Example 34
Source File: TPSetSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.state

import org.scalatest.FunSuite
import cats.syntax.all._

class TPSetSuite extends FunSuite {
  test("Fresh TPSet is empty") {
    val fresh = TPSet[Int]()
    assert(fresh.value.isEmpty)
  }

  test("TPSet could be updated") {
    val a = TPSet[Int]() + 3 - 3
    assert(a.value === Set.empty[Int])
    val b = TPSet[Int]() + 3 - 1
    assert(b.value === Set(3))
  }

  test("TPSet could be merged") {
    val a = TPSet[Int]() + 3 - 3
    val b = TPSet[Int]() + 1 - 1 + 2
    val c = a |+| b
    assert(c.value === Set(2))
  }
} 
Example 35
Source File: GSetSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.state

import cats.kernel.Eq
import org.scalatest.FunSuite
import cats.syntax.all._

class GSetSuite extends FunSuite {
  val eq = implicitly[Eq[GSet[Int]]]

  test("Just created GSet is empty") {
    val gSet = GSet[Int]()
    assert(gSet.value.isEmpty)
  }

  test("GSet calculates value") {
    val a = GSet[Int]()
    val b = a + 3
    val c = b + 1
    val d = c + 3
    assert(d.value === Set(1, 3))
  }

  test("GSets can be merged") {
    val a = GSet[Int](Set(1, 2, 3))
    val b = GSet[Int](Set(2, 3, 4))
    val result = a |+| b
    assert(result.value === Set(1, 2, 3, 4))
  }

  test("equality") {
    val a = GSet[Int]()
    val b = GSet[Int]()
    assert(eq.eqv(a, b))

    val a1 = a + 1
    assert(a1 !== b)

    val b1 = b + 1
    assert(eq.eqv(a1, b1))
  }
} 
Example 36
Source File: GCounterSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.state

import cats.kernel.{Eq, Monoid}
import cats.syntax.all._
import org.scalacheck.Gen
import org.scalatest.FunSuite
import org.scalatest.prop.PropertyChecks

class GCounterSuite extends FunSuite with PropertyChecks {
  val replicaGen = Gen.posNum[Int]
  val valueGen = Gen.posNum[Int]
  val counterGen = for {
    id <- replicaGen
    value <- valueGen
  } yield GCounter[Int, Int]() + (id -> value)
  val eq = implicitly[Eq[GCounter[Int, Int]]]

  test("Fresh GCounter is empty") {
    val fresh = GCounter[Int, Int]()
    assert(Monoid[GCounter[Int, Int]].isEmpty(fresh))
    assert(fresh.value === 0)
  }

  test("could be calculated") {
    val counter = GCounter[Int, Int]().increment(1, 2).increment(2, 3)
    assert(counter.value === 5)
  }

  test("could be combined") {
    val a = GCounter[Int, Int]() + (1, 2) + (2, 3)
    val b = GCounter[Int, Int]() + (1, 2) + (3, 4)
    val c = a |+| b
    assert(c.value === 2 + 3 + 4)
  }

  test("could present replica value") {
    val a = GCounter[Int, Int]()
    assert(a.get(0) === 0)
    val b = a.increment(1, 2).increment(2, 3)
    assert(b.get(1) === 2)
    assert(b.get(2) === 3)
  }

  test("equality") {
    val a = GCounter[Int, Int]()
    val b = GCounter[Int, Int]()
    assert(eq.eqv(a, b))

    val a1 = a + (1 -> 1)
    assert(eq.neqv(a1, b))
    val b1 = b + (1 -> 2)
    assert(a1 !== b1)
    assert(eq.neqv(a1, b1))
    val a2 = a1 + (1 -> 1)
    assert(eq.eqv(a2, b1))
  }

  test("associativity") {
    forAll(Gen.listOfN(3, counterGen)) {
      case x :: y :: z :: Nil =>
        val left = x |+| (y |+| z)
        val right = (x |+| y) |+| z
        assert(eq.eqv(left, right))
      case _ => throw new RuntimeException("This is unexpected, really")
    }
  }

  test("commutativity") {
    forAll(Gen.listOfN(2, counterGen)) {
      case x :: y :: Nil =>
        val left = x |+| y
        val right = y |+| x
        assert(eq.eqv(left, right))
      case _ => throw new RuntimeException("This is unexpected, really")
    }
  }

  test("idempotency") {
    forAll(Gen.listOf(counterGen)) { list =>
      whenever(list.nonEmpty) {
        val counter = list.reduce(_ |+| _)
        assert(eq.eqv(counter, counter |+| counter))
      }
    }
  }
} 
Example 37
Source File: PNCounterSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.state

import cats._
import org.scalatest.FunSuite
import cats.syntax.all._

class PNCounterSuite extends FunSuite {
  test("Empty value is zero") {
    val fresh = Monoid[PNCounter[Int, Int]].empty
    assert(fresh.value === 0)
  }

  test("Could be calculated") {
    val counter = Monoid[PNCounter[Int, Int]].empty + (1 -> 1) + (2 -> 3)
    assert(counter.value === 1 + 3)
  }

  test("Could be merged") {
    val a = Monoid[PNCounter[Int, Int]].empty + (1 -> 1) + (2 -> 3)
    val b = Monoid[PNCounter[Int, Int]].empty + (1 -> 2) + (2 -> -3)
    val c = a |+| b
    assert(c.value === 2)
  }

  test("Could get replica value") {
    val a = Monoid[PNCounter[Int, Int]].empty + (1 -> 2) + (2 -> 3) + (1 -> -1)
    assert(a.get(1) === 1)
  }

  test("Could get table") {
    val a = Monoid[PNCounter[Int, Int]].empty + (1 -> 2) + (2 -> 3) + (1 -> -1)
    assert(a.table === Map(1 -> 1, 2 -> 3))
  }

  test("Can update table") {
    val a = Monoid[PNCounter[Int, Int]].empty + (1 -> 2) + (2 -> 3) + (1 -> -1)
    val b = a + (2 -> 5)
    assert(b.table === Map(1 -> 1, 2 -> 8))
  }
} 
Example 38
Source File: MCSetSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.state

import cats._
import cats.syntax.all._
import org.scalatest.FunSuite

class MCSetSuite extends FunSuite {
  test("Just created MCSet is empty") {
    val gSet = Monoid[MCSet[Int, Int]].empty
    assert(gSet.value.isEmpty)
  }

  test("MCSet calculates value") {
    val a = Monoid[MCSet[Int, Int]].empty + 3 + 1 + 3
    assert(a.value === Set(1, 3))
  }

  test("MCSet additions can be merged") {
    val a = Monoid[MCSet[Int, Int]].empty + 1 + 2 + 3
    val b = Monoid[MCSet[Int, Int]].empty + 2 + 3 + 4
    val result = a |+| b
    assert(result.value === Set(1, 2, 3, 4))
  }

  test("MCSet additions and removals can be merged") {
    val a = Monoid[MCSet[Int, Int]].empty + 1 + 2 + 3
    val b = Monoid[MCSet[Int, Int]].empty - 2 - 3 - 4
    val c = a |+| b
    assert(c.value === Set(1, 2, 3))

    val d = a |+| (a - 2 - 3)
    assert(d.value === Set(1))
  }
} 
Example 39
Source File: LWWElementSetSuite.scala    From crdt   with Apache License 2.0 5 votes vote down vote up
package com.machinomy.crdt.state

import cats._
import cats.syntax.all._
import com.github.nscala_time.time.Imports._
import org.scalatest.FunSuite

class LWWElementSetSuite extends FunSuite {
  test("fresh is empty") {
    val set = Monoid[LWWElementSet[Int, DateTime, Bias.AdditionWins]].empty
    assert(set.value.isEmpty)
  }

  test("calculates value") {
    val a = Monoid[LWWElementSet[Int, DateTime, Bias.AdditionWins]].empty + 3 + 1 + 3
    assert(a.value === Set(1, 3))
  }

  test("can be combined, addition bias") {
    val now = DateTime.now
    val a = Monoid[LWWElementSet[Int, DateTime, Bias.AdditionWins]].empty + (1, now) + (2, now) + (3, now)
    val b = Monoid[LWWElementSet[Int, DateTime, Bias.AdditionWins]].empty - (2, now) - (3, now) - (4, now)
    val result = a |+| b
    assert(result.value === Set(1, 2, 3))
  }

  test("can be combined, removal bias") {
    val now = DateTime.now
    val a = Monoid[LWWElementSet[Int, DateTime, Bias.RemovalWins]].empty + (1, now) + (2, now) + (3, now)
    val b = Monoid[LWWElementSet[Int, DateTime, Bias.RemovalWins]].empty - (2, now) - (3, now) - (4, now)
    val result = a |+| b
    assert(result.value === Set(1))
  }
} 
Example 40
Source File: TDatabaseFactory.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.db.factory

import java.util.concurrent.TimeUnit

import com.gabry.job.db.slicks.{SlickDependencyAccess, SlickJobAccess, SlickScheduleAccess, SlickTaskAccess}
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfterAll, FunSuite}

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{Await, ExecutionContextExecutor}


class TDatabaseFactory extends FunSuite with BeforeAndAfterAll{
  private implicit lazy val executionContext: ExecutionContextExecutor = scala.concurrent.ExecutionContext.global
  private val config = ConfigFactory.load()
  private val duration = FiniteDuration(3,TimeUnit.SECONDS)
  private val dataAccessFactory = DatabaseFactory.getDataAccessFactory(config).get
  override def beforeAll(): Unit = {
    super.beforeAll()
    dataAccessFactory.init()
  }
  override def afterAll(): Unit = {
    super.afterAll()
    dataAccessFactory.destroy()
  }
  test("TDatabaseFactory default jobAccess type"){
    val access = dataAccessFactory.getJobAccess
    assert(access.isInstanceOf[SlickJobAccess])
  }
  test("TDatabaseFactory jobAccess select"){
    val access = dataAccessFactory.getJobAccess
    assert(access.isInstanceOf[SlickJobAccess])

    val select = Await.result(access.selectOne("test"),duration)
    assert(select.isDefined)
    assert(select.get.name == "test")

  }
  test("TDatabaseFactory dependencyAccess type"){
    val access = dataAccessFactory.getDependencyAccess
    assert(access.isInstanceOf[SlickDependencyAccess])
  }
  test("TDatabaseFactory scheduleAccess type"){
    val access = dataAccessFactory.getScheduleAccess
    assert(access.isInstanceOf[SlickScheduleAccess])
  }
  test("TDatabaseFactory taskAccess type"){
    val access = dataAccessFactory.getTaskAccess
    assert(access.isInstanceOf[SlickTaskAccess])
  }
} 
Example 41
Source File: TScheduleAccess.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.db.slicks

import java.util.concurrent.TimeUnit

import com.gabry.job.core.domain.UID
import com.gabry.job.core.po.SchedulePo
import com.gabry.job.db.slicks.schema.Tables
import com.gabry.job.utils.Utils
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import slick.jdbc.MySQLProfile.api._

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{Await, ExecutionContextExecutor}

class TScheduleAccess extends FunSuite with BeforeAndAfterAll{
  implicit lazy val executionContext: ExecutionContextExecutor = scala.concurrent.ExecutionContext.global
  val db = Database.forConfig("",ConfigFactory.load().getConfig("db.mysql"))
  val scheduleAccess = new SlickScheduleAccess(db)
  val scheduleNode = "3958164162305738376-node"
  var jobIdAndTriggerTime: (UID, Long) = ("999",1523497644627L)

  val schedulePo:Tables.SchedulesRow = SchedulePo("0",jobIdAndTriggerTime._1,2,3,false,
    jobIdAndTriggerTime._2,scheduleNode,123,false,
    Utils.calcPostOffsetTime(jobIdAndTriggerTime._2,0,TimeUnit.MINUTES),null)
  val duration = FiniteDuration(3,TimeUnit.SECONDS)
  override def beforeAll(): Unit = {
    super.beforeAll()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    Await.result(scheduleAccess.delete(schedulePo) ,duration)
    db.close()
  }
  test("ScheduleAccess insert"){
    val insert = Await.result(scheduleAccess.insert(schedulePo.copy(jobUid = schedulePo.jobUid)) ,duration)
    assert(insert != null )
  }
  test("ScheduleAccess insertOnDuplicateUpdate"){
    val insert1 = Await.result(scheduleAccess.insertOnDuplicateUpdate(schedulePo),duration)
    val insert2 = Await.result(scheduleAccess.insertOnDuplicateUpdate(schedulePo),duration)
    assert(insert1 > 0 )
    assert(insert2 > 0 )
  }

  test("ScheduleAccess select setDispatched"){
    val select = Await.result(scheduleAccess.selectOne(jobIdAndTriggerTime),duration)
    assert(select.isDefined)
    assert(select.get.jobUid == jobIdAndTriggerTime._1 && select.get.triggerTime == jobIdAndTriggerTime._2)
    val update = Await.result(scheduleAccess.setDispatched(select.get.uid,true),duration)
    assert(update > 0 )

    val select1 = Await.result(scheduleAccess.selectOne(jobIdAndTriggerTime),duration)
    assert(select1.isDefined)
    assert(select1.get.dispatched)
  }
  test("ScheduleAccess update"){
    val updateScheduleNode = "updateScheduleNode"
    val old = Await.result(scheduleAccess.selectOne(jobIdAndTriggerTime),duration)
    assert(old.isDefined)
    assert(old.get.scheduleNode!=updateScheduleNode)

    val update = Await.result(scheduleAccess.update(schedulePo.copy(scheduleNode = updateScheduleNode)),duration)
    assert(update > 0 )
    val newJob = Await.result(scheduleAccess.selectOne(jobIdAndTriggerTime),duration)
    assert(newJob.isDefined)
    assert(newJob.get.scheduleNode == updateScheduleNode)
  }
  test("ScheduleAccess selectUnDispatchSchedule"){
    scheduleAccess.selectUnDispatchSchedule("1",scheduleNode,jobIdAndTriggerTime._2+30,2){ r=>
      assert(!r.dispatched)
    }
  }
} 
Example 42
Source File: TTaskAccess.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.db.slicks

import java.util.concurrent.TimeUnit

import com.gabry.job.core.domain.UID
import com.gabry.job.db.slicks.schema.Tables
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import slick.jdbc.MySQLProfile.api._

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{Await, ExecutionContextExecutor}

class TTaskAccess extends FunSuite with BeforeAndAfterAll{
  implicit lazy val executionContext: ExecutionContextExecutor = scala.concurrent.ExecutionContext.global
  val db = Database.forConfig("",ConfigFactory.load().getConfig("db.mysql"))
  val taskAccess = new SlickTaskAccess(db)
  var jobIdAndTriggerTime: (UID, Long) = ("999",1523497644627L)
  val taskTrackerNode = "3958164162305738376-node"
  val taskPo:Tables.TasksRow = Tables.TasksRow(-1,jobIdAndTriggerTime._1,jobIdAndTriggerTime._1,"-1",1,taskTrackerNode,"TEST",jobIdAndTriggerTime._2,Some("test"),null)
  val duration = FiniteDuration(3,TimeUnit.SECONDS)

  override def beforeAll(): Unit = {
    super.beforeAll()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    db.close()
  }
  test("TaskAccess insert, select,delete"){
    val insert = Await.result(taskAccess.insert(taskPo) ,duration)
    assert(insert!=null)
    assert(insert.state==taskPo.state)
    val select = Await.result(taskAccess.selectOne(insert.uid),duration)
    assert(select.isDefined)
    assert(select.get.state==insert.state)
    val delete = Await.result(taskAccess.delete(insert),duration)
    assert(delete>0)
    val select1 = Await.result(taskAccess.selectOne(insert.uid),duration)
    assert(select1.isEmpty)
  }
  test("TaskAccess insertOnDuplicateUpdate"){
    val insert = Await.result(taskAccess.insertOnDuplicateUpdate(taskPo) ,duration)
    assert(insert==0)
  }
} 
Example 43
Source File: TJobAccess.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.db.slicks

import java.util.concurrent.TimeUnit

import com.gabry.job.core.domain.Job
import com.gabry.job.db.slicks
import com.gabry.job.db.slicks.schema.Tables
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import slick.jdbc.MySQLProfile.api._

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{Await, ExecutionContextExecutor}


class TJobAccess extends FunSuite with BeforeAndAfterAll{
  implicit lazy val executionContext: ExecutionContextExecutor = scala.concurrent.ExecutionContext.global
  val db = Database.forConfig("",ConfigFactory.load().getConfig("db.mysql"))
  val jobAccess = new SlickJobAccess(db)
  val scheduleNode = "3958164162305738376-node"
  val job:Tables.JobsRow = slicks.jobPo2Row(Job("0", "3958164162305738376-test","com.gabry.job.examples.TestTask","",0,TimeUnit.MINUTES))
    .copy(schedulerNode = Some(scheduleNode))
  val duration = FiniteDuration(3,TimeUnit.SECONDS)
  override def beforeAll(): Unit = {
    super.beforeAll()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    jobAccess.delete(job)
    db.close()
  }
  test("JobAccess insert"){
    val insert = Await.result(jobAccess.insert(job),duration)
    assert(insert != null )
    assert(insert.name == job.name)
  }
  test("JobAccess select"){
    val select = Await.result(jobAccess.selectOne(job.name),duration)
    assert(select.isDefined)
    assert(select.get.name == job.name)
  }
  test("JobAccess update"){
    val updateClassName = "updateClassName"
    val old = Await.result(jobAccess.selectOne(job.name),duration)
    assert(old.isDefined)
    assert(old.get.className!=updateClassName)
    val update = Await.result(jobAccess.update(job.copy(className = updateClassName)),duration)
    assert(update > 0 )
    val newJob = Await.result(jobAccess.selectOne(job.name),duration)
    assert(newJob.isDefined)
    assert(newJob.get.className==updateClassName)
  }
  test("JobAccess selectJobsByScheduleNode"){
    jobAccess.selectJobsByScheduleNode(scheduleNode){ r =>
      assert(r.schedulerNode.isDefined && r.schedulerNode.get == scheduleNode)
    }
  }
  test("JobAccess insertOnDuplicateUpdate"){
    val insert1 = Await.result(jobAccess.insertOnDuplicateUpdate(job),duration)
    val insert2 = Await.result(jobAccess.insertOnDuplicateUpdate(job),duration)
    assert(insert1>0)
    assert(insert2>0)
  }
  test("JobAccess delete"){
    val delete = Await.result(jobAccess.delete(job),duration)
    assert(delete > 0 )
    val select = Await.result(jobAccess.selectOne(job.name),duration)
    assert(select.isEmpty)
  }
} 
Example 44
Source File: TInsertTime.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.db.slicks

import java.util.concurrent.TimeUnit

import com.gabry.job.core.builder.JobBuilder
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import slick.jdbc.MySQLProfile.api._

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{Await, ExecutionContextExecutor, Future}

class TInsertTime extends FunSuite with BeforeAndAfterAll{
  implicit lazy val executionContext: ExecutionContextExecutor = scala.concurrent.ExecutionContext.global
  val db = Database.forConfig("",ConfigFactory.load().getConfig("db.mysql"))
  val jobAccess = new SlickJobAccess(db)
  val duration = FiniteDuration(3,TimeUnit.DAYS)

  override def beforeAll(): Unit = {
    super.beforeAll()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    db.close()
  }
  test("InsertTime"){
    val recordNum = 10000
    val futures = 1 to recordNum map{ i =>

      val job = JobBuilder().withName(i.toString)
        .withClass("com.gabry.job.examples.TestTask")
        .withDataTimeOffset(0)
        .withDataTimeOffsetUnit(TimeUnit.MINUTES)
        .build()

      jobAccess.insert(job)
    }
    val start = System.currentTimeMillis()
    val all = Future.sequence(futures)
    Await.result(all,duration)
    val end = System.currentTimeMillis()
    println(s"插入 $recordNum 条数据,总耗时 ${end-start} 毫秒,平均 ${(end-start)/recordNum} 毫秒/条")
  }
} 
Example 45
Source File: TestTaskExample.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.examples

import com.gabry.job.core.task.Task
import com.gabry.job.utils.TaskClassLoader
import org.scalatest.{BeforeAndAfterAll, FunSuite}


class TestTaskExample extends FunSuite with BeforeAndAfterAll{
  val classLoader = new TaskClassLoader("target/lemon-schedule-examples-1.0-SNAPSHOT.jar")

  override def beforeAll(): Unit = {
    super.beforeAll()
    classLoader.init()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    classLoader.destroy()
  }
  test("TestTaskExample task example"){
    val taskClaz = classLoader.loadInstance[Task]("com.gabry.job.examples.TaskExample")
    assert(taskClaz.isSuccess)
    val task = taskClaz.get
    assert(task!=null)
    task.initialize()
    task.destroy()
  }
} 
Example 46
Source File: TZookeeperRegistry.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.test

import com.gabry.job.core.domain.Node
import com.gabry.job.core.registry.RegistryEvent.RegistryEvent
import com.gabry.job.core.registry.{AbstractRegistry, RegistryFactory, RegistryListener}
import com.typesafe.config.{Config, ConfigFactory}
import org.scalatest.{BeforeAndAfterAll, FunSuite}


      override def onEvent(node: Node, event: RegistryEvent): Unit = {
        assert(regNode == node)
      }
    })
    registry.registerNode(regNode)
    Thread.sleep(1000)
  }
  def zkConfigStr = """registry{
                   |  type = "zookeeper"
                   |  zookeeper{
                   |    hosts = "dn1:2181,dn3:2181,dn4:2181"
                   |    exponential-backoff-retry {
                   |      base-sleep-timeMs = 1000
                   |      max-retries = 3
                   |    }
                   |    root-path = "/lemon-schedule"
                   |  }
                   |}
                   """.stripMargin
} 
Example 47
Source File: TTask.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.test

import com.gabry.job.core.builder.CommandJobBuilder
import com.gabry.job.core.domain.JobContext
import com.gabry.job.core.task.impl.CommandTask
import org.scalatest.FunSuite


class TTask extends FunSuite {
  test("TaskExample commandTask"){
    val job = CommandJobBuilder()
        .withEnv("testEnv","envValue")
      .withCommand("ls -l")
      .build()
    val jobContext = JobContext(job,null,0)
    val task = new CommandTask
    task.run(jobContext,0,0)
  }
} 
Example 48
Source File: TCronGenerator.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.utils

import org.scalatest.FunSuite


class TCronGenerator extends FunSuite{
  test("CronGenerator isValid"){
    assert(CronGenerator.isValid("0/1 * * * *"))
    assert(!CronGenerator.isValid("abcdef"))
  }
  test("CronGenerator getDescribe"){
    assert(CronGenerator.getDescribe("0/1 * * * *")!="")
    intercept[IllegalArgumentException]{
      CronGenerator.getDescribe("abcdef")
    }
  }
  test("CronGenerator getNextTriggerTime"){
    val minutesCron = "0/1 * * * *"
    val start = System.currentTimeMillis()
    val next1 = CronGenerator.getNextTriggerTime(minutesCron,start)
    val next2 = CronGenerator.getNextTriggerTime(minutesCron,next1.get)
    assert(next2.get-next1.get==60*1000)
    assert(next1.get-start<60*1000)
  }
  test("CronGenerator getPreviousTriggerTime"){
    val minutesCron = "0/1 * * * *"
    val start = System.currentTimeMillis()
    val next1 = CronGenerator.getNextTriggerTime(minutesCron,start)
    val next2 = CronGenerator.getNextTriggerTime(minutesCron,next1.get)
    val prev = CronGenerator.getPreviousTriggerTime(minutesCron,next2.get)
    assert(prev.get==next1.get)
  }
  test("CronGenerator getNextTriggerTime for now"){
    val minutesCron = "0/1 * * * *"
    val start = System.currentTimeMillis()
    val next = CronGenerator.getNextTriggerTime(minutesCron)
    assert(next.get-start<60*1000)
  }
  test("CronGenerator None"){
    val minutesCron = "0/1 *"
    val start = System.currentTimeMillis()
    val next1 = CronGenerator.getNextTriggerTime(minutesCron)
    val next2 = CronGenerator.getNextTriggerTime(minutesCron,start)
    val prev = CronGenerator.getPreviousTriggerTime(minutesCron,start)
    assert(next1.isEmpty && next2.isEmpty && prev.isEmpty)
  }
} 
Example 49
Source File: TTaskClassLoader.scala    From lemon-schedule   with GNU General Public License v2.0 5 votes vote down vote up
package com.gabry.job.utils

import org.scalatest.{BeforeAndAfterAll, FunSuite}


class TTaskClassLoader extends FunSuite with BeforeAndAfterAll{
  val taskClassLoader = new TaskClassLoader("../lemon-schedule-examples/target/lemon-schedule-examples-1.0-SNAPSHOT.jar")
  override def beforeAll(): Unit = {
    super.beforeAll()
    taskClassLoader.init()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    taskClassLoader.destroy()

  }
  test("taskClassLoader"){
//    intercept[NoClassDefFoundError]{
      taskClassLoader.load("com.gabry.job.examples.TestTask")
      // 此处不太方便测试,不再单独编写测试
  //  }
  }
} 
Example 50
Source File: PlySuite.scala    From spark-iqmulus   with Apache License 2.0 5 votes vote down vote up
package fr.ign.spark.iqmulus.ply

import org.scalatest.FunSuite
import org.scalatest.ShouldMatchers
import org.apache.spark.sql.types._

class PlySuite extends FunSuite with ShouldMatchers {

  val id = Array("fid" -> IntegerType, "pid" -> LongType)
  val xyz = Array("x" -> FloatType, "y" -> FloatType, "z" -> FloatType)
  val rgb = Array("r" -> ByteType, "g" -> ByteType, "b" -> ByteType)

  val files = Seq(
    ("trepied_xyz.ply", 5995, id ++ xyz) // ,
  //   ("trepied_dim.ply", 5995, id ++ xyz ++ rgb),
  //   ("trepied_dim2.ply", 5995, id ++ xyz ++ rgb),
  //   ("213-232-7.ply", 71651, id ++ xyz ++ rgb)
  )

  val resources = "src/test/resources"

  files foreach {
    case (file, count, fields) =>
      if (new java.io.File(s"$resources/$file").exists) {
        test(s"$file should read the correct header metadata") {
          val header = PlyHeader.read(s"$resources/$file");
          header.section("vertex").count should equal(count)
        }

        test(s"$file should have the correct schema") {
          val header = PlyHeader.read(s"$resources/$file");
          header.section("vertex").schema should equal(StructType(fields map {
            case (name, dataType) => StructField(name, dataType, nullable = false)
          }))
        }
      }
  }
} 
Example 51
Source File: DisplayDataSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.kernel.protocol.v5.content

import org.scalatest.FunSuite

import org.scalatest.{Matchers, FunSpec}
import play.api.data.validation.ValidationError
import play.api.libs.json._
import org.apache.toree.kernel.protocol.v5._

class DisplayDataSpec extends FunSpec with Matchers {
  val displayDataJson: JsValue = Json.parse("""
  {
    "source": "<STRING>",
    "data": {},
    "metadata": {}
  }
  """)

  val displayData: DisplayData = DisplayData(
    "<STRING>", Map(), Map()
  )

  describe("DisplayData") {
    describe("#toTypeString") {
      it("should return correct type") {
        DisplayData.toTypeString should be ("display_data")
      }
    }

    describe("implicit conversions") {
      it("should implicitly convert from valid json to a displayData instance") {
        // This is the least safe way to convert as an error is thrown if it fails
        displayDataJson.as[DisplayData] should be (displayData)
      }

      it("should also work with asOpt") {
        // This is safer, but we lose the error information as it returns
        // None if the conversion fails
        val newDisplayData = displayDataJson.asOpt[DisplayData]

        newDisplayData.get should be (displayData)
      }

      it("should also work with validate") {
        // This is the safest as it collects all error information (not just first error) and reports it
        val displayDataResults = displayDataJson.validate[DisplayData]

        displayDataResults.fold(
          (invalid: Seq[(JsPath, Seq[ValidationError])]) => println("Failed!"),
          (valid: DisplayData) => valid
        ) should be (displayData)
      }

      it("should implicitly convert from a displayData instance to valid json") {
        Json.toJson(displayData) should be (displayDataJson)
      }
    }
  }
} 
Example 52
Source File: HBaseCatalogSuite.scala    From hbase-connectors   with Apache License 2.0 5 votes vote down vote up
package org.apache.hadoop.hbase.spark

import org.apache.hadoop.hbase.spark.datasources.{DataTypeParserWrapper, DoubleSerDes, HBaseTableCatalog}
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.sql.types._
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}

class HBaseCatalogSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll  with Logging {

  val map = s"""MAP<int, struct<varchar:string>>"""
  val array = s"""array<struct<tinYint:tinyint>>"""
  val arrayMap = s"""MAp<int, ARRAY<double>>"""
  val catalog = s"""{
                    |"table":{"namespace":"default", "name":"htable"},
                    |"rowkey":"key1:key2",
                    |"columns":{
                    |"col1":{"cf":"rowkey", "col":"key1", "type":"string"},
                    |"col2":{"cf":"rowkey", "col":"key2", "type":"double"},
                    |"col3":{"cf":"cf1", "col":"col2", "type":"binary"},
                    |"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"},
                    |"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[DoubleSerDes].getName}"},
                    |"col6":{"cf":"cf1", "col":"col5", "type":"$map"},
                    |"col7":{"cf":"cf1", "col":"col6", "type":"$array"},
                    |"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"},
                    |"col9":{"cf":"cf1", "col":"col8", "type":"date"},
                    |"col10":{"cf":"cf1", "col":"col9", "type":"timestamp"}
                    |}
                    |}""".stripMargin
  val parameters = Map(HBaseTableCatalog.tableCatalog->catalog)
  val t = HBaseTableCatalog(parameters)

  def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
    test(s"parse ${dataTypeString.replace("\n", "")}") {
      assert(DataTypeParserWrapper.parse(dataTypeString) === expectedDataType)
    }
  }
  test("basic") {
    assert(t.getField("col1").isRowKey == true)
    assert(t.getPrimaryKey == "key1")
    assert(t.getField("col3").dt == BinaryType)
    assert(t.getField("col4").dt == TimestampType)
    assert(t.getField("col5").dt == DoubleType)
    assert(t.getField("col5").serdes != None)
    assert(t.getField("col4").serdes == None)
    assert(t.getField("col1").isRowKey)
    assert(t.getField("col2").isRowKey)
    assert(!t.getField("col3").isRowKey)
    assert(t.getField("col2").length == Bytes.SIZEOF_DOUBLE)
    assert(t.getField("col1").length == -1)
    assert(t.getField("col8").length == -1)
    assert(t.getField("col9").dt == DateType)
    assert(t.getField("col10").dt == TimestampType)
  }

  checkDataType(
    map,
    t.getField("col6").dt
  )

  checkDataType(
    array,
    t.getField("col7").dt
  )

  checkDataType(
    arrayMap,
    t.getField("col8").dt
  )

  test("convert") {
    val m = Map("hbase.columns.mapping" ->
      "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD DOUBLE c:b, C_FIELD BINARY c:c,",
      "hbase.table" -> "t1")
    val map = HBaseTableCatalog.convert(m)
    val json = map.get(HBaseTableCatalog.tableCatalog).get
    val parameters = Map(HBaseTableCatalog.tableCatalog->json)
    val t = HBaseTableCatalog(parameters)
    assert(t.getField("KEY_FIELD").isRowKey)
    assert(DataTypeParserWrapper.parse("STRING") === t.getField("A_FIELD").dt)
    assert(!t.getField("A_FIELD").isRowKey)
    assert(DataTypeParserWrapper.parse("DOUBLE") === t.getField("B_FIELD").dt)
    assert(DataTypeParserWrapper.parse("BINARY") === t.getField("C_FIELD").dt)
  }

  test("compatibility") {
    val m = Map("hbase.columns.mapping" ->
      "KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD DOUBLE c:b, C_FIELD BINARY c:c,",
      "hbase.table" -> "t1")
    val t = HBaseTableCatalog(m)
    assert(t.getField("KEY_FIELD").isRowKey)
    assert(DataTypeParserWrapper.parse("STRING") === t.getField("A_FIELD").dt)
    assert(!t.getField("A_FIELD").isRowKey)
    assert(DataTypeParserWrapper.parse("DOUBLE") === t.getField("B_FIELD").dt)
    assert(DataTypeParserWrapper.parse("BINARY") === t.getField("C_FIELD").dt)
  }
} 
Example 53
Source File: KeytabSettingsTest.scala    From stream-reactor   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.hbase.kerberos

import com.datamountaineer.streamreactor.connect.hbase.config.{HBaseConfig, HBaseConfigConstants}
import org.apache.kafka.common.config.ConfigException
import org.scalatest.{FunSuite, Matchers}

import scala.collection.JavaConverters._

class KeytabSettingsTest extends FunSuite with Matchers with FileCreation {
  test("validate a keytab setting") {
    val file = createFile("keytab1.keytab")
    try {
      val principal = "[email protected]"
      val config = HBaseConfig(
        Map(
          HBaseConfigConstants.KCQL_QUERY->s"INSERT INTO someTable SELECT * FROM someTable",
          HBaseConfigConstants.COLUMN_FAMILY->"someColumnFamily",
          HBaseConfigConstants.KerberosKey -> "true",
          HBaseConfigConstants.PrincipalKey -> principal,
          HBaseConfigConstants.KerberosKeyTabKey -> file.getAbsolutePath
        ).asJava
      )

      val actualSettings = KeytabSettings.from(config, HBaseConfigConstants)
      actualSettings shouldBe KeytabSettings(principal, file.getAbsolutePath, None)
    }
    finally {
      file.delete()
    }
  }

  test("throws an exception when principal is not set") {
    val file = createFile("keytab2.keytab")
    try {
      val principal = "[email protected]"
      val config = HBaseConfig(
        Map(
          HBaseConfigConstants.KCQL_QUERY->s"INSERT INTO someTable SELECT * FROM someTable",
          HBaseConfigConstants.COLUMN_FAMILY->"someColumnFamily",
          HBaseConfigConstants.KerberosKey -> "true",
          HBaseConfigConstants.KerberosKeyTabKey -> file.getAbsolutePath
        ).asJava
      )

      intercept[ConfigException] {
        KeytabSettings.from(config, HBaseConfigConstants)
      }
    }
    finally {
      file.delete()
    }
  }

  test("throws an exception when the keytab is not present") {
    val principal = "[email protected]"
    val config = HBaseConfig(
      Map(
        HBaseConfigConstants.KCQL_QUERY->s"INSERT INTO someTable SELECT * FROM someTable",
        HBaseConfigConstants.COLUMN_FAMILY->"someColumnFamily",
        HBaseConfigConstants.KerberosKey -> "true",
        HBaseConfigConstants.PrincipalKey -> principal,
        HBaseConfigConstants.KerberosKeyTabKey -> "does_not_exists.keytab"
      ).asJava
    )

    intercept[ConfigException] {
      KeytabSettings.from(config, HBaseConfigConstants)
    }
  }
} 
Example 54
Source File: JwtTokenAuthMiddlewareSpec.scala    From core   with Apache License 2.0 5 votes vote down vote up
package com.smartbackpackerapp.http.auth

import cats.effect.IO
import com.smartbackpackerapp.common.IOAssertion
import org.http4s.server.AuthMiddleware
import org.scalatest.{FunSuite, Matchers}

class JwtTokenAuthMiddlewareSpec extends FunSuite with Matchers {

  val token = "insert_here_your_long_long_access_token"

  test("it fails to create an auth middleware") {
    IOAssertion {
      new Middleware[IO](None).middleware.attempt.map { result =>
        assert(result.isLeft)
      }
    }
  }

  test("it create an auth middleware") {
    IOAssertion {
      new Middleware[IO](Some(token)).middleware.map { result =>
        result shouldBe an [AuthMiddleware[IO, String]]
      }
    }
  }

} 
Example 55
Source File: ConversionsSpec.scala    From core   with Apache License 2.0 5 votes vote down vote up
package com.smartbackpackerapp.scraper.sql

import com.smartbackpackerapp.model._
import com.smartbackpackerapp.scraper.model.VisaRequirementsFor
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck._
import org.scalatest.FunSuite
import org.scalatest.prop.PropertyChecks

class ConversionsSpec extends FunSuite with ConversionsArbitraries with PropertyChecks {

  forAll { (vr: VisaRequirementsFor) =>
    test(s"convert a $vr into a VisaRequirementsDTO") {
      val dto = (vr.from.value, vr.to.value, vr.visaCategory.toString, vr.description)
      assert(vr.toVisaRequirementsDTO == dto)
    }
  }

  forAll { (v: Vaccine) =>
    test(s"convert a $v into a VaccineDTO # ${v.hashCode()}") {
      val dto = (v.disease.value, v.description, v.diseaseCategories.map(_.toString).mkString(","))
      assert(v.toVaccineDTO == dto)
    }
  }

}

trait ConversionsArbitraries {

  implicit val visaCategory: Arbitrary[VisaCategory] = Arbitrary[VisaCategory] {
    val list = List(
      VisaNotRequired, VisaWaiverProgram, AdmissionRefused, TravelBanned,
      VisaRequired, VisaDeFactoRequired, ElectronicVisa, ElectronicVisitor,
      ElectronicTravelAuthority, FreeVisaOnArrival, VisaOnArrival,
      ElectronicVisaPlusVisaOnArrival, OnlineReciprocityFee,
      MainlandTravelPermit, HomeReturnPermitOnly, UnknownVisaCategory
    )
    Gen.oneOf(list)
  }

  implicit val visaRequirementsFor: Arbitrary[VisaRequirementsFor] = Arbitrary[VisaRequirementsFor] {
    for {
      f <- Gen.alphaUpperStr
      t <- Gen.alphaUpperStr
      c <- arbitrary[VisaCategory]
      d <- Gen.alphaStr
    } yield VisaRequirementsFor(CountryCode(f), CountryCode(t), c, d)
  }

  implicit val diseaseCategory: Arbitrary[DiseaseCategory] = Arbitrary[DiseaseCategory] {
    val list = List(
      AvoidNonSterileEquipment, TakeAntimalarialMeds, GetVaccinated,
      AvoidSharingBodyFluids, ReduceExposureToGerms, EatAndDrinkSafely,
      PreventBugBites, KeepAwayFromAnimals, UnknownDiseaseCategory
    )
    Gen.oneOf(list)
  }

  implicit val vaccine: Arbitrary[Vaccine] = Arbitrary[Vaccine] {
    for {
      d <- Gen.alphaStr
      x <- Gen.alphaStr
      c <- Gen.listOf(arbitrary[DiseaseCategory])
    } yield Vaccine(Disease(d), x, c)
  }

} 
Example 56
Source File: AwsClusterAdvanceTest.scala    From berilia   with Apache License 2.0 5 votes vote down vote up
package com.criteo.dev.cluster

import com.criteo.dev.cluster.aws._
import com.criteo.dev.cluster.command.{SshHiveAction, SshMultiAction}
import com.criteo.dev.cluster.docker.{CreateGatewayCliAction, DestroyGatewayCliAction}
import com.criteo.dev.cluster.utils.test.LoadConfig
import org.scalatest.{BeforeAndAfter, FunSuite}


class AwsClusterAdvanceTest extends FunSuite with BeforeAndAfter with LoadConfig {


  def testDbName = "testdb"
  def testTableName = "testtable"
  def testFileName = "testfile"
  def currentUser = System.getenv("USER")

  var clusterId: String = null

  test("Create a cluster, populate cluster") {

    //Create a docker cluster
    val cluster = CreateAwsCliAction(List("3"), config)
    clusterId = cluster.master.id

    assertResult(2)(cluster.slaves.size)
    assertResult(AwsRunning)(cluster.master.status)
    assertResult(currentUser)(cluster.user)

    //create database, table, and 2 partitions.
    val master = NodeFactory.getAwsNode(config.target.aws, cluster.master)
    SshHiveAction(master, List(s"create database $testDbName"))
    SshHiveAction(master, List(s"create table $testDbName.$testTableName (name string) partitioned by (month int)"))
    SshHiveAction(master, List(
      s"alter table $testDbName.$testTableName add partition (month=1)",
      s"alter table $testDbName.$testTableName add partition (month=1)"))

    //load data into the partitions (2 rows per partition)
    SshMultiAction(master, List(
      s"echo a | tee --append $testFileName",
      s"echo b | tee --append $testFileName"))
    SshHiveAction(master, List(
      s"load data local inpath '$testFileName' into table $testDbName.$testTableName partition (month=1)",
      s"load data local inpath '$testFileName' into table $testDbName.$testTableName partition (month=2)"))
  }


  test("Reconfigure a cluster and test the query.") {
    var clusters = ListAwsCliAction(List(), config)
    val cluster = getCluster(clusterId, clusters)
    val master = NodeFactory.getAwsNode(config.target.aws, cluster.master)
    ConfigureAwsCliAction(List(clusterId), config)
    RestartServicesCliAction(List(clusterId), config)

    //Run Hive Query, verify count is correct
    val results = SshHiveAction(master, List(s"select count(*) from $testDbName.$testTableName"))
    assertResult("4") (results.stripLineEnd)
  }

  test("Create a gateway") {
    CreateGatewayCliAction(List(clusterId), config)
    //TODO- Docker Gateway is in interactive mode as part of the CLI, so doesn't actually run as part of the program.
    //No way to unit test for now.. should we add a background mode like local-cluster?
    DestroyGatewayCliAction(List(), config)
  }

  def getCluster(clusterId: String, clusters: List[AwsCluster]) : AwsCluster = {
    val results = clusters.filter(_.master.id.equals(clusterId))
    if (results.length > 2) {
      fail("Invalid state, more than one cluster returned.")
    }

    if (results.length == 0) {
      fail(s"AWS cluster $clusterId not found")
    }
    results.last
  }

} 
Example 57
Source File: NavigationTest.scala    From metabrowse   with Apache License 2.0 5 votes vote down vote up
package metabrowse

import org.scalatest.FunSuite
import monaco.Range

class NavigationTest extends FunSuite {
  test("Navigation.parseState") {
    val state = Navigation.parseState("/path")
    assert(state.isDefined)
    assert(state.get.path == "/path")
    assert(state.get.selection.isEmpty)

    val stateWithSelection = Navigation.parseState("/path2#L11")
    assert(stateWithSelection.isDefined)
    assert(stateWithSelection.get.path == "/path2")
    assert(
      stateWithSelection.get.selection == Some(
        Navigation.Selection(11, 1, 11, 1)
      )
    )

    assert(Navigation.parseState("").isEmpty)
    assert(Navigation.parseState("#/path2#L11").isEmpty)
  }

  test("Navigation.fromHistoryState") {
    val noState = Navigation.fromHistoryState(null)
    assert(noState.isEmpty)

    val stateFromHash = Navigation.fromHistoryState("/path")
    assert(stateFromHash.nonEmpty)
    assert(stateFromHash.get.path == "/path")
    assert(stateFromHash.get.selection.isEmpty)

    val stateFromHashWithSelection = Navigation.fromHistoryState("/path2#L11")
    assert(stateFromHashWithSelection.isDefined)
    assert(stateFromHashWithSelection.get.path == "/path2")
    assert(
      stateFromHashWithSelection.get.selection == Some(
        Navigation.Selection(11, 1, 11, 1)
      )
    )
  }

  test("Navigation.Selection.toString") {
    assert(Navigation.Selection(11, 1, 11, 1).toString == "L11")
    assert(Navigation.Selection(11, 4, 11, 4).toString == "L11C4")
    assert(Navigation.Selection(11, 1, 12, 4).toString == "L11-L12C4")
    assert(Navigation.Selection(11, 2, 12, 4).toString == "L11C2-L12C4")
    assert(Navigation.Selection(11, 2, 12, 1).toString == "L11C2-L12")
  }

  test("Navigation.parseSelection") {
    val str = "L10C4-L14C20"
    val Some(parsed) = Navigation.parseSelection(str)

    assert(parsed == Navigation.Selection(10, 4, 14, 20))
    assert(parsed.toString == str)

    assert(Navigation.parseSelection("L10-C1") == None)
  }

  test("Navigation.parseSelection normalization") {
    val selection = Navigation.Selection(1, 1, 2, 1)
    assert(Navigation.parseSelection("L1-L2") == Some(selection))
    assert(Navigation.parseSelection("L1C1-L2") == Some(selection))
    assert(Navigation.parseSelection("L1-L2C1") == Some(selection))
    assert(Navigation.parseSelection("L1C1-L2C1") == Some(selection))
  }

  test("Navigation.parseSelection roundtrip") {
    val samples =
      """L1
        |L1C2
        |L1-L3
        |L1C2-L2
        |L1-L1C2
        |L110-L112C25
        |L110C42-L112C25
        |
        |L10-L2
        |"""

    samples.stripMargin.split('\n').filter(_.nonEmpty).foreach { selection =>
      assert(
        Some(selection) == Navigation.parseSelection(selection).map(_.toString)
      )
    }
  }
} 
Example 58
Source File: PakoSuite.scala    From metabrowse   with Apache License 2.0 5 votes vote down vote up
package metabrowse

import metabrowse.schema.Workspace
import org.scalatest.FunSuite
import scala.meta.internal.io.PathIO
import scala.scalajs.js
import scala.scalajs.js.annotation.JSGlobal
import scala.scalajs.js.annotation.JSImport
import scala.scalajs.js.typedarray.ArrayBuffer
import scala.scalajs.js.typedarray.TypedArrayBuffer
import scala.scalajs.js.typedarray.Uint8Array

class PakoSuite extends FunSuite {
  test("deflate") {
    val path = PathIO.workingDirectory
      .resolve("target")
      .resolve("metabrowse")
      .resolve("index.workspace.gz")
    val in = path.readAllBytes
    val input = new ArrayBuffer(in.length)
    val bbuf = TypedArrayBuffer.wrap(input)
    bbuf.put(in)
    val output = Pako.inflate(input)
    val out = Array.ofDim[Byte](output.byteLength)
    TypedArrayBuffer.wrap(output).get(out)
    val workspace = Workspace.parseFrom(out)
    val obtained =
      workspace.toProtoString.linesIterator.toList.sorted.mkString("\n").trim
    val expected =
      """
        |filenames: "paiges/core/src/main/scala/org/typelevel/paiges/Chunk.scala"
        |filenames: "paiges/core/src/main/scala/org/typelevel/paiges/Doc.scala"
        |filenames: "paiges/core/src/main/scala/org/typelevel/paiges/Document.scala"
        |filenames: "paiges/core/src/main/scala/org/typelevel/paiges/package.scala"
        |filenames: "paiges/core/src/test/scala/org/typelevel/paiges/DocumentTests.scala"
        |filenames: "paiges/core/src/test/scala/org/typelevel/paiges/Generators.scala"
        |filenames: "paiges/core/src/test/scala/org/typelevel/paiges/JsonTest.scala"
        |filenames: "paiges/core/src/test/scala/org/typelevel/paiges/PaigesTest.scala"
      """.stripMargin.trim
    assert(obtained == expected)
  }
} 
Example 59
Source File: MetabrowseServerSuite.scala    From metabrowse   with Apache License 2.0 5 votes vote down vote up
package metabrowse.tests

import java.io.File
import java.util.concurrent.TimeUnit
import metabrowse.server.MetabrowseServer
import metabrowse.server.Sourcepath
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.selenium.Chrome
import scala.meta.interactive.InteractiveSemanticdb

class MetabrowseServerSuite
    extends FunSuite
    with Chrome
    with BeforeAndAfterAll {
  val server = new MetabrowseServer()
  val sourcepath = Sourcepath("org.typelevel:paiges-core_2.12:0.2.1")
  override def beforeAll(): Unit = {
    server.start(sourcepath)
  }
  override def afterAll(): Unit = {
    server.stop()
    quit()
  }

  val host = "http://localhost:4000/#"
  val DocScala = s"$host/org/typelevel/paiges/Doc.scala#L209C14-L209C19"
  val ChunkScala = s"$host/org/typelevel/paiges/Chunk.scala#L5C24"

  // See: https://github.com/SeleniumHQ/selenium/blob/master/rb/lib/selenium/webdriver/common/keys.rb
  val F12 = "\ue03C"
  val Command = "\ue03D"
  def sleep(seconds: Int): Unit = {
    Thread.sleep(TimeUnit.SECONDS.toMillis(seconds))
  }

  // NOTE(olafur): This is a first selenium test I ever write so it's quite hacky.
  // This test likely fails in a CI environment (needs chrome installed) and also fails
  // on non-macOS since it relies on the Mac-specific "Cmd" keyboard modifier.
  test("goto definition") {
    go to DocScala
    assert(pageTitle == "Metabrowse")
    className("mtk1")
    sleep(10)
    pressKeys(Command + F12)
    sleep(5)
    assert(currentUrl == ChunkScala)
    goBack()
    sleep(5)
    assert(currentUrl == DocScala)
  }

  test("urlForSymbol") {
    val g = InteractiveSemanticdb.newCompiler(
      sourcepath.classpath.mkString(File.pathSeparator),
      Nil
    )
    val some =
      g.rootMirror.staticClass("scala.Some").info.member(g.TermName("isEmpty"))
    val obtained = server.urlForSymbol(g)(some).get
    g.askShutdown()
    assert(obtained == "#/scala/Option.scala#L333C6")
  }
} 
Example 60
Source File: BaseMetabrowseCliSuite.scala    From metabrowse   with Apache License 2.0 5 votes vote down vote up
package metabrowse.tests

import caseapp.RemainingArgs
import java.nio.file.Files
import metabrowse.cli.MetabrowseCli
import metabrowse.cli.MetabrowseOptions
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import scala.meta.io.AbsolutePath
import scala.meta.testkit.DiffAssertions
import metabrowse.{schema => d}
import metabrowse.MetabrowseEnrichments._
import GeneratedSiteEnrichments._

abstract class BaseMetabrowseCliSuite
    extends FunSuite
    with BeforeAndAfterAll
    with DiffAssertions {
  var out: AbsolutePath = _
  def options = MetabrowseOptions(
    out.toString(),
    cleanTargetFirst = true,
    nonInteractive = true
  )
  def files: Seq[String] =
    BuildInfo.exampleClassDirectory.map(_.getAbsolutePath).toSeq

  def runCli(): Unit = MetabrowseCli.run(options, RemainingArgs(files, Nil))

  override def beforeAll(): Unit = {
    out = AbsolutePath(Files.createTempDirectory("metabrowse"))
    out.toFile.deleteOnExit()
    runCli()
  }

  def checkSymbolIndex(id: String, expected: String) = {
    test(id) {
      val indexes = d.SymbolIndexes.parseFromCompressedPath(
        out.resolve("symbol").resolve(id.symbolIndexPath)
      )
      val index = indexes.indexes.find(_.symbol == id).get
      // Sort ranges to ensure we assert against deterministic input.
      val indexNormalized = index.copy(
        references = index.references
          .mapValues { ranges =>
            ranges.copy(ranges = ranges.ranges.sortBy(_.startLine))
          }
          .iterator
          .toMap
      )
      val obtained = indexNormalized.toProtoString
      assertNoDiffOrPrintExpected(obtained, expected)
    }
  }

} 
Example 61
Source File: BaseTransactionSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.transactions

import java.io.File

import com.typesafe.config.{Config, ConfigFactory}
import com.wavesplatform.it._
import monix.eval.Coeval
import org.scalatest.{BeforeAndAfterAll, FunSuite, Suite}

import scala.jdk.CollectionConverters._
import scala.concurrent.ExecutionContext

trait BaseTransactionSuiteLike extends WaitForHeight2 with IntegrationSuiteWithThreeAddresses with BeforeAndAfterAll with NodesFromDocker {
  this: Suite =>

  protected implicit val ec: ExecutionContext = ExecutionContext.Implicits.global

  protected def nodeConfigs: Seq[Config] =
    NodeConfigs.newBuilder
      .overrideBase(_.quorum(0))
      .withDefault(1)
      .withSpecial(_.nonMiner)
      .buildNonConflicting()

  override def miner: Node = nodes.head

  // protected because https://github.com/sbt/zinc/issues/292
  protected val theNodes: Coeval[Seq[Node]] = Coeval.evalOnce {
    Option(System.getProperty("waves.it.config.file")) match {
      case None => dockerNodes()
      case Some(filePath) =>
        val defaultConfig = ConfigFactory.load()
        ConfigFactory
          .parseFile(new File(filePath))
          .getConfigList("nodes")
          .asScala
          .toSeq
          .map(cfg => new ExternalNode(cfg.withFallback(defaultConfig).resolve()))
    }
  }

  override protected def nodes: Seq[Node] = theNodes()

  protected override def beforeAll(): Unit = {
    theNodes.run
    super.beforeAll()
  }
}

abstract class BaseTransactionSuite extends FunSuite with BaseTransactionSuiteLike 
Example 62
Source File: FairPoSTestSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync

import com.typesafe.config.{Config, ConfigFactory}
import org.scalatest.{CancelAfterFailure, FunSuite}
import com.wavesplatform.it.api.SyncHttpApi._
import com.wavesplatform.it.transactions.NodesFromDocker
import scala.concurrent.duration._

class FairPoSTestSuite extends FunSuite with CancelAfterFailure with NodesFromDocker {
  import FairPoSTestSuite._

  override protected def nodeConfigs: Seq[Config] = Configs

  test("blockchain grows with FairPoS activated") {
    nodes.waitForSameBlockHeadersAt(height = 10, conditionAwaitTime = 11.minutes)

    val txId = nodes.head.transfer(nodes.head.address, nodes.last.address, transferAmount, minFee).id
    nodes.last.waitForTransaction(txId)

    val heightAfterTransfer = nodes.head.height

    nodes.waitForSameBlockHeadersAt(heightAfterTransfer + 10, conditionAwaitTime = 11.minutes)
  }
}

object FairPoSTestSuite {
  import com.wavesplatform.it.NodeConfigs._
  private val microblockActivationHeight = 0
  private val fairPoSActivationHeight    = 10
  private val vrfActivationHeight        = 14

  private val config =
    ConfigFactory.parseString(s"""
    |waves {
    |   blockchain.custom {
    |      functionality {
    |        pre-activated-features {1 = $microblockActivationHeight, 8 = $fairPoSActivationHeight, 17 = $vrfActivationHeight}
    |        generation-balance-depth-from-50-to-1000-after-height = 1000
    |      }
    |   }
    |   miner.quorum = 1
    |}""".stripMargin)

  val Configs: Seq[Config] = Default.map(config.withFallback(_)).take(3)
} 
Example 63
Source File: DebugPortfoliosSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync.debug

import com.typesafe.config.Config
import com.wavesplatform.it.{Node, NodeConfigs}
import com.wavesplatform.it.api.SyncHttpApi._
import com.wavesplatform.it.transactions.NodesFromDocker
import com.wavesplatform.it.util._
import com.wavesplatform.it.sync._
import org.scalatest.FunSuite

class DebugPortfoliosSuite extends FunSuite with NodesFromDocker {
  override protected def nodeConfigs: Seq[Config] =
    NodeConfigs.newBuilder
      .overrideBase(_.quorum(0))
      .withDefault(entitiesNumber = 1)
      .buildNonConflicting()

  private def sender: Node = nodes.head

  private val firstAddress  = sender.createAddress()
  private val secondAddress = sender.createAddress()

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    sender.transfer(sender.address, firstAddress, 20.waves, minFee, waitForTx = true)
    sender.transfer(sender.address, secondAddress, 20.waves, minFee, waitForTx = true)
  }

  test("getting a balance considering pessimistic transactions from UTX pool - changed after UTX") {
    val portfolioBefore = sender.debugPortfoliosFor(firstAddress, considerUnspent = true)
    val utxSizeBefore   = sender.utxSize

    sender.transfer(firstAddress, secondAddress, 5.waves, 5.waves)
    sender.transfer(secondAddress, firstAddress, 7.waves, 5.waves)

    sender.waitForUtxIncreased(utxSizeBefore)

    val portfolioAfter = sender.debugPortfoliosFor(firstAddress, considerUnspent = true)

    val expectedBalance = portfolioBefore.balance - 10.waves // withdraw + fee
    assert(portfolioAfter.balance == expectedBalance)

  }

  test("getting a balance without pessimistic transactions from UTX pool - not changed after UTX") {
    nodes.waitForHeightArise()

    val portfolioBefore = sender.debugPortfoliosFor(firstAddress, considerUnspent = false)
    val utxSizeBefore   = sender.utxSize

    sender.transfer(firstAddress, secondAddress, 5.waves, fee = 5.waves)
    sender.waitForUtxIncreased(utxSizeBefore)

    val portfolioAfter = sender.debugPortfoliosFor(firstAddress, considerUnspent = false)
    assert(portfolioAfter.balance == portfolioBefore.balance)
  }
} 
Example 64
Source File: ForgeReturnedToUTXSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync

import com.typesafe.config.{Config, ConfigFactory}
import com.wavesplatform.it.api.SyncHttpApi._
import com.wavesplatform.it.api.TransactionInfo
import com.wavesplatform.it.transactions.NodesFromDocker
import org.scalatest.{CancelAfterFailure, FunSuite, Matchers}

class ForgeReturnedToUTXSuite extends FunSuite with CancelAfterFailure with NodesFromDocker with Matchers {

  import ForgeReturnedToUTXSuite._
  override protected def nodeConfigs: Seq[Config] = Configs

  private def miner = nodes.head
  private def last  = nodes.last

  test("dependent trasactions can be added to UTX if first mined and returned to UTX") {

    //asset tx should be mined in first microblock as as new keyblock mined, others microblocks should not be applied due to big microblockInterval
    val assetId                      = last.issue(last.address, "asset", "descr", issueAmount, 0, reissuable = false, issueFee, waitForTx = true).id
    val issueAssetInitialHeight: Int = last.transactionInfo[TransactionInfo](assetId).height

    //all microblocks should returned to utx, assetId should be returned to UTX and no any microblocks will be mined on this height
    //so trasfer tx will stay in utx until new keyblock mined
    val transferTx = last.transfer(last.address, miner.address, 1L, minFee, Some(assetId), None, waitForTx = true).id

    val issueAssetHeight = last.transactionInfo[TransactionInfo](assetId).height
    val transferTxHeight = last.transactionInfo[TransactionInfo](transferTx).height

    //trasfer tx and issue asset tx should be placed in the same microblock
    transferTxHeight shouldBe issueAssetHeight
    transferTxHeight shouldNot be(issueAssetInitialHeight)

  }

}

object ForgeReturnedToUTXSuite {
  import com.wavesplatform.it.NodeConfigs._

  //microblock interval should be greater than avarage block interval
  val microblockInterval  = 60
  private val minerConfig = ConfigFactory.parseString(s"""
                                                         |waves {
                                                         |  miner {
                                                         |    micro-block-interval = ${microblockInterval}s
                                                         |    min-micro-block-age = 60s
                                                         |  }
                                                         |  blockchain.custom.genesis {
                                                         |     average-block-delay = 20s
                                                         |  }
                                                         |  miner.quorum = 1
                                                         |}""".stripMargin)

  val Configs: Seq[Config] = Seq(
    minerConfig.withFallback(Default.head),
    minerConfig.withFallback(Default(1))
  )

} 
Example 65
Source File: GrpcBaseTransactionSuiteLike.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync.grpc

import java.io.File

import com.typesafe.config.{Config, ConfigFactory}
import com.wavesplatform.it.transactions.NodesFromDocker
import com.wavesplatform.it.{ExternalNode, GrpcIntegrationSuiteWithThreeAddress, GrpcWaitForHeight, Node, NodeConfigs}
import monix.eval.Coeval
import org.scalatest.{BeforeAndAfterAll, FunSuite, Suite}

import scala.concurrent.ExecutionContext
import scala.jdk.CollectionConverters._

trait GrpcBaseTransactionSuiteLike
  extends GrpcWaitForHeight
  with GrpcIntegrationSuiteWithThreeAddress
  with BeforeAndAfterAll
  with NodesFromDocker { this: Suite =>

  protected implicit val ec: ExecutionContext = ExecutionContext.Implicits.global

  protected def nodeConfigs: Seq[Config] =
    NodeConfigs.newBuilder
      .overrideBase(_.quorum(0))
      .withDefault(1)
      .withSpecial(_.nonMiner)
      .buildNonConflicting()

  // protected because https://github.com/sbt/zinc/issues/292
  protected val theNodes: Coeval[Seq[Node]] = Coeval.evalOnce {
    Option(System.getProperty("waves.it.config.file")) match {
      case None => dockerNodes()
      case Some(filePath) =>
        val defaultConfig = ConfigFactory.load()
        ConfigFactory
          .parseFile(new File(filePath))
          .getConfigList("nodes")
          .asScala
          .toSeq
          .map(cfg => new ExternalNode(cfg.withFallback(defaultConfig).resolve()))
    }
  }

  protected override def beforeAll(): Unit = {
    theNodes.run
    super.beforeAll()
  }
}

abstract class GrpcBaseTransactionSuite extends FunSuite with GrpcBaseTransactionSuiteLike 
Example 66
Source File: RideCreateMerkleRootTestSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync

import com.typesafe.config.Config
import com.wavesplatform.account._
import com.wavesplatform.common.merkle.Merkle
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.common.utils.{Base58, EitherExt2}
import com.wavesplatform.features.BlockchainFeatures
import com.wavesplatform.it.api.SyncHttpApi._
import com.wavesplatform.it.api.Transaction
import com.wavesplatform.it.transactions.NodesFromDocker
import com.wavesplatform.it.{Node, NodeConfigs, ReportingTestName, TransferSending}
import com.wavesplatform.lang.v1.compiler.Terms._
import com.wavesplatform.lang.v1.estimator.v3.ScriptEstimatorV3
import com.wavesplatform.state._
import com.wavesplatform.transaction.Asset._
import com.wavesplatform.transaction.{Proofs, TxVersion}
import com.wavesplatform.transaction.smart.script.ScriptCompiler
import com.wavesplatform.transaction.transfer.TransferTransaction
import org.scalatest.prop.TableDrivenPropertyChecks
import org.scalatest.{CancelAfterFailure, FunSuite, Matchers}


class RideCreateMerkleRootTestSuite
    extends FunSuite
    with CancelAfterFailure
    with TransferSending
    with NodesFromDocker
    with ReportingTestName
    with Matchers
    with TableDrivenPropertyChecks {
  override def nodeConfigs: Seq[Config] =
    NodeConfigs.newBuilder
      .overrideBase(_.quorum(0))
      .overrideBase(_.preactivatedFeatures((14, 1000000), BlockchainFeatures.NG.id.toInt -> 0, BlockchainFeatures.FairPoS.id.toInt -> 0, BlockchainFeatures.Ride4DApps.id.toInt -> 0, BlockchainFeatures.BlockV5.id.toInt -> 0))
      .withDefault(1)
      .buildNonConflicting()

  private def sender: Node         = nodes.last

  test("Ride createMerkleRoot") {
    val script =  """
        |{-# STDLIB_VERSION 4 #-}
        |{-# CONTENT_TYPE DAPP #-}
        |
        | @Callable(inv)
        |func foo(proof: List[ByteVector], id: ByteVector, index: Int) = [
        | BinaryEntry("root", createMerkleRoot(proof, id, index))
        |]
        """.stripMargin
    val cscript = ScriptCompiler.compile(script, ScriptEstimatorV3).explicitGet()._1.bytes().base64
    val node = nodes.head
    nodes.waitForHeightArise()
    val tx1 = node.broadcastTransfer(node.keyPair, sender.address, setScriptFee, minFee, None, None, version = TxVersion.V3, waitForTx = false)
    val txId1 = tx1.id
    val tx2 = node.broadcastTransfer(node.keyPair, node.address, 1, minFee, None, None, version = TxVersion.V3, waitForTx = false)
    val txId2 = tx2.id
    val tx3 = node.broadcastTransfer(node.keyPair, node.address, 1, minFee, None, None, version = TxVersion.V3, waitForTx = false)
    val txId3 = tx3.id
    val tx4 = node.broadcastTransfer(node.keyPair, node.address, 1, minFee, None, None, version = TxVersion.V3, waitForTx = false)
    val txId4 = tx4.id
    val tx5 = node.broadcastTransfer(node.keyPair, node.address, 1, minFee, None, None, version = TxVersion.V3, waitForTx = false)
    val txId5 = tx5.id

    val height = node.height

    nodes.waitForHeightArise()

    def tt(tx: Transaction) = TransferTransaction.create(
      tx.version.get,
      PublicKey(Base58.decode(tx.senderPublicKey.get)),
      Address.fromString(tx.recipient.get).explicitGet(),
      Waves ,
      tx.fee, ByteStr.empty,  // attachment
      tx.timestamp,
      Proofs(tx.proofs.get.map(v => ByteStr(Base58.decode(v))))
      ).explicitGet()
    val natives = Seq(tx1, tx2, tx3, tx4, tx5).map(tt).map(t => Base58.encode(t.id().arr) -> t).toMap

    val root = Base58.decode(node.blockAt(height).transactionsRoot.get)

    val proofs = nodes.head.getMerkleProof(txId1, txId2, txId3, txId4, txId5)

    sender.setScript(sender.address, Some(cscript), setScriptFee, waitForTx = true).id

    for(p <- proofs) {
      node.invokeScript(
        node.address,
        sender.address,
        func = Some("foo"),
        args = List(ARR(p.merkleProof.map(v => CONST_BYTESTR(ByteStr(Base58.decode(v))).explicitGet()).toIndexedSeq, false).explicitGet(),
                    CONST_BYTESTR(ByteStr(Merkle.hash(natives(p.id).bytes()))).explicitGet(),
                    CONST_LONG(p.transactionIndex.toLong)),
        payment = Seq(),
        fee = 2*smartFee+minFee,
        waitForTx = true
      )
      node.getDataByKey(sender.address, "root") shouldBe BinaryDataEntry("root", ByteStr(root))
    }
  }
} 
Example 67
Source File: MinerStateTestSuite.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.it.sync

import com.typesafe.config.{Config, ConfigFactory}
import com.wavesplatform.it.api.State
import com.wavesplatform.it.api.SyncHttpApi._
import com.wavesplatform.it.transactions.NodesFromDocker
import com.wavesplatform.it.util._
import org.scalatest.{CancelAfterFailure, FunSuite, Matchers}
import scala.concurrent.duration._

class MinerStateTestSuite extends FunSuite with CancelAfterFailure with NodesFromDocker with Matchers {
  import MinerStateTestSuite._

  override protected def nodeConfigs: Seq[Config] = Configs

  private val transferAmount = 1000.waves

  private def miner = nodes.head
  private def last  = nodes.last

  test("node w/o balance can forge blocks after effective balance increase") {
    val newAddress = last.createAddress()

    val (balance1, eff1)        = miner.accountBalances(miner.address)
    val minerFullBalanceDetails = miner.balanceDetails(miner.address)
    assert(balance1 == minerFullBalanceDetails.available)
    assert(eff1 == minerFullBalanceDetails.effective)

    val (balance2, eff2)     = last.accountBalances(newAddress)
    val newAccBalanceDetails = last.balanceDetails(newAddress)
    assert(balance2 == newAccBalanceDetails.available)
    assert(eff2 == newAccBalanceDetails.effective)

    val minerInfoBefore = last.debugMinerInfo()
    all(minerInfoBefore) shouldNot matchPattern { case State(`newAddress`, _, ts) if ts > 0 => }

    miner.waitForPeers(1)
    val txId = miner.transfer(miner.address, newAddress, transferAmount, minFee).id
    nodes.waitForHeightAriseAndTxPresent(txId)

    val heightAfterTransfer = miner.height

    last.assertBalances(newAddress, balance2 + transferAmount, eff2 + transferAmount)

    last.waitForHeight(heightAfterTransfer + 51, 6.minutes) // if you know how to reduce waiting time, please ping @monroid

    assert(last.balanceDetails(newAddress).generating == balance2 + transferAmount)

    val minerInfoAfter = last.debugMinerInfo()
    atMost(1, minerInfoAfter) should matchPattern { case State(`newAddress`, _, ts) if ts > 0 => }

    last.waitForPeers(1)
    val leaseBack = last.lease(newAddress, miner.address, (transferAmount - minFee), minFee).id
    nodes.waitForHeightAriseAndTxPresent(leaseBack)

    assert(last.balanceDetails(newAddress).generating == balance2)

    all(miner.debugMinerInfo()) shouldNot matchPattern { case State(`newAddress`, _, ts) if ts > 0 => }

    all(last.debugMinerInfo()) shouldNot matchPattern { case State(`newAddress`, _, ts) if ts > 0 => }

  }
}

object MinerStateTestSuite {
  import com.wavesplatform.it.NodeConfigs._
  private val minerConfig = ConfigFactory.parseString(s"""
    |waves {
    |  synchronization.synchronization-timeout = 10s
    |  blockchain.custom.functionality {
    |    pre-activated-features.1 = 0
    |    generation-balance-depth-from-50-to-1000-after-height = 100
    |  }
    |  blockchain.custom.genesis {
    |     average-block-delay = 5s
    |  }
    |  miner.quorum = 1
    |}""".stripMargin)

  val Configs: Seq[Config] = Seq(
    minerConfig.withFallback(Default.head),
    minerConfig.withFallback(Default(1))
  )

} 
Example 68
Source File: WalletSpecification.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.lagonaki.unit

import java.io.File
import java.nio.file.Files

import cats.syntax.option._
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.settings.WalletSettings
import com.wavesplatform.wallet.Wallet
import org.scalatest.{FunSuite, Matchers}

class WalletSpecification extends FunSuite with Matchers {

  private val walletSize = 10
  val w                  = Wallet(WalletSettings(None, "cookies".some, ByteStr.decodeBase58("FQgbSAm6swGbtqA3NE8PttijPhT4N3Ufh4bHFAkyVnQz").toOption))

  test("wallet - acc creation") {
    w.generateNewAccounts(walletSize)

    w.privateKeyAccounts.size shouldBe walletSize
    w.privateKeyAccounts.map(_.toAddress.toString) shouldBe Seq(
      "3MqMwwHW4v2nSEDHVWoh8RCQL8QrsWLkkeB",
      "3MuwVgJA8EXHukxo6rcakT5tD6FpvACtitG",
      "3MuAvUG4EAsG9RP9jaWjewCVmggaQD2t39B",
      "3MqoX4A3UGBYU7cX2JPs6BCzntNC8K8FBR4",
      "3N1Q9VVVQtY3GqhwHtJDEyHb3oWBcerZL8X",
      "3NARifVFHthMDnCwBacXijPB2szAgNTeBCz",
      "3N6dsnfD88j5yKgpnEavaaJDzAVSRBRVbMY",
      "3MufvXKZxLuNn5SHcEgGc2Vo7nLWnKVskfJ",
      "3Myt4tocZmj7o3d1gnuWRrnQWcoxvx5G7Ac",
      "3N3keodUiS8WLEw9W4BKDNxgNdUpwSnpb3K"
    )
  }

  test("wallet - acc deletion") {

    val head = w.privateKeyAccounts.head
    w.deleteAccount(head)
    assert(w.privateKeyAccounts.lengthCompare(walletSize - 1) == 0)

    w.deleteAccount(w.privateKeyAccounts.head)
    assert(w.privateKeyAccounts.lengthCompare(walletSize - 2) == 0)

    w.privateKeyAccounts.foreach(w.deleteAccount)

    assert(w.privateKeyAccounts.isEmpty)
  }

  test("reopening") {
    val walletFile = Some(createTestTemporaryFile("wallet", ".dat"))

    val w1 = Wallet(WalletSettings(walletFile, "cookies".some, ByteStr.decodeBase58("FQgbSAm6swGbtqA3NE8PttijPhT4N3Ufh4bHFAkyVnQz").toOption))
    w1.generateNewAccounts(10)
    val w1PrivateKeys = w1.privateKeyAccounts
    val w1nonce              = w1.nonce

    val w2 = Wallet(WalletSettings(walletFile, "cookies".some, None))
    w2.privateKeyAccounts.nonEmpty shouldBe true
    w2.privateKeyAccounts shouldEqual w1PrivateKeys
    w2.nonce shouldBe w1nonce
  }

  test("reopen with incorrect password") {
    val file = Some(createTestTemporaryFile("wallet", ".dat"))
    val w1   = Wallet(WalletSettings(file, "password".some, ByteStr.decodeBase58("FQgbSAm6swGbtqA3NE8PttijPhT4N3Ufh4bHFAkyVnQz").toOption))
    w1.generateNewAccounts(3)

    assertThrows[IllegalArgumentException] {
      Wallet(WalletSettings(file, "incorrect password".some, None))
    }
  }

  def createTestTemporaryFile(name: String, ext: String): File = {
    val file = Files.createTempFile(name, ext).toFile
    file.deleteOnExit()

    file
  }
} 
Example 69
Source File: MicroBlockSpecification.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.lagonaki.unit

import com.wavesplatform.account.KeyPair
import com.wavesplatform.block.{Block, MicroBlock}
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.common.utils.EitherExt2
import com.wavesplatform.mining.Miner
import com.wavesplatform.state.diffs.produce
import com.wavesplatform.transaction.Asset.{IssuedAsset, Waves}
import com.wavesplatform.transaction._
import com.wavesplatform.transaction.transfer._
import org.scalamock.scalatest.MockFactory
import org.scalatest.words.ShouldVerb
import org.scalatest.{FunSuite, Matchers}

import scala.util.Random

class MicroBlockSpecification extends FunSuite with Matchers with MockFactory with ShouldVerb {

  val prevResBlockSig  = ByteStr(Array.fill(Block.BlockIdLength)(Random.nextInt(100).toByte))
  val totalResBlockSig = ByteStr(Array.fill(Block.BlockIdLength)(Random.nextInt(100).toByte))
  val reference        = Array.fill(Block.BlockIdLength)(Random.nextInt(100).toByte)
  val sender           = KeyPair(reference.dropRight(2))
  val gen              = KeyPair(reference)

  test("MicroBlock with txs bytes/parse roundtrip") {

    val ts                       = System.currentTimeMillis() - 5000
    val tr: TransferTransaction  = TransferTransaction.selfSigned(1.toByte, sender, gen.toAddress, Waves, 5, Waves, 2, ByteStr.empty,  ts + 1).explicitGet()
    val assetId                  = IssuedAsset(ByteStr(Array.fill(AssetIdLength)(Random.nextInt(100).toByte)))
    val tr2: TransferTransaction = TransferTransaction.selfSigned(1.toByte, sender, gen.toAddress, assetId, 5, Waves, 2, ByteStr.empty,  ts + 2).explicitGet()

    val transactions = Seq(tr, tr2)

    val microBlock  = MicroBlock.buildAndSign(3.toByte, sender, transactions, prevResBlockSig, totalResBlockSig).explicitGet()
    val parsedBlock = MicroBlock.parseBytes(microBlock.bytes()).get

    assert(microBlock.signaturesValid().isRight)
    assert(parsedBlock.signaturesValid().isRight)

    assert(microBlock.signature == parsedBlock.signature)
    assert(microBlock.sender == parsedBlock.sender)
    assert(microBlock.totalResBlockSig == parsedBlock.totalResBlockSig)
    assert(microBlock.reference == parsedBlock.reference)
    assert(microBlock.transactionData == parsedBlock.transactionData)
    assert(microBlock == parsedBlock)
  }

  test("MicroBlock cannot be created with zero transactions") {

    val transactions       = Seq.empty[TransferTransaction]
    val eitherBlockOrError = MicroBlock.buildAndSign(3.toByte, sender, transactions, prevResBlockSig, totalResBlockSig)

    eitherBlockOrError should produce("cannot create empty MicroBlock")
  }

  test("MicroBlock cannot contain more than Miner.MaxTransactionsPerMicroblock") {

    val transaction =
      TransferTransaction.selfSigned(1.toByte, sender, gen.toAddress, Waves, 5, Waves, 1000, ByteStr.empty,  System.currentTimeMillis()).explicitGet()
    val transactions = Seq.fill(Miner.MaxTransactionsPerMicroblock + 1)(transaction)

    val eitherBlockOrError = MicroBlock.buildAndSign(3.toByte, sender, transactions, prevResBlockSig, totalResBlockSig)
    eitherBlockOrError should produce("too many txs in MicroBlock")
  }
} 
Example 70
Source File: AliasRequestTests.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.transaction.api.http.alias

import com.wavesplatform.api.http.requests.{CreateAliasV1Request, SignedCreateAliasV1Request}
import org.scalatest.{FunSuite, Matchers}
import play.api.libs.json.Json

class AliasRequestTests extends FunSuite with Matchers {
  test("CreateAliasRequest") {
    val json =
      """
        {
          "sender": "3Myss6gmMckKYtka3cKCM563TBJofnxvfD7",
           "fee": 10000000,
           "alias": "ALIAS"
        }
      """

    val req = Json.parse(json).validate[CreateAliasV1Request].get

    req shouldBe CreateAliasV1Request("3Myss6gmMckKYtka3cKCM563TBJofnxvfD7", "ALIAS", 10000000)
  }

  test("SignedCreateAliasRequest") {
    val json =
      """
         {
           "senderPublicKey": "CRxqEuxhdZBEHX42MU4FfyJxuHmbDBTaHMhM3Uki7pLw",
           "fee": 100000,
           "alias": "ALIAS",
           "timestamp": 1488807184731,
           "signature": "3aB6cL1osRNopWyqBYpJQCVCXNLibkwM58dvK85PaTK5sLV4voMhe5E8zEARM6YDHnQP5YE3WX8mxdFp3ciGwVfy"
          }
       """

    val req = Json.parse(json).validate[SignedCreateAliasV1Request].get

    req shouldBe SignedCreateAliasV1Request(
      "CRxqEuxhdZBEHX42MU4FfyJxuHmbDBTaHMhM3Uki7pLw",
      100000,
      "ALIAS",
      1488807184731L,
      "3aB6cL1osRNopWyqBYpJQCVCXNLibkwM58dvK85PaTK5sLV4voMhe5E8zEARM6YDHnQP5YE3WX8mxdFp3ciGwVfy"
    )
  }
} 
Example 71
Source File: LeaseV1RequestsTests.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.transaction.api.http.leasing

import com.wavesplatform.api.http.requests.{LeaseCancelV1Request, LeaseV1Request, SignedLeaseCancelV1Request, SignedLeaseV1Request}
import org.scalatest.{FunSuite, Matchers}
import play.api.libs.json.Json

class LeaseV1RequestsTests extends FunSuite with Matchers {

  test("LeaseRequest") {
    val json =
      """
        {
          "amount": 100000,
          "recipient": "3Myss6gmMckKYtka3cKCM563TBJofnxvfD7",
          "sender": "3MwKzMxUKaDaS4CXM8KNowCJJUnTSHDFGMb",
          "fee": 1000
        }
      """

    val req = Json.parse(json).validate[LeaseV1Request].get

    req shouldBe LeaseV1Request("3MwKzMxUKaDaS4CXM8KNowCJJUnTSHDFGMb", 100000, 1000, "3Myss6gmMckKYtka3cKCM563TBJofnxvfD7")
  }

  test("LeaseCancelRequest") {
    val json =
      """
        {
          "sender": "3Myss6gmMckKYtka3cKCM563TBJofnxvfD7",
          "txId": "ABMZDPY4MyQz7kKNAevw5P9eNmRErMutJoV9UNeCtqRV",
          "fee": 10000000
        }
      """

    val req = Json.parse(json).validate[LeaseCancelV1Request].get

    req shouldBe LeaseCancelV1Request("3Myss6gmMckKYtka3cKCM563TBJofnxvfD7", "ABMZDPY4MyQz7kKNAevw5P9eNmRErMutJoV9UNeCtqRV", 10000000)
  }

  test("SignedLeaseRequest") {
    val json =
      """
        {
         "senderPublicKey":"CRxqEuxhdZBEHX42MU4FfyJxuHmbDBTaHMhM3Uki7pLw",
         "recipient":"3MwKzMxUKaDaS4CXM8KNowCJJUnTSHDFGMb",
         "fee":1000000,
         "timestamp":0,
         "amount":100000,
         "signature":"4VPg4piLZGQz3vBqCPbjTfAR4cDErMi57rDvyith5XrQJDLryU2w2JsL3p4ejEqTPpctZ5YekpQwZPTtYiGo5yPC"
         }
      """

    val req = Json.parse(json).validate[SignedLeaseV1Request].get

    req shouldBe SignedLeaseV1Request(
      "CRxqEuxhdZBEHX42MU4FfyJxuHmbDBTaHMhM3Uki7pLw",
      100000L,
      1000000L,
      "3MwKzMxUKaDaS4CXM8KNowCJJUnTSHDFGMb",
      0L,
      "4VPg4piLZGQz3vBqCPbjTfAR4cDErMi57rDvyith5XrQJDLryU2w2JsL3p4ejEqTPpctZ5YekpQwZPTtYiGo5yPC"
    )
  }

  test("SignedLeaseCancelRequest") {
    val json =
      """
        {
         "senderPublicKey":"CRxqEuxhdZBEHX42MU4FfyJxuHmbDBTaHMhM3Uki7pLw",
         "txId":"D6HmGZqpXCyAqpz8mCAfWijYDWsPKncKe5v3jq1nTpf5",
         "timestamp":0,
         "fee": 1000000,
         "signature":"4VPg4piLZGQz3vBqCPbjTfAR4cDErMi57rDvyith5XrQJDLryU2w2JsL3p4ejEqTPpctZ5YekpQwZPTtYiGo5yPC"
         }
      """

    val req = Json.parse(json).validate[SignedLeaseCancelV1Request].get

    req shouldBe SignedLeaseCancelV1Request(
      "CRxqEuxhdZBEHX42MU4FfyJxuHmbDBTaHMhM3Uki7pLw",
      "D6HmGZqpXCyAqpz8mCAfWijYDWsPKncKe5v3jq1nTpf5",
      0L,
      "4VPg4piLZGQz3vBqCPbjTfAR4cDErMi57rDvyith5XrQJDLryU2w2JsL3p4ejEqTPpctZ5YekpQwZPTtYiGo5yPC",
      1000000L
    )
  }
} 
Example 72
Source File: PortfolioTest.scala    From Waves   with MIT License 5 votes vote down vote up
package com.wavesplatform.state

import java.nio.charset.StandardCharsets

import cats._
import com.wavesplatform.TestValues
import com.wavesplatform.common.state.ByteStr
import com.wavesplatform.transaction.Asset.IssuedAsset
import org.scalatest.{FunSuite, Matchers}

class PortfolioTest extends FunSuite with Matchers {
  test("pessimistic - should return only withdraws") {
    val Seq(fooKey, barKey, bazKey) = Seq("foo", "bar", "baz").map(x => IssuedAsset(ByteStr(x.getBytes(StandardCharsets.UTF_8))))

    val orig = Portfolio(
      balance = -10,
      lease = LeaseBalance(
        in = 11,
        out = 12
      ),
      assets = Map(
        fooKey -> -13,
        barKey -> 14,
        bazKey -> 0
      )
    )

    val p = orig.pessimistic
    p.balance shouldBe orig.balance
    p.lease.in shouldBe 0
    p.lease.out shouldBe orig.lease.out
    p.assets(fooKey) shouldBe orig.assets(fooKey)
    p.assets shouldNot contain(barKey)
    p.assets shouldNot contain(bazKey)
  }

  test("pessimistic - positive balance is turned into zero") {
    val orig = Portfolio(
      balance = 10,
      lease = LeaseBalance(0, 0),
      assets = Map.empty
    )

    val p = orig.pessimistic
    p.balance shouldBe 0
  }

  test("prevents overflow of assets") {
    val assetId = TestValues.asset
    val arg1    = Portfolio(0L, LeaseBalance.empty, Map(assetId -> (Long.MaxValue - 1L)))
    val arg2    = Portfolio(0L, LeaseBalance.empty, Map(assetId -> (Long.MaxValue - 2L)))
    Monoid.combine(arg1, arg2).assets(assetId) shouldBe Long.MinValue
  }
} 
Example 73
Source File: TimeSeriesUtilsSpec.scala    From cuttle   with Apache License 2.0 5 votes vote down vote up
package com.criteo.cuttle.timeseries.intervals

import org.scalatest.FunSuite
import de.sciss.fingertree.FingerTree

import com.criteo.cuttle.timeseries._
import com.criteo.cuttle.timeseries.JobState.{Done, Todo}
import com.criteo.cuttle.timeseries.TimeSeriesUtils.State
import com.criteo.cuttle.{Job, TestScheduling}

class TimeSeriesUtilsSpec extends FunSuite with TestScheduling {
  private val scheduling: TimeSeries = hourly(date"2017-03-25T02:00:00Z")
  private val jobA = Job("job_a", scheduling)(completed)
  private val jobB = Job("job_b", scheduling)(completed)

  test("clean overlapping state intervals") {
    val state: State = Map(
      jobA -> new IntervalMap.Impl(
        FingerTree(
          Interval(date"2017-03-25T01:00:00Z", date"2017-03-25T02:00:00Z") -> Done("v1"),
          Interval(date"2017-03-25T02:00:00Z", date"2017-03-25T04:00:00Z") -> Done("v2"),
          Interval(date"2017-03-25T04:00:00Z", date"2017-03-25T06:00:00Z") -> Todo(None),
          // Interval overlapping with previously defined interval
          Interval(date"2017-03-25T04:00:00Z", date"2017-03-25T05:00:00Z") -> Todo(None)
        )
      ),
      jobB -> new IntervalMap.Impl(
        FingerTree(
          Interval(date"2017-03-25T02:00:00Z", date"2017-03-25T03:00:00Z") -> Todo(None),
          Interval(date"2017-03-25T03:00:00Z", date"2017-03-25T04:00:00Z") -> Done("v2"),
          Interval(date"2017-03-25T04:30:00Z", date"2017-03-25T05:00:00Z") -> Todo(None),
          // Interval contiguous to the previous Done interval for jobB
          Interval(date"2017-03-25T04:00:00Z", date"2017-03-25T04:30:00Z") -> Done("v2")
        )
      )
    )

    TimeSeriesUtils.cleanTimeseriesState(state).foreach {
      case (job, intervalMap) =>
        job.id match {
          case jobA.id =>
            assert(
              intervalMap.toList.toSet.equals(
                Set(
                  (Interval(date"2017-03-25T01:00:00Z", date"2017-03-25T02:00:00Z"), Done("v1")),
                  (Interval(date"2017-03-25T02:00:00Z", date"2017-03-25T04:00:00Z"), Done("v2")),
                  (Interval(date"2017-03-25T04:00:00Z", date"2017-03-25T06:00:00Z"), Todo(None))
                )
              )
            )
          case jobB.id =>
            assert(
              intervalMap.toList.toSet.equals(
                Set(
                  (Interval(date"2017-03-25T02:00:00Z", date"2017-03-25T03:00:00Z"), Todo(None)),
                  (Interval(date"2017-03-25T03:00:00Z", date"2017-03-25T04:30:00Z"), Done("v2")),
                  (Interval(date"2017-03-25T04:30:00Z", date"2017-03-25T05:00:00Z"), Todo(None))
                )
              )
            )
        }
    }
  }
} 
Example 74
Source File: WorkflowSpec.scala    From cuttle   with Apache License 2.0 5 votes vote down vote up
package com.criteo.cuttle.timeseries

import com.criteo.cuttle._
import scala.concurrent.Future

import org.scalatest.FunSuite

class WorkflowSpec extends FunSuite {

  val testScheduling = hourly(start = date"2018-01-01T00:00:00Z")
  val void = (_: Execution[_]) => Future.successful(Completed)

  test("We should build a valid DAG for our workflow") {
    val jobs = Vector.tabulate(4)(i => Job(i.toString, testScheduling)(void))
    val graph = (jobs(1) and jobs(2)) dependsOn jobs(0) dependsOn jobs(3)

    assert(graph.vertices.size == 4)
    assert(graph.edges.size == 3)
  }

  test("Serialize workflow DAG in linear representation should throw an exception when DAG has a cycle") {
    val job1 = Job("job1", testScheduling)(void)
    val job2 = Job("job2", testScheduling)(void)
    val job3 = Job("job3", testScheduling)(void)

    val job = Job("job", testScheduling)(void)

    val workflow = job dependsOn (job1 and job2) dependsOn job3 dependsOn job1

    intercept[IllegalArgumentException] {
      workflow.jobsInOrder
    }
  }

  test("Serialize workflow DAG in linear representation should be ok without cycles") {
    val job1 = Job("job1", testScheduling)(void)
    val job2 = Job("job2", testScheduling)(void)
    val job3 = Job("job3", testScheduling)(void)

    val job = Job("job", testScheduling)(void)

    val workflow = job dependsOn (job1 and job2) dependsOn job3

    assert(workflow.jobsInOrder.size === 4)
  }

  test("Strongly connected component identification") {
    val job1 = Job("job1", testScheduling)(void)
    val job2 = Job("job2", testScheduling)(void)
    val job3 = Job("job3", testScheduling)(void)
    val job4 = Job("job4", testScheduling)(void)
    val job5 = Job("job5", testScheduling)(void)

    val job = Job("job", testScheduling)(void)
    val singletonJob = Job("singleton_job", testScheduling)(void)

    val cycle1 = job dependsOn job2 dependsOn job4 dependsOn job
    val cycle2 = job1 dependsOn job5 dependsOn job1
    var workflow = cycle1 and cycle2 and singletonJob
    workflow = new Workflow {
      val vertices = workflow.vertices
      val edges = workflow.edges ++ Set((job, job1, defaultDependencyDescriptor))
    }
    workflow = new Workflow {
      val vertices = workflow.vertices + job3
      val edges = workflow.edges ++ Set((job2, job3, defaultDependencyDescriptor))
    }
    workflow = new Workflow {
      val vertices = workflow.vertices
      val edges = workflow.edges ++ Set((job1, job3, defaultDependencyDescriptor))
    }

    //   ┌───────→ job3
    //   │           ↑
    //   │           |
    //   ├─→ job ← job2 ← job4     singletonJob
    //   │    |             ↑
    //   │    └─────────────┘
    //   │
    //   └── job1 ← job5
    //        |       ↑
    //        └───────┘

    val SCCs = graph.findStronglyConnectedComponents[Job[TimeSeries]](
      workflow.vertices,
      workflow.edges.map { case (child, parent, _) => parent -> child }
    )
    assert(SCCs.size === 4)
    assert(SCCs.filter(_.size == 1).toSet === Set(List(singletonJob), List(job3)))
    assert(SCCs.find(_.size == 2).get.toSet === Set(job1, job5))
    assert(SCCs.find(_.size == 3).get.toSet === Set(job, job2, job4))
  }
} 
Example 75
Source File: IntervalMapSpec.scala    From cuttle   with Apache License 2.0 5 votes vote down vote up
package com.criteo.cuttle.timeseries
package intervals

import cats.implicits._

import org.scalatest.FunSuite

class IntervalMapSpec extends FunSuite {

  implicit def measureBuilder[A: Ordering, B]: MeasureKey[Interval[A], B] = measure

  test("intervals") {
    assert(
      IntervalMap(Interval(0, 3) -> 42, Interval(3, 5) -> 12) ==
        IntervalMap(Interval(0, 3) -> 42, Interval(3, 5) -> 12)
    )
  }

  test("broken intervals") {
    intercept[Exception](
      IntervalMap(Interval(0, 3) -> 42, Interval(3, 5) -> 42)
    )
  }

  test("broken intervals 2") {
    intercept[Exception](
      IntervalMap(Interval(0, 3) -> 42, Interval(2, 5) -> 13)
    )
  }

  test("merge intervals") {
    val intervals = IntervalMap(Interval(0, 3) -> 42, Interval(3, 5) -> 12)
    assert(
      intervals.map(_ => 1) ==
        IntervalMap(Interval(0, 5) -> 1)
    )
  }

  test("whenUndef") {
    assert(
      IntervalMap(Interval(0, 3) -> 42)
        .whenIsUndef(IntervalMap(Interval(1, 2) -> "foo", Interval(2, 3) -> "bar")) ==
        IntervalMap(Interval(0, 1) -> 42)
    )
  }
} 
Example 76
Source File: DatabaseSuite.scala    From cuttle   with Apache License 2.0 5 votes vote down vote up
package com.criteo.cuttle

import cats.effect.IO
import doobie.implicits._
import doobie.util.log
import org.scalatest.{BeforeAndAfter, FunSuite}

class DatabaseSuite extends FunSuite with BeforeAndAfter {
  val dbName = "cuttle_it_test"

  implicit val logger: Logger = new Logger {
    override def debug(message: => String): Unit = ()
    override def info(message: => String): Unit = ()
    override def warn(message: => String): Unit = ()
    override def error(message: => String): Unit = ()
    override def trace(message: => String): Unit = ()
  }

  val queries: Queries = Queries(logger)

  private val dbConfig = DatabaseConfig(
    Seq(DBLocation("localhost", 3388)),
    "sys",
    "root",
    ""
  )

  // service transactor is used for schema creation
  private val serviceTransactor: doobie.Transactor[IO] =
    Database.newHikariTransactor(dbConfig).allocated.unsafeRunSync()._1

  private implicit val logHandler: log.LogHandler = DoobieLogsHandler(logger).handler

  private def createDatabaseIfNotExists(): Unit =
    sql"CREATE DATABASE IF NOT EXISTS cuttle_it_test".update.run.transact(serviceTransactor).unsafeRunSync()

  private def clean(): Unit =
    sql"DROP DATABASE IF EXISTS cuttle_it_test".update.run.transact(serviceTransactor).unsafeRunSync()

  before {
    clean()
    createDatabaseIfNotExists()
  }
} 
Example 77
Source File: RabbitmqSpec.scala    From ez-framework   with Apache License 2.0 5 votes vote down vote up
package com.ecfront.ez.framework.cluster.rabbitmq

import java.util.concurrent.CountDownLatch
import java.util.concurrent.atomic.AtomicLong

import com.ecfront.ez.framework.core.logger.Logging
import com.rabbitmq.client.AMQP.BasicProperties
import com.rabbitmq.client.{ConnectionFactory, QueueingConsumer}
import org.scalatest.{BeforeAndAfter, FunSuite}


class RabbitmqSpec extends FunSuite with BeforeAndAfter with Logging {

  test("rabbitmq test") {
    val p = new AtomicLong(0)
    val c = new AtomicLong(0)

    val factory = new ConnectionFactory()
    factory.setUsername("user")
    factory.setPassword("password")
    factory.setHost("127.0.0.1")
    val connection = factory.newConnection()
    // produce
    val produceThreads = for (i <- 0 until 50)
      yield new Thread(new Runnable {
        override def run(): Unit = {
          val channel = connection.createChannel()
          val replyQueueName = channel.queueDeclare().getQueue
          val replyConsumer = new QueueingConsumer(channel)
          channel.basicConsume(replyQueueName, true, replyConsumer)
          val corrId = java.util.UUID.randomUUID().toString
          val opt = new BasicProperties.Builder().correlationId(corrId).replyTo(replyQueueName).build()
          channel.basicPublish("", "a", opt, s"test${p.incrementAndGet()}".getBytes())
          var delivery = replyConsumer.nextDelivery()
          while (true) {
            if (delivery.getProperties.getCorrelationId.equals(corrId)) {
              logger.info(s"reply " + new String(delivery.getBody))
            }
            delivery = replyConsumer.nextDelivery()
          }
          channel.close()
        }
      })
    produceThreads.foreach(_.start())

    // consumer
    new Thread(new Runnable {
      override def run(): Unit = {
        val channel = connection.createChannel()
        channel.queueDeclare("a", false, false, false, null)
        val consumer = new QueueingConsumer(channel)
        channel.basicConsume("a", true, consumer)
        while (true) {
          val delivery = consumer.nextDelivery()
          val props = delivery.getProperties()
          val message = new String(delivery.getBody())
          new Thread(new Runnable {
            override def run(): Unit = {
              Thread.sleep(10000)
              logger.info(s"receive 1 [${c.incrementAndGet()}] " + message)
              channel.basicPublish("", props.getReplyTo(), new BasicProperties.Builder().correlationId(props.getCorrelationId()).build(), message.getBytes)
            }
          }).start()
        }
      }
    }).start()
   

    new CountDownLatch(1).await()
  }
} 
Example 78
Source File: KafkaConfigSpec.scala    From awesome-recommendation-engine   with Apache License 2.0 5 votes vote down vote up
import example.utils.KafkaConfig
import org.scalatest.FunSuite

class KafkaConfigSpec extends FunSuite {
  val config = new KafkaConfig {}

  test("Consumer config should be read") {
    assert(config.getProperty("group.id") == "1234")
    assert(config.getProperty("zookeeper.connect") == "localhost:2821")
  }

  test("example.producer.Producer config should be read") {
    assert(config.getProperty("metadata.broker.list") == "broker1:9092,broker2:9092")
    assert(config.getProperty("serializer.class") == "kafka.serializer.StringEncoder")
    assert(config.getProperty("partitioner.class") == "example.producer.SimplePartitioner")
    assert(config.getProperty("request.required.acks") == "1")
  }

  test("Missing keys should be null") {
    assert(config.getProperty("some.other.key") == null)
  }
} 
Example 79
Source File: lmPredict$Test.scala    From sparkGLM   with Apache License 2.0 5 votes vote down vote up
package com.Alteryx.sparkGLM

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.functions._
import org.scalatest.FunSuite
import com.Alteryx.testUtils.data.testData._

class lmPredict$Test extends FunSuite {
  val sqlCtx = TestSQLContext

  test("lmPredict with a single partition") {
    val testDF = testDFSinglePart
    val x = testDF.select("intercept", "x")
    val y = testDF.select("y")
    val lmTest = LM.fit(x, y)
    val predicted = lmTest.predict(x)

    assert(predicted.getClass.getName == "org.apache.spark.sql.DataFrame")
    assert(predicted.rdd.partitions.size == 1)
    assert(predicted.columns.size == 2)
    assert(predicted.agg(max("index")).collect.apply(0).get(0) == 49)
  }

  test("lmPredict with multiple partitions") {
    val testDF = testDFMultiPart
    val x = testDF.select("intercept", "x")
    val y = testDF.select("y")
    val lmTest = LM.fit(x, y)
    val predicted = lmTest.predict(x)

    assert(predicted.getClass.getName == "org.apache.spark.sql.DataFrame")
    assert(predicted.rdd.partitions.length == 4)
    assert(predicted.columns.length == 2)
    assert(predicted.agg(max("index")).collect.apply(0).get(0) == 49)
  }
} 
Example 80
Source File: modelMatrix$Test.scala    From sparkGLM   with Apache License 2.0 5 votes vote down vote up
package com.Alteryx.sparkGLM

import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.FunSuite
import com.Alteryx.testUtils.data.testData._

class modelMatrix$Test extends FunSuite {
  val sqlCtx = TestSQLContext

  test("modelMatrixWithMixedTypes") {
    val testDF = modelMatrix(dummyDF)
    assert(testDF.columns.length == 4)
    val expectedCols = Array("intField", "strField_b", "strField_c", "numField")
    assert(expectedCols.forall { elem =>
      testDF.columns.contains(elem)
    })
    assert(testDF.dtypes.forall(_._2 == "DoubleType"))
  }

  test("modelMatrixWithNumOnly") {
    val testDF = modelMatrix(dummyDF.select("numField", "intField"))
    assert(testDF.columns.length == 2)
    val expectedCols = Array("numField", "intField")
    assert(expectedCols.forall { elem =>
      testDF.columns.contains(elem)
    })
    assert(testDF.dtypes.forall(_._2 == "DoubleType"))
  }

  test("modelMatrixWithStrOnly") {
    val testDF = modelMatrix(dummyDF.select("strField"))
    assert(testDF.columns.length == 2)
    val expectedCols = Array("strField_b", "strField_c")
    assert(expectedCols.forall { elem =>
      testDF.columns.contains(elem)
    })
    assert(testDF.dtypes.forall(_._2 == "DoubleType"))
  }

  test("modelMatrixLinearRegData") {
    val rawDF = mixedDF.select("intercept", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "y")
    val testDF = modelMatrix(rawDF)
    assert(testDF.columns.length == 10)
    val expectedCols = Array("intercept", "x1", "x2", "x3", "x4", "x5", "x6", "x7_b", "x7_c", "y")
    assert(expectedCols.forall { elem =>
      testDF.columns.contains(elem)
    })
    assert(testDF.dtypes.forall(_._2 == "DoubleType"))
  }
} 
Example 81
Source File: utils$Test.scala    From sparkGLM   with Apache License 2.0 5 votes vote down vote up
package com.Alteryx.sparkGLM

import org.scalatest.FunSuite
import org.apache.spark.sql.test.TestSQLContext
import com.Alteryx.testUtils.data.testData._

class utils$Test extends FunSuite {
  val sqlCtx = TestSQLContext

  test("matchCols") {
    val df = modelMatrix(dummyDF)
    val dfWithMissingCategory = modelMatrix(oneLessCategoryDF)

    val testDF = utils.matchCols(df, dfWithMissingCategory)
    assert(testDF.getClass.getName == "org.apache.spark.sql.DataFrame")
    assert(testDF.columns.length == 4)
    assert(testDF.dtypes.forall(_._2 == "DoubleType"))
    val expectedCols = Array("intField", "strField_b", "strField_c", "numField")
    assert(expectedCols.forall { elem =>
      testDF.columns.contains(elem)
    })
    assert(testDF.select("strField_c").distinct.count == 1)
    assert(testDF.select("strField_c").distinct.collect().apply(0).get(0) === 0)
  }
} 
Example 82
Source File: SparkFunSuite.scala    From spark-gbtlr   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

// scalastyle:off
import java.io.File

import org.apache.spark.internal.Logging
import org.apache.spark.util.AccumulatorContext
import org.scalatest.{BeforeAndAfterAll, FunSuite, Outcome}


  final protected override def withFixture(test: NoArgTest): Outcome = {
    val testName = test.text
    val suiteName = this.getClass.getName
    val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s")
    try {
      logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n")
      test()
    } finally {
      logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n")
    }
  }

} 
Example 83
Source File: LagSmpFactorySuite.scala    From lagraph   with Apache License 2.0 5 votes vote down vote up
package com.ibm.lagraph.impl

import org.scalatest.FunSuite
import org.scalatest.Matchers
import scala.reflect.ClassTag
import scala.collection.mutable.{Map => MMap}
import com.ibm.lagraph._

class LagSmpFactorySuite extends FunSuite with Matchers {

  test("LagSmpContext.vIndices") {
    val nv = 8
    val hc: LagContext = LagContext.getLagSmpContext(nv)
    val start = 2
    val end = start + nv
    val v = hc.vIndices(start)
    val vRes = hc.vToVector(v)
//    assert(vRes.isInstanceOf[Vector[Int]])
    assert(vRes.size == (end - start))
    (start until end).map { r =>
      assert(vRes(r - start) == r)
    }
  }
  test("LagSmpContext.vReplicate") {
    val nv = 10
    val hc: LagContext = LagContext.getLagSmpContext(nv)
    val singleValue: Double = 99.0
    val v = hc.vReplicate(singleValue)
    val vRes = hc.vToVector(v)
    assert(vRes.isInstanceOf[Vector[Double]])
    assert(vRes.size == nv)
    (0 until nv).map { r =>
      assert(vRes(r) == singleValue)
    }
  }
  test("LagSmpContext.mIndices") {
    val nv = 8
    val hc: LagContext = LagContext.getLagSmpContext(nv)
    val start = (2L, 2L)
    val end = (start._1 + nv, start._2 + nv)
    val m = hc.mIndices(start)
    val (vvm, vvs) = hc.mToMap(m)
    val mRes = LagContext.vectorOfVectorFromMap(vvm, vvs, (nv, nv))
    assert(mRes.size == (end._1 - start._1))
    mRes.zipWithIndex.map {
      case (vr, r) => {
        assert(vr.size == (end._2 - start._2))
        vr.zipWithIndex.map {
          case (vc, c) => assert(vc == (start._1 + r, start._2 + c))
        }
      }
    }
  }
  test("LagSmpContext.mReplicate") {
    val nv = 10
    val hc: LagContext = LagContext.getLagSmpContext(nv)
    val size = (nv, nv)
    val singleValue: Double = 99.0
    val m = hc.mReplicate(singleValue)
    val (vvm, vvs) = hc.mToMap(m)
    val mRes = LagContext.vectorOfVectorFromMap(vvm, vvs, (nv, nv))
    assert(mRes.size == size._1)
    mRes.zipWithIndex.map {
      case (vr, r) => {
        assert(vr.size == size._2)
        vr.zipWithIndex.map {
          case (vc, c) => assert(vc == singleValue)
        }
      }
    }
  }
} 
Example 84
Source File: LagDstrFactorySuite.scala    From lagraph   with Apache License 2.0 5 votes vote down vote up
package com.ibm.lagraph.impl
// TODO get rid of printlns
// scalastyle:off println

import com.holdenkarau.spark.testing.SharedSparkContext
import org.scalatest.FunSuite
import org.scalatest.Matchers
import scala.reflect.ClassTag
import scala.collection.mutable.{Map => MMap}
import com.ibm.lagraph._

class LagDstrFactorySuite extends FunSuite with Matchers with SharedSparkContext {
  val DEBUG = false

  val denseGraphSizes = List(1 << 4, 1 << 5)
  //  val sparseGraphSizes = List(1 << 16, 1 << 17, 1 << 29, 1 << 30)
  val sparseGraphSizes = List(1 << 16, 1 << 17, 1 << 26, 1 << 27)
  val nblocks = List(1 << 0, 1 << 1, 1 << 2, 1 << 3)

  test("test initializing spark context") {
    val hc: LagContext = LagContext.getLagDstrContext(sc, 1 << 3, 1)
    val list = nblocks
    val rdd = sc.parallelize(list)
    assert(rdd.count === list.length)
  }

  test("LagDstrContext.vIndices") {
    for (graphSize <- denseGraphSizes) {
      for (nblock <- nblocks) {
        if (DEBUG) println("LagDstrContext.vIndices", graphSize, nblock)
        val hc: LagContext = LagContext.getLagDstrContext(sc, graphSize, nblock)
        val start = 2
        val end = start + hc.graphSize
        val v = hc.vIndices(start)
        val vRes = hc.vToVector(v)
        assert(v.size == hc.graphSize)
        assert(vRes.size == (end - start))
        (start until end.toInt).map { r =>
          assert(vRes(r - start) == r)
        }
      }
    }
  }

  test("LagDstrContext.mIndices") {
    for (graphSize <- denseGraphSizes) {
      for (nblock <- nblocks) {
        if (DEBUG) println("LagDstrContext.mIndices", graphSize, nblock)
        val hc: LagContext = LagContext.getLagDstrContext(sc, graphSize, nblock)
        val start = (2L, 2L)
        val m = hc.mIndices(start)
        val (mResMap, sparseValue) = hc.mToMap(m)
        val mRes =
          LagContext.vectorOfVectorFromMap(mResMap, sparseValue, m.size)
        val end = (start._1 + graphSize, start._2 + graphSize)
        assert(mRes.size == (end._1 - start._1))
        mRes.zipWithIndex.map {
          case (vr, r) => {
            assert(vr.size == (end._2 - start._2))
            vr.zipWithIndex.map {
              case (vc, c) => assert(vc == (start._1 + r, start._2 + c))
            }
          }
        }
      }
    }
  }
  test("LagDstrContext.mReplicate") {
    for (graphSize <- denseGraphSizes) {
      for (nblock <- nblocks) {
        if (DEBUG) println("LagDstrContext.mReplicate", graphSize, nblock)
        val hc: LagContext = LagContext.getLagDstrContext(sc, graphSize, nblock)
        val singleValue: Double = 99.0
        val m = hc.mReplicate(singleValue)
        val (mResMap, sparseValue) = hc.mToMap(m)
        val mRes =
          LagContext.vectorOfVectorFromMap(mResMap, sparseValue, m.size)
        mRes.zipWithIndex.map {
          case (vr, r) => {
            assert(vr.size == graphSize)
            vr.zipWithIndex.map {
              case (vc, c) => assert(vc == singleValue)
            }
          }
        }
      }
    }
  }
}
// scalastyle:on println 
Example 85
Source File: reexporttests.scala    From export-hook   with Apache License 2.0 5 votes vote down vote up
package tcuser

import org.scalatest.FunSuite

import adtdefns._, autodefns._, tca._, tcb._

class ReexportTests extends FunSuite {
  test("Single reexport") {
    import single._

    assert(TcA[Int].describe === "TcA[Int]")
    assert(TcA[Foo].describe === "TcA[Foo]")
    assert(TcA[Boolean].describe === "TcA[Boolean]")
    assert(TcA[Bar].describe === "TcA[Bar]")
    assert(TcA[Quux].describe === "Default TcA[T]")

    assert(TcB[Int].describe === "TcB[Int]")
    assert(TcB[Foo].describe === "Default TcB[T]")
    assert(TcB[Boolean].describe === "Default TcB[T]")
    assert(TcB[Bar].describe === "Default TcB[T]")
    assert(TcB[Quux].describe === "Default TcB[T]")
  }

  test("Multi class reexport") {
    import twoclasses._

    assert(TcA[Int].describe === "TcA[Int]")
    assert(TcA[Foo].describe === "TcA[Foo]")
    assert(TcA[Boolean].describe === "Default TcA[T]")
    assert(TcA[Bar].describe === "gen(TcA[Int] :: Default TcA[T] :: HNil)")
    assert(TcA[Quux].describe ===
      "gen(TcA[Foo] :: gen(TcA[Int] :: Default TcA[T] :: HNil) :: TcA[Baz] :: HNil)")

    assert(TcB[Int].describe === "TcB[Int]")
    assert(TcB[Foo].describe === "gen(TcB[Int] :: HNil)")
    assert(TcB[Boolean].describe === "Default TcB[T]")
    assert(TcB[Bar].describe === "gen(TcB[Int] :: Default TcB[T] :: HNil)")
    assert(TcB[Quux].describe ===
      "gen(gen(TcB[Int] :: HNil) :: gen(TcB[Int] :: Default TcB[T] :: HNil) :: gen(Default TcB[T] :: HNil) :: HNil)")
  }

  test("Multi priority reexport") {
    import twopriorities._

    assert(TcA[Int].describe === "TcA[Int]")
    assert(TcA[Foo].describe === "TcA[Foo]")
    assert(TcA[Boolean].describe === "TcA[Boolean]")
    assert(TcA[Bar].describe === "TcA[Bar]")
    assert(TcA[Quux].describe ===
      "gen(TcA[Foo] :: TcA[Bar] :: TcA[Baz] :: HNil)")

    assert(TcB[Int].describe === "TcB[Int]")
    assert(TcB[Foo].describe === "Default TcB[T]")
    assert(TcB[Boolean].describe === "Default TcB[T]")
    assert(TcB[Bar].describe === "Default TcB[T]")
    assert(TcB[Quux].describe === "Default TcB[T]")
  }
} 
Example 86
Source File: SparkFunSuite.scala    From click-through-rate-prediction   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

// scalastyle:off
import org.scalatest.{FunSuite, Outcome}


  final protected override def withFixture(test: NoArgTest): Outcome = {
    val testName = test.text
    val suiteName = this.getClass.getName
    val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s")
    try {
      logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n")
      test()
    } finally {
      logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n")
    }
  }

} 
Example 87
Source File: SHC.scala    From shc   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.execution.datasources.hbase.Logging

import java.io.File

import com.google.common.io.Files
import org.apache.hadoop.hbase.client.Table
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.hbase.{HBaseTestingUtility, TableName}
import org.apache.spark.sql.execution.datasources.hbase.SparkHBaseConf
import org.apache.spark.{SparkContext, SparkConf}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}

class SHC  extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll  with Logging {
  implicit class StringToColumn(val sc: StringContext) {
    def $(args: Any*): ColumnName = {
      new ColumnName(sc.s(args: _*))
    }
  }

  var spark: SparkSession = null
  var sc: SparkContext = null
  var sqlContext: SQLContext = null
  var df: DataFrame = null

  private[spark] var htu = new HBaseTestingUtility
  private[spark] def tableName = "table1"

  private[spark] def columnFamilies: Array[String] = Array.tabulate(9){ x=> s"cf$x"}
  var table: Table = null
  val conf = new SparkConf
  conf.set(SparkHBaseConf.testConf, "true")
  // private[spark] var columnFamilyStr = Bytes.toString(columnFamily)

  def defineCatalog(tName: String) = s"""{
                                         |"table":{"namespace":"default", "name":"$tName"},
                                         |"rowkey":"key",
                                         |"columns":{
                                              |"col0":{"cf":"rowkey", "col":"key", "type":"string"},
                                              |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"},
                                              |"col2":{"cf":"cf2", "col":"col2", "type":"double"},
                                              |"col3":{"cf":"cf3", "col":"col3", "type":"float"},
                                              |"col4":{"cf":"cf4", "col":"col4", "type":"int"},
                                              |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"},
                                              |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"},
                                              |"col7":{"cf":"cf7", "col":"col7", "type":"string"},
                                              |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"}
                                            |}
                                         |}""".stripMargin

  @deprecated(since = "04.12.2017(dd/mm/year)", message = "use `defineCatalog` instead")
  def catalog = defineCatalog(tableName)

  override def beforeAll() {
    val tempDir: File = Files.createTempDir
    tempDir.deleteOnExit
    htu.startMiniCluster
    SparkHBaseConf.conf = htu.getConfiguration
    logInfo(" - minicluster started")
    println(" - minicluster started")

    spark = SparkSession.builder()
      .master("local")
      .appName("HBaseTest")
      .config(conf)
      .getOrCreate()

    sqlContext = spark.sqlContext
    sc = spark.sparkContext
  }

  override def afterAll() {
    htu.shutdownMiniCluster()
    spark.stop()
  }

  def createTable(name: String, cfs: Array[String]) {
    val tName = Bytes.toBytes(name)
    val bcfs = cfs.map(Bytes.toBytes(_))
    try {
      htu.deleteTable(TableName.valueOf(tName))
    } catch {
      case _ : Throwable =>
        logInfo(" - no table " + name + " found")
    }
    htu.createMultiRegionTable(TableName.valueOf(tName), bcfs)
  }


  def createTable(name: Array[Byte], cfs: Array[Array[Byte]]) {
    try {
      htu.deleteTable(TableName.valueOf(name))
    } catch {
      case _ : Throwable =>
        logInfo(" - no table " + Bytes.toString(name) + " found")
    }
    htu.createMultiRegionTable(TableName.valueOf(name), cfs)
  }
} 
Example 88
Source File: HBaseTestSuite.scala    From shc   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import java.io.File

import scala.collection.JavaConverters._

import com.google.common.io.Files
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility}
import org.apache.spark.sql.execution.datasources.hbase.Logging
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}

class HBaseTestSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll  with Logging {
  private[spark] var htu = HBaseTestingUtility.createLocalHTU()
  private[spark] var tableName: Array[Byte] = Bytes.toBytes("t1")
  private[spark] var columnFamily: Array[Byte] = Bytes.toBytes("cf0")
  private[spark] var columnFamilies: Array[Array[Byte]] =
    Array(Bytes.toBytes("cf0"), Bytes.toBytes("cf1"), Bytes.toBytes("cf2"), Bytes.toBytes("cf3"), Bytes.toBytes("cf4"))
  var table: Table = null
  // private[spark] var columnFamilyStr = Bytes.toString(columnFamily)

  override def beforeAll() {
    val tempDir: File = Files.createTempDir
    tempDir.deleteOnExit
    htu.cleanupTestDir
    htu.startMiniZKCluster
    htu.startMiniHBaseCluster(1, 4)
    logInfo(" - minicluster started")
    println(" - minicluster started")
    try {
      htu.deleteTable(TableName.valueOf(tableName))

      //htu.createTable(TableName.valueOf(tableName), columnFamily, 2, Bytes.toBytes("abc"), Bytes.toBytes("xyz"), 2)
    } catch {
      case _ : Throwable =>
        logInfo(" - no table " + Bytes.toString(tableName) + " found")
    }
    setupTable()
  }



  override def afterAll() {
    try {
      table.close()
      println("shutdown")
      htu.deleteTable(TableName.valueOf(tableName))
      logInfo("shuting down minicluster")
      htu.shutdownMiniHBaseCluster
      htu.shutdownMiniZKCluster
      logInfo(" - minicluster shut down")
      htu.cleanupTestDir
    } catch {
      case _ : Throwable => logError("teardown error")
    }
  }

  def setupTable() {
    val config = htu.getConfiguration
    htu.createMultiRegionTable(TableName.valueOf(tableName), columnFamilies)
    println("create htable t1")
    val connection = ConnectionFactory.createConnection(config)
    val r = connection.getRegionLocator(TableName.valueOf("t1"))
    table = connection.getTable(TableName.valueOf("t1"))

    val regionLocations = r.getAllRegionLocations.asScala.toSeq
    println(s"$regionLocations size: ${regionLocations.size}")
    (0 until 100).foreach { x =>
      var put = new Put(Bytes.toBytes(s"row$x"))
      (0 until 5).foreach { y =>
        put.addColumn(columnFamilies(y), Bytes.toBytes(s"c$y"), Bytes.toBytes(s"value $x $y"))
      }
      table.put(put)
    }
  }
} 
Example 89
Source File: CatalogSuite.scala    From shc   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.sql.execution.datasources.hbase.Logging
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
import org.apache.spark.sql.execution.datasources.hbase.HBaseTableCatalog

class CatalogSuite  extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll  with Logging{
  def catalog = s"""{
            |"table":{"namespace":"default", "name":"table1", "tableCoder":"PrimitiveType"},
            |"rowkey":"key1:key2",
            |"columns":{
              |"col00":{"cf":"rowkey", "col":"key1", "type":"string", "length":"6"},
              |"col01":{"cf":"rowkey", "col":"key2", "type":"int"},
              |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"},
              |"col2":{"cf":"cf2", "col":"col2", "type":"double"},
              |"col3":{"cf":"cf3", "col":"col3", "type":"float"},
              |"col4":{"cf":"cf4", "col":"col4", "type":"int"},
              |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"},
              |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"},
              |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"},
              |"col7":{"cf":"cf7", "col":"col7", "type":"string"}
            |}
          |}""".stripMargin

  test("Catalog meta data check") {
    val m = HBaseTableCatalog(Map(HBaseTableCatalog.tableCatalog->catalog))
    assert(m.row.fields.filter(_.length == -1).isEmpty)
    assert(m.row.length == 10)
  }

  test("Catalog should preserve the columns order") {
    val m = HBaseTableCatalog(Map(HBaseTableCatalog.tableCatalog->catalog))
    assert(m.toDataType.fields.map(_.name).sameElements(
      Array("col00", "col01", "col1", "col2", "col3", "col4", "col5", "col6", "col8", "col7")))
  }
} 
Example 90
Source File: ToxKilledExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.exceptions

import im.tox.tox4j.impl.jni.ToxCoreImplFactory.withToxUnit
import org.scalatest.FunSuite

final class ToxKilledExceptionTest extends FunSuite {

  test("UseAfterCloseInOrder") {
    intercept[ToxKilledException] {
      withToxUnit { tox1 =>
        withToxUnit { tox2 =>
          tox1.close()
          tox1.iterationInterval
        }
      }
    }
  }

  test("UseAfterCloseReverseOrder") {
    intercept[ToxKilledException] {
      withToxUnit { tox1 =>
        withToxUnit { tox2 =>
          tox2.close()
          tox2.iterationInterval
        }
      }
    }
  }

} 
Example 91
Source File: BadInstanceNumberTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.impl.jni

import im.tox.tox4j.core.options.ToxOptions
import im.tox.tox4j.exceptions.ToxKilledException
import org.scalacheck.Gen
import org.scalatest.FunSuite
import org.scalatest.prop.PropertyChecks


@SuppressWarnings(Array("org.wartremover.warts.Equals"))
final class BadInstanceNumberTest extends FunSuite with PropertyChecks {

  private def callWithInstanceNumber(instanceNumber: Int): Unit = {
    val tox = new ToxCoreImpl(ToxOptions())

    val field = tox.getClass.getDeclaredField("instanceNumber")
    field.setAccessible(true)
    val oldInstanceNumber = field.get(tox).asInstanceOf[Int]
    field.set(tox, instanceNumber)

    val exception =
      try {
        tox.iterationInterval
        null
      } catch {
        case e: Throwable => e
      }

    // Set it back to the good one, so close() works.
    field.set(tox, oldInstanceNumber)
    tox.close()

    if (exception != null) {
      throw exception
    }
  }

  test("negative or zero instance numbers") {
    forAll(Gen.choose(Int.MinValue, 0)) { instanceNumber =>
      intercept[IllegalStateException] {
        callWithInstanceNumber(instanceNumber)
      }
    }
  }

  test("very large instance numbers") {
    forAll(Gen.choose(0xffff, Int.MaxValue)) { instanceNumber =>
      intercept[IllegalStateException] {
        callWithInstanceNumber(instanceNumber)
      }
    }
  }

  test("any invalid instance numbers") {
    // This could be fine if there is another Tox instance lingering around, but we assume there isn't.
    // So, it's either killed (ToxKilledException) or never existed (IllegalStateException).
    System.gc() // After this, there should be no lingering instances.

    forAll { (instanceNumber: Int) =>
      whenever(instanceNumber != 1) {
        try {
          callWithInstanceNumber(instanceNumber)
          fail("No exception thrown. Expected IllegalStateException or ToxKilledException.")
        } catch {
          case _: IllegalStateException =>
          case _: ToxKilledException    => // Both fine.
        }
      }
    }
  }

} 
Example 92
Source File: NamingConventionsTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.impl.jni

import org.scalatest.FunSuite

abstract class NamingConventionsTest(jniClass: Class[_], traitClass: Class[_]) extends FunSuite {

  private val exemptions = Seq("callback", "load", "close", "create", "getFriendNumbers")

  test("Java method names should be derivable from JNI method names") {
    val jniMethods = MethodMap(jniClass)

    traitClass
      .getDeclaredMethods.toSeq
      .map(_.getName)
      .filterNot(exemptions.contains)
      .foreach { name =>
        assert(jniMethods.contains(name))
      }
  }

} 
Example 93
Source File: ToxJniLogTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.impl.jni

import im.tox.tox4j.core.data.ToxFriendNumber
import im.tox.tox4j.impl.jni.proto.JniLog
import org.scalacheck.Gen
import org.scalatest.FunSuite
import org.scalatest.prop.PropertyChecks

@SuppressWarnings(Array("org.wartremover.warts.Equals"))
final class ToxJniLogTest extends FunSuite with PropertyChecks {

  private val TestMaxSize = 100

  private val friendNumber = ToxFriendNumber.fromInt(0).get

  test("constructing and destroying a Tox instance with logging enabled should result in a non-empty log") {
    ToxJniLog() // clear

    ToxJniLog.maxSize = TestMaxSize
    assert(ToxJniLog.maxSize == TestMaxSize)
    assert(ToxJniLog().entries.isEmpty)
    // Construct and destroy a Tox instance to cause something (tox_new) to be logged and the log
    // will be non-empty.
    ToxCoreImplFactory.withToxUnit { tox => }
    assert(ToxJniLog().entries.nonEmpty)
  }

  test("constructing and destroying a Tox instance with logging disabled should result in an empty log") {
    ToxJniLog() // clear

    ToxJniLog.maxSize = 0
    assert(ToxJniLog.maxSize == 0)
    assert(ToxJniLog().entries.isEmpty)
    ToxCoreImplFactory.withToxUnit { tox => }
    assert(ToxJniLog().entries.isEmpty)
  }

  test("one log entry per native call") {
    ToxJniLog() // clear

    ToxJniLog.maxSize = TestMaxSize
    assert(ToxJniLog().entries.isEmpty)

    ToxCoreImplFactory.withToxUnit { tox => }
    val count1 = ToxJniLog().entries.size

    ToxCoreImplFactory.withToxUnit { tox => tox.friendExists(friendNumber) }
    val count2 = ToxJniLog().entries.size

    assert(count2 == count1 + 1)
  }

  test("null protobufs are ignored") {
    assert(ToxJniLog.fromBytes(null) == JniLog.defaultInstance)
  }

  test("invalid protobufs are ignored") {
    forAll { (bytes: Array[Byte]) =>
      assert(ToxJniLog.fromBytes(bytes) == JniLog.defaultInstance)
    }
  }

  test("concurrent logging works") {
    ToxJniLog() // clear
    ToxJniLog.maxSize = 10000

    forAll(Gen.choose(1, 99), Gen.choose(1, 100)) { (threadCount, iterations) =>
      val threads = for (_ <- 1 to threadCount) yield {
        new Thread {
          override def run(): Unit = {
            ToxCoreImplFactory.withToxUnit { tox =>
              for (_ <- 0 until iterations) {
                tox.friendExists(friendNumber)
              }
            }
          }
        }
      }

      threads.foreach(_.start())
      threads.foreach(_.join())

      val log = ToxJniLog()
      assert(log.entries.size < 10000)
      assert(log.entries.size == threadCount + threadCount * iterations)
      assert(ToxJniLog.toString(log).count(_ == '\n') == log.entries.size)
    }

    assert(ToxJniLog().entries.isEmpty)
    ToxJniLog.maxSize = 0
  }

} 
Example 94
Source File: AliceBobTestBase.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.testing.autotest

import com.typesafe.scalalogging.Logger
import im.tox.tox4j.OptimisedIdOps._
import im.tox.tox4j.av.ToxAv
import im.tox.tox4j.core.ToxCore
import im.tox.tox4j.core.data.ToxFriendNumber
import im.tox.tox4j.testing.ToxTestMixin
import im.tox.tox4j.testing.autotest.AliceBobTestBase.Chatter
import org.scalatest.FunSuite
import org.slf4j.LoggerFactory

import scala.annotation.tailrec

object AliceBobTestBase {
  val FriendNumber: ToxFriendNumber = ToxFriendNumber.fromInt(10).get

  final case class Chatter[T](
      tox: ToxCore,
      av: ToxAv,
      client: ChatClientT[T],
      state: ChatStateT[T]
  )
}

abstract class AliceBobTestBase extends FunSuite with ToxTestMixin {

  protected val logger = Logger(LoggerFactory.getLogger(classOf[AliceBobTestBase]))

  protected type State
  protected type ChatState = ChatStateT[State]
  protected type ChatClient = ChatClientT[State]

  protected def initialState: State

  protected def newChatClient(name: String, expectedFriendName: String): ChatClient

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  private def getTopLevelMethod(stackTrace: Seq[StackTraceElement]): String = {
    stackTrace
      .filter(_.getClassName == classOf[AliceBobTest].getName)
      .lastOption
      .fold("<unknown>")(_.getMethodName)
  }

  @tailrec
  private def mainLoop(clients: Seq[Chatter[State]]): Unit = {
    val nextState = clients.map {
      case Chatter(tox, av, client, state) =>
        Chatter[State](tox, av, client, state |> tox.iterate(client) |> (_.runTasks(tox, av)))
    }

    val interval = (nextState.map(_.tox.iterationInterval) ++ nextState.map(_.av.iterationInterval)).min
    Thread.sleep(interval)

    if (nextState.exists(_.state.chatting)) {
      mainLoop(nextState)
    }
  }

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  protected def runAliceBobTest(
    withTox: (ToxCore => Unit) => Unit,
    withToxAv: ToxCore => (ToxAv => Unit) => Unit
  ): Unit = {
    val method = getTopLevelMethod(Thread.currentThread.getStackTrace)
    logger.info(s"[${Thread.currentThread.getId}] --- ${getClass.getSimpleName}.$method")

    val aliceChat = newChatClient("Alice", "Bob")
    val bobChat = newChatClient("Bob", "Alice")

    withTox { alice =>
      withTox { bob =>
        withToxAv(alice) { aliceAv =>
          withToxAv(bob) { bobAv =>
            assert(alice ne bob)

            addFriends(alice, AliceBobTestBase.FriendNumber.value)
            addFriends(bob, AliceBobTestBase.FriendNumber.value)

            alice.addFriendNorequest(bob.getPublicKey)
            bob.addFriendNorequest(alice.getPublicKey)

            aliceChat.expectedFriendAddress = bob.getAddress
            bobChat.expectedFriendAddress = alice.getAddress

            val aliceState = aliceChat.setup(alice)(ChatStateT[State](initialState))
            val bobState = bobChat.setup(bob)(ChatStateT[State](initialState))

            mainLoop(Seq(
              Chatter(alice, aliceAv, aliceChat, aliceState),
              Chatter(bob, bobAv, bobChat, bobState)
            ))
          }
        }
      }
    }
  }
} 
Example 95
Source File: AutoTestSuite.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.testing.autotest

import com.typesafe.scalalogging.Logger
import im.tox.tox4j.TestConstants
import im.tox.tox4j.core.data.ToxFriendNumber
import im.tox.tox4j.core.enums.ToxConnection
import im.tox.tox4j.core.options.ToxOptions
import im.tox.tox4j.impl.jni.{ ToxAvImplFactory, ToxCoreImplFactory }
import im.tox.tox4j.testing.autotest.AutoTest.ClientState
import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts
import org.slf4j.LoggerFactory
import shapeless.<:!<

import scala.util.Random

object AutoTestSuite {

  sealed abstract class Timed[A, R] {

    protected def wrap(time: Int, result: A): R

    def timed(block: => A): R = {
      val start = System.currentTimeMillis()
      val result = block
      val end = System.currentTimeMillis()
      wrap((end - start).toInt, result)
    }

  }

  implicit def otherTimed[A](implicit notUnit: A <:!< Unit): Timed[A, (Int, A)] = new Timed[A, (Int, A)] {
    protected def wrap(time: Int, result: A): (Int, A) = (time, result)
  }
  implicit val unitTimed: Timed[Unit, Int] = new Timed[Unit, Int] {
    protected def wrap(time: Int, result: Unit): Int = time
  }

  def timed[A, R](block: => A)(implicit timed: Timed[A, R]): R = timed.timed(block)

}

abstract class AutoTestSuite extends FunSuite with Timeouts {

  private val logger = Logger(LoggerFactory.getLogger(getClass))

  protected def maxParticipantCount: Int = 2

  type S

  abstract class EventListener(val initial: S) extends AutoTest.EventListener[S] {

    override def selfConnectionStatus(
      connectionStatus: ToxConnection
    )(state: State): State = {
      debug(state, s"Our connection: $connectionStatus")
      state
    }

    override def friendConnectionStatus(
      friendNumber: ToxFriendNumber,
      connectionStatus: ToxConnection
    )(state: State): State = {
      debug(state, s"Friend ${state.id(friendNumber)}'s connection: $connectionStatus")
      state
    }

  }

  def Handler: EventListener // scalastyle:ignore method.name

  protected def debug(state: ClientState[S], message: String): Unit = {
    logger.debug(s"[${state.id}] $message")
  }

  @SuppressWarnings(Array("org.wartremover.warts.Equals"))
  def run(ipv6Enabled: Boolean = true, udpEnabled: Boolean = true): Unit = {
    failAfter(TestConstants.Timeout) {
      val participantCount =
        if (maxParticipantCount == 2) {
          maxParticipantCount
        } else {
          new Random().nextInt(maxParticipantCount - 2) + 2
        }
      AutoTest(ToxCoreImplFactory, ToxAvImplFactory).run(participantCount, ToxOptions(ipv6Enabled, udpEnabled), Handler)
    }
  }

  test("UDP")(run(ipv6Enabled = true, udpEnabled = true))

} 
Example 96
Source File: ToxBootstrapExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.exceptions

import im.tox.core.network.Port
import im.tox.tox4j.core.ToxCoreConstants
import im.tox.tox4j.core.data.ToxPublicKey
import im.tox.tox4j.testing.ToxTestMixin
import org.scalatest.FunSuite

final class ToxBootstrapExceptionTest extends FunSuite with ToxTestMixin {

  private val host = "192.254.75.98"
  private val publicKey = ToxPublicKey.fromValue(Array.ofDim(ToxCoreConstants.PublicKeySize)).toOption.get
  private val port = Port.fromInt(ToxCoreConstants.DefaultStartPort).get

  test("BootstrapBadPort1") {
    interceptWithTox(ToxBootstrapException.Code.BAD_PORT)(
      _.bootstrap(
        host,
        Port.unsafeFromInt(0),
        publicKey
      )
    )
  }

  test("BootstrapBadPort2") {
    interceptWithTox(ToxBootstrapException.Code.BAD_PORT)(
      _.bootstrap(
        host,
        Port.unsafeFromInt(-10),
        publicKey
      )
    )
  }

  test("BootstrapBadPort3") {
    interceptWithTox(ToxBootstrapException.Code.BAD_PORT)(
      _.bootstrap(
        host,
        Port.unsafeFromInt(65536),
        publicKey
      )
    )
  }

  test("BootstrapBadHost") {
    interceptWithTox(ToxBootstrapException.Code.BAD_HOST)(
      _.bootstrap(
        ".",
        port,
        publicKey
      )
    )
  }

  test("BootstrapNullHost") {
    interceptWithTox(ToxBootstrapException.Code.NULL)(
      _.bootstrap(
        null,
        port,
        publicKey
      )
    )
  }

  test("BootstrapNullKey") {
    interceptWithTox(ToxBootstrapException.Code.NULL)(
      _.bootstrap(
        host,
        port,
        ToxPublicKey.unsafeFromValue(null)
      )
    )
  }

  test("BootstrapKeyTooShort") {
    interceptWithTox(ToxBootstrapException.Code.BAD_KEY)(
      _.bootstrap(
        host,
        port,
        ToxPublicKey.unsafeFromValue(Array.ofDim(ToxCoreConstants.PublicKeySize - 1))
      )
    )
  }

  test("BootstrapKeyTooLong") {
    interceptWithTox(ToxBootstrapException.Code.BAD_KEY)(
      _.bootstrap(
        host,
        port,
        ToxPublicKey.unsafeFromValue(Array.ofDim(ToxCoreConstants.PublicKeySize + 1))
      )
    )
  }

} 
Example 97
Source File: ToxSetInfoExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.exceptions

import im.tox.tox4j.ToxCoreTestBase
import im.tox.tox4j.core.ToxCoreConstants
import im.tox.tox4j.core.data.{ ToxNickname, ToxStatusMessage }
import im.tox.tox4j.testing.ToxTestMixin
import org.scalatest.FunSuite

final class ToxSetInfoExceptionTest extends FunSuite with ToxTestMixin {

  test("SetNameTooLong") {
    val array = ToxCoreTestBase.randomBytes(ToxCoreConstants.MaxNameLength + 1)

    interceptWithTox(ToxSetInfoException.Code.TOO_LONG)(
      _.setName(ToxNickname.unsafeFromValue(array))
    )
  }

  test("SetStatusMessageTooLong") {
    val array = ToxCoreTestBase.randomBytes(ToxCoreConstants.MaxStatusMessageLength + 1)

    interceptWithTox(ToxSetInfoException.Code.TOO_LONG)(
      _.setStatusMessage(ToxStatusMessage.unsafeFromValue(array))
    )
  }

  test("SetNameNull") {
    interceptWithTox(ToxSetInfoException.Code.NULL)(
      _.setName(ToxNickname.unsafeFromValue(null))
    )
  }

  test("SetStatusMessageNull") {
    interceptWithTox(ToxSetInfoException.Code.NULL)(
      _.setStatusMessage(ToxStatusMessage.unsafeFromValue(null))
    )
  }

} 
Example 98
Source File: ToxNewExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.exceptions

import im.tox.tox4j.core.options.{ ProxyOptions, SaveDataOptions }
import im.tox.tox4j.impl.jni.ToxCoreImplFactory.{ withToxUnit, withToxes }
import im.tox.tox4j.testing.ToxTestMixin
import org.scalatest.FunSuite

final class ToxNewExceptionTest extends FunSuite with ToxTestMixin {

  test("ToxNewProxyNull") {
    intercept(ToxNewException.Code.PROXY_BAD_HOST) {
      withToxUnit(ipv6Enabled = true, udpEnabled = true, new ProxyOptions.Socks5(null, 1)) { _ => }
    }
  }

  test("ToxNewProxyEmpty") {
    intercept(ToxNewException.Code.PROXY_BAD_HOST) {
      withToxUnit(ipv6Enabled = true, udpEnabled = true, new ProxyOptions.Socks5("", 1)) { _ => }
    }
  }

  test("ToxNewProxyBadPort0") {
    intercept(ToxNewException.Code.PROXY_BAD_PORT) {
      withToxUnit(ipv6Enabled = true, udpEnabled = true, new ProxyOptions.Socks5("localhost", 0)) { _ => }
    }
  }

  test("ToxNewProxyBadPortNegative") {
    intercept[IllegalArgumentException] {
      withToxUnit(ipv6Enabled = true, udpEnabled = true, new ProxyOptions.Socks5("localhost", -10)) { _ => }
    }
  }

  test("ToxNewProxyBadPortTooLarge") {
    intercept[IllegalArgumentException] {
      withToxUnit(ipv6Enabled = true, udpEnabled = true, new ProxyOptions.Socks5("localhost", 0x10000)) { _ => }
    }
  }

  test("ToxNewProxyBadAddress1") {
    intercept(ToxNewException.Code.PROXY_BAD_HOST) {
      val host = "\u2639" // scalastyle:ignore non.ascii.character.disallowed
      withToxUnit(ipv6Enabled = true, udpEnabled = true, new ProxyOptions.Socks5(host, 1)) { _ => }
    }
  }

  test("ToxNewProxyBadAddress2") {
    intercept(ToxNewException.Code.PROXY_BAD_HOST) {
      withToxUnit(ipv6Enabled = true, udpEnabled = true, new ProxyOptions.Socks5(".", 1)) { _ => }
    }
  }

  test("TooManyToxCreations") {
    intercept(ToxNewException.Code.PORT_ALLOC) {
      withToxes(102) { _ => }
    }
  }

  test("LoadEncrypted") {
    intercept(ToxNewException.Code.LOAD_ENCRYPTED) {
      withToxUnit(SaveDataOptions.ToxSave("toxEsave blah blah blah".getBytes)) { _ => }
    }
  }

  test("LoadBadFormat") {
    intercept(ToxNewException.Code.LOAD_BAD_FORMAT) {
      withToxUnit(SaveDataOptions.ToxSave("blah blah blah".getBytes)) { _ => }
    }
  }

} 
Example 99
Source File: ToxFriendSendMessageExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.exceptions

import im.tox.tox4j.core.data.{ ToxFriendNumber, ToxFriendMessage }
import im.tox.tox4j.core.enums.ToxMessageType
import im.tox.tox4j.testing.ToxTestMixin
import org.scalatest.FunSuite

final class ToxFriendSendMessageExceptionTest extends FunSuite with ToxTestMixin {

  private val friendNumber = ToxFriendNumber.fromInt(0).get
  private val badFriendNumber = ToxFriendNumber.fromInt(1).get

  test("SendMessageNotFound") {
    interceptWithTox(ToxFriendSendMessageException.Code.FRIEND_NOT_FOUND)(
      _.friendSendMessage(badFriendNumber, ToxMessageType.NORMAL, 0, ToxFriendMessage.fromString("hello").toOption.get)
    )
  }

  test("SendMessageNotConnected") {
    interceptWithTox(ToxFriendSendMessageException.Code.FRIEND_NOT_CONNECTED)(
      _.friendSendMessage(friendNumber, ToxMessageType.NORMAL, 0, ToxFriendMessage.fromString("hello").toOption.get)
    )
  }

  test("SendMessageNull") {
    interceptWithTox(ToxFriendSendMessageException.Code.NULL)(
      _.friendSendMessage(friendNumber, ToxMessageType.NORMAL, 0, ToxFriendMessage.unsafeFromValue(null))
    )
  }

  test("SendMessageEmpty") {
    interceptWithTox(ToxFriendSendMessageException.Code.EMPTY)(
      _.friendSendMessage(friendNumber, ToxMessageType.NORMAL, 0, ToxFriendMessage.unsafeFromValue("".getBytes))
    )
  }

} 
Example 100
Source File: ToxFriendDeleteExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.exceptions

import im.tox.tox4j.core.data.ToxFriendNumber
import im.tox.tox4j.testing.ToxTestMixin
import org.scalatest.FunSuite

final class ToxFriendDeleteExceptionTest extends FunSuite with ToxTestMixin {

  test("DeleteFriendTwice") {
    interceptWithTox(ToxFriendDeleteException.Code.FRIEND_NOT_FOUND) { tox =>
      addFriends(tox, 4)
      assert(tox.getFriendList sameElements Array(0, 1, 2, 3, 4))
      tox.deleteFriend(ToxFriendNumber.fromInt(2).get)
      assert(tox.getFriendList sameElements Array(0, 1, 3, 4))
      tox.deleteFriend(ToxFriendNumber.fromInt(2).get)
    }
  }

  test("DeleteNonExistentFriend") {
    interceptWithTox(ToxFriendDeleteException.Code.FRIEND_NOT_FOUND)(
      _.deleteFriend(ToxFriendNumber.fromInt(1).get)
    )
  }

} 
Example 101
Source File: ToxFriendCustomPacketExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.exceptions

import im.tox.tox4j.core.ToxCoreConstants
import im.tox.tox4j.core.data.{ ToxFriendNumber, ToxLosslessPacket, ToxLossyPacket }
import im.tox.tox4j.testing.ToxTestMixin
import org.scalatest.FunSuite

final class ToxFriendCustomPacketExceptionTest extends FunSuite with ToxTestMixin {

  private val friendNumber = ToxFriendNumber.fromInt(0).get
  private val badFriendNumber = ToxFriendNumber.fromInt(1).get

  test("SendLosslessPacketNotConnected") {
    interceptWithTox(ToxFriendCustomPacketException.Code.FRIEND_NOT_CONNECTED)(
      _.friendSendLosslessPacket(friendNumber, ToxLosslessPacket.fromValue(Array[Byte](160.toByte, 0, 1, 2, 3)).toOption.get)
    )
  }

  test("SendLossyPacketNotConnected") {
    interceptWithTox(ToxFriendCustomPacketException.Code.FRIEND_NOT_CONNECTED)(
      _.friendSendLossyPacket(friendNumber, ToxLossyPacket.fromValue(200.toByte +: Array.ofDim[Byte](4)).toOption.get)
    )
  }

  test("SendLosslessPacketNotFound") {
    interceptWithTox(ToxFriendCustomPacketException.Code.FRIEND_NOT_FOUND)(
      _.friendSendLosslessPacket(badFriendNumber, ToxLosslessPacket.fromValue(Array[Byte](160.toByte, 0, 1, 2, 3)).toOption.get)
    )
  }

  test("SendLossyPacketNotFound") {
    interceptWithTox(ToxFriendCustomPacketException.Code.FRIEND_NOT_FOUND)(
      _.friendSendLossyPacket(badFriendNumber, ToxLossyPacket.fromValue(Array[Byte](200.toByte, 0, 1, 2, 3)).toOption.get)
    )
  }

  test("SendLosslessPacketInvalid") {
    interceptWithTox(ToxFriendCustomPacketException.Code.INVALID)(
      _.friendSendLosslessPacket(friendNumber, ToxLosslessPacket.unsafeFromValue(Array[Byte](100.toByte)))
    )
  }

  test("SendLossyPacketInvalid") {
    interceptWithTox(ToxFriendCustomPacketException.Code.INVALID)(
      _.friendSendLossyPacket(friendNumber, ToxLossyPacket.unsafeFromValue(Array[Byte](100.toByte)))
    )
  }

  test("SendLosslessPacketEmpty") {
    interceptWithTox(ToxFriendCustomPacketException.Code.EMPTY)(
      _.friendSendLosslessPacket(friendNumber, ToxLosslessPacket.unsafeFromValue(Array[Byte]()))
    )
  }

  test("SendLossyPacketEmpty") {
    interceptWithTox(ToxFriendCustomPacketException.Code.EMPTY)(
      _.friendSendLossyPacket(friendNumber, ToxLossyPacket.unsafeFromValue(Array[Byte]()))
    )
  }

  test("SendLosslessPacketNull") {
    interceptWithTox(ToxFriendCustomPacketException.Code.NULL)(
      _.friendSendLosslessPacket(friendNumber, ToxLosslessPacket.unsafeFromValue(null))
    )
  }

  test("SendLossyPacketNull") {
    interceptWithTox(ToxFriendCustomPacketException.Code.NULL)(
      _.friendSendLossyPacket(friendNumber, ToxLossyPacket.unsafeFromValue(null))
    )
  }

  test("SendLosslessPacketTooLong") {
    interceptWithTox(ToxFriendCustomPacketException.Code.TOO_LONG)(
      _.friendSendLosslessPacket(
        friendNumber,
        ToxLosslessPacket.unsafeFromValue(160.toByte +: Array.ofDim[Byte](ToxCoreConstants.MaxCustomPacketSize))
      )
    )
  }

  test("SendLossyPacketTooLong") {
    interceptWithTox(ToxFriendCustomPacketException.Code.TOO_LONG)(
      _.friendSendLossyPacket(
        friendNumber,
        ToxLossyPacket.unsafeFromValue(200.toByte +: Array.ofDim[Byte](ToxCoreConstants.MaxCustomPacketSize))
      )
    )
  }

} 
Example 102
Source File: ToxFriendByPublicKeyExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.exceptions

import im.tox.tox4j.core.data.ToxPublicKey
import im.tox.tox4j.testing.ToxTestMixin
import org.scalatest.FunSuite

final class ToxFriendByPublicKeyExceptionTest extends FunSuite with ToxTestMixin {

  test("Null") {
    interceptWithTox(ToxFriendByPublicKeyException.Code.NULL)(
      _.friendByPublicKey(ToxPublicKey.unsafeFromValue(null))
    )
  }

  test("NotFound") {
    interceptWithTox(ToxFriendByPublicKeyException.Code.NOT_FOUND) { tox =>
      tox.friendByPublicKey(tox.getPublicKey)
    }
  }

} 
Example 103
Source File: ToxFileSendExceptionTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.exceptions

import im.tox.tox4j.core.ToxCoreConstants
import im.tox.tox4j.core.data.{ ToxFriendNumber, ToxFileId, ToxFilename }
import im.tox.tox4j.core.enums.ToxFileKind
import im.tox.tox4j.testing.ToxTestMixin
import org.scalatest.FunSuite

final class ToxFileSendExceptionTest extends FunSuite with ToxTestMixin {

  private val friendNumber = ToxFriendNumber.fromInt(0).get
  private val badFriendNumber = ToxFriendNumber.fromInt(1).get

  test("FileSendNotConnected") {
    interceptWithTox(ToxFileSendException.Code.FRIEND_NOT_CONNECTED)(
      _.fileSend(friendNumber, ToxFileKind.DATA, 123, ToxFileId.empty,
        ToxFilename.fromString("filename").toOption.get)
    )
  }

  test("FileSendNotFound") {
    interceptWithTox(ToxFileSendException.Code.FRIEND_NOT_FOUND)(
      _.fileSend(badFriendNumber, ToxFileKind.DATA, 123, ToxFileId.empty,
        ToxFilename.fromString("filename").toOption.get)
    )
  }

  test("FileSendNameTooLong") {
    interceptWithTox(ToxFileSendException.Code.NAME_TOO_LONG)(
      _.fileSend(friendNumber, ToxFileKind.DATA, 123, ToxFileId.empty,
        ToxFilename.unsafeFromValue(Array.ofDim(ToxCoreConstants.MaxFilenameLength + 1)))
    )
  }

} 
Example 104
Source File: ToxCoreEventAdapterTest.scala    From jvm-toxcore-c   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.callbacks

import im.tox.tox4j.core.data._
import im.tox.tox4j.core.enums._
import im.tox.tox4j.core.proto._
import im.tox.tox4j.testing.GetDisjunction._
import org.scalatest.FunSuite

final class ToxCoreEventAdapterTest extends FunSuite {

  private val listener = new ToxCoreEventAdapter[Unit]
  private val friendNumber = ToxFriendNumber.fromInt(0).get

  def test[T](f: => Unit)(implicit evidence: Manifest[T]): Unit = {
    test(evidence.runtimeClass.getSimpleName)(f)
  }

  test[SelfConnectionStatus] {
    listener.selfConnectionStatus(ToxConnection.NONE)(())
  }

  test[FileRecvControl] {
    listener.fileRecvControl(friendNumber, 0, ToxFileControl.RESUME)(())
  }

  test[FileRecv] {
    listener.fileRecv(friendNumber, 0, ToxFileKind.DATA, 0, ToxFilename.fromString("").toOption.get)(())
  }

  test[FileRecvChunk] {
    listener.fileRecvChunk(friendNumber, 0, 0, Array.empty)(())
  }

  test[FileChunkRequest] {
    listener.fileChunkRequest(friendNumber, 0, 0, 0)(())
  }

  test[FriendConnectionStatus] {
    listener.friendConnectionStatus(friendNumber, ToxConnection.NONE)(())
  }

  test[FriendMessage] {
    listener.friendMessage(friendNumber, ToxMessageType.NORMAL, 0, ToxFriendMessage.fromString("hello").toOption.get)(())
  }

  test[FriendName] {
    listener.friendName(friendNumber, ToxNickname.fromString("").toOption.get)(())
  }

  test[FriendRequest] {
    listener.friendRequest(
      ToxPublicKey.unsafeFromValue(null),
      0,
      ToxFriendRequestMessage.fromString("").toOption.get
    )(())
  }

  test[FriendStatus] {
    listener.friendStatus(friendNumber, ToxUserStatus.NONE)(())
  }

  test[FriendStatusMessage] {
    listener.friendStatusMessage(friendNumber, ToxStatusMessage.fromString("").toOption.get)(())
  }

  test[FriendTyping] {
    listener.friendTyping(friendNumber, isTyping = false)(())
  }

  test[FriendLosslessPacket] {
    listener.friendLosslessPacket(friendNumber, ToxLosslessPacket.fromByteArray(160, Array.empty).toOption.get)(())
  }

  test[FriendLossyPacket] {
    listener.friendLossyPacket(friendNumber, ToxLossyPacket.fromByteArray(200, Array.empty).toOption.get)(())
  }

  test[FriendReadReceipt] {
    listener.friendReadReceipt(friendNumber, 0)(())
  }

} 
Example 105
Source File: TestSparkContext.scala    From spark-images   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.image

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types._
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, DataFrame, SQLContext, SparkSession}

import scala.reflect.runtime.universe._
import org.scalatest.{FunSuite, BeforeAndAfterAll}

// This context is used for all tests in this project
trait TestSparkContext extends BeforeAndAfterAll { self: FunSuite =>
  @transient var sc: SparkContext = _
  @transient var sqlContext: SQLContext = _
  @transient lazy val spark: SparkSession = {
    val conf = new SparkConf()
      .setMaster("local[*]")
      .setAppName("Spark-Image-Test")
      .set("spark.ui.port", "4079")
      .set("spark.sql.shuffle.partitions", "4")  // makes small tests much faster

    val sess = SparkSession.builder().config(conf).getOrCreate()
    sess.sparkContext.setLogLevel("WARN")
    sess
  }

  override def beforeAll() {
    super.beforeAll()
    sc = spark.sparkContext
    sqlContext = spark.sqlContext
    import spark.implicits._
  }

  override def afterAll() {
    sqlContext = null
    if (sc != null) {
      sc.stop()
    }
    sc = null
    super.afterAll()
  }

  def makeDF[T: TypeTag](xs: Seq[T], col: String): DataFrame = {
    sqlContext.createDataFrame(xs.map(Tuple1.apply)).toDF(col)
  }

  def compareRows(r1: Array[Row], r2: Seq[Row]): Unit = {
    val a = r1.sortBy(_.toString())
    val b = r2.sortBy(_.toString())
    assert(a === b)
  }
} 
Example 106
Source File: TestTwitterBase.scala    From kafka-tweet-producer   with Apache License 2.0 5 votes vote down vote up
package com.eneco.trading.kafka.connect.twitter

import org.scalatest.{FunSuite, Matchers, BeforeAndAfter}


trait TestTwitterBase extends FunSuite with Matchers with BeforeAndAfter {
  def getConfig = {
    Map(TwitterSourceConfig.CONSUMER_KEY_CONFIG->"test",
      TwitterSourceConfig.CONSUMER_SECRET_CONFIG->"c-secret",
      TwitterSourceConfig.SECRET_CONFIG->"secret",
      TwitterSourceConfig.TOKEN_CONFIG->"token",
      TwitterSourceConfig.TRACK_TERMS->"term1",
      TwitterSourceConfig.TWITTER_APP_NAME->"myApp",
      TwitterSourceConfig.BATCH_SIZE->"1337",
      TwitterSourceConfig.TOPIC->"just-a-topic"
    )
  }
  def getSinkConfig = {
    Map(TwitterSinkConfig.CONSUMER_KEY_CONFIG->"test",
      TwitterSinkConfig.CONSUMER_SECRET_CONFIG->"c-secret",
      TwitterSinkConfig.SECRET_CONFIG->"secret",
      TwitterSinkConfig.TOKEN_CONFIG->"token",
      TwitterSinkConfig.TOPICS->"just-a-sink-topic"
    )
  }
} 
Example 107
Source File: StringUtilsSuite.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.transformers.util

import org.scalatest.FunSuite

class StringUtilsSuite extends FunSuite {
  private def testSnakeConversion(name: String, input: String, expected: String): Unit = {
    test(name) {
      assert(StringUtils.toSnakeCase(input) == expected)
    }
  }

  testSnakeConversion(
    "doesn't change lower case string",
    "monkey",
    "monkey"
  )

  testSnakeConversion(
    "doesn't change lower case string with underscores",
    "mon_key",
    "mon_key"
  )

  testSnakeConversion(
    "simple camel to snake case",
    "monKey",
    "mon_key"
  )

  testSnakeConversion(
    "upper camel to snake",
    "MonKey",
    "mon_key"
  )

  testSnakeConversion(
    "mixed",
    "MonKe_y",
    "mon_ke_y"
  )

  test("SnakeCaseMap") {
    val m = new SnakeCaseMap(
      Map(
        "AniMal" -> "MonKey"
      ))
    assert(m("AniMal") == "MonKey")
    assert(m("aniMal") == "MonKey")
    assert(m("ani_mal") == "MonKey")
    assert(!m.contains("animal"))
  }

  test("SnakeCaseMap (add / subtract)") {
    val base = new SnakeCaseMap(
      Map(
        "AniMal" -> "MonKey",
        "vegeTable" -> "carrot"
      ))
    val added = base + ("kEy" -> "value")
    assert(added("ani_mal") == "MonKey")
    assert(added("k_ey") == "value")

    val subtracted = base - "vege_table"
    assert(subtracted("ani_mal") == "MonKey")
    assert(subtracted.size == 1)
  }
} 
Example 108
Source File: GlowBaseTest.scala    From glow   with Apache License 2.0 5 votes vote down vote up
package io.projectglow.sql

import htsjdk.samtools.util.Log
import org.apache.spark.sql.SparkSession
import org.apache.spark.{DebugFilesystem, SparkConf}
import org.scalatest.concurrent.{AbstractPatienceConfiguration, Eventually}
import org.scalatest.time.{Milliseconds, Seconds, Span}
import org.scalatest.{Args, FunSuite, Status, Tag}

import io.projectglow.Glow
import io.projectglow.SparkTestShim.SharedSparkSessionBase
import io.projectglow.common.{GlowLogging, TestUtils}
import io.projectglow.sql.util.BGZFCodec

abstract class GlowBaseTest
    extends FunSuite
    with SharedSparkSessionBase
    with GlowLogging
    with GlowTestData
    with TestUtils
    with JenkinsTestPatience {

  override protected def sparkConf: SparkConf = {
    super
      .sparkConf
      .set("spark.hadoop.io.compression.codecs", classOf[BGZFCodec].getCanonicalName)
  }

  override def initializeSession(): Unit = ()

  override protected implicit def spark: SparkSession = {
    val sess = SparkSession.builder().config(sparkConf).master("local[2]").getOrCreate()
    Glow.register(sess)
    SparkSession.setActiveSession(sess)
    Log.setGlobalLogLevel(Log.LogLevel.ERROR)
    sess
  }

  protected def gridTest[A](testNamePrefix: String, testTags: Tag*)(params: Seq[A])(
      testFun: A => Unit): Unit = {
    for (param <- params) {
      test(testNamePrefix + s" ($param)", testTags: _*)(testFun(param))
    }
  }

  override def afterEach(): Unit = {
    DebugFilesystem.assertNoOpenStreams()
    eventually {
      assert(spark.sparkContext.getPersistentRDDs.isEmpty)
      assert(spark.sharedState.cacheManager.isEmpty, "Cache not empty.")
    }
    super.afterEach()
  }

  override def runTest(testName: String, args: Args): Status = {
    logger.info(s"Running test '$testName'")
    val res = super.runTest(testName, args)
    if (res.succeeds()) {
      logger.info(s"Done running test '$testName'")
    } else {
      logger.info(s"Done running test '$testName' with a failure")
    }
    res
  }

  protected def withSparkConf[T](configs: Map[String, String])(f: => T): T = {
    val initialConfigValues = configs.keys.map(k => (k, spark.conf.getOption(k)))
    try {
      configs.foreach { case (k, v) => spark.conf.set(k, v) }
      f
    } finally {
      initialConfigValues.foreach {
        case (k, Some(v)) => spark.conf.set(k, v)
        case (k, None) => spark.conf.unset(k)
      }
    }
  }
}


  final override implicit val patienceConfig: PatienceConfig =
    if (sys.env.get("JENKINS_HOST").nonEmpty) {
      // increase the timeout on jenkins where parallelizing causes things to be very slow
      PatienceConfig(Span(10, Seconds), Span(50, Milliseconds))
    } else {
      // use the default timeout on local machines so failures don't hang for a long time
      PatienceConfig(Span(5, Seconds), Span(15, Milliseconds))
    }
} 
Example 109
Source File: SHC.scala    From shc   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import java.io.File

import com.google.common.io.Files
import org.apache.hadoop.hbase.{HColumnDescriptor, HTableDescriptor, TableName, HBaseTestingUtility}
import org.apache.hadoop.hbase.client.{Scan, Put, ConnectionFactory, Table}
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.sql.execution.datasources.hbase.SparkHBaseConf
import org.apache.spark.sql.types.UTF8String
import org.apache.spark.{SparkContext, SparkConf, Logging}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
import scala.collection.JavaConverters._

class SHC  extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll  with Logging {
  implicit class StringToColumn(val sc: StringContext) {
    def $(args: Any*): ColumnName = {
      new ColumnName(sc.s(args: _*))
    }
  }


  private[spark] var htu = HBaseTestingUtility.createLocalHTU()
  private[spark] def tableName = "table1"

  private[spark] def columnFamilies: Array[String] = Array.tabulate(9){ x=> s"cf$x"}
  var table: Table = null
  val conf = new SparkConf
  conf.set(SparkHBaseConf.testConf, "true")
  SparkHBaseConf.conf = htu.getConfiguration
  // private[spark] var columnFamilyStr = Bytes.toString(columnFamily)

  def catalog = s"""{
            |"table":{"namespace":"default", "name":"table1"},
            |"rowkey":"key",
            |"columns":{
              |"col0":{"cf":"rowkey", "col":"key", "type":"string"},
              |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"},
              |"col2":{"cf":"cf2", "col":"col2", "type":"double"},
              |"col3":{"cf":"cf3", "col":"col3", "type":"float"},
              |"col4":{"cf":"cf4", "col":"col4", "type":"int"},
              |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"},
              |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"},
              |"col7":{"cf":"cf7", "col":"col7", "type":"string"},
              |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"}
            |}
          |}""".stripMargin

  override def beforeAll() {
    val tempDir: File = Files.createTempDir
    tempDir.deleteOnExit
    htu.cleanupTestDir
    htu.startMiniZKCluster
    htu.startMiniHBaseCluster(1, 4)
    logInfo(" - minicluster started")
    println(" - minicluster started")

  }

  override def afterAll() {
    try {
      table.close()
      println("shutdown")
      htu.deleteTable(TableName.valueOf(tableName))
      logInfo("shuting down minicluster")
      htu.shutdownMiniHBaseCluster
      htu.shutdownMiniZKCluster
      logInfo(" - minicluster shut down")
      htu.cleanupTestDir
    } catch {
      case _ => logError("teardown error")
    }
  }

  def createTable(name: String, cfs: Array[String]) {
    val tName = Bytes.toBytes(name)
    val bcfs = cfs.map(Bytes.toBytes(_))
    try {
      htu.deleteTable(TableName.valueOf(tName))
    } catch {
      case _ =>
        logInfo(" - no table " + name + " found")
    }
    htu.createMultiRegionTable(TableName.valueOf(tName), bcfs)
  }


  def createTable(name: Array[Byte], cfs: Array[Array[Byte]]) {
    try {
      htu.deleteTable(TableName.valueOf(name))
    } catch {
      case _ =>
        logInfo(" - no table " + Bytes.toString(name) + " found")
    }
    htu.createMultiRegionTable(TableName.valueOf(name), cfs)
  }
} 
Example 110
Source File: CatalogSuite.scala    From shc   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.apache.spark.Logging
import org.apache.spark.sql.execution.datasources.hbase.HBaseTableCatalog
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}

class CatalogSuite  extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll  with Logging{
  def catalog = s"""{
            |"table":{"namespace":"default", "name":"table1"},
            |"rowkey":"key1:key2",
            |"columns":{
              |"col00":{"cf":"rowkey", "col":"key1", "type":"string", "length":"6"},
              |"col01":{"cf":"rowkey", "col":"key2", "type":"int"},
              |"col1":{"cf":"cf1", "col":"col1", "type":"boolean"},
              |"col2":{"cf":"cf2", "col":"col2", "type":"double"},
              |"col3":{"cf":"cf3", "col":"col3", "type":"float"},
              |"col4":{"cf":"cf4", "col":"col4", "type":"int"},
              |"col5":{"cf":"cf5", "col":"col5", "type":"bigint"},
              |"col6":{"cf":"cf6", "col":"col6", "type":"smallint"},
              |"col7":{"cf":"cf7", "col":"col7", "type":"string"},
              |"col8":{"cf":"cf8", "col":"col8", "type":"tinyint"}
            |}
          |}""".stripMargin

  test("Catalog meta data check") {
     val m = HBaseTableCatalog(Map(HBaseTableCatalog.tableCatalog->catalog))
    assert(m.row.varLength == false)
    assert(m.row.length == 10)
  }

} 
Example 111
Source File: MVMSuite.scala    From zen   with Apache License 2.0 5 votes vote down vote up
package com.github.cloudml.zen.ml.recommendation

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, sum => brzSum}
import com.github.cloudml.zen.ml.util._
import com.google.common.io.Files
import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.scalatest.{FunSuite, Matchers}

class MVMSuite extends FunSuite with SharedSparkContext with Matchers {
  test("binary classification") {
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val dataSetFile = classOf[MVMSuite].getClassLoader().getResource("binary_classification_data.txt").toString()
    val checkpoint = s"$sparkHome/target/tmp"
    sc.setCheckpointDir(checkpoint)
    val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile).zipWithIndex().map {
      case (LabeledPoint(label, features), id) =>
        val newLabel = if (label > 0.0) 1.0 else 0.0
        (id, LabeledPoint(newLabel, features))
    }
    val stepSize = 0.1
    val regParam = 1e-2
    val l2 = (regParam, regParam, regParam)
    val rank = 20
    val useAdaGrad = true
    val trainSet = dataSet.cache()
    val fm = new FMClassification(trainSet, stepSize, l2, rank, useAdaGrad)

    val maxIter = 10
    val pps = new Array[Double](maxIter)
    var i = 0
    val startedAt = System.currentTimeMillis()
    while (i < maxIter) {
      fm.run(1)
      pps(i) = fm.saveModel().loss(trainSet)
      i += 1
    }
    println((System.currentTimeMillis() - startedAt) / 1e3)
    pps.foreach(println)

    val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs }
    assert(ppsDiff.count(_ < 0).toDouble / ppsDiff.size > 0.05)

    val fmModel = fm.saveModel()
    val tempDir = Files.createTempDir()
    tempDir.deleteOnExit()
    val path = tempDir.toURI.toString
    fmModel.save(sc, path)
    val sameModel = FMModel.load(sc, path)
    assert(sameModel.k === fmModel.k)
    assert(sameModel.classification === fmModel.classification)
    assert(sameModel.factors.sortByKey().map(_._2).collect() ===
      fmModel.factors.sortByKey().map(_._2).collect())
  }

  ignore("url_combined classification") {
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val dataSetFile = classOf[MVMSuite].getClassLoader().getResource("binary_classification_data.txt").toString()
    val checkpointDir = s"$sparkHome/target/tmp"
    sc.setCheckpointDir(checkpointDir)
    val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile).zipWithIndex().map {
      case (LabeledPoint(label, features), id) =>
        val newLabel = if (label > 0.0) 1.0 else 0.0
        (id, LabeledPoint(newLabel, features))
    }.cache()
    val numFeatures = dataSet.first()._2.features.size
    val stepSize = 0.1
    val numIterations = 500
    val regParam = 1e-3
    val rank = 20
    val views = Array(20, numFeatures / 2, numFeatures).map(_.toLong)
    val useAdaGrad = true
    val useWeightedLambda = true
    val miniBatchFraction = 1
    val Array(trainSet, testSet) = dataSet.randomSplit(Array(0.8, 0.2))
    trainSet.cache()
    testSet.cache()

    val fm = new MVMClassification(trainSet, stepSize, views, regParam, 0.0, rank,
      useAdaGrad, useWeightedLambda, miniBatchFraction)
    fm.run(numIterations)
    val model = fm.saveModel()
    println(f"Test loss: ${model.loss(testSet.cache())}%1.4f")

  }

} 
Example 112
Source File: MLPSuite.scala    From zen   with Apache License 2.0 5 votes vote down vote up
package com.github.cloudml.zen.ml.neuralNetwork


import com.github.cloudml.zen.ml.util.{Utils, SparkUtils, MnistDatasetSuite}
import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.linalg.{Vector => SV}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.scalatest.{FunSuite, Matchers}

class MLPSuite extends FunSuite with MnistDatasetSuite with Matchers {
  ignore("MLP") {
    val (data, numVisible) = mnistTrainDataset(5000)
    val topology = Array(numVisible, 500, 10)
    val nn = MLP.train(data, 20, 1000, topology, fraction = 0.02,
      learningRate = 0.1, weightCost = 0.0)

    // val nn = MLP.runLBFGS(data, topology, 100, 4000, 1e-5, 0.001)
    // MLP.runSGD(data, nn, 37, 6000, 0.1, 0.5, 0.0)

    val (dataTest, _) = mnistTrainDataset(10000, 5000)
    println("Error: " + MLP.error(dataTest, nn, 100))
  }

  ignore("binary classification") {
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val dataSetFile = s"$sparkHome/data/a5a"
    val checkpoint = s"$sparkHome/target/tmp"
    sc.setCheckpointDir(checkpoint)
    val data = MLUtils.loadLibSVMFile(sc, dataSetFile).map {
      case LabeledPoint(label, features) =>
        val y = BDV.zeros[Double](2)
        y := 0.04 / y.length
        y(if (label > 0) 0 else 1) += 0.96
        (features, SparkUtils.fromBreeze(y))
    }.persist()
    val trainSet = data.filter(_._1.hashCode().abs % 5 == 3).persist()
    val testSet = data.filter(_._1.hashCode().abs % 5 != 3).persist()

    val numVisible = trainSet.first()._1.size
    val topology = Array(numVisible, 30, 2)
    var nn = MLP.train(trainSet, 100, 1000, topology, fraction = 0.02,
      learningRate = 0.05, weightCost = 0.0)

    val modelPath = s"$checkpoint/model"
    nn.save(sc, modelPath)
    nn = MLP.load(sc, modelPath)
    val scoreAndLabels = testSet.map { case (features, label) =>
      val out = nn.predict(SparkUtils.toBreeze(features).toDenseVector.asDenseMatrix.t)
      // Utils.random.nextInt(2).toDouble
      (out(0, 0), if (label(0) > 0.5) 1.0 else 0.0)
    }.persist()
    scoreAndLabels.repartition(1).map(t => s"${t._1}\t${t._2}").
      saveAsTextFile(s"$checkpoint/mlp/${System.currentTimeMillis()}")
    val testAccuracy = new BinaryClassificationMetrics(scoreAndLabels).areaUnderROC()
    println(f"Test AUC = $testAccuracy%1.6f")

  }

} 
Example 113
Source File: DBNSuite.scala    From zen   with Apache License 2.0 5 votes vote down vote up
package com.github.cloudml.zen.ml.neuralNetwork

import com.github.cloudml.zen.ml.util.MnistDatasetSuite
import org.scalatest.{FunSuite, Matchers}

class DBNSuite extends FunSuite with MnistDatasetSuite with Matchers {

  ignore("DBN") {
    val (data, numVisible) = mnistTrainDataset(2500)
    val dbn = new DBN(Array(numVisible, 500, 10))
    DBN.pretrain(data, 100, 1000, dbn, 0.1, 0.05, 0.0)
    DBN.finetune(data, 100, 1000, dbn, 0.02, 0.05, 0.0)
    val (dataTest, _) = mnistTrainDataset(5000, 2500)
    println("Error: " + MLP.error(dataTest, dbn.mlp, 100))
  }

} 
Example 114
Source File: RBMSuite.scala    From zen   with Apache License 2.0 5 votes vote down vote up
package com.github.cloudml.zen.ml.neuralNetwork

import com.github.cloudml.zen.ml.util.MnistDatasetSuite
import org.scalatest.{FunSuite, Matchers}

class RBMSuite extends FunSuite with MnistDatasetSuite with Matchers {

  ignore("RBM") {
    val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
    val checkpoint = s"$sparkHome/target/tmp/rmb/${System.currentTimeMillis()}"
    sc.setCheckpointDir(checkpoint)
    val (data, numVisible) = mnistTrainDataset(2500)
    val rbm = RBM.train(data.map(_._1), 100, 1000, numVisible, 256, 0.1, 0.05, 0.0)
    val modelPath = s"$checkpoint/model"
    rbm.save(sc, modelPath)
    val newRBM = RBM.load(sc, modelPath)
    assert(rbm.equals(newRBM))
  }

} 
Example 115
Source File: LogisticRegressionSuite.scala    From zen   with Apache License 2.0 5 votes vote down vote up
package com.github.cloudml.zen.ml.regression

import com.github.cloudml.zen.ml.util._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils
import org.scalatest.{Matchers, FunSuite}
import com.github.cloudml.zen.ml.util.SparkUtils._


class LogisticRegressionSuite extends FunSuite with SharedSparkContext with Matchers {

  test("LogisticRegression MIS") {
    val zenHome = sys.props.getOrElse("zen.test.home", fail("zen.test.home is not set!"))
    val dataSetFile = classOf[LogisticRegressionSuite].getClassLoader().getResource("binary_classification_data.txt").toString()
    val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile)
    val max = dataSet.map(_.features.activeValuesIterator.map(_.abs).sum + 1L).max

    val maxIter = 10
    val stepSize = 1 / (2 * max)
    val trainDataSet = dataSet.zipWithUniqueId().map { case (LabeledPoint(label, features), id) =>
      val newLabel = if (label > 0.0) 1.0 else -1.0
      (id, LabeledPoint(newLabel, features))
    }
    val lr = new LogisticRegressionMIS(trainDataSet, stepSize)
    val pps = new Array[Double](maxIter)
    var i = 0
    val startedAt = System.currentTimeMillis()
    while (i < maxIter) {
      lr.run(1)
      val q = lr.forward(i)
      pps(i) = lr.loss(q)
      i += 1
    }
    println((System.currentTimeMillis() - startedAt) / 1e3)
    pps.foreach(println)

    val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs }
    assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.05)
    assert(pps.head - pps.last > 0)
  }

  test("LogisticRegression SGD") {
    val zenHome = sys.props.getOrElse("zen.test.home", fail("zen.test.home is not set!"))
    val dataSetFile = classOf[LogisticRegressionSuite].getClassLoader().getResource("binary_classification_data.txt").toString()
    val dataSet = MLUtils.loadLibSVMFile(sc, dataSetFile)
    val maxIter = 10
    val stepSize = 1
    val trainDataSet = dataSet.zipWithIndex().map { case (LabeledPoint(label, features), id) =>
      val newLabel = if (label > 0.0) 1.0 else 0
      (id, LabeledPoint(newLabel, features))
    }
    val lr = new LogisticRegressionSGD(trainDataSet, stepSize)
    val pps = new Array[Double](maxIter)
    var i = 0
    val startedAt = System.currentTimeMillis()
    while (i < maxIter) {
      lr.run(1)
      val margin = lr.forward(i)
      pps(i) = lr.loss(margin)
      i += 1
    }
    println((System.currentTimeMillis() - startedAt) / 1e3)
    pps.foreach(println)

    val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs }
    assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.05)
    assert(pps.head - pps.last > 0)
  }
} 
Example 116
Source File: NewtsSuite.scala    From newts   with Apache License 2.0 5 votes vote down vote up
package newts

import cats.instances.AllInstances
import newts.syntax.AllSyntax
import org.scalacheck.{Arbitrary, Cogen}
import org.scalacheck.Arbitrary.arbitrary
import org.scalatest.prop.GeneratorDrivenPropertyChecks
import org.scalatest.{FunSuite, Matchers}
import org.typelevel.discipline.scalatest.Discipline

trait NewtsSuite extends FunSuite
  with Matchers
  with GeneratorDrivenPropertyChecks
  with Discipline
  with AllSyntax
  with AllInstances
  with cats.syntax.AllSyntax
  with ArbitraryInstances

trait ArbitraryInstances {
  def arbNewtype[S, A: Arbitrary](implicit newtype: Newtype.Aux[S, A]): Arbitrary[S] =
    Arbitrary(arbitrary[A].map(newtype.wrap))

  def cogenNewtype[S, A: Cogen](implicit newtype: Newtype.Aux[S, A]): Cogen[S] =
    Cogen[A].contramap(newtype.unwrap)

  implicit val allArbitrary: Arbitrary[All] = arbNewtype[All, Boolean]
  implicit val anyArbitrary: Arbitrary[Any] = arbNewtype[Any, Boolean]
  implicit def multArbitrary[A:Arbitrary]: Arbitrary[Mult[A]] = arbNewtype[Mult[A], A]
  implicit def dualArbitrary[A: Arbitrary]: Arbitrary[Dual[A]] = arbNewtype[Dual[A], A]
  implicit def firstArbitrary[A: Arbitrary]: Arbitrary[First[A]] = arbNewtype[First[A], A]
  implicit def lastArbitrary[A: Arbitrary]: Arbitrary[Last[A]] = arbNewtype[Last[A], A]
  implicit def firstOptionArbitrary[A: Arbitrary]: Arbitrary[FirstOption[A]] = arbNewtype[FirstOption[A], Option[A]]
  implicit def lastOptionArbitrary[A: Arbitrary]: Arbitrary[LastOption[A]]  = arbNewtype[LastOption[A], Option[A]]
  implicit def minArbitrary[A: Arbitrary]: Arbitrary[Min[A]]  = arbNewtype[Min[A], A]
  implicit def maxArbitrary[A: Arbitrary]: Arbitrary[Max[A]]  = arbNewtype[Max[A], A]
  implicit def zipListArbitrary[A: Arbitrary]: Arbitrary[ZipList[A]] = arbNewtype[ZipList[A], List[A]]
  implicit def backwardsArbitrary[F[_], A](implicit ev: Arbitrary[F[A]]): Arbitrary[Backwards[F, A]] = arbNewtype[Backwards[F, A], F[A]]
  implicit def reverseArbitrary[F[_], A](implicit ev: Arbitrary[F[A]]): Arbitrary[Reverse[F, A]] = arbNewtype[Reverse[F, A], F[A]]

  implicit val allCogen: Cogen[All] = cogenNewtype[All, Boolean]
  implicit val anyCogen: Cogen[Any] = cogenNewtype[Any, Boolean]
  implicit def multCogen[A: Cogen]: Cogen[Mult[A]] = cogenNewtype[Mult[A], A]
  implicit def dualCogen[A: Cogen]: Cogen[Dual[A]] = cogenNewtype[Dual[A], A]
  implicit def firstCogen[A: Cogen]: Cogen[First[A]] = cogenNewtype[First[A], A]
  implicit def lastCogen[A: Cogen]: Cogen[Last[A]] = cogenNewtype[Last[A], A]
  implicit def firstOptionCogen[A: Cogen]: Cogen[FirstOption[A]] = cogenNewtype[FirstOption[A], Option[A]]
  implicit def lastOptionCogen[A: Cogen] : Cogen[LastOption[A]]  = cogenNewtype[LastOption[A], Option[A]]
  implicit def minOptionCogen[A: Cogen] : Cogen[Min[A]]  = cogenNewtype[Min[A], A]
  implicit def maxOptionCogen[A: Cogen] : Cogen[Max[A]]  = cogenNewtype[Max[A], A]
  implicit def zipListCogen[A: Cogen]: Cogen[ZipList[A]] = cogenNewtype[ZipList[A], List[A]]
  implicit def backwardsCogen[F[_], A](implicit ev: Cogen[F[A]]): Cogen[Backwards[F, A]] = cogenNewtype[Backwards[F, A], F[A]]
  implicit def reverseCogen[F[_], A](implicit ev: Cogen[F[A]]): Cogen[Reverse[F, A]] = cogenNewtype[Reverse[F, A], F[A]]
} 
Example 117
Source File: ImmutableCollectionsUnitTest.scala    From scala-tutorials   with MIT License 5 votes vote down vote up
package com.baeldung.scala.mutability

import org.scalatest.FunSuite

class ImmutableCollectionsUnitTest extends FunSuite {
  test("Immutable collections will create new instance if we add or update the elements") {
    val pets = Seq("Cat", "Dog")
    val myPets = pets :+ "Hamster"
    val notPets = pets ++ List("Giraffe", "Elephant")
    val yourPets = pets.updated(0, "Mice")

    assert(pets == Seq("Cat", "Dog"))
    assert(myPets == Seq("Cat", "Dog", "Hamster"))
    assert(notPets == Seq("Cat", "Dog", "Giraffe", "Elephant"))
    assert(yourPets == Seq("Mice", "Dog"))
  }
} 
Example 118
Source File: ImmutabilityCarUnitTest.scala    From scala-tutorials   with MIT License 5 votes vote down vote up
package com.baeldung.scala.mutability

import org.scalatest.FunSuite

class ImmutabilityCarUnitTest extends FunSuite {
  test("Mutable vs Immutable variables") {
    val pi = 3.14
    // pi = 4 // Compile error: Reassignment to val

    var myWeight = 60
    assert(myWeight == 60)
    myWeight = 65
    assert(myWeight == 65)
  }

  test("Immutable car cannot be changed") {
    val myCar = new ImmutabilityCar("blue", 4, "diesel")
    myCar.call()

    myCar.engine = "electric"
    assert(myCar.engine == "electric")
    myCar.call()
  }
} 
Example 119
Source File: MutableCollectionsUnitTest.scala    From scala-tutorials   with MIT License 5 votes vote down vote up
package com.baeldung.scala.mutability

import org.scalatest.FunSuite
import scala.collection.mutable.ArrayBuffer

class MutableCollectionsUnitTest extends FunSuite {
  test("Mutable collection can be added with new elements") {
    val breakfasts = ArrayBuffer("Sandwich", "Salad")

    breakfasts += "Bagels"
    assert(breakfasts == ArrayBuffer("Sandwich", "Salad", "Bagels"))

    breakfasts ++= Seq("PB & J", "Pancake")
    assert(breakfasts == ArrayBuffer("Sandwich", "Salad", "Bagels", "PB & J", "Pancake"))
  }

  test("Mutable collection's elements can be updated") {
    val breakfasts = ArrayBuffer("Sandwich", "Salad", "Bagels", "PB & J", "Pancake")
    breakfasts.update(2, "Steak")
    assert(breakfasts == ArrayBuffer("Sandwich", "Salad", "Steak", "PB & J", "Pancake"))
  }

  test("Mutable collection elements can be removed") {
    val breakfasts = ArrayBuffer("Sandwich", "Salad", "Steak", "PB & J", "Pancake")

    breakfasts -= "PB & J"
    assert(breakfasts == ArrayBuffer("Sandwich", "Salad", "Steak", "Pancake"))

    breakfasts -= "Fried rice"
    assert(breakfasts == ArrayBuffer("Sandwich", "Salad", "Steak", "Pancake"))
  }

  test("Array can be updated but not added") {
    val lunches = Array("Pasta", "Rice", "Hamburger")

    lunches.update(0, "Noodles")
    assert(lunches sameElements Array("Noodles", "Rice", "Hamburger"))
  }
} 
Example 120
Source File: ListFunSuite.scala    From scala-tutorials   with MIT License 5 votes vote down vote up
package com.baeldung.scala.scalatest

import org.scalatest.FunSuite

class ListFunSuite extends FunSuite {

  test("An empty List should have size 0") {
    assert(List.empty.size == 0)
  }

  test("Accessing invalid index should throw IndexOutOfBoundsException") {
    val fruit = List("Banana", "Pineapple", "Apple")
    assert(fruit.head == "Banana")
    assertThrows[IndexOutOfBoundsException] {
      fruit(5)
    }
  }

} 
Example 121
Source File: LazyValUnitTest.scala    From scala-tutorials   with MIT License 5 votes vote down vote up
package com.baeldung.scala.lazyval

import org.scalatest.FunSuite
import org.scalatest.Matchers._

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent.{Future, _}
import com.baeldung.scala.lazyval

class LazyValUnitTest extends FunSuite {

  test("lazy val is computed only once") {
    //given
    val lazyVal = new LazyVal
    lazyVal.getMemberNo //initialize the lazy val
    lazyVal.age shouldBe 28

    //when
    lazyVal.getMemberNo

    //then
    lazyVal.age shouldBe 28
  }

  test("lazy vals should execute sequentially in an instance ") {
    //given
    val futures = Future.sequence(Seq(
      Future {
        LazyValStore.squareOf5
      },
      Future {
        LazyValStore.squareOf6
      }))

    //when
    val result = Await.result(futures, 5.second)

    //then
    result should contain(25)
    result should contain(36)
  }
} 
Example 122
Source File: ColumnMetadataTest.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.vector

import java.util.regex.Pattern

import org.apache.spark.sql.types.DecimalType

import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Gen
import org.scalacheck.Gen.{choose, identifier}
import org.scalacheck.Prop.{forAll, propBoolean}
import org.scalatest.{FunSuite, Matchers}

import com.actian.spark_vector.test.tags.RandomizedTest

class ColumnMetadataTest extends FunSuite with Matchers {
  // Generate random column metadata and ensure the resultant StructField's are valid
  test("generated", RandomizedTest) {
    forAll(columnMetadataGen)(colMD => {
      assertColumnMetadata(colMD)
    }).check
  }

  val milliSecsPattern = Pattern.compile(".*\\.(S*)")

  def assertColumnMetadata(columnMD: ColumnMetadata): Boolean = {
    val structField = columnMD.structField
    structField.dataType match {
      // For decimal type, ensure the scale and precision match
      case decType: DecimalType =>
        decType.precision should be(columnMD.precision)
        decType.scale should be(columnMD.scale)
      case _ =>
    }
    true
  }

  val columnMetadataGen: Gen[ColumnMetadata] =
    for {
      name <- identifier
      typeName <- VectorTypeGen.vectorJdbcTypeGen
      nullable <- arbitrary[Boolean]
      precision <- choose(0, 20)
      scale <- choose(0, Math.min(20, precision))
    } yield ColumnMetadata(name, typeName, nullable, precision, scale)
} 
Example 123
Source File: VectorConnectionPropertiesTest.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.vector

import org.scalatest.{FunSuite, Matchers}
import org.scalatest.prop.PropertyChecks

class VectorConnectionPropertiesTest extends FunSuite with Matchers with PropertyChecks {
  val validCombos = Table(
    ("host", "instance", "database", "user", "password", "expectedURL"),
    ("host.com", "VH", "db", Some("user"), Some("pw"), "jdbc:ingres://host.com:VH7/db"),
    ("some.host.com", "VW", "db", None, None, "jdbc:ingres://some.host.com:VW7/db"),
    ("justhost", "99", "mydatabase", None, None, "jdbc:ingres://justhost:997/mydatabase"))

  val invalidCombos = Table(
    ("host", "instance", "database"),
    (null, "VW", "database"),
    ("", "VW", "database"),
    ("host.com", null, "db"),
    ("host.com", "", "db"),
    ("host.com", "VW", null),
    ("host.com", "VW", ""))

  test("valid URL and values") {
    forAll(validCombos) { (host: String, instance: String, database: String, user: Option[String], password: Option[String], expectedURL: String) =>
      // With user & password
      val props = VectorConnectionProperties(host, instance, database, user, password)
      validate(props, host, instance, database, user, password, expectedURL)

      // Without user & password
      val props2 = VectorConnectionProperties(host, instance, database)
      validate(props2, host, instance, database, None, None, expectedURL)
    }
  }

  test("invalid vector properties") {
    forAll(invalidCombos) { (host: String, instance: String, database: String) =>
      a[IllegalArgumentException] should be thrownBy {
        VectorConnectionProperties(host, instance, database)
      }
    }
  }

  private def validate(props: VectorConnectionProperties, host: String, instance: String, database: String, user: Option[String], password: Option[String], expectedURL: String): Unit = {
    props.toJdbcUrl should be(expectedURL)
    props.host should be(host)
    props.instance should be(instance)
    props.database should be(database)
    props.user should be(user)
    props.password should be(password)
  }
} 
Example 124
Source File: TableSchemaGeneratorTest.scala    From spark-vector   with Apache License 2.0 5 votes vote down vote up
package com.actian.spark_vector.vector

import org.apache.spark.sql.types._
import org.scalacheck.Gen.identifier
import org.scalacheck.Shrink
import org.scalatest.{ FunSuite, Inspectors, Matchers }
import org.scalatest.prop.PropertyChecks

import com.actian.spark_vector.vector.VectorJDBC.withJDBC;
import com.actian.spark_vector.DataTypeGens.schemaGen
import com.actian.spark_vector.test.IntegrationTest
import com.actian.spark_vector.test.tags.RandomizedTest

@IntegrationTest
class TableSchemaGeneratorTest extends FunSuite with Matchers with PropertyChecks with VectorFixture {
  import com.actian.spark_vector.DataTypeGens._
  import com.actian.spark_vector.vector.TableSchemaGenerator._
  import org.scalacheck.Gen._

  val defaultFields: Seq[StructField] = Seq(
    StructField("a", BooleanType, true),
    StructField("b", ByteType, false),
    StructField("c", ShortType, true),
    StructField("d", IntegerType, false),
    StructField("e", LongType, true),
    StructField("f", FloatType, false),
    StructField("g", DoubleType, true),
    StructField("h", DecimalType(10, 2), false),
    StructField("i", DateType, true),
    StructField("j", TimestampType, false),
    StructField("k", StringType, true))

  val defaultSchema = StructType(defaultFields)

  test("table schema") {
    withJDBC(connectionProps)(cxn => {
      cxn.autoCommit(false)
      assertSchemaGeneration(cxn, "testtable", defaultSchema)
    })
  }

  test("table schema/gen", RandomizedTest) {
    withJDBC(connectionProps)(cxn => {
      cxn.autoCommit(false)
      forAll(identifier, schemaGen)((name, schema) => {
        assertSchemaGeneration(cxn, name, schema)
      })(PropertyCheckConfig(minSuccessful = 5), Shrink.shrinkAny[String], Shrink.shrinkAny[StructType])
    })
  }

  private def assertSchemaGeneration(cxn: VectorJDBC, name: String, schema: StructType): Unit = {
    val sql = generateTableSQL(name, schema)
    try {
      cxn.executeStatement(sql)
      val columnsAsFields = cxn.columnMetadata(name).map(_.structField)
      columnsAsFields.size should be(schema.fields.length)
      Inspectors.forAll(columnsAsFields.zip(schema.fields)) {
        case (columnField, origField) => {
          columnField.name should be(origField.name.toLowerCase)
          columnField.dataType should be(origField.dataType)
          columnField.nullable should be(origField.nullable)
          // TODO ensure field metadata consistency
        }
      }
      cxn.dropTable(name)
    } finally {
      cxn.rollback()
    }
  }
} 
Example 125
Source File: CircleSuite.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.spatial

import org.scalatest.FunSuite


class CircleSuite extends FunSuite{
  test("Circle: intersects and minDist with other Circle"){
    val c1 = Circle(Point(Array(-2.0, 0.0)), 2.0)
    val c2 = Circle(Point(Array(-1.5, 0.0)), 0.5)
    val c3 = Circle(Point(Array(0.0, 1.0)), 1.0)
    val c4 = Circle(Point(Array(1.0, 0.0)), 1.0)
    val c5 = Circle(Point(Array(2.0, 1.0)), 1.0)

    assert(c1.intersects(c2))
    assert(c1.intersects(c3))
    assert(c1.intersects(c4))
    assert(!c1.intersects(c5))

    assert(Math.abs(c1.minDist(c2)) < 1e-8)
    assert(Math.abs(c1.minDist(c3)) < 1e-8)
    assert(Math.abs(c1.minDist(c4)) < 1e-8)
    assert(Math.abs(c1.minDist(c5) - (Math.sqrt(17.0) - 3.0)) < 1e-8)
  }

  val c = Circle(Point(Array(0.0, 0.0)), 1.0)
  test("Circle: intersects and minDist with Point"){
    val p1 = Point(Array(0.5, 0.0))
    val p2 = Point(Array(1.0, 0.0))
    val p3 = Point(Array(1.5, 0.0))

    assert(c.intersects(p1))
    assert(c.intersects(p2))
    assert(!c.intersects(p3))

    assert(Math.abs(c.minDist(p1)) < 1e-8)
    assert(Math.abs(c.minDist(p2)) < 1e-8)
    assert(Math.abs(c.minDist(p3) - 0.5) < 1e-8)
  }
  test("Circle: intersects and minDist with LineSegment"){
    val l1 = LineSegment(Point(Array(0.0, 0.0)), Point(Array(1.0, 1.0)))
    val l2 = LineSegment(Point(Array(1.0, 0.0)), Point(Array(1.0, 1.0)))
    val l3 = LineSegment(Point(Array(2.0, 0.0)), Point(Array(1.0, 1.0)))

    assert(c.intersects(l1))
    assert(c.intersects(l2))
    assert(!c.intersects(l3))

    assert(Math.abs(c.minDist(l1)) < 1e-8)
    assert(Math.abs(c.minDist(l2)) < 1e-8)
    assert(Math.abs(c.minDist(l3) - (Math.sqrt(2.0) - 1.0)) < 1e-8)
  }
  test("Circle: intersects and minDist with MBR"){
    val m1 = MBR(Point(Array(0.0, 0.0)), Point(Array(1.0, 1.0)))
    val m2 = MBR(Point(Array(1.0, 0.0)), Point(Array(2.0, 1.0)))
    val m3 = MBR(Point(Array(2.0, 0.0)), Point(Array(3.0, 1.0)))

    assert(c.intersects(m1))
    assert(c.intersects(m2))
    assert(!c.intersects(m3))

    assert(Math.abs(c.minDist(m1)) < 1e-8)
    assert(Math.abs(c.minDist(m2)) < 1e-8)
    assert(Math.abs(c.minDist(m3) - 1.0) < 1e-8)
  }
  test("Circle: intersects and minDist with Polygon"){
    val ply1 = Polygon.apply(Array(Point(Array(-1.0, -1.0)), Point(Array(1.0, -1.0)),
      Point(Array(0.0, 1.0)), Point(Array(-1.0, -1.0))))
    val ply2 = Polygon.apply(Array(Point(Array(1.0, 0.0)), Point(Array(2.0, 0.0)),
      Point(Array(2.0, 1.0)), Point(Array(1.0, 0.0))))
    val ply3 = Polygon.apply(Array(Point(Array(2.0, 0.0)), Point(Array(3.0, 0.0)),
      Point(Array(3.0, 1.0)), Point(Array(2.0, 0.0))))

    assert(c.intersects(ply1))
    assert(c.intersects(ply2))
    assert(!c.intersects(ply3))

    assert(Math.abs(c.minDist(ply1)) < 1e-8)
    assert(Math.abs(c.minDist(ply2)) < 1e-8)
    assert(Math.abs(c.minDist(ply3) - 1.0) < 1e-8)
  }

  test("Circle: Construct MBR"){
    assert(c.getMBR == MBR(Point(Array(-1.0, -1.0)), Point(Array(1.0, 1.0))))
  }
} 
Example 126
Source File: DistSuite.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.spatial

import org.scalatest.FunSuite


class DistSuite extends FunSuite{
  test("Dist: furthest distance from a point to an MBR"){
    val m = MBR(Point(Array(0.0, 0.0)), Point(Array(2.0, 2.0)))
    val p1 = Point(Array(1.0, 1.0))
    val p2 = Point(Array(0.0, 0.0))
    val p3 = Point(Array(1.0, 3.0))

    assert(Math.abs(Dist.furthest(p1, m) - Math.sqrt(2.0)) < 1e-8)
    assert(Math.abs(Dist.furthest(p2, m) - Math.sqrt(8.0)) < 1e-8)
    assert(Math.abs(Dist.furthest(p3, m) - Math.sqrt(10.0)) < 1e-8)
  }
} 
Example 127
Source File: MBRSuite.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.spatial

import org.scalatest.FunSuite


class MBRSuite extends FunSuite{
  test("MBR: intersects and minDist with other MBR"){
    val m1 = MBR(Point(Array(0.0, 0.0)), Point(Array(2.0, 2.0)))
    val m2 = MBR(Point(Array(1.0, 1.0)), Point(Array(3.0, 3.0)))
    val m3 = MBR(Point(Array(2.0, 0.0)), Point(Array(3.0, 1.0)))
    val m4 = MBR(Point(Array(1.0, 3.0)), Point(Array(2.0, 4.0)))

    assert(m1.intersects(m2))
    assert(m1.intersects(m3))
    assert(!m1.intersects(m4))

    assert(Math.abs(m1.minDist(m2)) < 1e-8)
    assert(Math.abs(m1.minDist(m3)) < 1e-8)
    assert(Math.abs(m1.minDist(m4) - 1.0) < 1e-8)
  }

  val m = MBR(Point(Array(0.0, 0.0)), Point(Array(2.0, 2.0)))
  test("MBR: intersects and minDist with Point"){
    val p1 = Point(Array(1.0,1.0))
    val p2 = Point(Array(2.0,  1.0))
    val p3 = Point(Array(3.0, 0.0))

    assert(m.intersects(p1))
    assert(m.intersects(p2))
    assert(!m.intersects(p3))

    assert(Math.abs(m.minDist(p1)) < 1e-8)
    assert(Math.abs(m.minDist(p2)) < 1e-8)
    assert(Math.abs(m.minDist(p3) - 1.0) < 1e-8)
  }
  test("MBR: intersects and minDist with LineSegment"){
    val l1 = LineSegment(Point(Array(1.0, 1.0)), Point(Array(3.0, 2.0)))
    val l2 = LineSegment(Point(Array(1.0, 3.0)), Point(Array(3.0, 1.0)))
    val l3 = LineSegment(Point(Array(3.0, 3.0)), Point(Array(4.0, 2.0)))

    assert(m.intersects(l1))
    assert(m.intersects(l2))
    assert(!m.intersects(l3))

    assert(Math.abs(m.minDist(l1)) < 1e-8)
    assert(Math.abs(m.minDist(l2)) < 1e-8)
    assert(Math.abs(m.minDist(l3) - Math.sqrt(2.0)) < 1e-8)
  }
  test("MBR: intersects and minDist with Circle"){
    val c1 = Circle(Point(Array(2.0, 1.0)), 1.0)
    val c2 = Circle(Point(Array(3.0, 3.0)), Math.sqrt(2.0))
    val c3 = Circle(Point(Array(4.0, 1.0)), 1.0)

    assert(m.intersects(c1))
    assert(m.intersects(c2))
    assert(!m.intersects(c3))

    assert(Math.abs(m.minDist(c1)) < 1e-8)
    assert(Math.abs(m.minDist(c2)) < 1e-8)
    assert(Math.abs(m.minDist(c3) - 1.0) < 1e-8)
  }
  test("MBR: intersects and minDist with Polygon"){
    val ply1 = Polygon.apply(Array(Point(Array(-1.0, -1.0)), Point(Array(1.0, -1.0)),
      Point(Array(0.0, 1.0)), Point(Array(-1.0, -1.0))))
    val ply2 = Polygon.apply(Array(Point(Array(-2.0, -1.0)), Point(Array(0.0, -1.0)),
      Point(Array(0.0, 1.0)), Point(Array(-2.0, -1.0))))
    val ply3 = Polygon.apply(Array(Point(Array(2.0, -1.0)), Point(Array(3.0, -1.0)),
      Point(Array(3.0, 0.0)), Point(Array(2.0, -1.0))))

    assert(m.intersects(ply1))
    assert(m.intersects(ply2))
    assert(!m.intersects(ply3))

    assert(Math.abs(m.minDist(ply1)) < 1e-8)
    assert(Math.abs(m.minDist(ply2)) < 1e-8)
    assert(Math.abs(m.minDist(ply3) - Math.sqrt(2.0) / 2.0) < 1e-8)
  }

  test("MBR: area"){
    assert(Math.abs(m.area - 4.0) < 1e-8)
  }

  test("MBR: ratio"){
    val m1 = MBR(Point(Array(1.0, 1.0)), Point(Array(3.0,  3.0)))
    assert(Math.abs(m.calcRatio(m1) - 0.25) < 1e-8)
  }

  test("MBR: getMBR"){
    assert(m.getMBR == MBR(Point(Array(0.0, 0.0)), Point(Array(2.0, 2.0))))
  }
} 
Example 128
Source File: PointSuite.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.spatial

import org.scalatest.FunSuite


class PointSuite extends FunSuite{
  val p1 = Point(Array(0.0, 0.0))
  val p2 = Point(Array(1.0, 1.0))
  val p3 = Point(Array(1.0, 1.0))

  test("Point: intersects with other Point"){
    assert(!p1.intersects(p2))
    assert(p2.intersects(p3))
  }
  test("Point: minDist with other point"){
    assert(Math.abs(p1.minDist(p2) - Math.sqrt(2.0)) < 1e-8)
    assert(p2.minDist(p3) == 0.0)
  }
  test("Point: equals another"){
    assert(!(p1 == p2))
    assert(p2 == p3)
  }
  test("Point: less than another"){
    assert(p1 <= p2)
    assert(p2 <= p3)
    assert(!(p2  <= p1))
  }
  test("Shift one point"){
    assert(p1.shift(1.0) == p2)
  }


  val p = Point(Array(0.0, 0.0))

  test("Point: intersects and minDist with MBR"){
    val m1 = MBR(Point(Array(-2.0, 1.0)), Point(Array(-1.0, 2.0)))
    val m2 = MBR(Point(Array(-1.0, -1.0)), Point(Array(0.0, 0.0)))
    val m3 = MBR(Point(Array(-1.0, -1.0)), Point(Array(1.0, 1.0)))

    assert(!p.intersects(m1))
    assert(p.intersects(m2))
    assert(p.intersects(m3))

    assert(Math.abs(p.minDist(m1) - Math.sqrt(2)) < 1e-8)
    assert(Math.abs(p.minDist(m2)) < 1e-8)
    assert(Math.abs(p.minDist(m3)) < 1e-8)
  }

  test("Point: intersects and minDist with Circle"){
    val c1 = Circle(Point(Array(2.0, 0.0)), 1.0)
    val c2 = Circle(Point(Array(1.0, 0.0)), 1.0)
    val c3 = Circle(Point(Array(0.0, 0.0)), 1.0)

    assert(!p.intersects(c1))
    assert(p.intersects(c2))
    assert(p.intersects(c3))

    assert(Math.abs(p.minDist(c1) - 1.0) < 1e-8)
    assert(Math.abs(p.minDist(c2)) < 1e-8)
    assert(Math.abs(p.minDist(c3)) < 1e-8)
  }
  test("Point: intersects and minDist with LineSegment"){
    val s1 = LineSegment(Point(Array(1.0, 1.0)), Point(Array(2.0, 1.0)))
    val s2 = LineSegment(Point(Array(0.0, 0.0)), Point(Array(1.0, 0.0)))
    val s3 = LineSegment(Point(Array(-1.0, 0.0)), Point(Array(1.0, 0.0)))

    assert(!p.intersects(s1))
    assert(p.intersects(s2))
    assert(p.intersects(s3))

    assert(Math.abs(p.minDist(s1) - Math.sqrt(2.0)) < 1e-8)
    assert(Math.abs(p.minDist(s2)) < 1e-8)
    assert(Math.abs(p.minDist(s3)) < 1e-8)
  }

  test("Point: intersects and minDist with Polygon"){
    val ply1 = Polygon.apply(Array(Point(Array(-1.0, -1.0)), Point(Array(1.0, -1.0)),
      Point(Array(0.0, 1.0)), Point(Array(-1.0, -1.0))))
    val ply2 = Polygon.apply(Array(Point(Array(0.0, 0.0)), Point(Array(4.0, 0.0)),
      Point(Array(3.0, 2.0)), Point(Array(0.0, 0.0))))
    val ply3 = Polygon.apply(Array(Point(Array(1.0, -1.0)), Point(Array(2.0, 1.0)),
      Point(Array(1.0, 1.0)), Point(Array(1.0, -1.0))))

    assert(p.intersects(ply1))
//    assert(p.intersects(ply2))
    assert(!p.intersects(ply3))

    assert(Math.abs(p.minDist(ply1)) < 1e-8)
    assert(Math.abs(p.minDist(ply2)) < 1e-8)
    assert(Math.abs(p.minDist(ply3) - 1.0) < 1e-8)
  }

  test("Point: Construct MBR"){
    assert(p.getMBR == MBR(Point(Array(0.0, 0.0)), Point(Array(0.0, 0.0))))
  }
} 
Example 129
Source File: PolygonSuite.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.spatial

import org.scalatest.FunSuite


class PolygonSuite extends FunSuite{
  test("Polygon: intersects and minDist with other Polygon"){
    val ply1 = Polygon.apply(Array(Point(Array(-1.0, -1.0)), Point(Array(1.0, -1.0)),
      Point(Array(0.0, 1.0)), Point(Array(-1.0, -1.0))))
    val ply2 = Polygon.apply(Array(Point(Array(0.0, 0.0)), Point(Array(2.0, 0.0)),
      Point(Array(1.0, 2.0)), Point(Array(0.0, 0.0))))
    val ply3 = Polygon.apply(Array(Point(Array(-1.0, 1.0)), Point(Array(1.0, 1.0)),
      Point(Array(0.0, 3.0)), Point(Array(-1.0, 1.0))))
    val ply4 = Polygon.apply(Array(Point(Array(1.0, 0.0)), Point(Array(2.0, 0.0)),
      Point(Array(2.0, 1.0)), Point(Array(1.0, 0.0))))

    assert(ply1.intersects(ply2))
    assert(ply1.intersects(ply3))
    assert(!ply1.intersects(ply4))

    assert(Math.abs(ply1.minDist(ply2)) < 1e-8)
    assert(Math.abs(ply1.minDist(ply3)) < 1e-8)
    assert(Math.abs(ply1.minDist(ply4) - 1.0 / Math.sqrt(5.0)) < 1e-8)
  }

  val ply = Polygon.apply(Array(Point(Array(0.0, 0.0)), Point(Array(2.0, 0.0)),
    Point(Array(1.0, 2.0)), Point(Array(0.0, 0.0))))

  test("Polygon: intersects and minDist with Point"){
    val p1 = Point(Array(1.0, 1.0))
    val p2 = Point(Array(1.5, 1.0))
    val p3 = Point(Array(2.0, 1.0))

    assert(ply.intersects(p1))
//    assert(ply.intersects(p2))
    assert(!ply.intersects(p3))

    assert(Math.abs(ply.minDist(p1)) < 1e-8)
    assert(Math.abs(ply.minDist(p2)) < 1e-8)
    assert(Math.abs(ply.minDist(p3) - 1.0 / Math.sqrt(5.0)) < 1e-8)
  }
  test("Polygon: intersects and minDist with LineSegment"){
    val s1 = LineSegment(Point(Array(1.0, 1.0)), Point(Array(2.0, 1.0)))
    val s2 = LineSegment(Point(Array(0.0, 2.0)), Point(Array(2.0, 2.0)))
    val s3 = LineSegment(Point(Array(3.0, 0.0)), Point(Array(3.0, 2.0)))

    assert(ply.intersects(s1))
    assert(ply.intersects(s2))
    assert(!ply.intersects(s3))

    assert(Math.abs(ply.minDist(s1)) < 1e-8)
    assert(Math.abs(ply.minDist(s2)) < 1e-8)
    assert(Math.abs(ply.minDist(s3) - 1.0) < 1e-8)
  }
  test("Polygon: intersects and minDist with Circle"){
    val c1 = Circle(Point(Array(1.0, 1.0)), 1.0)
    val c2 = Circle(Point(Array(1.0, -1.0)), 1.0)
    val c3 = Circle(Point(Array(4.0, 0.0)), 1.0)

    assert(ply.intersects(c1))
    assert(ply.intersects(c2))
    assert(!ply.intersects(c3))

    assert(Math.abs(ply.minDist(c1)) < 1e-8)
    assert(Math.abs(ply.minDist(c2)) < 1e-8)
    assert(Math.abs(ply.minDist(c3) - 1.0) < 1e-8)
  }

  test("Polygon: intersects and minDist with MBR"){
    val m1 = MBR(Point(Array(1.0, 1.0)), Point(Array(2.0, 2.0)))
    val m2 = MBR(Point(Array(0.0, 2.0)), Point(Array(2.0, 4.0)))
    val m3 = MBR(Point(Array(3.0, 0.0)), Point(Array(4.0, 1.0)))

    assert(ply.intersects(m1))
    assert(ply.intersects(m2))
    assert(!ply.intersects(m3))

    assert(Math.abs(ply.minDist(m1)) < 1e-8)
    assert(Math.abs(ply.minDist(m2)) < 1e-8)
    assert(Math.abs(ply.minDist(m3) - 1.0) < 1e-8)
  }

  test("Polygon: construct MBR"){
    assert(ply.getMBR == MBR(Point(Array(0.0, 0.0)), Point(Array(2.0, 2.0))))
  }
} 
Example 130
Source File: ShapeSuite.scala    From Simba   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.simba.spatial

import org.scalatest.FunSuite
import com.vividsolutions.jts.geom.{GeometryFactory, Polygon => JTSPolygon}
import com.vividsolutions.jts.io.WKTReader


class ShapeSuite extends FunSuite{
  def intersects(s1: Shape, s2: Shape): Boolean = {
    s1.intersects(s2)
  }

  def minDist(s1: Shape,  s2: Shape): Double = {
    s1.minDist(s2)
  }

  test("Shape: Abstract Functions"){
    val p = Point(Array(-1.0, 0.0))
    val s = LineSegment(Point(Array(-1.0, 0.0)), Point(Array(1.0, 1.0)))
    val m = MBR(Point(Array(0.0, 0.0)), Point(Array(2.0, 2.0)))
    val c = Circle(Point(Array(0.0, -1.0)), 1.0)

    assert(!intersects(p, c))
    assert(intersects(s, m))

    assert(Math.abs(minDist(p, c) - (Math.sqrt(2.0) - 1.0)) < 1e-8)
    assert(Math.abs(minDist(s, m)) < 1e-8)
  }

  test("Shape: apply Geometry"){
    val gf = new GeometryFactory()
    val reader = new WKTReader( gf )

    val point = reader.read("POINT (0.0 0.0)")
    assert(Shape.apply(point) == null)

    val ply = Shape.apply(reader.read("POLYGON((2.0 1.0, 3.0 0.0, 4.0 1.0, 3.0 2.0, 2.0 1.0))"))
    assert(ply != null)
    val p = Point(Array(2.0, 0.0))
    assert(Math.abs(ply.minDist(p) - Math.sqrt(2.0) / 2.0) < 1e-8)
  }
} 
Example 131
Source File: OrcSchemaCompatibilityTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.orc

import io.eels.schema._
import org.apache.orc.TypeDescription
import org.scalatest.{FunSuite, Matchers}

// tests that the eel <-> orc schemas are compatible
class OrcSchemaCompatibilityTest extends FunSuite with Matchers {

  test("orc schemas should be cross compatible with eel structs") {

    val schema = TypeDescription.createStruct()
      .addField("binary", TypeDescription.createBinary())
      .addField("boolean", TypeDescription.createBoolean())
      .addField("byte", TypeDescription.createByte())
      .addField("char", TypeDescription.createChar().withMaxLength(8))
      .addField("date", TypeDescription.createDate())
      .addField("decimal", TypeDescription.createDecimal().withScale(2).withPrecision(4))
      .addField("double", TypeDescription.createDouble())
      .addField("float", TypeDescription.createFloat())
      .addField("int", TypeDescription.createInt())
      .addField("long", TypeDescription.createLong())
      .addField("timestamp", TypeDescription.createTimestamp())
      .addField("varchar", TypeDescription.createVarchar().withMaxLength(222))
      .addField("map", TypeDescription.createMap(TypeDescription.createString(), TypeDescription.createBoolean()))
      .addField("array", TypeDescription.createList(TypeDescription.createString()))
      .addField("struct", TypeDescription.createStruct()
        .addField("a", TypeDescription.createString)
        .addField("b", TypeDescription.createBoolean()))

    val structType = StructType(
      Field("binary", BinaryType, true),
      Field("boolean", BooleanType, true),
      Field("byte", ByteType.Signed, true),
      Field("char", CharType(8), true),
      Field("date", DateType, true),
      Field("decimal", DecimalType(4, 2), true),
      Field("double", DoubleType, true),
      Field("float", FloatType, true),
      Field("int", IntType.Signed, true),
      Field("long", LongType.Signed, true),
      Field("timestamp", TimestampMillisType, true),
      Field("varchar", VarcharType(222), true),
      Field("map", MapType(StringType, BooleanType), true),
      Field("array", ArrayType(StringType), true),
      Field("struct", StructType(Field("a", StringType), Field("b", BooleanType)), true)
    )

    OrcSchemaFns.fromOrcType(schema) shouldBe structType
    OrcSchemaFns.toOrcSchema(structType) shouldBe schema
  }
} 
Example 132
Source File: TestTabulizer.scala    From Mastering-Machine-Learning-with-Spark-2.x   with MIT License 5 votes vote down vote up
package com.packtpub.mmlwspark.utils

import com.packtpub.mmlwspark.utils.Tabulizer.table
import org.scalatest.FunSuite


class TestTabulizer extends FunSuite {
  test("table sort") {
     println(
       s"""GBM Model: Grid results:
          ~${table(Seq("iterations, depth, learningRate", "AUC", "error"), gbmResults.sortBy(-_._2).take(10), format = Map(1 -> "%.3f", 2 -> "%.3f"))}
        """.stripMargin('~'))

  }

  val gbmResults = Seq(
    ((5,2,0.1),0.635,0.363),
    ((5,2,0.01),0.631,0.370),
    ((5,2,0.001),0.631,0.370),
    ((5,3,0.1),0.662,0.338),
    ((5,3,0.01),0.660,0.343),
    ((5,3,0.001),0.640,0.367),
    ((5,5,0.1),0.686,0.312),
    ((5,5,0.01),0.673,0.326),
    ((5,5,0.001),0.662,0.335),
    ((5,7,0.1),0.694,0.304),
    ((5,7,0.01),0.683,0.314),
    ((5,7,0.001),0.681,0.316),
    ((10,2,0.1),0.641,0.356),
    ((10,2,0.01),0.631,0.370),
    ((10,2,0.001),0.631,0.370),
    ((10,3,0.1),0.672,0.326),
    ((10,3,0.01),0.661,0.341),
    ((10,3,0.001),0.640,0.367),
    ((10,5,0.1),0.695,0.303),
    ((10,5,0.01),0.676,0.323),
    ((10,5,0.001),0.662,0.335),
    ((10,7,0.1),0.702,0.297),
    ((10,7,0.01),0.684,0.313),
    ((10,7,0.001),0.681,0.316),
    ((50,2,0.1),0.684,0.313),
    ((50,2,0.01),0.635,0.363),
    ((50,2,0.001),0.631,0.370),
    ((50,3,0.1),0.700,0.298),
    ((50,3,0.01),0.663,0.336),
    ((50,3,0.001),0.661,0.342),
    ((50,5,0.1),0.714,0.285),
    ((50,5,0.01),0.688,0.310),
    ((50,5,0.001),0.674,0.324),
    ((50,7,0.1),0.716,0.283),
    ((50,7,0.01),0.694,0.304),
    ((50,7,0.001),0.684,0.314),
    ((100,2,0.1),0.701,0.297),
    ((100,2,0.01),0.641,0.356),
    ((100,2,0.001),0.631,0.370),
    ((100,3,0.1),0.709,0.289),
    ((100,3,0.01),0.671,0.327),
    ((100,3,0.001),0.660,0.343),
    ((100,5,0.1),0.721,0.277),
    ((100,5,0.01),0.698,0.300),
    ((100,5,0.001),0.677,0.322),
    ((100,7,0.1),0.720,0.278),
    ((100,7,0.01),0.704,0.294),
    ((100,7,0.001),0.685,0.312)
  )
} 
Example 133
Source File: QuickCheckSuite.scala    From Principles-of-Reactive-Programming   with GNU General Public License v3.0 5 votes vote down vote up
package quickcheck

import org.scalatest.FunSuite

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Prop
import org.scalacheck.Prop._

import org.scalatest.exceptions.TestFailedException

object QuickCheckBinomialHeap extends QuickCheckHeap with BinomialHeap

@RunWith(classOf[JUnitRunner])
class QuickCheckSuite extends FunSuite with Checkers {
  def checkBogus(p: Prop) {
    var ok = false
    try {
      check(p)
    } catch {
      case e: TestFailedException =>
        ok = true
    }
    assert(ok, "A bogus heap should NOT satisfy all properties. Try to find the bug!")
  }

  test("Binomial heap satisfies properties.") {
    check(new QuickCheckHeap with BinomialHeap)
  }

  test("Bogus (1) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus1BinomialHeap)
  }

  test("Bogus (2) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus2BinomialHeap)
  }

  test("Bogus (3) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus3BinomialHeap)
  }

  test("Bogus (4) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus4BinomialHeap)
  }

  test("Bogus (5) binomial heap does not satisfy properties.") {
    checkBogus(new QuickCheckHeap with Bogus5BinomialHeap)
  }
} 
Example 134
Source File: RootNodesStorageTest.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.storage

import java.io.File

import encry.view.state.avlTree.utils.implicits.Instances._
import encry.modifiers.InstanceFactory
import encry.storage.VersionalStorage.{StorageKey, StorageValue, StorageVersion}
import encry.storage.levelDb.versionalLevelDB.{LevelDbFactory, VLDBWrapper, VersionalLevelDBCompanion}
import encry.utils.{EncryGenerator, FileHelper}
import encry.view.state.avlTree.AvlTree
import org.encryfoundation.common.utils.Algos
import org.encryfoundation.common.utils.TaggedTypes.Height
import org.iq80.leveldb.{DB, Options, ReadOptions}
import org.scalatest.{FunSuite, Matchers, PropSpec}
import scorex.utils.Random

import scala.util.{Random => SRandom}

class RootNodesStorageTest extends PropSpec with InstanceFactory with EncryGenerator with Matchers {

  def createAvl: AvlTree[StorageKey, StorageValue] = {
    val firstDir: File = FileHelper.getRandomTempDir
    val firstStorage: VLDBWrapper = {
      val levelDBInit = LevelDbFactory.factory.open(firstDir, new Options)
      VLDBWrapper(VersionalLevelDBCompanion(levelDBInit, settings.levelDB.copy(keySize = 33), keySize = 33))
    }
    val dir: File = FileHelper.getRandomTempDir
    val levelDb: DB = LevelDbFactory.factory.open(dir, new Options)
    AvlTree[StorageKey, StorageValue](firstStorage, RootNodesStorage.emptyRootStorage[StorageKey, StorageValue])
  }

  property("testRollback") {
    val avl: AvlTree[StorageKey, StorageValue] = createAvl
    val dir: File = FileHelper.getRandomTempDir
    val levelDb: DB = LevelDbFactory.factory.open(dir, new Options)
    val batch1 = levelDb.createWriteBatch()
    val readOptions1 = new ReadOptions()
    val rootNodesStorage = RootNodesStorage[StorageKey, StorageValue](levelDb, 10, dir)
    val (_, avlAfterInsertions, insertList) =
      (0 to SRandom.nextInt(1000) + 10).foldLeft(rootNodesStorage, avl, List.empty[(Height, (List[(StorageKey, StorageValue)], List[StorageKey]))]) {
      case ((rootStorage, previousAvl, insertionList), height) =>
        val version = StorageVersion @@ Random.randomBytes()
        val toInsert = (0 to SRandom.nextInt(100)).foldLeft(List.empty[(StorageKey, StorageValue)]) {
          case (list, _) => (StorageKey @@ Random.randomBytes() -> StorageValue @@ Random.randomBytes()) :: list
        }
        val previousInsertions = insertionList.lastOption.map(_._2._1).getOrElse(List.empty[(StorageKey, StorageValue)])
        val deletions = previousInsertions.take(1).map(_._1)
        val newAvl = previousAvl.insertAndDeleteMany(
          version,
          toInsert,
          deletions
        )
        val newRootStorage = rootStorage.insert(
          version,
          newAvl.rootNode,
          Height @@ height
        )
        (newRootStorage, newAvl, insertionList :+ (Height @@ height -> (toInsert -> deletions)))
    }
    val (_, rootNodeRestored) = rootNodesStorage.rollbackToSafePoint(insertList.dropWhile(_._1 != rootNodesStorage.safePointHeight).drop(1))
    (avlAfterInsertions.rootNode.hash sameElements rootNodeRestored.hash) shouldBe true
  }
} 
Example 135
Source File: BlockSerializerTest.scala    From EncryCore   with GNU General Public License v3.0 5 votes vote down vote up
package encry.modifiers.history

import encry.modifiers.mempool.TransactionFactory
import encry.settings.Settings
import encry.utils.{EncryGenerator, TestHelper}
import org.encryfoundation.common.crypto.equihash.EquihashSolution
import org.encryfoundation.common.modifiers.history._
import org.encryfoundation.common.utils.Algos
import org.encryfoundation.common.utils.TaggedTypes.ModifierId
import org.scalatest.FunSuite
import scorex.crypto.hash.Digest32
import scorex.utils.Random

class BlockSerializerTest extends FunSuite with EncryGenerator with Settings {

  test("testToBytes $ testFromBytes") {

    val blockHeader = Header(
      99: Byte,
      ModifierId @@ Random.randomBytes(),
      Digest32 @@ Random.randomBytes(),
      99999L,
      199,
      999L,
      settings.constants.InitialDifficulty,
      EquihashSolution(Seq(1, 2, 3)),
      Random.randomBytes()
    )

    val factory = TestHelper
    val keys = factory.genKeys(10)

    val fee = factory.Props.txFee
    val timestamp = 12345678L

    val txs = keys.map { k =>
      val useBoxes = IndexedSeq(factory.genAssetBox(k.publicImage.address.address))
      TransactionFactory.defaultPaymentTransactionScratch(k, fee,
        timestamp, useBoxes, randomAddress, factory.Props.boxValue)
    }

    val blockPayload = Payload(ModifierId @@ Array.fill(32)(19: Byte), txs)

    val block = Block(blockHeader,blockPayload)

    val blockSererialized = BlockSerializer.toBytes(block)

    val blockDeserealized = BlockSerializer.parseBytes(blockSererialized).get

    assert(Algos.hash(block.bytes) sameElements Algos.hash(blockDeserealized.bytes), "Block bytes mismatch.")
  }
} 
Example 136
Source File: TokenizerSuite.scala    From spark-nkp   with Apache License 2.0 5 votes vote down vote up
package com.github.uosdmlab.nkp

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{CountVectorizer, IDF}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfter, FunSuite}


class TokenizerSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {

  private var tokenizer: Tokenizer = _

  private val spark: SparkSession =
    SparkSession.builder()
      .master("local[2]")
      .appName("Tokenizer Suite")
      .getOrCreate

  spark.sparkContext.setLogLevel("WARN")

  import spark.implicits._

  override protected def afterAll(): Unit = {
    try {
      spark.stop
    } finally {
      super.afterAll()
    }
  }

  before {
    tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
  }

  private val df = spark.createDataset(
    Seq(
      "아버지가방에들어가신다.",
      "사랑해요 제플린!",
      "스파크는 재밌어",
      "나는야 데이터과학자",
      "데이터야~ 놀자~"
    )
  ).toDF("text")

  test("Default parameters") {
    assert(tokenizer.getFilter sameElements Array.empty[String])
  }

  test("Basic operation") {
    val words = tokenizer.transform(df)

    assert(df.count == words.count)
    assert(words.schema.fieldNames.contains(tokenizer.getOutputCol))
  }

  test("POS filter") {
    val nvTokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("nvWords")
      .setFilter("N", "V")

    val words = tokenizer.transform(df).join(nvTokenizer.transform(df), "text")

    assert(df.count == words.count)
    assert(words.schema.fieldNames.contains(nvTokenizer.getOutputCol))
    assert(words.where(s"SIZE(${tokenizer.getOutputCol}) < SIZE(${nvTokenizer.getOutputCol})").count == 0)
  }

  test("TF-IDF pipeline") {
    tokenizer.setFilter("N")

    val cntVec = new CountVectorizer()
      .setInputCol("words")
      .setOutputCol("tf")

    val idf = new IDF()
      .setInputCol("tf")
      .setOutputCol("tfidf")

    val pipe = new Pipeline()
      .setStages(Array(tokenizer, cntVec, idf))

    val pipeModel = pipe.fit(df)

    val result = pipeModel.transform(df)

    assert(result.count == df.count)

    val fields = result.schema.fieldNames
    assert(fields.contains(tokenizer.getOutputCol))
    assert(fields.contains(cntVec.getOutputCol))
    assert(fields.contains(idf.getOutputCol))

    result.show
  }
} 
Example 137
Source File: MetadataTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.operator.{MetadataTransformUtils, VectorCartesian}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FunSuite}

class MetadataTest extends FunSuite with BeforeAndAfter {

  var spark: SparkSession = _

  before {
    spark = SparkSession.builder().master("local").getOrCreate()
  }

  after {
    spark.close()
  }

  test("test_vector_cartesian") {
    val data = spark.read.format("libsvm")
      .option("numFeatures", "123")
      .load("data/a9a/a9a_123d_train_trans.libsvm")
      .persist()

    val cartesian = new VectorCartesian()
      .setInputCols(Array("features", "features"))
      .setOutputCol("cartesian_features")

    val assembler = new VectorAssembler()
      .setInputCols(Array("features", "cartesian_features"))
      .setOutputCol("assemble_features")

    val pipeline = new Pipeline()
      .setStages(Array(cartesian, assembler))

    val featureModel = pipeline.fit(data)
    val crossDF = featureModel.transform(data)

    crossDF.schema.fields.foreach { field =>
      println("name: " + field.name)
      println("metadata: " + field.metadata.toString())
    }
  }

  test("test_three_order_cartesian") {
    val data = spark.read.format("libsvm")
      .option("numFeatures", 8)
      .load("data/abalone/abalone_8d_train.libsvm")
      .persist()

    val cartesian = new VectorCartesian()
      .setInputCols(Array("features", "features"))
      .setOutputCol("f_f")

    val cartesian2 = new VectorCartesian()
      .setInputCols(Array("features", "f_f"))
      .setOutputCol("f_f_f")

    val pipeline = new Pipeline()
      .setStages(Array(cartesian, cartesian2))

    val crossDF = pipeline.fit(data).transform(data).persist()

    // first cartesian, the number of dimensions is 64
    println("first cartesian dimension = " + crossDF.select("f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).length)
    println(crossDF.select("f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).mkString(","))

    println()

    // second cartesian, the number of dimensions is 512
    println("second cartesian dimension = " + crossDF.select("f_f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).length)
    println(crossDF.select("f_f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).mkString(","))
  }
} 
Example 138
Source File: GPModelTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import breeze.linalg.{DenseMatrix, DenseVector}
import breeze.numerics.{cos, pow}
import com.tencent.angel.spark.automl.tuner.kernel.Matern5Iso
import com.tencent.angel.spark.automl.tuner.model.GPModel
import org.scalatest.FunSuite

class GPModelTest extends FunSuite {

  test("test_linear") {
    // Test linear: y=2*x
    val X = DenseMatrix((1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)).t
    val y = 2.0 * DenseVector(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)
    val z = DenseMatrix((2.5, 4.5, 6.5, 8.5, 10.0, 12.0)).t
    val truePredZ = 2.0 * DenseVector(2.5, 4.5, 6.5, 8.5, 10.0, 12.0)

    val covFunc = Matern5Iso()
    val initCovParams = DenseVector(1.0, 1.0)
    val initNoiseStdDev = 0.01

    val gpModel = GPModel(covFunc, initCovParams, initNoiseStdDev)
    gpModel.fit(X, y)

    println("Fitted covariance function params:")
    println(gpModel.covParams)
    println("Fitted noiseStdDev:")
    println(gpModel.noiseStdDev)
    println("\n")

    val prediction = gpModel.predict(z)
    println("Mean and Var:")
    println(prediction)
    println("True value:")
    println(truePredZ)
  }

  test("test_cosine") {
    // Test no_linear: y=cos(x)+1
    val X = DenseMatrix((1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)).t
    val y = cos(DenseVector(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)) + 1.0
    val z = DenseMatrix((2.5, 4.5, 6.5, 8.5, 10.0, 12.0)).t
    val truePredZ = cos(DenseVector(2.5, 4.5, 6.5, 8.5, 10.0, 12.0)) + 1.01

    val covFunc = Matern5Iso()
    val initCovParams = DenseVector(1.0, 1.0)
    val initNoiseStdDev = 0.01

    val gpModel = GPModel(covFunc, initCovParams, initNoiseStdDev)
    gpModel.fit(X, y)

    println("Fitted covariance function params:")
    println(gpModel.covParams)
    println("Fitted noiseStdDev:")
    println(gpModel.noiseStdDev)
    println("\n")

    val prediction = gpModel.predict(z)
    println("Mean and Var:")
    println(prediction)
    println("True value:")
    println(truePredZ)
  }

  test("testSquare") {
    // Test no_linear: y=x^2
    val X = DenseMatrix((1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0)).t
    val y = DenseVector(1.0, 4.0, 9.0, 16.0, 25.0, 36.0, 49.0, 64.0, 81.0)
    val z = DenseMatrix((2.5, 4.5, 6.5, 8.5, 10.0, 12.0)).t
    val truePredZ = pow(z, 2)

    val covFunc = Matern5Iso()
    val initCovParams = DenseVector(1.0, 1.0)
    val initNoiseStdDev = 0.01

    val gpModel = GPModel(covFunc, initCovParams, initNoiseStdDev)
    gpModel.fit(X, y)

    println("Fitted covariance function params:")
    println(gpModel.covParams)
    println("Fitted noiseStdDev:")
    println(gpModel.noiseStdDev)
    println("\n")

    val prediction = gpModel.predict(z)
    println("Mean and Var:")
    println(prediction)
    println("True value:")
    println(truePredZ)
  }
} 
Example 139
Source File: SquareDistTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import breeze.linalg.{DenseMatrix, DenseVector}
import com.tencent.angel.spark.automl.tuner.math.SquareDist
import org.junit.Assert._
import org.scalatest.FunSuite

class SquareDistTest extends FunSuite {

  test("test_XX_1D") {

    val x = DenseVector(1.0, 2.0, 3.0).toDenseMatrix.t
    val expected = DenseMatrix((0.0, 1.0, 4.0), (1.0, 0.0, 1.0), (4.0, 1.0, 0.0))
    assertEquals(expected, SquareDist(x, x))
  }

  test("test_XX_2D") {

    val x = DenseMatrix((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)).t
    val expected = DenseMatrix((0.0, 2.0, 8.0), (2.0, 0.0, 2.0), (8.0, 2.0, 0.0))
    assertEquals(expected, SquareDist(x, x))
  }

  test("test_XY_1D") {

    val x1 = DenseVector(1.0, 2.0, 3.0).toDenseMatrix.t
    val x2 = DenseVector(4.0, 5.0).toDenseMatrix.t

    val expected = DenseMatrix((9.0, 16.0), (4.0, 9.0), (1.0, 4.0))
    assertEquals(expected, SquareDist(x1, x2))
  }

  test("test_XY_2D") {

    val x1 = DenseMatrix((1.0, 2.0, 3.0), (4.0, 5.0, 6.0)).t
    val x2 = DenseMatrix((7.0, 8.0), (9.0, 10.0)).t

    val expected = DenseMatrix((61.0, 85.0), (41.0, 61.0), (25.0, 41.0))
    assertEquals(expected, SquareDist(x1, x2))
  }
} 
Example 140
Source File: PipelineTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import com.tencent.angel.spark.automl.feature.preprocess.{HashingTFWrapper, IDFWrapper, TokenizerWrapper}
import com.tencent.angel.spark.automl.feature.{PipelineBuilder, PipelineWrapper, TransformerWrapper}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FunSuite}

class PipelineTest extends FunSuite with BeforeAndAfter {

  var spark: SparkSession = _

  before {
    spark = SparkSession.builder().master("local").getOrCreate()
  }

  after {
    spark.close()
  }

  test("test_tfidf") {
    val sentenceData = spark.createDataFrame(Seq(
      (0.0, "Hi I heard about Spark"),
      (0.0, "I wish Java could use case classes"),
      (1.0, "Logistic regression models are neat")
    )).toDF("label", "sentence")

    val pipelineWrapper = new PipelineWrapper()

    val transformers = Array[TransformerWrapper](
      new TokenizerWrapper(),
      new HashingTFWrapper(20),
      new IDFWrapper()
    )

    val stages = PipelineBuilder.build(transformers)

    transformers.foreach { transformer =>
      val inputCols = transformer.getInputCols
      val outputCols = transformer.getOutputCols
      inputCols.foreach(print)
      print("    ")
      outputCols.foreach(print)
      println()
    }

    pipelineWrapper.setStages(stages)

    val model = pipelineWrapper.fit(sentenceData)

    val outputDF = model.transform(sentenceData)
    outputDF.select("outIDF").show()
    outputDF.select("outIDF").foreach { row =>
      println(row.get(0).getClass.getSimpleName)
      val arr = row.get(0)
      println(arr.toString)
    }
    outputDF.rdd.map(row => row.toString()).repartition(1)
      .saveAsTextFile("tmp/output/tfidf")
  }
} 
Example 141
Source File: BreezeOpTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import com.tencent.angel.spark.automl.tuner.math.BreezeOp._
import org.junit.Assert._
import org.scalatest.FunSuite

class BreezeOpTest extends FunSuite {

  test("test cartesian") {

    val a: Array[Double] = Array(1.0, 2.0)
    val b: Array[Double] = Array(3.0, 4.0)
    val c: Array[Array[Double]] = cartesian(a, b)
    val expected: Array[Array[Double]] = Array(Array(1.0, 3.0), Array(1.0, 4.0), Array(2.0, 3.0), Array(2.0, 4.0))

    println(c.deep.mkString("\n"))
    assertEquals(expected.deep.mkString("\n"), c.deep.mkString("\n"))
  }

  test("test_higher_cartesian") {

    val a: Array[Double] = Array(1.0, 2.0)
    val b: Array[Double] = Array(3.0, 4.0)
    val c: Array[Double] = Array(5.0, 6.0)
    val d: Array[Array[Double]] = cartesian(a, b)
    val e: Array[Array[Double]] = cartesian(d, c)
    val expected = Array(Array(1.0, 3.0, 5.0),
      Array(1.0, 3.0, 6.0),
      Array(1.0, 4.0, 5.0),
      Array(1.0, 4.0, 6.0),
      Array(2.0, 3.0, 5.0),
      Array(2.0, 3.0, 6.0),
      Array(2.0, 4.0, 5.0),
      Array(2.0, 4.0, 6.0))

    println(e.deep.mkString("\n"))
    assertEquals(expected.deep.mkString("\n"), e.deep.mkString("\n"))
  }

  test("test_cartesian_array") {

    val a: Array[Double] = Array(1.0, 2.0)
    val b: Array[Double] = Array(3.0, 4.0)
    val c: Array[Double] = Array(5.0, 6.0)
    val d: Array[Double] = Array(7.0, 8.0)
    val allArray = Array(a, b, c, d)
    var tmp: Array[Array[Double]] = cartesian(allArray(0), allArray(1))
    allArray.foreach { case a =>
      if (a != allArray(0) && a != allArray(1)) {
        tmp = cartesian(tmp, a)
      }
    }
    val expected = Array(Array(1.0, 3.0, 5.0, 7.0),
      Array(1.0, 3.0, 5.0, 8.0),
      Array(1.0, 3.0, 6.0, 7.0),
      Array(1.0, 3.0, 6.0, 8.0),
      Array(1.0, 4.0, 5.0, 7.0),
      Array(1.0, 4.0, 5.0, 8.0),
      Array(1.0, 4.0, 6.0, 7.0),
      Array(1.0, 4.0, 6.0, 8.0),
      Array(2.0, 3.0, 5.0, 7.0),
      Array(2.0, 3.0, 5.0, 8.0),
      Array(2.0, 3.0, 6.0, 7.0),
      Array(2.0, 3.0, 6.0, 8.0),
      Array(2.0, 4.0, 5.0, 7.0),
      Array(2.0, 4.0, 5.0, 8.0),
      Array(2.0, 4.0, 6.0, 7.0),
      Array(2.0, 4.0, 6.0, 8.0))

    println(tmp.deep.mkString("\n"))
    assertEquals(expected.deep.mkString("\n"), tmp.deep.mkString("\n"))
  }
} 
Example 142
Source File: TunerTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import com.tencent.angel.spark.automl.tuner.config.Configuration
import com.tencent.angel.spark.automl.tuner.parameter.ParamSpace
import com.tencent.angel.spark.automl.tuner.solver.Solver
import com.tencent.angel.spark.automl.tuner.trail.{TestTrail, Trail}
import org.apache.spark.ml.linalg.Vector
import org.scalatest.FunSuite

class TunerTest extends FunSuite {

  test("test_random") {
    val param1 = ParamSpace.fromConfigString("param1", "{2.0,3.0,4.0,5.0,6.0}")
    val param2 = ParamSpace.fromConfigString("param2", "{3:10:1}")
    val solver: Solver = Solver(Array(param1, param2), true, surrogate = "Random")
    val trail: Trail = new TestTrail()
    (0 until 10).foreach { iter =>
      println(s"------iteration $iter starts------")
      val configs: Array[Configuration] = solver.suggest()
      val results: Array[Double] = trail.evaluate(configs)
      solver.feed(configs, results)
    }
    val result: (Vector, Double) = solver.optimal
    solver.stop
    println(s"Best configuration ${result._1.toArray.mkString(",")}, best performance: ${result._2}")
  }

  test("test_grid") {
    val param1 = ParamSpace.fromConfigString("param1", "[1,10]")
    val param2 = ParamSpace.fromConfigString("param2", "[-5:5:10]")
    val solver: Solver = Solver(Array(param1, param2), true, surrogate = "Grid")
    val trail: Trail = new TestTrail()
    (0 until 10).foreach { iter =>
      println(s"------iteration $iter starts------")
      val configs: Array[Configuration] = solver.suggest()
      val results: Array[Double] = trail.evaluate(configs)
      solver.feed(configs, results)
    }
    val result: (Vector, Double) = solver.optimal
    solver.stop
    println(s"Best configuration ${result._1.toArray.mkString(",")}, best performance: ${result._2}")
  }

  test("test_gp") {
    val param1 = ParamSpace.fromConfigString("param1", "[1,10]")
    val param2 = ParamSpace.fromConfigString("param2", "[-5:5:10]")
    val param3 = ParamSpace.fromConfigString("param3", "{0.0,1.0,3.0,5.0}")
    val param4 = ParamSpace.fromConfigString("param4", "{-5:5:1}")
    val solver: Solver = Solver(Array(param1, param2, param3, param4), true, surrogate = "GaussianProcess")
    val trail: Trail = new TestTrail()
    (0 until 10).foreach { iter =>
      println(s"------iteration $iter starts------")
      val configs: Array[Configuration] = solver.suggest
      val results: Array[Double] = trail.evaluate(configs)
      solver.feed(configs, results)
    }
    val result: (Vector, Double) = solver.optimal
    solver.stop
    println(s"Best configuration ${result._1.toArray.mkString(",")}, best performance: ${result._2}")
  }

  test("test_rf") {
    val param1 = ParamSpace.fromConfigString("param1", "[1,10]")
    val param2 = ParamSpace.fromConfigString("param2", "[-5:5:10]")
    val param3 = ParamSpace.fromConfigString("param3", "{0.0,1.0,3.0,5.0}")
    val param4 = ParamSpace.fromConfigString("param4", "{-5:5:1}")
    val solver: Solver = Solver(Array(param1, param2, param3, param4), true, "RandomForest")
    val trail: Trail = new TestTrail()
    (0 until 10).foreach { iter =>
      println(s"------iteration $iter starts------")
      val configs: Array[Configuration] = solver.suggest
      val results: Array[Double] = trail.evaluate(configs)
      solver.feed(configs, results)
    }
    val result: (Vector, Double) = solver.optimal
    solver.stop
    println(s"Best configuration ${result._1.toArray.mkString(",")}, best performance: ${result._2}")
  }
} 
Example 143
Source File: X2PSuite.scala    From spark-tsne   with Apache License 2.0 5 votes vote down vote up
package com.github.saurfang.spark.tsne

import org.apache.spark.SharedSparkContext
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.scalatest.{FunSuite, Matchers}


class X2PSuite extends FunSuite with SharedSparkContext with Matchers {

  test("Test X2P against tsne.jl implementation") {
    val input = new RowMatrix(
      sc.parallelize(Seq(1 to 3, 4 to 6, 7 to 9, 10 to 12))
        .map(x => Vectors.dense(x.map(_.toDouble).toArray))
    )
    val output = X2P(input, 1e-5, 2).toRowMatrix().rows.collect().map(_.toArray.toList)
    println(output.toList)
    //output shouldBe List(List(0, .5, .5), List(.5, 0, .5), List(.5, .5, .0))
  }
} 
Example 144
Source File: BugDemonstrationTest.scala    From spark-tsne   with Apache License 2.0 5 votes vote down vote up
package com.github.saurfang.spark.tsne

import org.apache.spark.mllib.linalg.{Vectors, Vector}
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}


class BugDemonstrationTest extends FunSuite with Matchers with BeforeAndAfterAll {
  private var sparkSession : SparkSession = _
  override def beforeAll(): Unit = {
    super.beforeAll()
    sparkSession = SparkSession.builder().appName("BugTests").master("local[2]").getOrCreate()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    sparkSession.stop()
  }

  test("This demonstrates a bug was fixed in tsne-spark 2.1") {
    val sc = sparkSession.sparkContext

    val observations = sc.parallelize(
      Seq(
        Vectors.dense(1.0, 10.0, 100.0),
        Vectors.dense(2.0, 20.0, 200.0),
        Vectors.dense(3.0, 30.0, 300.0)
      )
    )

    // Compute column summary statistics.
    val summary: MultivariateStatisticalSummary = Statistics.colStats(observations)
    val expectedMean = Vectors.dense(2.0,20.0,200.0)
    val resultMean = summary.mean
    assertEqualEnough(resultMean, expectedMean)
    val expectedVariance = Vectors.dense(1.0,100.0,10000.0)
    assertEqualEnough(summary.variance, expectedVariance)
    val expectedNumNonZeros = Vectors.dense(3.0, 3.0, 3.0)
    assertEqualEnough(summary.numNonzeros, expectedNumNonZeros)
  }

  private def assertEqualEnough(sample: Vector, expected: Vector): Unit = {
    expected.toArray.zipWithIndex.foreach{ case(d: Double, i: Int) =>
      sample(i) should be (d +- 1E-12)
    }
  }
} 
Example 145
Source File: RedisSourceConfigSuite.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.spark.sql.redis.stream

import org.apache.spark.sql.redis._
import org.scalatest.{FunSuite, Matchers}


class RedisSourceConfigSuite extends FunSuite with Matchers {

  val group: String = "group55"

  test("testFromMap") {
    val config = RedisSourceConfig.fromMap(Map(
      StreamOptionStreamKeys -> "mystream1,mystream2,mystream3",
      StreamOptionStreamOffsets ->
        s"""
          |{
          |  "offsets":{
          |    "mystream1": {
          |      "groupName": "$group",
          |      "offset": "0-10"
          |    },
          |    "mystream2": {
          |       "groupName": "$group",
          |       "offset": "0-7"
          |    }
          |  }
          |}
        """.stripMargin,
      StreamOptionParallelism -> "2",
      StreamOptionGroupName -> group,
      StreamOptionConsumerPrefix -> "consumer"
    ))
    config shouldBe RedisSourceConfig(
      Seq(
        RedisConsumerConfig("mystream1", group, "consumer-1", 100, 500),
        RedisConsumerConfig("mystream1", group, "consumer-2", 100, 500),
        RedisConsumerConfig("mystream2", group, "consumer-1", 100, 500),
        RedisConsumerConfig("mystream2", group, "consumer-2", 100, 500),
        RedisConsumerConfig("mystream3", group, "consumer-1", 100, 500),
        RedisConsumerConfig("mystream3", group, "consumer-2", 100, 500)
      ),
      Some(RedisSourceOffset(Map(
        "mystream1" -> RedisConsumerOffset(group, "0-10"),
        "mystream2" -> RedisConsumerOffset(group, "0-7")
      )))
    )
  }
} 
Example 146
Source File: RedisSourceTest.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.spark.sql.redis.stream

import org.scalatest.{FunSuite, Matchers, OptionValues}


class RedisSourceTest extends FunSuite with Matchers with OptionValues {

  test("testGetOffsetRanges") {
    val startOffsets = RedisSourceOffset(Map("mystream" -> RedisConsumerOffset("group55", "0-0")))
    val endOffsets = RedisSourceOffset(Map("mystream" -> RedisConsumerOffset("group55", "0-1")))
    val consumerConfig = RedisConsumerConfig("mystream", "group55", "consumer", 1000, 100)
    val consumerConfigs = Seq(consumerConfig)
    val offsetRanges = RedisSource.getOffsetRanges(Some(startOffsets), endOffsets, consumerConfigs)
    offsetRanges.head shouldBe RedisSourceOffsetRange(Some("0-0"), "0-1", consumerConfig)
  }
} 
Example 147
Source File: RedisConsumerOffsetTest.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package org.apache.spark.sql.redis.stream

import org.scalatest.{FunSuite, Matchers}


class RedisConsumerOffsetTest extends FunSuite with Matchers {

  test("testFromJson") {
    val offset = RedisSourceOffset.fromJson(
      """
        |{
        |  "offsets":{
        |    "mystream": {
        |      "groupName": "group55",
        |      "offset": "1543674099961-0"
        |    }
        |  }
        |}
        |""".stripMargin)
    offset shouldBe RedisSourceOffset(Map("mystream" ->
      RedisConsumerOffset("group55", "1543674099961-0")))
  }
} 
Example 148
Source File: RedisConfigSuite.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.redislabs.provider.redis

import org.scalatest.{FunSuite, Matchers}
import redis.clients.jedis.util.JedisClusterCRC16

class RedisConfigSuite extends FunSuite with Matchers {

  val redisStandaloneConfig = new RedisConfig(RedisEndpoint("127.0.0.1", 6379, "passwd"))
  val redisClusterConfig = new RedisConfig(RedisEndpoint("127.0.0.1", 7379))

  test("getNodesBySlots") {
    redisStandaloneConfig.getNodesBySlots(0, 16383).length shouldBe 1
    redisClusterConfig.getNodesBySlots(0, 16383).length shouldBe 7
  }

  test("getHost") {
    val key = "getHost"
    val slot = JedisClusterCRC16.getSlot(key)
    val standaloneHost = redisStandaloneConfig.getHost(key)
    assert(standaloneHost.startSlot <= slot && standaloneHost.endSlot >= slot)
    val clusterHost = redisClusterConfig.getHost(key)
    assert(clusterHost.startSlot <= slot && clusterHost.endSlot >= slot)
  }

  test("getNodes") {
    redisStandaloneConfig.getNodes(RedisEndpoint("127.0.0.1", 6379, "passwd")).length shouldBe 1
    redisClusterConfig.getNodes(RedisEndpoint("127.0.0.1", 7379)).length shouldBe 7
  }
} 
Example 149
Source File: SparkStreamingRedisSuite.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.redislabs.provider.redis

import com.redislabs.provider.redis.env.Env
import com.redislabs.provider.redis.util.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.scalatest.{BeforeAndAfterEach, FunSuite}


trait SparkStreamingRedisSuite extends FunSuite with Env with BeforeAndAfterEach with Logging {

  override protected def beforeEach(): Unit = {
    super.beforeEach()
    spark = SparkSession.builder().config(conf).getOrCreate()
    sc = spark.sparkContext
    ssc = new StreamingContext(sc, Seconds(1))
  }

  override protected def afterEach(): Unit = {
    ssc.stop()
    spark.stop
    System.clearProperty("spark.driver.port")
    super.afterEach()
  }

} 
Example 150
Source File: ConnectionSSLUtilsTest.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.redislabs.provider.redis.util

import com.redislabs.provider.redis.env.RedisStandaloneSSLEnv
import com.redislabs.provider.redis.util.ConnectionUtils.{JedisExt, XINFO}
import org.scalatest.{FunSuite, Matchers}
import redis.clients.jedis.StreamEntryID

import scala.collection.JavaConverters._


class ConnectionSSLUtilsTest extends FunSuite with Matchers with RedisStandaloneSSLEnv {

  test("xinfo") {
    val streamKey = TestUtils.generateRandomKey()
    val conn = redisConfig.connectionForKey(streamKey)
    val data = Map("key" -> "value").asJava
    val entryId = conn.xadd(streamKey, new StreamEntryID(0, 1), data)
    val info = conn.xinfo(XINFO.SubCommandStream, streamKey)
    info.get(XINFO.LastGeneratedId) shouldBe Some(entryId.toString)
  }
} 
Example 151
Source File: ConnectionUtilsTest.scala    From spark-redis   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.redislabs.provider.redis.util

import com.redislabs.provider.redis.env.RedisStandaloneEnv
import com.redislabs.provider.redis.util.ConnectionUtils.{JedisExt, XINFO}
import org.scalatest.{FunSuite, Matchers}
import redis.clients.jedis.StreamEntryID

import scala.collection.JavaConverters._


class ConnectionUtilsTest extends FunSuite with Matchers with RedisStandaloneEnv {

  test("xinfo") {
    val streamKey = TestUtils.generateRandomKey()
    val conn = redisConfig.connectionForKey(streamKey)
    val data = Map("key" -> "value").asJava
    val entryId = conn.xadd(streamKey, new StreamEntryID(0, 1), data)
    val info = conn.xinfo(XINFO.SubCommandStream, streamKey)
    info.get(XINFO.LastGeneratedId) shouldBe Some(entryId.toString)
  }
} 
Example 152
Source File: Tests.scala    From spark-es   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.elasticsearch

import org.elasticsearch.common.settings.Settings
import org.scalatest.FunSuite

class Tests extends FunSuite with SparkSuite with ElasticSearchSuite {
  test("Reads documents from multiple shards") {
    val client = es.client

    val indexName = "index-with-multiple-shards"

    client.admin().indices().prepareCreate(indexName)
      .setSettings(Settings.settingsBuilder()
        .put("index.number_of_replicas", 0)
        .put("index.number_of_shards", 2)
        .build()
      )
      .get()

    for (i <- 1 to 1000) {
      client.prepareIndex(indexName, "foo", i.toString).setSource("{}").get()
    }

    client.admin().cluster().prepareHealth(indexName).setWaitForGreenStatus().get()
    client.admin().indices().prepareRefresh(indexName).get()

    val rdd = sparkContext.esRDD(Seq("localhost"), es.clusterName, Seq(indexName), Seq("foo"), "*")

    assert(rdd.partitions.length == 2)
    assert(rdd.collect().map(_.metadata.id).sorted.toList == (1 to 1000).map(_.toString).sorted.toList)
  }

  test("Writes documents to ElasticSearch") {
    val client = es.client

    val indexName = "index1"

    sparkContext.parallelize(Seq(1, 2, 3, 4))
      .map(id => ESDocument(ESMetadata(id.toString, "foo", indexName), "{}"))
      .saveToES(Seq("localhost"), es.clusterName)

    client.admin().cluster().prepareHealth(indexName).setWaitForGreenStatus().get()
    client.admin().indices().prepareRefresh(indexName).get()

    assert(client.prepareGet(indexName, "foo", "1").get().isExists)
    assert(client.prepareGet(indexName, "foo", "2").get().isExists)
    assert(client.prepareGet(indexName, "foo", "3").get().isExists)
    assert(client.prepareGet(indexName, "foo", "4").get().isExists)
  }
} 
Example 153
Source File: SparkFunSuite.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark

// scalastyle:off
import org.apache.log4j.{Level, Logger}
import org.scalatest.{FunSuite, Outcome}

import org.apache.spark.Logging


  final protected override def withFixture(test: NoArgTest): Outcome = {
    val testName = test.text
    val suiteName = this.getClass.getName
    val shortSuiteName = suiteName.replaceAll("org.apache.spark", "o.a.s")
    try {
      Logger.getLogger("org").setLevel(Level.OFF)
      Logger.getLogger("akka").setLevel(Level.OFF)

      logInfo(s"\n\n===== TEST OUTPUT FOR $shortSuiteName: '$testName' =====\n")
      test()
    } finally {
      logInfo(s"\n\n===== FINISHED $shortSuiteName: '$testName' =====\n")
    }
  }

} 
Example 154
Source File: PurchaseLogGeneratorSuite.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark.ch08

import org.scalatest.FunSuite

class PurchaseLogGeneratorSuite extends FunSuite {

  test("generate products, users, purchaseLog") {
    val numProducts = 10
    val numUsers = 10
    val numProductsPerUser = 2
    implicit val recOpts: RecommendLogOptions = RecommendLogOptions(numProducts, numUsers, numProductsPerUser)
    val products = PurchaseLogGenerator.genProductList
    assert(products.size === numProducts)
  }

  test("generate user list") {
    val numProducts = 10
    val numUsers = 10
    val numProductsPerUser = 2
    implicit val recOpts: RecommendLogOptions = RecommendLogOptions(numProducts, numUsers, numProductsPerUser)
    val users = PurchaseLogGenerator.genUserList
    assert(users.size === numUsers)
  }

  test("generate purchaseLog with RandomSelection") {
    val numProducts = 10
    val numUsers = 10
    val numProductsPerUser = 2
    implicit val recOpts: RecommendLogOptions = RecommendLogOptions(numProducts, numUsers, numProductsPerUser)
    implicit val pidGenerator = ProductIdGenerator.fromString("RandomSelection")

    val users = PurchaseLogGenerator.genUserList
    val purchaseLog = PurchaseLogGenerator.genPurchaseLog(users)

    assert(purchaseLog.size === numUsers * numProductsPerUser)
    assert(purchaseLog.groupBy(_.uid).size === numUsers)
  }

  test("generate purchaseLog with PreferentialAttachment") {
    val numProducts = 10
    val numUsers = 10
    val numProductsPerUser = 2
    implicit val recOpts: RecommendLogOptions = RecommendLogOptions(numProducts, numUsers, numProductsPerUser)
    implicit val pidGenerator = ProductIdGenerator.fromString("PreferentialAttachment")

    val users = PurchaseLogGenerator.genUserList
    val purchaseLog = PurchaseLogGenerator.genPurchaseLog(users)

    assert(purchaseLog.size === numUsers * numProductsPerUser)
    assert(purchaseLog.groupBy(_.uid).size === numUsers)
  }
} 
Example 155
Source File: ProductIdGeneratorSuite.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark.ch08

import org.scalatest.FunSuite

class ProductIdGeneratorSuite extends FunSuite {

  test("get next productId by RandomSelection") {
    val numProducts = 5
    val numUsers = 5
    val numProductsPerUser = 2
    implicit val recOpts: RecommendLogOptions =
      RecommendLogOptions(numProducts, numUsers, numProductsPerUser)
    val purchaseLog = List(
      Purchase(6L, 1L),
      Purchase(6L, 2L),
      Purchase(7L, 3L),
      Purchase(7L, 4L),
      Purchase(8L, 5L),
      Purchase(8L, 1L),
      Purchase(9L, 2L),
      Purchase(9L, 3L),
      Purchase(10L, 4L),
      Purchase(10L, 5L)
    )
    (1 to 10).foreach( i => {
      val pid = ProductIdGenerator.RandomSelection.getNextPid(recOpts, purchaseLog)
      assert(0 <= pid && pid <= numProducts)
    })
  }

  test("get next productId by PreferentialAttachment") {
    val numProducts = 5
    val numUsers = 5
    val numProductsPerUser = 2
    implicit val recOpts: RecommendLogOptions = RecommendLogOptions(numProducts, numUsers, numProductsPerUser)
    val purchaseLog = List(
      Purchase(6L, 1L),
      Purchase(6L, 2L),
      Purchase(7L, 3L),
      Purchase(7L, 4L),
      Purchase(8L, 5L),
      Purchase(8L, 1L),
      Purchase(9L, 2L),
      Purchase(9L, 3L),
      Purchase(10L, 4L),
      Purchase(10L, 5L)
    )
    (1 to 10).foreach( i => {
      val pid = ProductIdGenerator.PreferentialAttachment.getNextPid(recOpts, purchaseLog)
      assert(0 <= pid && pid <= numProducts)
    })
  }

  test("get ProductIdGenerator from string") {
    assert(ProductIdGenerator.RandomSelection === ProductIdGenerator.fromString("RandomSelection"))
    assert(ProductIdGenerator.PreferentialAttachment === ProductIdGenerator.fromString("PreferentialAttachment"))
    assert(ProductIdGenerator.RandomSelection === ProductIdGenerator.fromString("hoge"))
  }

} 
Example 156
Source File: TripSuite.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark.ch05

import org.scalatest.FunSuite

class TripSuite extends FunSuite {

  test("should be parsed") {
    val line = "911926,566,8/31/2015 8:20,Harry Bridges Plaza (Ferry Building)," +
      "50,8/31/2015 8:30,Post at Kearny,47,566,Subscriber,95442"
    val trip = Trip.parse(line)
    assert(trip.id === 911926)
    assert(trip.duration === 566)
    assert(trip.zipcode === "95442")
  }
} 
Example 157
Source File: StationSuite.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark.ch05

import java.sql.Timestamp
import java.text.SimpleDateFormat

import org.scalatest.FunSuite

class StationSuite extends FunSuite {

  test("should be parse") {
    val line = "2,San Jose Diridon Caltrain Station,37.329732,-121.901782,27,San Jose,8/6/2013"
    val station = Station.parse(line)

    val dateFormat = new SimpleDateFormat("MM/dd/yyy")
    assert(station.id === 2)
    assert(station.name === "San Jose Diridon Caltrain Station")
    assert(station.lat === 37.329732)
    assert(station.lon === -121.901782)
    assert(station.dockcount === 27)
    assert(station.landmark === "San Jose")
    assert(station.installation === new Timestamp(dateFormat.parse("8/6/2013").getTime))
  }
} 
Example 158
Source File: MeanAveragePrecisionSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.evaluation

import breeze.linalg.DenseVector
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import keystoneml.utils.Stats
import keystoneml.workflow.PipelineContext

class MeanAveragePrecisionSuite extends FunSuite with PipelineContext {

  test("random map test") {
    sc = new SparkContext("local", "test")

    // Build some random test data with 4 classes 0,1,2,3
    val actual = List(Array(0, 3), Array(2), Array(1, 2), Array(0))
    val actualRdd = sc.parallelize(actual)

    val predicted = List(
      DenseVector(0.1, -0.05, 0.12, 0.5),
      DenseVector(-0.23, -0.45, 0.23, 0.1),
      DenseVector(-0.34, -0.32, -0.66, 1.52),
      DenseVector(-0.1, -0.2, 0.5, 0.8))

    val predictedRdd = sc.parallelize(predicted)

    val map = new MeanAveragePrecisionEvaluator(4).evaluate(predictedRdd, actualRdd)

    // Expected values from running this in MATLAB
    val expected = DenseVector(1.0, 0.3333, 0.5, 0.3333)

    assert(Stats.aboutEq(map, expected, 1e-4))
  }
} 
Example 159
Source File: MulticlassClassifierEvaluatorSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.evaluation

import breeze.linalg.DenseMatrix
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.workflow.PipelineContext

class MulticlassClassifierEvaluatorSuite extends FunSuite with PipelineContext {
  test("Multiclass keystoneml.evaluation metrics") {
    
    sc = new SparkContext("local", "test")
    val confusionMatrix = new DenseMatrix(3, 3, Array(2, 1, 0, 1, 3, 0, 1, 0, 1))
    val labels = Array(0.0, 1.0, 2.0)
    val predictionAndLabels = sc.parallelize(
      Seq((0.0, 0.0), (0.0, 1.0), (0.0, 0.0), (1.0, 0.0), (1.0, 1.0),
        (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)), 2)
    val evaluator = new MulticlassClassifierEvaluator(3)
    val metrics = evaluator.evaluate(predictionAndLabels.map(_._1.toInt), predictionAndLabels.map(_._2.toInt)
    )
    val delta = 0.0000001
    val precision0 = 2.0 / (2 + 1)
    val precision1 = 3.0 / (3 + 1)
    val precision2 = 1.0 / (1 + 1)
    val recall0 = 2.0 / (2 + 2)
    val recall1 = 3.0 / (3 + 1)
    val recall2 = 1.0 / (1 + 0)
    val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
    val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
    val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
    val f2measure0 = (1 + 2 * 2) * precision0 * recall0 / (2 * 2 * precision0 + recall0)
    val f2measure1 = (1 + 2 * 2) * precision1 * recall1 / (2 * 2 * precision1 + recall1)
    val f2measure2 = (1 + 2 * 2) * precision2 * recall2 / (2 * 2 * precision2 + recall2)

    assert(metrics.confusionMatrix.toArray.sameElements(confusionMatrix.toArray))
    assert(math.abs(metrics.classMetrics(0).precision - precision0) < delta)
    assert(math.abs(metrics.classMetrics(1).precision - precision1) < delta)
    assert(math.abs(metrics.classMetrics(2).precision - precision2) < delta)
    assert(math.abs(metrics.classMetrics(0).recall - recall0) < delta)
    assert(math.abs(metrics.classMetrics(1).recall - recall1) < delta)
    assert(math.abs(metrics.classMetrics(2).recall - recall2) < delta)
    assert(math.abs(metrics.classMetrics(0).fScore() - f1measure0) < delta)
    assert(math.abs(metrics.classMetrics(1).fScore() - f1measure1) < delta)
    assert(math.abs(metrics.classMetrics(2).fScore() - f1measure2) < delta)
    assert(math.abs(metrics.classMetrics(0).fScore(2.0) - f2measure0) < delta)
    assert(math.abs(metrics.classMetrics(1).fScore(2.0) - f2measure1) < delta)
    assert(math.abs(metrics.classMetrics(2).fScore(2.0) - f2measure2) < delta)

    assert(math.abs(metrics.microRecall -
        (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
    assert(math.abs(metrics.microRecall - metrics.microPrecision) < delta)
    assert(math.abs(metrics.microRecall - metrics.microFScore()) < delta)
    assert(math.abs(metrics.macroPrecision -
        (precision0 + precision1 + precision2) / 3.0) < delta)
    assert(math.abs(metrics.macroRecall -
        (recall0 + recall1 + recall2) / 3.0) < delta)
    assert(math.abs(metrics.macroFScore() -
        (f1measure0 + f1measure1 + f1measure2) / 3.0) < delta)
    assert(math.abs(metrics.macroFScore(2.0) -
        (f2measure0 + f2measure1 + f2measure2) / 3.0) < delta)
  }
} 
Example 160
Source File: BinaryClassifierEvaluatorSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.evaluation

import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.utils.Stats
import keystoneml.workflow.PipelineContext

class BinaryClassifierEvaluatorSuite extends FunSuite with PipelineContext {
  test("Multiclass keystoneml.evaluation metrics") {
    
    sc = new SparkContext("local", "test")

    val predictionAndLabels = sc.parallelize( Seq.fill(6)((true, true)) ++ Seq.fill(2)((false, true))
        ++ Seq.fill(1)((true, false)) ++ Seq.fill(3)((false, false)), 2)
    val metrics = BinaryClassifierEvaluator.evaluate(predictionAndLabels.map(_._1), predictionAndLabels.map(_._2))

    assert(metrics.tp === 6)
    assert(metrics.fp === 1)
    assert(metrics.tn === 3)
    assert(metrics.fn === 2)

    assert(Stats.aboutEq(metrics.precision, 6.0/7.0))
    assert(Stats.aboutEq(metrics.recall, 6.0/8.0))
    assert(Stats.aboutEq(metrics.accuracy, 9.0/12.0))
    assert(Stats.aboutEq(metrics.specificity, 3.0/4.0))
    assert(Stats.aboutEq(metrics.fScore(), 2.0 * 6.0 / (2.0 * 6.0 + 2.0 + 1.0)))
  }
} 
Example 161
Source File: MLlibUtilsSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.utils

import org.apache.spark.mllib.linalg._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
import org.scalatest.FunSuite

class MLlibUtilsSuite extends FunSuite {
  val arr = Array(0.1, 0.2, 0.3, 0.4)
  val n = 20
  val indices = Array(0, 3, 5, 10, 13)
  val values = Array(0.1, 0.5, 0.3, -0.8, -1.0)

  test("dense vector to breeze dense") {
    val vec = Vectors.dense(arr)
    assert(MLlibUtils.mllibVectorToDenseBreeze(vec) === new BDV[Double](arr))
  }

  test("sparse vector to breeze dense") {
    val vec = Vectors.sparse(n, indices, values)
    val breeze = new BDV[Double](n)
    indices.zip(values).foreach { case (x, y) =>
      breeze(x) = y
    }
    assert(MLlibUtils.mllibVectorToDenseBreeze(vec) === breeze)
  }

  test("dense breeze to vector") {
    val breeze = new BDV[Double](arr)
    val vec = MLlibUtils.breezeVectorToMLlib(breeze).asInstanceOf[DenseVector]
    assert(vec.size === arr.length)
    assert(vec.values.eq(arr), "should not copy data")
  }

  test("sparse breeze to vector") {
    val breeze = new BSV[Double](indices, values, n)
    val vec = MLlibUtils.breezeVectorToMLlib(breeze).asInstanceOf[SparseVector]
    assert(vec.size === n)
    assert(vec.indices.eq(indices), "should not copy data")
    assert(vec.values.eq(values), "should not copy data")
  }

  test("sparse breeze with partially-used arrays to vector") {
    val activeSize = 3
    val breeze = new BSV[Double](indices, values, activeSize, n)
    val vec = MLlibUtils.breezeVectorToMLlib(breeze).asInstanceOf[SparseVector]
    assert(vec.size === n)
    assert(vec.indices === indices.slice(0, activeSize))
    assert(vec.values === values.slice(0, activeSize))
  }

  test("dense matrix to breeze dense") {
    val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0))
    val breeze = MLlibUtils.mllibMatrixToDenseBreeze(mat)
    assert(breeze.rows === mat.numRows)
    assert(breeze.cols === mat.numCols)
    assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data")
  }

  test("sparse matrix to breeze dense") {
    val values = Array(1.0, 2.0, 4.0, 5.0)
    val colPtrs = Array(0, 2, 4)
    val rowIndices = Array(1, 2, 1, 2)
    val mat = Matrices.sparse(3, 2, colPtrs, rowIndices, values)
    val breeze = MLlibUtils.mllibMatrixToDenseBreeze(mat)
    assert(breeze.rows === mat.numRows)
    assert(breeze.cols === mat.numCols)
    assert(breeze.toArray === mat.toArray)
  }
} 
Example 162
Source File: MatrixUtilsSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.utils

import org.scalatest.FunSuite

import breeze.linalg._
import breeze.stats._

import org.apache.spark.SparkContext

import keystoneml.pipelines._
import keystoneml.workflow.PipelineContext

class MatrixUtilsSuite extends FunSuite with PipelineContext {

  test("computeMean works correctly") {
    val numRows = 1000
    val numCols = 32
    val numParts = 4
    sc = new SparkContext("local", "test")
    val in = DenseMatrix.rand(numRows, numCols)
    val inArr = MatrixUtils.matrixToRowArray(in)
    val rdd = sc.parallelize(inArr, numParts).mapPartitions { iter => 
      Iterator.single(MatrixUtils.rowsToMatrix(iter))
    }
    val expected = mean(in(::, *)).t
    val actual = MatrixUtils.computeMean(rdd)
    assert(Stats.aboutEq(expected, actual, 1e-6))
  }

} 
Example 163
Source File: ImageUtilsSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.utils

import org.scalatest.FunSuite

class ImageUtilsSuite extends FunSuite {

  test("crop") {
    val imgArr =
      (0 until 4).flatMap { x =>
        (0 until 4).flatMap { y =>
          (0 until 1).map { c =>
            (c + x * 1 + y * 4 * 1).toDouble
          }
        }
      }.toArray

    val image = new ChannelMajorArrayVectorizedImage(imgArr, ImageMetadata(4, 4, 1))
    val cropped = ImageUtils.crop(image, 1, 1, 3, 3)

    assert(cropped.metadata.xDim == 2)
    assert(cropped.metadata.yDim == 2)
    assert(cropped.metadata.numChannels == 1)

    assert(cropped.get(0, 0, 0) == 5.0)
    assert(cropped.get(0, 1, 0) == 6.0)
    assert(cropped.get(1, 0, 0) == 9.0)
    assert(cropped.get(1, 1, 0) == 10.0)
  }

  test("flipHorizontal") {
    val imgArr =
      (0 until 4).flatMap { x =>
        (0 until 4).flatMap { y =>
          (0 until 1).map { c =>
            (c + x * 1 + y * 4 * 1).toDouble
          }
        }
      }.toArray

    val image = new ChannelMajorArrayVectorizedImage(imgArr, ImageMetadata(4, 4, 1))

    val flipped = ImageUtils.flipHorizontal(image)

    assert(flipped.metadata.xDim == 4)
    assert(flipped.metadata.yDim == 4)
    assert(flipped.metadata.numChannels == 1)

    (0 until 4).foreach { x =>
      assert(flipped.get(x, 0, 0) == image.get(x, 3, 0))
      assert(flipped.get(x, 1, 0) == image.get(x, 2, 0))
      assert(flipped.get(x, 2, 0) == image.get(x, 1, 0))
      assert(flipped.get(x, 3, 0) == image.get(x, 0, 0))
    }
  }

} 
Example 164
Source File: ImageSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.utils.images

import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.VectorizedImage
import keystoneml.utils.TestUtils._

class ImageSuite extends FunSuite with Logging {
  test("Vectorized Image Coordinates Should be Correct") {
    val (x,y,z) = (100,100,3)

    val images = Array[VectorizedImage](
      genChannelMajorArrayVectorizedImage(x,y,z),
      genColumnMajorArrayVectorizedImage(x,y,z),
      genRowMajorArrayVectorizedImage(x,y,z),
      genRowColumnMajorByteArrayVectorizedImage(x,y,z)
    )

    for (
      img <- images;
      idx <- 0 until x*y*z
    ) {
      val coord = img.vectorToImageCoords(idx)
      assert(img.imageToVectorCoords(coord.x,coord.y,coord.channelIdx) == idx,
        s"imageToVectorCoords(vectorToImageCoords(idx)) should be equivalent to identity(idx) for img $img")
    }

    for (
      img <- images;
      xi <- 0 until x;
      yi <- 0 until y;
      zi <- 0 until z
    ) {
      val coord = img.vectorToImageCoords(img.imageToVectorCoords(xi,yi,zi))
      assert((coord.x, coord.y, coord.channelIdx) == (xi,yi,zi),
        s"vectorToImageCoords(imageToVectorCoords(x,y,z)) should be equivalent to identity(x,y,z) for img $img")
    }
  }
} 
Example 165
Source File: VLFeatSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.utils.external

import java.io.File

import breeze.linalg._
import breeze.numerics.abs
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.{ImageUtils, MatrixUtils, TestUtils}

class VLFeatSuite extends FunSuite with Logging {
  test("Load an Image and compute SIFT Features") {
    val testImage = TestUtils.loadTestImage("images/000012.jpg")
    val singleImage = ImageUtils.mapPixels(testImage, _/255.0)
    val grayImage = ImageUtils.toGrayScale(singleImage)

    val extLib = new VLFeat

    val stepSize = 3
    val binSize = 4
    val scales = 4
    val descriptorLength = 128
    val scaleStep = 0

    val rawDescDataShort = extLib.getSIFTs(grayImage.metadata.xDim, grayImage.metadata.yDim,
      stepSize, binSize, scales, scaleStep, grayImage.getSingleChannelAsFloatArray())

    assert(rawDescDataShort.length % descriptorLength == 0, "Resulting SIFTs must be 128-dimensional.")

    val numCols = rawDescDataShort.length/descriptorLength
    val result = new DenseMatrix(descriptorLength, numCols, rawDescDataShort.map(_.toDouble))

    // Compare with the output of running this image through vl_phow with matlab from the enceval package:
    // featpipem_addpaths;
    // im = im2single(imread('images/000012.jpg'));
    // featextr = featpipem.features.PhowExtractor();
    // featextr.step = 3;
    // [frames feats] = featextr.compute(im);
    // csvwrite('images/feats128.csv', feats)

    val testFeatures = csvread(new File(TestUtils.getTestResourceFileName("images/feats128.csv")))

    val diff = result - testFeatures

    // Because of subtle differences in the way image smoothing works in the VLFeat C library and the VLFeat matlab
    // library (vl_imsmooth_f vs. _vl_imsmooth_f), these two matrices will not be exactly the same.
    // Instead, we check that 99.5% of the matrix entries are off by at most 1.
    val absdiff = abs(diff).toDenseVector

    assert(absdiff.findAll(_ > 1.0).length.toDouble < 0.005*absdiff.length,
      "Fewer than 0.05% of entries may be different by more than 1.")
  }
} 
Example 166
Source File: EncEvalSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.utils.external

import java.io.File

import breeze.linalg._
import breeze.stats.distributions.Gaussian
import keystoneml.nodes.learning.GaussianMixtureModel
import keystoneml.nodes.learning.external.GaussianMixtureModelEstimator
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.{Stats, TestUtils}

class EncEvalSuite extends FunSuite with Logging {

  test("Load SIFT Descriptors and compute Fisher Vector Features") {

    val siftDescriptor = csvread(new File(TestUtils.getTestResourceFileName("images/feats.csv")))

    val gmmMeans = TestUtils.getTestResourceFileName("images/voc_codebook/means.csv")
    val gmmVars = TestUtils.getTestResourceFileName("images/voc_codebook/variances.csv")
    val gmmWeights = TestUtils.getTestResourceFileName("images/voc_codebook/priors")

    val gmm = GaussianMixtureModel.load(gmmMeans, gmmVars, gmmWeights)

    val nCenters = gmm.means.cols
    val nDim = gmm.means.rows

    val extLib = new EncEval

    val fisherVector = extLib.calcAndGetFVs(
      gmm.means.toArray.map(_.toFloat),
      nCenters,
      nDim,
      gmm.variances.toArray.map(_.toFloat),
      gmm.weights.toArray.map(_.toFloat),
      siftDescriptor.toArray.map(_.toFloat))

    log.info(s"Fisher Vector is ${fisherVector.sum}")
    assert(Stats.aboutEq(fisherVector.sum, 40.109097, 1e-4), "SUM of Fisher Vectors must match expected sum.")

  }

  test("Compute a GMM from scala") {
    val nsamps = 10000

    // Generate two gaussians.
    val x = Gaussian(-1.0, 0.5).samples.take(nsamps).toArray
    val y = Gaussian(5.0, 1.0).samples.take(nsamps).toArray

    val z = shuffle(x ++ y).map(x => DenseVector(x))

    // Compute a 1-d GMM.
    val extLib = new EncEval
    val gmm = new GaussianMixtureModelEstimator(2).fit(z)

    logInfo(s"GMM means: ${gmm.means.toArray.mkString(",")}")
    logInfo(s"GMM vars: ${gmm.variances.toArray.mkString(",")}")
    logInfo(s"GMM weights: ${gmm.weights.toArray.mkString(",")}")

    // The results should be close to the distribution we set up.
    assert(Stats.aboutEq(min(gmm.means), -1.0, 1e-1), "Smallest mean should be close to -1.0")
    assert(Stats.aboutEq(max(gmm.means), 5.0, 1e-1), "Largest mean should be close to 1.0")
    assert(Stats.aboutEq(math.sqrt(min(gmm.variances)), 0.5, 1e-1), "Smallest SD should be close to 0.25")
    assert(Stats.aboutEq(math.sqrt(max(gmm.variances)), 1.0, 1e-1), "Largest SD should be close to 5.0")
  }
} 
Example 167
Source File: EstimatorSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.workflow

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging

class EstimatorSuite extends FunSuite with PipelineContext with Logging {
  test("Estimator fit RDD") {
    sc = new SparkContext("local", "test")

    val intEstimator = new Estimator[Int, Int] {
      def fit(data: RDD[Int]): Transformer[Int, Int] = {
        val first = data.first()
        Transformer(x => x + first)
      }
    }

    val trainData = sc.parallelize(Seq(32, 94, 12))
    val testData = sc.parallelize(Seq(42, 58, 61))

    val pipeline = intEstimator.withData(trainData)
    assert(pipeline.apply(testData).get().collect().toSeq === Seq(42 + 32, 58 + 32, 61 + 32))
  }

  test("Estimator fit Pipeline Data") {
    sc = new SparkContext("local", "test")

    val transformer = Transformer[Int, Int](_ * 2)

    val intEstimator = new Estimator[Int, Int] {
      def fit(data: RDD[Int]): Transformer[Int, Int] = {
        val first = data.first()
        Transformer(x => x + first)
      }
    }

    val trainData = sc.parallelize(Seq(32, 94, 12))
    val testData = sc.parallelize(Seq(42, 58, 61))

    val pipeline = intEstimator.withData(transformer(trainData))
    assert(pipeline.apply(testData).get().collect().toSeq === Seq(42 + 64, 58 + 64, 61 + 64))
  }

} 
Example 168
Source File: LabelEstimatorSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.workflow

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging

class LabelEstimatorSuite extends FunSuite with PipelineContext with Logging {
  test("LabelEstimator fit RDD") {
    sc = new SparkContext("local", "test")

    val intEstimator = new LabelEstimator[Int, Int, String] {
      def fit(data: RDD[Int], labels: RDD[String]): Transformer[Int, Int] = {
        val first = data.first()
        val label = labels.first().hashCode
        Transformer(x => x + first + label)

      }
    }

    val trainData = sc.parallelize(Seq(32, 94, 12))
    val trainLabels = sc.parallelize(Seq("sjkfdl", "iw", "432"))
    val testData = sc.parallelize(Seq(42, 58, 61))

    val pipeline = intEstimator.withData(trainData, trainLabels)
    val offset = 32 + "sjkfdl".hashCode
    assert(pipeline.apply(testData).get().collect().toSeq === Seq(42 + offset, 58 + offset, 61 + offset))
  }

  test("LabelEstimator fit pipeline data") {
    sc = new SparkContext("local", "test")

    val dataTransformer = Transformer[Int, Int](_ * 2)
    val labelTransformer = Transformer[String, String](_ + "hi")

    val intEstimator = new LabelEstimator[Int, Int, String] {
      def fit(data: RDD[Int], labels: RDD[String]): Transformer[Int, Int] = {
        val first = data.first()
        val label = labels.first().hashCode
        Transformer(x => x + first + label)

      }
    }

    val trainData = sc.parallelize(Seq(32, 94, 12))
    val trainLabels = sc.parallelize(Seq("sjkfdl", "iw", "432"))
    val testData = sc.parallelize(Seq(42, 58, 61))

    val pipeline = intEstimator.withData(dataTransformer(trainData), labelTransformer(trainLabels))
    val offset = 64 + "sjkfdlhi".hashCode
    assert(pipeline.apply(testData).get().collect().toSeq === Seq(42 + offset, 58 + offset, 61 + offset))
  }
} 
Example 169
Source File: KMeansPlusPlusSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.learning

import breeze.linalg._
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.pipelines._
import keystoneml.utils.{MatrixUtils, Stats}
import keystoneml.workflow.PipelineContext

class KMeansPlusPlusSuite extends FunSuite with PipelineContext with Logging {

  test("K-Means++ Single Center") {
    sc = new SparkContext("local", "test")

    val k = 1

    val data = sc.parallelize(Array(
      DenseVector[Double](1.0, 2.0, 6.0),
      DenseVector[Double](1.0, 3.0, 0.0),
      DenseVector[Double](1.0, 4.0, 6.0)
    ))

    val center = DenseVector[Double](1.0, 3.0, 4.0).asDenseMatrix

    val kMeans = KMeansPlusPlusEstimator(k, maxIterations = 1).fit(data)
    assert(Stats.aboutEq(kMeans.means, center))

    val kMeans10 = KMeansPlusPlusEstimator(k, maxIterations = 10).fit(data)
    assert(Stats.aboutEq(kMeans.means, center))

    val out = kMeans.apply(data).collect()
  }

  test("K-Means++ Two Centers") {
    sc = new SparkContext("local", "test")

    val k = 2

    val data = sc.parallelize(Array(
      DenseVector[Double](1.0, 2.0, 6.0),
      DenseVector[Double](1.0, 3.0, 0.0),
      DenseVector[Double](1.0, 4.0, 6.0),
      DenseVector[Double](1.0, 1.0, 0.0)
    ))

    val centers = Set(
      DenseVector[Double](1.0, 2.0, 0.0),
      DenseVector[Double](1.0, 3.0, 6.0)
    )

    val kMeans = KMeansPlusPlusEstimator(k, maxIterations = 10).fit(data)
    val fitCenters = MatrixUtils.matrixToRowArray(kMeans.means).toSet
    assert(fitCenters === centers )

    val kMeans5 = KMeansPlusPlusEstimator(k, maxIterations = 5).fit(data)
    val fitCenters5 = MatrixUtils.matrixToRowArray(kMeans5.means).toSet
    assert(fitCenters5 === centers )

    val out = kMeans.apply(data).collect()
  }

  test("K-Means Transformer") {
    sc = new SparkContext("local", "test")

    val data = Array(
      DenseVector[Double](1.0, 2.0, 6.0),
      DenseVector[Double](1.0, 3.0, 0.0),
      DenseVector[Double](1.0, 4.0, 6.0),
      DenseVector[Double](1.0, 1.0, 0.0)
    )

    val centers = MatrixUtils.rowsToMatrix(Array(
      DenseVector[Double](1.0, 2.0, 0.0),
      DenseVector[Double](1.0, 3.0, 6.0)
    ))

    val clusterOne = DenseVector[Double](1.0, 0.0)
    val clusterTwo = DenseVector[Double](0.0, 1.0)

    val assignments = Seq(clusterTwo, clusterOne, clusterTwo, clusterOne)
    val kMeans = KMeansModel(centers)

    // Test Single Apply
    assert(kMeans.apply(DenseVector[Double](1.0, 3.0, 0.0)) === clusterOne)
    assert(kMeans.apply(DenseVector[Double](1.0, 1.0, 0.0)) === clusterOne)
    assert(kMeans.apply(DenseVector[Double](1.0, 2.0, 6.0)) === clusterTwo)
    assert(kMeans.apply(DenseVector[Double](1.0, 4.0, 6.0)) === clusterTwo)

    // Test Matrix Apply
    assert(kMeans.apply(MatrixUtils.rowsToMatrix(data)) === MatrixUtils.rowsToMatrix(assignments))

    // Test RDD Apply
    assert(kMeans.apply(sc.parallelize(data)).collect().toSeq === assignments)
  }
} 
Example 170
Source File: KernelModelSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.learning

import breeze.linalg._

import org.apache.spark.SparkContext
import org.scalatest.FunSuite

import keystoneml.workflow.PipelineContext
import keystoneml.utils.{MatrixUtils, Stats}

class KernelModelSuite extends FunSuite with PipelineContext {

  test("KernelModel XOR test") {
    sc = new SparkContext("local", "test")

    val x = Array(DenseVector(-1.0, -1.0), DenseVector(1.0, 1.0), DenseVector(-1.0, 1.0),DenseVector(1.0, -1.0))
    val xTest = Array(DenseVector(-1.0, -1.0), DenseVector(1.0, 1.0), DenseVector(-1.0, 1.0))
    val y = Array(DenseVector(0.0, 1.0), DenseVector(0.0, 1.0), DenseVector(1.0, 0.0), DenseVector(1.0, 0.0))
    val yTest = Array(DenseVector(0.0, 1.0), DenseVector(0.0, 1.0), DenseVector(1.0, 0.0))

    val xRDD = sc.parallelize(x, 2)
    val yRDD = sc.parallelize(y, 2)
    val xTestRDD = sc.parallelize(xTest, 2)

    val gaussian = new GaussianKernelGenerator(10)
    // Set block size to number of data points so no blocking happens
    val clf = new KernelRidgeRegression(gaussian, 0, 4, 2)

    val kernelModel = clf.fit(xRDD, yRDD)
    val yHat = kernelModel(xTestRDD).collect()
    // Fit should be good
    val delta = MatrixUtils.rowsToMatrix(yHat) - MatrixUtils.rowsToMatrix(yTest)

    delta :*= delta
    println("SUM OF DELTA1 " + sum(delta))
    assert(Stats.aboutEq(sum(delta), 0, 1e-4))
  }

  test("KernelModel XOR blocked test") {
    sc = new SparkContext("local", "test")

    val x = Array(DenseVector(-1.0, -1.0), DenseVector(1.0, 1.0), DenseVector(-1.0, 1.0),DenseVector(1.0, -1.0))
    val xTest = Array(DenseVector(-1.0, -1.0), DenseVector(1.0, 1.0), DenseVector(-1.0, 1.0))
    val y = Array(DenseVector(0.0, 1.0), DenseVector(0.0, 1.0), DenseVector(1.0, 0.0), DenseVector(1.0, 0.0))
    val yTest = Array(DenseVector(0.0, 1.0), DenseVector(0.0, 1.0), DenseVector(1.0, 0.0))

    val xRDD = sc.parallelize(x, 2)
    val yRDD = sc.parallelize(y, 2)
    val xTestRDD = sc.parallelize(xTest, 2)

    val gaussian = new GaussianKernelGenerator(10)

    // Set block size to half number of data points so blocking happens
    val clf = new KernelRidgeRegression(gaussian, 0, 2, 2)

    val kernelModel = clf.fit(xRDD, yRDD)
    val yHat = kernelModel(xTestRDD).collect()
    // Fit should be good
    val delta = MatrixUtils.rowsToMatrix(yHat) - MatrixUtils.rowsToMatrix(yTest)

    delta :*= delta
    assert(Stats.aboutEq(sum(delta), 0, 1e-4))
  }
} 
Example 171
Source File: BlockLinearMapperSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.learning

import breeze.linalg.{DenseVector, DenseMatrix}
import breeze.stats.distributions.Rand
import keystoneml.workflow.PipelineContext
import scala.collection.mutable.ArrayBuffer

import org.scalatest.FunSuite

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import keystoneml.pipelines._
import keystoneml.utils.Stats

class BlockLinearMapperSuite extends FunSuite with PipelineContext with Logging {

  test("BlockLinearMapper transformation") {
    sc = new SparkContext("local", "test")

    val inDims = 1000
    val outDims = 100
    val numChunks = 5
    val numPerChunk = inDims/numChunks

    val mat = DenseMatrix.rand(inDims, outDims, Rand.gaussian)
    val vec = DenseVector.rand(inDims, Rand.gaussian)
    val intercept = DenseVector.rand(outDims, Rand.gaussian)

    val splitVec = (0 until numChunks).map(i => vec((numPerChunk*i) until (numPerChunk*i + numPerChunk)))
    val splitMat = (0 until numChunks).map(i => mat((numPerChunk*i) until (numPerChunk*i + numPerChunk), ::))

    val linearMapper = new LinearMapper[DenseVector[Double]](mat, Some(intercept))
    val blockLinearMapper = new BlockLinearMapper(splitMat, numPerChunk, Some(intercept))

    val linearOut = linearMapper(vec)

    // Test with intercept
    assert(Stats.aboutEq(blockLinearMapper(vec), linearOut, 1e-4))

    // Test the apply and evaluate call
    val blmOuts = new ArrayBuffer[RDD[DenseVector[Double]]]
    val splitVecRDDs = splitVec.map { vec =>
      sc.parallelize(Seq(vec), 1)
    }
    blockLinearMapper.applyAndEvaluate(splitVecRDDs,
      (predictedValues: RDD[DenseVector[Double]]) => {
        blmOuts += predictedValues
        ()
      }
    )

    // The last blmOut should match the linear mapper's output
    assert(Stats.aboutEq(blmOuts.last.collect()(0), linearOut, 1e-4))
  }
} 
Example 172
Source File: LinearMapperSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.learning

import breeze.linalg._
import edu.berkeley.cs.amplab.mlmatrix.RowPartitionedMatrix
import keystoneml.nodes.stats.StandardScaler
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.{TestUtils, MatrixUtils, Stats}
import keystoneml.workflow.PipelineContext

class LinearMapperSuite extends FunSuite with PipelineContext with Logging {
  test("Solve and apply a linear system") {
    sc = new SparkContext("local", "test")

    // Create the data.
    val A = TestUtils.createRandomMatrix(sc, 128, 5, 4)
    val x = DenseVector(5.0, 4.0, 3.0, 2.0, -1.0).toDenseMatrix
    val b = A.mapPartitions(part => part * x.t)

    val Aary = A.rdd.flatMap(part => MatrixUtils.matrixToRowArray(part.mat).toIterator)
    val bary = b.rdd.flatMap(part => MatrixUtils.matrixToRowArray(part.mat).toIterator)

    val mapper = new LinearMapEstimator().fit(Aary, bary)

    assert(Stats.aboutEq(mapper.x, x.t), "Coefficients from the solve must match the hand-created model.")

    val point = DenseVector(2.0, -3.0, 2.0, 3.0, 5.0)

    assert(Stats.aboutEq(mapper(sc.parallelize(Seq(point))).first()(0), 5.0),
        "Linear model applied to a point should be 5.0")

    val bt = mapper(Aary)
    assert(Stats.aboutEq(bt.collect()(0), bary.collect()(0)),
        "Linear model applied to input should be the same as training points.")
  }

  test("LocalLeastSquaresEstimator doesn't crash") {
    sc = new SparkContext("local", "test")

    // Create the data.
    val A = TestUtils.createRandomMatrix(sc, 50, 400, 4)
    val x = DenseVector(5.0, 4.0, 3.0, 2.0, -1.0).toDenseMatrix
    val b = A.mapPartitions(part => DenseMatrix.rand(part.rows, 3))

    val Aary = A.rdd.flatMap(part => MatrixUtils.matrixToRowArray(part.mat).toIterator)
    val bary = b.rdd.flatMap(part => MatrixUtils.matrixToRowArray(part.mat).toIterator)

    val mapper = new LocalLeastSquaresEstimator(1e-2).fit(Aary, bary)
    assert(mapper.x.rows === 400)
    assert(mapper.x.cols === 3)
  }

  test("Solve a dense linear system (fit intercept) using local least squares") {
    sc = new SparkContext("local", "test")

    // Create the data.
    val A = TestUtils.createRandomMatrix(sc, 128, 5, 4)
    val x = DenseMatrix((5.0, 4.0, 3.0, 2.0, -1.0), (3.0, -1.0, 2.0, -2.0, 1.0))
    val dataMean = DenseVector(1.0, 0.0, 1.0, 2.0, 0.0)
    val extraBias = DenseVector(3.0, 4.0)

    val initialAary = A.rdd.flatMap(part => MatrixUtils.matrixToRowArray(part.mat).toIterator)
    val meanScaler = new StandardScaler(normalizeStdDev = false).fit(initialAary)
    val Aary = meanScaler.apply(initialAary).map(_ + dataMean)
    val bary = Aary.map(a => (x * (a - dataMean)) + extraBias)

    val mapper = new LocalLeastSquaresEstimator(0).fit(Aary, bary)

    val trueResult = MatrixUtils.rowsToMatrix(bary.collect())
    val solverResult = MatrixUtils.rowsToMatrix(mapper(Aary).collect())

    assert(Stats.aboutEq(trueResult, solverResult, 1e-5), "Results from the solve must match the hand-created model.")
    assert(Stats.aboutEq(mapper.x, x.t, 1e-6), "Model weights from the solve must match the hand-created model.")
    assert(Stats.aboutEq(mapper.bOpt.get, extraBias, 1e-6), "Learned intercept must match the hand-created model.")
    assert(Stats.aboutEq(mapper.featureScaler.get.mean, dataMean, 1e-6),
      "Learned intercept must match the hand-created model.")

  }

} 
Example 173
Source File: ZCAWhiteningSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.learning

import breeze.linalg._
import breeze.numerics._
import breeze.stats.distributions._
import org.scalatest.FunSuite
import keystoneml.pipelines._
import keystoneml.workflow.PipelineContext

class ZCAWhiteningSuite extends FunSuite with PipelineContext with Logging {

  val nrows = 10000
  val ndim = 10

  val x = DenseMatrix.rand[Double](nrows, ndim, Gaussian(0.0, 1.0))

  def fitAndCompare(x: DenseMatrix[Double], eps: Double, thresh: Double): Boolean = {
    val whitener = new ZCAWhitenerEstimator(eps).fitSingle(x)

    val wx = whitener(x)

    //Checks max(max(abs(cov(whiten(x))) - eye(10)) < sqrt(eps)
    max(abs(cov(convert(wx, Double)) - DenseMatrix.eye[Double](ndim))) < thresh
  }

  test("whitening with small epsilon") {
    assert(fitAndCompare(x, 1e-12, 1e-4),
      "Whitening the base matrix should produce unit variance and zero covariance.")
  }

  test("whitening with large epsilon") {
    assert(fitAndCompare(x, 0.1, 0.1),
      "Whitening the base matrix should produce unit variance and zero covariance.")

    assert(!fitAndCompare(x, 0.1, 1e-4),
      "Whitening the base matrix with a large epsilon should be somewhat noisy.")
  }
} 
Example 174
Source File: LinearDiscriminantAnalysisSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.learning

import breeze.linalg._
import breeze.stats.distributions.{Multinomial, Uniform, Gaussian}
import keystoneml.nodes.stats.StandardScaler
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.{TestUtils, MatrixUtils, Stats}
import keystoneml.workflow.PipelineContext

class LinearDiscriminantAnalysisSuite extends FunSuite with PipelineContext with Logging {
  test("Solve Linear Discriminant Analysis on the Iris Dataset") {
    sc = new SparkContext("local", "test")

    // Uses the Iris flower dataset
    val irisData = sc.parallelize(TestUtils.loadFile("iris.data"))
    val trainData = irisData.map(_.split(",").dropRight(1).map(_.toDouble)).map(new DenseVector(_))
    val features = new StandardScaler().fit(trainData).apply(trainData)
    val labels = irisData.map(_ match {
      case x if x.endsWith("Iris-setosa") => 1
      case x if x.endsWith("Iris-versicolor") => 2
      case x if x.endsWith("Iris-virginica") => 3
    })

    val lda = new LinearDiscriminantAnalysis(2)
    val out = lda.fit(features, labels)

    // Correct output taken from http://sebastianraschka.com/Articles/2014_python_lda.html#introduction
    logInfo(s"\n${out.x}")
    val majorVector = DenseVector(-0.1498, -0.1482, 0.8511, 0.4808)
    val minorVector = DenseVector(0.0095, 0.3272, -0.5748, 0.75)

    // Note that because eigenvectors can be reversed and still valid, we allow either direction
    assert(Stats.aboutEq(out.x(::, 0), majorVector, 1E-4) || Stats.aboutEq(out.x(::, 0), majorVector * -1.0, 1E-4))
    assert(Stats.aboutEq(out.x(::, 1), minorVector, 1E-4) || Stats.aboutEq(out.x(::, 1), minorVector * -1.0, 1E-4))
  }

  test("Check LDA output for a diagonal covariance") {
    sc = new SparkContext("local", "test")

    val matRows = 1000
    val matCols = 10
    val dimRed = 5

    // Generate a random Gaussian matrix.
    val gau = new Gaussian(0.0, 1.0)
    val randMatrix = new DenseMatrix(matRows, matCols, gau.sample(matRows*matCols).toArray)

    // Parallelize and estimate the LDA.
    val data = sc.parallelize(MatrixUtils.matrixToRowArray(randMatrix))
    val labels = data.map(x => Multinomial(DenseVector(0.2, 0.2, 0.2, 0.2, 0.2)).draw(): Int)
    val lda = new LinearDiscriminantAnalysis(dimRed).fit(data, labels)

    // Apply LDA to the input data.
    val redData = lda(data)
    val redMat = MatrixUtils.rowsToMatrix(redData.collect)

    // Compute its covariance.
    val redCov = cov(redMat)
    log.info(s"Covar\n$redCov")

    // The covariance of the dimensionality reduced matrix should be diagonal.
    for (
      x <- 0 until dimRed;
      y <- 0 until dimRed if x != y
    ) {
      assert(Stats.aboutEq(redCov(x,y), 0.0, 1e-6), s"LDA Matrix should be 0 off-diagonal. $x,$y = ${redCov(x,y)}")
    }
  }

} 
Example 175
Source File: HogExtractorSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.images

import breeze.linalg._
import org.scalatest.FunSuite

import keystoneml.pipelines.Logging
import keystoneml.utils.{ImageUtils, Stats, TestUtils}

class HogExtractorSuite extends FunSuite with Logging {
  test("Load an Image and compute Hog Features") {
    val testImage = TestUtils.loadTestImage("images/gantrycrane.png")

    // NOTE: The MATLAB implementation from voc-release5 uses
    // images in double range -- So convert our image by rescaling
    val testImageScaled = ImageUtils.mapPixels(testImage, x => x/255.0)

    val binSize = 50
    val hog = new HogExtractor(binSize)
    val descriptors = hog.apply(testImageScaled)

    val ourSum = sum(descriptors)
    val matlabSum = 59.2162514

    assert(Stats.aboutEq((ourSum - matlabSum) / ourSum, 0, 1e-8),
      "Hog features sum should match")

    // With a smaller bin size
    val hog1 = new HogExtractor(binSize=8)
    val descriptors1 = hog1.apply(testImageScaled)

    val matlabSum1 = 4.5775269e+03
    val ourSum1 = sum(descriptors1)

    // TODO: Figure out why error is a bit higher here ?
    assert(Stats.aboutEq((ourSum1 - matlabSum1) / ourSum1, 0, 1e-4),
      "Hog features sum should match")
  }
} 
Example 176
Source File: DaisyExtractorSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.images

import breeze.linalg._
import keystoneml.nodes.images.external.SIFTExtractor
import org.scalatest.FunSuite

import keystoneml.pipelines.Logging
import keystoneml.utils.{ImageUtils, Stats, TestUtils}

class DaisyExtractorSuite extends FunSuite with Logging {
  test("Load an Image and compute Daisy Features") {
    val testImage = TestUtils.loadTestImage("images/gantrycrane.png")
    val grayImage = ImageUtils.toGrayScale(testImage)

    val df = new DaisyExtractor()
    val daisyDescriptors = convert(df.apply(grayImage), Double)

    val firstKeyPointSum = sum(daisyDescriptors(::, 0))
    val fullFeatureSum = sum(daisyDescriptors)

    // Values found from running matlab code on same input file.
    val matlabFirstKeyPointSum = 55.127217737738533
    val matlabFullFeatureSum = 3.240635661296463E5

    // TODO: This should be at most 1e-8 as we are using Floats. But its 1e-5, 1e-7 right now ?
    assert(Stats.aboutEq(
      (firstKeyPointSum - matlabFirstKeyPointSum)/matlabFirstKeyPointSum, 0, 1e-5),
      "First keypoint sum must match for Daisy")
    assert(Stats.aboutEq((fullFeatureSum - matlabFullFeatureSum)/matlabFullFeatureSum, 0, 1e-7),
      "Sum of Daisys must match expected sum")
  }

  test("Daisy and SIFT extractors should have same row/column ordering.") {
    val testImage = TestUtils.loadTestImage("images/gantrycrane.png")
    val grayImage = ImageUtils.toGrayScale(testImage)

    val df = new DaisyExtractor()
    val daisyDescriptors = convert(df.apply(grayImage), Double)

    val se = SIFTExtractor(scaleStep = 2)
    val siftDescriptors = se.apply(grayImage)

    assert(daisyDescriptors.rows == df.daisyFeatureSize && siftDescriptors.rows == se.descriptorSize)

  }
} 
Example 177
Source File: CenterCornerPatcherSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.images

import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.{ChannelMajorArrayVectorizedImage, ImageMetadata, TestUtils}

class CenterCornerPatcherSuite extends FunSuite with Logging {

  test("check number and dimension of patches") {
    val image = TestUtils.loadTestImage("images/000012.jpg")
    val xDim = image.metadata.xDim
    val yDim = image.metadata.yDim
    val patchSizeX = xDim / 2 
    val patchSizeY = yDim / 2

    val withFlipPatcher = CenterCornerPatcher(patchSizeX, patchSizeY, true)
    val withFlipPatches = withFlipPatcher.centerCornerPatchImage(image).toSeq

    assert(withFlipPatches.map(_.metadata.xDim).forall(_ == patchSizeX) &&
      withFlipPatches.map(_.metadata.yDim).forall(_ == patchSizeY) &&
      withFlipPatches.map(_.metadata.numChannels).forall(_ == image.metadata.numChannels),
      "All patches must have right dimensions")

    assert(withFlipPatches.size === 10, "Number of patches must match")

    val noFlipPatcher = CenterCornerPatcher(patchSizeX, patchSizeY, false) 
    val noFlipPatches = noFlipPatcher.centerCornerPatchImage(image).toSeq

    assert(noFlipPatches.map(_.metadata.xDim).forall(_ == patchSizeX) &&
      noFlipPatches.map(_.metadata.yDim).forall(_ == patchSizeY) &&
      noFlipPatches.map(_.metadata.numChannels).forall(_ == image.metadata.numChannels),
      "All patches must have right dimensions")

    assert(noFlipPatches.size === 5, "Number of patches must match")
  }

  test("1x1 image patches") {
    val imgArr =
      (0 until 5).flatMap { x =>
        (0 until 5).flatMap { y =>
          (0 until 1).map { c =>
            (c + x * 1 + y * 5 * 1).toDouble
          }
        }
      }.toArray

    val image = new ChannelMajorArrayVectorizedImage(imgArr, ImageMetadata(5, 5, 1))
    val patchSizeX = 1
    val patchSizeY = 1

    val noFlipPatcher = CenterCornerPatcher(patchSizeX, patchSizeY, false)
    val noFlipPatches = noFlipPatcher.centerCornerPatchImage(image).toSeq

    assert(noFlipPatches.length === 5)
    // NOTE(shivaram): This assumes order of patches returned stays the same. 
    assert(noFlipPatches(0).get(0, 0, 0) === 0.0)
    assert(noFlipPatches(1).get(0, 0, 0) === 20.0)
    assert(noFlipPatches(2).get(0, 0, 0) === 4.0)
    assert(noFlipPatches(3).get(0, 0, 0) === 24.0)
    assert(noFlipPatches(4).get(0, 0, 0) === 12.0)
  }
} 
Example 178
Source File: RandomPatcherSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.images

import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.{ChannelMajorArrayVectorizedImage, ImageMetadata, TestUtils}

class RandomPatcherSuite extends FunSuite with Logging {

  test("patch dimensions, number") {
    val image = TestUtils.loadTestImage("images/000012.jpg")
    val xDim = image.metadata.xDim
    val yDim = image.metadata.yDim
    val patchSizeX = xDim / 2 
    val patchSizeY = yDim / 2
    val numPatches = 5

    val patcher = RandomPatcher(numPatches, patchSizeX, patchSizeY)

    val patches = patcher.randomPatchImage(image).toSeq

    assert(patches.map(_.metadata.xDim).forall(_ == patchSizeX) &&
      patches.map(_.metadata.yDim).forall(_ == patchSizeY) &&
      patches.map(_.metadata.numChannels).forall(_ == image.metadata.numChannels),
      "All patches must have right dimensions")

    assert(patches.size === numPatches,
      "Number of patches must match argument passed in")
  }
} 
Example 179
Source File: PoolingSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.images

import breeze.linalg.{DenseVector, sum}
import keystoneml.nodes._
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.{ChannelMajorArrayVectorizedImage, ImageMetadata}

class PoolingSuite extends FunSuite with Logging {

  test("pooling") {
    val imgArr =
      (0 until 4).flatMap { x =>
        (0 until 4).flatMap { y =>
          (0 until 1).map { c =>
            (c + x * 1 + y * 4 * 1).toDouble
          }
        }
      }.toArray

    val image = new ChannelMajorArrayVectorizedImage(imgArr, ImageMetadata(4, 4, 1))
    val pooling = new Pooler(2, 2, x => x, x => x.max)

    val poolImage = pooling(image)

    assert(poolImage.get(0, 0, 0) === 5.0)
    assert(poolImage.get(0, 1, 0) === 7.0)
    assert(poolImage.get(1, 0, 0) === 13.0)
    assert(poolImage.get(1, 1, 0) === 15.0)
  }

  test("pooling odd") {
    val hogImgSize = 14
    val convSizes = List(1, 2, 3, 4, 6, 8)
    convSizes.foreach { convSize =>
      val convResSize = hogImgSize - convSize + 1

      val imgArr =
        (0 until convResSize).flatMap { x =>
          (0 until convResSize).flatMap { y =>
            (0 until 1000).map { c =>
              (c + x * 1 + y * 4 * 1).toDouble
            }
          }
        }.toArray

      val image = new ChannelMajorArrayVectorizedImage(
        imgArr, ImageMetadata(convResSize, convResSize, 1000))

      val poolSizeReqd = math.ceil(convResSize / 2.0).toInt

      // We want poolSize to be even !!
      val poolSize = (math.ceil(poolSizeReqd / 2.0) * 2).toInt
      // overlap as little as possible
      val poolStride = convResSize - poolSize


      println(s"VALUES: $convSize $convResSize $poolSizeReqd $poolSize $poolStride")

      def summ(x: DenseVector[Double]): Double = sum(x)

      val pooling = new Pooler(poolStride, poolSize, identity, summ)
      val poolImage = pooling(image)
    }
  }
} 
Example 180
Source File: WindowingSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.images

import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.{ChannelMajorArrayVectorizedImage, ImageMetadata, TestUtils}

class WindowingSuite extends FunSuite with Logging {

  test("windowing") {
    val image = TestUtils.loadTestImage("images/000012.jpg")
    val stride = 100
    val size = 50

    val windowing = new Windower(stride, size)

    val windows = windowing.getImageWindow(image)

    assert(windows.map(_.metadata.xDim).forall(_ == size) &&
      windows.map(_.metadata.yDim).forall(_ == size),
      "All windows must be 100x100")

    assert(windows.size == (image.metadata.xDim/stride) * (image.metadata.yDim/stride),
      "Must have number of windows matching xDims and yDims given the stride.")
  }

  test("1x1 windowing") {
    val imgArr =
      (0 until 4).flatMap { x =>
        (0 until 4).flatMap { y =>
          (0 until 1).map { c =>
            (c + x * 1 + y * 4 * 1).toDouble
          }
        }
      }.toArray


    val image = new ChannelMajorArrayVectorizedImage(imgArr, ImageMetadata(4, 4, 1))

    val windower = new Windower(1, 1)
    val windowImages = windower.getImageWindow(image)

    assert(windowImages.length === 16)
    assert(windowImages(0).get(0, 0, 0) === 0)
    assert(windowImages(1).get(0, 0, 0) === 1.0)
    assert(windowImages(2).get(0, 0, 0) === 2.0)
    assert(windowImages(3).get(0, 0, 0) === 3.0)
  }

  test("2x2 windowing") {
    val imgArr =
      (0 until 4).flatMap { x =>
        (0 until 4).flatMap { y =>
          (0 until 1).map { c =>
            (c + x * 1 + y * 4 * 1).toDouble
          }
        }
      }.toArray


    val image = new ChannelMajorArrayVectorizedImage(imgArr, ImageMetadata(4, 4, 1))

    val windower = new Windower(2, 2)

    val windowImages = windower.getImageWindow(image)

    assert(windowImages.length === 4)

    assert(windowImages(0).get(0, 0, 0) === 0)
    assert(windowImages(1).get(0, 0, 0) === 2.0)
    assert(windowImages(2).get(0, 0, 0) === 8.0)
    assert(windowImages(3).get(0, 0, 0) === 10.0)
  }

  test("nxn windowing with step=1") {
    val dim = 30
    val imgArr =
      (0 until dim).flatMap { x =>
        (0 until dim).flatMap { y =>
          (0 until 1).map { c =>
            (c + x * 1 + y * 4 * 1 + 10).toDouble
          }
        }
      }.toArray


    val image = new ChannelMajorArrayVectorizedImage(imgArr, ImageMetadata(dim, dim, 1))
    val sizes = List(1, 2, 3, 4, 6, 8)

    sizes.foreach { w =>
      val windower = new Windower(1, w)
      val windowImages = windower.getImageWindow(image)
      assert(windowImages.length === (dim-w+1) * (dim-w+1))
      assert(windowImages.forall(x => !x.toArray.contains(0.0)))
    }
  }
} 
Example 181
Source File: LCSExtractorSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.images

import breeze.linalg._
import org.scalatest.FunSuite

import keystoneml.pipelines.Logging
import keystoneml.utils.{ImageUtils, Stats, TestUtils}

class LCSExtractorSuite extends FunSuite with Logging {
  test("Load an Image and compute LCS Features") {
    val testImage = TestUtils.loadTestImage("images/gantrycrane.png")

    val lf = new LCSExtractor(stride=4, subPatchSize=6, strideStart=16)
    val lcsDescriptors = convert(lf.apply(testImage), Double)

    val firstKeyPointSum = sum(lcsDescriptors(::, 0))
    val fullFeatureSum = sum(lcsDescriptors)

    // Values found from running matlab code on same input file.
    val matlabFirstKeyPointSum = 3.786557667540610e+03
    val matlabFullFeatureSum = 3.171963632855949e+07

    assert(
      Stats.aboutEq((firstKeyPointSum - matlabFirstKeyPointSum)/matlabFirstKeyPointSum, 0, 1e-8),
      "First keypoint sum must match for LCS")
    assert(Stats.aboutEq((fullFeatureSum - matlabFullFeatureSum)/matlabFullFeatureSum, 0, 1e-8),
      "Sum of LCS must match expected sum")
  }
} 
Example 182
Source File: TermFrequencySuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.misc

import keystoneml.nodes.stats.TermFrequency
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.workflow.PipelineContext

class TermFrequencySuite extends FunSuite with PipelineContext {
  test("term frequency of simple strings") {
    sc = new SparkContext("local", "test")
    val in = Seq(Seq[Any]("b", "a", "c", "b", "b", "a", "b"))
    val out = TermFrequency().apply(sc.parallelize(in)).first().toMap
    assert(out === Map("a" -> 2, "b" -> 4, "c" -> 1))
  }

  test("term frequency of varying types") {
    sc = new SparkContext("local", "test")
    val in = Seq(Seq("b", "a", "c", ("b", "b"), ("b", "b"), 12, 12, "a", "b", 12))
    val out = TermFrequency().apply(sc.parallelize(in)).first().toMap
    assert(out === Map("a" -> 2, "b" -> 2, "c" -> 1, ("b", "b") -> 2, 12 -> 3))
  }

  test("log term frequency") {
    sc = new SparkContext("local", "test")
    val in = Seq(Seq[Any]("b", "a", "c", "b", "b", "a", "b"))
    val out = TermFrequency(x => math.log(x + 1)).apply(sc.parallelize(in)).first().toMap
    assert(out === Map("a" -> math.log(3), "b" -> math.log(5), "c" -> math.log(2)))
  }
} 
Example 183
Source File: SparseFeatureVectorizerSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.misc

import keystoneml.nodes.util.{SparseFeatureVectorizer, AllSparseFeatures, CommonSparseFeatures}
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.workflow.PipelineContext

class SparseFeatureVectorizerSuite extends FunSuite with PipelineContext with Logging {
  test("sparse feature vectorization") {
    sc = new SparkContext("local", "test")

    val featureVectorizer = new SparseFeatureVectorizer(Map("First" -> 0, "Second" -> 1, "Third" -> 2))
    val test = Seq(("Third", 4.0), ("Fourth", 6.0), ("First", 1.0))
    val vector = featureVectorizer.apply(sc.parallelize(Seq(test))).first()

    assert(vector.size == 3)
    assert(vector(0) == 1)
    assert(vector(1) == 0)
    assert(vector(2) == 4)
  }

  test("all sparse feature selection") {
    sc = new SparkContext("local", "test")
    val train = sc.parallelize(List(Seq(("First", 0.0), ("Second", 6.0)), Seq(("Third", 3.0), ("Second", 4.0))))

    val featureVectorizer = AllSparseFeatures().fit(train.map(x => x))
    // The selected features should now be "First", "Second", and "Third"

    val test = Seq(("Third", 4.0), ("Fourth", 6.0), ("First", 1.0))
    val out = featureVectorizer.apply(sc.parallelize(Seq(test))).first().toArray

    assert(out === Array(1.0, 0.0, 4.0))
  }

  test("common sparse feature selection") {
    sc = new SparkContext("local", "test")
    val train = sc.parallelize(List(
      Seq(("First", 0.0), ("Second", 6.0)),
      Seq(("Third", 3.0), ("Second", 4.8)),
      Seq(("Third", 7.0), ("Fourth", 5.0)),
      Seq(("Fifth", 5.0), ("Second", 7.3))
    ))

    val featureVectorizer = CommonSparseFeatures(2).fit(train.map(x => x))
    // The selected features should now be "Second", and "Third"

    val test = Seq(("Third", 4.0), ("Seventh", 8.0), ("Second", 1.3), ("Fourth", 6.0), ("First", 1.0))
    val out = featureVectorizer.apply(sc.parallelize(Seq(test))).first().toArray

    assert(out === Array(1.3, 4.0))
  }
} 
Example 184
Source File: LinearRectifierSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.stats

import breeze.linalg.DenseMatrix
import breeze.stats.distributions.Rand
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.pipelines._
import keystoneml.utils.{TestUtils, MatrixUtils}
import keystoneml.workflow.PipelineContext

class LinearRectifierSuite extends FunSuite with PipelineContext with Logging {

  test("Test MaxVal") {
    sc = new SparkContext("local", "test")
    val matrixParts = TestUtils.createRandomMatrix(sc, 128, 16, 4).rdd.map(_.mat)

    val x = matrixParts.flatMap(y => MatrixUtils.matrixToRowArray(y))
    val y = x.map(r => r.forall(_ >= 0.0))

    val valmaxNode = LinearRectifier()
    val maxy = valmaxNode.apply(x).map(r => r.forall(_ >= 0.0))

    //The random matrix should *not* all be >= 0
    assert(!y.reduce {(a,b) => a | b})

    //The valmax'ed random matrix *should* all be >= 0.
    assert(maxy.reduce {(a,b) => a | b})
  }
} 
Example 185
Source File: RandomSignNodeSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.stats

import breeze.linalg._
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import keystoneml.pipelines.Logging

class RandomSignNodeSuite extends FunSuite with Logging with ShouldMatchers {

  test("RandomSignNode") {
    val signs = DenseVector(1.0, -1.0, 1.0)
    val node = RandomSignNode(signs)
    val data: DenseVector[Double] = DenseVector(1.0, 2.0, 3.0)
    val result = node(data)
    Seq(result) should equal (Seq(DenseVector(1.0, -2.0, 3.0)))
  }

  test("RandomSignNode.create") {
    val node = RandomSignNode(1000)
    
    node.signs.foreach(elt => assert(elt == -1.0 || elt == 1.0))
  }
} 
Example 186
Source File: CosineRandomFeaturesSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.stats

import breeze.linalg._
import breeze.numerics.cos
import breeze.stats._
import breeze.stats.distributions.{CauchyDistribution, Rand}
import org.scalatest.FunSuite
import keystoneml.utils.Stats


class CosineRandomFeaturesSuite extends FunSuite {
  val gamma = 1.34
  val numInputFeatures = 400
  val numOutputFeatures = 1000

  test("Guassian cosine random features") {
    val rf = CosineRandomFeatures(numInputFeatures, numOutputFeatures, gamma)

    // Check that b is uniform
    assert(max(rf.b) <= 2*math.Pi)
    assert(min(rf.b) >= 0)
    assert(rf.b.size == numOutputFeatures)

    // Check that W is gaussian
    assert(rf.W.rows == numOutputFeatures)
    assert(rf.W.cols == numInputFeatures)
    assert(Stats.aboutEq(mean(rf.W),0, 10e-3 * gamma))
    assert(Stats.aboutEq(variance(rf.W), gamma * gamma, 10e-3 * gamma * gamma))

    //check the mapping
    val in = DenseVector.rand(numInputFeatures, Rand.uniform)
    val out = cos((in.t * rf.W.t).t + rf.b)
    assert(Stats.aboutEq(rf(in), out, 10e-3))
  }

  test("Cauchy cosine random features") {
    val rf = CosineRandomFeatures(
      numInputFeatures,
      numOutputFeatures,
      gamma,
      new CauchyDistribution(0, 1))

    // Check that b is uniform
    assert(max(rf.b) <= 2*math.Pi)
    assert(min(rf.b) >= 0)
    assert(rf.b.size == numOutputFeatures)

    // Check that W is cauchy
    assert(rf.W.rows == numOutputFeatures)
    assert(rf.W.cols == numInputFeatures)
    assert(Stats.aboutEq(median(rf.W),0,10e-3 * gamma))

    //check the mapping
    val in = DenseVector.rand(numInputFeatures, Rand.uniform)
    val out = cos((in.t * rf.W.t).t + rf.b)
    assert(Stats.aboutEq(rf(in), out, 10e-3))
  }
} 
Example 187
Source File: PaddedFFTSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.stats

import breeze.linalg._
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.utils.Stats
import keystoneml.workflow.PipelineContext


class PaddedFFTSuite extends FunSuite with PipelineContext with Logging {
  test("Test PaddedFFT node") {
    sc = new SparkContext("local", "test")

    // Set up a test matrix.
    val ones = DenseVector.zeros[Double](100)
    val twos = DenseVector.zeros[Double](100)
    ones(0) = 1.0
    twos(2) = 1.0

    val x = sc.parallelize(Seq(twos, ones))
    val fftd = PaddedFFT().apply(x).collect()

    val twosout = fftd(0)
    val onesout = fftd(1)

    // Proof by agreement w/ R: Re(fft(c(0, 0, 1, rep(0, 125))))
    assert(twosout.length === 64)
    assert(Stats.aboutEq(twosout(0), 1.0))
    assert(Stats.aboutEq(twosout(16), 0.0))
    assert(Stats.aboutEq(twosout(32), -1.0))
    assert(Stats.aboutEq(twosout(48), 0.0))

    // Proof by agreement w/ R: Re(fft(c(1, rep(0, 127))))
    assert(Stats.aboutEq(onesout, DenseVector.ones[Double](64)))
  }
} 
Example 188
Source File: CoreNLPFeatureExtractorSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.nlp

import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.pipelines.Logging
import keystoneml.workflow.PipelineContext

class CoreNLPFeatureExtractorSuite extends FunSuite with PipelineContext with Logging {
  test("lemmatization") {
    sc = new SparkContext("local", "test")

    val text = "jumping snakes lakes oceans hunted"
    val tokens = CoreNLPFeatureExtractor(1 to 3).apply(sc.parallelize(Seq(text))).first().toSet

    // Make sure at least very simple cases were lemmatized
    assert(tokens.contains("jump"))
    assert(tokens.contains("snake"))
    assert(tokens.contains("lake"))
    assert(tokens.contains("ocean"))
    assert(tokens.contains("hunt"))

    // Assert the unlemmatized tokens are no longer there
    assert(!tokens.contains("jumping"))
    assert(!tokens.contains("snakes"))
    assert(!tokens.contains("oceans"))
    assert(!tokens.contains("lakes"))
    assert(!tokens.contains("hunted"))
  }

  test("entity extraction") {
    sc = new SparkContext("local", "test")

    val text = "John likes cake and he lives in Florida"
    val tokens = CoreNLPFeatureExtractor(1 to 3).apply(sc.parallelize(Seq(text))).first().toSet

    // Make sure at least very simple entities were identified and extracted
    assert(tokens.contains("PERSON"))
    assert(tokens.contains("LOCATION"))

    // Assert the original tokens are no longer there
    assert(!tokens.contains("John"))
    assert(!tokens.contains("Florida"))
  }

  test("1-2-3-grams") {
    sc = new SparkContext("local", "test")

    val text = "a b c d"
    val tokens = CoreNLPFeatureExtractor(1 to 3).apply(sc.parallelize(Seq(text))).first().toSet

    // Make sure expected unigrams appear
    assert(tokens.contains("a"))
    assert(tokens.contains("b"))
    assert(tokens.contains("c"))
    assert(tokens.contains("d"))

    // Make sure expected bigrams appear
    assert(tokens.contains("a b"))
    assert(tokens.contains("b c"))
    assert(tokens.contains("c d"))

    // Make sure expected 3-grams appear
    assert(tokens.contains("a b c"))
    assert(tokens.contains("b c d"))
  }
} 
Example 189
Source File: NGramIndexerSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.nlp

import org.scalatest.FunSuite

class NGramIndexerSuite extends FunSuite {

  test("pack()") {
    require(NaiveBitPackIndexer.pack(Seq(1)) == math.pow(2, 40).toLong)

    require(NaiveBitPackIndexer.pack(Seq(1, 1)) ==
      math.pow(2, 40).toLong + math.pow(2, 20).toLong + math.pow(2, 60).toLong)

    require(NaiveBitPackIndexer.pack(Seq(1, 1, 1)) ==
      1 + math.pow(2, 40).toLong + math.pow(2, 20).toLong + math.pow(2, 61).toLong)

    val ngramIndexer = new NGramIndexerImpl[Int]
    val seq = ngramIndexer.minNgramOrder to ngramIndexer.maxNgramOrder
    require(ngramIndexer.pack(seq).equals(new NGram(seq)))
  }

  test("removeFarthestWord()") {
    def testWith[Word >: Int, Ngram](indexer: BackoffIndexer[Word, Ngram]) = {
      var ngramId = indexer.pack(Seq(1, 2, 3))
      var context = indexer.removeFarthestWord(ngramId)
      var expected = indexer.pack(Seq(2, 3))
      require(context == expected, s"actual $context, expected $expected")

      ngramId = indexer.pack(Seq(1, 2))
      context = indexer.removeFarthestWord(ngramId)
      expected = indexer.pack(Seq(2))
      require(context == expected, s"actual $context, expected $expected")
    }

    testWith(new NGramIndexerImpl[Int])
    testWith(NaiveBitPackIndexer)
  }

  test("removeCurrentWord()") {
    def testWith[Word >: Int, Ngram](indexer: BackoffIndexer[Word, Ngram]) = {
      var ngramId = indexer.pack(Seq(1, 2, 3))
      var context = indexer.removeCurrentWord(ngramId)
      var expected = indexer.pack(Seq(1, 2))
      require(context == expected, s"actual $context, expected $expected")

      ngramId = indexer.pack(Seq(1, 2))
      context = indexer.removeCurrentWord(ngramId)
      expected = indexer.pack(Seq(1))
      require(context == expected, s"actual $context, expected $expected")
    }

    testWith(new NGramIndexerImpl[Int])
    testWith(NaiveBitPackIndexer)
  }

} 
Example 190
Source File: StringUtilsSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.nlp

import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.workflow.PipelineContext

class StringUtilsSuite extends FunSuite with PipelineContext {
  val stringToManip = Array("  The quick BROWN fo.X ", " ! !.,)JumpeD. ovER the LAZy DOG.. ! ")
  test("trim") {
    sc = new SparkContext("local", "test")
    val out = Trim.apply(sc.parallelize(stringToManip, 1)).collect().toSeq
    assert(out === Seq("The quick BROWN fo.X", "! !.,)JumpeD. ovER the LAZy DOG.. !"))
  }

  test("lower case") {
    sc = new SparkContext("local", "test")
    val out = LowerCase().apply(sc.parallelize(stringToManip, 1)).collect().toSeq
    assert(out === Seq("  the quick brown fo.x ", " ! !.,)jumped. over the lazy dog.. ! "))
  }

  test("tokenizer") {
    sc = new SparkContext("local", "test")
    val out = Tokenizer().apply(sc.parallelize(stringToManip, 1)).collect().toSeq
    assert(out === Seq(Seq("", "The", "quick", "BROWN", "fo", "X"), Seq("", "JumpeD", "ovER", "the", "LAZy", "DOG")))
  }
} 
Example 191
Source File: ClassLabelIndicatorsSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.util

import breeze.linalg.DenseVector
import org.scalatest.FunSuite

class ClassLabelIndicatorsSuite extends FunSuite {
  test("single label indicators") {
    intercept[AssertionError] {
      val zerolabels = ClassLabelIndicatorsFromIntLabels(0)
    }

    intercept[AssertionError] {
      val onelabel = ClassLabelIndicatorsFromIntLabels(1)
    }


    val fivelabel = ClassLabelIndicatorsFromIntLabels(5)
    assert(fivelabel(2) === DenseVector(-1.0,-1.0,1.0,-1.0,-1.0))

    intercept[RuntimeException] {
      fivelabel(5)
    }
  }

  test("multiple label indicators without validation") {
    intercept[AssertionError] {
      val zerolabels = ClassLabelIndicatorsFromIntArrayLabels(0)
    }

    intercept[AssertionError] {
      val onelabel = ClassLabelIndicatorsFromIntArrayLabels(1)
    }

    val fivelabel = ClassLabelIndicatorsFromIntArrayLabels(5)

    assert(fivelabel(Array(2,1)) === DenseVector(-1.0,1.0,1.0,-1.0,-1.0))

    intercept[IndexOutOfBoundsException] {
      fivelabel(Array(4,6))
    }

    assert(fivelabel(Array(-1,2)) === DenseVector(-1.0,-1.0,1.0,-1.0,1.0),
      "In the unchecked case, we should get weird behavior.")

  }

  test("multiple label indicators with validation") {
    intercept[AssertionError] {
      val zerolabels = ClassLabelIndicatorsFromIntArrayLabels(0, true)
    }

    intercept[AssertionError] {
      val onelabel = ClassLabelIndicatorsFromIntArrayLabels(1, true)
    }

    val fivelabel = ClassLabelIndicatorsFromIntArrayLabels(5, true)

    assert(fivelabel(Array(2,1)) === DenseVector(-1.0,1.0,1.0,-1.0,-1.0))

    intercept[RuntimeException] {
      fivelabel(Array(4,6))
    }

    intercept[RuntimeException] {
      fivelabel(Array(-1,2))
    }
  }
} 
Example 192
Source File: VectorSplitterSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.util

import breeze.linalg._
import org.scalatest.FunSuite

class VectorSplitterSuite extends FunSuite {
  test("vector splitter") {
    for (
      bs <- Array(128, 256, 512, 1024, 2048);
      mul <- 0 to 2;
      off <- 0 to 20 by 5;
      feats <- Array(Some(bs*mul + off), None)
    ) {
      val sp = new VectorSplitter(bs, feats)
      val vec = DenseVector.zeros[Double](bs*mul + off)

      val expectedSplits = (bs*mul + off)/bs + (if ((bs*mul + off) % bs == 0) 0 else 1)

      assert(sp.splitVector(vec).length === expectedSplits,
        s"True length is ${sp.splitVector(vec).length}, expected length is ${expectedSplits}")
    }
  }

  test("vector splitter maintains order") {
    for (
      bs <- Array(128, 256, 512, 1024, 2048);
      mul <- 0 to 2;
      off <- 0 to 20 by 5;
      feats <- Array(Some(bs*mul + off), None)
    ) {
      val sp = new VectorSplitter(bs, feats)
      val vec = rand(bs*mul + off)

      assert(DenseVector.vertcat(sp.splitVector(vec):_*) === vec,
        s"Recombinded split vector of length ${bs*mul + off} with block size $bs did not match its input")
    }
  }
} 
Example 193
Source File: TopKClassifierSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.util

import breeze.linalg.DenseVector
import org.apache.spark.SparkContext
import org.scalatest.FunSuite
import keystoneml.workflow.PipelineContext

class TopKClassifierSuite extends FunSuite with PipelineContext {
  test("top k classifier, k <= vector size") {
    sc = new SparkContext("local", "test")

    assert(TopKClassifier(2).apply(DenseVector(-10.0, 42.4, -43.0, 23.0)) === Array(1, 3))
    assert(TopKClassifier(4).apply(DenseVector(Double.MinValue, Double.MaxValue, 12.0, 11.0, 10.0)) === Array(1, 2, 3, 4))
    assert(TopKClassifier(3).apply(DenseVector(3.0, -23.2, 2.99)) === Array(0, 2, 1))
  }

  test("top k classifier, k > vector size") {
    sc = new SparkContext("local", "test")

    assert(TopKClassifier(5).apply(DenseVector(-10.0, 42.4, -43.0, 23.0)) === Array(1, 3, 0, 2))
    assert(TopKClassifier(2).apply(DenseVector(Double.MinValue)) === Array(0))
    assert(TopKClassifier(20).apply(DenseVector(3.0, -23.2, 2.99)) === Array(0, 2, 1))
  }

} 
Example 194
Source File: VOCLoaderSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.loaders

import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import keystoneml.utils.TestUtils
import keystoneml.workflow.PipelineContext

class VOCLoaderSuite extends FunSuite with PipelineContext {
  test("load a sample of VOC data") {
    sc = new SparkContext("local", "test")
    val dataPath = TestUtils.getTestResourceFileName("images/voc")
    val labelsPath = TestUtils.getTestResourceFileName("images/voclabels.csv")

    val imgs = VOCLoader(sc,
      VOCDataPath(dataPath, "VOCdevkit/VOC2007/JPEGImages/", Some(1)),
      VOCLabelPath(labelsPath)).collect()

    // We should have 10 images
    assert(imgs.length === 10)

    // There should be one file whose name ends with "000104.jpg"
    val personMonitor = imgs.filter(_.filename.get.endsWith("000104.jpg"))
    assert(personMonitor.length === 1)

    // It should have two labels, 14 and 19.
    assert(personMonitor(0).label.contains(14) && personMonitor(0).label.contains(19))

    // There should be two 13 labels total and 9 should be distinct.
    assert(imgs.map(_.label).flatten.length === 13)
    assert(imgs.map(_.label).flatten.distinct.length === 9)
  }
} 
Example 195
Source File: ImageNetLoaderSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.loaders

import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import keystoneml.utils.TestUtils
import keystoneml.workflow.PipelineContext

class ImageNetLoaderSuite extends FunSuite with PipelineContext {
  test("load a sample of imagenet data") {
    sc = new SparkContext("local", "test")
    val dataPath = TestUtils.getTestResourceFileName("images/imagenet")
    val labelsPath = TestUtils.getTestResourceFileName("images/imagenet-test-labels")

    val imgs = ImageNetLoader.apply(sc, dataPath, labelsPath).collect()
    // We should have 5 images
    assert(imgs.length === 5)

    // The images should all have label 12
    assert(imgs.map(_.label).distinct.length === 1)
    assert(imgs.map(_.label).distinct.head === 12)

    // The image filenames should begin with n15075141
    assert(imgs.forall(_.filename.get.startsWith("n15075141")), "Image filenames should be correct")
  }
} 
Example 196
Source File: StupidBackoffSuite.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.pipelines.nlp

import keystoneml.nodes.nlp._

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD

import org.scalatest.FunSuite
import keystoneml.workflow.PipelineContext

import scala.collection.JavaConverters._

class StupidBackoffSuite extends FunSuite with PipelineContext {

  val data = Seq("Winter is coming",
    "Finals are coming",
    "Summer is coming really soon")

  def featurizer(orders: Seq[Int], mode: NGramsCountsMode.Value = NGramsCountsMode.Default) = {
    def feat(data: RDD[String]) = {
      NGramsCounts[String](mode).apply(
        (Tokenizer() andThen NGramsFeaturizer[String](orders)).apply(data).get)
    }
    feat _
  }

  def requireNGramColocation[T, V](
      ngrams: RDD[(NGram[T], V)],
      indexer: BackoffIndexer[T, NGram[T]]) = {

    ngrams.mapPartitions { part =>
      val map = new java.util.HashMap[NGram[T], V]().asScala
      part.foreach { case (ngramId, count) => map.put(ngramId, count) }

      map.keySet.foreach { ngramId =>
        var currNGram = ngramId
        while (indexer.ngramOrder(currNGram) > 2) {
          val context = indexer.removeCurrentWord(currNGram)
          require(map.contains(context),
            s"ngram $currNGram is not co-located with its context $context within same partition")
          currNGram = context
        }
      }
      Iterator.empty
    }.count()
  }

  test("end-to-end InitialBigramPartitioner") {
    sc = new SparkContext("local[4]", "StupidBackoffSuite")
    val corpus = sc.parallelize(data, 3)
    val ngrams = featurizer(2 to 5, NGramsCountsMode.NoAdd)(corpus)
    val unigrams = featurizer(1 to 1)(corpus)
      .collectAsMap()
      .map { case (key, value) => key.words(0) -> value }

    val stupidBackoff = StupidBackoffEstimator[String](unigrams).fit(ngrams)
    requireNGramColocation[String, Double](stupidBackoff.scoresRDD, new NGramIndexerImpl)
  }

  test("Stupid Backoff calculates correct scores") {
    sc = new SparkContext("local[4]", "StupidBackoffSuite")
    val corpus = sc.parallelize(data, 3)
    val ngrams = featurizer(2 to 5, NGramsCountsMode.NoAdd)(corpus)
    val unigrams = featurizer(1 to 1)(corpus)
      .collectAsMap()
      .map { case (key, value) => key.words(0) -> value }
    val lm = StupidBackoffEstimator[String](unigrams).fit(ngrams)

    assert(lm.score(new NGram(Seq("is", "coming"))) === 2.0 / 2.0)
    assert(lm.score(new NGram(Seq("is", "coming", "really"))) === 1.0 / 2.0)

    assert(lm.score(new NGram(Seq("is", "unseen-coming"))) === 0,
      "not equal to expected: bacoffed once & curr word unseen, so should be zero")
    assert(lm.score(new NGram(Seq("is-unseen", "coming"))) === lm.alpha * 3.0 / lm.numTokens,
      "not equal to expected: backoffed once, should be alpha * currWordCount / numTokens")
  }

} 
Example 197
Source File: MockedDefaultSourceSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import java.util.concurrent.{Callable, Executors}

import com.sap.spark.dsmock.DefaultSource
import org.apache.spark.sql.sources.HashPartitioningFunction
import org.apache.spark.sql.{GlobalSapSQLContext, Row, SQLContext}
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.FunSuite

import scala.concurrent.duration._


class MockedDefaultSourceSuite
  extends FunSuite
  with GlobalSapSQLContext {

  val testTimeout = 10 // seconds

  private def numberOfThreads: Int = {
    val noOfCores = Runtime.getRuntime.availableProcessors()
    assert(noOfCores > 0)

    if (noOfCores == 1) 2 // It should always be multithreaded although only
                          // one processor is available (pseudo-multithreading)
    else noOfCores
  }

  def runMultiThreaded[A](op: Int => A): Seq[A] = {
    info(s"Running with $numberOfThreads threads")
    val pool = Executors.newFixedThreadPool(numberOfThreads)

    val futures = 1 to numberOfThreads map { i =>
      val task = new Callable[A] {
        override def call(): A = op(i)
      }
      pool.submit(task)
    }

    futures.map(_.get(testTimeout, SECONDS))
  }

  test("Underlying mocks of multiple threads are distinct") {
    val dataSources = runMultiThreaded { _ =>
      DefaultSource.withMock(identity)
    }

    dataSources foreach { current =>
      val sourcesWithoutCurrent = dataSources.filter(_.ne(current))
      assert(sourcesWithoutCurrent.forall(_.underlying ne current))
    }
  }

  test("Mocking works as expected") {
    runMultiThreaded { i =>
      DefaultSource.withMock { defaultSource =>
        when(defaultSource.getAllPartitioningFunctions(
          anyObject[SQLContext],
          anyObject[Map[String, String]]))
          .thenReturn(Seq(HashPartitioningFunction(s"foo$i", Seq.empty, None)))

        val Array(Row(name)) = sqlc
          .sql("SHOW PARTITION FUNCTIONS USING com.sap.spark.dsmock")
          .select("name")
          .collect()

        assertResult(s"foo$i")(name)
      }
    }
  }
} 
Example 198
Source File: HierarchyBuilderSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hierarchy

import org.apache.spark.SparkConf
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.Node
import org.scalatest.FunSuite

class HierarchyBuilderSuite extends FunSuite {

  val N = 5
  val rowFunctions = HierarchyRowFunctions(Seq.fill(N)(StringType))

  test("HierarchyRowFunctions.rowGet") {
    for (i <- 0 to 5) {
      val row = Row((0 to 5).map(_.toString): _*)
      assertResult(i.toString)(rowFunctions.rowGet(i)(row))
    }
  }

  test("HierarchyRowFunctions.rowInit") {
    for (i <- 0 to 5) {
      val row = Row((0 to 5).map(_.toString): _*)

      val result = rowFunctions.rowInit(rowFunctions.rowGet(i), StringType)(row, None)
      val expected = Row(row.toSeq :+ Node(List(i.toString), StringType): _*)
      assertResult(expected)(result)
    }
  }

  // scalastyle:off magic.number
  test("HierarchyRowFunctions.rowInitWithOrder") {
    for (i <- 0 to 5) {
      val row = Row((0 to 5).map(_.toString): _*)
      val result = rowFunctions.rowInit(rowFunctions.rowGet(i), StringType)(row, Some(42L))
      val expected = Row(row.toSeq :+ Node(List(i.toString),StringType, ordPath = List(42L)): _*)
      assertResult(expected)(result)
    }
  }
  // scalastyle:on magic.number

  test("HierarchyRowFunctions.rowModify") {
    for (i <- 0 to 5) {
      val rightRow = Row(0 to 5: _*)
      val leftRow = Row("foo", 0, "bar", Node(List(0),StringType))
      val result = rowFunctions.rowModify(
        rowFunctions.rowGet(i),StringType
      )(leftRow, rightRow)
      val expected = Row((0 to 5) :+ Node(List(0, i), StringType): _*)
      assertResult(expected)(result)
    }
  }

  // scalastyle:off magic.number
  test("HierarchyRowFunctions.rowModifyAndOrder") {
    for (i <- 0 to 5) {
      val rightRow = Row(0 to 5: _*)
      val leftRow = Row("foo", 0, "bar", Node(List(0),StringType))
      val result = rowFunctions.rowModifyAndOrder(
        rowFunctions.rowGet(i), StringType
      )(leftRow, rightRow, Some(42L))
      val expected = Row((0 to 5) :+ Node(List(0, i), StringType, ordPath = List(42L)): _*)
      assertResult(expected)(result)
    }
  }
  // scalastyle:on magic.number

  test("HierarchyBuilder closure is serializable") {
    val closureSerializer = new JavaSerializer(new SparkConf(loadDefaults = false)).newInstance()
    val serialized = closureSerializer.serialize(() =>
      HierarchyJoinBuilder(null, null, null, null, null, null))
  }

  test("HierarchyRowFunctions closure is serializable") {
    val closureSerializer = new JavaSerializer(new SparkConf(loadDefaults = false)).newInstance()
    val serialized = closureSerializer.serialize(() =>
      HierarchyRowJoinBuilder(null, null, null, null))
  }

} 
Example 199
Source File: ExtendableOptimizerSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.extension

import org.apache.spark.sql.catalyst.optimizer.{FiltersReduction, Optimizer}
import org.apache.spark.sql.extension.OptimizerFactory.ExtendableOptimizerBatch
import org.scalatest.{FunSuite, PrivateMethodTester}

class ExtendableOptimizerSuite extends FunSuite with PrivateMethodTester {

  implicit class OptimizerOps(opt: Optimizer) {
    private val nameMethod = PrivateMethod[String]('name)
    private def batches: Seq[AnyRef] = {
      
      val clazz = opt.getClass
      val batchesMethod = clazz.getMethods.find(_.getName == "batches").get
      batchesMethod.setAccessible(true)
      batchesMethod.invoke(opt).asInstanceOf[Seq[AnyRef]]
    }
    def batchNames: Seq[String] =
      batches map { b => b invokePrivate nameMethod() }
  }

  test("No rules is equivalent to DefaultOptimizer") {
    val extOpt = OptimizerFactory.produce()
    val defOpt = OptimizerFactoryForTests.default()
    assert(extOpt.batchNames == defOpt.batchNames)
  }

  test("One early batch is added before the main optimizer batch") {
    val extOpt = OptimizerFactory.produce(
      earlyBatches = ExtendableOptimizerBatch("FOO", 1, FiltersReduction :: Nil) :: Nil
    )

    assert(extOpt.batchNames match {
      case subQueries :: early :: other => early.equals("FOO")
    })
  }

  test("Several early batches are added before the main optimizer batch") {
    val extOpt = OptimizerFactory.produce(
      earlyBatches = ExtendableOptimizerBatch("FOO", 1, FiltersReduction :: Nil) ::
        ExtendableOptimizerBatch("BAR", 1, FiltersReduction :: Nil) ::
        Nil
    )

    assert(extOpt.batchNames match {
      case subQueries :: firstEarly :: secondEarly :: other =>
        firstEarly.equals("FOO") && secondEarly.equals("BAR")
    })
  }

  test("Expression rules are added") {
    val extOpt = OptimizerFactory.produce(
      mainBatchRules = FiltersReduction :: Nil
    )
    val defOpt = OptimizerFactoryForTests.default()
    assert(extOpt.batchNames == defOpt.batchNames)
  }

  test("Both rules are added") {
    val extOpt = OptimizerFactory.produce(
      earlyBatches = ExtendableOptimizerBatch("FOO", 1, FiltersReduction :: Nil) :: Nil,
      mainBatchRules = FiltersReduction :: Nil
    )
    val defOpt = OptimizerFactoryForTests.default()
    assert(extOpt.batchNames.toSet ==
      defOpt.batchNames.toSet ++ Seq("FOO"))
  }
} 
Example 200
Source File: ShowPartitionFunctionsSuite.scala    From HANAVora-Extensions   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import com.sap.spark.dstest.DefaultSource
import org.apache.spark.sql.sources.Stride
import org.apache.spark.util.PartitioningFunctionUtils
import org.scalatest.FunSuite

class ShowPartitionFunctionsSuite
  extends FunSuite
  with GlobalSapSQLContext
  with PartitioningFunctionUtils {

  override def beforeEach(): Unit = {
    super.beforeEach()
    DefaultSource.reset()
  }

  test("Show partition functions shows no partitioning functions if none are there") {
    val funs = sqlc.sql("SHOW PARTITION FUNCTIONS USING com.sap.spark.dstest").collect()

    assert(funs.isEmpty)
  }

  // scalastyle:off magic.number
  test("Show partition functions shows the previously registered partitioning functions") {
    createHashPartitioningFunction("foo", Seq("string", "float"), Some(10), "com.sap.spark.dstest")
    createRangePartitioningFunction("bar", "int", 0, 10, Stride(10), "com.sap.spark.dstest")
    createRangeSplitPartitioningFunction("baz", "float", Seq(1, 2, 3),
      rightClosed = true, "com.sap.spark.dstest")

    val funs = sqlc.sql("SHOW PARTITION FUNCTIONS USING com.sap.spark.dstest").collect()

    assertResult(Set(
      Row("baz", "RangeSplitPartitioningFunction", "FloatType", "1,2,3",
        true, null, null, null, null, null),
      Row("foo", "HashPartitioningFunction", "StringType,FloatType", null,
        null, null, null, null, null, 10),
      Row("bar", "RangeIntervalPartitioningFunction", "IntegerType", null, null,
        0, 10, "Stride", 10, null)))(funs.toSet)
  }
  // scalastyle:on magic.number

  // scalastyle:off magic.number
  test("Show partition functions does not show deleted functions") {
    createHashPartitioningFunction("foo", Seq("string", "float"), Some(10), "com.sap.spark.dstest")
    createRangePartitioningFunction("bar", "int", 0, 10, Stride(10), "com.sap.spark.dstest")
    createRangeSplitPartitioningFunction("baz", "float", Seq(1, 2, 3),
      rightClosed = true, "com.sap.spark.dstest")

    val f1 = sqlc.sql("SHOW PARTITION FUNCTIONS USING com.sap.spark.dstest").collect()

    assertResult(Set(
      Row("baz", "RangeSplitPartitioningFunction", "FloatType", "1,2,3",
        true, null, null, null, null, null),
      Row("foo", "HashPartitioningFunction", "StringType,FloatType", null,
        null, null, null, null, null, 10),
      Row("bar", "RangeIntervalPartitioningFunction", "IntegerType", null, null,
        0, 10, "Stride", 10, null)))(f1.toSet)

    dropPartitioningFunction("bar", dataSource = "com.sap.spark.dstest")

    val f2 = sqlc.sql("SHOW PARTITION FUNCTIONS USING com.sap.spark.dstest").collect()

    assertResult(Set(
      Row("baz", "RangeSplitPartitioningFunction", "FloatType", "1,2,3",
        true, null, null, null, null, null),
      Row("foo", "HashPartitioningFunction", "StringType,FloatType", null,
        null, null, null, null, null, 10)))(f2.toSet)
  }
  // scalastyle:on magic.number
}