org.scalatest.BeforeAndAfter Scala Examples

The following examples show how to use org.scalatest.BeforeAndAfter. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: ExperimentVariantEventElasticServiceTest.scala    From izanami   with Apache License 2.0 5 votes vote down vote up
package specs.elastic.abtesting

import domains.abtesting.events.impl.ExperimentVariantEventElasticService
import domains.abtesting.AbstractExperimentServiceTest
import domains.abtesting.events.ExperimentVariantEventService
import elastic.api.Elastic
import env.{DbDomainConfig, DbDomainConfigDetails, ElasticConfig}
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import play.api.libs.json.JsValue
import store.elastic.ElasticClient

class ExperimentVariantEventElasticServiceTest
    extends AbstractExperimentServiceTest("Elastic")
    with BeforeAndAfter
    with BeforeAndAfterAll {

  private val config            = ElasticConfig("localhost", 9210, "http", None, None, true)
  val elastic: Elastic[JsValue] = ElasticClient(config, system)

  override def dataStore(name: String): ExperimentVariantEventService.Service = ExperimentVariantEventElasticService(
    elastic,
    config,
    DbDomainConfig(env.Elastic, DbDomainConfigDetails(name, None), None)
  )

  override protected def before(fun: => Any)(implicit pos: Position): Unit = {
    cleanUpElastic
    super.before(fun)
  }

  override protected def afterAll(): Unit = {
    cleanUpElastic
    super.afterAll()
  }

  private def cleanUpElastic = {
    import _root_.elastic.codec.PlayJson._
    elastic.deleteIndex("*").futureValue
  }

} 
Example 2
Source File: controllers.scala    From izanami   with Apache License 2.0 5 votes vote down vote up
package specs.elastic.controllers

import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import store.elastic._
import controllers._
import scala.util.Random

class ElasticApikeyControllerSpec
    extends ApikeyControllerSpec("Elastic", Configs.elasticConfiguration)
    with BeforeAndAfterAll {
  override protected def beforeAll(): Unit = Configs.initEs
}
class ElasticConfigControllerSpec
    extends ConfigControllerSpec("Elastic", Configs.elasticConfiguration)
    with BeforeAndAfterAll {
  override protected def beforeAll(): Unit = Configs.initEs
}
class ElasticExperimentControllerSpec
    extends ExperimentControllerSpec("Elastic", Configs.elasticConfiguration)
    with BeforeAndAfterAll {
  override protected def beforeAll(): Unit = Configs.initEs
}
class ElasticFeatureControllerWildcardAccessSpec
    extends FeatureControllerWildcardAccessSpec("Elastic", Configs.elasticConfiguration)
    with BeforeAndAfterAll {
  override protected def beforeAll(): Unit = Configs.initEs
}
class ElasticFeatureControllerSpec
    extends FeatureControllerSpec("Elastic", Configs.elasticConfiguration)
    with BeforeAndAfterAll {
  override protected def beforeAll(): Unit = Configs.initEs
}
class ElasticGlobalScriptControllerSpec
    extends GlobalScriptControllerSpec("Elastic", Configs.elasticConfiguration)
    with BeforeAndAfterAll {
  override protected def beforeAll(): Unit = Configs.initEs
}
class ElasticUserControllerSpec
    extends UserControllerSpec("Elastic", Configs.elasticConfiguration)
    with BeforeAndAfterAll {
  override protected def beforeAll(): Unit = Configs.initEs
}
class ElasticWebhookControllerSpec
    extends WebhookControllerSpec("Elastic", Configs.elasticConfiguration)
    with BeforeAndAfterAll {
  override protected def beforeAll(): Unit = Configs.initEs
} 
Example 3
Source File: ElasticJsonDataStoreTest.scala    From izanami   with Apache License 2.0 5 votes vote down vote up
package specs.elastic.store

import elastic.api.Elastic
import env.{DbDomainConfig, DbDomainConfigDetails, ElasticConfig}
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import play.api.libs.json.JsValue
import store.AbstractJsonDataStoreTest
import store.elastic._

class ElasticJsonDataStoreTest extends AbstractJsonDataStoreTest("Elastic") with BeforeAndAfter with BeforeAndAfterAll {


  private val config = ElasticConfig("localhost", 9210, "http", None, None, true)
  val elastic: Elastic[JsValue] = ElasticClient(config, system)

  override def dataStore(dataStore: String): ElasticJsonDataStore = ElasticJsonDataStore(
    elastic, config, DbDomainConfig(env.Elastic, DbDomainConfigDetails(dataStore, None), None)
  )

  override protected def before(fun: => Any)(implicit pos: Position): Unit = {
    super.before(fun)
    cleanUpElastic
  }

  private def cleanUpElastic = {
    import _root_.elastic.implicits._
    import _root_.elastic.codec.PlayJson._
    elastic.deleteIndex("*").futureValue
  }

  override protected def afterAll(): Unit = {
    super.afterAll()
    cleanUpElastic
  }
} 
Example 4
Source File: ExperimentVariantEventRedisServiceTest.scala    From izanami   with Apache License 2.0 5 votes vote down vote up
package specs.redis.abtesting

import java.time.Duration

import domains.abtesting.events.impl.ExperimentVariantEventRedisService
import domains.abtesting.AbstractExperimentServiceTest
import domains.abtesting.events.ExperimentVariantEventService
import env.{DbDomainConfig, DbDomainConfigDetails, Master}
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import store.redis.{RedisClientBuilder, RedisWrapper}
import test.FakeApplicationLifecycle
import zio.{Exit, Reservation}

import scala.jdk.CollectionConverters._

class ExperimentVariantEventRedisServiceTest
    extends AbstractExperimentServiceTest("Redis")
    with BeforeAndAfter
    with BeforeAndAfterAll {

  import zio.interop.catz._

  val redisWrapper: Reservation[Any, Throwable, Option[RedisWrapper]] = runtime.unsafeRun(
    RedisClientBuilder
      .redisClient(
        Some(Master("localhost", 6380, 5)),
        system
      )
      .reserve
  )
  private val maybeRedisWrapper: Option[RedisWrapper] = runtime.unsafeRun(redisWrapper.acquire)

  override def dataStore(name: String): ExperimentVariantEventService.Service =
    ExperimentVariantEventRedisService(DbDomainConfig(env.Redis, DbDomainConfigDetails(name, None), None),
                                       maybeRedisWrapper)

  override protected def before(fun: => Any)(implicit pos: Position): Unit = {
    super.before(fun)
    deleteAllData
  }

  override protected def afterAll(): Unit = {
    super.afterAll()

    deleteAllData
    runtime.unsafeRun(redisWrapper.release(Exit.unit))
  }

  private def deleteAllData =
    maybeRedisWrapper.get.connection
      .sync()
      .del(maybeRedisWrapper.get.connection.sync().keys("*").asScala.toSeq: _*)

} 
Example 5
Source File: RedisJsonDataStoreTest.scala    From izanami   with Apache License 2.0 5 votes vote down vote up
package specs.redis.store

import java.time.Duration

import env.Master
import org.scalactic.source.Position
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import store.AbstractJsonDataStoreTest
import test.FakeApplicationLifecycle

import scala.jdk.CollectionConverters._
import store.redis.RedisWrapper
import store.redis.RedisClientBuilder
import store.redis.RedisJsonDataStore
import zio.{Exit, Reservation}

class RedisJsonDataStoreTest extends AbstractJsonDataStoreTest("Redis") with BeforeAndAfter with BeforeAndAfterAll {

  val redisWrapper: Reservation[Any, Throwable, Option[RedisWrapper]] = runtime.unsafeRun(
    RedisClientBuilder
      .redisClient(
        Some(Master("localhost", 6380, 5)),
        system
      )
      .reserve
  )
  private val maybeRedisWrapper: Option[RedisWrapper] = runtime.unsafeRun(redisWrapper.acquire)

  override def dataStore(name: String): RedisJsonDataStore =
    RedisJsonDataStore(maybeRedisWrapper.get, name)

  override protected def before(fun: => Any)(implicit pos: Position): Unit = {
    super.before(fun)
    deleteAllData
  }

  override protected def afterAll(): Unit = {
    super.afterAll()

    deleteAllData
    runtime.unsafeRun(redisWrapper.release(Exit.unit))
  }

  private def deleteAllData =
    maybeRedisWrapper.get.connection.sync().del(maybeRedisWrapper.get.connection.sync().keys("*").asScala.toSeq: _*)

} 
Example 6
Source File: ExperimentVariantEventPostgresqlServiceTest.scala    From izanami   with Apache License 2.0 5 votes vote down vote up
package specs.postgresql.abtesting

import cats.effect.{ContextShift, IO}
import domains.abtesting.events.impl.ExperimentVariantEventPostgresqlService
import domains.abtesting.AbstractExperimentServiceTest
import domains.abtesting.events.ExperimentVariantEventService
import env.{DbDomainConfig, DbDomainConfigDetails, PostgresqlConfig}
import libs.logs.ZLogger
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import store.postgresql.PostgresqlClient
import zio.{Exit, Reservation}

class ExperimentVariantEventPostgresqlServiceTest
    extends AbstractExperimentServiceTest("Postgresql")
    with BeforeAndAfter
    with BeforeAndAfterAll {

  implicit val cs: ContextShift[IO] = IO.contextShift(scala.concurrent.ExecutionContext.global)
  import zio.interop.catz._

  private val pgConfig = PostgresqlConfig(
    "org.postgresql.Driver",
    "jdbc:postgresql://localhost:5555/izanami",
    "izanami",
    "izanami",
    32,
    None
  )

  val rPgClient: Reservation[ZLogger, Throwable, Option[PostgresqlClient]] = runtime.unsafeRun(
    PostgresqlClient
      .postgresqlClient(
        system,
        Some(pgConfig)
      )
      .reserve
      .provideLayer(ZLogger.live)
  )

  private val client: Option[PostgresqlClient] = runtime.unsafeRun(rPgClient.acquire.provideLayer(ZLogger.live))

  override def dataStore(name: String): ExperimentVariantEventService.Service = ExperimentVariantEventPostgresqlService(
    client.get,
    DbDomainConfig(env.Postgresql, DbDomainConfigDetails(name, None), None)
  )

  override protected def afterAll(): Unit = {
    super.afterAll()
    runtime.unsafeRun(rPgClient.release(Exit.unit).provideLayer(ZLogger.live))
  }
} 
Example 7
Source File: PostgresqlJsonDataStoreTest.scala    From izanami   with Apache License 2.0 5 votes vote down vote up
package specs.postgresql.store

import env.{DbDomainConfig, DbDomainConfigDetails, PostgresqlConfig}
import libs.logs.ZLogger
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import store.AbstractJsonDataStoreTest
import test.FakeApplicationLifecycle
import store.postgresql.PostgresqlClient
import store.postgresql.PostgresqlJsonDataStore

class PostgresqlJsonDataStoreTest
    extends AbstractJsonDataStoreTest("Postgresql")
    with BeforeAndAfter
    with BeforeAndAfterAll {
  import zio._
  import zio.interop.catz._

  private val pgConfig = PostgresqlConfig(
    "org.postgresql.Driver",
    "jdbc:postgresql://localhost:5555/izanami",
    "izanami",
    "izanami",
    32,
    None
  )

  val rPgClient: Reservation[ZLogger, Throwable, Option[PostgresqlClient]] = runtime.unsafeRun(
    PostgresqlClient
      .postgresqlClient(
        system,
        Some(pgConfig)
      )
      .reserve
      .provideLayer(ZLogger.live)
  )

  private val client: Option[PostgresqlClient] = runtime.unsafeRun(rPgClient.acquire.provideLayer(ZLogger.live))

  override def dataStore(name: String): PostgresqlJsonDataStore = {
    val store =
      PostgresqlJsonDataStore(client.get, DbDomainConfig(env.Postgresql, DbDomainConfigDetails(name, None), None))
    store
  }

  override protected def afterAll(): Unit = {
    super.afterAll()
    runtime.unsafeRun(rPgClient.release(Exit.unit).provideLayer(ZLogger.live))
  }
} 
Example 8
Source File: CassandraJsonDataStoreTest.scala    From izanami   with Apache License 2.0 5 votes vote down vote up
package specs.cassandra.store

import com.datastax.driver.core.{Cluster, Session}
import env.{CassandraConfig, DbDomainConfig, DbDomainConfigDetails}
import libs.logs.ZLogger
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
import store.AbstractJsonDataStoreTest
import store.cassandra.CassandraClient
import store.cassandra.CassandraJsonDataStore
import zio.{Exit, Reservation}

class CassandraJsonDataStoreTest
    extends AbstractJsonDataStoreTest("Cassandra")
    with BeforeAndAfter
    with BeforeAndAfterAll {

  val cassandraConfig = CassandraConfig(Seq("127.0.0.1:9042"), None, 1, "izanami_test")

  private val rDriver: Reservation[ZLogger, Throwable, Option[(Cluster, Session)]] =
    runtime.unsafeRun(CassandraClient.cassandraClient(Some(cassandraConfig)).reserve.provideLayer(ZLogger.live))

  val Some((_, session)) = runtime.unsafeRun(rDriver.acquire.provideLayer(ZLogger.live))

  override protected def afterAll(): Unit = {
    super.afterAll()
    runtime.unsafeRun(rDriver.release(Exit.unit).provideLayer(ZLogger.live))
  }

  override def dataStore(name: String): CassandraJsonDataStore =
    CassandraJsonDataStore(
      session,
      cassandraConfig,
      DbDomainConfig(env.Cassandra, DbDomainConfigDetails(name, None), None)
    )

} 
Example 9
Source File: TestSSLConfigContext.scala    From kafka-connect-common   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect.config

import javax.net.ssl.{KeyManager, SSLContext, TrustManager}
import org.scalatest.wordspec.AnyWordSpec
import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.should.Matchers


class TestSSLConfigContext extends AnyWordSpec with Matchers with BeforeAndAfter {
  var sslConfig : SSLConfig = null
  var sslConfigNoClient : SSLConfig = null

  before {
    val trustStorePath = System.getProperty("truststore")
    val trustStorePassword ="erZHDS9Eo0CcNo"
    val keystorePath = System.getProperty("keystore")
    val keystorePassword ="8yJQLUnGkwZxOw"
    sslConfig = SSLConfig(trustStorePath, trustStorePassword , Some(keystorePath), Some(keystorePassword), true)
    sslConfigNoClient = SSLConfig(trustStorePath, trustStorePassword , Some(keystorePath), Some(keystorePassword), false)
  }

  "SSLConfigContext" should {
    "should return an Array of KeyManagers" in {
      val keyManagers = SSLConfigContext.getKeyManagers(sslConfig)
      keyManagers.length shouldBe 1
      val entry = keyManagers.head
      entry shouldBe a [KeyManager]
    }

    "should return an Array of TrustManagers" in {
      val trustManager = SSLConfigContext.getTrustManagers(sslConfig)
      trustManager.length shouldBe 1
      val entry = trustManager.head
      entry shouldBe a [TrustManager]
    }

    "should return a SSLContext" in {
      val context = SSLConfigContext(sslConfig)
      context.getProtocol shouldBe "SSL"
      context shouldBe a [SSLContext]
    }
  }
} 
Example 10
Source File: TestUtilsBase.scala    From kafka-connect-common   with Apache License 2.0 5 votes vote down vote up
package com.datamountaineer.streamreactor.connect

import java.util
import java.util.Collections

import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.kafka.connect.data.{Schema, SchemaBuilder, Struct}
import org.apache.kafka.connect.sink.SinkRecord
import org.apache.kafka.connect.source.SourceTaskContext
import org.apache.kafka.connect.storage.OffsetStorageReader
import org.mockito.Mockito._
import org.mockito.MockitoSugar
import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import scala.collection.JavaConverters._



    //set up partition
    val partition: util.Map[String, String] = Collections.singletonMap(lookupPartitionKey, table)
    //as a list to search for
    val partitionList: util.List[util.Map[String, String]] = List(partition).asJava
    //set up the offset
    val offset: util.Map[String, Object] = (Collections.singletonMap(offsetColumn,offsetValue ))
    //create offsets to initialize from
    val offsets :util.Map[util.Map[String, String],util.Map[String, Object]] = Map(partition -> offset).asJava

    //mock out reader and task context
    val taskContext = mock[SourceTaskContext]
    val reader = mock[OffsetStorageReader]
    when(reader.offsets(partitionList)).thenReturn(offsets)
    when(taskContext.offsetStorageReader()).thenReturn(reader)

    taskContext
  }
} 
Example 11
Source File: FlumeStreamSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.language.postfixOps

import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}

class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
  val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
  var ssc: StreamingContext = null

  test("flume input stream") {
    testFlumeStream(testCompression = false)
  }

  test("flume input compressed stream") {
    testFlumeStream(testCompression = true)
  }

  
  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }
} 
Example 12
Source File: ResolveInlineTablesSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.{LongType, NullType}


class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables.validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables.validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables.convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables.convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables.convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 13
Source File: RowDataSourceStrategySuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE TEMPORARY TABLE inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
      """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 14
Source File: AggregateHashMapSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false",
      "configuration parameter changed in test body")
  }
}

class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true",
      "configuration parameter changed in test body")
  }
}

class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with
BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
    sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true",
      "configuration parameter changed in test body")
  }
} 
Example 15
Source File: ExtensionServiceIntegrationSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging


  before {
    val sparkConf = new SparkConf()
    sparkConf.set(SCHEDULER_SERVICES, Seq(classOf[SimpleExtensionService].getName()))
    sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite")
    sc = new SparkContext(sparkConf)
  }

  test("Instantiate") {
    val services = new SchedulerExtensionServices()
    assertResult(Nil, "non-nil service list") {
      services.getServices
    }
    services.start(SchedulerExtensionServiceBinding(sc, applicationId))
    services.stop()
  }

  test("Contains SimpleExtensionService Service") {
    val services = new SchedulerExtensionServices()
    try {
      services.start(SchedulerExtensionServiceBinding(sc, applicationId))
      val serviceList = services.getServices
      assert(serviceList.nonEmpty, "empty service list")
      val (service :: Nil) = serviceList
      val simpleService = service.asInstanceOf[SimpleExtensionService]
      assert(simpleService.started.get, "service not started")
      services.stop()
      assert(!simpleService.started.get, "service not stopped")
    } finally {
      services.stop()
    }
  }
} 
Example 16
Source File: FailureSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  private val batchDuration: Duration = Milliseconds(1000)
  private val numBatches = 30
  private var directory: File = null

  before {
    directory = Utils.createTempDir()
  }

  after {
    if (directory != null) {
      Utils.deleteRecursively(directory)
    }
    StreamingContext.getActive().foreach { _.stop() }

    // Stop SparkContext if active
    SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("bla")).stop()
  }

  test("multiple failures with map") {
    MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration)
  }

  test("multiple failures with updateStateByKey") {
    MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration)
  }
} 
Example 17
Source File: InputInfoTrackerSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.streaming.{Duration, StreamingContext, Time}

class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter {

  private var ssc: StreamingContext = _

  before {
    val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker")
    if (ssc == null) {
      ssc = new StreamingContext(conf, Duration(1000))
    }
  }

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
  }

  test("test report and get InputInfo from InputInfoTracker") {
    val inputInfoTracker = new InputInfoTracker(ssc)

    val streamId1 = 0
    val streamId2 = 1
    val time = Time(0L)
    val inputInfo1 = StreamInputInfo(streamId1, 100L)
    val inputInfo2 = StreamInputInfo(streamId2, 300L)
    inputInfoTracker.reportInfo(time, inputInfo1)
    inputInfoTracker.reportInfo(time, inputInfo2)

    val batchTimeToInputInfos = inputInfoTracker.getInfo(time)
    assert(batchTimeToInputInfos.size == 2)
    assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2))
    assert(batchTimeToInputInfos(streamId1) === inputInfo1)
    assert(batchTimeToInputInfos(streamId2) === inputInfo2)
    assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1)
  }

  test("test cleanup InputInfo from InputInfoTracker") {
    val inputInfoTracker = new InputInfoTracker(ssc)

    val streamId1 = 0
    val inputInfo1 = StreamInputInfo(streamId1, 100L)
    val inputInfo2 = StreamInputInfo(streamId1, 300L)
    inputInfoTracker.reportInfo(Time(0), inputInfo1)
    inputInfoTracker.reportInfo(Time(1), inputInfo2)

    inputInfoTracker.cleanup(Time(0))
    assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1)
    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)

    inputInfoTracker.cleanup(Time(1))
    assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None)
    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)
  }
} 
Example 18
Source File: SparkListenerWithClusterSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import scala.collection.mutable

import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.cluster.ExecutorInfo


  val WAIT_TIMEOUT_MILLIS = 10000

  before {
    sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite")
  }

  test("SparkListener sends executor added message") {
    val listener = new SaveExecutorInfo
    sc.addSparkListener(listener)

    // This test will check if the number of executors received by "SparkListener" is same as the
    // number of all executors, so we need to wait until all executors are up
    sc.jobProgressListener.waitUntilExecutorsUp(2, 60000)

    val rdd1 = sc.parallelize(1 to 100, 4)
    val rdd2 = rdd1.map(_.toString)
    rdd2.setName("Target RDD")
    rdd2.count()

    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
    assert(listener.addedExecutorInfo.size == 2)
    assert(listener.addedExecutorInfo("0").totalCores == 1)
    assert(listener.addedExecutorInfo("1").totalCores == 1)
  }

  private class SaveExecutorInfo extends SparkListener {
    val addedExecutorInfo = mutable.Map[String, ExecutorInfo]()

    override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
      addedExecutorInfo(executor.executorId) = executor.executorInfo
    }
  }
} 
Example 19
Source File: BlockReplicationPolicySuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.collection.mutable

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.{LocalSparkContext, SparkFunSuite}

class BlockReplicationPolicySuite extends SparkFunSuite
  with Matchers
  with BeforeAndAfter
  with LocalSparkContext {

  // Implicitly convert strings to BlockIds for test clarity.
  private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)

  
  test(s"block replication - random block replication policy") {
    val numBlockManagers = 10
    val storeSize = 1000
    val blockManagers = (1 to numBlockManagers).map { i =>
      BlockManagerId(s"store-$i", "localhost", 1000 + i, None)
    }
    val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None)
    val replicationPolicy = new RandomBlockReplicationPolicy
    val blockId = "test-block"

    (1 to 10).foreach {numReplicas =>
      logDebug(s"Num replicas : $numReplicas")
      val randomPeers = replicationPolicy.prioritize(
        candidateBlockManager,
        blockManagers,
        mutable.HashSet.empty[BlockManagerId],
        blockId,
        numReplicas
      )
      logDebug(s"Random peers : ${randomPeers.mkString(", ")}")
      assert(randomPeers.toSet.size === numReplicas)

      // choosing n peers out of n
      val secondPass = replicationPolicy.prioritize(
        candidateBlockManager,
        randomPeers,
        mutable.HashSet.empty[BlockManagerId],
        blockId,
        numReplicas
      )
      logDebug(s"Random peers : ${secondPass.mkString(", ")}")
      assert(secondPass.toSet.size === numReplicas)
    }

  }

} 
Example 20
Source File: TopologyMapperSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileOutputStream}

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark._
import org.apache.spark.util.Utils

class TopologyMapperSuite  extends SparkFunSuite
  with Matchers
  with BeforeAndAfter
  with LocalSparkContext {

  test("File based Topology Mapper") {
    val numHosts = 100
    val numRacks = 4
    val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap
    val propsFile = createPropertiesFile(props)

    val sparkConf = (new SparkConf(false))
    sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath)
    val topologyMapper = new FileBasedTopologyMapper(sparkConf)

    props.foreach {case (host, topology) =>
      val obtainedTopology = topologyMapper.getTopologyForHost(host)
      assert(obtainedTopology.isDefined)
      assert(obtainedTopology.get === topology)
    }

    // we get None for hosts not in the file
    assert(topologyMapper.getTopologyForHost("host").isEmpty)

    cleanup(propsFile)
  }

  def createPropertiesFile(props: Map[String, String]): File = {
    val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile
    val fileOS = new FileOutputStream(testFile)
    props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)}
    fileOS.close
    testFile
  }

  def cleanup(testFile: File): Unit = {
    testFile.getParentFile.listFiles.filter { file =>
      file.getName.startsWith(testFile.getName)
    }.foreach { _.delete() }
  }

} 
Example 21
Source File: LocalDirsSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.{SparkConfWithEnv, Utils}


class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter {

  before {
    Utils.clearLocalRootDirs()
  }

  test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") {
    // Regression test for SPARK-2974
    assert(!new File("/NONEXISTENT_DIR").exists())
    val conf = new SparkConf(false)
      .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

  test("SPARK_LOCAL_DIRS override also affects driver") {
    // Regression test for SPARK-2975
    assert(!new File("/NONEXISTENT_DIR").exists())
    // spark.local.dir only contains invalid directories, but that's not a problem since
    // SPARK_LOCAL_DIRS will override it on both the driver and workers:
    val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir")))
      .set("spark.local.dir", "/NONEXISTENT_PATH")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

} 
Example 22
Source File: JdbcRDDSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.sql._

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils

class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  before {
    Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver")
    val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
    try {

      try {
        val create = conn.createStatement
        create.execute("""
          CREATE TABLE FOO(
            ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
            DATA INTEGER
          )""")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
        (1 to 100).foreach { i =>
          insert.setInt(1, i * 2)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

      try {
        val create = conn.createStatement
        create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)")
        (1 to 100).foreach { i =>
          insert.setLong(1, 100000000000000000L +  4000000000000000L * i)
          insert.setInt(2, i)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

    } finally {
      conn.close()
    }
  }

  test("basic functionality") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
      1, 100, 3,
      (r: ResultSet) => { r.getInt(1) } ).cache()

    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 10100)
  }

  test("large id overflow") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?",
      1131544775L, 567279358897692673L, 20,
      (r: ResultSet) => { r.getInt(1) } ).cache()
    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 5050)
  }

  after {
    try {
      DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
    } catch {
      case se: SQLException if se.getSQLState == "08006" =>
        // Normal single database shutdown
        // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
    }
  }
} 
Example 23
Source File: FutureActionSuite.scala    From drizzle-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.util.ThreadUtils


class FutureActionSuite
  extends SparkFunSuite
  with BeforeAndAfter
  with Matchers
  with LocalSparkContext {

  before {
    sc = new SparkContext("local", "FutureActionSuite")
  }

  test("simple async action") {
    val rdd = sc.parallelize(1 to 10, 2)
    val job = rdd.countAsync()
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (10)
    job.jobIds.size should be (1)
  }

  test("complex async action") {
    val rdd = sc.parallelize(1 to 15, 3)
    val job = rdd.takeAsync(10)
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (1 to 10)
    job.jobIds.size should be (2)
  }

} 
Example 24
Source File: COCODatasetSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.dataset.segmentation

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.transform.vision.image.{ImageFeature, RoiImageInfo}
import com.intel.analytics.bigdl.transform.vision.image.label.roi.RoiLabel
import java.awt.image.DataBufferByte
import java.io.{File, FileInputStream}
import javax.imageio.ImageIO
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class COCODatasetSpec extends FlatSpec with Matchers with BeforeAndAfter {

  private def processPath(path: String): String = {
    if (path.contains(":")) {
      path.substring(1)
    } else {
      path
    }
  }

  val resourcePath: String = processPath(getClass.getClassLoader.getResource("coco").getPath)
  val dataSet: COCODataset = COCODataset.load(resourcePath
      + File.separator + "cocomini.json", resourcePath)

  "COCODataset" should "correctly be loaded" in {
    dataSet.images.length should be (5)
    dataSet.annotations.length should be (6)
    val cateIdx = Array(53, 53, 53, 1, 19, 1).toIterator
    val sizes = Array((428, 640), (480, 640), (427, 640), (480, 640), (427, 640)).toIterator
    for (anno <- dataSet.annotations) {
      anno.image.id should be (anno.imageId)
      dataSet.categoryId2Idx(anno.categoryId) should be (cateIdx.next())
      anno.categoryIdx should be (dataSet.categoryId2Idx(anno.categoryId))
      if (anno.isCrowd) {
        anno.segmentation.isInstanceOf[COCORLE] should be (true)
      } else {
        anno.segmentation.isInstanceOf[COCOPoly] should be (true)
        val poly = anno.segmentation.asInstanceOf[COCOPoly]
        poly.height should be (anno.image.height)
        poly.width should be (anno.image.width)
      }
    }
    for (img <- dataSet.images) {
      val size = sizes.next()
      img.height should be (size._1)
      img.width should be (size._2)
    }
    for (i <- 1 to dataSet.categories.length) {
      val cate = dataSet.getCategoryByIdx(i)
      dataSet.categoryId2Idx(cate.id) should be (i)
    }
  }

  "COCODataset.toImageFeatures" should "correctly work" in {
    val cateIdx = Array(1, 19, 53, 53, 53, 1).toIterator
    val sizes = Array((428, 640, 3), (480, 640, 3), (427, 640, 3), (480, 640, 3),
      (427, 640, 3)).toIterator
    val uri = Array("COCO_val2014_000000153344.jpg", "COCO_val2014_000000091136.jpg",
      "COCO_val2014_000000558840.jpg", "COCO_val2014_000000200365.jpg",
      "COCO_val2014_000000374530.jpg"
    ).toIterator
    val isCrowd = Array(1f, 1f, 0f, 0f, 0f, 1f).toIterator
    dataSet.toImageFeatures.foreach(imf => {
      imf.getOriginalSize should be (sizes.next())
      val iscr = imf[Tensor[Float]](RoiImageInfo.ISCROWD)

      val roilabel = imf.getLabel[RoiLabel]
      roilabel.classes.size() should be (iscr.size())
      for(i <- 1 to iscr.nElement()) {
        iscr.valueAt(i) should be (isCrowd.next())
        roilabel.classes.valueAt(i) should be (cateIdx.next())
      }
      roilabel.bboxes.size() should be (Array(roilabel.classes.size(1), 4))

      val inputStream = new FileInputStream(resourcePath + File.separator + uri.next())
      val image = ImageIO.read(inputStream)
      val rawdata = image.getRaster.getDataBuffer.asInstanceOf[DataBufferByte].getData()
      require(java.util.Arrays.equals(rawdata, imf[Array[Byte]](ImageFeature.bytes)))
    })
  }

  "COCOImage.toTable" should "correctly work" in {
    val cateIdx = Array(1, 19, 53, 53, 53, 1).toIterator
    val sizes = Array((428, 640, 3), (480, 640, 3), (427, 640, 3), (480, 640, 3),
      (427, 640, 3)).toIterator
    val isCrowd = Array(1f, 1f, 0f, 0f, 0f, 1f).toIterator
    dataSet.images.map(_.toTable).foreach(tab => {
      RoiImageInfo.getOrigSize(tab) should be (sizes.next())
      val iscr = RoiImageInfo.getIsCrowd(tab)
      val classes = RoiImageInfo.getClasses(tab)
      classes.size() should be (iscr.size())
      for(i <- 1 to iscr.nElement()) {
        iscr.valueAt(i) should be (isCrowd.next())
        classes.valueAt(i) should be (cateIdx.next())
      }
      RoiImageInfo.getBBoxes(tab).size() should be (Array(classes.size(1), 4))

    })
  }

} 
Example 25
Source File: TextToLabeledSentenceSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.dataset.text

import java.io.PrintWriter

import com.intel.analytics.bigdl.dataset.DataSet
import com.intel.analytics.bigdl.utils.{Engine, SparkContextLifeCycle}
import org.apache.spark.SparkContext
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.io.Source

@com.intel.analytics.bigdl.tags.Serial
class TextToLabeledSentenceSpec extends SparkContextLifeCycle with Matchers {
  override def nodeNumber: Int = 1
  override def coreNumber: Int = 1
  override def appName: String = "TextToLabeledSentence"

  "TextToLabeledSentenceSpec" should "indexes sentences correctly on Spark" in {
    val tmpFile = java.io.File
      .createTempFile("UnitTest", "DocumentTokenizerSpec").getPath

    val sentence1 = "Enter Barnardo and Francisco, two sentinels."
    val sentence2 = "Who’s there?"
    val sentence3 = "I think I hear them. Stand ho! Who is there?"
    val sentence4 = "The Dr. lives in a blue-painted box."

    val sentences = Array(sentence1, sentence2, sentence3, sentence4)

    new PrintWriter(tmpFile) {
      write(sentences.mkString("\n")); close
    }

    val tokens = DataSet.rdd(sc.textFile(tmpFile)
      .filter(!_.isEmpty))
      .transform(SentenceTokenizer())
    val output = tokens.toDistributed().data(train = false)
    val dictionary = Dictionary(output, 100)
    val textToLabeledSentence = TextToLabeledSentence[Float](dictionary)
    val labeledSentences = tokens.transform(textToLabeledSentence)
      .toDistributed().data(false).collect()
    labeledSentences.foreach(x => {
      println("input = " + x.data().mkString(","))
      println("target = " + x.label().mkString(","))
      var i = 1
      while (i < x.dataLength()) {
        x.getData(i) should be (x.getLabel(i - 1))
        i += 1
      }
    })
  }

  "TextToLabeledSentenceSpec" should "indexes sentences correctly on Local" in {
    val tmpFile = java.io.File
      .createTempFile("UnitTest", "DocumentTokenizerSpec").getPath

    val sentence1 = "Enter Barnardo and Francisco, two sentinels."
    val sentence2 = "Who’s there?"
    val sentence3 = "I think I hear them. Stand ho! Who is there?"
    val sentence4 = "The Dr. lives in a blue-painted box."

    val sentences = Array(sentence1, sentence2, sentence3, sentence4)

    new PrintWriter(tmpFile) {
      write(sentences.mkString("\n")); close
    }

    val logData = Source.fromFile(tmpFile).getLines().toArray
    val tokens = DataSet.array(logData
      .filter(!_.isEmpty))
      .transform(SentenceTokenizer())
    val output = tokens.toLocal().data(train = false)

    val dictionary = Dictionary(output, 100)
    val textToLabeledSentence = TextToLabeledSentence[Float](dictionary)
    val labeledSentences = tokens.transform(textToLabeledSentence)
      .toLocal().data(false)
    labeledSentences.foreach(x => {
      println("input = " + x.data().mkString(","))
      println("target = " + x.label().mkString(","))
      var i = 1
      while (i < x.dataLength()) {
        x.getData(i) should be (x.getLabel(i - 1))
        i += 1
      }
    })

  }
} 
Example 26
Source File: DictionarySpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.dataset.text

import java.io.PrintWriter

import com.intel.analytics.bigdl.dataset.DataSet
import com.intel.analytics.bigdl.utils.Engine
import com.intel.analytics.bigdl.utils.SparkContextLifeCycle
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.io.Source

class DictionarySpec extends SparkContextLifeCycle with Matchers {
  override def nodeNumber: Int = 1
  override def coreNumber: Int = 1
  override def appName: String = "DictionarySpec"

  "DictionarySpec" should "creates dictionary correctly on Spark" in {
    val tmpFile = java.io.File
      .createTempFile("UnitTest", "DictionarySpec").getPath

    val sentence1 = "Enter Barnardo and Francisco, two sentinels."
    val sentence2 = "Who’s there?"
    val sentence3 = "I think I hear them. Stand ho! Who is there?"

    val sentences = Array(sentence1, sentence2, sentence3)

    new PrintWriter(tmpFile, "UTF-8") {
      write(sentences.mkString("\n")); close
    }

    val tokens = DataSet.rdd(sc.textFile(tmpFile)
      .filter(!_.isEmpty)).transform(SentenceTokenizer())
    val output = tokens.toDistributed().data(train = false)

    val numOfWords = 21

    val dictionary = Dictionary(output, 100)

    dictionary.getVocabSize() should be (numOfWords)
    dictionary.getDiscardSize() should be (0)
    dictionary.print()
    dictionary.printDiscard()
    dictionary.getVocabSize() should be (numOfWords)
    sc.stop()
  }

  "DictionarySpec" should "creates dictionary correctly on local" in {
    val tmpFile = java.io.File
      .createTempFile("UnitTest", "DictionarySpec").getPath

    val sentence1 = "Enter Barnardo and Francisco, two sentinels."
    val sentence2 = "Who’s there?"
    val sentence3 = "I think I hear them. Stand ho! Who is there?"

    val sentences = Array(sentence1, sentence2, sentence3)

    new PrintWriter(tmpFile, "UTF-8") {
      write(sentences.mkString("\n")); close
    }

    val logData = Source.fromFile(tmpFile, "UTF-8").getLines().toArray
    val tokens = DataSet.array(logData
      .filter(!_.isEmpty)).transform(SentenceTokenizer())
    val output = tokens.toLocal().data(train = false)

    val numOfWords = 21

    val dictionary = Dictionary(output, 100)

    dictionary.getVocabSize() should be (numOfWords)
    dictionary.getDiscardSize() should be (0)
    dictionary.print()
    dictionary.printDiscard()
    dictionary.getVocabSize() should be (numOfWords)
  }
} 
Example 27
Source File: SentenceBiPaddingSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.dataset.text

import java.io.PrintWriter

import com.intel.analytics.bigdl.dataset.DataSet
import com.intel.analytics.bigdl.dataset.text.utils.SentenceToken
import com.intel.analytics.bigdl.utils.{Engine, SparkContextLifeCycle}
import org.apache.spark.SparkContext
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.io.Source

@com.intel.analytics.bigdl.tags.Serial
class SentenceBiPaddingSpec extends SparkContextLifeCycle with Matchers {
  override def nodeNumber: Int = 1
  override def coreNumber: Int = 1
  override def appName: String = "DocumentTokenizer"

  "SentenceBiPaddingSpec" should "pads articles correctly on Spark" in {
    val tmpFile = java.io.File
      .createTempFile("UnitTest", "DocumentTokenizerSpec").getPath

    val sentence1 = "Enter Barnardo and Francisco, two sentinels."
    val sentence2 = "Who’s there?"
    val sentence3 = "I think I hear them. Stand ho! Who is there?"
    val sentence4 = "The Dr. lives in a blue-painted box."

    val sentences = Array(sentence1, sentence2, sentence3, sentence4)
    new PrintWriter(tmpFile) {
      write(sentences.mkString("\n")); close
    }

    val sents = DataSet.rdd(sc.textFile(tmpFile)
      .filter(!_.isEmpty)).transform(SentenceSplitter())
      .toDistributed().data(train = false).flatMap(item => item.iterator).collect()
      .asInstanceOf[Array[String]]
    val tokens = DataSet.rdd(sc.parallelize(sents))
      .transform(SentenceBiPadding())
    val output = tokens.toDistributed().data(train = false).collect()

    var count = 0
    println("padding sentences:")
    output.foreach(x => {
      count += x.length
      println(x)
      val words = x.split(" ")
      val startToken = words(0)
      val endToken = words(words.length - 1)
      startToken should be (SentenceToken.start)
      endToken should be (SentenceToken.end)
    })
    sc.stop()
  }

  "SentenceBiPaddingSpec" should "pads articles correctly on local" in {
    val tmpFile = java.io.File
      .createTempFile("UnitTest", "DocumentTokenizerSpec").getPath

    val sentence1 = "Enter Barnardo and Francisco, two sentinels."
    val sentence2 = "Who’s there?"
    val sentence3 = "I think I hear them. Stand ho! Who is there?"
    val sentence4 = "The Dr. lives in a blue-painted box."

    val sentences = Array(sentence1, sentence2, sentence3, sentence4)

    new PrintWriter(tmpFile) {
      write(sentences.mkString("\n")); close
    }

    val logData = Source.fromFile(tmpFile).getLines().toArray
    val sents = DataSet.array(logData
      .filter(!_.isEmpty)).transform(SentenceSplitter())
      .toLocal().data(train = false).flatMap(item => item.iterator)
    val tokens = DataSet.array(sents.toArray)
      .transform(SentenceBiPadding())
    val output = tokens.toLocal().data(train = false).toArray

    var count_word = 0
    println("padding sentences:")
    output.foreach(x => {
      count_word += x.length
      println(x)
      val words = x.split(" ")
      val startToken = words(0)
      val endToken = words(words.length - 1)
      startToken should be (SentenceToken.start)
      endToken should be (SentenceToken.end)
    })
  }
} 
Example 28
Source File: TrainingSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.keras.nn

import com.intel.analytics.bigdl.dataset.Sample
import com.intel.analytics.bigdl.nn.MSECriterion
import com.intel.analytics.bigdl.nn.keras._
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{Engine, Shape}
import com.intel.analytics.bigdl.optim.{DummyDataSet, SGD, Top1Accuracy}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class TrainingSpec extends FlatSpec with Matchers with BeforeAndAfter {
  private var sc: SparkContext = _
  private val nodeNumber = 1
  private val coreNumber = 4
  var data: RDD[Sample[Float]] = null

  before {
    Engine.setNodeAndCore(nodeNumber, coreNumber)
    sc = new SparkContext(s"local[$coreNumber]", "TrainingSpec")

    data = sc.range(0, 16, 1).map { _ =>
      val featureTensor = Tensor[Float](10)
      featureTensor.apply1(_ => scala.util.Random.nextFloat())
      val labelTensor = Tensor[Float](1)
      labelTensor(Array(1)) = Math.round(scala.util.Random.nextFloat())
      Sample[Float](featureTensor, labelTensor)
    }

  }

  after {
    if (sc != null) {
      sc.stop()
    }
  }

  "sequential compile and fit" should "work properly" in {
    val model = Sequential[Float]()
    model.add(Dense(8, inputShape = Shape(10)))
    model.compile(optimizer = "sgd", loss = "mse", metrics = null)
    model.fit(data, batchSize = 8)
  }

  "graph compile and fit" should "work properly" in {
    val input = Input[Float](inputShape = Shape(10))
    val output = Dense[Float](8, activation = "relu").inputs(input)
    val model = Model[Float](input, output)
    model.compile(optimizer = "adam", loss = "mse", metrics = null)
    model.fit(data, batchSize = 8)
  }

  "sequential compile multiple times" should "use the last compile" in {
    val model = Sequential[Float]()
    model.add(Dense(3, inputShape = Shape(10)))
    model.compile(optimizer = "sgd", loss = "sparse_categorical_crossentropy", metrics = null)
    model.compile(optimizer = "adam", loss = "mse", metrics = null)
    model.fit(data, batchSize = 8)
  }

  "compile, fit with validation, evaluate and predict" should "work properly" in {
    val testData = sc.range(0, 8, 1).map { _ =>
      val featureTensor = Tensor[Float](10)
      featureTensor.apply1(_ => scala.util.Random.nextFloat())
      val labelTensor = Tensor[Float](1)
      labelTensor(Array(1)) = Math.round(scala.util.Random.nextFloat())
      Sample[Float](featureTensor, labelTensor)
    }
    val model = Sequential[Float]()
    model.add(Dense(8, activation = "relu", inputShape = Shape(10)))
    model.compile(optimizer = "sgd", loss = "mse", metrics = Array("accuracy"))
    model.fit(data, batchSize = 8, validationData = testData)
    val accuracy = model.evaluate(testData, batchSize = 8)
    val predictResults = model.predict(testData, batchSize = 8)
  }

  "compile, fit, evaluate and predict in local mode" should "work properly" in {
    val localData = DummyDataSet.mseDataSet
    val model = Sequential[Float]()
    model.add(Dense(8, activation = "relu", inputShape = Shape(4)))
    model.compile(optimizer = "sgd", loss = "mse", metrics = Array("accuracy"))
    model.fit(localData, nbEpoch = 5, validationData = null)
    val accuracy = model.evaluate(localData)
    val predictResults = model.predict(localData)
  }

} 
Example 29
Source File: SparkModeSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.integration

import com.intel.analytics.bigdl.models.lenet
import com.intel.analytics.bigdl.models.vgg
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

@com.intel.analytics.bigdl.tags.Integration
class SparkModeSpec extends FlatSpec with Matchers with BeforeAndAfter{

  val mnistFolder = System.getProperty("mnist")
  val cifarFolder = System.getProperty("cifar")

  "Lenet model train and validate" should "be correct" in {
    val batchSize = 8
    val args = Array("--folder", mnistFolder, "-b", batchSize.toString, "-e", "1")
    lenet.Train.main(args)
  }

  "Vgg model train and validate" should "be correct" in {
    val batchSize = 8
    val args = Array("--folder", cifarFolder, "-b", batchSize.toString, "-e", "1")
    vgg.Train.main(args)
  }
} 
Example 30
Source File: DnnTensorSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.tensor

import com.intel.analytics.bigdl.mkl.MklDnn
import com.intel.analytics.bigdl.nn.mkldnn.MemoryOwner
import com.intel.analytics.bigdl.utils.{BigDLSpecHelper, T}
import org.apache.commons.lang3.SerializationUtils
import org.scalatest.BeforeAndAfter

class DnnTensorSpec extends BigDLSpecHelper{
  implicit object Owner extends MemoryOwner {
  }
  "nElement" should "be correct" in {
    val tensor = DnnTensor[Float](3, 4, 5)
    tensor.nElement() should be(3 * 4 * 5)
  }

  "DnnTensor" should "does not support double" in {
    intercept[UnsupportedOperationException] {
      val t = DnnTensor[Double](3, 4, 5)
    }
  }

  "Copy" should "be correct" in {
    val heapTensor = Tensor[Float](T(1, 2, 3, 4))
    val dnnTensor1 = DnnTensor[Float](4)
    dnnTensor1.copy(heapTensor)
    val dnnTensor2 = DnnTensor[Float](4)
    dnnTensor2.copy(dnnTensor1)
    val heapTensor2 = Tensor[Float](4)
    heapTensor2.copy(dnnTensor2)
    heapTensor2 should be(heapTensor)
  }

  "release" should "be correct" in {
    val tensor = DnnTensor[Float](3, 4, 5)
    tensor.isReleased() should be(false)
    tensor.release()
    tensor.isReleased() should be(true)
  }

  "resize" should "be correct" in {
    val tensor = DnnTensor[Float](3, 4)
    tensor.size() should be(Array(3, 4))
    tensor.resize(Array(2, 3))
    tensor.size() should be(Array(2, 3))
    tensor.resize(2)
    tensor.size(1) should be(2)
    tensor.resize(Array(5, 6, 7))
    tensor.size() should be(Array(5, 6, 7))
    tensor.size(2) should be(6)
  }

  "add" should "be correct" in {
    val heapTensor1 = Tensor[Float](T(1, 2, 3, 4))
    val heapTensor2 = Tensor[Float](T(2, 5, 1, 7))
    val dnnTensor1 = DnnTensor[Float](4).copy(heapTensor1)
    val dnnTensor2 = DnnTensor[Float](4).copy(heapTensor2)
    dnnTensor1.add(dnnTensor2)
    val heapTensor3 = Tensor[Float](4).copy(dnnTensor1)
    heapTensor3 should be(Tensor[Float](T(3, 7, 4, 11)))
  }

  "tensor clone with java serialization" should "work correctly" in {
    val heapTensor = Tensor[Float](T(1, 2, 3, 4)).rand(-1, 1)
    val dnnTensor = DnnTensor[Float](4).copy(heapTensor)

    val cloned = SerializationUtils.clone(dnnTensor).asInstanceOf[DnnTensor[Float]]
    val heapCloned = Tensor[Float](4).copy(cloned)

    println(heapTensor)
    println("=" * 80)
    println(heapCloned)

    heapCloned should be (heapTensor)
  }
} 
Example 31
Source File: DistributedSynchronizerSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils

import com.intel.analytics.bigdl.tensor.Tensor
import org.apache.spark.{SparkContext, TaskContext}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class DistributedSynchronizerSpec extends FlatSpec with Matchers with BeforeAndAfter {

  var sc: SparkContext = null

  before {
    val conf = Engine.createSparkConf().setAppName("test synchronizer").setMaster("local[4]")
      .set("spark.rpc.message.maxSize", "200")
    sc = new SparkContext(conf)
    Engine.init
  }

  "DistributedSynchronizer" should "work properly" in {
    val partition = 4
    val cores = 4
    val res = sc.parallelize((0 until partition), partition).mapPartitions(p => {
      Engine.setNodeAndCore(partition, cores)
      val partitionID = TaskContext.getPartitionId
      val sync = new BlockManagerParameterSynchronizer[Float](partitionID, partition)
      val tensor = Tensor[Float](10).fill(partitionID.toFloat + 1.0f)
      sync.init(s"testPara", 10, weights = null, grads = tensor)
      var res : Iterator[_] = null
      sync.put(s"testPara")
      res = Iterator.single(sync.get(s"testPara"))
      sync.clear
      res
    }).collect
    res.length should be  (4)
    res(0).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(1).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(2).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(3).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
  }

  "DistributedSynchronizer with parameter size less than partition" should "work properly" in {
    val cores1 = Runtime.getRuntime().availableProcessors
    val partition = 4
    val cores = 4
    val res = sc.parallelize((0 until partition), partition).mapPartitions(p => {
      Engine.setNodeAndCore(partition, cores)
      val partitionID = TaskContext.getPartitionId
      val sync = new BlockManagerParameterSynchronizer[Float](partitionID, partition)
      val tensor = Tensor[Float](2).fill(partitionID.toFloat + 1.0f)
      sync.init(s"testPara", 2, weights = null, grads = tensor)
      var res : Iterator[_] = null
      sync.put(s"testPara")
      res = Iterator.single(sync.get(s"testPara"))
      sync.clear
      res
    }).collect
    res.length should be  (4)
    res(0).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](2).fill(2.5f))
    res(1).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](2).fill(2.5f))
    res(2).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](2).fill(2.5f))
    res(3).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](2).fill(2.5f))
  }

  "DistributedSynchronizer with parameter offset > 1" should "work properly" in {
    val partition = 4
    val cores = 4
    val res = sc.parallelize((0 until partition), partition).mapPartitions(p => {
      Engine.setNodeAndCore(partition, cores)
      val partitionID = TaskContext.getPartitionId
      val sync = new BlockManagerParameterSynchronizer[Float](partitionID, partition)
      val tensor = Tensor[Float](20)
      val parameter = tensor.narrow(1, 10, 10).fill(partitionID.toFloat + 1.0f)
      sync.init(s"testPara", 10, weights = null, grads = parameter)
      var res : Iterator[_] = null
      sync.put(s"testPara")
      res = Iterator.single(sync.get(s"testPara"))
      sync.clear
      res
    }).collect
    res.length should be  (4)
    res(0).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(1).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(2).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
    res(3).asInstanceOf[Tuple2[_, _]]._2 should be (Tensor[Float](10).fill(2.5f))
  }

  after {
    sc.stop
  }
} 
Example 32
Source File: TFUtilsSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils.tf

import java.io.File
import java.nio.ByteOrder

import com.google.protobuf.ByteString
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.T
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
import org.tensorflow.framework.TensorProto

import scala.collection.JavaConverters._

class TFUtilsSpec extends FlatSpec with Matchers with BeforeAndAfter {

  private var constTensors: Map[String, TensorProto] = null
  before {
    constTensors = getConstTensorProto()
  }

  private def getConstTensorProto(): Map[String, TensorProto] = {
    val resource = getClass.getClassLoader.getResource("tf")
    val path = resource.getPath + File.separator + "consts.pbtxt"
    val nodes = TensorflowLoader.parseTxt(path)
    nodes.asScala.map(node => node.getName -> node.getAttrMap.get("value").getTensor).toMap
  }

  "parseTensor " should "work with bool TensorProto" in {
    val tensorProto = constTensors("bool_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Boolean](T(true, false, true, false)))
  }

  "parseTensor " should "work with float TensorProto" in {
    val tensorProto = constTensors("float_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Float](T(1.0f, 2.0f, 3.0f, 4.0f)))
  }

  "parseTensor " should "work with double TensorProto" in {
    val tensorProto = constTensors("double_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Double](T(1.0, 2.0, 3.0, 4.0)))
  }

  "parseTensor " should "work with int TensorProto" in {
    val tensorProto = constTensors("int_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with long TensorProto" in {
    val tensorProto = constTensors("long_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Long](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with int8 TensorProto" in {
    val tensorProto = constTensors("int8_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with uint8 TensorProto" in {
    val tensorProto = constTensors("uint8_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with int16 TensorProto" in {
    val tensorProto = constTensors("int16_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with uint16 TensorProto" in {
    val tensorProto = constTensors("uint16_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    bigdlTensor should be (Tensor[Int](T(1, 2, 3, 4)))
  }

  "parseTensor " should "work with string TensorProto" in {
    import TFTensorNumeric.NumericByteString
    val tensorProto = constTensors("string_const")
    val bigdlTensor = TFUtils.parseTensor(tensorProto, ByteOrder.LITTLE_ENDIAN)
    val data = Array(
      ByteString.copyFromUtf8("a"),
      ByteString.copyFromUtf8("b"),
      ByteString.copyFromUtf8("c"),
      ByteString.copyFromUtf8("d")
    )
    bigdlTensor should be (Tensor[ByteString](data, Array[Int](4)))
  }
} 
Example 33
Source File: ShapeSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils

import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class ShapeSpec extends FlatSpec with Matchers with BeforeAndAfter {

  "update of SingleShape" should "be test" in {
    assert(Shape(1, 2, 3).copyAndUpdate(-1, 20) == Shape(1, 2, 20))
  }

  "update of MultiShape" should "be test" in {
    val multiShape = Shape(List(Shape(1, 2, 3), Shape(4, 5, 6)))
    assert(multiShape.copyAndUpdate(-1, Shape(5, 5, 5)) ==
      Shape(List(Shape(1, 2, 3), Shape(5, 5, 5))))
  }

  "multiShape not equal" should "be test" in {
    intercept[RuntimeException] {
      assert(Shape(List(Shape(1, 2, 3), Shape(5, 5, 5))) ==
        Shape(List(Shape(1, 2, 3), Shape(5, 6, 5))))
    }}

  "singleShape not equal" should "be test" in {
    intercept[RuntimeException] {
      assert(Shape(1, 2, 3) == Shape(1, 2, 4))
    }}
} 
Example 34
Source File: BigDLSpecHelper.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils

import java.io.{File => JFile}

import org.apache.log4j.Logger
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer

abstract class BigDLSpecHelper extends FlatSpec with Matchers with BeforeAndAfter {
  protected val logger = Logger.getLogger(getClass)

  private val tmpFiles : ArrayBuffer[JFile] = new ArrayBuffer[JFile]()

  protected def createTmpFile(): JFile = {
    val file = java.io.File.createTempFile("UnitTest", "BigDLSpecBase")
    logger.info(s"created file $file")
    tmpFiles.append(file)
    file
  }

  protected def getFileFolder(path: String): String = {
    path.substring(0, path.lastIndexOf(JFile.separator))
  }

  protected def getFileName(path: String): String = {
    path.substring(path.lastIndexOf(JFile.separator) + 1)
  }

  def doAfter(): Unit = {}

  def doBefore(): Unit = {}

  before {
    doBefore()
  }

  after {
    doAfter()
    tmpFiles.foreach(f => {
      if (f.exists()) {
        require(f.isFile, "cannot clean folder")
        f.delete()
        logger.info(s"deleted file $f")
      }
    })
  }
} 
Example 35
Source File: ZippedPartitionsWithLocalityRDDSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.rdd.ZippedPartitionsWithLocalityRDD
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

@com.intel.analytics.bigdl.tags.Serial
class ZippedPartitionsWithLocalityRDDSpec extends SparkContextLifeCycle with Matchers {
  override def coreNumber: Int = 4
  override def appName: String = "ZippedPartitionsWithLocalityRDDSpec"

  "two uncached rdd zip partition" should "not throw exception" in {
    val rdd1 = sc.parallelize((1 to 100), 4)
    val rdd2 = sc.parallelize((1 to 100), 4)
      ZippedPartitionsWithLocalityRDD(rdd1, rdd2)((iter1, iter2) => {
        iter1.zip(iter2)
      }).count()
  }

  "one uncached rdd zip partition" should "not throw exception" in {
    val rdd1 = sc.parallelize((1 to 100), 4).cache()
    val rdd2 = sc.parallelize((1 to 100), 4)
      ZippedPartitionsWithLocalityRDD(rdd1, rdd2)((iter1, iter2) => {
        iter1.zip(iter2)
      }).count()
  }

  "two cached rdd zip partition" should "should be zip" in {
    val rdd1 = sc.parallelize((1 to 100), 4).repartition(4).cache()
    val rdd2 = sc.parallelize((1 to 100), 4).repartition(4).cache()

    rdd1.count()
    rdd2.count()
    rdd2.count() // need to count twice

    ZippedPartitionsWithLocalityRDD(rdd1, rdd2)((iter1, iter2) => {
      iter1.zip(iter2)
    }).count()
  }
} 
Example 36
Source File: RandomGeneratorSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils

import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

@com.intel.analytics.bigdl.tags.Parallel
class RandomGeneratorSpec extends FlatSpec with BeforeAndAfter with Matchers {
  "uniform" should "return correct value" in {
    val a = new RandomGenerator(100)
    a.uniform(0, 1) should be(0.543404 +- 1e-6)
    a.uniform(0, 1) should be(0.671155 +- 1e-6)
    a.uniform(0, 1) should be(0.278369 +- 1e-6)
  }

  "normal" should "return correct value" in {
    val a = new RandomGenerator(100)
    a.normal(0, 1) should be(-1.436301 +- 1e-6)
    a.normal(0, 1) should be(-0.401719 +- 1e-6)
    a.normal(0, 1) should be(-0.182739 +- 1e-6)
  }

  "exponential" should "return correct value" in {
    val a = new RandomGenerator(100)
    a.exponential(1) should be(0.783958 +- 1e-6)
    a.exponential(1) should be(1.112170 +- 1e-6)
    a.exponential(1) should be(0.326241 +- 1e-6)
  }

  "cauchy" should "return correct value" in {
    val a = new RandomGenerator(100)
    a.cauchy(1, 1) should be(1.137212 +- 1e-6)
    a.cauchy(1, 1) should be(1.596309 +- 1e-6)
    a.cauchy(1, 1) should be(0.164062 +- 1e-6)
  }

  "logNormal" should "return correct value" in {
    val a = new RandomGenerator(100)
    a.logNormal(1, 1) should be(0.213872 +- 1e-6)
    a.logNormal(1, 1) should be(0.506097 +- 1e-6)
    a.logNormal(1, 1) should be(0.607310 +- 1e-6)
  }

  "geometric" should "return correct value" in {
    val a = new RandomGenerator(100)
    a.geometric(0.5) should be(2)
    a.geometric(0.5) should be(2)
    a.geometric(0.5) should be(1)
  }

  "bernoulli" should "return correct value" in {
    val a = new RandomGenerator(100)
    a.bernoulli(0.5) should be(false)
    a.bernoulli(0.5) should be(false)
    a.bernoulli(0.5) should be(true)
  }
} 
Example 37
Source File: SparkContextLifeCycle.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.utils

import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}


  def afterTest: Any = {}

  before {
    Engine.init(nodeNumber, coreNumber, true)
    val conf = Engine.createSparkConf().setMaster(s"local[$coreNumber]").setAppName(appName)
    sc = SparkContext.getOrCreate(conf)
    beforeTest
  }

  after {
    if (sc != null) {
      sc.stop()
    }
    afterTest
  }
} 
Example 38
Source File: BifurcateSplitTableSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.util.Random

@com.intel.analytics.bigdl.tags.Serial
class SplitTableSpec extends FlatSpec with BeforeAndAfter with Matchers {

  "A BifurcateSplitTable " should "generate correct output and grad" in {
    val seed = 100
    Random.setSeed(seed)

    val dim = 2
    val module = new BifurcateSplitTable[Double](dim)
    val input = Tensor[Double](3, 4).randn()
    val expectedGradInput = Tensor[Double]().resizeAs(input).randn()
    val gradOutput = T(expectedGradInput.narrow(dim, 1, 2), expectedGradInput.narrow(dim, 3, 2))

    val output = module.forward(input)
    val gradInput = module.backward(input, gradOutput)

    output.length() should be (2)
    val left = output(1).asInstanceOf[Tensor[Double]]
    val right = output(2).asInstanceOf[Tensor[Double]]
    left should be (input.narrow(dim, 1, 2))
    right should be (input.narrow(dim, 3, 2))

    gradInput should be (expectedGradInput)
  }
}

class BifurcateSplitTableSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val batchNorm = BifurcateSplitTable[Float](1).setName("batchNorm")
    val input = Tensor[Float](2, 5).apply1(_ => Random.nextFloat())
    runSerializationTest(batchNorm, input)
  }
}

class SplitTableSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val splitTable = SplitTable[Float](2).setName("splitTable")
    val input = Tensor[Float](2, 10).apply1( e => Random.nextFloat())
    runSerializationTest(splitTable, input)
  }
} 
Example 39
Source File: EvaluatorSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.dataset.{DataSet, MiniBatch, Sample}
import com.intel.analytics.bigdl.models.lenet.LeNet5
import com.intel.analytics.bigdl.nn.CrossEntropyCriterion
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{Engine, MklBlas, MklDnn, SparkContextLifeCycle}
import com.intel.analytics.bigdl.utils.RandomGenerator._
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
import com.intel.analytics.bigdl._

class EvaluatorSpec extends SparkContextLifeCycle with Matchers {

  override def nodeNumber: Int = 1
  override def coreNumber: Int = 1
  override def appName: String = "evaluator"

  private def processPath(path: String): String = {
    if (path.contains(":")) {
      path.substring(1)
    } else {
      path
    }
  }

  "Evaluator" should "be correct" in {
    RNG.setSeed(100)
    val tmp = new Array[Sample[Float]](100)
    var i = 0
    while (i < tmp.length) {
      val input = Tensor[Float](28, 28).fill(0.8f)
      val label = Tensor[Float](1).fill(1.0f)
      tmp(i) = Sample(input, label)
      i += 1
    }
    val model = LeNet5(classNum = 10)
    val dataSet = DataSet.array(tmp, sc).toDistributed().data(train = false)

    val result = model.evaluate(dataSet, Array(new Top1Accuracy[Float](), new Top5Accuracy[Float](),
      new Loss[Float](CrossEntropyCriterion[Float]())))

    result(0)._1 should be (new AccuracyResult(0, 100))
    result(1)._1 should be (new AccuracyResult(100, 100))
    result(2)._1 should be (new LossResult(230.44278f, 100))
    result(0)._1.result()._1 should be (0f)
    result(1)._1.result()._1 should be (1f)
    result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
  }

  "Evaluator MiniBatch" should "be correct" in {
    RNG.setSeed(100)
    val tmp = new Array[MiniBatch[Float]](25)
    var i = 0
    while (i < tmp.length) {
      val input = Tensor[Float](4, 28, 28).fill(0.8f)
      val label = Tensor[Float](4).fill(1.0f)
      tmp(i) = MiniBatch(input, label)
      i += 1
    }
    val model = LeNet5(classNum = 10)
    val dataSet = DataSet.array(tmp, sc).toDistributed().data(train = false)

    val result = model.evaluate(dataSet, Array(new Top1Accuracy[Float](), new Top5Accuracy[Float](),
      new Loss[Float](CrossEntropyCriterion[Float]())))

    result(0)._1 should be (new AccuracyResult(0, 100))
    result(1)._1 should be (new AccuracyResult(100, 100))
    result(2)._1 should be (new LossResult(230.44278f, 100))
    result(0)._1.result()._1 should be (0f)
    result(1)._1.result()._1 should be (1f)
    result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
  }

  "Evaluator different MiniBatch" should "be correct" in {
    RNG.setSeed(100)
    val tmp = new Array[MiniBatch[Float]](25)
    var i = 1
    while (i <= tmp.length) {
      val input = Tensor[Float](i, 28, 28).fill(0.8f)
      val label = Tensor[Float](i).fill(1.0f)
      tmp(i - 1) = MiniBatch(input, label)
      i += 1
    }
    val model = LeNet5(classNum = 10)
    val dataSet = DataSet.array(tmp, sc).toDistributed().data(train = false)

    val result = model.evaluate(dataSet, Array(new Top1Accuracy[Float](), new Top5Accuracy[Float](),
      new Loss[Float](CrossEntropyCriterion[Float]())))

    result(0)._1 should be (new AccuracyResult(0, 325))
    result(1)._1 should be (new AccuracyResult(325, 325))
    result(2)._1 should be (new LossResult(748.93896f, 325))
    result(0)._1.result()._1 should be (0f)
    result(1)._1.result()._1 should be (1f)
    result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
  }
} 
Example 40
Source File: ParallelOptimizerSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.dataset.{DataSet, MiniBatch}
import com.intel.analytics.bigdl.nn.{ClassNLLCriterion, Linear, MSECriterion}
import com.intel.analytics.bigdl.optim.DistriOptimizerSpecModel.mse
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{Engine, T}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

@com.intel.analytics.bigdl.tags.Serial
class ParallelOptimizerSpec extends FlatSpec with Matchers with BeforeAndAfter {

  Logger.getLogger("org").setLevel(Level.WARN)
  Logger.getLogger("akka").setLevel(Level.WARN)

  private var sc: SparkContext = _

  before {
    val conf = Engine.createSparkConf()
      .setMaster("local[1]").setAppName("ParallelOptimizerSpec")
    sc = new SparkContext(conf)
    Engine.init
    Engine.setCoreNumber(1)
  }

  after {
    if (sc != null) {
      sc.stop()
    }
  }

  "Train with parallel" should "work properly" in {
    val input = Tensor[Float](1, 10).fill(1.0f)
    val target = Tensor[Float](1).fill(1.0f)
    val miniBatch = MiniBatch(input, target)
    val model = Linear[Float](10, 2)
    model.getParameters()._1.fill(1.0f)
    val optimMethod = new SGD[Float]()

    val dataSet = DataSet.array(Array(miniBatch), sc)

    val optimizer = new DistriOptimizer[Float](model, dataSet, new ClassNLLCriterion[Float]())
      .setState(T("learningRate" -> 1.0))
      .setEndWhen(Trigger.maxIteration(10))

    optimizer.optimize()

  }

  "Train with parallel" should "have same results as DistriOptimizer" in {

    val input = Tensor[Float](1, 10).fill(1.0f)
    val target = Tensor[Float](1).fill(1.0f)
    val miniBatch = MiniBatch(input, target)
    val model1 = Linear[Float](10, 2)
    model1.getParameters()._1.fill(1.0f)

    val model2 = Linear[Float](10, 2)
    model2.getParameters()._1.fill(1.0f)

    val dataSet = DataSet.array(Array(miniBatch), sc)

    val parallelOptimizer = new DistriOptimizer[Float](model1,
      dataSet, new ClassNLLCriterion[Float]())
      .setState(T("learningRate" -> 1.0))
      .setEndWhen(Trigger.maxIteration(10))

    parallelOptimizer.optimize

    val distriOptimizer = new DistriOptimizer[Float](model2,
      dataSet, new ClassNLLCriterion[Float]())
      .setState(T("learningRate" -> 1.0))
      .setEndWhen(Trigger.maxIteration(10))

    distriOptimizer.optimize

    model1.getParameters()._1 should be (model2.getParameters()._1)

  }

} 
Example 41
Source File: AdamSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.nn.{CrossEntropyCriterion, Linear, Sequential}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.{Engine, RandomGenerator, T, TestUtils}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

@com.intel.analytics.bigdl.tags.Parallel
class AdamSpec extends FlatSpec with Matchers with BeforeAndAfter {

  before {
    System.setProperty("bigdl.localMode", "true")
    System.setProperty("spark.master", "local[2]")
    Engine.init
  }

  after {
    System.clearProperty("bigdl.localMode")
    System.clearProperty("spark.master")
  }


  val start = System.currentTimeMillis()
  "adam" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val config = T("learningRate" -> 0.002)
    val optm = new Adam[Double]
    var fx = new ArrayBuffer[Double]
    for (i <- 1 to 10001) {
      val result = optm.optimize(TestUtils.rosenBrock, x, config)
      if ((i - 1) % 1000 == 0) {
        fx += result._2(0)
      }
    }

    println(s"x is \n$x")
    println("fx is")
    for (i <- 1 to fx.length) {
      println(s"${(i - 1) * 1000 + 1}, ${fx(i - 1)}")
    }

    val spend = System.currentTimeMillis() - start
    println("Time Cost: " + spend + "ms")

    (fx.last < 1e-9) should be(true)
    x(Array(1)) should be(1.0 +- 0.01)
    x(Array(2)) should be(1.0 +- 0.01)
  }

  "ParallelAdam" should "perform well on rosenbrock function" in {
    val x = Tensor[Double](2).fill(0)
    val optm = new ParallelAdam[Double](learningRate = 0.002, parallelNum = 2)
    var fx = new ArrayBuffer[Double]
    for (i <- 1 to 10001) {
      val result = optm.optimize(TestUtils.rosenBrock, x)
      if ((i - 1) % 1000 == 0) {
        fx += result._2(0)
      }
    }

    println(s"x is \n$x")
    println("fx is")
    for (i <- 1 to fx.length) {
      println(s"${(i - 1) * 1000 + 1}, ${fx(i - 1)}")
    }

    val spend = System.currentTimeMillis() - start
    println("Time Cost: " + spend + "ms")

    (fx.last < 1e-9) should be(true)
    x(Array(1)) should be(1.0 +- 0.01)
    x(Array(2)) should be(1.0 +- 0.01)
  }

} 
Example 42
Source File: ValidatorSpec.scala    From BigDL   with Apache License 2.0 5 votes vote down vote up
package com.intel.analytics.bigdl.optim

import com.intel.analytics.bigdl.dataset.{DataSet, Sample, SampleToMiniBatch}
import com.intel.analytics.bigdl.models.lenet.LeNet5
import com.intel.analytics.bigdl.nn.CrossEntropyCriterion
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.utils.Engine
import com.intel.analytics.bigdl.utils.RandomGenerator._
import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.utils.SparkContextLifeCycle
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}


class ValidatorSpec extends SparkContextLifeCycle with Matchers {

  override def nodeNumber: Int = 1
  override def coreNumber: Int = 1
  override def appName: String = "validator"

  private def processPath(path: String): String = {
    if (path.contains(":")) {
      path.substring(1)
    } else {
      path
    }
  }

  "DistriValidator" should "be correct" in {
    RNG.setSeed(100)
    val tmp = new Array[Sample[Float]](100)
    var i = 0
    while (i < tmp.length) {
      val input = Tensor[Float](28, 28).fill(0.8f)
      val label = Tensor[Float](1).fill(1.0f)
      tmp(i) = Sample(input, label)
      i += 1
    }
    val model = LeNet5(classNum = 10)
    val dataSet = DataSet.array(tmp, sc).transform(SampleToMiniBatch(1))
    val validator = Validator(model, dataSet)

    val result = validator.test(Array(new Top1Accuracy[Float](), new Top5Accuracy[Float](),
      new Loss[Float](CrossEntropyCriterion[Float]())))

    result(0)._1 should be (new AccuracyResult(0, 100))
    result(1)._1 should be (new AccuracyResult(100, 100))
    result(2)._1 should be (new LossResult(230.4428f, 100))
    result(0)._1.result()._1 should be (0f)
    result(1)._1.result()._1 should be (1f)
    result(2)._1.result()._1 should be (2.3044279f+-0.000001f)
  }

  "LocalValidator" should "be correct" in {
    RNG.setSeed(100)
    val tmp = new Array[Sample[Float]](100)
    var i = 0
    while (i < tmp.length) {
      val input = Tensor[Float](28, 28).fill(0.8f)
      val label = Tensor[Float](1).fill(1.0f)
      tmp(i) = Sample(input, label)
      i += 1
    }
    val model = LeNet5(classNum = 10)
    val dataSet = DataSet.array(tmp).transform(SampleToMiniBatch(1))
    val validator = Validator(model, dataSet)

    val result = validator.test(Array(new Top1Accuracy[Float](), new Top5Accuracy[Float](),
      new Loss[Float](CrossEntropyCriterion[Float]())))

    result(0)._1 should be (new AccuracyResult(0, 100))
    result(1)._1 should be (new AccuracyResult(100, 100))
    result(2)._1 should be (new LossResult(230.4428f, 100))
  }
} 
Example 43
Source File: ClockProviderSpec.scala    From chronoscala   with MIT License 5 votes vote down vote up
package jp.ne.opt.chronoscala

import java.time.{Clock, ZoneId}

import jp.ne.opt.chronoscala.Imports._
import org.scalatest.BeforeAndAfter
import org.scalatest.flatspec.AnyFlatSpec

class ClockProviderSpec extends AnyFlatSpec with BeforeAndAfter {

  after {
    ClockProvider.setCurrentClockSystem()
  }

  it should "set current clock" in {

    ClockProvider.setCurrentClock(Clock.fixed(Instant.ofEpochMilli(0L), ZoneId.of("UTC")))

    assert(Instant.now() == Instant.ofEpochMilli(0L))
    assert(LocalDate.now() == LocalDate.parse("1970-01-01"))
    assert(LocalDateTime.now() == LocalDateTime.parse("1970-01-01T00:00:00.000"))
    assert(LocalTime.now() == LocalTime.parse("00:00:00.000"))
    assert(ZonedDateTime.now() == ZonedDateTime.parse("1970-01-01T00:00:00.000+00:00[UTC]"))
    assert(OffsetDateTime.now() == OffsetDateTime.parse("1970-01-01T00:00:00.000+00:00"))
  }
} 
Example 44
Source File: SparkFunSuite.scala    From spark-alchemy   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

// scalastyle:off
import java.io.File

import scala.annotation.tailrec
import org.apache.log4j.{Appender, Level, Logger}
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, BeforeAndAfterEach, FunSuite, Outcome, Suite}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.util.{AccumulatorContext, Utils}


  protected def withLogAppender(
    appender: Appender,
    loggerName: Option[String] = None,
    level: Option[Level] = None)(
    f: => Unit): Unit = {
    val logger = loggerName.map(Logger.getLogger).getOrElse(Logger.getRootLogger)
    val restoreLevel = logger.getLevel
    logger.addAppender(appender)
    if (level.isDefined) {
      logger.setLevel(level.get)
    }
    try f finally {
      logger.removeAppender(appender)
      if (level.isDefined) {
        logger.setLevel(restoreLevel)
      }
    }
  }
} 
Example 45
Source File: MyJournalSpec.scala    From akka-tools   with MIT License 5 votes vote down vote up
package no.nextgentel.oss.akkatools.persistence.jdbcjournal

import akka.persistence.CapabilityFlag
import akka.persistence.journal.JournalSpec
import akka.persistence.snapshot.SnapshotStoreSpec
import com.typesafe.config.ConfigFactory
import org.scalatest.BeforeAndAfter
import org.slf4j.LoggerFactory

class MyJournalSpec extends JournalSpec (
  config = ConfigFactory.parseString(
    s"""
       |akka.persistence.query.jdbc-read-journal.configName = MyJournalSpec
       |jdbc-journal.configName = MyJournalSpec
       |jdbc-snapshot-store.configName = MyJournalSpec
     """.stripMargin).withFallback(ConfigFactory.load("application-test.conf"))) {

  val log = LoggerFactory.getLogger(getClass)

  val errorHandler = new JdbcJournalErrorHandler {
    override def onError(e: Exception): Unit = log.error("JdbcJournalErrorHandler.onError", e)
  }

  JdbcJournalConfig.setConfig("MyJournalSpec", JdbcJournalConfig(DataSourceUtil.createDataSource("MyJournalSpec"), Some(errorHandler), StorageRepoConfig(), new PersistenceIdParserImpl('-')))

  override protected def supportsRejectingNonSerializableObjects: CapabilityFlag = false
}

class MySnapshotStoreSpec extends SnapshotStoreSpec (
  config = ConfigFactory.parseString(
    s"""
       |akka.persistence.query.jdbc-read-journal.configName = MySnapshotStoreSpec
       |jdbc-journal.configName = MySnapshotStoreSpec
       |jdbc-snapshot-store.configName = MySnapshotStoreSpec
     """.stripMargin).withFallback(ConfigFactory.load("application-test.conf"))) with BeforeAndAfter {

  val log = LoggerFactory.getLogger(getClass)

  val errorHandler = new JdbcJournalErrorHandler {
    override def onError(e: Exception): Unit = log.error("JdbcJournalErrorHandler.onError", e)
  }

  JdbcJournalConfig.setConfig("MySnapshotStoreSpec", JdbcJournalConfig(DataSourceUtil.createDataSource("MySnapshotStoreSpec"), None, StorageRepoConfig(), new PersistenceIdParserImpl('-')))

} 
Example 46
Source File: ClusterSingletonHelperTest.scala    From akka-tools   with MIT License 5 votes vote down vote up
package no.nextgentel.oss.akkatools.cluster

import akka.actor.{Actor, ActorRef, ActorSystem, Props}
import akka.testkit.{TestKit, TestProbe}
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuiteLike, Matchers}
import org.slf4j.LoggerFactory

import scala.util.Random

object ClusterSingletonHelperTest {
  val port = 20000 + Random.nextInt(20000)
}

class ClusterSingletonHelperTest (_system:ActorSystem) extends TestKit(_system) with FunSuiteLike with Matchers with BeforeAndAfterAll with BeforeAndAfter {

  def this() = this(ActorSystem("test-actor-system", ConfigFactory.parseString(
      s"""akka.actor.provider = "akka.cluster.ClusterActorRefProvider"
          |akka.remote.enabled-transports = ["akka.remote.netty.tcp"]
          |akka.remote.netty.tcp.hostname="localhost"
          |akka.remote.netty.tcp.port=${ClusterSingletonHelperTest.port}
          |akka.cluster.seed-nodes = ["akka.tcp://test-actor-system@localhost:${ClusterSingletonHelperTest.port}"]
    """.stripMargin
    ).withFallback(ConfigFactory.load("application-test.conf"))))

  override def afterAll {
    TestKit.shutdownActorSystem(system)
  }

  val log = LoggerFactory.getLogger(getClass)


  test("start and communicate with cluster-singleton") {


    val started = TestProbe()
    val proxy = ClusterSingletonHelper.startClusterSingleton(system, Props(new OurClusterSingleton(started.ref)), "ocl")
    started.expectMsg("started")
    val sender = TestProbe()
    sender.send(proxy, "ping")
    sender.expectMsg("pong")

  }
}

class OurClusterSingleton(started:ActorRef) extends Actor {

  started ! "started"
  def receive = {
    case "ping" => sender ! "pong"
  }
} 
Example 47
package no.nextgentel.oss.akkatools.aggregate.aggregateTest_usingAggregateStateBase

import java.util.UUID

import akka.actor.{ActorPath, ActorSystem, Props}
import akka.persistence.{DeleteMessagesFailure, DeleteMessagesSuccess, SaveSnapshotFailure, SaveSnapshotSuccess, SnapshotMetadata, SnapshotOffer}
import akka.testkit.{TestKit, TestProbe}
import com.typesafe.config.ConfigFactory
import no.nextgentel.oss.akkatools.aggregate._
import no.nextgentel.oss.akkatools.testing.AggregateTesting
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuiteLike, Matchers}
import org.slf4j.LoggerFactory



  override def onSnapshotOffer(offer: SnapshotOffer): Unit = {
    state = offer.snapshot.asInstanceOf[StringState]
  }

  override def acceptSnapshotRequest(req: SaveSnapshotOfCurrentState): Boolean = {
    if (state == StringState("WAT")) {
      state = StringState("SAVED")
      true
    }
    else {
      state = StringState("WAT") //So it works second time
      false
    }
  }

  override def onSnapshotSuccess(success: SaveSnapshotSuccess): Unit = {
    state = StringState("SUCCESS_SNAP")
  }

  override def onSnapshotFailure(failure: SaveSnapshotFailure): Unit = {
    state = StringState("FAIL_SNAP")
  }

  override def onDeleteMessagesSuccess(success: DeleteMessagesSuccess): Unit = {
    state = StringState("SUCCESS_MSG")
  }

  override def onDeleteMessagesFailure(failure: DeleteMessagesFailure): Unit = {
    state = StringState("FAIL_MSG")
  }

  // Used as prefix/base when constructing the persistenceId to use - the unique ID is extracted runtime from actorPath which is construced by Sharding-coordinator
  override def persistenceIdBase(): String = "/x/"
}

case class StringEv(data: String)

case class StringState(data:String) extends AggregateStateBase[StringEv, StringState] {
  override def transitionState(event: StringEv): StateTransition[StringEv, StringState] =
    StateTransition(StringState(event.data))
} 
Example 48
Source File: ActorWithDMSupportTest.scala    From akka-tools   with MIT License 5 votes vote down vote up
package no.nextgentel.oss.akkatools.persistence

import java.util.concurrent.TimeUnit

import akka.actor.{Props, ActorSystem}
import akka.testkit.{TestProbe, TestKit}
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, FunSuiteLike}

import scala.concurrent.duration.FiniteDuration

class ActorWithDMSupportTest(_system:ActorSystem) extends TestKit(_system) with FunSuiteLike with Matchers with BeforeAndAfterAll with BeforeAndAfter {
  def this() = this(ActorSystem("ActorWithDMSupportTest", ConfigFactory.load("application-test.conf")))

  test("success with dm") {
    val a = system.actorOf(Props(new TestActorWithDMSupport()))
    val s = TestProbe()

    // send raw
    s.send(a, "sendok")
    s.expectMsg("ok")

    // send via dm and withNewPayload
    val dm = DurableMessage(1L, "sendok", s.ref.path)
    s.send(a, dm)
    s.expectMsg(dm.withNewPayload("ok"))

    // send raw - do nothing
    s.send(a, "silent")


    // send silent - wait for configm
    s.send(a, DurableMessage(1L, "silent", s.ref.path))
    s.expectMsg( DurableMessageReceived(1,None) )


    // send noconfirm - with dm
    s.send(a, DurableMessage(1L, "no-confirm", s.ref.path))
    s.expectNoMessage(FiniteDuration(500, TimeUnit.MILLISECONDS))

    // send noconfirm - with dm
    s.send(a, DurableMessage(1L, "no-confirm-custom", s.ref.path))
    s.expectNoMessage(FiniteDuration(500, TimeUnit.MILLISECONDS))

    // send noconfirm - without dm
    s.send(a, "no-confirm")
    s.expectNoMessage(FiniteDuration(500, TimeUnit.MILLISECONDS))

    // send noconfirm - without dm
    s.send(a, "no-confirm-custom")
    s.expectNoMessage(FiniteDuration(500, TimeUnit.MILLISECONDS))

  }


}

class TestActorWithDMSupport extends ActorWithDMSupport {
  // All raw messages or payloads in DMs are passed to this function.
  override def receivePayload = {
    case "sendok" =>
      send(sender.path, "ok")
    case "silent" =>
      Unit
    case "no-confirm" =>
      throw new LogWarningAndSkipDMConfirmException("something went wrong")
    case "no-confirm-custom" =>
      throw new CustomLogWarningAndSkipDMConfirm()
  }
}

class CustomLogWarningAndSkipDMConfirm extends Exception("") with LogWarningAndSkipDMConfirm 
Example 49
Source File: ReadWriteTests.scala    From spark-cdm   with MIT License 5 votes vote down vote up
package com.microsoft.cdm.test

import com.microsoft.cdm.utils.{AADProvider, ADLGen2Provider, CDMModel}
import org.apache.spark.sql.{Row, SparkSession}
import org.scalatest.{BeforeAndAfter, FunSuite}

import scala.util.{Random, Try}

class ReadWriteTests extends FunSuite with BeforeAndAfter {

  private val appId = sys.env("APP_ID")
  private val appKey = sys.env("APP_KEY")
  private val tenantId = sys.env("TENANT_ID")
  private val inputModel = sys.env("DEMO_INPUT_MODEL")
  private val outputTestDir = sys.env("OUTPUT_TEST_DIR")
  private val spark = SparkSession.builder().master("local").appName("DemoTest").getOrCreate()
  spark.sparkContext.setLogLevel("ERROR")

  test("read and write basic CDM folders") {
    val entities = getEntities(inputModel)

    val outputDir = outputTestDir + "output" + Random.alphanumeric.take(5).mkString("") + "/"
    val outputModelName = Random.alphanumeric.take(5).mkString("")

    val collections: Map[String, Int] = entities map (t => t ->
      readWrite(inputModel, t, write=true, outputDir, outputModelName).length) toMap

    val outputModel = outputDir + "model.json"
    val outputEntities = getEntities(outputModel)

    assert(outputEntities.size == entities.size)

    collections.foreach{case (entity, size) =>
        assert(size == readWrite(outputModel, entity, write=false).length)
    }

    println("Done!")
  }

  private def getEntities(modelUri: String): Iterable[String] = {
    val aadProvider = new AADProvider(appId, appKey, tenantId)
    val adls = new ADLGen2Provider(aadProvider)
    val modelJson = new CDMModel(adls.getFullFile(modelUri))
    modelJson.listEntities()
  }

  private def readWrite(modelUri: String,
                        entity: String,
                        write: Boolean = true,
                        outputDirectory: String = "",
                        outputModelName: String = ""): Array[Row] = {
    println("%s : %s".format(entity, modelUri))

    val df = spark.read.format("com.microsoft.cdm")
      .option("cdmModel", modelUri)
      .option("entity", entity)
      .option("appId", appId)
      .option("appKey", appKey)
      .option("tenantId", tenantId)
      .load()
    val collection = Try(df.collect()).getOrElse(null)

    if(write) {
      df.write.format("com.microsoft.cdm")
        .option("entity", entity)
        .option("appId", appId)
        .option("appKey", appKey)
        .option("tenantId", tenantId)
        .option("cdmFolder", outputDirectory)
        .option("cdmModelName", outputModelName)
        .save()
    }

    collection
  }

} 
Example 50
Source File: CitizenDetailsServiceSpec.scala    From nisp-frontend   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.nisp.services

import org.joda.time.LocalDate
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.mock.MockitoSugar
import org.scalatestplus.play.OneAppPerSuite
import uk.gov.hmrc.domain.Nino
import uk.gov.hmrc.http.{HeaderCarrier, Upstream5xxResponse}
import uk.gov.hmrc.nisp.helpers.{MockCitizenDetailsService, TestAccountBuilder}
import uk.gov.hmrc.nisp.models.citizen.{Address, Citizen, CitizenDetailsError, CitizenDetailsResponse}
import uk.gov.hmrc.play.test.UnitSpec

import scala.concurrent.Future

class CitizenDetailsServiceSpec extends UnitSpec with MockitoSugar with BeforeAndAfter with ScalaFutures with OneAppPerSuite {
  val nino: Nino = TestAccountBuilder.regularNino
  val noNameNino: Nino = TestAccountBuilder.noNameNino
  val nonExistentNino: Nino = TestAccountBuilder.nonExistentNino
  val badRequestNino: Nino = TestAccountBuilder.blankNino

  "CitizenDetailsService" should {
    "return something for valid NINO" in {
      val person: Future[Either[CitizenDetailsError, CitizenDetailsResponse]] = MockCitizenDetailsService.retrievePerson(nino)(new HeaderCarrier())
      whenReady(person) {p =>
        p.isRight shouldBe true
      }
    }

    "return None for bad NINO" in {
      val person: Future[Either[CitizenDetailsError, CitizenDetailsResponse]] = MockCitizenDetailsService.retrievePerson(nonExistentNino)(new HeaderCarrier())
      whenReady(person) {p =>
        p.isLeft shouldBe true
      }
    }

    "return None for bad request" in {
      val person: Future[Either[CitizenDetailsError, CitizenDetailsResponse]] = MockCitizenDetailsService.retrievePerson(badRequestNino)(new HeaderCarrier())
      whenReady(person) {p =>
        p.isLeft shouldBe true
      }
    }

    "return a Failed Future for a 5XX error" in {
      val person: Future[Either[CitizenDetailsError, CitizenDetailsResponse]] = MockCitizenDetailsService.retrievePerson(TestAccountBuilder.internalServerError)(new HeaderCarrier())
      whenReady(person.failed) { ex =>
        ex shouldBe a [Upstream5xxResponse]
      }
    }

    "return correct name and Date of Birth for NINO" in {
      val person: Future[Either[CitizenDetailsError, CitizenDetailsResponse]] = MockCitizenDetailsService.retrievePerson(nino)(new HeaderCarrier())
      whenReady(person) {p =>
        p.right.map(_.person.copy(nino = nino)) shouldBe Right(Citizen(nino, Some("AHMED"), Some("BRENNAN"), new LocalDate(1954, 3, 9)))
        p.right.get.person.getNameFormatted shouldBe Some("AHMED BRENNAN")
      }
    }

    "return formatted name of None if Citizen returns without a name" in {
      val person: Future[Either[CitizenDetailsError, CitizenDetailsResponse]] = MockCitizenDetailsService.retrievePerson(noNameNino)(new HeaderCarrier())
      whenReady(person) {p =>
        p shouldBe Right(CitizenDetailsResponse(Citizen(noNameNino, None, None, new LocalDate(1954, 3, 9)), Some(Address(Some("GREAT BRITAIN")))))
        p.right.get.person.getNameFormatted shouldBe None
      }
    }
  }
} 
Example 51
Source File: ParquetWriterItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import java.nio.file.Files

import org.apache.parquet.hadoop.ParquetFileWriter
import org.scalatest.{BeforeAndAfter, FreeSpec, Matchers}

import scala.util.Random

class ParquetWriterItSpec
  extends FreeSpec
    with Matchers
    with BeforeAndAfter {

  case class Record(i: Int, d: Double, s: String)
  object Record {
    def random(n: Int): Seq[Record] =
      (1 to n).map(_ =>
        Record(Random.nextInt(), Random.nextDouble(), Random.nextString(10)))
  }

  private val tempDir = com.google.common.io.Files.createTempDir().toPath.toAbsolutePath
  private val writePath = tempDir.resolve("file.parquet")

  // Generate records and do a single batch write.
  private val records = Record.random(5000)

  private def readRecords: Seq[Record] = {
    val iter = ParquetReader.read[Record](writePath.toString)
    try iter.toSeq
    finally iter.close()
  }

  after { // Delete written files
    Files.deleteIfExists(writePath)
  }

  "Batch write should result in proper number of records in the file" in {
    ParquetWriter.writeAndClose(writePath.toString, records)
    readRecords should be(records)
  }

  "Multiple incremental writes produce same result as a single batch write" in {
    val w = ParquetWriter.writer[Record](writePath.toString)
    try records.grouped(5).foreach(w.write)
    finally w.close()
    readRecords shouldBe records
  }

  "Writing record by record works as well" in {
    val w = ParquetWriter.writer[Record](writePath.toString)
    try records.foreach(record => w.write(record))
    finally w.close()
    readRecords shouldBe records
  }

  "Incremental writes work with write mode OVERWRITE" in {
    val w = ParquetWriter.writer[Record](
      writePath.toString,
      ParquetWriter.Options(ParquetFileWriter.Mode.OVERWRITE))
    try records.grouped(5).foreach(w.write)
    finally w.close()
    readRecords shouldBe records
  }

  "Writing to closed writer throws an exception" in {
    val w = ParquetWriter.writer[Record](writePath.toString)
    w.close()
    an[IllegalStateException] should be thrownBy records
      .grouped(2)
      .foreach(w.write)
  }

  "Closing writer without writing anything to it throws no exception" in {
    val w = ParquetWriter.writer[Record](writePath.toString)
    noException should be thrownBy w.close()
  }

  "Closing writer twice throws no exception" in {
    val w = ParquetWriter.writer[Record](writePath.toString)
    noException should be thrownBy w.close()
    noException should be thrownBy w.close()
  }

} 
Example 52
package com.github.mjakubowski84.parquet4s

import com.github.mjakubowski84.parquet4s.Case.CaseDef
import com.github.mjakubowski84.parquet4s.CompatibilityParty._
import org.scalatest.{BeforeAndAfter, FreeSpec, Matchers}

class ParquetWriterAndParquetReaderCompatibilityItSpec extends
  FreeSpec
    with Matchers
    with BeforeAndAfter
    with TestUtils {

  before {
    clearTemp()
  }

  private def runTestCase(testCase: CaseDef): Unit = {
    testCase.description in {
      ParquetWriter.writeAndClose(tempPathString, testCase.data)(testCase.writerFactory)
      val parquetIterable = ParquetReader.read(tempPathString)(testCase.reader)
      try {
        parquetIterable should contain theSameElementsAs testCase.data
      } finally {
        parquetIterable.close()
      }
    }
  }

  "Spark should be able to read file saved by ParquetWriter if the file contains" - {
    CompatibilityTestCases.cases(Writer, Reader).foreach(runTestCase)
  }

} 
Example 53
Source File: ParquetWriterAndSparkCompatibilityItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import com.github.mjakubowski84.parquet4s.Case.CaseDef
import com.github.mjakubowski84.parquet4s.CompatibilityParty._
import org.scalatest.{BeforeAndAfter, FreeSpec, Matchers}

class ParquetWriterAndSparkCompatibilityItSpec extends
  FreeSpec
    with Matchers
    with BeforeAndAfter
    with SparkHelper {

  before {
    clearTemp()
  }

  private def runTestCase(testCase: CaseDef): Unit =
    testCase.description in {
      ParquetWriter.writeAndClose(tempPathString, testCase.data)(testCase.writerFactory)
      readFromTemp(testCase.typeTag) should contain theSameElementsAs testCase.data
    }

  "Spark should be able to read file saved by ParquetWriter if the file contains" - {
    CompatibilityTestCases.cases(Writer, Spark).foreach(runTestCase)
  }

} 
Example 54
Source File: TimeEncodingCompatibilityItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import java.util.TimeZone

import com.github.mjakubowski84.parquet4s.CompatibilityTestCases.TimePrimitives
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class TimeEncodingCompatibilityItSpec extends
  FlatSpec
    with Matchers
    with BeforeAndAfter
    with SparkHelper {

  private val localTimeZone = TimeZone.getDefault
  private val utcTimeZone = TimeZone.getTimeZone("UTC")
  private lazy val newYearEveEvening = java.sql.Timestamp.valueOf(java.time.LocalDateTime.of(2018, 12, 31, 23, 0, 0))
  private lazy val newYearMidnight = java.sql.Timestamp.valueOf(java.time.LocalDateTime.of(2019, 1, 1, 0, 0, 0))
  private lazy val newYear = java.sql.Date.valueOf(java.time.LocalDate.of(2019, 1, 1))

  override def beforeAll(): Unit = {
    super.beforeAll()
    TimeZone.setDefault(utcTimeZone)
  }

  before {
    clearTemp()
  }

  private def writeWithSpark(data: TimePrimitives): Unit = writeToTemp(Seq(data))
  private def readWithSpark: TimePrimitives = readFromTemp[TimePrimitives].head
  private def writeWithParquet4S(data: TimePrimitives, timeZone: TimeZone): Unit =
    ParquetWriter.writeAndClose(tempPathString, Seq(data), ParquetWriter.Options(timeZone = timeZone))
  private def readWithParquet4S(timeZone: TimeZone): TimePrimitives = {
    val parquetIterable = ParquetReader.read[TimePrimitives](tempPathString, ParquetReader.Options(timeZone = timeZone))
    try {
      parquetIterable.head
    } finally {
      parquetIterable.close()
    }
  }

  "Spark" should "read properly time written with time zone one hour east" in {
    val input = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    writeWithParquet4S(input, TimeZone.getTimeZone("GMT+1"))
    readWithSpark should be(expectedOutput)
  }

  it should "read properly written with time zone one hour west" in {
    val input = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    writeWithParquet4S(input, TimeZone.getTimeZone("GMT-1"))
    readWithSpark should be(expectedOutput)
  }

  "Parquet4S" should "read properly time written with time zone one hour east" in {
    val input = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    writeWithSpark(input)
    readWithParquet4S(TimeZone.getTimeZone("GMT-1")) should be(expectedOutput)
  }

  it should "read properly time written with time zone one hour west" in {
    val input = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    writeWithSpark(input)
    readWithParquet4S(TimeZone.getTimeZone("GMT+1")) should be(expectedOutput)
  }

  override def afterAll(): Unit = {
    TimeZone.setDefault(localTimeZone)
    super.afterAll()
  }

} 
Example 55
Source File: SparkAndParquetReaderCompatibilityItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import com.github.mjakubowski84.parquet4s.CompatibilityParty._
import org.scalatest.{BeforeAndAfter, FreeSpec, Matchers}

class SparkAndParquetReaderCompatibilityItSpec extends
  FreeSpec
    with Matchers
    with BeforeAndAfter
    with SparkHelper {

  before {
    clearTemp()
  }

  private def runTestCase(testCase: Case.CaseDef): Unit =
    testCase.description in {
      writeToTemp(testCase.data)(testCase.typeTag)
      val parquetIterable = ParquetReader.read(tempPathString)(testCase.reader)
      try {
        parquetIterable should contain theSameElementsAs testCase.data
      } finally {
        parquetIterable.close()
      }
    }

  "ParquetReader should be able to read file saved by Spark if the file contains" - {
    CompatibilityTestCases.cases(Spark, Reader).foreach(runTestCase)
  }

} 
Example 56
Source File: ParquetWriterAndSparkCompatibilityItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import com.github.mjakubowski84.parquet4s.Case.CaseDef
import com.github.mjakubowski84.parquet4s.CompatibilityParty._
import org.scalatest.{BeforeAndAfter, FreeSpec, Matchers}

class ParquetWriterAndSparkCompatibilityItSpec extends
  FreeSpec
    with Matchers
    with BeforeAndAfter
    with SparkHelper {

  before {
    clearTemp()
  }

  private def runTestCase(testCase: CaseDef): Unit =
    testCase.description in {
      ParquetWriter.writeAndClose(tempPathString, testCase.data)(testCase.writerFactory)
      readFromTemp(testCase.typeTag) should contain theSameElementsAs testCase.data
    }

  "Spark should be able to read file saved by ParquetWriter if the file contains" - {
    CompatibilityTestCases.cases(Writer, Spark).foreach(runTestCase)
  }

} 
Example 57
Source File: TimeEncodingCompatibilityItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import java.util.TimeZone

import com.github.mjakubowski84.parquet4s.CompatibilityTestCases.TimePrimitives
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class TimeEncodingCompatibilityItSpec extends
  FlatSpec
    with Matchers
    with BeforeAndAfter
    with SparkHelper {

  private val localTimeZone = TimeZone.getDefault
  private val utcTimeZone = TimeZone.getTimeZone("UTC")
  private lazy val newYearEveEvening = java.sql.Timestamp.valueOf(java.time.LocalDateTime.of(2018, 12, 31, 23, 0, 0))
  private lazy val newYearMidnight = java.sql.Timestamp.valueOf(java.time.LocalDateTime.of(2019, 1, 1, 0, 0, 0))
  private lazy val newYear = java.sql.Date.valueOf(java.time.LocalDate.of(2019, 1, 1))

  override def beforeAll(): Unit = {
    super.beforeAll()
    TimeZone.setDefault(utcTimeZone)
  }

  before {
    clearTemp()
  }

  private def writeWithSpark(data: TimePrimitives): Unit = writeToTemp(Seq(data))
  private def readWithSpark: TimePrimitives = readFromTemp[TimePrimitives].head
  private def writeWithParquet4S(data: TimePrimitives, timeZone: TimeZone): Unit =
    ParquetWriter.writeAndClose(tempPathString, Seq(data), ParquetWriter.Options(timeZone = timeZone))
  private def readWithParquet4S(timeZone: TimeZone): TimePrimitives = {
    val parquetIterable = ParquetReader.read[TimePrimitives](tempPathString, ParquetReader.Options(timeZone = timeZone))
    try {
      parquetIterable.head
    } finally {
      parquetIterable.close()
    }
  }

  "Spark" should "read properly time written with time zone one hour east" in {
    val input = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    writeWithParquet4S(input, TimeZone.getTimeZone("GMT+1"))
    readWithSpark should be(expectedOutput)
  }

  it should "read properly written with time zone one hour west" in {
    val input = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    writeWithParquet4S(input, TimeZone.getTimeZone("GMT-1"))
    readWithSpark should be(expectedOutput)
  }

  "Parquet4S" should "read properly time written with time zone one hour east" in {
    val input = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    writeWithSpark(input)
    readWithParquet4S(TimeZone.getTimeZone("GMT-1")) should be(expectedOutput)
  }

  it should "read properly time written with time zone one hour west" in {
    val input = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    writeWithSpark(input)
    readWithParquet4S(TimeZone.getTimeZone("GMT+1")) should be(expectedOutput)
  }

  override def afterAll(): Unit = {
    TimeZone.setDefault(localTimeZone)
    super.afterAll()
  }

} 
Example 58
Source File: SparkAndParquetReaderCompatibilityItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import com.github.mjakubowski84.parquet4s.CompatibilityParty._
import org.scalatest.{BeforeAndAfter, FreeSpec, Matchers}

class SparkAndParquetReaderCompatibilityItSpec extends
  FreeSpec
    with Matchers
    with BeforeAndAfter
    with SparkHelper {

  before {
    clearTemp()
  }

  private def runTestCase(testCase: Case.CaseDef): Unit =
    testCase.description in {
      writeToTemp(testCase.data)(testCase.typeTag)
      val parquetIterable = ParquetReader.read(tempPathString)(testCase.reader)
      try {
        parquetIterable should contain theSameElementsAs testCase.data
      } finally {
        parquetIterable.close()
      }
    }

  "ParquetReader should be able to read file saved by Spark if the file contains" - {
    CompatibilityTestCases.cases(Spark, Reader).foreach(runTestCase)
  }

} 
Example 59
Source File: SonarLogTester.scala    From sonar-scala   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package org.sonar.api.utils.log

import scala.jdk.CollectionConverters._

import org.scalatest.{BeforeAndAfter, Suite}

trait SonarLogTester extends BeforeAndAfter { this: Suite =>
  before {
    LogInterceptors.set(new ListInterceptor())
    Loggers.getFactory.setLevel(LoggerLevel.DEBUG)
  }

  after {
    LogInterceptors.set(NullInterceptor.NULL_INSTANCE)
    Loggers.getFactory.setLevel(LoggerLevel.DEBUG)
  }

  def logs: Seq[String] =
    LogInterceptors.get().asInstanceOf[ListInterceptor].logs.asScala.toSeq

  def getLogs: Seq[LogAndArguments] =
    LogInterceptors.get().asInstanceOf[ListInterceptor].getLogs().asScala.toSeq

  def logsFor(level: LoggerLevel): Seq[String] =
    LogInterceptors.get().asInstanceOf[ListInterceptor].logs(level).asScala.toSeq

  def getLogsFor(level: LoggerLevel): Seq[LogAndArguments] =
    LogInterceptors.get().asInstanceOf[ListInterceptor].getLogs(level).asScala.toSeq
} 
Example 60
Source File: PointRDDExtensionsSpec.scala    From reactiveinflux-spark   with Apache License 2.0 5 votes vote down vote up
package com.pygmalios.reactiveinflux.extensions

import com.holdenkarau.spark.testing.SharedSparkContext
import com.pygmalios.reactiveinflux.Point.Measurement
import com.pygmalios.reactiveinflux._
import com.pygmalios.reactiveinflux.extensions.PointRDDExtensionsSpec._
import com.pygmalios.reactiveinflux.spark._
import com.pygmalios.reactiveinflux.spark.extensions.PointRDDExtensions
import org.joda.time.{DateTime, DateTimeZone}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, FlatSpec}

import scala.concurrent.duration._

@RunWith(classOf[JUnitRunner])
class PointRDDExtensionsSpec extends FlatSpec with SharedSparkContext
  with BeforeAndAfter {

  before {
    withInflux(_.create())
  }

  after {
    withInflux(_.drop())
  }

  behavior of "saveToInflux"

  it should "write single point to Influx" in {
    val points = List(point1)
    val rdd = sc.parallelize(points)

    // Execute
    rdd.saveToInflux()

    // Assert
    assert(PointRDDExtensions.totalBatchCount == 1)
    assert(PointRDDExtensions.totalPointCount == 1)
    val result = withInflux(
      _.query(Query(s"SELECT * FROM $measurement1"))
      .result
      .singleSeries)

    assert(result.rows.size == 1)

    val row = result.rows.head
    assert(row.time == point1.time)
    assert(row.values.size == 5)
  }

  it should "write 1000 points to Influx" in {
    val points = (1 to 1000).map { i =>
      Point(
        time = point1.time.plusMinutes(i),
        measurement = point1.measurement,
        tags = point1.tags,
        fields = point1.fields
      )
    }
    val rdd = sc.parallelize(points)

    // Execute
    rdd.saveToInflux()

    // Assert
    assert(PointRDDExtensions.totalBatchCount == 8)
    assert(PointRDDExtensions.totalPointCount == 1000)
    val result = withInflux(
      _.query(Query(s"SELECT * FROM $measurement1"))
        .result
        .singleSeries)

    assert(result.rows.size == 1000)
  }
}

object PointRDDExtensionsSpec {
  implicit val params: ReactiveInfluxDbName = ReactiveInfluxDbName("test")
  implicit val awaitAtMost: Duration = 1.second

  val measurement1: Measurement = "measurement1"
  val point1 = Point(
    time        = new DateTime(1983, 1, 10, 7, 43, 10, 3, DateTimeZone.UTC),
    measurement = measurement1,
    tags        = Map("tagKey1" -> "tagValue1", "tagKey2" -> "tagValue2"),
    fields      = Map("fieldKey1" -> StringFieldValue("fieldValue1"), "fieldKey2" -> BigDecimalFieldValue(10.7)))
} 
Example 61
Source File: TokenizerSuite.scala    From spark-nkp   with Apache License 2.0 5 votes vote down vote up
package com.github.uosdmlab.nkp

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{CountVectorizer, IDF}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfter, FunSuite}


class TokenizerSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {

  private var tokenizer: Tokenizer = _

  private val spark: SparkSession =
    SparkSession.builder()
      .master("local[2]")
      .appName("Tokenizer Suite")
      .getOrCreate

  spark.sparkContext.setLogLevel("WARN")

  import spark.implicits._

  override protected def afterAll(): Unit = {
    try {
      spark.stop
    } finally {
      super.afterAll()
    }
  }

  before {
    tokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("words")
  }

  private val df = spark.createDataset(
    Seq(
      "아버지가방에들어가신다.",
      "사랑해요 제플린!",
      "스파크는 재밌어",
      "나는야 데이터과학자",
      "데이터야~ 놀자~"
    )
  ).toDF("text")

  test("Default parameters") {
    assert(tokenizer.getFilter sameElements Array.empty[String])
  }

  test("Basic operation") {
    val words = tokenizer.transform(df)

    assert(df.count == words.count)
    assert(words.schema.fieldNames.contains(tokenizer.getOutputCol))
  }

  test("POS filter") {
    val nvTokenizer = new Tokenizer()
      .setInputCol("text")
      .setOutputCol("nvWords")
      .setFilter("N", "V")

    val words = tokenizer.transform(df).join(nvTokenizer.transform(df), "text")

    assert(df.count == words.count)
    assert(words.schema.fieldNames.contains(nvTokenizer.getOutputCol))
    assert(words.where(s"SIZE(${tokenizer.getOutputCol}) < SIZE(${nvTokenizer.getOutputCol})").count == 0)
  }

  test("TF-IDF pipeline") {
    tokenizer.setFilter("N")

    val cntVec = new CountVectorizer()
      .setInputCol("words")
      .setOutputCol("tf")

    val idf = new IDF()
      .setInputCol("tf")
      .setOutputCol("tfidf")

    val pipe = new Pipeline()
      .setStages(Array(tokenizer, cntVec, idf))

    val pipeModel = pipe.fit(df)

    val result = pipeModel.transform(df)

    assert(result.count == df.count)

    val fields = result.schema.fieldNames
    assert(fields.contains(tokenizer.getOutputCol))
    assert(fields.contains(cntVec.getOutputCol))
    assert(fields.contains(idf.getOutputCol))

    result.show
  }
} 
Example 62
Source File: MetadataTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.operator.{MetadataTransformUtils, VectorCartesian}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FunSuite}

class MetadataTest extends FunSuite with BeforeAndAfter {

  var spark: SparkSession = _

  before {
    spark = SparkSession.builder().master("local").getOrCreate()
  }

  after {
    spark.close()
  }

  test("test_vector_cartesian") {
    val data = spark.read.format("libsvm")
      .option("numFeatures", "123")
      .load("data/a9a/a9a_123d_train_trans.libsvm")
      .persist()

    val cartesian = new VectorCartesian()
      .setInputCols(Array("features", "features"))
      .setOutputCol("cartesian_features")

    val assembler = new VectorAssembler()
      .setInputCols(Array("features", "cartesian_features"))
      .setOutputCol("assemble_features")

    val pipeline = new Pipeline()
      .setStages(Array(cartesian, assembler))

    val featureModel = pipeline.fit(data)
    val crossDF = featureModel.transform(data)

    crossDF.schema.fields.foreach { field =>
      println("name: " + field.name)
      println("metadata: " + field.metadata.toString())
    }
  }

  test("test_three_order_cartesian") {
    val data = spark.read.format("libsvm")
      .option("numFeatures", 8)
      .load("data/abalone/abalone_8d_train.libsvm")
      .persist()

    val cartesian = new VectorCartesian()
      .setInputCols(Array("features", "features"))
      .setOutputCol("f_f")

    val cartesian2 = new VectorCartesian()
      .setInputCols(Array("features", "f_f"))
      .setOutputCol("f_f_f")

    val pipeline = new Pipeline()
      .setStages(Array(cartesian, cartesian2))

    val crossDF = pipeline.fit(data).transform(data).persist()

    // first cartesian, the number of dimensions is 64
    println("first cartesian dimension = " + crossDF.select("f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).length)
    println(crossDF.select("f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).mkString(","))

    println()

    // second cartesian, the number of dimensions is 512
    println("second cartesian dimension = " + crossDF.select("f_f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).length)
    println(crossDF.select("f_f_f").schema.fields.last.metadata.getStringArray(MetadataTransformUtils.DERIVATION).mkString(","))
  }
} 
Example 63
Source File: PipelineTest.scala    From automl   with Apache License 2.0 5 votes vote down vote up
package com.tencent.angel.spark.automl

import com.tencent.angel.spark.automl.feature.preprocess.{HashingTFWrapper, IDFWrapper, TokenizerWrapper}
import com.tencent.angel.spark.automl.feature.{PipelineBuilder, PipelineWrapper, TransformerWrapper}
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FunSuite}

class PipelineTest extends FunSuite with BeforeAndAfter {

  var spark: SparkSession = _

  before {
    spark = SparkSession.builder().master("local").getOrCreate()
  }

  after {
    spark.close()
  }

  test("test_tfidf") {
    val sentenceData = spark.createDataFrame(Seq(
      (0.0, "Hi I heard about Spark"),
      (0.0, "I wish Java could use case classes"),
      (1.0, "Logistic regression models are neat")
    )).toDF("label", "sentence")

    val pipelineWrapper = new PipelineWrapper()

    val transformers = Array[TransformerWrapper](
      new TokenizerWrapper(),
      new HashingTFWrapper(20),
      new IDFWrapper()
    )

    val stages = PipelineBuilder.build(transformers)

    transformers.foreach { transformer =>
      val inputCols = transformer.getInputCols
      val outputCols = transformer.getOutputCols
      inputCols.foreach(print)
      print("    ")
      outputCols.foreach(print)
      println()
    }

    pipelineWrapper.setStages(stages)

    val model = pipelineWrapper.fit(sentenceData)

    val outputDF = model.transform(sentenceData)
    outputDF.select("outIDF").show()
    outputDF.select("outIDF").foreach { row =>
      println(row.get(0).getClass.getSimpleName)
      val arr = row.get(0)
      println(arr.toString)
    }
    outputDF.rdd.map(row => row.toString()).repartition(1)
      .saveAsTextFile("tmp/output/tfidf")
  }
} 
Example 64
Source File: CouchbasePluginSpec.scala    From akka-persistence-couchbase   with Apache License 2.0 5 votes vote down vote up
package akka.persistence.couchbase.support

import akka.actor.ActorSystem
import akka.persistence.couchbase.{CouchbaseExtension, LoggingConfig}
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Suite}

import scala.concurrent.Await
import scala.concurrent.duration._

object CouchbasePluginSpec {

  val config = ConfigFactory.parseString(
    """
      |akka {
      |  persistence {
      |    journal {
      |      plugin = "couchbase-journal"
      |    }
      |
      |    snapshot-store {
      |      plugin =  "couchbase-snapshot-store"
      |    }
      |
      |    journal-plugin-fallback {
      |      replay-filter {
      |        mode = warn
      |      }
      |    }
      |  }
      |
      |  test.single-expect-default = 10s
      |  loglevel = WARNING
      |  log-dead-letters = 0
      |  log-dead-letters-during-shutdown = off
      |  test.single-expect-default = 10s
      |}
      |
      |couchbase-replay {
      |
      |  batchSize = "4"
      |}
    """.stripMargin)
}

trait CouchbasePluginSpec
  extends Suite
    with BeforeAndAfter
    with BeforeAndAfterAll {

  System.setProperty("java.util.logging.config.class", classOf[LoggingConfig].getName)

  def system: ActorSystem

  def couchbase = CouchbaseExtension(system)

  before {
    assert(couchbase.journalBucket.bucketManager.flush())
    assert(couchbase.snapshotStoreBucket.bucketManager.flush())
  }

  override protected def afterAll(): Unit = {
    Await.result(system.terminate(), 10.seconds)
    super.afterAll()
  }
} 
Example 65
Source File: JdbcExampleSuite.scala    From gihyo-spark-book-example   with Apache License 2.0 5 votes vote down vote up
package jp.gihyo.spark.ch05

import java.sql.DriverManager
import java.util.Properties

import jp.gihyo.spark.{SparkFunSuite, TestSparkContext}
import org.scalatest.BeforeAndAfter

class JdbcExampleSuite extends SparkFunSuite with TestSparkContext with BeforeAndAfter {

  val user = "testUser"
  val pass = "testPass"
  val url = "jdbc:h2:mem:testdb;MODE=MySQL"
  val urlWithUserAndPass = s"jdbc:h2:mem:testdb;user=${user}};password=${pass}"
  var conn: java.sql.Connection = null

  override def beforeAll(): Unit = {
    super.beforeAll()

    Class.forName("org.h2.Driver")
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("CREATE SCHEMA gihyo_spark").executeUpdate()
    conn.prepareStatement(
      """
        |CREATE TABLE gihyo_spark.person (
        |  id INTEGER NOT NULL,
        |  name TEXT(32) NOT NULL,
        |  age INTEGER NOT NULL
        |)
      """.stripMargin.replaceAll("\n", " ")
    ).executeUpdate()
    conn.prepareStatement("INSERT INTO gihyo_spark.person VALUES (1, 'fred', 23)").executeUpdate()
    conn.prepareStatement("INSERT INTO gihyo_spark.person VALUES (2, 'mary', 22)").executeUpdate()
    conn.prepareStatement("INSERT INTO gihyo_spark.person VALUES (3, 'bob', 23)").executeUpdate()
    conn.prepareStatement("INSERT INTO gihyo_spark.person VALUES (4, 'ann', 22)").executeUpdate()
    conn.commit()
  }

  override def afterAll(): Unit = {
    super.afterAll()
    conn.close()
  }

  test("run") {
    JdbcExample.run(sc, sqlContext, url, user, pass)
  }
} 
Example 66
Source File: LoggerHandlerWithIdSpec.scala    From rokku   with Apache License 2.0 5 votes vote down vote up
package com.ing.wbaa.rokku.proxy.provider

import ch.qos.logback.classic.{ Level, Logger }
import com.ing.wbaa.rokku.proxy.data.RequestId
import com.ing.wbaa.rokku.proxy.handler.LoggerHandlerWithId
import org.scalatest.BeforeAndAfter
import org.scalatest.diagrams.Diagrams
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

class LoggerHandlerWithIdSpec extends AnyWordSpec with Matchers with Diagrams with BeforeAndAfter {

  private val logger = new LoggerHandlerWithId
  implicit val id: RequestId = RequestId("1")

  private val logRoot: Logger = org.slf4j.LoggerFactory.getLogger(org.slf4j.Logger.ROOT_LOGGER_NAME).asInstanceOf[Logger]
  private val currentLogLevel = logRoot.getLevel
  private val val1 = 1
  private val val2 = 2
  before(logRoot.setLevel(Level.DEBUG))
  after(logRoot.setLevel(currentLogLevel))

  "Logger" should {
    "work" in {

      noException should be thrownBy {

        logger.debug("test debug {}", val1)
        logger.debug("test debug {} {}", val1, val2)
        logger.debug("test debug {}", new RuntimeException("RTE").getMessage)

        logger.info("test info {}", val1)
        logger.info("test info {} {}", val1, val2)
        logger.info("test info {}", new RuntimeException("RTE").getMessage)

        logger.warn("test warn {}", val1)
        logger.warn("test warn {} {}", val1, val2)
        logger.warn("test warn {}", new RuntimeException("RTE").getMessage)

        logger.error("test error {}", val1)
        logger.error("test error {} {}", val1, val2)
        logger.error("test error {}", new RuntimeException("RTE").getMessage)
      }
    }
  }
} 
Example 67
Source File: DarwinConcurrentHashMapSpec.scala    From darwin   with Apache License 2.0 5 votes vote down vote up
package it.agilelab.darwin.common

import java.util.concurrent.atomic.AtomicInteger

import org.scalatest.BeforeAndAfter
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

class DarwinConcurrentHashMapSpec extends AnyFlatSpec with Matchers with BeforeAndAfter {
  private val realJavaVersion = System.getProperty("java.version")

  after {
    System.setProperty("java.version", realJavaVersion)
  }

  def test(): Unit = {
    val threadNumber = 1000
    val map = DarwinConcurrentHashMap.empty[String, Int]
    var counter = 0
    val threadCounter = new AtomicInteger(0)
    val runnables = for (_ <- 1 to threadNumber) yield {
      new Runnable {
        override def run(): Unit = {
          threadCounter.incrementAndGet()
          val res = map.getOrElseUpdate("A", {
            counter += 1
            counter
          })
          res should be(1)
        }
      }
    }
    val threads = for (r <- runnables) yield {
      val t = new Thread(r)
      t
    }
    for (t <- threads) {
      t.start()
    }
    for (t <- threads) {
      t.join()
    }
    threadCounter.get() should be(threadNumber)
  }


  it should "not evaluate the value if the key is present JAVA 8" in {
    test()
  }

  it should "not evaluate the value if the key is present JAVA 7" in {
    if (JavaVersion.parseJavaVersion(realJavaVersion) >= 8) {
      System.setProperty("java.version", "1.7")
      test()
    } else {
      assert(true)
    }
  }

} 
Example 68
Source File: OptimizeHiveMetadataOnlyQuerySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive

import org.scalatest.BeforeAndAfter

import org.apache.spark.metrics.source.HiveCatalogMetrics
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleton
    with BeforeAndAfter with SQLTestUtils {

  import spark.implicits._

  override def beforeAll(): Unit = {
    super.beforeAll()
    sql("CREATE TABLE metadata_only (id bigint, data string) PARTITIONED BY (part int)")
    (0 to 10).foreach(p => sql(s"ALTER TABLE metadata_only ADD PARTITION (part=$p)"))
  }

  override protected def afterAll(): Unit = {
    try {
      sql("DROP TABLE IF EXISTS metadata_only")
    } finally {
      super.afterAll()
    }
  }

  test("SPARK-23877: validate metadata-only query pushes filters to metastore") {
    withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") {
      val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

      // verify the number of matching partitions
      assert(sql("SELECT DISTINCT part FROM metadata_only WHERE part < 5").collect().length === 5)

      // verify that the partition predicate was pushed down to the metastore
      assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount === 5)
    }
  }

  test("SPARK-23877: filter on projected expression") {
    withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") {
      val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount

      // verify the matching partitions
      val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr,
        Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]),
          spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child)))
          .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType))))

      checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x"))

      // verify that the partition predicate was not pushed down to the metastore
      assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11)
    }
  }
} 
Example 69
Source File: ResolveInlineTablesSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.{LongType, NullType, TimestampType}


class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables(conf).validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables(conf).validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables(conf).validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(conf)(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables(conf).convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("convert TimeZoneAwareExpression") {
    val table = UnresolvedInlineTable(Seq("c1"),
      Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType))))
    val withTimeZone = ResolveTimeZone(conf).apply(table)
    val LocalRelation(output, data, _) = ResolveInlineTables(conf).apply(withTimeZone)
    val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType)
      .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long]
    assert(output.map(_.dataType) == Seq(TimestampType))
    assert(data.size == 1)
    assert(data.head.getLong(0) == correct)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables(conf).convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables(conf).convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 70
Source File: RowDataSourceStrategySuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE OR REPLACE TEMPORARY VIEW inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
       """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 71
Source File: MemorySinkV2Suite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.streaming.sources._
import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
import org.apache.spark.sql.types.StructType

class MemorySinkV2Suite extends StreamTest with BeforeAndAfter {
  test("data writer") {
    val partition = 1234
    val writer = new MemoryDataWriter(
      partition, OutputMode.Append(), new StructType().add("i", "int"))
    writer.write(InternalRow(1))
    writer.write(InternalRow(2))
    writer.write(InternalRow(44))
    val msg = writer.commit()
    assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44))
    assert(msg.partition == partition)

    // Buffer should be cleared, so repeated commits should give empty.
    assert(writer.commit().data.isEmpty)
  }

  test("streaming writer") {
    val sink = new MemorySinkV2
    val writeSupport = new MemoryStreamWriter(
      sink, OutputMode.Append(), new StructType().add("i", "int"))
    writeSupport.commit(0,
      Array(
        MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))),
        MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))),
        MemoryWriterCommitMessage(2, Seq(Row(6), Row(7)))
      ))
    assert(sink.latestBatchId.contains(0))
    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7))
    writeSupport.commit(19,
      Array(
        MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))),
        MemoryWriterCommitMessage(0, Seq(Row(33)))
      ))
    assert(sink.latestBatchId.contains(19))
    assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33))

    assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33))
  }
} 
Example 72
Source File: MicroBatchExecutionSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.streaming

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.functions.{count, window}
import org.apache.spark.sql.streaming.StreamTest

class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter {

  import testImplicits._

  after {
    sqlContext.streams.active.foreach(_.stop())
  }

  test("SPARK-24156: do not plan a no-data batch again after it has already been planned") {
    val inputData = MemoryStream[Int]
    val df = inputData.toDF()
      .withColumn("eventTime", $"value".cast("timestamp"))
      .withWatermark("eventTime", "10 seconds")
      .groupBy(window($"eventTime", "5 seconds") as 'window)
      .agg(count("*") as 'count)
      .select($"window".getField("start").cast("long").as[Long], $"count".as[Long])

    testStream(df)(
      AddData(inputData, 10, 11, 12, 13, 14, 15), // Set watermark to 5
      CheckAnswer(),
      AddData(inputData, 25), // Set watermark to 15 to make MicroBatchExecution run no-data batch
      CheckAnswer((10, 5)),   // Last batch should be a no-data batch
      StopStream,
      Execute { q =>
        // Delete the last committed batch from the commit log to signify that the last batch
        // (a no-data batch) never completed
        val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L)
        q.commitLog.purgeAfter(commit - 1)
      },
      // Add data before start so that MicroBatchExecution can plan a batch. It should not,
      // it should first re-run the incomplete no-data batch and then run a new batch to process
      // new data.
      AddData(inputData, 30),
      StartStream(),
      CheckNewAnswer((15, 1)),   // This should not throw the error reported in SPARK-24156
      StopStream,
      Execute { q =>
        // Delete the entire commit log
        val commit = q.commitLog.getLatest().map(_._1).getOrElse(-1L)
        q.commitLog.purge(commit + 1)
      },
      AddData(inputData, 50),
      StartStream(),
      CheckNewAnswer((25, 1), (30, 1))   // This should not throw the error reported in SPARK-24156
    )
  }
} 
Example 73
Source File: AggregateHashMapSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkConf

class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {
  override protected def sparkConf: SparkConf = super.sparkConf
    .set("spark.sql.codegen.fallback", "false")
    .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false")

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "false",
      "configuration parameter changed in test body")
  }
}

class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {
  override protected def sparkConf: SparkConf = super.sparkConf
    .set("spark.sql.codegen.fallback", "false")
    .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true")

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true",
      "configuration parameter changed in test body")
  }
}

class TwoLevelAggregateHashMapWithVectorizedMapSuite
  extends DataFrameAggregateSuite
  with BeforeAndAfter {

  override protected def sparkConf: SparkConf = super.sparkConf
    .set("spark.sql.codegen.fallback", "false")
    .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true")
    .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true",
      "configuration parameter changed in test body")
  }
} 
Example 74
Source File: AnalyzeHistoryRepositoryTest.scala    From CodeAnalyzerTutorial   with Apache License 2.0 5 votes vote down vote up
package tutor.repo

import scala.concurrent.ExecutionContext.Implicits.global
import org.scalatest.{BeforeAndAfter, FunSpec, Matchers}
import tutor.CodebaseInfo

import scala.concurrent.Await
import scala.concurrent.duration._

class AnalyzeHistoryRepositoryTest extends FunSpec with Matchers with Schemas with H2DB
  with AnalyzeHistoryRepository with BeforeAndAfter {

  before {
    Await.result(setupDB(), 5 seconds)
  }

  after {
    Await.result(dropDB(), 5 seconds)
  }

  describe("AnalyzeHistoryRecorder"){
//    it("should create tables in h2"){
//      AnalyzeHistoryRecorder.setupDB()
//    }
    it("can insert analyzeHistory"){
      val c = Await.result(
        record("some path",CodebaseInfo(1, Map("java" -> 1), 1, 10,None,Nil))
        , 10 seconds)
      c shouldBe 1
    }
  }
} 
Example 75
Source File: OozieWFTest.scala    From schedoscope   with Apache License 2.0 5 votes vote down vote up
package org.schedoscope.dsl.transformations

import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
import org.schedoscope.dsl.Parameter.p
import org.schedoscope.dsl.transformations.OozieTransformation.configurationFromResource
import org.schedoscope.dsl.{Parameter, View}

case class Productfeed(ecShopCode: Parameter[String],
                       year: Parameter[String],
                       month: Parameter[String],
                       day: Parameter[String]) extends View {

  val artNumber = fieldOf[String]
  val artName = fieldOf[String]
  val imageUrl = fieldOf[String]
  val category = fieldOf[String]
  val gender = fieldOf[String]
  val brand = fieldOf[String]

  transformVia(() =>
    OozieTransformation(
      "products_processed-oozie.bundle",
      "workflow-processed_productfeed",
      s"/hdp/${env}/applications/eci/scripts/oozie/products_processed-oozie.bundle/workflow-processed_productfeed/")
      .configureWith(
        configurationFromResource("ooziewftest.properties") ++
          Map(
            "env" -> env,
            "envDir" -> env,
            "env_dir" -> env,
            "success_flag" -> "_SUCCESS",
            "app" -> "eci",
            "output_folder" -> tablePath,
            "wtEcnr" -> ecShopCode.v.get,
            "day" -> day.v.get,
            "month" -> month.v.get,
            "year" -> year.v.get)))
}

class OozieWFTest extends FlatSpec with BeforeAndAfter with Matchers {

  "OozieWF" should "load configuration correctly" in {
    val view = Productfeed(p("ec0101"), p("2014"), p("10"), p("11"))

    val t = view.transformation().asInstanceOf[OozieTransformation]

    t.workflowAppPath shouldEqual "/hdp/dev/applications/eci/scripts/oozie/products_processed-oozie.bundle/workflow-processed_productfeed/"

    val expectedConfiguration = Map(
      "oozie.bundle.application.path" -> "${nameNode}${bundlePath}",
      "oozie.use.system.libpath" -> true,
      "preprocMrIndir" -> "${stageDir}/preproc-in/${wtEcnr}/",
      "datahubBaseDir" -> "${nameNode}/hdp/${envDir}/applications/eci/datahub",
      "output_folder" -> "/hdp/dev/org/schedoscope/dsl/transformations/productfeed",
      "preprocMrOutdir" -> "${stageDir}/preproc-out/${wtEcnr}/",
      "preprocOrigDir" -> "${incomingDir}/",
      "year" -> "2014",
      "preprocTmpDir" -> "${stageDir}/preprocessed/webtrends_log_${wtEcnr}",
      "success_flag" -> "_SUCCESS",
      "sessionDir" -> "${datahubBaseDir}/sessions/${wtEcnr}",
      "envDir" -> "dev",
      "wtEcnr" -> "ec0101",
      "env_dir" -> "dev",
      "incomingDir" -> "${stageDir}/${wtEcnr}",
      "app" -> "eci",
      "preprocOutfilePrefix" -> "${loop_datum}-preprocessed",
      "oozieLauncherQueue" -> "root.webtrends-oozie-launcher",
      "throttle" -> "200",
      "processedDir" -> "${nameNode}/hdp/${envDir}/applications/eci/processed",
      "timeout" -> "10080",
      "stageDir" -> "${nameNode}/hdp/${envDir}/applications/eci/stage",
      "day" -> "11",
      "env" -> "dev",
      "month" -> "10")

    t.configuration.foreach { case (key, value) => expectedConfiguration(key).toString shouldEqual value.toString }
  }
} 
Example 76
Source File: ClientTestBase.scala    From endpoints4s   with MIT License 5 votes vote down vote up
package endpoints4s.algebra.client

import java.net.ServerSocket

import com.github.tomakehurst.wiremock.WireMockServer
import com.github.tomakehurst.wiremock.core.WireMockConfiguration.options
import endpoints4s.algebra
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import scala.concurrent.Future
import scala.concurrent.duration._
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

trait ClientTestBase[T <: algebra.Endpoints]
    extends AnyWordSpec
    with Matchers
    with ScalaFutures
    with BeforeAndAfterAll
    with BeforeAndAfter {

  override implicit def patienceConfig: PatienceConfig =
    PatienceConfig(15.seconds, 10.millisecond)

  val wiremockPort = findOpenPort
  val wireMockServer = new WireMockServer(options().port(wiremockPort))

  override def beforeAll(): Unit = wireMockServer.start()

  override def afterAll(): Unit = wireMockServer.stop()

  before {
    wireMockServer.resetAll()
  }

  def findOpenPort: Int = {
    val socket = new ServerSocket(0)
    try socket.getLocalPort
    finally if (socket != null) socket.close()
  }

  val client: T

  def call[Req, Resp](
      endpoint: client.Endpoint[Req, Resp],
      args: Req
  ): Future[Resp]

  def encodeUrl[A](url: client.Url[A])(a: A): String

} 
Example 77
Source File: ServerTestBase.scala    From endpoints4s   with MIT License 5 votes vote down vote up
package endpoints4s.algebra.server

import java.nio.charset.StandardCharsets

import akka.actor.ActorSystem
import akka.http.scaladsl.Http
import akka.http.scaladsl.model.headers.`Content-Type`
import akka.http.scaladsl.model.{HttpRequest, HttpResponse}
import akka.util.ByteString
import endpoints4s.algebra
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import scala.concurrent.duration._
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

import scala.concurrent.{ExecutionContext, Future}

trait ServerTestBase[T <: algebra.Endpoints]
    extends AnyWordSpec
    with Matchers
    with ScalaFutures
    with BeforeAndAfterAll
    with BeforeAndAfter {

  override implicit def patienceConfig: PatienceConfig =
    PatienceConfig(10.seconds, 10.millisecond)

  val serverApi: T

  
  case class Malformed(errors: Seq[String]) extends DecodedUrl[Nothing]
} 
Example 78
Source File: RerunnableBenchmarkSpec.scala    From catbird   with Apache License 2.0 5 votes vote down vote up
package io.catbird.benchmark

import org.scalatest.BeforeAndAfter
import org.scalatest.flatspec.AnyFlatSpec

class RerunnableBenchmarkSpec extends AnyFlatSpec with BeforeAndAfter {
  val benchmark: RerunnableBenchmark = new RerunnableBenchmark
  val sum = benchmark.numbers.sum

  before(benchmark.initPool())
  after(benchmark.shutdownPool())

  "The benchmark" should "correctly calculate the sum using futures" in {
    assert(benchmark.sumIntsF === sum)
  }

  it should "correctly calculate the sum using futures and future pools" in {
    assert(benchmark.sumIntsPF === sum)
  }

  it should "correctly calculate the sum using rerunnables" in {
    assert(benchmark.sumIntsR === sum)
  }

  it should "correctly calculate the sum using rerunnables and future pools" in {
    assert(benchmark.sumIntsPR === sum)
  }
} 
Example 79
Source File: HierarchyFieldTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.field.hierarchy

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.prop.TableDrivenPropertyChecks
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class HierarchyFieldTest extends WordSpecLike
with Matchers
with BeforeAndAfter
with BeforeAndAfterAll
with TableDrivenPropertyChecks {

  var hbs: Option[HierarchyField] = _

  before {
    hbs = Some(new HierarchyField())
  }

  after {
    hbs = None
  }

  "A HierarchyDimension" should {
    "In default implementation, get 4 precisions for all precision sizes" in {
      val precisionLeftToRight = hbs.get.precisionValue(HierarchyField.LeftToRightName, "")
      val precisionRightToLeft = hbs.get.precisionValue(HierarchyField.RightToLeftName, "")
      val precisionLeftToRightWithWildCard = hbs.get.precisionValue(HierarchyField.LeftToRightWithWildCardName, "")
      val precisionRightToLeftWithWildCard = hbs.get.precisionValue(HierarchyField.RightToLeftWithWildCardName, "")

      precisionLeftToRight._1.id should be(HierarchyField.LeftToRightName)
      precisionRightToLeft._1.id should be(HierarchyField.RightToLeftName)
      precisionLeftToRightWithWildCard._1.id should be(HierarchyField.LeftToRightWithWildCardName)
      precisionRightToLeftWithWildCard._1.id should be(HierarchyField.RightToLeftWithWildCardName)
    }

    "In default implementation, every proposed combination should be ok" in {
      val data = Table(
        ("i", "o"),
        ("google.com", Seq("google.com", "*.com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.LeftToRightWithWildCardName, i)
        assertResult(o)(result._2)
      }
    }
    "In reverse implementation, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("com.stratio.sparta", Seq("com.stratio.sparta", "com.stratio.*", "com.*", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.RightToLeftWithWildCardName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
    "In reverse implementation without wildcards, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("com.stratio.sparta", Seq("com.stratio.sparta", "com.stratio", "com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.RightToLeftName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
    "In non-reverse implementation without wildcards, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("google.com", Seq("google.com", "com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.LeftToRightName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
  }
} 
Example 80
Source File: RabbitIntegrationSpec.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.rabbitmq

import akka.actor.ActorSystem
import akka.event.slf4j.SLF4JLogging
import akka.util.Timeout
import com.typesafe.config.ConfigFactory
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.concurrent.TimeLimitedTests
import org.scalatest.time.{Minute, Span}
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpec}

import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Try


abstract class RabbitIntegrationSpec extends WordSpec with Matchers with SLF4JLogging with TimeLimitedTests
  with BeforeAndAfter with BeforeAndAfterAll {
  private lazy val config = ConfigFactory.load()


  implicit val system = ActorSystem("ActorRabbitMQSystem")
  implicit val timeout = Timeout(10 seconds)
  val timeLimit = Span(1, Minute)
  
  val RabbitTimeOut = 3 second
  val configQueueName = Try(config.getString("rabbitmq.queueName")).getOrElse("rabbitmq-queue")
  val configExchangeName = Try(config.getString("rabbitmq.exchangeName")).getOrElse("rabbitmq-exchange")
  val exchangeType = Try(config.getString("rabbitmq.exchangeType")).getOrElse("topic")
  val routingKey = Try(config.getString("rabbitmq.routingKey")).getOrElse("")
  val vHost = Try(config.getString("rabbitmq.vHost")).getOrElse("/")
  val hosts = Try(config.getString("rabbitmq.hosts")).getOrElse("127.0.0.1")
  val userName = Try(config.getString("rabbitmq.userName")).getOrElse("guest")
  val password = Try(config.getString("rabbitmq.password")).getOrElse("guest")
  val RabbitConnectionURI = s"amqp://$userName:$password@$hosts/%2F"
  var sc: Option[SparkContext] = None
  var ssc: Option[StreamingContext] = None

  def initSpark(): Unit = {
    sc = Some(new SparkContext(conf))
    ssc = Some(new StreamingContext(sc.get, Seconds(1)))
  }

  def stopSpark(): Unit = {
    ssc.foreach(_.stop())
    sc.foreach(_.stop())

    System.gc()
  }

  def initRabbitMQ(): Unit

  def closeRabbitMQ(): Unit

  before {
    log.info("Init spark")
    initSpark()
    log.info("Sending messages to queue..")
    initRabbitMQ()
    log.info("Messages in queue.")
  }

  after {
    log.info("Stop spark")
    stopSpark()
    log.info("Clean rabbitmq")
    closeRabbitMQ()
  }
} 
Example 81
Source File: MorphlinesParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.morphline

import java.io.Serializable

import com.stratio.sparta.sdk.pipeline.input.Input
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike}


@RunWith(classOf[JUnitRunner])
class MorphlinesParserTest extends WordSpecLike with Matchers with BeforeAndAfter with BeforeAndAfterAll {

  val morphlineConfig = """
          id : test1
          importCommands : ["org.kitesdk.**"]
          commands: [
          {
              readJson {},
          }
          {
              extractJsonPaths {
                  paths : {
                      col1 : /col1
                      col2 : /col2
                  }
              }
          }
          {
            java {
              code : "return child.process(record);"
            }
          }
          {
              removeFields {
                  blacklist:["literal:_attachment_body"]
              }
          }
          ]
                        """
  val inputField = Some(Input.RawDataKey)
  val outputsFields = Seq("col1", "col2")
  val props: Map[String, Serializable] = Map("morphline" -> morphlineConfig)

  val schema = StructType(Seq(StructField("col1", StringType), StructField("col2", StringType)))

  val parser = new MorphlinesParser(1, inputField, outputsFields, schema, props)

  "A MorphlinesParser" should {

    "parse a simple json" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row(simpleJson, "hello", "world"))

      result should be eq(expected)
    }

    "parse a simple json removing raw" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row("hello", "world"))

      result should be eq(expected)
    }

    "exclude not configured fields" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word",
            "col3":"!"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row(simpleJson, "hello", "world"))

      result should be eq(expected)
    }
  }
} 
Example 82
Source File: TemporalSparkContext.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin

import org.apache.spark._
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FlatSpec}


private[plugin] trait TemporalSparkContext extends FlatSpec with BeforeAndAfterAll with BeforeAndAfter {

  val conf = new SparkConf()
    .setAppName("simulator-test")
    .setIfMissing("spark.master", "local[*]")

  @transient private var _sc: SparkContext = _
  @transient private var _ssc: StreamingContext = _

  def sc: SparkContext = _sc
  def ssc: StreamingContext = _ssc

  override def beforeAll()  {
    _sc = new SparkContext(conf)
    _ssc = new StreamingContext(sc, Seconds(2))
  }

  override def afterAll() : Unit = {
    if(ssc != null){
      ssc.stop(stopSparkContext =  false, stopGracefully = false)
      _ssc = null
    }
    if (sc != null){
      sc.stop()
      _sc = null
    }

    System.gc()
  }


} 
Example 83
Source File: VersionGeneratorTest.scala    From slick-repo   with MIT License 5 votes vote down vote up
package com.byteslounge.slickrepo.version

import java.time.{Instant, LocalDateTime}

import com.byteslounge.slickrepo.datetime.{DateTimeHelper, MockDateTimeHelper}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class VersionGeneratorTest extends FlatSpec with Matchers with BeforeAndAfter {

  before {
    MockDateTimeHelper.start()
    MockDateTimeHelper.mock(
      Instant.parse("2016-01-03T01:01:02Z")
    )
  }

  "The Integer Version Generator" should "generate the integer initial value" in {
    VersionGenerator.intVersionGenerator.initialVersion() should equal(1)
  }

  it should "generate the next integer value" in {
    VersionGenerator.intVersionGenerator.nextVersion(1) should equal(2)
  }

  "The Long Version Generator" should "generate the long initial value" in {
    VersionGenerator.longVersionGenerator.initialVersion() should equal(1L)
  }

  it should "generate the next long value" in {
    VersionGenerator.longVersionGenerator.nextVersion(1L) should equal(2L)
  }

  "The Instant Version Generator" should "generate the instant initial value" in {
    VersionGenerator.instantVersionGenerator.initialVersion() should equal(InstantVersion(Instant.parse("2016-01-03T01:01:02Z")))
  }

  it should "generate the next instant value" in {
    VersionGenerator.instantVersionGenerator.nextVersion(InstantVersion(Instant.parse("2016-01-01T01:00:02.112Z"))) should equal(InstantVersion(Instant.parse("2016-01-03T01:01:02Z")))
  }

  "The LongInstant Version Generator" should "generate the LongInstant initial value" in {
    VersionGenerator.longInstantVersionGenerator.initialVersion() should equal(LongInstantVersion(Instant.parse("2016-01-03T01:01:02Z")))
  }

  it should "generate the next LongInstant value" in {
    VersionGenerator.longInstantVersionGenerator.nextVersion(LongInstantVersion(Instant.parse("2016-01-01T01:00:02.112Z"))) should equal(LongInstantVersion(Instant.parse("2016-01-03T01:01:02Z")))
  }

  "The LocalDateTime Version Generator" should "generate the LocalDateTime initial value" in {
    VersionGenerator.localDateTimeVersionGenerator.initialVersion() should equal(LocalDateTimeVersion(instantToLocalDateTime(Instant.parse("2016-01-03T01:01:02Z"))))
  }

  it should "generate the next LocalDateTime value" in {
    VersionGenerator.localDateTimeVersionGenerator.nextVersion(LocalDateTimeVersion(instantToLocalDateTime(Instant.parse("2016-01-01T01:00:02.112Z")))) should equal(LocalDateTimeVersion(instantToLocalDateTime(Instant.parse("2016-01-03T01:01:02Z"))))
  }

  "The LongLocalDateTime Version Generator" should "generate the LongLocalDateTime initial value" in {
    VersionGenerator.longLocalDateTimeVersionGenerator.initialVersion() should equal(LongLocalDateTimeVersion(instantToLocalDateTime(Instant.parse("2016-01-03T01:01:02Z"))))
  }

  it should "generate the next LocalDateTime value" in {
    VersionGenerator.longLocalDateTimeVersionGenerator.nextVersion(LongLocalDateTimeVersion(instantToLocalDateTime(Instant.parse("2016-01-01T01:00:02.112Z")))) should equal(LongLocalDateTimeVersion(instantToLocalDateTime(Instant.parse("2016-01-03T01:01:02Z"))))
  }

  private def instantToLocalDateTime(instant: Instant): LocalDateTime = {
    LocalDateTime.ofInstant(instant, DateTimeHelper.localDateTimeZone)
  }
} 
Example 84
Source File: LifecycleHelperTest.scala    From slick-repo   with MIT License 5 votes vote down vote up
package com.byteslounge.slickrepo.repository

import com.byteslounge.slickrepo.annotation.postLoad
import com.byteslounge.slickrepo.test.{H2Config, LifecycleEntityRepositoryPostLoad}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class LifecycleHelperTest extends FlatSpec with BeforeAndAfter with Matchers {

  "The LifecycleHelper" should "detect that an entity does not define a handler" in {
    new PersonRepository(H2Config.config.driver)
    LifecycleHelper.isLifecycleHandlerDefined(classOf[PersonRepository], classOf[postLoad]) should equal(false)
  }

  it should "detect that an entity defines a handler" in {
    new LifecycleEntityRepositoryPostLoad(H2Config.config.driver)
    LifecycleHelper.isLifecycleHandlerDefined(classOf[LifecycleEntityRepositoryPostLoad], classOf[postLoad]) should equal(true)
  }
} 
Example 85
Source File: DateTimeHelperTest.scala    From slick-repo   with MIT License 5 votes vote down vote up
package com.byteslounge.slickrepo.datetime

import java.time.{Instant, LocalDateTime}

import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class DateTimeHelperTest extends FlatSpec with Matchers with BeforeAndAfter {

  before {
    MockDateTimeHelper.restore()
  }

  "The DateTimeHelper" should "return the current instant" in {
    val now: Instant = Instant.now()
    val currentInstant: Instant = DateTimeHelper.currentInstant
    currentInstant.toEpochMilli should be >= now.toEpochMilli
  }

  it should "return the current LocalDateTime" in {
    val now: Instant = Instant.now()
    val currentLocalDateTime: LocalDateTime = DateTimeHelper.currentLocalDateTime
    currentLocalDateTime.atZone(DateTimeHelper.localDateTimeZone).toInstant.toEpochMilli should be >= now.toEpochMilli
  }
} 
Example 86
Source File: OapQuerySuite.scala    From OAP   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.hive.execution

import java.util.{Locale, TimeZone}

import org.scalatest.{BeforeAndAfter, Ignore}

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.HiveUtils
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.internal.SQLConf

// Ignore because in separate package will encounter problem with shaded spark source.
@Ignore
class OapQuerySuite extends HiveComparisonTest with BeforeAndAfter  {
  private lazy val originalTimeZone = TimeZone.getDefault
  private lazy val originalLocale = Locale.getDefault
  import org.apache.spark.sql.hive.test.TestHive._

  // Note: invoke TestHive will create a SparkContext which can't be configured by us.
  // So be careful this may affect current using SparkContext and cause strange problem.
  private lazy val originalCrossJoinEnabled = TestHive.conf.crossJoinEnabled

  override def beforeAll() {
    super.beforeAll()
    TestHive.setCacheTables(true)
    // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*)
    TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles"))
    // Add Locale setting
    Locale.setDefault(Locale.US)
    // Ensures that cross joins are enabled so that we can test them
    TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, true)
    TestHive.setConf(HiveUtils.CONVERT_METASTORE_PARQUET, true)
  }

  override def afterAll() {
    try {
      TestHive.setCacheTables(false)
      TimeZone.setDefault(originalTimeZone)
      Locale.setDefault(originalLocale)
      sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2")
      TestHive.setConf(SQLConf.CROSS_JOINS_ENABLED, originalCrossJoinEnabled)
    } finally {
      super.afterAll()
    }
  }
  private def assertDupIndex(body: => Unit): Unit = {
    val e = intercept[AnalysisException] { body }
    assert(e.getMessage.toLowerCase.contains("exists"))
  }

  test("create hive table in parquet format") {
    try {
      sql("create table p_table (key int, val string) stored as parquet")
      sql("insert overwrite table p_table select * from src")
      sql("create oindex if not exists p_index on p_table(key)")
      assert(sql("select val from p_table where key = 238")
        .collect().head.getString(0) == "val_238")
    } finally {
      sql("drop oindex p_index on p_table")
      sql("drop table p_table")
    }
  }

  test("create duplicate hive table in parquet format") {
    try {
      sql("create table p_table1 (key int, val string) stored as parquet")
      sql("insert overwrite table p_table1 select * from src")
      sql("create oindex p_index on p_table1(key)")
      assertDupIndex { sql("create oindex p_index on p_table1(key)") }
    } finally {
      sql("drop oindex p_index on p_table1")
    }
  }
} 
Example 87
Source File: SparkExecutionPlanProcessForRdbmsQuerySuite.scala    From spark-atlas-connector   with Apache License 2.0 5 votes vote down vote up
package com.hortonworks.spark.atlas.sql

import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}
import java.sql.DriverManager

import com.hortonworks.spark.atlas.{AtlasClientConf, AtlasUtils, WithHiveSupport}
import com.hortonworks.spark.atlas.AtlasEntityReadHelper._
import com.hortonworks.spark.atlas.sql.testhelper.{AtlasQueryExecutionListener, CreateEntitiesTrackingAtlasClient, DirectProcessSparkExecutionPlanProcessor, ProcessEntityValidator}
import com.hortonworks.spark.atlas.types.{external, metadata}
import org.apache.atlas.model.instance.AtlasEntity

class SparkExecutionPlanProcessForRdbmsQuerySuite
  extends FunSuite
  with Matchers
  with BeforeAndAfter
  with WithHiveSupport
  with ProcessEntityValidator {

  val sinkTableName = "sink_table"
  val sourceTableName = "source_table"
  val databaseName = "testdb"
  val jdbcDriver = "org.apache.derby.jdbc.EmbeddedDriver"

  val atlasClientConf: AtlasClientConf = new AtlasClientConf()
  var atlasClient: CreateEntitiesTrackingAtlasClient = _
  val testHelperQueryListener = new AtlasQueryExecutionListener()

  before {
    // setup derby database and necesaary table
    val connectionURL = s"jdbc:derby:memory:$databaseName;create=true"
    Class.forName(jdbcDriver)
    val connection = DriverManager.getConnection(connectionURL)

    val createSinkTableQuery = s"CREATE TABLE $sinkTableName (NAME VARCHAR(20))"
    val createSourceTableQuery = s"CREATE TABLE $sourceTableName (NAME VARCHAR(20))"
    val insertQuery = s"INSERT INTO $sourceTableName (Name) VALUES ('A'), ('B'), ('C')"
    val statement = connection.createStatement
    statement.executeUpdate(createSinkTableQuery)
    statement.executeUpdate(createSourceTableQuery)
    statement.executeUpdate(insertQuery)

    // setup Atlas client
    atlasClient = new CreateEntitiesTrackingAtlasClient()
    sparkSession.listenerManager.register(testHelperQueryListener)
  }

  test("read from derby table and insert into a different derby table") {
    val planProcessor = new DirectProcessSparkExecutionPlanProcessor(atlasClient, atlasClientConf)

    val jdbcProperties = new java.util.Properties
    jdbcProperties.setProperty("driver", jdbcDriver)
    val url = s"jdbc:derby:memory:$databaseName;create=false"

    val readDataFrame = sparkSession.read.jdbc(url, sourceTableName, jdbcProperties)
    readDataFrame.write.mode("append").jdbc(url, sinkTableName, jdbcProperties)

    val queryDetail = testHelperQueryListener.queryDetails.last
    planProcessor.process(queryDetail)
    val entities = atlasClient.createdEntities

    // we're expecting two table entities:
    // one from the source table and another from the sink table
    val tableEntities = listAtlasEntitiesAsType(entities, external.RDBMS_TABLE)
    assert(tableEntities.size === 2)

    val inputEntity = getOnlyOneEntityOnAttribute(tableEntities, "name", sourceTableName)
    val outputEntity = getOnlyOneEntityOnAttribute(tableEntities, "name", sinkTableName)
    assertTableEntity(inputEntity, sourceTableName)
    assertTableEntity(outputEntity, sinkTableName)

    // check for 'spark_process'
    validateProcessEntityWithAtlasEntities(entities, _ => {},
      AtlasUtils.entitiesToReferences(Seq(inputEntity)),
      AtlasUtils.entitiesToReferences(Seq(outputEntity)))
  }

  private def assertTableEntity(entity: AtlasEntity, tableName: String): Unit = {
    val tableQualifiedName = getStringAttribute(entity, "qualifiedName")
    assert(tableQualifiedName.equals(s"$databaseName.$tableName"))
  }

} 
Example 88
Source File: DistributedShellClientSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.examples.distributedshell

import scala.concurrent.Future
import scala.util.{Success, Try}

import akka.testkit.TestProbe
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.cluster.ClientToMaster.ResolveAppId
import org.apache.gearpump.cluster.MasterToClient.ResolveAppIdResult
import org.apache.gearpump.cluster.{MasterHarness, TestUtil}
import org.apache.gearpump.examples.distributedshell.DistShellAppMaster.ShellCommand
import org.apache.gearpump.util.LogUtil

class DistributedShellClientSpec
  extends PropSpec with Matchers with BeforeAndAfter with MasterHarness {

  private val LOG = LogUtil.getLogger(getClass)

  before {
    startActorSystem()
  }

  after {
    shutdownActorSystem()
  }

  protected override def config = TestUtil.DEFAULT_CONFIG

  property("DistributedShellClient should succeed to submit application with required arguments") {
    val command = "ls /"
    val requiredArgs = Array("-appid", "0", "-command", command)
    val masterReceiver = createMockMaster()

    assert(Try(DistributedShellClient.main(Array.empty[String])).isFailure,
      "missing required arguments, print usage")

    Future {
      DistributedShellClient.main(masterConfig, requiredArgs)
    }

    masterReceiver.expectMsg(PROCESS_BOOT_TIME, ResolveAppId(0))
    val mockAppMaster = TestProbe()(getActorSystem)
    masterReceiver.reply(ResolveAppIdResult(Success(mockAppMaster.ref)))
    LOG.info(s"Reply back ResolveAppIdResult, current actorRef: ${mockAppMaster.ref.path.toString}")
    mockAppMaster.expectMsg(PROCESS_BOOT_TIME, ShellCommand(command))
    mockAppMaster.reply("result")
  }
} 
Example 89
Source File: ShellCommandResultAggregatorSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.examples.distributedshell

import org.scalatest.{BeforeAndAfter, Matchers, WordSpec}

import org.apache.gearpump.examples.distributedshell.DistShellAppMaster.{ShellCommandResult, ShellCommandResultAggregator}

class ShellCommandResultAggregatorSpec extends WordSpec with Matchers with BeforeAndAfter {
  "ShellCommandResultAggregator" should {
    "aggregate ShellCommandResult" in {
      val executorId1 = 1
      val executorId2 = 2
      val responseBuilder = new ShellCommandResultAggregator
      val response1 = ShellCommandResult(executorId1, "task1")
      val response2 = ShellCommandResult(executorId2, "task2")
      val result = responseBuilder.aggregate(response1).aggregate(response2).toString()
      val expected = s"Execute results from executor $executorId1 : \ntask1\n" +
        s"Execute results from executor $executorId2 : \ntask2\n"
      assert(result == expected)
    }
  }
} 
Example 90
Source File: DistributedShellSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.examples.distributedshell

import scala.concurrent.Future
import scala.util.Success

import com.typesafe.config.Config
import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.cluster.ClientToMaster.SubmitApplication
import org.apache.gearpump.cluster.MasterToClient.SubmitApplicationResult
import org.apache.gearpump.cluster.{MasterHarness, TestUtil}

class DistributedShellSpec
  extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter with MasterHarness {

  before {
    startActorSystem()
  }

  after {
    shutdownActorSystem()
  }

  override def config: Config = TestUtil.DEFAULT_CONFIG

  property("DistributedShell should succeed to submit application with required arguments") {
    val requiredArgs = Array.empty[String]

    val masterReceiver = createMockMaster()

    Future {
      DistributedShell.main(masterConfig, requiredArgs)
    }

    masterReceiver.expectMsgType[SubmitApplication](PROCESS_BOOT_TIME)
    masterReceiver.reply(SubmitApplicationResult(Success(0)))
  }
} 
Example 91
Source File: DistShellAppMasterSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.examples.distributedshell

import scala.concurrent.Await
import scala.concurrent.duration.Duration

import akka.actor.ActorSystem
import akka.testkit.{TestActorRef, TestProbe}
import org.scalatest.{BeforeAndAfter, Matchers, WordSpec}

import org.apache.gearpump.cluster.AppMasterToMaster.{GetAllWorkers, RegisterAppMaster, RequestResource}
import org.apache.gearpump.cluster.AppMasterToWorker.LaunchExecutor
import org.apache.gearpump.cluster.MasterToAppMaster.{AppMasterRegistered, ResourceAllocated, WorkerList}
import org.apache.gearpump.cluster._
import org.apache.gearpump.cluster.appmaster.{AppMasterRuntimeEnvironment, ApplicationRuntimeInfo}
import org.apache.gearpump.cluster.scheduler.{Relaxation, Resource, ResourceAllocation, ResourceRequest}
import org.apache.gearpump.cluster.worker.WorkerId
import org.apache.gearpump.util.ActorSystemBooter.RegisterActorSystem
import org.apache.gearpump.util.ActorUtil

class DistShellAppMasterSpec extends WordSpec with Matchers with BeforeAndAfter {
  implicit val system = ActorSystem("AppMasterSpec", TestUtil.DEFAULT_CONFIG)
  val mockMaster = TestProbe()(system)
  val mockWorker1 = TestProbe()(system)
  val masterProxy = mockMaster.ref
  val appId = 0
  val userName = "test"
  val masterExecutorId = 0
  val workerList = List(WorkerId(1, 0L), WorkerId(2, 0L), WorkerId(3, 0L))
  val resource = Resource(1)
  val appJar = None
  val appDescription = AppDescription("app0", classOf[DistShellAppMaster].getName, UserConfig.empty)

  "DistributedShell AppMaster" should {
    "launch one ShellTask on each worker" in {
      val appMasterInfo = ApplicationRuntimeInfo(appId, appName = appId.toString)
      val appMasterContext = AppMasterContext(appId, userName, resource, null, appJar, masterProxy)
      TestActorRef[DistShellAppMaster](
        AppMasterRuntimeEnvironment.props(List(masterProxy.path), appDescription,
          appMasterContext))
      mockMaster.expectMsgType[RegisterAppMaster]
      mockMaster.reply(AppMasterRegistered(appId))
      // The DistributedShell AppMaster asks for worker list from Master.
      mockMaster.expectMsg(GetAllWorkers)
      mockMaster.reply(WorkerList(workerList))
      // After worker list is ready, DistributedShell AppMaster requests resource on each worker
      workerList.foreach { workerId =>
        mockMaster.expectMsg(RequestResource(appId, ResourceRequest(Resource(1), workerId,
          relaxation = Relaxation.SPECIFICWORKER)))
      }
      mockMaster.reply(ResourceAllocated(
        Array(ResourceAllocation(resource, mockWorker1.ref, WorkerId(1, 0L)))))
      mockWorker1.expectMsgClass(classOf[LaunchExecutor])
      mockWorker1.reply(RegisterActorSystem(ActorUtil.getSystemAddress(system).toString))
    }
  }

  after {
    system.terminate()
    Await.result(system.whenTerminated, Duration.Inf)
  }
} 
Example 92
Source File: SeqFileStreamProcessorSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.fsio

import java.io.File
import java.time.Instant
import scala.collection.mutable.ArrayBuffer

import akka.actor.ActorSystem
import akka.testkit.TestProbe
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.SequenceFile.Reader
import org.apache.hadoop.io.{SequenceFile, Text}
import org.mockito.Mockito._
import org.scalacheck.Gen
import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.Message
import org.apache.gearpump.cluster.{TestUtil, UserConfig}
import org.apache.gearpump.streaming.task.TaskId
import org.apache.gearpump.streaming.{MockUtil, Processor}
class SeqFileStreamProcessorSpec
  extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter {

  val kvPairs = new ArrayBuffer[(String, String)]
  val outputDirectory = "SeqFileStreamProcessor_Test"
  val sequenceFilePath = new Path(outputDirectory + File.separator + TaskId(0, 0))
  val hadoopConf = new Configuration()
  val fs = FileSystem.get(hadoopConf)
  val textClass = new Text().getClass
  val _key = new Text()
  val _value = new Text()

  val kvGenerator = for {
    key <- Gen.alphaStr
    value <- Gen.alphaStr
  } yield (key, value)

  before {
    implicit val system1 = ActorSystem("SeqFileStreamProcessor", TestUtil.DEFAULT_CONFIG)
    val system2 = ActorSystem("Reporter", TestUtil.DEFAULT_CONFIG)
    val watcher = TestProbe()(system1)
    val conf = HadoopConfig(UserConfig.empty.withString(SeqFileStreamProcessor.OUTPUT_PATH,
      outputDirectory)).withHadoopConf(new Configuration())
    val context = MockUtil.mockTaskContext

    val processorDescription =
      Processor.ProcessorToProcessorDescription(id = 0, Processor[SeqFileStreamProcessor](1))

    val taskId = TaskId(0, 0)
    when(context.taskId).thenReturn(taskId)

    val processor = new SeqFileStreamProcessor(context, conf)
    processor.onStart(Instant.EPOCH)

    forAll(kvGenerator) { kv =>
      val (key, value) = kv
      kvPairs.append((key, value))
      processor.onNext(Message(key + "++" + value))
    }
    processor.onStop()
  }

  property("SeqFileStreamProcessor should write the key-value pairs to a sequence file") {
    val reader = new SequenceFile.Reader(hadoopConf, Reader.file(sequenceFilePath))
    kvPairs.foreach { kv =>
      val (key, value) = kv
      if (value.length > 0 && reader.next(_key, _value)) {
        assert(_key.toString == key && _value.toString == value)
      }
    }
    reader.close()
  }

  after {
    fs.deleteOnExit(new Path(outputDirectory))
  }
} 
Example 93
Source File: SeqFileStreamProducerSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.fsio

import java.time.Instant

import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.SequenceFile.Writer
import org.apache.hadoop.io.{SequenceFile, Text}
import org.mockito.Mockito._
import org.scalacheck.Gen
import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.Message
import org.apache.gearpump.cluster.UserConfig
import org.apache.gearpump.streaming.MockUtil
import org.apache.gearpump.streaming.MockUtil._

class SeqFileStreamProducerSpec
  extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter {

  val kvPairs = new ArrayBuffer[(String, String)]
  val inputFile = "SeqFileStreamProducer_Test"
  val sequenceFilePath = new Path(inputFile)
  val hadoopConf = new Configuration()
  val fs = FileSystem.get(hadoopConf)
  val textClass = new Text().getClass
  val _key = new Text()
  val _value = new Text()

  val kvGenerator = for {
    key <- Gen.alphaStr
    value <- Gen.alphaStr
  } yield (key, value)

  before {
    fs.deleteOnExit(sequenceFilePath)
    val writer = SequenceFile.createWriter(hadoopConf, Writer.file(sequenceFilePath),
      Writer.keyClass(textClass), Writer.valueClass(textClass))
    forAll(kvGenerator) { kv =>
      _key.set(kv._1)
      _value.set(kv._2)
      kvPairs.append((kv._1, kv._2))
      writer.append(_key, _value)
    }
    writer.close()
  }

  property("SeqFileStreamProducer should read the key-value pairs from " +
    "a sequence file and deliver them") {

    val conf = HadoopConfig(UserConfig.empty.withString(SeqFileStreamProducer.INPUT_PATH,
      inputFile)).withHadoopConf(new Configuration())

    val context = MockUtil.mockTaskContext

    val producer = new SeqFileStreamProducer(context, conf)
    producer.onStart(Instant.EPOCH)
    producer.onNext(Message("start"))

    val expected = kvPairs.map(kv => kv._1 + "++" + kv._2).toSet
    verify(context).output(argMatch[Message](msg =>
      expected.contains(msg.value.asInstanceOf[String])))
  }

  after {
    fs.deleteOnExit(sequenceFilePath)
  }
} 
Example 94
Source File: NodeSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.complexdag

import org.apache.gearpump.Message
import org.apache.gearpump.cluster.UserConfig
import org.apache.gearpump.streaming.MockUtil
import org.apache.gearpump.streaming.MockUtil._
import org.mockito.Mockito._
import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

class NodeSpec extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter {

  val context = MockUtil.mockTaskContext

  val node = new Node(context, UserConfig.empty)

  property("Node should send a Vector[String](classOf[Node].getCanonicalName, " +
    "classOf[Node].getCanonicalName") {
    val list = Vector(classOf[Node].getCanonicalName)
    val expected = Vector(classOf[Node].getCanonicalName, classOf[Node].getCanonicalName)
    node.onNext(Message(list))
    verify(context).output(argMatch[Message](_.value == expected))
  }
} 
Example 95
Source File: SinkSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.complexdag

import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.Message
import org.apache.gearpump.cluster.UserConfig
import org.apache.gearpump.streaming.MockUtil

class SinkSpec extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter {

  val context = MockUtil.mockTaskContext

  val sink = new Sink(context, UserConfig.empty)

  property("Sink should send a Vector[String](classOf[Sink].getCanonicalName, " +
    "classOf[Sink].getCanonicalName") {
    val list = Vector(classOf[Sink].getCanonicalName)
    val expected = Vector(classOf[Sink].getCanonicalName, classOf[Sink].getCanonicalName)
    sink.onNext(Message(list))

    (0 until sink.list.size).map(i => {
      assert(sink.list(i).equals(expected(i)))
    })
  }
} 
Example 96
Source File: WindowAverageAppSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.state

import scala.concurrent.Future
import scala.util.Success

import com.typesafe.config.Config
import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.cluster.ClientToMaster.SubmitApplication
import org.apache.gearpump.cluster.MasterToClient.SubmitApplicationResult
import org.apache.gearpump.cluster.{MasterHarness, TestUtil}

class WindowAverageAppSpec
  extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter with MasterHarness {

  before {
    startActorSystem()
  }

  after {
    shutdownActorSystem()
  }

  override def config: Config = TestUtil.DEFAULT_CONFIG

  property("WindowAverage should succeed to submit application with required arguments") {
    val requiredArgs = Array.empty[String]
    val optionalArgs = Array(
      "-gen", "2",
      "-window", "2",
      "-window_size", "5000",
      "-window_step", "5000"
    )

    val args = {
      Table(
        ("requiredArgs", "optionalArgs"),
        (requiredArgs, optionalArgs.take(0)),
        (requiredArgs, optionalArgs.take(2)),
        (requiredArgs, optionalArgs.take(4)),
        (requiredArgs, optionalArgs.take(6)),
        (requiredArgs, optionalArgs)
      )
    }
    val masterReceiver = createMockMaster()
    forAll(args) { (requiredArgs: Array[String], optionalArgs: Array[String]) =>
      val args = requiredArgs ++ optionalArgs

      Future {
        WindowAverageApp.main(masterConfig, args)
      }

      masterReceiver.expectMsgType[SubmitApplication](PROCESS_BOOT_TIME)
      masterReceiver.reply(SubmitApplicationResult(Success(0)))
    }
  }
} 
Example 97
Source File: DefaultMessageCountAppSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.state

import scala.concurrent.Future
import scala.util.Success

import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.cluster.ClientToMaster.SubmitApplication
import org.apache.gearpump.cluster.MasterToClient.SubmitApplicationResult
import org.apache.gearpump.cluster.{MasterHarness, TestUtil}
import org.apache.gearpump.streaming.examples.state.MessageCountApp._

class DefaultMessageCountAppSpec
  extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter with MasterHarness {

  before {
    startActorSystem()
  }

  after {
    shutdownActorSystem()
  }

  protected override def config = TestUtil.DEFAULT_CONFIG

  property("MessageCount should succeed to submit application with required arguments") {
    val requiredArgs = Array(
      s"-$SOURCE_TOPIC", "source",
      s"-$SINK_TOPIC", "sink",
      s"-$ZOOKEEPER_CONNECT", "localhost:2181",
      s"-$BROKER_LIST", "localhost:9092",
      s"-$DEFAULT_FS", "hdfs://localhost:9000"
    )
    val optionalArgs = Array(
      s"-$SOURCE_TASK", "2",
      s"-$COUNT_TASK", "2",
      s"-$SINK_TASK", "2"
    )

    val args = {
      Table(
        ("requiredArgs", "optionalArgs"),
        (requiredArgs, optionalArgs.take(0)),
        (requiredArgs, optionalArgs.take(2)),
        (requiredArgs, optionalArgs.take(4)),
        (requiredArgs, optionalArgs)
      )
    }

    val masterReceiver = createMockMaster()
    forAll(args) { (requiredArgs: Array[String], optionalArgs: Array[String]) =>
      val args = requiredArgs ++ optionalArgs
      Future {
        MessageCountApp.main(masterConfig, args)
      }
      masterReceiver.expectMsgType[SubmitApplication](PROCESS_BOOT_TIME)
      masterReceiver.reply(SubmitApplicationResult(Success(0)))
    }
  }
} 
Example 98
Source File: WordCountSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.wordcount

import scala.concurrent.Future
import scala.util.Success

import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.cluster.ClientToMaster.SubmitApplication
import org.apache.gearpump.cluster.MasterToClient.SubmitApplicationResult
import org.apache.gearpump.cluster.{MasterHarness, TestUtil}

class WordCountSpec
  extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter with MasterHarness {

  before {
    startActorSystem()
  }

  after {
    shutdownActorSystem()
  }

  protected override def config = TestUtil.DEFAULT_CONFIG

  property("WordCount should succeed to submit application with required arguments") {
    val requiredArgs = Array.empty[String]
    val optionalArgs = Array(
      "-split", "1",
      "-sum", "1")

    val args = {
      Table(
        ("requiredArgs", "optionalArgs"),
        (requiredArgs, optionalArgs)
      )
    }
    val masterReceiver = createMockMaster()
    forAll(args) { (requiredArgs: Array[String], optionalArgs: Array[String]) =>

      val args = requiredArgs ++ optionalArgs

      Future {
        WordCount.main(masterConfig, args)
      }

      masterReceiver.expectMsgType[SubmitApplication](PROCESS_BOOT_TIME)
      masterReceiver.reply(SubmitApplicationResult(Success(0)))
    }
  }
} 
Example 99
Source File: WordCountSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.wordcountjava

import org.apache.gearpump.cluster.ClientToMaster.SubmitApplication
import org.apache.gearpump.cluster.MasterToClient.SubmitApplicationResult
import org.apache.gearpump.cluster.{MasterHarness, TestUtil}
import org.apache.gearpump.streaming.examples.wordcountjava.dsl.WordCount
import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import scala.concurrent.Future
import scala.util.Success

class WordCountSpec
  extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter with MasterHarness {

  before {
    startActorSystem()
  }

  after {
    shutdownActorSystem()
  }

  protected override def config = TestUtil.DEFAULT_CONFIG

  property("WordCount should succeed to submit application with required arguments") {
    val requiredArgs = Array.empty[String]

    val masterReceiver = createMockMaster()

    val args = requiredArgs

    Future {
      WordCount.main(masterConfig, args)
    }

    masterReceiver.expectMsgType[SubmitApplication](PROCESS_BOOT_TIME)
    masterReceiver.reply(SubmitApplicationResult(Success(0)))
  }
} 
Example 100
Source File: KafkaWordCountSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.streaming.examples.kafka.wordcount

import scala.concurrent.Future
import scala.util.Success

import com.typesafe.config.Config
import org.scalatest.prop.PropertyChecks
import org.scalatest.{BeforeAndAfter, Matchers, PropSpec}

import org.apache.gearpump.cluster.ClientToMaster.SubmitApplication
import org.apache.gearpump.cluster.MasterToClient.SubmitApplicationResult
import org.apache.gearpump.cluster.{MasterHarness, TestUtil}

class KafkaWordCountSpec
  extends PropSpec with PropertyChecks with Matchers with BeforeAndAfter with MasterHarness {

  before {
    startActorSystem()
  }

  after {
    shutdownActorSystem()
  }

  override def config: Config = TestUtil.DEFAULT_CONFIG

  property("KafkaWordCount should succeed to submit application with required arguments") {
    val requiredArgs = Array.empty[String]
    val optionalArgs = Array(
      "-source", "1",
      "-split", "1",
      "-sum", "1",
      "-sink", "1")

    val args = {
      Table(
        ("requiredArgs", "optionalArgs"),
        (requiredArgs, optionalArgs)
      )
    }
    val masterReceiver = createMockMaster()
    forAll(args) { (requiredArgs: Array[String], optionalArgs: Array[String]) =>
      val args = requiredArgs ++ optionalArgs

      Future {
        KafkaWordCount.main(masterConfig, args)
      }

      masterReceiver.expectMsgType[SubmitApplication](PROCESS_BOOT_TIME)
      masterReceiver.reply(SubmitApplicationResult(Success(0)))
    }
  }
} 
Example 101
Source File: DistServiceAppMasterSpec.scala    From incubator-retired-gearpump   with Apache License 2.0 5 votes vote down vote up
package org.apache.gearpump.experiments.distributeservice

import scala.concurrent.Await
import scala.concurrent.duration._

import akka.actor.ActorSystem
import akka.testkit.{TestActorRef, TestProbe}
import org.scalatest.{BeforeAndAfter, Matchers, WordSpec}

import org.apache.gearpump.cluster.AppMasterToMaster.{GetAllWorkers, RegisterAppMaster, RequestResource}
import org.apache.gearpump.cluster.AppMasterToWorker.LaunchExecutor
import org.apache.gearpump.cluster.MasterToAppMaster.{AppMasterRegistered, ResourceAllocated, WorkerList}
import org.apache.gearpump.cluster.appmaster.AppMasterRuntimeEnvironment
import org.apache.gearpump.cluster.scheduler.{Relaxation, Resource, ResourceAllocation, ResourceRequest}
import org.apache.gearpump.cluster.worker.WorkerId
import org.apache.gearpump.cluster.{AppDescription, AppMasterContext, TestUtil, UserConfig}
import org.apache.gearpump.experiments.distributeservice.DistServiceAppMaster.{FileContainer, GetFileContainer}
import org.apache.gearpump.util.ActorSystemBooter.RegisterActorSystem
import org.apache.gearpump.util.ActorUtil

class DistServiceAppMasterSpec extends WordSpec with Matchers with BeforeAndAfter {
  implicit val system = ActorSystem("AppMasterSpec", TestUtil.DEFAULT_CONFIG)
  val mockMaster = TestProbe()(system)
  val mockWorker1 = TestProbe()(system)
  val client = TestProbe()(system)
  val masterProxy = mockMaster.ref
  val appId = 0
  val userName = "test"
  val masterExecutorId = 0
  val workerList = List(WorkerId(1, 0L), WorkerId(2, 0L), WorkerId(3, 0L))
  val resource = Resource(1)
  val appJar = None
  val appDescription = AppDescription("app0", classOf[DistServiceAppMaster].getName,
    UserConfig.empty)

  "DistService AppMaster" should {
    "responsable for service distributing" in {
      val appMasterContext = AppMasterContext(appId, userName, resource, null, appJar, masterProxy)
      TestActorRef[DistServiceAppMaster](
        AppMasterRuntimeEnvironment.props(List(masterProxy.path), appDescription,
          appMasterContext))
      val registerAppMaster = mockMaster.receiveOne(15.seconds)
      assert(registerAppMaster.isInstanceOf[RegisterAppMaster])

      val appMaster = registerAppMaster.asInstanceOf[RegisterAppMaster].appMaster
      mockMaster.reply(AppMasterRegistered(appId))
      // The DistributedShell AppMaster will ask for worker list
      mockMaster.expectMsg(GetAllWorkers)
      mockMaster.reply(WorkerList(workerList))
      // After worker list is ready, DistributedShell AppMaster will request resouce on each worker
      workerList.foreach { workerId =>
        mockMaster.expectMsg(RequestResource(appId, ResourceRequest(Resource(1), workerId,
          relaxation = Relaxation.SPECIFICWORKER)))
      }
      mockMaster.reply(ResourceAllocated(Array(ResourceAllocation(resource, mockWorker1.ref,
        WorkerId(1, 0L)))))
      mockWorker1.expectMsgClass(classOf[LaunchExecutor])
      mockWorker1.reply(RegisterActorSystem(ActorUtil.getSystemAddress(system).toString))

      appMaster.tell(GetFileContainer, client.ref)
      client.expectMsgClass(15.seconds, classOf[FileContainer])
    }
  }

  after {
    system.terminate()
    Await.result(system.whenTerminated, Duration.Inf)
  }
} 
Example 102
Source File: SignatureProducerActorSpecForIntegration.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package integration.security

import akka.actor.{ActorRef, ActorSystem, Props}
import akka.testkit.{ImplicitSender, TestKit}
import org.apache.toree.kernel.protocol.v5._
import org.apache.toree.communication.security.{Hmac, SignatureProducerActor}
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfter, FunSpecLike, Matchers}

object SignatureProducerActorSpecForIntegration {
  val config = """
    akka {
      loglevel = "WARNING"
    }"""
}

class SignatureProducerActorSpecForIntegration extends TestKit(
  ActorSystem(
    "SignatureProducerActorSpec",
    ConfigFactory.parseString(SignatureProducerActorSpecForIntegration.config)
  )
) with ImplicitSender with FunSpecLike with Matchers with BeforeAndAfter
{

  private val sigKey = "12345"

  private var signatureProducer: ActorRef = _

  before {
    val hmac = Hmac(sigKey)
    signatureProducer =
      system.actorOf(Props(classOf[SignatureProducerActor], hmac))

  }

  after {
    signatureProducer = null
  }

  describe("SignatureProducerActor") {
    describe("#receive") {
      it("should return the correct signature for a kernel message") {
        val expectedSignature =
          "1c4859a7606fd93eb5f73c3d9642f9bc860453ba42063961a00d02ed820147b5"
        val message =
          KernelMessage(
            null, "",
            Header("a", "b", "c", "d", "e"),
            ParentHeader("f", "g", "h", "i", "j"),
            Metadata(),
            "<STRING>"
          )

        signatureProducer ! message
        expectMsg(expectedSignature)
      }
    }
  }
} 
Example 103
Source File: SignatureCheckerActorSpecForIntegration.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package integration.security

import akka.actor.{ActorRef, ActorSystem, Props}
import akka.testkit.{ImplicitSender, TestKit}
import org.apache.toree.kernel.protocol.v5._
import org.apache.toree.communication.security.{Hmac, SignatureCheckerActor}
import com.typesafe.config.ConfigFactory
import org.scalatest.{BeforeAndAfter, FunSpecLike, Matchers}
import play.api.libs.json.Json

object SignatureCheckerActorSpecForIntegration {
  val config = """
    akka {
      loglevel = "WARNING"
    }"""
}

class SignatureCheckerActorSpecForIntegration extends TestKit(
  ActorSystem(
    "SignatureCheckerActorSpec",
    ConfigFactory.parseString(SignatureCheckerActorSpecForIntegration.config)
  )
) with ImplicitSender with FunSpecLike with Matchers with BeforeAndAfter
{

  private val sigKey = "12345"
  private val signature =
    "1c4859a7606fd93eb5f73c3d9642f9bc860453ba42063961a00d02ed820147b5"
  private val goodMessage =
    KernelMessage(
      null, signature,
      Header("a", "b", "c", "d", "e"),
      ParentHeader("f", "g", "h", "i", "j"),
      Metadata(),
      "<STRING>"
    )
  private val badMessage =
    KernelMessage(
      null, "wrong signature",
      Header("a", "b", "c", "d", "e"),
      ParentHeader("f", "g", "h", "i", "j"),
      Metadata(),
      "<STRING>"
    )

  private var signatureChecker: ActorRef = _

  before {
    val hmac = Hmac(sigKey)
    signatureChecker =
      system.actorOf(Props(classOf[SignatureCheckerActor], hmac))
  }

  after {
    signatureChecker = null
  }

  describe("SignatureCheckerActor") {
    describe("#receive") {
      it("should return true if the kernel message is valid") {
        val blob =
          Json.stringify(Json.toJson(goodMessage.header)) ::
          Json.stringify(Json.toJson(goodMessage.parentHeader)) ::
          Json.stringify(Json.toJson(goodMessage.metadata)) ::
          goodMessage.contentString ::
          Nil
        signatureChecker ! ((goodMessage.signature, blob))
        expectMsg(true)
      }

      it("should return false if the kernel message is invalid") {
        val blob =
          Json.stringify(Json.toJson(badMessage.header)) ::
          Json.stringify(Json.toJson(badMessage.parentHeader)) ::
          Json.stringify(Json.toJson(badMessage.metadata)) ::
          badMessage.contentString ::
          Nil
        signatureChecker ! ((badMessage.signature, blob))
        expectMsg(false)
      }
    }
  }
} 
Example 104
Source File: JeroMQSocketSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.communication.socket

import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.{Matchers, BeforeAndAfter, OneInstancePerTest, FunSpec}
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito._
import org.zeromq.ZMsg

class JeroMQSocketSpec extends FunSpec with MockitoSugar
  with OneInstancePerTest with BeforeAndAfter with Matchers
{
  private val runnable = mock[ZeroMQSocketRunnable]
  @volatile private var running = true
  //  Mock the running of the runnable for the tests
  doAnswer(new Answer[Unit] {
    override def answer(invocation: InvocationOnMock): Unit = while (running) {
      Thread.sleep(1)
    }
  }).when(runnable).run()


  //  Mock the close of the runnable to shutdown
  doAnswer(new Answer[Unit] {
    override def answer(invocation: InvocationOnMock): Unit = running = false
  }).when(runnable).close()

  private val socket: JeroMQSocket = new JeroMQSocket(runnable)

  after {
    running = false
  }

  describe("JeroMQSocket") {
    describe("#send") {
      it("should offer a message to the runnable") {
        val message: String = "Some Message"
        val expected = ZMsg.newStringMsg(message)

        socket.send(message.getBytes)
        verify(runnable).offer(expected)
      }

      it("should thrown and AssertionError when socket is no longer alive") {
        socket.close()

        intercept[AssertionError] {
          socket.send("".getBytes)
        }
      }
    }

    describe("#close") {
      it("should close the runnable") {
        socket.close()

        verify(runnable).close()
      }

      it("should close the socket thread") {
        socket.close()

        socket.isAlive should be (false)
      }
    }

    describe("#isAlive") {
      it("should evaluate to true when the socket thread is alive") {
        socket.isAlive should be (true)
      }

      it("should evaluate to false when the socket thread is dead") {
        socket.close()

        socket.isAlive should be (false)
      }
    }
  }
} 
Example 105
Source File: JVMReprSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package integration.interpreter.scala

import java.util
import java.io.ByteArrayOutputStream
import jupyter.{Displayer, Displayers, MIMETypes}
import org.apache.toree.global.StreamState
import org.apache.toree.interpreter.Interpreter
import org.apache.toree.interpreter.Results.Success
import org.apache.toree.kernel.api.{DisplayMethodsLike, KernelLike}
import org.apache.toree.kernel.interpreter.scala.ScalaInterpreter
import org.mockito.Mockito.doReturn
import org.scalatest.{BeforeAndAfter, FunSpec, Matchers}
import org.scalatest.mock.MockitoSugar
import scala.util.Random

class JVMReprSpec extends FunSpec with Matchers with MockitoSugar with BeforeAndAfter {

  private val outputResult = new ByteArrayOutputStream()
  private var interpreter: Interpreter = _

  before {
    val mockKernel = mock[KernelLike]
    val mockDisplayMethods = mock[DisplayMethodsLike]
    doReturn(mockDisplayMethods).when(mockKernel).display

    interpreter = new ScalaInterpreter().init(mockKernel)

    StreamState.setStreams(outputStream = outputResult)
  }

  after {
    interpreter.stop()
    outputResult.reset()
  }

  describe("ScalaInterpreter") {
    describe("#interpret") {
      it("should display Scala int as a text representation") {
        val (result, outputOrError) = interpreter.interpret("val a = 12")

        result should be(Success)
        outputOrError.isLeft should be(true)
        outputOrError.left.get should be(Map(MIMETypes.TEXT -> "12"))
      }

      it("should display Scala Some(str) as a text representation") {
        val (result, outputOrError) = interpreter.interpret("""val a = Some("str")""")

        result should be(Success)
        outputOrError.isLeft should be(true)
        outputOrError.left.get should be(Map(MIMETypes.TEXT -> "Some(str)"))
      }

      ignore("should use the Jupyter REPR API for display representation") {
        Displayers.register(classOf[DisplayerTest], new Displayer[DisplayerTest] {
          override def display(t: DisplayerTest): util.Map[String, String] = {
            val output = new util.HashMap[String, String]()
            output.put("text/plain", s"test object: ${t.id}")
            output.put("application/json", s"""{"id": ${t.id}""")
            output
          }
        })

        val inst = DisplayerTest()
        interpreter.bind("inst", classOf[DisplayerTest].getName, inst, List())

        val (result, outputOrError) = interpreter.interpret("""inst""")

        result should be(Success)
        outputOrError.isLeft should be(true)
        outputOrError.left.get should be(Map(
          MIMETypes.TEXT -> s"test object: ${inst.id}",
          "application/json" -> s"""{"id": ${inst.id}"""
        ))
      }
    }
  }
}

case class DisplayerTest(id: Long = new Random().nextLong()) 
Example 106
Source File: ClientCommManagerSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.comm

import org.apache.toree.kernel.protocol.v5
import org.apache.toree.kernel.protocol.v5._
import org.apache.toree.kernel.protocol.v5.client.ActorLoader
import org.apache.toree.kernel.protocol.v5.content.CommContent
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito._
import org.mockito.Matchers._
import org.scalatest.{BeforeAndAfter, FunSpec, Matchers}

class ClientCommManagerSpec extends FunSpec with Matchers with BeforeAndAfter
  with MockitoSugar
{
  private val TestTargetName = "some target"

  private var mockActorLoader: ActorLoader = _
  private var mockKMBuilder: KMBuilder = _
  private var mockCommRegistrar: CommRegistrar = _
  private var clientCommManager: ClientCommManager = _

  private var generatedCommWriter: CommWriter = _

  before {
    mockActorLoader = mock[ActorLoader]
    mockKMBuilder = mock[KMBuilder]
    mockCommRegistrar = mock[CommRegistrar]

    clientCommManager = new ClientCommManager(
      mockActorLoader,
      mockKMBuilder,
      mockCommRegistrar
    ) {
      override protected def newCommWriter(commId: UUID): CommWriter = {
        val commWriter = super.newCommWriter(commId)

        generatedCommWriter = commWriter

        val spyCommWriter = spy(commWriter)
        doNothing().when(spyCommWriter)
          .sendCommKernelMessage(any[KernelMessageContent with CommContent])

        spyCommWriter
      }
    }
  }

  describe("ClientCommManager") {
    describe("#open") {
      it("should return a wrapped instance of ClientCommWriter") {
        clientCommManager.open(TestTargetName, v5.MsgData.Empty)

        // Exposed hackishly for testing
        generatedCommWriter shouldBe a [ClientCommWriter]
      }
    }
  }
} 
Example 107
Source File: SparkKernelClientSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.kernel.protocol.v5.client

import akka.actor.ActorSystem
import akka.testkit.{TestKit, TestProbe}
import org.apache.toree.comm.{CommCallbacks, CommStorage, CommRegistrar}
import org.apache.toree.kernel.protocol.v5
import org.apache.toree.kernel.protocol.v5._
import org.apache.toree.kernel.protocol.v5.client.execution.ExecuteRequestTuple
import scala.concurrent.duration._
import org.mockito.Mockito._
import org.mockito.Matchers.{eq => mockEq, _}
import org.scalatest.mock.MockitoSugar
import org.scalatest.{BeforeAndAfter, FunSpecLike, Matchers}

class SparkKernelClientSpec
  extends TestKit(ActorSystem("SparkKernelClientActorSystem"))
  with Matchers with MockitoSugar with FunSpecLike with BeforeAndAfter
{
  private val TestTargetName = "some target"

  private var mockActorLoader: ActorLoader = _
  private var mockCommRegistrar: CommRegistrar = _
  private var sparkKernelClient: SparkKernelClient = _
  private var executeRequestProbe: TestProbe = _
  private var shellClientProbe: TestProbe = _

  before {
    mockActorLoader = mock[ActorLoader]
    mockCommRegistrar = mock[CommRegistrar]

    executeRequestProbe = TestProbe()
    when(mockActorLoader.load(MessageType.Incoming.ExecuteRequest))
      .thenReturn(system.actorSelection(executeRequestProbe.ref.path.toString))

    shellClientProbe = TestProbe()
    when(mockActorLoader.load(SocketType.ShellClient))
      .thenReturn(system.actorSelection(shellClientProbe.ref.path.toString))

    sparkKernelClient = new SparkKernelClient(
      mockActorLoader, system, mockCommRegistrar)
  }

  describe("SparkKernelClient") {
    describe("#execute") {
      it("should send an ExecuteRequest message") {
        val func = (x: Any) => println(x)
        sparkKernelClient.execute("val foo = 2")
        executeRequestProbe.expectMsgClass(classOf[ExecuteRequestTuple])
      }
    }
  }
} 
Example 108
Source File: KernelCommManagerSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.comm

import org.apache.toree.kernel.protocol.v5
import org.apache.toree.kernel.protocol.v5._
import org.apache.toree.kernel.protocol.v5.content.CommContent
import org.apache.toree.kernel.protocol.v5.kernel.ActorLoader
import org.scalatest.mock.MockitoSugar
import org.mockito.Mockito._
import org.mockito.Matchers._
import org.scalatest.{BeforeAndAfter, FunSpec, Matchers}

class KernelCommManagerSpec extends FunSpec with Matchers with BeforeAndAfter
  with MockitoSugar
{
  private val TestTargetName = "some target"

  private var mockActorLoader: ActorLoader = _
  private var mockKMBuilder: KMBuilder = _
  private var mockCommRegistrar: CommRegistrar = _
  private var kernelCommManager: KernelCommManager = _

  private var generatedCommWriter: CommWriter = _

  before {
    mockActorLoader = mock[ActorLoader]
    mockKMBuilder = mock[KMBuilder]
    mockCommRegistrar = mock[CommRegistrar]

    kernelCommManager = new KernelCommManager(
      mockActorLoader,
      mockKMBuilder,
      mockCommRegistrar
    ) {
      override protected def newCommWriter(commId: UUID): CommWriter = {
        val commWriter = super.newCommWriter(commId)

        generatedCommWriter = commWriter

        val spyCommWriter = spy(commWriter)
        doNothing().when(spyCommWriter)
          .sendCommKernelMessage(any[KernelMessageContent with CommContent])

        spyCommWriter
      }
    }
  }

  describe("KernelCommManager") {
    describe("#open") {
      it("should return a wrapped instance of KernelCommWriter") {
        kernelCommManager.open(TestTargetName, v5.MsgData.Empty)

        // Exposed hackishly for testing
        generatedCommWriter shouldBe a [KernelCommWriter]
      }
    }
  }
} 
Example 109
Source File: CodeCompleteHandlerSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.kernel.protocol.v5.handler

import akka.actor._
import akka.testkit.{TestProbe, ImplicitSender, TestKit}
import org.apache.toree.Main
import org.apache.toree.kernel.protocol.v5._
import org.apache.toree.kernel.protocol.v5.content.CompleteRequest
import org.apache.toree.kernel.protocol.v5.kernel.ActorLoader
import org.apache.toree.kernel.protocol.v5Test._
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FunSpecLike, BeforeAndAfter, Matchers}
import org.mockito.Mockito._
import test.utils.MaxAkkaTestTimeout

class CodeCompleteHandlerSpec extends TestKit(
  ActorSystem("CodeCompleteHandlerSpec", None, Some(Main.getClass.getClassLoader))
) with ImplicitSender with FunSpecLike with Matchers with MockitoSugar
  with BeforeAndAfter {

  var actorLoader: ActorLoader = _
  var handlerActor: ActorRef = _
  var kernelMessageRelayProbe: TestProbe = _
  var interpreterProbe: TestProbe = _
  var statusDispatchProbe: TestProbe = _

  before {
    actorLoader = mock[ActorLoader]

    handlerActor = system.actorOf(Props(classOf[CodeCompleteHandler], actorLoader))

    kernelMessageRelayProbe = TestProbe()
    when(actorLoader.load(SystemActorType.KernelMessageRelay))
      .thenReturn(system.actorSelection(kernelMessageRelayProbe.ref.path.toString))

    interpreterProbe = new TestProbe(system)
    when(actorLoader.load(SystemActorType.Interpreter))
      .thenReturn(system.actorSelection(interpreterProbe.ref.path.toString))

    statusDispatchProbe = new TestProbe(system)
    when(actorLoader.load(SystemActorType.StatusDispatch))
      .thenReturn(system.actorSelection(statusDispatchProbe.ref.path.toString))
  }

  def replyToHandlerWithOkAndResult() = {
    val expectedClass = classOf[CompleteRequest]
    interpreterProbe.expectMsgClass(expectedClass)
    interpreterProbe.reply((0, List[String]()))
  }

  def replyToHandlerWithOkAndBadResult() = {
    val expectedClass = classOf[CompleteRequest]
    interpreterProbe.expectMsgClass(expectedClass)
    interpreterProbe.reply("hello")
  }

  describe("CodeCompleteHandler (ActorLoader)") {
    it("should send a CompleteRequest") {
      handlerActor ! MockCompleteRequestKernelMessage
      replyToHandlerWithOkAndResult()
      kernelMessageRelayProbe.fishForMessage(MaxAkkaTestTimeout) {
        case KernelMessage(_, _, header, _, _, _) =>
          header.msg_type == MessageType.Outgoing.CompleteReply.toString
      }
    }

    it("should throw an error for bad JSON") {
      handlerActor ! MockKernelMessageWithBadJSON
      var result = false
      try {
        replyToHandlerWithOkAndResult()
      }
      catch {
        case t: Throwable => result = true
      }
      result should be (true)
    }

    it("should throw an error for bad code completion") {
      handlerActor ! MockCompleteRequestKernelMessage
      try {
        replyToHandlerWithOkAndBadResult()
      }
      catch {
        case error: Exception => error.getMessage should be ("Parse error in CodeCompleteHandler")
      }
    }

    it("should send an idle message") {
      handlerActor ! MockCompleteRequestKernelMessage
      replyToHandlerWithOkAndResult()
      statusDispatchProbe.fishForMessage(MaxAkkaTestTimeout) {
        case Tuple2(status, _) =>
          status == KernelStatusType.Idle
      }
    }
  }
} 
Example 110
Source File: StatusDispatchSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.kernel.protocol.v5.dispatch

import akka.actor.{ActorRef, ActorSystem, Props}
import akka.testkit.{TestKit, TestProbe}
import org.apache.toree.kernel.protocol.v5._
import org.apache.toree.kernel.protocol.v5.content.KernelStatus
import org.apache.toree.kernel.protocol.v5.kernel.ActorLoader
import org.mockito.Mockito._
import org.scalatest.mock.MockitoSugar
import org.scalatest.{BeforeAndAfter, FunSpecLike, Matchers}
import play.api.libs.json.Json
import test.utils.MaxAkkaTestTimeout

class StatusDispatchSpec extends TestKit(
  ActorSystem(
    "StatusDispatchSystem",
    None,
    Some(org.apache.toree.Main.getClass.getClassLoader)
  )
)
with FunSpecLike with Matchers with MockitoSugar with BeforeAndAfter{
  var statusDispatchRef: ActorRef = _
  var relayProbe: TestProbe = _
  before {
    //  Mock the relay with a probe
    relayProbe = TestProbe()
    //  Mock the ActorLoader
    val mockActorLoader: ActorLoader = mock[ActorLoader]
    when(mockActorLoader.load(SystemActorType.KernelMessageRelay))
      .thenReturn(system.actorSelection(relayProbe.ref.path.toString))

    statusDispatchRef = system.actorOf(Props(classOf[StatusDispatch],mockActorLoader))
  }


  describe("StatusDispatch") {
    describe("#receive( KernelStatusType )") {
      it("should send a status message to the relay") {
        statusDispatchRef ! KernelStatusType.Busy
        //  Check the kernel message is the correct type
        val statusMessage: KernelMessage = relayProbe.receiveOne(MaxAkkaTestTimeout).asInstanceOf[KernelMessage]
        statusMessage.header.msg_type should be (MessageType.Outgoing.Status.toString)
        //  Check the status is what we sent
        val status: KernelStatus = Json.parse(statusMessage.contentString).as[KernelStatus]
         status.execution_state should be (KernelStatusType.Busy.toString)
      }
    }

    describe("#receive( KernelStatusType, Header )") {
      it("should send a status message to the relay") {
        val tuple = Tuple2(KernelStatusType.Busy, mock[Header])
        statusDispatchRef ! tuple
        //  Check the kernel message is the correct type
        val statusMessage: KernelMessage = relayProbe.receiveOne(MaxAkkaTestTimeout).asInstanceOf[KernelMessage]
        statusMessage.header.msg_type should be (MessageType.Outgoing.Status.toString)
        //  Check the status is what we sent
        val status: KernelStatus = Json.parse(statusMessage.contentString).as[KernelStatus]
        status.execution_state should be (KernelStatusType.Busy.toString)
      }
    }
  }
} 
Example 111
Source File: StreamMethodsSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.kernel.api

import akka.actor.ActorSystem
import akka.testkit.{ImplicitSender, TestKit, TestProbe}
import org.apache.toree.kernel.protocol.v5
import org.apache.toree.kernel.protocol.v5.KernelMessage
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FunSpecLike, BeforeAndAfter, Matchers}
import play.api.libs.json.Json
import test.utils.MaxAkkaTestTimeout
import org.mockito.Mockito._

class StreamMethodsSpec extends TestKit(
  ActorSystem(
    "StreamMethodsSpec",
    None,
    Some(org.apache.toree.Main.getClass.getClassLoader)
  )
) with ImplicitSender with FunSpecLike with Matchers with MockitoSugar
  with BeforeAndAfter
{

  private var kernelMessageRelayProbe: TestProbe = _
  private var mockParentHeader: v5.ParentHeader = _
  private var mockActorLoader: v5.kernel.ActorLoader = _
  private var mockKernelMessage: v5.KernelMessage = _
  private var streamMethods: StreamMethods = _

  before {
    kernelMessageRelayProbe = TestProbe()

    mockParentHeader = mock[v5.ParentHeader]

    mockActorLoader = mock[v5.kernel.ActorLoader]
    doReturn(system.actorSelection(kernelMessageRelayProbe.ref.path))
      .when(mockActorLoader).load(v5.SystemActorType.KernelMessageRelay)

    mockKernelMessage = mock[v5.KernelMessage]
    doReturn(mockParentHeader).when(mockKernelMessage).header

    streamMethods = new StreamMethods(mockActorLoader, mockKernelMessage)
  }

  describe("StreamMethods") {
    describe("#()") {
      it("should put the header of the given message as the parent header") {
        val expected = mockKernelMessage.header
        val actual = streamMethods.kmBuilder.build.parentHeader

        actual should be (expected)
      }
    }

    describe("#sendAll") {
      it("should send a message containing all of the given text") {
        val expected = "some text"

        streamMethods.sendAll(expected)

        val outgoingMessage = kernelMessageRelayProbe.receiveOne(MaxAkkaTestTimeout)
        val kernelMessage = outgoingMessage.asInstanceOf[KernelMessage]

        val actual = Json.parse(kernelMessage.contentString)
          .as[v5.content.StreamContent].text

        actual should be (expected)
      }
    }
  }

} 
Example 112
Source File: DownloadSupportSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.utils

import java.io.FileNotFoundException
import java.net.URL

import org.scalatest.{BeforeAndAfter, Matchers, FunSpec}
import scala.io.Source
import scala.tools.nsc.io.File

class DownloadSupportSpec extends FunSpec with Matchers with BeforeAndAfter {
  val downloadDestinationUrl = new URL("file:///tmp/testfile2.ext")

  val testFileContent = "This is a test"
  val testFileName = "/tmp/testfile.txt"

  //  Create a test file for downloading
  before {
    File(testFileName).writeAll(testFileContent)
  }

  //  Cleanup what we made
  after {
    if (File(testFileName).exists) File(testFileName).delete()
    if (File(downloadDestinationUrl.getPath).exists) File(downloadDestinationUrl.getPath).delete()
  }

  describe("DownloadSupport"){
    describe("#downloadFile( String, String )"){
      it("should download a file to the download directory"){
        val testFileUrl = "file:///tmp/testfile.txt"

        //  Create our utility and download the file
        val downloader = new Object with DownloadSupport
        downloader.downloadFile(
          testFileUrl,
          downloadDestinationUrl.getProtocol + "://" +
            downloadDestinationUrl.getPath)

        //  Verify the file contents are what was in the original file
        val downloadedFileContent: String =
          Source.fromFile(downloadDestinationUrl.getPath).mkString

        downloadedFileContent should be (testFileContent)
      }

    }

    describe("#downloadFile( URL, URL )"){
      it("should download a file to the download directory"){
        val testFileUrl = new URL("file:///tmp/testfile.txt")

        val downloader = new Object with DownloadSupport
        downloader.downloadFile(testFileUrl, downloadDestinationUrl)

        //  Verify the file contents are what was in the original file
        val downloadedFileContent: String =
          Source.fromFile(downloadDestinationUrl.getPath).mkString

        downloadedFileContent should be (testFileContent)
      }

      it("should throw FileNotFoundException if the download URL is bad"){
        val badFilename = "file:///tmp/testbadfile.txt"
        if (File(badFilename).exists) File(badFilename).delete()

        val badFileUrl = new URL(badFilename)

        val downloader = new Object with DownloadSupport
        intercept[FileNotFoundException] {
          downloader.downloadFile(badFileUrl, downloadDestinationUrl)
        }
      }

      it("should throw FileNotFoundException if the download ") {
        val testFileUrl = new URL("file:///tmp/testfile.txt")
        val badDestinationUrl =
          new URL("file:///tmp/badloc/that/doesnt/exist.txt")

        val downloader = new Object with DownloadSupport
        intercept[FileNotFoundException] {
          downloader.downloadFile(testFileUrl, badDestinationUrl)
        }
      }
    }
  }

} 
Example 113
Source File: MultiOutputStreamSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.utils

import java.io.OutputStream

import org.scalatest.mock.MockitoSugar
import org.scalatest.{BeforeAndAfter, Matchers, FunSpec}
import org.mockito.Matchers._
import org.mockito.Mockito._

class MultiOutputStreamSpec
  extends FunSpec with Matchers with MockitoSugar with BeforeAndAfter {

  describe("MultiOutputStream") {
    val listOfMockOutputStreams = List(mock[OutputStream], mock[OutputStream])
    val multiOutputStream = MultiOutputStream(listOfMockOutputStreams)

    describe("#close") {
      it("should call #close on all internal output streams") {
        multiOutputStream.close()

        listOfMockOutputStreams.foreach(mockOutputStream => verify(mockOutputStream).close())
      }
    }

    describe("#flush") {
      it("should call #flush on all internal output streams") {
        multiOutputStream.flush()

        listOfMockOutputStreams.foreach(mockOutputStream => verify(mockOutputStream).flush())
      }
    }

    describe("#write(int)") {
      it("should call #write(int) on all internal output streams") {
        multiOutputStream.write(anyInt())

        listOfMockOutputStreams.foreach(
          mockOutputStream => verify(mockOutputStream).write(anyInt()))
      }
    }
    describe("#write(byte[])") {
      it("should call #write(byte[]) on all internal output streams") {
        multiOutputStream.write(any[Array[Byte]])

        listOfMockOutputStreams.foreach(
          mockOutputStream => verify(mockOutputStream).write(any[Array[Byte]]))
      }
    }

    describe("#write(byte[], int, int)") {
      it("should call #write(byte[], int, int) on all internal output streams") {
        multiOutputStream.write(any[Array[Byte]], anyInt(), anyInt())

        listOfMockOutputStreams.foreach(
          mockOutputStream =>
            verify(mockOutputStream).write(any[Array[Byte]], anyInt(), anyInt()))
      }
    }
  }
} 
Example 114
Source File: ArgumentParsingSupportSpec.scala    From incubator-toree   with Apache License 2.0 5 votes vote down vote up
package org.apache.toree.utils

import org.scalatest.{BeforeAndAfter, Matchers, FunSpec}
import joptsimple.{OptionSet, OptionSpec, OptionParser}
import org.scalatest.mock.MockitoSugar

import org.mockito.Mockito._
import org.mockito.Matchers._

import collection.JavaConverters._

class ArgumentParsingSupportSpec extends FunSpec with Matchers
  with BeforeAndAfter with MockitoSugar
{
  private var mockOptions: OptionSet = _
  private var mockParser: OptionParser = _
  private var argumentParsingInstance: ArgumentParsingSupport = _

  before {
    mockOptions = mock[OptionSet]
    mockParser = mock[OptionParser]
    doReturn(mockOptions).when(mockParser).parse(anyVararg[String]())

    argumentParsingInstance = new Object() with ArgumentParsingSupport {
      override protected lazy val parser: OptionParser = mockParser
    }
  }

  describe("ArgumentParsingSupport") {
    describe("#parseArgs") {
      it("should invoke the underlying parser's parse method") {
        doReturn(Nil.asJava).when(mockOptions).nonOptionArguments()
        argumentParsingInstance.parseArgs("")

        verify(mockParser).parse(anyString())
      }

      it("should return an empty list if there are no non-option arguments") {
        val expected = Nil
        doReturn(expected.asJava).when(mockOptions).nonOptionArguments()
        val actual = argumentParsingInstance.parseArgs((
          "--transitive" :: expected
        ).mkString(" "))

        actual should be (expected)
      }

      it("should return a list containing non-option arguments") {
        val expected = "non-option" :: Nil
        doReturn(expected.asJava).when(mockOptions).nonOptionArguments()
        val actual = argumentParsingInstance.parseArgs((
          "--transitive" :: expected
          ).mkString(" "))

        actual should be (expected)
      }
    }
  }
} 
Example 115
Source File: ClientTest.scala    From bitcoin-s-spv-node   with MIT License 5 votes vote down vote up
package org.bitcoins.spvnode.networking

import java.net.{InetSocketAddress, ServerSocket}

import akka.actor.ActorSystem
import akka.io.{Inet, Tcp}
import akka.testkit.{ImplicitSender, TestActorRef, TestKit, TestProbe}
import org.bitcoins.core.config.TestNet3
import org.bitcoins.core.util.{BitcoinSLogger, BitcoinSUtil}
import org.bitcoins.spvnode.messages.control.VersionMessage
import org.bitcoins.spvnode.messages.{NetworkPayload, VersionMessage}
import org.bitcoins.spvnode.util.BitcoinSpvNodeUtil
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FlatSpecLike, MustMatchers}

import scala.concurrent.duration._
import scala.util.Try

class ClientTest extends TestKit(ActorSystem("ClientTest")) with FlatSpecLike
  with MustMatchers with ImplicitSender
  with BeforeAndAfter with BeforeAndAfterAll with BitcoinSLogger {

  "Client" must "connect to a node on the bitcoin network, " +
    "send a version message to a peer on the network and receive a version message back, then close that connection" in {
    val probe = TestProbe()

    val client = TestActorRef(Client.props,probe.ref)

    val remote = new InetSocketAddress(TestNet3.dnsSeeds(0), TestNet3.port)
    val randomPort = 23521
    //random port
    client ! Tcp.Connect(remote, Some(new InetSocketAddress(randomPort)))

    //val bound : Tcp.Bound = probe.expectMsgType[Tcp.Bound]
    val conn : Tcp.Connected = probe.expectMsgType[Tcp.Connected]

    //make sure the socket is currently bound
    Try(new ServerSocket(randomPort)).isSuccess must be (false)
    client ! Tcp.Abort
    val confirmedClosed = probe.expectMsg(Tcp.Aborted)

    //make sure the port is now available
    val boundSocket = Try(new ServerSocket(randomPort))
    boundSocket.isSuccess must be (true)

    boundSocket.get.close()

  }

  it must "bind connect to two nodes on one port" in {
    //NOTE if this test case fails it is more than likely because one of the two dns seeds
    //below is offline
    val remote1 = new InetSocketAddress(TestNet3.dnsSeeds(0), TestNet3.port)
    val remote2 = new InetSocketAddress(TestNet3.dnsSeeds(2), TestNet3.port)

    val probe1 = TestProbe()
    val probe2 = TestProbe()


    val client1 = TestActorRef(Client.props, probe1.ref)
    val client2 = TestActorRef(Client.props, probe2.ref)

    val local1 = new InetSocketAddress(TestNet3.port)
    val options = List(Inet.SO.ReuseAddress(true))
    client1 ! Tcp.Connect(remote1,Some(local1),options)


    probe1.expectMsgType[Tcp.Connected]
    client1 ! Tcp.Abort

    val local2 = new InetSocketAddress(TestNet3.port)
    client2 ! Tcp.Connect(remote2,Some(local2),options)
    probe2.expectMsgType[Tcp.Connected](5.seconds)
    client2 ! Tcp.Abort
  }

  override def afterAll: Unit = {
    TestKit.shutdownActorSystem(system)
  }


} 
Example 116
Source File: BlockActorTest.scala    From bitcoin-s-spv-node   with MIT License 5 votes vote down vote up
package org.bitcoins.spvnode.networking

import akka.actor.ActorSystem
import akka.testkit.{ImplicitSender, TestActorRef, TestKit, TestProbe}
import org.bitcoins.core.crypto.DoubleSha256Digest
import org.bitcoins.core.protocol.blockchain.BlockHeader
import org.bitcoins.core.util.{BitcoinSLogger, BitcoinSUtil}
import org.bitcoins.spvnode.messages.BlockMessage
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FlatSpecLike, MustMatchers}

import scala.concurrent.duration.DurationInt


class BlockActorTest extends TestKit(ActorSystem("BlockActorTest")) with FlatSpecLike
  with MustMatchers with ImplicitSender
  with BeforeAndAfter with BeforeAndAfterAll with BitcoinSLogger  {

  def blockActor = TestActorRef(BlockActor.props,self)
  val blockHash = DoubleSha256Digest(BitcoinSUtil.flipEndianness("00000000b873e79784647a6c82962c70d228557d24a747ea4d1b8bbe878e1206"))

  "BlockActor" must "be able to send a GetBlocksMessage then receive that block back" in {
    blockActor ! blockHash
    val blockMsg = expectMsgType[BlockMessage](10.seconds)
    blockMsg.block.blockHeader.hash must be (blockHash)

  }


  it must "be able to request a block from it's block header" in {
    val blockHeader = BlockHeader("0100000043497fd7f826957108f4a30fd9cec3aeba79972084e90ead01ea330900000000bac8b0fa927c0ac8234287e33c5f74d38d354820e24756ad709d7038fc5f31f020e7494dffff001d03e4b672")
    blockActor ! blockHeader
    val blockMsg = expectMsgType[BlockMessage](10.seconds)
    blockMsg.block.blockHeader.hash must be (blockHash)
  }


  override def afterAll = {
    TestKit.shutdownActorSystem(system)
  }
} 
Example 117
Source File: BlockHeaderStoreTest.scala    From bitcoin-s-spv-node   with MIT License 5 votes vote down vote up
package org.bitcoins.spvnode.store

import org.bitcoins.core.gen.BlockchainElementsGenerator
import org.scalatest.{BeforeAndAfter, FlatSpec, MustMatchers}


class BlockHeaderStoreTest extends FlatSpec with MustMatchers with BeforeAndAfter {
  val testFile = new java.io.File("src/test/resources/block_header.dat")
  "BlockHeaderStore" must "write and then read a block header from the database" in {
    val blockHeader = BlockchainElementsGenerator.blockHeader.sample.get
    BlockHeaderStore.append(Seq(blockHeader),testFile)
    val headersFromFile = BlockHeaderStore.read(testFile)

    headersFromFile must be (Seq(blockHeader))
  }


  it must "write one blockheader to the file, then append another header to the file, then read them both" in {
    val blockHeader1 = BlockchainElementsGenerator.blockHeader.sample.get
    val blockHeader2 = BlockchainElementsGenerator.blockHeader.sample.get
    BlockHeaderStore.append(Seq(blockHeader1),testFile)
    val headersFromFile1 = BlockHeaderStore.read(testFile)
    headersFromFile1 must be (Seq(blockHeader1))

    BlockHeaderStore.append(Seq(blockHeader2),testFile)
    val headersFromFile2 = BlockHeaderStore.read(testFile)
    headersFromFile2 must be (Seq(blockHeader1, blockHeader2))
  }


  after {
    testFile.delete()
  }

} 
Example 118
Source File: UpdateBloomFilterTest.scala    From bitcoin-s   with MIT License 5 votes vote down vote up
package org.bitcoins.node

import org.bitcoins.core.currency._
import org.bitcoins.server.BitcoinSAppConfig
import org.bitcoins.testkit.BitcoinSTestAppConfig
import org.bitcoins.testkit.node.{NodeUnitTest, SpvNodeFundedWalletBitcoind}
import org.scalatest.{BeforeAndAfter, FutureOutcome}

class UpdateBloomFilterTest extends NodeUnitTest with BeforeAndAfter {

  
  implicit override protected def config: BitcoinSAppConfig =
    BitcoinSTestAppConfig.getSpvWithEmbeddedDbTestConfig(pgUrl)

  override type FixtureParam = SpvNodeFundedWalletBitcoind

  def withFixture(test: OneArgAsyncTest): FutureOutcome = {
    withSpvNodeFundedWalletBitcoind(test, NodeCallbacks.empty, None)
  }

  it must "update the bloom filter with a TX" in { param =>
    val SpvNodeFundedWalletBitcoind(spv, wallet, rpc, _) = param

    for {
      _ <- wallet.getBloomFilter()
      tx <- wallet.sendToAddress(junkAddress, 5.bitcoin, None)
      updatedBloom <- spv.updateBloomFilter(tx).map(_.bloomFilter)
      _ = assert(updatedBloom.contains(tx.txId))
      _ <- rpc.broadcastTransaction(tx)

      // this should confirm our TX
      // since we updated the bloom filter
      hash <- rpc.generateToAddress(1, junkAddress).map(_.head)

      merkleBlock <- rpc.getTxOutProof(Vector(tx.txIdBE), hash)
      txs <- rpc.verifyTxOutProof(merkleBlock)

    } yield assert(txs.contains(tx.txIdBE))
  }

  it must "update the bloom filter with an address" in { param =>
    val SpvNodeFundedWalletBitcoind(spv, wallet, rpc, _) = param

    for {
      _ <- wallet.getBloomFilter()

      address <- wallet.getNewAddress()
      updatedBloom <- spv.updateBloomFilter(address).map(_.bloomFilter)
      hash <- rpc.sendToAddress(address, 1.bitcoin)
      tx <- rpc.getRawTransactionRaw(hash)
    } yield assert(updatedBloom.isRelevant(tx))
  }
} 
Example 119
Source File: StateStoreSpec.scala    From incubator-livy   with Apache License 2.0 5 votes vote down vote up
package org.apache.livy.server.recovery

import scala.reflect.classTag

import org.scalatest.{BeforeAndAfter, FunSpec}
import org.scalatest.Matchers._

import org.apache.livy.{LivyBaseUnitTestSuite, LivyConf}
import org.apache.livy.sessions.SessionManager

class StateStoreSpec extends FunSpec with BeforeAndAfter with LivyBaseUnitTestSuite {
  describe("StateStore") {
    after {
      StateStore.cleanup()
    }

    def createConf(stateStore: String): LivyConf = {
      val conf = new LivyConf()
      conf.set(LivyConf.RECOVERY_MODE.key, SessionManager.SESSION_RECOVERY_MODE_RECOVERY)
      conf.set(LivyConf.RECOVERY_STATE_STORE.key, stateStore)
      conf
    }

    it("should throw an error on get if it's not initialized") {
      intercept[AssertionError] { StateStore.get }
    }

    it("should initialize blackhole state store if recovery is disabled") {
      StateStore.init(new LivyConf())
      StateStore.get shouldBe a[BlackholeStateStore]
    }

    it("should pick the correct store according to state store config") {
      StateStore.pickStateStore(createConf("filesystem")) shouldBe classOf[FileSystemStateStore]
      StateStore.pickStateStore(createConf("zookeeper")) shouldBe classOf[ZooKeeperStateStore]
    }

    it("should return error if an unknown recovery mode is set") {
      val conf = new LivyConf()
      conf.set(LivyConf.RECOVERY_MODE.key, "unknown")
      intercept[IllegalArgumentException] { StateStore.init(conf) }
    }

    it("should return error if an unknown state store is set") {
      intercept[IllegalArgumentException] { StateStore.init(createConf("unknown")) }
    }
  }
} 
Example 120
Source File: SessionSpec.scala    From incubator-livy   with Apache License 2.0 5 votes vote down vote up
package org.apache.livy.repl

import java.util.Properties
import java.util.concurrent.{ConcurrentLinkedQueue, CountDownLatch, TimeUnit}

import org.apache.spark.SparkConf
import org.scalatest.{BeforeAndAfter, FunSpec}
import org.scalatest.Matchers._
import org.scalatest.concurrent.Eventually
import org.scalatest.time._

import org.apache.livy.LivyBaseUnitTestSuite
import org.apache.livy.repl.Interpreter.ExecuteResponse
import org.apache.livy.rsc.RSCConf
import org.apache.livy.sessions._

class SessionSpec extends FunSpec with Eventually with LivyBaseUnitTestSuite with BeforeAndAfter {
  override implicit val patienceConfig =
    PatienceConfig(timeout = scaled(Span(30, Seconds)), interval = scaled(Span(100, Millis)))

  private val rscConf = new RSCConf(new Properties()).set(RSCConf.Entry.SESSION_KIND, "spark")

  describe("Session") {
    var session: Session = null

    after {
      if (session != null) {
        session.close()
        session = null
      }
    }

    it("should call state changed callbacks in happy path") {
      val expectedStateTransitions =
        Array("not_started", "starting", "idle", "busy", "idle", "busy", "idle")
      val actualStateTransitions = new ConcurrentLinkedQueue[String]()

      session = new Session(rscConf, new SparkConf(), None,
        { s => actualStateTransitions.add(s.toString) })
      session.start()
      session.execute("")

      eventually {
        actualStateTransitions.toArray shouldBe expectedStateTransitions
      }
    }

    it("should not transit to idle if there're any pending statements.") {
      val expectedStateTransitions =
        Array("not_started", "starting", "idle", "busy", "busy", "busy", "idle", "busy", "idle")
      val actualStateTransitions = new ConcurrentLinkedQueue[String]()

      val blockFirstExecuteCall = new CountDownLatch(1)
      val interpreter = new SparkInterpreter(new SparkConf()) {
        override def execute(code: String): ExecuteResponse = {
          blockFirstExecuteCall.await(10, TimeUnit.SECONDS)
          super.execute(code)
        }
      }
      session = new Session(rscConf, new SparkConf(), Some(interpreter),
        { s => actualStateTransitions.add(s.toString) })
      session.start()

      for (_ <- 1 to 2) {
        session.execute("")
      }

      blockFirstExecuteCall.countDown()
      eventually {
        actualStateTransitions.toArray shouldBe expectedStateTransitions
      }
    }

    it("should remove old statements when reaching threshold") {
      rscConf.set(RSCConf.Entry.RETAINED_STATEMENTS, 2)
      session = new Session(rscConf, new SparkConf())
      session.start()

      session.statements.size should be (0)
      session.execute("")
      session.statements.size should be (1)
      session.statements.map(_._1).toSet should be (Set(0))
      session.execute("")
      session.statements.size should be (2)
      session.statements.map(_._1).toSet should be (Set(0, 1))
      session.execute("")
      eventually {
        session.statements.size should be (2)
        session.statements.map(_._1).toSet should be (Set(1, 2))
      }

      // Continue submitting statements, total statements in memory should be 2.
      session.execute("")
      eventually {
        session.statements.size should be (2)
        session.statements.map(_._1).toSet should be (Set(2, 3))
      }
    }
  }
} 
Example 121
Source File: SparkJobConfTest.scala    From spark-bench   with Apache License 2.0 5 votes vote down vote up
package com.ibm.sparktc.sparkbench.sparklaunch

import java.io.File

import com.ibm.sparktc.sparkbench.sparklaunch.confparse.{ConfigWrangler, SparkJobConf}
import com.ibm.sparktc.sparkbench.utils.SparkBenchException
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class SparkJobConfTest extends FlatSpec with Matchers with BeforeAndAfter {

  private def setEnv(key: String, value: String) = {
    val field = System.getenv().getClass.getDeclaredField("m")
    field.setAccessible(true)
    val map = field.get(System.getenv()).asInstanceOf[java.util.Map[java.lang.String, java.lang.String]]
    map.put(key, value)
  }

  private def unsetEnv(key: String) = {
    val field = System.getenv().getClass.getDeclaredField("m")
    field.setAccessible(true)
    val map = field.get(System.getenv()).asInstanceOf[java.util.Map[java.lang.String, java.lang.String]]
    map.remove(key)
  }

  val sparkHome = sys.env.get("SPARK_HOME")
  val masterHost = "local[2]"

  before {
    SparkSession.clearDefaultSession()
    SparkSession.clearActiveSession()
    if(sparkHome.isEmpty) throw SparkBenchException("SPARK_HOME needs to be set")
  }

  after {
    if(sparkHome.nonEmpty) setEnv("SPARK_HOME", sparkHome.get)
    setEnv("SPARK_MASTER_HOST", masterHost)
  }

  "SparkLaunchConf" should "turn into arguments properly" in {

    val relativePath = "/etc/sparkConfTest.conf"
    val resource = new File(getClass.getResource(relativePath).toURI)
//    val source = scala.io.Source.fromFile(resource)
    val (sparkContextConfs, _) = SparkLaunch.mkConfs(resource)
    val conf1 = sparkContextConfs.head

    val expectedSparkConfs = Map(
      "spark.shuffle.service.enabled" -> "false",
      "spark.fake" -> "yes",
      "spark.dynamicAllocation.enabled" -> "false"
    )

    conf1.sparkConfs shouldBe expectedSparkConfs
    conf1.sparkArgs.contains("master") shouldBe true

  }

  it should "not blow up when spark context confs are left out" in {
    val relativePath = "/etc/noMasterConf.conf"
    val oldValue = unsetEnv("SPARK_MASTER_HOST")
    setEnv("SPARK_MASTER_HOST", "local[2]")
    val resource = new File(getClass.getResource(relativePath).toURI)
    val (sparkContextConfs, _) = SparkLaunch.mkConfs(resource)
    val conf2 = sparkContextConfs.head

    conf2.sparkConfs.isEmpty shouldBe true
    conf2.sparkArgs.contains("master") shouldBe true
    setEnv("SPARK_MASTER_HOST", masterHost)

  }

  it should "pick up spark-home as set in the config file" in {
    val oldSparkHome = unsetEnv("SPARK_HOME")
    val relativePath = "/etc/specific-spark-home.conf"
    val resource = new File(getClass.getResource(relativePath).toURI)
    val (sparkContextConfs, _) = SparkLaunch.mkConfs(resource)
    val conf2 = sparkContextConfs.head

    ConfigWrangler.isSparkSubmit(conf2.submissionParams) shouldBe true
    conf2. submissionParams("spark-home") shouldBe "/usr/iop/current/spark2-client/"

    if(sparkHome.nonEmpty) setEnv("SPARK_HOME", sparkHome.get)
  }

  it should "pick up the livy submission parameters" in {
    val oldSparkHome = unsetEnv("SPARK_HOME")
    val relativePath = "/etc/livy-example.conf"
    val resource = new File(getClass.getResource(relativePath).toURI)
    val (sparkContextConfs, _) = SparkLaunch.mkConfs(resource)
    val conf2: SparkJobConf = sparkContextConfs.head

    ConfigWrangler.isLivySubmit(conf2.submissionParams) shouldBe true
    conf2.sparkBenchJar shouldBe "hdfs:///opt/spark-bench.jar"

    if(sparkHome.nonEmpty) setEnv("SPARK_HOME", sparkHome.get)
  }

} 
Example 122
Source File: ProtocolSpec.scala    From spark-summit-2018   with GNU General Public License v3.0 5 votes vote down vote up
package com.twilio.open.streaming.trend.discovery

import com.googlecode.protobuf.format.JsonFormat
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}
import com.twilio.open.protocol.Calls.{CallEventType, CallState, CallStateEvent, Carrier, Country, Direction, PddEvent, SignalingEvent, SignalingEventType, CallEvent => CallEventProto, Dimensions => DimensionsProto}

class ProtocolSpec extends FlatSpec with Matchers with BeforeAndAfter {

  // example using protobuf ser/deser with Message and sub-Message
  "CallEvent with Signaling Event" should " serialize to json " in {

    val signalingEvent = SignalingEvent.newBuilder()
      .setEventType(SignalingEventType.call_state)
      .setName("progress")
      .setCallState(CallStateEvent.newBuilder()
        .setState(CallState.initialized).build())
      .build()

    val ce = CallEventProto.newBuilder()
      .setEventId("CE1234")
      .setEventTime(1527966284972L)
      .setLoggedEventTime(1527966304972L)
      .setEventType(CallEventType.signaling_event)
      .setSignalingEvent(signalingEvent)
      .setEventDimensions(DimensionsProto.newBuilder()
        .setCarrier(Carrier.telco_a)
        .setCountry(Country.us)
        .setDirection(Direction.outbound)
        .build()
      ).build()

    val jsonOutput = new JsonFormat().printToString(ce)
    jsonOutput shouldEqual "{\"event_time\": 1527966284972,\"logged_event_time\": 1527966304972,\"event_id\": \"CE1234\",\"event_type\": \"signaling_event\",\"signaling_event\": {\"name\": \"progress\",\"event_type\": \"call_state\",\"call_state\": {\"state\": \"initialized\"}},\"event_dimensions\": {\"country\": \"us\",\"direction\": \"outbound\",\"carrier\": \"telco_a\"}}"

  }

  "CallEvent with Signaling Event" should " for pdd type and serialize to json " in {

    val signalingEvent = SignalingEvent.newBuilder()
      .setEventType(SignalingEventType.pdd)
      .setName("preflight").setPdd(PddEvent.newBuilder().setPdd(2.8f).build())
      .build()

    val ce = CallEventProto.newBuilder()
      .setEventId("CE1234")
      .setEventTime(1527966284972L)
      .setLoggedEventTime(1527966304972L)
      .setEventType(CallEventType.signaling_event)
      .setSignalingEvent(signalingEvent)
      .setRouteId("RI123")
      .setEventDimensions(DimensionsProto.newBuilder()
        .setCarrier(Carrier.telco_a)
        .setCountry(Country.us)
        .setDirection(Direction.outbound)
        .build()
      ).build()

    val jsonOutput = new JsonFormat().printToString(ce)
    jsonOutput shouldEqual "{\"event_time\": 1527966284972,\"logged_event_time\": 1527966304972,\"event_id\": \"CE1234\",\"route_id\": \"RI123\",\"event_type\": \"signaling_event\",\"signaling_event\": {\"name\": \"preflight\",\"event_type\": \"pdd\",\"pdd\": {\"pdd\": 2.8}},\"event_dimensions\": {\"country\": \"us\",\"direction\": \"outbound\",\"carrier\": \"telco_a\"}}"

  }

} 
Example 123
Source File: FlumeStreamSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.language.postfixOps

import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}

class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
  val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
  var ssc: StreamingContext = null

  test("flume input stream") {
    testFlumeStream(testCompression = false)
  }

  test("flume input compressed stream") {
    testFlumeStream(testCompression = true)
  }

  
  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }
} 
Example 124
Source File: ResolveInlineTablesSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.{LongType, NullType}


class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables.validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables.validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables.convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables.convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables.convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 125
Source File: RowDataSourceStrategySuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE TEMPORARY TABLE inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
      """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 126
Source File: AggregateHashMapSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false",
      "configuration parameter changed in test body")
  }
}

class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true",
      "configuration parameter changed in test body")
  }
}

class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with
BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
    sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true",
      "configuration parameter changed in test body")
  }
} 
Example 127
Source File: ExtensionServiceIntegrationSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging


  before {
    val sparkConf = new SparkConf()
    sparkConf.set(SCHEDULER_SERVICES, Seq(classOf[SimpleExtensionService].getName()))
    sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite")
    sc = new SparkContext(sparkConf)
  }

  test("Instantiate") {
    val services = new SchedulerExtensionServices()
    assertResult(Nil, "non-nil service list") {
      services.getServices
    }
    services.start(SchedulerExtensionServiceBinding(sc, applicationId))
    services.stop()
  }

  test("Contains SimpleExtensionService Service") {
    val services = new SchedulerExtensionServices()
    try {
      services.start(SchedulerExtensionServiceBinding(sc, applicationId))
      val serviceList = services.getServices
      assert(serviceList.nonEmpty, "empty service list")
      val (service :: Nil) = serviceList
      val simpleService = service.asInstanceOf[SimpleExtensionService]
      assert(simpleService.started.get, "service not started")
      services.stop()
      assert(!simpleService.started.get, "service not stopped")
    } finally {
      services.stop()
    }
  }
} 
Example 128
Source File: FailureSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  private val batchDuration: Duration = Milliseconds(1000)
  private val numBatches = 30
  private var directory: File = null

  before {
    directory = Utils.createTempDir()
  }

  after {
    if (directory != null) {
      Utils.deleteRecursively(directory)
    }
    StreamingContext.getActive().foreach { _.stop() }

    // Stop SparkContext if active
    SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("bla")).stop()
  }

  test("multiple failures with map") {
    MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration)
  }

  test("multiple failures with updateStateByKey") {
    MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration)
  }
} 
Example 129
Source File: InputInfoTrackerSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.streaming.{Duration, StreamingContext, Time}

class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter {

  private var ssc: StreamingContext = _

  before {
    val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker")
    if (ssc == null) {
      ssc = new StreamingContext(conf, Duration(1000))
    }
  }

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
  }

  test("test report and get InputInfo from InputInfoTracker") {
    val inputInfoTracker = new InputInfoTracker(ssc)

    val streamId1 = 0
    val streamId2 = 1
    val time = Time(0L)
    val inputInfo1 = StreamInputInfo(streamId1, 100L)
    val inputInfo2 = StreamInputInfo(streamId2, 300L)
    inputInfoTracker.reportInfo(time, inputInfo1)
    inputInfoTracker.reportInfo(time, inputInfo2)

    val batchTimeToInputInfos = inputInfoTracker.getInfo(time)
    assert(batchTimeToInputInfos.size == 2)
    assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2))
    assert(batchTimeToInputInfos(streamId1) === inputInfo1)
    assert(batchTimeToInputInfos(streamId2) === inputInfo2)
    assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1)
  }

  test("test cleanup InputInfo from InputInfoTracker") {
    val inputInfoTracker = new InputInfoTracker(ssc)

    val streamId1 = 0
    val inputInfo1 = StreamInputInfo(streamId1, 100L)
    val inputInfo2 = StreamInputInfo(streamId1, 300L)
    inputInfoTracker.reportInfo(Time(0), inputInfo1)
    inputInfoTracker.reportInfo(Time(1), inputInfo2)

    inputInfoTracker.cleanup(Time(0))
    assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1)
    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)

    inputInfoTracker.cleanup(Time(1))
    assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None)
    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)
  }
} 
Example 130
Source File: SparkListenerWithClusterSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import scala.collection.mutable

import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.cluster.ExecutorInfo


  val WAIT_TIMEOUT_MILLIS = 10000

  before {
    sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite")
  }

  test("SparkListener sends executor added message") {
    val listener = new SaveExecutorInfo
    sc.addSparkListener(listener)

    // This test will check if the number of executors received by "SparkListener" is same as the
    // number of all executors, so we need to wait until all executors are up
    sc.jobProgressListener.waitUntilExecutorsUp(2, 60000)

    val rdd1 = sc.parallelize(1 to 100, 4)
    val rdd2 = rdd1.map(_.toString)
    rdd2.setName("Target RDD")
    rdd2.count()

    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
    assert(listener.addedExecutorInfo.size == 2)
    assert(listener.addedExecutorInfo("0").totalCores == 1)
    assert(listener.addedExecutorInfo("1").totalCores == 1)
  }

  private class SaveExecutorInfo extends SparkListener {
    val addedExecutorInfo = mutable.Map[String, ExecutorInfo]()

    override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
      addedExecutorInfo(executor.executorId) = executor.executorInfo
    }
  }
} 
Example 131
Source File: BlockReplicationPolicySuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.collection.mutable

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.{LocalSparkContext, SparkFunSuite}

class BlockReplicationPolicySuite extends SparkFunSuite
  with Matchers
  with BeforeAndAfter
  with LocalSparkContext {

  // Implicitly convert strings to BlockIds for test clarity.
  private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)

  
  test(s"block replication - random block replication policy") {
    val numBlockManagers = 10
    val storeSize = 1000
    val blockManagers = (1 to numBlockManagers).map { i =>
      BlockManagerId(s"store-$i", "localhost", 1000 + i, None)
    }
    val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None)
    val replicationPolicy = new RandomBlockReplicationPolicy
    val blockId = "test-block"

    (1 to 10).foreach {numReplicas =>
      logDebug(s"Num replicas : $numReplicas")
      val randomPeers = replicationPolicy.prioritize(
        candidateBlockManager,
        blockManagers,
        mutable.HashSet.empty[BlockManagerId],
        blockId,
        numReplicas
      )
      logDebug(s"Random peers : ${randomPeers.mkString(", ")}")
      assert(randomPeers.toSet.size === numReplicas)

      // choosing n peers out of n
      val secondPass = replicationPolicy.prioritize(
        candidateBlockManager,
        randomPeers,
        mutable.HashSet.empty[BlockManagerId],
        blockId,
        numReplicas
      )
      logDebug(s"Random peers : ${secondPass.mkString(", ")}")
      assert(secondPass.toSet.size === numReplicas)
    }

  }

} 
Example 132
Source File: TopologyMapperSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileOutputStream}

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark._
import org.apache.spark.util.Utils

class TopologyMapperSuite  extends SparkFunSuite
  with Matchers
  with BeforeAndAfter
  with LocalSparkContext {

  test("File based Topology Mapper") {
    val numHosts = 100
    val numRacks = 4
    val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap
    val propsFile = createPropertiesFile(props)

    val sparkConf = (new SparkConf(false))
    sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath)
    val topologyMapper = new FileBasedTopologyMapper(sparkConf)

    props.foreach {case (host, topology) =>
      val obtainedTopology = topologyMapper.getTopologyForHost(host)
      assert(obtainedTopology.isDefined)
      assert(obtainedTopology.get === topology)
    }

    // we get None for hosts not in the file
    assert(topologyMapper.getTopologyForHost("host").isEmpty)

    cleanup(propsFile)
  }

  def createPropertiesFile(props: Map[String, String]): File = {
    val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile
    val fileOS = new FileOutputStream(testFile)
    props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)}
    fileOS.close
    testFile
  }

  def cleanup(testFile: File): Unit = {
    testFile.getParentFile.listFiles.filter { file =>
      file.getName.startsWith(testFile.getName)
    }.foreach { _.delete() }
  }

} 
Example 133
Source File: LocalDirsSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.{SparkConfWithEnv, Utils}


class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter {

  before {
    Utils.clearLocalRootDirs()
  }

  test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") {
    // Regression test for SPARK-2974
    assert(!new File("/NONEXISTENT_DIR").exists())
    val conf = new SparkConf(false)
      .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

  test("SPARK_LOCAL_DIRS override also affects driver") {
    // Regression test for SPARK-2975
    assert(!new File("/NONEXISTENT_DIR").exists())
    // spark.local.dir only contains invalid directories, but that's not a problem since
    // SPARK_LOCAL_DIRS will override it on both the driver and workers:
    val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir")))
      .set("spark.local.dir", "/NONEXISTENT_PATH")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

} 
Example 134
Source File: JdbcRDDSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.sql._

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils

class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  before {
    Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver")
    val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
    try {

      try {
        val create = conn.createStatement
        create.execute("""
          CREATE TABLE FOO(
            ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
            DATA INTEGER
          )""")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
        (1 to 100).foreach { i =>
          insert.setInt(1, i * 2)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

      try {
        val create = conn.createStatement
        create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)")
        (1 to 100).foreach { i =>
          insert.setLong(1, 100000000000000000L +  4000000000000000L * i)
          insert.setInt(2, i)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

    } finally {
      conn.close()
    }
  }

  test("basic functionality") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
      1, 100, 3,
      (r: ResultSet) => { r.getInt(1) } ).cache()

    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 10100)
  }

  test("large id overflow") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?",
      1131544775L, 567279358897692673L, 20,
      (r: ResultSet) => { r.getInt(1) } ).cache()
    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 5050)
  }

  after {
    try {
      DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
    } catch {
      case se: SQLException if se.getSQLState == "08006" =>
        // Normal single database shutdown
        // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
    }
  }
} 
Example 135
Source File: FutureActionSuite.scala    From sparkoscope   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.util.ThreadUtils


class FutureActionSuite
  extends SparkFunSuite
  with BeforeAndAfter
  with Matchers
  with LocalSparkContext {

  before {
    sc = new SparkContext("local", "FutureActionSuite")
  }

  test("simple async action") {
    val rdd = sc.parallelize(1 to 10, 2)
    val job = rdd.countAsync()
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (10)
    job.jobIds.size should be (1)
  }

  test("complex async action") {
    val rdd = sc.parallelize(1 to 15, 3)
    val job = rdd.takeAsync(10)
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (1 to 10)
    job.jobIds.size should be (2)
  }

} 
Example 136
Source File: MetricsConfigSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.metrics

import org.scalatest.{BeforeAndAfter, FunSuite}

class MetricsConfigSuite extends FunSuite with BeforeAndAfter {
  var filePath: String = _

  before {
    filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile()
  }

  test("MetricsConfig with default properties") {
    val conf = new MetricsConfig(None)
    conf.initialize()

    assert(conf.properties.size() === 4)
    assert(conf.properties.getProperty("test-for-dummy") === null)

    val property = conf.getInstance("random")
    assert(property.size() === 2)
    assert(property.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet")
    assert(property.getProperty("sink.servlet.path") === "/metrics/json")
  }

  test("MetricsConfig with properties set") {
    val conf = new MetricsConfig(Option(filePath))
    conf.initialize()

    val masterProp = conf.getInstance("master")
    assert(masterProp.size() === 5)
    assert(masterProp.getProperty("sink.console.period") === "20")
    assert(masterProp.getProperty("sink.console.unit") === "minutes")
    assert(masterProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource")
    assert(masterProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet")
    assert(masterProp.getProperty("sink.servlet.path") === "/metrics/master/json")

    val workerProp = conf.getInstance("worker")
    assert(workerProp.size() === 5)
    assert(workerProp.getProperty("sink.console.period") === "10")
    assert(workerProp.getProperty("sink.console.unit") === "seconds")
    assert(workerProp.getProperty("source.jvm.class") === "org.apache.spark.metrics.source.JvmSource")
    assert(workerProp.getProperty("sink.servlet.class") === "org.apache.spark.metrics.sink.MetricsServlet")
    assert(workerProp.getProperty("sink.servlet.path") === "/metrics/json")
  }

  test("MetricsConfig with subProperties") {
    val conf = new MetricsConfig(Option(filePath))
    conf.initialize()

    val propCategories = conf.propertyCategories
    assert(propCategories.size === 3)

    val masterProp = conf.getInstance("master")
    val sourceProps = conf.subProperties(masterProp, MetricsSystem.SOURCE_REGEX)
    assert(sourceProps.size === 1)
    assert(sourceProps("jvm").getProperty("class") === "org.apache.spark.metrics.source.JvmSource")

    val sinkProps = conf.subProperties(masterProp, MetricsSystem.SINK_REGEX)
    assert(sinkProps.size === 2)
    assert(sinkProps.contains("console"))
    assert(sinkProps.contains("servlet"))

    val consoleProps = sinkProps("console")
    assert(consoleProps.size() === 2)

    val servletProps = sinkProps("servlet")
    assert(servletProps.size() === 2)
  }
} 
Example 137
Source File: CacheManagerSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.mock.MockitoSugar

import org.apache.spark.executor.DataReadMethod
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._

// TODO: Test the CacheManager's thread-safety aspects
class CacheManagerSuite extends FunSuite with LocalSparkContext with BeforeAndAfter
  with MockitoSugar {

  var blockManager: BlockManager = _
  var cacheManager: CacheManager = _
  var split: Partition = _
  
  var rdd: RDD[Int] = _
  var rdd2: RDD[Int] = _
  var rdd3: RDD[Int] = _

  before {
    sc = new SparkContext("local", "test")
    blockManager = mock[BlockManager]
    cacheManager = new CacheManager(blockManager)
    split = new Partition { override def index: Int = 0 }
    rdd = new RDD[Int](sc, Nil) {
      override def getPartitions: Array[Partition] = Array(split)
      override val getDependencies = List[Dependency[_]]()
      override def compute(split: Partition, context: TaskContext) = Array(1, 2, 3, 4).iterator
    }
    rdd2 = new RDD[Int](sc, List(new OneToOneDependency(rdd))) {
      override def getPartitions: Array[Partition] = firstParent[Int].partitions
      override def compute(split: Partition, context: TaskContext) =
        firstParent[Int].iterator(split, context)
    }.cache()
    rdd3 = new RDD[Int](sc, List(new OneToOneDependency(rdd2))) {
      override def getPartitions: Array[Partition] = firstParent[Int].partitions
      override def compute(split: Partition, context: TaskContext) =
        firstParent[Int].iterator(split, context)
    }.cache()
  }

  test("get uncached rdd") {
    // Do not mock this test, because attempting to match Array[Any], which is not covariant,
    // in blockManager.put is a losing battle. You have been warned.
    blockManager = sc.env.blockManager
    cacheManager = sc.env.cacheManager
    val context = new TaskContextImpl(0, 0, 0, 0)
    val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
    val getValue = blockManager.get(RDDBlockId(rdd.id, split.index))
    assert(computeValue.toList === List(1, 2, 3, 4))
    assert(getValue.isDefined, "Block cached from getOrCompute is not found!")
    assert(getValue.get.data.toList === List(1, 2, 3, 4))
  }

  test("get cached rdd") {
    val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12)
    when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result))

    val context = new TaskContextImpl(0, 0, 0, 0)
    val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
    assert(value.toList === List(5, 6, 7))
  }

  test("get uncached local rdd") {
    // Local computation should not persist the resulting value, so don't expect a put().
    when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None)

    val context = new TaskContextImpl(0, 0, 0, 0, true)
    val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY)
    assert(value.toList === List(1, 2, 3, 4))
  }

  test("verify task metrics updated correctly") {
    cacheManager = sc.env.cacheManager
    val context = new TaskContextImpl(0, 0, 0, 0)
    cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY)
    assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2)
  }
} 
Example 138
Source File: TaskContextSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.mockito.Mockito._
import org.mockito.Matchers.any

import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener}


class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {

  test("calls TaskCompletionListener after failure") {
    TaskContextSuite.completed = false
    sc = new SparkContext("local", "test")
    val rdd = new RDD[String](sc, List()) {
      override def getPartitions = Array[Partition](StubPartition(0))
      override def compute(split: Partition, context: TaskContext) = {
        context.addTaskCompletionListener(context => TaskContextSuite.completed = true)
        sys.error("failed")
      }
    }
    val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
    val func = (c: TaskContext, i: Iterator[String]) => i.next()
    val task = new ResultTask[String, String](
      0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0)
    intercept[RuntimeException] {
      task.run(0, 0)
    }
    assert(TaskContextSuite.completed === true)
  }

  test("all TaskCompletionListeners should be called even if some fail") {
    val context = new TaskContextImpl(0, 0, 0, 0)
    val listener = mock(classOf[TaskCompletionListener])
    context.addTaskCompletionListener(_ => throw new Exception("blah"))
    context.addTaskCompletionListener(listener)
    context.addTaskCompletionListener(_ => throw new Exception("blah"))

    intercept[TaskCompletionListenerException] {
      context.markTaskCompleted()
    }

    verify(listener, times(1)).onTaskCompletion(any())
  }

  test("TaskContext.attemptNumber should return attempt number, not task id (SPARK-4014)") {
    sc = new SparkContext("local[1,2]", "test")  // use maxRetries = 2 because we test failed tasks
    // Check that attemptIds are 0 for all tasks' initial attempts
    val attemptIds = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      Seq(TaskContext.get().attemptNumber).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0))

    // Test a job with failed tasks
    val attemptIdsWithFailedTask = sc.parallelize(Seq(1, 2), 2).mapPartitions { iter =>
      val attemptId = TaskContext.get().attemptNumber
      if (iter.next() == 1 && attemptId == 0) {
        throw new Exception("First execution of task failed")
      }
      Seq(attemptId).iterator
    }.collect()
    assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
  }

  test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") {
    sc = new SparkContext("local", "test")
    val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter =>
      Seq(TaskContext.get().attemptId).iterator
    }.collect()
    assert(attemptIds.toSet === Set(0, 1, 2, 3))
  }
}

private object TaskContextSuite {
  @volatile var completed = false
}

private case class StubPartition(index: Int) extends Partition 
Example 139
Source File: SparkListenerWithClusterSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.{SparkContext, LocalSparkContext}

import org.scalatest.{FunSuite, BeforeAndAfter, BeforeAndAfterAll}

import scala.collection.mutable


  val WAIT_TIMEOUT_MILLIS = 10000

  before {
    sc = new SparkContext("local-cluster[2,1,512]", "SparkListenerSuite")
  }

  test("SparkListener sends executor added message") {
    val listener = new SaveExecutorInfo
    sc.addSparkListener(listener)

    val rdd1 = sc.parallelize(1 to 100, 4)
    val rdd2 = rdd1.map(_.toString)
    rdd2.setName("Target RDD")
    rdd2.count()

    assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
    assert(listener.addedExecutorInfo.size == 2)
    assert(listener.addedExecutorInfo("0").totalCores == 1)
    assert(listener.addedExecutorInfo("1").totalCores == 1)
  }

  private class SaveExecutorInfo extends SparkListener {
    val addedExecutorInfo = mutable.Map[String, ExecutorInfo]()

    override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
      addedExecutorInfo(executor.executorId) = executor.executorInfo
    }
  }
} 
Example 140
Source File: JdbcRDDSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.sql._

import org.scalatest.{BeforeAndAfter, FunSuite}

import org.apache.spark.{LocalSparkContext, SparkContext}

class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {

  before {
    Class.forName("org.apache.derby.jdbc.EmbeddedDriver")
    val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
    try {
      val create = conn.createStatement
      create.execute("""
        CREATE TABLE FOO(
          ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
          DATA INTEGER
        )""")
      create.close()
      val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
      (1 to 100).foreach { i =>
        insert.setInt(1, i * 2)
        insert.executeUpdate
      }
      insert.close()
    } catch {
      case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
    } finally {
      conn.close()
    }
  }

  test("basic functionality") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
      1, 100, 3,
      (r: ResultSet) => { r.getInt(1) } ).cache()

    assert(rdd.count === 100)
    assert(rdd.reduce(_+_) === 10100)
  }

  after {
    try {
      DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
    } catch {
      case se: SQLException if se.getSQLState == "08006" =>
        // Normal single database shutdown
        // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
    }
  }
} 
Example 141
Source File: FutureActionSuite.scala    From SparkCore   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.Await
import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, FunSuite, Matchers}


class FutureActionSuite extends FunSuite with BeforeAndAfter with Matchers with LocalSparkContext {

  before {
    sc = new SparkContext("local", "FutureActionSuite")
  }

  test("simple async action") {
    val rdd = sc.parallelize(1 to 10, 2)
    val job = rdd.countAsync()
    val res = Await.result(job, Duration.Inf)
    res should be (10)
    job.jobIds.size should be (1)
  }

  test("complex async action") {
    val rdd = sc.parallelize(1 to 15, 3)
    val job = rdd.takeAsync(10)
    val res = Await.result(job, Duration.Inf)
    res should be (1 to 10)
    job.jobIds.size should be (2)
  }

} 
Example 142
Source File: TestSQLSuite.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.view.rewrite

import org.apache.carbondata.view.MVCatalogInSpark
import org.apache.carbondata.view.testutil.ModularPlanTest
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.optimizer.MVRewrite
import org.scalatest.BeforeAndAfter

class TestSQLSuite extends ModularPlanTest with BeforeAndAfter { 
  import org.apache.carbondata.view.rewrite.matching.TestSQLBatch._

  val spark = sqlContext
  val testHive = sqlContext.sparkSession

  ignore("protypical mqo rewrite test") {
    
    hiveClient.runSqlHive(
        s"""
           |CREATE TABLE if not exists Fact (
           |  `tid`     int,
           |  `fpgid`   int,
           |  `flid`    int,
           |  `date`    timestamp,
           |  `faid`    int,
           |  `price`   double,
           |  `qty`     int,
           |  `disc`    string
           |)
           |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
           |STORED AS TEXTFILE       
        """.stripMargin.trim
        )
        
    hiveClient.runSqlHive(
        s"""
           |CREATE TABLE if not exists Dim (
           |  `lid`     int,
           |  `city`    string,
           |  `state`   string,
           |  `country` string
           |)
           |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
           |STORED AS TEXTFILE   
        """.stripMargin.trim
        )    
        
    hiveClient.runSqlHive(
        s"""
           |CREATE TABLE if not exists Item (
           |  `i_item_id`     int,
           |  `i_item_sk`     int
           |)
           |ROW FORMAT DELIMITED FIELDS TERMINATED BY ','
           |STORED AS TEXTFILE   
        """.stripMargin.trim
        )

    val dest = "case_11"
        
    sampleTestCases.foreach { testcase =>
      if (testcase._1 == dest) {
        val mvSession = new MVCatalogInSpark(testHive)
        val summary = testHive.sql(testcase._2)
        mvSession.registerSchema(summary)
        val rewrittenSQL =
          new MVRewrite(mvSession, mvSession.session.sql(
            testcase._3).queryExecution.optimizedPlan, mvSession.session).toCompactSQL.trim

        if (!rewrittenSQL.trim.equals(testcase._4)) {
          fail(
              s"""
              |=== FAIL: SQLs do not match ===
              |${sideBySide(rewrittenSQL, testcase._4).mkString("\n")}
              """.stripMargin)
              }
        }
    
    }
  }

} 
Example 143
Source File: Tpcds_1_4_Suite.scala    From carbondata   with Apache License 2.0 5 votes vote down vote up
package org.apache.carbondata.view.rewrite

import org.apache.carbondata.view.MVCatalogInSpark
import org.apache.carbondata.view.testutil.ModularPlanTest
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.optimizer.MVRewrite
import org.scalatest.BeforeAndAfter
//import org.apache.spark.sql.catalyst.SQLBuilder
import java.io.{File, PrintWriter}

class Tpcds_1_4_Suite extends ModularPlanTest with BeforeAndAfter {
  import org.apache.carbondata.view.rewrite.matching.TestTPCDS_1_4_Batch._
  import org.apache.carbondata.view.testutil.Tpcds_1_4_Tables._

  val spark = sqlContext
  val testHive = sqlContext.sparkSession

  test("test using tpc-ds queries") {

    tpcds1_4Tables.foreach { create_table =>
      hiveClient.runSqlHive(create_table)
    }

    val writer = new PrintWriter(new File("batch.txt"))
//    val dest = "case_30"
//    val dest = "case_32"
//    val dest = "case_33"
// case_15 and case_16 need revisit

    val dest = "case_39"   
    
    }

    writer.close()
  }
} 
Example 144
Source File: ReservationViewEndpointSpec.scala    From ddd-leaven-akka-v2   with MIT License 5 votes vote down vote up
package ecommerce.sales.app

import java.sql.Date

import akka.http.scaladsl.model.StatusCodes.NotFound
import akka.http.scaladsl.server._
import akka.http.scaladsl.testkit.ScalatestRouteTest
import com.typesafe.config.ConfigFactory
import ecommerce.sales.view.{ReservationDao, ReservationView, ViewTestSupport}
import ecommerce.sales.{ReservationStatus, SalesSerializationHintsProvider}
import org.joda.time.DateTime._
import org.json4s.Formats
import org.scalatest.{BeforeAndAfter, Matchers, WordSpecLike}
import pl.newicom.dddd.serialization.JsonSerHints._
import pl.newicom.dddd.utils.UUIDSupport.uuid7

class ReservationViewEndpointSpec extends WordSpecLike with Matchers with ScalatestRouteTest with ViewTestSupport with BeforeAndAfter {

  override lazy val config = ConfigFactory.load
  implicit val formats: Formats = new SalesSerializationHintsProvider().hints()

  lazy val dao = new ReservationDao
  val reservationId = uuid7

  before {
    viewStore.run {
      dao.createOrUpdate(ReservationView(reservationId, "client-1", ReservationStatus.Opened, new Date(now.getMillis)))
    }.futureValue
  }

  after {
    viewStore.run {
      dao.remove(reservationId)
    }.futureValue
  }

  "Reservation view endpoint" should {

    def response = responseAs[String]

    val route: Route = new ReservationViewEndpoint().route(viewStore)

    "respond to /reservation/all with all reservations" in {
      Get("/reservation/all") ~> route ~> check {
        response should include (reservationId)
      }
    }

    "respond to /reservation/{reservationId} with requested reservation" in {
      Get(s"/reservation/$reservationId") ~> route ~> check {
        response should include (reservationId)
      }
    }

    "respond to /reservation/{reservationId} with NotFound if reservation unknown" in {
      Get(s"/reservation/invalid") ~> route ~> check {
        status shouldBe NotFound
      }
    }

  }

  def ensureSchemaDropped = dao.ensureSchemaDropped
  def ensureSchemaCreated = dao.ensureSchemaCreated

} 
Example 145
Source File: ShipmentViewEndpointSpec.scala    From ddd-leaven-akka-v2   with MIT License 5 votes vote down vote up
package ecommerce.shipping.app

import akka.http.scaladsl.model.StatusCodes.NotFound
import akka.http.scaladsl.server._
import akka.http.scaladsl.testkit.{RouteTestTimeout, ScalatestRouteTest}
import akka.testkit.TestDuration
import com.typesafe.config.ConfigFactory
import ecommerce.sales.view.ViewTestSupport
import ecommerce.shipping.view.{ShipmentDao, ShipmentView}
import ecommerce.shipping.{ShippingSerializationHintsProvider, ShippingStatus}
import org.json4s.Formats
import org.scalatest.{BeforeAndAfter, Matchers, WordSpecLike}
import pl.newicom.dddd.serialization.JsonSerHints._
import pl.newicom.dddd.utils.UUIDSupport.uuid7

import scala.concurrent.duration.DurationInt

class ShipmentViewEndpointSpec extends WordSpecLike with Matchers with ScalatestRouteTest
  with ViewTestSupport with BeforeAndAfter {

  override lazy val config = ConfigFactory.load
  implicit val formats: Formats = new ShippingSerializationHintsProvider().hints()

  implicit val routeTimeout = RouteTestTimeout(3.seconds dilated)

  lazy val dao = new ShipmentDao
  val shipmentId = uuid7

  before {
    viewStore.run {
      dao.createOrUpdate(ShipmentView(shipmentId, "order-1", ShippingStatus.Delivered))
    }.futureValue
  }

  after {
    viewStore.run {
      dao.remove(shipmentId)
    }.futureValue
  }

  "Shipment view endpoint" should {

    def response = responseAs[String]

    val route: Route = new ShipmentViewEndpoint().route(viewStore)

    "respond to /shipment/all with all shipments" in {
      Get("/shipment/all") ~> route ~> check {
        response should include (shipmentId)
      }
    }

    "respond to /shipment/{shipmentId} with requested shipment" in {
      Get(s"/shipment/$shipmentId") ~> route ~> check {
        response should include (shipmentId)
      }
    }

    "respond to /shipment/{shipmentId} with NotFound if shipment unknown" in {
      Get(s"/shipment/invalid") ~> route ~> check {
        status shouldBe NotFound
      }
    }

  }

  def ensureSchemaDropped = dao.ensureSchemaDropped
  def ensureSchemaCreated = dao.ensureSchemaCreated

} 
Example 146
Source File: AMQPServerStreamSuite.scala    From streaming-amqp   with Apache License 2.0 5 votes vote down vote up
package io.radanalytics.streaming.amqp

import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.amqp.AMQPUtils
import org.apache.spark.streaming.{Duration, Seconds, StreamingContext}
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually

import scala.concurrent.duration._


class AMQPServerStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {
  
  private val batchDuration: Duration = Seconds(1)
  private val master: String = "local[2]"
  private val appName: String = this.getClass().getSimpleName()
  private val address: String = "my_address"
  private val checkpointDir: String = "/tmp/spark-streaming-amqp-tests"
  
  private var conf: SparkConf = _
  private var ssc: StreamingContext = _
  private var amqpTestUtils: AMQPTestUtils = _

  before {
    
    conf = new SparkConf().setMaster(master).setAppName(appName)
    conf.set("spark.streaming.receiver.writeAheadLog.enable", "true")
    ssc = new StreamingContext(conf, batchDuration)
    ssc.checkpoint(checkpointDir)
    
    amqpTestUtils = new AMQPTestUtils()
    amqpTestUtils.setup()
  }
  
  after {

    if (ssc != null) {
      ssc.stop()
    }

    if (amqpTestUtils != null) {
      amqpTestUtils.teardown()
    }
  }

  test("AMQP receive server") {

    val sendMessage = "Spark Streaming & AMQP"
    val max = 10
    val delay = 100l

    amqpTestUtils.startAMQPServer(sendMessage, max, delay)

    val converter = new AMQPBodyFunction[String]

    val receiveStream =
      AMQPUtils.createStream(ssc, amqpTestUtils.host, amqpTestUtils.port,
        amqpTestUtils.username, amqpTestUtils.password, address, converter, StorageLevel.MEMORY_ONLY)

    var receivedMessage: List[String] = List()
    receiveStream.foreachRDD(rdd => {
      if (!rdd.isEmpty()) {
        receivedMessage = receivedMessage ::: rdd.collect().toList
      }
    })

    ssc.start()

    eventually(timeout(10000 milliseconds), interval(1000 milliseconds)) {

      assert(receivedMessage.length == max)
    }
    ssc.stop()

    amqpTestUtils.stopAMQPServer()
  }
} 
Example 147
Source File: DaoServiceTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.scheduler.dao

import com.ivan.nikolov.scheduler.TestEnvironment
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class DaoServiceTest extends FlatSpec with Matchers with BeforeAndAfter with TestEnvironment {

  override val databaseService = new H2DatabaseService
  override val migrationService = new MigrationService
  override val daoService = new DaoServiceImpl
  
  before {
    // we run this here. Generally migrations will only
    // be dealing with data layout and we will be able to have
    // test classes that insert test data.
    migrationService.runMigrations()
  }
  
  after {
    migrationService.cleanupDatabase()
  }
  
  "readResultSet" should "properly iterate over a result set and apply a function to it." in {
    val connection = daoService.getConnection()
    try {
      val result = daoService.executeSelect(
        connection.prepareStatement(
          "SELECT name FROM people"
        )
      ) {
        case rs =>
          daoService.readResultSet(rs) {
            case row =>
              row.getString("name")
          }
      }
      result should have size(3)
      result should contain("Ivan")
      result should contain("Maria")
      result should contain("John")
    } finally {
      connection.close()
      
    }
  }
} 
Example 148
Source File: DaoServiceTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.scheduler.dao

import com.ivan.nikolov.scheduler.TestEnvironment
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class DaoServiceTest extends FlatSpec with Matchers with BeforeAndAfter with TestEnvironment {

  override val databaseService = new H2DatabaseService
  override val migrationService = new MigrationService
  override val daoService = new DaoServiceImpl
  
  before {
    // we run this here. Generally migrations will only
    // be dealing with data layout and we will be able to have
    // test classes that insert test data.
    migrationService.runMigrations()
  }
  
  after {
    migrationService.cleanupDatabase()
  }
  
  "readResultSet" should "properly iterate over a result set and apply a function to it." in {
    val connection = daoService.getConnection()
    try {
      val result = daoService.executeSelect(
        connection.prepareStatement(
          "SELECT name FROM people"
        )
      ) {
        case rs =>
          daoService.readResultSet(rs) {
            case row =>
              row.getString("name")
          }
      }
      result should have size(3)
      result should contain("Ivan")
      result should contain("Maria")
      result should contain("John")
    } finally {
      connection.close()
      
    }
  }
} 
Example 149
Source File: InstanceStoppingSpec.scala    From sbt-docker-compose   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
import com.tapad.docker.{ DockerComposePluginLocal, RunningInstanceInfo }
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.{ BeforeAndAfter, FunSuite, OneInstancePerTest }

class InstanceStoppingSpec extends FunSuite with BeforeAndAfter with OneInstancePerTest with MockHelpers {
  test("Validate the proper stopping of a single instance when only one instance is running and no specific instances are passed in as arguments") {
    val instanceId = "instanceId"
    val composePath = "path"
    val serviceName = "service"
    val composeMock = spy(new DockerComposePluginLocal)
    val instance = RunningInstanceInfo(instanceId, serviceName, composePath, List.empty)

    mockDockerCommandCalls(composeMock)
    mockSystemSettings(composeMock, serviceName, Some(List(instance)))

    composeMock.stopRunningInstances(null, Seq.empty)

    //Validate that the instance was stopped and cleaned up
    verify(composeMock, times(1)).dockerComposeStopInstance(instanceId, composePath)
    verify(composeMock, times(1)).dockerComposeRemoveContainers(instanceId, composePath)
  }

  test("Validate the proper stopping of a multiple instances when no specific instances are passed in as arguments") {
    val instanceId = "instanceId"
    val composePath = "path"
    val serviceName = "service"
    val composeMock = spy(new DockerComposePluginLocal)
    val instance = RunningInstanceInfo(instanceId, serviceName, composePath, List.empty)
    val instance2 = RunningInstanceInfo("instanceId2", serviceName, composePath, List.empty)

    mockDockerCommandCalls(composeMock)
    mockSystemSettings(composeMock, serviceName, Some(List(instance, instance2)))

    composeMock.stopRunningInstances(null, Seq.empty)

    //Validate that the instance was stopped and cleaned up
    verify(composeMock, times(2)).dockerComposeStopInstance(anyString, anyString)
    verify(composeMock, times(2)).dockerComposeRemoveContainers(anyString, anyString)
  }

  test("Validate the proper stopping of a single instance when multiple instances are running") {
    val instanceIdStop = "instanceIdStop"
    val instanceIdKeep = "instanceIdKeep"
    val serviceName = "service"
    val composePath = "path"
    val composeMock = spy(new DockerComposePluginLocal)
    val instanceStop = RunningInstanceInfo(instanceIdStop, serviceName, composePath, List.empty)
    val instanceKeep = RunningInstanceInfo(instanceIdKeep, serviceName, composePath, List.empty)

    mockDockerCommandCalls(composeMock)
    mockSystemSettings(composeMock, serviceName, Some(List(instanceStop, instanceKeep)))

    composeMock.stopRunningInstances(null, Seq(instanceIdStop))

    //Validate that only once instance was Stopped and Removed
    verify(composeMock, times(1)).setAttribute(any, any)(any[sbt.State])
    verify(composeMock, times(1)).dockerComposeStopInstance(anyString, anyString)
    verify(composeMock, times(1)).dockerComposeRemoveContainers(anyString, anyString)
  }

  test("Validate that only instances from the current SBT project are stopped when no arguments are supplied to DockerComposeStop") {
    val composeMock = spy(new DockerComposePluginLocal)
    val serviceName = "matchingservice"
    val instance1 = RunningInstanceInfo("instanceName1", serviceName, "path", List.empty)
    val instance2 = RunningInstanceInfo("instanceName2", serviceName, "path", List.empty)
    val instance3 = RunningInstanceInfo("instanceName3", "nonSbtProjectService", "path", List.empty)

    mockDockerCommandCalls(composeMock)
    mockSystemSettings(composeMock, serviceName, Some(List(instance1, instance2, instance3)))

    composeMock.stopRunningInstances(null, Seq.empty)

    //Validate that only once instance was Stopped and Removed
    verify(composeMock, times(1)).setAttribute(any, any)(any[sbt.State])
    verify(composeMock, times(2)).dockerComposeStopInstance(anyString, anyString)
    verify(composeMock, times(2)).dockerComposeRemoveContainers(anyString, anyString)
  }

  test("Validate that instances from any SBT project can be stopped when explicitly passed to DockerComposeStop") {
    val composeMock = spy(new DockerComposePluginLocal)
    val serviceName = "matchingservice"
    val instance1 = RunningInstanceInfo("instanceName1", serviceName, "path", List.empty)
    val instance2 = RunningInstanceInfo("instanceName2", serviceName, "path", List.empty)
    val instance3 = RunningInstanceInfo("instanceName3", "nonSbtProjectService", "path", List.empty)

    mockDockerCommandCalls(composeMock)
    mockSystemSettings(composeMock, serviceName, Some(List(instance1, instance2, instance3)))

    composeMock.stopRunningInstances(null, Seq("instanceName3"))

    //Validate that only once instance was Stopped and Removed
    verify(composeMock, times(1)).setAttribute(any, any)(any[sbt.State])
    verify(composeMock, times(1)).dockerComposeStopInstance(anyString, anyString)
    verify(composeMock, times(1)).dockerComposeRemoveContainers(anyString, anyString)
  }
} 
Example 150
Source File: VersionSpec.scala    From sbt-docker-compose   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
import com.tapad.docker.Version
import org.scalatest.{ BeforeAndAfter, FunSuite, OneInstancePerTest }

class VersionSpec extends FunSuite with BeforeAndAfter with OneInstancePerTest with MockHelpers {
  test("Validate version information is parsed correctly") {
    assert(Version.parseVersion("1.0.0") == Version(1, 0, 0))
    assert(Version.parseVersion("11.1.1") == Version(11, 1, 1))
    assert(Version.parseVersion("1.0.0-SNAPSHOT") == Version(1, 0, 0))
    assert(Version.parseVersion("1.2.3") == Version(1, 2, 3))
    assert(Version.parseVersion("1.2.3-rc3") == Version(1, 2, 3))
    assert(Version.parseVersion("1.2.3rc3") == Version(1, 2, 3))
  }

  test("Validate invalid version information reports an exception") {
    intercept[RuntimeException] {
      Version.parseVersion("")
    }

    intercept[RuntimeException] {
      Version.parseVersion("1.0")
    }

    intercept[RuntimeException] {
      Version.parseVersion("-1.0")
    }

    intercept[RuntimeException] {
      Version.parseVersion("version")
    }
  }
} 
Example 151
Source File: ComposeInstancesSpec.scala    From sbt-docker-compose   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
import sbt._
import com.tapad.docker.{ DockerComposePluginLocal, RunningInstanceInfo, Version }
import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.{ BeforeAndAfter, FunSuite, OneInstancePerTest }

class ComposeInstancesSpec extends FunSuite with BeforeAndAfter with OneInstancePerTest with MockHelpers {
  test("Validate that no instances are printed when none are running") {
    val composeMock = spy(new DockerComposePluginLocal)
    val serviceName = "matchingservice"

    mockDockerCommandCalls(composeMock)
    mockSystemSettings(composeMock, serviceName, None)

    composeMock.printDockerComposeInstances(null, null)

    verify(composeMock, times(0)).printMappedPortInformation(any[State], any[RunningInstanceInfo], any[Version])
  }

  test("Validate that multiple instances across sbt projects are printed when they are running") {
    val composeMock = spy(new DockerComposePluginLocal)
    val serviceName = "matchingservice"
    val instance1 = RunningInstanceInfo("instanceName1", serviceName, "path", List.empty)
    val instance2 = RunningInstanceInfo("instanceName2", serviceName, "path", List.empty)
    val instance3 = RunningInstanceInfo("instanceName3", "nonSbtProjectService", "path", List.empty)

    mockDockerCommandCalls(composeMock)
    mockSystemSettings(composeMock, serviceName, Some(List(instance1, instance2, instance3)))

    composeMock.printDockerComposeInstances(null, null)

    verify(composeMock, times(3)).printMappedPortInformation(any[State], any[RunningInstanceInfo], any[Version])
  }
} 
Example 152
Source File: TagProcessingSpec.scala    From sbt-docker-compose   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
import com.tapad.docker.DockerComposePlugin._
import org.scalatest.{ BeforeAndAfter, FunSuite, OneInstancePerTest }

class TagProcessingSpec extends FunSuite with BeforeAndAfter with OneInstancePerTest {

  val imageNoTag = "testImage"
  val imageLatestTag = "testImage:latest"
  val imageWithTag = "testImage:tag"
  val imagePrivateRegistryNoTag = "registry/testImage"
  val imagePrivateRegistryWithLatest = "registry/testImage:latest"
  val imagePrivateRegistryWithTag = "registry/testImage:tag"
  val imagePrivateRegistryWithOrgNoTag = "registry/org/testImage"
  val imagePrivateRegistryWithOrgWithTag = "registry/org/testImage:tag"
  val imageCustomTag = "testImage<localbuild>"
  val imageTagAndCustomTag = "testImage:latest<localbuild>"

  // Boundary
  val badImageWithColon = "testImage:"
  val badImageWithMultipleColon = "testImage:fooImage:latest"
  val badImageWithOnlyColon = ":::::::"

  test("Validate various image tag formats are properly replaced") {
    val replacementTag = "replaceTag"
    assert(replaceDefinedVersionTag(imageNoTag, replacementTag) == imageNoTag)

    assert(replaceDefinedVersionTag(imageLatestTag, replacementTag) == imageLatestTag)

    assert(replaceDefinedVersionTag(imageWithTag, replacementTag) == s"testImage:$replacementTag")

    assert(replaceDefinedVersionTag(imagePrivateRegistryNoTag, replacementTag) == imagePrivateRegistryNoTag)

    assert(replaceDefinedVersionTag(imagePrivateRegistryWithLatest, replacementTag) == imagePrivateRegistryWithLatest)

    assert(replaceDefinedVersionTag(imagePrivateRegistryWithTag, replacementTag) == s"registry/testImage:$replacementTag")
  }

  test("Validate image tag retrieval from various formats") {
    assert(getTagFromImage(imageNoTag) == "latest")

    assert(getTagFromImage(imageLatestTag) == "latest")

    assert(getTagFromImage(imageWithTag) == "tag")

    assert(getTagFromImage(imagePrivateRegistryNoTag) == "latest")

    assert(getTagFromImage(imagePrivateRegistryWithLatest) == "latest")

    assert(getTagFromImage(imagePrivateRegistryWithTag) == "tag")
  }

  test("Validate custom tags get removed") {
    assert(processImageTag(null, null, imageCustomTag) == "testImage")
    assert(processImageTag(null, null, imageTagAndCustomTag) == "testImage:latest")
  }

  test("Validate the removal of a tag from various image formats") {
    assert(getImageNameOnly(imageNoTag) == imageNoTag)
    assert(getImageNameOnly(imageLatestTag) == "testImage")
    assert(getImageNameOnly(imagePrivateRegistryNoTag) == "testImage")
    assert(getImageNameOnly(imagePrivateRegistryWithLatest) == "testImage")
    assert(getImageNameOnly(imagePrivateRegistryWithTag) == "testImage")
    assert(getImageNameOnly(imagePrivateRegistryWithOrgWithTag) == "testImage")
    assert(getImageNameOnly(imagePrivateRegistryWithOrgWithTag, removeOrganization = false) == "org/testImage")
  }

  test("Validate getting image name with no tag") {
    assert(getImageNoTag("") == "")
    assert(getImageNoTag(imageNoTag) == imageNoTag)
    assert(getImageNoTag(imageLatestTag) == imageNoTag)
    assert(getImageNoTag(imagePrivateRegistryNoTag) == imagePrivateRegistryNoTag)
    assert(getImageNoTag(imagePrivateRegistryWithLatest) == imagePrivateRegistryNoTag)
    assert(getImageNoTag(imagePrivateRegistryWithTag) == imagePrivateRegistryNoTag)
    assert(getImageNoTag(imagePrivateRegistryWithOrgWithTag) == imagePrivateRegistryWithOrgNoTag)
    assert(getImageNoTag(badImageWithColon) == imageNoTag)
    assert(getImageNoTag(badImageWithMultipleColon) == badImageWithMultipleColon.split(":").dropRight(1).mkString(":"))
    assert(getImageNoTag(badImageWithOnlyColon) == badImageWithOnlyColon.dropRight(1))
  }
} 
Example 153
Source File: ImagePullingSpec.scala    From sbt-docker-compose   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
import com.tapad.docker.DockerComposePlugin._
import com.tapad.docker.{ ServiceInfo, DockerComposePluginLocal }
import org.mockito.Mockito._
import org.scalatest.{ OneInstancePerTest, BeforeAndAfter, FunSuite }

class ImagePullingSpec extends FunSuite with BeforeAndAfter with OneInstancePerTest {
  test("Validate that when the 'skipPull' argument is passed in no imaged are pull from the Docker registry") {
    val instanceMock = new DockerComposePluginLocal with MockOutput

    instanceMock.pullDockerImages(Seq(skipPullArg), null, suppressColor = false)
    assert(instanceMock.messages.exists(_.contains("Skipping Docker Repository Pull for all images.")))
  }

  test("Validate that images with a 'build' source not pulled from the Docker registry") {
    val instanceMock = new DockerComposePluginLocal with MockOutput
    val imageName = "buildImageName"
    val serviceInfo = ServiceInfo("serviceName", imageName, buildImageSource, null)

    instanceMock.pullDockerImages(null, List(serviceInfo), suppressColor = false)
    assert(instanceMock.messages.contains(s"Skipping Pull of image: $imageName"))
  }

  test("Validate that images with a 'defined' source are pulled from the Docker registry") {
    val instanceMock = spy(new DockerComposePluginLocal)
    val imageName = "buildImageName"
    val serviceInfo = ServiceInfo("serviceName", imageName, definedImageSource, null)

    doNothing().when(instanceMock).dockerPull(imageName)

    instanceMock.pullDockerImages(null, List(serviceInfo), suppressColor = false)

    verify(instanceMock, times(1)).dockerPull(imageName)
  }

  test("Validate that images with a 'cache' source are not pulled from the Docker registry") {
    val instanceMock = new DockerComposePluginLocal with MockOutput
    val imageName = "cacheImageName"
    val serviceInfo = ServiceInfo("serviceName", imageName, cachedImageSource, null)

    instanceMock.pullDockerImages(null, List(serviceInfo), suppressColor = false)
    assert(instanceMock.messages.contains(s"Skipping Pull of image: $imageName"))
  }
} 
Example 154
Source File: InstancePersistenceSpec.scala    From sbt-docker-compose   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
import com.tapad.docker.{ RunningInstanceInfo, DockerComposePluginLocal }
import com.tapad.docker.DockerComposeKeys._
import org.mockito.Mockito._
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{ BeforeAndAfter, FunSuite, OneInstancePerTest }

class InstancePersistenceSpec extends FunSuite with BeforeAndAfter with OneInstancePerTest with MockitoSugar {

  test("Validate that only running instances from this sbt session are returned") {
    val instanceMock = spy(new DockerComposePluginLocal)

    val runningInstanceMatch = RunningInstanceInfo("instanceNameMatch", "matchingservice", "composePath", List.empty)
    val runningInstanceNoMatch = RunningInstanceInfo("instanceNameNoMatch", "nomatchingservice", "composePath", List.empty)

    doReturn("matchingservice").when(instanceMock).getSetting(composeServiceName)(null)
    doReturn(Option(List(runningInstanceMatch, runningInstanceNoMatch))).when(instanceMock).getAttribute(runningInstances)(null)

    val instanceIds = instanceMock.getServiceRunningInstanceIds(null)

    assert(instanceIds.size == 1)
    assert(instanceIds.contains("instanceNameMatch"))
  }

  test("Validate that only matching instance ids are returned") {
    val instanceMock = spy(new DockerComposePluginLocal)

    val runningInstanceMatch = RunningInstanceInfo("instanceNameMatch", "matchingservice", "composePath", List.empty)
    val runningInstanceNoMatch = RunningInstanceInfo("instanceNameNoMatch", "nomatchingservice", "composePath", List.empty)

    doReturn("matchingservice").when(instanceMock).getSetting(composeServiceName)(null)
    doReturn(Option(List(runningInstanceMatch, runningInstanceNoMatch))).when(instanceMock).getAttribute(runningInstances)(null)

    val instance = instanceMock.getMatchingRunningInstance(null, Seq("instanceNameMatch"))

    assert(instance.isDefined)
    assert(instance.get.instanceName == "instanceNameMatch")
  }
} 
Example 155
Source File: ImageBuildingSpec.scala    From sbt-docker-compose   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
import com.tapad.docker.DockerComposeKeys._
import com.tapad.docker.DockerComposePlugin._
import com.tapad.docker.DockerComposePluginLocal
import org.mockito.Mockito._
import org.scalatest.{ OneInstancePerTest, BeforeAndAfter, FunSuite }

class ImageBuildingSpec extends FunSuite with BeforeAndAfter with OneInstancePerTest {
  test("Validate that a Docker image is built when 'skipBuild' and 'noBuild' are not set") {
    val composeMock = spy(new DockerComposePluginLocal)

    doReturn(false).when(composeMock).getSetting(suppressColorFormatting)(null)
    doReturn(false).when(composeMock).getSetting(composeNoBuild)(null)
    doNothing().when(composeMock).buildDockerImageTask(null)

    composeMock.buildDockerImage(null, null)

    verify(composeMock, times(1)).buildDockerImageTask(null)
  }

  test("Validate that a Docker image is not built when 'skipBuild' is passed as an argument") {
    val composeMock = spy(new DockerComposePluginLocal)

    doReturn(false).when(composeMock).getSetting(suppressColorFormatting)(null)
    doReturn(false).when(composeMock).getSetting(composeNoBuild)(null)
    doNothing().when(composeMock).buildDockerImageTask(null)

    composeMock.buildDockerImage(null, Seq(skipBuildArg))

    verify(composeMock, times(0)).buildDockerImageTask(null)
  }

  test("Validate that a Docker image is not built when the 'noBuild' setting is true") {
    val composeMock = spy(new DockerComposePluginLocal)

    doReturn(false).when(composeMock).getSetting(suppressColorFormatting)(null)
    doReturn(true).when(composeMock).getSetting(composeNoBuild)(null)
    doNothing().when(composeMock).buildDockerImageTask(null)

    composeMock.buildDockerImage(null, null)

    verify(composeMock, times(0)).buildDockerImageTask(null)
  }
} 
Example 156
Source File: DockerTmpDB.scala    From akka-stream-extensions   with Apache License 2.0 5 votes vote down vote up
package com.mfglabs.stream
package extensions.postgres

import java.sql.{DriverManager, Connection}
import org.postgresql.util.PSQLException
import org.scalatest.{Suite, BeforeAndAfter}
import scala.sys.process._
import scala.util.{Failure, Success, Try}
import com.typesafe.config.ConfigFactory

trait DockerTmpDB extends BeforeAndAfter { self: Suite =>

  import Debug._

  val version: PostgresVersion = PostgresVersion(ConfigFactory.load().getString("postgres.version"))

  Class.forName("org.postgresql.Driver")
  implicit var conn : Connection = _

  val dockerInstances = collection.mutable.Buffer.empty[String]

  def newPGDB(): Int = {
    val port: Int = 5432 + (math.random * (10000 - 5432)).toInt
    Try {
      s"docker pull postgres:${version.value}".pp.!!.trim
      val containerId =
        s"""docker run -p $port:5432 -e POSTGRES_PASSWORD=pwd -d postgres:${version.value}""".pp.!!.trim
      dockerInstances += containerId.pp("New docker instance with id")
      port
    } match {
      case Success(p) => p
      case Failure(err) =>
        throw  new IllegalStateException(s"Error while trying to run docker container", err)
    }
  }

  lazy val dockerIp: String =
    Try("docker-machine ip default".!!.trim).toOption
      .orElse {
        val conf = ConfigFactory.load()
        if (conf.hasPath("docker.ip")) Some(conf.getString("docker.ip")) else None
      }
      .getOrElse("127.0.0.1") // platform dependent

  //ugly solution to wait for the connection to be ready
  def waitsForConnection(port : Int) : Connection = {
    try {
      DriverManager.getConnection(s"jdbc:postgresql://$dockerIp:$port/postgres", "postgres", "pwd")
    } catch {
      case _: PSQLException =>
        println("Retrying DB connection...")
        Thread.sleep(1000)
        waitsForConnection(port)
    }
  }

  before {
    val port = newPGDB()
    println(s"New postgres ${version.value} instance at port $port")
    Thread.sleep(5000)
    conn = waitsForConnection(port)
  }

  after {
    conn.close()
    dockerInstances.toSeq.foreach { dockerId =>
      s"docker stop $dockerId".pp.!!
      s"docker rm $dockerId".pp.!!
    }
  }

}

object Debug {

  implicit class RichString(s:String){
    def pp :String = pp(None)
    def pp(p:String) :String = pp(Some(p))

    private def pp(p:Option[String]) = {
      println(p.map(_ + " ").getOrElse("") + s)
      s
    }
  }
} 
Example 157
Source File: WSProxyConfigurationSpec.scala    From http-verbs   with Apache License 2.0 5 votes vote down vote up
package uk.gov.hmrc.play.http.ws

import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpecLike
import play.api.libs.ws.DefaultWSProxyServer
import play.api.test.{FakeApplication, WithApplication}
import uk.gov.hmrc.play.http.ws.WSProxyConfiguration.ProxyConfigurationException

class WSProxyConfigurationSpec extends AnyWordSpecLike with Matchers with BeforeAndAfter {

  def proxyFlagConfiguredTo(value: Boolean): Map[String, Any] =
    Map("Dev.httpProxy.proxyRequiredForThisEnvironment" -> value)

  def proxyConfigWithFlagSetTo(flag: Option[Boolean] = None): Map[String, Any] =
    Map(
      "Dev.httpProxy.protocol" -> "https",
      "Dev.httpProxy.host"     -> "localhost",
      "Dev.httpProxy.port"     -> 7979,
      "Dev.httpProxy.username" -> "user",
      "Dev.httpProxy.password" -> "secret"
    ) ++ flag.fold(Map.empty[String, Any])(flag => proxyFlagConfiguredTo(flag))

  val proxy = DefaultWSProxyServer(
    protocol  = Some("https"),
    host      = "localhost",
    port      = 7979,
    principal = Some("user"),
    password  = Some("secret")
  )

  "If the proxyRequiredForThisEnvironment flag is not present, the WSProxyConfiguration apply method" should {

    "fail if no proxy is defined" in new WithApplication(FakeApplication()) {
      a[ProxyConfigurationException] should be thrownBy WSProxyConfiguration("Dev.httpProxy")
    }

    "return the proxy configuration if the proxy is defined" in new WithApplication(
      FakeApplication(additionalConfiguration = proxyConfigWithFlagSetTo(None))) {
      WSProxyConfiguration("Dev.httpProxy") shouldBe Some(proxy)
    }
  }

  "If the proxyRequiredForThisEnvironment flag is set to true, the WSProxyConfiguration apply method" should {

    "fail if no proxy is defined" in new WithApplication(
      FakeApplication(additionalConfiguration = proxyFlagConfiguredTo(value = true))) {
      a[ProxyConfigurationException] should be thrownBy WSProxyConfiguration("Dev.httpProxy")
    }

    "return the proxy configuration if the proxy is defined" in new WithApplication(
      FakeApplication(additionalConfiguration = proxyConfigWithFlagSetTo(Some(true)))) {
      WSProxyConfiguration("Dev.httpProxy") shouldBe Some(proxy)
    }
  }

  "If the proxyRequiredForThisEnvironment flag is set to false, the WSProxyConfiguration apply method" should {
    "return None if no proxy is defined" in new WithApplication(
      FakeApplication(additionalConfiguration = proxyFlagConfiguredTo(value = false))) {
      WSProxyConfiguration("Dev.httpProxy") shouldBe None
    }

    "return None if the proxy is defined" in new WithApplication(
      FakeApplication(additionalConfiguration = proxyConfigWithFlagSetTo(Some(false)))) {
      WSProxyConfiguration("Dev.httpProxy") shouldBe None
    }
  }
} 
Example 158
Source File: TemporalDataSuite.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.datasource.config.ConfigParameters._
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.BeforeAndAfter

private[datasource] trait TemporalDataSuite extends DatasourceSuite
  with BeforeAndAfter {

  val conf = new SparkConf()
    .setAppName("datasource-receiver-example")
    .setIfMissing("spark.master", "local[*]")
  var sc: SparkContext = null
  var ssc: StreamingContext = null
  val tableName = "tableName"
  val datasourceParams = Map(
    StopGracefully -> "true",
    StopSparkContext -> "false",
    StorageLevelKey -> "MEMORY_ONLY",
    RememberDuration -> "15s"
  )
  val schema = new StructType(Array(
    StructField("id", StringType, nullable = true),
    StructField("idInt", IntegerType, nullable = true)
  ))
  val totalRegisters = 10000
  val registers = for (a <- 1 to totalRegisters) yield Row(a.toString, a)

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (sc != null) {
      sc.stop()
      sc = null
    }
  }
} 
Example 159
Source File: UDFTest.scala    From SparkGIS   with Apache License 2.0 5 votes vote down vote up
package org.betterers.spark.gis

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.{SQLContext, Row}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.betterers.spark.gis.udf.Functions


class UDFTest extends FunSuite with BeforeAndAfter {
  import Geometry.WGS84

  val point = Geometry.point((2.0, 2.0))
  val multiPoint = Geometry.multiPoint((1.0, 1.0), (2.0, 2.0), (3.0, 3.0))
  var line = Geometry.line((11.0, 11.0), (12.0, 12.0))
  var multiLine = Geometry.multiLine(
    Seq((11.0, 1.0), (23.0, 23.0)),
    Seq((31.0, 3.0), (42.0, 42.0)))
  var polygon = Geometry.polygon((1.0, 1.0), (2.0, 2.0), (3.0, 1.0))
  var multiPolygon = Geometry.multiPolygon(
    Seq((1.0, 1.0), (2.0, 2.0), (3.0, 1.0)),
    Seq((1.1, 1.1), (2.0, 1.9), (2.5, 1.1))
  )
  val collection = Geometry.collection(point, multiPoint, line)
  val all: Seq[Geometry] = Seq(point, multiPoint, line, multiLine, polygon, multiPolygon, collection)

  var sc: SparkContext = _
  var sql: SQLContext = _

  before {
    sc = new SparkContext(new SparkConf().setMaster("local[4]").setAppName("SparkGIS"))
    sql = new SQLContext(sc)
  }

  after {
    sc.stop()
  }

  test("ST_Boundary") {
    // all.foreach(g => println(Functions.ST_Boundary(g).toString))

    assertResult(true) {
      Functions.ST_Boundary(point).isEmpty
    }
    assertResult(true) {
      Functions.ST_Boundary(multiPoint).isEmpty
    }
    assertResult("Some(MULTIPOINT ((11 11), (12 12)))") {
      Functions.ST_Boundary(line).toString
    }
    assertResult(None) {
      Functions.ST_Boundary(multiLine)
    }
    assertResult("Some(LINEARRING (1 1, 2 2, 3 1, 1 1))") {
      Functions.ST_Boundary(polygon).toString
    }
    assertResult(None) {
      Functions.ST_Boundary(multiPolygon)
    }
    assertResult(None) {
      Functions.ST_Boundary(collection)
    }
  }

  test("ST_CoordDim") {
    all.foreach(g => {
      assertResult(3) {
        Functions.ST_CoordDim(g)
      }
    })
  }

  test("UDF in SQL") {
    val schema = StructType(Seq(
      StructField("id", IntegerType),
      StructField("geo", GeometryType.Instance)
    ))
    val jsons = Map(
      (1, "{\"type\":\"Point\",\"coordinates\":[1,1]}}"),
      (2, "{\"type\":\"LineString\",\"coordinates\":[[12,13],[15,20]]}}")
    )
    val rdd = sc.parallelize(Seq(
      "{\"id\":1,\"geo\":" + jsons(1) + "}",
      "{\"id\":2,\"geo\":" + jsons(2) + "}"
    ))
    rdd.name = "TEST"
    val df = sql.read.schema(schema).json(rdd)
    df.registerTempTable("TEST")
    Functions.register(sql)
    assertResult(Array(3,3)) {
      sql.sql("SELECT ST_CoordDim(geo) FROM TEST").collect().map(_.get(0))
    }
  }
} 
Example 160
Source File: LocalMessageStoreSuite.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.bahir.sql.streaming.mqtt

import java.io.File

import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite

import org.apache.bahir.utils.FileHelper


class LocalMessageStoreSuite extends SparkFunSuite with BeforeAndAfter {

  private val testData = Seq(1, 2, 3, 4, 5, 6)
  private val javaSerializer: JavaSerializer = new JavaSerializer()

  private val serializerInstance = javaSerializer
  private val tempDir: File = new File(System.getProperty("java.io.tmpdir") + "/mqtt-test2/")
  private val persistence: MqttDefaultFilePersistence =
    new MqttDefaultFilePersistence(tempDir.getAbsolutePath)

  private val store = new LocalMessageStore(persistence, javaSerializer)

  before {
    tempDir.mkdirs()
    tempDir.deleteOnExit()
    persistence.open("temp", "tcp://dummy-url:0000")
  }

  after {
    persistence.clear()
    persistence.close()
    FileHelper.deleteFileQuietly(tempDir)
  }

  test("serialize and deserialize") {
      val serialized = serializerInstance.serialize(testData)
    val deserialized: Seq[Int] = serializerInstance
      .deserialize(serialized).asInstanceOf[Seq[Int]]
    assert(testData === deserialized)
  }

  test("Store and retrieve") {
    store.store(1, testData)
    val result: Seq[Int] = store.retrieve(1)
    assert(testData === result)
  }

  test("Max offset stored") {
    store.store(1, testData)
    store.store(10, testData)
    val offset = store.maxProcessedOffset
    assert(offset == 10)
  }

} 
Example 161
Source File: AkkaStreamSuite.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.akka

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.JavaConverters._
import scala.concurrent.Await
import scala.concurrent.duration._

import akka.actor._
import com.typesafe.config.ConfigFactory
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.streaming.{Milliseconds, StreamingContext}

class AkkaStreamSuite extends SparkFunSuite with Eventually with BeforeAndAfter {

  private var ssc: StreamingContext = _

  private var actorSystem: ActorSystem = _

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
    if (actorSystem != null) {
      Await.ready(actorSystem.terminate(), 30.seconds)
      actorSystem = null
    }
  }

  test("actor input stream") {
    val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
    ssc = new StreamingContext(sparkConf, Milliseconds(500))

    // we set the TCP port to "0" to have the port chosen automatically for the Feeder actor and
    // the Receiver actor will "pick it up" from the Feeder URI when it subscribes to the Feeder
    // actor (http://doc.akka.io/docs/akka/2.3.11/scala/remoting.html)
    val akkaConf = ConfigFactory.parseMap(
      Map(
        "akka.actor.provider" -> "akka.remote.RemoteActorRefProvider",
        "akka.remote.netty.tcp.transport-class" -> "akka.remote.transport.netty.NettyTransport",
        "akka.remote.netty.tcp.port" -> "0").
        asJava)
    actorSystem = ActorSystem("test", akkaConf)
    actorSystem.actorOf(Props(classOf[FeederActor]), "FeederActor")
    val feederUri =
      actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + "/user/FeederActor"

    val actorStream =
      AkkaUtils.createStream[String](ssc, Props(classOf[TestActorReceiver], feederUri),
        "TestActorReceiver")
    val result = new ConcurrentLinkedQueue[String]
    actorStream.foreachRDD { rdd =>
      rdd.collect().foreach(result.add)
    }
    ssc.start()

    eventually(timeout(10.seconds), interval(10.milliseconds)) {
      assert((1 to 10).map(_.toString) === result.asScala.toList)
    }
  }
}

case class SubscribeReceiver(receiverActor: ActorRef)

class FeederActor extends Actor {

  def receive: Receive = {
    case SubscribeReceiver(receiverActor: ActorRef) =>
      (1 to 10).foreach(i => receiverActor ! i.toString())
  }
}

class TestActorReceiver(uriOfPublisher: String) extends ActorReceiver {

  lazy private val remotePublisher = context.actorSelection(uriOfPublisher)

  override def preStart(): Unit = {
    remotePublisher ! SubscribeReceiver(self)
  }

  def receive: PartialFunction[Any, Unit] = {
    case msg: String => store(msg)
  }

} 
Example 162
Source File: TwitterStreamSuite.scala    From bahir   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.twitter

import java.util.UUID

import scala.collection.mutable

import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.Eventually
import org.scalatest.time
import org.scalatest.time.Span
import twitter4j.{FilterQuery, Status, TwitterFactory}
import twitter4j.auth.{Authorization, NullAuthorization}

import org.apache.spark.ConditionalSparkFunSuite
import org.apache.spark.internal.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.streaming.dstream.ReceiverInputDStream

class TwitterStreamSuite extends ConditionalSparkFunSuite
    with Eventually with BeforeAndAfter with Logging {
  def shouldRunTest(): Boolean = sys.env.get("ENABLE_TWITTER_TESTS").contains("1")

  var ssc: StreamingContext = _

  before {
    ssc = new StreamingContext("local[2]", this.getClass.getSimpleName, Seconds(1))
  }

  after {
    if (ssc != null) {
      ssc.stop()
    }
  }

  test("twitter input stream") {
    val filters = Seq("filter1", "filter2")
    val query = new FilterQuery().language("fr,es")
    val authorization: Authorization = NullAuthorization.getInstance()

    // tests the API, does not actually test data receiving
    val test1: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None)
    val test2: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, None, filters)
    val test3: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_AND_DISK_SER_2)
    val test4: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, Some(authorization))
    val test5: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, Some(authorization), filters)
    val test6: ReceiverInputDStream[Status] = TwitterUtils.createStream(
      ssc, Some(authorization), filters, StorageLevel.MEMORY_AND_DISK_SER_2)
    val test7: ReceiverInputDStream[Status] = TwitterUtils.createFilteredStream(
      ssc, Some(authorization), Some(query), StorageLevel.MEMORY_AND_DISK_SER_2)
  }

  testIf("messages received", () => TwitterStreamSuite.this.shouldRunTest()) {
    val userId = TwitterFactory.getSingleton.updateStatus(
      UUID.randomUUID().toString
    ).getUser.getId

    val receiveStream = TwitterUtils.createFilteredStream(
      ssc, None, Some(new FilterQuery().follow(userId))
    )
    @volatile var receivedMessages: mutable.Set[Status] = mutable.Set()
    receiveStream.foreachRDD { rdd =>
      for (element <- rdd.collect()) {
        receivedMessages += element
      }
      receivedMessages
    }
    ssc.start()

    val nbOfMsg = 2
    var publishedMessages: List[String] = List()

    (1 to nbOfMsg).foreach(
      _ => {
        publishedMessages = UUID.randomUUID().toString :: publishedMessages
      }
    )

    eventually(timeout(Span(15, time.Seconds)), interval(Span(1000, time.Millis))) {
      publishedMessages.foreach(
        m => if (!receivedMessages.map(m => m.getText).contains(m.toString)) {
          TwitterFactory.getSingleton.updateStatus(m)
        }
      )
      assert(
        publishedMessages.map(m => m.toString).toSet
          .subsetOf(receivedMessages.map(m => m.getText))
      )
    }
  }
} 
Example 163
Source File: DiskStorageOperationsSpec.scala    From nexus   with Apache License 2.0 5 votes vote down vote up
package ch.epfl.bluebrain.nexus.kg.storage

import java.nio.file.Paths

import akka.http.scaladsl.model.{ContentTypes, Uri}
import cats.effect.IO
import ch.epfl.bluebrain.nexus.commons.test._
import ch.epfl.bluebrain.nexus.commons.test.io.IOEitherValues
import ch.epfl.bluebrain.nexus.kg.config.KgConfig._
import ch.epfl.bluebrain.nexus.kg.resources.file.File.FileDescription
import ch.epfl.bluebrain.nexus.kg.resources.Id
import ch.epfl.bluebrain.nexus.kg.resources.ProjectIdentifier.ProjectRef
import ch.epfl.bluebrain.nexus.kg.{KgError, TestHelper}
import ch.epfl.bluebrain.nexus.service.config.Settings
import ch.epfl.bluebrain.nexus.sourcing.RetryStrategyConfig
import org.mockito.IdiomaticMockito
import org.scalatest.{BeforeAndAfter, OptionValues}
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpecLike

import scala.concurrent.duration._

class DiskStorageOperationsSpec
    extends ActorSystemFixture("DiskStorageOperationsSpec")
    with AnyWordSpecLike
    with Matchers
    with BeforeAndAfter
    with IdiomaticMockito
    with IOEitherValues
    with Resources
    with TestHelper
    with OptionValues {

  implicit private val appConfig = Settings(system).serviceConfig

  implicit private val sc: StorageConfig = appConfig.kg.storage.copy(
    DiskStorageConfig(Paths.get("/tmp"), "SHA-256", read, write, false, 1024L),
    RemoteDiskStorageConfig("http://example.com", "v1", None, "SHA-256", read, write, true, 1024L),
    S3StorageConfig("MD5", read, write, true, 1024L),
    "password",
    "salt",
    RetryStrategyConfig("linear", 300.millis, 5.minutes, 100, 1.second)
  )

  private val project  = ProjectRef(genUUID)
  private val storage  = Storage.DiskStorage.default(project)
  private val resId    = Id(storage.ref, genIri)
  private val fileDesc = FileDescription("my file.txt", ContentTypes.`text/plain(UTF-8)`)

  "DiskStorageOperations" should {

    "verify when the storage exists" in {
      val verify = new DiskStorageOperations.VerifyDiskStorage[IO](storage)
      verify.apply.accepted
    }

    "save and fetch files" in {
      val save   = new DiskStorageOperations.SaveDiskFile[IO](storage)
      val fetch  = new DiskStorageOperations.FetchDiskFile[IO]()
      val source = genSource

      val attr    = save.apply(resId, fileDesc, source).ioValue
      attr.bytes shouldEqual 16L
      attr.filename shouldEqual fileDesc.filename
      attr.mediaType shouldEqual fileDesc.mediaType.value
      attr.location shouldEqual Uri(s"file:///tmp/${mangle(project, attr.uuid, "my%20file.txt")}")
      attr.path shouldEqual attr.location.path.tail.tail.tail
      val fetched = fetch.apply(attr).ioValue

      consume(source) shouldEqual consume(fetched)
    }

    "not link files" in {
      val link = new DiskStorageOperations.LinkDiskFile[IO]()
      link.apply(resId, fileDesc, Uri.Path("/foo")).failed[KgError] shouldEqual KgError.UnsupportedOperation
    }
  }

} 
Example 164
Source File: RemoteDiskStorageOperationsSpec.scala    From nexus   with Apache License 2.0 5 votes vote down vote up
package ch.epfl.bluebrain.nexus.kg.storage

import akka.http.scaladsl.model.ContentTypes._
import akka.http.scaladsl.model.Uri
import cats.effect.IO
import ch.epfl.bluebrain.nexus.commons.test.io.IOEitherValues
import ch.epfl.bluebrain.nexus.commons.test.{ActorSystemFixture, Resources}
import ch.epfl.bluebrain.nexus.iam.auth.AccessToken
import ch.epfl.bluebrain.nexus.iam.client.types.AuthToken
import ch.epfl.bluebrain.nexus.iam.types.Permission
import ch.epfl.bluebrain.nexus.kg.TestHelper
import ch.epfl.bluebrain.nexus.kg.resources.file.File.{Digest, FileAttributes, FileDescription}
import ch.epfl.bluebrain.nexus.kg.resources.Id
import ch.epfl.bluebrain.nexus.kg.resources.ProjectIdentifier.ProjectRef
import ch.epfl.bluebrain.nexus.kg.storage.Storage.RemoteDiskStorage
import ch.epfl.bluebrain.nexus.storage.client.StorageClient
import ch.epfl.bluebrain.nexus.storage.client.types.FileAttributes.{Digest => StorageDigest}
import ch.epfl.bluebrain.nexus.storage.client.types.{FileAttributes => StorageFileAttributes}
import org.mockito.{IdiomaticMockito, Mockito}
import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpecLike

class RemoteDiskStorageOperationsSpec
    extends ActorSystemFixture("RemoteDiskStorageOperationsSpec")
    with AnyWordSpecLike
    with Matchers
    with BeforeAndAfter
    with IdiomaticMockito
    with IOEitherValues
    with Resources
    with TestHelper {

  private val endpoint = "http://nexus.example.com/v1"

  // TODO: Remove when migrating ADMIN client
  implicit private def oldTokenConversion(implicit token: Option[AccessToken]): Option[AuthToken] =
    token.map(t => AuthToken(t.value))

  sealed trait Ctx {
    val cred                                = genString()
    implicit val token: Option[AccessToken] = Some(AccessToken(cred))
    val path                                = Uri.Path(s"${genString()}/${genString()}")
    // format: off
    val storage = RemoteDiskStorage(ProjectRef(genUUID), genIri, 1L, false, false, "SHA-256", endpoint, Some(cred), genString(), Permission.unsafe(genString()), Permission.unsafe(genString()), 1024L)
    val attributes = FileAttributes(s"$endpoint/${storage.folder}/$path", path, s"${genString()}.json", `application/json`, 12L, Digest("SHA-256", genString()))
    // format: on
  }

  private val client = mock[StorageClient[IO]]

  before {
    Mockito.reset(client)
  }

  "RemoteDiskStorageOperations" should {

    "verify when storage exists" in new Ctx {
      client.exists(storage.folder) shouldReturn IO(true)
      val verify = new RemoteDiskStorageOperations.Verify[IO](storage, client)
      verify.apply.accepted
    }

    "verify when storage does not exists" in new Ctx {
      client.exists(storage.folder) shouldReturn IO(false)
      val verify = new RemoteDiskStorageOperations.Verify[IO](storage, client)
      verify.apply
        .rejected[
          String
        ] shouldEqual s"Folder '${storage.folder}' does not exists on the endpoint '${storage.endpoint}'"
    }

    "fetch file" in new Ctx {
      val source       = genSource
      client.getFile(storage.folder, path) shouldReturn IO(source)
      val fetch        = new RemoteDiskStorageOperations.Fetch[IO](storage, client)
      val resultSource = fetch.apply(attributes).ioValue
      consume(resultSource) shouldEqual consume(source)
    }

    "link file" in new Ctx {
      val id               = Id(storage.ref, genIri)
      val sourcePath       = Uri.Path(s"${genString()}/${genString()}")
      val destRelativePath = Uri.Path(mangle(storage.ref, attributes.uuid, attributes.filename))
      client.moveFile(storage.folder, sourcePath, destRelativePath) shouldReturn
        IO(
          StorageFileAttributes(
            attributes.location,
            attributes.bytes,
            StorageDigest(attributes.digest.algorithm, attributes.digest.value),
            attributes.mediaType
          )
        )
      val link             = new RemoteDiskStorageOperations.Link[IO](storage, client)
      link
        .apply(id, FileDescription(attributes.uuid, attributes.filename, Some(attributes.mediaType)), sourcePath)
        .ioValue shouldEqual attributes.copy(path = destRelativePath)
    }
  }
} 
Example 165
Source File: IdentitiesRoutesSpec.scala    From nexus   with Apache License 2.0 5 votes vote down vote up
package ch.epfl.bluebrain.nexus.iam.routes

import akka.http.scaladsl.model.StatusCodes
import akka.http.scaladsl.model.headers.OAuth2BearerToken
import akka.http.scaladsl.testkit.ScalatestRouteTest
import ch.epfl.bluebrain.nexus.iam.acls.Acls
import ch.epfl.bluebrain.nexus.iam.auth.{AccessToken, TokenRejection}
import ch.epfl.bluebrain.nexus.iam.realms._
import ch.epfl.bluebrain.nexus.iam.testsyntax._
import ch.epfl.bluebrain.nexus.iam.types.Caller
import ch.epfl.bluebrain.nexus.iam.types.IamError.InvalidAccessToken
import ch.epfl.bluebrain.nexus.iam.types.Identity.{Anonymous, Authenticated, User}
import ch.epfl.bluebrain.nexus.service.config.Settings
import ch.epfl.bluebrain.nexus.service.marshallers.instances._
import ch.epfl.bluebrain.nexus.service.routes.Routes
import ch.epfl.bluebrain.nexus.util.Resources
import com.typesafe.config.{Config, ConfigFactory}
import io.circe.Json
import monix.eval.Task
import org.mockito.matchers.MacroBasedMatchers
import org.mockito.{IdiomaticMockito, Mockito}
import org.scalatest.BeforeAndAfter
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpecLike

import scala.concurrent.duration._

//noinspection TypeAnnotation
class IdentitiesRoutesSpec
    extends AnyWordSpecLike
    with Matchers
    with ScalatestRouteTest
    with BeforeAndAfter
    with MacroBasedMatchers
    with Resources
    with ScalaFutures
    with IdiomaticMockito {

  implicit override def patienceConfig: PatienceConfig = PatienceConfig(3.seconds, 100.milliseconds)

  override def testConfig: Config = ConfigFactory.load("test.conf")

  private val config        = Settings(system).serviceConfig
  implicit private val http = config.http

  private val realms: Realms[Task] = mock[Realms[Task]]
  private val acls: Acls[Task]     = mock[Acls[Task]]

  before {
    Mockito.reset(realms, acls)
  }

  "The IdentitiesRoutes" should {
    val routes = Routes.wrap(new IdentitiesRoutes(acls, realms).routes)
    "return forbidden" in {
      val err = InvalidAccessToken(TokenRejection.InvalidAccessToken)
      realms.caller(any[AccessToken]) shouldReturn Task.raiseError(err)
      Get("/identities").addCredentials(OAuth2BearerToken("token")) ~> routes ~> check {
        status shouldEqual StatusCodes.Unauthorized
      }
    }
    "return anonymous" in {
      realms.caller(any[AccessToken]) shouldReturn Task.pure(Caller.anonymous)
      Get("/identities") ~> routes ~> check {
        status shouldEqual StatusCodes.OK
        responseAs[Json].sort shouldEqual jsonContentOf("/identities/anonymous.json")
      }
    }
    "return all identities" in {
      val user   = User("theuser", "therealm")
      val auth   = Authenticated("therealm")
      val caller = Caller(user, Set(user, Anonymous, auth))
      realms.caller(any[AccessToken]) shouldReturn Task.pure(caller)
      Get("/identities").addCredentials(OAuth2BearerToken("token")) ~> routes ~> check {
        status shouldEqual StatusCodes.OK
        responseAs[Json].sort shouldEqual jsonContentOf("/identities/identities.json")
      }
    }
  }
} 
Example 166
Source File: FlinkTestBase.scala    From flink-tensorflow   with Apache License 2.0 5 votes vote down vote up
package org.apache.flink.contrib.tensorflow.util

import org.apache.flink.runtime.minicluster.LocalFlinkMiniCluster
import org.apache.flink.streaming.util.TestStreamEnvironment
import org.apache.flink.test.util.TestBaseUtils
import org.junit.rules.TemporaryFolder
import org.scalatest.{BeforeAndAfter, Suite}

// Copied from Apache Flink.


trait FlinkTestBase extends BeforeAndAfter {
  that: Suite =>

  var cluster: Option[LocalFlinkMiniCluster] = None
  val parallelism = 4

  protected val tempFolder = new TemporaryFolder()

  before {
    tempFolder.create()
    val cl = TestBaseUtils.startCluster(
      1,
      parallelism,
      false,
      false,
      true)

    TestStreamEnvironment.setAsContext(cl, parallelism)

    cluster = Some(cl)
  }

  after {
    TestStreamEnvironment.unsetAsContext()
    cluster.foreach(c => TestBaseUtils.stopCluster(c, TestBaseUtils.DEFAULT_TIMEOUT))
    tempFolder.delete()
  }

} 
Example 167
Source File: TestTypeReduction.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.types

import at.forsyte.apalache.tla.lir.TestingPredefs
import org.junit.runner.RunWith
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.junit.JUnitRunner

@RunWith( classOf[JUnitRunner] )
class TestTypeReduction extends FunSuite with TestingPredefs with BeforeAndAfter {

  var gen = new SmtVarGenerator
  var tr  = new TypeReduction( gen )

  before {
    gen = new SmtVarGenerator
    tr = new TypeReduction( gen )
  }

  test( "Test nesting" ) {
    val tau = FunT( IntT, SetT( IntT ) )
    val m = Map.empty[TypeVar, SmtTypeVariable]
    val rr = tr( tau, m )
    assert( rr.t == fun( int, set( int ) ) )
  }

  test("Test tuples"){
    val tau = SetT( FunT( TupT( IntT, StrT ), SetT( IntT ) ) )
    val m = Map.empty[TypeVar, SmtTypeVariable]
    val rr = tr(tau, m)
    val idx = SmtIntVariable( 0 )
    assert( rr.t == set( fun( tup( idx ), set( int ) ) ) )
    assert( rr.phi.contains( hasIndex( idx, 0, int ) ) )
    assert( rr.phi.contains( hasIndex( idx, 1, str ) ) )
  }
} 
Example 168
Source File: RetentionPolicyManagementSuite.scala    From scala-influxdb-client   with MIT License 5 votes vote down vote up
package com.paulgoldbaum.influxdbclient

import org.scalatest.BeforeAndAfter

class RetentionPolicyManagementSuite extends CustomTestSuite with BeforeAndAfter {

  val database = new Database("_test_database_rp", new HttpClient("localhost", 8086, false, databaseUsername, databasePassword))
  val retentionPolicyName = "test_retention_policy"

  before {
    await(database.create())
  }

  after {
    await(database.drop())
  }

  test("A retention policy can be created") {
    await(database.createRetentionPolicy(retentionPolicyName, "1w", 1, default = true))
    val policies = await(database.showRetentionPolicies())
    assert(policies.series.head.records.length == 2)
    val policy = policies.series.head.records(1)
    assert(policy("name") == retentionPolicyName)
    assert(policy("duration") == "168h0m0s")
    assert(policy("replicaN") == 1)
    assert(policy("default") == true)
  }

  test("A retention policy can be created and deleted") {
    await(database.createRetentionPolicy(retentionPolicyName, "1w", 1, default = false))
    await(database.dropRetentionPolicy(retentionPolicyName))

    val policiesAfterDeleting = await(database.showRetentionPolicies())
    assert(policiesAfterDeleting.series.head.records.length == 1)
  }

  test("A retention policy's duration can be altered") {
    await(database.createRetentionPolicy(retentionPolicyName, "1w", 1, default = false))
    await(database.alterRetentionPolicy(retentionPolicyName, "2w"))
    val policies = await(database.showRetentionPolicies())
    val policy = policies.series.head.records(1)
    assert(policy("name") == retentionPolicyName)
    assert(policy("duration") == "336h0m0s")
  }

  test("A retention policy's replication can be altered") {
    await(database.createRetentionPolicy(retentionPolicyName, "1w", 1, default = false))
    await(database.alterRetentionPolicy(retentionPolicyName, replication = 2))
    val policies = await(database.showRetentionPolicies())
    val policy = policies.series.head.records(1)
    assert(policy("name") == retentionPolicyName)
    assert(policy("replicaN") == 2)
  }

  test("A retention policy's defaultness can be altered") {
    await(database.createRetentionPolicy(retentionPolicyName, "1w", 1, default = false))
    await(database.alterRetentionPolicy(retentionPolicyName, default = true))
    val policies = await(database.showRetentionPolicies())
    val policy = policies.series.head.records(1)
    assert(policy("name") == retentionPolicyName)
    assert(policy("default") == true)
  }

  test("At least one parameter has to be altered") {
    await(database.createRetentionPolicy(retentionPolicyName, "1w", 1, default = false))
    try {
      await(database.alterRetentionPolicy(retentionPolicyName))
      fail("Exception was not thrown")
    } catch {
      case e: InvalidRetentionPolicyParametersException => // expected
    }
  }

} 
Example 169
Source File: UdpClientSuite.scala    From scala-influxdb-client   with MIT License 5 votes vote down vote up
package com.paulgoldbaum.influxdbclient

import org.scalatest.BeforeAndAfter

class UdpClientSuite extends CustomTestSuite with BeforeAndAfter {

  val databaseName = "_test_database_udp"
  val database = influxdb.selectDatabase(databaseName)

  before {
    await(database.create())
  }

  after {
    await(database.drop())
  }

  test("Points can be written") {
    val udpClient = InfluxDB.udpConnect("localhost", 8086)
    udpClient.write(Point("test_measurement").addField("value", 123).addTag("tag_key", "tag_value"))
    udpClient.close()
    Thread.sleep(1000) // to allow flushing to happen inside influx

    val database = influxdb.selectDatabase(databaseName)
    val result = await(database.query("SELECT * FROM test_measurement"))
    assert(result.series.head.records.length == 1)
    assert(result.series.head.records.head("value") == 123)
  }

  test("Points can be written in bulk") {
    val udpClient = InfluxDB.udpConnect("localhost", 8086)
    val timestamp = System.currentTimeMillis()
    udpClient.bulkWrite(List(
      Point("test_measurement", timestamp).addField("value", 1).addTag("tag_key", "tag_value"),
      Point("test_measurement", timestamp + 1).addField("value", 2).addTag("tag_key", "tag_value"),
      Point("test_measurement", timestamp + 2).addField("value", 3).addTag("tag_key", "tag_value")
    ))
    udpClient.close()
    Thread.sleep(1000) // to allow flushing to happen inside influx

    val database = influxdb.selectDatabase(databaseName)
    val result = await(database.query("SELECT * FROM test_measurement"))
    assert(result.series.head.records.length == 3)
  }

} 
Example 170
Source File: DatabaseManagementSuite.scala    From scala-influxdb-client   with MIT License 5 votes vote down vote up
package com.paulgoldbaum.influxdbclient

import org.scalatest.BeforeAndAfter

class DatabaseManagementSuite extends CustomTestSuite with BeforeAndAfter {

  val databaseName = "_test_database_mgmnt"
  val database = influxdb.selectDatabase(databaseName)

  before {
    val exists = await(database.exists())
    if (exists) {
      await(database.drop())
    }
  }

  test("A database can be created and dropped") {
    await(database.create())
    assert(await(database.exists()))

    await(database.drop())
    assert(!await(database.exists()))
  }
} 
Example 171
Source File: SkinnySpecSupport.scala    From scala-ddd-base   with MIT License 5 votes vote down vote up
package com.github.j5ik2o.dddbase.example.repository.util

import org.scalatest.{ BeforeAndAfter, BeforeAndAfterAll, Suite }
import scalikejdbc.config.DBs
import scalikejdbc.{ ConnectionPool, GlobalSettings, LoggingSQLAndTimeSettings }

trait SkinnySpecSupport extends BeforeAndAfter with BeforeAndAfterAll with JdbcSpecSupport {
  self: Suite with FlywayWithMySQLSpecSupport =>

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    Class.forName("com.mysql.jdbc.Driver")
    ConnectionPool.singleton(s"jdbc:mysql://localhost:${jdbcPort}/dddbase?useSSL=false", "dddbase", "dddbase")
    GlobalSettings.loggingSQLAndTime = LoggingSQLAndTimeSettings(
      enabled = true,
      logLevel = 'DEBUG,
      warningEnabled = true,
      warningThresholdMillis = 1000L,
      warningLogLevel = 'WARN
    )
  }

  override protected def afterAll(): Unit = {
    DBs.closeAll()
    super.afterAll()
  }

} 
Example 172
Source File: Slick3SpecSupport.scala    From scala-ddd-base   with MIT License 5 votes vote down vote up
package com.github.j5ik2o.dddbase.example.repository.util

import com.typesafe.config.ConfigFactory
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{ BeforeAndAfter, BeforeAndAfterAll, Suite }
import slick.basic.DatabaseConfig
import slick.jdbc.SetParameter.SetUnit
import slick.jdbc.{ JdbcProfile, SQLActionBuilder }

import scala.concurrent.Future

trait Slick3SpecSupport extends BeforeAndAfter with BeforeAndAfterAll with ScalaFutures with JdbcSpecSupport {
  self: Suite with FlywayWithMySQLSpecSupport =>

  private var _dbConfig: DatabaseConfig[JdbcProfile] = _

  private var _profile: JdbcProfile = _

  protected def dbConfig = _dbConfig

  protected def profile = _profile

  after {
    implicit val ec = dbConfig.db.executor.executionContext
    val futures = tables.map { table =>
      val q = SQLActionBuilder(List(s"TRUNCATE TABLE $table"), SetUnit).asUpdate
      dbConfig.db.run(q)
    }
    Future.sequence(futures).futureValue
  }

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    val config = ConfigFactory.parseString(s"""
         |dddbase {
         |  profile = "slick.jdbc.MySQLProfile$$"
         |  db {
         |    connectionPool = disabled
         |    driver = "com.mysql.jdbc.Driver"
         |    url = "jdbc:mysql://localhost:$jdbcPort/dddbase?useSSL=false"
         |    user = "dddbase"
         |    password = "dddbase"
         |  }
         |}
      """.stripMargin)
    _dbConfig = DatabaseConfig.forConfig[JdbcProfile]("dddbase", config)
    _profile = dbConfig.profile
  }

  override protected def afterAll(): Unit = {
    dbConfig.db.shutdown
    super.afterAll()
  }

} 
Example 173
Source File: BaseAwsClientTest.scala    From aws-spi-akka-http   with Apache License 2.0 5 votes vote down vote up
package com.github.matsluni.akkahttpspi

import java.net.URI

import com.dimafeng.testcontainers.{ForAllTestContainer, GenericContainer}
import com.github.matsluni.akkahttpspi.testcontainers.LocalStackReadyLogWaitStrategy
import org.scalatest.concurrent.{Eventually, Futures, IntegrationPatience}
import org.scalatest.BeforeAndAfter
import software.amazon.awssdk.core.SdkClient
import software.amazon.awssdk.regions.Region

import scala.util.Random
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec

trait BaseAwsClientTest[C <: SdkClient]
  extends AnyWordSpec
    with Matchers
    with Futures
    with Eventually
    with BeforeAndAfter
    with IntegrationPatience
    with ForAllTestContainer {

  lazy val defaultRegion: Region = Region.EU_WEST_1

  def client: C
  def exposedServicePort: Int
  val container: GenericContainer

  def endpoint = new URI(s"http://localhost:${container.mappedPort(exposedServicePort)}")
  def randomIdentifier(length: Int): String = Random.alphanumeric.take(length).mkString
}

trait LocalstackBaseAwsClientTest[C <: SdkClient] extends BaseAwsClientTest[C] {
  def service: String

  lazy val exposedServicePort: Int = LocalstackServicePorts.services(service)

  override lazy val container: GenericContainer =
    new GenericContainer(
      dockerImage = "localstack/localstack",
      exposedPorts = Seq(exposedServicePort),
      env = Map("SERVICES" -> service),
      waitStrategy = Some(LocalStackReadyLogWaitStrategy)
    )
}

object LocalstackServicePorts {
  //services and ports based on https://github.com/localstack/localstack
  val services: Map[String, Int] = Map(
    "s3" -> 4572,
    "sqs" -> 4576,
    "sns" -> 4575,
    "dynamodb" -> 4569
  )
} 
Example 174
Source File: WordCountSpec.scala    From CSYE7200_Old   with MIT License 5 votes vote down vote up
package edu.neu.coe.csye7200.asstswc

import org.apache.spark.sql.SparkSession
import org.scalatest.tagobjects.Slow
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class WordCountSpec extends FlatSpec with Matchers with BeforeAndAfter  {

  implicit var spark: SparkSession = _

  before {
    spark = SparkSession
      .builder()
      .appName("WordCount")
      .master("local[*]")
      .getOrCreate()
  }

  after {
    if (spark != null) {
      spark.stop()
    }
  }

  behavior of "Spark"

  ignore should "work for wordCount" taggedAs Slow in {
    WordCount.wordCount(spark.read.textFile(getClass.getResource("WordCount.txt").getPath).rdd," ").collect() should matchPattern {
      case Array(("Hello",3),("World",3),("Hi",1)) =>
    }
  }

} 
Example 175
Source File: WordCountSpark2ItSpec.scala    From CSYE7200_Old   with MIT License 5 votes vote down vote up
package edu.neu.coe.csye7200

import org.apache.spark.sql.SparkSession
import org.scalatest.tagobjects.Slow
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class WordCountSpark2ItSpec extends FlatSpec with Matchers with BeforeAndAfter  {

  implicit var spark: SparkSession = _

  before {
    spark = SparkSession
      .builder()
      .appName("WordCount")
      .master("local[*]")
      .getOrCreate()
  }

  after {
    if (spark != null) {
      spark.stop()
    }
  }

  behavior of "Spark"

  it should "work for wordCount" taggedAs Slow in {
    WordCount.wordCount(spark.read.textFile(getClass.getResource("/WordCount.txt").getPath).rdd," ").collect() should matchPattern {
      case Array(("Hello",3),("World",3),("Hi",1)) =>
    }
  }

  it should "work for wordCount2" taggedAs Slow in {
    WordCount.wordCount2(spark.read.textFile(getClass.getResource("/WordCount2.txt").getPath).rdd," ").collect() should matchPattern {
      case Array(("hi",2), ("hello",1), ("and",1), ("world",2)) =>
    }
  }

  it should "work for wordCount3" taggedAs Slow in {
    WordCount.wordCount3(spark.read.textFile(getClass.getResource("/WordCount2.txt").getPath).rdd," ").collect() should matchPattern {
      case Array(("hi",2), ("hello",1), ("and",1), ("world",2)) =>
    }
  }

  it should "work for Dataset and Spark SQL" taggedAs Slow in {
    val ds = spark.read.textFile(getClass.getResource("/WordCount.txt").getPath)
    val words = WordCount.createWordDS(ds," ")
    words.createTempView("words")
    words.cache()
    spark.sql("select count(*) from words").head().getLong(0) shouldBe 7
    spark.sql("select word, count(*) from words group by word").collect().map(_.toString()) shouldBe
      Array("[World,3]","[Hi,1]","[Hello,3]")
  }

} 
Example 176
Source File: WordCountItSpec.scala    From CSYE7200_Old   with MIT License 5 votes vote down vote up
package edu.neu.coe.csye7200

import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.tagobjects.Slow
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class WordCountItSpec extends FlatSpec with Matchers with BeforeAndAfter  {

  private var sc: SparkContext = _

  before {
    sc = new SparkContext(new SparkConf().setAppName("WordCount").setMaster("local[*]"))
  }

  after {
    if (sc != null) {
      sc.stop()
    }
  }

  "result" should "right for wordCount" taggedAs Slow in {
    WordCount.wordCount(sc.textFile(getClass.getResource("/WordCount.txt").getPath)," ").collect() should matchPattern {
      case Array(("Hello",3),("World",3),("Hi",1)) =>
    }
  }
} 
Example 177
Source File: EmbeddedKafkaServer.scala    From KafkaPlayground   with GNU General Public License v3.0 5 votes vote down vote up
package com.github.pedrovgs.kafkaplayground.utils

import cakesolutions.kafka.KafkaProducerRecord
import cakesolutions.kafka.testkit.KafkaServer
import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
import org.scalatest.{BeforeAndAfter, Suite}

import scala.concurrent.duration._

trait EmbeddedKafkaServer extends BeforeAndAfter {
  this: Suite =>

  private var kafkaServer: KafkaServer = _

  before {
    kafkaServer = new KafkaServer
    startKafkaServer()
  }

  after {
    stopKafkaServer()
  }

  def startKafkaServer(): Unit = kafkaServer.startup()

  def stopKafkaServer(): Unit = kafkaServer.close()

  def kafkaServerAddress(): String = s"localhost:${kafkaServer.kafkaPort}"

  def zookeeperServerAddress(): String = s"localhost:${kafkaServer.zookeeperPort}"

  def recordsForTopic(topic: String, expectedNumberOfRecords: Int = 1): Iterable[String] =
    kafkaServer
      .consume[String, String](
        topic = topic,
        keyDeserializer = new StringDeserializer,
        valueDeserializer = new StringDeserializer,
        expectedNumOfRecords = expectedNumberOfRecords,
        timeout = 10.seconds.toMillis
      )
      .map(_._2)

  def produceMessage(topic: String, content: String): Unit =
    kafkaServer.produce(
      topic = topic,
      records = Seq(KafkaProducerRecord[String, String](topic = topic, value = content)),
      keySerializer = new StringSerializer(),
      valueSerializer = new StringSerializer()
    )

} 
Example 178
Source File: TheFlashTweetsProducerSpec.scala    From KafkaPlayground   with GNU General Public License v3.0 5 votes vote down vote up
package com.github.pedrovgs.kafkaplayground.flash

import java.util.Date

import com.danielasfregola.twitter4s.entities.{Geo, Tweet}
import com.github.pedrovgs.kafkaplayground.utils.EmbeddedKafkaServer
import org.scalatest.concurrent.{PatienceConfiguration, ScalaFutures}
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

import scala.concurrent.duration._

object TheFlashTweetsProducerSpec {
  private val unknownLocationFlashTopic = "the-flash-tweets"
  private val locatedFlashTopic         = "the-flash-tweets-with-location"
  private val anyNotGeoLocatedTweet = Tweet(
    created_at = new Date(),
    id = 1L,
    id_str = "1",
    source = "source",
    text = "I've seen the fastest man alive!"
  )

  private val anyGeoLocatedTweet = anyNotGeoLocatedTweet.copy(
    geo = Some(Geo(Seq(12.0, 11.0), "lat-long"))
  )
}

class TheFlashTweetsProducerSpec
    extends FlatSpec
    with Matchers
    with EmbeddedKafkaServer
    with ScalaFutures
    with BeforeAndAfter {

  import TheFlashTweetsProducerSpec._

  "TheFlashTweetsProducer" should "return the tweet passed as param if the tweet has no geo location info" in {
    val result = produceTweet(anyNotGeoLocatedTweet)

    result shouldBe anyNotGeoLocatedTweet
  }

  it should "send a record with just the text of the tweet to the the-flash-tweets topic if the tweet has no geo location info" in {
    produceTweet(anyNotGeoLocatedTweet)

    val records = recordsForTopic(unknownLocationFlashTopic)

    val expectedMessage =
      s"""
         |{
         |  "message": "I've seen the fastest man alive!"
         |}
        """.stripMargin
    records.size shouldBe 1
    records.head shouldBe expectedMessage
  }

  it should "return the tweet passed as param if the tweet has geo location info" in {
    val result = produceTweet(anyGeoLocatedTweet)

    result shouldBe anyGeoLocatedTweet
  }

  it should "send a record with just the text of the tweet to the the-flash-tweets-with-location topic if the tweet has geo location info" in {
    produceTweet(anyGeoLocatedTweet)

    val records = recordsForTopic(locatedFlashTopic)

    val expectedMessage =
      s"""
         |{
         |  "latitude": 12.0,
         |  "longitude": 11.0,
         |  "id": "1",
         |  "message": "I've seen the fastest man alive!"
         |}
       """.stripMargin
    records.size shouldBe 1
    records.head shouldBe expectedMessage
  }

  it should "send a not geo-located tweet to a topic and another geo-located to the other topic configured" in {
    produceTweet(anyNotGeoLocatedTweet)
    produceTweet(anyGeoLocatedTweet)

    val locatedTopicRecords         = recordsForTopic(locatedFlashTopic)
    val unknownLocationTopicRecords = recordsForTopic(unknownLocationFlashTopic)

    locatedTopicRecords.size shouldBe 1
    unknownLocationTopicRecords.size shouldBe 1
  }

  private def produceTweet(tweet: Tweet) =
    new TheFlashTweetsProducer(kafkaServerAddress())(tweet)
      .futureValue(timeout = PatienceConfiguration.Timeout(1.seconds))

} 
Example 179
Source File: EntitySupport.scala    From akka-cqrs   with Apache License 2.0 5 votes vote down vote up
package com.productfoundry.akka.cqrs

import akka.actor.{ActorRef, ActorSystem, PoisonPill, Terminated}
import akka.testkit.{ImplicitSender, TestKit}
import akka.util.Timeout
import org.scalatest.concurrent.Eventually
import org.scalatest.time.{Millis, Second, Span}
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike}

import scala.concurrent.duration._

abstract class EntitySupport(_system: ActorSystem)
  extends TestKit(_system)
  with ImplicitSender
  with WordSpecLike
  with Matchers
  with BeforeAndAfterAll
  with BeforeAndAfter
  with Eventually {

  
  override def afterAll(): Unit = {
    TestKit.shutdownActorSystem(system)
  }
} 
Example 180
Source File: MetadataTest.scala    From spark-pagerank   with MIT License 5 votes vote down vote up
package com.soundcloud.spark.pagerank

import java.io.File

import org.apache.commons.io.FileUtils
import org.scalatest.{ BeforeAndAfter, FunSuite, Matchers }

class MetadataTest
  extends FunSuite
  with BeforeAndAfter
  with Matchers
  with SparkTesting {

  val path = "target/test/MetadataTest"
  val metadata = Metadata(numVertices=1)

  before {
    FileUtils.deleteDirectory(new File(path))
  }

  test("save and load") {
    Metadata.save(spark, metadata, path)
    Metadata.load(spark, path) shouldBe (metadata)
  }
} 
Example 181
Source File: GraphBuilderAppTest.scala    From spark-pagerank   with MIT License 5 votes vote down vote up
package com.soundcloud.spark.pagerank

import java.io.File

import org.apache.commons.io.FileUtils
import org.scalatest.{ BeforeAndAfter, FunSuite, Matchers }

class GraphBuilderAppTest
  extends FunSuite
  with BeforeAndAfter
  with Matchers
  with GraphTesting
  with SparkTesting {

  val path = "target/test/GraphBuilderAppTest"

  before {
    FileUtils.deleteDirectory(new File(path))
  }

  // TODO(jd): design a better integration test as this just runs the app without assertions
  test("integration test") {
    val options = new GraphBuilderApp.Options()
    options.output = path
    options.numPartitions = 1

    val input = spark.sparkContext.parallelize(Seq(
      (1, 5, 1.0),
      (2, 1, 1.0),
      (3, 1, 1.0),
      (4, 2, 1.0),
      (4, 3, 1.0),
      (5, 3, 1.0),
      (5, 4, 1.0)
    ).map(_.productIterator.toSeq.mkString("\t")))

    GraphBuilderApp.runFromInputs(options, spark, input)
  }
} 
Example 182
Source File: PageRankAppTest.scala    From spark-pagerank   with MIT License 5 votes vote down vote up
package com.soundcloud.spark.pagerank

import java.io.File

import org.apache.commons.io.FileUtils
import org.apache.spark.storage.StorageLevel
import org.scalatest.{ BeforeAndAfter, Matchers, FunSuite }

class PageRankAppTest
  extends FunSuite
  with BeforeAndAfter
  with Matchers
  with GraphTesting
  with SparkTesting {

  val path = "target/test/PageRankAppTest"

  before {
    FileUtils.deleteDirectory(new File(path))
  }

  // TODO(jd): design a better integration test as this just runs the app without assertions
  test("integration test") {
    val options = new PageRankApp.Options()
    options.output = path

    val numVertices = 5
    val prior = 1.0 / numVertices
    val stats = Seq(s"numVertices,$numVertices")

    val edges = spark.sparkContext.parallelize(Seq[OutEdgePair](
      // node 1 is dangling
      (2, OutEdge(1, 1.0)),
      (3, OutEdge(1, 1.0)),
      (4, OutEdge(2, 0.5)),
      (4, OutEdge(3, 0.5)),
      (5, OutEdge(3, 0.5)),
      (5, OutEdge(4, 0.5))
    ))
    val vertices = spark.sparkContext.parallelize(Seq[RichVertexPair](
      (1, VertexMetadata(prior, true)),
      (2, VertexMetadata(prior, false)),
      (3, VertexMetadata(prior, false)),
      (4, VertexMetadata(prior, false)),
      (5, VertexMetadata(prior, false))
    ))
    val graph = PageRankGraph(
      numVertices,
      edges.persist(StorageLevel.MEMORY_ONLY),
      vertices.persist(StorageLevel.MEMORY_ONLY)
    )

    PageRankApp.runFromInputs(
      spark,
      options,
      graph,
      priorsOpt = None
    )
  }
} 
Example 183
Source File: PatternElementWriterTest.scala    From morpheus   with Apache License 2.0 5 votes vote down vote up
package org.opencypher.okapi.neo4j.io

import org.neo4j.driver.v1.Values
import org.opencypher.okapi.api.value.CypherValue.CypherMap
import org.opencypher.okapi.neo4j.io.Neo4jHelpers.Neo4jDefaults.metaPropertyKey
import org.opencypher.okapi.neo4j.io.Neo4jHelpers._
import org.opencypher.okapi.neo4j.io.testing.Neo4jServerFixture
import org.opencypher.okapi.testing.Bag._
import org.opencypher.okapi.testing.BaseTestSuite
import org.scalatest.BeforeAndAfter

import scala.collection.immutable

class PatternElementWriterTest extends BaseTestSuite with Neo4jServerFixture with BeforeAndAfter {

  it("can write nodes") {
    ElementWriter.createNodes(
      inputNodes.toIterator,
      Array(metaPropertyKey, "val1", "val2", "val3", null),
      neo4jConfig,
      Set("Foo", "Bar", "Baz")
    )(rowToListValue)

    val expected = inputNodes.map { node =>
      CypherMap(
        s"n.$metaPropertyKey" -> node(0),
        "n.val1" -> node(1),
        "n.val2" -> node(2),
        "n.val3" -> node(3)
      )
    }.toBag

    val result = neo4jConfig.cypherWithNewSession(s"MATCH (n) RETURN n.$metaPropertyKey, n.val1, n.val2, n.val3").map(CypherMap).toBag
    result should equal(expected)
  }

  it("can write relationships") {
    ElementWriter.createRelationships(
      inputRels.toIterator,
      1,
      2,
      Array(metaPropertyKey, null, null, "val3"),
      neo4jConfig,
      "REL",
      None
    )(rowToListValue)

    val expected = inputRels.map { rel =>
      CypherMap(
        s"r.$metaPropertyKey" -> rel(0),
        "r.val3" -> rel(3)
      )
    }.toBag

    val result = neo4jConfig.cypherWithNewSession(s"MATCH ()-[r]->() RETURN r.$metaPropertyKey, r.val3").map(CypherMap).toBag
    result should equal(expected)
  }

  override def dataFixture: String = ""

  private def rowToListValue(data: Array[AnyRef]) = Values.value(data.map(Values.value): _*)

  private val numberOfNodes = 10
  val inputNodes: immutable.IndexedSeq[Array[AnyRef]] = (1 to numberOfNodes).map { i =>
    Array[AnyRef](
      i.asInstanceOf[AnyRef],
      i.asInstanceOf[AnyRef],
      i.toString.asInstanceOf[AnyRef],
      (i % 2 == 0).asInstanceOf[AnyRef],
      (i+1).asInstanceOf[AnyRef]
    )
  }

  val inputRels: immutable.IndexedSeq[Array[AnyRef]] = (2 to numberOfNodes).map { i =>
    Array[AnyRef](
      i.asInstanceOf[AnyRef],
      (i - 1).asInstanceOf[AnyRef],
      i.asInstanceOf[AnyRef],
      (i % 2 == 0).asInstanceOf[AnyRef]
    )
  }
} 
Example 184
Source File: FlumeStreamSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.util.concurrent.ConcurrentLinkedQueue

import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.language.postfixOps

import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}

class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
  val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")
  var ssc: StreamingContext = null

  test("flume input stream") {
    testFlumeStream(testCompression = false)
  }

  test("flume input compressed stream") {
    testFlumeStream(testCompression = true)
  }

  
  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }
} 
Example 185
Source File: ResolveInlineTablesSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.catalyst.analysis

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Literal, Rand}
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.{LongType, NullType}


class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter {

  private def lit(v: Any): Literal = Literal(v)

  test("validate inputs are foldable") {
    ResolveInlineTables.validateInputEvaluable(
      UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)))))

    // nondeterministic (rand) should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1)))))
    }

    // aggregate should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1))))))
    }

    // unresolved attribute should not work
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputEvaluable(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A")))))
    }
  }

  test("validate input dimensions") {
    ResolveInlineTables.validateInputDimension(
      UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2)))))

    // num alias != data dimension
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2)))))
    }

    // num alias == data dimension, but data themselves are inconsistent
    intercept[AnalysisException] {
      ResolveInlineTables.validateInputDimension(
        UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22)))))
    }
  }

  test("do not fire the rule if not all expressions are resolved") {
    val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A"))))
    assert(ResolveInlineTables(table) == table)
  }

  test("convert") {
    val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted = ResolveInlineTables.convert(table)

    assert(converted.output.map(_.dataType) == Seq(LongType))
    assert(converted.data.size == 2)
    assert(converted.data(0).getLong(0) == 1L)
    assert(converted.data(1).getLong(0) == 2L)
  }

  test("nullability inference in convert") {
    val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L))))
    val converted1 = ResolveInlineTables.convert(table1)
    assert(!converted1.schema.fields(0).nullable)

    val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType))))
    val converted2 = ResolveInlineTables.convert(table2)
    assert(converted2.schema.fields(0).nullable)
  }
} 
Example 186
Source File: RowDataSourceStrategySuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.execution.datasources

import java.sql.DriverManager
import java.util.Properties

import org.scalatest.BeforeAndAfter

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

class RowDataSourceStrategySuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
  import testImplicits._

  val url = "jdbc:h2:mem:testdb0"
  val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
  var conn: java.sql.Connection = null

  before {
    Utils.classForName("org.h2.Driver")
    // Extra properties that will be specified for our database. We need these to test
    // usage of parameters from OPTIONS clause in queries.
    val properties = new Properties()
    properties.setProperty("user", "testUser")
    properties.setProperty("password", "testPass")
    properties.setProperty("rowId", "false")

    conn = DriverManager.getConnection(url, properties)
    conn.prepareStatement("create schema test").executeUpdate()
    conn.prepareStatement("create table test.inttypes (a INT, b INT, c INT)").executeUpdate()
    conn.prepareStatement("insert into test.inttypes values (1, 2, 3)").executeUpdate()
    conn.commit()
    sql(
      s"""
        |CREATE TEMPORARY TABLE inttypes
        |USING org.apache.spark.sql.jdbc
        |OPTIONS (url '$url', dbtable 'TEST.INTTYPES', user 'testUser', password 'testPass')
      """.stripMargin.replaceAll("\n", " "))
  }

  after {
    conn.close()
  }

  test("SPARK-17673: Exchange reuse respects differences in output schema") {
    val df = sql("SELECT * FROM inttypes")
    val df1 = df.groupBy("a").agg("b" -> "min")
    val df2 = df.groupBy("a").agg("c" -> "min")
    val res = df1.union(df2)
    assert(res.distinct().count() == 2)  // would be 1 if the exchange was incorrectly reused
  }
} 
Example 187
Source File: AggregateHashMapSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false",
      "configuration parameter changed in test body")
  }
}

class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true",
      "configuration parameter changed in test body")
  }
}

class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with
BeforeAndAfter {

  protected override def beforeAll(): Unit = {
    sparkConf.set("spark.sql.codegen.fallback", "false")
    sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
    sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
    super.beforeAll()
  }

  // adding some checking after each test is run, assuring that the configs are not changed
  // in test code
  after {
    assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true",
      "configuration parameter changed in test body")
    assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true",
      "configuration parameter changed in test body")
  }
} 
Example 188
Source File: ExtensionServiceIntegrationSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler.cluster

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.deploy.yarn.config._
import org.apache.spark.internal.Logging


  before {
    val sparkConf = new SparkConf()
    sparkConf.set(SCHEDULER_SERVICES, Seq(classOf[SimpleExtensionService].getName()))
    sparkConf.setMaster("local").setAppName("ExtensionServiceIntegrationSuite")
    sc = new SparkContext(sparkConf)
  }

  test("Instantiate") {
    val services = new SchedulerExtensionServices()
    assertResult(Nil, "non-nil service list") {
      services.getServices
    }
    services.start(SchedulerExtensionServiceBinding(sc, applicationId))
    services.stop()
  }

  test("Contains SimpleExtensionService Service") {
    val services = new SchedulerExtensionServices()
    try {
      services.start(SchedulerExtensionServiceBinding(sc, applicationId))
      val serviceList = services.getServices
      assert(serviceList.nonEmpty, "empty service list")
      val (service :: Nil) = serviceList
      val simpleService = service.asInstanceOf[SimpleExtensionService]
      assert(simpleService.started.get, "service not started")
      services.stop()
      assert(!simpleService.started.get, "service not stopped")
    } finally {
      services.stop()
    }
  }
} 
Example 189
Source File: FailureSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils


class FailureSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  private val batchDuration: Duration = Milliseconds(1000)
  private val numBatches = 30
  private var directory: File = null

  before {
    directory = Utils.createTempDir()
  }

  after {
    if (directory != null) {
      Utils.deleteRecursively(directory)
    }
    StreamingContext.getActive().foreach { _.stop() }

    // Stop SparkContext if active
    SparkContext.getOrCreate(new SparkConf().setMaster("local").setAppName("bla")).stop()
  }

  test("multiple failures with map") {
    MasterFailureTest.testMap(directory.getAbsolutePath, numBatches, batchDuration)
  }

  test("multiple failures with updateStateByKey") {
    MasterFailureTest.testUpdateStateByKey(directory.getAbsolutePath, numBatches, batchDuration)
  }
} 
Example 190
Source File: InputInfoTrackerSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.scheduler

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.streaming.{Duration, StreamingContext, Time}

class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter {

  private var ssc: StreamingContext = _

  before {
    val conf = new SparkConf().setMaster("local[2]").setAppName("DirectStreamTacker")
    if (ssc == null) {
      ssc = new StreamingContext(conf, Duration(1000))
    }
  }

  after {
    if (ssc != null) {
      ssc.stop()
      ssc = null
    }
  }

  test("test report and get InputInfo from InputInfoTracker") {
    val inputInfoTracker = new InputInfoTracker(ssc)

    val streamId1 = 0
    val streamId2 = 1
    val time = Time(0L)
    val inputInfo1 = StreamInputInfo(streamId1, 100L)
    val inputInfo2 = StreamInputInfo(streamId2, 300L)
    inputInfoTracker.reportInfo(time, inputInfo1)
    inputInfoTracker.reportInfo(time, inputInfo2)

    val batchTimeToInputInfos = inputInfoTracker.getInfo(time)
    assert(batchTimeToInputInfos.size == 2)
    assert(batchTimeToInputInfos.keys === Set(streamId1, streamId2))
    assert(batchTimeToInputInfos(streamId1) === inputInfo1)
    assert(batchTimeToInputInfos(streamId2) === inputInfo2)
    assert(inputInfoTracker.getInfo(time)(streamId1) === inputInfo1)
  }

  test("test cleanup InputInfo from InputInfoTracker") {
    val inputInfoTracker = new InputInfoTracker(ssc)

    val streamId1 = 0
    val inputInfo1 = StreamInputInfo(streamId1, 100L)
    val inputInfo2 = StreamInputInfo(streamId1, 300L)
    inputInfoTracker.reportInfo(Time(0), inputInfo1)
    inputInfoTracker.reportInfo(Time(1), inputInfo2)

    inputInfoTracker.cleanup(Time(0))
    assert(inputInfoTracker.getInfo(Time(0))(streamId1) === inputInfo1)
    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)

    inputInfoTracker.cleanup(Time(1))
    assert(inputInfoTracker.getInfo(Time(0)).get(streamId1) === None)
    assert(inputInfoTracker.getInfo(Time(1))(streamId1) === inputInfo2)
  }
} 
Example 191
Source File: SparkListenerWithClusterSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.scheduler

import scala.collection.mutable

import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.scheduler.cluster.ExecutorInfo


  val WAIT_TIMEOUT_MILLIS = 10000

  before {
    sc = new SparkContext("local-cluster[2,1,1024]", "SparkListenerSuite")
  }

  test("SparkListener sends executor added message") {
    val listener = new SaveExecutorInfo
    sc.addSparkListener(listener)

    // This test will check if the number of executors received by "SparkListener" is same as the
    // number of all executors, so we need to wait until all executors are up
    sc.jobProgressListener.waitUntilExecutorsUp(2, 60000)

    val rdd1 = sc.parallelize(1 to 100, 4)
    val rdd2 = rdd1.map(_.toString)
    rdd2.setName("Target RDD")
    rdd2.count()

    sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
    assert(listener.addedExecutorInfo.size == 2)
    assert(listener.addedExecutorInfo("0").totalCores == 1)
    assert(listener.addedExecutorInfo("1").totalCores == 1)
  }

  private class SaveExecutorInfo extends SparkListener {
    val addedExecutorInfo = mutable.Map[String, ExecutorInfo]()

    override def onExecutorAdded(executor: SparkListenerExecutorAdded) {
      addedExecutorInfo(executor.executorId) = executor.executorInfo
    }
  }
} 
Example 192
Source File: BlockReplicationPolicySuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import scala.collection.mutable

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.{LocalSparkContext, SparkFunSuite}

class BlockReplicationPolicySuite extends SparkFunSuite
  with Matchers
  with BeforeAndAfter
  with LocalSparkContext {

  // Implicitly convert strings to BlockIds for test clarity.
  private implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)

  
  test(s"block replication - random block replication policy") {
    val numBlockManagers = 10
    val storeSize = 1000
    val blockManagers = (1 to numBlockManagers).map { i =>
      BlockManagerId(s"store-$i", "localhost", 1000 + i, None)
    }
    val candidateBlockManager = BlockManagerId("test-store", "localhost", 1000, None)
    val replicationPolicy = new RandomBlockReplicationPolicy
    val blockId = "test-block"

    (1 to 10).foreach {numReplicas =>
      logDebug(s"Num replicas : $numReplicas")
      val randomPeers = replicationPolicy.prioritize(
        candidateBlockManager,
        blockManagers,
        mutable.HashSet.empty[BlockManagerId],
        blockId,
        numReplicas
      )
      logDebug(s"Random peers : ${randomPeers.mkString(", ")}")
      assert(randomPeers.toSet.size === numReplicas)

      // choosing n peers out of n
      val secondPass = replicationPolicy.prioritize(
        candidateBlockManager,
        randomPeers,
        mutable.HashSet.empty[BlockManagerId],
        blockId,
        numReplicas
      )
      logDebug(s"Random peers : ${secondPass.mkString(", ")}")
      assert(secondPass.toSet.size === numReplicas)
    }

  }

} 
Example 193
Source File: TopologyMapperSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.{File, FileOutputStream}

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark._
import org.apache.spark.util.Utils

class TopologyMapperSuite  extends SparkFunSuite
  with Matchers
  with BeforeAndAfter
  with LocalSparkContext {

  test("File based Topology Mapper") {
    val numHosts = 100
    val numRacks = 4
    val props = (1 to numHosts).map{i => s"host-$i" -> s"rack-${i % numRacks}"}.toMap
    val propsFile = createPropertiesFile(props)

    val sparkConf = (new SparkConf(false))
    sparkConf.set("spark.storage.replication.topologyFile", propsFile.getAbsolutePath)
    val topologyMapper = new FileBasedTopologyMapper(sparkConf)

    props.foreach {case (host, topology) =>
      val obtainedTopology = topologyMapper.getTopologyForHost(host)
      assert(obtainedTopology.isDefined)
      assert(obtainedTopology.get === topology)
    }

    // we get None for hosts not in the file
    assert(topologyMapper.getTopologyForHost("host").isEmpty)

    cleanup(propsFile)
  }

  def createPropertiesFile(props: Map[String, String]): File = {
    val testFile = new File(Utils.createTempDir(), "TopologyMapperSuite-test").getAbsoluteFile
    val fileOS = new FileOutputStream(testFile)
    props.foreach{case (k, v) => fileOS.write(s"$k=$v\n".getBytes)}
    fileOS.close
    testFile
  }

  def cleanup(testFile: File): Unit = {
    testFile.getParentFile.listFiles.filter { file =>
      file.getName.startsWith(testFile.getName)
    }.foreach { _.delete() }
  }

} 
Example 194
Source File: LocalDirsSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.storage

import java.io.File

import org.scalatest.BeforeAndAfter

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.{SparkConfWithEnv, Utils}


class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter {

  before {
    Utils.clearLocalRootDirs()
  }

  test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") {
    // Regression test for SPARK-2974
    assert(!new File("/NONEXISTENT_DIR").exists())
    val conf = new SparkConf(false)
      .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

  test("SPARK_LOCAL_DIRS override also affects driver") {
    // Regression test for SPARK-2975
    assert(!new File("/NONEXISTENT_DIR").exists())
    // spark.local.dir only contains invalid directories, but that's not a problem since
    // SPARK_LOCAL_DIRS will override it on both the driver and workers:
    val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir")))
      .set("spark.local.dir", "/NONEXISTENT_PATH")
    assert(new File(Utils.getLocalDir(conf)).exists())
  }

} 
Example 195
Source File: JdbcRDDSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.rdd

import java.sql._

import org.scalatest.BeforeAndAfter

import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.util.Utils

class JdbcRDDSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {

  before {
    Utils.classForName("org.apache.derby.jdbc.EmbeddedDriver")
    val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true")
    try {

      try {
        val create = conn.createStatement
        create.execute("""
          CREATE TABLE FOO(
            ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1),
            DATA INTEGER
          )""")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)")
        (1 to 100).foreach { i =>
          insert.setInt(1, i * 2)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

      try {
        val create = conn.createStatement
        create.execute("CREATE TABLE BIGINT_TEST(ID BIGINT NOT NULL, DATA INTEGER)")
        create.close()
        val insert = conn.prepareStatement("INSERT INTO BIGINT_TEST VALUES(?,?)")
        (1 to 100).foreach { i =>
          insert.setLong(1, 100000000000000000L +  4000000000000000L * i)
          insert.setInt(2, i)
          insert.executeUpdate
        }
        insert.close()
      } catch {
        case e: SQLException if e.getSQLState == "X0Y32" =>
        // table exists
      }

    } finally {
      conn.close()
    }
  }

  test("basic functionality") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
      1, 100, 3,
      (r: ResultSet) => { r.getInt(1) } ).cache()

    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 10100)
  }

  test("large id overflow") {
    sc = new SparkContext("local", "test")
    val rdd = new JdbcRDD(
      sc,
      () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") },
      "SELECT DATA FROM BIGINT_TEST WHERE ? <= ID AND ID <= ?",
      1131544775L, 567279358897692673L, 20,
      (r: ResultSet) => { r.getInt(1) } ).cache()
    assert(rdd.count === 100)
    assert(rdd.reduce(_ + _) === 5050)
  }

  after {
    try {
      DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
    } catch {
      case se: SQLException if se.getSQLState == "08006" =>
        // Normal single database shutdown
        // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
    }
  }
} 
Example 196
Source File: FutureActionSuite.scala    From multi-tenancy-spark   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark

import scala.concurrent.duration.Duration

import org.scalatest.{BeforeAndAfter, Matchers}

import org.apache.spark.util.ThreadUtils


class FutureActionSuite
  extends SparkFunSuite
  with BeforeAndAfter
  with Matchers
  with LocalSparkContext {

  before {
    sc = new SparkContext("local", "FutureActionSuite")
  }

  test("simple async action") {
    val rdd = sc.parallelize(1 to 10, 2)
    val job = rdd.countAsync()
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (10)
    job.jobIds.size should be (1)
  }

  test("complex async action") {
    val rdd = sc.parallelize(1 to 15, 3)
    val job = rdd.takeAsync(10)
    val res = ThreadUtils.awaitResult(job, Duration.Inf)
    res should be (1 to 10)
    job.jobIds.size should be (2)
  }

} 
Example 197
Source File: TwitterStreamSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.twitter


import org.scalatest.BeforeAndAfter
import twitter4j.Status
import twitter4j.auth.{NullAuthorization, Authorization}

import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream

class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging {

  val batchDuration = Seconds(1)

  private val master: String = "local[2]"

  private val framework: String = this.getClass.getSimpleName

  test("twitter input stream") {
    val ssc = new StreamingContext(master, framework, batchDuration)
    val filters = Seq("filter1", "filter2")
    val authorization: Authorization = NullAuthorization.getInstance()

    // tests the API, does not actually test data receiving
    val test1: ReceiverInputDStream[Status] = TwitterUtils.createStream(ssc, None)
    val test2: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, None, filters)
    val test3: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, None, filters, StorageLevel.MEMORY_AND_DISK_SER_2)
    val test4: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, Some(authorization))
    val test5: ReceiverInputDStream[Status] =
      TwitterUtils.createStream(ssc, Some(authorization), filters)
    val test6: ReceiverInputDStream[Status] = TwitterUtils.createStream(
      ssc, Some(authorization), filters, StorageLevel.MEMORY_AND_DISK_SER_2)

    // Note that actually testing the data receiving is hard as authentication keys are
    // necessary for accessing Twitter live stream
    ssc.stop()
  }
} 
Example 198
Source File: FlumeStreamSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.flume

import java.net.{InetSocketAddress, ServerSocket}
import java.nio.ByteBuffer

import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.concurrent.duration._
import scala.language.postfixOps

import com.google.common.base.Charsets
import org.apache.avro.ipc.NettyTransceiver
import org.apache.avro.ipc.specific.SpecificRequestor
import org.apache.commons.lang3.RandomUtils
import org.apache.flume.source.avro
import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol}
import org.jboss.netty.channel.ChannelPipeline
import org.jboss.netty.channel.socket.SocketChannel
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory
import org.jboss.netty.handler.codec.compression._
import org.scalatest.{BeforeAndAfter, Matchers}
import org.scalatest.concurrent.Eventually._

import org.apache.spark.{Logging, SparkConf, SparkFunSuite}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream}
import org.apache.spark.util.Utils

class FlumeStreamSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging {
  val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite")

  var ssc: StreamingContext = null
  var transceiver: NettyTransceiver = null

  after {
    if (ssc != null) {
      ssc.stop()
    }
    if (transceiver != null) {
      transceiver.close()
    }
  }

  test("flume input stream") {
    testFlumeStream(testCompression = false)
  }

  test("flume input compressed stream") {
    testFlumeStream(testCompression = true)
  }

  
  private class CompressionChannelFactory(compressionLevel: Int)
    extends NioClientSocketChannelFactory {

    override def newChannel(pipeline: ChannelPipeline): SocketChannel = {
      val encoder = new ZlibEncoder(compressionLevel)
      pipeline.addFirst("deflater", encoder)
      pipeline.addFirst("inflater", new ZlibDecoder())
      super.newChannel(pipeline)
    }
  }
} 
Example 199
Source File: ListTablesSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql

import org.scalatest.BeforeAndAfter

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}

class ListTablesSuite extends QueryTest with BeforeAndAfter {

  import org.apache.spark.sql.test.TestSQLContext.implicits._

  val df =
    sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value")

  before {
    df.registerTempTable("ListTablesSuiteTable")
  }

  after {
    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
  }

  test("get all tables") {
    checkAnswer(
      tables().filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
    assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("getting all Tables with a database name has no impact on returned table names") {
    checkAnswer(
      tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    checkAnswer(
      sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
      Row("ListTablesSuiteTable", true))

    catalog.unregisterTable(Seq("ListTablesSuiteTable"))
    assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
  }

  test("query the returned DataFrame of tables") {
    val expectedSchema = StructType(
      StructField("tableName", StringType, false) ::
      StructField("isTemporary", BooleanType, false) :: Nil)

    Seq(tables(), sql("SHOW TABLes")).foreach {
      case tableDF =>
        assert(expectedSchema === tableDF.schema)

        tableDF.registerTempTable("tables")
        checkAnswer(
          sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
          Row(true, "ListTablesSuiteTable")
        )
        checkAnswer(
          tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
          Row("tables", true))
        dropTempTable("tables")
    }
  }
} 
Example 200
Source File: BagelSuite.scala    From iolap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.bagel

import org.scalatest.{BeforeAndAfter, Assertions}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._

import org.apache.spark._
import org.apache.spark.storage.StorageLevel

class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable

class BagelSuite extends SparkFunSuite with Assertions with BeforeAndAfter with Timeouts {

  var sc: SparkContext = _

  after {
    if (sc != null) {
      sc.stop()
      sc = null
    }
  }

  test("halting by voting") {
    sc = new SparkContext("local", "test")
    val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
    val msgs = sc.parallelize(Array[(String, TestMessage)]())
    val numSupersteps = 5
    val result =
      Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
        (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
          (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
      }
    for ((id, vert) <- result.collect) {
      assert(vert.age === numSupersteps)
    }
  }

  test("halting by message silence") {
    sc = new SparkContext("local", "test")
    val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0))))
    val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
    val numSupersteps = 5
    val result =
      Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
        (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
          val msgsOut =
            msgs match {
              case Some(ms) if (superstep < numSupersteps - 1) =>
                ms
              case _ =>
                Array[TestMessage]()
            }
        (new TestVertex(self.active, self.age + 1), msgsOut)
      }
    for ((id, vert) <- result.collect) {
      assert(vert.age === numSupersteps)
    }
  }

  test("large number of iterations") {
    // This tests whether jobs with a large number of iterations finish in a reasonable time,
    // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
    failAfter(30 seconds) {
      sc = new SparkContext("local", "test")
      val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
      val msgs = sc.parallelize(Array[(String, TestMessage)]())
      val numSupersteps = 50
      val result =
        Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
          (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
            (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
        }
      for ((id, vert) <- result.collect) {
        assert(vert.age === numSupersteps)
      }
    }
  }

  test("using non-default persistence level") {
    failAfter(10 seconds) {
      sc = new SparkContext("local", "test")
      val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
      val msgs = sc.parallelize(Array[(String, TestMessage)]())
      val numSupersteps = 20
      val result =
        Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
          (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
            (new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
        }
      for ((id, vert) <- result.collect) {
        assert(vert.age === numSupersteps)
      }
    }
  }
}