org.scalatest.FlatSpec Scala Examples

The following examples show how to use org.scalatest.FlatSpec. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: NestedCaseClassesTest.scala    From cleanframes   with Apache License 2.0 8 votes vote down vote up
package cleanframes

import com.holdenkarau.spark.testing.DataFrameSuiteBase
import org.apache.spark.sql.functions
import org.scalatest.{FlatSpec, Matchers}


class NestedCaseClassesTest
  extends FlatSpec
    with Matchers
    with DataFrameSuiteBase {

  "Cleaner" should "compile and use a custom transformer for a custom type" in {
    import cleanframes.syntax._ // to use `.clean`
    import spark.implicits._

    // define test data for a dataframe
    val input = Seq(
      // @formatter:off
      ("1",           "1",          "1",           "1",           null),
      (null,          "2",          null,          "2",           "corrupted"),
      ("corrupted",   null,         "corrupted",   null,          "true"),
      ("4",           "corrupted",  "4",           "4",           "false"),
      ("5",           "5",          "5",           "corrupted",   "false"),
      ("6",           "6",          "6",           "6",           "true")
      // @formatter:on
    )
      // give column names that are known to you
      .toDF("col1", "col2", "col3", "col4", "col5")

    // import standard functions for conversions shipped with the library
    import cleanframes.instances.all._

    // !important: you need to give a new structure to allow to access sub elements
    val renamed = input.select(
      functions.struct(
        input.col("col1") as "a_col_1",
        input.col("col2") as "a_col_2"
      ) as "a",
      functions.struct(
        input.col("col3") as "b_col_1",
        input.col("col4") as "b_col_2"
      ) as "b",
      input.col("col5") as "c"
    )

    val result = renamed.clean[AB]
      .as[AB]
      .collect

    result should {
      contain theSameElementsAs Seq(
        // @formatter:off
        AB( A(Some(1), Some(1)),  B(Some(1),  Some(1.0)), Some(false)),
        AB( A(None,    Some(2)),  B(None,     Some(2.0)), Some(false)),
        AB( A(None,    None),     B(None,     None),      Some(true)),
        AB( A(Some(4), None),     B(Some(4),  Some(4.0)), Some(false)),
        AB( A(Some(5), Some(5)),  B(Some(5),  None),      Some(false)),
        AB( A(Some(6), Some(6)),  B(Some(6),  Some(6.0)), Some(true))
        // @formatter:on
      )
    }
  }

}

case class A(a_col_1: Option[Int], a_col_2: Option[Float])

case class B(b_col_1: Option[Float], b_col_2: Option[Double])

case class AB(a: A, b: B, c: Option[Boolean]) 
Example 2
Source File: TanhSpec.scala    From BigDL   with Apache License 2.0 7 votes vote down vote up
package com.intel.analytics.bigdl.nn

import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl._
import com.intel.analytics.bigdl.nn.tf.TanhGrad
import com.intel.analytics.bigdl.utils.T
import com.intel.analytics.bigdl.utils.serializer.ModuleSerializationTest
import org.scalatest.{FlatSpec, Matchers}

import scala.math.abs

@com.intel.analytics.bigdl.tags.Parallel
class TanhSpec extends FlatSpec with Matchers {
  "A Tanh Module " should "generate correct output and grad" in {
    val module = new Tanh[Double]()
    val input = Tensor[Double](2, 2, 2)
    input(Array(1, 1, 1)) = -0.17020166106522
    input(Array(1, 1, 2)) = 0.57785657607019
    input(Array(1, 2, 1)) = -1.3404131438583
    input(Array(1, 2, 2)) = 1.0938102817163
    input(Array(2, 1, 1)) = 1.120370157063
    input(Array(2, 1, 2)) = -1.5014141565189
    input(Array(2, 2, 1)) = 0.3380249235779
    input(Array(2, 2, 2)) = -0.625677742064
    val gradOutput = Tensor[Double](2, 2, 2)
    gradOutput(Array(1, 1, 1)) = 0.79903302760795
    gradOutput(Array(1, 1, 2)) = 0.019753993256018
    gradOutput(Array(1, 2, 1)) = 0.63136631483212
    gradOutput(Array(1, 2, 2)) = 0.29849314852618
    gradOutput(Array(2, 1, 1)) = 0.94380705454387
    gradOutput(Array(2, 1, 2)) = 0.030344664584845
    gradOutput(Array(2, 2, 1)) = 0.33804601291195
    gradOutput(Array(2, 2, 2)) = 0.8807330634445
    val expectedOutput = Tensor[Double](2, 2, 2)
    expectedOutput(Array(1, 1, 1)) = -0.16857698275003
    expectedOutput(Array(1, 1, 2)) = 0.52110579963112
    expectedOutput(Array(1, 2, 1)) = -0.87177144344863
    expectedOutput(Array(1, 2, 2)) = 0.79826462420686
    expectedOutput(Array(2, 1, 1)) = 0.80769763073281
    expectedOutput(Array(2, 1, 2)) = -0.90540347425835
    expectedOutput(Array(2, 2, 1)) = 0.32571298952384
    expectedOutput(Array(2, 2, 2)) = -0.55506882753488
    val expectedGrad = Tensor[Double](2, 2, 2)
    expectedGrad(Array(1, 1, 1)) = 0.77632594793144
    expectedGrad(Array(1, 1, 2)) = 0.014389771607755
    expectedGrad(Array(1, 2, 1)) = 0.15153710218424
    expectedGrad(Array(1, 2, 2)) = 0.1082854310036
    expectedGrad(Array(2, 1, 1)) = 0.32809049064441
    expectedGrad(Array(2, 1, 2)) = 0.0054694603766104
    expectedGrad(Array(2, 2, 1)) = 0.3021830658283
    expectedGrad(Array(2, 2, 2)) = 0.6093779706637
    val inputOrg = input.clone()
    val gradOutputOrg = gradOutput.clone()
    val output = module.forward(input)
    val gradInput = module.backward(input, gradOutput)
    expectedOutput.map(output, (v1, v2) => {
      assert(abs(v1 - v2) < 1e-6);
      v1
    })
    expectedGrad.map(gradInput, (v1, v2) => {
      assert(abs(v1 - v2) < 1e-6);
      v1
    })
    assert(input == inputOrg)
    assert(gradOutput == gradOutputOrg)
  }

  "A Tanh Module " should "be good in gradient check" in {
    val module = new Tanh[Double]()
    val input = Tensor[Double](2, 2, 2).rand()

    val checker = new GradientChecker(1e-4, 1e-2)
    checker.checkLayer[Double](module, input) should be(true)
  }
}

class TanhSerialTest extends ModuleSerializationTest {
  override def test(): Unit = {
    val module = TanhGrad[Float, Float]()

    val input = T(Tensor[Float](1, 5, 3, 4).rand(), Tensor[Float](1, 5, 3, 4).rand())

    runSerializationTest(module, input)
  }
} 
Example 3
Source File: CirceSpec.scala    From featherbed   with Apache License 2.0 6 votes vote down vote up
package featherbed.circe

import cats.implicits._
import com.twitter.util.Future
import io.circe._
import io.circe.generic.auto._
import io.circe.parser.parse
import io.circe.syntax._
import org.scalatest.FlatSpec
import shapeless.{Coproduct, Witness}
import shapeless.union.Union

case class Foo(someText: String, someInt: Int)

class CirceSpec extends FlatSpec {

  "post request of a case class" should "derive JSON encoder" in {

    import com.twitter.util.{Future, Await}
    import com.twitter.finagle.{Service, Http}
    import com.twitter.finagle.http.{Request, Response}
    import java.net.InetSocketAddress

    val server = Http.serve(new InetSocketAddress(8766), new Service[Request, Response] {
      def apply(request: Request): Future[Response] = Future {
        val rep = Response()
        rep.contentString = s"${request.contentString}"
        rep.setContentTypeJson()
        rep
      }
    })

    import java.net.URL
    val client = new featherbed.Client(new URL("http://localhost:8766/api/"))

    import io.circe.generic.auto._

    val req = client.post("foo/bar")
      .withContent(Foo("Hello world!", 42), "application/json")
        .accept[Coproduct.`"application/json"`.T]

    val result = Await.result {
       req.send[Foo]()
    }

    Foo("test", 42).asJson.toString

    parse("""{"someText": "test", "someInt": 42}""").toValidated.map(_.as[Foo])

    Await.result(server.close())

  }

  "API example" should "compile" in {
    import shapeless.Coproduct
    import java.net.URL
    import com.twitter.util.Await
    case class Post(userId: Int, id: Int, title: String, body: String)

    case class Comment(postId: Int, id: Int, name: String, email: String, body: String)
    class JSONPlaceholderAPI(baseUrl: URL) {

      private val client = new featherbed.Client(baseUrl)
      type JSON = Coproduct.`"application/json"`.T

      object posts {

        private val listRequest = client.get("posts").accept[JSON]
        private val getRequest = (id: Int) => client.get(s"posts/$id").accept[JSON]

        def list(): Future[Seq[Post]] = listRequest.send[Seq[Post]]()
        def get(id: Int): Future[Post] = getRequest(id).send[Post]()
      }

      object comments {
        private val listRequest = client.get("comments").accept[JSON]
        private val getRequest = (id: Int) => client.get(s"comments/$id").accept[JSON]

        def list(): Future[Seq[Comment]] = listRequest.send[Seq[Comment]]()
        def get(id: Int): Future[Comment] = getRequest(id).send[Comment]()
      }
    }

    val apiClient = new JSONPlaceholderAPI(new URL("http://jsonplaceholder.typicode.com/"))

    Await.result(apiClient.posts.list())
  }

} 
Example 4
Source File: ReadOnlyNodeStorageSpec.scala    From mantis   with Apache License 2.0 5 votes vote down vote up
package io.iohk.ethereum.db.storage

import akka.util.ByteString
import io.iohk.ethereum.db.dataSource.EphemDataSource
import org.scalatest.{FlatSpec, Matchers}

class ReadOnlyNodeStorageSpec extends FlatSpec with Matchers {

  "ReadOnlyNodeStorage" should "not update dataSource" in new TestSetup {
    readOnlyNodeStorage.put(ByteString("key1"), ByteString("Value1").toArray)
    dataSource.storage.size shouldEqual 0
  }

  it should "be able to read from underlying storage but not change it" in new TestSetup {
    val key1 = ByteString("key1")
    val val1 = ByteString("Value1").toArray
    referenceCountNodeStorage.put(key1, val1)

    val previousSize = dataSource.storage.size
    readOnlyNodeStorage.get(key1).get shouldEqual val1

    readOnlyNodeStorage.remove(key1)

    dataSource.storage.size shouldEqual previousSize
    readOnlyNodeStorage.get(key1).get shouldEqual val1
  }

  trait TestSetup {
    val dataSource = EphemDataSource()
    val nodeStorage = new NodeStorage(dataSource)

    val referenceCountNodeStorage = new ReferenceCountNodeStorage(nodeStorage, blockNumber = Some(1))
    val readOnlyNodeStorage = ReadOnlyNodeStorage(referenceCountNodeStorage)
  }
} 
Example 5
Source File: EthashSpec.scala    From mantis   with Apache License 2.0 5 votes vote down vote up
package io.iohk.ethereum.consensus

import akka.util.ByteString
import io.iohk.ethereum.crypto.kec256
import org.scalacheck.Arbitrary
import org.scalatest.prop.PropertyChecks
import org.scalatest.{FlatSpec, Matchers}
import org.spongycastle.util.encoders.Hex

class EthashSpec extends FlatSpec with Matchers with PropertyChecks {

  import Ethash._

  "Ethash" should "generate correct hash" in {
    forAll(Arbitrary.arbitrary[Long].filter(_ < 15000000)) { blockNumber =>
      seed(epoch(blockNumber)) shouldBe seedForBlockReference(blockNumber)
    }
  }

  it should "calculate cache size" in {
    val cacheSizes = Seq(16776896, 16907456, 17039296, 17170112, 17301056, 17432512, 17563072)
    cacheSizes.zipWithIndex.foreach { case (referenceSize, epoch) =>
      cacheSize(epoch) shouldBe referenceSize
    }
  }

  it should "compute proof of work using cache" in {
    val hash = Array(0xf5, 0x7e, 0x6f, 0x3a, 0xcf, 0xc0, 0xdd, 0x4b, 0x5b, 0xf2, 0xbe, 0xe4, 0x0a, 0xb3, 0x35, 0x8a, 0xa6, 0x87, 0x73, 0xa8, 0xd0, 0x9f, 0x5e, 0x59, 0x5e, 0xab, 0x55, 0x94, 0x05, 0x52, 0x7d, 0x72).map(_.toByte)
    val nonce = Array(0xd7, 0xb3, 0xac, 0x70, 0xa3, 0x01, 0xa2, 0x49).map(_.toByte)

    val mixHash = Array(0x1f, 0xff, 0x04, 0xce, 0xc9, 0x41, 0x73, 0xfd, 0x59, 0x1e, 0x3d, 0x89, 0x60, 0xce, 0x6b, 0xdf, 0x8b, 0x19, 0x71, 0x04, 0x8c, 0x71, 0xff, 0x93, 0x7b, 0xb2, 0xd3, 0x2a, 0x64, 0x31, 0xab, 0x6d).map(_.toByte)
    val boundary = Array(0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x3e, 0x9b, 0x6c, 0x69, 0xbc, 0x2c, 0xe2, 0xa2, 0x4a, 0x8e, 0x95, 0x69, 0xef, 0xc7, 0xd7, 0x1b, 0x33, 0x35, 0xdf, 0x36, 0x8c, 0x9a, 0xe9, 0x7e, 0x53, 0x84).map(_.toByte)

    val blockNumber = 486382
    val cache = makeCache(epoch(blockNumber))
    val proofOfWork = hashimotoLight(hash, nonce, dagSize(epoch(blockNumber)), cache)

    proofOfWork.mixHash shouldBe ByteString(mixHash)
    proofOfWork.difficultyBoundary shouldBe ByteString(boundary)

    val table = Table(
      ("blockNumber", "hashWithoutNonce", "nonce", "mixHash"),
      (3521,"269d13f7ca546dced28ee26071dcb61085b7c54dfc5f93808b94885e136cd616","534ab630b9aa1f68","c6913517d1dc7544febde9f17e65d6ae4fa380d4a2a2d31305b0043caf95e717"),
      (5021,"7bd6c3c49a0627712c51f1abf0a7828bb25ebb8679d2584385a191db955667da","413dc4ec1a0df7c4","35890608f8867402052b2ce55a694b86a44ce87e7fb5412a77a025b184f25883"),
      (5091,"5b27820bfa3a059274ce17db0beea90ba0b6fbe6b49d2a23cbf972e8cde79319","59225875d18ad606","46f72f8b269461078e9d1cf4edf1b608f9d101e0f335ea59568c3436f291d01b"),
      (3091,"c37d980124cf83a4de4d9600f5bb6d3883797b84b7ec472feff6ca855c01d245","745609efa9c4eef3","c647fec06481b9f3f74cd771968d6d630aa11bf75ebd9e3c55ccfbae0fbad4da"),
      (1091,"c1c1efb8fdd4241a55db39e092fedae3df6d4abc13133778810027ade6557bc6","4d9ddadaea6c20b2","53624a7faac2ec82208348f7a11e3b38c880a2fec76dd8b47e434fe641eeacde"),
      (109,"aa234d4bcee14e93d127275dcc83504b6e730a14e9110bd09b68e1964f0daad3","388e6b37c22147b7","df14701b1ad6d3d5639956e463250960de3189a726cb38d71a6f6042f45dea72"),
      (1009,"3259779f9d2c477d29e18ead0ccc829bf2146723563c3e81e5e4886673d93bfb","5faa044b70ccdf6b","a1f1af0c2ca3e1d8e69da59fefbfeb4d0d172ec96bdbdac71b2cde49ddb3a828"),
      (1001,"028cc9a70d6db52c2a2606f04392e9a323d0370291d6c6d78bc8ce54acf1d761","a54b5b31ce3de766","819c26573f1a9cd6c4b9a399b72fbfb0084a104b25b62083533e114ee98a4831"),
      (1000,"15c5729eb017a703c13d00752338f6b55e2d2551b380706f0486f2ccca57ae1e","eb610e766452a801","a369e2fd5c4e357cf9f60ba063ae0baf32075b0d7ed80cd78134bc401db8f1bf"),
      (100,"41944a94a42695180b1ca231720a87825f17d36475112b659c23dea1542e0977","37129c7f29a9364b","5bb43c0772e58084b221c8e0c859a45950c103c712c5b8f11d9566ee078a4501"))

    forAll(table) { (blockNumber, hashWithoutNonce, nonce, mixHash) =>
      val cache = makeCache(epoch(blockNumber))
      val proofOfWork = hashimotoLight(Hex.decode(hashWithoutNonce), Hex.decode(nonce), dagSize(epoch(blockNumber)), cache)
      proofOfWork.mixHash shouldBe ByteString(Hex.decode(mixHash))
    }
  }

  def seedForBlockReference(blockNumber: BigInt): ByteString = {
    if (blockNumber < EPOCH_LENGTH) {
      //wrong version from YP:
      //ByteString(kec256(Hex.decode("00" * 32)))
      //working version:
      ByteString(Hex.decode("00" * 32))
    } else {
      kec256(seedForBlockReference(blockNumber - EPOCH_LENGTH))
    }
  }

} 
Example 6
Source File: BlockchainSpec.scala    From mantis   with Apache License 2.0 5 votes vote down vote up
package io.iohk.ethereum.domain

import akka.util.ByteString
import io.iohk.ethereum.Fixtures
import io.iohk.ethereum.blockchain.sync.EphemBlockchainTestSetup
import io.iohk.ethereum.db.storage.ArchiveNodeStorage
import io.iohk.ethereum.mpt.MerklePatriciaTrie
import org.scalatest.{FlatSpec, Matchers}

class BlockchainSpec extends FlatSpec with Matchers {

  "Blockchain" should "be able to store a block and return if if queried by hash" in new EphemBlockchainTestSetup {
    val validBlock = Fixtures.Blocks.ValidBlock.block
    blockchain.save(validBlock)
    val block = blockchain.getBlockByHash(validBlock.header.hash)
    assert(block.isDefined)
    assert(validBlock == block.get)
    val blockHeader = blockchain.getBlockHeaderByHash(validBlock.header.hash)
    assert(blockHeader.isDefined)
    assert(validBlock.header == blockHeader.get)
    val blockBody = blockchain.getBlockBodyByHash(validBlock.header.hash)
    assert(blockBody.isDefined)
    assert(validBlock.body == blockBody.get)
  }

  it should "be able to store a block and retrieve it by number" in new EphemBlockchainTestSetup {
    val validBlock = Fixtures.Blocks.ValidBlock.block
    blockchain.save(validBlock)
    val block = blockchain.getBlockByNumber(validBlock.header.number)
    assert(block.isDefined)
    assert(validBlock == block.get)
  }

  it should "be able to query a stored blockHeader by it's number" in new EphemBlockchainTestSetup {
    val validHeader = Fixtures.Blocks.ValidBlock.header
    blockchain.save(validHeader)
    val header = blockchain.getBlockHeaderByNumber(validHeader.number)
    assert(header.isDefined)
    assert(validHeader == header.get)
  }

  it should "not return a value if not stored" in new EphemBlockchainTestSetup {
    assert(blockchain.getBlockByNumber(Fixtures.Blocks.ValidBlock.header.number).isEmpty)
    assert(blockchain.getBlockByHash(Fixtures.Blocks.ValidBlock.header.hash).isEmpty)
  }

  it should "return an account given an address and a block number" in new EphemBlockchainTestSetup {
    val address = Address(42)
    val account = Account.empty(UInt256(7))

    val validHeader = Fixtures.Blocks.ValidBlock.header

    val emptyMpt = MerklePatriciaTrie[Address, Account](
      new ArchiveNodeStorage(storagesInstance.storages.nodeStorage)
    )

    val mptWithAcc = emptyMpt.put(address, account)
    val headerWithAcc = validHeader.copy(stateRoot = ByteString(mptWithAcc.getRootHash))

    blockchain.save(headerWithAcc)

    val retrievedAccount = blockchain.getAccount(address, headerWithAcc.number)
    retrievedAccount shouldEqual Some(account)
  }
} 
Example 7
Source File: SignedTransactionSpec.scala    From mantis   with Apache License 2.0 5 votes vote down vote up
package io.iohk.ethereum.domain

import io.iohk.ethereum.crypto
import io.iohk.ethereum.crypto.generateKeyPair
import io.iohk.ethereum.domain.SignedTransaction.FirstByteOfAddress
import io.iohk.ethereum.nodebuilder.SecureRandomBuilder
import io.iohk.ethereum.vm.Generators
import org.scalacheck.Arbitrary
import org.scalatest.prop.PropertyChecks
import org.scalatest.{FlatSpec, Matchers}
import org.spongycastle.crypto.params.ECPublicKeyParameters

class SignedTransactionSpec extends FlatSpec with Matchers with PropertyChecks with SecureRandomBuilder {
  "SignedTransaction" should "correctly set pointSign for chainId with chain specific signing schema" in {
    forAll(Generators.transactionGen(), Arbitrary.arbitrary[Unit].map(_ => generateKeyPair(secureRandom))) {
      (tx, key) =>
        val chainId: Byte = 0x3d
        val allowedPointSigns = Set((chainId * 2 + 35).toByte, (chainId * 2 + 36).toByte)
        //byte 0 of encoded ECC point indicates that it is uncompressed point, it is part of spongycastle encoding
        val address = Address(crypto.kec256(key.getPublic.asInstanceOf[ECPublicKeyParameters].getQ.getEncoded(false).tail).drop(FirstByteOfAddress))
        val result = SignedTransaction.sign(tx, key, Some(chainId))

        allowedPointSigns should contain(result.signature.v)
        address shouldEqual result.senderAddress
    }
  }
} 
Example 8
Source File: EncryptedKeySpec.scala    From mantis   with Apache License 2.0 5 votes vote down vote up
package io.iohk.ethereum.keystore

import org.scalatest.{FlatSpec, Matchers}
import io.iohk.ethereum.crypto
import io.iohk.ethereum.domain.Address
import io.iohk.ethereum.nodebuilder.SecureRandomBuilder

class EncryptedKeySpec extends FlatSpec with Matchers with SecureRandomBuilder {

  val gethKey =
    """{
      |  "id": "033b7a63-30f2-47fc-bbbe-d22925a14ab3",
      |  "address": "932245e1c40ec2026a2c7acc80befb68816cdba4",
      |  "crypto": {
      |    "cipher": "aes-128-ctr",
      |    "ciphertext": "8fb53f8695795d1f0480cad7954bd7a888392bb24c414b9895b4cb288b4897dc",
      |    "cipherparams": {
      |      "iv": "7a754cfd548a351aed270f6b1bfd306d"
      |    },
      |    "kdf": "scrypt",
      |    "kdfparams": {
      |      "dklen": 32,
      |      "n": 262144,
      |      "p": 1,
      |      "r": 8,
      |      "salt": "2321125eff8c3172a05a5947726004075b30e0a01534061fa5c13fb4e5e32465"
      |    },
      |    "mac": "6383677d3b0f34b1dcb9e1c767f8130daf6233266e35f28e00467af97bf2fbfa"
      |  },
      |  "version": 3
      |}
    """.stripMargin

  val parityKey =
    """{
      |  "id": "20909a42-09c4-0740-02dc-a0b8cbaea688",
      |  "version": 3,
      |  "Crypto": {
      |    "cipher": "aes-128-ctr",
      |    "cipherparams": {
      |      "iv": "68521b4d5fc5ecf83bbe24768a321fe5"
      |    },
      |    "ciphertext": "235ce2efb355a963eb838f42e3c30e59a00ab14030e66202d729f60fc7af57b3",
      |    "kdf": "pbkdf2",
      |    "kdfparams": {
      |      "c": 10240,
      |      "dklen": 32,
      |      "prf": "hmac-sha256",
      |      "salt": "b720278006c39e3ed01903421aad4ca3d3267f40ddc48bf4ec06429eb1c10fc5"
      |    },
      |    "mac": "f52a2c40173dea137695f40f0f6bed67fc7814e16639acdbff11092cc9f563d0"
      |  },
      |  "address": "04fecb5c49ee66fdbda1f196c120225bdd1ac35c",
      |  "name": "20909a42-09c4-0740-02dc-a0b8cbaea688",
      |  "meta": "{}"
      |}""".stripMargin


  "EncryptedKey" should "securely store private keys" in {
    val prvKey = crypto.secureRandomByteString(secureRandom, 32)
    val passphrase = "P4S5W0rd"
    val encKey = EncryptedKey(prvKey, passphrase, secureRandom)

    val json = EncryptedKeyJsonCodec.toJson(encKey)
    val decoded = EncryptedKeyJsonCodec.fromJson(json)

    decoded shouldEqual Right(encKey)
    decoded.flatMap(_.decrypt(passphrase)) shouldEqual Right(prvKey)
  }

  it should "decrypt a key encrypted by Geth" in {
    val encKey = EncryptedKeyJsonCodec.fromJson(gethKey)
    val prvKey = encKey.flatMap(_.decrypt("qwerty"))
    val address = prvKey.map(k => Address(crypto.kec256(crypto.pubKeyFromPrvKey(k))))
    address shouldEqual Right(Address("932245e1c40ec2026a2c7acc80befb68816cdba4"))
  }

  it should "decrypt a key encrypted by Parity" in {
    val encKey = EncryptedKeyJsonCodec.fromJson(parityKey)
    val prvKey = encKey.flatMap(_.decrypt("qwerty"))
    val address = prvKey.map(k => Address(crypto.kec256(crypto.pubKeyFromPrvKey(k))))
    address shouldEqual Right(Address("04fecb5c49ee66fdbda1f196c120225bdd1ac35c"))
  }
} 
Example 9
Source File: UtilsTest.scala    From sparkMeasure   with Apache License 2.0 5 votes vote down vote up
package ch.cern.sparkmeasure

import java.io.File

import scala.collection.mutable.ListBuffer

import org.scalatest.{FlatSpec, Matchers}

class UtilsTest extends FlatSpec with Matchers {

  val stageVals0 = StageVals(jobId = 1, jobGroup = "test", stageId= 2, name = "stageVal",
    submissionTime = 10, completionTime = 11, stageDuration = 12, numTasks = 13,
    executorRunTime = 14, executorCpuTime = 15,
    executorDeserializeTime = 16, executorDeserializeCpuTime = 17,
    resultSerializationTime = 18, jvmGCTime = 19, resultSize = 20, numUpdatedBlockStatuses = 21,
    diskBytesSpilled = 30, memoryBytesSpilled = 31, peakExecutionMemory = 32, recordsRead = 33,
    bytesRead = 34, recordsWritten = 35, bytesWritten = 36,
    shuffleFetchWaitTime = 40, shuffleTotalBytesRead = 41, shuffleTotalBlocksFetched = 42,
    shuffleLocalBlocksFetched = 43, shuffleRemoteBlocksFetched = 44, shuffleWriteTime = 45,
    shuffleBytesWritten = 46, shuffleRecordsWritten = 47
  )

  val taskVals0 = TaskVals(jobId = 1, jobGroup = "test", stageId = 2, index = 3, launchTime = 4, finishTime = 5,
    duration = 10, schedulerDelay = 11, executorId = "exec0", host = "host0", taskLocality = 12,
    speculative = false, gettingResultTime = 12, successful = true,
    executorRunTime = 14, executorCpuTime = 15,
    executorDeserializeTime = 16, executorDeserializeCpuTime = 17,
    resultSerializationTime = 18, jvmGCTime = 19, resultSize = 20, numUpdatedBlockStatuses = 21,
    diskBytesSpilled = 30, memoryBytesSpilled = 31, peakExecutionMemory = 32, recordsRead = 33,
    bytesRead = 34, recordsWritten = 35, bytesWritten = 36,
    shuffleFetchWaitTime = 40, shuffleTotalBytesRead = 41, shuffleTotalBlocksFetched = 42,
    shuffleLocalBlocksFetched = 43, shuffleRemoteBlocksFetched = 44, shuffleWriteTime = 45,
    shuffleBytesWritten = 46, shuffleRecordsWritten = 47
  )

  it should "write and read back StageVal (Java Serialization)" in {
    val file = File.createTempFile("stageVal", ".tmp")
    try {
      IOUtils.writeSerialized(file.getAbsolutePath, ListBuffer(stageVals0))
      val stageVals = IOUtils.readSerializedStageMetrics(file.getAbsolutePath)
      stageVals should have length 1
      stageVals.head shouldEqual stageVals0
    } finally {
      file.delete()
    }
  }

  it should "write and read back TaskVal (Java Serialization)" in {
    val file = File.createTempFile("taskVal", ".tmp")
    try {
      IOUtils.writeSerialized(file.getAbsolutePath, ListBuffer(taskVals0))
      val taskVals = IOUtils.readSerializedTaskMetrics(file.getAbsolutePath)
      taskVals should have length 1
      taskVals.head shouldEqual taskVals0
    } finally {
      file.delete()
    }
  }

  it should "write and read back StageVal JSON" in {
    val file = File.createTempFile("stageVal", ".json")
    try {
      IOUtils.writeSerializedJSON(file.getAbsolutePath, ListBuffer(stageVals0))
      val stageVals = IOUtils.readSerializedStageMetricsJSON(file.getAbsolutePath)
      stageVals should have length 1
      stageVals.head shouldEqual stageVals0
    } finally {
      file.delete()
    }
  }

  it should "write and read back TaskVal JSON" in {
    val file = File.createTempFile("taskVal", ".json")
    try {
      IOUtils.writeSerializedJSON(file.getAbsolutePath, ListBuffer(taskVals0))
      val taskVals = IOUtils.readSerializedTaskMetricsJSON(file.getAbsolutePath)
      taskVals should have length 1
      taskVals.head shouldEqual taskVals0
    } finally {
      file.delete()
    }
  }

} 
Example 10
Source File: TestSpec.scala    From akka-serialization-test   with Apache License 2.0 5 votes vote down vote up
package com.github.dnvriend

import akka.actor.{ ActorRef, ActorSystem, PoisonPill }
import akka.event.{ Logging, LoggingAdapter }
import akka.serialization.SerializationExtension
import akka.stream.{ ActorMaterializer, Materializer }
import akka.testkit.TestProbe
import akka.util.Timeout
import org.scalatest.concurrent.{ Eventually, ScalaFutures }
import org.scalatest.prop.PropertyChecks
import org.scalatest.{ BeforeAndAfterAll, FlatSpec, GivenWhenThen, Matchers }

import scala.concurrent.duration._
import scala.concurrent.{ ExecutionContext, Future }
import scala.util.Try

trait TestSpec extends FlatSpec
    with Matchers
    with GivenWhenThen
    with ScalaFutures
    with BeforeAndAfterAll
    with Eventually
    with PropertyChecks
    with AkkaPersistenceQueries
    with AkkaStreamUtils
    with InMemoryCleanup {

  implicit val timeout: Timeout = Timeout(10.seconds)
  implicit val system: ActorSystem = ActorSystem()
  implicit val ec: ExecutionContext = system.dispatcher
  implicit val mat: Materializer = ActorMaterializer()
  implicit val log: LoggingAdapter = Logging(system, this.getClass)
  implicit val pc: PatienceConfig = PatienceConfig(timeout = 50.seconds)
  val serialization = SerializationExtension(system)

  implicit class FutureToTry[T](f: Future[T]) {
    def toTry: Try[T] = Try(f.futureValue)
  }

  def killActors(actors: ActorRef*): Unit = {
    val probe = TestProbe()
    actors.foreach { actor ⇒
      probe watch actor
      actor ! PoisonPill
      probe expectTerminated actor
    }
  }

  override protected def afterAll(): Unit = {
    system.terminate()
    system.whenTerminated.toTry should be a 'success
  }
} 
Example 11
Source File: InterpreterLiteralExpressionTest.scala    From feel-scala   with Apache License 2.0 5 votes vote down vote up
package org.camunda.feel.interpreter.impl

import org.scalatest.{FlatSpec, Matchers}
import org.camunda.feel.syntaxtree._


class InterpreterLiteralExpressionTest
    extends FlatSpec
    with Matchers
    with FeelIntegrationTest {

  "A literal" should "be a number" in {

    eval("2") should be(ValNumber(2))
    eval("2.4") should be(ValNumber(2.4))
    eval("-3") should be(ValNumber(-3))
  }

  it should "be a string" in {

    eval(""" "a" """) should be(ValString("a"))
  }

  it should "be a boolean" in {

    eval("true") should be(ValBoolean(true))
  }

  it should "be null" in {

    eval("null") should be(ValNull)
  }

  it should "be a context" in {

    eval("{ a : 1 }")
      .asInstanceOf[ValContext]
      .context
      .variableProvider
      .getVariables should be(Map("a" -> ValNumber(1)))

    eval("""{ a:1, b:"foo" }""")
      .asInstanceOf[ValContext]
      .context
      .variableProvider
      .getVariables should be(Map("a" -> ValNumber(1), "b" -> ValString("foo")))

    // nested
    val nestedContext = eval("{ a : { b : 1 } }")
      .asInstanceOf[ValContext]
      .context
      .variableProvider
      .getVariable("a")
      .get

    nestedContext shouldBe a[ValContext]
    nestedContext
      .asInstanceOf[ValContext]
      .context
      .variableProvider
      .getVariables should be(Map("b" -> ValNumber(1)))
  }

  it should "be a list" in {

    eval("[1]") should be(ValList(List(ValNumber(1))))

    eval("[1,2]") should be(ValList(List(ValNumber(1), ValNumber(2))))

    // nested
    eval("[ [1], [2] ]") should be(
      ValList(List(ValList(List(ValNumber(1))), ValList(List(ValNumber(2))))))
  }

} 
Example 12
Source File: FacetedLuceneRDDImplicitsSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd.facets

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.testing.FavoriteCaseClass
import org.zouzias.spark.lucenerdd.{LuceneRDD, LuceneRDDKryoRegistrator}

class FacetedLuceneRDDImplicitsSpec  extends FlatSpec
  with Matchers
  with BeforeAndAfterEach
  with SharedSparkContext {

  var luceneRDD: LuceneRDD[_] = _


  override val conf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  override def afterEach() {
    luceneRDD.close()
  }


  val elem = Array("fear", "death", "water", "fire", "house")
    .zipWithIndex.map{ case (str, index) =>
    FavoriteCaseClass(str, index, 10L, 12.3F, s"${str}@gmail.com")}


  "FacetedLuceneRDD(case class).count" should "return correct number of elements" in {
    val rdd = sc.parallelize(elem)
    val spark = SparkSession.builder().getOrCreate()
    import spark.implicits._
    val df = rdd.toDF()
    luceneRDD = FacetedLuceneRDD(df)
    luceneRDD.count should equal (elem.size)
  }

  "FacetedLuceneRDD(case class).fields" should "return all fields" in {
    val rdd = sc.parallelize(elem)
    val spark = SparkSession.builder().getOrCreate()
    import spark.implicits._
    val df = rdd.toDF()
    luceneRDD = FacetedLuceneRDD(df)

    luceneRDD.fields().size should equal(5)
    luceneRDD.fields().contains("name") should equal(true)
    luceneRDD.fields().contains("age") should equal(true)
    luceneRDD.fields().contains("myLong") should equal(true)
    luceneRDD.fields().contains("myFloat") should equal(true)
    luceneRDD.fields().contains("email") should equal(true)
  }

  "FacetedLuceneRDD(case class).termQuery" should "correctly search with TermQueries" in {
    val rdd = sc.parallelize(elem)
    val spark = SparkSession.builder().getOrCreate()
    import spark.implicits._
    val df = rdd.toDF()
    luceneRDD = FacetedLuceneRDD(df)

    val results = luceneRDD.termQuery("name", "water")
    results.count() should equal(1)
  }
} 
Example 13
Source File: LucenePrimitiveTypesSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}

class LucenePrimitiveTypesSpec extends FlatSpec with Matchers
  with BeforeAndAfterEach
  with SharedSparkContext {

  override val conf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  def randomString(length: Int): String = scala.util.Random.alphanumeric.take(length).mkString
  val array = (1 to 24).map(randomString(_))

  var luceneRDD: LuceneRDD[_] = _

  override def afterEach() {
    luceneRDD.close()
  }

  

  "LuceneRDD" should "work with RDD[Array[String]]" in {
    val array = Array(Array("aaa", "aaa2"), Array("bbb", "bbb2"),
      Array("ccc", "ccc2"), Array("ddd"), Array("eee"))
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should be (array.length)
  }

  "LuceneRDD" should "work with RDD[Set[String]]" in {
    val array = Array(Set("aaa", "aaa2"), Set("bbb", "bbb2"),
      Set("ccc", "ccc2"), Set("ddd"), Set("eee"))
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should be (array.length)
  }

  "LuceneRDD" should "work with RDD[String]" in {
    val array = Array("aaa", "bbb", "ccc", "ddd", "eee")
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should be (array.length)
  }

  "LuceneRDD" should "work with RDD[Int]" in {
    val array = (1 to 22)
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should be (array.size)
  }

  "LuceneRDD" should "work with RDD[Float]" in {
    val array: IndexedSeq[Float] = (1 to 22).map(_.toFloat)
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should be (array.size)
  }

  "LuceneRDD" should "work with RDD[Double]" in {
    val array: IndexedSeq[Double] = (1 to 22).map(_.toDouble)
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should be (array.size)
  }

  "LuceneRDD" should "work with RDD[Long]" in {
    val array: IndexedSeq[Long] = (1 to 22).map(_.toLong)
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should equal (array.size)
  }

  "LuceneRDD" should "work with RDD[Map[String, String]]" in {
    val maps = List(Map( "a" -> "hello"), Map("b" -> "world"), Map("c" -> "how are you"))
    val rdd = sc.parallelize(maps)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should equal (maps.size)
    luceneRDD.termQuery("a", "hello").isEmpty() should equal (false)
    luceneRDD.prefixQuery("b", "wor").isEmpty() should equal (false)
    luceneRDD.prefixQuery("a", "no").isEmpty() should equal (true)
  }

  "LuceneRDD" should "work with RDD[String] and ignore null values" in {
    val array = Array("aaa", null, "ccc", null, "eee")
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should be (array.length)
  }

} 
Example 14
Source File: BlockingLinkageSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.lucene.index.Term
import org.apache.lucene.search.{Query, TermQuery}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.testing.Person

class BlockingLinkageSpec extends FlatSpec
  with Matchers
  with BeforeAndAfterEach
  with SharedSparkContext {

  override val conf: SparkConf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  "LuceneRDD.blockEntityLinkage" should "deduplicate elements on unique elements" in {
    val spark = SparkSession.builder().getOrCreate()
    import spark.implicits._

    val peopleLeft: Array[Person] = Array("fear", "death", "water", "fire", "house")
      .zipWithIndex.map { case (str, index) =>
      val email = if (index % 2 == 0) "[email protected]" else "[email protected]"
      Person(str, index, email)
    }

    val peopleRight: Array[Person] = Array("fear", "death", "water", "fire", "house")
      .zipWithIndex.map { case (str, index) =>
      val email = if (index % 2 == 0) "[email protected]" else "[email protected]"
      Person(str, index, email)
    }

    val leftDF = sc.parallelize(peopleLeft).repartition(2).toDF()
    val rightDF = sc.parallelize(peopleRight).repartition(3).toDF()

    // Define a Lucene Term linker
    val linker: Row => Query = { row =>
      val name = row.getString(row.fieldIndex("name"))
      val term = new Term("name", name)

      new TermQuery(term)
    }


    val linked = LuceneRDD.blockEntityLinkage(leftDF, rightDF, linker,
      Array("email"), Array("email"))

    val linkedCount, dfCount = (linked.count, leftDF.count())

    linkedCount should equal(dfCount)

    // Check for correctness
    // Age is a unique index
    linked.collect().foreach { case (row, results) =>
      val leftAge, rightAge = (row.getInt(row.fieldIndex("age")),
        results.headOption.map(x => x.getInt(x.fieldIndex("age"))))

      leftAge should equal(rightAge)

    }
  }
} 
Example 15
Source File: LuceneRDDCustomCaseClassImplicitsSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.testing.Person

class LuceneRDDCustomCaseClassImplicitsSpec extends FlatSpec
  with Matchers
  with BeforeAndAfterEach
  with SharedSparkContext {

  var luceneRDD: LuceneRDD[_] = _

  override def afterEach() {
    luceneRDD.close()
  }

  override val conf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  val elem: Array[Person] = Array("fear", "death", "water", "fire", "house")
    .zipWithIndex.map{ case (str, index) => Person(str, index, s"${str}@gmail.com")}

  "LuceneRDD(case class).count" should "handle nulls properly" in {
    val elemsWithNulls = Array("fear", "death", "water", "fire", "house")
      .zipWithIndex.map{ case (str, index) => Person(str, index, null)}
    val rdd = sc.parallelize(elemsWithNulls)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count() should equal (elemsWithNulls.length)
  }

  "LuceneRDD(case class).count" should "return correct number of elements" in {
    val rdd = sc.parallelize(elem)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count() should equal (elem.length)
  }

  "LuceneRDD(case class).fields" should "return all fields" in {
    val rdd = sc.parallelize(elem)
    luceneRDD = LuceneRDD(rdd)

    luceneRDD.fields().size should equal(3)
    luceneRDD.fields().contains("name") should equal(true)
    luceneRDD.fields().contains("age") should equal(true)
    luceneRDD.fields().contains("email") should equal(true)
  }

  "LuceneRDD(case class).termQuery" should "correctly search with TermQueries" in {
    val rdd = sc.parallelize(elem)
    luceneRDD = LuceneRDD(rdd)

    val results = luceneRDD.termQuery("name", "water")
    results.count() should equal(1)
  }
} 
Example 16
Source File: ShapeLuceneRDDImplicitsSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd.spatial.shape.implicits

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.spatial.shape.{ShapeLuceneRDD, _}
import org.zouzias.spark.lucenerdd.testing.LuceneRDDTestUtils
import org.zouzias.spark.lucenerdd._
import org.zouzias.spark.lucenerdd.spatial.shape.context.ContextLoader

class ShapeLuceneRDDImplicitsSpec extends FlatSpec
  with Matchers
  with BeforeAndAfterEach
  with SharedSparkContext
  with ContextLoader
  with LuceneRDDTestUtils {

  val Radius: Double = 5D

  override val conf = ShapeLuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  "ShapeLuceneRDDImplicits" should "implicitly convert to point" in {

    val rdd = sc.parallelize(cities)
    val shapeRDD = ShapeLuceneRDD(rdd)

    shapeRDD.count should equal(cities.length)
  }

  "ShapeLuceneRDDImplicits" should "implicitly convert to circle" in {

    val circleCities: Array[(((Double, Double), Double), String)]
    = cities.map(convertToCircle)
    val rdd = sc.parallelize(circleCities)
    val shapeRDD = ShapeLuceneRDD(rdd)

    shapeRDD.count should equal(circleCities.length)
  }

  "ShapeLuceneRDDImplicits" should "implicitly convert to rectangle" in {

    val rectangleCities = cities.map(convertToRectangle)
    val rdd = sc.parallelize(rectangleCities)
    val shapeRDD = ShapeLuceneRDD(rdd)

    shapeRDD.count should equal(rectangleCities.length)
  }

  "ShapeLuceneRDDImplicits" should "implicitly convert POINTS from WKT" in {
    val sparkSession = SparkSession.builder().getOrCreate()
    val citiesDF = sparkSession.read.parquet("data/world-cities-points.parquet")
    import sparkSession.implicits._
    val citiesRDD = citiesDF.map(row =>
      (row.getString(2), (row.getString(0), row.getString(1))))

    val total = citiesDF.count()
    total > 0 should equal(true)

    val shapeRDD = ShapeLuceneRDD(citiesRDD)

    shapeRDD.count > 0 should equal(true)
  }

  "ShapeLuceneRDDImplicits" should "implicitly convert BBOX from WKT" in {
    val sparkSession = SparkSession.builder().getOrCreate()
    import sparkSession.implicits._
    val countriesDF = sparkSession.read.parquet("data/countries-bbox.parquet")
    val citiesRDD = countriesDF.map(row =>
      (row.getString(2), (row.getString(0), row.getString(1))))

    val total = countriesDF.count()
    total > 0 should equal(true)

    val shapeRDD = ShapeLuceneRDD(citiesRDD)

    shapeRDD.count > 0 should equal(true)
  }

  "ShapeLuceneRDDImplicits" should "implicitly convert to polygon" in {

    val polygonCities = cities.map(convertToPolygon(_, Radius))
    val rdd = sc.parallelize(polygonCities)
    val shapeRDD = ShapeLuceneRDD(rdd)

    shapeRDD.count should equal(polygonCities.length)
  }

} 
Example 17
Source File: AnalyzersConfigurableSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd.analyzers

import org.apache.lucene.analysis.en.EnglishAnalyzer
import org.apache.lucene.analysis.el.GreekAnalyzer
import org.apache.lucene.analysis.de.GermanAnalyzer
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}

class AnalyzersConfigurableSpec extends FlatSpec with Matchers
  with BeforeAndAfterEach
  with AnalyzerConfigurable {

  "AnalyzersConfigurable.getAnalyzer" should "return english analyzer with 'en' input" in {
    val englishAnalyzer = getAnalyzer(Some("en"))
    englishAnalyzer shouldNot equal(null)
    englishAnalyzer.isInstanceOf[EnglishAnalyzer] should equal(true)
  }

  "AnalyzersConfigurable.getAnalyzer" should
    "return custom test analyzer with 'org.apache.lucene.analysis.el.GreekAnalyzer'" in {
    val greekAnalyzer = getAnalyzer(Some("org.apache.lucene.analysis.el.GreekAnalyzer"))
    greekAnalyzer shouldNot equal(null)
    greekAnalyzer.isInstanceOf[GreekAnalyzer] should equal(true)
  }

  "AnalyzersConfigurable.getAnalyzer" should
    "return custom test analyzer with 'org.apache.lucene.analysis.de.GermanAnalyzer'" in {
    val deutschAnalyzer = getAnalyzer(Some("org.apache.lucene.analysis.de.GermanAnalyzer"))
    deutschAnalyzer shouldNot equal(null)
    deutschAnalyzer.isInstanceOf[GermanAnalyzer] should equal(true)
  }
} 
Example 18
Source File: LuceneRDDSearchSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.testing.LuceneRDDTestUtils

class LuceneRDDSearchSpec extends FlatSpec
  with Matchers
  with BeforeAndAfterEach
  with LuceneRDDTestUtils
  with SharedSparkContext {

  var luceneRDD: LuceneRDD[_] = _

  override def Radius: Double = 0

  override val conf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  override def afterEach() {
    luceneRDD.close()
  }


  val First = "_1"

  val array = List("fear", "death", " apologies", "romance", "tree", "fashion", "fascism")

  "LuceneRDD.query" should "use phrase query syntax" in {
    val words = Array("aabaa", "aaacaa", "aadaa", "aaaa", "qwerty")
    val rdd = sc.parallelize(words)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.query("_1:aadaa").isEmpty() should equal (false)
    luceneRDD.query("_1:aa*").count() should equal (4)
    luceneRDD.query("_1:q*").count() should equal (1)
  }

  "LuceneRDD.count" should "return correct number of elements" in {
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should equal (array.size)
  }

  "LuceneRDD.termQuery" should "correctly search with TermQueries" in {
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD.termQuery(First, array(1))
    results.count() should equal (1)
  }

  "LuceneRDD.prefixQuery" should "correctly search with PrefixQueries" in {

    val prefices = Array("aaaabcd", "aaadcb", "aaz", "az", "qwerty")
    val rdd = sc.parallelize(prefices)
    luceneRDD = LuceneRDD(rdd)

    luceneRDD.prefixQuery(First, "a").count() should equal (4)
    luceneRDD.prefixQuery(First, "aa").count() should equal(3)
    luceneRDD.prefixQuery(First, "aaa").count() should equal (2)
    luceneRDD.prefixQuery(First, "aaaa").count() should equal (1)
  }

  "LuceneRDD.fuzzyQuery" should "correctly search with FuzzyQuery" in {
    val rdd = sc.parallelize(array)
    luceneRDD = LuceneRDD(rdd)

    luceneRDD.fuzzyQuery(First, "fear", 1).count() should equal (1)
    luceneRDD.fuzzyQuery(First, "fascsm", 1).count() should equal(1)
    luceneRDD.fuzzyQuery(First, "dath", 1).count() should equal (1)
    luceneRDD.fuzzyQuery(First, "tree", 1).count() should equal (1)
  }

  

  "LuceneRDD.phraseQuery" should "correctly search with PhraseQuery" in {
    val phrases = Array("hello world", "the company name was", "highlight lucene")
    val rdd = sc.parallelize(phrases)
    luceneRDD = LuceneRDD(rdd)

    luceneRDD.phraseQuery(First, "company name", 10).count() should equal (1)
    luceneRDD.phraseQuery(First, "hello world", 10).count() should equal (1)
    luceneRDD.phraseQuery(First, "highlight lucene", 10).count() should equal(1)
  }
} 
Example 19
Source File: BlockingDedupSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.lucene.index.Term
import org.apache.lucene.search.{Query, TermQuery}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.testing.Person

class BlockingDedupSpec extends FlatSpec
  with Matchers
  with BeforeAndAfterEach
  with SharedSparkContext {

  override val conf: SparkConf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  "LuceneRDD.blockDedup" should "deduplicate elements on unique elements" in {
    val spark = SparkSession.builder().getOrCreate()
    import spark.implicits._

    val people: Array[Person] = Array("fear", "death", "water", "fire", "house")
      .zipWithIndex.map { case (str, index) =>
      val email = if (index % 2 == 0) "[email protected]" else "[email protected]"
      Person(str, index, email)
    }
    val df = sc.parallelize(people).repartition(2).toDF()

    val linker: Row => Query = { row =>
      val name = row.getString(row.fieldIndex("name"))
      val term = new Term("name", name)

      new TermQuery(term)
    }


    val linked = LuceneRDD.blockDedup(df, linker, Array("email"))

    val linkedCount, dfCount = (linked.count, df.count())

    linkedCount should equal(dfCount)

    // Check for correctness
    // Age is a unique index
    linked.collect().foreach { case (row, results) =>
      val leftAge, rightAge = (row.getInt(row.fieldIndex("age")),
        results.headOption.map(x => x.getInt(x.fieldIndex("age"))))

      leftAge should equal(rightAge)

    }
  }
} 
Example 20
Source File: LuceneDocToSparkRowpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import java.io.{Reader, StringReader}

import org.apache.lucene.document.{Document, DoublePoint, Field, FloatPoint, IntPoint, LongPoint, StoredField, TextField}
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.models.SparkScoreDoc
import org.zouzias.spark.lucenerdd.models.SparkScoreDoc.{DocIdField, ScoreField, ShardField}

import scala.collection.JavaConverters._

class LuceneDocToSparkRowpec extends FlatSpec
  with Matchers
  with BeforeAndAfterEach {

  val (score: Float, docId: Int, shardIndex: Int) = (1.0f, 1, 2)
  val float: Float = 20.001f
  val double: Double = 10.1000000001D

  def generate_doc(): Document = {
    val doc = new Document()

    // Add long field
    doc.add(new LongPoint("longField", 10))
    doc.add(new StoredField("longField", 10))

    doc.add(new FloatPoint("floatField", float))
    doc.add(new StoredField("floatField", float))

    doc.add(new IntPoint("intField", 9))
    doc.add(new StoredField("intField", 9))

    doc.add(new DoublePoint("doubleField", double))
    doc.add(new StoredField("doubleField", double))

    doc.add(new TextField("textField", "hello world", Field.Store.NO))
    doc.add(new StoredField("textField", "hello world"))

    doc
  }

  private val doc: Document = generate_doc()

  val sparkScoreDoc = SparkScoreDoc(score, docId, shardIndex, doc)


  "SparkScoreDoc.toRow" should "return correct score" in {
    val row = sparkScoreDoc.toRow()
    row.getFloat(row.fieldIndex(ScoreField)) should equal(score)
  }

  "SparkScoreDoc.toRow" should "return correct docId" in {
    val row = sparkScoreDoc.toRow()
    row.getInt(row.fieldIndex(DocIdField)) should equal(docId)
  }

  "SparkScoreDoc.toRow" should "return correct shard number" in {
    val row = sparkScoreDoc.toRow()
    row.getInt(row.fieldIndex(ShardField)) should equal(shardIndex)
  }

  "SparkScoreDoc.toRow" should "return correct number of fields" in {
    val row = sparkScoreDoc.toRow()
    row.getFields().asScala.count(_.fieldType().stored()) should equal(8)
  }

  "SparkScoreDoc.toRow" should "set correctly DoublePoint" in {
    val row = sparkScoreDoc.toRow()
    row.getDouble(row.fieldIndex("doubleField")) should equal(double)
  }

  "SparkScoreDoc.toRow" should "set correctly FloatPoint" in {
    val row = sparkScoreDoc.toRow()
    row.getFloat(row.fieldIndex("floatField")) should equal(float)
  }
} 
Example 21
Source File: LuceneRDDTermVectorsSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import org.zouzias.spark.lucenerdd.testing.LuceneRDDTestUtils

class LuceneRDDTermVectorsSpec extends FlatSpec
  with Matchers
  with BeforeAndAfterEach
  with LuceneRDDTestUtils
  with SharedSparkContext {

  var luceneRDD: LuceneRDD[_] = _

  override def Radius: Double = 0

  override val conf: SparkConf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  override def afterEach() {
    luceneRDD.close()
  }

  val First = "_1"

  "LuceneRDD.termVectors" should "return valid terms" in {

    val words = Array("To smile or not to smile smile",
      "Don't cry because it's over, smile because it happened",
      "So many books, so little time",
      "A room without books is like a body without a soul",
      "If you tell the truth, you don't have to remember anything")
    val rdd = sc.parallelize(words)

    luceneRDD = LuceneRDD(rdd)

    val terms = luceneRDD.termVectors(First).collect()

    // These terms should exist
    terms.exists(_.term.compareToIgnoreCase("time") == 0) should equal(true)
    terms.exists(_.term.compareToIgnoreCase("room") == 0) should equal(true)
    terms.exists(_.term.compareToIgnoreCase("soul") == 0) should equal(true)
    terms.exists(_.term.compareToIgnoreCase("smile") == 0) should equal(true)

    terms.exists(t => (t.term.compareToIgnoreCase("smile") == 0)
      && t.count == 3) should equal (true)
    terms.exists(t => (t.term.compareToIgnoreCase("becaus") == 0)
      && t.count == 2) should equal (true)
  }
} 
Example 22
Source File: LuceneRDDTuplesSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import org.scalatest.{FlatSpec, Matchers}

class LuceneRDDTuplesSpec extends FlatSpec with Matchers with SharedSparkContext {

  val First = "_1"
  val Second = "_2"

  val array = List("fear", "death", " apology", "romance", "tree", "fashion", "fascism")


  override val conf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  "LuceneRDD" should "work with Tuple2" in {
    val rdd = sc.parallelize(array).map(x => (x, x))
    val luceneRDD = LuceneRDD(rdd)
    luceneRDD.count should equal (array.size)
  }

  "LuceneRDD" should "work with Tuple3" in {
    val rdd = sc.parallelize(array).map(x => (x, x, x))
    val luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD.termQuery(Second, array(1))
    results.count should equal (1)
  }

  "LuceneRDD" should "work with Tuple4" in {
    val rdd = sc.parallelize(array).map(x => (x, x, x, x))
    val luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD.termQuery(Second, array(1))
    results.count should equal (1)
  }

  "LuceneRDD" should "work with Tuple5" in {
    val rdd = sc.parallelize(array).map(x => (x, x, x, x, x))
    val luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD.termQuery(Second, array(1))
    results.count should equal (1)
  }

  "LuceneRDD" should "work with Tuple6" in {
    val rdd = sc.parallelize(array).map(x => (x, x, x, x, x, x))
    val luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD.termQuery(Second, array(1))
    results.count should equal (1)
  }

  "LuceneRDD" should "work with Tuple7" in {
    val rdd = sc.parallelize(array).map(x => (x, x, 2.0d, 1.0d, x, 1, x))
    val luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD.termQuery(First, array.head)
    results.count should equal (1)
  }

  "LuceneRDD" should "work with Tuple8" in {
    val rdd = sc.parallelize(array).map(x => (x, x, 2.0d, 1.0d, x, 1, x, 3.4))
    val luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD.termQuery(First, array(1))
    results.count should equal (1)
  }

  "LuceneRDD" should "work with mixed types in Tuples" in {
    val rdd = sc.parallelize(array).map(x => (x, 1, x, 2L, x, 3.0F))
    val luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD.termQuery(First, array(1))
    results.count should equal (1)
  }
} 
Example 23
Source File: LuceneRDDMoreLikeThisSpec.scala    From spark-lucenerdd   with Apache License 2.0 5 votes vote down vote up
package org.zouzias.spark.lucenerdd

import com.holdenkarau.spark.testing.SharedSparkContext
import org.apache.spark.SparkConf
import scala.collection.JavaConverters._
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}

import scala.io.Source

class LuceneRDDMoreLikeThisSpec extends FlatSpec
  with Matchers
  with BeforeAndAfterEach
  with SharedSparkContext {

  var luceneRDD: LuceneRDD[_] = _


  override val conf = LuceneRDDKryoRegistrator.registerKryoClasses(new SparkConf().
    setMaster("local[*]").
    setAppName("test").
    set("spark.ui.enabled", "false").
    set("spark.app.id", appID))

  override def afterEach() {
    luceneRDD.close()
  }

  "LuceneRDD.moreLikeThis" should "return relevant documents" in {
    val words: Seq[String] = Source.fromFile("src/test/resources/alice.txt")
      .getLines().map(_.toLowerCase).toSeq
    val rdd = sc.parallelize(words)
    luceneRDD = LuceneRDD(rdd)
    val results = luceneRDD
      .moreLikeThis("_1", "alice adventures wonderland", 1, 1)
      .collect()

    results.length > 0 should equal(true)
    val firstDoc = results.head
    val x = firstDoc.getString(firstDoc.fieldIndex("_1"))

    x.contains("alice") &&
      x.contains("wonderland") &&
      x.contains("adventures") should equal(true)

    val lastDoc = results.last
    val y = lastDoc.getString(lastDoc.fieldIndex("_1"))


      y.contains("alice") &&
        !y.contains("wonderland") &&
        !y.contains("adventures") should equal(true)

  }
} 
Example 24
Source File: ProxyOptionsTest.scala    From jvm-toxcore-api   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.options

import im.tox.tox4j.core.enums.ToxProxyType
import org.scalatest.FlatSpec

@SuppressWarnings(Array("org.wartremover.warts.Equals"))
final class ProxyOptionsTest extends FlatSpec {

  "proxy options" should "not allow negative ports" in {
    intercept[IllegalArgumentException] {
      ProxyOptions.Http("localhost", -1)
    }
    intercept[IllegalArgumentException] {
      ProxyOptions.Socks5("localhost", -1)
    }
  }

  it should "allow the port to be 0" in {
    ProxyOptions.Http("localhost", 0)
    ProxyOptions.Socks5("localhost", 0)
  }

  it should "allow the port to be 65535" in {
    ProxyOptions.Http("localhost", 65535)
    ProxyOptions.Socks5("localhost", 65535)
  }

  it should "not allow the port to be greater than 65535" in {
    intercept[IllegalArgumentException] {
      ProxyOptions.Http("localhost", 65536)
    }
    intercept[IllegalArgumentException] {
      ProxyOptions.Socks5("localhost", 65536)
    }
  }

  it should "produce the right low level enum values" in {
    assert(ProxyOptions.None.proxyType == ToxProxyType.NONE)
    assert(ProxyOptions.Http("localhost", 1).proxyType == ToxProxyType.HTTP)
    assert(ProxyOptions.Socks5("localhost", 1).proxyType == ToxProxyType.SOCKS5)
  }

} 
Example 25
Source File: ToxOptionsTest.scala    From jvm-toxcore-api   with GNU General Public License v3.0 5 votes vote down vote up
package im.tox.tox4j.core.options

import org.scalatest.FlatSpec

final class ToxOptionsTest extends FlatSpec {

  "tox options" should "not allow negative ports" in {
    intercept[IllegalArgumentException] {
      ToxOptions(startPort = -1)
    }
    intercept[IllegalArgumentException] {
      ToxOptions(endPort = -1)
    }
    intercept[IllegalArgumentException] {
      ToxOptions(tcpPort = -1)
    }
  }

  it should "allow the port to be 0" in {
    ToxOptions(startPort = 0, endPort = 0, tcpPort = 0)
  }

  it should "allow the port to be 65535" in {
    ToxOptions(startPort = 65535, endPort = 65535, tcpPort = 65535)
  }

  it should "not allow the port to be greater than 65535" in {
    intercept[IllegalArgumentException] {
      ToxOptions(startPort = 65536)
    }
    intercept[IllegalArgumentException] {
      ToxOptions(endPort = 65536)
    }
    intercept[IllegalArgumentException] {
      ToxOptions(tcpPort = 65536)
    }
  }

  it should "require startPort <= endPort" in {
    intercept[IllegalArgumentException] {
      ToxOptions(startPort = 2, endPort = 1)
    }
  }

} 
Example 26
Source File: CounterEtlFunctionsSpec.scala    From incubator-s2graph   with Apache License 2.0 5 votes vote down vote up
package org.apache.s2graph.counter.loader.core

import com.typesafe.config.ConfigFactory
import org.apache.s2graph.core.schema.{Label, Service}
import org.apache.s2graph.core.types.HBaseType
import org.apache.s2graph.core.{S2Graph, Management}
import org.apache.s2graph.counter.models.DBModel
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

import scala.concurrent.ExecutionContext.Implicits.global

class CounterEtlFunctionsSpec extends FlatSpec with BeforeAndAfterAll with Matchers {
  val config = ConfigFactory.load()
  val cluster = config.getString("hbase.zookeeper.quorum")
  DBModel.initialize(config)

  val graph = new S2Graph(config)(global)
  val management = new Management(graph)

  override def beforeAll: Unit = {
    management.createService("test", cluster, "test", 1, None, "gz")
    management.createLabel("test_case", "test", "src", "string", "test", "tgt", "string", true, "test", Nil, Nil, "weak", None, None, HBaseType.DEFAULT_VERSION, false, "gz")
  }

  override def afterAll: Unit = {
    Label.delete(Label.findByName("test_case", false).get.id.get)
    Service.delete(Service.findByName("test", false).get.id.get)
  }

  "CounterEtlFunctions" should "parsing log" in {
    val data =
      """
        |1435107139287	insert	e	aaPHfITGUU0B_150212123559509	abcd	test_case	{"cateid":"100110102","shopid":"1","brandid":""}
        |1435106916136	insert	e	Tgc00-wtjp2B_140918153515441	efgh	test_case	{"cateid":"101104107","shopid":"2","brandid":""}
      """.stripMargin.trim.split('\n')
    val items = {
      for {
        line <- data
        item <- CounterEtlFunctions.parseEdgeFormat(line)
      } yield {
        item.action should equal("test_case")
        item
      }
    }

    items should have size 2
  }
} 
Example 27
Source File: ContractInjectionTests.scala    From dsentric   with Apache License 2.0 5 votes vote down vote up
package dsentric

import dsentric._
import org.scalatest.{FlatSpec, FunSpec, Matchers}

class ContractInjectionTests extends FlatSpec with Matchers {

  import PessimisticCodecs._
  import Dsentric._

  object Flat extends Contract {
    val one = \?[String]
    val two = \![Int](123)
    val _three = \![String]("three", "bob")
  }

  object FlatNoDefaults extends Contract {
    val one = \?[String]
  }

  object Nested extends Contract {
    val nest = new \\ {
      val one = \?[String]
      val two = \![Int](123)
      val _three = \![String]("three", "bob")
    }
  }

  object NestedNoDefaults extends Contract {
    val nest = new \\ {
      val one = \?[String]
    }
  }

  "$applyDefaults" should "add a missing value to a dobject only if it has a default" in {
    val dObject = DObject.empty
    val res = Flat.$applyDefaults(dObject)
    Flat.one.$get(res) shouldBe None
    Flat.two.$get(res) shouldBe 123
  }
  "$applyDefaults" should "does nothing if no defaults in flat structure" in {
    val dObject = DObject.empty
    val res = FlatNoDefaults.$applyDefaults(dObject)
    res shouldBe DObject.empty
  }
  "$applyDefaults" should "not override an existing property in a dobject with a default" in {
    val dObject = DObject("two" -> Data(321))
    val res = Flat.$applyDefaults(dObject)
    Flat.two.$get(res) shouldBe 321
  }
  "$applyDefaults" should "work when there is a name override" in {
    val dObject1 = DObject.empty
    val res1 = Flat.$applyDefaults(dObject1)
    Flat._three.$get(res1) shouldBe "bob"

    val dObject2 = DObject("three" -> Data("sally"))
    val res2 = Flat.$applyDefaults(dObject2)
    Flat._three.$get(res2) shouldBe "sally"
  }

  "$applyDefaults" should "add a missing value in a nested structure" in {
    val dObject = DObject.empty
    val res = Nested.$applyDefaults(dObject)
    Nested.nest.two.$get(res) shouldBe 123
  }
  "$applyDefaults" should "work when there is a name override in a nested structure" in {
    val dObject1 = DObject.empty
    val res1 = Nested.$applyDefaults(dObject1)
    Nested.nest._three.$get(res1) shouldBe "bob"

    val dObject2 = DObject("nest" -> DObject("three" -> Data("sally")))
    val res2 = Nested.$applyDefaults(dObject2)
    Nested.nest._three.$get(res2) shouldBe "sally"
  }
  "$applyDefaults" should "does nothing if no defaults in nested structure" in {
    val dObject = DObject("non" -> Data("bob"))
    val res = NestedNoDefaults.$applyDefaults(dObject)
    res shouldBe dObject
  }


} 
Example 28
Source File: PailDataSourceSpec.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.spark.pail

import java.util

import com.backtype.hadoop.pail.{PailFormatFactory, PailSpec, PailStructure}
import com.backtype.support.{Utils => PailUtils}
import com.google.common.io.Files
import org.apache.commons.io.FileUtils
import org.apache.spark.sql.SparkSession
import org.scalatest.{BeforeAndAfterAll, FlatSpec}
import org.scalatest.Matchers._

import scala.collection.JavaConverters._
import scala.util.Random

case class User(name: String, age: Int)

class UserPailStructure extends PailStructure[User] {
  override def isValidTarget(dirs: String*): Boolean = true

  override def getType: Class[_] = classOf[User]

  override def serialize(user: User): Array[Byte] = PailUtils.serialize(user)

  override def getTarget(user: User): util.List[String] = List(user.age % 10).map(_.toString).asJava

  override def deserialize(serialized: Array[Byte]): User = PailUtils.deserialize(serialized).asInstanceOf[User]
}

class PailDataSourceSpec extends FlatSpec with BeforeAndAfterAll with PailDataSource {
  private var spark: SparkSession = _

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    spark = SparkSession.builder().master("local[2]").appName("PailDataSource").getOrCreate()
  }

  val userPailSpec = new PailSpec(PailFormatFactory.SEQUENCE_FILE, new UserPailStructure)

  "PailBasedReaderWriter" should "read/write user records from/into pail" in {
    val output = Files.createTempDir()
    val users = (1 to 100).map { index => User(s"foo$index", Random.nextInt(40))}
    spark.sparkContext.parallelize(users)
      .saveAsPail(output.getAbsolutePath, userPailSpec)

    val input = output.getAbsolutePath
    val total = spark.sparkContext.pailFile[User](input)
      .map(u => u.name)
      .count()

    total should be(100)
    FileUtils.deleteDirectory(output)
  }
} 
Example 29
Source File: ParquetAvroDataSourceSpec.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.spark.parquet

import java.io.File

import com.google.common.io.Files
import com.indix.utils.spark.parquet.avro.ParquetAvroDataSource
import org.apache.commons.io.FileUtils
import org.apache.parquet.hadoop.metadata.CompressionCodecName
import org.apache.spark.sql.SparkSession
import org.scalactic.Equality
import org.scalatest.Matchers.{be, convertToAnyShouldWrapper, equal}
import org.scalatest.{BeforeAndAfterAll, FlatSpec}
import java.util.{Arrays => JArrays}

case class SampleAvroRecord(a: Int, b: String, c: Seq[String], d: Boolean, e: Double, f: collection.Map[String, String], g: Array[Byte])

class ParquetAvroDataSourceSpec extends FlatSpec with BeforeAndAfterAll with ParquetAvroDataSource {
  private var spark: SparkSession = _
  implicit val sampleAvroRecordEq = new Equality[SampleAvroRecord] {
    override def areEqual(left: SampleAvroRecord, b: Any): Boolean = b match {
      case right: SampleAvroRecord =>
        left.a == right.a &&
          left.b == right.b &&
          Equality.default[Seq[String]].areEqual(left.c, right.c) &&
          left.d == right.d &&
          left.e == right.e &&
          Equality.default[collection.Map[String, String]].areEqual(left.f, right.f) &&
          JArrays.equals(left.g, right.g)
      case _ => false
    }
  }

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    spark = SparkSession.builder().master("local[2]").appName("ParquetAvroDataSource").getOrCreate()
  }

  override protected def afterAll(): Unit = {
    try {
      spark.sparkContext.stop()
    } finally {
      super.afterAll()
    }
  }

  "AvroBasedParquetDataSource" should "read/write avro records as ParquetData" in {

    val outputLocation = Files.createTempDir().getAbsolutePath + "/output"

    val sampleRecords: Seq[SampleAvroRecord] = Seq(
      SampleAvroRecord(1, "1", List("a1"), true, 1.0d, Map("a1" -> "b1"), "1".getBytes),
      SampleAvroRecord(2, "2", List("a2"), false, 2.0d, Map("a2" -> "b2"), "2".getBytes),
      SampleAvroRecord(3, "3", List("a3"), true, 3.0d, Map("a3" -> "b3"), "3".getBytes),
      SampleAvroRecord(4, "4", List("a4"), true, 4.0d, Map("a4" -> "b4"), "4".getBytes),
      SampleAvroRecord(5, "5", List("a5"), false, 5.0d, Map("a5" -> "b5"), "5".getBytes)
    )

    val sampleDf = spark.createDataFrame(sampleRecords)

    sampleDf.rdd.saveAvroInParquet(outputLocation, sampleDf.schema, CompressionCodecName.GZIP)

    val sparkVal = spark

    import sparkVal.implicits._

    val records: Array[SampleAvroRecord] = spark.read.parquet(outputLocation).as[SampleAvroRecord].collect()

    records.length should be(5)
    // We use === to use the custom Equality defined above for comparing Array[Byte]
    // Ref - https://github.com/scalatest/scalatest/issues/491
    records.sortBy(_.a) === sampleRecords.sortBy(_.a)

    FileUtils.deleteDirectory(new File(outputLocation))
  }

} 
Example 30
Source File: RocksMapTest.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.store

import java.io.Serializable
import java.nio.file.{Paths, Files}

import org.apache.commons.io.FileUtils
import org.scalatest.{Matchers, FlatSpec}


case class TestObject(a: Int, b: String, c: Array[Int], d: Array[String]) extends Serializable {

  def equals(other: TestObject): Boolean = {
    this.a.equals(other.a) && this.b.equals(other.b) && this.c.sameElements(other.c) && this.d.sameElements(other.d)
  }

}

case class ComplexTestObject(a: Int, b: TestObject) extends Serializable {
  def equals(other: ComplexTestObject): Boolean = {
    this.a.equals(other.a) && this.b.equals(other.b)
  }
}

class RocksMapTest extends FlatSpec with Matchers {

  "RocksMap" should "serialize and deserialize the keys and values" in {
    val db = new RocksMap("test")

    val a: Int = 1
    val b: String = "hello"
    val c: Array[Int] = Array(1, 2, 3)

    val d: Array[String] = Array("a", "b", "c")

    val serialized_a = db.serialize(a)
    val serialized_b = db.serialize(b)
    val serialized_c = db.serialize(c)
    val serialized_d = db.serialize(d)
    val serialized_TestObject = db.serialize(TestObject(a, b, c, d))
    val serialized_ComplexObject = db.serialize(ComplexTestObject(a, TestObject(a, b, c, d)))

    db.deserialize[Int](serialized_a) should be(a)
    db.deserialize[String](serialized_b) should be(b)
    db.deserialize[Array[Int]](serialized_c) should be(c)
    db.deserialize[Array[String]](serialized_d) should be(d)
    db.deserialize[TestObject](serialized_TestObject).equals(TestObject(a, b, c, d)) should be(true)
    db.deserialize[ComplexTestObject](serialized_ComplexObject).equals(ComplexTestObject(a, TestObject(a, b, c, d))) should be(true)
    db.drop()
    db.close()
  }

  it should "put and get values" in {
    val db = new RocksMap("test")

    db.put(1, 1.0)
    db.get[Int, Double](1).getOrElse(0) should be(1.0)
    db.clear()
    db.drop()
    db.close()
  }

  it should "remove values" in {
    val db = new RocksMap("test")

    db.put(1, 1L)
    db.get[Int, Long](1).getOrElse(0) should be(1L)
    db.remove(1)
    db.get[Int, Long](1) should be(None)
    db.drop()
    db.close()
  }

  it should "clear all the values" in {
    val db = new RocksMap(name = "test")
    db.put(1, "hello")
    db.put(2, "yello")
    db.get(1) should not be (None)
    db.get(2) should not be (None)
    db.clear()
    db.get(1) should be(None)
    db.get(2) should be(None)
    db.drop()
    db.close()
  }

  it should "clear the data files when drop is called" in {
    val db = new RocksMap(name = "test")
    Files.exists(Paths.get(db.pathString)) should be (true)
    db.drop()
    Files.exists(Paths.get(db.pathString)) should be (false)
    db.close()
  }


} 
Example 31
Source File: UrlUtilsSpec.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.core

import org.scalatest.FlatSpec
import org.scalatest.Matchers._
import scala.collection.JavaConverters._

class UrlUtilsSpec extends FlatSpec {

  "UrlUtils#toHostname" should "return the hostname of any given url" in {
    UrlUtils.toHostname("http://www.google.com") should be(Some("www.google.com"))
    UrlUtils.toHostname("https://www.google.com") should be(Some("www.google.com"))
    UrlUtils.toHostname("www.google.com/abc") should be(None)
  }

  "UrlUtils#toQueryMap" should "return the query params from a url as a scala map" in {
    val resMap = UrlUtils.toQueryMap("http://google.com/?query=hello&lang=en&somekey=value&")
    resMap.size should be (3)
    resMap.head should be ("query", "hello")

    val resMap1 = UrlUtils.toQueryMap("http://google.com/???query=hello&lang=en&somekey=value&")
    resMap1.size should be (3)
    resMap1.head should be ("??query", "hello")

    val resMap2 = UrlUtils.toQueryMap("http://google.com/")
    resMap2.size should be (0)

    val resMap3 = UrlUtils.toQueryMap("http://uae.souq.com/ae-en/educational-book/national-park-service/english/a-19-1401/l/?ref=nav?ref=nav&page=23&msg=foo%20bar")
    resMap3.size should be (3)
    resMap3 should contain ("ref" -> "nav?ref=nav")
    resMap3 should contain ("page" -> "23")
    resMap3 should contain ("msg" -> "foo bar")
  }

  "UrlUtils#isValid" should "return true/false given the url is valid" in {
    UrlUtils.isValid("google.coma") should be(false)
    UrlUtils.isValid("https://google.com") should be(true)
  }

  "UrlUtils#resolve" should "resolve relative urls against the base url" in {
    UrlUtils.resolve("http://google.com/", "shopping") should be("http://google.com/shopping")
  }

  "UrlUtils#decode" should "UTF-8 encoded urls to unicode strings" in {
    UrlUtils.decode("http%3A%2F%2Fwww.example.com%2Fd%C3%BCsseldorf%3Fneighbourhood%3DL%C3%B6rick") should be ("http://www.example.com/düsseldorf?neighbourhood=Lörick")
  }

  "UrlUtils#encode" should "UTF-8 decoded urls to unicode strings" in {
    UrlUtils.encode("http://www.example.com/düsseldorf?neighbourhood=Lörick") should be ("http%3A%2F%2Fwww.example.com%2Fd%C3%BCsseldorf%3Fneighbourhood%3DL%C3%B6rick")
  }

  "UrlUtils#encodeSpaces" should "UTF-8 decoded urls to unicode strings" in {
    UrlUtils.encode("word1 abcd") should be ("word1%20abcd")
  }

  "UrlUtils#stripHashes" should "UTF-8 decoded urls to unicode strings" in {
    UrlUtils.stripHashes("http://www.example.com/url#fragment") should be ("http://www.example.com/url")
    UrlUtils.stripHashes("http://www.example.com/url#fragment1#fragment2") should be ("http://www.example.com/url")
  }

  "UrlUtils#addHashFragments" should "add fragments to url" in {
    UrlUtils.addHashFragments("http://www.example.com/url",
      Map[String, String](
        "attr1" -> "fragment2",
        "attr2" -> "fragment 1",
        "attr3" -> "Fragment-of-1",
        "attr4" -> "XL"
      ).asJava) should be ("http://www.example.com/url#Fragment2#Fragment+1#Fragment-Of-1#XL")
  }

  "UrlUtils#convertToUrlFragment" should "convert to url fragments" in {
    UrlUtils.convertToUrlFragment("x-large / red") should be ("X-Large+%2F+Red")
  }

  "UrlUtils#get" should "UTF-8 decoded urls to unicode strings" in {
    UrlUtils.getHashFragments("http://www.example.com/url#fragment1#fragment2") should be (List("fragment1", "fragment2"))
    UrlUtils.getHashFragments("http://www.example.com/url") should be (List.empty)
  }

} 
Example 32
Source File: UPCSpec.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.core

import org.scalatest.{FlatSpec, Matchers}

class UPCSpec extends FlatSpec with Matchers {

  "UPC" should "convert a UPC to a standardized format" in {
    UPC.standardize("63938200039") should be("00639382000393")
    UPC.standardize("99999999623") should be("00999999996237")
    UPC.standardize("89504500098") should be("00895045000982")
    UPC.standardize("934093502349") should be("09340935023493")
    UPC.standardize("841106172217") should be("08411061722176")
    UPC.standardize("810000439") should be("00008100004393")
    UPC.standardize("931177059140") should be("09311770591409")
    UPC.standardize("9311770591409") should be("09311770591409")
    UPC.standardize("27242860940") should be("00027242860940")
    UPC.standardize("75317405253") should be("00753174052534")
    UPC.standardize("-810000439") should be("00008100004393")
    UPC.standardize("810-000-439") should be("00008100004393")

    // Iphone UPCs
    UPC.standardize("885909950652") should be("00885909950652")
    UPC.standardize("715660702866") should be("00715660702866")
  }

  it should "check if the input UPC is valid as checked by UPCItemDB" in {
    UPC.isValid("0420160002247") should be(false)
    UPC.isValid("000000010060") should be(false)
    // the above same UPCs are validated after standardizing them
    UPC.isValid(UPC.standardize("0420160002247")) should be(true)
    UPC.isValid(UPC.standardize("000000010060")) should be(true)
  }

  it should "work correctly for GTIN UPCs by converting it to a valid EAN-13 with padded zeros" in {
    UPC.standardize("10010942220401") should be ("00010942220404")
    UPC.standardize("47111850104013") should be("07111850104015")
    UPC.standardize("40628043604719") should be("00628043604711")
  }


  it should "work correctly for ISBN numbers" in {
    UPC.standardize("978052549832") should be("9780525498322")
    UPC.standardize("9780500517260") should be("9780500517260")
    UPC.standardize("9780316512787") should be("9780316512787")
    UPC.standardize("9780997355932") should be("9780997355932")
  }

  it should "not replace check-digit if UPC already left padded and check-digit and last-digit same" in {
    UPC.standardize("0753174052534") should be("00753174052534")
  }

  it should "fail for an invalid 14-digit UPC (GTIN)" in {
    intercept[IllegalArgumentException] {
      UPC.standardize("47111850104010")
    }
  }

  it should "fail for all zeroes UPC" in {
    intercept[IllegalArgumentException] {
      UPC.standardize("00000000000")
    }
  }

  it should "fail for invalid UPC" in {
    intercept[IllegalArgumentException] {
      UPC.standardize("12345")
    }
  }

  it should "fail for empty or null UPC" in {
    intercept[IllegalArgumentException] {
      UPC.standardize("")
    }
  }

  it should "fail in case UPC is not a number" in {
    intercept[IllegalArgumentException] {
      UPC.standardize("ABCD")
    }
  }
} 
Example 33
Source File: MPNSpec.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.core

import org.scalatest.{FlatSpec, Matchers}

class MPNSpec extends FlatSpec with Matchers {

  behavior of "MPN"

  it should "check title case" in {
    MPN.isTitleCase("Key Shell") should be(true)
    MPN.isTitleCase("Samsung Galaxy A8") should be(true)
    MPN.isTitleCase("Samsung Galaxy Note 5") should be(true)
    MPN.isTitleCase("Tempered Glass") should be(true)

    MPN.isTitleCase("1442820G1") should be(false)
    MPN.isTitleCase("Macbook") should be(false)
    MPN.isTitleCase("CE 7200") should be(false)
    MPN.isTitleCase("IPHONE") should be(false)
  }

  it should "validate identifier" in {
    MPN.isValidIdentifier(null) should be (false)
    MPN.isValidIdentifier("") should be (false)
    MPN.isValidIdentifier("51") should be (false)
    MPN.isValidIdentifier("  NA   ") should be (false)
    MPN.isValidIdentifier("Does not apply") should be (false)

    MPN.isValidIdentifier("DT.VFGAA.003") should be (true)
    MPN.isValidIdentifier("A55BM-A/USB3") should be (true)
    MPN.isValidIdentifier("cASSP1598345-10") should be (true)
    MPN.isValidIdentifier("016393B119058-Regular-18x30-BE-BK") should be (true)
    MPN.isValidIdentifier("PJS2V") should be (true)
  }

  it should "standardize MPN" in {
    MPN.standardizeMPN(null) should be (None)
    MPN.standardizeMPN("Does not apply") should be (None)
    MPN.standardizeMPN("All Windows %22") should be (None)
    MPN.standardizeMPN("Samsung Galaxy Note 5") should be (None)
    MPN.standardizeMPN("A Square") should be (None)
    MPN.standardizeMPN("{{availableProducts[0].sku}}") should be (None)

    MPN.standardizeMPN("PJS2V") should be (Some("PJS2V"))
    MPN.standardizeMPN("30634190, 30753839, 31253006") should be (Some("30634190"))
    MPN.standardizeMPN("mr16r082gbn1-ck8 ") should be (Some("MR16R082GBN1-CK8"))
    MPN.standardizeMPN("SM-G950FZDAXSA") should be (Some("SM-G950FZDAXSA"))
  }

} 
Example 34
Source File: ISBNSpec.scala    From utils   with Apache License 2.0 5 votes vote down vote up
package com.indix.utils.core

import org.scalatest.{FlatSpec, Matchers}

class ISBNSpec extends FlatSpec with Matchers {
  "ISBN" should "create a valid ISBN object for" in {
    ISBN("0-306-40615-2").get.isbn10 should be (Some("0306406152"))
    ISBN("0-306-40615-2").get.isbn should be ("9780306406157")
    ISBN("978-0-306-40615-7").get.isbn should be ("9780306406157")
    ISBN("978-0-306-40615-7").get.isbn10 should be (None)
    ISBN("9971502100").get.isbn10 should be (Some("9971502100"))
    ISBN("9971502100").get.isbn should be ("9789971502102")
    ISBN("960 425 059 0").get.isbn10 should be (Some("9604250590"))
    ISBN("960 425 059 0").get.isbn should be ("9789604250592")
  }

  it should "not create a valid ISBN object for" in {
    ISBN("abcd") should be (None)
    ISBN("123") should be (None)
    ISBN("") should be (None)
    ISBN("  ") should be (None)
    ISBN("-") should be (None)
    ISBN("1234567890") should be (None)
    ISBN("1234567890111") should be (None)
    ISBN("0-306-40615-1") should be (None)
    ISBN("978-0-306-40615-5") should be (None)
  }
} 
Example 35
Source File: TimeEncodingCompatibilityItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import java.util.TimeZone

import com.github.mjakubowski84.parquet4s.CompatibilityTestCases.TimePrimitives
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class TimeEncodingCompatibilityItSpec extends
  FlatSpec
    with Matchers
    with BeforeAndAfter
    with SparkHelper {

  private val localTimeZone = TimeZone.getDefault
  private val utcTimeZone = TimeZone.getTimeZone("UTC")
  private lazy val newYearEveEvening = java.sql.Timestamp.valueOf(java.time.LocalDateTime.of(2018, 12, 31, 23, 0, 0))
  private lazy val newYearMidnight = java.sql.Timestamp.valueOf(java.time.LocalDateTime.of(2019, 1, 1, 0, 0, 0))
  private lazy val newYear = java.sql.Date.valueOf(java.time.LocalDate.of(2019, 1, 1))

  override def beforeAll(): Unit = {
    super.beforeAll()
    TimeZone.setDefault(utcTimeZone)
  }

  before {
    clearTemp()
  }

  private def writeWithSpark(data: TimePrimitives): Unit = writeToTemp(Seq(data))
  private def readWithSpark: TimePrimitives = readFromTemp[TimePrimitives].head
  private def writeWithParquet4S(data: TimePrimitives, timeZone: TimeZone): Unit =
    ParquetWriter.writeAndClose(tempPathString, Seq(data), ParquetWriter.Options(timeZone = timeZone))
  private def readWithParquet4S(timeZone: TimeZone): TimePrimitives = {
    val parquetIterable = ParquetReader.read[TimePrimitives](tempPathString, ParquetReader.Options(timeZone = timeZone))
    try {
      parquetIterable.head
    } finally {
      parquetIterable.close()
    }
  }

  "Spark" should "read properly time written with time zone one hour east" in {
    val input = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    writeWithParquet4S(input, TimeZone.getTimeZone("GMT+1"))
    readWithSpark should be(expectedOutput)
  }

  it should "read properly written with time zone one hour west" in {
    val input = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    writeWithParquet4S(input, TimeZone.getTimeZone("GMT-1"))
    readWithSpark should be(expectedOutput)
  }

  "Parquet4S" should "read properly time written with time zone one hour east" in {
    val input = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    writeWithSpark(input)
    readWithParquet4S(TimeZone.getTimeZone("GMT-1")) should be(expectedOutput)
  }

  it should "read properly time written with time zone one hour west" in {
    val input = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    writeWithSpark(input)
    readWithParquet4S(TimeZone.getTimeZone("GMT+1")) should be(expectedOutput)
  }

  override def afterAll(): Unit = {
    TimeZone.setDefault(localTimeZone)
    super.afterAll()
  }

} 
Example 36
Source File: TimeEncodingCompatibilityItSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import java.util.TimeZone

import com.github.mjakubowski84.parquet4s.CompatibilityTestCases.TimePrimitives
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

class TimeEncodingCompatibilityItSpec extends
  FlatSpec
    with Matchers
    with BeforeAndAfter
    with SparkHelper {

  private val localTimeZone = TimeZone.getDefault
  private val utcTimeZone = TimeZone.getTimeZone("UTC")
  private lazy val newYearEveEvening = java.sql.Timestamp.valueOf(java.time.LocalDateTime.of(2018, 12, 31, 23, 0, 0))
  private lazy val newYearMidnight = java.sql.Timestamp.valueOf(java.time.LocalDateTime.of(2019, 1, 1, 0, 0, 0))
  private lazy val newYear = java.sql.Date.valueOf(java.time.LocalDate.of(2019, 1, 1))

  override def beforeAll(): Unit = {
    super.beforeAll()
    TimeZone.setDefault(utcTimeZone)
  }

  before {
    clearTemp()
  }

  private def writeWithSpark(data: TimePrimitives): Unit = writeToTemp(Seq(data))
  private def readWithSpark: TimePrimitives = readFromTemp[TimePrimitives].head
  private def writeWithParquet4S(data: TimePrimitives, timeZone: TimeZone): Unit =
    ParquetWriter.writeAndClose(tempPathString, Seq(data), ParquetWriter.Options(timeZone = timeZone))
  private def readWithParquet4S(timeZone: TimeZone): TimePrimitives = {
    val parquetIterable = ParquetReader.read[TimePrimitives](tempPathString, ParquetReader.Options(timeZone = timeZone))
    try {
      parquetIterable.head
    } finally {
      parquetIterable.close()
    }
  }

  "Spark" should "read properly time written with time zone one hour east" in {
    val input = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    writeWithParquet4S(input, TimeZone.getTimeZone("GMT+1"))
    readWithSpark should be(expectedOutput)
  }

  it should "read properly written with time zone one hour west" in {
    val input = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    writeWithParquet4S(input, TimeZone.getTimeZone("GMT-1"))
    readWithSpark should be(expectedOutput)
  }

  "Parquet4S" should "read properly time written with time zone one hour east" in {
    val input = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    writeWithSpark(input)
    readWithParquet4S(TimeZone.getTimeZone("GMT-1")) should be(expectedOutput)
  }

  it should "read properly time written with time zone one hour west" in {
    val input = TimePrimitives(timestamp = newYearEveEvening, date = newYear)
    val expectedOutput = TimePrimitives(timestamp = newYearMidnight, date = newYear)
    writeWithSpark(input)
    readWithParquet4S(TimeZone.getTimeZone("GMT+1")) should be(expectedOutput)
  }

  override def afterAll(): Unit = {
    TimeZone.setDefault(localTimeZone)
    super.afterAll()
  }

} 
Example 37
Source File: ValueCodecSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import org.scalatest.{FlatSpec, Matchers}


class ValueCodecSpec extends FlatSpec with Matchers {

  case class TestType(i: Int)

  val requiredValueCodec: RequiredValueCodec[TestType] = new RequiredValueCodec[TestType] {
    override protected def decodeNonNull(value: Value, configuration: ValueCodecConfiguration): TestType = value match {
      case IntValue(i) => TestType(i)
    }
    override protected def encodeNonNull(data: TestType, configuration: ValueCodecConfiguration): Value = IntValue(data.i)
  }
  val optionalValueCodec: OptionalValueCodec[TestType] = new OptionalValueCodec[TestType] {
    override protected def decodeNonNull(value: Value, configuration: ValueCodecConfiguration): TestType = value match {
      case IntValue(i) => TestType(i)
    }
    override protected def encodeNonNull(data: TestType, configuration: ValueCodecConfiguration): Value = IntValue(data.i)
  }

  val testType = TestType(42)
  val testValue = IntValue(testType.i)
  val configuration: ValueCodecConfiguration = ValueCodecConfiguration.default

  "Required value codec" should "encode non-null value" in {
    requiredValueCodec.encode(testType, configuration) should be(testValue)
  }

  it should "decode non-null value" in {
    requiredValueCodec.decode(testValue, configuration) should be(testType)
  }

  it should "throw an exception when decoding null-value" in {
    an[IllegalArgumentException] should be thrownBy requiredValueCodec.decode(NullValue, configuration)
  }

  it should "throw an exception when encoding null" in {
    an[IllegalArgumentException] should be thrownBy requiredValueCodec.encode(null.asInstanceOf[TestType], configuration)
  }

  "Optional value codec" should "encode non-null value" in {
    optionalValueCodec.encode(testType, configuration) should be(testValue)
  }

  it should "decode non-null value" in {
    optionalValueCodec.decode(testValue, configuration) should be(testType)
  }

  it should "throw an exception when decoding null-value" in {
    optionalValueCodec.decode(NullValue, configuration) should be(null)
  }

  it should "throw an exception when encoding null" in {
    optionalValueCodec.encode(null.asInstanceOf[TestType], configuration) should be(NullValue)
  }

} 
Example 38
Source File: DecimalsSpec.scala    From parquet4s   with MIT License 5 votes vote down vote up
package com.github.mjakubowski84.parquet4s

import org.scalatest.{FlatSpec, Matchers}

class DecimalsSpec extends FlatSpec with Matchers {

  "Decimals" should "be able to convert decimal to binary and back" in {
    val decimal = BigDecimal(1001, 2)
    val binary = Decimals.binaryFromDecimal(decimal)
    val revertedDecimal = Decimals.decimalFromBinary(binary)

    revertedDecimal should be(decimal)
  }

  it should "be able to convert negative decimal to binary and back" in {
    val decimal = BigDecimal(-1001, 2)
    val binary = Decimals.binaryFromDecimal(decimal)
    val revertedDecimal = Decimals.decimalFromBinary(binary)

    revertedDecimal should be(decimal)
  }

  it should "be able to convert zero decimal to binary and back" in {
    val decimal = BigDecimal(0, 0)
    val binary = Decimals.binaryFromDecimal(decimal)
    val revertedDecimal = Decimals.decimalFromBinary(binary)

    revertedDecimal should be(decimal)
  }

  it should "round decimal with scale greater than Parquet4S uses" in {
    val decimal = BigDecimal(1L, Decimals.Scale + 1, Decimals.MathContext)
    val binary = Decimals.binaryFromDecimal(decimal)
    val revertedDecimal = Decimals.decimalFromBinary(binary)

    revertedDecimal should be(0)
  }

} 
Example 39
Source File: FileCompressionFormatSpec.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package daf.filesystem

import org.scalatest.{ FlatSpec, Matchers }

class FileCompressionFormatSpec extends FlatSpec with Matchers {

  "Uncompressed files" must "be correctly categorized" in {
    FileCompressionFormats.fromName { "path/to/data.json" } should be { FileCompressionFormats.none }
  }

  "lzo compressed files" must "be correctly categorized" in {
    FileCompressionFormats.fromName { "path/to/data.lzo.avro" } should be { FileCompressionFormats.lzo }
    FileCompressionFormats.fromName { "path/to/data.lzo" } should be { FileCompressionFormats.lzo }
    FileCompressionFormats.fromName { "path/to/data.parq.lzo" } should be { FileCompressionFormats.lzo }
  }

  "snappy compressed files" must "be correctly categorized" in {
    FileCompressionFormats.fromName { "path/to/data.snappy.json" } should be { FileCompressionFormats.snappy }
    FileCompressionFormats.fromName { "path/to/data.snappy" } should be { FileCompressionFormats.snappy }
    FileCompressionFormats.fromName { "path/to/data.avro.snappy" } should be { FileCompressionFormats.snappy }
  }

  "gzip data files" must "be correctly categorized" in {
    FileCompressionFormats.fromName { "path/to/data.gzip.parquet" } should be { FileCompressionFormats.gzip }
    FileCompressionFormats.fromName { "path/to/data.gzip" } should be { FileCompressionFormats.gzip }
    FileCompressionFormats.fromName { "path/to/data.json.gzip" } should be { FileCompressionFormats.gzip }
  }

} 
Example 40
Source File: HDFSBase.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package daf.util

import better.files.{ File, _ }
import daf.util.DataFrameClasses.{ Address, Person }
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.hdfs.{ HdfsConfiguration, MiniDFSCluster }
import org.apache.hadoop.test.PathUtils
import org.apache.spark.sql.{ SaveMode, SparkSession }
import org.scalatest.{ BeforeAndAfterAll, FlatSpec, Matchers }
import org.slf4j.LoggerFactory

import scala.util.{ Failure, Random, Try }

abstract class HDFSBase extends FlatSpec with Matchers with BeforeAndAfterAll {

  var miniCluster: Try[MiniDFSCluster] = Failure[MiniDFSCluster](new Exception)

  var fileSystem: Try[FileSystem] = Failure[FileSystem](new Exception)

  val sparkSession: SparkSession = SparkSession.builder().master("local").getOrCreate()

  val alogger = LoggerFactory.getLogger(this.getClass)

  val (testDataPath, confPath) = {
    val testDataPath = s"${PathUtils.getTestDir(this.getClass).getCanonicalPath}/MiniCluster"
    val confPath = s"$testDataPath/conf"
    (
      testDataPath.toFile.createIfNotExists(asDirectory = true, createParents = false),
      confPath.toFile.createIfNotExists(asDirectory = true, createParents = false)
    )
  }

  def pathAvro = "opendata/test.avro"
  def pathParquet = "opendata/test.parquet"
  def pathCsv = "opendata/test.csv"

  def getSparkSession = sparkSession

  override def beforeAll(): Unit = {

    val conf = new HdfsConfiguration()
    conf.setBoolean("dfs.permissions", true)
    System.clearProperty(MiniDFSCluster.PROP_TEST_BUILD_DATA)

    conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, testDataPath.pathAsString)
    //FileUtil.fullyDelete(testDataPath.toJava)

    conf.set(s"hadoop.proxyuser.${System.getProperties.get("user.name")}.groups", "*")
    conf.set(s"hadoop.proxyuser.${System.getProperties.get("user.name")}.hosts", "*")

    val builder = new MiniDFSCluster.Builder(conf)
    miniCluster = Try(builder.build())
    fileSystem = miniCluster.map(_.getFileSystem)
    fileSystem.foreach(fs => {
      val confFile: File = confPath / "hdfs-site.xml"
      for { os <- confFile.newOutputStream.autoClosed } fs.getConf.writeXml(os)
    })

    writeDf()
  }

  override def afterAll(): Unit = {
    miniCluster.foreach(_.shutdown(true))
    val _ = testDataPath.parent.parent.delete(true)
    sparkSession.stop()
  }

  
  private def writeDf(): Unit = {
    import sparkSession.implicits._

    alogger.info(s"TestDataPath ${testDataPath.toJava.getAbsolutePath}")
    alogger.info(s"ConfPath ${confPath.toJava.getAbsolutePath}")
    val persons = (1 to 10).map(i => Person(s"Andy$i", Random.nextInt(85), Address("Via Ciccio Cappuccio")))
    val caseClassDS = persons.toDS()
    caseClassDS.write.format("parquet").mode(SaveMode.Overwrite).save(pathParquet)
    caseClassDS.write.format("com.databricks.spark.avro").mode(SaveMode.Overwrite).save(pathAvro)
    //writing directly the Person dataframe generates an exception
    caseClassDS.toDF.select("name", "age").write.format("csv").mode(SaveMode.Overwrite).option("header", "true").save(pathCsv)
  }
}

object DataFrameClasses {

  final case class Address(street: String)

  final case class Person(name: String, age: Int, address: Address)
} 
Example 41
Source File: SftpHandlerTest.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package it.gov.daf.ftp

import org.scalatest.{FlatSpec, Matchers}

class SftpHandlerTest extends FlatSpec with Matchers {

  val host = "edge1.platform.daf.gov.it"
  val user = "put your user"
  val pwd = "put your password"

  "A SftpHandler" should "be able to print the working directory" in {
    val client = new SftpHandler(user, pwd, host)
    val res = client.workingDir()
    println(res)
    res shouldBe 'Success

    client.disconnect()
  }

  it should "create a directory from an absolute path" in {
    val path = s"/home/$user/test"
    val client = new SftpHandler(user, pwd, host)
    val res = client.mkdir(path)

    println(res)
    res shouldBe 'Success

    client.rmdir(path)
    client.disconnect()
  }

  it should "create folders recursively from a relative path" in {
    val path = s"/home/$user/test/subtest"
    val client = new SftpHandler(user, pwd, host)
    val res = client.mkdir(path)

    println(res)
    res shouldBe 'Success

    client.rmdir(path)
    client.disconnect()
  }

  it should "create a directory from a relative path" in {
    val path = "test"
    val client = new SftpHandler(user, pwd, host)
    val res = client.mkdir(path)

    println(res)
    res shouldBe 'Success

    client.rmdir(path)
    client.disconnect()
  }

} 
Example 42
Source File: EventToKuduEventSpec.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package it.teamdigitale.events

import it.gov.daf.iotingestion.event.Event
import it.teamdigitale.EventModel.EventToKuduEvent
import it.teamdigitale.EventModel.EventToStorableEvent
import org.scalatest.{FlatSpec, Matchers}

import scala.util.{Success, Try}

class EventToKuduEventSpec extends FlatSpec with Matchers{

  val metrics = Range(0,100).map(x => Success( Event(
    version = 1L,
    id = x + "metric",
    ts = System.currentTimeMillis(),
    event_type_id = 0,
    location = "41.1260529:16.8692905",
    source = "http://domain/sensor/url",
    body = Option("""{"rowdata": "this json should contain row data"}""".getBytes()),
    event_subtype_id = Some("SPEED_Via_Cernaia_TO"),
    attributes = Map(
      "value" -> x.toString)
  )))

  // this metric doesn't have any value
  val wrongMetric = Success( Event(
    version = 1L,
    id = "wrongmetric1",
    ts = System.currentTimeMillis(),
    event_type_id = 0,
    location = "41.1260529:16.8692905",
    source = "http://domain/sensor/url",
    body = Option("""{"rowdata": "this json should contain row data"}""".getBytes()),
    event_annotation = Some(s"This is a free text for a wrong metric"),
    event_subtype_id = Some("SPEED_Via_Cernaia_TO"),
    attributes = Map()
  ))

  // this metric doesn't have a correct value
  val wrongMetric2 = Success( Event(
    version = 1L,
    id = "wrongmetric2",
    ts = System.currentTimeMillis(),
    event_type_id = 0,
    location = "41.1260529:16.8692905",
    source = "http://domain/sensor/url",
    body = Option("""{"rowdata": "this json should contain row data"}""".getBytes()),
    event_annotation = Some(s"This is a free text ©"),
    event_subtype_id = Some("SPEED_Via_Cernaia_TO"),
    attributes = Map(
      "value" -> "wrongValue"
    )
  ))

  // this metric doesn't have the metric id
  val wrongMetric3 = Success( Event(
    version = 1L,
    id = "wrongmetric3",
    ts = System.currentTimeMillis(),
    event_type_id = 2,
    location = "41.1260529:16.8692905",
    source = "http://domain/sensor/url",
    body = Option("""{"rowdata": "this json should contain row data"}""".getBytes()),
    event_annotation = Some(s"This is a free text for a wrong metric"),
    attributes = Map(
      "value" -> "100"
    )
  ))


  "Correct events" should "be converted" in {
    val res = metrics.map(event => EventToStorableEvent(event)).flatMap(_.toOption).map(event => EventToKuduEvent(event)).filter(_.isSuccess)
    res.length shouldBe 100
    res.head.get.metric shouldBe 0D
    res.head.get.metric_id shouldBe "SPEED_Via_Cernaia_TO"
  }

  "Wrong events" should "be filtered" in {
    val seq = metrics ++ List(wrongMetric, wrongMetric2, wrongMetric3)

    val res = seq.map(event => EventToStorableEvent(event)).flatMap(_.toOption).map(event => EventToKuduEvent(event)).filter(_.isSuccess)
    res.length shouldBe 100
  }
} 
Example 43
Source File: KuduMiniCluster.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package it.teamdigitale.miniclusters

import org.apache.kudu.client.{KuduClient, MiniKuduCluster}
import org.apache.kudu.client.KuduClient.KuduClientBuilder
import org.apache.kudu.client.MiniKuduCluster.MiniKuduClusterBuilder
import org.apache.kudu.spark.kudu.KuduContext
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}
import org.apache.logging.log4j.LogManager

class KuduMiniCluster extends AutoCloseable {
  val alogger = LogManager.getLogger(this.getClass)

  var kuduMiniCluster: MiniKuduCluster = _
  var kuduClient: KuduClient = _
  var kuduContext: KuduContext = _
  var sparkSession = SparkSession.builder().appName(s"test-${System.currentTimeMillis()}").master("local[*]").getOrCreate()

  def start() {
    alogger.info("Starting KUDU mini cluster")

    System.setProperty(
      "binDir",
      s"${System.getProperty("user.dir")}/src/test/kudu_executables/${sun.awt.OSInfo.getOSType().toString.toLowerCase}"
    )


    kuduMiniCluster = new MiniKuduClusterBuilder()
      .numMasters(1)
      .numTservers(3)
      .build()

    val envMap = Map[String, String](("Xmx", "512m"))

    kuduClient = new KuduClientBuilder(kuduMiniCluster.getMasterAddresses).build()
    assert(kuduMiniCluster.waitForTabletServers(1))

    kuduContext = new KuduContext(kuduMiniCluster.getMasterAddresses, sparkSession.sparkContext)

  }

  override def close() {
    alogger.info("Ending KUDU mini cluster")
    kuduClient.shutdown()
    kuduMiniCluster.shutdown()
    sparkSession.close()
  }
}

  object KuduMiniCluster {

    def main(args: Array[String]): Unit = {

      try {
        val kudu = new KuduMiniCluster()
        kudu.start()

        println(s"MASTER KUDU ${kudu.kuduMiniCluster.getMasterAddresses}")
        while(true){

        }
      }
    }


  } 
Example 44
Source File: KuduEventsHandlerSpec.scala    From daf   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package it.teamdigitale.storage

import java.io.File
import java.util.concurrent.TimeUnit

import org.apache.kudu.spark.kudu._
import it.teamdigitale.miniclusters.KuduMiniCluster
import it.teamdigitale.config.IotIngestionManagerConfig.KuduConfig
import it.teamdigitale.managers.IotIngestionManager
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}
import it.gov.daf.iotingestion.event.Event
import it.teamdigitale.EventModel.{EventToKuduEvent, EventToStorableEvent}
import org.apache.logging.log4j.LogManager

import scala.util.{Failure, Success, Try}

class KuduEventsHandlerSpec extends FlatSpec with Matchers with BeforeAndAfterAll {

  val logger = LogManager.getLogger(this.getClass)
  val kuduCluster = new KuduMiniCluster()

  val metrics: Seq[Try[Event]] = Range(0,100).map(x => Success( Event(
    version = 1L,
    id = x + "metric",
    ts = System.currentTimeMillis() + x ,
    event_type_id = 0,
    location = "41.1260529:16.8692905",
    source = "http://domain/sensor/url",
    body = Option("""{"rowdata": "this json should contain row data"}""".getBytes()),
    event_subtype_id = Some("Via Cernaia(TO)"),
    attributes = Map("value" -> x.toString)
  )))

  val rdd = kuduCluster.sparkSession.sparkContext.parallelize(metrics)


  "KuduEventsHandler" should "store correctly data" in {

   val metricsRDD = rdd
      .map(event => EventToStorableEvent(event))
      .flatMap(e => e.toOption)
      .map(se => EventToKuduEvent(se)).flatMap(e => e.toOption)

    val metricsDF = kuduCluster.sparkSession.createDataFrame(metricsRDD)

    val kuduConfig = KuduConfig(kuduCluster.kuduMiniCluster.getMasterAddresses, "TestEvents", 2)

    KuduEventsHandler.getOrCreateTable(kuduCluster.kuduContext, kuduConfig)
    KuduEventsHandler.write(metricsDF, kuduCluster.kuduContext, kuduConfig)

    val df = kuduCluster.sparkSession.sqlContext
      .read
      .options(Map("kudu.master" -> kuduConfig.masterAdresses,"kudu.table" -> kuduConfig.eventsTableName))
      .kudu

    df.count shouldBe 100

  }

  "KuduEventsHandler" should "handle redundant data" in {

    val metricsRDD = rdd
      .map(event => EventToStorableEvent(event))
      .flatMap(e => e.toOption)
      .map(se => EventToKuduEvent(se))
      .flatMap(e => e.toOption)

    val metricsDF = kuduCluster.sparkSession.createDataFrame(metricsRDD)

    val kuduConfig = KuduConfig(kuduCluster.kuduMiniCluster.getMasterAddresses, "TestEventsDuplicate", 2)
    KuduEventsHandler.getOrCreateTable(kuduCluster.kuduContext, kuduConfig)

    KuduEventsHandler.write(metricsDF, kuduCluster.kuduContext, kuduConfig)
    KuduEventsHandler.write(metricsDF, kuduCluster.kuduContext, kuduConfig)

    val df = kuduCluster.sparkSession.sqlContext
      .read
      .options(Map("kudu.master" -> kuduConfig.masterAdresses,"kudu.table" -> kuduConfig.eventsTableName))
      .kudu

    df.count shouldBe 100

  }

  override def beforeAll() {
    kuduCluster.start()
  }

  override def afterAll() {
    kuduCluster.start()
  }

} 
Example 45
Source File: GranularBigMatrixSpec.scala    From glint   with MIT License 5 votes vote down vote up
package glint.matrix

import glint.SystemTest
import glint.models.client.granular.GranularBigMatrix
import org.scalatest.{FlatSpec, Matchers}


class GranularBigMatrixSpec extends FlatSpec with SystemTest with Matchers {

  "A GranularBigMatrix" should "handle large push/pull requests" in withMaster { _ =>
    withServers(2) { _ =>
      withClient { client =>
        val model = client.matrix[Double](1000, 1000)
        val granularModel = new GranularBigMatrix[Double](model, 10000)
        val rows = new Array[Long](1000000)
        val cols = new Array[Int](1000000)
        val values = new Array[Double](1000000)
        var i = 0
        while (i < rows.length) {
          rows(i) = i % 1000
          cols(i) = i / 1000
          values(i) = i * 3.14
          i += 1
        }

        whenReady(granularModel.push(rows, cols, values)) {
          identity
        }
        val result = whenReady(granularModel.pull(rows, cols)) {
          identity
        }

        result shouldEqual values
      }
    }
  }

  it should " handle large pull requests for rows"  in withMaster { _ =>
    withServers(3) { _ =>
      withClient { client =>
        val model = client.matrix[Double](1000, 1000)
        val granularModel = new GranularBigMatrix[Double](model, 10000)
        val rows = new Array[Long](1000000)
        val cols = new Array[Int](1000000)
        val values = new Array[Double](1000000)
        var i = 0
        while (i < rows.length) {
          rows(i) = i % 1000
          cols(i) = i / 1000
          values(i) = i * 3.14
          i += 1
        }

        whenReady(granularModel.push(rows, cols, values)) {
          identity
        }
        val result = whenReady(granularModel.pull((0L until 1000L).toArray)) {
          identity
        }

        i = 0
        while (i < rows.length) {
          assert(result(i % 1000)(i / 1000) == values(i))
          i += 1
        }
      }
    }
  }

  it should "have non-zero max message size" in withMaster { _ =>
    withServers(2) { _ =>
      withClient { client =>
        val model = client.matrix[Double](10, 10)
        intercept[IllegalArgumentException] {
          new GranularBigMatrix(model, 0)
        }
        intercept[IllegalArgumentException] {
          new GranularBigMatrix(model, -1)
        }
      }
    }
  }

} 
Example 46
Source File: GranularBigVectorSpec.scala    From glint   with MIT License 5 votes vote down vote up
package glint.vector

import scala.util.Random

import glint.SystemTest
import glint.models.client.granular.GranularBigVector
import org.scalatest.{FlatSpec, Matchers}


class GranularBigVectorSpec extends FlatSpec with SystemTest with Matchers {

  "A GranularBigVector" should "handle large push/pull requests" in withMaster { _ =>
    withServers(2) { _ =>
      withClient { client =>
        val size = 1e6.toInt
        val rng = new Random()
        rng.setSeed(42)
        val model = client.vector[Double](size)
        val granularModel = new GranularBigVector(model, 1000)
        val keys = (0 until size).map(_.toLong).toArray
        val values = Array.fill(size) { rng.nextDouble() }

        whenReady(granularModel.push(keys, values)) {
          identity
        }
        val result = whenReady(granularModel.pull(keys)) {
          identity
        }

        result shouldEqual values
      }
    }
  }

  it should "have non-zero max message size" in withMaster { _ =>
    withServers(2) { _ =>
      withClient { client =>
        val model = client.vector[Double](10)
        intercept[IllegalArgumentException] {
          val granularModel = new GranularBigVector(model, 0)
        }
        intercept[IllegalArgumentException] {
          val granularModel = new GranularBigVector(model, -1)
        }
      }
    }
  }

} 
Example 47
Source File: ColumnIteratorSpec.scala    From glint   with MIT License 5 votes vote down vote up
package glint.iterators

import akka.util.Timeout
import glint.SystemTest
import glint.mocking.MockBigMatrix
import org.scalatest.{FlatSpec, Matchers}

import scala.concurrent.ExecutionContext
import scala.concurrent.duration._


class ColumnIteratorSpec extends FlatSpec with SystemTest with Matchers {

  "A ColumnIterator" should "iterate over all columns in order" in {

    // Construct mock matrix and data to push into it
    val nrOfRows = 2
    val nrOfCols = 4
    val mockMatrix = new MockBigMatrix[Long](nrOfRows, nrOfCols, 0, _ + _)

    val rows   = Array(0L, 1L, 0L, 1L, 0L, 1L, 0L, 1L)
    val cols   = Array( 0,  0,  1,  1,  2,  2,  3,  3)
    val values = Array(0L,  1,  2,  3,  4,  5,  6,  7)

    whenReady(mockMatrix.push(rows, cols, values)) { identity }

    // Check whether elements are in order
    var counter = 0
    val iterator = new ColumnIterator[Long](mockMatrix)
    iterator.foreach {
      case column => column.foreach {
        case value =>
          assert(value == counter)
          counter += 1
      }
    }

  }

  it should "iterate over all columns in order with larger rows" in {

    // Construct mock matrix and data to push into it
    val nrOfRows = 4
    val nrOfCols = 2
    val mockMatrix = new MockBigMatrix[Long](nrOfRows, nrOfCols, 0, _ + _)

    val rows   = Array(0L, 1L, 2L, 3L, 0L, 1L, 2L, 3L)
    val cols   = Array( 0,  0,  0,  0,  1,  1,  1,  1)
    val values = Array(0L,  1,  2,  3,  4,  5,  6,  7)

    whenReady(mockMatrix.push(rows, cols, values)) { identity }

    // Check whether elements are in order
    var counter = 0
    val iterator = new ColumnIterator[Long](mockMatrix)
    iterator.foreach {
      case column => column.foreach {
        case value =>
          assert(value == counter)
          counter += 1
      }
    }

  }

  it should "not iterate over an empty matrix" in {
    val mockMatrix = new MockBigMatrix[Double](0, 2, 0, _ + _)

    val iterator = new ColumnIterator[Double](mockMatrix)
    assert(!iterator.hasNext)
    iterator.foreach {
      case _ => fail("This should never execute")
    }

  }

} 
Example 48
Source File: RowBlockIteratorSpec.scala    From glint   with MIT License 5 votes vote down vote up
package glint.iterators

import akka.util.Timeout
import glint.SystemTest
import glint.mocking.MockBigMatrix
import org.scalatest.{FlatSpec, Matchers}

import scala.concurrent.ExecutionContext
import scala.concurrent.duration._


class RowBlockIteratorSpec extends FlatSpec with SystemTest with Matchers {

  "A RowBlockIterator" should "iterate over all blocks of rows in order" in {

    // Construct mock matrix and data to push into it
    val nrOfRows = 5
    val nrOfCols = 2
    val mockMatrix = new MockBigMatrix[Long](nrOfRows, nrOfCols, 0, _ + _)

    val rows   = Array(0L, 0L, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 4L)
    val cols   = Array( 0,  1,  0,  1,  0,  1,  0,  1,  0,  1)
    val values = Array(0L,  1,  2,  3,  4,  5,  6,  7,  8,  9)

    whenReady(mockMatrix.push(rows, cols, values)) { identity }

    // Check whether elements are in order
    var counter = 0
    val iterator = new RowBlockIterator[Long](mockMatrix, 2)
    iterator.foreach {
      case rows => rows.foreach {
        case row => row.foreach {
          case value =>
            assert(value == counter)
            counter += 1
        }
      }
    }

  }

  it should "iterate over a single block" in {

    // Construct mock matrix and data to push into it
    val nrOfRows = 5
    val nrOfCols = 2
    val mockMatrix = new MockBigMatrix[Long](nrOfRows, nrOfCols, 0, _ + _)

    val rows = Array(0L, 0L, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 4L)
    val cols = Array(0, 1, 0, 1, 0, 1, 0, 1, 0, 1)
    val values = Array(0L, 1, 2, 3, 4, 5, 6, 7, 8, 9)

    whenReady(mockMatrix.push(rows, cols, values)) {
      identity
    }

    // Check whether elements are in order
    var counter = 0
    val iterator = new RowBlockIterator[Long](mockMatrix, 7)
    val resultRows = iterator.next()
    assert(!iterator.hasNext)
    resultRows.foreach {
      case row => row.foreach {
        case value =>
          assert(value == counter)
          counter += 1
      }
    }

  }

  it should "not iterate over an empty matrix" in {
    val mockMatrix = new MockBigMatrix[Double](0, 5, 0, _ + _)

    val iterator = new RowBlockIterator[Double](mockMatrix, 3)
    assert(!iterator.hasNext)
    iterator.foreach {
      case _ => fail("This should never execute")
    }

  }

} 
Example 49
Source File: RowIteratorSpec.scala    From glint   with MIT License 5 votes vote down vote up
package glint.iterators

import akka.util.Timeout
import glint.SystemTest
import glint.mocking.MockBigMatrix
import org.scalatest.{FlatSpec, Matchers}

import scala.concurrent.ExecutionContext
import scala.concurrent.duration._


class RowIteratorSpec extends FlatSpec with SystemTest with Matchers {

  "A RowIterator" should "iterate over all rows in order" in {

    // Construct mock matrix and data to push into it
    val nrOfRows = 5
    val nrOfCols = 2
    val mockMatrix = new MockBigMatrix[Long](nrOfRows, nrOfCols, 0, _ + _)

    val rows   = Array(0L, 0L, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 4L)
    val cols   = Array( 0,  1,  0,  1,  0,  1,  0,  1,  0,  1)
    val values = Array(0L,  1,  2,  3,  4,  5,  6,  7,  8,  9)

    whenReady(mockMatrix.push(rows, cols, values)) { identity }

    // Check whether elements are in order
    var counter = 0
    val iterator = new RowIterator[Long](mockMatrix, 2)
    iterator.foreach {
      case row => row.foreach {
        case value =>
          assert(value == counter)
          counter += 1
      }
    }

  }

  it should "iterate over a single block" in {

    // Construct mock matrix and data to push into it
    val nrOfRows = 5
    val nrOfCols = 2
    val mockMatrix = new MockBigMatrix[Long](nrOfRows, nrOfCols, 0, _ + _)

    val rows = Array(0L, 0L, 1L, 1L, 2L, 2L, 3L, 3L, 4L, 4L)
    val cols = Array(0, 1, 0, 1, 0, 1, 0, 1, 0, 1)
    val values = Array(0L, 1, 2, 3, 4, 5, 6, 7, 8, 9)

    whenReady(mockMatrix.push(rows, cols, values)) {
      identity
    }

    // Check whether elements are in order
    var counter = 0
    val iterator = new RowIterator[Long](mockMatrix, 7)
    iterator.foreach {
      case row => row.foreach {
        case value =>
          assert(value == counter)
          counter += 1
      }
    }
    assert(!iterator.hasNext)

  }

  it should "not iterate over an empty matrix" in {
    val mockMatrix = new MockBigMatrix[Double](3, 0, 0, _ + _)

    val iterator = new RowBlockIterator[Double](mockMatrix, 3)
    assert(!iterator.hasNext)
    iterator.foreach {
      case _ => fail("This should never execute")
    }

  }

} 
Example 50
Source File: Test_Flat_03_ScalaTest.scala    From LearningScala   with Apache License 2.0 5 votes vote down vote up
package _100_assertions_and_tests

// BDD
import _100_assertions_and_tests._03_ScalaTest._
import org.scalatest.{FlatSpec, Matchers}


class Test_Flat_03_ScalaTest extends FlatSpec with Matchers {

  "The account" should "approve deposit of a positive amount" in {
    val currentBalance = balance
    val amount = 20
    deposit(amount)
    balance should be(currentBalance + amount)
  }

  it should "approve withdraw of positive amount if amount is sufficient" in {
    balance = 20
    val currentBalance = balance
    val amount = 10
    withdraw(amount)
    balance should be(currentBalance - amount)
  }

  it must "decline negative amount in withdraw" in {
    balance = 20
    val currentBalance = balance
    val amount = -20
    an[Error] must be thrownBy {
      withdraw(amount)
      balance = currentBalance - amount
    }
  }

  it must "decline withdraw if amount > balance" in {
    balance = 20
    val currentBalance = balance
    val amount = 100
    an[Error] must be thrownBy {
      withdraw(amount)
      balance = currentBalance - amount
    }
  }
} 
Example 51
Source File: ModelSpec.scala    From scala-spark-cab-rides-predictions   with MIT License 5 votes vote down vote up
import MockedResponse._
import com.lyft.networking.apiObjects.CostEstimate
import com.uber.sdk.rides.client.model.PriceEstimate
import models._
import org.scalatest.{FlatSpec, Matchers}

class ModelSpec extends FlatSpec with Matchers {
  behavior of "Models"
  "UberPriceModel" should "return CabPrice" in {
    import com.google.gson.Gson
    val g = new Gson

    val priceEstimate = g.fromJson(priceEstimateJson, classOf[PriceEstimate])
    val source = Location("source_test", 45f, 34f)
    val destination = Location("destination_test", 55f, 34f)

    UberPriceModel(priceEstimate, source, destination) should matchPattern {
      case CabPrice("Uber", "a1111c8c-c720-46c3-8534-2fcdd730040d", "uberX", _, Some(6.17f), 1.0, _, "source_test", "destination_test", _) =>
    }
  }

  "LyftPriceModel" should "return CabPrice" in {
    import com.google.gson.Gson

    val g = new Gson

    val costEstimate = g.fromJson(costEstimateJson, classOf[CostEstimate])
    val source = Location("source_test", 45f, 34f)
    val destination = Location("destination_test", 55f, 34f)

    LyftPriceModel(costEstimate, source, destination) should matchPattern {
      case CabPrice("Lyft", "lyft_plus", "Lyft Plus", _, Some(3.29f), 1.25, _, "source_test", "destination_test", _) =>
    }
  }

  "WeatherModel" should "return Weather" in {
    val location = Location("source_test", 45f, 34f)
    WeatherModel(weatherResponse, location) should matchPattern {
      case Weather("source_test", Some(40.38f), Some(0.64f), Some(995.66f), Some(0.83f), None, None, Some(11.05f), 1543448090) =>
    }
  }
} 
Example 52
Source File: LocationSpec.scala    From scala-spark-cab-rides-predictions   with MIT License 5 votes vote down vote up
import models.{Location, LocationRepository}
import org.scalatest.{FlatSpec, Matchers}

import scala.collection.Seq

class LocationSpec extends FlatSpec with Matchers {
  behavior of "Location"

  "repository locations" should " return Seq[Location]" in {
    val locations = LocationRepository.getLocations
    locations should matchPattern {
      case _: Seq[Location] =>
    }
  }

  "random pairing of locations" should "return Seq oftype of locations" in {
    val locationsTuples = LocationRepository.getPairedLocations
    locationsTuples should matchPattern {
      case _: Seq[(Location, Location)] =>
    }
  }


} 
Example 53
Source File: TemplateRenderingSpec.scala    From avoin-voitto   with MIT License 5 votes vote down vote up
package liigavoitto.journalist.utils

import liigavoitto.util.Logging
import org.scalatest.{FlatSpec, Matchers}

class TemplateRenderingSpec extends FlatSpec with Matchers with Logging with WeightedRandomizer {

  val sampleSize = 1000000
  val maxDeviationThreshold = 0.1

  "TemplateRendering" should "render a template without variables and preserve weight" in {
    val template = Template("This is a template.")
    val result = TemplateRendering.render(template, Map())
    result shouldBe Some(RenderedTemplate(template.template, template.weight))
  }

  it should "render a template with a variable and a custom weight" in {
    val template = Template("This is a {{template}}.", 0.5)
    val attr = Map("template" -> "text")
    val result = TemplateRendering.render(template, attr)
    result shouldBe Some(RenderedTemplate("This is a text.", 0.5))
  }

  "WeightedRandomizer" should "always pick from a list of 1" in {
    val templateList = List(RenderedTemplate("test", 1.0))
    val result = weightedRandom(templateList)
    result shouldBe Some("test")
  }

  it should "return None with an empty list" in {
    weightedRandom(List()) shouldBe None
  }

  it should "return equal amounts with default weights" in {
    def getRandom = {
      weightedRandom(List(
        RenderedTemplate("first", 1.0),
        RenderedTemplate("second", 1.0),
        RenderedTemplate("third", 1.0)
      ))
    }

    val results = (1 to sampleSize).flatMap(_ => getRandom)
    val expected = sampleSize / 3

    results.count(t => t == "first") should beCloseToExpected(expected)
    results.count(t => t == "second") should beCloseToExpected(expected)
    results.count(t => t == "third") should beCloseToExpected(expected)
  }

  it should "return different amounts with weighted templates" in {
    def getRandom = {
      weightedRandom(List(
        RenderedTemplate("first", 1.0),
        RenderedTemplate("second", 0.5)
      ))
    }

    val results = (1 to sampleSize).flatMap(_ => getRandom)
    val expected = sampleSize / 3
    results.count(t => t == "second") should beCloseToExpected(expected)
  }

  it should "work even if the templates are not sorted by weight" in {
    def getRandom = {
      weightedRandom(List(
        RenderedTemplate("first", 1.0),
        RenderedTemplate("second", 0.5),
        RenderedTemplate("third", 1.0),
        RenderedTemplate("fourth", 0.5),
        RenderedTemplate("fifth", 1.0)
      ))
    }

    val results = (1 to sampleSize).flatMap(_ => getRandom)

    // each 1.0 weight template should have 1/4 of the results
    val expected = sampleSize / 4
    results.count(t => t == "first") should beCloseToExpected(expected)
    results.count(t => t == "third") should beCloseToExpected(expected)
    results.count(t => t == "fifth") should beCloseToExpected(expected)

    // second and fourth should have 1/4 combined
    results.count(t => t == "second" || t == "fourth") should beCloseToExpected(expected)
  }

  def beCloseToExpected(expected: Double) = {
    val lowerBound = (expected - (expected * maxDeviationThreshold)).toInt
    val higherBound = (expected + (expected * maxDeviationThreshold)).toInt
    be >= lowerBound and be <= higherBound
  }
} 
Example 54
Source File: MustacheSpec.scala    From avoin-voitto   with MIT License 5 votes vote down vote up
package robottijournalismi.service.journalist.utils

import liigavoitto.journalist.utils.Mustache
import liigavoitto.journalist.utils.Mustache.ValueNotFoundException
import org.scalatest.{ FlatSpec, Matchers }

class MustacheSpec extends FlatSpec with Matchers {

  "Mustache" should "fail rendering when value is not found" in {
    val ctx = Map("something" -> "yes")
    assertThrows[ValueNotFoundException] {
      Mustache("testing {{something}}, {{eiole}}").apply(ctx)
    }
  }

  it should "render a tag with ':' in it" in {
    val ctx = Map("winner:n" -> "voittajan")
    Mustache("Löysin {{winner:n}}.").apply(ctx) shouldEqual "Löysin voittajan."
  }

} 
Example 55
Source File: TemplateUtilsSpec.scala    From avoin-voitto   with MIT License 5 votes vote down vote up
package liigavoitto.journalist.utils

import liigavoitto.journalist.MockData
import liigavoitto.scores._
import org.scalatest.{FlatSpec, Matchers}

class TemplateUtilsSpec extends FlatSpec with Matchers with TemplateUtils with MockData {

  "TemplateUtilsSpec" should "return correct numeral declensions" in {
    val numberFi = numeralWithDeclensions("number6", 6, "fi")
    val numberSv = numeralWithDeclensions("number6", 6, "sv")

    numberFi("number6:nominative") shouldBe "kuusi"
    numberFi("number6:ordinalGenitive") shouldBe "kuudennen"

    numberSv("number6:nominative") shouldBe "sex"
  }

  it should "return correct declension for name" in {
    val feedPlayerDeclensions = feedPlayerWithDeclensions("testPlayer", FeedPlayer("jkl-29837610", PlayerName("Teemu", "Väyrynen"), None))
    feedPlayerDeclensions("testPlayer") shouldBe "Teemu Väyrynen"
    feedPlayerDeclensions("testPlayer:genitive") shouldBe "Teemu Väyrysen"
    feedPlayerDeclensions("testPlayer:ablative") shouldBe "Teemu Väyryseltä"
    feedPlayerDeclensions("testPlayer:allative") shouldBe "Teemu Väyryselle"
    feedPlayerDeclensions("testPlayer:adessive") shouldBe "Teemu Väyrysellä"
    feedPlayerDeclensions("testPlayer.last") shouldBe "Väyrynen"
    feedPlayerDeclensions("testPlayer.last:genitive") shouldBe "Väyrysen"
    feedPlayerDeclensions("testPlayer.last:ablative") shouldBe "Väyryseltä"
    feedPlayerDeclensions("testPlayer.last:allative") shouldBe "Väyryselle"
    feedPlayerDeclensions("testPlayer.last:adessive") shouldBe "Väyrysellä"
  }

  it should "return correct declension for team" in {
    val teamMeta = Meta(List[Image](), None)
    val teamDeclension = teamWithDeclensions("testTeam", Team("jkl-624554857", "Lukko", "Lukko", teamMeta, None, List[Player]()), "fi")
    teamDeclension("testTeam:genitive") shouldBe "Lukon"
    teamDeclension("testTeam:accusative") shouldBe "Lukon"
    teamDeclension("testTeam:allative") shouldBe "Lukolle"
    teamDeclension("testTeam:elative") shouldBe "Lukosta"
    teamDeclension("testTeam:partitive") shouldBe "Lukkoa"
  }
} 
Example 56
Source File: TemplateLoaderSpec.scala    From avoin-voitto   with MIT License 5 votes vote down vote up
package liigavoitto.journalist.utils

import org.scalatest.{FlatSpec, Matchers}

class TemplateLoaderSpec extends FlatSpec with Matchers with TemplateLoader {

  val testData = load("template/loader-test.edn", "score-over-three-goals", "fi")
  "TemplateLoader" should "load templates from an edn file" in {
    testData.length shouldBe 2
    testData.head shouldBe Template("{{bestPlayer}} tykitti {{bestPlayer.goals:text}} maalia")
  }
  
  it should "parse template weight if set" in {
    testData(1).weight shouldBe 0.5
  }
} 
Example 57
Source File: ScoresFromFilesFetcherSpec.scala    From avoin-voitto   with MIT License 5 votes vote down vote up
package liigavoitto.fetch

import org.scalatest.{FlatSpec, Matchers}

class ScoresFromFilesFetcherSpec extends FlatSpec with Matchers {

  "ScoresFromFilesFetcher" should "fetch match data from files" in {
    val matchId = "jkl-0-2018-3814"
    val fetcher = new ScoresFromFilesFetcher(matchId)
    val data = fetcher.getEnrichedMatchData

    data.mtch.id shouldEqual matchId
    data.seriesId shouldEqual "mestis"
    data.allAwayTeamMatches.length shouldEqual 38
    data.allHomeTeamMatches.length shouldEqual 36
    data.playerStats.length shouldEqual 466
    data.leagueTable.length shouldEqual 12
  }

  it should "filter matches in match list in older games" in {
    val matchId = "jkl-0-2018-3748"
    val fetcher = new ScoresFromFilesFetcher(matchId)
    val data = fetcher.getEnrichedMatchData

    data.mtch.id shouldEqual matchId
    data.seriesId shouldEqual "mestis"
    data.allAwayTeamMatches.length shouldEqual 17
    data.allHomeTeamMatches.length shouldEqual 18
  }
} 
Example 58
Source File: ServerSpec.scala    From seals   with Apache License 2.0 5 votes vote down vote up
package com.example.server

import java.util.concurrent.Executors

import scala.concurrent.ExecutionContext

import cats.effect.{ IO, Blocker, ContextShift }

import org.scalatest.{ FlatSpec, Matchers, BeforeAndAfterAll }

import fs2.{ Stream, Chunk }

import scodec.bits._
import scodec.Codec

import dev.tauri.seals.scodec.Codecs._

import com.example.proto._

class ServerSpec extends FlatSpec with Matchers with BeforeAndAfterAll {

  implicit val cs: ContextShift[IO] = IO.contextShift(ExecutionContext.global)

  val ex = Executors.newCachedThreadPool()
  val ec = ExecutionContext.fromExecutor(ex)
  val bl = Blocker.liftExecutionContext(ec)
  val (sg, closeSg) = fs2.io.tcp.SocketGroup[IO](bl).allocated.unsafeRunSync()

  override def afterAll(): Unit = {
    super.afterAll()
    closeSg.unsafeRunSync()
    ex.shutdown()
  }

  "Server" should "respond to a request" in {
    val responses: Vector[Response] = Stream(
      Server.serve(Server.port, sg).drain,
      client(Server.port)
    ).parJoin(Int.MaxValue).take(1L).compile.toVector.unsafeRunSync()
    responses should === (Vector(Ok))
  }

  def client(port: Int): Stream[IO, Response] = {
    Stream.resource(sg.client[IO](Server.addr(port))).flatMap { socket =>
      val bvs: Stream[IO, BitVector] = Stream(Codec[Request].encode(ReSeed(56)).require)
      val bs: Stream[IO, Byte] = bvs.flatMap { bv =>
        Stream.chunk(Chunk.bytes(bv.bytes.toArray))
      }
      val read = bs.through(socket.writes(Server.timeout)).drain.onFinalize(socket.endOfOutput) ++
        socket.reads(Server.bufferSize, Server.timeout).chunks.map(ch => BitVector.view(ch.toArray))
      read.fold(BitVector.empty)(_ ++ _).map(bv => Codec[Response].decode(bv).require.value)
    }
  }
} 
Example 59
Source File: ExceptionSerializerSpec.scala    From reliable-http-client   with Apache License 2.0 5 votes vote down vote up
package rhttpc.transport.json4s

import org.json4s.{DefaultFormats, TypeHints}
import org.scalatest.{FlatSpec, Matchers, TryValues}

class ExceptionSerializerSpec extends FlatSpec with Matchers with TryValues {

  it should "round-trip serialize case class exception" in {
    roundTrip(CaseClassException(123))
  }

  it should "round-trip serialize exception with message" in {
    roundTrip(new ExceptionWithMessage("foo"))
  }

  it should "round-trip serialize exception with null message" in {
    roundTrip(new ExceptionWithMessage(null))
  }

  it should "round-trip serialize exception with message and cause" in {
    roundTrip(new ExceptionWithMessageAndCause("foo", CaseClassException(123)))
  }

  private def roundTrip(ex: Throwable): Unit = {
    implicit val formats = new DefaultFormats {
      override val typeHints: TypeHints = AllTypeHints
    } + ExceptionSerializer
    val serializer = new Json4sSerializer[Throwable]()
    val deserializer = new Json4sDeserializer[Throwable]()
    val serialized = serializer.serialize(ex)
    val deserialized = deserializer.deserialize(serialized)
    deserialized.success.value shouldEqual ex
  }

}

case class CaseClassException(x: Int) extends Exception(s"x: $x")

class ExceptionWithMessage(msg: String) extends Exception(msg) {
  def canEqual(other: Any): Boolean = other.isInstanceOf[ExceptionWithMessage]

  override def equals(other: Any): Boolean = other match {
    case that: ExceptionWithMessage =>
      (that canEqual this) &&
        getMessage == that.getMessage
    case _ => false
  }

  override def hashCode(): Int = {
    val state = Seq(getMessage)
    state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
  }
}

class ExceptionWithMessageAndCause(msg: String, cause: Throwable) extends Exception(msg, cause) {
  def canEqual(other: Any): Boolean = other.isInstanceOf[ExceptionWithMessageAndCause]

  override def equals(other: Any): Boolean = other match {
    case that: ExceptionWithMessageAndCause =>
      (that canEqual this) &&
        getMessage == that.getMessage &&
        getCause == that.getCause
    case _ => false
  }

  override def hashCode(): Int = {
    val state = Seq(getMessage, getCause)
    state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
  }
} 
Example 60
Source File: ConfigParserSpec.scala    From reliable-http-client   with Apache License 2.0 5 votes vote down vote up
package rhttpc.client.config

import com._
import org.scalatest.{FlatSpec, Matchers}
import rhttpc.client.proxy.{BackoffRetry, HandleAll}

import scala.concurrent.duration._

class ConfigParserSpec extends FlatSpec with Matchers {

  it should "parse config with backoff strategy" in {
    val config = typesafe.config.ConfigFactory.parseString(
      """x {
        |  queuesPrefix = "rhttpc"
        |  batchSize = 10
        |  parallelConsumers = 1
        |  retryStrategy {
        |    initialDelay = 5 seconds
        |    multiplier = 1.2
        |    maxRetries = 3
        |  }
        |}
      """.stripMargin)

    ConfigParser.parse(config, "x") shouldEqual RhttpcConfig("rhttpc", 10, 1, BackoffRetry(5.seconds, 1.2, 3, None))
  }

  it should "parse config with publish all strategy" in {
    val config = typesafe.config.ConfigFactory.parseString(
      """x {
        |  queuesPrefix = "rhttpc"
        |  batchSize = 10
        |  parallelConsumers = 1
        |  retryStrategy = handle-all
        |}
      """.stripMargin)

    ConfigParser.parse(config, "x") shouldEqual RhttpcConfig("rhttpc", 10, 1, HandleAll)
  }

  it should "parse config with backoff with deadline strategy" in {
    val config = typesafe.config.ConfigFactory.parseString(
      """x {
        |  queuesPrefix = "rhttpc"
        |  batchSize = 10
        |  parallelConsumers = 1
        |  retryStrategy {
        |    initialDelay = 5 seconds
        |    multiplier = 1.2
        |    maxRetries = 3
        |    deadline = 5 seconds
        |  }
        |}
      """.stripMargin)

    ConfigParser.parse(config, "x") shouldEqual RhttpcConfig("rhttpc", 10, 1, BackoffRetry(5 seconds, 1.2, 3, Some(5 seconds)))
  }
} 
Example 61
Source File: DataStreamExpressionsTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.datastream

import io.eels.Row
import io.eels.schema.{DoubleType, Field, IntType, LongType, StructType}
import org.scalatest.{FlatSpec, Matchers}

class DataStreamExpressionsTest extends FlatSpec with Matchers {

  val schema = StructType(
    Field("artist"),
    Field("year", IntType()),
    Field("album"),
    Field("sales", LongType())
  )
  val ds = DataStream.fromRows(schema,
    Row(schema, Vector("Elton John", 1969, "Empty Sky", 1433)),
    Row(schema, Vector("Elton John", 1971, "Madman Across the Water", 7636)),
    Row(schema, Vector("Elton John", 1972, "Honky Château", 2525)),
    Row(schema, Vector("Elton John", 1973, "Goodbye Yellow Brick Road", 4352)),
    Row(schema, Vector("Elton John", 1975, "Rock of the Westies", 5645)),
    Row(schema, Vector("Kate Bush", 1978, "The Kick Inside", 2577)),
    Row(schema, Vector("Kate Bush", 1978, "Lionheart", 745)),
    Row(schema, Vector("Kate Bush", 1980, "Never for Ever", 7444)),
    Row(schema, Vector("Kate Bush", 1982, "The Dreaming", 8253)),
    Row(schema, Vector("Kate Bush", 1985, "Hounds of Love", 2495))
  )

  "DataStream.filter" should "support expressions" in {
    import io.eels._
    ds.filter(select("album") === "Lionheart")
      .collectValues shouldBe Vector(Vector("Kate Bush", 1978, "Lionheart", 745))
  }

  "DataStream.addField" should "support multiply expressions" in {
    import io.eels._
    ds.filter(select("album") === "Lionheart")
      .addField(Field("woo", DoubleType), select("year") * 1.2)
      .collectValues shouldBe Vector(Vector("Kate Bush", 1978, "Lionheart", 745, BigDecimal(2373.6)))
  }

  "DataStream.addField" should "support addition expressions" in {
    import io.eels._
    ds.filter(select("album") === "Lionheart")
      .addField(Field("woo", DoubleType), select("year") + 1.2)
      .collectValues shouldBe Vector(Vector("Kate Bush", 1978, "Lionheart", 745, BigDecimal(1979.2)))
  }

  "DataStream.addField" should "support subtraction expressions" in {
    import io.eels._
    ds.filter(select("album") === "Lionheart")
      .addField(Field("woo", DoubleType), select("year") - 1.2)
      .collectValues shouldBe Vector(Vector("Kate Bush", 1978, "Lionheart", 745, BigDecimal(1976.8)))
  }

  "DataStream.addField" should "support division expressions" in {
    import io.eels._
    ds.filter(select("album") === "Lionheart")
      .addField(Field("woo", DoubleType), select("year") / 2)
      .collectValues shouldBe Vector(Vector("Kate Bush", 1978, "Lionheart", 745, 989))
  }
} 
Example 62
Source File: ParquetProjectionTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.parquet

import java.io.{File, FilenameFilter}

import io.eels.datastream.DataStream
import io.eels.schema.{Field, StringType, StructType}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.scalatest.{FlatSpec, Matchers}

class ParquetProjectionTest extends FlatSpec with Matchers {

  cleanUpResidualParquetTestFiles

  private val schema = StructType(
    Field("name", StringType, nullable = false),
    Field("job", StringType, nullable = false),
    Field("location", StringType, nullable = false)
  )
  private val ds = DataStream.fromValues(
    schema,
    Seq(
      Vector("clint eastwood", "actor", "carmel"),
      Vector("elton john", "musician", "pinner")
    )
  )

  private implicit val conf = new Configuration()
  private implicit val fs = FileSystem.get(new Configuration())
  private val file = new File(s"test_${System.currentTimeMillis()}.pq")
  file.deleteOnExit()
  private val path = new Path(file.toURI)

  if (fs.exists(path))
    fs.delete(path, false)

  ds.to(ParquetSink(path).withOverwrite(true))

  "ParquetSource" should "support projections" in {
    val rows = ParquetSource(path).withProjection("name").toDataStream().collect
    rows.map(_.values) shouldBe Vector(Vector("clint eastwood"), Vector("elton john"))
  }

  it should "return all data when no projection is set" in {
    val rows = ParquetSource(path).toDataStream().collect
    rows.map(_.values) shouldBe Vector(Vector("clint eastwood", "actor", "carmel"), Vector("elton john", "musician", "pinner"))
  }

  private def cleanUpResidualParquetTestFiles = {
    new File(".").listFiles(new FilenameFilter {
      override def accept(dir: File, name: String): Boolean = {
        (name.startsWith("test_") && name.endsWith(".pq")) || (name.startsWith(".test_") && name.endsWith(".pq.crc"))
      }
    }).foreach(_.delete())
  }

} 
Example 63
Source File: AvroAndParquetCrossCompatibilityTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.parquet

import io.eels.component.parquet.avro.{AvroParquetSink, AvroParquetSource}
import io.eels.datastream.DataStream
import io.eels.schema.{Field, StringType, StructType}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.scalatest.{FlatSpec, Matchers}

// tests that avro source/sink and avro parquet source/sink can write/read each others files
class AvroAndParquetCrossCompatibilityTest extends FlatSpec with Matchers {

  private implicit val conf = new Configuration()
  private implicit val fs = FileSystem.get(new Configuration())

  "AvroParquetSource and ParquetSource" should "be compatible" in {

    val path = new Path("cross.pq")
    if (fs.exists(path))
      fs.delete(path, false)

    val structType = StructType(
      Field("name", StringType, nullable = false),
      Field("location", StringType, nullable = false)
    )

    val ds = DataStream.fromValues(
      structType,
      Seq(
        Vector("clint eastwood", "carmel"),
        Vector("elton john", "pinner")
      )
    )

    ds.to(ParquetSink(path))
    AvroParquetSource(path).toDataStream().collect shouldBe ds.collect
    fs.delete(path, false)

    ds.to(AvroParquetSink(path))
    ParquetSource(path).toDataStream().collect shouldBe ds.collect
    fs.delete(path, false)
  }
} 
Example 64
Source File: KuduComponentTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.kudu

import io.eels.Row
import io.eels.datastream.DataStream
import io.eels.schema._
import org.scalatest.{FlatSpec, Matchers, Tag}

object Kudu extends Tag("kudu")

class KuduComponentTest extends FlatSpec with Matchers {

  "kudu" should "support end to end sink to source" taggedAs Kudu in {

    val schema = StructType(
      Field("planet", StringType, nullable = false, key = true),
      Field("position", StringType, nullable = true)
    )

    val ds = DataStream.fromValues(
      schema,
      Seq(
        Vector("earth", 3),
        Vector("saturn", 6)
      )
    )

    val master = "localhost:7051"
    ds.to(KuduSink(master, "mytable"))

    val rows = KuduSource(master, "mytable").toDataStream().collect
    rows shouldBe Seq(
      Row(schema, Vector("earth", "3")),
      Row(schema, Vector("saturn", "6"))
    )
  }

  it should "support all basic types" taggedAs Kudu in {

    val schema = StructType(
      Field("planet", StringType, nullable = false, key = true),
      Field("position", ByteType.Signed, nullable = false),
      Field("volume", DoubleType, nullable = false),
      Field("bytes", BinaryType, nullable = false),
      Field("gas", BooleanType, nullable = false),
      Field("distance", LongType.Signed, nullable = false)
    )

    val data = Array("earth", 3: Byte, 4515135988.632, Array[Byte](1, 2, 3), false, 83000000)

    val ds = DataStream.fromValues(schema, Seq(data))

    val master = "localhost:7051"
    ds.to(KuduSink(master, "mytable2"))

    val rows = KuduSource(master, "mytable2").toDataStream().collect
    val values = rows.head.values.toArray
    data(3) = data(3).asInstanceOf[Array[Byte]].toList
    values shouldBe data
  }
} 
Example 65
Source File: HiveTableFilesFnTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.hive

import java.nio.file.Paths

import com.sksamuel.exts.Logging
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hdfs.MiniDFSCluster
import org.apache.hadoop.hive.metastore.IMetaStoreClient
import org.apache.hadoop.hive.metastore.api.Table
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{FlatSpec, Matchers}

class HiveTableFilesFnTest extends FlatSpec with Matchers with Logging with MockitoSugar {

  System.clearProperty(MiniDFSCluster.PROP_TEST_BUILD_DATA)
  val clusterPath = Paths.get("miniclusters", "cluster")
  val conf = new Configuration()
  conf.set(MiniDFSCluster.HDFS_MINIDFS_BASEDIR, clusterPath.toAbsolutePath.toString)
  val cluster = new MiniDFSCluster.Builder(conf).build()
  implicit val fs = cluster.getFileSystem

  "HiveTableFilesFn" should "detect all files in root when no partitions" in {

    implicit val client = mock[IMetaStoreClient]
    org.mockito.Mockito.when(client.getTable("default", "mytable")).thenReturn(new Table)

    val root = new Path("tab1")
    fs.mkdirs(root)

    // table scanner will skip 0 length files
    val a = fs.create(new Path(root, "a"))
    a.write(1)
    a.close()

    val b = fs.create(new Path(root, "b"))
    b.write(1)
    b.close()

    HiveTableFilesFn("default", "mytable", fs.resolvePath(root), Nil).values.flatten.map(_.getPath.getName).toSet shouldBe Set("a", "b")
  }

  it should "ignore hidden files in root when no partitions" in {
    implicit val client = mock[IMetaStoreClient]
    org.mockito.Mockito.when(client.getTable("default", "mytable")).thenReturn(new Table)

    val root = new Path("tab2")
    fs.mkdirs(root)

    // table scanner will skip 0 length files
    val a = fs.create(new Path(root, "a"))
    a.write(1)
    a.close()

    val b = fs.create(new Path(root, "_b"))
    b.write(1)
    b.close()

    HiveTableFilesFn("default", "mytable", fs.resolvePath(root), Nil).values.flatten.map(_.getPath.getName).toSet shouldBe Set("a")
  }
} 
Example 66
Source File: OrcPredicateTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.orc

import java.io.{File, FilenameFilter}

import io.eels.Predicate
import io.eels.datastream.DataStream
import io.eels.schema.{Field, LongType, StringType, StructType}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

class OrcPredicateTest extends FlatSpec with Matchers with BeforeAndAfterAll {
  cleanUpResidualOrcTestFiles

  val schema = StructType(
    Field("name", StringType, nullable = true),
    Field("city", StringType, nullable = true),
    Field("age", LongType.Signed, nullable = true)
  )

  val values = Vector.fill(1000) {
    Vector("sam", "middlesbrough", 37)
  } ++ Vector.fill(1000) {
    Vector("laura", "iowa city", 24)
  }

  val ds = DataStream.fromValues(schema, values)

  implicit val conf = new Configuration()
  implicit val fs = FileSystem.get(new Configuration())
  val path = new Path("test.orc")

  if (fs.exists(path))
    fs.delete(path, false)

  new File(path.toString).deleteOnExit()

  ds.to(OrcSink(path).withRowIndexStride(1000))

  override protected def afterAll(): Unit = fs.delete(path, false)

  "OrcSource" should "support string equals predicates" in {
    conf.set("eel.orc.predicate.row.filter", "false")
    val rows = OrcSource(path).withPredicate(Predicate.equals("name", "sam")).toDataStream().collect
    rows.map(_.values).toSet shouldBe Set(Vector("sam", "middlesbrough", 37L))
  }

  it should "support gt predicates" in {
    conf.set("eel.orc.predicate.row.filter", "false")
    val rows = OrcSource(path).withPredicate(Predicate.gt("age", 30L)).toDataStream().collect
    rows.map(_.values).toSet shouldBe Set(Vector("sam", "middlesbrough", 37L))
  }

  it should "support lt predicates" in {
    conf.set("eel.orc.predicate.row.filter", "false")
    val rows = OrcSource(path).withPredicate(Predicate.lt("age", 30)).toDataStream().collect
    rows.map(_.values).toSet shouldBe Set(Vector("laura", "iowa city", 24L))
  }

  it should "enable row level filtering with predicates by default" in {
    conf.set("eel.orc.predicate.row.filter", "true")
    val rows = OrcSource(path).withPredicate(Predicate.equals("name", "sam")).toDataStream().collect
    rows.head.schema shouldBe schema
    rows.head.values shouldBe Vector("sam", "middlesbrough", 37L)
  }

  private def cleanUpResidualOrcTestFiles = {
    new File(".").listFiles(new FilenameFilter {
      override def accept(dir: File, name: String): Boolean = {
        (name.startsWith("test_") && name.endsWith(".orc")) || (name.startsWith(".test_") && name.endsWith(".orc.crc"))
      }
    }).foreach(_.delete())
  }
} 
Example 67
Source File: KafkaSinkTest.scala    From eel-sdk   with Apache License 2.0 5 votes vote down vote up
package io.eels.component.kafka

import java.util
import java.util.{Properties, UUID}

import io.eels.Row
import io.eels.datastream.DataStream
import io.eels.schema.{Field, StringType, StructType}
import net.manub.embeddedkafka.{EmbeddedKafka, EmbeddedKafkaConfig}
import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.producer.KafkaProducer
import org.apache.kafka.common.serialization.{Deserializer, Serializer}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

import scala.collection.JavaConverters._
import scala.util.Try

class KafkaSinkTest extends FlatSpec with Matchers with BeforeAndAfterAll {

  implicit val kafkaConfig = EmbeddedKafkaConfig(
    kafkaPort = 6001,
    zooKeeperPort = 6000
  )
  Try {
    EmbeddedKafka.start()
  }

  val schema = StructType(
    Field("name", StringType, nullable = true),
    Field("location", StringType, nullable = true)
  )

  val ds = DataStream.fromValues(
    schema,
    Seq(
      Vector("clint eastwood", UUID.randomUUID().toString),
      Vector("elton john", UUID.randomUUID().toString)
    )
  )

  "KafkaSink" should "support default implicits" ignore {

    val topic = "mytopic-" + System.currentTimeMillis()

    val properties = new Properties()
    properties.put("bootstrap.servers", s"localhost:${kafkaConfig.kafkaPort}")
    properties.put("group.id", "test")
    properties.put("auto.offset.reset", "earliest")

    val producer = new KafkaProducer[String, Row](properties, StringSerializer, RowSerializer)
    val sink = KafkaSink(topic, producer)

    val consumer = new KafkaConsumer[String, String](properties, StringDeserializer, StringDeserializer)
    consumer.subscribe(util.Arrays.asList(topic))

    ds.to(sink)
    producer.close()

    val records = consumer.poll(4000)
    records.iterator().asScala.map(_.value).toList shouldBe ds.collect.map {
      case Row(_, values) => values.mkString(",")
    }.toList
  }
}

object RowSerializer extends Serializer[Row] {
  override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = ()
  override def serialize(topic: String, data: Row): Array[Byte] = data.values.mkString(",").getBytes
  override def close(): Unit = ()
}

object StringSerializer extends Serializer[String] {
  override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = ()
  override def close(): Unit = ()
  override def serialize(topic: String, data: String): Array[Byte] = data.getBytes
}

object StringDeserializer extends Deserializer[String] {
  override def configure(configs: util.Map[String, _], isKey: Boolean): Unit = ()
  override def close(): Unit = ()
  override def deserialize(topic: String, data: Array[Byte]): String = new String(data)
} 
Example 68
Source File: KahanSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.math

import java.math.BigDecimal

import com.twosigma.flint.util.Timer
import org.scalatest.FlatSpec

import scala.util.Random

class KahanSpec extends FlatSpec {

  "Kahan" should "sum correctly in wiki example" in {
    val kahan = new Kahan()
    var i = 0
    while (i < 1000) {
      kahan.add(1.0)
      kahan.add(1.0e100)
      kahan.add(1.0)
      kahan.add(-1.0e100)
      i += 1
    }

    assert(kahan.value === 2000.0)
  }

  it should "sum correctly for constants of Double(s)" in {
    val kahan = new Kahan()
    val x = 1000.0002
    var sum = 0.0
    val bigDecimal = new BigDecimal(x)
    var bigDecimalSum = new BigDecimal(0.0)
    var i = 0
    while (i < (Int.MaxValue >> 5)) {
      sum += x
      kahan.add(x)
      bigDecimalSum = bigDecimalSum.add(bigDecimal)
      i += 1
    }
    assert(
      Math.abs(
        bigDecimalSum
          .subtract(new BigDecimal(kahan.value))
          .doubleValue()
      ) < 1.0e-5
    )

    assert(
      Math.abs(
        bigDecimalSum
          .subtract(new BigDecimal(sum))
          .doubleValue()
      ) > 1.0
    )
  }

  it should "subtract correctly" in {
    val kahan1 = new Kahan()
    val kahan2 = new Kahan()
    val x = 1000.0002
    var i = 0
    while (i < (Int.MaxValue >> 5)) {
      kahan1.add(x)
      kahan2.add(x)
      kahan2.add(x)
      i += 1
    }
    kahan2.subtract(kahan1)
    assert(kahan2.value === kahan1.value)
  }
} 
Example 69
Source File: SchemaSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.timeseries

import com.twosigma.flint.timeseries.row.Schema
import org.scalatest.FlatSpec
import org.apache.spark.sql.types._

class SchemaSpec extends FlatSpec {
  "Schema" should "create a schema correctly" in {
    val schema = StructType(Array(
      StructField("time", LongType),
      StructField("price", DoubleType)
    ))
    assert(Schema("time" -> LongType, "price" -> DoubleType) == schema)
    assert(Schema("price" -> DoubleType) == schema)
  }

  it should "`of` correctly" in {
    val schema = StructType(Array(
      StructField("foo", LongType),
      StructField("price", DoubleType)
    ))
    assert(Schema.of("foo" -> LongType, "price" -> DoubleType) == schema)
  }

  it should "create a time schema correctly without specifying any column" in {
    val schema = StructType(Array(
      StructField("time", LongType)
    ))
    assert(Schema() == schema)
  }

  it should "throw exception if `time` is not a LongType" in {
    intercept[IllegalArgumentException] {
      Schema("time" -> DoubleType)
    }
  }

  it should "throw an exception if you try to cast the `time` column" in {
    val schema = Schema()

    intercept[IllegalArgumentException] {
      Schema.cast(schema, "time" -> DoubleType)
    }
  }

  it should "throw an exception if you try to cast non-existing columns" in {
    val schema = Schema()

    intercept[IllegalArgumentException] {
      Schema.cast(schema, "foo" -> DoubleType)
    }
  }

  it should "throw an exception if you try to cast non-numeric columns" in {
    val schema = Schema("foo" -> StringType)

    intercept[IllegalArgumentException] {
      Schema.cast(schema, "foo" -> DoubleType)
    }
  }

  it should "cast numeric columns" in {
    val currentSchema = Schema("foo" -> IntegerType, "bar" -> ShortType)
    val newSchema = Schema.cast(currentSchema, "foo" -> DoubleType, "bar" -> IntegerType)

    assert(Schema.of("time" -> LongType, "foo" -> DoubleType, "bar" -> IntegerType) == newSchema)
  }
} 
Example 70
Source File: LinkedListHolderSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.util.collection

import org.scalatest.FlatSpec
import java.util.{ LinkedList => JLinkedList }
import scala.collection.JavaConverters._

class LinkedListHolderSpec extends FlatSpec {

  "LinkedListHolder" should "dropWhile correctly" in {
    import Implicits._
    val l = new JLinkedList[Int]()
    val n = 10
    val p = 3
    l.addAll((1 to n).asJava)

    val (dropped1, rest1) = l.dropWhile { i => i > p }
    assert(dropped1.size() == 0)
    assert(rest1.toArray.deep == l.toArray.deep)

    val (dropped2, rest2) = l.dropWhile { i => i < p }
    assert(dropped2.toArray.deep == (1 until p).toArray.deep)
    assert(rest2.toArray.deep == (p to n).toArray.deep)
  }

  it should "foldLeft correctly" in {
    import com.twosigma.flint.util.collection.Implicits._
    val l = new JLinkedList[Char]()
    assert(l.foldLeft("")(_ + _) == "")
    l.add('a')
    l.add('b')
    l.add('c')
    l.add('d')
    assert(l.foldLeft("")(_ + _) == "abcd")
  }

  it should "foldRight correctly" in {
    import com.twosigma.flint.util.collection.Implicits._
    val l = new JLinkedList[Char]()
    assert(l.foldRight("")(_ + _) == "")
    l.add('a')
    l.add('b')
    l.add('c')
    l.add('d')
    assert(l.foldRight("")(_ + _) == "dcba")
  }
} 
Example 71
Source File: RangeSplitSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd

import org.scalatest.FlatSpec

class RangeSplitSpec extends FlatSpec {

  val rangeSplits = IndexedSeq(
    RangeSplit(Split(0), CloseOpen(2, Option(3))),
    RangeSplit(Split(1), CloseOpen(3, Option(4))),
    RangeSplit(Split(2), CloseOpen(5, Option(7))),
    RangeSplit(Split(3), CloseOpen(7, Option(10)))
  )

  "The RangeSplit" should "getNextBegin correctly" in {
    val begins = rangeSplits.map(_.range.begin)
    assert(RangeSplit.getNextBegin(2, begins) == Some(3))
    assert(RangeSplit.getNextBegin(4, begins) == Some(5))
    assert(RangeSplit.getNextBegin(6, begins) == Some(7))
    assert(RangeSplit.getNextBegin(8, begins).isEmpty)
    assert(RangeSplit.getNextBegin(1, Vector[Int]()).isEmpty)
  }

  it should "getSplitsWithinRange correctly" in {
    assert(RangeSplit.getIntersectingSplits(CloseOpen(4, Some(6)), rangeSplits) ==
      List(RangeSplit(Split(2), CloseOpen(5, Some(7)))))

    assert(RangeSplit.getIntersectingSplits(CloseOpen(4, None), rangeSplits) ==
      List(RangeSplit(Split(2), CloseOpen(5, Some(7))), RangeSplit(Split(3), CloseOpen(7, Some(10)))))

  }
} 
Example 72
Source File: OrderedIteratorSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd

import org.scalatest.FlatSpec

class OrderedIteratorSpec extends FlatSpec {

  "OrderedKeyValueIterator" should "iterate through an ordered collection correctly" in {
    val data = (1 to 10).map { x => (x, x) }.toArray
    assert(data.deep == OrderedIterator(data.iterator).toArray.deep)
  }

  it should "filter on a range correctly" in {
    val data = (1 to 10).map { x => (x, x) }.toArray
    var range = CloseOpen(0, Some(1))
    assert(OrderedIterator(data.iterator).filterByRange(range).length == 0)

    range = CloseOpen(0, Some(2))
    assert(OrderedIterator(data.iterator).filterByRange(range).toArray.deep == Array((1, 1)).deep)

    range = CloseOpen(3, Some(5))
    assert(OrderedIterator(data.iterator).filterByRange(range).toArray.deep == Array((3, 3), (4, 4)).deep)

    range = CloseOpen(9, Some(20))
    assert(OrderedIterator(data.iterator).filterByRange(range).toArray.deep == Array((9, 9), (10, 10)).deep)

    range = CloseOpen(11, Some(20))
    assert(OrderedIterator(data.iterator).filterByRange(range).length == 0)
  }
} 
Example 73
Source File: TreeAggregateSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd.function.summarize

import com.twosigma.flint.SharedSparkContext
import org.scalatest.FlatSpec
import org.scalatest.tagobjects.Slow

class TreeAggregateSpec extends FlatSpec with SharedSparkContext {

  "TreeAggregate" should "aggregate as RDD.aggregate in order for max op" taggedAs (Slow) in {
    val numOfPartitions = 1001
    val scale = 5
    val maxDepth = 5
    val rdd = sc.parallelize(1 to numOfPartitions, numOfPartitions).mapPartitionsWithIndex {
      (idx, _) => (1 to scale).map { x => idx * scale + x }.toIterator
    }

    // Use -1 as a "bad" state and propagate through the aggregation, otherwise
    // it is just simply a Math.max() operator.
    val seqOp = (u: Int, t: Int) => if (u > t || u < 0) {
      -1
    } else {
      t
    }

    val combOp = (u1: Int, u2: Int) => if (u1 >= u2 || u1 < 0 || u2 < 0) {
      -1
    } else {
      u2
    }

    val expectedAggregatedResult = rdd.max()

    (1 to maxDepth).foreach {
      depth => assert(TreeAggregate(rdd)(0, seqOp, combOp, depth) == expectedAggregatedResult)
    }
  }

  it should "aggregate as RDD.aggregate in order for string concat" taggedAs (Slow) in {
    val numOfPartitions = 1001
    val scale = 5
    val maxDepth = 5
    val rdd = sc.parallelize(1 to numOfPartitions, numOfPartitions).mapPartitionsWithIndex {
      (idx, _) => (1 to scale).map { x => s"${idx * scale + x}" }.toIterator
    }

    val seqOp = (u: String, t: String) => u + t
    val combOp = (u1: String, u2: String) => u1 + u2
    val expectedAggregatedResult = rdd.collect().mkString("")

    (1 to maxDepth).foreach {
      depth => assert(TreeAggregate(rdd)("", seqOp, combOp, depth) == expectedAggregatedResult)
    }
  }

  it should "aggregate as RDD.aggregate in order for sum" taggedAs (Slow) in {
    val numOfPartitions = 1001
    val scale = 5
    val maxDepth = 5
    val rdd = sc.parallelize(1 to numOfPartitions, numOfPartitions).mapPartitionsWithIndex {
      (idx, _) => (1 to scale).map { x => idx * scale + x }.toIterator
    }

    val seqOp = (u: Int, t: Int) => u + t
    val combOp = (u1: Int, u2: Int) => u1 + u2
    val expectedAggregatedResult = rdd.sum()

    (1 to maxDepth).foreach {
      depth => assert(TreeAggregate(rdd)(0, seqOp, combOp, depth) == expectedAggregatedResult)
    }
  }

} 
Example 74
Source File: TreeReduceSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd.function.summarize

import com.twosigma.flint.SharedSparkContext
import org.scalatest.FlatSpec
import org.scalatest.tagobjects.Slow

class TreeReduceSpec extends FlatSpec with SharedSparkContext {

  "TreeReduce" should "reduce in order for max op" taggedAs (Slow) in {
    val numOfPartitions = 1023
    val scale = 5
    val maxDepth = 5
    val rdd = sc.parallelize(1 to numOfPartitions, numOfPartitions).mapPartitionsWithIndex {
      (idx, _) => (1 to scale).map { x => idx * scale + x }.toIterator
    }

    // Use -1 as a "bad" state and propagate through the aggregation, otherwise
    // it is just simply a Math.max() operator.
    val op = (u1: Int, u2: Int) => if (u1 >= u2 || u1 < 0 || u2 < 0) {
      -1
    } else {
      u2
    }

    val expectedReducedResult = rdd.max()

    (1 to maxDepth).foreach {
      depth => assert(TreeReduce(rdd)(op, depth) == expectedReducedResult)
    }
  }

  it should "reduce in order for string concat" taggedAs (Slow) in {
    val numOfPartitions = 1111
    val scale = 5
    val maxDepth = 5
    val rdd = sc.parallelize(1 to numOfPartitions, numOfPartitions).mapPartitionsWithIndex {
      (idx, _) => (1 to scale).map { x => s"${idx * scale + x}" }.toIterator
    }

    val f = (u1: String, u2: String) => u1 + u2
    val expectedReducedResult = rdd.collect().mkString("")

    (1 to maxDepth).foreach {
      depth => assert(TreeReduce(rdd)(f, depth) == expectedReducedResult)
    }
  }

  it should "reduce in order for sum op" taggedAs (Slow) in {
    val numOfPartitions = 1023
    val scale = 5
    val maxDepth = 5
    val rdd = sc.parallelize(1 to numOfPartitions, numOfPartitions).mapPartitionsWithIndex {
      (idx, _) => (1 to scale).map { x => idx * scale + x }.toIterator
    }

    val expectedReducedResult = rdd.sum()

    val f = (u1: Int, u2: Int) => u1 + u2

    (1 to maxDepth).foreach {
      depth => assert(TreeReduce(rdd)(f, depth) == expectedReducedResult)
    }
  }

} 
Example 75
Source File: RegressionSummarizerSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd.function.summarize.summarizer

import breeze.linalg.{ DenseVector, DenseMatrix }
import org.scalatest.FlatSpec

class RegressionSummarizerSpec extends FlatSpec {

  "RegressionSummarizer" should "transform from RegressRow correctly" in {
    val x: Array[RegressionRow] = Array(
      RegressionRow(time = 0L, x = Array(1d, 2d), y = 3d, weight = 4d),
      RegressionRow(time = 0L, x = Array(4d, 5d), y = 6d, weight = 16d)
    )

    val (response1, predictor1, yw1) = RegressionSummarizer.transform(x, shouldIntercept = true, isWeighted = true)
    assert(response1.equals(DenseMatrix(Array(2d, 2d, 4d), Array(4d, 16d, 20d))))
    assert(predictor1.equals(DenseVector(Array(6d, 24d))))
    assert(yw1.deep == Array((3d, 4d), (6d, 16d)).deep)

    val (response2, predictor2, yw2) = RegressionSummarizer.transform(x, shouldIntercept = true, isWeighted = false)
    assert(response2.equals(DenseMatrix(Array(1d, 1d, 2d), Array(1d, 4d, 5d))))
    assert(predictor2.equals(DenseVector(Array(3d, 6d))))
    assert(yw2.deep == Array((3d, 1d), (6d, 1d)).deep)

    val (response3, predictor3, yw3) = RegressionSummarizer.transform(x, shouldIntercept = false, isWeighted = true)
    assert(response3.equals(DenseMatrix(Array(2d, 4d), Array(16d, 20d))))
    assert(predictor3.equals(DenseVector(Array(6d, 24d))))
    assert(yw3.deep == Array((3d, 4d), (6d, 16d)).deep)

    val (response4, predictor4, yw4) = RegressionSummarizer.transform(x, shouldIntercept = false, isWeighted = false)
    assert(response4.equals(DenseMatrix(Array(1d, 2d), Array(4d, 5d))))
    assert(predictor4.equals(DenseVector(Array(3d, 6d))))
    assert(yw4.deep == Array((3d, 1d), (6d, 1d)).deep)
  }
} 
Example 76
Source File: LagWindowSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd.function.summarize.summarizer.regression

import org.scalatest.FlatSpec

class LagWindowSpec extends FlatSpec {

  "LagWindow" should "give an `AbsoluteTimeLagWindow` as expect" in {
    val lagWindow = LagWindow.absolute(10L)
    assert(!lagWindow.shouldKeep(20L, 10L, 0))
    assert(!lagWindow.shouldKeep(20L, 10L, Int.MaxValue))
    assert(lagWindow.shouldKeep(15L, 10L, 0))
    assert(lagWindow.shouldKeep(15L, 10L, Int.MaxValue))
    assert(!lagWindow.shouldKeep(25L, 10L, 0))
    assert(!lagWindow.shouldKeep(25L, 10L, Int.MaxValue))
  }

  it should "give a `CountLagWindow` as expect" in {
    val countLagWindow = LagWindow.count(3, 10)
    assert(countLagWindow.shouldKeep(20L, 10L, 0))
    assert(countLagWindow.shouldKeep(20L, 10L, 1))
    assert(countLagWindow.shouldKeep(20L, 10L, 2))
    assert(countLagWindow.shouldKeep(20L, 10L, 3))
    assert(countLagWindow.shouldKeep(20L, 10L, 4))
    assert(!countLagWindow.shouldKeep(20L, 10L, 5))

    assert(countLagWindow.shouldKeep(15L, 10L, 0))
    assert(countLagWindow.shouldKeep(15L, 10L, 1))
    assert(countLagWindow.shouldKeep(15L, 10L, 2))
    assert(countLagWindow.shouldKeep(15L, 10L, 3))
    assert(countLagWindow.shouldKeep(15L, 10L, 4))
    assert(!countLagWindow.shouldKeep(15L, 10L, 5))

    assert(!countLagWindow.shouldKeep(25L, 10L, 0))
    assert(!countLagWindow.shouldKeep(25L, 10L, 1))
    assert(!countLagWindow.shouldKeep(25L, 10L, 2))
    assert(!countLagWindow.shouldKeep(25L, 10L, 3))
    assert(!countLagWindow.shouldKeep(25L, 10L, 4))
    assert(!countLagWindow.shouldKeep(25L, 10L, 5))
  }

  "LagWindowQueue" should "work with an `AbsoluteTimeLagWindow` as expect" in {
    val window = LagWindow.absolute(10L)
    val lagWindowQueue = new LagWindowQueue[Int](window)
    lagWindowQueue.enqueue(0, 0)
    lagWindowQueue.enqueue(5, 5)
    lagWindowQueue.enqueue(10, 10)
    assert(lagWindowQueue.length == 2)
    assert(lagWindowQueue.head.timestamp == 5L)
    assert(lagWindowQueue.last.timestamp == 10L)
    lagWindowQueue.enqueue(100, 100)
    assert(lagWindowQueue.length == 1)
    assert(lagWindowQueue.last.timestamp == 100L)
  }

  it should "work with a `CountLagWindow` as expect" in {
    val window = LagWindow.count(3, 10L)
    val lagWindowQueue = new LagWindowQueue[Int](window)
    lagWindowQueue.enqueue(0, 0)
    lagWindowQueue.enqueue(5, 5)
    lagWindowQueue.enqueue(7, 7)
    lagWindowQueue.enqueue(10, 10)
    lagWindowQueue.enqueue(10, 10)

    assert(lagWindowQueue.length == 4)
    assert(lagWindowQueue.head.timestamp == 0L)
    assert(lagWindowQueue.last.timestamp == 10L)

    lagWindowQueue.enqueue(11, 11)
    assert(lagWindowQueue.length == 4)
    assert(lagWindowQueue.head.timestamp == 5L)

    lagWindowQueue.enqueue(100, 100)
    assert(lagWindowQueue.length == 1)
    assert(lagWindowQueue.head.timestamp == 100)

    lagWindowQueue.clear()
    assert(lagWindowQueue.length == 0)
  }
} 
Example 77
Source File: OverlappedOrderedRDDSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd

import com.twosigma.flint.SharedSparkContext
import org.apache.spark.rdd.RDD
import org.scalatest.FlatSpec

class OverlappedOrderedRDDSpec extends FlatSpec with SharedSparkContext {

  val numSlices: Int = 3

  val sliceLength: Int = 4

  var rdd: RDD[(Int, Int)] = _

  var orderedRdd: OrderedRDD[Int, Int] = _

  var overlappedOrderedRdd: OverlappedOrderedRDD[Int, Int] = _

  private def window(t: Int): (Int, Int) = (t - 2, t)

  override def beforeAll() {
    super.beforeAll()
    val s = sliceLength
    rdd = sc.parallelize(0 until numSlices, numSlices).flatMap {
      i => (1 to s).map { j => i * s + j }
    }.map { x => (x, x) }
    orderedRdd = OrderedRDD.fromRDD(rdd, KeyPartitioningType.Sorted)
    overlappedOrderedRdd = OverlappedOrderedRDD(orderedRdd, window)
  }

  "The OverlappedOrderedRDD" should "be constructed from `OrderedRDD` correctly" in {
    assert(overlappedOrderedRdd.rangeSplits.deep == orderedRdd.rangeSplits.deep)
    val benchmark = Array(1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 8, 9, 10, 11, 12).map { x => (x, x) }
    assert(overlappedOrderedRdd.collect().deep == benchmark.deep)
  }

  it should "be able to remove overlapped rows to get an `OrderedRDD` correctly" in {
    assert(overlappedOrderedRdd.rangeSplits.deep == orderedRdd.rangeSplits.deep)
    assert(overlappedOrderedRdd.nonOverlapped().collect().deep == orderedRdd.collect().deep)
  }

  it should "`mapPartitionsWithIndexOverlapped` correctly" in {
    val mapped = overlappedOrderedRdd.mapPartitionsWithIndexOverlapped(
      (index, iterator) => iterator.map { case (k, v) => (k, v * 2) }
    )
    val benchmark = Array(1, 2, 3, 4, 5, 4, 5, 6, 7, 8, 9, 8, 9, 10, 11, 12).map { x => (x, 2 * x) }
    assert(mapped.collect().deep == benchmark.deep)
  }
} 
Example 78
Source File: RangeMergeJoinSpec.scala    From flint   with Apache License 2.0 5 votes vote down vote up
package com.twosigma.flint.rdd

import com.twosigma.flint.rdd.function.join.RangeMergeJoin
import org.apache.spark.Partition
import org.scalatest.FlatSpec

class RangeMergeJoinSpec extends FlatSpec {
  val thisSplits = IndexedSeq(
    RangeSplit(Split(0), CloseOpen(1, Some(2))),
    RangeSplit(Split(1), CloseOpen(2, Some(3))),
    RangeSplit(Split(2), CloseOpen(3, Some(4))),
    RangeSplit(Split(3), CloseOpen(4, Some(5))),
    RangeSplit(Split(4), CloseOpen(5, None))
  )

  val thatSplits = IndexedSeq(
    RangeSplit(Split(0), CloseOpen(1, Some(3))),
    RangeSplit(Split(1), CloseOpen(3, Some(7))),
    RangeSplit(Split(2), CloseOpen(7, None))
  )

  "The RangeMergeJoin" should "`mergeSplits` with no tolerance correctly" in {
    val benchmark = List(
      RangeMergeJoin(
        CloseOpen(1, Some(2)),
        List(RangeSplit(Split(0), CloseOpen(1, Some(2)))),
        List(RangeSplit(Split(0), CloseOpen(1, Some(3))))
      ),
      RangeMergeJoin(
        CloseOpen(2, Some(3)),
        List(RangeSplit(Split(1), CloseOpen(2, Some(3)))),
        List(RangeSplit(Split(0), CloseOpen(1, Some(3))))
      ),
      RangeMergeJoin(
        CloseOpen(3, Some(4)),
        List(RangeSplit(Split(2), CloseOpen(3, Some(4)))),
        List(RangeSplit(Split(1), CloseOpen(3, Some(7))))
      ),
      RangeMergeJoin(
        CloseOpen(4, Some(5)),
        List(RangeSplit(Split(3), CloseOpen(4, Some(5)))),
        List(RangeSplit(Split(1), CloseOpen(3, Some(7))))
      ),
      RangeMergeJoin(
        CloseOpen(5, Some(7)),
        List(RangeSplit(Split(4), CloseOpen(5, None))),
        List(RangeSplit(Split(1), CloseOpen(3, Some(7))))
      ),
      RangeMergeJoin(
        CloseOpen(7, None),
        List(RangeSplit(Split(4), CloseOpen(5, None))),
        List(RangeSplit(Split(2), CloseOpen(7, None)))
      )
    )
    assertResult(benchmark) { RangeMergeJoin.mergeSplits(thisSplits, thatSplits) }
  }

  it should "`mergeSplits` with some tolerance correctly" in {
    val benchmark = List(
      RangeMergeJoin(
        CloseOpen(1, Some(2)),
        List(RangeSplit(Split(0), CloseOpen(1, Some(2)))),
        List(RangeSplit(Split(0), CloseOpen(1, Some(3))))
      ),
      RangeMergeJoin(
        CloseOpen(2, Some(3)),
        List(RangeSplit(Split(0), CloseOpen(1, Some(2))), RangeSplit(Split(1), CloseOpen(2, Some(3)))),
        List(RangeSplit(Split(0), CloseOpen(1, Some(3))))
      ),
      RangeMergeJoin(
        CloseOpen(3, Some(4)),
        List(RangeSplit(Split(1), CloseOpen(2, Some(3))), RangeSplit(Split(2), CloseOpen(3, Some(4)))),
        List(RangeSplit(Split(0), CloseOpen(1, Some(3))), RangeSplit(Split(1), CloseOpen(3, Some(7))))
      ),
      RangeMergeJoin(
        CloseOpen(4, Some(5)),
        List(RangeSplit(Split(2), CloseOpen(3, Some(4))), RangeSplit(Split(3), CloseOpen(4, Some(5)))),
        List(RangeSplit(Split(1), CloseOpen(3, Some(7))))
      ),
      RangeMergeJoin(
        CloseOpen(5, Some(7)),
        List(RangeSplit(Split(3), CloseOpen(4, Some(5))), RangeSplit(Split(4), CloseOpen(5, None))),
        List(RangeSplit(Split(1), CloseOpen(3, Some(7))))
      ),
      RangeMergeJoin(
        CloseOpen(7, None),
        List(RangeSplit(Split(4), CloseOpen(5, None))),
        List(RangeSplit(Split(1), CloseOpen(3, Some(7))), RangeSplit(Split(2), CloseOpen(7, None)))
      )
    )

    assertResult(benchmark) { RangeMergeJoin.mergeSplits(thisSplits, thatSplits, { x: Int => x - 1 }) }
  }
} 
Example 79
Source File: PointRDDExtensionsSpec.scala    From reactiveinflux-spark   with Apache License 2.0 5 votes vote down vote up
package com.pygmalios.reactiveinflux.extensions

import com.holdenkarau.spark.testing.SharedSparkContext
import com.pygmalios.reactiveinflux.Point.Measurement
import com.pygmalios.reactiveinflux._
import com.pygmalios.reactiveinflux.extensions.PointRDDExtensionsSpec._
import com.pygmalios.reactiveinflux.spark._
import com.pygmalios.reactiveinflux.spark.extensions.PointRDDExtensions
import org.joda.time.{DateTime, DateTimeZone}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, FlatSpec}

import scala.concurrent.duration._

@RunWith(classOf[JUnitRunner])
class PointRDDExtensionsSpec extends FlatSpec with SharedSparkContext
  with BeforeAndAfter {

  before {
    withInflux(_.create())
  }

  after {
    withInflux(_.drop())
  }

  behavior of "saveToInflux"

  it should "write single point to Influx" in {
    val points = List(point1)
    val rdd = sc.parallelize(points)

    // Execute
    rdd.saveToInflux()

    // Assert
    assert(PointRDDExtensions.totalBatchCount == 1)
    assert(PointRDDExtensions.totalPointCount == 1)
    val result = withInflux(
      _.query(Query(s"SELECT * FROM $measurement1"))
      .result
      .singleSeries)

    assert(result.rows.size == 1)

    val row = result.rows.head
    assert(row.time == point1.time)
    assert(row.values.size == 5)
  }

  it should "write 1000 points to Influx" in {
    val points = (1 to 1000).map { i =>
      Point(
        time = point1.time.plusMinutes(i),
        measurement = point1.measurement,
        tags = point1.tags,
        fields = point1.fields
      )
    }
    val rdd = sc.parallelize(points)

    // Execute
    rdd.saveToInflux()

    // Assert
    assert(PointRDDExtensions.totalBatchCount == 8)
    assert(PointRDDExtensions.totalPointCount == 1000)
    val result = withInflux(
      _.query(Query(s"SELECT * FROM $measurement1"))
        .result
        .singleSeries)

    assert(result.rows.size == 1000)
  }
}

object PointRDDExtensionsSpec {
  implicit val params: ReactiveInfluxDbName = ReactiveInfluxDbName("test")
  implicit val awaitAtMost: Duration = 1.second

  val measurement1: Measurement = "measurement1"
  val point1 = Point(
    time        = new DateTime(1983, 1, 10, 7, 43, 10, 3, DateTimeZone.UTC),
    measurement = measurement1,
    tags        = Map("tagKey1" -> "tagValue1", "tagKey2" -> "tagValue2"),
    fields      = Map("fieldKey1" -> StringFieldValue("fieldValue1"), "fieldKey2" -> BigDecimalFieldValue(10.7)))
} 
Example 80
Source File: IntegrationSpec.scala    From Principles-of-Reactive-Programming   with GNU General Public License v3.0 5 votes vote down vote up
package kvstore

import akka.actor.{ Actor, Props, ActorRef, ActorSystem }
import akka.testkit.{ TestProbe, ImplicitSender, TestKit }
import org.scalatest.{ BeforeAndAfterAll, FlatSpec, Matchers }
import scala.concurrent.duration._
import org.scalatest.FunSuiteLike
import org.scalactic.ConversionCheckedTripleEquals

class IntegrationSpec(_system: ActorSystem) extends TestKit(_system)
    with FunSuiteLike
        with Matchers
    with BeforeAndAfterAll
    with ConversionCheckedTripleEquals
    with ImplicitSender
    with Tools {

  import Replica._
  import Replicator._
  import Arbiter._

  def this() = this(ActorSystem("ReplicatorSpec"))

  override def afterAll: Unit = system.shutdown()

  
  } 
Example 81
Source File: NestedTest.scala    From TwoTails   with Apache License 2.0 5 votes vote down vote up
package twotails

import org.scalatest.{ FlatSpec, Matchers }

class Nested{
  @mutualrec final def nest(z: Int): Int ={
    @mutualrec def one(x: Int): Int = if(0 < x) two(x-1) else 0
    @mutualrec def two(x: Int): Int = if(0 < x) one(x-2) else 0

    if(0 < z) one(z) else aha(z)
  }

  @mutualrec final def aha(z: Int): Int = nest(z+1)

  def thing(y: Int) ={
    class Yo{
      @mutualrec final def one(x: Int): Int = if(0 < x) two(x-1) else 0
      @mutualrec final def two(x: Int): Int = if(0 < x) one(x-2) else 0
    }

    (new Yo).two(y)
  }

  def other(y: Int) ={
    object Yo{
      @mutualrec final def one(x: Int): Int = if(0 < x) two(x-1) else 0
      @mutualrec final def two(x: Int): Int = if(0 < x) one(x-2) else 0
    }

    Yo.one(y)
  }

  class Foo{
    @mutualrec final def one(x: Int): Int = if(0 < x) two(x-1) else 0
    @mutualrec final def two(x: Int): Int = if(0 < x) one(x-2) else 0
  }

  def something(y: Int): Int ={
  	@mutualrec def one(x: Int): Int = if(0 < x) two(x-1) else 0
    @mutualrec def two(x: Int): Int = if(0 < x) one(x-2) else 0

    one(y)
  }

  { //just a block which will be discarded
  	@mutualrec def one(x: Int): Int = if(0 < x) two(x-1) else 0
    @mutualrec def two(x: Int): Int = if(0 < x) one(x-2) else 0
  }

  def dis(y: Int) ={
  	{ //another block which will be discarded but with a name clash
  	  @mutualrec def one(x: Int): Int = if(0 < x) two(x-1) else 1
      @mutualrec def two(x: Int): Int = if(0 < x) one(x-2) else 1
    }

    @mutualrec def one(x: Int): Int = if(0 < x) two(x-1) else 0
    @mutualrec def two(x: Int): Int = if(0 < x) one(x-2) else 0

    one(y)
  }

  val that ={ xy: Int =>
  	@mutualrec def one(x: Int): Int = if(0 < x) two(x-1) else 0
    @mutualrec def two(x: Int): Int = if(0 < x) one(x-2) else 0

    one(xy)
  }
}

class NestedTest extends FlatSpec with Matchers{
  val fourK = 400000

  "A nested class within a def which has annotated methods" should "not throw a StackOverflow" in{
  	val nest = new Nested

  	nest.thing(fourK) should equal (0)
  }

  "A nested object within a def which has annotated methods" should "not throw a StackOverflow" in{
  	val nest = new Nested

  	nest.other(fourK) should equal (0)
  }

  "A nested class within a class which has annotated methods" should "not throw a StackOverflow" in{
  	val nest = new Nested
  	val foo = new nest.Foo

  	foo.one(fourK) should equal (0)
  }

  "A nested set of annotated methods with a name clashing nested block" should "not throw a StackOverflow" in{
    val nest = new Nested

    nest.dis(fourK) should equal(0)
  }
} 
Example 82
Source File: MultiRecTest.scala    From TwoTails   with Apache License 2.0 5 votes vote down vote up
package twotails

import org.scalatest.{ FlatSpec, Matchers }
import java.lang.StackOverflowError

class Cup{
  @mutualrec final def one(x: Int, y: Int = 0): Int = if(0 < x) two(x-1, y) else y
  @mutualrec final def two(u: Int, v: Int = 0): Int = if(0 < u) one(u-1, v) else v

  @mutualrec final def three(x: Int, y: Int = 1): Int = if(0 < x) four(x-1, y) else y
  @mutualrec final def four(u: Int, v: Int = 1): Int = if(0 < u) three(u-1, v) else v
}

class Bowl{
  @mutualrec final def one[A](x: Int, y: A): A = if(0 < x) two(x-1, y) else y
  @mutualrec final def two[A](u: Int, v: A): A = if(0 < u) one(u-1, v) else v

  @mutualrec final def three(x: Int, y: Int = 1): Int = if(0 < x) four(x-1, y) else y
  @mutualrec final def four(u: Int, v: Int = 1): Int = if(0 < u) three(u-1, v) else v
}

class Plate{
  @mutualrec final def zero(x: Int): Int = if(x < 0) x else one(x-1)
  @mutualrec final def one(x: Int): Int = if(x < 0) two(x, x) else zero(x-1)

  @mutualrec final def two(x: Int, y: Int): Int = if(x < 0) y else three(x-1, y+1)
  @mutualrec final def three(x: Int, y: Int): Int = if(x < 0) one(x) else two(x-1, y+1)
}

class Saucer{
  @mutualrec final def zero(x: Int): Int = if(x < 0) x else one(x-1)
  @mutualrec final def one(x: Int): Int = if(x < 0) zero(x-1) else two(x, x)

  @mutualrec final def two(x: Int, y: Int): Int = if(x < 0) y else three(x-1, y+1)
  @mutualrec final def three(x: Int, y: Int): Int = if(x < 0) two(x-1, y+1) else one(x)
}

class MultiRecTest extends FlatSpec with Matchers{
  val fourK = 400000

  "a class with two sets of mutually recursive functions" should "just work" in{
  	val cup = new Cup
  	cup.one(fourK) should equal(0)
  	cup.three(fourK) should equal(1)
  }

  "a class with two sets of mutually recursive functions but different types" should "just work" in{
  	val bowl = new Bowl
  	bowl.one(fourK, "a") should equal("a")
  	bowl.three(fourK) should equal(1)
  }

  "a class with recursive functions of different types" should "not be transformed together" in{
    val saucer = new Saucer

    intercept[StackOverflowError]{
      saucer.three(fourK, fourK)
    }
  }
} 
Example 83
Source File: DefaultArgumentTest.scala    From TwoTails   with Apache License 2.0 5 votes vote down vote up
package twotails

import org.scalatest.{ FlatSpec, Matchers }

class SpeedBump{
  @mutualrec final def one(x: Int, y: Int = 1): Int = if(0 < x) two(x-1) else y
  @mutualrec final def two(x: Int, y: Int = 2): Int = if(0 < x) one(x-1) else y
}

class Pothole{
  @mutualrec final def one(x: Int)(y: Int, z: Int = 1): Int = if(0 < x) two(x-1)(y) else z
  @mutualrec final def two(x: Int)(y: Int, z: Int = 2): Int = if(0 < x) one(x-1)(y) else z
}

class Ditch{
  @mutualrec final def one(x: Int, y: Int, z: Int = 1): Int = if(0 < x) two(x = x-1, y = y) else z
  @mutualrec final def two(x: Int, y: Int, z: Int = 2): Int = if(0 < x) one(y = y, x = x-1) else z
}

class GuardRail{
  @mutualrec final def one(v: Int, x: Int, y: Int = 1, z: Int = 1): Int = if(0 < v) two(v-1, x, z=z) else y
  @mutualrec final def two(v: Int, x: Int, y: Int = 2, z: Int = 2): Int = if(0 < v) one(v-1, x, z=z) else y
}

class DefaultArgumentTest extends FlatSpec with Matchers{
  val fourK = 400000

  "mutually recursive functions with default args" should "use the default args" in{
  	val sb = new SpeedBump
  	sb.one(fourK) should not equal sb.two(fourK)
  }

  "mutually recursive functions with default args and multi-param lists" should "use the default args" in{
  	val pt = new Pothole
  	pt.one(fourK)(fourK) should not equal pt.two(fourK)(fourK)
  }

  "mutually recursive function with default args called by name" should "use the default args" in{
  	val dt = new Ditch
  	dt.one(fourK, fourK) should not equal dt.two(fourK, fourK)
  }

  "mutually recursive function with default args with mixed calls" should "use the default args" in{
  	val gr = new GuardRail
  	gr.one(fourK, fourK) should not equal gr.two(fourK, fourK)
  }
} 
Example 84
Source File: ObjectTest.scala    From TwoTails   with Apache License 2.0 5 votes vote down vote up
package twotails

import org.scalatest.{ FlatSpec, Matchers }
import java.lang.StackOverflowError

object Blappy{
  @mutualrec final def one(x: Int): Int = if(0 < x) two(x-1) else 0
  @mutualrec final def two(x: Int): Int = if(0 < x) one(x-1) else 0
}

class ObjectTest extends FlatSpec with Matchers{
  val fourK = 400000

  "Two mutually recursive, single argument, annotated methods on an object" should "not throw a StackOverflow" in{
    Blappy.one(fourK) should equal (0)
  }
} 
Example 85
Source File: ArgumentListTest.scala    From TwoTails   with Apache License 2.0 5 votes vote down vote up
package twotails

import org.scalatest.{ FlatSpec, Matchers }

class Bippy{
  @mutualrec final def one(x: Int, y: Int = 1): Int = if(0 < x) two(y, x) else 0
  @mutualrec final def two(x: Int, y: Int = 1): Int = if(0 < x) one(x-1, y-1) else 0
}

class Baz{
  @mutualrec final def one(x: Int)(y: Int): Int = if(0 < x) two(y)(x) else 0
  @mutualrec final def two(x: Int)(y: Int): Int = if(0 < x) one(x-1)(y-1) else 0
}

class Bazooka{
  @mutualrec final def one(x: Int)(y: Int)(z: Int): Int = if(0 < x) two(y)(x)(z) else z
  @mutualrec final def two(x: Int)(y: Int)(z: Int): Int = if(0 < x) one(x-1)(y-1)(z+1) else z
}

class ArgumentListTest extends FlatSpec with Matchers{
  val fourK = 400000

  "Two mutually recursive, double-argument list, annotated methods" should "not throw a StackOverflow" in{
  	val c = new Baz

  	c.one(fourK)(fourK) should equal (0)
  }

  "Two mutually recursive, multi-argument list, annotated methods" should "not throw a StackOverflow" in{
    val baz = new Bazooka

    baz.one(fourK)(fourK)(0) should equal (fourK)
  }
} 
Example 86
Source File: RetimeSpec.scala    From barstools   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
// See LICENSE for license details.

package barstools.tapeout.transforms.retime.test

import chisel3._
import firrtl._
import org.scalatest.{FlatSpec, Matchers}
import chisel3.experimental._
import chisel3.util.HasBlackBoxInline
import chisel3.iotesters._
import barstools.tapeout.transforms.retime._

class RetimeSpec extends FlatSpec with Matchers {
  def normalized(s: String): String = {
    require(!s.contains("\n"))
    s.replaceAll("\\s+", " ").trim
  }
  def uniqueDirName[T](gen: => T, name: String): String = {
    val genClassName = gen.getClass.getName
    name + genClassName.hashCode.abs
  }

  behavior of "retime library"

  it should "pass simple retime module annotation" in {
    val gen = () => new RetimeModule()
    val dir = uniqueDirName(gen, "RetimeModule")
    chisel3.Driver.execute(Array("-td", s"test_run_dir/$dir", "-foaf", s"test_run_dir/$dir/final"), gen) shouldBe a [ChiselExecutionSuccess]

    val lines = io.Source.fromFile(s"test_run_dir/$dir/test_run_dir/$dir/final.anno.json").getLines().map(normalized).mkString("\n")
    lines should include("barstools.tapeout.transforms.retime.RetimeTransform")
  }

  // TODO(azidar): need to fix/add instance annotations
  ignore should "pass simple retime instance annotation" in {
    val gen = () => new RetimeInstance()
    val dir = uniqueDirName(gen, "RetimeInstance")
    chisel3.Driver.execute(Array("-td", s"test_run_dir/$dir", "-foaf", s"test_run_dir/$dir/final.anno"), gen) shouldBe a [ChiselExecutionSuccess]

    val lines = io.Source.fromFile(s"test_run_dir/$dir/final.anno").getLines().map(normalized).toSeq
    lines should contain ("Annotation(ComponentName(instance, ModuleName(RetimeInstance,CircuitName(RetimeInstance))),class barstools.tapeout.transforms.retime.RetimeTransform,retime)")
  }
}

class RetimeModule extends Module with RetimeLib {
  val io = IO(new Bundle {
    val in = Input(UInt(15.W))
    val out = Output(UInt(15.W))
  })
  io.out := io.in
  retime(this)
}

class MyModule extends Module with RetimeLib {
  val io = IO(new Bundle {
    val in = Input(UInt(15.W))
    val out = Output(UInt(15.W))
  })
  io.out := io.in
}

class RetimeInstance extends Module with RetimeLib {
  val io = IO(new Bundle {
    val in = Input(UInt(15.W))
    val out = Output(UInt(15.W))
  })
  val instance = Module(new MyModule)
  retime(instance)
  instance.io.in := io.in
  io.out := instance.io.out
} 
Example 87
Source File: MainServiceSpec.scala    From akka-api-gateway-example   with MIT License 5 votes vote down vote up
package jp.co.dzl.example.akka.api

import akka.http.scaladsl.Http
import com.typesafe.config.ConfigFactory
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{ BeforeAndAfterAll, Matchers, FlatSpec }

import scala.concurrent.Await
import scala.concurrent.duration.Duration

class MainServiceSpec extends FlatSpec with Matchers with BeforeAndAfterAll with ScalaFutures with MainService {
  override protected def afterAll: Unit = {
    Await.result(system.terminate(), Duration.Inf)
  }

  it should "inject configuration of http" in {
    val config = ConfigFactory.load()

    host shouldEqual config.getString("http.listen.host")
    port shouldEqual config.getInt("http.listen.port")
  }

  it should "bind and handle" in {
    val http = Http().bindAndHandle(handler.routes, host, port)
    http.futureValue.localAddress.getPort shouldEqual port
    http.futureValue.unbind()
  }
} 
Example 88
Source File: HttpClientSpec.scala    From akka-api-gateway-example   with MIT License 5 votes vote down vote up
package jp.co.dzl.example.akka.api.service

import akka.actor.ActorSystem
import akka.stream.scaladsl.Flow
import org.scalatest.{ BeforeAndAfterAll, Matchers, FlatSpec }

import scala.concurrent.Await
import scala.concurrent.duration.Duration

class HttpClientSpec extends FlatSpec with Matchers with BeforeAndAfterAll {
  implicit val system = ActorSystem("http-client-spec")
  implicit val executor = system.dispatcher

  override protected def afterAll: Unit = {
    Await.result(system.terminate(), Duration.Inf)
  }

  "#conectionHttps" should "return outgoing connection flow" in {
    val httpClient = new HttpClientImpl(system)
    val connection = httpClient.connectionHttps("127.0.0.1", 8000, 5)

    connection shouldBe a[Flow[_, _, _]]
  }
} 
Example 89
Source File: GitHubSpec.scala    From akka-api-gateway-example   with MIT License 5 votes vote down vote up
package jp.co.dzl.example.akka.api.service

import akka.actor.ActorSystem
import akka.http.scaladsl.model.headers.RawHeader
import akka.http.scaladsl.model.{ HttpMethods, HttpRequest, HttpResponse }
import akka.stream.ActorMaterializer
import akka.stream.scaladsl.{ Flow, Source }
import akka.stream.testkit.scaladsl.TestSink
import org.scalamock.scalatest.MockFactory
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{ BeforeAndAfterAll, FlatSpec, Matchers }

import scala.concurrent.Await
import scala.concurrent.duration.Duration

class GitHubSpec extends FlatSpec with Matchers with ScalaFutures with BeforeAndAfterAll with MockFactory {
  implicit val system = ActorSystem("github-spec")
  implicit val executor = system.dispatcher
  implicit val materializer = ActorMaterializer()

  override protected def afterAll: Unit = {
    Await.result(system.terminate(), Duration.Inf)
  }

  "#from" should "merge original headers to github request" in {
    val github = new GitHubImpl("127.0.0.1", 8000, 5, mock[HttpClient])
    val request = HttpRequest(HttpMethods.GET, "/")
      .addHeader(RawHeader("host", "dummy"))
      .addHeader(RawHeader("timeout-access", "dummy"))

    val result = Source.single(HttpRequest(HttpMethods.GET, "/v1/github/users/xxxxxx"))
      .via(github.from(request))
      .runWith(TestSink.probe[HttpRequest])
      .request(1)
      .expectNext()

    result.headers.filter(_.lowercaseName() == "host") shouldBe empty
    result.headers.filter(_.lowercaseName() == "timeout-access") shouldBe empty
    result.headers.filter(_.lowercaseName() == "x-forwarded-host") shouldNot be(empty)
  }

  "#send" should "connect using http client" in {
    val httpResponse = HttpResponse()
    val httpClient = mock[HttpClient]
    (httpClient.connectionHttps _).expects(*, *, *).returning(Flow[HttpRequest].map(_ => httpResponse))

    val github = new GitHubImpl("127.0.0.1", 8000, 5, httpClient)
    val result = Source.single(HttpRequest(HttpMethods.GET, "/"))
      .via(github.send)
      .runWith(TestSink.probe[HttpResponse])
      .request(1)
      .expectNext()

    result shouldBe httpResponse
  }
} 
Example 90
Source File: SqsClientSpec.scala    From akka-stream-sqs   with Apache License 2.0 5 votes vote down vote up
package me.snov.akka.sqs.client

import com.amazonaws.ClientConfiguration
import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.services.sqs.{AmazonSQS, AmazonSQSAsync}
import com.amazonaws.services.sqs.model.{ReceiveMessageRequest, ReceiveMessageResult}
import org.scalatest.mockito.MockitoSugar.mock
import org.scalatest.{FlatSpec, Matchers}
import org.mockito.Mockito._
import org.mockito.ArgumentMatchers._

class SqsClientSpec extends FlatSpec with Matchers {

  it should "call AWS client" in {

    val awsClient = mock[AmazonSQSAsync]

    val sqsClientSettings = SqsSettings(
      awsCredentialsProvider = Some(mock[AWSCredentialsProvider]),
      awsClientConfiguration = Some(mock[ClientConfiguration]),
      awsClient = Some(awsClient),
      queueUrl = ""
    )
    val sqsClient = SqsClient(sqsClientSettings)
    val receiveMessageResult = mock[ReceiveMessageResult]

    when(awsClient.receiveMessage(any[ReceiveMessageRequest])).thenReturn(receiveMessageResult)

    sqsClient.receiveMessage()

    verify(receiveMessageResult).getMessages
  }

  it should "pass parameters with ReceiveMessageRequest" in {

    val awsClient = mock[AmazonSQSAsync]

    val sqsClientSettings = SqsSettings(
      queueUrl = "",
      maxNumberOfMessages = 9,
      waitTimeSeconds = 7,
      awsCredentialsProvider = Some(mock[AWSCredentialsProvider]),
      awsClientConfiguration = Some(mock[ClientConfiguration]),
      awsClient = Some(awsClient),
      visibilityTimeout = Some(75)
    )
    val sqsClient = SqsClient(sqsClientSettings)
    val receiveMessageResult = mock[ReceiveMessageResult]

    val receiveMessageRequest = new ReceiveMessageRequest()
        .withQueueUrl("")
        .withMaxNumberOfMessages(9)
        .withVisibilityTimeout(75)
        .withWaitTimeSeconds(7)

    when(awsClient.receiveMessage(receiveMessageRequest)).thenReturn(receiveMessageResult)

    sqsClient.receiveMessage()

    verify(receiveMessageResult).getMessages
  }
} 
Example 91
Source File: SqsClientSettingsSpec.scala    From akka-stream-sqs   with Apache License 2.0 5 votes vote down vote up
package me.snov.akka.sqs.client

import com.amazonaws.ClientConfiguration
import com.amazonaws.auth.AWSCredentialsProvider
import com.amazonaws.client.builder.AwsClientBuilder.EndpointConfiguration
import com.typesafe.config.ConfigFactory
import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.mockito.MockitoSugar._

class SqsClientSettingsSpec extends FlatSpec with Matchers {

  it should "parse configuration" in {
    val conf = ConfigFactory.parseString(
      """
        reactive-sqs {
          endpoint = "http://localhost:9324/"
          region = "eu-west-1"
          queue-url = "http://localhost:9324/queue/queue1"
          max-number-of-messages = 10
          visibility-timeout = 60
          wait-time-seconds = 5
        }
      """)
      .getConfig("reactive-sqs")

    val settings = SqsSettings(
      conf,
      Some(mock[AWSCredentialsProvider]),
      Some(mock[ClientConfiguration])
    )

    settings.endpoint.get.getServiceEndpoint shouldBe "http://localhost:9324/"
    settings.endpoint.get.getSigningRegion shouldBe "eu-west-1"
    settings.queueUrl shouldBe "http://localhost:9324/queue/queue1"
    settings.maxNumberOfMessages shouldBe 10
    settings.visibilityTimeout shouldBe Some(60)
    settings.waitTimeSeconds shouldBe 5
  }

  it should "support optional parameters" in {
    val conf = ConfigFactory.parseString(
      """
        reactive-sqs {
          queue-url = "http://localhost:9324/queue/queue1"
          wait-time-seconds = 5
        }
      """)
      .getConfig("reactive-sqs")

    val settings = SqsSettings(
      conf,
      Some(mock[AWSCredentialsProvider]),
      Some(mock[ClientConfiguration])
    )

    settings.endpoint shouldBe None
    settings.queueUrl shouldBe "http://localhost:9324/queue/queue1"
    settings.maxNumberOfMessages shouldBe 10
    settings.visibilityTimeout shouldBe None
    settings.waitTimeSeconds shouldBe 5
  }
} 
Example 92
Source File: SqsAckSinkShapeSpec.scala    From akka-stream-sqs   with Apache License 2.0 5 votes vote down vote up
package me.snov.akka.sqs.shape

import akka.Done
import akka.stream.scaladsl.{Keep, Sink}
import akka.stream.testkit.scaladsl.TestSource
import com.amazonaws.handlers.AsyncHandler
import com.amazonaws.services.sqs.model._
import me.snov.akka.sqs._
import me.snov.akka.sqs.client.SqsClient
import org.mockito.Mockito._
import org.mockito.ArgumentMatchers._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.mockito.MockitoSugar.mock
import org.scalatest.{FlatSpec, Matchers}

import scala.concurrent.Await
import scala.concurrent.duration._

class SqsAckSinkShapeSpec extends FlatSpec with Matchers with DefaultTestContext {
  it should "delete messages on Ack" in {

    val sqsClient = mock[SqsClient]
    when(sqsClient.deleteAsync(any(), any())).thenAnswer(
      new Answer[Object] {
        override def answer(invocation: InvocationOnMock): Object = {
          val receiptHandle = invocation.getArgument[String](0)
          val callback = invocation.getArgument[AsyncHandler[DeleteMessageRequest, DeleteMessageResult]](1)
          callback.onSuccess(
            new DeleteMessageRequest().withReceiptHandle(receiptHandle),
            new DeleteMessageResult
          )
          None
        }
      }
    )

    val (probe, future) = TestSource.probe[MessageActionPair]
      .toMat(Sink.fromGraph(SqsAckSinkShape(sqsClient)))(Keep.both)
      .run()

    probe
      .sendNext((new Message().withReceiptHandle("123"), Ack()))
      .sendComplete()

    Await.result(future, 1.second) shouldBe Done
    verify(sqsClient, times(1)).deleteAsync(any(), any())
  }

  it should "requeue messages on RequeueWithDelay" in {

    val sqsClient = mock[SqsClient]
    when(sqsClient.sendWithDelayAsync(any[String], any[Int], any())).thenAnswer(
      new Answer[Object] {
        override def answer(invocation: InvocationOnMock): Object = {
          val body = invocation.getArgument[String](0)
          val delay = invocation.getArgument[Int](1)
          val callback = invocation.getArgument[AsyncHandler[SendMessageRequest, SendMessageResult]](2)
          callback.onSuccess(
            new SendMessageRequest().withMessageBody(body).withDelaySeconds(delay),
            new SendMessageResult().withMessageId("12345")
          )
          None
        }
      }
    )

    val (probe, future) = TestSource.probe[MessageActionPair]
      .toMat(Sink.fromGraph(SqsAckSinkShape(sqsClient)))(Keep.both)
      .run()

    probe
      .sendNext((new Message().withBody("foo"), RequeueWithDelay(9)))
      .sendComplete()

    Await.result(future, 100.second) shouldBe Done
    verify(sqsClient, times(1)).sendWithDelayAsync(any(), any(), any())
  }
} 
Example 93
Source File: SearchInputWithRulesSpec.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.sql.Connection

import models.rules.{SynonymRule, SynonymRuleId}
import org.scalatest.{FlatSpec, Matchers}
import utils.WithInMemoryDB

class SearchInputWithRulesSpec extends FlatSpec with Matchers with WithInMemoryDB with TestData {

  private val tag = InputTag.create(None, Some("tenant"), "MO", exported = true)

  "SearchInputWithRules" should "load lists with hundreds of entries successfully" in {
    db.withConnection { implicit conn =>
      SolrIndex.insert(indexDe)
      SolrIndex.insert(indexEn)
      InputTag.insert(tag)

      insertInputs(300, indexDe.id, "term_de")
      insertInputs(200, indexEn.id, "term_en")

      val inputsDe = SearchInputWithRules.loadWithUndirectedSynonymsAndTagsForSolrIndexId(indexDe.id)
      inputsDe.size shouldBe 300
      for (input <- inputsDe) {
        input.term should startWith("term_de_")
        input.tags.size shouldBe 1
        input.tags.head.displayValue shouldBe "tenant:MO"
        input.synonymRules.size shouldBe 1 // Only undirected synonyms should be loaded
        input.synonymRules.head.term should startWith("term_de_synonym_")
      }

      SearchInputWithRules.loadWithUndirectedSynonymsAndTagsForSolrIndexId(indexEn.id).size shouldBe 200
    }
  }

  private def insertInputs(count: Int, indexId: SolrIndexId, termPrefix: String)(implicit conn: Connection): Unit = {
    for (i <- 0 until count) {
      val input = SearchInput.insert(indexId, s"${termPrefix}_$i")
      SynonymRule.updateForSearchInput(input.id, Seq(
        SynonymRule(SynonymRuleId(), SynonymRule.TYPE_UNDIRECTED, s"${termPrefix}_synonym_$i", isActive = true),
        SynonymRule(SynonymRuleId(), SynonymRule.TYPE_DIRECTED, s"${termPrefix}_directedsyn_$i", isActive = true),
      ))
      TagInputAssociation.updateTagsForSearchInput(input.id, Seq(tag.id))
    }
  }

  "SearchInputWithRules" should "be (de)activatable" in {
    db.withConnection { implicit conn =>
      SolrIndex.insert(indexDe)

      val input = SearchInput.insert(indexDe.id, "my input")
      input.isActive shouldBe true

      SearchInput.update(input.id, input.term, false, input.comment)
      SearchInput.loadById(input.id).get.isActive shouldBe false

      SearchInput.update(input.id, input.term, true, input.comment)
      SearchInput.loadById(input.id).get.isActive shouldBe true
    }
  }

  "SearchInputWithRules" should "have a modifiable comment" in {
    db.withConnection { implicit conn =>
      SolrIndex.insert(indexDe)

      val input = SearchInput.insert(indexDe.id, "my input")
      input.comment shouldBe ""

      SearchInput.update(input.id, input.term, input.isActive, "My #magic comment.")
      SearchInput.loadById(input.id).get.comment shouldBe "My #magic comment."

      SearchInput.update(input.id, input.term, input.isActive, "My #magic comment - updated.")
      SearchInput.loadById(input.id).get.comment shouldBe "My #magic comment - updated."
    }
  }

} 
Example 94
Source File: RulesTxtDeploymentServiceSpec.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.util.zip.ZipInputStream

import org.apache.commons.io.IOUtils
import org.scalatest.{FlatSpec, Matchers}

class RulesTxtDeploymentServiceSpec extends FlatSpec with Matchers with ApplicationTestBase {

  private lazy val service = injector.instanceOf[RulesTxtDeploymentService]
  private var inputIds: Seq[SearchInputId] = Seq.empty

  override protected def beforeAll(): Unit = {
    super.beforeAll()

    createTestCores()
    inputIds = createTestRule()
  }

  private def rulesFileContent(ruleIds: Seq[SearchInputId]): String = s"""aerosmith =>
                           |	SYNONYM: mercury
                           |	DOWN(10): battery
                           |	UP(10): notebook
                           |	FILTER: zz top
                           |	@{
                           |	  "_log" : "${ruleIds.head}"
                           |	}@
                           |
                           |mercury =>
                           |	SYNONYM: aerosmith
                           |	DOWN(10): battery
                           |	UP(10): notebook
                           |	FILTER: zz top
                           |	@{
                           |	  "_log" : "${ruleIds.head}"
                           |	}@
                           |
                           |shipping =>
                           |	DECORATE: REDIRECT http://xyz.com/shipping
                           |	@{
                           |	  "_log" : "${ruleIds.last}"
                           |	}@""".stripMargin

  "RulesTxtDeploymentService" should "generate rules files with correct file names" in {
    val rulesTxt = service.generateRulesTxtContentWithFilenames(core1Id, "LIVE", logDebug = false)
    rulesTxt.solrIndexId shouldBe core1Id
    rulesTxt.decompoundRules shouldBe empty
    rulesTxt.regularRules.content.trim shouldBe rulesFileContent(inputIds)

    rulesTxt.regularRules.sourceFileName shouldBe "/tmp/search-management-ui_rules-txt.tmp"
    rulesTxt.regularRules.destinationFileName shouldBe "/usr/bin/solr/liveCore/conf/rules.txt"
  }

  it should "validate the rules files correctly" in {
    val rulesTxt = service.generateRulesTxtContentWithFilenames(core1Id, "LIVE", logDebug = false)
    service.validateCompleteRulesTxts(rulesTxt, logDebug = false) shouldBe empty

    val badRulesTxt = rulesTxt.copy(regularRules = rulesTxt.regularRules.copy(content = "a very bad rules file"))
    service.validateCompleteRulesTxts(badRulesTxt, logDebug = false) shouldBe List("Line 1: Missing input for instruction")
  }

  it should "provide a zip file with all rules files" in {
    val out = new ByteArrayOutputStream()
    service.writeAllRulesTxtFilesAsZipFileToStream(out)

    val bytes = out.toByteArray
    val zipStream = new ZipInputStream(new ByteArrayInputStream(bytes))
    val firstEntry = zipStream.getNextEntry
    firstEntry.getName shouldBe "rules_core1.txt"
    IOUtils.toString(zipStream, "UTF-8").trim shouldBe rulesFileContent(inputIds)
    val secondEntry = zipStream.getNextEntry
    secondEntry.getName shouldBe "rules_core2.txt"
    IOUtils.toString(zipStream, "UTF-8").trim shouldBe ""
  }

} 
Example 95
Source File: DBCompatibilitySpec.scala    From smui   with Apache License 2.0 5 votes vote down vote up
package models

import java.time.LocalDateTime

import models.rules._
import org.scalatest.{FlatSpec, Matchers}
import play.api.db.Database

abstract class DBCompatibilitySpec extends FlatSpec with Matchers with TestData {

  protected def db: Database

  // Set millis/nanos of second to 0 since MySQL does not save them
  // and so comparisons would fail if they were set
  private val now = LocalDateTime.now().withNano(0)

  "Most important DB queries" should "work using this database" in {
    db.withConnection { implicit conn =>
      SolrIndex.insert(indexDe)
      SolrIndex.loadNameById(indexDe.id) shouldBe indexDe.name
      SolrIndex.listAll shouldBe Seq(indexDe)

      val tag = InputTag(InputTagId(), Some(indexDe.id), Some("testProperty"), "testValue",
        exported = true, predefined = false, now)
      InputTag.insert(tag)
      InputTag.loadAll() shouldBe Seq(tag)

      val input = SearchInput.insert(indexDe.id, "test")
      val inputWithRules = SearchInputWithRules(input.id, input.term,
        List(SynonymRule(SynonymRuleId(), SynonymRule.TYPE_UNDIRECTED, "testSynonym", isActive = true)),
        List(UpDownRule(UpDownRuleId(), UpDownRule.TYPE_UP, 5, "upDownTerm", isActive = true)),
        List(FilterRule(FilterRuleId(), "filterTerm", isActive = true)),
        List(DeleteRule(DeleteRuleId(), "deleteTerm", isActive = true)),
        List(RedirectRule(RedirectRuleId(), "/testTarget", isActive = true)),
        List(tag),
        true,
        "Some search input comment."
      )
      SearchInputWithRules.update(inputWithRules)
      SearchInputWithRules.loadById(input.id) shouldBe Some(inputWithRules)

      SearchInputWithRules.loadWithUndirectedSynonymsAndTagsForSolrIndexId(indexDe.id) shouldBe Seq(
        inputWithRules.copy(upDownRules = Nil, filterRules = Nil, deleteRules = Nil, redirectRules = Nil)
      )

      SearchInputWithRules.delete(input.id)
      SearchInputWithRules.loadById(input.id) shouldBe None

      val field1 = SuggestedSolrField.insert(indexDe.id, "title")
      val field2 = SuggestedSolrField.insert(indexDe.id, "description")
      SuggestedSolrField.listAll(indexDe.id).toSet shouldBe Set(field1, field2)

      InputTag.deleteByIds(Seq(tag.id))
      InputTag.loadAll() shouldBe Nil
    }
  }

} 
Example 96
Source File: DistancesSpec.scala    From FunctionalProgrammingExercises   with MIT License 5 votes vote down vote up
package com.cutajarjames.exercise1

import org.scalatest.{FlatSpec, Matchers}

class DistancesSpec extends FlatSpec with Matchers {


  it should "output an empty list if empty input" in {
    (new Distances).allDistancesFurther(Array(), 10) shouldEqual Seq()
  }

  it should "output an empty list if only one input" in {
    (new Distances).allDistancesFurther(Array((1,2)), 10) shouldEqual Seq()
  }

  it should "output only one distance for two points" in {
    (new Distances).allDistancesFurther(Array((0,0), (30,40)), 10).map(math.round(_).toInt) shouldEqual Seq(50)
  }

  it should "output only an empty list for two points if min distance is smaller" in {
    (new Distances).allDistancesFurther(Array((0,0), (30,40)), 100) shouldEqual Seq()
  }

  it should "output 10 distance points for 5 points" in {
    val inputs = Array((0,0), (30,40), (25,3), (78,22), (97,12))
    (new Distances).allDistancesFurther(inputs, -1).map(math.round(_).toInt) shouldEqual Seq(50, 25, 81, 98, 37, 51, 73, 56, 73, 21)
  }

  it should "output 9 distance points for 5 points if one gets filtered" in {
    val inputs = Array((0,0), (30,40), (25,34), (78,22), (97,12))
    (new Distances).allDistancesFurther(inputs, 10).map(math.round(_).toInt) shouldEqual Seq(50, 42, 81, 98, 51, 73, 54, 75, 21)
  }

  it should "output empty list for 5 points if all gets filtered" in {
    val inputs = Array((0,0), (30,40), (25,34), (78,22), (97,12))
    (new Distances).allDistancesFurther(inputs, 100).map(math.round(_).toInt) shouldEqual Seq()
  }

} 
Example 97
Source File: SimpleSolverSpec.scala    From FunctionalProgrammingExercises   with MIT License 5 votes vote down vote up
package com.cutajarjames.exercise6

import org.scalatest.{FlatSpec, Matchers}

class SimpleSolverSpec extends FlatSpec with Matchers {
  it should "be able to do 3 plus 2" in {
    (new SimpleSolver).solve(Addition(3,2)) shouldEqual Some(5.0)
  }

  it should "be able to do 3 times 2" in {
    (new SimpleSolver).solve(Multiplication(3,2)) shouldEqual Some(6.0)
  }

  it should "be able to do 3 minus 2" in {
    (new SimpleSolver).solve(Subtraction(3,2)) shouldEqual Some(1.0)
  }

  it should "be able to do 3 div 2" in {
    (new SimpleSolver).solve(Division(3,2)) shouldEqual Some(1.5)
  }

  it should "be able to do 3 div 0 by return None" in {
    (new SimpleSolver).solve(Division(3,0)) shouldEqual None
  }
} 
Example 98
Source File: FindingHomeSpec.scala    From FunctionalProgrammingExercises   with MIT License 5 votes vote down vote up
package com.cutajarjames.exercise4

import org.scalatest.{FlatSpec, Matchers}

class FindingHomeSpec extends FlatSpec with Matchers {
  it should "output the user.home" in {
    System.setProperty("user.home", "/Users/cutajar")
    System.setProperty("doc.home", "/documents")
    System.setProperty("appdata.home", "/appdata")
    (new FindingHome).findAHome() shouldEqual Some("/Users/cutajar")
  }

  it should "output the doc.home" in {
    System.getProperties.remove("user.home")
    System.setProperty("doc.home", "/documents")
    System.setProperty("appdata.home", "/appdata")
    (new FindingHome).findAHome() shouldEqual Some("/documents")
  }

  it should "output the /appdata" in {
    System.getProperties.remove("user.home")
    System.getProperties.remove("doc.home")
    System.setProperty("appdata.home", "/appdata")
    (new FindingHome).findAHome() shouldEqual Some("/appdata")
  }

  it should "output None if none of the properties are set" in {
    System.getProperties.remove("user.home")
    System.getProperties.remove("doc.home")
    System.getProperties.remove("appdata.home")
    (new FindingHome).findAHome() shouldEqual None
  }

  it should "output None if properties have a whitespace emtpy string" in {
    System.setProperty("user.home", " ")
    System.setProperty("doc.home", "   ")
    System.setProperty("appdata.home", "")
    (new FindingHome).findAHome() shouldEqual None
  }
} 
Example 99
Source File: SecretSantaSpec.scala    From FunctionalProgrammingExercises   with MIT License 5 votes vote down vote up
package com.cutajarjames.exercise4

import org.scalatest.{FlatSpec, Matchers}

class SecretSantaSpec extends FlatSpec with Matchers {
  it should "output the correct pairing for 3 items" in {
    val pairing = (new SecretSanta).generateGiftPairing(List("James", "Ruth", "Isabel"))
    pairing.keys shouldEqual Set("James", "Ruth", "Isabel")
    pairing.values.size shouldEqual 3
    pairing.values.toSet shouldEqual Set("James", "Ruth", "Isabel")
    pairing.foreach{case (v,k) => v shouldNot be(k)}
  }

  it should "output the correct pairing for 10 items" in {
    val pairing = (new SecretSanta).generateGiftPairing(List("a", "b", "c", "d", "e", "f", "g", "h", "i", "k"))
    pairing.keys shouldEqual Set("a", "b", "c", "d", "e", "f", "g", "h", "i", "k")
    pairing.values.size shouldEqual 10
    pairing.values.toSet shouldEqual Set("a", "b", "c", "d", "e", "f", "g", "h", "i", "k")
    pairing.foreach{case (v,k) => v shouldNot be(k)}
  }

  it should "output self pairing in case of one name" in {
    val pairing = (new SecretSanta).generateGiftPairing(List("James"))
    pairing.size shouldEqual 1
    pairing.get("James") should contain("James")
  }

} 
Example 100
Source File: SwapElementsSpec.scala    From FunctionalProgrammingExercises   with MIT License 5 votes vote down vote up
package com.cutajarjames.exercise7

import org.scalatest.{FlatSpec, Matchers}

class SwapElementsSpec extends FlatSpec with Matchers {
  it should "swap the first two elements with list of 4 items" in {
    (new SwapElements).swapFirstAndSecond(List(1, 2, 3, 4)) shouldEqual List(2, 1, 3, 4)
  }

  it should "not swap anything if the list is one element" in {
    (new SwapElements).swapFirstAndSecond(List(9)) shouldEqual List(9)
  }

  it should "not swap anything if the list is empty" in {
    (new SwapElements).swapFirstAndSecond(List()) shouldEqual List()
  }

  it should "swap the first two elements with list of 2 items" in {
    (new SwapElements).swapFirstAndSecond(List(5, 9)) shouldEqual List(9, 5)
  }
} 
Example 101
Source File: LazyPrimesSpec.scala    From FunctionalProgrammingExercises   with MIT License 5 votes vote down vote up
package com.cutajarjames.exercise3

import org.scalatest.{FlatSpec, Matchers}

class LazyPrimesSpec extends FlatSpec with Matchers {

  it should "output the correct sequence of the first 100 primes" in {
    (new LazyPrimes).allPrimes.take(100) shouldEqual Seq(2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43,
      47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163,
      167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281,
      283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, 419, 421,
      431, 433, 439, 443, 449, 457, 461, 463, 467, 479, 487, 491, 499, 503, 509, 521, 523, 541)
  }

  it should "output the correct sequence of the 1000th prime" in {
    (new LazyPrimes).allPrimes(1000) shouldEqual 7927
  }

} 
Example 102
Source File: HyperLogLog.scala    From spark-hyperloglog   with Apache License 2.0 5 votes vote down vote up
package com.mozilla.spark.sql.hyperloglog.test

import com.mozilla.spark.sql.hyperloglog.aggregates._
import com.mozilla.spark.sql.hyperloglog.functions._
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.functions._
import org.apache.spark.{SparkConf, SparkContext}
import org.scalatest.{FlatSpec, Matchers}

class HyperLogLogTest extends FlatSpec with Matchers{
 "Algebird's HyperLogLog" can "be used from Spark" in {
  val sparkConf = new SparkConf().setAppName("HyperLogLog")
  sparkConf.setMaster(sparkConf.get("spark.master", "local[1]"))

  val sc = new SparkContext(sparkConf)
  val sqlContext = new SQLContext(sc)
  import sqlContext.implicits._

  val hllMerge = new HyperLogLogMerge
  sqlContext.udf.register("hll_merge", hllMerge)
  sqlContext.udf.register("hll_create", hllCreate _)
  sqlContext.udf.register("hll_cardinality", hllCardinality _)

  val frame = sc.parallelize(List("a", "b", "c", "c"), 4).toDF("id")
  val count = frame
    .select(expr("hll_create(id, 12) as hll"))
    .groupBy()
    .agg(expr("hll_cardinality(hll_merge(hll)) as count"))
    .collect()
  count(0)(0) should be (3)
 }
} 
Example 103
Source File: GenericFlatSpecSuite.scala    From XSQL   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.sql.test

import org.scalatest.FlatSpec

import org.apache.spark.sql.Dataset


class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession {
  import testImplicits._

  private def ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS

  "A Simple Dataset" should "have the specified number of elements" in {
    assert(8 === ds.count)
  }
  it should "have the specified number of unique elements" in {
      assert(8 === ds.distinct.count)
  }
  it should "have the specified number of elements in each column" in {
    assert(8 === ds.select("_1").count)
    assert(8 === ds.select("_2").count)
  }
  it should "have the correct number of distinct elements in each column" in {
    assert(8 === ds.select("_1").distinct.count)
    assert(4 === ds.select("_2").distinct.count)
  }
} 
Example 104
Source File: AccessTokenSpec.scala    From akka-http-oauth2-client   with Apache License 2.0 5 votes vote down vote up
package com.github.dakatsuka.akka.http.oauth2.client

import akka.actor.ActorSystem
import akka.http.scaladsl.model.{ HttpEntity, HttpResponse, StatusCodes }
import akka.http.scaladsl.model.ContentTypes.`application/json`
import akka.stream.{ ActorMaterializer, Materializer }
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.time.{ Millis, Seconds, Span }
import org.scalatest.{ BeforeAndAfterAll, DiagrammedAssertions, FlatSpec }

import scala.concurrent.{ Await, ExecutionContext }
import scala.concurrent.duration.Duration

class AccessTokenSpec extends FlatSpec with DiagrammedAssertions with ScalaFutures with BeforeAndAfterAll {
  implicit val system: ActorSystem        = ActorSystem()
  implicit val ec: ExecutionContext       = system.dispatcher
  implicit val materializer: Materializer = ActorMaterializer()
  implicit val defaultPatience: PatienceConfig =
    PatienceConfig(timeout = Span(5, Seconds), interval = Span(700, Millis))

  override def afterAll(): Unit = {
    Await.ready(system.terminate(), Duration.Inf)
  }

  behavior of "AccessToken"

  it should "apply from HttpResponse" in {
    val accessToken  = "xxx"
    val tokenType    = "bearer"
    val expiresIn    = 86400
    val refreshToken = "yyy"

    val httpResponse = HttpResponse(
      status = StatusCodes.OK,
      headers = Nil,
      entity = HttpEntity(
        `application/json`,
        s"""
           |{
           |  "access_token": "$accessToken",
           |  "token_type": "$tokenType",
           |  "expires_in": $expiresIn,
           |  "refresh_token": "$refreshToken"
           |}
         """.stripMargin
      )
    )

    val result = AccessToken(httpResponse)

    whenReady(result) { token =>
      assert(token.accessToken == accessToken)
      assert(token.tokenType == tokenType)
      assert(token.expiresIn == expiresIn)
      assert(token.refreshToken.contains(refreshToken))
    }
  }
} 
Example 105
Source File: VectorClockOpsTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.vectorclocks

import org.scalatest.{FlatSpec, Matchers}

class VectorClockOpsTest extends FlatSpec with Matchers {

  behavior of "Vector Clock Ops"

  it should "create Vector Clock instance from plain string" in {
    "A:2".toVectorClock[String]           shouldBe VectorClock(Map("A" -> Counter(2)))
    "A:1, B:2".toVectorClock[String]      shouldBe VectorClock(Map("A" -> Counter(1), "B" -> Counter(2)))
    "A:1, B:1, C:1".toVectorClock[String] shouldBe VectorClock(Map("A" -> Counter(1), "B" -> Counter(1), "C" -> Counter(1)))
  }

  it should "create Vector Clock instance from plain string with numerical ids" in {
    import VectorClockOps.intAsId

    ("1:2": VectorClock[Int])       shouldBe VectorClock(Map(1 -> Counter(2)))
    ("1:2, 2:10": VectorClock[Int]) shouldBe VectorClock(Map(1 -> Counter(2), 2 -> Counter(10)))
  }
} 
Example 106
Source File: DataTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db

import java.util.UUID

import justin.db.consistenthashing.NodeId
import justin.db.replica.PreferenceList
import justin.db.vectorclocks.{Counter, VectorClock}
import org.scalatest.{FlatSpec, Matchers}

class DataTest extends FlatSpec with Matchers {

  behavior of "Data"

  it should "update its empty inner Vector Clock based on preference list" in {
    // given
    val data           = Data(id = UUID.randomUUID(), value = "some value")
    val preferenceList = PreferenceList(NodeId(1), List(NodeId(5), NodeId(8)))

    // when
    val updatedData = Data.updateVclock(data, preferenceList)

    // then
    val expectedVclock = VectorClock[NodeId](Map(
      NodeId(1) -> Counter(1),
      NodeId(5) -> Counter(1),
      NodeId(8) -> Counter(1))
    )
    updatedData shouldBe Data(data.id, data.value, expectedVclock, updatedData.timestamp)
  }

  it should "increase vector clock's counter of repeated nodeId when updating data" in {
    // given
    val data           = Data(id = UUID.randomUUID(), value = "some value")
    val preferenceList = PreferenceList(NodeId(1), List(NodeId(1), NodeId(1)))

    // when
    val updatedData = Data.updateVclock(data, preferenceList)

    // then
    val expectedVclock = VectorClock[NodeId](Map(
      NodeId(1) -> Counter(3)
    ))
    updatedData shouldBe Data(data.id, data.value, expectedVclock, data.timestamp)
  }

  it should "increase already existed vector clock's counter when updating data" in {
    // given
    val initVClock     = VectorClock[NodeId](Map(NodeId(1) -> Counter(3)))
    val data = Data(id = UUID.randomUUID(), value = "some value", initVClock)
    val preferenceList = PreferenceList(NodeId(1), List(NodeId(5), NodeId(8)))

    // when
    val updatedData = Data.updateVclock(data, preferenceList)

    // then
    val expectedVclock = VectorClock[NodeId](Map(
      NodeId(1) -> Counter(4),
      NodeId(5) -> Counter(1),
      NodeId(8) -> Counter(1))
    )
    updatedData shouldBe Data(data.id, data.value, expectedVclock, data.timestamp)
  }
} 
Example 107
Source File: RegisterNodeSerializerTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.kryo

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import justin.db.actors.protocol.RegisterNode
import justin.db.consistenthashing.NodeId
import org.scalatest.{FlatSpec, Matchers}

class RegisterNodeSerializerTest extends FlatSpec with Matchers {

  behavior of "RegisterNode Serializer"

  it should "serialize/deserialize correctly" in {
    // kryo init
    val kryo = new Kryo()
    kryo.register(classOf[RegisterNode], RegisterNodeSerializer)

    // object
    val serializedData = RegisterNode(NodeId(1))

    // serialization
    val bos    = new ByteArrayOutputStream()
    val output = new Output(bos)
    val _      = kryo.writeObject(output, serializedData)
    output.flush()

    // deserialization
    val bis              = new ByteArrayInputStream(bos.toByteArray)
    val input            = new Input(bis)
    val deserializedData = kryo.readObject(input, classOf[RegisterNode])

    serializedData shouldBe deserializedData
  }
} 
Example 108
Source File: SerializerInitTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.kryo

import com.esotericsoftware.kryo.Kryo
import org.scalatest.{FlatSpec, Matchers}

class SerializerInitTest extends FlatSpec with Matchers {

  behavior of "SerializerInit"

  it should "init Kryo serializer" in {
    val kryo = new Kryo()
    val serializerInit = new SerializerInit()
    serializerInit.customize(kryo)

    // cluster
    val classId_50 = 50
    kryo.getRegistration(classId_50).getId          shouldBe 50
    kryo.getRegistration(classId_50).getSerializer  shouldBe RegisterNodeSerializer
    kryo.getRegistration(classId_50).getType        shouldBe classOf[justin.db.actors.protocol.RegisterNode]


    // write -- request
    val classId_60 = 60
    kryo.getRegistration(classId_60).getId          shouldBe 60
    kryo.getRegistration(classId_60).getSerializer  shouldBe StorageNodeWriteDataLocalSerializer
    kryo.getRegistration(classId_60).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeWriteDataLocal]

    // write -- responses
    val classId_70 = 70
    kryo.getRegistration(classId_70).getId          shouldBe 70
    kryo.getRegistration(classId_70).getSerializer  shouldBe StorageNodeWriteResponseSerializer
    kryo.getRegistration(classId_70).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeFailedWrite]

    val classId_71 = 71
    kryo.getRegistration(classId_71).getId          shouldBe 71
    kryo.getRegistration(classId_71).getSerializer  shouldBe StorageNodeWriteResponseSerializer
    kryo.getRegistration(classId_71).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeSuccessfulWrite]

    val classId_72 = 72
    kryo.getRegistration(classId_72).getId          shouldBe 72
    kryo.getRegistration(classId_72).getSerializer  shouldBe StorageNodeWriteResponseSerializer
    kryo.getRegistration(classId_72).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeConflictedWrite]

    // read - request
    val classId_80 = 80
    kryo.getRegistration(classId_80).getId          shouldBe 80
    kryo.getRegistration(classId_80).getSerializer  shouldBe StorageNodeLocalReadSerializer
    kryo.getRegistration(classId_80).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeLocalRead]

    // read - responses
    val classId_90 = 90
    kryo.getRegistration(classId_90).getId          shouldBe 90
    kryo.getRegistration(classId_90).getSerializer  shouldBe StorageNodeReadResponseSerializer
    kryo.getRegistration(classId_90).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeFoundRead]

    val classId_91 = 91
    kryo.getRegistration(classId_91).getId          shouldBe 91
    kryo.getRegistration(classId_91).getSerializer  shouldBe StorageNodeReadResponseSerializer
    kryo.getRegistration(classId_91).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeConflictedRead]

    val classId_92 = 92
    kryo.getRegistration(classId_92).getId          shouldBe 92
    kryo.getRegistration(classId_92).getSerializer  shouldBe StorageNodeReadResponseSerializer
    kryo.getRegistration(classId_92).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeNotFoundRead]

    val classId_93 = 93
    kryo.getRegistration(classId_93).getId          shouldBe 93
    kryo.getRegistration(classId_93).getSerializer  shouldBe StorageNodeReadResponseSerializer
    kryo.getRegistration(classId_93).getType        shouldBe classOf[justin.db.actors.protocol.StorageNodeFailedRead]
  }
} 
Example 109
Source File: DataSerializerTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.kryo

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.util.UUID

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import justin.db.Data
import justin.db.consistenthashing.NodeId
import justin.db.vectorclocks.{Counter, VectorClock}
import org.scalatest.{FlatSpec, Matchers}

class DataSerializerTest extends FlatSpec with Matchers {

  behavior of "Data Serializer"

  it should "serialize/deserialize correctly" in {
    // kryo init
    val kryo = new Kryo()
    kryo.register(classOf[justin.db.Data], DataSerializer)

    // object
    val vClock         = VectorClock[NodeId](Map(NodeId(1) -> Counter(3)))
    val timestamp      = System.currentTimeMillis()
    val serializedData = Data(id = UUID.randomUUID(), value = "some value", vClock, timestamp)

    // serialization
    val bos    = new ByteArrayOutputStream()
    val output = new Output(bos)
    val _      = kryo.writeObject(output, serializedData)
    output.flush()

    // deserialization
    val bis              = new ByteArrayInputStream(bos.toByteArray)
    val input            = new Input(bis)
    val deserializedData = kryo.readObject(input, classOf[Data])

    serializedData shouldBe deserializedData
  }
} 
Example 110
Source File: StorageNodeWriteDataLocalSerializerTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.kryo

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.util.UUID

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import justin.db.Data
import justin.db.actors.protocol.StorageNodeWriteDataLocal
import justin.db.consistenthashing.NodeId
import justin.db.vectorclocks.{Counter, VectorClock}
import org.scalatest.{FlatSpec, Matchers}

class StorageNodeWriteDataLocalSerializerTest extends FlatSpec with Matchers {

  behavior of "StorageNodeWriteDataLocal Serializer"

  it should "serialize/deserialize StorageNodeWriteDataLocal" in {
    // kryo init
    val kryo = new Kryo()
    kryo.register(classOf[StorageNodeWriteDataLocal], StorageNodeWriteDataLocalSerializer)

    // object
    val data = Data(
      id        = UUID.randomUUID(),
      value     = "some value",
      vclock    = VectorClock[NodeId](Map(NodeId(1) -> Counter(3))),
      timestamp = System.currentTimeMillis()
    )
    val serializedData = StorageNodeWriteDataLocal(data)

    // serialization
    val bos    = new ByteArrayOutputStream()
    val output = new Output(bos)
    val _      = kryo.writeObject(output, serializedData)
    output.flush()

    // deserialization
    val bis              = new ByteArrayInputStream(bos.toByteArray)
    val input            = new Input(bis)
    val deserializedData = kryo.readObject(input, classOf[StorageNodeWriteDataLocal])

    serializedData shouldBe deserializedData
  }
} 
Example 111
Source File: StorageNodeLocalReadSerializerTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.kryo

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.util.UUID

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import justin.db.actors.protocol.StorageNodeLocalRead
import org.scalatest.{FlatSpec, Matchers}

class StorageNodeLocalReadSerializerTest extends FlatSpec with Matchers {

  behavior of "StorageNodeLocalReader Serializer"

  it should "serialize/deserialize correctly" in {
    // kryo init
    val kryo = new Kryo()
    kryo.register(classOf[StorageNodeLocalRead], StorageNodeLocalReadSerializer)

    // object
    val serializedData = StorageNodeLocalRead(UUID.randomUUID())

    // serialization
    val bos    = new ByteArrayOutputStream()
    val output = new Output(bos)
    val _      = kryo.writeObject(output, serializedData)
    output.flush()

    // deserialization
    val bis              = new ByteArrayInputStream(bos.toByteArray)
    val input            = new Input(bis)
    val deserializedData = kryo.readObject(input, classOf[StorageNodeLocalRead])

    serializedData shouldBe deserializedData
  }
} 
Example 112
Source File: IsPrimaryOrReplicaTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.replica

import java.util.UUID

import justin.db.consistenthashing.{NodeId, Ring}
import justin.db.storage.PluggableStorageProtocol.DataOriginality
import org.scalatest.{FlatSpec, Matchers}

class IsPrimaryOrReplicaTest extends FlatSpec with Matchers {

  behavior of "Data Originality Resolver"

  it should "reason exemplary data's id as a replica" in {
    // given
    val nodeId   = NodeId(0)
    val ring     = Ring.apply(nodesSize = 3, partitionsSize = 21)
    val resolver = new IsPrimaryOrReplica(nodeId, ring)
    val id       = UUID.fromString("179d6eb0-681d-4277-9caf-3d6d60e9faf9")

    // when
    val originality = resolver.apply(id)

    // then
    originality shouldBe a[DataOriginality.Replica]
  }

  it should "reason exemplary data's id as a primary" in {
    // given
    val nodeId   = NodeId(0)
    val ring     = Ring.apply(nodesSize = 3, partitionsSize = 21)
    val resolver = new IsPrimaryOrReplica(nodeId, ring)
    val id       = UUID.fromString("16ec44cd-5b4e-4b38-a647-206c1dc11b50")

    // when
    val originality = resolver.apply(id)

    // then
    originality shouldBe a[DataOriginality.Primary]
  }
} 
Example 113
Source File: ReplicaLocalWriterTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.replica.write

import java.util.UUID

import justin.db.Data
import justin.db.actors.protocol.{StorageNodeConflictedWrite, StorageNodeFailedWrite, StorageNodeSuccessfulWrite}
import justin.db.consistenthashing.NodeId
import justin.db.storage.PluggableStorageProtocol.{Ack, DataOriginality, StorageGetData}
import justin.db.storage.{GetStorageProtocol, JustinData, PutStorageProtocol}
import justin.db.vectorclocks.{Counter, VectorClock}
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{FlatSpec, Matchers}

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future
import scala.concurrent.duration._

class ReplicaLocalWriterTest extends FlatSpec with Matchers with ScalaFutures {

  behavior of "Replica Local Writer"

  override implicit def patienceConfig: PatienceConfig = PatienceConfig(10.seconds, 50.millis)

  
  it should "fail to write predecessor to already stored data" in {
    // given
    val id      = UUID.randomUUID()
    val data    = Data(id, "some-value", VectorClock(Map(NodeId(1) -> Counter(2))))
    val newData = Data(id, "some-value-2", VectorClock(Map(NodeId(1) -> Counter(1))))
    val writer = new ReplicaLocalWriter(new GetStorageProtocol with PutStorageProtocol {
      override def get(id: UUID)(resolveOriginality: (UUID) => DataOriginality): Future[StorageGetData] = Future.successful(StorageGetData.Single(data))
      override def put(data: JustinData)(resolveOriginality: (UUID) => DataOriginality): Future[Ack] = ???
    })

    // when
    val result = writer.apply(newData, null)

    // then
    whenReady(result) { _ shouldBe StorageNodeFailedWrite(id) }
  }

  it should "get conflicted write when trying to save new data with conflicted vector clock comparing to already existed one" in {
    // given
    val id      = UUID.randomUUID()
    val data    = Data(id, "some-value", VectorClock(Map(NodeId(1) -> Counter(1))))
    val newData = Data(id, "some-value-2", VectorClock(Map(NodeId(2) -> Counter(1))))
    val writer = new ReplicaLocalWriter(new GetStorageProtocol with PutStorageProtocol {
      override def get(id: UUID)(resolveOriginality: (UUID) => DataOriginality): Future[StorageGetData] = Future.successful(StorageGetData.Single(data))
      override def put(data: JustinData)(resolveOriginality: (UUID) => DataOriginality): Future[Ack] = Ack.future
    })

    // when
    val result = writer.apply(newData, null)

    // then
    whenReady(result) { _ shouldBe StorageNodeConflictedWrite(data, newData) }
  }

  it should "get successful write when trying to save new data with consequent vector clock comparing to already existed one" in {
    // given
    val id      = UUID.randomUUID()
    val data    = Data(id, "some-value", VectorClock(Map(NodeId(1) -> Counter(1))))
    val newData = Data(id, "some-value-2", VectorClock(Map(NodeId(1) -> Counter(2))))
    val writer = new ReplicaLocalWriter(new GetStorageProtocol with PutStorageProtocol {
      override def get(id: UUID)(resolveOriginality: (UUID) => DataOriginality): Future[StorageGetData] = Future.successful(StorageGetData.Single(data))
      override def put(data: JustinData)(resolveOriginality: (UUID) => DataOriginality): Future[Ack] = Ack.future
    })

    // when
    val result = writer.apply(newData, null)

    // then
    whenReady(result) { _ shouldBe StorageNodeSuccessfulWrite(id) }
  }
} 
Example 114
Source File: ReplicaWriteAgreementTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.replica.write

import java.util.UUID

import justin.db.actors.protocol.{StorageNodeFailedWrite, StorageNodeSuccessfulWrite}
import justin.db.replica.W
import org.scalatest.{FlatSpec, Matchers}

class ReplicaWriteAgreementTest extends FlatSpec with Matchers {

  behavior of "Reach Consensus of Replicated Writes"

  it should "agreed on \"SuccessfulWrite\" if number of successful write is not less than client expectations" in {
    // given
    val w = W(2)
    val writes = List(StorageNodeSuccessfulWrite(UUID.randomUUID()), StorageNodeSuccessfulWrite(UUID.randomUUID()), StorageNodeFailedWrite(UUID.randomUUID()))

    // when
    val result = new ReplicaWriteAgreement().reach(w)(writes)

    // then
    result shouldBe WriteAgreement.Ok
  }

  it should "agreed on \"NotEnoughWrites\" if number of successful write is less than client expectations" in {
    // given
    val w = W(3)
    val writes = List(StorageNodeSuccessfulWrite(UUID.randomUUID()), StorageNodeSuccessfulWrite(UUID.randomUUID()), StorageNodeFailedWrite(UUID.randomUUID()))

    // when
    val result = new ReplicaWriteAgreement().reach(w)(writes)

    // then
    result shouldBe WriteAgreement.NotEnoughWrites
  }
} 
Example 115
Source File: ReplicaLocalReaderTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.replica.read

import java.util.UUID

import justin.db.Data
import justin.db.actors.protocol.{StorageNodeFailedRead, StorageNodeFoundRead, StorageNodeNotFoundRead}
import justin.db.consistenthashing.NodeId
import justin.db.storage.GetStorageProtocol
import justin.db.storage.PluggableStorageProtocol.{DataOriginality, StorageGetData}
import justin.db.vectorclocks.VectorClock
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{FlatSpec, Matchers}

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.concurrent.Future

class ReplicaLocalReaderTest extends FlatSpec with Matchers with ScalaFutures {

  behavior of "Replica Local Reader"

  override implicit def patienceConfig: PatienceConfig = PatienceConfig(10.seconds, 50.millis)

  it should "found data for existing key" in {
    // given
    val id   = UUID.randomUUID()
    val data = Data(id, "value", VectorClock[NodeId]().increase(NodeId(1)))
    val service = new ReplicaLocalReader(new GetStorageProtocol {
      override def get(id: UUID)(resolveOriginality: (UUID) => DataOriginality): Future[StorageGetData] = {
        Future.successful(StorageGetData.Single(data))
      }
    })

    // when
    val result = service.apply(id, null)

    // then
    whenReady(result) { _ shouldBe StorageNodeFoundRead(data) }
  }

  it should "not found data for non-existing key" in {
    // given
    val id = UUID.randomUUID()
    val service = new ReplicaLocalReader(new GetStorageProtocol {
      override def get(id: UUID)(resolveOriginality: (UUID) => DataOriginality): Future[StorageGetData] = {
        Future.successful(StorageGetData.None)
      }
    })

    // when
    val result = service.apply(id, null)

    // then
    whenReady(result) { _ shouldBe StorageNodeNotFoundRead(id) }
  }

  it should "recover failure reading" in {
    // given
    val id = UUID.randomUUID()
    val service = new ReplicaLocalReader(new GetStorageProtocol {
      override def get(id: UUID)(resolveOriginality: (UUID) => DataOriginality): Future[StorageGetData] = Future.failed(new Exception)
    })

    // when
    val result = service.apply(id, null)

    // then
    whenReady(result) { _ shouldBe StorageNodeFailedRead(id) }
  }
} 
Example 116
Source File: UnmarshallersTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.httpapi

import java.util.UUID

import org.scalatest.{FlatSpec, Matchers}
import spray.json.{DeserializationException, JsNumber, JsString}

class UnmarshallersTest extends FlatSpec with Matchers {

  behavior of "Unmarshaller"

  it should "encode JSON into UUID" in {
    val uuid = UUID.randomUUID()
    val jsString = JsString(uuid.toString)

    Unmarshallers.UuidFormat.read(jsString) shouldBe uuid
  }

  it should "decode UUID into JSON" in {
    val uuid = UUID.randomUUID()
    val expectedJSON = Unmarshallers.UuidFormat.write(uuid)

    expectedJSON shouldBe JsString(uuid.toString)
  }

  it should "handle not expected format of JSON" in {
    val jsNumber = JsNumber(1)

    intercept[DeserializationException] {
      Unmarshallers.UuidFormat.read(jsNumber)
    }
  }

  it should "handle wrong format of UUID" in {
    val fakeUUID = "1-2-3-4"
    val jsString = JsString(fakeUUID)

    intercept[DeserializationException] {
      Unmarshallers.UuidFormat.read(jsString)
    }
  }
} 
Example 117
Source File: VectorClockHeaderTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.httpapi

import justin.db.consistenthashing.NodeId
import justin.db.vectorclocks.{Counter, VectorClock}
import org.scalatest.{FlatSpec, Matchers}

class VectorClockHeaderTest extends FlatSpec with Matchers {

  behavior of "Vector Clock Header"

  it should "parse string and create Vector Clock instance upon it" in {
    // given
    val encoded = "W1siMSIsMV0sWyIyIiwyXSxbIjMiLDldXQ=="

    // when
    val vClockHeader = VectorClockHeader.parse(encoded).get

    // then
    vClockHeader.vectorClock shouldBe VectorClock(Map(NodeId(1) -> Counter(1), NodeId(2) -> Counter(2), NodeId(3) -> Counter(9)))
  }

  it should "stringify Vector Clock instance" in {
    // given
    val vClock = VectorClock(Map(NodeId(1) -> Counter(1), NodeId(2) -> Counter(2), NodeId(3) -> Counter(9)))

    // when
    val encoded = VectorClockHeader(vClock).value()

    // then
    encoded shouldBe "W1siMSIsMV0sWyIyIiwyXSxbIjMiLDldXQ=="
  }

  it should "throw an Exception for not parsable Vector Clock" in {
    val vClock = null

    intercept[VectorClockHeaderException] {
      val encoded = VectorClockHeader(vClock).value()
    }
  }

  it should "render header in response" in {
    VectorClockHeader(null).renderInResponses() shouldBe true
  }

  it should "render header in request" in {
    VectorClockHeader(null).renderInRequests() shouldBe true
  }
} 
Example 118
Source File: JustinDirectivesTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.httpapi

import akka.http.scaladsl.server.directives._
import akka.http.scaladsl.testkit.ScalatestRouteTest
import justin.db.consistenthashing.NodeId
import justin.db.vectorclocks.{Counter, VectorClock}
import org.scalatest.{FlatSpec, Matchers}

class JustinDirectivesTest extends FlatSpec with Matchers with ScalatestRouteTest
  with RouteDirectives
  with JustinDirectives {

  behavior of "Justin Directives"

  it should "provide empty VectorClock instance when no header is passed" in {
    Get("/") ~> withVectorClockHeader(x => complete(x.vectorClock.toString)) ~> check {
      responseAs[String] shouldBe VectorClockHeader.empty.vectorClock.toString
    }
  }

  it should "provide instance of VectorClock build upon passed header" in {
    val vClock = VectorClock(Map(NodeId(1) -> Counter(1), NodeId(2) -> Counter(2), NodeId(3) -> Counter(9)))
    val header = VectorClockHeader(vClock)

    Get("/").addHeader(header) ~> withVectorClockHeader(x => complete(x.vectorClock.toString)) ~> check {
      responseAs[String] shouldBe vClock.toString
    }
  }
} 
Example 119
Source File: JustinDataSerializerTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.storage

import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.util.UUID

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}
import justin.db.storage.RocksDBStorage.JustinDataSerializer
import org.scalatest.{FlatSpec, Matchers}

class JustinDataSerializerTest extends FlatSpec with Matchers {

  behavior of "JustinDataSerializer"

  it should "serialize/deserialize JustinData with Kryo" in {
    val kryo = new Kryo()
    val data = JustinData(
      id        = UUID.randomUUID,
      value     = "to jest przykladowa wartość",
      vclock    = "vclock-value",
      timestamp = 1234124L
    )

    // serialize
    val output = new Output(new ByteArrayOutputStream())
    JustinDataSerializer.write(kryo, output, data)
    val dataBytes = output.getBuffer

    // deserialize
    val input = new Input(new ByteArrayInputStream(dataBytes))
    JustinDataSerializer.read(kryo, input, classOf[JustinData]) shouldBe data
  }
} 
Example 120
Source File: RocksDBStorageTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.storage

import java.nio.file.Files
import java.util.UUID

import justin.db.storage.PluggableStorageProtocol.{Ack, DataOriginality, StorageGetData}
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.{FlatSpec, Matchers}

import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._

class RocksDBStorageTest extends FlatSpec with Matchers  with ScalaFutures {

  behavior of "RocksDBStorage"

  it should "save 3 payloads and read them" in {
    val journal = Files.createTempDirectory("rocksdb")
    val rocksdb = new RocksDBStorage(journal.toFile)
    val data1 = JustinData(
      id        = UUID.randomUUID,
      value     = "1",
      vclock    = "vclock-value",
      timestamp = 1234124L
    )
    val data2 = JustinData(
      id        = UUID.randomUUID,
      value     = "1",
      vclock    = "vclock-value",
      timestamp = 1234124L
    )
    val data3 = JustinData(
      id        = UUID.randomUUID,
      value     = "3",
      vclock    = "vclock-value",
      timestamp = 1234124L
    )
    val dataOriginality = DataOriginality.Primary(ringPartitionId = 1)

    // PUT
    rocksdb.put(data1)(_ => dataOriginality).futureValue shouldBe Ack
    rocksdb.put(data2)(_ => dataOriginality).futureValue shouldBe Ack
    rocksdb.put(data3)(_ => dataOriginality).futureValue shouldBe Ack

    // GET
    rocksdb.get(data3.id)(_ => dataOriginality).futureValue shouldBe StorageGetData.Single(data3)
    rocksdb.get(data2.id)(_ => dataOriginality).futureValue shouldBe StorageGetData.Single(data2)
    rocksdb.get(data1.id)(_ => dataOriginality).futureValue shouldBe StorageGetData.Single(data1)
  }

  override implicit def patienceConfig: PatienceConfig = PatienceConfig(10.seconds, 50.millis)
} 
Example 121
Source File: UUIDSerializerTest.scala    From JustinDB   with Apache License 2.0 5 votes vote down vote up
package justin.db.storage

import java.io.ByteArrayInputStream
import java.util.UUID

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.Input
import justin.db.storage.RocksDBStorage.UUIDSerializer
import org.scalatest.{FlatSpec, Matchers}

class UUIDSerializerTest extends FlatSpec with Matchers {

  behavior of "UUIDSerializer"

  it should "serialize/deserialize UUID with Kryo" in {
    val uuid = UUID.randomUUID()
    val kryo = new Kryo()

    // serialize
    val bytes = RocksDBStorage.uuid2bytes(kryo, uuid)

    // deserialize
    val input = new Input(new ByteArrayInputStream(bytes))
    val id = UUIDSerializer.read(kryo, input, classOf[UUID])

    uuid shouldBe id
  }
} 
Example 122
Source File: AddClassAnnotationEditorTest.scala    From rug   with GNU General Public License v3.0 5 votes vote down vote up
package com.atomist.project.edit.java

import com.atomist.param.SimpleParameterValues
import com.atomist.parse.java.ParsingTargets
import com.atomist.project.edit.{FailedModificationAttempt, SuccessfulModification}
import com.atomist.rug.kind.java.AddClassAnnotationEditor
import org.scalatest.{FlatSpec, Matchers}

class AddClassAnnotationEditorTest extends FlatSpec with Matchers {

  // Adds a Mysterious annotation to any class
  val addFoobarAnnotationEditor = new AddClassAnnotationEditor(
    coit => true,
    annotationPackageName = Some("com.megacorp"),
    annotationName = "Mysterious",
    javaSourcePath = ""
  )

  val args = SimpleParameterValues.Empty

  it should "apply annotations where needed" in {
    val as = ParsingTargets.SpringIoGuidesRestServiceSource
    assert(addFoobarAnnotationEditor.applicability(as).canApply === true)
    addFoobarAnnotationEditor.modify(as, args) match {
      case sma: SuccessfulModification =>
        val javaFiles = sma.result.files.filter(_.name.endsWith(".java"))
        assert(javaFiles.size === 2)
        javaFiles.foreach(f => {
          f.content contains "@Mysterious" shouldBe true
        })
      case f: FailedModificationAttempt => fail
      case _ => fail
    }
  }

  it should "recognize that annotation is applied using FQN without import" is pending
} 
Example 123
Source File: PackageFinderTest.scala    From rug   with GNU General Public License v3.0 5 votes vote down vote up
package com.atomist.project.edit.java

import com.atomist.rug.kind.java.support.{PackageFinder, PackageInfo}
import com.atomist.source.{EmptyArtifactSource, SimpleFileBasedArtifactSource, StringFileArtifact}
import org.scalatest.{FlatSpec, Matchers}

class PackageFinderTest extends FlatSpec with Matchers {

  it should "find default package in no-package source" in {
    val src = "public class Hello {}"
    PackageFinder.findPackage(src) should equal("")
  }

  it should "find non default package on first line" in {
    val src = "package com.foo.bar;\npublic class Hello {}"
    PackageFinder.findPackage(src) should equal("com.foo.bar")
  }

  it should "find non default package after comment" in {
    val src =
      """
        |
        |
        |package com.foo.bar;
        |
        |public class Hello {
        |}
      """.stripMargin
    PackageFinder.findPackage(src) should equal("com.foo.bar")
  }

  it should "find no packages in empty project" in {
    PackageFinder.packages(new EmptyArtifactSource("")) should be(empty)
  }

  it should "find default package in root" in {
    PackageFinder.packages(
      new SimpleFileBasedArtifactSource("", StringFileArtifact("Hello.java", "public class Hello {}"))
    ) should equal(Seq(PackageInfo("", 1)))
  }

  it should "find explicit package in root" in {
    PackageFinder.packages(
      new SimpleFileBasedArtifactSource("", StringFileArtifact("com/foo/bar/Hello.java", "package com.foo.bar;\npublic class Hello {}"))
    ) should equal(Seq(PackageInfo("com.foo.bar", 1)))
  }

  it should "find and count explicit packages in root" in {
    PackageFinder.packages(
      new SimpleFileBasedArtifactSource("",
        Seq(
          StringFileArtifact("com/foo/bar/Hello.java", "package com.foo.bar;\npublic class Hello {}"),
          StringFileArtifact("com/foo/bar/Hello2.java", "package com.foo.bar;\npublic class Hello2 {}")
        ))
    ) should equal(Seq(PackageInfo("com.foo.bar", 2)))
  }

} 
Example 124
Source File: MustacheMergeToolTest.scala    From rug   with GNU General Public License v3.0 5 votes vote down vote up
package com.atomist.project.common.template

import com.atomist.source.file.ClassPathArtifactSource
import com.atomist.source.{EmptyArtifactSource, SimpleFileBasedArtifactSource, StringFileArtifact}
import org.scalatest.{FlatSpec, Matchers}

class MustacheMergeToolTest extends FlatSpec with Matchers {

  import MustacheSamples._

  it should "fail with empty backing ArtifactSource" in {
    val mmt = new MustacheMergeTool(EmptyArtifactSource(""))
    an[IllegalArgumentException] should be thrownBy mmt.mergeToFile(FirstContext, "any.mustache")
  }

  it should "succeed with first template" in {
    val templateName = "first.mustache"
    val mmt = new MustacheMergeTool(new SimpleFileBasedArtifactSource("foo", StringFileArtifact(templateName, First)))
    for (i <- 1 to 3) {
      val r = mmt.mergeToFile(FirstContext, templateName).content
      r should equal(FirstExpected)
    }
  }

  val templateName = "first.mustache"
  val static1 = StringFileArtifact("static1", "test")
  val doubleDynamic = StringFileArtifact("location_was_{{in_ca}}.txt_.mustache", First)
  val straightTemplate = StringFileArtifact(templateName, First)
  val templateAs = new SimpleFileBasedArtifactSource("",
    Seq(
      straightTemplate,
      static1,
      doubleDynamic
    ))
  val cpTemplateAs = ClassPathArtifactSource.toArtifactSource("mustache/test.mustache")

  // This actually tests MergeHelper, not just MergeTool functionality
  it should "process template files" in {
    val mmt = new MustacheMergeTool(templateAs)
    val files = mmt.processTemplateFiles(FirstContext, templateAs.allFiles)
    assert(files.size === 3)
    val expectedPath = "location_was_true.txt"
    // First.mustache
    assert(files.map(f => f.path).toSet === Set(static1.path, expectedPath, "first.mustache"))
    assert(files.find(_.path.equals(expectedPath)).get.content === FirstExpected)
  }

  it should "process classpath template files" in {
    val mmt = new MustacheMergeTool(cpTemplateAs)
    val files = mmt.processTemplateFiles(FirstContext, cpTemplateAs.allFiles)
    assert(files.size === 1)
    val expectedPath = "G'day Chris. You just scored 10000 dollars. But the ATO has hit you with tax so you'll only get 6000.0"
    assert(files.map(f => f.path).toSet === Set("test.mustache"))
    assert(files.head.content === expectedPath)
  }

  it should "process template ArtifactSource" in {
    val mmt = new MustacheMergeTool(templateAs)
    val files = mmt.processTemplateFiles(FirstContext, templateAs).allFiles
    assert(files.size === 3)
    val expectedPath = "location_was_true.txt"
    // First.mustache
    assert(files.map(f => f.path).toSet === Set(static1.path, expectedPath, "first.mustache"))
    assert(files.find(_.path.equals(expectedPath)).get.content === FirstExpected)
  }

  val mt = new MustacheMergeTool(new EmptyArtifactSource(""))

  it should "strip .mustache extension" in {
    val name = "template_.mustache"
    mt.toInPlaceFilePath(name) should equal ("template")
  }

  it should "strip .scaml extension" in {
    val name = "template_.scaml"
    mt.toInPlaceFilePath(name) should equal ("template")
  }
} 
Example 125
Source File: UnaryEstimatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.base.unary

import com.salesforce.op.UID
import com.salesforce.op.features.Feature
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{DoubleType, MetadataBuilder, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class UnaryEstimatorTest extends OpEstimatorSpec[Real, UnaryModel[Real, Real], UnaryEstimator[Real, Real]] {

  
  val expectedResult = Seq(0.0, 0.8, 0.4, 0.2, 1.0).map(_.toReal)

}

class MinMaxNormEstimator(uid: String = UID[MinMaxNormEstimator])
  extends UnaryEstimator[Real, Real](operationName = "minMaxNorm", uid = uid) {

  def fitFn(dataset: Dataset[Real#Value]): UnaryModel[Real, Real] = {
    val grouped = dataset.groupBy()
    val maxVal = grouped.max().first().getDouble(0)
    val minVal = grouped.min().first().getDouble(0)
    new MinMaxNormEstimatorModel(min = minVal, max = maxVal, operationName = operationName, uid = uid)
  }
}

final class MinMaxNormEstimatorModel private[op](val min: Double, val max: Double, operationName: String, uid: String)
  extends UnaryModel[Real, Real](operationName = operationName, uid = uid) {
  def transformFn: Real => Real = _.v.map(v => (v - min) / (max - min)).toReal
} 
Example 126
Source File: PredictionTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class PredictionTest extends FlatSpec with TestCommon {
  import Prediction.Keys._

  Spec[Prediction] should "extend FeatureType" in {
    Prediction(1.0) shouldBe a[FeatureType]
    Prediction(1.0) shouldBe a[OPMap[_]]
    Prediction(1.0) shouldBe a[NumericMap]
    Prediction(1.0) shouldBe a[RealMap]
  }
  it should "error if prediction is missing" in {
    intercept[NonNullableEmptyException](new Prediction(null))
    intercept[NonNullableEmptyException](new Prediction(Map.empty))
    intercept[NonNullableEmptyException](Map.empty[String, Double].toPrediction)
    intercept[NonNullableEmptyException]((null: Map[String, Double]).toPrediction)
    assertPredictionError(new Prediction(Map("a" -> 1.0)))
    assertPredictionError(Map("a" -> 1.0, "b" -> 2.0).toPrediction)
    assertInvalidKeysError(new Prediction(Map(PredictionName -> 2.0, "a" -> 1.0)))
  }
  it should "compare values correctly" in {
    Prediction(1.0).equals(Prediction(1.0)) shouldBe true
    Prediction(1.0).equals(Prediction(0.0)) shouldBe false
    Prediction(1.0, Array(1.0), Array.empty[Double]).equals(Prediction(1.0)) shouldBe false
    Prediction(1.0, Array(1.0), Array(2.0, 3.0)).equals(Prediction(1.0, Array(1.0), Array(2.0, 3.0))) shouldBe true

    Map(PredictionName -> 5.0).toPrediction shouldBe a[Prediction]
  }
  it should "return prediction" in {
    Prediction(2.0).prediction shouldBe 2.0
  }
  it should "return raw prediction" in {
    Prediction(2.0).rawPrediction shouldBe Array()
    Prediction(1.0, Array(1.0, 2.0), Array.empty[Double]).rawPrediction shouldBe Array(1.0, 2.0)
    Prediction(1.0, (1 until 200).map(_.toDouble).toArray, Array.empty[Double]).rawPrediction shouldBe
      (1 until 200).map(_.toDouble).toArray
  }
  it should "return probability" in {
    Prediction(3.0).probability shouldBe Array()
    Prediction(1.0, Array.empty[Double], Array(1.0, 2.0)).probability shouldBe Array(1.0, 2.0)
    Prediction(1.0, Array.empty[Double], (1 until 200).map(_.toDouble).toArray).probability shouldBe
      (1 until 200).map(_.toDouble).toArray
  }
  it should "return score" in {
    Prediction(4.0).score shouldBe Array(4.0)
    Prediction(1.0, Array(2.0, 3.0), Array.empty[Double]).score shouldBe Array(1.0)
    Prediction(1.0, Array.empty[Double], Array(2.0, 3.0)).score shouldBe Array(2.0, 3.0)
  }
  it should "have a nice .toString method implementation" in {
    Prediction(4.0).toString shouldBe
      "Prediction(prediction = 4.0, rawPrediction = Array(), probability = Array())"
    Prediction(1.0, Array(2.0, 3.0), Array.empty[Double]).toString shouldBe
      "Prediction(prediction = 1.0, rawPrediction = Array(2.0, 3.0), probability = Array())"
    Prediction(1.0, Array.empty[Double], Array(2.0, 3.0)).toString shouldBe
      "Prediction(prediction = 1.0, rawPrediction = Array(), probability = Array(2.0, 3.0))"
  }

  private def assertPredictionError(f: => Unit) =
    intercept[NonNullableEmptyException](f).getMessage shouldBe
      s"Prediction cannot be empty: value map must contain '$PredictionName' key"

  private def assertInvalidKeysError(f: => Unit) =
    intercept[IllegalArgumentException](f).getMessage shouldBe
      s"requirement failed: value map must only contain valid keys: '$PredictionName' or " +
        s"starting with '$RawPredictionName' or '$ProbabilityName'"

} 
Example 127
Source File: JavaConversionTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import java.util

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class JavaConversionTest extends FlatSpec with TestCommon {

  Spec[JavaConversionTest] should "convert java Map to TextMap" in {
    type T = util.HashMap[String, String]
    null.asInstanceOf[T].toTextMap shouldEqual TextMap(Map())
    val j = new T()
    j.toTextMap shouldEqual TextMap(Map())
    j.put("A", "a")
    j.put("B", null)
    j.toTextMap shouldEqual TextMap(Map("A" -> "a", "B" -> null))
  }

  it should "convert java Map to MultiPickListMap" in {
    type T = util.HashMap[String, java.util.HashSet[String]]
    null.asInstanceOf[T].toMultiPickListMap shouldEqual MultiPickListMap(Map())
    val j = new T()
    j.toMultiPickListMap shouldEqual MultiPickListMap(Map())
    val h = new util.HashSet[String]()
    h.add("X")
    h.add("Y")
    j.put("test", h)
    j.put("test2", null)
    j.toMultiPickListMap shouldEqual MultiPickListMap(Map("test" -> Set("X", "Y"), "test2" -> Set()))
  }

  it should "convert java Map to IntegralMap" in {
    type T = util.HashMap[String, java.lang.Long]
    null.asInstanceOf[T].toIntegralMap shouldEqual IntegralMap(Map())
    val j = new T()
    j.toIntegralMap shouldEqual IntegralMap(Map())
    j.put("test", java.lang.Long.valueOf(17))
    j.put("test2", null)
    j.toIntegralMap.v("test") shouldEqual 17L
    j.toIntegralMap.v("test2") shouldEqual (null: java.lang.Long)
  }

  it should "convert java Map to RealMap" in {
    type T = util.HashMap[String, java.lang.Double]
    null.asInstanceOf[T].toRealMap shouldEqual RealMap(Map())
    val j = new T()
    j.toRealMap shouldEqual RealMap(Map())
    j.put("test", java.lang.Double.valueOf(17.5))
    j.put("test2", null)
    j.toRealMap.v("test") shouldEqual 17.5
    j.toRealMap.v("test2") shouldEqual (null: java.lang.Double)
  }

  it should "convert java Map to BinaryMap" in {
    type T = util.HashMap[String, java.lang.Boolean]
    null.asInstanceOf[T].toBinaryMap shouldEqual RealMap(Map())
    val j = new T()
    j.toBinaryMap shouldEqual RealMap(Map())
    j.put("test1", java.lang.Boolean.TRUE)
    j.put("test0", java.lang.Boolean.FALSE)
    j.put("test2", null)
    j.toBinaryMap.v("test1") shouldEqual true
    j.toBinaryMap.v("test0") shouldEqual false
    j.toBinaryMap.v("test2") shouldEqual (null: java.lang.Boolean)
  }

} 
Example 128
Source File: GeolocationTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.apache.lucene.spatial3d.geom.{GeoPoint, PlanetModel}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class GeolocationTest extends FlatSpec with TestCommon {
  val PaloAlto: (Double, Double) = (37.4419, -122.1430)

  Spec[Geolocation] should "extend OPList[Double]" in {
    val myGeolocation = new Geolocation(List.empty[Double])
    myGeolocation shouldBe a[FeatureType]
    myGeolocation shouldBe a[OPCollection]
    myGeolocation shouldBe a[OPList[_]]
  }

  it should "behave on missing data" in {
    val sut = new Geolocation(List.empty[Double])
    sut.lat.isNaN shouldBe true
    sut.lon.isNaN shouldBe true
    sut.accuracy shouldBe GeolocationAccuracy.Unknown
  }

  it should "not accept missing value" in {
    assertThrows[IllegalArgumentException](new Geolocation(List(PaloAlto._1)))
    assertThrows[IllegalArgumentException](new Geolocation(List(PaloAlto._1, PaloAlto._2)))
    assertThrows[IllegalArgumentException](new Geolocation((PaloAlto._1, PaloAlto._2, 123456.0)))
  }

  it should "compare values correctly" in {
    new Geolocation(List(32.399, 154.213, 6.0)).equals(new Geolocation(List(32.399, 154.213, 6.0))) shouldBe true
    new Geolocation(List(12.031, -23.44, 6.0)).equals(new Geolocation(List(32.399, 154.213, 6.0))) shouldBe false
    FeatureTypeDefaults.Geolocation.equals(new Geolocation(List(32.399, 154.213, 6.0))) shouldBe false
    FeatureTypeDefaults.Geolocation.equals(FeatureTypeDefaults.Geolocation) shouldBe true
    FeatureTypeDefaults.Geolocation.equals(Geolocation(List.empty[Double])) shouldBe true

    (35.123, -94.094, 5.0).toGeolocation shouldBe a[Geolocation]
  }

  it should "correctly generate a Lucene GeoPoint object" in {
    val myGeo = new Geolocation(List(32.399, 154.213, 6.0))
    myGeo.toGeoPoint shouldBe new GeoPoint(PlanetModel.WGS84, math.toRadians(myGeo.lat), math.toRadians(myGeo.lon))
  }

} 
Example 129
Source File: ListTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import org.apache.lucene.spatial3d.geom.{GeoPoint, PlanetModel}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class ListTest extends FlatSpec with TestCommon {

  
  Spec[DateTimeList] should "extend OPList[Long]" in {
    val myDateTimeList = new DateTimeList(List.empty[Long])
    myDateTimeList shouldBe a[FeatureType]
    myDateTimeList shouldBe a[OPCollection]
    myDateTimeList shouldBe a[OPList[_]]
    myDateTimeList shouldBe a[DateList]
  }
  it should "compare values correctly" in {
    new DateTimeList(List(456L, 13L)) shouldBe new DateTimeList(List(456L, 13L))
    new DateTimeList(List(13L, 456L)) should not be new DateTimeList(List(456L, 13L))
    FeatureTypeDefaults.DateTimeList should not be new DateTimeList(List(456L, 13L))
    FeatureTypeDefaults.DateTimeList shouldBe DateTimeList(List.empty[Long])

    List(12237834L, 4890489839L).toDateTimeList shouldBe a[DateTimeList]
  }


} 
Example 130
Source File: OPVectorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features.types

import com.salesforce.op.test.TestCommon
import com.salesforce.op.utils.spark.RichVector._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OPVectorTest extends FlatSpec with TestCommon {

  val vectors = Seq(
    Vectors.sparse(4, Array(0, 3), Array(1.0, 1.0)).toOPVector,
    Vectors.dense(Array(2.0, 3.0, 4.0)).toOPVector,
    // Purposely added a very large sparse vector to verify the efficiency
    Vectors.sparse(100000000, Array(1), Array(777.0)).toOPVector
  )

  Spec[OPVector] should "be empty" in {
    val zero = Vectors.zeros(0)
    new OPVector(zero).isEmpty shouldBe true
    new OPVector(zero).nonEmpty shouldBe false
    zero.toOPVector shouldBe a[OPVector]
  }

  it should "error on size mismatch" in {
    val ones = Array.fill(vectors.size)(Vectors.sparse(1, Array(0), Array(1.0)).toOPVector)
    for {
      (v1, v2) <- vectors.zip(ones)
      res <- Seq(() => v1 + v2, () => v1 - v2, () => v1 dot v2)
    } intercept[IllegalArgumentException](res()).getMessage should {
      startWith("requirement failed: Vectors must") and include("same length")
    }
  }

  it should "compare values" in {
    val zero = Vectors.zeros(0)
    new OPVector(zero) shouldBe new OPVector(zero)
    new OPVector(zero).value shouldBe zero

    Vectors.dense(Array(1.0, 2.0)).toOPVector shouldBe Vectors.dense(Array(1.0, 2.0)).toOPVector
    Vectors.sparse(5, Array(3, 4), Array(1.0, 2.0)).toOPVector shouldBe
      Vectors.sparse(5, Array(3, 4), Array(1.0, 2.0)).toOPVector
    Vectors.dense(Array(1.0, 2.0)).toOPVector should not be Vectors.dense(Array(2.0, 2.0)).toOPVector
    new OPVector(Vectors.dense(Array(1.0, 2.0))) should not be Vectors.dense(Array(2.0, 2.0)).toOPVector
    OPVector.empty shouldBe new OPVector(zero)
  }

  it should "'+' add" in {
    for {(v1, v2) <- vectors.zip(vectors)} {
      (v1 + v2) shouldBe (v1.value + v2.value).toOPVector
    }
  }

  it should "'-' subtract" in {
    for {(v1, v2) <- vectors.zip(vectors)} {
      (v1 - v2) shouldBe (v1.value - v2.value).toOPVector
    }
  }

  it should "compute dot product" in {
    for {(v1, v2) <- vectors.zip(vectors)} {
      (v1 dot v2) shouldBe (v1.value dot v2.value)
    }
  }

  it should "combine" in {
    for {(v1, v2) <- vectors.zip(vectors)} {
      v1.combine(v2) shouldBe v1.value.combine(v2.value).toOPVector
      v1.combine(v2, v2, v1) shouldBe v1.value.combine(v2.value, v2.value, v1.value).toOPVector
    }
  }

} 
Example 131
Source File: FeatureSparkTypeTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features

import com.salesforce.op.features.types.FeatureType
import com.salesforce.op.test.{TestCommon, TestSparkContext}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest.{Assertion, FlatSpec}
import org.scalatest.junit.JUnitRunner

import scala.reflect.runtime.universe._

@RunWith(classOf[JUnitRunner])
class FeatureSparkTypeTest extends FlatSpec with TestCommon {
  val primitiveTypes = Seq(
    (DoubleType, weakTypeTag[types.Real], DoubleType),
    (FloatType, weakTypeTag[types.Real], DoubleType),
    (LongType, weakTypeTag[types.Integral], LongType),
    (IntegerType, weakTypeTag[types.Integral], LongType),
    (ShortType, weakTypeTag[types.Integral], LongType),
    (ByteType, weakTypeTag[types.Integral], LongType),
    (DateType, weakTypeTag[types.Date], LongType),
    (TimestampType, weakTypeTag[types.DateTime], LongType),
    (StringType, weakTypeTag[types.Text], StringType),
    (BooleanType, weakTypeTag[types.Binary], BooleanType),
    (VectorType, weakTypeTag[types.OPVector], VectorType)
  )

  val nonNullable = Seq(
    (DoubleType, weakTypeTag[types.RealNN], DoubleType),
    (FloatType, weakTypeTag[types.RealNN], DoubleType)
  )

  private def mapType(v: DataType) = MapType(StringType, v, valueContainsNull = true)
  private def arrType(v: DataType) = ArrayType(v, containsNull = true)

  val collectionTypes = Seq(
    (arrType(LongType), weakTypeTag[types.DateList], arrType(LongType)),
    (arrType(DoubleType), weakTypeTag[types.Geolocation], arrType(DoubleType)),
    (arrType(StringType), weakTypeTag[types.TextList], arrType(StringType)),
    (mapType(StringType), weakTypeTag[types.TextMap], mapType(StringType)),
    (mapType(DoubleType), weakTypeTag[types.RealMap], mapType(DoubleType)),
    (mapType(LongType), weakTypeTag[types.IntegralMap], mapType(LongType)),
    (mapType(BooleanType), weakTypeTag[types.BinaryMap], mapType(BooleanType)),
    (mapType(arrType(StringType)), weakTypeTag[types.MultiPickListMap], mapType(arrType(StringType))),
    (mapType(arrType(DoubleType)), weakTypeTag[types.GeolocationMap], mapType(arrType(DoubleType)))
  )

  Spec(FeatureSparkTypes.getClass) should "assign appropriate feature type tags for valid types and versa" in {
    primitiveTypes.map(scala.Function.tupled(assertTypes()))
  }

  it should "assign appropriate feature type tags for valid non-nullable types and versa" in {
    nonNullable.map(scala.Function.tupled(assertTypes(isNullable = false)))
  }

  it should "assign appropriate feature type tags for collection types and versa" in {
    collectionTypes.map(scala.Function.tupled(assertTypes()))
  }

  it should "error for unsupported types" in {
    val error = intercept[IllegalArgumentException](FeatureSparkTypes.featureTypeTagOf(BinaryType, isNullable = false))
    error.getMessage shouldBe "Spark BinaryType is currently not supported"
  }

  it should "error for unknown types" in {
    val unknownType = NullType
    val error = intercept[IllegalArgumentException](FeatureSparkTypes.featureTypeTagOf(unknownType, isNullable = false))
    error.getMessage shouldBe s"No feature type tag mapping for Spark type $unknownType"
  }

  def assertTypes(
    isNullable: Boolean = true
  )(
    sparkType: DataType,
    featureType: WeakTypeTag[_ <: FeatureType],
    expectedSparkType: DataType
  ): Assertion = {
    FeatureSparkTypes.featureTypeTagOf(sparkType, isNullable) shouldBe featureType
    FeatureSparkTypes.sparkTypeOf(featureType) shouldBe expectedSparkType
  }

} 
Example 132
Source File: RichStructTypeTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.spark

import com.salesforce.op.test.TestSparkContext
import org.apache.spark.sql.functions._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RichStructTypeTest extends FlatSpec with TestSparkContext {

  import com.salesforce.op.utils.spark.RichStructType._

  case class Human
  (
    name: String,
    age: Double,
    height: Double,
    heightIsNull: Double,
    isBlueEyed: Double,
    gender: Double,
    testFeatNegCor: Double
  )

  // scalastyle:off
  val humans = Seq(
    Human("alex",     32,  5.0,  0,  1,  1,  0),
    Human("alice",    32,  4.0,  1,  0,  0,  1),
    Human("bob",      32,  6.0,  1,  1,  1,  0),
    Human("charles",  32,  5.5,  0,  1,  1,  0),
    Human("diana",    32,  5.4,  1,  0,  0,  1),
    Human("max",      32,  5.4,  1,  0,  0,  1)
  )
  // scalastyle:on

  val humansDF = spark.createDataFrame(humans).select(col("*"), col("name").as("(name)_blarg_123"))
  val schema = humansDF.schema

  Spec[RichStructType] should "find schema fields by name (case insensitive)" in {
    schema.findFields("name").map(_.name) shouldBe Seq("name", "(name)_blarg_123")
    schema.findFields("blArg").map(_.name) shouldBe Seq("(name)_blarg_123")
  }

  it should "find schema fields by name (case sensitive)" in {
    schema.findFields("Name", ignoreCase = false) shouldBe Seq.empty
    schema.findFields("aGe", ignoreCase = false) shouldBe Seq.empty
    schema.findFields("age", ignoreCase = false).map(_.name) shouldBe Seq("age")
  }

  it should "fail on duplication" in {
    the[IllegalArgumentException] thrownBy schema.findField("a")
  }

  it should "throw an error if no such name" in {
    the[IllegalArgumentException] thrownBy schema.findField("???")
  }

} 
Example 133
Source File: TimeBasedAggregatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.aggregators

import com.salesforce.op.features.FeatureBuilder
import com.salesforce.op.features.types._
import com.salesforce.op.stages.FeatureGeneratorStage
import com.salesforce.op.test.TestCommon
import org.joda.time.Duration
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TimeBasedAggregatorTest extends FlatSpec with TestCommon {

  private val data = Seq(TimeBasedTest(100L, 1.0, "a", Map("a" -> "a")),
    TimeBasedTest(200L, 2.0, "b", Map("b" -> "b")),
    TimeBasedTest(300L, 3.0, "c", Map("c" -> "c")),
    TimeBasedTest(400L, 4.0, "d", Map("d" -> "d")),
    TimeBasedTest(500L, 5.0, "e", Map("e" -> "e")),
    TimeBasedTest(600L, 6.0, "f", Map("f" -> "f"))
  )

  private val timeExt = Option((d: TimeBasedTest) => d.time)

  Spec[LastAggregator[_]] should "return the most recent event" in {
    val feature = FeatureBuilder.Real[TimeBasedTest].extract(_.real.toRealNN)
      .aggregate(LastReal).asPredictor
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(data, timeExt, CutOffTime.NoCutoff())
    extracted shouldBe Real(Some(6.0))
  }

  it should "return the most recent event within the time window" in {
    val feature = FeatureBuilder.Text[TimeBasedTest].extract(_.string.toText)
      .aggregate(LastText).asResponse
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(data, timeExt, CutOffTime.UnixEpoch(300L),
      responseWindow = Option(new Duration(201L)))
    extracted shouldBe Text(Some("e"))
  }

  it should "return the feature type empty value when no events are passed in" in {
    val feature = FeatureBuilder.TextMap[TimeBasedTest].extract(_.map.toTextMap)
      .aggregate(LastTextMap).asPredictor
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(Seq(), timeExt, CutOffTime.NoCutoff())
    extracted shouldBe TextMap.empty
  }

  Spec[FirstAggregator[_]] should "return the first event" in {
    val feature = FeatureBuilder.TextAreaMap[TimeBasedTest].extract(_.map.toTextAreaMap)
      .aggregate(FirstTextAreaMap).asResponse
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(data, timeExt, CutOffTime.UnixEpoch(301L))
    extracted shouldBe TextAreaMap(Map("d" -> "d"))
  }

  it should "return the first event within the time window" in {
    val feature = FeatureBuilder.Currency[TimeBasedTest].extract(_.real.toCurrency)
      .aggregate(FirstCurrency).asPredictor
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(data, timeExt, CutOffTime.UnixEpoch(400L),
      predictorWindow = Option(new Duration(201L)))
    extracted shouldBe Currency(Some(2.0))
  }

  it should "return the feature type empty value when no events are passed in" in {
    val feature = FeatureBuilder.State[TimeBasedTest].extract(_.string.toState)
      .aggregate(FirstState).asPredictor
    val aggregator = feature.originStage.asInstanceOf[FeatureGeneratorStage[TimeBasedTest, _]].featureAggregator
    val extracted = aggregator.extract(Seq(), timeExt, CutOffTime.NoCutoff())
    extracted shouldBe State.empty
  }
}

case class TimeBasedTest(time: Long, real: Double, string: String, map: Map[String, String]) 
Example 134
Source File: RichGenericRecordTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.avro

import com.salesforce.op.test.{TestCommon, TestSparkContext}
import com.salesforce.op.utils.io.avro.AvroInOut
import org.apache.avro.generic.GenericRecord
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class RichGenericRecordTest extends FlatSpec with Matchers with TestSparkContext with TestCommon {

  import com.salesforce.op.utils.avro.RichGenericRecord._

  val dataPath = resourceFile(parent = "../test-data", name = s"PassengerData.avro").getPath
  val passengerData = AvroInOut.read[GenericRecord](dataPath).getOrElse(throw new Exception("Couldn't read data"))
  val firstRow = passengerData.sortBy(_.get("passengerId").toString.toInt).first

  Spec[RichGenericRecord] should "get value of Int" in {
    val id = firstRow.getValue[Int]("passengerId")
    id shouldBe Some(1)
  }

  it should "get value of Double" in {
    val survived = firstRow.getValue[Double]("survived")
    survived shouldBe Some(0.0)
  }

  it should "get value of Long" in {
    val height = firstRow.getValue[Long]("height")
    height shouldBe Some(168L)
  }

  it should "get value of String" in {
    val gender = firstRow.getValue[String]("gender")
    gender shouldBe Some("Female")
  }

  it should "get value of Char" in {
    val gender = firstRow.getValue[Char]("gender")
    gender shouldBe Some("Female")
  }

  it should "get value of Float" in {
    val age = firstRow.getValue[Float]("age")
    age shouldBe Some(32.0)
  }

  it should "get value of Short" in {
    val weight = firstRow.getValue[Short]("weight")
    weight shouldBe Some(67)
  }

  it should "throw error for invalid field" in {
    val error = intercept[IllegalArgumentException](firstRow.getValue[Short]("invalidField"))
    error.getMessage shouldBe "requirement failed: invalidField is not found in Avro schema!"
  }
} 
Example 135
Source File: SpecialDoubleSerializerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.json

import com.salesforce.op.test.TestCommon
import org.json4s.jackson.JsonMethods._
import org.json4s.{DefaultFormats, Extraction, Formats}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class SpecialDoubleSerializerTest extends FlatSpec with TestCommon {

  val data = Map(
    "normal" -> Seq(-1.1, 0.0, 2.3),
    "infs" -> Seq(Double.NegativeInfinity, Double.PositiveInfinity),
    "minMax" -> Seq(Double.MinValue, Double.MaxValue),
    "nan" -> Seq(Double.NaN)
  )

  Spec[SpecialDoubleSerializer] should behave like
    readWriteDoubleValues(data)(
      json = """{"normal":[-1.1,0.0,2.3],"infs":["-Infinity","Infinity"],"minMax":[-1.7976931348623157E308,1.7976931348623157E308],"nan":["NaN"]}""" // scalastyle:off
    )(DefaultFormats + new SpecialDoubleSerializer)

  Spec[SpecialDoubleSerializer] + " (with big decimal)" should behave like
    readWriteDoubleValues(data)(
      json = """{"normal":[-1.1,0.0,2.3],"infs":["-Infinity","Infinity"],"minMax":[-1.7976931348623157E+308,1.7976931348623157E+308],"nan":["NaN"]}""" // scalastyle:off
    )(DefaultFormats.withBigDecimal + new SpecialDoubleSerializer)


  def readWriteDoubleValues(input: Map[String, Seq[Double]])(json: String)(implicit formats: Formats): Unit = {
    it should "write double entries" in {
      compact(Extraction.decompose(input)) shouldBe json
    }
    it should "read double entries" in {
      val parsed = parse(json).extract[Map[String, Seq[Double]]]
      parsed.keys shouldBe input.keys
      parsed zip input foreach {
        case (("nan", a), ("nan", b)) => a.foreach(_.isNaN shouldBe true)
        case ((_, a), (_, b)) => a should contain theSameElementsAs b
      }
    }
  }
} 
Example 136
Source File: RichTupleTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.tuples

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import com.salesforce.op.utils.tuples.RichTuple._

@RunWith(classOf[JUnitRunner])
class RichTupleTest extends FlatSpec with TestCommon {
  Spec(RichTuple.getClass) should "map a function to provided elements in tuples" in {
    val res = (Some(1), Some(2)).map((x, y) => x + y)
    res.get shouldBe 3
  }

  it should "map on empty tuples" in {
    val none: (Option[String], Option[String]) = None -> None
    none.map((x, y) => x + y) shouldBe None
  }

  it should "map the function with no effect for left param alone" in {
    val res = (Some(1), None).map((x, y) => x + y)
    res shouldBe Some(1)
  }

  it should "map the function with no effect for right param alone" in {
    val res = (None, Some(1)).map((x, y) => x + y)
    res shouldBe Some(1)
  }
} 
Example 137
Source File: SequenceAggregatorsTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.spark

import com.salesforce.op.test.TestSparkContext
import com.salesforce.op.utils.spark.SequenceAggregators.{SeqMapDouble, SeqMapLong}
import org.apache.spark.sql.Row
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class SequenceAggregatorsTest extends FlatSpec with TestSparkContext {
  import spark.implicits._
  val meanSeqMapDouble = SequenceAggregators.MeanSeqMapDouble(2)
  val modeSeqMapLong = SequenceAggregators.ModeSeqMapLong(2)
  val modeSeqNullInt = SequenceAggregators.ModeSeqNullInt(2)
  implicit val encMeanMapAgg = meanSeqMapDouble.outputEncoder
  implicit val encModeMapAgg = modeSeqMapLong.outputEncoder

  Spec(SequenceAggregators.getClass) should "correctly compute the mean by key of maps" in {
    val data = Seq(
      (Map("a" -> 1.0, "b" -> 5.0), Map("z" -> 10.0)),
      (Map("c" -> 11.0), Map("y" -> 3.0, "x" -> 0.0)),
      (Map.empty[String, Double], Map.empty[String, Double])
    ).toDF("f1", "f2").map(Helper.toSeqMapDouble)

    val res = data.select(meanSeqMapDouble.toColumn).first()
    res shouldBe Seq(Map("a" -> 1.0, "c" -> 11.0, "b" -> 5.0), Map("z" -> 10.0, "y" -> 3.0, "x" -> 0.0))
  }

  it should "correctly compute the mean by key of maps again" in {
    val meanData = Seq(
      (Map("a" -> 1.0, "b" -> 5.0), Map("y" -> 4.0, "x" -> 0.0, "z" -> 10.0)),
      (Map("a" -> -3.0, "b" -> 3.0, "c" -> 11.0), Map("y" -> 3.0, "x" -> 0.0)),
      (Map("a" -> 1.0, "b" -> 5.0), Map("y" -> 1.0, "x" -> 0.0, "z" -> 5.0))
    ).toDF("f1", "f2").map(Helper.toSeqMapDouble)

    val res = meanData.select(meanSeqMapDouble.toColumn).first()
    res shouldBe Seq(Map("a" -> -1.0 / 3, "c" -> 11.0, "b" -> 13.0 / 3), Map("z" -> 7.5, "y" -> 8.0 / 3, "x" -> 0.0))
  }

  it should "correctly compute the mode by key of maps" in {
    val data = Seq(
      (Map("a" -> 1L, "b" -> 5L), Map("z" -> 10L)),
      (Map("c" -> 11L), Map("y" -> 3L, "x" -> 0L)),
      (Map.empty[String, Long], Map.empty[String, Long])
    ).toDF("f1", "f2").map(Helper.toSeqMapLong)

    val res = data.select(modeSeqMapLong.toColumn).first()
    res shouldBe Seq(Map("a" -> 1L, "b" -> 5L, "c" -> 11L), Map("x" -> 0L, "y" -> 3L, "z" -> 10L))
  }

  it should "correctly compute the mode by key of maps again" in {
    val modeData = Seq(
      (Map("a" -> 1L, "b" -> 5L), Map("y" -> 4L, "x" -> 0L, "z" -> 10L)),
      (Map("a" -> -3L, "b" -> 3L, "c" -> 11L), Map("y" -> 3L, "x" -> 0L)),
      (Map("a" -> 1L, "b" -> 5L), Map("y" -> 1L, "x" -> 0L, "z" -> 5L))
    ).toDF("f1", "f2").map(Helper.toSeqMapLong)

    val res = modeData.select(modeSeqMapLong.toColumn).first()
    res shouldBe Seq(Map("a" -> 1L, "b" -> 5L, "c" -> 11L), Map("x" -> 0L, "y" -> 1L, "z" -> 5L))
  }

  it should "correctly compute the mode" in {
    val data = Seq(
      (Some(3L), None),
      (Some(3L), Some(2L)),
      (Some(1L), Some(5L))
    ).toDF("f1", "f2").map(r => Seq(if (r.isNullAt(0)) None else Option(r.getLong(0)),
      if (r.isNullAt(1)) None else Option(r.getLong(1))))

    val res = data.select(modeSeqNullInt.toColumn).first()
    res shouldBe Seq(3L, 2L)
  }
}

private object Helper {
  def toSeqMapDouble(r: Row): SeqMapDouble = Seq(r.getMap[String, Double](0).toMap, r.getMap[String, Double](1).toMap)
  def toSeqMapLong(r: Row): SeqMapLong = Seq(r.getMap[String, Long](0).toMap, r.getMap[String, Long](1).toMap)
} 
Example 138
Source File: TextUtilsTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.text

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TextUtilsTest extends FlatSpec with TestCommon {
  Spec(TextUtils.getClass) should "concat strings" in {
    TextUtils.concat("Left", "Right", ",") shouldBe "Left,Right"
  }

  it should "concat with no effect for right half alone" in {
    TextUtils.concat("", "Right", ",") shouldBe "Right"
  }

  it should "concat with no effect for left half alone" in {
    TextUtils.concat("Left", "", ",") shouldBe "Left"
  }

  it should "concat empty strings" in {
    TextUtils.concat("", "", ",") shouldBe ""
  }

  it should "clean a string with special chars" in {
    TextUtils.cleanString("A string wit#h %bad pun&ctuation mark<=>s") shouldBe "AStringWitHBadPunCtuationMarkS"
  }

  it should "clean an Option(string) with special chars" in {
    val testString: Option[String] = Some("A string wit#h %bad pun&ctuation mark<=>s")
    TextUtils.cleanOptString(testString) shouldBe Some("AStringWitHBadPunCtuationMarkS")
  }
} 
Example 139
Source File: DateTimeUtilsTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.date

import com.salesforce.op.test.TestCommon
import org.joda.time.{DateTime, DateTimeZone}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class DateTimeUtilsTest extends FlatSpec with TestCommon {

  val dateStr = "2017-03-29T14:00:07.000Z"
  val date = DateTime.parse(dateStr)
  val now = DateTime.now(DateTimeZone.UTC)

  Spec(DateTimeUtils.getClass) should "parse date in Iso format" in {
    DateTimeUtils.parse(dateStr) shouldBe date.getMillis
  }

  it should "parse date in yyyy-MM-dd HH:mm:ss.SSS format" in {
    val formattedStr = "2017-03-29 14:00:07.000"
    DateTimeUtils.parse(formattedStr) shouldBe date.getMillis
  }

  it should "parse Unix" in {
    DateTimeUtils.parseUnix(now.getMillis) shouldBe now.toString("YYYY/MM/dd")
  }

  it should "get range between two dates" in {
    val numberOfDays = 500
    val diff = DateTimeUtils.getRange(
      date.minusDays(numberOfDays).toString("YYYY/MM/dd"),
      date.toString("YYYY/MM/dd")
    )
    diff.length shouldBe numberOfDays + 1
  }

  it should "get date difference days from start date" in {
    val datePlusDays = DateTimeUtils.getDatePlusDays(now.toString("YYYY/MM/dd"), 31)
    datePlusDays shouldBe now.plusDays(31).toString("YYYY/MM/dd")
  }
} 
Example 140
Source File: AvroInOutTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.io.avro

import java.io.{File, FileNotFoundException, FileWriter}
import java.nio.file.Paths

import com.salesforce.op.test.TestSparkContext
import com.salesforce.op.utils.io.avro.AvroInOut._
import org.apache.avro.generic.GenericRecord
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.rdd.RDD
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class AvroInOutTest extends FlatSpec with TestSparkContext {
  val avroSchemaPath = s"$testDataDir/PassengerDataAll.avsc"
  val avroFilePath = s"$testDataDir/PassengerDataAll.avro"
  val avroFileRecordCount = 891
  val hdfs: FileSystem = FileSystem.get(sc.hadoopConfiguration)
  lazy val avroTemp: String = tempDir + "/avro-inout-test"

  Spec(AvroInOut.getClass) should "creates RDD from an avro file" in {
    val res = readPathSeq(avroFilePath, withCount = true, deepCopy = true, persist = false)
    res shouldBe a[RDD[_]]
    res.count shouldBe avroFileRecordCount
  }

  it should "creates RDD from a sequence of avro files" in {
    val res = readPathSeq(s"$avroFilePath,$avroFilePath")
    res.count shouldBe avroFileRecordCount*2
  }

  it should "create RDD from a mixed sequence of valid and invalid avro files" in {
    val res = readPathSeq(s"badfile/path1,$avroFilePath,badfile/path2,$avroFilePath,badfile/path3")
    res.count shouldBe avroFileRecordCount*2
  }

  it should "throw an error if passed in avro files are invalid" in {
    val error = intercept[IllegalArgumentException](readPathSeq("badfile/path1,badfile/path2"))
    error.getMessage shouldBe "No valid directory found in path 'badfile/path1,badfile/path2'"
  }

  it should "creates Some(RDD) from an avro file" in {
    val res = read(avroFilePath)
    res.size shouldBe 1
    res.get shouldBe an[RDD[_]]
    res.get.count shouldBe avroFileRecordCount
  }

  it should "create None from an invalid avro file" in {
    val res = read("badfile/path")
    res shouldBe None
  }

  Spec[AvroWriter[_]] should "writeAvro to filesystem" in {
    val avroData = readPathSeq(avroFilePath).asInstanceOf[RDD[GenericRecord]]
    val avroSchema = loadFile(avroSchemaPath)

    val error = intercept[FileNotFoundException](hdfs.listStatus(new Path(avroTemp)))
    error.getMessage shouldBe s"File $avroTemp does not exist"

    AvroWriter(avroData).writeAvro(avroTemp, avroSchema)
    val hdfsFiles = hdfs.listStatus(new Path(avroTemp)) filter (x => x.getPath.getName.contains("part"))
    val res = readPathSeq((for { x <- hdfsFiles } yield avroTemp + "/" + x.getPath.getName).mkString(","))
    res.count shouldBe avroFileRecordCount
  }

  it should "checkPathsExist" in {
    val tmpDir = Paths.get(File.separator, "tmp").toFile
    val f1 = new File(tmpDir, "avroinouttest")
    f1.delete()
    val w = new FileWriter(f1)
    w.write("just checking")
    w.close()
    val f2 = new File(tmpDir, "thisfilecannotexist")
    f2.delete()
    val f3 = new File(tmpDir, "this file cannot exist")
    f3.delete()
    assume(f1.exists && !f2.exists && !f3.exists)

    // check for one dir being invalid in the path amongst two
    selectExistingPaths(s"$f1,$f2") shouldBe f1.toString

    // check if all dirs in the path are invalid then we get an exception
    intercept[IllegalArgumentException] { selectExistingPaths(f2.toString) }

    // also, check if all dirs in the path are invalid ( in a different way ) then we get an exception
    intercept[IllegalArgumentException] { selectExistingPaths(f3.toString) }

    // check for one dir being invalid ( in a different way ) in the path amongst the two dirs in it
    selectExistingPaths(s"$f1,$f3") shouldBe f1.toString

    // check for paths order insensitivity
    selectExistingPaths(s"$f3,$f1") shouldBe f1.toString

    // check for an exception if the path is an empty string
    intercept[IllegalArgumentException] { selectExistingPaths("") }
  }

} 
Example 141
Source File: CSVToAvroTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.io.csv

import com.salesforce.op.test.{Passenger, TestSparkContext}
import org.apache.spark.SparkException
import org.apache.spark.rdd.RDD
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CSVToAvroTest extends FlatSpec with TestSparkContext {
  val avroSchema: String = loadFile(s"$resourceDir/PassengerSchemaModifiedDataTypes.avsc")
  val csvReader: CSVInOut = new CSVInOut(CSVOptions(header = true))
  lazy val csvRDD: RDD[Seq[String]] = csvReader.readRDD(s"$resourceDir/PassengerDataModifiedDataTypes.csv")
  lazy val csvFileRecordCount: Long = csvRDD.count

  Spec(CSVToAvro.getClass) should "convert RDD[Seq[String]] to RDD[GenericRecord]" in {
    val res = CSVToAvro.toAvro(csvRDD, avroSchema)
    res shouldBe a[RDD[_]]
    res.count shouldBe csvFileRecordCount
  }

  it should "convert RDD[Seq[String]] to RDD[T]" in {
    val res = CSVToAvro.toAvroTyped[Passenger](csvRDD, avroSchema)
    res shouldBe a[RDD[_]]
    res.count shouldBe csvFileRecordCount
  }

  it should "throw an error for nested schema" in {
    val invalidAvroSchema = loadFile(s"$resourceDir/PassengerSchemaNestedTypeCSV.avsc")
    val exceptionMsg = "CSV should be a flat file and not have nested records (unsupported column(Sex schemaType=ENUM)"
    val error = intercept[SparkException](CSVToAvro.toAvro(csvRDD, invalidAvroSchema).count())
    error.getCause.getMessage shouldBe exceptionMsg
  }

  it should "throw an error for mis-matching schema fields" in {
    val invalidAvroSchema = loadFile(s"$resourceDir/PassengerSchemaInvalidField.avsc")
    val error = intercept[SparkException](CSVToAvro.toAvro(csvRDD, invalidAvroSchema).count())
    error.getCause.getMessage shouldBe "Mismatch number of fields in csv record and avro schema"
  }

  it should "throw an error for bad data" in {
    val invalidDataRDD = csvReader.readRDD(s"$resourceDir/PassengerDataContentTypeMisMatch.csv")
    val error = intercept[SparkException](CSVToAvro.toAvro(invalidDataRDD, avroSchema).count())
    error.getCause.getMessage shouldBe "Boolean column not actually a boolean. Invalid value: 'fail'"
  }
} 
Example 142
Source File: CSVInOutTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.io.csv

import com.salesforce.op.test.TestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, DataFrame}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CSVInOutTest extends FlatSpec with TestSparkContext {
  private val csvReader = new CSVInOut(CSVOptions(header = true))
  private val csvFile = s"$testDataDir/PassengerDataAllWithHeader.csv"

  Spec[CSVInOut] should "throw error for bad file paths with DataFrame" in {
    val error = intercept[AnalysisException](csvReader.readDataFrame("/bad/file/path/read/dataframe"))
    error.getMessage should endWith ("Path does not exist: file:/bad/file/path/read/dataframe;")
  }

  it should "throw error for bad file paths with RDD" in {
    val error = intercept[AnalysisException](csvReader.readRDD("/bad/file/path/read/rdd"))
    error.getMessage should endWith ("Path does not exist: file:/bad/file/path/read/rdd;")
  }

  it should "read a CSV file to DataFrame" in {
    val res = csvReader.readDataFrame(csvFile)
    res shouldBe a[DataFrame]
    res.count shouldBe 891
  }

  it should "read a CSV file to RDD" in {
    val res = csvReader.readRDD(csvFile)
    res shouldBe a[RDD[_]]
    res.count shouldBe 891
  }
} 
Example 143
Source File: UIDTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class UIDTest extends FlatSpec with TestCommon {

  Spec(UID.getClass) should "generate UIDs" in {
    (1 to 100000).map(_ => UID[UIDTest]).toSet.size shouldBe 100000
  }

  it should "allow counting UIDs" in {
    val start = UID.count()
    (1 to 100).foreach(_ => UID[UIDTest])
    val end = UID.count()
    end - start shouldBe 100
  }

  it should "allow reset UIDs to a specific count" in {
    val count = UID.count()
    val first = (1 to 100).map(_ => UID[UIDTest])
    UID.reset(count)
    val second = (1 to 100).map(_ => UID[UIDTest])
    first should contain theSameElementsAs second
    UID.reset()[UIDTest] shouldBe "UIDTest_000000000001"
  }

  it should "allow reset UIDs" in {
    UID.reset()
    val first = (1 to 100).map(_ => UID[UIDTest])
    UID.reset()
    val second = (1 to 100).map(_ => UID[UIDTest])
    first should contain theSameElementsAs second
  }

  it should "parse from string" in {
    UID.reset().fromString(UID[UIDTest]) shouldBe ("UIDTest", "000000000001")
  }

  it should "error on invalid string" in {
    intercept[IllegalArgumentException](UID.fromString("foo")).getMessage shouldBe "Invalid UID: foo"
  }
} 
Example 144
Source File: FeatureHistoryTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op

import com.salesforce.op.FeatureHistory.{OriginFeatureKey, StagesKey}
import com.salesforce.op.test.TestCommon
import org.apache.spark.sql.types.MetadataBuilder
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class FeatureHistoryTest extends FlatSpec with TestCommon {

  val feature1 = "feature1"
  val feature2 = "feature2"
  val stage1 = "stage1"
  val stage2 = "stage2"

  Spec[FeatureHistory] should "convert a feature history to metadata" in {
    val featureHistory = FeatureHistory(originFeatures = Seq(feature1, feature2), stages = Seq(stage1, stage2))

    val featureHistoryMetadata = featureHistory.toMetadata

    featureHistoryMetadata.contains(OriginFeatureKey) shouldBe true
    featureHistoryMetadata.contains(StagesKey) shouldBe true

    featureHistoryMetadata.getStringArray(OriginFeatureKey) shouldBe Array(feature1, feature2)
    featureHistoryMetadata.getStringArray(StagesKey) shouldBe Array(stage1, stage2)
  }

  it should "merge two instances" in {
    val featureHistory1 = FeatureHistory(originFeatures = Seq(feature1), stages = Seq(stage1))
    val featureHistory2 = FeatureHistory(originFeatures = Seq(feature2), stages = Seq(stage2))

    val featureHistoryCombined = featureHistory1.merge(featureHistory2)
    featureHistoryCombined.originFeatures shouldBe Seq(feature1, feature2)
    featureHistoryCombined.stages shouldBe Seq(stage1, stage2)
  }

  it should "create a metadata for a map" in {
    val featureHistory1 = FeatureHistory(originFeatures = Seq(feature1), stages = Seq(stage1))
    val featureHistory2 = FeatureHistory(originFeatures = Seq(feature2), stages = Seq(stage2))

    val map = Map(("1" -> featureHistory1), ("2" -> featureHistory2))
    val featureHistoryMetadata = FeatureHistory.toMetadata(map)

    featureHistoryMetadata.contains("1") shouldBe true
    featureHistoryMetadata.contains("2") shouldBe true

    val f1 = featureHistoryMetadata.getMetadata("1")

    f1.contains(OriginFeatureKey) shouldBe true
    f1.contains(StagesKey) shouldBe true

    f1.getStringArray(OriginFeatureKey) shouldBe Array(feature1)
    f1.getStringArray(StagesKey) shouldBe Array(stage1)

    val f2 = featureHistoryMetadata.getMetadata("2")

    f2.contains(OriginFeatureKey) shouldBe true
    f2.contains(StagesKey) shouldBe true

    f2.getStringArray(OriginFeatureKey) shouldBe Array(feature2)
    f2.getStringArray(StagesKey) shouldBe Array(stage2)
  }

  it should "create a map from metadata" in {

    val featureHistory1 = FeatureHistory(originFeatures = Seq(feature1), stages = Seq(stage1))
    val featureHistory2 = FeatureHistory(originFeatures = Seq(feature2), stages = Seq(stage2))

    val featureHistoryMapMetadata = new MetadataBuilder()
      .putMetadata("1", featureHistory1.toMetadata)
      .putMetadata("2", featureHistory2.toMetadata)
      .build()

    val featureHistoryMap = FeatureHistory.fromMetadataMap(featureHistoryMapMetadata)

    featureHistoryMap.contains("1") shouldBe true
    featureHistoryMap("1") shouldBe featureHistory1

    featureHistoryMap.contains("2") shouldBe true
    featureHistoryMap("2") shouldBe featureHistory2
  }
} 
Example 145
Source File: JoinedReadersTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.readers

import com.salesforce.op.aggregators.CutOffTime
import com.salesforce.op.test._
import org.joda.time.{DateTimeConstants, Duration}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class JoinedReadersTest extends FlatSpec with PassengerSparkFixtureTest {

  val sparkReader = DataReaders.Aggregate.csv[SparkExample](
    path = Some("../test-data/SparkExample.csv"),
    schema = SparkExample.getClassSchema.toString,
    key = _.getLabel.toString,
    aggregateParams = AggregateParams(None, CutOffTime.NoCutoff())
  )

  val passengerReader = DataReaders.Conditional.avro[Passenger](
    path = Some(passengerAvroPath), // Path should be optional so can also pass in as a parameter
    key = _.getPassengerId.toString, // Entity to score
    conditionalParams = ConditionalParams(
      timeStampFn = _.getRecordDate.toLong, // Record field which defines the date for the rest of the columns
      targetCondition = _.getBoarded >= 1471046600, // Function to figure out if target event has occurred
      responseWindow = None, // How many days after target event to include in response aggregation
      predictorWindow = None, // How many days before target event to include in predictor aggregation
      timeStampToKeep = TimeStampToKeep.Min
    )
  )

  Spec[JoinedReader[_, _]] should "take any kind of reader as the leftmost input" in {
    profileReader.innerJoin(sparkReader) shouldBe a[JoinedDataReader[_, _]]
    dataReader.outerJoin(sparkReader) shouldBe a[JoinedDataReader[_, _]]
    passengerReader.leftOuterJoin(sparkReader) shouldBe a[JoinedDataReader[_, _]]

  }

  it should "allow simple readers for right inputs" in {
    sparkReader.innerJoin(profileReader).joinType shouldBe JoinTypes.Inner
    sparkReader.outerJoin(profileReader).joinType shouldBe JoinTypes.Outer
    sparkReader.leftOuterJoin(profileReader).joinType shouldBe JoinTypes.LeftOuter
  }

  it should "have all subreaders correctly ordered" in {
    val joinedReader = profileReader.innerJoin(sparkReader).outerJoin(dataReader)
    joinedReader.subReaders should contain theSameElementsAs Seq(profileReader, sparkReader, dataReader)
  }

  it should "correctly set leftKey in left outer and inner joins" in {
    dataReader.leftOuterJoin(sparkReader, joinKeys = JoinKeys(leftKey = "id")).joinKeys.leftKey shouldBe "id"
    dataReader.innerJoin(sparkReader, joinKeys = JoinKeys(leftKey = "id")).joinKeys.leftKey shouldBe "id"
  }

  it should "throw an error if you try to perform a self join" in {
    a[IllegalArgumentException] should be thrownBy {
      dataReader.innerJoin(dataReader)
    }
  }

  it should "throw an error if you try to use the same reader twice" in {
    a[IllegalArgumentException] should be thrownBy {
      dataReader.innerJoin(sparkReader).innerJoin(dataReader)
    }
  }

  it should "throw an error if you try to read the same data type twice with different readers" in {
    a[IllegalArgumentException] should be thrownBy {
      passengerReader.innerJoin(sparkReader).outerJoin(dataReader)
    }
  }

  it should "throw an error if you try to use an invalid key combination" in {
    a[RuntimeException] should be thrownBy {
      dataReader.innerJoin(sparkReader, joinKeys = JoinKeys(resultKey = DataFrameFieldNames.KeyFieldName))
        .generateDataFrame(Array.empty)
    }
  }

  it should "produce a JoinedAggregateDataReader when withSecondaryAggregation is called" in {
    val joinedReader = profileReader.innerJoin(sparkReader)
    val timeFilter = TimeBasedFilter(
      condition = new TimeColumn(boardedTime),
      primary = new TimeColumn(boardedTime),
      timeWindow = Duration.standardDays(DateTimeConstants.DAYS_PER_WEEK)
    )
    joinedReader.withSecondaryAggregation(timeFilter) shouldBe a[JoinedAggregateDataReader[_, _]]
  }

} 
Example 146
Source File: ParquetProductReaderTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.readers

import com.salesforce.op.features.FeatureBuilder
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestCommon, TestSparkContext}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

// Need this case class to be external to (not nested in) ParquetProductReaderTest for spark sql to work correctly.
// Fields in the case class are case-sensitive and should exactly match the parquet column names.
case class PassengerType
(
  PassengerId: Int,
  Survived: Int,
  Pclass: Option[Int],
  Name: Option[String],
  Sex: String,
  Age: Option[Double],
  SibSp: Option[Int],
  Parch: Option[Int],
  Ticket: String,
  Fare: Double,
  Cabin: Option[String],
  Embarked: Option[String]
)

@RunWith(classOf[JUnitRunner])
class ParquetProductReaderTest extends FlatSpec with TestSparkContext with TestCommon {
  def passengerFilePath: String = s"$testDataDir/PassengerDataAll.parquet"

  val parquetRecordCount = 891

  import spark.implicits._
  val dataReader = new ParquetProductReader[PassengerType](
    readPath = Some(passengerFilePath),
    key = _.PassengerId.toString
  )

  Spec[ParquetProductReader[_]] should "read in data correctly" in {
    val data = dataReader.readDataset().collect()
    data.foreach(_ shouldBe a[PassengerType])
    data.length shouldBe parquetRecordCount
  }

  it should "read in byte arrays as valid strings" in {
    val caseReader = DataReaders.Simple.parquetCase[PassengerType](
      path = Some(passengerFilePath),
      key = _.PassengerId.toString
    )

    val records = caseReader.readDataset().collect()
    records.collect { case r if r.PassengerId == 1 => r.Ticket } shouldBe Array("A/5 21171")
  }

  it should "map the columns of data to types defined in schema" in {
    val caseReader = DataReaders.Simple.parquetCase[PassengerType](
      path = Some(passengerFilePath),
      key = _.PassengerId.toString
    )

    val records = caseReader.readDataset().collect()
    records(0).Survived shouldBe a[java.lang.Integer]
    records(0).Fare shouldBe a[java.lang.Double]
    records(0).Ticket shouldBe a[java.lang.String]
    records.collect { case r if r.PassengerId == 1 => r.Age } shouldBe Array(Some(22.0))
  }

  it should "generate a dataframe" in {
    val tokens = FeatureBuilder.TextList[PassengerType]
      .extract(p => p.Name.map(_.split(" ")).toSeq.flatten.toTextList)
      .asPredictor
    val data = dataReader.generateDataFrame(rawFeatures = Array(tokens)).collect()

    data.collect { case r if r.get(0) == "3" => r.get(1) } shouldBe Array(Array("Heikkinen,", "Miss.", "Laina"))
    data.length shouldBe parquetRecordCount
  }
} 
Example 147
Source File: StreamingReadersTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.readers

import org.apache.hadoop.fs.Path
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import com.salesforce.op.test.TestCommon

@RunWith(classOf[JUnitRunner])
class StreamingReadersTest extends FlatSpec with TestCommon {
  val dotPath = new Path(".part.avro")
  val underScorePath = new Path("_part.avro")
  val goodPath = new Path("part.avro")

  val readerCls = StreamingReaders.getClass

  Spec[StreamingReaders.type] should "ignore hidden files" in {
    assert(!StreamingReaders.Simple.defaultPathFiler(dotPath))
    assert(!StreamingReaders.Simple.defaultPathFiler(underScorePath))
  }

  it should "accept non-hidden files" in {
    assert(StreamingReaders.Simple.defaultPathFiler(goodPath))
  }
} 
Example 148
Source File: CSVProductReadersTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.readers

import com.salesforce.op.features.FeatureBuilder
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestCommon, TestSparkContext}
import com.salesforce.op.utils.io.csv.CSVOptions
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


// need this to be external to (not nested in) CSVProductReaderTest for spark sql to work correctly
case class PassengerCaseClass
(
  passengerId: Int,
  age: Option[Int],
  gender: Option[String],
  height: Option[Int],
  weight: Option[scala.math.BigInt],
  description: Option[String],
  boarded: Option[Long],
  recordDate: Option[Long],
  survived: Option[Boolean],
  randomTime: Option[java.sql.Timestamp],
  randomFloating: Option[Double]
)

@RunWith(classOf[JUnitRunner])
class CSVProductReadersTest extends FlatSpec with TestSparkContext with TestCommon {
  def csvWithoutHeaderPath: String = s"$testDataDir/BigPassenger.csv"

  def csvWithHeaderPath: String = s"$testDataDir/BigPassengerWithHeader.csv"

  import spark.implicits._

  Spec[CSVProductReader[_]] should "read in data correctly with header" in {
    val dataReader = new CSVProductReader[PassengerCaseClass](
      readPath = Some(csvWithHeaderPath),
      key = _.passengerId.toString,
      options = CSVOptions(header = true)
    )
    val data = dataReader.readDataset().collect()
    data.foreach(_ shouldBe a[PassengerCaseClass])
    data.length shouldBe 8
  }

  it should "read in data correctly without header" in {
    val dataReader = DataReaders.Simple.csvCase[PassengerCaseClass](
      path = Some(csvWithoutHeaderPath),
      key = _.passengerId.toString
    )
    val data = dataReader.readDataset().collect()
    data.foreach(_ shouldBe a[PassengerCaseClass])
    data.length shouldBe 8
  }

  it should "generate a dataframe" in {
    val dataReader = new CSVProductReader[PassengerCaseClass](
      readPath = Some(csvWithHeaderPath),
      key = _.passengerId.toString,
      options = CSVOptions(header = true)
    )
    val tokens =
      FeatureBuilder.TextList[PassengerCaseClass]
        .extract(p => p.description.map(_.split(" ")).toSeq.flatten.toTextList).asPredictor
    val data = dataReader.generateDataFrame(rawFeatures = Array(tokens)).collect()
    data.collect { case r if r.get(0) == "3" => r.get(1) } shouldBe Array(Array("this", "is", "a", "description"))
    data.length shouldBe 8
  }
} 
Example 149
Source File: CSVAutoReadersTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.readers

import com.salesforce.op.test.PassengerSparkFixtureTest
import org.apache.avro.Schema
import org.apache.avro.generic.GenericRecord
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.collection.JavaConverters._


@RunWith(classOf[JUnitRunner])
class CSVAutoReadersTest extends FlatSpec with PassengerSparkFixtureTest {

  private val expectedSchema = new Schema.Parser().parse(resourceFile(name = "PassengerAuto.avsc"))
  private val allFields = expectedSchema.getFields.asScala.map(_.name())
  private val keyField: String = allFields.head

  Spec[CSVAutoReader[_]] should "read in data correctly and infer schema" in {
    val dataReader = DataReaders.Simple.csvAuto[GenericRecord](
      path = Some(passengerCsvWithHeaderPath),
      key = _.get(keyField).toString
    )
    val data = dataReader.readRDD().collect()
    data.foreach(_ shouldBe a[GenericRecord])
    data.length shouldBe 8

    val inferredSchema = data.head.getSchema
    inferredSchema shouldBe expectedSchema
  }

  it should "read in data correctly and infer schema based with headers provided" in {
    val dataReader = DataReaders.Simple.csvAuto[GenericRecord](
      path = Some(passengerCsvPath),
      key = _.get(keyField).toString,
      headers = allFields
    )
    val data = dataReader.readRDD().collect()
    data.foreach(_ shouldBe a[GenericRecord])
    data.length shouldBe 8

    val inferredSchema = data.head.getSchema
    inferredSchema shouldBe expectedSchema

  }

} 
Example 150
Source File: AvroFieldTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.cli.gen

import com.salesforce.op.cli.gen.AvroField._
import com.salesforce.op.test.TestCommon
import org.apache.avro.Schema
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.collection.JavaConverters._
import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class AvroFieldTest extends FlatSpec with TestCommon with Assertions {

  Spec[AvroField] should "do from" in {
    val types = List(
      Schema.Type.STRING,
      //  Schema.Type.BYTES, // somehow this avro type is not covered (yet)
      Schema.Type.INT,
      Schema.Type.LONG,
      Schema.Type.FLOAT,
      Schema.Type.DOUBLE,
      Schema.Type.BOOLEAN
    )
    val simpleSchemas = types map Schema.create

    val unions = List(
      Schema.createUnion((Schema.Type.NULL::Schema.Type.INT::Nil) map Schema.create asJava),
      Schema.createUnion((Schema.Type.INT::Schema.Type.NULL::Nil) map Schema.create asJava)
    )

    val enum = Schema.createEnum("Aliens", "undocumented", "outer",
      List("Edgar_the_Bug", "Boris_the_Animal", "Laura_Vasquez") asJava)

    val allSchemas = (enum::unions)++simpleSchemas // NULL does not work

    val fields = allSchemas.zipWithIndex map {
      case (s, i) => new Schema.Field("x" + i, s, "Who", null: Object)
    }

    val expected = List(
      AEnum(fields(0), isNullable = false),
      AInt(fields(1), isNullable = true),
      AInt(fields(2), isNullable = true),
      AString(fields(3), isNullable = false),
      AInt(fields(4), isNullable = false),
      ALong(fields(5), isNullable = false),
      AFloat(fields(6), isNullable = false),
      ADouble(fields(7), isNullable = false),
      ABoolean(fields(8), isNullable = false)
    )

    an[IllegalArgumentException] should be thrownBy {
      val nullSchema = Schema.create(Schema.Type.NULL)
      val nullField = new Schema.Field("xxx", null, "Nobody", null: Object)
      AvroField from nullField
    }

    fields.size shouldBe expected.size

    for {
      (field, expected) <- fields zip expected
    } {
      val actual = AvroField from field
      actual shouldBe expected
    }
  }

} 
Example 151
Source File: OpsTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.cli.gen

import java.io.File
import java.nio.file.Paths

import com.salesforce.op.cli.{AvroSchemaFromFile, CliParameters, GeneratorConfig}
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.io.Source


@RunWith(classOf[JUnitRunner])
class OpsTest extends FlatSpec with TestCommon with Assertions {

  val tempFolder = new File(System.getProperty("java.io.tmpdir"))
  val projectFolder = new File(tempFolder, "cli_test")
  projectFolder.deleteOnExit()

  val testParams = CliParameters(
    location = tempFolder,
    projName = "cli_test",
    inputFile = Some(Paths.get("templates", "simple", "src", "main", "resources", "PassengerData.csv").toFile),
    response = Some("survived"),
    idField = Some("passengerId"),
    schemaSource = Some(
      AvroSchemaFromFile(Paths.get("..", "utils", "src", "main", "avro", "PassengerCSV.avsc").toFile)
    ),
    answersFile = Some(new File("passengers.answers")),
    overwrite = true).values

  Spec[Ops] should "generate project files" in {

    testParams match {
      case None =>
        fail("Could not create config, I wonder why")
      case Some(conf: GeneratorConfig) =>
        val ops = Ops(conf)
        ops.run()
        val buildFile = new File(projectFolder, "build.gradle")
        buildFile should exist

        val buildFileContent = Source.fromFile(buildFile).mkString

        buildFileContent should include("mainClassName = 'com.salesforce.app.cli_test'")

        val scalaSourcesFolder = Paths.get(projectFolder.toString, "src", "main", "scala", "com", "salesforce", "app")
        val featuresFile = Source.fromFile(new File(scalaSourcesFolder.toFile, "Features.scala")).getLines
        val heightLine = featuresFile.find(_ contains "description") map (_.trim)
        heightLine shouldBe Some(
          "val description = FB.Text[PassengerCSV].extract(_.getDescription.toText).asPredictor"
        )
    }

  }

} 
Example 152
Source File: UserIOTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.cli.gen

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}


@RunWith(classOf[JUnitRunner])
class UserIOTest extends FlatSpec with TestCommon with Assertions {


  private case class Oracle(answers: String*) extends UserIO {
    private var i = -1
    var question = "---"

    override def readLine(q: String): Option[String] = {
      question = q
      i += 1
      if (i < answers.length) Some(answers(i)) else throw new IllegalStateException(s"Out of answers, q=$q")
    }
  }

  Spec[UserIO] should "do qna" in {
    // @see https://www.urbandictionary.com/define.php?term=aks
    def aksme(q: String, answers: String*): Option[String] = {
      Oracle(answers: _*).qna(q, _.length == 1, Map("2*3" -> "6", "3+2" -> "5"))
    }

    aksme("2+2", "11", "22", "?") shouldBe Some("?")
    aksme("2+2", "4", "5", "?") shouldBe Some("4")
    aksme("2+3", "44", "", "?") shouldBe Some("?")
    aksme("2*3", "4", "?") shouldBe Some("6")
    aksme("3+2", "4", "?") shouldBe Some("5")
  }

  it should "ask" in {

    // @see https://www.urbandictionary.com/define.php?term=aks
    def aksme[Int](q: String, opts: Map[Int, List[String]], answers: String*): (String, Int) = {
      val console = Oracle(answers: _*)
      val answer = console.ask(q, opts) getOrElse fail(s"A problem answering question $q")
      (console.question, answer)
    }

    an[IllegalStateException] should be thrownBy
      aksme("what is your name?", Map(1 -> List("one", "uno")), "11", "1", "?")

    aksme("what is your name?",
      Map(
        1 -> List("Nessuno", "Nobody"),
        2 -> List("Ishmael", "Gantenbein")),
      "5", "1", "?") shouldBe("what is your name? [0] Nessuno [1] Ishmael: ", 2)
  }

} 
Example 153
Source File: OpTransformerWrapperTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.sparkwrappers.specific

import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.feature.{Normalizer, StopWordsRemover}
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpTransformerWrapperTest extends FlatSpec with TestSparkContext {

  val (testData, featureVector) = TestFeatureBuilder(
    Seq[MultiPickList](
      Set("I", "saw", "the", "red", "balloon").toMultiPickList,
      Set("Mary", "had", "a", "little", "lamb").toMultiPickList
    )
  )

  val (testDataNorm, _, _) = TestFeatureBuilder("label", "features",
    Seq[(Real, OPVector)](
      0.0.toReal -> Vectors.dense(1.0, 0.5, -1.0).toOPVector,
      1.0.toReal -> Vectors.dense(2.0, 1.0, 1.0).toOPVector,
      2.0.toReal -> Vectors.dense(4.0, 10.0, 2.0).toOPVector
    )
  )
  val (targetDataNorm, targetLabelNorm, featureVectorNorm) = TestFeatureBuilder("label", "features",
    Seq[(Real, OPVector)](
      0.0.toReal -> Vectors.dense(0.4, 0.2, -0.4).toOPVector,
      1.0.toReal -> Vectors.dense(0.5, 0.25, 0.25).toOPVector,
      2.0.toReal -> Vectors.dense(0.25, 0.625, 0.125).toOPVector
    )
  )

  Spec[OpTransformerWrapper[_, _, _]] should "remove stop words with caseSensitivity=true" in {
    val remover = new StopWordsRemover().setCaseSensitive(true)
    val swFilter =
      new OpTransformerWrapper[MultiPickList, MultiPickList, StopWordsRemover](remover).setInput(featureVector)
    val output = swFilter.transform(testData)

    output.collect(swFilter.getOutput()) shouldBe Array(
      Seq("I", "saw", "red", "balloon").toMultiPickList,
      Seq("Mary", "little", "lamb").toMultiPickList
    )
  }

  it should "should properly normalize each feature vector instance with non-default norm of 1" in {
    val baseNormalizer = new Normalizer().setP(1.0)
    val normalizer =
      new OpTransformerWrapper[OPVector, OPVector, Normalizer](baseNormalizer).setInput(featureVectorNorm)
    val output = normalizer.transform(testDataNorm)

    val sumSqDist = validateDataframeDoubleColumn(output, normalizer.getOutput().name, targetDataNorm, "features")
    assert(sumSqDist <= 1E-6, "==> the sum of squared distances between actual and expected should be below tolerance.")
  }

  def validateDataframeDoubleColumn(
    normalizedFeatureDF: DataFrame, normalizedFeatureName: String, targetFeatureDF: DataFrame, targetColumnName: String
  ): Double = {
    val sqDistUdf = udf { (leftColVec: Vector, rightColVec: Vector) => Vectors.sqdist(leftColVec, rightColVec) }

    val targetColRename = "targetFeatures"
    val renamedTargedDF = targetFeatureDF.withColumnRenamed(targetColumnName, targetColRename)
    val joinedDF = normalizedFeatureDF.join(renamedTargedDF, Seq("label"))

    // compute sum of squared distances between expected and actual
    val finalDF = joinedDF.withColumn("sqDist", sqDistUdf(joinedDF(normalizedFeatureName), joinedDF(targetColRename)))
    val sumSqDist: Double = finalDF.agg(sum(finalDF("sqDist"))).first().getDouble(0)
    sumSqDist
  }
} 
Example 154
Source File: OpEstimatorWrapperTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.sparkwrappers.specific

import com.salesforce.op.features.types._
import com.salesforce.op.test.{PrestigeData, TestFeatureBuilder, _}
import org.apache.spark.ml.feature.{MinMaxScaler, MinMaxScalerModel}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import org.slf4j.LoggerFactory


@RunWith(classOf[JUnitRunner])
class OpEstimatorWrapperTest extends FlatSpec with TestSparkContext with PrestigeData {

  val log = LoggerFactory.getLogger(this.getClass)

  val (ds, education, income, women, prestige) =
    TestFeatureBuilder[OPVector, OPVector, OPVector, OPVector]("education", "income", "women", "prestige",
      prestigeSeq.map(p =>
        (Vectors.dense(p.prestige).toOPVector, Vectors.dense(p.education).toOPVector,
          Vectors.dense(p.income).toOPVector, Vectors.dense(p.women).toOPVector)
      )
    )

  Spec[OpEstimatorWrapper[_, _, _, _]] should "scale variables properly with default min/max params" in {
    val baseScaler = new MinMaxScaler()
    val scalerModel: MinMaxScalerModel = fitScalerModel(baseScaler)

    (scalerModel.getMax - 1.0).abs should be < 1E-6
  }

  it should "scale variables properly with custom min/max params" in {
    val maxParam = 100
    val baseScaler = new MinMaxScaler().setMax(maxParam)
    val scalerModel: MinMaxScalerModel = fitScalerModel(baseScaler)

    (scalerModel.getMax - maxParam).abs should be < 1E-6
  }

  it should "should have the expected feature name" in {
    val wrappedEstimator =
      new OpEstimatorWrapper[OPVector, OPVector, MinMaxScaler, MinMaxScalerModel](new MinMaxScaler()).setInput(income)
    wrappedEstimator.getOutput().name shouldBe wrappedEstimator.getOutputFeatureName
  }

  private def fitScalerModel(baseScaler: MinMaxScaler): MinMaxScalerModel = {
    val scaler =
      new OpEstimatorWrapper[OPVector, OPVector, MinMaxScaler, MinMaxScalerModel](baseScaler).setInput(income)

    val model = scaler.fit(ds)
    val scalerModel = model.getSparkMlStage().get

    if (log.isInfoEnabled) {
      val output = scalerModel.transform(ds)
      output.show(false)
    }
    scalerModel
  }
} 
Example 155
Source File: OpPredictorWrapperTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.sparkwrappers.specific

import com.salesforce.op.features.types._
import com.salesforce.op.stages.sparkwrappers.generic.SparkWrapperParams
import com.salesforce.op.test.{PrestigeData, TestFeatureBuilder, TestSparkContext}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import org.slf4j.LoggerFactory


@RunWith(classOf[JUnitRunner])
class OpPredictorWrapperTest extends FlatSpec with TestSparkContext with PrestigeData {

  val log = LoggerFactory.getLogger(this.getClass)

  val (ds, targetLabel, featureVector) = TestFeatureBuilder[RealNN, OPVector](
    prestigeSeq.map(p => p.prestige.toRealNN -> Vectors.dense(p.education, p.income, p.women).toOPVector)
  )

  Spec[OpPredictorWrapper[_, _]] should
    "be able to run a simple logistic regression model (fitIntercept=true)" in {
    val lrModel: LinearRegressionModel = fitLinRegModel(fitIntercept = true)
    lrModel.intercept.abs should be > 1E-6
  }

  it should "be able to run a simple logistic regression model (fitIntercept=false)" in {
    val lrModel: LinearRegressionModel = fitLinRegModel(fitIntercept = false)
    lrModel.intercept.abs should be < Double.MinPositiveValue
  }

  private def fitLinRegModel(fitIntercept: Boolean): LinearRegressionModel = {
    val lrBase =
      new LinearRegression()
        .setMaxIter(10)
        .setRegParam(0.3)
        .setElasticNetParam(0.8)
        .setFitIntercept(fitIntercept)

    val lr = new OpPredictorWrapper[LinearRegression, LinearRegressionModel](lrBase)
      .setInput(targetLabel, featureVector)

    // Fit the model
    val model = lr.fit(ds).asInstanceOf[SparkWrapperParams[LinearRegressionModel]]
    val lrModel = model.getSparkMlStage().get

    // Print the coefficients and intercept for linear regression
    log.info(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")

    // Summarize the model over the training set and print out some metrics
    val trainingSummary = lrModel.summary
    log.info(s"numIterations: ${trainingSummary.totalIterations}")
    log.info(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]")
    if (log.isInfoEnabled) trainingSummary.residuals.show()
    log.info(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
    log.info(s"r2: ${trainingSummary.r2}")
    // checking r2 as a cheap way to make sure things are running as intended.
    assert(trainingSummary.r2 > 0.9)

    if (log.isInfoEnabled) {
      val output = lrModel.transform(ds)
      output.show(false)
    }

    lrModel
  }
} 
Example 156
Source File: SparkWrapperParamsTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.sparkwrappers.generic

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import org.apache.spark.ml.feature.{StandardScaler, StandardScalerModel}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterEach, FlatSpec}

@RunWith(classOf[JUnitRunner])
class SparkWrapperParamsTest extends FlatSpec with BeforeAndAfterEach with TestCommon {

  private def estimator(sparkMlStageIn: Option[StandardScaler] = None) = {
    new SwUnaryEstimator[Real, Real, StandardScalerModel, StandardScaler](
      inputParamName = "in", outputParamName = "out",
      operationName = "test-op", sparkMlStageIn = sparkMlStageIn
    )
  }

  Spec[SparkWrapperParams[_]] should "have proper default values for path and stage" in {
    val stage = estimator()
    stage.getStageSavePath() shouldBe None
    stage.getSparkMlStage() shouldBe None
  }
  it should "when setting path, it should also set path to the stage param" in {
    val stage = estimator()
    stage.setStageSavePath("/test/path")
    stage.getStageSavePath() shouldBe Some("/test/path")
  }
  it should "allow set/get spark params on a wrapped stage" in {
    val sparkStage = new StandardScaler()
    val stage = estimator(sparkMlStageIn = Some(sparkStage))
    stage.getSparkMlStage() shouldBe Some(sparkStage)
    for {
      sparkStage <- stage.getSparkMlStage()
      withMean = sparkStage.getOrDefault(sparkStage.withMean)
    } {
      withMean shouldBe false
      sparkStage.set[Boolean](sparkStage.withMean, true)
      sparkStage.get(sparkStage.withMean) shouldBe Some(true)
    }
  }

} 
Example 157
Source File: OPLogLossTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.evaluator

import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OPLogLossTest extends FlatSpec with TestSparkContext {

  val (ds, rawLabel, pred) = TestFeatureBuilder[RealNN, Prediction](
    Seq(
      (1.0, Vectors.dense(8, 1, 1), Vectors.dense(0.8, 0.1, 0.1), 0.0),
      (0.0, Vectors.dense(1.0, 0.0, 0.0), Vectors.dense(1.0, 0.0, 0.0), 0.0),
      (0.0, Vectors.dense(1.0, 0.8, 0.2), Vectors.dense(0.5, 0.4, 0.1), 0.0),
      (1.0, Vectors.dense(10.0, 80.0, 10.0), Vectors.dense(0.1, 0.8, 0.1), 1.0),
      (2.0, Vectors.dense(0.0, 0.0, 14.0), Vectors.dense(0.0, 0.0, 1.0), 2.0),
      (2.0, Vectors.dense(0.0, 0.0, 13.0), Vectors.dense(0.0, 0.0, 1.0), 2.0),
      (1.0, Vectors.dense(0.1, 0.4, 0.5), Vectors.dense(0.1, 0.4, 0.5), 2.0),
      (0.0, Vectors.dense(0.1, 0.6, 0.3), Vectors.dense(0.1, 0.6, 0.3), 1.0),
      (1.0, Vectors.dense(1.0, 0.8, 0.2), Vectors.dense(0.5, 0.4, 0.1), 0.0),
      (2.0, Vectors.dense(1.0, 0.8, 0.2), Vectors.dense(0.5, 0.4, 0.1), 0.0)
    ).map(v => (v._1.toRealNN, Prediction(v._4, v._2, v._3)))
  )

  val label = rawLabel.copy(isResponse = true)
  val expected: Double = -math.log(0.1 * 0.5 * 0.8 * 0.4 * 0.1 * 0.4 * 0.1) / 10.0

  val (dsEmpty, rawLabelEmpty, predEmpty) = TestFeatureBuilder[RealNN, Prediction](
    Seq()
  )

  val labelEmpty = rawLabel.copy(isResponse = true)

  val logLoss = LogLoss.multiLogLoss

  Spec(LogLoss.getClass) should "compute logarithmic loss metric" in {
    val metric = logLoss.setLabelCol(label).setPredictionCol(pred).evaluate(ds)
    metric shouldBe expected
  }

  it should "throw an error when the dataset is empty" in {
    the[IllegalArgumentException] thrownBy {
      logLoss.setLabelCol(labelEmpty).setPredictionCol(predEmpty).evaluate(dsEmpty)
    } should have message "requirement failed: Dataset is empty, log loss cannot be calculated"
  }
} 
Example 158
Source File: TextVectorizerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class TextVectorizerTest extends FlatSpec with TestSparkContext with AttributeAsserts {
  // scalastyle:off
  lazy val (data, f1, f2) = TestFeatureBuilder(
    Seq[(Text, Text)](
      (Text("Hamlet: To be or not to be - that is the question."), Text("Enter Hamlet")),
      (Text("Гамлет: Быть или не быть - вот в чём вопрос."), Text("Входит Гамлет")),
      (Text("המלט: להיות או לא להיות - זאת השאלה."), Text("נככס המלט"))
    )
  )
  // scalastyle:on

  "TextVectorizer" should "work correctly out of the box" in {
    val vectorized = f1.vectorize(numHashes = TransmogrifierDefaults.DefaultNumOfFeatures,
      autoDetectLanguage = TextTokenizer.AutoDetectLanguage,
      minTokenLength = TextTokenizer.MinTokenLength,
      toLowercase = TextTokenizer.ToLowercase
    )
    vectorized.originStage shouldBe a[VectorsCombiner]
    vectorized.parents.head.originStage shouldBe a[OPCollectionHashingVectorizer[_]]
    val hasher = vectorized.parents.head.originStage.asInstanceOf[OPCollectionHashingVectorizer[_]].hashingTF()
    val transformed = new OpWorkflow().setResultFeatures(vectorized).transform(data)
    val result = transformed.collect(vectorized)
    val f1NameHash = hasher.indexOf(vectorized.parents.head.originStage.getInputFeatures().head.name)
    val field = transformed.schema(vectorized.name)
    assertNominal(field, Array.fill(result.head.value.size - 1)(false) :+ true, result)
    // scalastyle:off
    result(0).value(hasher.indexOf(s"${f1NameHash}_" + "hamlet")) should be >= 1.0
    result(0).value(hasher.indexOf(s"${f1NameHash}_" + "question")) should be >= 1.0
    result(1).value(hasher.indexOf(s"${f1NameHash}_" + "гамлет")) should be >= 1.0
    result(1).value(hasher.indexOf(s"${f1NameHash}_" + "вопрос")) should be >= 1.0
    result(1).value(hasher.indexOf(s"${f1NameHash}_" + "быть")) should be >= 2.0
    result(2).value(hasher.indexOf(s"${f1NameHash}_" + "המלט")) should be >= 1.0
    result(2).value(hasher.indexOf(s"${f1NameHash}_" + "להיות")) should be >= 2.0
    // scalastyle:on
  }

  it should "allow forcing hashing into a shared hash space" in {
    val vectorized = f1.vectorize(numHashes = TransmogrifierDefaults.DefaultNumOfFeatures,
      autoDetectLanguage = TextTokenizer.AutoDetectLanguage,
      minTokenLength = TextTokenizer.MinTokenLength,
      toLowercase = TextTokenizer.ToLowercase,
      binaryFreq = true,
      others = Array(f2))
    val hasher = vectorized.parents.head.originStage.asInstanceOf[OPCollectionHashingVectorizer[_]].hashingTF()
    val transformed = new OpWorkflow().setResultFeatures(vectorized).transform(data)
    val result = transformed.collect(vectorized)
    val f1NameHash = hasher.indexOf(vectorized.parents.head.originStage.getInputFeatures().head.name)
    val field = transformed.schema(vectorized.name)
    assertNominal(field, Array.fill(result.head.value.size - 2)(false) ++ Array(true, true), result)
    // scalastyle:off
    result(0).value(hasher.indexOf(s"${f1NameHash}_" + "hamlet")) shouldBe 1.0
    result(0).value(hasher.indexOf(s"${f1NameHash}_" + "hamlet")) shouldBe 1.0
    result(1).value(hasher.indexOf(s"${f1NameHash}_" + "гамлет")) shouldBe 1.0
    result(1).value(hasher.indexOf(s"${f1NameHash}_" + "гамлет")) shouldBe 1.0
    result(2).value(hasher.indexOf(s"${f1NameHash}_" + "המלט")) shouldBe 1.0
    result(2).value(hasher.indexOf(s"${f1NameHash}_" + "המלט")) shouldBe 1.0
    // scalastyle:on
  }
} 
Example 159
Source File: ScalerMetadataTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.test.TestSparkContext
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import com.salesforce.op.utils.json.JsonUtils
import org.apache.spark.sql.types.MetadataBuilder

import scala.util.{Failure, Success}

@RunWith(classOf[JUnitRunner])
class ScalerMetadataTest extends FlatSpec with TestSparkContext {
  val linearArgs = LinearScalerArgs(slope = 2.0, intercept = 1.0)

  Spec[ScalerMetadata] should "properly construct ScalerMetadata for a LinearScaler" in {
    val metadata = ScalerMetadata(scalingType = ScalingType.Linear,
      scalingArgs = linearArgs).toMetadata()
    metadata.getString(ScalerMetadata.scalingTypeName) shouldBe ScalingType.Linear.entryName
    val args = JsonUtils.fromString[LinearScalerArgs](metadata.getString(ScalerMetadata.scalingArgsName))
    args match {
      case Failure(err) => fail(err)
      case Success(x) => x shouldBe linearArgs
    }
  }

  it should "properly construct ScalerMetaData for a LogScaler" in {
    val metadata = ScalerMetadata(scalingType = ScalingType.Logarithmic, scalingArgs = EmptyScalerArgs()).toMetadata()
    metadata.getString(ScalerMetadata.scalingTypeName) shouldBe ScalingType.Logarithmic.entryName
    metadata.getString(ScalerMetadata.scalingArgsName) shouldBe "{}"
  }

  it should "use apply to properly convert metadata to ScalerMetadata" in {
    val metadata = new MetadataBuilder().putString(ScalerMetadata.scalingTypeName, ScalingType.Linear.entryName)
      .putString(ScalerMetadata.scalingArgsName, linearArgs.toJson(pretty = false)).build()
    ScalerMetadata.apply(metadata) match {
      case Failure(err) => fail(err)
      case Success(x) => x shouldBe ScalerMetadata(ScalingType.Linear, linearArgs)
    }
  }

  it should "error when apply is given an invalid scaling type" in {
    val invalidMetaData = new MetadataBuilder().putString(ScalerMetadata.scalingTypeName, "unsupportedScaling")
      .putString(ScalerMetadata.scalingArgsName, linearArgs.toJson(pretty = false)).build()

    val err = intercept[NoSuchElementException] (
      ScalerMetadata.apply(invalidMetaData).get
    )
    err.getMessage shouldBe "unsupportedScaling is not a member of Enum (Linear, Logarithmic)"
  }
} 
Example 160
Source File: OpLDATest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.clustering.LDA
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class OpLDATest extends FlatSpec with TestSparkContext {

  val inputData = Seq(
    (0.0, Vectors.sparse(11, Array(0, 1, 2, 4, 5, 6, 7, 10), Array(1.0, 2.0, 6.0, 2.0, 3.0, 1.0, 1.0, 3.0))),
    (1.0, Vectors.sparse(11, Array(0, 1, 3, 4, 7, 10), Array(1.0, 3.0, 1.0, 3.0, 2.0, 1.0))),
    (2.0, Vectors.sparse(11, Array(0, 1, 2, 5, 6, 8, 9), Array(1.0, 4.0, 1.0, 4.0, 9.0, 1.0, 2.0))),
    (3.0, Vectors.sparse(11, Array(0, 1, 3, 6, 8, 9, 10), Array(2.0, 1.0, 3.0, 5.0, 2.0, 3.0, 9.0))),
    (4.0, Vectors.sparse(11, Array(0, 1, 2, 3, 4, 6, 9, 10), Array(3.0, 1.0, 1.0, 9.0, 3.0, 2.0, 1.0, 3.0))),
    (5.0, Vectors.sparse(11, Array(0, 1, 3, 4, 5, 6, 7, 8, 9), Array(4.0, 2.0, 3.0, 4.0, 5.0, 1.0, 1.0, 1.0, 4.0))),
    (6.0, Vectors.sparse(11, Array(0, 1, 3, 6, 8, 9, 10), Array(2.0, 1.0, 3.0, 5.0, 2.0, 2.0, 9.0))),
    (7.0, Vectors.sparse(11, Array(0, 1, 2, 3, 4, 5, 6, 9, 10), Array(1.0, 1.0, 1.0, 9.0, 2.0, 1.0, 2.0, 1.0, 3.0))),
    (8.0, Vectors.sparse(11, Array(0, 1, 3, 4, 5, 6, 7), Array(4.0, 4.0, 3.0, 4.0, 2.0, 1.0, 3.0))),
    (9.0, Vectors.sparse(11, Array(0, 1, 2, 4, 6, 8, 9, 10), Array(2.0, 8.0, 2.0, 3.0, 2.0, 2.0, 7.0, 2.0))),
    (10.0, Vectors.sparse(11, Array(0, 1, 2, 3, 5, 6, 9, 10), Array(1.0, 1.0, 1.0, 9.0, 2.0, 2.0, 3.0, 3.0))),
    (11.0, Vectors.sparse(11, Array(0, 1, 4, 5, 6, 7, 9), Array(4.0, 1.0, 4.0, 5.0, 1.0, 3.0, 1.0)))
  ).map(v => v._1.toReal -> v._2.toOPVector)

  lazy val (ds, f1, f2) = TestFeatureBuilder(inputData)

  lazy val inputDS = ds.persist()

  val seed = 1234567890L
  val k = 3
  val maxIter = 100

  lazy val expected = new LDA()
    .setFeaturesCol(f2.name)
    .setK(k)
    .setSeed(seed)
    .fit(inputDS)
    .transform(inputDS)
    .select("topicDistribution")
    .collect()
    .toSeq
    .map(_.getAs[Vector](0))

  Spec[OpLDA] should "convert document term vectors into topic vectors" in {
    val f2Vec = new OpLDA().setInput(f2).setK(k).setSeed(seed).setMaxIter(maxIter)
    val testTransformedData = f2Vec.fit(inputDS).transform(inputDS)
    val output = f2Vec.getOutput()
    val estimate = testTransformedData.collect(output)
    val mse = computeMeanSqError(estimate, expected)
    val expectedMse = 0.5
    withClue(s"Computed mse $mse (expected $expectedMse)") {
      mse should be < expectedMse
    }
  }

  it should "convert document term vectors into topic vectors (shortcut version)" in {
    val output = f2.lda(k = k, seed = seed, maxIter = maxIter)
    val f2Vec = output.originStage.asInstanceOf[OpLDA]
    val testTransformedData = f2Vec.fit(inputDS).transform(inputDS)
    val estimate = testTransformedData.collect(output)
    val mse = computeMeanSqError(estimate, expected)
    val expectedMse = 0.5
    withClue(s"Computed mse $mse (expected $expectedMse)") {
      mse should be < expectedMse
    }
  }

  private def computeMeanSqError(estimate: Seq[OPVector], expected: Seq[Vector]): Double = {
    val n = estimate.length.toDouble
    estimate.zip(expected).map { case (est, exp) => Vectors.sqdist(est.value, exp) }.sum / n
  }
} 
Example 161
Source File: EmailParserTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryLambdaTransformer
import com.salesforce.op.test.{TestFeatureBuilder, _}
import com.salesforce.op.utils.spark.RichDataset._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class EmailParserTest extends FlatSpec with TestCommon with TestSparkContext {

  val (df, email) = TestFeatureBuilder("email", Seq(
    Email("[email protected]"),
    Email("@example.com"),
    Email("test@"),
    Email("@"),
    Email(""),
    Email("notanemail"),
    Email.empty,
    Email("[email protected]")
  ))

  "Email Extraction" should "extract prefix from simple email addresses" in {
    val prefix = email.toEmailPrefix
    val result = prefix.originStage.asInstanceOf[UnaryLambdaTransformer[Email, Text]].transform(df)

    result.collect(prefix) should contain theSameElementsInOrderAs
      Seq(Text("test"), Text.empty, Text.empty, Text.empty, Text.empty, Text.empty, Text.empty, Text("first.last"))
  }

  it should "extract domain from simple email addresses" in {
    val domain = email.toEmailDomain
    val result = domain.originStage.asInstanceOf[UnaryLambdaTransformer[Email, Text]].transform(df)

    result.collect(domain) should contain theSameElementsInOrderAs
      Seq(Text("example.com"), Text.empty, Text.empty, Text.empty, Text.empty, Text.empty, Text.empty,
        Text("example.com"))
  }
} 
Example 162
Source File: FillMissingWithMeanTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.{UnaryEstimator, UnaryModel}
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.sql.DataFrame
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.reflect.ClassTag


@RunWith(classOf[JUnitRunner])
class FillMissingWithMeanTest extends FlatSpec with TestSparkContext {
  val data = Seq[Real](Real(4.0), Real(2.0), Real.empty, Real(6.0))
  val dataNull = List.fill(7)(Real.empty)
  val binData = Seq[Binary](true.toBinary, false.toBinary, Binary.empty)

  lazy val (ds, f) = TestFeatureBuilder(data = data, f1name = "f")
  lazy val (dsi, fi) = TestFeatureBuilder(data = data.map(_.value.map(_.toLong).toIntegral), f1name = "fi")
  lazy val (dsNull, fNull) = TestFeatureBuilder(data = dataNull, f1name = "fNull")
  lazy val (dsb, fb) = TestFeatureBuilder(data = binData, f1name = "fb")

  Spec[FillMissingWithMean[_, _]] should "fill missing values with mean" in {
    assertUnaryEstimator[Real, RealNN](
      output = new FillMissingWithMean[Double, Real]().setInput(f).getOutput(),
      data = ds,
      expected = Array(4.0, 2.0, 4.0, 6.0).map(_.toRealNN)
    )
  }

  it should "fill missing values with mean from a shortcut on Real feature" in {
    assertUnaryEstimator[Real, RealNN](
      output = f.fillMissingWithMean(),
      data = ds,
      expected = Array(4.0, 2.0, 4.0, 6.0).map(_.toRealNN)
    )
  }

  it should "fill missing values with mean from a shortcut on Integral feature" in {
    assertUnaryEstimator[Real, RealNN](
      output = fi.fillMissingWithMean(),
      data = dsi,
      expected = Array(4.0, 2.0, 4.0, 6.0).map(_.toRealNN)
    )
  }

  it should "fill missing values with mean from a shortcut on Binary feature" in {
    assertUnaryEstimator[Real, RealNN](
      output = fb.fillMissingWithMean(),
      data = dsb,
      expected = Array(1.0, 0.0, 0.5).map(_.toRealNN)
    )
  }

  it should "fill a feature of only nulls with default value" in {
    val default = 3.14159
    assertUnaryEstimator[Real, RealNN](
      output = fNull.fillMissingWithMean(default = default),
      data = dsNull,
      expected = List.fill(dataNull.length)(default).map(_.toRealNN)
    )
  }

  // TODO move this assert to testkit
  private def assertUnaryEstimator[I <: FeatureType, O <: FeatureType : FeatureTypeSparkConverter : ClassTag]
  (
    output: FeatureLike[O], data: DataFrame, expected: Seq[O]
  ): Unit = {
    output.originStage shouldBe a[UnaryEstimator[_, _]]
    val model = output.originStage.asInstanceOf[UnaryEstimator[I, O]].fit(data)
    model shouldBe a[UnaryModel[_, _]]
    val transformed = model.asInstanceOf[UnaryModel[I, O]].transform(data)
    val results = transformed.collect(output)
    results should contain theSameElementsAs expected
  }
} 
Example 163
Source File: IsotonicRegressionCalibratorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.features.{Feature, FeatureLike}
import com.salesforce.op.stages.impl.regression.IsotonicRegressionCalibrator
import com.salesforce.op.stages.sparkwrappers.specific.OpBinaryEstimatorWrapper
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.regression.{IsotonicRegression, IsotonicRegressionModel}
import org.apache.spark.sql._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class IsotonicRegressionCalibratorTest extends FlatSpec with TestSparkContext {

  val isoExpectedPredictions = Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)
  val isoExpectedModelBoundaries = Array(0, 1, 3, 4, 5, 6, 7, 8)
  val isoExpectedModelPredictions = Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)

  val isoDataLabels = Seq(1, 2, 3, 1, 6, 17, 16, 17, 18)
  val isoTestData = isoDataLabels.zipWithIndex.map {
    case (label, i) => label.toRealNN -> i.toRealNN
  }

  val (isoScoresDF, isoLabels, isoScores): (DataFrame, Feature[RealNN], Feature[RealNN]) =
    TestFeatureBuilder(isoTestData)

  val antiExpectedPredictions = Array(7.0, 5.0, 4.0, 4.0, 1.0)
  val antiExpectedModelBoundaries = Array(0, 1, 2, 3, 4)
  val antiExpectedModelPredictions = Array(7.0, 5.0, 4.0, 4.0, 1.0)

  val antiDataLabels = Seq(7, 5, 3, 5, 1)
  val antiTestData = antiDataLabels.zipWithIndex.map {
    case (label, i) => label.toRealNN -> i.toRealNN
  }

  val (antiScoresDF, antiLabels, antiScores): (DataFrame, Feature[RealNN], Feature[RealNN]) =
    TestFeatureBuilder(antiTestData)

  Spec[IsotonicRegressionCalibrator] should "isotonically calibrate scores using shortcut" in {
    val calibratedScores = isoScores.toIsotonicCalibrated(isoLabels)

    val estimator = calibratedScores.originStage
      .asInstanceOf[OpBinaryEstimatorWrapper[RealNN, RealNN, RealNN, IsotonicRegression, IsotonicRegressionModel]]

    val model = estimator.fit(isoScoresDF).getSparkMlStage().get

    val predictionsDF = model.asInstanceOf[Transformer]
      .transform(isoScoresDF)

    validateOutput(calibratedScores, model, predictionsDF, true, isoExpectedPredictions, isoExpectedModelBoundaries,
      isoExpectedModelPredictions)
  }

  it should "isotonically calibrate scores" in {
    val isotonicCalibrator = new IsotonicRegressionCalibrator().setInput(isoLabels, isoScores)

    val calibratedScores = isotonicCalibrator.getOutput()

    val model = isotonicCalibrator.fit(isoScoresDF).getSparkMlStage().get

    val predictionsDF = model.asInstanceOf[Transformer]
      .transform(isoScoresDF)

    validateOutput(calibratedScores, model, predictionsDF, true, isoExpectedPredictions, isoExpectedModelBoundaries,
      isoExpectedModelPredictions)
  }

  it should "antitonically calibrate scores" in {
    val isIsotonic: Boolean = false
    val isotonicCalibrator = new IsotonicRegressionCalibrator().setIsotonic(isIsotonic).setInput(isoLabels, isoScores)

    val calibratedScores = isotonicCalibrator.getOutput()

    val model = isotonicCalibrator.fit(antiScoresDF).getSparkMlStage().get

    val predictionsDF = model.asInstanceOf[Transformer]
      .transform(antiScoresDF)

    validateOutput(calibratedScores, model, predictionsDF, isIsotonic, antiExpectedPredictions,
      antiExpectedModelBoundaries, antiExpectedModelPredictions)
  }

  def validateOutput(calibratedScores: FeatureLike[RealNN],
    model: IsotonicRegressionModel, predictionsDF: DataFrame, expectedIsIsotonic: Boolean,
    expectedPredictions: Array[Double], expectedModelBoundaries: Array[Int],
    expectedModelPredictions: Array[Double]): Unit = {

    val predictions = predictionsDF.select(calibratedScores.name).rdd.map { case Row(pred) => pred }.collect()
    val isIsotonic = model.getIsotonic

    isIsotonic should be(expectedIsIsotonic)
    predictions should contain theSameElementsInOrderAs expectedPredictions
    model.boundaries.toArray should contain theSameElementsInOrderAs expectedModelBoundaries
    model.predictions.toArray should contain theSameElementsInOrderAs expectedModelPredictions
  }
} 
Example 164
Source File: OpWord2VecTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class OpWord2VecTest extends FlatSpec with TestSparkContext {

  val data = Seq(
    "I I I like like Spark".split(" "),
    "Hi I heard about Spark".split(" "),
    "I wish Java could use case classes".split(" "),
    "Logistic regression models are neat".split(" ")
  ).map(_.toSeq.toTextList)

  lazy val (inputData, f1) = TestFeatureBuilder(Seq(data.head))
  lazy val (testData, _) = TestFeatureBuilder(data.tail)

  lazy val expected = data.tail.zip(Seq(
    Vectors.dense(-0.029884086549282075, -0.055613189935684204, 0.04186216294765473).toOPVector,
    Vectors.dense(-0.0026281912411962234, -0.016138136386871338, 0.010740748473576136).toOPVector,
    Vectors.dense(0.0, 0.0, 0.0).toOPVector
  )).toArray

  Spec[OpWord2VecTest] should "convert array of strings into a vector" in {
    val f1Vec = new OpWord2Vec().setInput(f1).setMinCount(0).setVectorSize(3).setSeed(1234567890L)
    val output = f1Vec.getOutput()
    val testTransformedData = f1Vec.fit(inputData).transform(testData)
    testTransformedData.orderBy(f1.name).collect(f1, output) shouldBe expected
  }

  it should "convert array of strings into a vector (shortcut version)" in {
    val output = f1.word2vec(minCount = 0, vectorSize = 3)
    val f1Vec = output.originStage.asInstanceOf[OpWord2Vec].setSeed(1234567890L)
    val testTransformedData = f1Vec.fit(inputData).transform(testData)
    testTransformedData.orderBy(f1.name).collect(f1, output) shouldBe expected
  }

} 
Example 165
Source File: OpStringIndexerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.features.types.Text
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import com.salesforce.op.utils.spark.RichDataset._

@RunWith(classOf[JUnitRunner])
class OpStringIndexerTest extends FlatSpec with TestSparkContext {

  val txtData = Seq("a", "b", "c", "a", "a", "c").map(_.toText)
  val (ds, txtF) = TestFeatureBuilder(txtData)
  val expected = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN)


  Spec[OpStringIndexer[_]] should "correctly set the wrapped spark stage params" in {
    val indexer = new OpStringIndexer[Text]()
    indexer.setHandleInvalid(StringIndexerHandleInvalid.Skip)
    indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Skip.entryName.toLowerCase
    indexer.setHandleInvalid(StringIndexerHandleInvalid.Error)
    indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Error.entryName.toLowerCase
    indexer.setHandleInvalid(StringIndexerHandleInvalid.Keep)
    indexer.getSparkMlStage().get.getHandleInvalid shouldBe StringIndexerHandleInvalid.Keep.entryName.toLowerCase
  }

  it should "throw an error if you try to set noFilter as the indexer" in {
    val indexer = new OpStringIndexer[Text]()
    intercept[IllegalArgumentException](indexer.setHandleInvalid(StringIndexerHandleInvalid.NoFilter))
  }

  it should "correctly index a text column" in {
    val stringIndexer = new OpStringIndexer[Text]().setInput(txtF)
    val indices = stringIndexer.fit(ds).transform(ds).collect(stringIndexer.getOutput())

    indices shouldBe expected
  }

  it should "correctly deindex a numeric column" in {
    val indexedStage = new OpStringIndexer[Text]().setInput(txtF)
    val indexed = indexedStage.getOutput()
    val indices = indexedStage.fit(ds).transform(ds)
    val deindexedStage = new OpIndexToString().setInput(indexed)
    val deindexed = deindexedStage.getOutput()
    val deindexedData = deindexedStage.transform(indices).collect(deindexed)
    deindexedData shouldBe txtData
  }

} 
Example 166
Source File: LinearScalerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.test.TestSparkContext
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class LinearScalerTest extends FlatSpec with TestSparkContext {

  Spec[LinearScaler] should "Error on construction of a non-invertible transformation" in {
    val error = intercept[java.lang.IllegalArgumentException](
      LinearScaler(LinearScalerArgs(slope = 0.0, intercept = 1.0))
    )
    error.getMessage shouldBe "requirement failed: LinearScaler must have a non-zero slope to be invertible"
  }

  it should "correctly construct the linear scaling and inverse scaling function" in {
    val sampleData = Seq(0.0, 1.0, 2.0, 3.0, 4.0)
    val scaler = LinearScaler(LinearScalerArgs(slope = 2.0, intercept = 1.0))
    sampleData.map(x => scaler.scale(x)) shouldEqual sampleData.map(x => 2.0*x + 1.0)
    sampleData.map(x => scaler.descale(x)) shouldEqual sampleData.map(x => 0.5*x - 0.5)
  }
} 
Example 167
Source File: IDFTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.feature.IDF
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.{Estimator, Transformer}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class IDFTest extends FlatSpec with TestSparkContext {

  val data = Seq(
    Vectors.sparse(4, Array(1, 3), Array(1.0, 2.0)),
    Vectors.dense(0.0, 1.0, 2.0, 3.0),
    Vectors.sparse(4, Array(1), Array(1.0))
  )

  lazy val (ds, f1) = TestFeatureBuilder(data.map(_.toOPVector))

  Spec[IDF] should "compute inverted document frequency" in {
    val idf = f1.idf()
    val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds)
    val transformedData = model.asInstanceOf[Transformer].transform(ds)
    val results = transformedData.select(idf.name).collect(idf)

    idf.name shouldBe idf.originStage.getOutputFeatureName

    val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
      math.log((data.length + 1.0) / (x + 1.0))
    })
    val expected = scaleDataWithIDF(data, expectedIdf)

    for {
      (res, exp) <- results.zip(expected)
      (x, y) <- res.value.toArray.zip(exp.toArray)
    } assert(math.abs(x - y) <= 1e-5)
  }

  it should "compute inverted document frequency when minDocFreq is 1" in {
    val idf = f1.idf(minDocFreq = 1)
    val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds)
    val transformedData = model.asInstanceOf[Transformer].transform(ds)
    val results = transformedData.select(idf.name).collect(idf)
    idf.name shouldBe idf.originStage.getOutputFeatureName

    val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
      if (x > 0) math.log((data.length + 1.0) / (x + 1.0)) else 0
    })
    val expected = scaleDataWithIDF(data, expectedIdf)

    for {
      (res, exp) <- results.zip(expected)
      (x, y) <- res.value.toArray.zip(exp.toArray)
    } assert(math.abs(x - y) <= 1e-5)
  }

  private def scaleDataWithIDF(dataSet: Seq[Vector], model: Vector): Seq[Vector] = {
    dataSet.map {
      case data: DenseVector =>
        val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y }
        Vectors.dense(res)
      case data: SparseVector =>
        val res = data.indices.zip(data.values).map { case (id, value) =>
          (id, value * model(id))
        }
        Vectors.sparse(data.size, res)
    }
  }

} 
Example 168
Source File: OpCountVectorizerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.TestOpVectorColumnType.{IndCol, IndVal}
import com.salesforce.op.test.{TestFeatureBuilder, TestOpVectorMetadataBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.OpVectorMetadata
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class OpCountVectorizerTest extends FlatSpec with TestSparkContext {

  val data = Seq[(Real, TextList)](
    (Real(0), Seq("a", "b", "c").toTextList),
    (Real(1), Seq("a", "b", "b", "b", "a", "c").toTextList)
  )

  lazy val (ds, f1, f2) = TestFeatureBuilder(data)

  lazy val expected = Array[(Real, OPVector)](
    (Real(0), Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 1.0, 1.0)).toOPVector),
    (Real(1), Vectors.sparse(3, Array(0, 1, 2), Array(3.0, 2.0, 1.0)).toOPVector)
  )

  val f2vec = new OpCountVectorizer().setInput(f2).setVocabSize(3).setMinDF(2)

  Spec[OpCountVectorizerTest] should "convert array of strings into count vector" in {
    val transformedData = f2vec.fit(ds).transform(ds)
    val output = f2vec.getOutput()
    transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected
  }

  it should "return the a fitted vectorizer with the correct parameters" in {
    val fitted = f2vec.fit(ds)
    val vectorMetadata = fitted.getMetadata()
    val expectedMeta = TestOpVectorMetadataBuilder(
      f2vec,
      f2 -> List(IndVal(Some("b")), IndVal(Some("a")), IndVal(Some("c")))
    )
    // cannot just do equals because fitting is nondeterministic
    OpVectorMetadata(f2vec.getOutputFeatureName, vectorMetadata).columns should contain theSameElementsAs
      expectedMeta.columns
  }

  it should "convert array of strings into count vector (shortcut version)" in {
    val output = f2.countVec(minDF = 2, vocabSize = 3)
    val f2vec = output.originStage.asInstanceOf[OpCountVectorizer]
    val transformedData = f2vec.fit(ds).transform(ds)
    transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected
  }
} 
Example 169
Source File: ScalerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.test.TestSparkContext
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ScalerTest extends FlatSpec with TestSparkContext {

  Spec[Scaler] should "error on invalid data" in {
    val error = intercept[IllegalArgumentException](
      Scaler.apply(scalingType = ScalingType.Linear, args = EmptyScalerArgs())
    )
    error.getMessage shouldBe
      s"Invalid combination of scaling type '${ScalingType.Linear}' " +
        s"and args type '${EmptyScalerArgs().getClass.getSimpleName}'"
  }

  it should "correctly build construct a LinearScaler" in {
    val linearScaler = Scaler.apply(scalingType = ScalingType.Linear,
      args = LinearScalerArgs(slope = 1.0, intercept = 2.0))
    linearScaler shouldBe a[LinearScaler]
    linearScaler.scalingType shouldBe ScalingType.Linear
  }

  it should "correctly build construct a LogScaler" in {
    val linearScaler = Scaler.apply(scalingType = ScalingType.Logarithmic, args = EmptyScalerArgs())
    linearScaler shouldBe a[LogScaler]
    linearScaler.scalingType shouldBe ScalingType.Logarithmic
  }
} 
Example 170
Source File: TransmogrifierTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features._
import com.salesforce.op.features.types._
import com.salesforce.op.test.TestOpVectorColumnType._
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestOpVectorMetadataBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.RichStructType._
import com.salesforce.op._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TransmogrifierTest extends FlatSpec with PassengerSparkFixtureTest with AttributeAsserts {

  val inputFeatures = Array[OPFeature](heightNoWindow, weight, gender)

  Spec(Transmogrifier.getClass) should "return a single output feature of type vector with the correct name" in {
    val feature = inputFeatures.transmogrify()
    feature.name.contains("gender-heightNoWindow-weight_3-stagesApplied_OPVector")
  }

  it should "return a model when fitted" in {
    val feature = inputFeatures.transmogrify()
    val model = new OpWorkflow().setResultFeatures(feature).setReader(dataReader).train()

    model.getResultFeatures() should contain theSameElementsAs Array(feature)
    val name = model.getResultFeatures().map(_.name).head
    name.contains("gender-heightNoWindow-weight_3-stagesApplied_OPVector")
  }

  it should "correctly transform the data and store the feature names in metadata" in {
    val feature = inputFeatures.toSeq.transmogrify()
    val model = new OpWorkflow().setResultFeatures(feature).setReader(dataReader).train()
    val transformed = model.score(keepRawFeatures = true, keepIntermediateFeatures = true)
    val hist = feature.parents.flatMap { f =>
      val h = f.history()
      h.originFeatures.map(o => o -> FeatureHistory(Seq(o), h.stages))
    }.toMap
    transformed.schema.toOpVectorMetadata(feature.name) shouldEqual
      TestOpVectorMetadataBuilder.withOpNamesAndHist(
        feature.originStage,
        hist,
        (gender, "vecSet", List(IndCol(Some("OTHER")), IndCol(Some(TransmogrifierDefaults.NullString)))),
        (heightNoWindow, "vecReal", List(RootCol,
          IndColWithGroup(Some(TransmogrifierDefaults.NullString), heightNoWindow.name))),
        (weight, "vecReal", List(RootCol, IndColWithGroup(Some(TransmogrifierDefaults.NullString), weight.name)))
      )

    transformed.schema.findFields("heightNoWindow-weight_1-stagesApplied_OPVector").nonEmpty shouldBe true

    val collected = transformed.collect(feature)

    collected.head.v.size shouldEqual 6
    collected.map(_.v.toArray.toList).toSet shouldEqual
      Set(
        List(0.0, 1.0, 211.4, 1.0, 96.0, 1.0),
        List(1.0, 0.0, 172.0, 0.0, 78.0, 0.0),
        List(1.0, 0.0, 168.0, 0.0, 67.0, 0.0),
        List(1.0, 0.0, 363.0, 0.0, 172.0, 0.0),
        List(1.0, 0.0, 186.0, 0.0, 96.0, 0.0)
      )
    val field = transformed.schema(feature.name)
    assertNominal(field, Array(false, true, false, true, false, true), collected)
  }

} 
Example 171
Source File: Base64VectorizerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.OpWorkflow
import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types._
import com.salesforce.op.test.TestSparkContext
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class Base64VectorizerTest extends FlatSpec with TestSparkContext with Base64TestData with AttributeAsserts {

  "Base64Vectorizer" should "vectorize random binary data" in {
    val vec = randomBase64.vectorize(topK = 10, minSupport = 0, cleanText = true, trackNulls = false)
    val result = new OpWorkflow().setResultFeatures(vec).transform(randomData)

    result.collect(vec) should contain theSameElementsInOrderAs
      OPVector(Vectors.dense(0.0, 0.0)) +:
        Array.fill(expectedRandom.length - 1)(OPVector(Vectors.dense(1.0, 0.0)))
  }
  it should "vectorize some real binary content" in {
    val vec = realBase64.vectorize(topK = 10, minSupport = 0, cleanText = true)
    assertVectorizer(vec, expectedMime)
  }
  it should "vectorize some real binary content with a type hint" in {
    val vec = realBase64.vectorize(topK = 10, minSupport = 0, cleanText = true, typeHint = Some("application/json"))
    assertVectorizer(vec, expectedMimeJson)
  }

  def assertVectorizer(vec: FeatureLike[OPVector], expected: Seq[Text]): Unit = {
    val result = new OpWorkflow().setResultFeatures(vec).transform(realData)
    val vectors = result.collect(vec)
    val schema = result.schema(vec.name)
    assertNominal(schema, Array.fill(vectors.head.value.size)(true), vectors)

    vectors.length shouldBe expected.length
    // TODO add a more robust check
  }

} 
Example 172
Source File: DataSplitterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.tuning

import com.salesforce.op.test.TestSparkContext
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.mllib.random.RandomRDDs
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class DataSplitterTest extends FlatSpec with TestSparkContext with SplitterSummaryAsserts {
  import spark.implicits._

  val seed = 1234L
  val dataCount = 1000
  val trainingLimitDefault = 1E6.toLong

  val data =
    RandomRDDs.normalVectorRDD(sc, 1000, 3, seed = seed)
      .map(v => (1.0, Vectors.dense(v.toArray), "A")).toDF()

  val dataSplitter = DataSplitter(seed = seed)

  Spec[DataSplitter] should "split the data in the appropriate proportion - 0.0" in {
    val (train, test) = dataSplitter.setReserveTestFraction(0.0).split(data)
    test.count() shouldBe 0
    train.count() shouldBe dataCount
  }

  it should "down-sample when the data count is above the default training limit" in {
    val numRows = trainingLimitDefault * 2
    val data =
      RandomRDDs.normalVectorRDD(sc, numRows, 3, seed = seed)
        .map(v => (1.0, Vectors.dense(v.toArray), "A")).toDF()
    dataSplitter.preValidationPrepare(data)

    val dataBalanced = dataSplitter.validationPrepare(data)
    // validationPrepare calls the data sample method that samples the data to a target ratio but there is an epsilon
    // to how precise this function is which is why we need to check around that epsilon
    val samplingErrorEpsilon = (0.1 * trainingLimitDefault).toLong

    dataBalanced.count() shouldBe trainingLimitDefault +- samplingErrorEpsilon
  }

  it should "set and get all data splitter params" in {
    val maxRows = dataCount / 2
    val downSampleFraction = maxRows / dataCount.toDouble

    val dataSplitter = DataSplitter()
      .setReserveTestFraction(0.0)
      .setSeed(seed)
      .setMaxTrainingSample(maxRows)
      .setDownSampleFraction(downSampleFraction)

    dataSplitter.getReserveTestFraction shouldBe 0.0
    dataSplitter.getDownSampleFraction shouldBe downSampleFraction
    dataSplitter.getSeed shouldBe seed
    dataSplitter.getMaxTrainingSample shouldBe maxRows
  }

  it should "split the data in the appropriate proportion - 0.2" in {
    val (train, test) = dataSplitter.setReserveTestFraction(0.2).split(data)
    math.abs(test.count() - 200) < 30 shouldBe true
    math.abs(train.count() - 800) < 30 shouldBe true
  }

  it should "split the data in the appropriate proportion - 0.6" in {
    val (train, test) = dataSplitter.setReserveTestFraction(0.6).split(data)
    math.abs(test.count() - 600) < 30 shouldBe true
    math.abs(train.count() - 400) < 30 shouldBe true
  }

  it should "keep the data unchanged when prepare is called" in {
    val dataCount = data.count()
    val summary = dataSplitter.preValidationPrepare(data)
    val train = dataSplitter.validationPrepare(data)
    val sampleF = trainingLimitDefault / dataCount.toDouble
    val downSampleFraction = math.min(sampleF, 1.0)
    train.collect().zip(data.collect()).foreach { case (a, b) => a shouldBe b }
    assertDataSplitterSummary(summary.summaryOpt) { s => s shouldBe DataSplitterSummary(dataCount, downSampleFraction) }
  }

} 
Example 173
Source File: RandomParamBuilderTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.selector

import com.salesforce.op.stages.impl.classification.{OpLogisticRegression, OpRandomForestClassifier, OpXGBoostClassifier}
import com.salesforce.op.test.TestSparkContext
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class RandomParamBuilderTest extends FlatSpec with TestSparkContext {

  private val lr = new OpLogisticRegression()
  private val rf = new OpRandomForestClassifier()
  private val xgb = new OpXGBoostClassifier()


  Spec[RandomParamBuilder] should "build a param grid of the desired length with one param variable" in {
    val min = 0.00001
    val max = 10
    val lrParams = new RandomParamBuilder()
      .uniform(lr.regParam, min, max)
      .build(5)
    lrParams.length shouldBe 5
    lrParams.foreach(_.toSeq.length shouldBe 1)
    lrParams.foreach(_.toSeq.foreach( p => (p.value.asInstanceOf[Double] < max &&
      p.value.asInstanceOf[Double] > min) shouldBe true))
    lrParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam))

    val lrParams2 = new RandomParamBuilder()
      .exponential(lr.regParam, min, max)
      .build(20)
    lrParams2.length shouldBe 20
    lrParams2.foreach(_.toSeq.length shouldBe 1)
    lrParams2.foreach(_.toSeq.foreach( p => (p.value.asInstanceOf[Double] < max &&
      p.value.asInstanceOf[Double] > min) shouldBe true))
    lrParams2.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam))
  }

  it should "build a param grid of the desired length with many param variables" in {
    val lrParams = new RandomParamBuilder()
      .exponential(lr.regParam, .000001, 10)
      .subset(lr.family, Seq("auto", "binomial", "multinomial"))
      .uniform(lr.maxIter, 2, 50)
      .build(23)
    lrParams.length shouldBe 23
    lrParams.foreach(_.toSeq.length shouldBe 3)
    lrParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam, lr.family, lr.maxIter))
  }

  it should "work for all param types" in {
    val xgbParams = new RandomParamBuilder()
      .subset(xgb.checkpointPath, Seq("a", "b")) // string
      .uniform(xgb.alpha, 0, 1) // double
      .uniform(xgb.missing, 0, 100) // float
      .uniform(xgb.checkpointInterval, 2, 5) // int
      .uniform(xgb.seed, 5, 1000) // long
      .uniform(xgb.useExternalMemory) // boolean
      .exponential(xgb.baseScore, 0.0001, 1) // double
      .exponential(xgb.missing, 0.000001F, 1) // float - overwrites first call
      .build(2)

    xgbParams.length shouldBe 2
    xgbParams.foreach(_.toSeq.length shouldBe 7)
    xgbParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(xgb.checkpointPath, xgb.alpha, xgb.missing,
      xgb.checkpointInterval, xgb.seed, xgb.useExternalMemory, xgb.baseScore))
  }

  it should "throw a requirement error if an improper min value is passed in for exponential scale" in {
    intercept[IllegalArgumentException]( new RandomParamBuilder()
      .exponential(xgb.baseScore, 0, 1)).getMessage() shouldBe
      "requirement failed: Min value must be greater than zero for exponential distribution to work"
  }

  it should "throw a requirement error if an min max are passed in" in {
    intercept[IllegalArgumentException]( new RandomParamBuilder()
      .uniform(xgb.baseScore, 1, 0)).getMessage() shouldBe
      "requirement failed: Min must be less than max"
  }
} 
Example 174
Source File: PredictionDeIndexerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.preparators


import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryLambdaTransformer
import com.salesforce.op.stages.impl.feature.OpStringIndexerNoFilter
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class PredictionDeIndexerTest extends FlatSpec with TestSparkContext {

  val data = Seq(("a", 0.0), ("b", 1.0), ("c", 2.0)).map { case (txt, num) => (txt.toText, num.toRealNN) }
  val (ds, txtF, numF) = TestFeatureBuilder(data)

  val response = txtF.indexed()
  val indexedData = response.originStage.asInstanceOf[OpStringIndexerNoFilter[_]].fit(ds).transform(ds)

  val permutation = new UnaryLambdaTransformer[RealNN, RealNN](
    operationName = "modulo",
    transformFn = v => ((v.value.get + 1).toInt % 3).toRealNN
  ).setInput(response)
  val pred = permutation.getOutput()
  val permutedData = permutation.transform(indexedData)

  val expected = Array("b", "c", "a").map(_.toText)

  Spec[PredictionDeIndexer] should "deindexed the feature correctly" in {
    val predDeIndexer = new PredictionDeIndexer().setInput(response, pred)
    val deIndexed = predDeIndexer.getOutput()

    val results = predDeIndexer.fit(permutedData).transform(permutedData).collect(deIndexed)
    results shouldBe expected
  }


  it should "throw a nice error when there is no metadata" in {
    val predDeIndexer = new PredictionDeIndexer().setInput(numF, pred)
    the[Error] thrownBy {
      predDeIndexer.fit(permutedData).transform(permutedData)
    } should have message
      s"The feature ${numF.name} does not contain any label/index mapping in its metadata"
  }
} 
Example 175
Source File: MinVarianceFilterMetadataTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.preparators

import com.salesforce.op.stages.impl.preparators.MinVarianceSummary.statisticsFromMetadata
import com.salesforce.op.test.TestSparkContext
import com.salesforce.op.utils.spark.RichMetadata._
import org.apache.spark.sql.types.Metadata
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class MinVarianceFilterMetadataTest extends FlatSpec with TestSparkContext {

  val summary = MinVarianceSummary(
    dropped = Seq("f1"),
    featuresStatistics = SummaryStatistics(3, 0.01, Seq(0.1, 0.2, 0.3), Seq(0.1, 0.2, 0.3),
      Seq(0.1, 0.2, 0.3), Seq(0.1, 0.2, 0.3)),
    names = Seq("f1", "f2", "f3")
  )

  Spec[MinVarianceSummary] should "convert to and from metadata correctly" in {
    val meta = summary.toMetadata()
    meta.isInstanceOf[Metadata] shouldBe true

    val retrieved = MinVarianceSummary.fromMetadata(meta)
    retrieved.isInstanceOf[MinVarianceSummary]

    retrieved.dropped should contain theSameElementsAs summary.dropped
    retrieved.featuresStatistics.count shouldBe summary.featuresStatistics.count
    retrieved.featuresStatistics.max should contain theSameElementsAs summary.featuresStatistics.max
    retrieved.featuresStatistics.min should contain theSameElementsAs summary.featuresStatistics.min
    retrieved.featuresStatistics.mean should contain theSameElementsAs summary.featuresStatistics.mean
    retrieved.featuresStatistics.variance should contain theSameElementsAs summary.featuresStatistics.variance
    retrieved.names should contain theSameElementsAs summary.names
  }

  it should "convert to and from JSON and give the same values" in {
    val meta = summary.toMetadata()
    val json = meta.wrapped.prettyJson
    val recovered = Metadata.fromJson(json).wrapped
    val dropped = recovered.getArray[String](MinVarianceNames.Dropped).toSeq
    val stats = statisticsFromMetadata(recovered.get[Metadata](MinVarianceNames.FeaturesStatistics))
    val names = recovered.getArray[String](MinVarianceNames.Names).toSeq

    dropped should contain theSameElementsAs summary.dropped
    stats.count shouldBe summary.featuresStatistics.count
    stats.max should contain theSameElementsAs summary.featuresStatistics.max
    stats.min should contain theSameElementsAs summary.featuresStatistics.min
    stats.mean should contain theSameElementsAs summary.featuresStatistics.mean
    stats.variance should contain theSameElementsAs summary.featuresStatistics.variance
    names should contain theSameElementsAs summary.names
  }

} 
Example 176
Source File: OpRegressionModelTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.regression

import com.salesforce.op.features.types.{Prediction, RealNN}
import com.salesforce.op.stages.sparkwrappers.specific.SparkModelConverter.toOP
import com.salesforce.op.test._
import com.salesforce.op.testkit._
import ml.dmlc.xgboost4j.scala.spark.{OpXGBoost, OpXGBoostQuietLogging, XGBoostRegressor}
import org.apache.spark.ml.regression._
import org.apache.spark.sql.DataFrame
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpRegressionModelTest extends FlatSpec with TestSparkContext with OpXGBoostQuietLogging {

  private val label = RandomIntegral.integrals(0, 2).limit(1000)
    .map{ v => RealNN(v.value.map(_.toDouble).getOrElse(0.0)) }
  private val fv = RandomVector.binary(10, 0.3).limit(1000)

  private val data = label.zip(fv)

  private val (rawDF, labelF, featureV) = TestFeatureBuilder("label", "features", data)

  Spec[OpDecisionTreeRegressionModel] should "produce the same values as the spark version" in {
    val spk = new DecisionTreeRegressor()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpLinearRegressionModel] should "produce the same values as the spark version" in {
    val spk = new LinearRegression()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpGBTRegressionModel] should "produce the same values as the spark version" in {
    val spk = new GBTRegressor()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpRandomForestRegressionModel] should "produce the same values as the spark version" in {
    val spk = new RandomForestRegressor()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpGeneralizedLinearRegressionModel] should "produce the same values as the spark version" in {
    val spk = new GeneralizedLinearRegression()
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
      .fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  Spec[OpXGBoostRegressionModel] should "produce the same values as the spark version" in {
    val reg = new XGBoostRegressor()
    reg.set(reg.trackerConf, OpXGBoost.DefaultTrackerConf)
      .setFeaturesCol(featureV.name)
      .setLabelCol(labelF.name)
    val spk = reg.fit(rawDF)

    val op = toOP(spk, spk.uid).setInput(labelF, featureV)
    compareOutputs(spk.transform(rawDF), op.transform(rawDF))
  }

  def compareOutputs(df1: DataFrame, df2: DataFrame): Unit = {
    val sorted1 = df1.collect().sortBy(_.getAs[Double](2))
    val sorted2 = df2.collect().sortBy(_.getAs[Map[String, Double]](2)(Prediction.Keys.PredictionName))
    sorted1.zip(sorted2).foreach{ case (r1, r2) =>
      val map = r2.getAs[Map[String, Double]](2)
      r1.getAs[Double](2) shouldEqual map(Prediction.Keys.PredictionName)
    }
  }
} 
Example 177
Source File: OpPipelineStageReaderWriterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages

import com.salesforce.op.features._
import com.salesforce.op.features.types._
import com.salesforce.op.stages.OpPipelineStageReaderWriter._
import com.salesforce.op.test.PassengerSparkFixtureTest
import com.salesforce.op.utils.reflection.ReflectionUtils
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.{Model, Transformer}
import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder}
import org.json4s.JsonAST.JValue
import org.json4s.jackson.JsonMethods.{compact, parse, pretty, render}
import org.json4s.{JArray, JObject}
import org.scalatest.FlatSpec
import org.slf4j.LoggerFactory


// TODO: consider adding a read/write test for a spark wrapped stage as well
private[stages] abstract class OpPipelineStageReaderWriterTest
  extends FlatSpec with PassengerSparkFixtureTest {

  val meta = new MetadataBuilder().putString("foo", "bar").build()
  val expectedFeaturesLength = 1
  def stage: OpPipelineStageBase with Transformer
  val expected: Array[Real]
  val hasOutputName = true

  private val log = LoggerFactory.getLogger(this.getClass)
  private lazy val savePath = tempDir + "/" + this.getClass.getSimpleName + "-" + System.currentTimeMillis()
  private lazy val writer = new OpPipelineStageWriter(stage)
  private lazy val stageJsonString: String = writer.writeToJsonString(savePath)
  private lazy val stageJson: JValue = parse(stageJsonString)
  private lazy val isModel = stage.isInstanceOf[Model[_]]
  private val FN = FieldNames

  Spec(this.getClass) should "write stage uid" in {
    log.info(pretty(stageJson))
    (stageJson \ FN.Uid.entryName).extract[String] shouldBe stage.uid
  }
  it should "write class name" in {
    (stageJson \ FN.Class.entryName).extract[String] shouldBe stage.getClass.getName
  }
  it should "write params map" in {
    val params = extractParams(stageJson).extract[Map[String, Any]]
    if (hasOutputName) {
      params should have size 4
      params.keys shouldBe Set("inputFeatures", "outputMetadata", "inputSchema", "outputFeatureName")
    } else {
      params should have size 3
      params.keys shouldBe Set("inputFeatures", "outputMetadata", "inputSchema")
    }
  }
  it should "write outputMetadata" in {
    val params = extractParams(stageJson)
    val metadataStr = compact(render(extractParams(stageJson) \ "outputMetadata"))
    val metadata = Metadata.fromJson(metadataStr)
    metadata shouldBe stage.getMetadata()
  }
  it should "write inputSchema" in {
    val schemaStr = compact(render(extractParams(stageJson) \ "inputSchema"))
    val schema = DataType.fromJson(schemaStr)
    schema shouldBe stage.getInputSchema()
  }
  it should "write input features" in {
    val jArray = (extractParams(stageJson) \ "inputFeatures").extract[JArray]
    jArray.values should have length expectedFeaturesLength
    val obj = jArray(0).extract[JObject]
    obj.values.keys shouldBe Set("name", "isResponse", "isRaw", "uid", "typeName", "stages", "originFeatures")
  }
  it should "write model ctor args" in {
    if (stage.isInstanceOf[Model[_]]) {
      val ctorArgs = (stageJson \ FN.CtorArgs.entryName).extract[JObject]
      val (_, args) = ReflectionUtils.bestCtorWithArgs(stage)
      ctorArgs.values.keys shouldBe args.map(_._1).toSet
    }
  }
  it should "load stage correctly" in {
    val reader = new OpPipelineStageReader(stage)
    val stageLoaded = reader.loadFromJsonString(stageJsonString, path = savePath)
    stageLoaded shouldBe a[OpPipelineStageBase]
    stageLoaded shouldBe a[Transformer]
    stageLoaded.getOutput() shouldBe a[FeatureLike[_]]
    val _ = stage.asInstanceOf[Transformer].transform(passengersDataSet)
    val transformed = stageLoaded.asInstanceOf[Transformer].transform(passengersDataSet)
    transformed.collect(stageLoaded.getOutput().asInstanceOf[FeatureLike[Real]]) shouldBe expected
    stageLoaded.uid shouldBe stage.uid
    stageLoaded.operationName shouldBe stage.operationName
    stageLoaded.getInputFeatures() shouldBe stage.getInputFeatures()
    stageLoaded.getInputSchema() shouldBe stage.getInputSchema()
  }

  private def extractParams(stageJson: JValue): JValue = {
    val defaultParamsMap = stageJson \ FN.DefaultParamMap.entryName
    val paramsMap = stageJson \ FN.ParamMap.entryName
    defaultParamsMap.merge(paramsMap)
  }

} 
Example 178
Source File: FeatureJsonHelperTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.features

import com.salesforce.op._
import com.salesforce.op.test.{PassengerFeaturesTest, TestCommon}
import org.json4s.MappingException
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class FeatureJsonHelperTest extends FlatSpec with PassengerFeaturesTest with TestCommon {

  trait DifferentParents {
    val feature = height + weight
    val stages = Map(feature.originStage.uid -> feature.originStage)
    val features = Map(height.uid -> height, weight.uid -> weight)
  }

  trait SameParents {
    val feature = height + height
    val stages = Map(feature.originStage.uid -> feature.originStage)
    val features = Map(height.uid -> height, height.uid -> height)
  }

  Spec(FeatureJsonHelper.getClass) should "serialize/deserialize a feature properly" in new DifferentParents {
    val json = feature.toJson()
    val parsedFeature = FeatureJsonHelper.fromJsonString(json, stages, features)
    if (parsedFeature.isFailure) fail(s"Failed to deserialize from json: $json", parsedFeature.failed.get)

    val res = parsedFeature.get
    res shouldBe a[Feature[_]]
    res.equals(feature) shouldBe true
    res.uid shouldBe feature.uid
    res.wtt.tpe =:= feature.wtt.tpe shouldBe true
  }

  it should "deserialize a set of parent features from one reference" in new SameParents {
    val json = feature.toJson()
    val parsedFeature = FeatureJsonHelper.fromJsonString(feature.toJson(), stages, features)
    if (parsedFeature.isFailure) fail(s"Failed to deserialize from json: $json", parsedFeature.failed.get)

    val res = parsedFeature.get
    res.equals(feature) shouldBe true
    res.wtt.tpe =:= feature.wtt.tpe shouldBe true
  }

  it should "fail to deserialize invalid json" in new DifferentParents {
    val res = FeatureJsonHelper.fromJsonString("{}", stages, features)
    res.isFailure shouldBe true
    res.failed.get shouldBe a[MappingException]
  }

  it should "fail when origin stage is not found" in new DifferentParents {
    val res = FeatureJsonHelper.fromJsonString(feature.toJson(), stages = Map.empty, features)
    res.isFailure shouldBe true
    res.failed.get shouldBe a[RuntimeException]
  }

  it should "fail when not all parents are found" in new DifferentParents {
    val res = FeatureJsonHelper.fromJsonString(feature.toJson(), stages, features = Map.empty)
    res.isFailure shouldBe true
    res.failed.get shouldBe a[RuntimeException]
  }


} 
Example 179
Source File: JobGroupUtilTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.spark

import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner
import com.salesforce.op.test.{TestCommon, TestSparkContext}

@RunWith(classOf[JUnitRunner])
class JobGroupUtilTest extends FlatSpec with TestCommon with TestSparkContext {

  Spec(JobGroupUtil.getClass) should "be able to set a job group ID around a code block" in {
    JobGroupUtil.withJobGroup(OpStep.DataReadingAndFiltering) {
      spark.sparkContext.parallelize(Seq(1, 2, 3, 4, 5)).collect()
    }
    spark.sparkContext.statusTracker.getJobIdsForGroup("DataReadingAndFiltering") should not be empty
  }

  it should "reset the job group ID after a code block" in {
    JobGroupUtil.withJobGroup(OpStep.DataReadingAndFiltering) {
      spark.sparkContext.parallelize(Seq(1, 2, 3, 4, 5)).collect()
    }
    spark.sparkContext.parallelize(Seq(1, 2, 3, 4, 5)).collect()
    // Ensure that the last `.collect()` was not tagged with "DataReadingAndFiltering"
    spark.sparkContext.statusTracker.getJobIdsForGroup(null) should not be empty
  }
} 
Example 180
Source File: OpenNLPSentenceSplitterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.utils.text

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.feature.TextTokenizer
import com.salesforce.op.stages.impl.feature.TextTokenizer.TextTokenizerResult
import com.salesforce.op.test.TestCommon
import com.salesforce.op.utils.text.Language._
import opennlp.tools.sentdetect.SentenceModel
import opennlp.tools.tokenize.TokenizerModel
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpenNLPSentenceSplitterTest extends FlatSpec with TestCommon {

  val splitter = new OpenNLPSentenceSplitter()

  Spec[OpenNLPSentenceSplitter] should "split an English paragraph into sentences" in {
    val input =
      "Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov 29. " +
        "Mr Vinken is chairman of Elsevier N.V., the Dutch publishing group. Rudolph Agnew, 55 years old and " +
        "former chairman of Consolidated Gold Fields PLC, was named a director of this British industrial conglomerate."

    splitter.getSentences(input, language = English) shouldEqual Seq(
      "Pierre Vinken, 61 years old, will join the board as a nonexecutive director Nov 29.",
      "Mr Vinken is chairman of Elsevier N.V., the Dutch publishing group.",
      "Rudolph Agnew, 55 years old and former chairman of Consolidated Gold Fields PLC, " +
        "was named a director of this British industrial conglomerate."
    )

    TextTokenizer.tokenize(input.toText, sentenceSplitter = Option(splitter), defaultLanguage = English) shouldEqual
      TextTokenizerResult(English, Seq(
        Seq("pierr", "vinken", "61", "year", "old", "will", "join", "board",
          "nonexecut", "director", "nov", "29").toTextList,
        Seq("mr", "vinken", "chairman", "elsevi", "n.v", "dutch", "publish", "group").toTextList,
        Seq("rudolph", "agnew", "55", "year", "old", "former", "chairman", "consolid", "gold", "field", "plc",
          "name", "director", "british", "industri", "conglomer").toTextList))

    TextTokenizer.tokenize(input.toText, analyzer = new OpenNLPAnalyzer(), sentenceSplitter = Option(splitter),
      defaultLanguage = English) shouldEqual TextTokenizerResult(
      English, Seq(
        Seq("pierre", "vinken", ",", "61", "years", "old", ",", "will", "join", "the", "board", "as", "a",
          "nonexecutive", "director", "nov", "29", ".").toTextList,
        Seq("mr", "vinken", "is", "chairman", "of", "elsevier", "n", ".v.", ",", "the", "dutch", "publishing",
          "group", ".").toTextList,
        Seq("rudolph", "agnew", ",", "55", "years", "old", "and", "former", "chairman", "of", "consolidated",
          "gold", "fields", "plc", ",", "was", "named", "a", "director", "of", "this", "british", "industrial",
          "conglomerate", ".").toTextList))
  }

  it should "split a Portuguese text into sentences" in {
    // scalastyle:off
    val input = "Depois de Guimarães, o North Music Festival estaciona este ano no Porto. A partir de sexta-feira, " +
      "a Alfândega do Porto recebe a segunda edição deste festival de dois dias. No cartaz há nomes como os " +
      "portugueses Linda Martini e Mão Morta, mas também Guano Apes ou os DJ’s portugueses Rich e Mendes."

    splitter.getSentences(input, language = Portuguese) shouldEqual Seq(
      "Depois de Guimarães, o North Music Festival estaciona este ano no Porto.",
      "A partir de sexta-feira, a Alfândega do Porto recebe a segunda edição deste festival de dois dias.",
      "No cartaz há nomes como os portugueses Linda Martini e Mão Morta, mas também Guano Apes ou os DJ’s " +
        "portugueses Rich e Mendes."
    )
    // scalastyle:on
  }

  it should "load a sentence detection and tokenizer model for a language if they exist" in {
    val languages = Seq(Danish, Portuguese, English, Dutch, German, Sami)
    languages.map { language =>
      OpenNLPModels.getSentenceModel(language).exists(_.isInstanceOf[SentenceModel]) shouldBe true
      OpenNLPModels.getTokenizerModel(language).exists(_.isInstanceOf[TokenizerModel]) shouldBe true
    }
  }

  it should "load not a sentence detection and tokenizer model for a language if they do not exist" in {
    val languages = Seq(Japanese, Czech)
    languages.map { language =>
      OpenNLPModels.getSentenceModel(language) shouldEqual None
      OpenNLPModels.getTokenizerModel(language) shouldEqual None
    }
  }

  it should "return non-preprocessed input if no such a sentence detection model exist" in {
    // scalastyle:off
    val input = "ピエール・ヴィンケン(61歳)は、11月29日に臨時理事に就任します。" +
      "ヴィンケン氏は、オランダの出版グループであるエルゼビアN.V.の会長です。 " +
      "55歳のルドルフ・アグニュー(Rudolph Agnew、元コネチカットゴールドフィールドPLC)会長は、" +
      "この英国の産業大企業の取締役に任命されました。"
    // scalastyle:on
    splitter.getSentences(input, language = Language.Japanese) shouldEqual Seq(input)
  }
} 
Example 181
Source File: RawFeatureFilterResultsComparison.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.filters

import com.salesforce.op.DoubleEquality
import org.scalatest.{FlatSpec, Matchers}


object RawFeatureFilterResultsComparison extends FlatSpec with Matchers with DoubleEquality{

  def compareConfig(c1: RawFeatureFilterConfig, c2: RawFeatureFilterConfig): Unit = {
    c1.minFill shouldBe c2.minFill
    c1.maxFillDifference shouldBe c2.maxFillDifference
    c1.maxFillRatioDiff shouldBe c2.maxFillRatioDiff
    c1.maxJSDivergence shouldBe c2.maxJSDivergence
    c1.maxCorrelation shouldBe c2.maxCorrelation
    c1.correlationType shouldBe c2.correlationType
    c1.jsDivergenceProtectedFeatures shouldBe c2.jsDivergenceProtectedFeatures
    c1.protectedFeatures shouldBe c2.protectedFeatures
  }

  def compareDistributions(d1: FeatureDistribution, d2: FeatureDistribution): Unit = {
    d1.name shouldEqual d2.name
    d1.key shouldEqual d2.key
    d1.count shouldEqual d2.count
    d1.nulls shouldEqual d2.nulls
    d1.distribution shouldEqual d2.distribution
    d1.summaryInfo shouldEqual d2.summaryInfo
  }

  def compareSeqDistributions(d1: Seq[FeatureDistribution], d2: Seq[FeatureDistribution]): Unit = {
    d1.zip(d2).foreach { case (a, b) => compareDistributions(a, b) }
  }

  def compareMetrics(m1: RawFeatureFilterMetrics, m2: RawFeatureFilterMetrics): Unit = {
    m1.name shouldBe m2.name
    m1.key shouldBe m2.key
    m1.trainingFillRate shouldBe m2.trainingFillRate
    m1.trainingNullLabelAbsoluteCorr shouldEqual m2.trainingNullLabelAbsoluteCorr
    m1.scoringFillRate shouldEqual m2.scoringFillRate
    m1.jsDivergence shouldEqual m2.jsDivergence
    m1.fillRateDiff shouldEqual m2.fillRateDiff
    m1.fillRatioDiff shouldEqual m2.fillRatioDiff
  }

  def compareSeqMetrics(m1: Seq[RawFeatureFilterMetrics], m2: Seq[RawFeatureFilterMetrics]): Unit = {
    m1.zip(m2).foreach { case (a, b) => compareMetrics(a, b) }
  }

  def compareExclusionReasons(er1: ExclusionReasons, er2: ExclusionReasons): Unit = {
    er1.name shouldBe er2.name
    er1.key shouldBe er2.key
    er1.trainingUnfilledState shouldBe er2.trainingUnfilledState
    er1.trainingNullLabelLeaker shouldBe er2.trainingNullLabelLeaker
    er1.scoringUnfilledState shouldBe er2.scoringUnfilledState
    er1.jsDivergenceMismatch shouldBe er2.jsDivergenceMismatch
    er1.fillRateDiffMismatch shouldBe er2.fillRateDiffMismatch
    er1.fillRatioDiffMismatch shouldBe er2.fillRatioDiffMismatch
    er1.excluded shouldBe er2.excluded
  }

  def compareSeqExclusionReasons(er1: Seq[ExclusionReasons], er2: Seq[ExclusionReasons]): Unit = {
    er1.zip(er2).foreach { case (a, b) => compareExclusionReasons(a, b) }
  }

  def compare(rff1: RawFeatureFilterResults, rff2: RawFeatureFilterResults): Unit = {
    compareConfig(rff1.rawFeatureFilterConfig, rff2.rawFeatureFilterConfig)
    compareSeqDistributions(rff1.rawFeatureDistributions, rff2.rawFeatureDistributions)
    compareSeqMetrics(rff1.rawFeatureFilterMetrics, rff2.rawFeatureFilterMetrics)
    compareSeqExclusionReasons(rff1.exclusionReasons, rff2.exclusionReasons)
  }
} 
Example 182
Source File: SummaryTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.filters

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SummaryTest extends FlatSpec with TestCommon {
  Spec[Summary] should "be correctly created from a sequence of features" in {
    val f1 = Left(Seq("a", "b", "c"))
    val f2 = Right(Seq(0.5, 1.0))
    val f1s = Summary(f1)
    val f2s = Summary(f2)
    f1s.min shouldBe 3
    f1s.max shouldBe 3
    f1s.sum shouldBe 3
    f1s.count shouldBe 1
    f2s.min shouldBe 0.5
    f2s.max shouldBe 1.0
    f2s.sum shouldBe 1.5
    f2s.count shouldBe 2
  }
} 
Example 183
Source File: OpRegressionEvaluatorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.evaluators

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.classification.OpLogisticRegression
import com.salesforce.op.stages.impl.regression.{OpLinearRegression, RegressionModelSelector}
import com.salesforce.op.stages.impl.selector.ModelSelectorNames.EstimatorType
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tuning.ParamGridBuilder
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpRegressionEvaluatorTest extends FlatSpec with TestSparkContext {

  val (ds, rawLabel, features) = TestFeatureBuilder[RealNN, OPVector](
    Seq(
      (10.0, Vectors.dense(1.0, 4.3, 1.3)),
      (20.0, Vectors.dense(2.0, 0.3, 0.1)),
      (30.0, Vectors.dense(3.0, 3.9, 4.3)),
      (40.0, Vectors.dense(4.0, 1.3, 0.9)),
      (50.0, Vectors.dense(5.0, 4.7, 1.3)),
      (10.0, Vectors.dense(1.0, 4.3, 1.3)),
      (20.0, Vectors.dense(2.0, 0.3, 0.1)),
      (30.0, Vectors.dense(3.0, 3.9, 4.3)),
      (40.0, Vectors.dense(4.0, 1.3, 0.9)),
      (50.0, Vectors.dense(5.0, 4.7, 1.3))
    ).map(v => v._1.toRealNN -> v._2.toOPVector)
  )

  val label = rawLabel.copy(isResponse = true)

  val lr = new OpLogisticRegression()
  val lrParams = new ParamGridBuilder().addGrid(lr.regParam, Array(0.0)).build()

  val testEstimator = RegressionModelSelector.withTrainValidationSplit(dataSplitter = None, trainRatio = 0.5,
    modelsAndParameters = Seq(lr -> lrParams))
    .setInput(label, features)

  val prediction = testEstimator.getOutput()
  val testEvaluator = new OpRegressionEvaluator().setLabelCol(label).setPredictionCol(prediction)

  val testEstimator2 = new OpLinearRegression().setInput(label, features)

  val prediction2 = testEstimator2.getOutput()
  val testEvaluator2 = new OpRegressionEvaluator().setLabelCol(label).setPredictionCol(prediction2)


  Spec[OpRegressionEvaluator] should "copy" in {
    val testEvaluatorCopy = testEvaluator.copy(ParamMap())
    testEvaluatorCopy.uid shouldBe testEvaluator.uid
  }

  it should "evaluate the metrics from a model selector" in {
    val model = testEstimator.fit(ds)
    val transformedData = model.setInput(label, features).transform(ds)
    val metrics = testEvaluator.evaluateAll(transformedData).toMetadata()

    assert(metrics.getDouble(RegressionEvalMetrics.RootMeanSquaredError.toString) <= 1E-12, "rmse should be close to 0")
    assert(metrics.getDouble(RegressionEvalMetrics.MeanSquaredError.toString) <= 1E-24, "mse should be close to 0")
    assert(metrics.getDouble(RegressionEvalMetrics.R2.toString) == 1.0, "R2 should equal 1.0")
    assert(metrics.getDouble(RegressionEvalMetrics.MeanAbsoluteError.toString) <= 1E-12, "mae should be close to 0")
  }

  it should "evaluate the metrics from a single model" in {
    val model = testEstimator2.fit(ds)
    val transformedData = model.setInput(label, features).transform(ds)
    val metrics = testEvaluator2.evaluateAll(transformedData).toMetadata()

    assert(metrics.getDouble(RegressionEvalMetrics.RootMeanSquaredError.toString) <= 1E-12, "rmse should be close to 0")
    assert(metrics.getDouble(RegressionEvalMetrics.MeanSquaredError.toString) <= 1E-24, "mse should be close to 0")
    assert(metrics.getDouble(RegressionEvalMetrics.R2.toString) == 1.0, "R2 should equal 1.0")
    assert(metrics.getDouble(RegressionEvalMetrics.MeanAbsoluteError.toString) <= 1E-12, "mae should be close to 0")
  }
} 
Example 184
Source File: ScalaStyleValidationTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class ScalaStyleValidationTest extends FlatSpec with Matchers with Assertions {
  import scala.Throwable
  private def +(x: Int, y: Int) = x + y
  private def -(x: Int, y: Int) = x - y
  private def *(x: Int, y: Int) = x * y
  private def /(x: Int, y: Int) = x / y
  private def +-(x: Int, y: Int) = x + (-y)
  private def xx_=(y: Int) = println(s"setting xx to $y")

  "bad names" should "never happen" in {
    "def _=abc = ???" shouldNot compile
    true shouldBe true
  }

  "non-ascii" should "not be allowed" in {
//    "def ⇒ = ???" shouldNot compile // it does not even compile as a string
  }

} 
Example 185
Source File: RandomListTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import java.text.SimpleDateFormat

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import com.salesforce.op.testkit.RandomList.{NormalGeolocation, UniformGeolocation}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RandomListTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142136L

  private def check[D, T <: OPList[D]](
    g: RandomList[D, T],
    minLen: Int, maxLen: Int,
    predicate: (D => Boolean) = (_: D) => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    segment count (_.value.length < minLen) shouldBe 0
    segment count (_.value.length > maxLen) shouldBe 0
    segment foreach (list => list.value foreach { x =>
      predicate(x) shouldBe true
    })
  }

  private val df = new SimpleDateFormat("dd/MM/yy")

  Spec[Text, RandomList[String, TextList]] should "generate lists of strings" in {
    val sut = RandomList.ofTexts(RandomText.countries, 0, 4)
    check[String, TextList](sut, 0, 4, _.length > 0)

    (sut limit 7 map (_.value.toList)) shouldBe
      List(
        List("Madagascar", "Gondal", "Zephyria"),
        List("Holy Alliance"),
        List("North American Union"),
        List("Guatemala", "Estonia", "Kolechia"),
        List(),
        List("Myanmar", "Bhutan"),
        List("Equatorial Guinea")
      )
  }

  Spec[Date, RandomList[Long, DateList]] should "generate lists of dates" in {
    val dates = RandomIntegral.dates(df.parse("01/01/2017"), 1000, 1000000)
    val sut = RandomList.ofDates(dates, 11, 22)
    var d0 = 0L
    check[Long, DateList](sut, 11, 22, d => {
      val d1 = d0
      d0 = d
      d > d1
    })
  }

  Spec[DateTimeList, RandomList[Long, DateTimeList]] should "generate lists of datetimes" in {
    val datetimes = RandomIntegral.datetimes(df.parse("01/01/2017"), 1000, 1000000)
    val sut = RandomList.ofDateTimes(datetimes, 11, 22)
    var d0 = 0L
    check[Long, DateTimeList](sut, 11, 22, d => {
      val d1 = d0
      d0 = d
      d > d1
    })
  }

  Spec[UniformGeolocation] should "generate uniformly distributed geolocations" in {
    val sut = RandomList.ofGeolocations
    val segment = sut limit numTries
    segment foreach (_.value.length shouldBe 3)
  }

  Spec[NormalGeolocation] should "generate geolocations around given point" in {
    for {accuracy <- GeolocationAccuracy.values} {
      val geolocation = RandomList.ofGeolocationsNear(37.444136, 122.163160, accuracy)
      val segment = geolocation limit numTries
      segment foreach (_.value.length shouldBe 3)
    }
  }
} 
Example 186
Source File: RandomRealTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.features.types.{Currency, Percent, Real, RealNN}
import com.salesforce.op.test.TestCommon
import com.salesforce.op.testkit.RandomReal._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RandomRealTest extends FlatSpec with TestCommon {
  val numTries = 1000000

  
  // ignore should "cast to default data type" in {
  //  check(uniform(1.0, 2.0), probabilityOfEmpty = 0.5, range = (1.0, 2.0))
  // }

  Spec[RandomReal[Real]]  should "Give Normal distribution with mean 1 sigma 0.1, 10% nulls" in {
    val normalReals = normal[Real](1.0, 0.2)
    check(normalReals, probabilityOfEmpty = 0.1, range = (-2.0, 4.0))
  }

  Spec[RandomReal[Real]] should "Give Uniform distribution on 1..2, half nulls" in {
    check(uniform[Real](1.0, 2.0), probabilityOfEmpty = 0.5, range = (1.0, 2.0))
  }

  it should "Give Poisson distribution with mean 4, 20% nulls" in {
    check(poisson[Real](4.0), probabilityOfEmpty = 0.2, range = (0.0, 15.0))
  }

  it should "Give Exponential distribution with mean 1, 1% nulls" in {
    check(exponential[Real](1.0), probabilityOfEmpty = 0.01, range = (0.0, 15.0))
  }

  it should "Give Gamma distribution with mean 5, 0% nulls" in {
    check(gamma[Real](5.0), probabilityOfEmpty = 0.0, range = (0.0, 25.0))
  }

  it should "Give LogNormal distribution with mean 0.25, 20% nulls" in {
    check(logNormal[Real](0.25, 0.001), probabilityOfEmpty = 0.7, range = (0.1, 15.0))
  }

  it should "Weibull distribution (4.0, 5.0), 20% nulls" in {
    check(weibull[Real](4.0, 5.0), probabilityOfEmpty = 0.2, range = (0.0, 15.0))
  }

  Spec[RandomReal[RealNN]] should "give no nulls" in {
    check(normal[RealNN](1.0, 0.2), probabilityOfEmpty = 0.0, range = (-2.0, 4.0))
  }

  Spec[RandomReal[Currency]] should "distribute money normally" in {
    check(normal[Currency](1.0, 0.2), probabilityOfEmpty = 0.5, range = (-2.0, 4.0))
  }

  Spec[RandomReal[Percent]] should "distribute percentage evenly" in {
    check(uniform[Percent](1.0, 2.0), probabilityOfEmpty = 0.5, range = (0.0, 2.0))
  }

  private val rngSeed = 7688721

  private def check[T <: Real](
    src: RandomReal[T],
    probabilityOfEmpty: Double,
    range: (Double, Double)) = {
    val sut = src withProbabilityOfEmpty probabilityOfEmpty
    sut reset rngSeed

    val found = sut.next
    sut reset rngSeed
    val foundAfterReseed = sut.next
    if (foundAfterReseed != found) {
      sut.reset(rngSeed)
    }
    withClue(s"generator reset did not work for $sut") {
      foundAfterReseed shouldBe found
    }
    sut reset rngSeed

    val numberOfNulls = sut limit numTries count (_.isEmpty)

    val expectedNumberOfNulls = probabilityOfEmpty * numTries
    withClue(s"numNulls = $numberOfNulls, expected $expectedNumberOfNulls") {
      math.abs(numberOfNulls - expectedNumberOfNulls) < numTries / 100 shouldBe true
    }

    val numberOfOutliers = sut limit numTries count (xOpt => xOpt.value.exists(x => x < range._1 || x > range._2))

    numberOfOutliers should be < (numTries / 1000)

  }
} 
Example 187
Source File: InfiniteStreamTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class InfiniteStreamTest extends FlatSpec with TestCommon {

  Spec[InfiniteStream[_]] should "map" in {
    var i = 0
    val src = new InfiniteStream[Int] {
      override def next: Int = {
        i += 1;
        i
      }
    }

    val sut = src map (5 +)

    while (i < 10) sut.next shouldBe (i + 5)
  }

} 
Example 188
Source File: RandomIntegralTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import java.text.SimpleDateFormat

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps

@RunWith(classOf[JUnitRunner])
class RandomIntegralTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142135L

  private def check[T <: Integral](
    g: RandomIntegral[T],
    predicate: Long => Boolean = _ => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    val numberOfEmpties = segment count (_.isEmpty)

    val expectedNumberOfEmpties = g.probabilityOfEmpty * numTries

    withClue(s"numEmpties = $numberOfEmpties, expected $expectedNumberOfEmpties") {
      math.abs(numberOfEmpties - expectedNumberOfEmpties) < 2 * math.sqrt(numTries) shouldBe true
    }

    val maybeValues = segment filterNot (_.isEmpty) map (_.value)
    val values = maybeValues collect { case Some(s) => s }

    values foreach (x => predicate(x) shouldBe true)

    withClue(s"number of distinct values = ${values.size}, expected:") {
      math.abs(maybeValues.size - values.toSet.size) < maybeValues.size / 20
    }

  }

  private val df = new SimpleDateFormat("dd/MM/yy")

  Spec[RandomIntegral[Integral]] should "generate empties and distinct numbers" in {
    val sut0 = RandomIntegral.integrals
    val sut = sut0.withProbabilityOfEmpty(0.3)
    check(sut)
    sut.probabilityOfEmpty shouldBe 0.3
  }

  Spec[RandomIntegral[Integral]] should "generate empties and distinct numbers in some range" in {
    val sut0 = RandomIntegral.integrals(100, 200)
    val sut = sut0.withProbabilityOfEmpty(0.3)
    check(sut, i => i >= 100 && i < 200)
    sut.probabilityOfEmpty shouldBe 0.3
  }

  Spec[RandomIntegral[Date]] should "generate dates" in {
    val sut = RandomIntegral.dates(df.parse("01/01/2017"), 1000, 1000000)
    var d0 = 0L
    check(sut withProbabilityOfEmpty 0.01, d => {
      val d1 = d0
      d0 = d
      d0 > d1
    })
  }

  Spec[RandomIntegral[DateTime]] should "generate dates with times" in {
    val sut = RandomIntegral.datetimes(df.parse("08/24/2017"), 1000, 1000000)
    var d0 = 0L
    check(sut withProbabilityOfEmpty 0.001, d => {
      val d1 = d0
      d0 = d
      d0 > d1
    })
  }
} 
Example 189
Source File: RandomBinaryTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.features.types.Binary
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RandomBinaryTest extends FlatSpec with TestCommon {
  val numTries = 1000000
  val rngSeed = 12345

  private def truthWithProbability(probabilityOfTrue: Double) = {
    RandomBinary(probabilityOfTrue)
  }

  Spec[RandomBinary] should "generate empties, truths and falses" in {
    check(truthWithProbability(0.5) withProbabilityOfEmpty 0.5)
    check(truthWithProbability(0.3) withProbabilityOfEmpty 0.65)
    check(truthWithProbability(0.0) withProbabilityOfEmpty 0.1)
    check(truthWithProbability(1.0) withProbabilityOfEmpty 0.0)
  }

  private def check(g: RandomBinary) = {
    g reset rngSeed
    val numberOfEmpties = g limit numTries count (_.isEmpty)
    val expectedNumberOfEmpties = g.probabilityOfEmpty * numTries
    withClue(s"numEmpties = $numberOfEmpties, expected $expectedNumberOfEmpties") {
      math.abs(numberOfEmpties - expectedNumberOfEmpties) < numTries / 100 shouldBe true
    }

    val expectedNumberOfTruths = g.probabilityOfSuccess * (1 - g.probabilityOfEmpty) * numTries
    val numberOfTruths = g limit numTries count (Binary(true) ==)
    withClue(s"numTruths = $numberOfTruths, expected $expectedNumberOfTruths") {
      math.abs(numberOfTruths - expectedNumberOfTruths) < numTries / 100 shouldBe true
    }
  }
} 
Example 190
Source File: RandomSetTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.testkit

import com.salesforce.op.features.types._
import com.salesforce.op.test.TestCommon
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec}

import scala.language.postfixOps


@RunWith(classOf[JUnitRunner])
class RandomSetTest extends FlatSpec with TestCommon with Assertions {
  private val numTries = 10000
  private val rngSeed = 314159214142136L

  private def check[D, T <: OPSet[D]](
    g: RandomSet[D, T],
    minLen: Int, maxLen: Int,
    predicate: (D => Boolean) = (_: D) => true
  ) = {
    g reset rngSeed

    def segment = g limit numTries

    segment count (_.value.size < minLen) shouldBe 0
    segment count (_.value.size > maxLen) shouldBe 0
    segment foreach (Set => Set.value foreach { x =>
      predicate(x) shouldBe true
    })
  }

  Spec[MultiPickList] should "generate multipicklists" in {
    val sut = RandomMultiPickList.of(RandomText.countries, maxLen = 5)

    check[String, MultiPickList](sut, 0, 5, _.nonEmpty)

    val expected = List(
      Set(),
      Set("Aldorria", "Palau", "Glubbdubdrib"),
      Set(),
      Set(),
      Set("Sweden", "Wuhu Islands", "Tuvalu")
    )

    {sut reset 42; sut limit 5 map (_.value)} shouldBe expected

    {sut reset 42; sut limit 5 map (_.value)} shouldBe expected
  }

} 
Example 191
Source File: TypeConversionTest.scala    From spark-hbase-connector   with Apache License 2.0 5 votes vote down vote up
package com.user.integration

import java.util.UUID

import it.nerdammer.spark.hbase._
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}

class TypeConversionTest extends FlatSpec with Matchers with BeforeAndAfterAll  {

  val tables = Seq(UUID.randomUUID().toString, UUID.randomUUID().toString)
  val columnFamilies = Seq("cfconv", "cfconv2")

  override def beforeAll() = {
    (tables zip columnFamilies) foreach (t => {
      IntegrationUtils.createTable(t._1, t._2)
    })
  }

  override def afterAll() = {
    tables.foreach(table => IntegrationUtils.dropTable(table))
  }

  "type conversion" should "work" in {

    val sc = IntegrationUtils.sparkContext

    sc.parallelize(1 to 100)
      .map(i => (i.toString, i, i.toShort, i.toLong, i % 2 == 0, i.toDouble, i.toFloat, BigDecimal(i), i.toString))
      .toHBaseTable(tables(0)).toColumns("col-int", "col-sho", "col-lon", "col-boo", "col-dou", "col-flo", "col-big", "col-str")
      .inColumnFamily(columnFamilies(0))
      .save()


    val retrieved = sc.hbaseTable[(String, Int, Short, Long, Boolean, Double, Float, BigDecimal, String)](tables(0))
      .select("col-int", "col-sho", "col-lon", "col-boo", "col-dou", "col-flo", "col-big", "col-str")
      .inColumnFamily(columnFamilies(0))
      .sortBy(_._1.toInt)
      .collect()

    val cmp = (1 to 100) zip retrieved

    cmp.foreach(p => {
      p._1 should be(p._2._2)
      p._1.toShort should be(p._2._3)
      p._1.toLong should be(p._2._4)
      (p._1 % 2 == 0) should be(p._2._5)
      p._1.toDouble should be(p._2._6)
      p._1.toFloat should be(p._2._7)
      BigDecimal(p._1) should be(p._2._8)
      p._1.toString should be(p._2._9)
    })

  }

  "type conversion" should "support empty values" in {

    val sc = IntegrationUtils.sparkContext

    sc.parallelize(1 to 100)
      .map(i => (i.toString, i, None.asInstanceOf[Option[Short]]))
      .toHBaseTable(tables(1))
      .inColumnFamily(columnFamilies(1))
      .toColumns("myint", "myshort")
      .save()


    val chk = sc.hbaseTable[(String, Option[Int], Option[Short], Option[Long], Option[Boolean], Option[Double], Option[Float], Option[BigDecimal], Option[String])](tables(1))
      .inColumnFamily(columnFamilies(1))
      .select("myint", "myshort", "mynonexistentlong", "mynonexistentbool", "mynonexistentdouble", "mynonexistentfloat", "mynonexistentbigd", "mynonexistentstr")
      .filter(r => r match {
      case (s, Some(i), None, None, None, None, None, None, None) => true
      case _ => false
      })
      .count

    chk should be (100)

  }

} 
Example 192
Source File: HostAndPortSpec.scala    From akka-persistence-redis   with Apache License 2.0 5 votes vote down vote up
package akka.persistence.utils

import org.scalatest.{ FlatSpec, Matchers }

class HostAndPortSpec extends FlatSpec with Matchers {

  "A HostAndPort" should "have localhost as host and 8080 as port" in {
    val hostAndPort = HostAndPort("localhost:8080")
    assert(hostAndPort.host == "localhost")
    assert(hostAndPort.port == 8080)
  }

  it should "return a tuple (String, Int)" in {
    val hostAndPort = HostAndPort("my.host.name:10001")
    assert(hostAndPort.asTuple == ("my.host.name", 10001))
  }

  it should "return 8080 instead of 9090" in {
    val hostAndPort = HostAndPort("localhost:8080")
    assert(hostAndPort.portOrDefault(9090) == 8080)
  }

} 
Example 193
Source File: RedisUtilsSpec.scala    From akka-persistence-redis   with Apache License 2.0 5 votes vote down vote up
package akka.persistence.redis

import com.typesafe.config.ConfigFactory
import org.scalatest.{ FlatSpec, Matchers }

class RedisUtilsSpec extends FlatSpec with Matchers {

  "A RedisClient" should "use redis.sentinels as config if sentinel-list is also present" in {
    val config = ConfigFactory.parseString(s"""
         |redis {
         |  mode = sentinel
         |  master = foo
         |  database = 0
         |  sentinel-list = "host1:1234,host2:1235"
         |  sentinels = [
         |   {
         |     host = "host3"
         |     port = 1236
         |    },
         |    {
         |      host = "host4"
         |      port = 1237
         |     }
         |   ]
         |}""".stripMargin)
    val sentinels = RedisUtils.sentinels(config)
    sentinels should be(List(("host3", 1236), ("host4", 1237)))
  }

  it should "use redis.sentinel-list if redis.sentinels is not present" in {
    val config = ConfigFactory.parseString(s"""
      |redis {
      |  mode = sentinel
      |  master = foo
      |  database = 0
      |  sentinel-list = "host1:1234,host2:1235"
      |}""".stripMargin)
    val sentinels = RedisUtils.sentinels(config)
    sentinels should be(List(("host1", 1234), ("host2", 1235)))
  }

} 
Example 194
Source File: FilterUDFSpec.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine

import org.scalatest.{FlatSpec, Matchers}

class FilterUDFSpec extends FlatSpec with Matchers with BaseSivaSpec with BaseSparkSpec {

  var engine: Engine = _

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    engine = Engine(ss, resourcePath, "siva")
  }

  "Filter by language" should "work properly" in {
    val langDf = engine
      .getRepositories
      .getReferences
      .getCommits
      .getBlobs
      .classifyLanguages

    val filteredLang = langDf.select("repository_id", "path", "lang").where("lang='Python'")
    filteredLang.count() should be(6)
  }

  override protected def afterAll(): Unit = {
    super.afterAll()
    engine = _: Engine
  }
} 
Example 195
Source File: RepositoryRDDProviderSpec.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine.provider

import java.nio.file.{Path, Paths}
import java.util.UUID

import org.apache.commons.io.FileUtils
import org.scalatest.{BeforeAndAfterEach, FlatSpec, Matchers}
import tech.sourced.engine.util.RepoUtils
import tech.sourced.engine.{BaseSivaSpec, BaseSparkSpec}

class RepositoryRDDProviderSpec extends FlatSpec with Matchers with BeforeAndAfterEach
  with BaseSparkSpec with BaseSivaSpec {

  private var provider: RepositoryRDDProvider = _
  private var tmpPath: Path = _

  override def beforeEach(): Unit = {
    super.beforeEach()
    provider = RepositoryRDDProvider(ss.sparkContext)
    tmpPath = Paths.get(
      System.getProperty("java.io.tmpdir"),
      UUID.randomUUID().toString
    )
  }

  override def afterEach(): Unit = {
    super.afterEach()

    FileUtils.deleteQuietly(tmpPath.toFile)
  }

  "RepositoryRDDProvider" should "retrieve bucketized raw repositories" in {
    tmpPath.resolve("a").toFile.mkdir()
    createRepo(tmpPath.resolve("a").resolve("repo"))

    tmpPath.resolve("b").toFile.mkdir()
    createRepo(tmpPath.resolve("b").resolve("repo"))

    createRepo(tmpPath.resolve("repo"))

    val repos = provider.get(tmpPath.toString, "standard").collect()
    repos.length should be(3)
  }

  it should "retrieve non-bucketized raw repositories" in {
    tmpPath.resolve("a").toFile.mkdir()
    createRepo(tmpPath.resolve("repo"))

    tmpPath.resolve("b").toFile.mkdir()
    createRepo(tmpPath.resolve("repo2"))

    val repos = provider.get(tmpPath.toString, "standard").collect()
    repos.length should be(2)
  }

  it should "retrieve bucketized siva repositories" in {
    val repos = provider.get(resourcePath, "siva").collect()
    repos.length should be(3)
  }

  it should "retrieve non-bucketized siva repositories" in {
    val repos = provider.get(Paths.get(resourcePath, "ff").toString, "siva").collect()
    repos.length should be(1)
  }

  private def createRepo(path: Path) = {
    val repo = RepoUtils.createRepo(path)
    RepoUtils.commitFile(repo, "file.txt", "something something", "some commit")
  }

} 
Example 196
Source File: StorageLevelSpec.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine

import org.scalatest.{FlatSpec, Matchers}

class StorageLevelSpec  extends FlatSpec with Matchers with BaseSivaSpec with BaseSparkSpec {

  var engine: Engine = _

  override protected def beforeAll(): Unit = {
    super.beforeAll()
    engine = Engine(ss, resourcePath, "siva")
  }

  "A Dataframe" should "work with all storage levels" in {
    import org.apache.spark.storage.StorageLevel._
    val storageLevels = List(
      DISK_ONLY,
      DISK_ONLY_2,
      MEMORY_AND_DISK,
      MEMORY_AND_DISK_2,
      MEMORY_AND_DISK_SER,
      MEMORY_AND_DISK_SER_2,
      MEMORY_ONLY,
      MEMORY_ONLY_2,
      MEMORY_ONLY_SER,
      MEMORY_ONLY_SER_2,
      NONE,
      OFF_HEAP
    )

    storageLevels.foreach(level => {
      val df = engine.getRepositories.persist(level)
      df.count()
      df.unpersist()
    })
  }
} 
Example 197
Source File: FilterSpec.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine.util

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.StringType
import org.scalatest.{FlatSpec, Matchers}

class FilterSpec extends FlatSpec with Matchers {
  "CompiledFilters" should "filter properly depending of his type" in {
    val eq = EqualFilter(Attr("test", ""), "a")

    eq.eval("a") should be(true)
    eq.eval("b") should be(false)

    val notEq = NotFilter(EqualFilter(Attr("test", ""), "a"))

    notEq.eval("a") should be(false)
    notEq.eval("b") should be(true)

    val in = InFilter(Attr("test", ""), Array("a", "b", "c"))

    in.eval("a") should be(true)
    in.eval("b") should be(true)
    in.eval("c") should be(true)
    in.eval("d") should be(false)

    val gt = GreaterThanFilter(Attr("test", ""), 5)

    gt.eval(4) should be(false)
    gt.eval(5) should be(false)
    gt.eval(6) should be(true)

    val gte = GreaterThanOrEqualFilter(Attr("test", ""), 5)

    gte.eval(4) should be(false)
    gte.eval(5) should be(true)
    gte.eval(6) should be(true)

    val lt = LessThanFilter(Attr("test", ""), 5)

    lt.eval(4) should be(true)
    lt.eval(5) should be(false)
    lt.eval(6) should be(false)

    val lte = LessThanOrEqualFilter(Attr("test", ""), 5)

    lte.eval(4) should be(true)
    lte.eval(5) should be(true)
    lte.eval(6) should be(false)
  }

  "ColumnFilter" should "process correctly columns" in {
    // test = 'val' AND test IS NOT NULL AND test2 = 'val2' AND test3 IN ('a', 'b')
    val f = Filter.compile(And(
      And(
        And(
          EqualTo(AttributeReference("test", StringType)(), Literal("val")),
          IsNotNull(AttributeReference("test", StringType)())
        ),
        EqualTo(AttributeReference("test2", StringType)(), Literal("val2"))
      ),
      In(AttributeReference("test3", StringType)(), Seq(Literal("a"), Literal("b")))
    ))

    f.length should be(4)
    val filters = Filters(f)
    filters.matches(Seq("test"), "val") should be(true)
    filters.matches(Seq("test2"), "val") should be(false)
    filters.matches(Seq("test3"), "b") should be(true)
  }

  "ColumnFilter" should "handle correctly unsupported filters" in {
    val f = Filter.compile(StartsWith(AttributeReference("test", StringType)(), Literal("a")))

    f.length should be(0)
  }
} 
Example 198
Source File: ReferenceIteratorSpec.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine.iterator

import org.scalatest.FlatSpec
import tech.sourced.engine.util.{Attr, EqualFilter}

class ReferenceIteratorSpec extends FlatSpec with BaseChainableIterator {

  "ReferenceIterator" should "return all references from all repositories into a siva file" in {
    testIterator(
      new ReferenceIterator(Array("repository_id", "name", "hash"), _, null, Seq(), false), {
        case (0, row) =>
          row.getString(0) should be("github.com/xiyou-linuxer/faq-xiyoulinux")
          row.getString(1) should be("refs/heads/HEAD")
          row.getString(2) should be("fff7062de8474d10a67d417ccea87ba6f58ca81d")
        case (1, row) =>
          row.getString(0) should be("github.com/mawag/faq-xiyoulinux")
          row.getString(1) should be("refs/heads/HEAD")
          row.getString(2) should be("fff7062de8474d10a67d417ccea87ba6f58ca81d")
        case (2, row) =>
          row.getString(0) should be("github.com/xiyou-linuxer/faq-xiyoulinux")
          row.getString(1) should be("refs/heads/develop")
          row.getString(2) should be("880653c14945dbbc915f1145561ed3df3ebaf168")
        case _ =>
      }, total = 43, columnsCount = 3
    )
  }

  it should "return only specified columns" in {
    testIterator(
      new ReferenceIterator(Array("repository_id", "name"), _, null, Seq(), false), {
        case (0, row) =>
          row.getString(0) should be("github.com/xiyou-linuxer/faq-xiyoulinux")
          row.getString(1) should be("refs/heads/HEAD")
        case (1, row) =>
          row.getString(0) should be("github.com/mawag/faq-xiyoulinux")
          row.getString(1) should be("refs/heads/HEAD")
        case (2, row) =>
          row.getString(0) should be("github.com/xiyou-linuxer/faq-xiyoulinux")
          row.getString(1) should be("refs/heads/develop")
        case _ =>
      }, total = 43, columnsCount = 2
    )
  }

  it should "apply passed filters" in {
    testIterator(
      new ReferenceIterator(
        Array("repository_id", "name"),
        _,
        null,
        Seq(EqualFilter(Attr("name", "references"), "refs/heads/develop")),
        false
      ), {
        case (0, row) =>
          row.getString(0) should be("github.com/xiyou-linuxer/faq-xiyoulinux")
          row.getString(1) should be("refs/heads/develop")
        case (1, row) =>
          row.getString(0) should be("github.com/mawag/faq-xiyoulinux")
          row.getString(1) should be("refs/heads/develop")
      }, total = 2, columnsCount = 2
    )
  }

  it should "use previously passed iterator" in {
    testIterator(repo =>
      new ReferenceIterator(
        Array("repository_id", "name"),
        repo,
        new RepositoryIterator(
          "/foo/bar",
          Array("id"),
          repo,
          Seq(EqualFilter(Attr("id", "repository"), "github.com/xiyou-linuxer/faq-xiyoulinux")),
          false
        ),
        Seq(EqualFilter(Attr("name", "references"), "refs/heads/develop")),
        false
      ), {
      case (0, row) =>
        row.getString(0) should be("github.com/xiyou-linuxer/faq-xiyoulinux")
        row.getString(1) should be("refs/heads/develop")
    }, total = 1, columnsCount = 2
    )
  }
} 
Example 199
Source File: MetadataIteratorSpec.scala    From jgit-spark-connector   with Apache License 2.0 5 votes vote down vote up
package tech.sourced.engine.iterator

import java.nio.file.Paths
import java.util.{Properties, UUID}

import org.apache.commons.io.FileUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.types.{Metadata, StringType, StructType}
import org.scalatest.{BeforeAndAfterAll, FlatSpec, Matchers}
import tech.sourced.engine.{BaseSparkSpec, Schema}

class JDBCQueryIteratorSpec
  extends FlatSpec with Matchers with BeforeAndAfterAll with BaseSparkSpec {
  private val tmpPath = Paths.get(
    System.getProperty("java.io.tmpdir"),
    UUID.randomUUID.toString
  )

  private val dbPath = tmpPath.resolve("test.db")

  override def beforeAll(): Unit = {
    super.beforeAll()
    tmpPath.toFile.mkdir()
    val rdd = ss.sparkContext.parallelize(Seq(
      Row("id1"),
      Row("id2"),
      Row("id3")
    ))

    val properties = new Properties()
    properties.put("driver", "org.sqlite.JDBC")
    val df = ss.createDataFrame(rdd, StructType(Seq(Schema.repositories.head)))
    df.write.jdbc(s"jdbc:sqlite:${dbPath.toString}", "repositories", properties)
  }

  override def afterAll(): Unit = {
    super.afterAll()
    FileUtils.deleteQuietly(tmpPath.toFile)
  }

  "JDBCQueryIterator" should "return all rows for the query" in {
    val iter = new JDBCQueryIterator(
      Seq(attr("id")),
      dbPath.toString,
      "SELECT id FROM repositories ORDER BY id"
    )

    // calling hasNext more than one time does not cause rows to be lost
    iter.hasNext
    iter.hasNext
    val rows = (for (row <- iter) yield row).toArray
    rows.length should be(3)
    rows(0).length should be(1)
    rows(0)(0).toString should be("id1")
    rows(1)(0).toString should be("id2")
    rows(2)(0).toString should be("id3")
  }

  private def attr(name: String): Attribute = AttributeReference(
    name, StringType, nullable = false, Metadata.empty
  )()
} 
Example 200
Source File: ExamplesTest.scala    From json-schema-codegen   with Apache License 2.0 5 votes vote down vote up
import java.net.{Inet6Address, InetAddress, Inet4Address, URI}
import java.util.Date

import argonaut.Argonaut._
import argonaut._
import org.scalatest.{FlatSpec, Matchers}
import product.vox.shop._


class ExamplesTest extends FlatSpec with Matchers {


  "AdditionalPropertiesJson" should "encode and decode" in {
    import additional.Codecs._
    test(additional.Properties("bvalue", Some(Map("p1" -> additional.PropertiesAdditional(1)))))
  }

  "AdditionalPropertiesOnlyJson" should "encode and decode" in {
    import additional.properties.Codecs._
    test(additional.properties.Only(Some(Map("p1" -> additional.properties.OnlyAdditional(1)))))
  }

  "EnumStrings" should "encode and decode" in {
    import Codecs._
    test(Strings.a)
    test(Strings.b)
  }

  "EnumIntegers" should "encode and decode" in {
    import Codecs._
    test(Integers.v10)
    test(Integers.v20)
  }

  "Formats" should "encode and decode" in {
    import Codecs._
    test(Formats(
      new URI("http://uri/address"),
      InetAddress.getByName("127.0.0.1").asInstanceOf[Inet4Address],
      InetAddress.getByName("FE80:0000:0000:0000:0202:B3FF:FE1E:8329").asInstanceOf[Inet6Address],
      new Date()
    ))
  }

  "Product" should "decode from string and encode to string" in {
    import product.vox.shop.Codecs._
    val js = """{"name":"Recharge Cards (5 PIN)","prices":[{"cost":0.0187,"currency":"USD","moq":200000}],"eid":"iso-card-5-pin","description":"<p>ISO card, 5 PINs, printed 4 colour front and back</p>\n<p>Every card option shown below meets Tier 1 operator quality standards, at a competitive pricing including freight to your country that’s always openly visible, with streamlined fulfillment and support included, creating what we believe is the best overall value at the lowest total cost of ownership in the industry.</p>\n<p>Material:        Cardboard 300 GSM, UV varnish both sides</p>\n<p>Scratch panel:   Silver/Black Ink with black overprint</p> \n<p>Individually plastic wrapped in chain of 50 cards</p>\n<p>Small boxes of 500 cards, Master Carton of 5000 cards</p>\n<p>Alternate names: Scratch cards, RCV, top-up cards</p>\n","properties":[{"name":"Overscratch Protection","options":[{"name":"No protection"},{"name":"Protective measures against over scratching","prices":[{"cost":0.0253,"currency":"USD","moq":200000},{"cost":0.021,"currency":"USD","moq":500000},{"cost":0.02,"currency":"USD","moq":1000000},{"cost":0.0188,"currency":"USD","moq":5000000,"leadtime":21},{"cost":0.0173,"currency":"USD","moq":10000000},{"cost":0.0171,"currency":"USD","moq":50000000,"leadtime":28}]}]},{"name":"Payment terms","options":[{"name":"Payment on shipment readiness"},{"name":"Net 30 (subject to approval)"}]},{"name":"Order Timing","options":[{"name":"Ship order when ready"},{"name":"Pre-order for shipment in 3 months"}]}],"client":"112","sample":{"price":{"cost":250,"currency":"USD"}},"category":"recharge_cards","leadtime":14,"imageUrl":["https://d2w2n7dk76p3lq.cloudfront.net/product_image/recharge_cards/iso-5pin.png"],"types":[{"name":"Recharge Cards (5 PIN)","prices":[{"cost":0.0187,"currency":"USD","moq":200000},{"cost":0.0175,"currency":"USD","moq":500000},{"cost":0.0162,"currency":"USD","moq":1000000},{"cost":0.0153,"currency":"USD","moq":5000000,"leadtime":21},{"cost":0.0138,"currency":"USD","moq":10000000,"leadtime":28},{"cost":0.0137,"currency":"USD","moq":50000000,"leadtime":28}]}],"presentation":1000}"""
    val po = js.decodeValidation[Product]
    println(po)
    po.isSuccess shouldBe true
    test(po.toOption.get)
  }


  def test[T: CodecJson](value: T) = {
    val json = value.asJson
    println(json)
    json.jdecode[T] shouldBe DecodeResult.ok(value)
  }
}