org.junit.runner.RunWith Scala Examples

The following examples show how to use org.junit.runner.RunWith. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: RegressITCase.scala    From flink-tensorflow   with Apache License 2.0 6 votes vote down vote up
package org.apache.flink.contrib.tensorflow.ml

import com.twitter.bijection.Conversion._
import org.apache.flink.api.common.functions.RichFlatMapFunction
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.contrib.tensorflow.ml.signatures.RegressionMethod._
import org.apache.flink.contrib.tensorflow.types.TensorInjections.{message2Tensor, messages2Tensor}
import org.apache.flink.contrib.tensorflow.util.TestData._
import org.apache.flink.contrib.tensorflow.util.{FlinkTestBase, RegistrationUtils}
import org.apache.flink.core.fs.Path
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.util.Collector
import org.apache.flink.util.Preconditions.checkState
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}
import org.tensorflow.Tensor
import org.tensorflow.contrib.scala.Arrays._
import org.tensorflow.contrib.scala.Rank._
import org.tensorflow.contrib.scala._
import org.tensorflow.example.Example
import resource._

@RunWith(classOf[JUnitRunner])
class RegressITCase extends WordSpecLike
  with Matchers
  with FlinkTestBase {

  override val parallelism = 1

  type LabeledExample = (Example, Float)

  def examples(): Seq[LabeledExample] = {
    for (v <- Seq(0.0f -> 2.0f, 1.0f -> 2.5f, 2.0f -> 3.0f, 3.0f -> 3.5f))
      yield (example("x" -> feature(v._1)), v._2)
  }

  "A RegressFunction" should {
    "process elements" in {
      val env = StreamExecutionEnvironment.getExecutionEnvironment
      RegistrationUtils.registerTypes(env.getConfig)

      val model = new HalfPlusTwo(new Path("../models/half_plus_two"))

      val outputs = env
        .fromCollection(examples())
        .flatMap(new RichFlatMapFunction[LabeledExample, Float] {
          override def open(parameters: Configuration): Unit = model.open()
          override def close(): Unit = model.close()

          override def flatMap(value: (Example, Float), out: Collector[Float]): Unit = {
            for {
              x <- managed(Seq(value._1).toList.as[Tensor].taggedAs[ExampleTensor])
              y <- model.regress_x_to_y(x)
            } {
              // cast as a 1D tensor to use the available conversion
              val o = y.taggedAs[TypedTensor[`1D`,Float]].as[Array[Float]]
              val actual = o(0)
              checkState(actual == value._2)
              out.collect(actual)
            }
          }
        })
        .print()

      env.execute()
    }
  }
} 
Example 2
Source File: LoggedUserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.core.models

import com.stratio.sparta.serving.core.models.dto.{LoggedUser, LoggedUserConstant}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class LoggedUserTest extends WordSpec with Matchers {

  val dummyGroupID = "66"

  "An input String" when {
    "containing a well-formed JSON" should {
      "be correctly transformed into a LoggedUser" in {
        val objectUser = LoggedUser("1234-qwerty", "user1",
          LoggedUserConstant.dummyMail, dummyGroupID, Seq.empty[String], Seq("admin"))
        val stringJson =
          """
        {"id":"1234-qwerty",
        "attributes":[
          {"cn":"user1"},
          {"mail":"[email protected]"},
          {"gidNumber":"66"},
          {"groups":[]},
          {"roles":["admin"]}
        ]}"""

        val parsedUser = LoggedUser.jsonToDto(stringJson)
        parsedUser shouldBe defined
        parsedUser.get should equal(objectUser)
      }
    }
  }

  "An input String" when {
    "has missing fields" should {
      "be correctly parsed " in {
        val stringSparta =
          """{"id":"sparta","attributes":[
          |{"cn":"sparta"},
          |{"mail":"[email protected]"},
          |{"groups":["Developers"]},
          |{"roles":[]}]}""".stripMargin
        val parsedUser = LoggedUser.jsonToDto(stringSparta)
        val objectUser = LoggedUser("sparta", "sparta",
          "[email protected]", "", Seq("Developers"), Seq.empty[String])
        parsedUser shouldBe defined
        parsedUser.get should equal (objectUser)
      }
    }
  }


  "An input String" when {
    "is empty" should {
      "be transformed into None" in {
        val stringJson = ""
        val parsedUser = LoggedUser.jsonToDto(stringJson)
        parsedUser shouldBe None
      }
    }
  }

  "A user" when {
    "Oauth2 security is enabled" should {
      "be authorized only if one of its roles is contained inside allowedRoles" in {
        val objectUser = LoggedUser("1234-qwerty", "user1",
          LoggedUserConstant.dummyMail, dummyGroupID, Seq.empty[String], Seq("admin"))
        objectUser.isAuthorized(securityEnabled = true, allowedRoles = Seq("admin")) === true &&
          objectUser.isAuthorized(securityEnabled = true,
            allowedRoles = Seq("OtherAdministratorRole", "dummyUser")) === false
      }
    }
  }

  "A user" when {
    "Oauth2 security is disabled" should {
      "always be authorized" in {
        val objectUser = LoggedUser("1234-qwerty", "user1",
          LoggedUserConstant.dummyMail, dummyGroupID, Seq.empty[String], Seq("admin"))
        objectUser.isAuthorized(securityEnabled = false, allowedRoles = LoggedUserConstant.allowedRoles) === true
      }
    }
  }

} 
Example 3
Source File: ErrorsModelTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.core.models

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ErrorsModelTest extends WordSpec with Matchers {

  val error = new ErrorModel("100", "Error 100", None, None)

  "ErrorModel" should {

    "toString method should return the number of the error and the error" in {
      val res = ErrorModel.toString(error)
      res should be ("""{"i18nCode":"100","message":"Error 100"}""")
    }

    "toError method should return the number of the error and the error" in {
      val res = ErrorModel.toErrorModel(
        """
          |{
          | "i18nCode": "100",
          | "message": "Error 100"
          |}
        """.stripMargin)
      res should be (error)
    }
  }
} 
Example 4
Source File: ServingExceptionTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.core.exception

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ServingExceptionTest extends WordSpec with Matchers {

  "A ServingException" should {
    "create an exception with message" in {
      ServingCoreException.create("message").getMessage should be("message")
    }
    "create an exception with message and a cause" in {
      val cause = new IllegalArgumentException("any exception")
      val exception = ServingCoreException.create("message", cause)
      exception.getMessage should be("message")
      exception.getCause should be theSameInstanceAs(cause)
    }
  }
} 
Example 5
Source File: SparkContextFactoryTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.test.factory

import com.stratio.sparta.driver.factory.SparkContextFactory
import com.stratio.sparta.serving.core.config.SpartaConfig
import com.stratio.sparta.serving.core.helpers.PolicyHelper
import com.typesafe.config.ConfigFactory
import org.apache.spark.streaming.Duration
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, FlatSpec, _}

@RunWith(classOf[JUnitRunner])
class SparkContextFactoryTest extends FlatSpec with ShouldMatchers with BeforeAndAfterAll {
  self: FlatSpec =>

  override def afterAll {
    SparkContextFactory.destroySparkContext()
  }

  trait WithConfig {

    val config = SpartaConfig.initConfig("sparta.local")
    val wrongConfig = ConfigFactory.empty
    val seconds = 6
    val batchDuraction = Duration(seconds)
    val specificConfig = Map("spark.driver.allowMultipleContexts" -> "true") ++
      PolicyHelper.getSparkConfFromProps(config.get)
  }

  "SparkContextFactorySpec" should "fails when properties is missing" in new WithConfig {
    an[Exception] should be thrownBy SparkContextFactory.sparkStandAloneContextInstance(
      Map.empty[String, String], Seq())
  }

  it should "create and reuse same context" in new WithConfig {
    val sc = SparkContextFactory.sparkStandAloneContextInstance(specificConfig, Seq())
    val otherSc = SparkContextFactory.sparkStandAloneContextInstance(specificConfig, Seq())
    sc should be equals (otherSc)
    SparkContextFactory.destroySparkContext()
  }

  it should "create and reuse same SparkSession" in new WithConfig {
    val sc = SparkContextFactory.sparkStandAloneContextInstance(specificConfig, Seq())
    val sqc = SparkContextFactory.sparkSessionInstance
    sqc shouldNot be equals (null)
    val otherSqc = SparkContextFactory.sparkSessionInstance
    sqc should be equals (otherSqc)
    SparkContextFactory.destroySparkContext()
  }

  it should "create and reuse same SparkStreamingContext" in new WithConfig {
    val checkpointDir = "checkpoint/SparkContextFactorySpec"
    val sc = SparkContextFactory.sparkStandAloneContextInstance(specificConfig, Seq())
    val ssc = SparkContextFactory.sparkStreamingInstance(batchDuraction, checkpointDir, None)
    ssc shouldNot be equals (None)
    val otherSsc = SparkContextFactory.sparkStreamingInstance(batchDuraction, checkpointDir, None)
    ssc should be equals (otherSsc)
  }
} 
Example 6
Source File: CubeMakerTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.test.cube

import java.sql.Timestamp

import com.github.nscala_time.time.Imports._
import com.stratio.sparta.driver.step.{Cube, CubeOperations, Trigger}
import com.stratio.sparta.driver.writer.WriterOptions
import com.stratio.sparta.plugin.default.DefaultField
import com.stratio.sparta.plugin.cube.field.datetime.DateTimeField
import com.stratio.sparta.plugin.cube.operator.count.CountOperator
import com.stratio.sparta.sdk.pipeline.aggregation.cube.{Dimension, DimensionValue, DimensionValuesTime, InputFields}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import com.stratio.sparta.sdk.utils.AggregationTime
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.streaming.TestSuiteBase
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CubeMakerTest extends TestSuiteBase {

  val PreserverOrder = false

  
  def getEventOutput(timestamp: Timestamp, millis: Long):
  Seq[Seq[(DimensionValuesTime, InputFields)]] = {
    val dimensionString = Dimension("dim1", "eventKey", "identity", new DefaultField)
    val dimensionTime = Dimension("minute", "minute", "minute", new DateTimeField)
    val dimensionValueString1 = DimensionValue(dimensionString, "value1")
    val dimensionValueString2 = dimensionValueString1.copy(value = "value2")
    val dimensionValueString3 = dimensionValueString1.copy(value = "value3")
    val dimensionValueTs = DimensionValue(dimensionTime, timestamp)
    val tsMap = Row(timestamp)
    val valuesMap1 = InputFields(Row("value1", timestamp), 1)
    val valuesMap2 = InputFields(Row("value2", timestamp), 1)
    val valuesMap3 = InputFields(Row("value3", timestamp), 1)

    Seq(Seq(
      (DimensionValuesTime("cubeName", Seq(dimensionValueString1, dimensionValueTs)), valuesMap1),
      (DimensionValuesTime("cubeName", Seq(dimensionValueString2, dimensionValueTs)), valuesMap2),
      (DimensionValuesTime("cubeName", Seq(dimensionValueString3, dimensionValueTs)), valuesMap3)
    ))
  }
} 
Example 7
Source File: RawStageTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.driver.test.stage

import akka.actor.ActorSystem
import akka.testkit.TestKit
import com.stratio.sparta.driver.stage.{LogError, RawDataStage}
import com.stratio.sparta.sdk.pipeline.autoCalculations.AutoCalculatedField
import com.stratio.sparta.sdk.properties.JsoneyString
import com.stratio.sparta.serving.core.models.policy.writer.{AutoCalculatedFieldModel, WriterModel}
import com.stratio.sparta.serving.core.models.policy.{PolicyModel, RawDataModel}
import org.junit.runner.RunWith
import org.mockito.Mockito.when
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FlatSpecLike, ShouldMatchers}

@RunWith(classOf[JUnitRunner])
class RawStageTest
  extends TestKit(ActorSystem("RawStageTest"))
    with FlatSpecLike with ShouldMatchers with MockitoSugar {

  case class TestRawData(policy: PolicyModel) extends RawDataStage with LogError

  def mockPolicy: PolicyModel = {
    val policy = mock[PolicyModel]
    when(policy.id).thenReturn(Some("id"))
    policy
  }

  "rawDataStage" should "Generate a raw data" in {
    val field = "field"
    val timeField = "time"
    val tableName = Some("table")
    val outputs = Seq("output")
    val partitionBy = Some("field")
    val autocalculateFields = Seq(AutoCalculatedFieldModel())
    val configuration = Map.empty[String, JsoneyString]

    val policy = mockPolicy
    val rawData = mock[RawDataModel]
    val writerModel = mock[WriterModel]

    when(policy.rawData).thenReturn(Some(rawData))
    when(rawData.dataField).thenReturn(field)
    when(rawData.timeField).thenReturn(timeField)
    when(rawData.writer).thenReturn(writerModel)
    when(writerModel.tableName).thenReturn(tableName)
    when(writerModel.outputs).thenReturn(outputs)
    when(writerModel.partitionBy).thenReturn(partitionBy)
    when(writerModel.autoCalculatedFields).thenReturn(autocalculateFields)
    when(rawData.configuration).thenReturn(configuration)

    val result = TestRawData(policy).rawDataStage()

    result.timeField should be(timeField)
    result.dataField should be(field)
    result.writerOptions.tableName should be(tableName)
    result.writerOptions.partitionBy should be(partitionBy)
    result.configuration should be(configuration)
    result.writerOptions.outputs should be(outputs)
  }

  "rawDataStage" should "Fail with bad table name" in {
    val field = "field"
    val timeField = "time"
    val tableName = None
    val outputs = Seq("output")
    val partitionBy = Some("field")
    val configuration = Map.empty[String, JsoneyString]

    val policy = mockPolicy
    val rawData = mock[RawDataModel]
    val writerModel = mock[WriterModel]

    when(policy.rawData).thenReturn(Some(rawData))
    when(rawData.dataField).thenReturn(field)
    when(rawData.timeField).thenReturn(timeField)
    when(rawData.writer).thenReturn(writerModel)
    when(writerModel.tableName).thenReturn(tableName)
    when(writerModel.outputs).thenReturn(outputs)
    when(writerModel.partitionBy).thenReturn(partitionBy)
    when(rawData.configuration).thenReturn(configuration)


    the[IllegalArgumentException] thrownBy {
      TestRawData(policy).rawDataStage()
    } should have message "Something gone wrong saving the raw data. Please re-check the policy."
  }

} 
Example 8
Source File: DimensionTypeTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.aggregation.cube

import java.io.{Serializable => JSerializable}

import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class DimensionTypeTest extends WordSpec with Matchers {

  val prop = Map("hello" -> "bye")

  "DimensionType" should {

    "the return operations properties" in {
      val dimensionTypeTest = new DimensionTypeMock(prop)
      val result = dimensionTypeTest.operationProps
      result should be(prop)
    }

    "the return properties" in {
      val dimensionTypeTest = new DimensionTypeMock(prop)
      val result = dimensionTypeTest.properties
      result should be(prop)
    }

    "the return precisionValue" in {
      val dimensionTypeTest = new DimensionTypeMock(prop)
      val expected = (DimensionType.getIdentity(None, dimensionTypeTest.defaultTypeOperation), "hello")
      val result = dimensionTypeTest.precisionValue("", "hello")
      result should be(expected)
    }

    "the return precision" in {
      val dimensionTypeTest = new DimensionTypeMock(prop)
      val expected = (DimensionType.getIdentity(None, dimensionTypeTest.defaultTypeOperation))
      val result = dimensionTypeTest.precision("")
      result should be(expected)
    }
  }

  "DimensionType object" should {

    "getIdentity must be " in {
      val identity = DimensionType.getIdentity(None, TypeOp.Int)
      identity.typeOp should be(TypeOp.Int)
      identity.id should be(DimensionType.IdentityName)
      val identity2 = DimensionType.getIdentity(Some(TypeOp.String), TypeOp.Int)
      identity2.typeOp should be(TypeOp.String)
    }

    "getIdentityField must be " in {
      val identity = DimensionType.getIdentityField(None, TypeOp.Int)
      identity.typeOp should be(TypeOp.Int)
      identity.id should be(DimensionType.IdentityFieldName)
      val identity2 = DimensionType.getIdentityField(Some(TypeOp.String), TypeOp.Int)
      identity2.typeOp should be(TypeOp.String)
    }

    "getTimestamp must be " in {
      val identity = DimensionType.getTimestamp(None, TypeOp.Int)
      identity.typeOp should be(TypeOp.Int)
      identity.id should be(DimensionType.TimestampName)
      val identity2 = DimensionType.getTimestamp(Some(TypeOp.String), TypeOp.Int)
      identity2.typeOp should be(TypeOp.String)
    }
  }
} 
Example 9
Source File: DimensionTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.aggregation.cube

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class DimensionTest extends WordSpec with Matchers {

  "Dimension" should {
    val defaultDimensionType = new DimensionTypeMock(Map())
    val dimension = Dimension("dim1", "eventKey", "identity", defaultDimensionType)
    val dimensionIdentity = Dimension("dim1", "identity", "identity", defaultDimensionType)
    val dimensionNotIdentity = Dimension("dim1", "key", "key", defaultDimensionType)

    "Return the associated identity precision name" in {
      val expected = "identity"
      val result = dimensionIdentity.getNamePrecision
      result should be(expected)
    }

    "Return the associated name precision name" in {
      val expected = "key"
      val result = dimensionNotIdentity.getNamePrecision
      result should be(expected)
    }

    "Return the associated precision name" in {
      val expected = "eventKey"
      val result = dimension.getNamePrecision
      result should be(expected)
    }

    "Compare function with other dimension must be less" in {
      val dimension2 = Dimension("dim2", "eventKey", "identity", defaultDimensionType)
      val expected = -1
      val result = dimension.compare(dimension2)
      result should be(expected)
    }

    "Compare function with other dimension must be equal" in {
      val dimension2 = Dimension("dim1", "eventKey", "identity", defaultDimensionType)
      val expected = 0
      val result = dimension.compare(dimension2)
      result should be(expected)
    }

    "Compare function with other dimension must be higher" in {
      val dimension2 = Dimension("dim0", "eventKey", "identity", defaultDimensionType)
      val expected = 1
      val result = dimension.compare(dimension2)
      result should be(expected)
    }

    "classSuffix must be " in {
      val expected = "Field"
      val result = Dimension.FieldClassSuffix
      result should be(expected)
    }
  }
} 
Example 10
Source File: OutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.output

import com.stratio.sparta.sdk.pipeline.aggregation.cube.{Dimension, DimensionTypeMock, DimensionValue, DimensionValuesTime}
import com.stratio.sparta.sdk.pipeline.transformation.OutputMock
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class OutputTest extends WordSpec with Matchers {

  trait CommonValues {
    val timeDimension = "minute"
    val tableName = "table"
    val timestamp = 1L
    val defaultDimension = new DimensionTypeMock(Map())
    val dimensionValuesT = DimensionValuesTime("testCube", Seq(
      DimensionValue(Dimension("dim1", "eventKey", "identity", defaultDimension), "value1"),
      DimensionValue(Dimension("dim2", "eventKey", "identity", defaultDimension), "value2"),
      DimensionValue(Dimension("minute", "eventKey", "identity", defaultDimension), 1L)))
    val dimensionValuesTFixed = DimensionValuesTime("testCube", Seq(
      DimensionValue(Dimension("dim1", "eventKey", "identity", defaultDimension), "value1"),
      DimensionValue(Dimension("minute", "eventKey", "identity", defaultDimension), 1L)))
    val outputName = "outputName"
    val output = new OutputMock(outputName, Map())
    val outputOperation = new OutputMock(outputName, Map())
    val outputProps = new OutputMock(outputName, Map())
  }

  "Output" should {

    "Name must be " in new CommonValues {
      val expected = outputName
      val result = output.name
      result should be(expected)
    }

    "the spark geo field returned must be " in new CommonValues {
      val expected = StructField("field", ArrayType(DoubleType), false)
      val result = Output.defaultGeoField("field", false)
      result should be(expected)
    }

    "classSuffix must be " in {
      val expected = "Output"
      val result = Output.ClassSuffix
      result should be(expected)
    }
  }
} 
Example 11
Source File: InputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.input

import org.apache.spark.storage.StorageLevel
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class InputTest extends WordSpec with Matchers {

  "Input" should {
    val input = new InputMock(Map())
    val expected = StorageLevel.DISK_ONLY
    val result = input.storageLevel("DISK_ONLY")

    "Return the associated storageLevel" in {
      result should be(expected)
    }
  }

  "classSuffix must be " in {
    val expected = "Input"
    val result = Input.ClassSuffix
    result should be(expected)
  }
} 
Example 12
Source File: ParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.transformation

import java.io.{Serializable => JSerializable}

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ParserTest extends WordSpec with Matchers {

  "Parser" should {

    val parserTest = new ParserMock(
      1,
      Some("input"),
      Seq("output"),
      StructType(Seq(StructField("some", StringType))),
      Map()
    )

    "Order must be " in {
      val expected = 1
      val result = parserTest.getOrder
      result should be(expected)
    }

    "Parse must be " in {
      val event = Row("value")
      val expected = Seq(event)
      val result = parserTest.parse(event)
      result should be(expected)
    }

    "checked fields not be contained in outputs must be " in {
      val keyMap = Map("field" -> "value")
      val expected = Map()
      val result = parserTest.checkFields(keyMap)
      result should be(expected)
    }

    "checked fields are contained in outputs must be " in {
      val keyMap = Map("output" -> "value")
      val expected = keyMap
      val result = parserTest.checkFields(keyMap)
      result should be(expected)
    }

    "classSuffix must be " in {
      val expected = "Parser"
      val result = Parser.ClassSuffix
      result should be(expected)
    }
  }
} 
Example 13
Source File: TypeConversionsTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.pipeline.schema

import com.stratio.sparta.sdk.pipeline.aggregation.cube.Precision
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class TypeConversionsTest extends WordSpec with Matchers {

  "TypeConversions" should {

    val typeConvesions = new TypeConversionsMock

    "typeOperation must be " in {
      val expected = TypeOp.Int
      val result = typeConvesions.defaultTypeOperation
      result should be(expected)
    }

    "operationProps must be " in {
      val expected = Map("typeOp" -> "string")
      val result = typeConvesions.operationProps
      result should be(expected)
    }

    "the operation type must be " in {
      val expected = Some(TypeOp.String)
      val result = typeConvesions.getTypeOperation
      result should be(expected)
    }

    "the detailed operation type must be " in {
      val expected = Some(TypeOp.String)
      val result = typeConvesions.getTypeOperation("string")
      result should be(expected)
    }

    "the precision type must be " in {
      val expected = Precision("precision", TypeOp.String, Map())
      val result = typeConvesions.getPrecision("precision", Some(TypeOp.String))
      result should be(expected)
    }
  }
} 
Example 14
Source File: JsoneyStringTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.sdk.properties

import org.json4s.jackson.JsonMethods._
import org.json4s.jackson.Serialization.write
import org.json4s.{DefaultFormats, _}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class JsoneyStringTest extends WordSpecLike
with Matchers {

  "A JsoneyString" should {
    "have toString equivalent to its internal string" in {
      assertResult("foo")(new JsoneyString("foo").toString)
    }

    "be deserialized if its JSON" in {
      implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer()
      val result = parse( """{ "foo": "bar" }""").extract[JsoneyString]
      assertResult(new JsoneyString( """{"foo":"bar"}"""))(result)
    }

    "be deserialized if it's a String" in {
      implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer()
      val result = parse("\"foo\"").extract[JsoneyString]
      assertResult(new JsoneyString("foo"))(result)
    }

    "be deserialized if it's an Int" in {
      implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer()
      val result = parse("1").extract[JsoneyString]
      assertResult(new JsoneyString("1"))(result)
    }

    "be serialized as JSON" in {
      implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer()

      var result = write(new JsoneyString("foo"))
      assertResult("\"foo\"")(result)

      result = write(new JsoneyString("{\"foo\":\"bar\"}"))
      assertResult("\"{\\\"foo\\\":\\\"bar\\\"}\"")(result)
    }

    "be deserialized if it's an JBool" in {
      implicit val json4sJacksonFormats = DefaultFormats + new JsoneyStringSerializer()
      val result = parse("true").extract[JsoneyString]
      assertResult(new JsoneyString("true"))(result)
    }

    "have toSeq equivalent to its internal string" in {
      assertResult(Seq("o"))(new JsoneyString("foo").toSeq)
    }
  }
} 
Example 15
Source File: SpartaClusterLauncherActorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.helpers

import com.stratio.sparta.serving.core.config.{SpartaConfigFactory, SpartaConfig}
import com.typesafe.config.ConfigFactory
import org.junit.runner.RunWith
import org.scalamock.scalatest._
import org.scalatest._
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class SpartaClusterLauncherActorTest extends FlatSpec with MockFactory with ShouldMatchers with Matchers {

  it should "init SpartaConfig from a file with a configuration" in {
    val config = ConfigFactory.parseString(
      """
        |sparta {
        | testKey : "testValue"
        |}
      """.stripMargin)

    val spartaConfig = SpartaConfig.initConfig(node = "sparta", configFactory = SpartaConfigFactory(config))
    spartaConfig.get.getString("testKey") should be("testValue")
  }

  it should "init a config from a given config" in {
    val config = ConfigFactory.parseString(
      """
        |sparta {
        |  testNode {
        |    testKey : "testValue"
        |  }
        |}
      """.stripMargin)

    val spartaConfig = SpartaConfig.initConfig(node = "sparta", configFactory = SpartaConfigFactory(config))
    val testNodeConfig = SpartaConfig.initConfig("testNode", spartaConfig, SpartaConfigFactory(config))
    testNodeConfig.get.getString("testKey") should be("testValue")
  }
} 
Example 16
Source File: ControllerActorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.actor

import akka.actor.{ActorSystem, Props}
import akka.testkit.{ImplicitSender, TestKit}
import com.stratio.sparta.driver.service.StreamingContextService
import com.stratio.sparta.serving.core.actor.{RequestActor, FragmentActor, StatusActor}
import com.stratio.sparta.serving.core.config.SpartaConfig
import com.stratio.sparta.serving.core.constants.AkkaConstant
import org.apache.curator.framework.CuratorFramework
import org.junit.runner.RunWith
import org.scalamock.scalatest.MockFactory
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ControllerActorTest(_system: ActorSystem) extends TestKit(_system)
  with ImplicitSender
  with WordSpecLike
  with Matchers
  with BeforeAndAfterAll
  with MockFactory {

  SpartaConfig.initMainConfig()
  SpartaConfig.initApiConfig()

  val curatorFramework = mock[CuratorFramework]
  val statusActor = _system.actorOf(Props(new StatusActor(curatorFramework)))
  val executionActor = _system.actorOf(Props(new RequestActor(curatorFramework)))
  val streamingContextService = new StreamingContextService(curatorFramework)
  val fragmentActor = _system.actorOf(Props(new FragmentActor(curatorFramework)))
  val policyActor = _system.actorOf(Props(new PolicyActor(curatorFramework, statusActor)))
  val sparkStreamingContextActor = _system.actorOf(
    Props(new LauncherActor(streamingContextService, curatorFramework)))
  val pluginActor = _system.actorOf(Props(new PluginActor()))
  val configActor = _system.actorOf(Props(new ConfigActor()))

  def this() =
    this(ActorSystem("ControllerActorSpec", SpartaConfig.daemonicAkkaConfig))

  implicit val actors = Map(
    AkkaConstant.StatusActorName -> statusActor,
    AkkaConstant.FragmentActorName -> fragmentActor,
    AkkaConstant.PolicyActorName -> policyActor,
    AkkaConstant.LauncherActorName -> sparkStreamingContextActor,
    AkkaConstant.PluginActorName -> pluginActor,
    AkkaConstant.ExecutionActorName -> executionActor,
    AkkaConstant.ConfigActorName -> configActor
  )

  override def afterAll {
    TestKit.shutdownActorSystem(system)
  }

  "ControllerActor" should {
    "set up the controller actor that contains all sparta's routes without any error" in {
      _system.actorOf(Props(new ControllerActor(actors, curatorFramework)))
    }
  }
} 
Example 17
Source File: DriverActorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.actor

import java.nio.file.{Files, Path}

import akka.actor.{ActorSystem, Props}
import akka.testkit.{DefaultTimeout, ImplicitSender, TestKit}
import akka.util.Timeout
import com.stratio.sparta.serving.api.actor.DriverActor.UploadDrivers
import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory}
import com.stratio.sparta.serving.core.models.SpartaSerializer
import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse}
import com.typesafe.config.{Config, ConfigFactory}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import spray.http.BodyPart

import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Failure, Success}

@RunWith(classOf[JUnitRunner])
class DriverActorTest extends TestKit(ActorSystem("PluginActorSpec"))
  with DefaultTimeout
  with ImplicitSender
  with WordSpecLike
  with Matchers
  with BeforeAndAfterAll
  with BeforeAndAfterEach
  with MockitoSugar with SpartaSerializer {

  val tempDir: Path = Files.createTempDirectory("test")
  tempDir.toFile.deleteOnExit()

  val localConfig: Config = ConfigFactory.parseString(
    s"""
       |sparta{
       |   api {
       |     host = local
       |     port= 7777
       |   }
       |}
       |
       |sparta.config.driverPackageLocation = "$tempDir"
    """.stripMargin)

  val fileList = Seq(BodyPart("reference.conf", "file"))

  override def beforeEach(): Unit = {
    SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig))
    SpartaConfig.initApiConfig()
  }

  override def afterAll: Unit = {
    shutdown()
  }

  override implicit val timeout: Timeout = Timeout(15 seconds)

  "DriverActor " must {

    "Not save files with wrong extension" in {
      val driverActor = system.actorOf(Props(new DriverActor()))
      driverActor ! UploadDrivers(fileList)
      expectMsgPF() {
        case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.isEmpty shouldBe true
      }
    }
    "Not upload empty files" in {
      val driverActor = system.actorOf(Props(new DriverActor()))
      driverActor ! UploadDrivers(Seq.empty)
      expectMsgPF() {
        case SpartaFilesResponse(Failure(f)) => f.getMessage shouldBe "At least one file is expected"
      }
    }
    "Save a file" in {
      val driverActor = system.actorOf(Props(new DriverActor()))
      driverActor ! UploadDrivers(Seq(BodyPart("reference.conf", "file.jar")))
      expectMsgPF() {
        case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.head.fileName.endsWith("file.jar") shouldBe true
      }
    }
  }
} 
Example 18
Source File: PluginActorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.actor

import java.nio.file.{Files, Path}

import akka.actor.{ActorSystem, Props}
import akka.testkit.{DefaultTimeout, ImplicitSender, TestKit}
import akka.util.Timeout
import com.stratio.sparta.serving.api.actor.PluginActor.{PluginResponse, UploadPlugins}
import com.stratio.sparta.serving.api.constants.HttpConstant
import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory}
import com.stratio.sparta.serving.core.models.SpartaSerializer
import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse}
import com.typesafe.config.{Config, ConfigFactory}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import spray.http.BodyPart

import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.{Failure, Success}

@RunWith(classOf[JUnitRunner])
class PluginActorTest extends TestKit(ActorSystem("PluginActorSpec"))
  with DefaultTimeout
  with ImplicitSender
  with WordSpecLike
  with Matchers
  with BeforeAndAfterAll
  with BeforeAndAfterEach
  with MockitoSugar with SpartaSerializer {

  val tempDir: Path = Files.createTempDirectory("test")
  tempDir.toFile.deleteOnExit()

  val localConfig: Config = ConfigFactory.parseString(
    s"""
       |sparta{
       |   api {
       |     host = local
       |     port= 7777
       |   }
       |}
       |
       |sparta.config.pluginPackageLocation = "$tempDir"
    """.stripMargin)


  val fileList = Seq(BodyPart("reference.conf", "file"))

  override def beforeEach(): Unit = {
    SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig))
    SpartaConfig.initApiConfig()
  }

  override def afterAll: Unit = {
    shutdown()
  }

  override implicit val timeout: Timeout = Timeout(15 seconds)

  "PluginActor " must {

    "Not save files with wrong extension" in {
      val pluginActor = system.actorOf(Props(new PluginActor()))
      pluginActor ! UploadPlugins(fileList)
      expectMsgPF() {
        case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.isEmpty shouldBe true
      }
    }
    "Not upload empty files" in {
      val pluginActor = system.actorOf(Props(new PluginActor()))
      pluginActor ! UploadPlugins(Seq.empty)
      expectMsgPF() {
        case SpartaFilesResponse(Failure(f)) => f.getMessage shouldBe "At least one file is expected"
      }
    }
    "Save a file" in {
      val pluginActor = system.actorOf(Props(new PluginActor()))
      pluginActor ! UploadPlugins(Seq(BodyPart("reference.conf", "file.jar")))
      expectMsgPF() {
        case SpartaFilesResponse(Success(f: Seq[SpartaFile])) => f.head.fileName.endsWith("file.jar") shouldBe true
      }
    }
  }

} 
Example 19
Source File: CustomExceptionHandlerTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.service.handler

import akka.actor.ActorSystem
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}
import spray.http.StatusCodes
import spray.httpx.Json4sJacksonSupport
import spray.routing.{Directives, HttpService, StandardRoute}
import spray.testkit.ScalatestRouteTest
import com.stratio.sparta.sdk.exception.MockException
import com.stratio.sparta.serving.api.service.handler.CustomExceptionHandler._
import com.stratio.sparta.serving.core.exception.ServingCoreException
import com.stratio.sparta.serving.core.models.{ErrorModel, SpartaSerializer}

@RunWith(classOf[JUnitRunner])
class CustomExceptionHandlerTest extends WordSpec
with Directives with ScalatestRouteTest with Matchers
with Json4sJacksonSupport with HttpService with SpartaSerializer {

  def actorRefFactory: ActorSystem = system

  trait MyTestRoute {

    val exception: Throwable
    val route: StandardRoute = complete(throw exception)
  }

  def route(throwable: Throwable): StandardRoute = complete(throw throwable)

  "CustomExceptionHandler" should {
    "encapsulate a unknow error in an error model and response with a 500 code" in new MyTestRoute {
      val exception = new MockException
      Get() ~> sealRoute(route) ~> check {
        status should be(StatusCodes.InternalServerError)
        response.entity.asString should be(ErrorModel.toString(new ErrorModel("666", "unknown")))
      }
    }
    "encapsulate a serving api error in an error model and response with a 400 code" in new MyTestRoute {
      val exception = ServingCoreException.create(ErrorModel.toString(new ErrorModel("333", "testing exception")))
      Get() ~> sealRoute(route) ~> check {
        status should be(StatusCodes.NotFound)
        response.entity.asString should be(ErrorModel.toString(new ErrorModel("333", "testing exception")))
      }
    }
  }
} 
Example 20
Source File: ConfigHttpServiceTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.service.http

import akka.actor.ActorRef
import akka.testkit.TestProbe
import com.stratio.sparta.serving.api.actor.ConfigActor
import com.stratio.sparta.serving.api.actor.ConfigActor._
import com.stratio.sparta.serving.api.constants.HttpConstant
import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory}
import com.stratio.sparta.serving.core.constants.{AkkaConstant, AppConstant}
import com.stratio.sparta.serving.core.models.dto.LoggedUserConstant
import com.stratio.sparta.serving.core.models.frontend.FrontendConfiguration
import org.junit.runner.RunWith
import org.scalatest.WordSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ConfigHttpServiceTest extends WordSpec
  with ConfigHttpService
  with HttpServiceBaseTest{

  val configActorTestProbe = TestProbe()

  val dummyUser = Some(LoggedUserConstant.AnonymousUser)

  override implicit val actors: Map[String, ActorRef] = Map(
    AkkaConstant.ConfigActorName -> configActorTestProbe.ref
  )

  override val supervisor: ActorRef = testProbe.ref

  override def beforeEach(): Unit = {
    SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig))
  }

  protected def retrieveStringConfig(): FrontendConfiguration =
    FrontendConfiguration(AppConstant.DefaultFrontEndTimeout, Option(AppConstant.DefaultOauth2CookieName))

  "ConfigHttpService.FindAll" should {
    "retrieve a FrontendConfiguration item" in {
      startAutopilot(ConfigResponse(retrieveStringConfig()))
      Get(s"/${HttpConstant.ConfigPath}") ~> routes(dummyUser) ~> check {
        testProbe.expectMsgType[ConfigActor.FindAll.type]
        responseAs[FrontendConfiguration] should equal(retrieveStringConfig())
      }
    }
  }

} 
Example 21
Source File: AppStatusHttpServiceTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.service.http

import akka.actor.ActorRef
import com.stratio.sparta.serving.api.constants.HttpConstant
import org.apache.curator.framework.CuratorFramework
import org.junit.runner.RunWith
import org.scalamock.scalatest.MockFactory
import org.scalatest.WordSpec
import org.scalatest.junit.JUnitRunner
import spray.http.StatusCodes

@RunWith(classOf[JUnitRunner])
class AppStatusHttpServiceTest extends WordSpec
                              with AppStatusHttpService
                              with HttpServiceBaseTest
with MockFactory {

  override implicit val actors: Map[String, ActorRef] = Map()
  override val supervisor: ActorRef = testProbe.ref
  override val curatorInstance = mock[CuratorFramework]

  "AppStatusHttpService" should {
    "check the status of the server" in {
      Get(s"/${HttpConstant.AppStatus}") ~> routes() ~> check {
        status should be (StatusCodes.InternalServerError)
      }
    }
  }
} 
Example 22
Source File: PluginsHttpServiceTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.serving.api.service.http

import akka.actor.ActorRef
import akka.testkit.TestProbe
import com.stratio.sparta.serving.api.actor.PluginActor.{PluginResponse, UploadPlugins}
import com.stratio.sparta.serving.api.constants.HttpConstant
import com.stratio.sparta.serving.core.config.{SpartaConfig, SpartaConfigFactory}
import com.stratio.sparta.serving.core.models.dto.LoggedUserConstant
import com.stratio.sparta.serving.core.models.files.{SpartaFile, SpartaFilesResponse}
import org.junit.runner.RunWith
import org.scalatest.WordSpec
import org.scalatest.junit.JUnitRunner
import spray.http._

import scala.util.{Failure, Success}

@RunWith(classOf[JUnitRunner])
class PluginsHttpServiceTest extends WordSpec
  with PluginsHttpService
  with HttpServiceBaseTest {

  override val supervisor: ActorRef = testProbe.ref

  val pluginTestProbe = TestProbe()

  val dummyUser = Some(LoggedUserConstant.AnonymousUser)

  override implicit val actors: Map[String, ActorRef] = Map.empty

  override def beforeEach(): Unit = {
    SpartaConfig.initMainConfig(Option(localConfig), SpartaConfigFactory(localConfig))
  }

  "PluginsHttpService.upload" should {
    "Upload a file" in {
      val response = SpartaFilesResponse(Success(Seq(SpartaFile("", "", "", ""))))
      startAutopilot(response)
      Put(s"/${HttpConstant.PluginsPath}") ~> routes(dummyUser) ~> check {
        testProbe.expectMsgType[UploadPlugins]
        status should be(StatusCodes.OK)
      }
    }
    "Fail when service is not available" in {
      val response = SpartaFilesResponse(Failure(new IllegalArgumentException("Error")))
      startAutopilot(response)
      Put(s"/${HttpConstant.PluginsPath}") ~> routes(dummyUser) ~> check {
        testProbe.expectMsgType[UploadPlugins]
        status should be(StatusCodes.InternalServerError)
      }
    }
  }
} 
Example 23
Source File: FileSystemOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.filesystem

import java.io.File

import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.plugin.output.fileSystem.FileSystemOutput
import com.stratio.sparta.sdk.pipeline.output.{Output, OutputFormatEnum, SaveModeEnum}
import org.apache.commons.io.FileUtils
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest.Matchers
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class FileSystemOutputIT extends TemporalSparkContext with Matchers {

  val directory = getClass().getResource("/origin.txt")
  val parentFile = new File(directory.getPath).getParent
  val properties = Map(("path", parentFile + "/testRow"), ("outputFormat", "row"))
  val fields = StructType(StructField("name", StringType, false) ::
    StructField("age", IntegerType, false) ::
    StructField("year", IntegerType, true) :: Nil)
  val fsm = new FileSystemOutput("key", properties)


  "An object of type FileSystemOutput " should "have the same values as the properties Map" in {
    fsm.outputFormat should be(OutputFormatEnum.ROW)
  }

  
  private def dfGen(): DataFrame = {
    val sqlCtx = SparkSession.builder().config(sc.getConf).getOrCreate()
    val dataRDD = sc.parallelize(List(("user1", 23, 1993), ("user2", 26, 1990), ("user3", 21, 1995)))
      .map { case (name, age, year) => Row(name, age, year) }

    sqlCtx.createDataFrame(dataRDD, fields)
  }

  def fileExists(path: String): Boolean = new File(path).exists()

  "Given a DataFrame, a directory" should "be created with the data written inside" in {
    fsm.save(dfGen(), SaveModeEnum.Append, Map(Output.TableNameKey -> "test"))
    fileExists(fsm.path.get) should equal(true)
  }

  it should "exist with the given path and be deleted" in {
    if (fileExists(fsm.path.get))
      FileUtils.deleteDirectory(new File(fsm.path.get))
    fileExists(fsm.path.get) should equal(false)
  }

  val fsm2 = new FileSystemOutput("key", properties.updated("outputFormat", "json")
    .updated("path", parentFile + "/testJson"))

  "Given another DataFrame, a directory" should "be created with the data inside in JSON format" in {
    fsm2.outputFormat should be(OutputFormatEnum.JSON)
    fsm2.save(dfGen(), SaveModeEnum.Append, Map(Output.TableNameKey -> "test"))
    fileExists(fsm2.path.get) should equal(true)
  }

  it should "exist with the given path and be deleted" in {
    if (fileExists(s"${fsm2.path.get}/test"))
      FileUtils.deleteDirectory(new File(s"${fsm2.path.get}/test"))
    fileExists(s"${fsm2.path.get}/test") should equal(false)
  }
} 
Example 24
Source File: AvroOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.avro

import java.sql.Timestamp
import java.time.Instant

import com.databricks.spark.avro._
import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

import scala.reflect.io.File
import scala.util.Random


@RunWith(classOf[JUnitRunner])
class AvroOutputIT extends TemporalSparkContext with Matchers {

  trait CommonValues {
    val tmpPath: String = File.makeTemp().name
    val sparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
    val schema = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType),
      StructField("minute", LongType)
    ))

    val data =
      sparkSession.createDataFrame(sc.parallelize(Seq(
        Row("Kevin", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Kira", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Ariadne", Random.nextInt, Timestamp.from(Instant.now).getTime)
      )), schema)
  }

  trait WithEventData extends CommonValues {
    val properties = Map("path" -> tmpPath)
    val output = new AvroOutput("avro-test", properties)
  }


  "AvroOutput" should "throw an exception when path is not present" in {
    an[Exception] should be thrownBy new AvroOutput("avro-test", Map.empty)
  }

  it should "throw an exception when empty path " in {
    an[Exception] should be thrownBy new AvroOutput("avro-test", Map("path" -> "    "))
  }

  it should "save a dataframe " in new WithEventData {
    output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person"))
    val read = sparkSession.read.avro(s"$tmpPath/person")
    read.count should be(3)
    read should be eq data
    File(tmpPath).deleteRecursively
    File("spark-warehouse").deleteRecursively
  }

} 
Example 25
Source File: HttpOutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.http

import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.sdk.pipeline.output.OutputFormatEnum
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest.Matchers
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class HttpOutputTest extends TemporalSparkContext with Matchers {

  val properties = Map(
    "url" -> "https://httpbin.org/post",
    "delimiter" -> ",",
    "parameterName" -> "thisIsAKeyName",
    "readTimeOut" -> "5000",
    "outputFormat" -> "ROW",
    "postType" -> "body",
    "connTimeout" -> "6000"
  )

  val fields = StructType(StructField("name", StringType, false) ::
    StructField("age", IntegerType, false) ::
    StructField("year", IntegerType, true) :: Nil)
  val OkHTTPResponse = 200

  "An object of type RestOutput " should "have the same values as the properties Map" in {
    val rest = new HttpOutput("key", properties)

    rest.outputFormat should be(OutputFormatEnum.ROW)
    rest.readTimeout should be(5000)
  }
  it should "throw a NoSuchElementException" in {
    val properties2 = properties.updated("postType", "vooooooody")
    a[NoSuchElementException] should be thrownBy {
      new HttpOutput("keyName", properties2)
    }
  }

  
  private def dfGen(): DataFrame = {
    val sqlCtx = SparkSession.builder().config(sc.getConf).getOrCreate()
    val dataRDD = sc.parallelize(List(("user1", 23, 1993), ("user2", 26, 1990))).map { case (name, age, year) =>
      Row(name, age, year)
    }
    sqlCtx.createDataFrame(dataRDD, fields)
  }

  val restMock1 = new HttpOutput("key", properties)
  "Given a DataFrame it" should "be parsed and send through a Raw data POST request" in {

    dfGen().collect().foreach(row => {
      assertResult(OkHTTPResponse)(restMock1.sendData(row.mkString(restMock1.delimiter)).code)
    })
  }

  it should "return the same amount of responses as rows in the DataFrame" in {
    val size = dfGen().collect().map(row => restMock1.sendData(row.mkString(restMock1.delimiter)).code).size
    assertResult(dfGen().count())(size)
  }

  val restMock2 = new HttpOutput("key", properties.updated("postType", "parameter"))
  it should "be parsed and send as a POST request along with a parameter stated by properties.parameterKey " in {
    dfGen().collect().foreach(row => {
      assertResult(OkHTTPResponse)(restMock2.sendData(row.mkString(restMock2.delimiter)).code)
    })
  }

  val restMock3 = new HttpOutput("key", properties.updated("outputFormat", "JSON"))
  "Given a DataFrame it" should "be sent as JSON through a Raw data POST request" in {

    dfGen().toJSON.collect().foreach(row => {
      assertResult(OkHTTPResponse)(restMock3.sendData(row).code)
    })
  }

  val restMock4 = new HttpOutput("key", properties.updated("postType", "parameter").updated("format", "JSON"))
  it should "sent as a POST request along with a parameter stated by properties.parameterKey " in {

    dfGen().toJSON.collect().foreach(row => {
      assertResult(OkHTTPResponse)(restMock4.sendData(row).code)
    })
  }
} 
Example 26
Source File: CassandraOutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.cassandra

import java.io.{Serializable => JSerializable}

import com.datastax.spark.connector.cql.CassandraConnector
import com.stratio.sparta.sdk._
import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class CassandraOutputTest extends FlatSpec with Matchers with MockitoSugar with AnswerSugar {

  val s = "sum"
  val properties = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042"))

  "getSparkConfiguration" should "return a Seq with the configuration" in {
    val configuration = Map(("connectionHost", "127.0.0.1"), ("connectionPort", "9042"))
    val cass = CassandraOutput.getSparkConfiguration(configuration)

    cass should be(List(("spark.cassandra.connection.host", "127.0.0.1"), ("spark.cassandra.connection.port", "9042")))
  }

  "getSparkConfiguration" should "return all cassandra-spark config" in {
    val config: Map[String, JSerializable] = Map(
      ("sparkProperties" -> JsoneyString(
        "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," +
          "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")),
      ("anotherProperty" -> "true")
    )

    val sparkConfig = CassandraOutput.getSparkConfiguration(config)

    sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(true)
    sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(true)
    sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false)
  }

  "getSparkConfiguration" should "not return cassandra-spark config" in {
    val config: Map[String, JSerializable] = Map(
      ("hadoopProperties" -> JsoneyString(
        "[{\"sparkPropertyKey\":\"spark.cassandra.input.fetch.size_in_rows\",\"sparkPropertyValue\":\"2000\"}," +
          "{\"sparkPropertyKey\":\"spark.cassandra.input.split.size_in_mb\",\"sparkPropertyValue\":\"64\"}]")),
      ("anotherProperty" -> "true")
    )

    val sparkConfig = CassandraOutput.getSparkConfiguration(config)

    sparkConfig.exists(_ == ("spark.cassandra.input.fetch.size_in_rows" -> "2000")) should be(false)
    sparkConfig.exists(_ == ("spark.cassandra.input.split.size_in_mb" -> "64")) should be(false)
    sparkConfig.exists(_ == ("anotherProperty" -> "true")) should be(false)
  }
} 
Example 27
Source File: ElasticSearchOutputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.elasticsearch

import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.types._
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ElasticSearchOutputTest extends FlatSpec with ShouldMatchers {

  trait BaseValues {

    final val localPort = 9200
    final val remotePort = 9300
    val output = getInstance()
    val outputMultipleNodes = new ElasticSearchOutput("ES-out",
      Map("nodes" ->
        new JsoneyString(
          s"""[{"node":"host-a","tcpPort":"$remotePort","httpPort":"$localPort"},{"node":"host-b",
              |"tcpPort":"9301","httpPort":"9201"}]""".stripMargin),
        "dateType" -> "long"))

    def getInstance(host: String = "localhost", httpPort: Int = localPort, tcpPort: Int = remotePort)
    : ElasticSearchOutput =
      new ElasticSearchOutput("ES-out",
        Map("nodes" -> new JsoneyString( s"""[{"node":"$host","httpPort":"$httpPort","tcpPort":"$tcpPort"}]"""),
          "clusterName" -> "elasticsearch"))
  }

  trait NodeValues extends BaseValues {

    val ipOutput = getInstance("127.0.0.1", localPort, remotePort)
    val ipv6Output = getInstance("0:0:0:0:0:0:0:1", localPort, remotePort)
    val remoteOutput = getInstance("dummy", localPort, remotePort)
  }

  trait TestingValues extends BaseValues {

    val indexNameType = "spartatable/sparta"
    val tableName = "spartaTable"
    val baseFields = Seq(StructField("string", StringType), StructField("int", IntegerType))
    val schema = StructType(baseFields)
    val extraFields = Seq(StructField("id", StringType, false), StructField("timestamp", LongType, false))
    val properties = Map("nodes" -> new JsoneyString(
      """[{"node":"localhost","httpPort":"9200","tcpPort":"9300"}]""".stripMargin),
      "dateType" -> "long",
      "clusterName" -> "elasticsearch")
    override val output = new ElasticSearchOutput("ES-out", properties)
    val dateField = StructField("timestamp", TimestampType, false)
    val expectedDateField = StructField("timestamp", LongType, false)
    val stringField = StructField("string", StringType)
    val expectedStringField = StructField("string", StringType)
  }

  trait SchemaValues extends BaseValues {

    val fields = Seq(
      StructField("long", LongType),
      StructField("double", DoubleType),
      StructField("decimal", DecimalType(10, 0)),
      StructField("int", IntegerType),
      StructField("boolean", BooleanType),
      StructField("date", DateType),
      StructField("timestamp", TimestampType),
      StructField("array", ArrayType(StringType)),
      StructField("map", MapType(StringType, IntegerType)),
      StructField("string", StringType),
      StructField("binary", BinaryType))
    val completeSchema = StructType(fields)
  }

  "ElasticSearchOutput" should "format properties" in new NodeValues with SchemaValues {
    output.httpNodes should be(Seq(("localhost", 9200)))
    outputMultipleNodes.httpNodes should be(Seq(("host-a", 9200), ("host-b", 9201)))
    output.clusterName should be("elasticsearch")
  }

  it should "parse correct index name type" in new TestingValues {
    output.indexNameType(tableName) should be(indexNameType)
  }

  it should "return a Seq of tuples (host,port) format" in new NodeValues {

    output.getHostPortConfs("nodes", "localhost", "9200", "node", "httpPort") should be(List(("localhost", 9200)))
    output.getHostPortConfs("nodes", "localhost", "9300", "node", "tcpPort") should be(List(("localhost", 9300)))
    outputMultipleNodes.getHostPortConfs("nodes", "localhost", "9200", "node", "httpPort") should be(List(
      ("host-a", 9200), ("host-b", 9201)))
    outputMultipleNodes.getHostPortConfs("nodes", "localhost", "9300", "node", "tcpPort") should be(List(
      ("host-a", 9300), ("host-b", 9301)))
  }
} 
Example 28
Source File: CsvOutputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.output.csv

import java.sql.Timestamp
import java.time.Instant

import com.databricks.spark.avro._
import com.stratio.sparta.plugin.TemporalSparkContext
import com.stratio.sparta.sdk.pipeline.output.{Output, SaveModeEnum}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Row, SparkSession}
import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

import scala.reflect.io.File
import scala.util.Random


@RunWith(classOf[JUnitRunner])
class CsvOutputIT extends TemporalSparkContext with Matchers {

  trait CommonValues {
    val tmpPath: String = File.makeTemp().name
    val sparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
    val schema = StructType(Seq(
      StructField("name", StringType),
      StructField("age", IntegerType),
      StructField("minute", LongType)
    ))

    val data =
      sparkSession.createDataFrame(sc.parallelize(Seq(
        Row("Kevin", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Kira", Random.nextInt, Timestamp.from(Instant.now).getTime),
        Row("Ariadne", Random.nextInt, Timestamp.from(Instant.now).getTime)
      )), schema)
  }

  trait WithEventData extends CommonValues {
    val properties = Map("path" -> tmpPath)
    val output = new CsvOutput("csv-test", properties)
  }


  "CsvOutput" should "throw an exception when path is not present" in {
    an[Exception] should be thrownBy new CsvOutput("csv-test", Map.empty)
  }

  it should "throw an exception when empty path " in {
    an[Exception] should be thrownBy new CsvOutput("csv-test", Map("path" -> "    "))
  }

  it should "save a dataframe " in new WithEventData {
    output.save(data, SaveModeEnum.Append, Map(Output.TableNameKey -> "person"))
    val read = sparkSession.read.csv(s"$tmpPath/person.csv")
    read.count should be(3)
    read should be eq data
    File(tmpPath).deleteRecursively
    File("spark-warehouse").deleteRecursively
  }

} 
Example 29
Source File: LastValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.lastValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class LastValueOperatorTest extends WordSpec with Matchers {

  "LastValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new LastValueOperator("lastValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new LastValueOperator("lastValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(2))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("b"))
    }

    "associative process must be " in {
      val inputFields = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new LastValueOperator("lastValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))

      val inputFields4 = new LastValueOperator("lastValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new LastValueOperator("lastValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
    }
  }
} 
Example 30
Source File: StddevOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.stddev

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class StddevOperatorTest extends WordSpec with Matchers {

  "Std dev operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new StddevOperator("stdev", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new StddevOperator("stdev", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new StddevOperator("stdev", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new StddevOperator("stdev", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be
      (Some(2.8284271247461903))

      val inputFields3 = new StddevOperator("stdev", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be
      (Some(2.850438562747845))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(
        Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
    }

    "processReduce distinct must be " in {
      val inputFields = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be
      (Some(2.8284271247461903))

      val inputFields3 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be
      (Some(2.850438562747845))

      val inputFields4 = new StddevOperator("stdev", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new StddevOperator("stdev", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(
        Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some("2.850438562747845"))
    }
  }
} 
Example 31
Source File: MedianOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.median

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MedianOperatorTest extends WordSpec with Matchers {

  "Median operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MedianOperator("median", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MedianOperator("median", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MedianOperator("median", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MedianOperator("median", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(3d))

      val inputFields3 = new MedianOperator("median", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("3.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(2.5))

      val inputFields3 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(3))

      val inputFields4 = new MedianOperator("median", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MedianOperator("median", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("2.5"))
    }
  }
} 
Example 32
Source File: ModeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mode

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class ModeOperatorTest extends WordSpec with Matchers {

  "Mode operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new ModeOperator("mode", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new ModeOperator("mode", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new ModeOperator("mode", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new ModeOperator("mode", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new ModeOperator("mode", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new ModeOperator("mode", initSchema, Map())
      inputFields2.processReduce(Seq(Some("hey"), Some("hey"), Some("hi"))) should be(Some(List("hey")))

      val inputFields3 = new ModeOperator("mode", initSchema, Map())
      inputFields3.processReduce(Seq(Some("1"), Some("1"), Some("4"))) should be(Some(List("1")))

      val inputFields4 = new ModeOperator("mode", initSchema, Map())
      inputFields4.processReduce(Seq(
        Some("1"), Some("1"), Some("4"), Some("4"), Some("4"), Some("4"))) should be(Some(List("4")))

      val inputFields5 = new ModeOperator("mode", initSchema, Map())
      inputFields5.processReduce(Seq(
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"))) should be(Some(List("1", "2", "4")))

      val inputFields6 = new ModeOperator("mode", initSchema, Map())
      inputFields6.processReduce(Seq(
        Some("1"), Some("1"), Some("2"), Some("2"), Some("4"), Some("4"), Some("5"))
      ) should be(Some(List("1", "2", "4")))
    }
  }
} 
Example 33
Source File: RangeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.range

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class RangeOperatorTest extends WordSpec with Matchers {

  "Range operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new RangeOperator("range", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new RangeOperator("range", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new RangeOperator("range", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new RangeOperator("range", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(0))

      val inputFields3 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(4))) should be(Some(3))

      val inputFields4 = new RangeOperator("range", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new RangeOperator("range", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some("6.0"))
    }
  }
} 
Example 34
Source File: AccumulatorOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.accumulator

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class AccumulatorOperatorTest extends WordSpec with Matchers {

  "Accumulator operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new AccumulatorOperator("accumulator", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new AccumulatorOperator("accumulator", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":2}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new AccumulatorOperator("accumulator", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":2}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(Seq("1", "1")))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(Seq("a", "b")))
    }

    "associative process must be " in {
      val inputFields = new AccumulatorOperator("accumulator", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Seq(1L))),
        (Operator.NewValuesKey, Some(Seq(2L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Seq("1", "2")))

      val inputFields2 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> "arraydouble"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(3))))
      inputFields2.associativity(resultInput2) should be(Some(Seq(1d, 3d)))

      val inputFields3 = new AccumulatorOperator("accumulator", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Seq(1))),
        (Operator.NewValuesKey, Some(Seq(1))))
      inputFields3.associativity(resultInput3) should be(Some(Seq("1", "1")))
    }
  }
} 
Example 35
Source File: FirstValueOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.firstValue

import java.util.Date

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class FirstValueOperatorTest extends WordSpec with Matchers {

  "FirstValue operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new FirstValueOperator("firstValue", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FirstValueOperator("firstValue", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields.processReduce(Seq()) should be(None)

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2))) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some("a"))
    }

    "associative process must be " in {
      val inputFields = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(1L))

      val inputFields2 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)))
      inputFields2.associativity(resultInput2) should be(Some(1))

      val inputFields3 = new FirstValueOperator("firstValue", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(1)),
        (Operator.NewValuesKey, Some(1)),
        (Operator.NewValuesKey, None))
      inputFields3.associativity(resultInput3) should be(Some(1))

      val inputFields4 = new FirstValueOperator("firstValue", initSchema, Map())
      val resultInput4 = Seq()
      inputFields4.associativity(resultInput4) should be(None)

      val inputFields5 = new FirstValueOperator("firstValue", initSchema, Map())
      val date = new Date()
      val resultInput5 = Seq((Operator.NewValuesKey, Some(date)))
      inputFields5.associativity(resultInput5) should be(Some(date))
    }
  }
} 
Example 36
Source File: MeanAssociativeOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MeanAssociativeOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MeanAssociativeOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanAssociativeOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanAssociativeOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanAssociativeOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be (Some(List(1.0, 1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(List(1.0, 2.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(List()))
    }

    "processReduce distinct must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(List()))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(List(1.0)))

      val inputFields3 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(List(1.0, 3.0)))

      val inputFields4 = new MeanAssociativeOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(List()))
    }

    "associative process must be " in {
      val inputFields = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput =
        Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("count" -> 1.0, "sum" -> 2.0, "mean" -> 2.0)))

      val inputFields2 = new MeanAssociativeOperator("avg", initSchema, Map())
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("count" -> 1d, "sum" -> 2d, "mean" -> 2d))),
        (Operator.NewValuesKey, Some(Seq(1d))))
      inputFields2.associativity(resultInput2) should be(Some(Map("sum" -> 3.0, "count" -> 2.0, "mean" -> 1.5)))
    }
  }
} 
Example 37
Source File: MeanOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.mean

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MeanOperatorTest extends WordSpec with Matchers {

  "Mean operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new MeanOperator("avg", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new MeanOperator("avg", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields4.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new MeanOperator("avg", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new MeanOperator("avg", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
    }

    "processReduce distinct must be " in {
      val inputFields = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(1), None)) should be(Some(1))

      val inputFields3 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields3.processReduce(Seq(Some(1), Some(3), Some(1), None)) should be(Some(2))

      val inputFields4 = new MeanOperator("avg", initSchema, Map("distinct" -> "true"))
      inputFields4.processReduce(Seq(None)) should be(Some(0d))

      val inputFields5 = new MeanOperator("avg", initSchema, Map("typeOp" -> "string", "distinct" -> "true"))
      inputFields5.processReduce(Seq(Some(1), Some(1))) should be(Some("1.0"))
    }
  }
} 
Example 38
Source File: OperatorEntityCountTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import java.io.{Serializable => JSerializable}

import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class OperatorEntityCountTest extends WordSpec with Matchers {

  "EntityCount" should {
    val props = Map(
      "inputField" -> "inputField".asInstanceOf[JSerializable],
      "split" -> ",".asInstanceOf[JSerializable])
    val schema = StructType(Seq(StructField("inputField", StringType)))
    val entityCount = new OperatorEntityCountMock("op1", schema, props)
    val inputFields = Row("hello,bye")

    "Return the associated precision name" in {
      val expected = Option(Seq("hello", "bye"))
      val result = entityCount.processMap(inputFields)
      result should be(expected)
    }

    "Return empty list" in {
      val expected = None
      val result = entityCount.processMap(Row())
      result should be(expected)
    }
  }
} 
Example 39
Source File: EntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.entityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class EntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new EntityCountOperator("entityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new EntityCountOperator("entityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 = new EntityCountOperator("entityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new EntityCountOperator("entityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be(Some(Seq("hola", "holo")))
    }

    "processReduce must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(Seq()))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(Seq("hola", "holo")))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(Seq("hola", "holo", "hola")))
    }

    "associative process must be " in {
      val inputFields = new EntityCountOperator("entityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))

      val inputFields2 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))),
        (Operator.NewValuesKey, Some(Seq("hola"))))
      inputFields2.associativity(resultInput2) should be(Some(Map()))

      val inputFields3 = new EntityCountOperator("entityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(Map("hola" -> 1L, "holo" -> 1L))))
      inputFields3.associativity(resultInput3) should be(Some(Map("hola" -> 1L, "holo" -> 1L)))
    }
  }
} 
Example 40
Source File: SumOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.sum

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class SumOperatorTest extends WordSpec with Matchers {

  "Sum operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new SumOperator("sum", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new SumOperator("sum", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("1", 2)) should be(Some(1))

      val inputFields6 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields6.processMap(Row(1.5, 2)) should be(Some(1.5))

      val inputFields7 = new SumOperator("sum", initSchema, Map("inputField" -> "field1"))
      inputFields7.processMap(Row(5L, 2)) should be(Some(5L))

      val inputFields8 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields8.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields9 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields9.processMap(Row(1, 2)) should be(None)

      val inputFields10 = new SumOperator("sum", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields10.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new SumOperator("sum", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new SumOperator("sum", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(3), Some(7), Some(7))) should be(Some(20d))

      val inputFields3 = new SumOperator("sum", initSchema, Map())
      inputFields3.processReduce(Seq(Some(1), Some(2), Some(3), Some(6.5), Some(7.5))) should be(Some(20d))

      val inputFields4 = new SumOperator("sum", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0d))
    }

    "processReduce distinct must be " in {
      val inputFields = new SumOperator("sum", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0d))

      val inputFields2 = new SumOperator("sum", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(1), Some(2), Some(1))) should be(Some(3d))
    }

    "associative process must be " in {
      val inputFields = new SumOperator("count", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(2d))

      val inputFields2 = new SumOperator("count", initSchema, Map("typeOp" -> "string"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(1L)),
        (Operator.NewValuesKey, Some(1L)),
        (Operator.NewValuesKey, None))
      inputFields2.associativity(resultInput2) should be(Some("2.0"))
    }
  }
} 
Example 41
Source File: FullTextOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.fullText

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class FullTextOperatorTest extends WordSpec with Matchers {

  "FullText operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new FullTextOperator("fullText", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new FullTextOperator("fullText", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row(1, 2)) should be(Some(1))

      val inputFields4 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"<\", \"value\":2}]"))
      inputFields4.processMap(Row(1, 2)) should be(Some(1L))

      val inputFields5 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \">\", \"value\":\"2\"}]"))
      inputFields5.processMap(Row(1, 2)) should be(None)

      val inputFields6 = new FullTextOperator("fullText", initSchema,
        Map("inputField" -> "field1", "filters" -> {
          "[{\"field\":\"field1\", \"type\": \"<\", \"value\":\"2\"}," +
            "{\"field\":\"field2\", \"type\": \"<\", \"value\":\"2\"}]"
        }))
      inputFields6.processMap(Row(1, 2)) should be(None)
    }

    "processReduce must be " in {
      val inputFields = new FullTextOperator("fullText", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(""))

      val inputFields2 = new FullTextOperator("fullText", initSchema, Map())
      inputFields2.processReduce(Seq(Some(1), Some(1))) should be(Some(s"1${Operator.SpaceSeparator}1"))

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map())
      inputFields3.processReduce(Seq(Some("a"), Some("b"))) should be(Some(s"a${Operator.SpaceSeparator}b"))
    }

    "associative process must be " in {
      val inputFields = new FullTextOperator("fullText", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some("2"))

      val inputFields2 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> "arraystring"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(2)),
        (Operator.NewValuesKey, Some(1)))
      inputFields2.associativity(resultInput2) should be(Some(Seq(s"2${Operator.SpaceSeparator}1")))

      val inputFields3 = new FullTextOperator("fullText", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(2)), (Operator.OldValuesKey, Some(3)))
      inputFields3.associativity(resultInput3) should be(Some(s"2${Operator.SpaceSeparator}3"))
    }
  }
} 
Example 42
Source File: TotalEntityCountOperatorTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.operator.totalEntityCount

import com.stratio.sparta.sdk.pipeline.aggregation.operator.Operator
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class TotalEntityCountOperatorTest extends WordSpec with Matchers {

  "Entity Count Operator" should {

    val initSchema = StructType(Seq(
      StructField("field1", IntegerType, false),
      StructField("field2", IntegerType, false),
      StructField("field3", IntegerType, false)
    ))

    val initSchemaFail = StructType(Seq(
      StructField("field2", IntegerType, false)
    ))

    "processMap must be " in {
      val inputField = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputField.processMap(Row(1, 2)) should be(None)

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchemaFail, Map("inputField" -> "field1"))
      inputFields2.processMap(Row(1, 2)) should be(None)

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1"))
      inputFields3.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields4 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields4.processMap(Row("hola holo", 2)) should be(Some(Seq("hola holo")))

      val inputFields5 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> "-"))
      inputFields5.processMap(Row("hola-holo", 2)) should be(Some(Seq("hola", "holo")))

      val inputFields6 =
        new TotalEntityCountOperator("totalEntityCount", initSchema, Map("inputField" -> "field1", "split" -> ","))
      inputFields6.processMap(Row("hola,holo adios", 2)) should be(Some(Seq("hola", "holo " + "adios")))

      val inputFields7 = new TotalEntityCountOperator("totalEntityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]"))
      inputFields7.processMap(Row("hola", 2)) should be(None)

      val inputFields8 = new TotalEntityCountOperator("totalEntityCount", initSchema,
        Map("inputField" -> "field1", "filters" -> "[{\"field\":\"field1\", \"type\": \"!=\", \"value\":\"hola\"}]",
          "split" -> " "))
      inputFields8.processMap(Row("hola holo", 2)) should be
      (Some(Seq("hola", "holo")))
    }

    "processReduce must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields.processReduce(Seq()) should be(Some(0L))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo")))) should be(Some(2L))

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields3.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(3L))

      val inputFields4 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      inputFields4.processReduce(Seq(None)) should be(Some(0L))
    }

    "processReduce distinct must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true"))
      inputFields.processReduce(Seq()) should be(Some(0L))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("distinct" -> "true"))
      inputFields2.processReduce(Seq(Some(Seq("hola", "holo", "hola")))) should be(Some(2L))
    }

    "associative process must be " in {
      val inputFields = new TotalEntityCountOperator("totalEntityCount", initSchema, Map())
      val resultInput = Seq((Operator.OldValuesKey, Some(2)), (Operator.NewValuesKey, None))
      inputFields.associativity(resultInput) should be(Some(2))

      val inputFields2 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> "int"))
      val resultInput2 = Seq((Operator.OldValuesKey, Some(2)),
        (Operator.NewValuesKey, Some(1)))
      inputFields2.associativity(resultInput2) should be(Some(3))

      val inputFields3 = new TotalEntityCountOperator("totalEntityCount", initSchema, Map("typeOp" -> null))
      val resultInput3 = Seq((Operator.OldValuesKey, Some(2)))
      inputFields3.associativity(resultInput3) should be(Some(2))
    }
  }
} 
Example 43
Source File: HierarchyFieldTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.field.hierarchy

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.prop.TableDrivenPropertyChecks
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class HierarchyFieldTest extends WordSpecLike
with Matchers
with BeforeAndAfter
with BeforeAndAfterAll
with TableDrivenPropertyChecks {

  var hbs: Option[HierarchyField] = _

  before {
    hbs = Some(new HierarchyField())
  }

  after {
    hbs = None
  }

  "A HierarchyDimension" should {
    "In default implementation, get 4 precisions for all precision sizes" in {
      val precisionLeftToRight = hbs.get.precisionValue(HierarchyField.LeftToRightName, "")
      val precisionRightToLeft = hbs.get.precisionValue(HierarchyField.RightToLeftName, "")
      val precisionLeftToRightWithWildCard = hbs.get.precisionValue(HierarchyField.LeftToRightWithWildCardName, "")
      val precisionRightToLeftWithWildCard = hbs.get.precisionValue(HierarchyField.RightToLeftWithWildCardName, "")

      precisionLeftToRight._1.id should be(HierarchyField.LeftToRightName)
      precisionRightToLeft._1.id should be(HierarchyField.RightToLeftName)
      precisionLeftToRightWithWildCard._1.id should be(HierarchyField.LeftToRightWithWildCardName)
      precisionRightToLeftWithWildCard._1.id should be(HierarchyField.RightToLeftWithWildCardName)
    }

    "In default implementation, every proposed combination should be ok" in {
      val data = Table(
        ("i", "o"),
        ("google.com", Seq("google.com", "*.com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.LeftToRightWithWildCardName, i)
        assertResult(o)(result._2)
      }
    }
    "In reverse implementation, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("com.stratio.sparta", Seq("com.stratio.sparta", "com.stratio.*", "com.*", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.RightToLeftWithWildCardName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
    "In reverse implementation without wildcards, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("com.stratio.sparta", Seq("com.stratio.sparta", "com.stratio", "com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.RightToLeftName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
    "In non-reverse implementation without wildcards, every proposed combination should be ok" in {
      hbs = Some(new HierarchyField())
      val data = Table(
        ("i", "o"),
        ("google.com", Seq("google.com", "com", "*"))
      )

      forAll(data) { (i: String, o: Seq[String]) =>
        val result = hbs.get.precisionValue(HierarchyField.LeftToRightName, i.asInstanceOf[Any])
        assertResult(o)(result._2)
      }
    }
  }
} 
Example 44
Source File: DateTimeFieldTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.field.datetime

import java.io.{Serializable => JSerializable}
import java.util.Date

import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class DateTimeFieldTest extends WordSpecLike with Matchers {

  val dateTimeDimension = new DateTimeField(Map("second" -> "long", "minute" -> "date", "typeOp" -> "datetime"))

  "A DateTimeDimension" should {
    "In default implementation, get 6 dimensions for a specific time" in {
      val newDate = new Date()
      val precision5s =
        dateTimeDimension.precisionValue("5s", newDate.asInstanceOf[JSerializable])
      val precision10s =
        dateTimeDimension.precisionValue("10s", newDate.asInstanceOf[JSerializable])
      val precision15s =
        dateTimeDimension.precisionValue("15s", newDate.asInstanceOf[JSerializable])
      val precisionSecond =
        dateTimeDimension.precisionValue("second", newDate.asInstanceOf[JSerializable])
      val precisionMinute =
        dateTimeDimension.precisionValue("minute", newDate.asInstanceOf[JSerializable])
      val precisionHour =
        dateTimeDimension.precisionValue("hour", newDate.asInstanceOf[JSerializable])
      val precisionDay =
        dateTimeDimension.precisionValue("day", newDate.asInstanceOf[JSerializable])
      val precisionMonth =
        dateTimeDimension.precisionValue("month", newDate.asInstanceOf[JSerializable])
      val precisionYear =
        dateTimeDimension.precisionValue("year", newDate.asInstanceOf[JSerializable])

      precision5s._1.id should be("5s")
      precision10s._1.id should be("10s")
      precision15s._1.id should be("15s")
      precisionSecond._1.id should be("second")
      precisionMinute._1.id should be("minute")
      precisionHour._1.id should be("hour")
      precisionDay._1.id should be("day")
      precisionMonth._1.id should be("month")
      precisionYear._1.id should be("year")
    }

    "Each precision dimension have their output type, second must be long, minute must be date, others datetime" in {
      dateTimeDimension.precision("5s").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("10s").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("15s").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("second").typeOp should be(TypeOp.Long)
      dateTimeDimension.precision("minute").typeOp should be(TypeOp.Date)
      dateTimeDimension.precision("day").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("month").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision("year").typeOp should be(TypeOp.DateTime)
      dateTimeDimension.precision(DateTimeField.TimestampPrecision.id).typeOp should be(TypeOp.Timestamp)
    }
  }
} 
Example 45
Source File: DefaultFieldTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.cube.field.defaultField

import com.stratio.sparta.plugin.default.DefaultField
import com.stratio.sparta.sdk.pipeline.aggregation.cube.{DimensionType, Precision}
import com.stratio.sparta.sdk.pipeline.schema.TypeOp
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class DefaultFieldTest extends WordSpecLike with Matchers {

  val defaultDimension: DefaultField = new DefaultField(Map("typeOp" -> "int"))

  "A DefaultDimension" should {
    "In default implementation, get one precisions for a specific time" in {
      val precision: (Precision, Any) = defaultDimension.precisionValue("", "1".asInstanceOf[Any])

      precision._2 should be(1)

      precision._1.id should be(DimensionType.IdentityName)
    }

    "The precision must be int" in {
      defaultDimension.precision(DimensionType.IdentityName).typeOp should be(TypeOp.Int)
    }
  }
} 
Example 46
Source File: SocketInputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.socket

import java.io.{Serializable => JSerializable}

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SocketInputTest extends WordSpec {

  "A SocketInput" should {
    "instantiate successfully with parameters" in {
      new SocketInput(Map("hostname" -> "localhost", "port" -> 9999).mapValues(_.asInstanceOf[JSerializable]))
    }
    "fail without parameters" in {
      intercept[IllegalStateException] {
        new SocketInput(Map())
      }
    }
    "fail with bad port argument" in {
      intercept[IllegalStateException] {
        new SocketInput(Map("hostname" -> "localhost", "port" -> "BADPORT").mapValues(_.asInstanceOf[JSerializable]))
      }
    }
  }
} 
Example 47
Source File: TwitterJsonInputTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.twitter

import java.io.{Serializable => JSerializable}

import org.junit.runner.RunWith
import org.scalatest._
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TwitterJsonInputTest extends WordSpec {

  "A TwitterInput" should {

    "fail without parameters" in {
      intercept[IllegalStateException] {
        new TwitterJsonInput(Map())
      }
    }
    "fail with bad arguments argument" in {
      intercept[IllegalStateException] {
        new TwitterJsonInput(Map("hostname" -> "localhost", "port" -> "BADPORT")
          .mapValues(_.asInstanceOf[JSerializable]))
      }
    }
  }
} 
Example 48
Source File: RabbitMQInputIT.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.rabbitmq

import java.util.UUID

import akka.pattern.{ask, gracefulStop}
import com.github.sstone.amqp.Amqp._
import com.github.sstone.amqp.{Amqp, ChannelOwner, ConnectionOwner, Consumer}
import com.rabbitmq.client.ConnectionFactory
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


import scala.concurrent.Await

@RunWith(classOf[JUnitRunner])
class RabbitMQInputIT extends RabbitIntegrationSpec {

  val queueName = s"$configQueueName-${this.getClass.getName}-${UUID.randomUUID().toString}"

  def initRabbitMQ(): Unit = {
    val connFactory = new ConnectionFactory()
    connFactory.setUri(RabbitConnectionURI)
    val conn = system.actorOf(ConnectionOwner.props(connFactory, RabbitTimeOut))
    val producer = ConnectionOwner.createChildActor(
      conn,
      ChannelOwner.props(),
      timeout = RabbitTimeOut,
      name = Some("RabbitMQ.producer")
    )

    val queue = QueueParameters(
      name = queueName,
      passive = false,
      exclusive = false,
      durable = true,
      autodelete = false
    )

    Amqp.waitForConnection(system, conn, producer).await()

    val deleteQueueResult = producer ? DeleteQueue(queueName)
    Await.result(deleteQueueResult, RabbitTimeOut)
    val createQueueResult = producer ? DeclareQueue(queue)
    Await.result(createQueueResult, RabbitTimeOut)

    //Send some messages to the queue
    val results = for (register <- 1 to totalRegisters)
      yield producer ? Publish(
        exchange = "",
        key = queueName,
        body = register.toString.getBytes
      )
    results.map(result => Await.result(result, RabbitTimeOut))
    
    conn ! Close()
    Await.result(gracefulStop(conn, RabbitTimeOut), RabbitTimeOut * 2)
    Await.result(gracefulStop(consumer, RabbitTimeOut), RabbitTimeOut * 2)
  }


  "RabbitMQInput " should {

    "Read all the records" in {
      val props = Map(
        "hosts" -> hosts,
        "queueName" -> queueName)

      val input = new RabbitMQInput(props)
      val distributedStream = input.initStream(ssc.get, DefaultStorageLevel)
      val totalEvents = ssc.get.sparkContext.accumulator(0L, "Number of events received")

      // Fires each time the configured window has passed.
      distributedStream.foreachRDD(rdd => {
        if (!rdd.isEmpty()) {
          val count = rdd.count()
          // Do something with this message
          log.info(s"EVENTS COUNT : $count")
          totalEvents.add(count)
        } else log.info("RDD is empty")
        log.info(s"TOTAL EVENTS : $totalEvents")
      })

      ssc.get.start() // Start the computation
      ssc.get.awaitTerminationOrTimeout(SparkTimeOut) // Wait for the computation to terminate

      totalEvents.value should ===(totalRegisters.toLong)
    }
  }

} 
Example 49
Source File: MessageHandlerTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.rabbitmq.handler

import com.rabbitmq.client.QueueingConsumer.Delivery
import com.stratio.sparta.plugin.input.rabbitmq.handler.MessageHandler.{ByteArrayMessageHandler, StringMessageHandler}
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.junit.JUnitRunner
import org.scalatest.mock.MockitoSugar
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class MessageHandlerTest extends WordSpec with Matchers with MockitoSugar {

  val message = "This is the message for testing"
  "RabbitMQ MessageHandler Factory " should {

    "Get correct handler for string with a map " in {
      MessageHandler(Map(MessageHandler.KeyDeserializer -> "")) should matchPattern { case StringMessageHandler => }
      MessageHandler(Map.empty[String, String]) should matchPattern { case StringMessageHandler => }
      MessageHandler(Map(MessageHandler.KeyDeserializer -> "badInput")) should
        matchPattern { case StringMessageHandler => }
      MessageHandler(Map(MessageHandler.KeyDeserializer -> "arraybyte")) should
        matchPattern { case ByteArrayMessageHandler => }
    }

    "Get correct handler for string " in {
      val result = MessageHandler("string")
      result should matchPattern { case StringMessageHandler => }
    }

    "Get correct handler for arraybyte " in {
      val result = MessageHandler("arraybyte")
      result should matchPattern { case ByteArrayMessageHandler => }
    }

    "Get correct handler for empty input " in {
      val result = MessageHandler("")
      result should matchPattern { case StringMessageHandler => }
    }

    "Get correct handler for bad input " in {
      val result = MessageHandler("badInput")
      result should matchPattern { case StringMessageHandler => }
    }
  }

  "StringMessageHandler " should {
    "Handle strings" in {
      val delivery = mock[Delivery]
      when(delivery.getBody).thenReturn(message.getBytes)
      val result = StringMessageHandler.handler(delivery)
      verify(delivery, times(1)).getBody
      result.getString(0) shouldBe message
    }
  }

  "ByteArrayMessageHandler " should {
    "Handle bytes" in {
      val delivery = mock[Delivery]
      when(delivery.getBody).thenReturn(message.getBytes)
      val result = ByteArrayMessageHandler.handler(delivery)
      verify(delivery, times(1)).getBody
      result.getAs[Array[Byte]](0) shouldBe message.getBytes
    }

  }

} 
Example 50
Source File: HostPortZkTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.input.kafka


import java.io.Serializable

import com.stratio.sparta.sdk.properties.JsoneyString
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpec}

@RunWith(classOf[JUnitRunner])
class HostPortZkTest extends WordSpec with Matchers {

  class KafkaTestInput(val properties: Map[String, Serializable]) extends KafkaBase
  
  "getHostPortZk" should {

    "return a chain (zookeper:conection , host:port)" in {
      val conn = """[{"host": "localhost", "port": "2181"}]"""
      val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181"))
    }

    "return a chain (zookeper:conection , host:port, zookeeper.path:path)" in {
      val conn = """[{"host": "localhost", "port": "2181"}]"""
      val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "/test")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181/test"))
    }

    "return a chain (zookeper:conection , host:port,host:port,host:port)" in {
      val conn =
        """[{"host": "localhost", "port": "2181"},{"host": "localhost", "port": "2181"},
          |{"host": "localhost", "port": "2181"}]""".stripMargin
      val props = Map("zookeeper.connect" -> JsoneyString(conn))
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181,localhost:2181,localhost:2181"))
    }

    "return a chain (zookeper:conection , host:port,host:port,host:port, zookeeper.path:path)" in {
      val conn =
        """[{"host": "localhost", "port": "2181"},{"host": "localhost", "port": "2181"},
          |{"host": "localhost", "port": "2181"}]""".stripMargin
      val props = Map("zookeeper.connect" -> JsoneyString(conn), "zookeeper.path" -> "/test")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181,localhost:2181,localhost:2181/test"))
    }

    "return a chain with default port (zookeper:conection , host: defaultport)" in {

      val props = Map("foo" -> "var")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181"))
    }

    "return a chain with default port (zookeper:conection , host: defaultport, zookeeper.path:path)" in {
      val props = Map("zookeeper.path" -> "/test")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181/test"))
    }

    "return a chain with default host and default porty (zookeeper.connect: ," +
      "defaultHost: defaultport," +
      "zookeeper.path:path)" in {
      val props = Map("foo" -> "var")
      val input = new KafkaTestInput(props)

      input.getHostPortZk("zookeeper.connect", "localhost", "2181") should
        be(Map("zookeeper.connect" -> "localhost:2181"))
    }
  }
} 
Example 51
Source File: MorphlinesParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.morphline

import java.io.Serializable

import com.stratio.sparta.sdk.pipeline.input.Input
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, Matchers, WordSpecLike}


@RunWith(classOf[JUnitRunner])
class MorphlinesParserTest extends WordSpecLike with Matchers with BeforeAndAfter with BeforeAndAfterAll {

  val morphlineConfig = """
          id : test1
          importCommands : ["org.kitesdk.**"]
          commands: [
          {
              readJson {},
          }
          {
              extractJsonPaths {
                  paths : {
                      col1 : /col1
                      col2 : /col2
                  }
              }
          }
          {
            java {
              code : "return child.process(record);"
            }
          }
          {
              removeFields {
                  blacklist:["literal:_attachment_body"]
              }
          }
          ]
                        """
  val inputField = Some(Input.RawDataKey)
  val outputsFields = Seq("col1", "col2")
  val props: Map[String, Serializable] = Map("morphline" -> morphlineConfig)

  val schema = StructType(Seq(StructField("col1", StringType), StructField("col2", StringType)))

  val parser = new MorphlinesParser(1, inputField, outputsFields, schema, props)

  "A MorphlinesParser" should {

    "parse a simple json" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row(simpleJson, "hello", "world"))

      result should be eq(expected)
    }

    "parse a simple json removing raw" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row("hello", "world"))

      result should be eq(expected)
    }

    "exclude not configured fields" in {
      val simpleJson =
        """{
            "col1":"hello",
            "col2":"word",
            "col3":"!"
            }
        """
      val input = Row(simpleJson)
      val result = parser.parse(input)

      val expected = Seq(Row(simpleJson, "hello", "world"))

      result should be eq(expected)
    }
  }
} 
Example 52
Source File: DateTimeParserTest.scala    From sparta   with Apache License 2.0 5 votes vote down vote up
package com.stratio.sparta.plugin.transformation.datetime

import com.stratio.sparta.sdk.properties.JsoneyString
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class DateTimeParserTest extends WordSpecLike with Matchers {

  val inputField = Some("ts")
  val outputsFields = Seq("ts")

  //scalastyle:off
  "A DateTimeParser" should {
    "parse unixMillis to string" in {
      val input = Row(1416330788000L)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unixMillis"))
          .parse(input)

      val expected = Seq(Row(1416330788000L, "1416330788000"))

      assertResult(result)(expected)
    }

    "parse unix to string" in {
      val input = Row(1416330788)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unix"))
          .parse(input)

      val expected = Seq(Row(1416330788, "1416330788000"))

      assertResult(result)(expected)
    }

    "parse unix to string removing raw" in {
      val input = Row(1416330788)
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unix",
          "removeInputField" -> JsoneyString.apply("true")))
          .parse(input)

      val expected = Seq(Row("1416330788000"))

      assertResult(result)(expected)
    }

    "not parse anything if the field does not match" in {
      val input = Row("1212")
      val schema = StructType(Seq(StructField("otherField", StringType)))

      an[IllegalStateException] should be thrownBy new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "unixMillis")).parse(input)
    }

    "not parse anything and generate a new Date" in {
      val input = Row("anything")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "autoGenerated"))
          .parse(input)

      assertResult(result.head.size)(2)
    }

    "Auto generated if inputFormat does not exist" in {
      val input = Row("1416330788")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map()).parse(input)

      assertResult(result.head.size)(2)
    }

    "parse dateTime in hive format" in {
      val input = Row("2015-11-08 15:58:58")
      val schema = StructType(Seq(StructField("ts", StringType)))

      val result =
        new DateTimeParser(1, inputField, outputsFields, schema, Map("inputFormat" -> "hive"))
          .parse(input)

      val expected = Seq(Row("2015-11-08 15:58:58", "1446998338000"))

      assertResult(result)(expected)
    }
  }
} 
Example 53
Source File: LongInputTests.scala    From boson   with Apache License 2.0 5 votes vote down vote up
package io.zink.boson

import bsonLib.BsonObject
import io.netty.util.ResourceLeakDetector
import io.vertx.core.json.JsonObject
import io.zink.boson.bson.bsonImpl.BosonImpl
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner
import org.junit.Assert._

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.io.Source


@RunWith(classOf[JUnitRunner])
class LongInputTests extends FunSuite {
  ResourceLeakDetector.setLevel(ResourceLeakDetector.Level.ADVANCED)

  val bufferedSource: Source = Source.fromURL(getClass.getResource("/jsonOutput.txt"))
  val finale: String = bufferedSource.getLines.toSeq.head
  bufferedSource.close

  val json: JsonObject = new JsonObject(finale)
  val bson: BsonObject = new BsonObject(json)

  test("extract top field") {
    val expression: String = ".Epoch"
    val boson: Boson = Boson.extractor(expression, (out: Int) => {
      assertTrue(3 == out)
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
  }

  test("extract bottom field") {
    val expression: String = "SSLNLastName"
    val expected: String = "de Huanuco"
    val boson: Boson = Boson.extractor(expression, (out: String) => {
      assertTrue(expected.zip(out).forall(e => e._1.equals(e._2)))
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
  }

  test("extract positions of an Array") {
    val expression: String = "Markets[3 to 5]"
    val mutableBuffer: ArrayBuffer[Array[Byte]] = ArrayBuffer()
    val boson: Boson = Boson.extractor(expression, (out: Array[Byte]) => {
      mutableBuffer += out
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
    assertEquals(3, mutableBuffer.size)
  }

  test("extract further positions of an Array") {
    val expression: String = "Markets[50 to 55]"
    val mutableBuffer: ArrayBuffer[Array[Byte]] = ArrayBuffer()
    val boson: Boson = Boson.extractor(expression, (out: Array[Byte]) => {
      mutableBuffer += out
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
    assertEquals(6, mutableBuffer.size)
  }

  test("size of all occurrences of Key") {
    val expression: String = "Price"
    val mutableBuffer: ArrayBuffer[Float] = ArrayBuffer()
    val boson: Boson = Boson.extractor(expression, (out: Float) => {
      mutableBuffer += out
    })
    val res = boson.go(bson.encode.getBytes)
    Await.result(res, Duration.Inf)
    assertEquals(195, mutableBuffer.size)
  }

} 
Example 54
Source File: StorageTest.scala    From mqttd   with MIT License 5 votes vote down vote up
package plantae.citrus.mqtt.actors.session

import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class StorageTest extends FunSuite {
  test("persist test") {
    val storage = Storage("persist-test")
    Range(1, 2000).foreach(count => {
      storage.persist((count + " persist").getBytes, (count % 3).toShort, true, "topic" + count)
    })

    assert(
      !Range(1, 2000).exists(count => {
        storage.nextMessage match {
          case Some(message) =>
            storage.complete(message.packetId match {
              case Some(x) => Some(x)
              case None => None
            })
            println(new String(message.payload.toArray))
            count + " persist" != new String(message.payload.toArray)

          case None => true
        }
      })
    )
  }
} 
Example 55
Source File: LogisticRegressionTest.scala    From spark-cp   with Apache License 2.0 5 votes vote down vote up
package se.uu.farmbio.cp.alg

import org.apache.spark.SharedSparkContext
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import se.uu.farmbio.cp.ICP
import se.uu.farmbio.cp.TestUtils
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class LogisticRegressionTest extends FunSuite with SharedSparkContext {

  test("test performance") {
    val trainData = TestUtils.generateBinaryData(100, 11)
    val testData = TestUtils.generateBinaryData(30, 22)
    val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16)  
    val lr = new LogisticRegression(properTrain, 30)
    val model = ICP.trainClassifier(lr, numClasses=2, calibration)
    assert(TestUtils.testPerformance(model, sc.parallelize(testData)))
  }

} 
Example 56
Source File: SVMTest.scala    From spark-cp   with Apache License 2.0 5 votes vote down vote up
package se.uu.farmbio.cp.alg

import org.apache.spark.SharedSparkContext
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import se.uu.farmbio.cp.ICP
import se.uu.farmbio.cp.TestUtils
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SVMTest extends FunSuite with SharedSparkContext {

  test("test performance") {
    val trainData = TestUtils.generateBinaryData(100, 11)
    val testData = TestUtils.generateBinaryData(30, 22)
    val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16)  
    val svm = new SVM(properTrain, 30)
    val model = ICP.trainClassifier(svm, numClasses=2, calibration)
    assert(TestUtils.testPerformance(model, sc.parallelize(testData)))
  }

} 
Example 57
Source File: GBTTest.scala    From spark-cp   with Apache License 2.0 5 votes vote down vote up
package se.uu.farmbio.cp.alg

import scala.util.Random
import org.apache.spark.SharedSparkContext
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import se.uu.farmbio.cp.ICP
import se.uu.farmbio.cp.TestUtils
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class GBTTest extends FunSuite with SharedSparkContext {
  
  Random.setSeed(11)

  test("test performance") {
    val trainData = TestUtils.generateBinaryData(100, 11)
    val testData = TestUtils.generateBinaryData(30, 22)
    val (calibration, properTrain) = ICP.calibrationSplit(sc.parallelize(trainData), 16)  
    val gbt = new GBT(properTrain, 30)
    val model = ICP.trainClassifier(gbt, numClasses=2, calibration)
    assert(TestUtils.testPerformance(model, sc.parallelize(testData)))
  }

} 
Example 58
Source File: KafkaTestUtilsTest.scala    From spark-testing-base   with Apache License 2.0 5 votes vote down vote up
package com.holdenkarau.spark.testing.kafka

import java.util.Properties

import scala.collection.JavaConversions._

import kafka.consumer.ConsumerConfig
import org.apache.spark.streaming.kafka.KafkaTestUtils
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, FunSuite}

@RunWith(classOf[JUnitRunner])
class KafkaTestUtilsTest extends FunSuite with BeforeAndAfterAll {

  private var kafkaTestUtils: KafkaTestUtils = _

  override def beforeAll(): Unit = {
    kafkaTestUtils = new KafkaTestUtils
    kafkaTestUtils.setup()
  }

  override def afterAll(): Unit = if (kafkaTestUtils != null) {
    kafkaTestUtils.teardown()
    kafkaTestUtils = null
  }

  test("Kafka send and receive message") {
    val topic = "test-topic"
    val message = "HelloWorld!"
    kafkaTestUtils.createTopic(topic)
    kafkaTestUtils.sendMessages(topic, message.getBytes)

    val consumerProps = new Properties()
    consumerProps.put("zookeeper.connect", kafkaTestUtils.zkAddress)
    consumerProps.put("group.id", "test-group")
    consumerProps.put("flow-topic", topic)
    consumerProps.put("auto.offset.reset", "smallest")
    consumerProps.put("zookeeper.session.timeout.ms", "2000")
    consumerProps.put("zookeeper.connection.timeout.ms", "6000")
    consumerProps.put("zookeeper.sync.time.ms", "2000")
    consumerProps.put("auto.commit.interval.ms", "2000")

    val consumer = kafka.consumer.Consumer.createJavaConsumerConnector(new ConsumerConfig(consumerProps))

    try {
      val topicCountMap = Map(topic -> new Integer(1))
      val consumerMap = consumer.createMessageStreams(topicCountMap)
      val stream = consumerMap.get(topic).get(0)
      val it = stream.iterator()
      val mess = it.next
      assert(new String(mess.message().map(_.toChar)) === message)
    } finally {
      consumer.shutdown()
    }
  }

} 
Example 59
Source File: TestBase.scala    From open-korean-text   with Apache License 2.0 5 votes vote down vote up
package org.openkoreantext.processor

import java.util.logging.{Level, Logger}

import org.junit.runner.RunWith
import org.openkoreantext.processor.util.KoreanDictionaryProvider._
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

object TestBase {

  case class ParseTime(time: Long, chunk: String)

  def time[R](block: => R): Long = {
    val t0 = System.currentTimeMillis()
    block
    val t1 = System.currentTimeMillis()
    t1 - t0
  }

  def assertExamples(exampleFiles: String, log: Logger, f: (String => String)) {
    assert({
      val input = readFileByLineFromResources(exampleFiles)

      val (parseTimes, hasErrors) = input.foldLeft((List[ParseTime](), true)) {
        case ((l: List[ParseTime], output: Boolean), line: String) =>
          val s = line.split("\t")
          val (chunk, parse) = (s(0), if (s.length == 2) s(1) else "")

          val oldTokens = parse
          val t0 = System.currentTimeMillis()
          val newTokens = f(chunk)
          val t1 = System.currentTimeMillis()

          val oldParseMatches = oldTokens == newTokens

          if (!oldParseMatches) {
            System.err.println("Example set match error: %s \n - EXPECTED: %s\n - ACTUAL  : %s".format(
              chunk, oldTokens, newTokens))
          }

          (ParseTime(t1 - t0, chunk) :: l, output && oldParseMatches)
      }

      val averageTime = parseTimes.map(_.time).sum.toDouble / parseTimes.size
      val maxItem = parseTimes.maxBy(_.time)

      log.log(Level.INFO, ("Parsed %d chunks. \n" +
          "       Total time: %d ms \n" +
          "       Average time: %.2f ms \n" +
          "       Max time: %d ms, %s").format(
            parseTimes.size,
            parseTimes.map(_.time).sum,
            averageTime,
            maxItem.time,
            maxItem.chunk
          ))
      hasErrors
    }, "Some parses did not match the example set.")
  }
}

@RunWith(classOf[JUnitRunner])
abstract class TestBase extends FunSuite 
Example 60
Source File: StreamingFormulaDemo1.scala    From sscheck   with Apache License 2.0 5 votes vote down vote up
package es.ucm.fdi.sscheck.spark.demo

import org.junit.runner.RunWith
import org.specs2.runner.JUnitRunner
import org.specs2.ScalaCheck
import org.specs2.Specification
import org.specs2.matcher.ResultMatchers
import org.scalacheck.Arbitrary.arbitrary

import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.dstream.DStream

import es.ucm.fdi.sscheck.spark.streaming.SharedStreamingContextBeforeAfterEach
import es.ucm.fdi.sscheck.prop.tl.{Formula,DStreamTLProperty}
import es.ucm.fdi.sscheck.prop.tl.Formula._
import es.ucm.fdi.sscheck.gen.{PDStreamGen,BatchGen}

@RunWith(classOf[JUnitRunner])
class StreamingFormulaDemo1 
  extends Specification 
  with DStreamTLProperty
  with ResultMatchers
  with ScalaCheck {
  
  // Spark configuration
  override def sparkMaster : String = "local[*]"
  override def batchDuration = Duration(150)
  override def defaultParallelism = 4  

  def is = 
    sequential ^ s2"""
    Simple demo Specs2 example for ScalaCheck properties with temporal
    formulas on Spark Streaming programs
      - where a simple property for DStream.count is a success ${countForallAlwaysProp(_.count)}     
      - where a faulty implementation of the DStream.count is detected ${countForallAlwaysProp(faultyCount) must beFailing}
    """
      
  def faultyCount(ds : DStream[Double]) : DStream[Long] = 
    ds.count.transform(_.map(_ - 1))
      
  def countForallAlwaysProp(testSubject : DStream[Double] => DStream[Long]) = {
    type U = (RDD[Double], RDD[Long])
    val (inBatch, transBatch) = ((_ : U)._1, (_ : U)._2)
    val numBatches = 10
    val formula : Formula[U] = always { (u : U) =>
      transBatch(u).count === 1 and
      inBatch(u).count === transBatch(u).first 
    } during numBatches

    val gen = BatchGen.always(BatchGen.ofNtoM(10, 50, arbitrary[Double]), numBatches)
    
    forAllDStream(
      gen)(
      testSubject)(
      formula)
  }.set(minTestsOk = 10).verbose  
  
} 
Example 61
Source File: StreamingFormulaDemo2.scala    From sscheck   with Apache License 2.0 5 votes vote down vote up
package es.ucm.fdi.sscheck.spark.demo

import org.junit.runner.RunWith
import org.specs2.runner.JUnitRunner
import org.specs2.ScalaCheck
import org.specs2.Specification
import org.specs2.matcher.ResultMatchers
import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Gen

import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.dstream.DStream._

import scalaz.syntax.std.boolean._
    
import es.ucm.fdi.sscheck.spark.streaming.SharedStreamingContextBeforeAfterEach
import es.ucm.fdi.sscheck.prop.tl.{Formula,DStreamTLProperty}
import es.ucm.fdi.sscheck.prop.tl.Formula._
import es.ucm.fdi.sscheck.gen.{PDStreamGen,BatchGen}
import es.ucm.fdi.sscheck.gen.BatchGenConversions._
import es.ucm.fdi.sscheck.gen.PDStreamGenConversions._
import es.ucm.fdi.sscheck.matcher.specs2.RDDMatchers._

@RunWith(classOf[JUnitRunner])
class StreamingFormulaDemo2 
  extends Specification 
  with DStreamTLProperty
  with ResultMatchers
  with ScalaCheck {
  
  // Spark configuration
  override def sparkMaster : String = "local[*]"
  override def batchDuration = Duration(300)
  override def defaultParallelism = 3
  override def enableCheckpointing = true

  def is = 
    sequential ^ s2"""
    Check process to persistently detect and ban bad users
      - where a stateful implementation extracts the banned users correctly ${checkExtractBannedUsersList(listBannedUsers)}
      - where a trivial implementation ${checkExtractBannedUsersList(statelessListBannedUsers) must beFailing}
    """
  type UserId = Long
  
  def listBannedUsers(ds : DStream[(UserId, Boolean)]) : DStream[UserId] = 
    ds.updateStateByKey((flags : Seq[Boolean], maybeFlagged : Option[Unit]) =>
      maybeFlagged match {
        case Some(_) => maybeFlagged  
        case None => flags.contains(false) option {()}
      } 
    ).transform(_.keys)
      
  def statelessListBannedUsers(ds : DStream[(UserId, Boolean)]) : DStream[UserId] =
    ds.map(_._1)
    
  def checkExtractBannedUsersList(testSubject : DStream[(UserId, Boolean)] => DStream[UserId]) = {
    val batchSize = 20 
    val (headTimeout, tailTimeout, nestedTimeout) = (10, 10, 5) 
    val (badId, ids) = (15L, Gen.choose(1L, 50L))   
    val goodBatch = BatchGen.ofN(batchSize, ids.map((_, true)))
    val badBatch = goodBatch + BatchGen.ofN(1, (badId, false))
    val gen = BatchGen.until(goodBatch, badBatch, headTimeout) ++ 
               BatchGen.always(Gen.oneOf(goodBatch, badBatch), tailTimeout)
    
    type U = (RDD[(UserId, Boolean)], RDD[UserId])
    val (inBatch, outBatch) = ((_ : U)._1, (_ : U)._2)
    
    val formula = {
      val badInput = at(inBatch)(_ should existsRecord(_ == (badId, false)))
      val allGoodInputs = at(inBatch)(_ should foreachRecord(_._2 == true))
      val noIdBanned = at(outBatch)(_.isEmpty)
      val badIdBanned = at(outBatch)(_ should existsRecord(_ == badId))
      
      ( ( allGoodInputs and noIdBanned ) until badIdBanned on headTimeout ) and
      ( always { badInput ==> (always(badIdBanned) during nestedTimeout) } during tailTimeout )  
    }  
    
    forAllDStream(    
      gen)(
      testSubject)( 
      formula)
  }.set(minTestsOk = 10).verbose

} 
Example 62
Source File: SimpleStreamingFormulas.scala    From sscheck   with Apache License 2.0 5 votes vote down vote up
package es.ucm.fdi.sscheck.spark.simple

import org.junit.runner.RunWith
import org.specs2.runner.JUnitRunner
import org.specs2.matcher.ResultMatchers
import org.scalacheck.Arbitrary.arbitrary
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.Duration
import org.apache.spark.streaming.dstream.DStream
import es.ucm.fdi.sscheck.spark.streaming.SharedStreamingContextBeforeAfterEach
import es.ucm.fdi.sscheck.prop.tl.{Formula,DStreamTLProperty}
import es.ucm.fdi.sscheck.prop.tl.Formula._
import es.ucm.fdi.sscheck.matcher.specs2.RDDMatchers._
import es.ucm.fdi.sscheck.gen.{PDStreamGen,BatchGen}
import org.scalacheck.Gen
import es.ucm.fdi.sscheck.gen.PDStream
import es.ucm.fdi.sscheck.gen.Batch

@RunWith(classOf[JUnitRunner])
class SimpleStreamingFormulas 
  extends org.specs2.Specification 
  with DStreamTLProperty
  with org.specs2.ScalaCheck {
  
   // Spark configuration
  override def sparkMaster : String = "local[*]"
  override def batchDuration = Duration(50)
  override def defaultParallelism = 4  

  def is = 
    sequential ^ s2"""
    Simple demo Specs2 example for ScalaCheck properties with temporal
    formulas on Spark Streaming programs
      - Given a stream of integers
        When we filter out negative numbers
        Then we get only numbers greater or equal to 
          zero $filterOutNegativeGetGeqZero
      - where time increments for each batch $timeIncreasesMonotonically
      """
      
    def filterOutNegativeGetGeqZero = {
      type U = (RDD[Int], RDD[Int])
      val numBatches = 10
      val gen = BatchGen.always(BatchGen.ofNtoM(10, 50, arbitrary[Int]), 
                                numBatches)
      val formula = always(nowTime[U]{ (letter, time) => 
        val (_input, output) = letter
        output should foreachRecord {_ >= 0} 
      }) during numBatches
      
      forAllDStream(
      gen)(
      _.filter{ x => !(x < 0)})(
      formula)
    }.set(minTestsOk = 50).verbose

    def timeIncreasesMonotonically = {
      type U = (RDD[Int], RDD[Int])
      val numBatches = 10
      val gen = BatchGen.always(BatchGen.ofNtoM(10, 50, arbitrary[Int]))

      val formula = always(nextTime[U]{ (letter, time) =>
        nowTime[U]{ (nextLetter, nextTime) =>
          time.millis <= nextTime.millis
        }
      }) during numBatches-1

      forAllDStream(
      gen)(
      identity[DStream[Int]])(
      formula)
    }.set(minTestsOk = 10).verbose
} 
Example 63
Source File: SharedStreamingContextBeforeAfterEachTest.scala    From sscheck   with Apache License 2.0 5 votes vote down vote up
package es.ucm.fdi.sscheck.spark.streaming

import org.junit.runner.RunWith
import org.specs2.runner.JUnitRunner 
import org.specs2.execute.Result

import org.apache.spark.streaming.Duration
import org.apache.spark.rdd.RDD

import scala.collection.mutable.Queue
import scala.concurrent.duration._

import org.slf4j.LoggerFactory

import es.ucm.fdi.sscheck.matcher.specs2.RDDMatchers._

// sbt "test-only es.ucm.fdi.sscheck.spark.streaming.SharedStreamingContextBeforeAfterEachTest"

@RunWith(classOf[JUnitRunner])
class SharedStreamingContextBeforeAfterEachTest 
  extends org.specs2.Specification 
  with org.specs2.matcher.MustThrownExpectations 
  with org.specs2.matcher.ResultMatchers
  with SharedStreamingContextBeforeAfterEach {
  
  // cannot use private[this] due to https://issues.scala-lang.org/browse/SI-8087
  @transient private val logger = LoggerFactory.getLogger("SharedStreamingContextBeforeAfterEachTest")
  
  // Spark configuration
  override def sparkMaster : String = "local[5]"
  override def batchDuration = Duration(250) 
  override def defaultParallelism = 3
  override def enableCheckpointing = false // as queueStream doesn't support checkpointing 
  
  def is = 
    sequential ^ s2"""
    Simple test for SharedStreamingContextBeforeAfterEach 
      where a simple queueStream test must be successful $successfulSimpleQueueStreamTest
      where a simple queueStream test can also fail $failingSimpleQueueStreamTest
    """      
            
  def successfulSimpleQueueStreamTest = simpleQueueStreamTest(expectedCount = 0)
  def failingSimpleQueueStreamTest = simpleQueueStreamTest(expectedCount = 1) must beFailing
        
  def simpleQueueStreamTest(expectedCount : Int) : Result = {
    val record = "hola"
    val batches = Seq.fill(5)(Seq.fill(10)(record))
    val queue = new Queue[RDD[String]]
    queue ++= batches.map(batch => sc.parallelize(batch, numSlices = defaultParallelism))
    val inputDStream = ssc.queueStream(queue, oneAtATime = true)
    val sizesDStream = inputDStream.map(_.length)
    
    var batchCount = 0
    // NOTE wrapping assertions with a Result object is needed
    // to avoid the Spark Streaming runtime capturing the exceptions
    // from failing assertions
    var result : Result = ok
    inputDStream.foreachRDD { rdd =>
      batchCount += 1
      println(s"completed batch number $batchCount: ${rdd.collect.mkString(",")}")
      result = result and {
        rdd.filter(_!= record).count() === expectedCount
        rdd should existsRecord(_ == "hola")
      }
    }
    sizesDStream.foreachRDD { rdd =>
      result = result and { 
        rdd should foreachRecord(record.length)(len => _ == len)      
      }
    }
    
    // should only start the dstream after all the transformations and actions have been defined
    ssc.start()
    
    // wait for completion of batches.length batches
    StreamingContextUtils.awaitForNBatchesCompleted(batches.length, atMost = 10 seconds)(ssc)
    
    result
  }
} 
Example 64
Source File: ScalaCheckStreamingTest.scala    From sscheck   with Apache License 2.0 5 votes vote down vote up
package es.ucm.fdi.sscheck.spark.streaming

import org.junit.runner.RunWith
import org.specs2.runner.JUnitRunner
import org.specs2.ScalaCheck
import org.specs2.execute.{AsResult, Result}

import org.scalacheck.{Prop, Gen}
import org.scalacheck.Arbitrary.arbitrary

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration}
import org.apache.spark.streaming.dstream.DStream

import es.ucm.fdi.sscheck.prop.tl.Formula._
import es.ucm.fdi.sscheck.prop.tl.DStreamTLProperty
import es.ucm.fdi.sscheck.matcher.specs2.RDDMatchers._

@RunWith(classOf[JUnitRunner])
class ScalaCheckStreamingTest 
  extends org.specs2.Specification  
  with DStreamTLProperty
  with org.specs2.matcher.ResultMatchers
  with ScalaCheck {
    
  override def sparkMaster : String = "local[5]"
  override def batchDuration = Duration(350)
  override def defaultParallelism = 4  
  
  def is = 
    sequential ^ s2"""
    Simple properties for Spark Streaming
      - where the first property is a success $prop1
      - where a simple property for DStream.count is a success ${countProp(_.count)}
      - where a faulty implementation of the DStream.count is detected ${countProp(faultyCount) must beFailing}
    """    
      
  def prop1 = {
    val batchSize = 30   
    val numBatches = 10 
    val dsgenSeqSeq1 = {
      val zeroSeqSeq = Gen.listOfN(numBatches,  Gen.listOfN(batchSize, 0)) 
      val oneSeqSeq = Gen.listOfN(numBatches, Gen.listOfN(batchSize, 1))
      Gen.oneOf(zeroSeqSeq, oneSeqSeq)  
    } 
    type U = (RDD[Int], RDD[Int])
    
    forAllDStream[Int, Int](
      "inputDStream" |: dsgenSeqSeq1)(
      (inputDs : DStream[Int]) => {  
        val transformedDs = inputDs.map(_+1)
        transformedDs
      })(always ((u : U) => {
          val (inputBatch, transBatch) = u
          inputBatch.count === batchSize and 
          inputBatch.count === transBatch.count and
          (inputBatch.intersection(transBatch).isEmpty should beTrue) and
          ( inputBatch should foreachRecord(_ == 0) or 
            (inputBatch should foreachRecord(_ == 1)) 
          )
        }) during numBatches 
      )}.set(minTestsOk = 10).verbose
      
  def faultyCount(ds : DStream[Double]) : DStream[Long] = 
    ds.count.transform(_.map(_ - 1))
    
  def countProp(testSubject : DStream[Double] => DStream[Long]) = {
    type U = (RDD[Double], RDD[Long])
    val numBatches = 10 
    forAllDStream[Double, Long]( 
      Gen.listOfN(numBatches,  Gen.listOfN(30, arbitrary[Double])))(
      testSubject
      )(always ((u : U) => {
         val (inputBatch, transBatch) = u
         transBatch.count === 1 and
         inputBatch.count === transBatch.first
      }) during numBatches
    )}.set(minTestsOk = 10).verbose
    
} 
Example 65
Source File: ITSelectorSuite.scala    From spark-infotheoretic-feature-selection   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.feature

import org.apache.spark.sql.{DataFrame, SQLContext}
import org.junit.runner.RunWith
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.junit.JUnitRunner
import TestHelper._



  test("Run ITFS on nci data (nPart = 10, nfeat = 10)") {

    val df = readCSVData(sqlContext, "test_nci9_s3.csv")
    val cols = df.columns
    val pad = 2
    val allVectorsDense = true
    val model = getSelectorModel(sqlContext, df, cols.drop(1), cols.head, 
        10, 10, allVectorsDense, pad)

    assertResult("443, 755, 1369, 1699, 3483, 5641, 6290, 7674, 9399, 9576") {
      model.selectedFeatures.mkString(", ")
    }
  }
} 
Example 66
Source File: instagram_api_yaml.scala    From play-swagger   with MIT License 5 votes vote down vote up
package instagram.api.yaml

import de.zalando.play.controllers._
import org.scalacheck._
import org.scalacheck.Arbitrary._
import org.scalacheck.Prop._
import org.scalacheck.Test._
import org.specs2.mutable._
import play.api.test.Helpers._
import play.api.test._
import play.api.mvc.MultipartFormData.FilePart
import play.api.mvc._

import org.junit.runner.RunWith
import org.specs2.runner.JUnitRunner
import java.net.URLEncoder
import com.fasterxml.jackson.databind.ObjectMapper

import play.api.http.Writeable
import play.api.libs.Files.TemporaryFile
import play.api.test.Helpers.{status => requestStatusCode_}
import play.api.test.Helpers.{contentAsString => requestContentAsString_}
import play.api.test.Helpers.{contentType => requestContentType_}

import scala.math.BigInt
import scala.math.BigDecimal

import Generators._

    @RunWith(classOf[JUnitRunner])
    class Instagram_api_yamlSpec extends Specification {
        def toPath[T](value: T)(implicit binder: PathBindable[T]): String = Option(binder.unbind("", value)).getOrElse("")
        def toQuery[T](key: String, value: T)(implicit binder: QueryStringBindable[T]): String = Option(binder.unbind(key, value)).getOrElse("")
        def toHeader[T](value: T)(implicit binder: PathBindable[T]): String = Option(binder.unbind("", value)).getOrElse("")

      def checkResult(props: Prop) =
        Test.check(Test.Parameters.default, props).status match {
          case Failed(args, labels) =>
            val failureMsg = labels.mkString("\n") + " given args: " + args.map(_.arg).mkString("'", "', '","'")
            failure(failureMsg)
          case Proved(_) | Exhausted | Passed => success
          case PropException(_, e, labels) =>
            val error = if (labels.isEmpty) e.getLocalizedMessage() else labels.mkString("\n")
            failure(error)
        }

      private def parserConstructor(mimeType: String) = PlayBodyParsing.jacksonMapper(mimeType)

      def parseResponseContent[T](mapper: ObjectMapper, content: String, mimeType: Option[String], expectedType: Class[T]) =
        mapper.readValue(content, expectedType)

} 
Example 67
Source File: security_api_yaml.scala    From play-swagger   with MIT License 5 votes vote down vote up
package security.api.yaml

import de.zalando.play.controllers._
import org.scalacheck._
import org.scalacheck.Arbitrary._
import org.scalacheck.Prop._
import org.scalacheck.Test._
import org.specs2.mutable._
import play.api.test.Helpers._
import play.api.test._
import play.api.mvc.MultipartFormData.FilePart
import play.api.mvc._

import org.junit.runner.RunWith
import org.specs2.runner.JUnitRunner
import java.net.URLEncoder
import com.fasterxml.jackson.databind.ObjectMapper

import play.api.http.Writeable
import play.api.libs.Files.TemporaryFile
import play.api.test.Helpers.{status => requestStatusCode_}
import play.api.test.Helpers.{contentAsString => requestContentAsString_}
import play.api.test.Helpers.{contentType => requestContentType_}

import de.zalando.play.controllers.ArrayWrapper

import Generators._

    @RunWith(classOf[JUnitRunner])
    class Security_api_yamlSpec extends Specification {
        def toPath[T](value: T)(implicit binder: PathBindable[T]): String = Option(binder.unbind("", value)).getOrElse("")
        def toQuery[T](key: String, value: T)(implicit binder: QueryStringBindable[T]): String = Option(binder.unbind(key, value)).getOrElse("")
        def toHeader[T](value: T)(implicit binder: PathBindable[T]): String = Option(binder.unbind("", value)).getOrElse("")

      def checkResult(props: Prop) =
        Test.check(Test.Parameters.default, props).status match {
          case Failed(args, labels) =>
            val failureMsg = labels.mkString("\n") + " given args: " + args.map(_.arg).mkString("'", "', '","'")
            failure(failureMsg)
          case Proved(_) | Exhausted | Passed => success
          case PropException(_, e, labels) =>
            val error = if (labels.isEmpty) e.getLocalizedMessage() else labels.mkString("\n")
            failure(error)
        }

      private def parserConstructor(mimeType: String) = PlayBodyParsing.jacksonMapper(mimeType)

      def parseResponseContent[T](mapper: ObjectMapper, content: String, mimeType: Option[String], expectedType: Class[T]) =
        mapper.readValue(content, expectedType)

} 
Example 68
Source File: Downloader$Test.scala    From mystem-scala   with MIT License 5 votes vote down vote up
package ru.stachek66.tools

import java.io.File
import java.net.URL

import org.junit.runner.RunWith
import org.scalatest.{Ignore, FunSuite}
import org.scalatest.junit.JUnitRunner


@Ignore
class Downloader$Test extends FunSuite {

  test("downloading-something") {

    val hello = new File("hello-test.html")
    val mystem = new File("atmta.binary")

    Downloader.downloadBinaryFile(new URL("http://www.stachek66.ru/"), hello)

    Downloader.downloadBinaryFile(
      new URL("http://download.cdn.yandex.net/mystem/mystem-3.0-linux3.1-64bit.tar.gz"),
      mystem
    )

    Downloader.downloadBinaryFile(
      new URL("http://download.cdn.yandex.net/mystem/mystem-3.1-win-64bit.zip"),
      mystem
    )

    hello.delete
    mystem.delete
  }

  test("download-and-unpack") {
    val bin = new File("atmta.binary.tar.gz")
    val bin2 = new File("executable")

    Decompressor.select.unpack(
      Downloader.downloadBinaryFile(
        new URL("http://download.cdn.yandex.net/mystem/mystem-3.0-linux3.1-64bit.tar.gz"),
        bin),
      bin2
    )

    bin.delete
    bin2.delete
  }
} 
Example 69
Source File: Zip$Test.scala    From mystem-scala   with MIT License 5 votes vote down vote up
package ru.stachek66.tools

import java.io.{File, FileInputStream}

import org.apache.commons.io.IOUtils
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

import org.junit.runner.RunWith


class Zip$Test extends FunSuite {

  test("zip-test") {
    val src = new File("src/test/resources/test.txt")
    Zip.unpack(
      new File("src/test/resources/test.zip"),
      new File("src/test/resources/res.txt")) match {
      case f =>
        val content0 = IOUtils.toString(new FileInputStream(f))
        val content1 = IOUtils.toString(new FileInputStream(src))
        print(content0.trim + " vs " + content1.trim)
        assert(content0 === content1)
    }
  }

} 
Example 70
Source File: TarGz$Test.scala    From mystem-scala   with MIT License 5 votes vote down vote up
package ru.stachek66.tools

import java.io.{File, FileInputStream}

import org.apache.commons.io.IOUtils
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner


class TarGz$Test extends FunSuite {

  test("tgz-test") {
    val src = new File("src/test/resources/test.txt")
    TarGz.unpack(
      new File("src/test/resources/test.tar.gz"),
      new File("src/test/resources/res.txt")) match {
      case f =>
        val content0 = IOUtils.toString(new FileInputStream(f))
        val content1 = IOUtils.toString(new FileInputStream(src))
        print(content0.trim + " vs " + content1.trim)
        assert(content0 === content1)
    }
  }
} 
Example 71
Source File: DaoServiceTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.scheduler.dao

import com.ivan.nikolov.scheduler.TestEnvironment
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfter, FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class DaoServiceTest extends FlatSpec with Matchers with BeforeAndAfter with TestEnvironment {

  override val databaseService = new H2DatabaseService
  override val migrationService = new MigrationService
  override val daoService = new DaoServiceImpl
  
  before {
    // we run this here. Generally migrations will only
    // be dealing with data layout and we will be able to have
    // test classes that insert test data.
    migrationService.runMigrations()
  }
  
  after {
    migrationService.cleanupDatabase()
  }
  
  "readResultSet" should "properly iterate over a result set and apply a function to it." in {
    val connection = daoService.getConnection()
    try {
      val result = daoService.executeSelect(
        connection.prepareStatement(
          "SELECT name FROM people"
        )
      ) {
        case rs =>
          daoService.readResultSet(rs) {
            case row =>
              row.getString("name")
          }
      }
      result should have size(3)
      result should contain("Ivan")
      result should contain("Maria")
      result should contain("John")
    } finally {
      connection.close()
      
    }
  }
} 
Example 72
package com.ivan.nikolov.scheduler.services

import com.ivan.nikolov.scheduler.TestEnvironment
import com.ivan.nikolov.scheduler.config.job.{Console, Daily, JobConfig, TimeOptions}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class JobConfigReaderServiceTest extends FlatSpec with Matchers with TestEnvironment {

  override val ioService: IOService = new IOService
  override val jobConfigReaderService: JobConfigReaderService = new JobConfigReaderService

  "readJobConfigs" should "read and parse configurations successfully." in {
    val result = jobConfigReaderService.readJobConfigs()
    result should have size(1)
    result should contain(
      JobConfig(
        "Test Command",
        "ping google.com -c 10",
        Console,
        Daily,
        TimeOptions(12, 10)
      )
    )
  }
} 
Example 73
Source File: TimeOptionsTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.scheduler.config.job

import java.time.LocalDateTime

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class TimeOptionsTest extends FlatSpec with Matchers {

  "getInitialDelay" should "get the right initial delay for hourly less than an hour after now." in {
    val now = LocalDateTime.of(2018, 3, 20, 12, 43, 10)
    val later = now.plusMinutes(20)
    val timeOptions = TimeOptions(later.getHour, later.getMinute)
    val result = timeOptions.getInitialDelay(now, Hourly)
    result.toMinutes should equal(20)
  }
  
  it should "get the right initial delay for hourly more than an hour after now." in {
    val now = LocalDateTime.of(2018, 3, 20, 18, 51, 17)
    val later = now.plusHours(3)
    val timeOptions = TimeOptions(later.getHour, later.getMinute)
    val result = timeOptions.getInitialDelay(now, Hourly)
    result.toHours should equal(3)
  }
  
  it should "get the right initial delay for hourly less than an hour before now." in {
    val now = LocalDateTime.of(2018, 3, 20, 11, 18, 55)
    val earlier = now.minusMinutes(25)
    // because of the logic and it will fail otherwise.
    if (earlier.getDayOfWeek == now.getDayOfWeek) {
      val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute)
      val result = timeOptions.getInitialDelay(now, Hourly)
      result.toMinutes should equal(35)
    }
  }
  
  it should "get the right initial delay for hourly more than an hour before now." in {
    val now = LocalDateTime.of(2018, 3, 20, 12, 43, 59)
    val earlier = now.minusHours(1).minusMinutes(25)
    // because of the logic and it will fail otherwise.
    if (earlier.getDayOfWeek == now.getDayOfWeek) {
      val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute)
      val result = timeOptions.getInitialDelay(now, Hourly)
      result.toMinutes should equal(35)
    }
  }
  
  it should "get the right initial delay for daily before now." in {
    val now = LocalDateTime.of(2018, 3, 20, 14, 43, 10)
    val earlier = now.minusMinutes(25)
    // because of the logic and it will fail otherwise.
    if (earlier.getDayOfWeek == now.getDayOfWeek) {
      val timeOptions = TimeOptions(earlier.getHour, earlier.getMinute)
      val result = timeOptions.getInitialDelay(now, Daily)
      result.toMinutes should equal(24 * 60 - 25)
    }
  }
  
  it should "get the right initial delay for daily after now." in {
    val now = LocalDateTime.of(2018, 3, 20, 16, 21, 6)
    val later = now.plusMinutes(20)
    val timeOptions = TimeOptions(later.getHour, later.getMinute)
    val result = timeOptions.getInitialDelay(now, Daily)
    result.toMinutes should equal(20)
  }
} 
Example 74
Source File: TraitATest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.composition

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class TraitATest extends FlatSpec with Matchers with A {

  "hello" should "greet properly." in {
    hello() should equal("Hello, I am trait A!")
  }
  
  "pass" should "return the right string with the number." in {
    pass(10) should equal("Trait A said: 'You passed 10.'")
  }
  
  it should "be correct also for negative values." in {
    pass(-10) should equal("Trait A said: 'You passed -10.'")
  }
} 
Example 75
Source File: TraitACaseScopeTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.composition

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class TraitACaseScopeTest extends FlatSpec with Matchers {
  "hello" should "greet properly." in new A {
    hello() should equal("Hello, I am trait A!")
  }

  "pass" should "return the right string with the number." in new A {
    pass(10) should equal("Trait A said: 'You passed 10.'")
  }

  it should "be correct also for negative values." in new A {
    pass(-10) should equal("Trait A said: 'You passed -10.'")
  }
} 
Example 76
Source File: UserComponentTest.scala    From Scala-Design-Patterns-Second-Edition   with MIT License 5 votes vote down vote up
package com.ivan.nikolov.cake

import com.ivan.nikolov.cake.model.Person
import org.junit.runner.RunWith
import org.mockito.Mockito._
import org.scalatest.junit.JUnitRunner
import org.scalatest.mockito.MockitoSugar
import org.scalatest.{FlatSpec, Matchers}

@RunWith(classOf[JUnitRunner])
class UserComponentTest extends FlatSpec with Matchers with MockitoSugar with TestEnvironment {
  val className = "A"
  val emptyClassName = "B"
  val people = List(
    Person(1, "a", 10),
    Person(2, "b", 15),
    Person(3, "c", 20)
  )
  
  override val userService = new UserService
  
  when(dao.getPeopleInClass(className)).thenReturn(people)
  when(dao.getPeopleInClass(emptyClassName)).thenReturn(List())
  
  "getAverageAgeOfUsersInClass" should "properly calculate the average of all ages." in {
    userService.getAverageAgeOfUsersInClass(className) should equal(15.0)
  }
  
  it should "properly handle an empty result." in {
    userService.getAverageAgeOfUsersInClass(emptyClassName) should equal(0.0)
  }
} 
Example 77
Source File: MetricsStatsReceiverTest.scala    From finagle-metrics   with MIT License 5 votes vote down vote up
package com.twitter.finagle.metrics

import com.twitter.finagle.metrics.MetricsStatsReceiver._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.FunSuite

@RunWith(classOf[JUnitRunner])
class MetricsStatsReceiverTest extends FunSuite {

  private[this] val receiver = new MetricsStatsReceiver()

  private[this] def readGauge(name: String): Option[Number] =
    Option(metrics.getGauges.get(name)) match {
      case Some(gauge) => Some(gauge.getValue.asInstanceOf[Float])
      case _ => None
    }

  private[this] def readCounter(name: String): Option[Number] =
    Option(metrics.getMeters.get(name)) match {
      case Some(counter) => Some(counter.getCount)
      case _ => None
    }

  private[this] def readStat(name: String): Option[Number] =
    Option(metrics.getHistograms.get(name)) match {
      case Some(stat) => Some(stat.getSnapshot.getValues.toSeq.sum)
      case _ => None
    }

  test("MetricsStatsReceiver should store and read gauge into the Codahale Metrics library") {
    val x = 1.5f
    receiver.addGauge("my_gauge")(x)

    assert(readGauge("my_gauge") === Some(x))
  }

  test("MetricsStatsReceiver should always assume the latest value of an already created gauge") {
    val gaugeName = "my_gauge2"
    val expectedValue = 8.8f

    receiver.addGauge(gaugeName)(2.2f)
    receiver.addGauge(gaugeName)(9.9f)
    receiver.addGauge(gaugeName)(expectedValue)

    assert(readGauge(gaugeName) === Some(expectedValue))
  }

  test("MetricsStatsReceiver should store and remove gauge into the Codahale Metrics Library") {
    val gaugeName = "temp-gauge"
    val expectedValue = 2.8f

    val tempGauge = receiver.addGauge(gaugeName)(expectedValue)
    assert(readGauge(gaugeName) === Some(expectedValue))

    tempGauge.remove()

    assert(readGauge(gaugeName) === None)
  }

  test("MetricsStatsReceiver should store and read stat into the Codahale Metrics library") {
    val x = 1
    val y = 3
    val z = 5

    val s = receiver.stat("my_stat")
    s.add(x)
    s.add(y)
    s.add(z)

    assert(readStat("my_stat") === Some(x + y + z))
  }

  test("MetricsStatsReceiver should store and read counter into the Codahale Metrics library") {
    val x = 2
    val y = 5
    val z = 8

    val c = receiver.counter("my_counter")
    c.incr(x)
    c.incr(y)
    c.incr(z)

    assert(readCounter("my_counter") === Some(x + y + z))
  }

} 
Example 78
Source File: ReceiverWithoutOffsetIT.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.datasource.models.{InputSentences, StopConditions}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ReceiverWithoutOffsetIT extends TemporalDataSuite {

  test("DataSource Receiver should read all the records on each batch without offset conditions") {
    sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.parallelize(registers)
    sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName)
    ssc = new StreamingContext(sc, Seconds(1))
    val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received")
    val inputSentences = InputSentences(
      s"select * from $tableName",
      StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true),
      initialStatements = Seq.empty[String]
    )
    val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams)

    distributedStream.start()
    distributedStream.foreachRDD(rdd => {
      val streamingEvents = rdd.count()
      log.info(s" EVENTS COUNT : \t $streamingEvents")
      totalEvents += streamingEvents
      log.info(s" TOTAL EVENTS : \t $totalEvents")
      if (!rdd.isEmpty())
        assert(streamingEvents === totalRegisters.toLong)
    })
    ssc.start()
    ssc.awaitTerminationOrTimeout(10000L)

    assert(totalEvents.value === totalRegisters.toLong * 10)
  }
} 
Example 79
Source File: ReceiverNotStopContextIT.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ReceiverNotStopContextIT extends TemporalDataSuite {

  test("DataSource Receiver should read all the records in one streaming batch") {
    sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.parallelize(registers)
    sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName)
    ssc = new StreamingContext(sc, Seconds(1))
    val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received")
    val inputSentences = InputSentences(
      s"select * from $tableName",
      OffsetConditions(OffsetField("idInt")),
      initialStatements = Seq.empty[String]
    )
    val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams)

    distributedStream.start()
    distributedStream.foreachRDD(rdd => {
      totalEvents += rdd.count()
    })
    ssc.start()
    ssc.awaitTerminationOrTimeout(15000L)

    assert(totalEvents.value === totalRegisters.toLong)
  }
} 
Example 80
Source File: ReceiverLimitedIT.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField, StopConditions}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ReceiverLimitedIT extends TemporalDataSuite {

  test("DataSource Receiver should read the records limited on each batch") {
    sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.parallelize(registers)
    sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName)

    ssc = new StreamingContext(sc, Seconds(1))
    val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received")
    val inputSentences = InputSentences(
      s"select * from $tableName",
      OffsetConditions(OffsetField("idInt"), limitRecords = 1000),
      StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true),
      initialStatements = Seq.empty[String]
    )
    val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams)

    // Start up the receiver.
    distributedStream.start()

    // Fires each time the configured window has passed.
    distributedStream.foreachRDD(rdd => {
      totalEvents += rdd.count()
    })

    ssc.start() // Start the computation
    ssc.awaitTerminationOrTimeout(15000L) // Wait for the computation to terminate

    assert(totalEvents.value === totalRegisters.toLong)
  }
} 
Example 81
Source File: ReceiverBasicIT.scala    From datasource-receiver   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.streaming.datasource

import org.apache.spark.SparkContext
import org.apache.spark.sql.SQLContext
import org.apache.spark.streaming.datasource.models.{InputSentences, OffsetConditions, OffsetField, StopConditions}
import org.apache.spark.streaming.{Seconds, StreamingContext}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class ReceiverBasicIT extends TemporalDataSuite {

  test ("DataSource Receiver should read all the records in one streaming batch") {
    sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    val rdd = sc.parallelize(registers)
    sqlContext.createDataFrame(rdd, schema).registerTempTable(tableName)
    ssc = new StreamingContext(sc, Seconds(1))
    val totalEvents = ssc.sparkContext.accumulator(0L, "Number of events received")
    val inputSentences = InputSentences(
      s"select * from $tableName",
      OffsetConditions(OffsetField("idInt")),
      StopConditions(stopWhenEmpty = true, finishContextWhenEmpty = true),
      initialStatements = Seq.empty[String]
    )
    val distributedStream = DatasourceUtils.createStream(ssc, inputSentences, datasourceParams)

    distributedStream.start()
    distributedStream.foreachRDD(rdd => {
      val streamingEvents = rdd.count()
      log.info(s" EVENTS COUNT : \t $streamingEvents")
      totalEvents += streamingEvents
      log.info(s" TOTAL EVENTS : \t $totalEvents")
      val streamingRegisters = rdd.collect()
      if (!rdd.isEmpty())
        assert(streamingRegisters === registers.reverse)
    })
    ssc.start()
    ssc.awaitTerminationOrTimeout(15000L)

    assert(totalEvents.value === totalRegisters.toLong)
  }
} 
Example 82
Source File: GeohashTest.scala    From sfseize   with Apache License 2.0 5 votes vote down vote up
package org.eichelberger.sfc.examples

import com.typesafe.scalalogging.slf4j.LazyLogging
import org.eichelberger.sfc.SpaceFillingCurve._
import org.eichelberger.sfc.study.composition.CompositionSampleData._
import org.eichelberger.sfc.utils.Timing
import org.eichelberger.sfc.{DefaultDimensions, Dimension}
import org.junit.runner.RunWith
import org.specs2.mutable.Specification
import org.specs2.runner.JUnitRunner

@RunWith(classOf[JUnitRunner])
class GeohashTest extends Specification with LazyLogging {
  val xCville = -78.488407
  val yCville = 38.038668

  "Geohash example" should {
    val geohash = new Geohash(35)

    "encode/decode round-trip for an interior point" >> {
      // encode
      val hash = geohash.pointToHash(Seq(xCville, yCville))
      hash must equalTo("dqb0muw")

      // decode
      val cell = geohash.hashToCell(hash)
      println(s"[Geohash example, Charlottesville] POINT($xCville $yCville) -> $hash -> $cell")
      cell(0).containsAny(xCville) must beTrue
      cell(1).containsAny(yCville) must beTrue
    }

    "encode/decode properly at the four corners and the center" >> {
      for (x <- Seq(-180.0, 0.0, 180.0); y <- Seq(-90.0, 0.0, 90.0)) {
        // encode
        val hash = geohash.pointToHash(Seq(x, y))

        // decode
        val cell = geohash.hashToCell(hash)
        println(s"[Geohash example, extrema] POINT($x $y) -> $hash -> $cell")
        cell(0).containsAny(x) must beTrue
        cell(1).containsAny(y) must beTrue
      }

      // degenerate test outcome
      1 must equalTo(1)
    }

    def getCvilleRanges(curve: Geohash): (OrdinalPair, OrdinalPair, Iterator[OrdinalPair]) = {
      val lonIdxRange = OrdinalPair(
        curve.children(0).asInstanceOf[Dimension[Double]].index(bboxCville._1),
        curve.children(1).asInstanceOf[Dimension[Double]].index(bboxCville._3)
      )
      val latIdxRange = OrdinalPair(
        curve.children(0).asInstanceOf[Dimension[Double]].index(bboxCville._2),
        curve.children(1).asInstanceOf[Dimension[Double]].index(bboxCville._4)
      )
      val query = Query(Seq(OrdinalRanges(lonIdxRange), OrdinalRanges(latIdxRange)))
      val cellQuery = Cell(Seq(
        DefaultDimensions.createDimension("x", bboxCville._1, bboxCville._3, 0),
        DefaultDimensions.createDimension("y", bboxCville._2, bboxCville._4, 0)
      ))
      (lonIdxRange, latIdxRange, curve.getRangesCoveringCell(cellQuery))
    }
    
    "generate valid selection indexes" >> {
      val (_, _, ranges) = getCvilleRanges(geohash)

      ranges.size must equalTo(90)
    }
    
    "report range efficiency" >> {
      def atPrecision(xBits: OrdinalNumber, yBits: OrdinalNumber): (Long, Long) = {
        val curve = new Geohash(xBits + yBits)
        val (lonRange, latRange, ranges) = getCvilleRanges(curve)
        (lonRange.size * latRange.size, ranges.size.toLong)
      }

      for (dimPrec <- 10 to 25) {
        val ((numCells, numRanges), ms) = Timing.time{ () => atPrecision(dimPrec, dimPrec - 1) }
        println(s"[ranges across scales, Charlottesville] precision ($dimPrec, ${dimPrec - 1}) -> $numCells / $numRanges = ${numCells / numRanges} in $ms milliseconds")
      }

      1 must equalTo(1)
    }
  }

} 
Example 83
Source File: LexicographicTest.scala    From sfseize   with Apache License 2.0 5 votes vote down vote up
package org.eichelberger.sfc.utils

import com.typesafe.scalalogging.slf4j.LazyLogging
import org.eichelberger.sfc.SpaceFillingCurve.{OrdinalVector, ords2ordvec}
import org.eichelberger.sfc.{DefaultDimensions, ZCurve}
import org.junit.runner.RunWith
import org.specs2.mutable.Specification
import org.specs2.runner.JUnitRunner

@RunWith(classOf[JUnitRunner])
class LexicographicTest extends Specification with LazyLogging {
  sequential
  
  "Lexicographical encoding" should {
    val precisions = new ords2ordvec(Seq(18L, 17L)).toOrdinalVector

    val sfc = ZCurve(precisions)

    val Longitude = DefaultDimensions.createLongitude(18L)
    val Latitude = DefaultDimensions.createLatitude(17L)

    "work for a known point" >> {
      val x = -78.488407
      val y = 38.038668

      val point = OrdinalVector(Longitude.index(x), Latitude.index(y))
      val idx = sfc.index(point)
      val gh = sfc.lexEncodeIndex(idx)

      gh must equalTo("dqb0muw")
    }

    "be consistent round-trip" >> {
      val xs = (-180.0 to 180.0 by 33.3333).toSeq ++ Seq(180.0)
      val ys = (-90.0 to 90.0 by 33.3333).toSeq ++ Seq(90.0)
      for (x <- xs; y <- ys) {
        val ix = Longitude.index(x)
        val iy = Latitude.index(y)
        val point = OrdinalVector(ix, iy)
        val idx = sfc.index(point)
        val gh = sfc.lexEncodeIndex(idx)
        val idx2 = sfc.lexDecodeIndex(gh)
        idx2 must equalTo(idx)
        val point2 = sfc.inverseIndex(idx2)
        point2(0) must equalTo(ix)
        point2(1) must equalTo(iy)
        val rx = Longitude.inverseIndex(ix)
        val ry = Latitude.inverseIndex(iy)

        val sx = x.formatted("%8.3f")
        val sy = y.formatted("%8.3f")
        val sidx = idx.formatted("%20d")
        println(s"[LEXI ROUND-TRIP] POINT($sx $sy) -> $sidx = $gh -> ($rx, $ry)")
      }

      // degenerate
      1 must equalTo(1)
    }
  }

  "multiple lexicographical encoders" should {
    "return different results for different base resolutions" >> {
      val x = -78.488407
      val y = 38.038668

      for (xBits <- 1 to 30; yBits <- xBits - 1 to xBits if yBits > 0) {
        val precisions = new ords2ordvec(Seq(xBits, yBits)).toOrdinalVector
        val sfc = ZCurve(precisions)

        val Longitude = DefaultDimensions.createLongitude(xBits)
        val Latitude = DefaultDimensions.createLatitude(yBits)

        val idx = sfc.index(OrdinalVector(Longitude.index(x), Latitude.index(y)))
        val gh = sfc.lexEncodeIndex(idx)
        val idx2 = sfc.lexDecodeIndex(gh)

        idx2 must equalTo(idx)

        println(s"[LEXI ACROSS RESOLUTIONS] mx $xBits + my $yBits = base ${sfc.alphabet.size}, idx $idx -> gh $gh -> $idx2")
      }

      // degenerate
      1 must equalTo(1)
    }
  }
} 
Example 84
Source File: BitManipulationsTest.scala    From sfseize   with Apache License 2.0 5 votes vote down vote up
package org.eichelberger.sfc.utils

import com.typesafe.scalalogging.slf4j.LazyLogging
import org.junit.runner.RunWith
import org.specs2.mutable.Specification
import org.specs2.runner.JUnitRunner

import BitManipulations._

@RunWith(classOf[JUnitRunner])
class BitManipulationsTest extends Specification with LazyLogging {
  "static methods" should {
    "usedMask" >> {
      // single bits
      for (pos <- 0 to 62) {
        val v = 1L << pos.toLong
        val actual = usedMask(v)
        val expected = (1L << (pos + 1L)) - 1L
        println(s"[usedMask single bit]  pos $pos, value $v, actual $actual, expected $expected")
        actual must equalTo(expected)
      }

      // full bit masks
      for (pos <- 0 to 62) {
        val expected = (1L << (pos.toLong + 1L)) - 1L
        val actual = usedMask(expected)
        println(s"[usedMask full bit masks]  pos $pos, value $expected, actual $actual, expected $expected")
        actual must equalTo(expected)
      }

      usedMask(0) must equalTo(0)
    }

    "sharedBitPrefix" >> {
      sharedBitPrefix(2, 3) must equalTo(2)
      sharedBitPrefix(178, 161) must equalTo(160)
    }

    "common block extrema" >> {
      commonBlockMin(178, 161) must equalTo(160)
      commonBlockMax(178, 161) must equalTo(191)
    }
  }
} 
Example 85
Source File: CompositionParserTest.scala    From sfseize   with Apache License 2.0 5 votes vote down vote up
package org.eichelberger.sfc.utils

import com.typesafe.scalalogging.slf4j.LazyLogging
import org.eichelberger.sfc.SpaceFillingCurve.SpaceFillingCurve
import org.eichelberger.sfc.SpaceFillingCurve.SpaceFillingCurve
import org.eichelberger.sfc._
import org.junit.runner.RunWith
import org.specs2.mutable.Specification
import org.specs2.runner.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CompositionParserTest extends Specification {
  sequential

  def parsableCurve(curve: SpaceFillingCurve): String = curve match {
    case c: ComposedCurve =>
      c.delegate.name.charAt(0).toString + c.children.map {
        case d: Dimension[_]      => d.precision
        case s: SubDimension[_]   => s.precision
        case c: SpaceFillingCurve => parsableCurve(c)
      }.mkString("(", ", ", ")")
    case s =>
      s.name.charAt(0).toString + s.precisions.toSeq.map(_.toString).mkString("(", ", ", ")")
  }

  def eval(curve: ComposedCurve): Boolean = {
    val toParse: String = parsableCurve(curve)
    val parsed: ComposedCurve = CompositionParser.buildWholeNumberCurve(toParse)
    val fromParse: String = parsableCurve(parsed)
    println(s"[CURVE PARSER]\n  Input:  $toParse\n  Output:  $fromParse")
    toParse == fromParse
  }

  "simple expressions" should {
    val R23 = new ComposedCurve(
      RowMajorCurve(2, 3),
      Seq(
        DefaultDimensions.createIdentityDimension(2),
        DefaultDimensions.createIdentityDimension(3)
      )
    )
    val H_2_R23 = new ComposedCurve(
      CompactHilbertCurve(2),
      Seq(
        DefaultDimensions.createIdentityDimension(2),
        R23
      )
    )
    val Z_R23_2 = new ComposedCurve(
      ZCurve(2),
      Seq(
        R23,
        DefaultDimensions.createIdentityDimension(2)
      )
    )

    "parse correctly" >> {
      eval(R23) must beTrue
      eval(H_2_R23) must beTrue
      eval(Z_R23_2) must beTrue
    }
  }
} 
Example 86
Source File: LocalityEstimatorTest.scala    From sfseize   with Apache License 2.0 5 votes vote down vote up
package org.eichelberger.sfc.utils

import com.typesafe.scalalogging.slf4j.LazyLogging
import org.eichelberger.sfc.{CompactHilbertCurve, RowMajorCurve, ZCurve}
import org.junit.runner.RunWith
import org.specs2.mutable.Specification
import org.specs2.runner.JUnitRunner

@RunWith(classOf[JUnitRunner])
class LocalityEstimatorTest extends Specification with LazyLogging {
  sequential

  "locality" should {
    "evaluate on square 2D curves" >> {
      (1 to 6).foreach { p =>
        val locR = LocalityEstimator(RowMajorCurve(p, p)).locality
        println(s"[LOCALITY R($p, $p)] $locR")

        val locZ = LocalityEstimator(ZCurve(p, p)).locality
        println(s"[LOCALITY Z($p, $p)] $locZ")

        val locH = LocalityEstimator(CompactHilbertCurve(p, p)).locality
        println(s"[LOCALITY H($p, $p)] $locH")
      }

      1 must beEqualTo(1)
    }

    "evaluate on non-square 2D curves" >> {
      (1 to 6).foreach { p =>
        val locR = LocalityEstimator(RowMajorCurve(p << 1L, p)).locality
        println(s"[LOCALITY R(${p*2}, $p)] $locR")

        val locZ = LocalityEstimator(ZCurve(p << 1L, p)).locality
        println(s"[LOCALITY Z(${p*2}, $p)] $locZ")

        val locH = LocalityEstimator(CompactHilbertCurve(p << 1L, p)).locality
        println(s"[LOCALITY H(${p*2}, $p)] $locH")
      }

      1 must beEqualTo(1)
    }
  }
} 
Example 87
Source File: RowMajorCurveTest.scala    From sfseize   with Apache License 2.0 5 votes vote down vote up
package org.eichelberger.sfc

import com.typesafe.scalalogging.slf4j.LazyLogging
import org.eichelberger.sfc.CompactHilbertCurve.Mask
import org.eichelberger.sfc.SpaceFillingCurve.{OrdinalVector, SpaceFillingCurve, _}
import org.junit.runner.RunWith
import org.specs2.mutable.Specification
import org.specs2.runner.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RowMajorCurveTest extends Specification with GenericCurveValidation with LazyLogging {
  sequential

  def curveName = "RowmajorCurve"

  def createCurve(precisions: OrdinalNumber*): SpaceFillingCurve =
    RowMajorCurve(precisions.toOrdinalVector)

  "rowmajor space-filling curves" should {
    "satisfy the ordering constraints" >> {
      timeTestOrderings() must beTrue
    }

    "identify sub-ranges correctly" >> {
      val sfc = createCurve(3, 3)
      val query = Query(Seq(OrdinalRanges(OrdinalPair(1, 2)), OrdinalRanges(OrdinalPair(1, 3))))
      val ranges = sfc.getRangesCoveringQuery(query).toList

      for (i <- 0 until ranges.size) {
        println(s"[rowmajor ranges:  query $query] range $i = ${ranges(i)}")
      }

      ranges(0) must equalTo(OrdinalPair(9, 11))
      ranges(1) must equalTo(OrdinalPair(17, 19))
    }
  }
} 
Example 88
Source File: DefaultSaverITCase.scala    From flink-tensorflow   with Apache License 2.0 5 votes vote down vote up
package org.apache.flink.contrib.tensorflow.io

import org.apache.flink.contrib.tensorflow.models.savedmodel.DefaultSavedModelLoader
import org.apache.flink.contrib.tensorflow.util.{FlinkTestBase, RegistrationUtils}
import org.apache.flink.core.fs.Path
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}
import org.tensorflow.{Session, Tensor}

import scala.collection.JavaConverters._

@RunWith(classOf[JUnitRunner])
class DefaultSaverITCase extends WordSpecLike
  with Matchers
  with FlinkTestBase {

  override val parallelism = 1

  "A DefaultSaver" should {
    "run the save op" in {
      val env = StreamExecutionEnvironment.getExecutionEnvironment
      RegistrationUtils.registerTypes(env.getConfig)

      val loader = new DefaultSavedModelLoader(new Path("../models/half_plus_two"), "serve")
      val bundle = loader.load()
      val saverDef = loader.metagraph.getSaverDef
      val saver = new DefaultSaver(saverDef)

      def getA = getVariable(bundle.session(), "a").floatValue()
      def setA(value: Float) = setVariable(bundle.session(), "a", Tensor.create(value))

      val initialA = getA
      println("Initial value: " + initialA)

      setA(1.0f)
      val savePath = tempFolder.newFolder("model-0").getAbsolutePath
      val path = saver.save(bundle.session(), savePath)
      val savedA = getA
      savedA shouldBe (1.0f)
      println("Saved value: " + getA)

      setA(2.0f)
      val updatedA = getA
      updatedA shouldBe (2.0f)
      println("Updated value: " + updatedA)

      saver.restore(bundle.session(), path)
      val restoredA = getA
      restoredA shouldBe (savedA)
      println("Restored value: " + restoredA)
    }

    def getVariable(sess: Session, name: String): Tensor = {
      val result = sess.runner().fetch(name).run().asScala
      result.head
    }

    def setVariable(sess: Session, name: String, value: Tensor): Unit = {
      sess.runner()
        .addTarget(s"$name/Assign")
        .feed(s"$name/initial_value", value)
        .run()
    }
  }
} 
Example 89
Source File: ArraysTest.scala    From flink-tensorflow   with Apache License 2.0 5 votes vote down vote up
package org.tensorflow.contrib.scala

import com.twitter.bijection.Conversion._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Matchers, WordSpecLike}
import org.tensorflow.contrib.scala.Arrays._
import org.tensorflow.contrib.scala.Rank._
import resource._

@RunWith(classOf[JUnitRunner])
class ArraysTest extends WordSpecLike
  with Matchers {

  "Arrays" when {
    "Array[Float]" should {
      "convert to Tensor[`1D`,Float]" in {
        val expected = Array(1f,2f,3f)
        managed(expected.as[TypedTensor[`1D`,Float]]).foreach { t =>
          t.shape shouldEqual Array(expected.length)
          val actual = t.as[Array[Float]]
          actual shouldEqual expected
        }
      }
    }
  }
} 
Example 90
Source File: TestRenaming.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.lir

import at.forsyte.apalache.tla.lir.transformations.impl.TrackerWithListeners
import at.forsyte.apalache.tla.lir.transformations.standard.Renaming
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterEach, FunSuite}


@RunWith(classOf[JUnitRunner])
class TestRenaming extends FunSuite with BeforeAndAfterEach with TestingPredefs {
  import at.forsyte.apalache.tla.lir.Builder._

  private var renaming = new Renaming(TrackerWithListeners())

  override protected def beforeEach(): Unit = {
    renaming = new Renaming(TrackerWithListeners())
  }

  test("test renaming exists/forall") {
    val original =
        and(
          exists(n_x, n_S, gt(n_x, int(1))),
          forall(n_x, n_T, lt(n_x, int(42))))
    ///
    val expected =
      and(
        exists(name("x_1"), n_S, gt(name("x_1"), int(1))),
        forall(name("x_2"), n_T, lt(name("x_2"), int(42))))
    val renamed = renaming.renameBindingsUnique(original)
    assert(expected == renamed)
  }

  test("test renaming filter") {
    val original =
        cup(
          filter(name("x"), name("S"), eql(name("x"), int(1))),
          filter(name("x"), name("S"), eql(name("x"), int(2)))
        )
    val expected =
      cup(
        filter(name("x_1"), name("S"), eql(name("x_1"), int(1))),
        filter(name("x_2"), name("S"), eql(name("x_2"), int(2))))
    val renamed = renaming.renameBindingsUnique(original)
    assert(expected == renamed)
  }

  test( "Test renaming LET-IN" ) {
    // LET p(t) == \A x \in S . R(t,x) IN \E x \in S . p(x)
    val original =
      letIn(
        exists( n_x, n_S, appOp( name( "p" ), n_x ) ),
        declOp( "p", forall( n_x, n_S, appOp( name( "R" ), name( "t" ), n_x ) ), "t" )
      )

    val expected =
      letIn(
        exists( name( "x_2" ), n_S, appOp( name( "p_1" ), name( "x_2" ) ) ),
        declOp( "p_1", forall( name( "x_1" ), n_S, appOp( name( "R" ), name( "t_1" ), name( "x_1" ) ) ), "t_1" )
      )

    val actual = renaming( original )

    assert(expected == actual)
  }

  test( "Test renaming multiple LET-IN" ) {
    // LET X == TRUE IN X /\ LET X == FALSE IN X
    val original =
      and(
        letIn(
          appOp( name( "X" ) ),
          declOp( "X", trueEx )
        ),
        letIn(
          appOp( name( "X" ) ),
          declOp( "X", falseEx )
        )
      )

    val expected =
      and(
      letIn(
        appOp( name( "X_1" ) ),
        declOp( "X_1", trueEx )
      ),
      letIn(
        appOp( name( "X_2" ) ),
        declOp( "X_2", falseEx )
      )
    )

    val actual = renaming( original )

    assert(expected == actual)
  }

} 
Example 91
Source File: TestLirValues.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.lir

import at.forsyte.apalache.tla.lir.values._
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class TestLirValues extends FunSuite {
  test("create booleans") {
    val b = TlaBool(false)
    assert(!b.value)
  }

  test("create int") {
    val i = TlaInt(1)
    assert(i.value == BigInt(1))
    assert(i == TlaInt(1))
    assert(i.isNatural)
    assert(TlaInt(0).isNatural)
    assert(!TlaInt(-1).isNatural)
  }

  test("create a string") {
    val s = TlaStr("hello")
    assert(s.value == "hello")
  }


  test("create a constant") {
    val c = new TlaConstDecl("x")
    assert("x" == c.name)
  }

  test("create a variable") {
    val c = new TlaVarDecl("x")
    assert("x" == c.name)
  }
} 
Example 92
Source File: TestAux.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.lir

import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

@RunWith( classOf[JUnitRunner] )
class TestAux extends FunSuite with TestingPredefs {

  test( "Test aux::collectSegments" ){

    val ar0Decl1 = TlaOperDecl( "X", List.empty, n_x )
    val ar0Decl2 = TlaOperDecl( "Y", List.empty, n_y )
    val ar0Decl3 = TlaOperDecl( "Z", List.empty, n_z )

    val arGe0Decl1 = TlaOperDecl( "A", List( SimpleFormalParam( "t" ) ), n_a )
    val arGe0Decl2 = TlaOperDecl( "B", List( SimpleFormalParam( "t" ) ), n_b )
    val arGe0Decl3 = TlaOperDecl( "C", List( SimpleFormalParam( "t" ) ), n_c )

    val pa1 =
      List( ar0Decl1 ) ->
        List( List( ar0Decl1 ) )
    val pa2 =
      List( ar0Decl1, ar0Decl2 ) ->
        List( List( ar0Decl1, ar0Decl2 ) )
    val pa3 =
      List( arGe0Decl1, ar0Decl1 ) ->
        List( List( arGe0Decl1 ), List( ar0Decl1 ) )
    val pa4 =
      List( arGe0Decl1, arGe0Decl2 ) ->
        List( List( arGe0Decl1, arGe0Decl2 ) )
    val pa5 =
      List( arGe0Decl1, arGe0Decl2, ar0Decl1, ar0Decl2, arGe0Decl3 ) ->
        List( List( arGe0Decl1, arGe0Decl2 ), List( ar0Decl1, ar0Decl2 ), List( arGe0Decl3 ) )

    val expected = Seq(
      pa1, pa2, pa3, pa4, pa5
    )
    val cmp = expected map { case (k, v) =>
      (v, aux.collectSegments( k ))
    }
    cmp foreach { case (ex, act) =>
      assert( ex == act )
    }
  }
} 
Example 93
Source File: TestTypeReduction.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.types

import at.forsyte.apalache.tla.lir.TestingPredefs
import org.junit.runner.RunWith
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.junit.JUnitRunner

@RunWith( classOf[JUnitRunner] )
class TestTypeReduction extends FunSuite with TestingPredefs with BeforeAndAfter {

  var gen = new SmtVarGenerator
  var tr  = new TypeReduction( gen )

  before {
    gen = new SmtVarGenerator
    tr = new TypeReduction( gen )
  }

  test( "Test nesting" ) {
    val tau = FunT( IntT, SetT( IntT ) )
    val m = Map.empty[TypeVar, SmtTypeVariable]
    val rr = tr( tau, m )
    assert( rr.t == fun( int, set( int ) ) )
  }

  test("Test tuples"){
    val tau = SetT( FunT( TupT( IntT, StrT ), SetT( IntT ) ) )
    val m = Map.empty[TypeVar, SmtTypeVariable]
    val rr = tr(tau, m)
    val idx = SmtIntVariable( 0 )
    assert( rr.t == set( fun( tup( idx ), set( int ) ) ) )
    assert( rr.phi.contains( hasIndex( idx, 0, int ) ) )
    assert( rr.phi.contains( hasIndex( idx, 1, str ) ) )
  }
} 
Example 94
Source File: TestSymbStateRewriterStr.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.lir.convenience.tla
import at.forsyte.apalache.tla.lir.values.TlaStr
import at.forsyte.apalache.tla.lir.{NameEx, ValEx}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterStr extends RewriterBase {
  test("SE-STR-CTOR: \"red\" -> $C$k") {
    val state = new SymbState(ValEx(TlaStr("red")),
      CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextStateRed = rewriter.rewriteUntilDone(state)
    nextStateRed.ex match {
      case predEx@NameEx(name) =>
        assert(CellTheory().hasConst(name))
        assert(CellTheory() == state.theory)
        assert(solverContext.sat())
        val redEqBlue = tla.eql(tla.str("blue"), tla.str("red"))
        val nextStateEq = rewriter.rewriteUntilDone(nextStateRed.setRex(redEqBlue))
        rewriter.push()
        solverContext.assertGroundExpr(nextStateEq.ex)
        assert(!solverContext.sat())
        rewriter.pop()
        solverContext.assertGroundExpr(tla.not(nextStateEq.ex))
        assert(solverContext.sat())


      case _ =>
        fail("Unexpected rewriting result")
    }
  }
} 
Example 95
Source File: TestVCGenerator.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.imp.SanyImporter
import at.forsyte.apalache.tla.imp.src.SourceStore
import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker
import at.forsyte.apalache.tla.lir.{TlaModule, TlaOperDecl}
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

import scala.io.Source

@RunWith(classOf[JUnitRunner])
class TestVCGenerator extends FunSuite {
  private def mkVCGen(): VCGenerator = {
    new VCGenerator(new IdleTracker)
  }

  test("simple invariant") {
    val text =
      """---- MODULE inv ----
        |EXTENDS Integers
        |VARIABLE x
        |Inv == x > 0
        |====================
      """.stripMargin

    val mod = loadFromText("inv", text)
    val newMod = mkVCGen().gen(mod, "Inv")
    assertDecl(newMod, "VCInv$0", "x > 0")
    assertDecl(newMod, "VCNotInv$0", "¬(x > 0)")
  }

  test("conjunctive invariant") {
    val text =
      """---- MODULE inv ----
        |EXTENDS Integers
        |VARIABLE x
        |Inv == x > 0 /\ x < 10
        |====================
      """.stripMargin

    val mod = loadFromText("inv", text)
    val newMod = mkVCGen().gen(mod, "Inv")
    assertDecl(newMod, "VCInv$0", "x > 0")
    assertDecl(newMod, "VCInv$1", "x < 10")
    assertDecl(newMod, "VCNotInv$0", "¬(x > 0)")
    assertDecl(newMod, "VCNotInv$1", "¬(x < 10)")
  }

  test("conjunction under universals") {
    val text =
      """---- MODULE inv ----
        |EXTENDS Integers
        |VARIABLE x, S
        |Inv == \A z \in S: \A y \in S: y > 0 /\ y < 10
        |====================
      """.stripMargin

    val mod = loadFromText("inv", text)
    val newMod = mkVCGen().gen(mod, "Inv")
    assertDecl(newMod, "VCInv$0", """∀z ∈ S: (∀y ∈ S: (y > 0))""")
    assertDecl(newMod, "VCInv$1", """∀z ∈ S: (∀y ∈ S: (y < 10))""")
    assertDecl(newMod, "VCNotInv$0", """¬(∀z ∈ S: (∀y ∈ S: (y > 0)))""")
    assertDecl(newMod, "VCNotInv$1", """¬(∀z ∈ S: (∀y ∈ S: (y < 10)))""")
  }

  private def assertDecl(mod: TlaModule, name: String, expectedBodyText: String): Unit = {
    val vc = mod.declarations.find(_.name == name)
    assert(vc.nonEmpty, s"(VC $name not found)")
    assert(vc.get.isInstanceOf[TlaOperDecl])
    assert(vc.get.asInstanceOf[TlaOperDecl].body.toString == expectedBodyText)
  }

  private def loadFromText(moduleName: String, text: String): TlaModule = {
    val locationStore = new SourceStore
    val (rootName, modules) = new SanyImporter(locationStore)
      .loadFromSource(moduleName, Source.fromString(text))
    modules(moduleName)
  }
} 
Example 96
Source File: TestTypeInference.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types.{Signatures, TypeInference}
import at.forsyte.apalache.tla.lir.TestingPredefs
import at.forsyte.apalache.tla.lir.convenience._
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

// TODO: remove?
@RunWith( classOf[JUnitRunner] )
class TestTypeInference extends FunSuite with TestingPredefs {

  ignore( "Signatures" ) {
    val exs = List(
      tla.and( n_x, n_y ),
      tla.choose( n_x, n_S, n_p ),
      tla.enumSet( seq( 10 ) : _* ),
      tla.in( n_x, n_S ),
      tla.map( n_e, n_x, n_S )
    )

    val sigs = exs map Signatures.get

    exs zip sigs foreach { case (x, y) => println( s"${x}  ...  ${y}" ) }

    val funDef = tla.funDef( tla.plus( n_x, n_y ), n_x, n_S, n_y, n_T )

    val sig = Signatures.get( funDef )

    printsep()
    println( sig )
    printsep()
  }

  ignore( "TypeInference" ) {
    val ex = tla.and( tla.primeEq( n_a, tla.choose( n_x, n_S, n_p ) ), tla.in( 2, n_S ) )

    val r = TypeInference.theta( ex )

    println( r )

  }

  ignore( "Application" ) {

    val ex = tla.eql( tla.plus(  tla.appFun( n_f, n_x ) , 2), 4 )
    val ex2 =
      tla.and(
        tla.in( n_x, n_S ),
        tla.le(
          tla.plus(
            tla.mult( 2, n_x ),
            5
          ),
          10
        ),
        tla.primeEq( n_x,
          tla.appFun(
            n_f,
            n_x
          )
        )
      )

    val r = TypeInference( ex )
  }
} 
Example 97
Source File: TestSymbStateRewriterChoose.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types.{AnnotationParser, FinSetT, IntT}
import at.forsyte.apalache.tla.lir.TestingPredefs
import at.forsyte.apalache.tla.lir.convenience.tla
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterChoose extends RewriterBase with TestingPredefs {
  test("""CHOOSE x \in {1, 2, 3}: x > 1""") {
    val ex = tla.choose(tla.name("x"),
      tla.enumSet(tla.int(1), tla.int(2), tla.int(3)),
      tla.gt(tla.name("x"), tla.int(1)))
    val state = new SymbState(ex, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    def assertEq(i: Int): SymbState = {
      val ns = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(nextState.ex, tla.int(i))))
      solverContext.assertGroundExpr(ns.ex)
      ns
    }

    rewriter.push()
    assertEq(3)
    assert(solverContext.sat())
    rewriter.pop()
    rewriter.push()
    assertEq(2)
    assert(solverContext.sat())
    rewriter.pop()
    rewriter.push()
    val ns = assertEq(1)
    assertUnsatOrExplain(rewriter, ns)
  }

  test("""CHOOSE x \in {1}: x > 1""") {
    val ex = tla.choose(tla.name("x"),
      tla.enumSet(tla.int(1)),
      tla.gt(tla.name("x"), tla.int(1)))
    val state = new SymbState(ex, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    // the buggy implementation of choose fails on a dynamically empty set
    assert(solverContext.sat())
    // The semantics of choose does not restrict the outcome on the empty sets,
    // so we do not test for anything here. Our previous implementation of CHOOSE produced default values in this case,
    // but this happened to be error-prone and sometimes conflicting with other rules. So, no default values.
  }

  test("""CHOOSE x \in {}: x > 1""") {
    val ex = tla.choose(tla.name("x"),
      tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))),
      tla.gt(tla.name("x"), tla.int(1)))
    val state = new SymbState(ex, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    // the buggy implementation of choose fails on a dynamically empty set
    assert(solverContext.sat())
    def assertEq(i: Int): SymbState = {
      val ns = rewriter.rewriteUntilDone(nextState.setRex(tla.eql(nextState.ex, tla.int(i))))
      solverContext.assertGroundExpr(ns.ex)
      ns
    }

    // Actually, semantics of choose does not restrict the outcome on the empty sets.
    // But we know that our implementation would always return 0 in this case.
    val ns = assertEq(1)
    assertUnsatOrExplain(rewriter, ns)
  }
} 
Example 98
Source File: TestSymbStateRewriterFiniteSets.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types._
import at.forsyte.apalache.tla.lir.{NameEx, TlaEx}
import at.forsyte.apalache.tla.lir.convenience.tla
import at.forsyte.apalache.tla.lir.oper.TlaFunOper
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterFiniteSets extends RewriterBase {
  test("""Cardinality({1, 2, 3}) = 3""") {
    val set = tla.enumSet(1.to(3).map(tla.int) :_*)
    val card = tla.card(set)
    val state = new SymbState(card, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(3), nextState.ex)))
  }

  test("""Cardinality({1, 2, 2, 2, 3, 3}) = 3""") {
    val set = tla.enumSet(Seq(1, 2, 2, 2, 3, 3).map(tla.int) :_*)
    val card = tla.card(set)
    val state = new SymbState(card, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(3), nextState.ex)))
  }

  test("""Cardinality({1, 2, 3} \ {2}) = 2""") {
    def setminus(set: TlaEx, intVal: Int): TlaEx = {
      tla.filter(tla.name("t"),
        set,
        tla.not(tla.eql(tla.name("t"), tla.int(intVal))))
    }

    val set = setminus(tla.enumSet(1.to(3).map(tla.int) :_*), 2)
    val card = tla.card(set)
    val state = new SymbState(card, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.int(2), nextState.ex)))
  }

  test("""IsFiniteSet({1, 2, 3}) = TRUE""") {
    val set = tla.enumSet(1.to(3).map(tla.int) :_*)
    val card = tla.isFin(set)
    val state = new SymbState(card, CellTheory(), arena, new Binding)
    val rewriter = create()
    val nextState = rewriter.rewriteUntilDone(state)
    assert(solverContext.sat())
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(tla.bool(true), nextState.ex)))
  }

} 
Example 99
Source File: TestUninterpretedConstOracle.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt.rules.aux

import at.forsyte.apalache.tla.bmcmt.types.BoolT
import at.forsyte.apalache.tla.bmcmt.{Binding, CellTheory, RewriterBase, SymbState}
import at.forsyte.apalache.tla.lir.TestingPredefs
import at.forsyte.apalache.tla.lir.convenience.tla
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestUninterpretedConstOracle extends RewriterBase with TestingPredefs {
  test("""Oracle.create""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
  }

  test("""Oracle.whenEqualTo""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3))
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 4))
    assert(!solverContext.sat())
  }

  test("""Oracle.evalPosition""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3))
    assert(solverContext.sat())
    val position = oracle.evalPosition(rewriter.solverContext, nextState)
    assert(3 == position)
  }

  test("""Oracle.caseAssertions""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    state = state.updateArena(_.appendCell(BoolT()))
    val flag = state.arena.topCell
    // introduce an oracle
    val (nextState, oracle) = UninterpretedConstOracle.create(rewriter, state, 2)
    // assert flag == true iff oracle = 0
    rewriter.solverContext.assertGroundExpr(oracle.caseAssertions(nextState, Seq(flag.toNameEx, tla.not(flag.toNameEx))))
    // assert oracle = 1
    rewriter.push()
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 1))
    assert(solverContext.sat())
    assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(false))
    rewriter.pop()
    // assert oracle = 0
    rewriter.push()
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 0))
    assert(solverContext.sat())
    assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(true))
    rewriter.pop()
  }
} 
Example 100
Source File: TestPropositionalOracle.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt.rules.aux

import at.forsyte.apalache.tla.bmcmt.types.BoolT
import at.forsyte.apalache.tla.bmcmt.{Binding, CellTheory, RewriterBase, SymbState}
import at.forsyte.apalache.tla.lir.TestingPredefs
import at.forsyte.apalache.tla.lir.convenience.tla
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestPropositionalOracle extends RewriterBase with TestingPredefs {
  test("""Oracle.create""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
  }

  test("""Oracle.whenEqualTo""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3))
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 4))
    assert(!solverContext.sat())
  }

  test("""Oracle.evalPosition""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    // introduce an oracle
    val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 6)
    assert(solverContext.sat())
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 3))
    assert(solverContext.sat())
    val position = oracle.evalPosition(rewriter.solverContext, nextState)
    assert(3 == position)
  }

  test("""Oracle.caseAssertions""") {
    val rewriter = create()
    var state = new SymbState(tla.bool(true), CellTheory(), arena, new Binding)
    state = state.updateArena(_.appendCell(BoolT()))
    val flag = state.arena.topCell
    // introduce an oracle
    val (nextState, oracle) = PropositionalOracle.create(rewriter, state, 2)
    // assert flag == true iff oracle = 0
    rewriter.solverContext.assertGroundExpr(oracle.caseAssertions(nextState, Seq(flag.toNameEx, tla.not(flag.toNameEx))))
    // assert oracle = 1
    rewriter.push()
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 1))
    assert(solverContext.sat())
    assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(false))
    rewriter.pop()
    // assert oracle = 0
    rewriter.push()
    rewriter.solverContext.assertGroundExpr(oracle.whenEqualTo(nextState, 0))
    assert(solverContext.sat())
    assert(solverContext.evalGroundExpr(flag.toNameEx) == tla.bool(true))
    rewriter.pop()
  }
} 
Example 101
Source File: TestSymbStateRewriterAction.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.SymbStateRewriter.Continue
import at.forsyte.apalache.tla.bmcmt.types.IntT
import at.forsyte.apalache.tla.lir.NameEx
import at.forsyte.apalache.tla.lir.convenience._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterAction extends RewriterBase {
  test("""SE-PRIME: x' ~~> NameEx(x')""") {
    val rewriter = create()
    arena.appendCell(IntT()) // the type finder is strict about unassigned types, so let's create a cell for x'
    val state = new SymbState(tla.prime(NameEx("x")), CellTheory(), arena, Binding("x'" -> arena.topCell))
    rewriter.rewriteOnce(state) match {
      case Continue(next) =>
        assert(next.ex == NameEx("x'"))

      case _ =>
        fail("Expected x to be renamed to x'")
    }
  }
} 
Example 102
Source File: TestArena.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types.{BoolT, FinSetT, UnknownT}
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestArena extends FunSuite {
  test("create cells") {
    val solverContext = new Z3SolverContext()
    val emptyArena = Arena.create(solverContext)
    val arena = emptyArena.appendCell(UnknownT())
    assert(emptyArena.cellCount + 1 == arena.cellCount)
    assert(UnknownT() == arena.topCell.cellType)
    val arena2 = arena.appendCell(BoolT())
    assert(emptyArena.cellCount + 2 == arena2.cellCount)
    assert(BoolT() == arena2.topCell.cellType)
  }

  test("add 'has' edges") {
    val solverContext = new Z3SolverContext()
    val arena = Arena.create(solverContext).appendCell(FinSetT(UnknownT()))
    val set = arena.topCell
    val arena2 = arena.appendCell(BoolT())
    val elem = arena2.topCell
    val arena3 = arena2.appendHas(set, elem)
    assert(List(elem) == arena3.getHas(set))
  }

  test("BOOLEAN has FALSE and TRUE") {
    val solverContext = new Z3SolverContext()
    val arena = Arena.create(solverContext)
    val boolean = arena.cellBooleanSet()
    assert(List(arena.cellFalse(), arena.cellTrue()) == arena.getHas(arena.cellBooleanSet()))
  }
} 
Example 103
Source File: TestSymbStateRewriterExpand.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.bmcmt

import at.forsyte.apalache.tla.bmcmt.types._
import at.forsyte.apalache.tla.lir.{NameEx, OperEx, TlaEx}
import at.forsyte.apalache.tla.lir.convenience.tla
import at.forsyte.apalache.tla.lir.oper.BmcOper
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestSymbStateRewriterExpand extends RewriterBase {
  test("""Expand(SUBSET {1, 2})""") {
    val baseset = tla.enumSet(tla.int(1), tla.int(2))
    val expandPowset = OperEx(BmcOper.expand, tla.powSet(baseset))
    val state = new SymbState(expandPowset, CellTheory(), arena, new Binding)
    val rewriter = create()
    var nextState = rewriter.rewriteUntilDone(state)
    val powCell = nextState.asCell
    // check equality
    val eq = tla.eql(nextState.ex,
      tla.enumSet(tla.withType(tla.enumSet(), AnnotationParser.toTla(FinSetT(IntT()))),
        tla.enumSet(tla.int(1)),
        tla.enumSet(tla.int(2)),
          tla.enumSet(tla.int(1), tla.int(2))))
    assertTlaExAndRestore(rewriter, nextState.setRex(eq))
  }

  test("""Expand([{1, 2, 3} -> {FALSE, TRUE}]) should fail""") {
    val domain = tla.enumSet(tla.int(1), tla.int(2), tla.int(3))
    val codomain = tla.enumSet(tla.bool(false), tla.bool(true))
    val funSet = OperEx(BmcOper.expand, tla.funSet(domain, codomain))
    val state = new SymbState(funSet, CellTheory(), arena, new Binding)
    val rewriter = create()
    assertThrows[RewriterException](rewriter.rewriteUntilDone(state))
  }

  // Constructing an explicit set of functions is, of course, expensive. But it should work for small values.
  // Left for the future...
  ignore("""Expand([{1, 2} -> {FALSE, TRUE}]) should work""") {
    val domain = tla.enumSet(tla.int(1), tla.int(2))
    val codomain = tla.enumSet(tla.bool(false), tla.bool(true))
    val funSet = OperEx(BmcOper.expand, tla.funSet(domain, codomain))
    val state = new SymbState(funSet, CellTheory(), arena, new Binding)
    val rewriter = create()
    var nextState = rewriter.rewriteUntilDone(state)
    val funSetCell = nextState.asCell
    def mkFun(v1: Boolean, v2: Boolean): TlaEx = {
      val mapEx = tla.ite(tla.eql(NameEx("x"), tla.int(1)), tla.bool(v1), tla.bool(v2))
      tla.funDef(mapEx, tla.name("x"), domain)
    }
    val expected = tla.enumSet(mkFun(false, false), mkFun(false, true),
                               mkFun(true, false), mkFun(true, true))
    assertTlaExAndRestore(rewriter, nextState.setRex(tla.eql(expected, funSetCell.toNameEx)))
  }

} 
Example 104
Source File: TestSourceStore.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.imp.src

import at.forsyte.apalache.tla.lir.convenience.tla
import at.forsyte.apalache.tla.lir.src.SourceRegion
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class TestSourceStore extends FunSuite {
  test("basic add and find") {
    val store = new SourceStore()
    val ex = tla.int(1)
    val loc = SourceLocation("root", SourceRegion(1, 2, 3, 4))
    store.addRec(ex, loc)
    val foundLoc = store.find(ex.ID)
    assert(loc == foundLoc.get)
  }

  test("recursive add and find") {
    val store = new SourceStore()
    val int1 = tla.int(1)
    val set = tla.enumSet(int1)
    val loc = SourceLocation("root", SourceRegion(1, 2, 3, 4))
    store.addRec(set, loc)
    val foundLoc = store.find(set.ID)
    assert(loc == foundLoc.get)
    val foundLoc2 = store.find(int1.ID)
    assert(loc == foundLoc2.get)
  }

  test("locations are not overwritten") {
    val store = new SourceStore()
    val int1 = tla.int(1)
    val set = tla.enumSet(int1)
    val set2 = tla.enumSet(set)
    val loc1 = SourceLocation("tada", SourceRegion(100, 200, 300, 400))
    store.addRec(int1, loc1)
    val loc2 = SourceLocation("root", SourceRegion(1, 2, 3, 4))
    store.addRec(set2, loc2)
    assert(loc2 == store.find(set2.ID).get)
    assert(loc2 == store.find(set.ID).get)
    assert(loc1 == store.find(int1.ID).get)
  }
} 
Example 105
Source File: TestRegionTree.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.imp.src

import at.forsyte.apalache.tla.lir.src.{RegionTree, SourcePosition, SourceRegion}
import org.junit.runner.RunWith
import org.scalatest.FunSuite
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestRegionTree extends FunSuite {
  test("add") {
    val tree = new RegionTree()
    val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10))
    tree.add(region)
  }

  test("add a subregion, then size") {
    val tree = new RegionTree()
    val reg1 = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10))
    tree.add(reg1)
    assert(tree.size == 1)
    val reg2 = SourceRegion(SourcePosition(1, 20), SourcePosition(2, 5))
    tree.add(reg2)
    assert(tree.size == 2)
    val reg3 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10))
    tree.add(reg3)
    assert(tree.size == 3)
  }

  test("add an overlapping subregion") {
    val tree = new RegionTree()
    val reg1 = SourceRegion(SourcePosition(1, 10), SourcePosition(3, 10))
    tree.add(reg1)
    val reg2 = SourceRegion(SourcePosition(1, 20), SourcePosition(5, 20))
    assertThrows[IllegalArgumentException] {
      tree.add(reg2)
    }
  }

  test("add a small region, then a larger region") {
    val tree = new RegionTree()
    val reg1 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10))
    tree.add(reg1)
    val reg2 = SourceRegion(SourcePosition(1, 1), SourcePosition(4, 1))
    tree.add(reg2)
  }

  test("add a region twice") {
    val tree = new RegionTree()
    val reg1 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10))
    tree.add(reg1)
    val reg2 = SourceRegion(SourcePosition(2, 10), SourcePosition(3, 10))
    tree.add(reg2)
  }

  test("add and find") {
    val tree = new RegionTree()
    val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10))
    val idx = tree.add(region)
    val found = tree(idx)
    assert(found == region)
  }

  test("find non-existing index") {
    val tree = new RegionTree()
    val region = SourceRegion(SourcePosition(1, 20), SourcePosition(3, 10))
    val idx = tree.add(region)
    assertThrows[IndexOutOfBoundsException] {
      tree(999)
    }
  }
} 
Example 106
Source File: TestConstAndDefRewriter.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.pp

import at.forsyte.apalache.tla.imp.SanyImporter
import at.forsyte.apalache.tla.imp.src.SourceStore
import at.forsyte.apalache.tla.lir.{SimpleFormalParam, TlaOperDecl}
import at.forsyte.apalache.tla.lir.convenience._
import at.forsyte.apalache.tla.lir.transformations.impl.IdleTracker
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterEach, FunSuite}

import scala.io.Source

@RunWith(classOf[JUnitRunner])
class TestConstAndDefRewriter extends FunSuite with BeforeAndAfterEach {
  test("override a constant") {
    val text =
      """---- MODULE const ----
        |CONSTANT n
        |OVERRIDE_n == 10
        |A == {n}
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("const", Source.fromString(text))
    val root = modules(rootName)
    val rewritten = new ConstAndDefRewriter(new IdleTracker())(root)
    assert(rewritten.constDeclarations.isEmpty) // no constants anymore
    assert(rewritten.operDeclarations.size == 2)
    val expected_n = TlaOperDecl("n", List(), tla.int(10))
    assert(expected_n == rewritten.operDeclarations.head)
    val expected_A = TlaOperDecl("A", List(), tla.enumSet(tla.appOp(tla.name("n"))))
    assert(expected_A == rewritten.operDeclarations(1))
  }

  // In TLA+, constants may be operators with multiple arguments.
  // We do not support that yet.
  test("override a constant with a unary operator") {
    val text =
      """---- MODULE const ----
        |CONSTANT n
        |OVERRIDE_n(x) == x
        |A == {n}
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("const", Source.fromString(text))
    val root = modules(rootName)
    assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root))
  }

  test("overriding a variable with an operator => error") {
    val text =
      """---- MODULE const ----
        |VARIABLE n, m
        |OVERRIDE_n == m
        |A == {n}
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("const", Source.fromString(text))
    val root = modules(rootName)
    assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root))
  }

  test("override an operator") {
    val text =
      """---- MODULE op ----
        |BoolMin(S) == CHOOSE x \in S: \A y \in S: x => y
        |OVERRIDE_BoolMin(S) == CHOOSE x \in S: TRUE
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("op", Source.fromString(text))
    val root = modules(rootName)
    val rewritten = new ConstAndDefRewriter(new IdleTracker())(root)
    assert(rewritten.constDeclarations.isEmpty)
    assert(rewritten.operDeclarations.size == 1)
    val expected = TlaOperDecl("BoolMin", List(SimpleFormalParam("S")),
      tla.choose(tla.name("x"), tla.name("S"), tla.bool(true)))
    assert(expected == rewritten.operDeclarations.head)
  }

  test("override a unary operator with a binary operator") {
    val text =
      """---- MODULE op ----
        |BoolMin(S) == CHOOSE x \in S: \A y \in S: x => y
        |OVERRIDE_BoolMin(S, T) == CHOOSE x \in S: x \in T
        |================================
      """.stripMargin

    val (rootName, modules) = new SanyImporter(new SourceStore)
      .loadFromSource("op", Source.fromString(text))
    val root = modules(rootName)
    assertThrows[OverridingError](new ConstAndDefRewriter(new IdleTracker())(root))
  }
} 
Example 107
Source File: TestUniqueNameGenerator.scala    From apalache   with Apache License 2.0 5 votes vote down vote up
package at.forsyte.apalache.tla.pp

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterEach, FunSuite}

@RunWith(classOf[JUnitRunner])
class TestUniqueNameGenerator extends FunSuite with BeforeAndAfterEach {
  test("first three") {
    val gen = new UniqueNameGenerator
    assert("t_1" == gen.newName())
    assert("t_2" == gen.newName())
    assert("t_3" == gen.newName())
  }

  test("after 10000") {
    val gen = new UniqueNameGenerator
    for (i <- 1.to(10000)) {
      gen.newName()
    }
    assert("t_7pt" == gen.newName())
  }
} 
Example 108
Source File: DumThroAwayTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.semantics.compiled.plugin.csv

import com.eharmony.aloha.audit.impl.OptionAuditor
import com.eharmony.aloha.factory.ModelFactory
import com.eharmony.aloha.semantics.compiled.CompiledSemantics
import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

import scala.concurrent.ExecutionContext.Implicits.global

@RunWith(classOf[BlockJUnit4ClassRunner])
class DumThroAwayTest {
    @Test def test1() {
        val compiler = TwitterEvalCompiler()
        val plugin = CompiledSemanticsCsvPlugin(Map("profile.user_id" -> CsvTypes.withNameExtended("oi")))
        val imports = Seq("com.eharmony.aloha.feature.BasicFunctions._", "scala.math._")
        val semantics = CompiledSemantics(compiler, plugin, imports)
        val factory = ModelFactory.defaultFactory(semantics, OptionAuditor[Double]())

        val model = factory.fromResource("fizzbuzz.json").get

        val lineProducer = CsvLines(Map("profile.user_id" -> 0))
        val examples = "" :: (-16 to 16 map { _.toString }).toList
        val lines = lineProducer(examples)

        val expected = Seq(
            (None, -1.0),
            (Some(-16), 16.0),
            (Some(-15), -6.0),
            (Some(-14), 14.0),
            (Some(-13), 13.0),
            (Some(-12), -2.0),
            (Some(-11), 11.0),
            (Some(-10), -4.0),
            (Some(-9), -2.0),
            (Some(-8), 8.0),
            (Some(-7), 7.0),
            (Some(-6), -2.0),
            (Some(-5), -4.0),
            (Some(-4), 4.0),
            (Some(-3), -2.0),
            (Some(-2), 2.0),
            (Some(-1), 1.0),
            (Some(0), -6.0),

            (Some(1), 1.0),
            (Some(2), 2.0),
            (Some(3), -2.0),
            (Some(4), 4.0),
            (Some(5), -4.0),
            (Some(6), -2.0),
            (Some(7), 7.0),
            (Some(8), 8.0),
            (Some(9), -2.0),
            (Some(10), -4.0),
            (Some(11), 11.0),
            (Some(12), -2.0),
            (Some(13), 13.0),
            (Some(14), 14.0),
            (Some(15), -6.0),
            (Some(16), 16.0)
        )

        val results = lines.map { l => (l.oi("profile.user_id"), model(l)) }.
                            map { case (optId, s) => (optId, s.get) }

        assertEquals(expected, results)
    }
} 
Example 109
Source File: CsvColumnTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.csv.json

import com.eharmony.aloha.reflect.{RefInfo, RefInfoOps}
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner
import spray.json._
import spray.json.DefaultJsonProtocol._


@RunWith(classOf[BlockJUnit4ClassRunner])
class CsvColumnTest {
    @Test def test1() {
        val examples = Seq(
            """{ "name": "long",       "type": "long",   "spec": "${long}" }""",
            """{ "name": "opt_double", "type": "double", "spec": "${opt_double}" }""",
            """{ "name": "syn_enum",   "type": "enum",   "spec": "${opt_string}", "values":    [ "e1v1" ] }""",
            """{ "name": "enum",       "type": "enum",   "spec": "${string}", "enumClass": "com.eharmony.matching.notaloha.AnEnum" }"""
        )

        val expected = Seq(
            CsvColumnWithDefault[Long]("long", "${long}"),
            CsvColumnWithDefault[Double]("opt_double", "${opt_double}"),
            SyntheticEnumCsvColumn("syn_enum", "${opt_string}", Seq("e1v1")),
            EnumCsvColumn("enum", "${string}", "com.eharmony.matching.notaloha.AnEnum")
        )

        val act = examples.map { ex => CsvColumn.csvColumnSpecFormat.read(ex.parseJson) }

        assertEquals(expected, act)
    }

    @Test def testReqEnum() {
        val jsonTxt = """{ "name": "some_enum",
                        |  "type": "enum",
                        |  "spec": "${string}",
                        |  "enumClass": "com.eharmony.matching.notaloha.AnEnum"
                        |}""".stripMargin
        val json = jsonTxt.parseJson
        val col = json.convertTo[CsvColumn]
        assertTrue(col.isInstanceOf[EnumCsvColumn])
    }

    @Test def testOptEnum() {
        val jsonTxt = """{ "name": "some_enum",
                        |  "type": "enum",
                        |  "spec": "${string}",
                        |  "enumClass": "com.eharmony.matching.notaloha.AnEnum",
                        |  "defVal": "VALUE_2",
                        |  "optional": true
                        |}""".stripMargin
        val json = jsonTxt.parseJson
        val col = json.convertTo[CsvColumn]
        assertTrue(col.isInstanceOf[OptionEnumCsvColumn[_]])
    }

    @Test def testSizedByte(): Unit = testSizedCreation[Byte]
    @Test def testSizedChar(): Unit = testSizedCreation[Char]
    @Test def testSizedShort(): Unit = testSizedCreation[Short]
    @Test def testSizedInt(): Unit = testSizedCreation[Int]
    @Test def testSizedLong(): Unit = testSizedCreation[Long]
    @Test def testSizedFloat(): Unit = testSizedCreation[Float]
    @Test def testSizedDouble(): Unit = testSizedCreation[Double]
    @Test def testSizedString(): Unit = testSizedCreation[String]

    private def testSizedCreation[A: RefInfo: JsonFormat]: Unit = {
        val tpe = RefInfoOps.toString(RefInfo[A]).split("\\.").last.toLowerCase
        val name = tpe.replaceAll("[aeiou]", "")
        val spec = "${string}"
        val jsonTxt = s"""{ "name": "$name", "type": "$tpe", "size": 2, "spec": "$spec"}"""
        val col = jsonTxt.parseJson.convertTo[CsvColumn]
        val exp = SeqCsvColumnWithNoDefault[A](name, spec, 2)
        assertEquals(exp, col)
    }
} 
Example 110
Source File: OptionCsvColumnWithDefaultTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.csv.json

import com.eharmony.aloha.semantics.compiled.CompiledSemantics
import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler
import com.eharmony.aloha.semantics.compiled.plugin.csv.{CompiledSemanticsCsvPlugin, CsvLines, CsvTypes}
import com.eharmony.aloha.semantics.func.GenAggFunc
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner
import spray.json.DefaultJsonProtocol.DoubleJsonFormat

import scala.concurrent.ExecutionContext.Implicits.global


  private[this] def compileOptFn[A, C](s: CompiledSemantics[A], c: TypedColumnCol[C]): GenAggFunc[A, Option[C]] = {
    s.createFunction[Option[C]](c.wrappedSpec, Some(c.defVal))(c.refInfo).fold(
      errs => throw new RuntimeException(s"Problem compiling function:\n${errs.mkString("\n")}"),
      fn => c match {
        case col: OptionCsvColumnWithDefault[C] => fn.andThenGenAggFunc(_ orElse c.defVal)
        case _ => fn
      }
    )
  }
}

private[json] object OptionCsvColumnWithDefaultTest {
  type TypedColumnCol[A] = CsvColumn { type ColType = A }

  private[this] val features = Seq(
    "height_mm" -> CsvTypes.DoubleOptionType,
    "height_cm" -> CsvTypes.IntType
  )

  private[this] val missing = ""

  // Test height actual data:  height_mm [TAB] height_cm
  val lines = CsvLines(indices = features.unzip._1.zipWithIndex.toMap)(
    "1800\t180",
    s"$missing\t165"
  )

  lazy val plugin = CompiledSemanticsCsvPlugin(features: _*)
  lazy val semantics = CompiledSemantics(TwitterEvalCompiler(), plugin, Nil)
} 
Example 111
Source File: ComparisonsTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.feature

import org.junit.runners.BlockJUnit4ClassRunner
import org.junit.runner.RunWith
import org.junit.Test
import org.junit.Assert._

@RunWith(classOf[BlockJUnit4ClassRunner])
class ComparisonsTest {
    import ComparisonsTest._

    @Test def test_gtLt_1_1(): Unit = assertFalse(gtLt2(1, 1))
    @Test def test_gtLt_1_2(): Unit = assertFalse(gtLt2(1, 2))
    @Test def test_gtLt_1_3(): Unit = assertTrue(gtLt2(1, 3))
    @Test def test_gtLt_2_1(): Unit = assertFalse(gtLt2(2, 1))
    @Test def test_gtLt_2_2(): Unit = assertFalse(gtLt2(2, 2))
    @Test def test_gtLt_2_3(): Unit = assertFalse(gtLt2(2, 3))
    @Test def test_gtLt_3_1(): Unit = assertFalse(gtLt2(3, 1))
    @Test def test_gtLt_3_2(): Unit = assertFalse(gtLt2(3, 2))
    @Test def test_gtLt_3_3(): Unit = assertFalse(gtLt2(3, 3))

    @Test def test_gtLte_1_1(): Unit = assertFalse(gtLte2(1, 1))
    @Test def test_gtLte_1_2(): Unit = assertTrue(gtLte2(1, 2))
    @Test def test_gtLte_1_3(): Unit = assertTrue(gtLte2(1, 3))
    @Test def test_gtLte_2_1(): Unit = assertFalse(gtLte2(2, 1))
    @Test def test_gtLte_2_2(): Unit = assertFalse(gtLte2(2, 2))
    @Test def test_gtLte_2_3(): Unit = assertFalse(gtLte2(2, 3))
    @Test def test_gtLte_3_1(): Unit = assertFalse(gtLte2(3, 1))
    @Test def test_gtLte_3_2(): Unit = assertFalse(gtLte2(3, 2))
    @Test def test_gtLte_3_3(): Unit = assertFalse(gtLte2(3, 3))

    @Test def test_gteLt_1_1(): Unit = assertFalse(gteLt2(1, 1))
    @Test def test_gteLt_1_2(): Unit = assertFalse(gteLt2(1, 2))
    @Test def test_gteLt_1_3(): Unit = assertTrue(gteLt2(1, 3))
    @Test def test_gteLt_2_1(): Unit = assertFalse(gteLt2(2, 1))
    @Test def test_gteLt_2_2(): Unit = assertFalse(gteLt2(2, 2))
    @Test def test_gteLt_2_3(): Unit = assertTrue(gteLt2(2, 3))
    @Test def test_gteLt_3_1(): Unit = assertFalse(gteLt2(3, 1))
    @Test def test_gteLt_3_2(): Unit = assertFalse(gteLt2(3, 2))
    @Test def test_gteLt_3_3(): Unit = assertFalse(gteLt2(3, 3))

    @Test def test_gteLte_1_1(): Unit = assertFalse(gteLte2(1, 1))
    @Test def test_gteLte_1_2(): Unit = assertTrue(gteLte2(1, 2))
    @Test def test_gteLte_1_3(): Unit = assertTrue(gteLte2(1, 3))
    @Test def test_gteLte_2_1(): Unit = assertFalse(gteLte2(2, 1))
    @Test def test_gteLte_2_2(): Unit = assertTrue(gteLte2(2, 2))
    @Test def test_gteLte_2_3(): Unit = assertTrue(gteLte2(2, 3))
    @Test def test_gteLte_3_1(): Unit = assertFalse(gteLte2(3, 1))
    @Test def test_gteLte_3_2(): Unit = assertFalse(gteLte2(3, 2))
    @Test def test_gteLte_3_3(): Unit = assertFalse(gteLte2(3, 3))
}


object ComparisonsTest {
    import Comparisons._
    val gtLt2: (Int, Int) => Boolean = gtLt(2, _, _)
    val gtLte2: (Int, Int) => Boolean = gtLe(2, _, _)
    val gteLt2: (Int, Int) => Boolean = geLt(2, _, _)
    val gteLte2: (Int, Int) => Boolean = geLe(2, _, _)
} 
Example 112
Source File: BasicFunctionsTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.feature

import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner
import BasicFunctions._

object BasicFunctionsTest {
  type KVPair = Iterable[(String, Double)]
  val one: KVPair = Iterable(("", 1.0))
}

@RunWith(classOf[BlockJUnit4ClassRunner])
class BasicFunctionsTest {
  import BasicFunctionsTest._

  @Test def testByteToSeq(): Unit   = testAnyValToSeq(1.toByte, one)
  @Test def testShortToSeq(): Unit  = testAnyValToSeq(1.toShort, one)
  @Test def testIntToSeq(): Unit    = testAnyValToSeq(1, one)
  @Test def testLongToSeq(): Unit   = testAnyValToSeq(1L, one)
  @Test def testFloatToSeq(): Unit  = testAnyValToSeq(1f, one)
  @Test def testDoubleToSeq(): Unit = testAnyValToSeq(1d, one)

  @Test def testOptByteToSeq(): Unit   = assertEquals(one, Option(1.toByte).toKv)
  @Test def testOptShortToSeq(): Unit  = assertEquals(one, Option(1.toShort).toKv)
  @Test def testOptIntToSeq(): Unit    = assertEquals(one, Option(1).toKv)
  @Test def testOptLongToSeq(): Unit   = assertEquals(one, Option(1L).toKv)
  @Test def testOptFloatToSeq(): Unit  = assertEquals(one, Option(1f).toKv)
  @Test def testOptDoubleToSeq(): Unit = assertEquals(one, Option(1d).toKv)

  @Test def testNoneByteToSeq(): Unit   = testNoneToSeq[Byte]
  @Test def testNoneShortToSeq(): Unit  = testNoneToSeq[Short]
  @Test def testNoneIntToSeq(): Unit    = testNoneToSeq[Int]
  @Test def testNoneLongToSeq(): Unit   = testNoneToSeq[Long]
  @Test def testNoneFloatToSeq(): Unit  = testNoneToSeq[Float]
  @Test def testNoneDoubleToSeq(): Unit = testNoneToSeq[Double]

  def testAnyValToSeq[A](a: A, kv: KVPair)(implicit f: A => KVPair): Unit = assertEquals(kv, f(a))

  def testNoneToSeq[A](implicit f: A => Double): Unit = assertEquals(Nil, Option.empty[A].toKv)
} 
Example 113
Source File: SparsityTransformsTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.feature

import org.junit.runners.BlockJUnit4ClassRunner
import org.junit.runner.RunWith
import org.junit.Test
import org.junit.Assert._
import scala.util.Random

@RunWith(classOf[BlockJUnit4ClassRunner])
class SparsityTransformsTest {
    import SparsityTransforms._

    
    @Test def testSparifiedDensify() {
        implicit val r = new Random(0)

        (1 to 100) foreach { i => {
            val n = r.nextInt(100)
            val d = Seq.fill(r.nextInt(10000))(r.nextInt(1000))
            val k = Seq.fill(n)(r.nextInt(1000))
            val v = Seq.fill(n)(r.nextDouble())

            val m = k.zip(v).toMap
            val f = m.get _

            assertTrue(s" test $i iterable: ", parIterableInverseLaw(d, k, v))
            assertTrue(s" test $i map: ", mapInverseLaw(d, m))
            assertTrue(s" test $i function: ", fnInverseLaw(d, f))
        }}
    }

    @Test def testDensifyPI() {

        val res = densifyPI(_3to6, Array(4, 6), Array(1, 2), 0)
        assertEquals(Vector(0, 1, 0, 2), res)

        // Show off the cool CBF / functor stuff.  Vector because 3 to 6 is an IndexedSeq.
        assertEquals("scala.collection.immutable.Vector", res.getClass.getCanonicalName)
    }

    @Test def testDensifyPIwithEmptyKeys() {
        val res = densifyPI(_3to6, Seq.empty, Array(1, 2), 0)
        assertEquals(Vector.fill(4)(0), res)
    }

    @Test def testDensifyPIwithEmptyValues() {
        val res = densifyPI(_3to6, Array(4, 6), Seq.empty, 0)
        assertEquals(Vector.fill(4)(0), res)
    }

    @Test def testDensifyF() {
        val f = _map.get _

        val res = densifyFn(_3to6, f, 0)
        assertEquals(Vector(0, 1, 0, 2), res)

        // Show off the cool CBF / functor stuff.  Vector because 3 to 6 is an IndexedSeq.
        assertEquals("scala.collection.immutable.Vector", res.getClass.getCanonicalName)
    }

    @Test def testDensifyMap() {
        val res = densifyMap(_3to6, _map, 0)
        assertEquals(Vector(0, 1, 0, 2), res)

        // Show off the cool CBF / functor stuff.  Vector because 3 to 6 is an IndexedSeq.
        assertEquals("scala.collection.immutable.Vector", res.getClass.getCanonicalName)
    }
} 
Example 114
Source File: FactoryImportedModelTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.factory

import com.eharmony.aloha.audit.impl.OptionAuditor
import com.eharmony.aloha.factory.ex.{AlohaFactoryException, RecursiveModelDefinitionException}
import com.eharmony.aloha.semantics.NoSemantics
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

@RunWith(classOf[BlockJUnit4ClassRunner])
class FactoryImportedModelTest {
  private[this] val factory = ModelFactory.defaultFactory(NoSemantics[Any](), OptionAuditor[Int]())

  @Test(expected = classOf[RecursiveModelDefinitionException])
  def test1CycleDetected() {
    factory.fromResource("com/eharmony/aloha/factory/cycle1_A.json").get
  }

  @Test(expected = classOf[RecursiveModelDefinitionException])
  def test2CycleDetected() {
    factory.fromResource("com/eharmony/aloha/factory/cycle2_A.json").get
  }

  @Test(expected = classOf[RecursiveModelDefinitionException])
  def test3CycleDetected() {
    factory.fromResource("com/eharmony/aloha/factory/cycle3_A.json").get
  }

  @Test def test1LevelSuccessDefault() {
    val m = factory.fromResource("com/eharmony/aloha/factory/success_1_level_default.json").get
    assertEquals(Option(1), m(null))
  }

  @Test def test1LevelSuccessVfs1() {
    val m = factory.fromResource("com/eharmony/aloha/factory/success_1_level_vfs1.json").get
    assertEquals(Option(3), m(null))
  }

  @Test def test1LevelSuccessVfs2() {
    val m = factory.fromResource("com/eharmony/aloha/factory/success_1_level_vfs2.json").get
    assertEquals(Option(4), m(null))
  }

  @Test def test1LevelSuccessFile() {
    val m = factory.fromResource("com/eharmony/aloha/factory/success_1_level_file.json").get
    assertEquals(Option(2), m(null))
  }

  @Test def test1LevelAppropriateFailureWithDefaultProtocol() {
    try {
      factory.fromResource("com/eharmony/aloha/factory/bad_reference_default.json").get
      fail()
    }
    catch {
      case e: AlohaFactoryException => assertTrue(e.getMessage.startsWith("Couldn't resolve VFS2 file"))
      case e: Exception => fail()
    }
  }

  @Test def test1LevelAppropriateFailureWithVfs1() {
    try {
      factory.fromResource("com/eharmony/aloha/factory/bad_reference_vfs1.json").get
      fail()
    }
    catch {
      case e: AlohaFactoryException => assertTrue(e.getMessage, e.getMessage.startsWith("Couldn't resolve VFS1 file"))
      case e: Exception => fail()
    }
  }

  @Test def test1LevelAppropriateFailureWithVfs2() {
    try {
      factory.fromResource("com/eharmony/aloha/factory/bad_reference_vfs2.json").get
      fail()
    }
    catch {
      case e: AlohaFactoryException => assertTrue(e.getMessage.startsWith("Couldn't resolve VFS2 file"))
      case e: Exception => fail()
    }
  }

  @Test def test1LevelApproprivateFailureWithFile() {
    try {
      factory.fromResource("com/eharmony/aloha/factory/bad_reference_file.json").get
      fail()
    }
    catch {
      case e: AlohaFactoryException => assertTrue(e.getMessage.startsWith("Couldn't get JSON for file"))
      case e: Exception => fail()
    }
  }
} 
Example 115
Source File: FormatsTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.factory

import com.eharmony.matching.notaloha.AnEnum
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner
import spray.json.DefaultJsonProtocol.jsonFormat1
import spray.json._


object FormatsTest {
  case class GenEnumPossessor[E <: Enum[E]](value: E)
  case class EnumPossessor(value: AnEnum)
  implicit val AnEnumFormat = JavaJsonFormats.enumFormat(classOf[AnEnum])
  implicit val EnumPossessorFormat: RootJsonFormat[EnumPossessor] = jsonFormat1(EnumPossessor)
}

@RunWith(classOf[BlockJUnit4ClassRunner])
class FormatsTest {
  import FormatsTest._

  @Test(expected = classOf[DeserializationException]) def testEnumFormatValue1(): Unit =
    """{ "value": "VALUE_1" }""".parseJson.convertTo[EnumPossessor]

  @Test def testEnumFormatValue2(): Unit =
    assertEquals(EnumPossessor(AnEnum.VALUE_2), """{ "value": "VALUE_2" }""".parseJson.convertTo[EnumPossessor])

  @Test def testEnumFormatValue3(): Unit =
    assertEquals(EnumPossessor(AnEnum.VALUE_3), """{ "value": "VALUE_3" }""".parseJson.convertTo[EnumPossessor])

  @Test def testGenEnumFormatValue3(): Unit = {
    val clas = Class.forName(classOf[AnEnum].getName)
    implicit val ge = geFormat(clas)
    val v = """{ "value": "VALUE_3" }""".parseJson.convertTo(ge)
    assertEquals(GenEnumPossessor(AnEnum.VALUE_3), v)
  }

  def geFormat[E <: Enum[E]](clas: Class[_]): RootJsonFormat[GenEnumPossessor[E]] = {
    implicit val ef = Formats.enumFormat(clas.asInstanceOf[Class[E]])
    jsonFormat1(GenEnumPossessor[E])
  }
} 
Example 116
Source File: ErrorModelTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models

import com.eharmony.aloha.ModelSerializationTestHelper
import com.eharmony.aloha.audit.impl.OptionAuditor
import com.eharmony.aloha.audit.impl.tree.RootedTreeAuditor
import com.eharmony.aloha.factory.ModelFactory
import com.eharmony.aloha.id.ModelId
import com.eharmony.aloha.semantics.NoSemantics
import org.junit.Assert.{assertEquals, assertNotNull, assertTrue}
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

@RunWith(classOf[BlockJUnit4ClassRunner])
class ErrorModelTest extends ModelSerializationTestHelper {

  private val factory = ModelFactory.defaultFactory(NoSemantics[Unit](), OptionAuditor[Byte]())

  @Test def test1() {
    val em = ErrorModel(ModelId(), Seq("There should be a valid user ID.  Couldn't find one...", "blah blah"), RootedTreeAuditor.noUpperBound[Byte]())
    val s = em(null)
    assertNotNull(s)
    assertTrue(s.value.isEmpty)
  }

  @Test def testEmptyErrors() {

    val json =
      """
        |{
        |  "modelType": "Error",
        |  "modelId": { "id": 0, "name": "" }
        |}
      """.stripMargin

    val m1 = factory.fromString(json)
    assertTrue(m1.isSuccess)

    val json2 =
      """
        |{
        |  "modelType": "Error",
        |  "modelId": { "id": 0, "name": "" },
        |  "errors": []
        |}
      """.stripMargin


    val m2 = factory.fromString(json2)
    assertTrue(m2.isSuccess)
  }

  @Test def testSerialization(): Unit = {
    val m = ErrorModel(ModelId(2, "abc"), Seq("def", "ghi"), OptionAuditor[Byte]())
    val m1 = serializeDeserializeRoundTrip(m)
    assertEquals(m, m1)
  }
} 
Example 117
Source File: BigModelParseTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models.reg

import com.eharmony.aloha.models.reg.json.RegressionModelJson
import spray.json.pimpString
import com.eharmony.aloha.io.StringReadable
import org.junit.runners.BlockJUnit4ClassRunner
import org.junit.runner.RunWith
import org.junit.Test
import org.junit.Assert._
import com.eharmony.aloha.util.{Logging, Timing}

@RunWith(classOf[BlockJUnit4ClassRunner])
class BigModelParseTest extends RegressionModelJson with Timing with Logging {

    
    @Test def testBigJsonParsedToAstForRegModel() {
        val ((s, data), t) = time(getBigZippedData("/com/eharmony/aloha/models/reg/semi_cleaned_big_model.json.gz"))
        assertTrue(s"Should take less than 10 seconds to parse, took $t", t < 10)

        assertEquals("file lines", 184846, scala.io.Source.fromString(s).getLines().size)
        assertEquals("Features", 94, data.features.size)
        assertEquals("First order weights", 874, data.weights.size)
        assertEquals("Higher order weights", 30598, data.higherOrderFeatures.map(_.size).getOrElse(0))
        assertEquals("spline size", 341, data.spline.map(_.knots.size).getOrElse(0))

        debug("file lines: 184846, features: 94, first order weights: 874, higher order weights: 30598, spline size: 341")
    }

    private[this] def getBigZippedData(resourcePath: String) = {
        val s = StringReadable.gz.fromResource(resourcePath)
        (s, s.parseJson.convertTo[RegData])
    }
} 
Example 118
Source File: PolynomialEvaluationAlgoTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models.reg

import com.eharmony.aloha.audit.impl.OptionAuditor
import com.eharmony.aloha.factory.ModelFactory
import com.eharmony.aloha.reflect.RefInfo
import com.eharmony.aloha.semantics.Semantics
import com.eharmony.aloha.semantics.func.{GenAggFunc, GenFunc, GeneratedAccessor}
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner


@RunWith(classOf[BlockJUnit4ClassRunner])
class PolynomialEvaluationAlgoTest {
  private[this] val expected = (1 << 7) - 1 // 0111 1111
  private[this] val tolerance = 1.0e-6

  private[this] val accessor = (key: String) => (_:Map[String, String]).get(key).map(k => Seq((k, 1.0))).getOrElse(Nil)
  private[this] val semantics = new Semantics[Map[String, String]] {
    def close(): Unit = {}
    def refInfoA: RefInfo[Map[String, String]] = RefInfo[Map[String, String]]
    def accessorFunctionNames = Nil
    def createFunction[B: RefInfo](codeSpec: String, default: Option[B]): Either[Seq[String], GenAggFunc[Map[String, String], B]] = {
      val acc = GeneratedAccessor(codeSpec, accessor(codeSpec))
      val f = GenFunc.f1(acc)(codeSpec, identity)
      Right(f.asInstanceOf[GenAggFunc[Map[String, String], B]])
    }
  }

  private[this] val factory = ModelFactory.defaultFactory(semantics, OptionAuditor[Double]())

  @Test def testManualPolyEval() {
    val x = IndexedSeq(
      Seq(("intercept",          1.0)),
      Seq(("female_country=1", 1.0)),
      Seq(("male_country=2",   1.0)),
      Seq(("user_gender=MALE",   1.0)),
      Seq(("cand_gender=FEMALE", 1.0))
    )

    val weightPaths = Map[Map[String, Int], Double](
      Map("intercept"        -> 0                           ) -> (1 << 0),
      Map("female_country=1" -> 1                           ) -> (1 << 1),
      Map("male_country=2"   -> 2                           ) -> (1 << 2),
      Map("user_gender=MALE" -> 3                           ) -> (1 << 3),
      Map("female_country=1" -> 1, "user_gender=MALE"   -> 3) -> (1 << 4),
      Map("female_country=1" -> 1, "cand_gender=FEMALE" -> 4) -> (1 << 5),
      Map("male_country=2"   -> 2, "user_gender=MALE"   -> 3) -> (1 << 6)
    )

    val w = (PolynomialEvaluator.builder ++= weightPaths).result()

    val y = w at x
    assertEquals(expected, y, tolerance)
    assertEquals(weightPaths.values.sum, y, tolerance)
  }

  
  @Test def testJsonParsedPolyEval() {
    val jStr =
      """
        |{
        |  "modelType": "Regression",
        |  "modelId": { "id": 0, "name": "" },
        |  "features": {
        |    "intercept": "intercept",
        |    "female_country": "female_country",
        |    "male_country": "male_country",
        |    "user_gender": "user_gender",
        |    "cand_gender": "cand_gender"
        |  },
        |  "weights": {
        |    "intercept": 1,
        |    "female_country=1": 2,
        |    "male_country=2": 4,
        |    "user_gender=MALE": 8
        |  },
        |  "higherOrderFeatures": [
        |    { "features": { "female_country": ["female_country=1"], "user_gender": ["user_gender=MALE"] },   "wt": 16 },
        |    { "features": { "female_country": ["female_country=1"], "cand_gender": ["cand_gender=FEMALE"] }, "wt": 32 },
        |    { "features": { "male_country":   ["male_country=2"],   "user_gender": ["user_gender=MALE"] },   "wt": 64 }
        |  ]
        |}
      """.stripMargin.trim

    val m = factory.fromString(jStr).get

    val x = Map(
      "intercept" -> "",
      "female_country" -> "=1",
      "male_country" -> "=2",
      "user_gender" -> "=MALE",
      "cand_gender" -> "=FEMALE"
    )

    val score = m(x)
    assertEquals(expected, score.get, tolerance)
  }
} 
Example 119
Source File: ConstantModelTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models

import com.eharmony.aloha.ModelSerializationTestHelper
import com.eharmony.aloha.audit.impl.OptionAuditor
import com.eharmony.aloha.id.ModelId
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner


@RunWith(classOf[BlockJUnit4ClassRunner])
class ConstantModelTest extends ModelSerializationTestHelper {

  @Test def testSerialization(): Unit = {
    val m = ConstantModel(Option(1), ModelId(2, "abc"), OptionAuditor[Int]())
    val m1 = serializeDeserializeRoundTrip(m)
    assertEquals(m, m1)

    val m2 = ConstantModel(None: Option[String], ModelId(3, "abc"), OptionAuditor[String]())
    val m3 = serializeDeserializeRoundTrip(m2)
    assertEquals(m2, m3)
  }
} 
Example 120
Source File: ConstantModelParserTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models

import com.eharmony.aloha.audit.impl.OptionAuditor
import com.eharmony.aloha.factory.ModelFactory
import com.eharmony.aloha.factory.ex.AlohaFactoryException
import com.eharmony.aloha.semantics.NoSemantics
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

@RunWith(classOf[BlockJUnit4ClassRunner])
class ConstantModelParserTest {

  private val factory = ModelFactory.defaultFactory(NoSemantics[String](), OptionAuditor[Int]())


  @Test def testValueOnly() {
    val js =
      """
        |{
        |  "modelType": "Constant",
        |  "modelId": {"id": 0, "name": ""},
        |  "value": 1
        |}
      """.stripMargin

    val m = factory.fromString(js).get
    val s = m(null)
    assertEquals(Option(1), s)
  }

  @Test(expected = classOf[Exception])
  def testNoOutputSpecified() {
    val js =
      """
        |{
        |  "modelType": "Constant",
        |  "modelId": {"id": 0, "name": ""}
        |}
      """.stripMargin

    val m = factory.fromString(js)
    m.get
  }

  @Test(expected = classOf[Exception])
  def testNoModelIdSpecified() {
    val js =
      """
        |{
        |  "modelType": "Constant",
        |  "value": 1
        |}
      """.stripMargin

    val m = factory.fromString(js)
    m.get
  }

  @Test(expected = classOf[AlohaFactoryException])
  def testNothingSpecified() {
    val js =
      """
        |{
        |}
      """.stripMargin

    val m = factory.fromString(js)
    m.get
  }
} 
Example 121
Source File: ErrorModelParserTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models

import com.eharmony.aloha.audit.impl.OptionAuditor
import com.eharmony.aloha.factory.ModelFactory
import com.eharmony.aloha.semantics.NoSemantics
import org.junit.Assert.assertTrue
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

@RunWith(classOf[BlockJUnit4ClassRunner])
class ErrorModelParserTest {
    private val factory = ModelFactory.defaultFactory(NoSemantics[String](), OptionAuditor[Int]())

  @Test def testErrorsFieldMissing() {
    val js =
      """
        |{
        |  "modelType": "Error",
        |  "modelId": {"id":0, "name": ""}
        |}
      """.stripMargin

    val m = factory.fromString(js)
    assertTrue(m.isSuccess)
  }

  @Test def test0Errors() {
    val js =
      """
        |{
        |  "modelType": "Error",
        |  "modelId": {"id":0, "name": ""},
        |  "errors": []
        |}
      """.stripMargin

    val m = factory.fromString(js)
    assertTrue(m.isSuccess)
  }

  @Test def test1Error() {
    val js =
      """
        |{
        |  "modelType": "Error",
        |  "modelId": {"id":0, "name": ""},
        |  "errors": [
        |    "error 1"
        |  ]
        |}
      """.stripMargin

    val m = factory.fromString(js)
    assertTrue(m.isSuccess)
  }

  @Test def test2Errors() {
    val js =
      """
        |{
        |  "modelType": "Error",
        |  "modelId": {"id":0, "name": ""},
        |  "errors": [
        |    "error 1",
        |    "error 2"
        |  ]
        |}
      """.stripMargin

    val m = factory.fromString(js)
    assertTrue(m.isSuccess)
  }
} 
Example 122
Source File: PkgTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha

import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner


@RunWith(classOf[BlockJUnit4ClassRunner])
class PkgTest {
    @Test def testPkgLocationOk(): Unit = {
        assertEquals("com.eharmony.aloha", pkgName)
    }

    
    @Test def testVersionFormatOk(): Unit = {
        val ok = """(\d+)\.(\d+)\.(\d+)(-(SNAPSHOT))?""".r
        version match {
            case ok(major, minor, fix, _, snapshot) => ()
            case notOk => fail(s"Bad version format: $notOk")
        }
    }
} 
Example 123
Source File: HeaderTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.csv

import com.eharmony.aloha.dataset.RowCreatorBuilder
import com.eharmony.aloha.semantics.compiled.CompiledSemantics
import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler
import com.eharmony.aloha.semantics.compiled.plugin.csv._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner
import org.junit.Assert.assertEquals

import scala.concurrent.ExecutionContext.Implicits.global


  private def dsJson(encoding: String) =
    s"""
       |{
       |  "imports": [],
       |  "separator": ",",
       |  "nullValue": "null",
       |  "encoding": "$encoding",
       |  "features": [
       |    { "spec": "1 to 3", "type": "int", "size": 3, "name": "vec" },
       |    { "spec": "\\"some_string_value\\"", "type": "string", "name": "str" },
       |    { "spec": "4", "type": "double", "name": "doub" },
       |    { "spec": "true", "type": "boolean", "name": "bool" },
       |    { "spec": "com.eharmony.matching.notaloha.AnEnum.VALUE_2",
       |      "type": "enum",
       |      "enumClass": "com.eharmony.matching.notaloha.AnEnum",
       |      "name": "enum"
       |    }
       |  ]
       |}
     """.stripMargin

  private def csvRowCreator(encoding: String) = {
    val json = dsJson(encoding)
    val plugin = CompiledSemanticsCsvPlugin()
    val semantics = CompiledSemantics(TwitterEvalCompiler(classCacheDir = None), plugin, Nil)
    val sb = RowCreatorBuilder(semantics, List(CsvRowCreator.Producer[CsvLine]()))
    sb.fromString(json).get
  }

  // Since dsJson doesn't rely on any input data, this can be anything, including null.
  private val EmptyLine: CsvLineImpl = null
} 
Example 124
Source File: CsvTypesTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.semantics.compiled.plugin.csv

import org.junit.Test
import org.junit.runners.BlockJUnit4ClassRunner
import org.junit.runner.RunWith
import org.junit.Assert._


@RunWith(classOf[BlockJUnit4ClassRunner])
class CsvTypesTest {
    @Test def testNumTypesCorrect() {
        assertEquals("Wrong number of types found in CsvTypes", 28, CsvTypes.values.size)
    }

    @Test def testTypeMethodCorrespondence() {
        val typeNames = CsvTypes.values.map(_.toString).toSet
        val methodNames = classOf[CsvLine].getDeclaredMethods.map(_.getName).toSet

        val typesWithoutMethods = typeNames -- methodNames
        val methodsWithoutTypes = methodNames -- typeNames

        assertEquals(s"The following types in CsvTypes seem to be missing methods in CsvLine: ${typesWithoutMethods.mkString("{", ", ", "}" )}", 0, typesWithoutMethods.size)
        assertEquals(s"The following methods in CsvLine don't seem to have associated types in CsvTypes: ${methodsWithoutTypes.mkString("{", ", ", "}" )}", 0, methodsWithoutTypes.size)
    }
} 
Example 125
Source File: CompiledSemanticsTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.semantics.compiled

import java.{lang => jl}

import com.eharmony.aloha.FileLocations
import com.eharmony.aloha.reflect.RefInfo
import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

import scala.concurrent.ExecutionContext.Implicits.global
import scala.language.implicitConversions

@RunWith(classOf[BlockJUnit4ClassRunner])
class CompiledSemanticsTest {
    private[this] val compiler = TwitterEvalCompiler(classCacheDir = Option(FileLocations.testGeneratedClasses))

    @Test def test0() {
        val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq())
        val f = s.createFunction[Int]("List(${five:-5L}).sum.toInt").right.get
        val x1 = Map("five" -> 1L)
        val x2 = Map.empty[String, Long]
        assertEquals(1, f(x1))
        assertEquals(5, f(x2))
    }

    @Test def test1() {
        val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq())
        val f = s.createFunction[Int]("List(${one}, ${two}, ${three}).sum.toInt", Option(Int.MinValue)).right.get
        val x1 = Map[String, Long]("one" -> 2, "two" -> 4, "three" -> 6)
        val y1 = f(x1)
        assertEquals(12, y1)
    }

    @Test def test2() {
        val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq())
        val f = s.createFunction[Double]("${user.inboundComm} / ${user.pageViews}.toDouble", Some(Double.NaN)).right.get
        val x1 = Map[String, Long]("user.inboundComm" -> 5, "user.pageViews" -> 10)
        val x2 = Map[String, Long]("user.inboundComm" -> 5)
        val y1 = f(x1)
        val y2 = f(x2)
        assertEquals(0.5, y1, 1.0e-6)
        assertEquals(Double.NaN, y2, 0)
    }

    @Test def test3() {
        val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq())
        val f = s.createFunction[Long]("new util.Random(0).nextLong").right.get
        val y1 = f(null)
        assertEquals(-4962768465676381896L, y1)
    }

    @Test def testNullDefaultOnExistingValue() {
        val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq("com.eharmony.aloha.semantics.compiled.StaticFuncs._"))
        val f = s.createFunction[Long]("f(${one})").left.map(_.foreach(println)).right.get
        val y1 = f(Map("one" -> 1))
        assertEquals(18, y1)
    }

    
    @Test
    def testNullDefaultOnNonMissingPrimitiveValue() {
        val s = CompiledSemantics(compiler, MapStringLongPlugin, Seq("com.eharmony.aloha.semantics.compiled.StaticFuncs._"))
        var errors: Seq[String] = Nil
        val f = s.createFunction[Long]("f(${missing:-null}.asInstanceOf[java.lang.Long])").
            left.map(e => errors = e).
            right.get
        val y1 = f(Map("missing" -> 13))
        assertEquals("Should process correctly when defaulting to null", 18, y1)
        assertEquals("No errors should appear", 0, errors.size)
    }


    private[this] object MapStringLongPlugin extends CompiledSemanticsPlugin[Map[String, Long]] {
        def refInfoA = RefInfo[Map[String, Long]]
        def accessorFunctionCode(spec: String) = {
            val required = Seq("user.inboundComm", "one", "two", "three")
            spec match {
                case s if required contains s  => Right(RequiredAccessorCode(Seq("(_:Map[String, Long]).apply(\"" + spec + "\")")))
                case _                         => Right(OptionalAccessorCode(Seq("(_:Map[String, Long]).get(\"" + spec + "\")")))
            }
        }
    }
}

object StaticFuncs {
    def f(a: jl.Long): Long = if (null == a) 13 else 18

    implicit def doubletoJlDouble(d: Double): java.lang.Double = java.lang.Double.valueOf(d)
} 
Example 126
Source File: PostgresJsonMarshallerTest.scala    From sundial   with MIT License 5 votes vote down vote up
package dao.postgres.marshalling

import com.fasterxml.jackson.databind.PropertyNamingStrategy.SNAKE_CASE
import com.hbc.svc.sundial.v2.models.NotificationOptions
import model.{EmailNotification, PagerdutyNotification, Team}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatestplus.play.PlaySpec
import util.Json

@RunWith(classOf[JUnitRunner])
class PostgresJsonMarshallerTest extends PlaySpec {

  private val postgresJsonMarshaller = new PostgresJsonMarshaller()

  private val objectMapper = Json.mapper()
  objectMapper.setPropertyNamingStrategy(SNAKE_CASE)
  //  objectMapper.setVisibility(PropertyAccessor.FIELD,Visibility.ANY)

  "PostgresJsonMarshaller" should {

    "correctly deserialize a json string into Seq[Team]" in {
      val json =
        """
          | [{
          |   "name" : "teamName",
          |   "email" : "teamEmail",
          |   "notify_action": "on_state_change_and_failures"
          | }]
        """.stripMargin

      val expectedTeams: Seq[Team] =
        Vector(Team("teamName", "teamEmail", "on_state_change_and_failures"))
      val actualTeams = postgresJsonMarshaller.toTeams(json)
      actualTeams must be(expectedTeams)
    }

    "correctly serialise a Seq[Team] in a json string" in {
      val expectedJson =
        """
          | [{
          |   "name" : "teamName",
          |   "email" : "teamEmail",
          |   "notify_action": "on_state_change_and_failures"
          | }]
        """.stripMargin
      val expectedTeams: Seq[Team] =
        Vector(Team("teamName", "teamEmail", "on_state_change_and_failures"))
      val actualJson = postgresJsonMarshaller.toJson(expectedTeams)
      objectMapper.readTree(actualJson) must be(
        objectMapper.readTree(expectedJson))
    }

    "correctly deserialize a json string into Seq[Notification]" in {

      val json =
        """
          |[{"name":"name","email":"email","notify_action":"on_state_change_and_failures", "type": "email"},{"service_key":"service-key","api_url":"http://google.com", "type": "pagerduty","num_consecutive_failures":1}]
        """.stripMargin

      val notifications = Vector(
        EmailNotification(
          "name",
          "email",
          NotificationOptions.OnStateChangeAndFailures.toString),
        PagerdutyNotification("service-key", "http://google.com", 1)
      )

      val actualNotifications = postgresJsonMarshaller.toNotifications(json)

      actualNotifications must be(notifications)

    }

    "correctly serialise a Seq[Notification] in a json string" in {

      val json =
        """
          |[{"name":"name","email":"email","notify_action":"on_state_change_and_failures", "type": "email"},{"service_key":"service-key","api_url":"http://google.com", "type": "pagerduty","num_consecutive_failures":1}]
        """.stripMargin

      val notifications = Vector(
        EmailNotification(
          "name",
          "email",
          NotificationOptions.OnStateChangeAndFailures.toString),
        PagerdutyNotification("service-key", "http://google.com", 1)
      )

      println(s"bla1: ${postgresJsonMarshaller.toJson(notifications)}")
      println(s"bla2: ${objectMapper.writeValueAsString(notifications)}")

      objectMapper.readTree(json) must be(
        objectMapper.readTree(postgresJsonMarshaller.toJson(notifications)))

    }

  }

} 
Example 127
Source File: CronScheduleSpec.scala    From sundial   with MIT License 5 votes vote down vote up
package model

import java.text.ParseException
import java.util.GregorianCalendar

import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatestplus.play.PlaySpec

@RunWith(classOf[JUnitRunner])
class CronScheduleSpec extends PlaySpec {

  "Cron scheduler" should {

    "successfully parse cron entry for 10pm every day" in {
      val cronSchedule = CronSchedule("0", "22", "*", "*", "?")
      val date = new GregorianCalendar(2015, 10, 5, 21, 0).getTime
      val expectedNextDate = new GregorianCalendar(2015, 10, 5, 22, 0).getTime
      val nextDate = cronSchedule.nextRunAfter(date)
      nextDate must be(expectedNextDate)
    }

    "Throw exception on creation if cron schedlue is invalid" in {
      intercept[ParseException] {
        CronSchedule("0", "22", "*", "*", "*")
      }
    }
  }

} 
Example 128
Source File: PluginTest.scala    From marathon-vault-plugin   with MIT License 5 votes vote down vote up
package com.avast.marathon.plugin.vault

import java.util.concurrent.TimeUnit

import com.bettercloud.vault.{Vault, VaultConfig}
import org.junit.runner.RunWith
import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.junit.JUnitRunner

import scala.collection.JavaConverters._
import scala.concurrent.Await
import scala.concurrent.duration.Duration

@RunWith(classOf[JUnitRunner])
class PluginTest extends FlatSpec with Matchers {

  private lazy val marathonUrl = s"http://${System.getProperty("marathon.host")}:${System.getProperty("marathon.tcp.8080")}"
  private lazy val mesosSlaveUrl = s"http://${System.getProperty("mesos-slave.host")}:${System.getProperty("mesos-slave.tcp.5051")}"
  private lazy val vaultUrl = s"http://${System.getProperty("vault.host")}:${System.getProperty("vault.tcp.8200")}"

  it should "read existing shared secret" in {
    check("SECRETVAR", env => deployWithSecret("testappjson", env, "/test@testKey")) { envVarValue =>
      envVarValue shouldBe "testValue"
    }
  }

  it should "read existing private secret" in {
    check("SECRETVAR", env => deployWithSecret("testappjson", env, "test@testKey")) { envVarValue =>
      envVarValue shouldBe "privateTestValue"
    }
  }

  it should "read existing private secret from application in folder" in {
    check("SECRETVAR", env => deployWithSecret("folder/testappjson", env, "test@testKey")) { envVarValue =>
      envVarValue shouldBe "privateTestFolderValue"
    }
  }

  it should "fail when using .. in secret" in {
    intercept[RuntimeException] {
      check("SECRETVAR", env => deployWithSecret("folder/testappjson", env, "test/../test@testKey"), java.time.Duration.ofSeconds(1)) { envVarValue =>
        envVarValue shouldNot be("privateTestFolderValue")
      }
    }
  }

  private def deployWithSecret(appId: String, envVarName: String, secret: String): String = {
    val json = s"""{ "id": "$appId","cmd": "${EnvAppCmd.create(envVarName)}","env": {"$envVarName": {"secret": "pwd"}},"secrets": {"pwd": {"source": "$secret"}}}"""

    val marathonResponse = new MarathonClient(marathonUrl).put(appId, json)
    appId
  }

  private def check(envVarName: String, deployApp: String => String, timeout: java.time.Duration = java.time.Duration.ofSeconds(30))(verifier: String => Unit): Unit = {
    val client = new MarathonClient(marathonUrl)
    val eventStream = new MarathonEventStream(marathonUrl)

    val vaultConfig = new VaultConfig().address(vaultUrl).token("testroottoken").build()
    val vault = new Vault(vaultConfig)
    vault.logical().write("secret/shared/test", Map[String, AnyRef]("testKey" -> "testValue").asJava)
    vault.logical().write("secret/private/testappjson/test", Map[String, AnyRef]("testKey" -> "privateTestValue").asJava)
    vault.logical().write("secret/private/folder/testappjson/test", Map[String, AnyRef]("testKey" -> "privateTestFolderValue").asJava)

    val appId = deployApp(envVarName)
    val appCreatedFuture = eventStream.when(_.eventType.contains("deployment_success"))
    Await.result(appCreatedFuture, Duration.create(20, TimeUnit.SECONDS))

    val agentClient = MesosAgentClient(mesosSlaveUrl)
    val state = agentClient.fetchState()

    try {
      val envVarValue = agentClient.waitForStdOutContentsMatch(envVarName, state.frameworks(0).executors(0),
        o => EnvAppCmd.extractEnvValue(envVarName, o),
        timeout)
      verifier(envVarValue)
    } finally {
      client.delete(appId)
      val appRemovedFuture = eventStream.when(_.eventType.contains("deployment_success"))
      Await.result(appRemovedFuture, Duration.create(20, TimeUnit.SECONDS))
      eventStream.close()
    }
  }
} 
Example 129
Source File: MinMaxActorSpec.scala    From coral   with Apache License 2.0 5 votes vote down vote up
package io.coral.actors.transform

import akka.actor.{Actor, ActorSystem, Props}
import akka.testkit.{TestProbe, ImplicitSender, TestActorRef, TestKit}
import akka.util.Timeout
import io.coral.actors.CoralActorFactory
import io.coral.api.DefaultModule
import org.json4s.JsonDSL._
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
import scala.concurrent.duration._

@RunWith(classOf[JUnitRunner])
class MinMaxActorSpec(_system: ActorSystem)
	extends TestKit(_system)
	with ImplicitSender
	with WordSpecLike
	with Matchers
	with BeforeAndAfterAll {
	implicit val timeout = Timeout(100.millis)
	implicit val formats = org.json4s.DefaultFormats
	implicit val injector = new DefaultModule(system.settings.config)
	def this() = this(ActorSystem("ZscoreActorSpec"))

	override def afterAll() {
		TestKit.shutdownActorSystem(system)
	}

	"A MinMaxActor" must {
		val createJson = parse(
			"""{ "type": "minmax", "params": { "field": "field1", "min": 10.0, "max": 13.5 }}"""
				.stripMargin).asInstanceOf[JObject]

		implicit val injector = new DefaultModule(system.settings.config)

		val props = CoralActorFactory.getProps(createJson).get
		val threshold = TestActorRef[MinMaxActor](props)

		// subscribe the testprobe for emitting
		val probe = TestProbe()
		threshold.underlyingActor.emitTargets += probe.ref

		"Emit the minimum when lower than the min" in {
			val json = parse( """{"field1": 7 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{ "field1": 10.0 }"""))
		}

		"Emit the maximum when higher than the max" in {
			val json = parse( """{"field1": 15.3 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{"field1": 13.5 }"""))
		}

		"Emit the value itself when between the min and the max" in {
			val json = parse( """{"field1": 11.7 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{"field1": 11.7 }"""))
		}

		"Emit object unchanged when key is not present in triggering json" in {
			val json = parse( """{"otherfield": 15.3 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{"otherfield": 15.3 }"""))
		}
	}
} 
Example 130
Source File: ThresholdActorSpec.scala    From coral   with Apache License 2.0 5 votes vote down vote up
package io.coral.actors.transform

import io.coral.actors.CoralActorFactory
import io.coral.api.DefaultModule
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import scala.concurrent.duration._
import akka.actor.ActorSystem
import akka.testkit._
import akka.util.Timeout
import org.json4s._
import org.json4s.jackson.JsonMethods._
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}

@RunWith(classOf[JUnitRunner])
class ThresholdActorSpec(_system: ActorSystem) extends TestKit(_system)
	with ImplicitSender
	with WordSpecLike
	with Matchers
	with BeforeAndAfterAll {
	implicit val timeout = Timeout(100.millis)
	def this() = this(ActorSystem("ThresholdActorSpec"))

	override def afterAll() {
		TestKit.shutdownActorSystem(system)
	}

	"A ThresholdActor" must {
		val createJson = parse(
			"""{ "type": "threshold", "params": { "key": "key1", "threshold": 10.5 }}"""
				.stripMargin).asInstanceOf[JObject]

		implicit val injector = new DefaultModule(system.settings.config)

		// test invalid definition json as well !!!
		val props = CoralActorFactory.getProps(createJson).get
		val threshold = TestActorRef[ThresholdActor](props)

		// subscribe the testprobe for emitting
		val probe = TestProbe()
		threshold.underlyingActor.emitTargets += probe.ref

		"Emit when equal to the threshold" in {
			val json = parse( """{"key1": 10.5}""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{ "key1": 10.5 }"""))
		}

		"Emit when higher than the threshold" in {
			val json = parse( """{"key1": 10.7}""").asInstanceOf[JObject]
			threshold ! json
			probe.expectMsg(parse( """{"key1": 10.7 }"""))
		}

		"Not emit when lower than the threshold" in {
			val json = parse( """{"key1": 10.4 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectNoMsg()
		}

		"Not emit when key is not present in triggering json" in {
			val json = parse( """{"key2": 10.7 }""").asInstanceOf[JObject]
			threshold ! json
			probe.expectNoMsg()
		}
	}
} 
Example 131
Source File: BootConfigSpec.scala    From coral   with Apache License 2.0 5 votes vote down vote up
package io.coral.api

import org.junit.runner.RunWith
import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, WordSpecLike}
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class BootConfigSpec
	extends WordSpecLike
	with BeforeAndAfterAll
	with BeforeAndAfterEach {
	"A Boot program actor" should {
		"Properly process given command line arguments for api and akka ports" in {
			val commandLine = CommandLineConfig(apiPort = Some(1234), akkaPort = Some(5345))
			val actual: CoralConfig = io.coral.api.Boot.getFinalConfig(commandLine)
			assert(actual.akka.remote.nettyTcpPort == 5345)
			assert(actual.coral.api.port == 1234)
		}

		"Properly process a given configuration file through the command line" in {
			val configPath = getClass().getResource("bootconfigspec.conf").getFile()
			val commandLine = CommandLineConfig(config = Some(configPath), apiPort = Some(4321))
			val actual: CoralConfig = io.coral.api.Boot.getFinalConfig(commandLine)
			// Overriden in bootconfigspec.conf
			assert(actual.akka.remote.nettyTcpPort == 6347)
			// Overridden in command line parameter
			assert(actual.coral.api.port == 4321)
			// Not overriden in command line or bootconfigspec.conf
			assert(actual.coral.cassandra.port == 9042)
		}
	}
} 
Example 132
Source File: RuntimeStatisticsSpec.scala    From coral   with Apache License 2.0 5 votes vote down vote up
package io.coral.api

import io.coral.TestHelper
import org.junit.runner.RunWith
import org.scalatest.WordSpecLike
import org.scalatest.junit.JUnitRunner
import org.json4s._
import org.json4s.jackson.JsonMethods._

@RunWith(classOf[JUnitRunner])
class RuntimeStatisticsSpec
	extends WordSpecLike {
	"A RuntimeStatistics class" should {
		"Properly sum multiple statistics objects together" in {
			val counters1 = Map(
				(("actor1", "stat1") -> 100L),
				(("actor1", "stat2") -> 20L),
				(("actor1", "stat3") -> 15L))
			val counters2 = Map(
				(("actor2", "stat1") -> 20L),
				(("actor2", "stat2") -> 30L),
				(("actor2", "stat3") -> 40L))
			val counters3 = Map(
				(("actor2", "stat1") -> 20L),
				(("actor2", "stat2") -> 30L),
				(("actor2", "stat3") -> 40L),
				(("actor2", "stat4") -> 12L))
			val stats1 = RuntimeStatistics(1, 2, 3, counters1)
			val stats2 = RuntimeStatistics(2, 3, 4, counters2)
			val stats3 = RuntimeStatistics(4, 5, 6, counters3)

			val actual = RuntimeStatistics.merge(List(stats1, stats2, stats3))

			val expected = RuntimeStatistics(7, 10, 13,
				Map(("actor1", "stat1") -> 100,
					("actor1", "stat2") -> 20,
					("actor1", "stat3") -> 15,
					("actor2", "stat1") -> 20,
					("actor2", "stat2") -> 30,
					("actor2", "stat3") -> 40,
					("actor2", "stat4") -> 12))

			assert(actual == expected)
		}

		"Create a JSON object from a RuntimeStatistics object" in {
			val input = RuntimeStatistics(1, 2, 3,
				Map((("actor1", "stat1") -> 10L),
					(("actor1", "stat2") -> 20L)))

			val expected = parse(
				s"""{
				   |  "totalActors": 1,
				   |  "totalMessages": 2,
				   |  "totalExceptions": 3,
				   |  "counters": {
				   |    "total": {
				   |      "stat1": 10,
				   |      "stat2": 20
				   |    }, "actor1": {
				   |      "stat1": 10,
				   |      "stat2": 20
				   |    }
				   |  }
				   |}
				 """.stripMargin).asInstanceOf[JObject]

			val actual = RuntimeStatistics.toJson(input)
			assert(actual == expected)
		}

		"Create a RuntimeStatistics object from a JSON object" in {
			val input = parse(
				s"""{
				   |  "totalActors": 1,
				   |  "totalMessages": 2,
				   |  "totalExceptions": 3,
				   |  "counters": {
				   |    "total": {
				   |      "stat1": 10,
				   |      "stat2": 20
				   |    }, "actor1": {
				   |      "stat1": 10,
				   |      "stat2": 20
				   |    }
				   |  }
				   |}
				 """.stripMargin).asInstanceOf[JObject]

			val actual = RuntimeStatistics.fromJson(input)

			val expected = RuntimeStatistics(1, 2, 3,
				Map((("actor1", "stat1") -> 10L),
					(("actor1", "stat2") -> 20L)))

			assert(actual == expected)
		}
	}
} 
Example 133
Source File: XmlScoverageReportParserSpec.scala    From sonar-scala   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package com.buransky.plugins.scoverage.xml

import org.scalatest.{FlatSpec, Matchers}
import org.scalatest.junit.JUnitRunner
import org.junit.runner.RunWith
import com.buransky.plugins.scoverage.ScoverageException

@RunWith(classOf[JUnitRunner])
class XmlScoverageReportParserSpec extends FlatSpec with Matchers {
  behavior of "parse file path"

  it must "fail for null path" in {
    the[IllegalArgumentException] thrownBy XmlScoverageReportParser().parse(null.asInstanceOf[String], null)
  }

  it must "fail for empty path" in {
    the[IllegalArgumentException] thrownBy XmlScoverageReportParser().parse("", null)
  }

  it must "fail for not existing path" in {
    the[ScoverageException] thrownBy XmlScoverageReportParser().parse("/x/a/b/c/1/2/3/4.xml", null)
  }
} 
Example 134
Source File: PathUtilSpec.scala    From sonar-scala   with GNU Lesser General Public License v3.0 5 votes vote down vote up
package com.buransky.plugins.scoverage.util

import org.scalatest.{FlatSpec, Matchers}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class PathUtilSpec extends FlatSpec with Matchers {
  
  val osName = System.getProperty("os.name")
  val separator = System.getProperty("file.separator")
  
  behavior of s"splitPath for $osName"
  
  it should "ignore the empty path" in {
    PathUtil.splitPath("") should equal(List.empty[String])
  }

  it should "ignore a separator at the beginning" in {
    PathUtil.splitPath(s"${separator}a") should equal(List("a"))
  }

  it should "work with separator in the middle" in {
    PathUtil.splitPath(s"a${separator}b") should equal(List("a", "b"))
  }
  
  it should "work with an OS dependent absolute path" in {
    if (osName.startsWith("Windows")) {
      PathUtil.splitPath("C:\\test\\2") should equal(List("test", "2"))
    } else {
      PathUtil.splitPath("/test/2") should equal(List("test", "2"))
    }
  }
} 
Example 135
Source File: BasicSimulation.scala    From warp-core   with MIT License 5 votes vote down vote up
package com.workday.warp.adapters

import com.workday.warp.adapters.gatling.{GatlingJUnitRunner, WarpSimulation}
import io.gatling.core.Predef._
import io.gatling.http.Predef._
import org.junit.runner.RunWith
import io.gatling.core.structure.ScenarioBuilder
import io.gatling.http.protocol.HttpProtocolBuilder

@RunWith(classOf[GatlingJUnitRunner])
class BasicSimulation extends WarpSimulation {
  val httpConf: HttpProtocolBuilder = http
    .baseUrl("http://google.com")

  val scn: ScenarioBuilder = scenario("Positive Scenario")
    .exec(
      http("request_1").get("/")
    )

  setUp(scn.inject(atOnceUsers(1)).protocols(httpConf))
} 
Example 136
Source File: BotPluginTestKit.scala    From sumobot   with Apache License 2.0 5 votes vote down vote up
package com.sumologic.sumobot.test.annotated

import akka.actor.ActorSystem
import akka.testkit.{TestKit, TestProbe}
import com.sumologic.sumobot.core.model.{IncomingMessage, InstantMessageChannel, OutgoingMessage, UserSender}
import org.junit.runner.RunWith
import org.scalatest.concurrent.Eventually
import org.scalatest.junit.JUnitRunner
import org.scalatest.{BeforeAndAfterAll, Matchers, WordSpecLike}
import slack.models.User

import scala.concurrent.duration.{FiniteDuration, _}

@RunWith(classOf[JUnitRunner])
abstract class BotPluginTestKit(actorSystem: ActorSystem)
  extends TestKit(actorSystem)
    with WordSpecLike with Eventually with Matchers
    with BeforeAndAfterAll {

  protected val outgoingMessageProbe = TestProbe()
  system.eventStream.subscribe(outgoingMessageProbe.ref, classOf[OutgoingMessage])

  protected def confirmOutgoingMessage(test: OutgoingMessage => Unit, timeout: FiniteDuration = 1.second): Unit = {
    outgoingMessageProbe.expectMsgClass(timeout, classOf[OutgoingMessage]) match {
      case msg: OutgoingMessage =>
        test(msg)
    }
  }

  protected def instantMessage(text: String, user: User = mockUser("123", "jshmoe")): IncomingMessage = {
    IncomingMessage(text, true, InstantMessageChannel("125", user), "1527239216000090", sentBy = UserSender(user))
  }

  protected def mockUser(id: String, name: String): User = {
    User(id, name, None, None, None, None, None, None, None, None, None, None, None, None, None, None)
  }

  protected def send(message: IncomingMessage): Unit = {
    system.eventStream.publish(message)
  }

  override protected def afterAll(): Unit = {
    TestKit.shutdownActorSystem(system)
  }
} 
Example 137
Source File: ModelSerializabilityTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models.vw.jni

import com.eharmony.aloha.ModelSerializabilityTestBase
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner


@RunWith(classOf[BlockJUnit4ClassRunner])
class ModelSerializabilityTest extends ModelSerializabilityTestBase(
  Seq(ModelSerializabilityTest.pkg),
  Seq(
    ".*Test.*",
    ".*\\$.*"
  )
)


object ModelSerializabilityTest {
  def pkg = getClass.getPackage.getName
} 
Example 138
Source File: ControlThrowable.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.neg

import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test

import lacasa.util._

@RunWith(classOf[JUnit4])
class ControlThrowableSpec {
  @Test
  def test1() {
    println(s"ControlThrowableSpec.test1")
    expectError("propagated") {
      """
        class C {
          import scala.util.control.ControlThrowable
          import lacasa.Box._
          def m(): Unit = {
            try {
              val x = 0
              val y = x + 10
              println(s"res: ${x + y}")
            } catch {
              case t: ControlThrowable =>
                println("hello")
                uncheckedCatchControl
            }
          }
        }
      """
    }
  }

  @Test
  def test2() {
    println(s"ControlThrowableSpec.test2")
    expectError("propagated") {
      """
        class C {
          import scala.util.control.ControlThrowable
          def m(): Unit = {
            try {
              throw new ControlThrowable {}
            } catch {
              case t: Throwable =>
                println("hello")
            }
          }
        }
      """
    }
  }

  @Test
  def test3() {
    println(s"ControlThrowableSpec.test3")
    expectError("propagated") {
      """
        class SpecialException(msg: String) extends RuntimeException
        class C {
          import scala.util.control.ControlThrowable
          def m(): Unit = {
            val res = try { 5 } catch {
              case s: SpecialException => println("a")
              case c: ControlThrowable => println("b")
              case t: Throwable => println("c")
            }
          }
        }
      """
    }
  }
} 
Example 139
Source File: CaptureSpec.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.test.plugin.capture

import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test

import lacasa.util._


@RunWith(classOf[JUnit4])
class CaptureSpec {

  @Test
  def test() {
    println(s"CaptureSpec.test")
    expectError("invalid reference to value acc") {
      """
        import lacasa.{Box, Packed}
        import Box._
        import scala.spores._
        class Data {
          var name: String = _
        }
        class Data2 {
          var num: Int = _
          var dat: Data = _
        }
        object Use {
          mkBox[Data] { packed =>
            implicit val acc = packed.access
            val box: packed.box.type = packed.box

            box.open { _.name = "John" }

            mkBox[Data2] { packed2 =>
              implicit val acc2 = packed2.access
              val box2: packed2.box.type = packed2.box

              box2.capture(box)((x, y) => x.dat = y)(spore {
                val localBox = box
                (packedData: Packed[Data2]) =>
                  implicit val accessData = packedData.access

                  localBox.open { x => assert(false) }
              })
            }
          }
        }
      """
    }
  }

} 
Example 140
Source File: BoxOcap.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.neg

import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test

import lacasa.util._

@RunWith(classOf[JUnit4])
class BoxOcapSpec {

  @Test
  def test1() {
    println("neg.BoxOcapSpec.test1")
    expectError("NonOcap") {
      """
        object Global {
          var state = "a"
        }
        class NonOcap {
          def doIt(): Unit = {
            Global.state = "b"
          }
        }
        class Data {
          var arr: Array[Int] = _
        }
        class Test {
          import lacasa.Box._
          import scala.spores._
          def m(): Unit = {
            mkBox[Data] { packed => // ok, Data ocap
              implicit val acc = packed.access
              packed.box.open(spore {
                (d: Data) =>
                  d.arr = Array(0, 1, 2) // ok
                  val obj = new NonOcap  // not ok: cannot inst. non-ocap class
              })
            }
          }
        }
      """
    }
  }

} 
Example 141
Source File: Stack2.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.neg

import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test

import lacasa.util._

@RunWith(classOf[JUnit4])
class Stack2Spec {
  @Test
  def test1() {
    println(s"Stack2Spec.test1")
    expectError("confined") {
      """
        class D { }
        class C {
          import scala.spores._
          import lacasa.Box
          def m(): Unit = {
            Box.mkBox[D] { packed =>
              val fun = () => {
                val acc = packed.access
              }
            }
          }
        }
      """
    }
  }

  @Test
  def test2() {
    println(s"Stack2Spec.test2")
    expectError("propagated") {
      """
        class D { var arr: Array[Int] = _ }
        class C {
          import scala.spores._
          import lacasa.Box
          def m(): Unit = {
            try {
              Box.mkBox[D] { packed =>
                val access = packed.access
              }
            } catch {
              case ct: scala.util.control.ControlThrowable =>
            }
          }
        }
      """
    }
  }

} 
Example 142
Source File: Stack1.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.run

import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

import scala.util.control.ControlThrowable

class Message {
  var arr: Array[Int] = _
}

@RunWith(classOf[JUnit4])
class Stack1Spec {

  import lacasa.Box._

  @Test
  def test1(): Unit = {
    println(s"run.Stack1Spec.test1")
    try {
      mkBox[Message] { packed =>
        implicit val access = packed.access
        packed.box open { msg =>
          msg.arr = Array(1, 2, 3, 4)
        }
      }
    } catch {
      case ct: ControlThrowable =>
        uncheckedCatchControl
        assert(true, "this should not fail!")
    }
  }

} 
Example 143
Source File: Control.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.run

import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

@RunWith(classOf[JUnit4])
class ControlSpec {
  import scala.util.control.ControlThrowable
  import lacasa.Box._

  @Test
  def test1(): Unit = {
    println("run.ControlSpec.test1")
    val res = try { 5 } catch {
      case c: ControlThrowable =>
        throw c
      case t: Throwable =>
        println("hello")
    }
    assert(res == 5, "this should not fail")
  }

} 
Example 144
Source File: BoxSpec.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.test

import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

import lacasa.Box._


class DoesNotHaveNoArgCtor(val num: Int) {
  def incNum = new DoesNotHaveNoArgCtor(num + 1)
}

@RunWith(classOf[JUnit4])
class BoxSpec {

  @Test
  def testMkBoxFor1(): Unit = {
    try {
      mkBoxFor(new DoesNotHaveNoArgCtor(0)) { packed =>
        implicit val access = packed.access
        val box: packed.box.type = packed.box
        box.open { dnh =>
          assert(dnh.num == 0)
        }
      }
    } catch {
      case t: Throwable =>
        
    }
  }

} 
Example 145
Source File: example1.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.test.examples

import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

import scala.concurrent.ExecutionContext
import scala.concurrent.ExecutionContext.Implicits.global

import scala.concurrent.{Future, Promise, Await}
import scala.concurrent.duration._

import scala.spores._

import lacasa.{System, Box, CanAccess, Actor, ActorRef, doNothing}
import Box._


class Message1 {
  var arr: Array[Int] = _
}

class Start {
  var next: ActorRef[Message1] = _
}

class ActorA extends Actor[Any] {
  override def receive(b: Box[Any])
      (implicit acc: CanAccess { type C = b.C }) {
    b.open(spore { x =>
      x match {
        case s: Start =>
          mkBox[Message1] { packed =>
            implicit val access = packed.access
            packed.box open { msg =>
              msg.arr = Array(1, 2, 3, 4)
            }
            s.next.send(packed.box) { doNothing.consume(packed.box) }
          }

        case other => // ..
      }
    })
  }
}

class ActorB(p: Promise[String]) extends Actor[Message1] {
  override def receive(box: Box[Message1])
      (implicit acc: CanAccess { type C = box.C }) {
    // Strings are Safe, and can therefore be extracted from the box.
    p.success(box.extract(_.arr.mkString(",")))
  }
}

@RunWith(classOf[JUnit4])
class Spec {

  @Test
  def test(): Unit = {
    // to check result
    val p: Promise[String] = Promise()

    val sys = System()
    val a = sys.actor[ActorA, Any]
    val b = sys.actor[Message1](new ActorB(p))

    try {
      mkBox[Start] { packed =>
        import packed.access
        val box: packed.box.type = packed.box
        box open { s =>
          s.next = capture(b) // !!! captures `b` within `open`
        }
        a.send(box) { doNothing.consume(packed.box) }
      }
    } catch {
      case t: Throwable =>
        val res = Await.result(p.future, 2.seconds)
        assert(res == "1,2,3,4")
    }
  }
} 
Example 146
Source File: CaptureSpec.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.test.capture

import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

import scala.spores._
import scala.spores.SporeConv._

import lacasa.{Box, Packed}
import Box._


class Data {
  var name: String = _
}

class Data2 {
  var num: Int = _
  var dat: Data = _
}

@RunWith(classOf[JUnit4])
class CaptureSpec {

  @Test
  def test(): Unit = {
    try {
      mkBox[Data] { packed =>
        implicit val acc = packed.access
        val box: packed.box.type = packed.box

        box.open { _.name = "John" }

        mkBox[Data2] { packed2 =>
          implicit val acc2 = packed2.access
          val box2: packed2.box.type = packed2.box

          box2.capture(box)(_.dat = _)(spore { (packedData: Packed[Data2]) =>
            implicit val accessData = packedData.access

            packedData.box.open { d =>
              assert(d.dat.name == "John")
            }
          })
        }
      }
    } catch {
      case t: Throwable =>
    }
  }

} 
Example 147
Source File: box.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.test.uniqueness

import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

import scala.concurrent.ExecutionContext
import scala.concurrent.ExecutionContext.Implicits.global

import scala.concurrent.{Future, Promise, Await}
import scala.concurrent.duration._

import scala.spores._
import scala.spores.SporeConv._

import lacasa.{System, Box, CanAccess, Actor, ActorRef}
import Box._


class C {
  var f: D = null
  //var count = 0
}

class D {
  var g: C = null
}

sealed abstract class Msg
final case class Start() extends Msg
//final case class Repeat(obj: C) extends Msg

class ActorA(next: ActorRef[C]) extends Actor[Msg] {
  def receive(msg: Box[Msg])(implicit access: CanAccess { type C = msg.C }): Unit = {
    // create box with externally-unique object
    mkBox[C] { packed =>
      implicit val acc = packed.access
      val box: packed.box.type = packed.box

      // initialize object in box
      box.open(spore { obj =>
        val d = new D
        d.g = obj
        obj.f = d
      })

      next.send(box)(spore { () => })
    }
  }
}

class ActorB(p: Promise[Boolean]) extends Actor[C] {
  def receive(msg: Box[C])(implicit access: CanAccess { type C = msg.C }): Unit = {
    msg.open(spore { x =>
      val d = x.f
      // check that `d` refers back to `x`
      p.success(d.g == x)
    })
  }
}


@RunWith(classOf[JUnit4])
class Spec {

  @Test
  def test(): Unit = {
    // to check result
    val p: Promise[Boolean] = Promise()

    val sys = System()
    val b = sys.actor[C](new ActorB(p))
    val a = sys.actor[Msg](new ActorA(b))

    try {
      mkBox[Start] { packed =>
        import packed.access
        val box: packed.box.type = packed.box
        a.send(box)(spore { () => })
      }
    } catch {
      case t: Throwable =>
        val res = Await.result(p.future, 2.seconds)
        assert(res)
    }

  }

} 
Example 148
Source File: actor.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.test

import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4

import scala.concurrent.ExecutionContext
import scala.concurrent.ExecutionContext.Implicits.global

import scala.concurrent.{Future, Promise, Await}
import scala.concurrent.duration._

import scala.spores._
import scala.spores.SporeConv._

import lacasa.{System, Box, CanAccess, Actor, ActorRef}
import Box._


class NonSneaky {
  def process(a: Array[Int]): Unit = {
    for (i <- 0 until a.length)
      a(i) = a(i) + 1
  }
}

class ActorA(next: ActorRef[C]) extends Actor[C] {
  def receive(msg: Box[C])(implicit access: CanAccess { type C = msg.C }): Unit = {
    msg.open(spore { (obj: C) =>
      // OK: update array
      obj.arr(0) = 100

      // OK: create instance of ocap class
      val ns = new NonSneaky
      ns.process(obj.arr)
    })
    next.send(msg)(spore { () => })
  }
}

class ActorB(p: Promise[String]) extends Actor[C] {
  def receive(msg: Box[C])(implicit access: CanAccess { type C = msg.C }): Unit = {
    msg.open(spore { x =>
      p.success(x.arr.mkString(","))
    })
  }
}

class C {
  var arr: Array[Int] = _
}


@RunWith(classOf[JUnit4])
class Spec {

  @Test
  def test(): Unit = {
    // to check result
    val p: Promise[String] = Promise()

    val sys = System()
    val b = sys.actor[C](new ActorB(p))
    val a = sys.actor[C](new ActorA(b))

    try {
      mkBox[C] { packed =>
        import packed.access
        val box: packed.box.type = packed.box

        // initialize object in box with new array
        box.open(spore { obj =>
          obj.arr = Array(1, 2, 3, 4)
        })

        a.send(box)(spore { () => })
      }
    } catch {
      case t: Throwable =>
        val res = Await.result(p.future, 2.seconds)
        assert(res == "101,3,4,5")
    }

  }

} 
Example 149
Source File: ModelSerializabilityTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models.h2o

import com.eharmony.aloha.ModelSerializabilityTestBase
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner


@RunWith(classOf[BlockJUnit4ClassRunner])
class ModelSerializabilityTest extends ModelSerializabilityTestBase(
  Seq(ModelSerializabilityTest.pkg),
  Seq(
    ".*Test.*",
    ".*\\$.*"
  )
)

object ModelSerializabilityTest {
  def pkg = getClass.getPackage.getName
} 
Example 150
Source File: CompilerTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models.h2o.compiler

import hex.genmodel.GenModel
import hex.genmodel.easy.prediction.RegressionModelPrediction
import hex.genmodel.easy.{RowData, EasyPredictModelWrapper}
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner


@RunWith(classOf[BlockJUnit4ClassRunner])
class CompilerTest {
  @Test def testNoPackage(): Unit = {
    val compiler = new Compiler[GenModel]
    val genModel = compiler.fromResource("com/eharmony/aloha/models/h2o/glm_afa04e31_17ad_4ca6_9bd1_8ab80005ce38.java")
    val y: GenModel = genModel.get
    assertTrue(classOf[GenModel].isAssignableFrom(y.getClass))

    val x = new RowData
    x.put("Sex", "F")
    x.put("Length", java.lang.Double.valueOf(0.0))
    x.put("Diameter", java.lang.Double.valueOf(0.0))
    x.put("Height", java.lang.Double.valueOf(0.0))
    x.put("Whole weight", java.lang.Double.valueOf(0.0))
    x.put("Shucked weight", java.lang.Double.valueOf(0.0))
    x.put("Viscera weight", java.lang.Double.valueOf(0.0))
    x.put("Shell weight", java.lang.Double.valueOf(0.0))
    println(new EasyPredictModelWrapper(y).predictRegression(x).value)
  }

  @Test def testWithPackage(): Unit = {
    val compiler = new Compiler[GenModel]()
    val genModel = compiler.fromResource("com/eharmony/aloha/models/h2o/domain.glm_afa04e31_17ad_4ca6_9bd1_8ab80005ce37.java")
    val y: GenModel = genModel.get
    assertTrue(classOf[GenModel].isAssignableFrom(y.getClass))
  }

  @Test def testDrfCompiles(): Unit = {
    val compiler = new Compiler[GenModel]()
    val modelTry = compiler.fromResource("com/eharmony/aloha/models/h2o/DRF_model_1463074092542_1.java")
    val model = new EasyPredictModelWrapper(modelTry.get)
    val complexPrediction = model.predict(new RowData)
    val pred = complexPrediction match {
      case r: RegressionModelPrediction => Option(r.value)
      case _ => None
    }
    assertEquals(Option(0.0), pred)
  }
} 
Example 151
Source File: VwSparseMultilabelPredictorTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.models.vw.jni.multilabel

import java.io.{ByteArrayOutputStream, File, FileInputStream}

import com.eharmony.aloha.ModelSerializationTestHelper
import com.eharmony.aloha.io.sources.{Base64StringSource, ExternalSource, ModelSource}
import org.apache.commons.codec.binary.Base64
import org.apache.commons.io.IOUtils
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner
import vowpalWabbit.learner.{VWActionScoresLearner, VWLearners}


@RunWith(classOf[BlockJUnit4ClassRunner])
class VwSparseMultilabelPredictorTest extends ModelSerializationTestHelper {
  import VwSparseMultilabelPredictorTest._

  @Test def testSerializability(): Unit = {
    val predictor = getPredictor(getModelSource(), 3)
    val ds = serializeDeserializeRoundTrip(predictor)
    assertEquals(predictor, ds)
    assertEquals(predictor.vwParams(), ds.vwParams())
    assertNotNull(ds.vwModel)
  }

  @Test def testVwParameters(): Unit = {
    val numLabelsInTrainingSet = 3
    val predictor = getPredictor(getModelSource(), numLabelsInTrainingSet)

    predictor.vwParams() match {
      case Data(vwBinFilePath, ringSize) =>
        checkVwBinFile(vwBinFilePath)
        checkVwRingSize(numLabelsInTrainingSet, ringSize.toInt)
      case ps => fail(s"Unexpected VW parameters format.  Found string: $ps")
    }
  }
}

object VwSparseMultilabelPredictorTest {
  private val Data = """\s*-i\s+(\S+)\s+--ring_size\s+(\d+)\s+--testonly\s+--quiet""".r

  private def getModelSource(): ModelSource = {
    val f = File.createTempFile("i_dont", "care")
    f.deleteOnExit()
    val learner = VWLearners.create[VWActionScoresLearner](s"--quiet --csoaa_ldf mc --csoaa_rank -f ${f.getCanonicalPath}")
    learner.close()
    val baos = new ByteArrayOutputStream()
    IOUtils.copy(new FileInputStream(f), baos)
    val src = Base64StringSource(Base64.encodeBase64URLSafeString(baos.toByteArray))
    ExternalSource(src.localVfs)
  }

  private def getPredictor(modelSrc: ModelSource, numLabelsInTrainingSet: Int) =
    VwSparseMultilabelPredictor[Any](modelSrc, Nil, Nil, numLabelsInTrainingSet)

  private def checkVwBinFile(vwBinFilePath: String): Unit = {
    val vwBinFile = new File(vwBinFilePath)
    assertTrue("VW binary file should have been written to disk", vwBinFile.exists())
    vwBinFile.deleteOnExit()
  }

  private def checkVwRingSize(numLabelsInTrainingSet: Int, ringSize: Int): Unit = {
    assertEquals(
      "vw --ring_size parameter is incorrect:",
      numLabelsInTrainingSet + VwSparseMultilabelPredictor.AddlVwRingSize,
      ringSize.toInt
    )
  }
} 
Example 152
Source File: Stack.scala    From lacasa   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package lacasa.neg

import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.Test

import lacasa.util._

@RunWith(classOf[JUnit4])
class StackSpec {
  @Test
  def test1() {
    println(s"StackSpec.test1")
    expectError("confined") {
      """
        class D { }
        class C {
          import scala.spores._
          import lacasa.Box
          var b: Box[D] = _
          def m(): Unit = {
            Box.mkBox[D] { packed =>
              b = packed.box // assign box to field
            }
          }
        }
      """
    }
  }

  @Test
  def test2() {
    println(s"StackSpec.test2")
    expectError("confined") {
      """
        class D { }
        class C {
          import scala.spores._
          import lacasa.Box
          var b: lacasa.CanAccess = _
          def m(): Unit = {
            Box.mkBox[D] { packed =>
              b = packed.access // assign permission to field
            }
          }
        }
      """
    }
  }

  @Test
  def test3() {
    println(s"StackSpec.test3")
    expectError("confined") {
      """
        class D { }
        class C {
          import scala.spores._
          import lacasa.Box
          var b: Any = _
          def m(): Unit = {
            Box.mkBox[D] { packed =>
              b = packed.access // assign permission to field
            }
          }
        }
      """
    }
  }

  @Test
  def test4() {
    println(s"StackSpec.test4")
    expectError("confined") {
      """
        class E(x: Any) {}
        class D { }
        class C {
          import scala.spores._
          import lacasa.Box
          def m(): Unit = {
            Box.mkBox[D] { packed =>
              new E(packed.box)
            }
          }
        }
      """
    }
  }
} 
Example 153
Source File: ImplicitsTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.audit.impl.avro

import com.google.common.collect.Lists
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

import scala.collection.JavaConverters.seqAsJavaListConverter
import com.eharmony.aloha.audit.impl.avro.Implicits.{RichFlatScore, RichScore}
import java.{lang => jl, util => ju}

import org.apache.avro.generic.GenericRecord


  @Test def testAllFieldsAppear(): Unit = {
    val s = filledInScore
    assertEquals(s, s.toFlatScore.toScore)
  }

  @Test def testSameFieldsInGenericRecord(): Unit = {
    val s = filledInScore
    val s1 = s.asInstanceOf[GenericRecord]
    val s2 = s.toFlatScore.asInstanceOf[GenericRecord]

    testStuff(s1, s2, Map(
      "model" -> modelId,
      "value" -> value,
      "errorMsgs" -> errors,
      "missingVarNames" -> missing,
      "prob" -> prob
    ))
  }

  private[this] def testStuff(r1: GenericRecord, r2: GenericRecord, data: Map[String, Any]): Unit = {
    data.foreach { case (k, v) =>
      val v1 = r1.get(k)
      val v2 = r2.get(k)
      assertEquals(s"for r1('$k') = $v1.  Expected $v", v, r1.get(k))
      assertEquals(s"for r2('$k') = $v2.  Expected $v", v, r2.get(k))
    }
  }
}


object ImplicitsTest {
  private def filledInScore = new Score(modelId, value, subvalues, errors, missing, prob)
  private def modelId = new ModelId(5L, "five")
  private def value: jl.Double = 13d
  private def subvalues = Lists.newArrayList(scr(12L, 8))
  private def errors: ju.List[CharSequence] = Lists.newArrayList("one error", "two errors")
  private def missing: ju.List[CharSequence] =
    Lists.newArrayList("some feature", "another feature", "yet another feature")
  private def prob: jl.Float = 1f

  private lazy val score: Score =
    scr(1, 1,
      scr(2L, 2,
        scr(4f, 4),
        scr(5,  5)
      ),
      scr(3d, 3,
        scr(6d, 6),
        scr(7L, 7)
      )
    )

  private lazy val irregularTree: Score =
    scr(1, 1,
      scr(2L, 2),
      scr(3d, 3,
        scr(5d, 5),
        scr(6L, 6)
      ),
      scr(4d, 4,
        scr(7L, 7)
      )
    )

  private[this] def scr(value: Any, id: Long, children: Score*): Score = {
    new Score(
      new ModelId(id, ""),
      value,
      Lists.newArrayList(children.asJava),
      java.util.Collections.emptyList(),
      java.util.Collections.emptyList(),
      null
    )
  }
} 
Example 154
Source File: FlatScoreTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.audit.impl.avro

import com.google.common.collect.Lists
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner
import com.eharmony.aloha.audit.impl.avro.AvroScoreAuditorTest.serializeRoundTrip
import scala.collection.JavaConverters.seqAsJavaListConverter
import java.{util => ju}


@RunWith(classOf[BlockJUnit4ClassRunner])
class FlatScoreTest {
  import FlatScoreTest.flatScore

  @Test def testSerializability(): Unit = {
    val serDeserFS =
      serializeRoundTrip(FlatScore.getClassSchema, flatScore).head

    // When comparing the records instead of the JSON strings, equality doesn't
    // hold because they are different types. flatScoreList is a SpecificRecord
    // and SpecificRecord checks if the other values is a SpecificRecord.
    assertEquals(flatScore.toString, serDeserFS.toString)
  }
}

object FlatScoreTest {
  private[this] def empty[A]: ju.List[A] = ju.Collections.emptyList[A]

  private[this] implicit def toArrayList[A, B](as: Seq[A])(implicit ev: A => B): ju.ArrayList[B] =
    Lists.newArrayList(as.map(ev).asJava)

  private[this] def fsd(value: Any, id: Long, children: Int*): FlatScoreDescendant = {
    new FlatScoreDescendant(
      new ModelId(id, ""),
      value,
      children,
      empty[CharSequence],
      empty[CharSequence],
      null
    )
  }

  private[avro] lazy val flatScore: FlatScore = {
    new FlatScore(new ModelId(1L, ""), 1, Vector(0, 1), empty[CharSequence], empty[CharSequence], null,
      Seq(
        fsd(2L, 2, 2, 3),  // 0
        fsd(3d, 3, 4, 5),  // 1
        fsd(4f, 4),        // 2
        fsd(5,  5),        // 3
        fsd(6d, 6),        // 4
        fsd(7L, 7)         // 5
      )
    )
  }
} 
Example 155
Source File: StdAvroModelFactoryTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.factory.avro

import com.eharmony.aloha.audit.impl.avro.Score
import com.eharmony.aloha.factory.ModelFactory
import com.eharmony.aloha.io.vfs.Vfs1
import com.eharmony.aloha.models.Model
import org.apache.avro.Schema
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.commons.io.IOUtils
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

import scala.util.Try


  private[this] def record = {
    val r = new GenericData.Record(TheSchema)
    r.put("req_str_1", "smart handsome stubborn")
    r
  }
}

object StdAvroModelFactoryTest {
  private lazy val TheSchema = {
    val is = getClass.getClassLoader.getResourceAsStream(SchemaUrlResource)
    try new Schema.Parser().parse(is) finally IOUtils.closeQuietly(is)
  }

  private val ExpectedResult = 7d

  private val SchemaUrlResource = "avro/class7.avpr"

  private val SchemaUrl = s"res:$SchemaUrlResource"

  private val SchemaFile = new java.io.File(getClass.getClassLoader.getResource(SchemaUrlResource).getFile)

  private val SchemaVfs1FileObject = org.apache.commons.vfs.VFS.getManager.resolveFile(SchemaUrl)

  private val SchemaVfs2FileObject = org.apache.commons.vfs2.VFS.getManager.resolveFile(SchemaUrl)

  private val Imports = Seq("com.eharmony.aloha.feature.BasicFunctions._", "scala.math._")

  private val ReturnType = "Double"

  private val ModelJson =
    """
      |{
      |  "modelType": "Regression",
      |  "modelId": { "id": 0, "name": "" },
      |  "features" : {
      |    "my_attributes": "${req_str_1}.split(\"\\\\W+\").map(v => (s\"=$v\", 1.0))"
      |  },
      |  "weights": {
      |    "my_attributes=handsome": 1,
      |    "my_attributes=smart": 2,
      |    "my_attributes=stubborn": 4
      |  }
      |}
    """.stripMargin
} 
Example 156
Source File: PrintProtosTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.cli.dataset

import java.io.{ByteArrayOutputStream, IOException}
import java.util.Arrays

import com.eharmony.aloha.test.proto.Testing.{PhotoProto, UserProto}
import com.eharmony.aloha.test.proto.Testing.GenderProto.{FEMALE, MALE}
import com.google.protobuf.GeneratedMessage
import org.apache.commons.codec.binary.Base64
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner
import org.junit.{Ignore, Test}


@RunWith(classOf[BlockJUnit4ClassRunner])
@Ignore
class PrintProtosTest {
    @Test def testPrintProtos(): Unit = {
        System.out.println(alan)
        System.out.println(kate)
    }

    @throws(classOf[IOException])
    def alan: String = {
        val t = UserProto.newBuilder.
            setId(1).
            setName("Alan").
            setGender(MALE).
            setBmi(23).
            addAllPhotos(Arrays.asList(
                PhotoProto.newBuilder.
                    setId(1).
                    setAspectRatio(1).
                    setHeight(1).
                    build,
                PhotoProto.newBuilder.
                    setId(2).
                    setAspectRatio(2).
                    setHeight(2).build
            )).build
        b64(t)
    }

    def kate: String = {
        val t = UserProto.newBuilder.
            setId(1).
            setName("Kate").
            setGender(FEMALE).
            addAllPhotos(Arrays.asList(
                PhotoProto.newBuilder.
                    setId(3).
                    setAspectRatio(3).
                    setHeight(3).
                    build
            )).build
        b64(t)
    }

    def b64[M <: GeneratedMessage](p: M): String = {
        val baos: ByteArrayOutputStream = new ByteArrayOutputStream
        p.writeTo(baos)
        new String(Base64.encodeBase64(baos.toByteArray))
    }
} 
Example 157
Source File: ModelTypesTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.cli

import com.eharmony.aloha.factory.ModelFactory
import org.junit.Assert.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner



@RunWith(classOf[BlockJUnit4ClassRunner])
class ModelTypesTest {
  @Test def testKnownModels(): Unit = {
    val expected = Seq(
      "BootstrapExploration",
      "CategoricalDistribution",
      "CloserTester",              // A test model.
      "Constant",
      "DecisionTree",
      "DoubleToLong",
      "EpsilonGreedyExploration",
      "Error",
      "ErrorSwallowingModel",
      "H2o",
      "ModelDecisionTree",
      "Regression",
      "Segmentation",
      "SparseMultilabel",
      "VwJNI"
    )

    val actual = ModelFactory.defaultFactory(null, null).parsers.map(_.modelType).sorted
    assertEquals(expected, actual)
  }
} 
Example 158
Source File: CliTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.cli

import com.eharmony.aloha
import com.eharmony.aloha.util.io.TestWithIoCapture
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner


@RunWith(classOf[BlockJUnit4ClassRunner])
class CliTest extends TestWithIoCapture {

    @Test def testNoArgs(): Unit = {
        val res = run(Cli.main)(Array.empty)
        assertEquals("No arguments supplied. Supply one of: '--dataset', '--h2o', '--modelrunner', '--vw'.", res.err.contents.trim)
    }

    @Test def testBadFlag(): Unit = {
        val res = run(Cli.main)(Array("-BADFLAG"))
        assertEquals("'-BADFLAG' supplied. Supply one of: '--dataset', '--h2o', '--modelrunner', '--vw'.", res.err.contents.trim)
    }

    @Test def testVw(): Unit = {
        val res = run(Cli.main)(Array("--vw"))
        val expected =
            """
              |Error: Missing option --spec
              |Error: Missing option --model
              |vw """.stripMargin + aloha.version + """
              |Usage: vw [options]
              |
              |  -s <value> | --spec <value>
              |        spec is an Apache VFS URL to an aloha spec file.
              |  -m <value> | --model <value>
              |        model is an Apache VFS URL to a VW binary model.
              |  --fs-type <value>
              |        file system type: vfs1, vfs2, file. default = vfs2.
              |  -n <value> | --name <value>
              |        name of the model.
              |  -i <value> | --id <value>
              |        numeric id of the model.
              |  --vw-args <value>
              |        arguments to vw
              |  --external
              |        link to a binary VW model rather than embedding it inline in the aloha model.
              |  --num-missing-thresh <value>
              |        number of missing features to allow before returning a 'no-prediction'.
              |  --note <value>
              |        notes to add to the model. Can provide this many parameter times.
              |  --spline-min <value>
              |        min value for spline domain. (must additional provide spline-max and spline-knots).
              |  --spline-max <value>
              |        max value for spline domain. (must additional provide spline-min and spline-knots).
              |  --spline-knots <value>
              |        max value for spline domain. (must additional provide spline-min, spline-delta, and spline-knots).
            """.stripMargin

        assertEquals(expected.trim, res.err.contents.trim)
    }
} 
Example 159
Source File: RowCreatorProducerTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset

import java.lang.reflect.Modifier

import com.eharmony.aloha
import com.eharmony.aloha.dataset.vw.multilabel.VwMultilabelRowCreator
import org.junit.Assert._
import org.junit.{Ignore, Test}
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

import scala.collection.JavaConversions.asScalaSet
import org.reflections.Reflections

@RunWith(classOf[BlockJUnit4ClassRunner])
class RowCreatorProducerTest {
    import RowCreatorProducerTest._

    private[this] def scanPkg = aloha.pkgName + ".dataset"

    @Test def testAllRowCreatorProducersHaveOnlyZeroArgConstructors() {
        val reflections = new Reflections(scanPkg)
        val specProdClasses = reflections.getSubTypesOf(classOf[RowCreatorProducer[_, _, _]]).toSet
        specProdClasses.foreach { clazz =>
            val cons = clazz.getConstructors
            assertTrue(s"There should only be one constructor for ${clazz.getCanonicalName}.  Found ${cons.length} constructors.", cons.length <= 1)
            cons.headOption.foreach { c =>
                if (!(WhitelistedRowCreatorProducers contains clazz)) {
                    val nParams = c.getParameterTypes.length
                    assertEquals(s"The constructor for ${clazz.getCanonicalName} should take 0 arguments.  It takes $nParams.", 0, nParams)
                }
            }
        }
    }

    
    // TODO: Report the above bug!
    @Ignore @Test def testAllRowCreatorProducersAreFinalClasses() {
        val reflections = new Reflections(scanPkg)
        val specProdClasses = reflections.getSubTypesOf(classOf[RowCreatorProducer[_, _, _]]).toSet
        specProdClasses.foreach { clazz =>
            assertTrue(s"${clazz.getCanonicalName} needs to be declared final.", Modifier.isFinal(clazz.getModifiers))
        }
    }
}

object RowCreatorProducerTest {
    private val WhitelistedRowCreatorProducers = Set[Class[_]](
        classOf[VwMultilabelRowCreator.Producer[_, _]]
    )
} 
Example 160
Source File: VwFeatureNormalizerTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.vw

import org.junit.Assert._
import org.junit.{Test, Before}
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

@RunWith(classOf[BlockJUnit4ClassRunner])
class VwFeatureNormalizerTest {
    private[this] var normalizer: VwFeatureNormalizer = _

    @Before def setup() {
        normalizer = new VwFeatureNormalizer
    }

    @Test def testBlank() {
        assertEquals("", normalizer("").toString)
    }

    @Test def testSimple() {
        val vwLine: String = "1 1| |A a b c"
        assertEquals("1 1| |A:0.57735 a b c", normalizer.apply(vwLine).toString)
    }

    @Test def testMultipleNamespaces() {
        val vwLine: String = "1 1| |A a b c |b 1=2 3=4"
        assertEquals("1 1| |A:0.57735 a b c |b:0.70711 1=2 3=4", normalizer.apply(vwLine).toString)
    }

    @Test def testWithWeights() {
        val vwLine: String = "1 1| |A a:0.987 b c:0.435"
        assertEquals("1 1| |A:0.67988 a:0.987 b c:0.435", normalizer.apply(vwLine).toString)
    }
} 
Example 161
Source File: VwContextualBanditRowCreatorProducerTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.vw.cb

import com.eharmony.aloha.dataset.RowCreatorBuilder
import com.eharmony.aloha.dataset.vw.VwParsingAndChainOfRespTest
import com.eharmony.aloha.semantics.compiled.plugin.csv.CsvLine
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

@RunWith(classOf[BlockJUnit4ClassRunner])
class VwContextualBanditRowCreatorProducerTest {

    
    @Test def testAnyMissingDvFails(): Unit = {
        val semantics = VwParsingAndChainOfRespTest.semantics
        val sb = RowCreatorBuilder(semantics, List(new VwContextualBanditRowCreator.Producer[CsvLine]))
        val spec = sb.fromResource("com/eharmony/aloha/dataset/simpleCbSpec.json").get

        val lines = VwParsingAndChainOfRespTest.csvLines(
            "Alex,,,,,,,2,1,0",
            "Bill,,,,,,,2,1,",
            "Carl,,,,,,,2,,0",
            "Dale,,,,,,,,1,0"
        )

        // TODO: Work on removing trailing and leading spaces.  This is clearly not perfect.
        val expected = Seq(
            "2:1:0 |A name=Alex",
            "|A name=Bill",
            "|A name=Carl",
            "|A name=Dale"
        )

        (lines zip expected).zipWithIndex.foreach {
            case ((x, exp), i) =>
                val act = spec(x)._2.toString
                assertEquals(s"On test $i: ", exp, act)
            case d => fail(s"bad: $d")
        }
    }
} 
Example 162
Source File: VwCovariateProducerTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.vw

import com.eharmony.aloha.FileLocations
import com.eharmony.aloha.dataset.json.SparseSpec
import com.eharmony.aloha.dataset.vw.VwCovariateProducerTest.{X, semantics}
import com.eharmony.aloha.dataset.vw.json.VwJsonLike
import com.eharmony.aloha.dataset.{CompilerFailureMessages, SparseCovariateProducer, SparseFeatureExtractorFunction}
import com.eharmony.aloha.semantics.compiled.CompiledSemantics
import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler
import com.eharmony.aloha.semantics.compiled.plugin.csv.{CompiledSemanticsCsvPlugin, CsvLine, CsvTypes}
import com.eharmony.aloha.semantics.func.GenAggFunc
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

import scala.concurrent.ExecutionContext.Implicits.global
import scala.util.Success


@RunWith(classOf[BlockJUnit4ClassRunner])
class VwCovariateProducerTest {
    @Test def testGetVwDataWith1Function() {
        val j = new VwJsonLike {
            val namespaces = None
            val normalizeFeatures = None
            val features = Vector(SparseSpec("i_plus_d", """List(("", ${i} + ${d}))"""))
            val imports = Nil
        }

        val (covariates, default, nss, normalizer) = X.getVwData(semantics, j)

        covariates match {
            case Success(SparseFeatureExtractorFunction(IndexedSeq(("i_plus_d", f)))) =>
                assertTrue("Wrong covariate function", f.isInstanceOf[GenAggFunc[CsvLine, Iterable[(String, Double)]]])
            case _ => fail("Wrong covariates.")
        }

        assertEquals(1, default.size)
        assertEquals(0, nss.size)
        assertEquals(None, normalizer)
    }

    @Test def testGetVwDataEverythingMissing() {

        val j = new VwJsonLike {
            val namespaces = None
            val normalizeFeatures = None
            val features = Vector.empty
            val imports = Nil
        }

        val (covariates, default, nss, normalizer) = X.getVwData(semantics, j)

        covariates match {
            case Success(SparseFeatureExtractorFunction(IndexedSeq())) =>
            case _ => fail("Wrong covariates.")
        }

        assertEquals(0, default.size)
        assertEquals(0, nss.size)
        assertEquals(None, normalizer)
    }
}

private object VwCovariateProducerTest {
    object X extends VwCovariateProducer[CsvLine] with SparseCovariateProducer with CompilerFailureMessages {
        // To expose for testing.
        override def getVwData(semantics: CompiledSemantics[CsvLine], json: VwJsonLike) =
            super.getVwData(semantics, json)
    }

    lazy val semantics = {
        val compiler = TwitterEvalCompiler(classCacheDir = Option(FileLocations.testGeneratedClasses))
        val plugin = CompiledSemanticsCsvPlugin(
            "i" -> CsvTypes.IntType,
            "d" -> CsvTypes.DoubleType
        )
        CompiledSemantics[CsvLine](compiler, plugin, Nil)
    }
} 
Example 163
Source File: VwRowCreatorProducerTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.vw.unlabeled

import com.eharmony.aloha.dataset.RowCreatorBuilder

import scala.concurrent.ExecutionContext.Implicits.global
import com.eharmony.aloha.FileLocations
import com.eharmony.aloha.semantics.compiled.CompiledSemantics
import com.eharmony.aloha.semantics.compiled.compiler.TwitterEvalCompiler
import com.eharmony.aloha.semantics.compiled.plugin.csv.{CompiledSemanticsCsvPlugin, CsvLine}
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

@RunWith(classOf[BlockJUnit4ClassRunner])
class VwRowCreatorProducerTest {
    @Test def test1() {
        val p = CompiledSemanticsCsvPlugin()
        val sem = CompiledSemantics(TwitterEvalCompiler(classCacheDir = Option(FileLocations.testGeneratedClasses)), p, Nil)
        val sb = RowCreatorBuilder(sem, List(new VwRowCreator.Producer[CsvLine]))

        val json1 =
            """
              |{
              |  "imports": [],
              |  "features": [ { "name":"x", "spec":"Nil" } ]
              |}
            """.stripMargin.trim

        val xOpt = sb.fromString(json1)
        assertTrue(xOpt.isSuccess)

        val x = xOpt.get
        assertEquals(Seq(0), x.defaultNamespace)
        assertEquals(1, x.featuresFunction.features.size)
        assertEquals("x", x.featuresFunction.features.head._1)
        assertEquals(0, x.featuresFunction.features.head._2.accessors.size)
        assertEquals(0, x.featuresFunction.features.head._2.arity)
        assertTrue(x.namespaces.isEmpty)
        assertEquals(None, x.normalizer)
    }
} 
Example 164
Source File: VwLabelRowCreatorProducerTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.vw.labeled

import com.eharmony.aloha.dataset.RowCreatorBuilder
import com.eharmony.aloha.dataset.vw.VwParsingAndChainOfRespTest
import com.eharmony.aloha.semantics.compiled.plugin.csv.CsvLine
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

@RunWith(classOf[BlockJUnit4ClassRunner])
class VwLabelRowCreatorProducerTest {

    @Test def testNonDefaultTagThatsMissingDoesntRemoveLabel() {
        val semantics = VwParsingAndChainOfRespTest.semantics

        val sb = RowCreatorBuilder(semantics, List(new VwLabelRowCreator.Producer[CsvLine]))
        val spec = sb.fromResource("com/eharmony/aloha/dataset/simpleSpecWithTag.json").get

        val lines = VwParsingAndChainOfRespTest.csvLines(
            "Alex,,1,,2,,,,,",
            "Bill,,2,,3,,,,,",
            "Carl,,0,,,,,,,",
            "Dale,,3,,1,,,,,"
        )

        val expected = Seq(
            "1 2|A name=Alex marriages=UNK",
            "2 3|A name=Bill marriages=UNK",
            "0 |A name=Carl marriages=UNK",
            "3 1|A name=Dale marriages=UNK"
        )

        lines.zip(expected).foreach{
            case(x, exp) => assertEquals(
                s"for ${x.line}: ",
                exp,
                spec(x)._2.toString
            )
        }
    }


    @Test def testImportanceMissingRemovesLabel() {
        val semantics = VwParsingAndChainOfRespTest.semantics

        val sb = RowCreatorBuilder(semantics, List(new VwLabelRowCreator.Producer[CsvLine]))
        val spec = sb.fromResource("com/eharmony/aloha/dataset/simpleSpecWithImp.json").get

        val lines = VwParsingAndChainOfRespTest.csvLines(
            "Alex,,1,,2,,,,,",
            "Bill,,2,,3,,,,,",
            "Carl,,0,,,,,,,",
            "Dale,,3,,1,,,,,"
        )

        val expected = Seq(
            "1 2 1|A name=Alex marriages=UNK",
            "2 3 2|A name=Bill marriages=UNK",
            "|A name=Carl marriages=UNK",  // Omitting the importance variable removes the entire label.
            "3 3|A name=Dale marriages=UNK"
        )

        lines.zip(expected).foreach{
            case(x, exp) => assertEquals(
                s"for ${x.line}: ",
                exp,
                spec(x)._2.toString
            )
        }
    }

    @Test def testLabelMissingRemovesLabel() {

        val semantics = VwParsingAndChainOfRespTest.semantics

        val sb = RowCreatorBuilder(semantics, List(new VwLabelRowCreator.Producer[CsvLine]))
        val spec = sb.fromResource("com/eharmony/aloha/dataset/simpleSpec.json").get

        val lines = VwParsingAndChainOfRespTest.csvLines(
            "Alex,,1,,,,,,,",
            "Bill,,2,,,,,,,",
            "Carl,,,,,,,,,"
        )

        val expected = Seq(
            "1 1|A name=Alex marriages=UNK",
            "2 2|A name=Bill marriages=UNK",
            "|A name=Carl marriages=UNK"
        )

        lines.zip(expected).foreach{ case(x, exp) => assertEquals(s"for ${x.line}: ", exp, spec(x)._2.toString) }
    }
} 
Example 165
Source File: VwLabelRowCreatorTest.scala    From aloha   with MIT License 5 votes vote down vote up
package com.eharmony.aloha.dataset.vw.labeled

import com.eharmony.aloha.dataset.SparseFeatureExtractorFunction
import com.eharmony.aloha.semantics.func.GenFunc.f0
import org.junit.Assert._
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.BlockJUnit4ClassRunner

import scala.language.{postfixOps, implicitConversions}

@RunWith(classOf[BlockJUnit4ClassRunner])
final class VwLabelRowCreatorTest {

    private[this] val lab = 3d
    private[this] val imp0 = 0d
    private[this] val imp1 = 1d
    private[this] val imp2 = 2d
    private[this] val emptyTag = ""
    private[this] val tag = "t"

    private[this] implicit def liftToOption[A](a: A): Option[A] = Option(a)

    private[this] def spec(lab: Option[Double] = None, imp: Option[Double] = None, tag: Option[String] = None): VwLabelRowCreator[Any] = {
        val fef = new SparseFeatureExtractorFunction[Any](Vector("f1" -> f0("Empty", _ => Nil)))
        VwLabelRowCreator(fef, 0 to 0 toList, Nil, None, f0("", _ => lab), f0("", _ => imp), f0("", _ => tag))
    }

    private[this] def testLabelRemoval(spec: VwLabelRowCreator[Any], exp: String = ""): Unit = assertEquals(exp, spec(())._2.toString)

    // All of these should return empty label because the Label function returns a missing label.
    @Test def testS___() = testLabelRemoval(spec())
    @Test def testS__e() = testLabelRemoval(spec(tag = emptyTag))
    @Test def testS__t() = testLabelRemoval(spec(tag = tag))
    @Test def testS_0_() = testLabelRemoval(spec(imp = imp0))
    @Test def testS_0e() = testLabelRemoval(spec(imp = imp0, tag = emptyTag))
    @Test def testS_0t() = testLabelRemoval(spec(imp = imp0, tag = tag))
    @Test def testS_1_() = testLabelRemoval(spec(imp = imp1))
    @Test def testS_1e() = testLabelRemoval(spec(imp = imp1, tag = emptyTag))
    @Test def testS_1t() = testLabelRemoval(spec(imp = imp1, tag = tag))
    @Test def testS_2_() = testLabelRemoval(spec(imp = imp2))
    @Test def testS_2e() = testLabelRemoval(spec(imp = imp2, tag = emptyTag))
    @Test def testS_2t() = testLabelRemoval(spec(imp = imp2, tag = tag))

    // Importance not provided makes entire label vanish
    @Test def testS1_e() = testLabelRemoval(spec(lab = lab, tag = emptyTag))
    @Test def testS1_t() = testLabelRemoval(spec(lab = lab, tag = tag))

    // Importance of zero is given explicitly.
    @Test def testS10_() = testLabelRemoval(spec(lab = lab, imp = imp0), "3 0 |")
    @Test def testS10e() = testLabelRemoval(spec(lab = lab, imp = imp0, tag = emptyTag), "3 0 |")
    @Test def testS10t() = testLabelRemoval(spec(lab = lab, imp = imp0, tag = tag), "3 0 t|")

    // Importance of 1 is omitted.
    @Test def testS11_() = testLabelRemoval(spec(lab = lab, imp = imp1), "3 |")
    @Test def testS11e() = testLabelRemoval(spec(lab = lab, imp = imp1, tag = emptyTag), "3 |")
    @Test def testS11t() = testLabelRemoval(spec(lab = lab, imp = imp1, tag = tag), "3 t|")

    @Test def testS12_() = testLabelRemoval(spec(lab = lab, imp = imp2), "3 2 |")
    @Test def testS12e() = testLabelRemoval(spec(lab = lab, imp = imp2, tag = emptyTag), "3 2 |")
    @Test def testS12t() = testLabelRemoval(spec(lab = lab, imp = imp2, tag = tag), "3 2 t|")


    @Test def testStringLabel() {
        val spec = new VwLabelRowCreator(
            new SparseFeatureExtractorFunction(Vector("f1" -> f0("Empty", (_: Double) => Nil))),
            0 to 0 toList,
            Nil,
            None,
            f0("", (s: Double) => Option(s)),  // Label
            f0("", (_: Double) => Option(1d)), // Importance
            f0("", (_: Double) => None))       // Tag

        val values = Seq(
            -1.0                 -> "-1",
            -0.99999999999999999 -> "-1",
            -0.9999999999999999  -> "-0.9999999999999999",
            -1.0E-16             -> "-0.0000000000000001",
            -1.0E-17             -> "-0.00000000000000001",
            -1.0E-18             -> "-0",
             0.0                 ->  "0",
             1.0E-18             ->  "0",
             1.0E-17             ->  "0.00000000000000001",
             1.0E-16             ->  "0.0000000000000001",
             0.9999999999999999  ->  "0.9999999999999999",
             0.99999999999999999 ->  "1",
             1.0                 ->  "1"
        )

        values foreach { case(v, ex) => assertEquals(s"for line: $v", Option(ex), spec.stringLabel(v)) }
    }
} 
Example 166
Source File: DateMapToUnitCircleVectorizerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.sequence.SequenceModel
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.OpVectorMetadata
import org.apache.spark.ml.{Estimator, Transformer}
import org.apache.spark.ml.linalg.Vectors
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.RichMetadata._
import org.joda.time.{DateTime => JDateTime}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class DateMapToUnitCircleVectorizerTest extends OpEstimatorSpec[OPVector, SequenceModel[DateMap, OPVector],
  DateMapToUnitCircleVectorizer[DateMap]] with AttributeAsserts {

  val eps = 1E-4
  val sampleDateTimes = Seq[JDateTime](
    new JDateTime(2018, 2, 11, 0, 0, 0, 0),
    new JDateTime(2018, 11, 28, 6, 0, 0, 0),
    new JDateTime(2018, 2, 17, 12, 0, 0, 0),
    new JDateTime(2017, 4, 17, 18, 0, 0, 0),
    new JDateTime(1918, 2, 13, 3, 0, 0, 0)
  )

  val (inputData, f1) = TestFeatureBuilder(
    sampleDateTimes.map(x => Map("a" -> x.getMillis, "b" -> x.getMillis).toDateMap)
  )

  
  override val expectedResult: Seq[OPVector] = sampleDateTimes
    .map{ v =>
      val rad = DateToUnitCircle.convertToRandians(Option(v.getMillis), TimePeriod.HourOfDay)
      (rad ++ rad).toOPVector
    }

  it should "work with its shortcut as a DateMap" in {
    val output = f1.toUnitCircle(TimePeriod.HourOfDay)
    val transformed = output.originStage.asInstanceOf[DateMapToUnitCircleVectorizer[DateMap]]
      .fit(inputData).transform(inputData)
    val field = transformed.schema(output.name)
    val actual = transformed.collect(output)
    assertNominal(field, Array.fill(actual.head.value.size)(false), actual)
    all (actual.zip(expectedResult).map(g => Vectors.sqdist(g._1.value, g._2.value))) should be < eps
  }

  it should "work with its shortcut as a DateTimeMap" in {
    val (inputDataDT, f1DT) = TestFeatureBuilder(
      sampleDateTimes.map(x => Map("a" -> x.getMillis, "b" -> x.getMillis).toDateTimeMap)
    )
    val output = f1DT.toUnitCircle(TimePeriod.HourOfDay)
    val transformed = output.originStage.asInstanceOf[DateMapToUnitCircleVectorizer[DateMap]]
      .fit(inputData).transform(inputData)
    val field = transformed.schema(output.name)
    val actual = transformed.collect(output)
    assertNominal(field, Array.fill(actual.head.value.size)(false), actual)
    all (actual.zip(expectedResult).map(g => Vectors.sqdist(g._1.value, g._2.value))) should be < eps
  }

  it should "make the correct metadata" in {
    val fitted = estimator.fit(inputData)
    val meta = OpVectorMetadata(fitted.getOutputFeatureName, fitted.getMetadata())
    meta.columns.length shouldBe 4
    meta.columns.flatMap(_.grouping) shouldEqual Seq("a", "a", "b", "b")
    meta.columns.flatMap(_.descriptorValue) shouldEqual Seq("x_HourOfDay", "y_HourOfDay", "x_HourOfDay", "y_HourOfDay")
  }

} 
Example 167
Source File: OpIndexToStringNoFilterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpIndexToStringNoFilterTest extends OpTransformerSpec[Text, OpIndexToStringNoFilter] {
  val (inputData, indF) = TestFeatureBuilder(Seq(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN))
  val labels = Array("a", "c")

  override val transformer: OpIndexToStringNoFilter = new OpIndexToStringNoFilter().setInput(indF).setLabels(labels)

  override val expectedResult: Seq[Text] =
    Array("a", OpIndexToStringNoFilter.unseenDefault, "c", "a", "a", "c").map(_.toText)

  it should "correctly deindex a numeric column using shortcut" in {
    val str2 = indF.deindexed(labels, handleInvalid = IndexToStringHandleInvalid.NoFilter)
    val strs2 = str2.originStage.asInstanceOf[OpIndexToStringNoFilter].transform(inputData).collect(str2)
    strs2 shouldBe expectedResult
  }
} 
Example 168
Source File: SetNGramSimilarityTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.Transformer
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class SetNGramSimilarityTest extends OpTransformerSpec[RealNN, SetNGramSimilarity] {

  val (inputData, f1, f2) = TestFeatureBuilder(
    Seq(
      (Seq("Red", "Green"), Seq("Red")),
      (Seq("Red", "Green"), Seq("Yellow, Blue")),
      (Seq("Red", "Yellow"), Seq("Red", "Yellow")),
      (Seq[String](), Seq("Red", "Yellow")),
      (Seq[String](), Seq[String]()),
      (Seq[String](""), Seq[String]("asdf")),
      (Seq[String](""), Seq[String]("")),
      (Seq[String]("", ""), Seq[String]("", ""))
    ).map(v => v._1.toMultiPickList -> v._2.toMultiPickList)
  )

  val expectedResult = Seq(0.3333333134651184, 0.09722214937210083, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0).toRealNN
  val catNGramSimilarity = f1.toNGramSimilarity(f2)
  val transformer = catNGramSimilarity.originStage.asInstanceOf[SetNGramSimilarity]

  it should "correctly compute char-n-gram similarity with nondefault ngram param" in {
    val cat5GramSimilarity = f1.toNGramSimilarity(f2, 5)
    val transformedDs = cat5GramSimilarity.originStage.asInstanceOf[Transformer].transform(inputData)
    val actualOutput = transformedDs.collect(cat5GramSimilarity)

    actualOutput shouldBe Seq(0.3333333432674408, 0.12361115217208862, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0).toRealNN
  }
} 
Example 169
Source File: RoundTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class RoundTransformerTest extends OpTransformerSpec[Integral, RoundTransformer[Real]] {
  val sample = Seq(Real(-1.3), Real(-4.9), Real.empty, Real(5.1), Real(-5.1), Real(0.1), Real(2.5), Real(0.4))
  val (inputData, f1) = TestFeatureBuilder(sample)
  val transformer: RoundTransformer[Real] = new RoundTransformer[Real]().setInput(f1)
  val expectedResult: Seq[Integral] = Seq(Integral(-1), Integral(-5), Integral.empty, Integral(5),
    Integral(-5), Integral(0), Integral(3), Integral(0))

  it should "have a working shortcut" in {
    val f2 = f1.round()
    f2.originStage.isInstanceOf[RoundTransformer[_]] shouldBe true
  }
} 
Example 170
Source File: OpStringIndexerNoFilterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryModel
import com.salesforce.op.stages.impl.feature.StringIndexerHandleInvalid.Skip
import com.salesforce.op.stages.sparkwrappers.generic.SwUnaryModel
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.feature.StringIndexerModel
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpStringIndexerNoFilterTest
  extends OpEstimatorSpec[RealNN, UnaryModel[Text, RealNN], OpStringIndexerNoFilter[Text]] {

  val txtData = Seq("a", "b", "c", "a", "a", "c").map(_.toText)
  val (inputData, txtF) = TestFeatureBuilder(txtData)
  override val expectedResult: Seq[RealNN] = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0).map(_.toRealNN)

  override val estimator: OpStringIndexerNoFilter[Text] = new OpStringIndexerNoFilter[Text]().setInput(txtF)

  val txtDataNew = Seq("a", "b", "c", "a", "a", "c", "d", "e").map(_.toText)
  val (dsNew, txtFNew) = TestFeatureBuilder(txtDataNew)
  val expectedNew = Array(0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 3.0, 3.0).map(_.toRealNN)

  it should "correctly index a text column (shortcut)" in {
    val indexed = txtF.indexed()
    val indices = indexed.originStage.asInstanceOf[OpStringIndexerNoFilter[_]]
      .fit(inputData).transform(inputData).collect(indexed)
    indices shouldBe expectedResult

    val indexed2 = txtF.indexed(handleInvalid = Skip)
    val indicesfit = indexed2.originStage.asInstanceOf[OpStringIndexer[_]].fit(inputData)
    val indices2 = indicesfit.transform(inputData).collect(indexed2)
    val indices3 = indicesfit.asInstanceOf[SwUnaryModel[Text, RealNN, StringIndexerModel]]
      .setInput(txtFNew).transform(dsNew).collect(indexed2)
    indices2 shouldBe expectedResult
    indices3 shouldBe expectedResult
  }

  it should "correctly deinxed a numeric column" in {
    val indexed = txtF.indexed()
    val indices = indexed.originStage.asInstanceOf[OpStringIndexerNoFilter[_]].fit(inputData).transform(inputData)
    val deindexed = indexed.deindexed()
    val deindexedData = deindexed.originStage.asInstanceOf[OpIndexToStringNoFilter]
      .transform(indices).collect(deindexed)
    deindexedData shouldBe txtData
  }

  it should "assign new strings to the unseen string category" in {
    val indices = estimator.fit(inputData).setInput(txtFNew).transform(dsNew).collect(estimator.getOutput())
    indices shouldBe expectedNew
  }
} 
Example 171
Source File: IDFTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.feature.IDF
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.{Estimator, Transformer}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class IDFTest extends FlatSpec with TestSparkContext {

  val data = Seq(
    Vectors.sparse(4, Array(1, 3), Array(1.0, 2.0)),
    Vectors.dense(0.0, 1.0, 2.0, 3.0),
    Vectors.sparse(4, Array(1), Array(1.0))
  )

  lazy val (ds, f1) = TestFeatureBuilder(data.map(_.toOPVector))

  Spec[IDF] should "compute inverted document frequency" in {
    val idf = f1.idf()
    val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds)
    val transformedData = model.asInstanceOf[Transformer].transform(ds)
    val results = transformedData.select(idf.name).collect(idf)

    idf.name shouldBe idf.originStage.getOutputFeatureName

    val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
      math.log((data.length + 1.0) / (x + 1.0))
    })
    val expected = scaleDataWithIDF(data, expectedIdf)

    for {
      (res, exp) <- results.zip(expected)
      (x, y) <- res.value.toArray.zip(exp.toArray)
    } assert(math.abs(x - y) <= 1e-5)
  }

  it should "compute inverted document frequency when minDocFreq is 1" in {
    val idf = f1.idf(minDocFreq = 1)
    val model = idf.originStage.asInstanceOf[Estimator[_]].fit(ds)
    val transformedData = model.asInstanceOf[Transformer].transform(ds)
    val results = transformedData.select(idf.name).collect(idf)
    idf.name shouldBe idf.originStage.getOutputFeatureName

    val expectedIdf = Vectors.dense(Array(0, 3, 1, 2).map { x =>
      if (x > 0) math.log((data.length + 1.0) / (x + 1.0)) else 0
    })
    val expected = scaleDataWithIDF(data, expectedIdf)

    for {
      (res, exp) <- results.zip(expected)
      (x, y) <- res.value.toArray.zip(exp.toArray)
    } assert(math.abs(x - y) <= 1e-5)
  }

  private def scaleDataWithIDF(dataSet: Seq[Vector], model: Vector): Seq[Vector] = {
    dataSet.map {
      case data: DenseVector =>
        val res = data.toArray.zip(model.toArray).map { case (x, y) => x * y }
        Vectors.dense(res)
      case data: SparseVector =>
        val res = data.indices.zip(data.values).map { case (id, value) =>
          (id, value * model(id))
        }
        Vectors.sparse(data.size, res)
    }
  }

} 
Example 172
Source File: OpCountVectorizerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.test.TestOpVectorColumnType.{IndCol, IndVal}
import com.salesforce.op.test.{TestFeatureBuilder, TestOpVectorMetadataBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.OpVectorMetadata
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner
import org.scalatest.{Assertions, FlatSpec, Matchers}


@RunWith(classOf[JUnitRunner])
class OpCountVectorizerTest extends FlatSpec with TestSparkContext {

  val data = Seq[(Real, TextList)](
    (Real(0), Seq("a", "b", "c").toTextList),
    (Real(1), Seq("a", "b", "b", "b", "a", "c").toTextList)
  )

  lazy val (ds, f1, f2) = TestFeatureBuilder(data)

  lazy val expected = Array[(Real, OPVector)](
    (Real(0), Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 1.0, 1.0)).toOPVector),
    (Real(1), Vectors.sparse(3, Array(0, 1, 2), Array(3.0, 2.0, 1.0)).toOPVector)
  )

  val f2vec = new OpCountVectorizer().setInput(f2).setVocabSize(3).setMinDF(2)

  Spec[OpCountVectorizerTest] should "convert array of strings into count vector" in {
    val transformedData = f2vec.fit(ds).transform(ds)
    val output = f2vec.getOutput()
    transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected
  }

  it should "return the a fitted vectorizer with the correct parameters" in {
    val fitted = f2vec.fit(ds)
    val vectorMetadata = fitted.getMetadata()
    val expectedMeta = TestOpVectorMetadataBuilder(
      f2vec,
      f2 -> List(IndVal(Some("b")), IndVal(Some("a")), IndVal(Some("c")))
    )
    // cannot just do equals because fitting is nondeterministic
    OpVectorMetadata(f2vec.getOutputFeatureName, vectorMetadata).columns should contain theSameElementsAs
      expectedMeta.columns
  }

  it should "convert array of strings into count vector (shortcut version)" in {
    val output = f2.countVec(minDF = 2, vocabSize = 3)
    val f2vec = output.originStage.asInstanceOf[OpCountVectorizer]
    val transformedData = f2vec.fit(ds).transform(ds)
    transformedData.orderBy(f1.name).collect(f1, output) should contain theSameElementsInOrderAs expected
  }
} 
Example 173
Source File: OpTextPivotVectorizerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.sequence.SequenceModel
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpTextPivotVectorizerTest
  extends OpEstimatorSpec[OPVector, SequenceModel[Text, OPVector], OpTextPivotVectorizer[Text]] {

  lazy val (inputData, f1, f2) = TestFeatureBuilder("text1", "text2",
    Seq[(Text, Text)](
      ("hello world".toText, "Hello world!".toText),
      ("hello world".toText, "What's up".toText),
      ("good evening".toText, "How are you doing, my friend?".toText),
      ("hello world".toText, "Not bad, my friend.".toText),
      (Text.empty, Text.empty)
    )
  )

  
  override val expectedResult: Seq[OPVector] = Seq(
    Vectors.sparse(8, Array(0, 4), Array(1.0, 1.0)),
    Vectors.sparse(8, Array(0, 6), Array(1.0, 1.0)),
    Vectors.sparse(8, Array(1, 5), Array(1.0, 1.0)),
    Vectors.sparse(8, Array(0, 6), Array(1.0, 1.0)),
    Vectors.sparse(8, Array(3, 7), Array(1.0, 1.0))
  ).map(_.toOPVector)
} 
Example 174
Source File: ScalerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.test.TestSparkContext
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class ScalerTest extends FlatSpec with TestSparkContext {

  Spec[Scaler] should "error on invalid data" in {
    val error = intercept[IllegalArgumentException](
      Scaler.apply(scalingType = ScalingType.Linear, args = EmptyScalerArgs())
    )
    error.getMessage shouldBe
      s"Invalid combination of scaling type '${ScalingType.Linear}' " +
        s"and args type '${EmptyScalerArgs().getClass.getSimpleName}'"
  }

  it should "correctly build construct a LinearScaler" in {
    val linearScaler = Scaler.apply(scalingType = ScalingType.Linear,
      args = LinearScalerArgs(slope = 1.0, intercept = 2.0))
    linearScaler shouldBe a[LinearScaler]
    linearScaler.scalingType shouldBe ScalingType.Linear
  }

  it should "correctly build construct a LogScaler" in {
    val linearScaler = Scaler.apply(scalingType = ScalingType.Logarithmic, args = EmptyScalerArgs())
    linearScaler shouldBe a[LogScaler]
    linearScaler.scalingType shouldBe ScalingType.Logarithmic
  }
} 
Example 175
Source File: RoundDigitsTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class RoundDigitsTransformerTest extends OpTransformerSpec[Real, RoundDigitsTransformer[Real]] {
  val sample = Seq(Real(1.4231092), Real(4.3231), Real.empty, Real(-1.0), Real(2.03728181))
  val (inputData, f1) = TestFeatureBuilder(sample)
  val transformer: RoundDigitsTransformer[Real] = new RoundDigitsTransformer[Real](2)
    .setInput(f1)
  val expectedResult: Seq[Real] = Seq(Real(1.42), Real(4.32), Real.empty, Real(-1.0), Real(2.04))

  it should "have a working shortcut" in {
    val f2 = f1.round(4)
    f2.originStage.isInstanceOf[RoundDigitsTransformer[_]] shouldBe true
  }
} 
Example 176
Source File: LangDetectorTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.text.Language
import org.apache.spark.ml.Transformer
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class LangDetectorTest extends OpTransformerSpec[RealMap, LangDetector[Text]] {

  // scalastyle:off
  val (inputData, f1, f2, f3) = TestFeatureBuilder(
    Seq(
      (
        "I've got a lovely bunch of coconuts".toText,
        "文化庁によりますと、世界文化遺産への登録を目指している、福岡県の「宗像・沖ノ島と関連遺産群」について、ユネスコの諮問機関は、8つの構成資産のうち、沖ノ島など4つについて、「世界遺産に登録することがふさわしい」とする勧告をまとめました。".toText,
        "Première détection d’une atmosphère autour d’une exoplanète de la taille de la Terre".toText
      ),
      (
        "There they are, all standing in a row".toText,
        "地磁気発生の謎に迫る地球内部の環境、再現実験".toText,
        "Les deux commissions, créées respectivement en juin 2016 et janvier 2017".toText
      ),
      (
        "Big ones, small ones, some as big as your head".toText,
        "大学レスリング界で「黒船」と呼ばれたカザフスタン出身の大型レスラーが、日本の男子グレコローマンスタイルの重量級強化のために一役買っている。山梨学院大をこの春卒業したオレッグ・ボルチン(24)。4月から新日本プロレスの親会社ブシロードに就職。自身も日本を拠点に、アマチュアレスリングで2020年東京五輪を目指す。".toText,
        "Il publie sa théorie de la relativité restreinte en 1905".toText
      )
    )
  )
  // scalastyle:on
  val transformer = new LangDetector[Text]().setInput(f1)

  private val langMap = f1.detectLanguages()

  // English result
  val expectedResult: Seq[RealMap] = Seq(
    Map("en" -> 0.9999984360934321),
    Map("en" -> 0.9999900853228016),
    Map("en" -> 0.9999900116744931)
  ).map(_.toRealMap)

  it should "return empty RealMap when input text is empty" in {
    transformer.transformFn(Text.empty) shouldBe RealMap.empty
  }

  it should "detect Japanese language" in {
    assertDetectionResults(
      results = transformer.setInput(f2).transform(inputData).collect(transformer.getOutput()),
      expectedLanguage = Language.Japanese
    )
  }

  it should "detect French language" in {
    assertDetectionResults(
      results = transformer.setInput(f3).transform(inputData).collect(transformer.getOutput()),
      expectedLanguage = Language.French
    )
  }

  it should "has a working shortcut" in {
    val tokenized = f1.detectLanguages()

    assertDetectionResults(
      results = tokenized.originStage.asInstanceOf[Transformer].transform(inputData).collect(tokenized),
      expectedLanguage = Language.English
    )
  }

  private def assertDetectionResults
  (
    results: Array[RealMap],
    expectedLanguage: Language,
    confidence: Double = 0.99
  ): Unit =
    results.foreach(res => {
      res.value.size shouldBe 1
      res.value.contains(expectedLanguage.entryName) shouldBe true
      res.value(expectedLanguage.entryName) should be >= confidence
    })

} 
Example 177
Source File: TextLenTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder, TestSparkContext}
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class TextLenTransformerTest extends OpTransformerSpec[OPVector, TextLenTransformer[_]]
  with TestSparkContext with AttributeAsserts {

  val (ds, f1, f2) = TestFeatureBuilder(
    Seq[(TextList, TextList)](
      (TextList(Seq("A", "giraffe", "drinks", "by", "the", "watering", "hole")),
        TextList(Seq("A giraffe drinks by the watering hole"))),
      (TextList(Seq("A giraffe drinks by the watering hole")), TextList(Seq("Cheese"))),
      (TextList(Seq("Cheese", "cake")), TextList(Seq("A giraffe drinks by the watering hole"))),
      (TextList(Seq("Cheese")), TextList(Seq("Cheese"))),
      (TextList.empty, TextList(Seq("A giraffe drinks by the watering hole"))),
      (TextList.empty, TextList(Seq("Cheese", "tart"))),
      (TextList(Seq("A giraffe drinks by the watering hole")), TextList.empty),
      (TextList(Seq("Cheese")), TextList.empty),
      (TextList.empty, TextList.empty)
    )
  )

  // Variables for OpTransformer base tests
  val inputData = ds

  val transformer = new TextLenTransformer().setInput(f1, f2)

  val expectedResult = Seq(
    Array(31.0, 37.0),
    Array(37.0, 6.0),
    Array(10.0, 37.0),
    Array(6.0, 6.0),
    Array(0.0, 37.0),
    Array(0.0, 10.0),
    Array(37.0, 0.0),
    Array(6.0, 0.0),
    Array(0.0, 0.0)
  ).map(Vectors.dense(_).toOPVector)

  Spec[TextLenTransformer[_]] should "take an array of features as input and return a single vector feature" in {
    val vector = transformer.getOutput()

    vector.name shouldBe transformer.getOutputFeatureName
    vector.typeName shouldBe FeatureType.typeName[OPVector]
    vector.isResponse shouldBe false
  }

  it should "transform the data correctly" in {
    val transformed = transformer.transform(ds)
    val vector = transformer.getOutput()

    val result = transformed.collect(vector)
    result should contain theSameElementsAs expectedResult
  }
} 
Example 178
Source File: VectorsCombinerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.TransientFeature
import com.salesforce.op.features.types.{Text, _}
import com.salesforce.op.stages.base.sequence.SequenceModel
import com.salesforce.op.test.{OpEstimatorSpec, PassengerSparkFixtureTest, TestFeatureBuilder}
import com.salesforce.op.utils.spark.OpVectorMetadata
import com.salesforce.op.utils.spark.RichMetadata._
import org.apache.spark.ml.attribute.MetadataHelper
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.types.Metadata
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class VectorsCombinerTest
  extends OpEstimatorSpec[OPVector, SequenceModel[OPVector, OPVector], VectorsCombiner]
    with PassengerSparkFixtureTest {

  override def specName: String = classOf[VectorsCombiner].getSimpleName

  val (inputData, f1, f2) = TestFeatureBuilder(Seq(
    Vectors.sparse(4, Array(0, 3), Array(1.0, 1.0)).toOPVector ->
      Vectors.sparse(4, Array(0, 3), Array(2.0, 3.0)).toOPVector,
    Vectors.dense(Array(2.0, 3.0, 4.0)).toOPVector ->
      Vectors.dense(Array(12.0, 13.0, 14.0)).toOPVector,
    // Purposely added some very large sparse vectors to verify the efficiency
    Vectors.sparse(100000000, Array(1), Array(777.0)).toOPVector ->
      Vectors.sparse(500000000, Array(0), Array(888.0)).toOPVector
  ))

  val estimator = new VectorsCombiner().setInput(f1, f2)

  val expectedResult = Seq(
    Vectors.sparse(8, Array(0, 3, 4, 7), Array(1.0, 1.0, 2.0, 3.0)).toOPVector,
    Vectors.dense(Array(2.0, 3.0, 4.0, 12.0, 13.0, 14.0)).toOPVector,
    Vectors.sparse(600000000, Array(1, 100000000), Array(777.0, 888.0)).toOPVector
  )

  it should "combine metadata correctly" in {
    val vector = Seq(height, description, stringMap).transmogrify()
    val inputs = vector.parents
    val outputData = new OpWorkflow().setReader(dataReader)
      .setResultFeatures(vector, inputs(0), inputs(1), inputs(2))
      .train().score()
    val inputMetadata = OpVectorMetadata.flatten(vector.name,
      inputs.map(i => OpVectorMetadata(outputData.schema(i.name))))
    OpVectorMetadata(outputData.schema(vector.name)).columns should contain theSameElementsAs inputMetadata.columns
  }

  it should "create metadata correctly" in {
    val descVect = description.map[Text] { t =>
      Text(t.value match {
        case Some(text) => "this is dumb " + text
        case None => "some STUFF to tokenize"
      })
    }.tokenize().tf(numTerms = 5)
    val vector = Seq(height, stringMap, descVect).transmogrify()
    val Seq(inputs1, inputs2, inputs3) = vector.parents

    val outputData = new OpWorkflow().setReader(dataReader)
      .setResultFeatures(vector, inputs1, inputs2, inputs3)
      .train().score()
    outputData.schema(inputs1.name).metadata.wrapped
      .get[Metadata](MetadataHelper.attributeKeys.ML_ATTR)
      .getLong(MetadataHelper.attributeKeys.NUM_ATTRIBUTES) shouldBe 5

    val inputMetadata = OpVectorMetadata.flatten(vector.name,
      Array(TransientFeature(inputs1).toVectorMetaData(5, Option(inputs1.name)),
        OpVectorMetadata(outputData.schema(inputs2.name)), OpVectorMetadata(outputData.schema(inputs3.name))))
    OpVectorMetadata(outputData.schema(vector.name)).columns should contain theSameElementsAs inputMetadata.columns
  }
} 
Example 179
Source File: OpStopWordsRemoverTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op._
import com.salesforce.op.features.types._
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.test.{SwTransformerSpec, TestFeatureBuilder}
import org.apache.spark.ml.feature.StopWordsRemover
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpStopWordsRemoverTest extends SwTransformerSpec[TextList, StopWordsRemover, OpStopWordsRemover] {
  val data = Seq(
    "I AM groot", "Groot call me human", "or I will crush you"
  ).map(_.split(" ").toSeq.toTextList)

  val (inputData, textListFeature) = TestFeatureBuilder(data)

  val bigrams = textListFeature.removeStopWords()
  val transformer = bigrams.originStage.asInstanceOf[OpStopWordsRemover]

  val expectedResult = Seq(Seq("groot"), Seq("Groot", "call", "human"), Seq("crush")).map(_.toTextList)

  it should "allow case sensitivity" in {
    val noStopWords = textListFeature.removeStopWords(caseSensitive = true)
    val res = noStopWords.originStage.asInstanceOf[OpStopWordsRemover].transform(inputData)
    res.collect(noStopWords) shouldBe Seq(
      Seq("I", "AM", "groot"), Seq("Groot", "call", "human"), Seq("I", "crush")).map(_.toTextList)
  }

  it should "set custom stop words" in {
    val noStopWords = textListFeature.removeStopWords(stopWords = Array("Groot", "I"))
    val res = noStopWords.originStage.asInstanceOf[OpStopWordsRemover].transform(inputData)
    res.collect(noStopWords) shouldBe Seq(
      Seq("AM"), Seq("call", "me", "human"), Seq("or", "will", "crush", "you")).map(_.toTextList)
  }
} 
Example 180
Source File: TransmogrifierTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features._
import com.salesforce.op.features.types._
import com.salesforce.op.test.TestOpVectorColumnType._
import com.salesforce.op.test.{PassengerSparkFixtureTest, TestOpVectorMetadataBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.spark.RichStructType._
import com.salesforce.op._
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TransmogrifierTest extends FlatSpec with PassengerSparkFixtureTest with AttributeAsserts {

  val inputFeatures = Array[OPFeature](heightNoWindow, weight, gender)

  Spec(Transmogrifier.getClass) should "return a single output feature of type vector with the correct name" in {
    val feature = inputFeatures.transmogrify()
    feature.name.contains("gender-heightNoWindow-weight_3-stagesApplied_OPVector")
  }

  it should "return a model when fitted" in {
    val feature = inputFeatures.transmogrify()
    val model = new OpWorkflow().setResultFeatures(feature).setReader(dataReader).train()

    model.getResultFeatures() should contain theSameElementsAs Array(feature)
    val name = model.getResultFeatures().map(_.name).head
    name.contains("gender-heightNoWindow-weight_3-stagesApplied_OPVector")
  }

  it should "correctly transform the data and store the feature names in metadata" in {
    val feature = inputFeatures.toSeq.transmogrify()
    val model = new OpWorkflow().setResultFeatures(feature).setReader(dataReader).train()
    val transformed = model.score(keepRawFeatures = true, keepIntermediateFeatures = true)
    val hist = feature.parents.flatMap { f =>
      val h = f.history()
      h.originFeatures.map(o => o -> FeatureHistory(Seq(o), h.stages))
    }.toMap
    transformed.schema.toOpVectorMetadata(feature.name) shouldEqual
      TestOpVectorMetadataBuilder.withOpNamesAndHist(
        feature.originStage,
        hist,
        (gender, "vecSet", List(IndCol(Some("OTHER")), IndCol(Some(TransmogrifierDefaults.NullString)))),
        (heightNoWindow, "vecReal", List(RootCol,
          IndColWithGroup(Some(TransmogrifierDefaults.NullString), heightNoWindow.name))),
        (weight, "vecReal", List(RootCol, IndColWithGroup(Some(TransmogrifierDefaults.NullString), weight.name)))
      )

    transformed.schema.findFields("heightNoWindow-weight_1-stagesApplied_OPVector").nonEmpty shouldBe true

    val collected = transformed.collect(feature)

    collected.head.v.size shouldEqual 6
    collected.map(_.v.toArray.toList).toSet shouldEqual
      Set(
        List(0.0, 1.0, 211.4, 1.0, 96.0, 1.0),
        List(1.0, 0.0, 172.0, 0.0, 78.0, 0.0),
        List(1.0, 0.0, 168.0, 0.0, 67.0, 0.0),
        List(1.0, 0.0, 363.0, 0.0, 172.0, 0.0),
        List(1.0, 0.0, 186.0, 0.0, 96.0, 0.0)
      )
    val field = transformed.schema(feature.name)
    assertNominal(field, Array(false, true, false, true, false, true), collected)
  }

} 
Example 181
Source File: OPMapTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.UID
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryTransformer
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OPMapTransformerTest
  extends OpTransformerSpec[IntegralMap, OPMapTransformer[Email, Integral, EmailMap, IntegralMap]] {

  lazy val (inputData, top) = TestFeatureBuilder("name", Seq(
    Map("p1" -> "[email protected]", "p2" -> "[email protected]").toEmailMap
  ))

  val transformer: OPMapTransformer[Email, Integral, EmailMap, IntegralMap] =
    new LengthMapTransformer().setInput(top)

  val expectedResult: Seq[IntegralMap] = Seq(
    Map("p1" -> 10L, "p2" -> 11L).toIntegralMap
  )
}

class LengthTransformer extends UnaryTransformer[Email, Integral](
  operationName = "lengthUnary",
  uid = UID[LengthTransformer]
) {
  override def transformFn: (Email => Integral) = (input: Email) => input.value.map(_.length).toIntegral
}


class LengthMapTransformer
(
  uid: String = UID[LengthMapTransformer],
  operationName: String = "lengthMap"
) extends OPMapTransformer[Email, Integral, EmailMap, IntegralMap](
  uid = uid,
  operationName = operationName,
  transformer = new LengthTransformer
) 
Example 182
Source File: TimePeriodListTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.date.DateTimeUtils
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.Transformer
import org.joda.time.{DateTime => JDateTime}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TimePeriodListTransformerTest extends OpTransformerSpec[OPVector, TimePeriodListTransformer[DateList]] {

  val dateList: DateList = Seq[Long](
    new JDateTime(1879, 3, 14, 0, 0, DateTimeUtils.DefaultTimeZone).getMillis,
    new JDateTime(1955, 11, 12, 10, 4, DateTimeUtils.DefaultTimeZone).getMillis,
    new JDateTime(1999, 3, 8, 12, 0, DateTimeUtils.DefaultTimeZone).getMillis,
    new JDateTime(2019, 4, 30, 13, 0, DateTimeUtils.DefaultTimeZone).getMillis
  ).toDateList

  val (inputData, f1) = TestFeatureBuilder(Seq(dateList))

  override val transformer: TimePeriodListTransformer[DateList] =
    new TimePeriodListTransformer(TimePeriod.DayOfMonth).setInput(f1)

  override val expectedResult: Seq[OPVector] = Seq(Seq(14, 12, 8, 30).map(_.toDouble).toVector.toOPVector)

  it should "transform with rich shortcuts" in {
    val dlist = List(new JDateTime(1879, 3, 14, 0, 0, DateTimeUtils.DefaultTimeZone).getMillis)
    val (inputData2, d1, d2) = TestFeatureBuilder(
      Seq[(DateList, DateTimeList)]((dlist.toDateList, dlist.toDateTimeList))
    )

    def assertFeature(feature: FeatureLike[OPVector], expected: Seq[OPVector]): Unit = {
      val transformed = feature.originStage.asInstanceOf[Transformer].transform(inputData2)
      val actual = transformed.collect(feature)
      actual shouldBe expected
    }

    assertFeature(d1.toTimePeriod(TimePeriod.DayOfMonth), Seq(Vector(14.0).toOPVector))
    assertFeature(d2.toTimePeriod(TimePeriod.DayOfMonth), Seq(Vector(14.0).toOPVector))
  }
} 
Example 183
Source File: FilterIntegralMapTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class FilterIntegralMapTest extends OpTransformerSpec[IntegralMap, FilterMap[IntegralMap]] {

  val (inputData, f1Int) = TestFeatureBuilder[IntegralMap](
    Seq(
      IntegralMap(Map("Arthur" -> 1, "Lancelot" -> 2, "Galahad" -> 3)),
      IntegralMap(Map("Lancelot" -> 2, "Galahad" -> 3, "Bedevere" -> 4)),
      IntegralMap(Map("Knight" -> 5))
    )
  )
  val transformer = new FilterMap[IntegralMap]().setInput(f1Int)

  val expectedResult: Seq[IntegralMap] = Seq(
    IntegralMap(Map("Arthur" -> 1, "Lancelot" -> 2, "Galahad" -> 3)),
    IntegralMap(Map("Lancelot" -> 2, "Galahad" -> 3, "Bedevere" -> 4)),
    IntegralMap(Map("Knight" -> 5))
  )

  it should "filter by whitelisted keys" in {
    transformer.setWhiteListKeys(Array("Arthur", "Knight"))
    val filtered = transformer.transform(inputData).collect(transformer.getOutput())

    val dataExpected = Array(
      IntegralMap(Map("Arthur" -> 1)),
      IntegralMap.empty,
      IntegralMap(Map("Knight" -> 5))
    )

    filtered should contain theSameElementsAs dataExpected
  }

  it should "filter by blacklisted keys" in {
    transformer.setInput(f1Int)
      .setWhiteListKeys(Array[String]())
      .setBlackListKeys(Array("Arthur", "Knight"))
    val filtered = transformer.transform(inputData).collect(transformer.getOutput())

    val dataExpected = Array(
      IntegralMap(Map("Lancelot" -> 2, "Galahad" -> 3)),
      IntegralMap(Map("Lancelot" -> 2, "Galahad" -> 3, "Bedevere" -> 4)),
      IntegralMap.empty
    )

    filtered should contain theSameElementsAs dataExpected
  }

} 
Example 184
Source File: Base64VectorizerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.OpWorkflow
import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types._
import com.salesforce.op.test.TestSparkContext
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class Base64VectorizerTest extends FlatSpec with TestSparkContext with Base64TestData with AttributeAsserts {

  "Base64Vectorizer" should "vectorize random binary data" in {
    val vec = randomBase64.vectorize(topK = 10, minSupport = 0, cleanText = true, trackNulls = false)
    val result = new OpWorkflow().setResultFeatures(vec).transform(randomData)

    result.collect(vec) should contain theSameElementsInOrderAs
      OPVector(Vectors.dense(0.0, 0.0)) +:
        Array.fill(expectedRandom.length - 1)(OPVector(Vectors.dense(1.0, 0.0)))
  }
  it should "vectorize some real binary content" in {
    val vec = realBase64.vectorize(topK = 10, minSupport = 0, cleanText = true)
    assertVectorizer(vec, expectedMime)
  }
  it should "vectorize some real binary content with a type hint" in {
    val vec = realBase64.vectorize(topK = 10, minSupport = 0, cleanText = true, typeHint = Some("application/json"))
    assertVectorizer(vec, expectedMimeJson)
  }

  def assertVectorizer(vec: FeatureLike[OPVector], expected: Seq[Text]): Unit = {
    val result = new OpWorkflow().setResultFeatures(vec).transform(realData)
    val vectors = result.collect(vec)
    val schema = result.schema(vec.name)
    assertNominal(schema, Array.fill(vectors.head.value.size)(true), vectors)

    vectors.length shouldBe expected.length
    // TODO add a more robust check
  }

} 
Example 185
Source File: FilterMultiPickListMapTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class FilterMultiPickListMapTest extends OpTransformerSpec[MultiPickListMap, FilterMap[MultiPickListMap]] {
  val (inputData, f1Cat) = TestFeatureBuilder[MultiPickListMap](
    Seq(
      MultiPickListMap(Map("Arthur" -> Set("King", "Briton"),
        "Lancelot" -> Set("Brave", "Knight"),
        "Galahad" -> Set("Pure", "Knight"))),
      MultiPickListMap(Map("Lancelot" -> Set("Brave", "Knight"),
        "Galahad" -> Set("Pure", "Knight"),
        "Bedevere" -> Set("Wise", "Knight"))),
      MultiPickListMap(Map("Knight" -> Set("Ni", "Ekke Ekke Ekke Ekke Ptang Zoo Boing")))
    )
  )
  val transformer = new FilterMap[MultiPickListMap]().setInput(f1Cat)

  val expectedResult = Seq(
    MultiPickListMap(Map("Arthur" -> Set("King", "Briton"),
      "Lancelot" -> Set("Brave", "Knight"),
      "Galahad" -> Set("Pure", "Knight"))),
    MultiPickListMap(Map("Lancelot" -> Set("Brave", "Knight"),
      "Galahad" -> Set("Pure", "Knight"),
      "Bedevere" -> Set("Wise", "Knight"))),
    MultiPickListMap(Map("Knight" -> Set("Ni", "EkkeEkkeEkkeEkkePtangZooBoing")))
  )

  it should "filter whitelisted keys" in {
    transformer.setWhiteListKeys(Array("Arthur", "Knight"))
    val filtered = transformer.transform(inputData).collect(transformer.getOutput())

    val dataExpected = Array(
      MultiPickListMap(Map("Arthur" -> Set("King", "Briton"))),
      MultiPickListMap.empty,
      MultiPickListMap(Map("Knight" -> Set("Ni", "EkkeEkkeEkkeEkkePtangZooBoing")))
    )

    filtered should contain theSameElementsAs dataExpected
  }

  it should "filter blacklisted keys" in {
    transformer
      .setWhiteListKeys(Array[String]())
      .setBlackListKeys(Array("Arthur", "Knight"))

    val filtered = transformer.transform(inputData).collect(transformer.getOutput())

    val dataExpected = Array(
      MultiPickListMap(Map("Lancelot" -> Set("Brave", "Knight"),
        "Galahad" -> Set("Pure", "Knight"))),
      MultiPickListMap(Map("Lancelot" -> Set("Brave", "Knight"),
        "Galahad" -> Set("Pure", "Knight"),
        "Bedevere" -> Set("Wise", "Knight"))),
      MultiPickListMap.empty
    )

    filtered should contain theSameElementsAs dataExpected
  }

  it should "not clean map when flag set to false" in {
    transformer
      .setCleanText(false)
      .setCleanKeys(false)
      .setWhiteListKeys(Array("Arthur", "Knight"))
      .setBlackListKeys(Array())
    val filtered = transformer.transform(inputData).collect(transformer.getOutput())

    val dataExpected = Array(
      MultiPickListMap(Map("Arthur" -> Set("King", "Briton"))),
      MultiPickListMap.empty,
      MultiPickListMap(Map("Knight" -> Set("Ni", "Ekke Ekke Ekke Ekke Ptang Zoo Boing")))
    )
    filtered should contain theSameElementsAs dataExpected
  }

  it should "clean map when flag set to true" in {
    transformer
      .setCleanKeys(true)
      .setCleanText(true)
      .setWhiteListKeys(Array("Arthur", "Knight"))
      .setBlackListKeys(Array())
    val filtered = transformer.transform(inputData).collect(transformer.getOutput())

    val dataExpected = Array(
      MultiPickListMap(Map("Arthur" -> Set("King", "Briton"))),
      MultiPickListMap.empty,
      MultiPickListMap(Map("Knight" -> Set("Ni", "EkkeEkkeEkkeEkkePtangZooBoing")))
    )
    filtered should contain theSameElementsAs dataExpected
  }

} 
Example 186
Source File: AliasTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.binary.BinaryLambdaTransformer
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.utils.tuples.RichTuple._
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class AliasTransformerTest extends OpTransformerSpec[RealNN, AliasTransformer[RealNN]] {
  val sample = Seq((RealNN(1.0), RealNN(2.0)), (RealNN(4.0), RealNN(4.0)))
  val (inputData, f1, f2) = TestFeatureBuilder(sample)
  val transformer = new AliasTransformer(name = "feature").setInput(f1)
  val expectedResult: Seq[RealNN] = sample.map(_._1)

  it should "have a shortcut that changes feature name on a raw feature" in {
    val feature = f1.alias
    feature.name shouldBe "feature"
    feature.originStage shouldBe a[AliasTransformer[_]]
    val origin = feature.originStage.asInstanceOf[AliasTransformer[RealNN]]
    val transformed = origin.transform(inputData)
    transformed.collect(feature) shouldEqual expectedResult
  }
  it should "have a shortcut that changes feature name on a derived feature" in {
    val feature = (f1 / f2).alias
    feature.name shouldBe "feature"
    feature.originStage shouldBe a[DivideTransformer[_, _]]
    val origin = feature.originStage.asInstanceOf[DivideTransformer[_, _]]
    val transformed = origin.transform(inputData)
    transformed.columns should contain (feature.name)
    transformed.collect(feature) shouldEqual sample.map { case (v1, v2) => (v1.v -> v2.v).map(_ / _).toRealNN(0.0) }
  }
  it should "have a shortcut that changes feature name on a derived wrapped feature" in {
    val feature = f1.toIsotonicCalibrated(label = f2).alias
    feature.name shouldBe "feature"
    feature.originStage shouldBe a[AliasTransformer[_]]
  }
} 
Example 187
Source File: PowerTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class PowerTransformerTest extends OpTransformerSpec[Real, PowerTransformer[Real]] {
  val sample = Seq(Real(-1.3), Real(-4.9), Real.empty, Real(5.1), Real(-5.1), Real(0.1), Real(2.5), Real(0.4))
  val (inputData, f1) = TestFeatureBuilder(sample)
  val transformer: PowerTransformer[Real] = new PowerTransformer[Real](3.0).setInput(f1)
  override val expectedResult: Seq[Real] = Seq(Some(-1.3), Some(-4.9), None,
    Some(5.1), Some(-5.1), Some(0.1), Some(2.5),
    Some(0.4)).map(_.map(v => math.pow(v, 3)).toReal)

  it should "have a working shortcut" in {
    val f2 = f1.power(4)
    f2.originStage.isInstanceOf[PowerTransformer[_]] shouldBe true
  }
} 
Example 188
Source File: CeilTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class CeilTransformerTest extends OpTransformerSpec[Integral, CeilTransformer[Real]] {
  val sample = Seq(Real(-1.3), Real(-4.9), Real.empty, Real(5.1), Real(-5.1), Real(0.1), Real(2.5), Real(0.4))
  val (inputData, f1) = TestFeatureBuilder(sample)
  val transformer: CeilTransformer[Real] = new CeilTransformer[Real]().setInput(f1)
  override val expectedResult: Seq[Integral] = Seq(Integral(-1), Integral(-4), Integral.empty, Integral(6),
    Integral(-5), Integral(1), Integral(3), Integral(1))

  it should "have a working shortcut" in {
    val f2 = f1.ceil()
    f2.originStage.isInstanceOf[CeilTransformer[_]] shouldBe true
  }
} 
Example 189
Source File: TimePeriodTransformerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.FeatureLike
import com.salesforce.op.features.types._
import com.salesforce.op.test.{OpTransformerSpec, TestFeatureBuilder}
import com.salesforce.op.utils.date.DateTimeUtils
import com.salesforce.op.utils.spark.RichDataset._
import org.apache.spark.ml.Transformer
import org.joda.time.{DateTime => JDateTime}
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TimePeriodTransformerTest extends OpTransformerSpec[Integral, TimePeriodTransformer[Date]] {

  val (inputData, f1) = TestFeatureBuilder(Seq[Date](
    new JDateTime(1879, 3, 14, 0, 0, DateTimeUtils.DefaultTimeZone).getMillis.toDate,
    new JDateTime(1955, 11, 12, 10, 4, DateTimeUtils.DefaultTimeZone).getMillis.toDate,
    new JDateTime(1999, 3, 8, 12, 0, DateTimeUtils.DefaultTimeZone).getMillis.toDate,
    Date.empty,
    new JDateTime(2019, 4, 30, 13, 0, DateTimeUtils.DefaultTimeZone).getMillis.toDate
  ))

  override val transformer: TimePeriodTransformer[Date] = new TimePeriodTransformer(TimePeriod.DayOfMonth).setInput(f1)

  override val expectedResult: Seq[Integral] =
    Seq(Integral(14), Integral(12), Integral(8), Integral.empty, Integral(30))

  it should "correctly transform for all TimePeriod types" in {
    def assertFeature(feature: FeatureLike[Integral], expected: Seq[Integral]): Unit = {
      val transformed = feature.originStage.asInstanceOf[Transformer].transform(inputData)
      val actual = transformed.collect(feature)
      actual shouldBe expected
    }

    TimePeriod.values.foreach(tp => {
      val expected = tp match {
        case TimePeriod.DayOfMonth => Array(Integral(14), Integral(12), Integral(8), Integral.empty, Integral(30))
        case TimePeriod.DayOfWeek => Array(Integral(5), Integral(6), Integral(1), Integral.empty, Integral(2))
        case TimePeriod.DayOfYear => Array(Integral(73), Integral(316), Integral(67), Integral.empty, Integral(120))
        case TimePeriod.HourOfDay => Array(Integral(0), Integral(10), Integral(12), Integral.empty, Integral(13))
        case TimePeriod.MonthOfYear => Array(Integral(3), Integral(11), Integral(3), Integral.empty, Integral(4))
        case TimePeriod.WeekOfMonth => Array(Integral(3), Integral(2), Integral(2), Integral.empty, Integral(5))
        case TimePeriod.WeekOfYear => Array(Integral(11), Integral(46), Integral(11), Integral.empty, Integral(18))
        case _ => throw new Exception(s"Unexpected TimePeriod encountered, $tp")
      }

      withClue(s"Assertion failed for TimePeriod $tp: ") {
        assertFeature(f1.toTimePeriod(tp), expected)
      }
    })
  }
} 
Example 190
Source File: DataSplitterTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.tuning

import com.salesforce.op.test.TestSparkContext
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.mllib.random.RandomRDDs
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class DataSplitterTest extends FlatSpec with TestSparkContext with SplitterSummaryAsserts {
  import spark.implicits._

  val seed = 1234L
  val dataCount = 1000
  val trainingLimitDefault = 1E6.toLong

  val data =
    RandomRDDs.normalVectorRDD(sc, 1000, 3, seed = seed)
      .map(v => (1.0, Vectors.dense(v.toArray), "A")).toDF()

  val dataSplitter = DataSplitter(seed = seed)

  Spec[DataSplitter] should "split the data in the appropriate proportion - 0.0" in {
    val (train, test) = dataSplitter.setReserveTestFraction(0.0).split(data)
    test.count() shouldBe 0
    train.count() shouldBe dataCount
  }

  it should "down-sample when the data count is above the default training limit" in {
    val numRows = trainingLimitDefault * 2
    val data =
      RandomRDDs.normalVectorRDD(sc, numRows, 3, seed = seed)
        .map(v => (1.0, Vectors.dense(v.toArray), "A")).toDF()
    dataSplitter.preValidationPrepare(data)

    val dataBalanced = dataSplitter.validationPrepare(data)
    // validationPrepare calls the data sample method that samples the data to a target ratio but there is an epsilon
    // to how precise this function is which is why we need to check around that epsilon
    val samplingErrorEpsilon = (0.1 * trainingLimitDefault).toLong

    dataBalanced.count() shouldBe trainingLimitDefault +- samplingErrorEpsilon
  }

  it should "set and get all data splitter params" in {
    val maxRows = dataCount / 2
    val downSampleFraction = maxRows / dataCount.toDouble

    val dataSplitter = DataSplitter()
      .setReserveTestFraction(0.0)
      .setSeed(seed)
      .setMaxTrainingSample(maxRows)
      .setDownSampleFraction(downSampleFraction)

    dataSplitter.getReserveTestFraction shouldBe 0.0
    dataSplitter.getDownSampleFraction shouldBe downSampleFraction
    dataSplitter.getSeed shouldBe seed
    dataSplitter.getMaxTrainingSample shouldBe maxRows
  }

  it should "split the data in the appropriate proportion - 0.2" in {
    val (train, test) = dataSplitter.setReserveTestFraction(0.2).split(data)
    math.abs(test.count() - 200) < 30 shouldBe true
    math.abs(train.count() - 800) < 30 shouldBe true
  }

  it should "split the data in the appropriate proportion - 0.6" in {
    val (train, test) = dataSplitter.setReserveTestFraction(0.6).split(data)
    math.abs(test.count() - 600) < 30 shouldBe true
    math.abs(train.count() - 400) < 30 shouldBe true
  }

  it should "keep the data unchanged when prepare is called" in {
    val dataCount = data.count()
    val summary = dataSplitter.preValidationPrepare(data)
    val train = dataSplitter.validationPrepare(data)
    val sampleF = trainingLimitDefault / dataCount.toDouble
    val downSampleFraction = math.min(sampleF, 1.0)
    train.collect().zip(data.collect()).foreach { case (a, b) => a shouldBe b }
    assertDataSplitterSummary(summary.summaryOpt) { s => s shouldBe DataSplitterSummary(dataCount, downSampleFraction) }
  }

} 
Example 191
Source File: RandomParamBuilderTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.selector

import com.salesforce.op.stages.impl.classification.{OpLogisticRegression, OpRandomForestClassifier, OpXGBoostClassifier}
import com.salesforce.op.test.TestSparkContext
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class RandomParamBuilderTest extends FlatSpec with TestSparkContext {

  private val lr = new OpLogisticRegression()
  private val rf = new OpRandomForestClassifier()
  private val xgb = new OpXGBoostClassifier()


  Spec[RandomParamBuilder] should "build a param grid of the desired length with one param variable" in {
    val min = 0.00001
    val max = 10
    val lrParams = new RandomParamBuilder()
      .uniform(lr.regParam, min, max)
      .build(5)
    lrParams.length shouldBe 5
    lrParams.foreach(_.toSeq.length shouldBe 1)
    lrParams.foreach(_.toSeq.foreach( p => (p.value.asInstanceOf[Double] < max &&
      p.value.asInstanceOf[Double] > min) shouldBe true))
    lrParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam))

    val lrParams2 = new RandomParamBuilder()
      .exponential(lr.regParam, min, max)
      .build(20)
    lrParams2.length shouldBe 20
    lrParams2.foreach(_.toSeq.length shouldBe 1)
    lrParams2.foreach(_.toSeq.foreach( p => (p.value.asInstanceOf[Double] < max &&
      p.value.asInstanceOf[Double] > min) shouldBe true))
    lrParams2.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam))
  }

  it should "build a param grid of the desired length with many param variables" in {
    val lrParams = new RandomParamBuilder()
      .exponential(lr.regParam, .000001, 10)
      .subset(lr.family, Seq("auto", "binomial", "multinomial"))
      .uniform(lr.maxIter, 2, 50)
      .build(23)
    lrParams.length shouldBe 23
    lrParams.foreach(_.toSeq.length shouldBe 3)
    lrParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(lr.regParam, lr.family, lr.maxIter))
  }

  it should "work for all param types" in {
    val xgbParams = new RandomParamBuilder()
      .subset(xgb.checkpointPath, Seq("a", "b")) // string
      .uniform(xgb.alpha, 0, 1) // double
      .uniform(xgb.missing, 0, 100) // float
      .uniform(xgb.checkpointInterval, 2, 5) // int
      .uniform(xgb.seed, 5, 1000) // long
      .uniform(xgb.useExternalMemory) // boolean
      .exponential(xgb.baseScore, 0.0001, 1) // double
      .exponential(xgb.missing, 0.000001F, 1) // float - overwrites first call
      .build(2)

    xgbParams.length shouldBe 2
    xgbParams.foreach(_.toSeq.length shouldBe 7)
    xgbParams.foreach(_.toSeq.map(_.param).toSet shouldBe Set(xgb.checkpointPath, xgb.alpha, xgb.missing,
      xgb.checkpointInterval, xgb.seed, xgb.useExternalMemory, xgb.baseScore))
  }

  it should "throw a requirement error if an improper min value is passed in for exponential scale" in {
    intercept[IllegalArgumentException]( new RandomParamBuilder()
      .exponential(xgb.baseScore, 0, 1)).getMessage() shouldBe
      "requirement failed: Min value must be greater than zero for exponential distribution to work"
  }

  it should "throw a requirement error if an min max are passed in" in {
    intercept[IllegalArgumentException]( new RandomParamBuilder()
      .uniform(xgb.baseScore, 1, 0)).getMessage() shouldBe
      "requirement failed: Min must be less than max"
  }
} 
Example 192
Source File: OpLinearSVCTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{LinearSVC, LinearSVCModel}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpLinearSVCTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[LinearSVCModel],
  OpPredictorWrapper[LinearSVC, LinearSVCModel]] with PredictionEquality {

  override def specName: String = Spec[OpLinearSVC]

  val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features",
    Seq[(RealNN, OPVector)](
      1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector,
      0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector,
      1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector,
      1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector,
      1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector
    )
  )
  val feature1 = rawFeature1.copy(isResponse = true)
  val estimator = new OpLinearSVC().setInput(feature1, feature2)

  val expectedResult = Seq(
    Prediction(1.0, Vectors.dense(Array(-1.33, 1.33))),
    Prediction(0.0, Vectors.dense(Array(1.04, -1.04))),
    Prediction(0.0, Vectors.dense(Array(2.69, -2.69))),
    Prediction(1.0, Vectors.dense(Array(-1.32, 1.32))),
    Prediction(1.0, Vectors.dense(Array(-2.11, 2.11))),
    Prediction(0.0, Vectors.dense(Array(4.41, -4.41))),
    Prediction(1.0, Vectors.dense(Array(-1.46, 1.46))),
    Prediction(0.0, Vectors.dense(Array(1.42, -1.42)))
  )


  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setRegParam(0.1)
      .setMaxIter(20)
      .setTol(1E-4)
    estimator.fit(inputData)

    estimator.predictor.getRegParam shouldBe 0.1
    estimator.predictor.getMaxIter shouldBe 20
    estimator.predictor.getTol shouldBe 1E-4
  }
} 
Example 193
Source File: OpLogisticRegressionTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpLogisticRegressionTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[LogisticRegressionModel],
  OpPredictorWrapper[LogisticRegression, LogisticRegressionModel]] with PredictionEquality {

  override def specName: String = Spec[OpLogisticRegression]

  val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features",
    Seq[(RealNN, OPVector)](
      1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector,
      0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector,
      1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector,
      1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector,
      1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector
    )
  )
  val feature1 = rawFeature1.copy(isResponse = true)
  val estimator = new OpLogisticRegression().setInput(feature1, feature2)

  val expectedResult = Seq(
    Prediction(1.0, Array(-20.88, 20.88), Array(0.0, 1.0)),
    Prediction(0.0, Array(16.70, -16.7), Array(1.0, 0.0)),
    Prediction(0.0, Array(22.2, -22.2), Array(1.0, 0.0)),
    Prediction(1.0, Array(-18.35, 18.35), Array(0.0, 1.0)),
    Prediction(1.0, Array(-31.46, 31.46), Array(0.0, 1.0)),
    Prediction(0.0, Array(24.67, -24.67), Array(1.0, 0.0)),
    Prediction(1.0, Array(-22.07, 22.07), Array(0.0, 1.0)),
    Prediction(0.0, Array(20.9, -20.9), Array(1.0, 0.0))
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setRegParam(0.1)
      .setElasticNetParam(0.1)
      .setMaxIter(20)
    estimator.fit(inputData)

    estimator.predictor.getRegParam shouldBe 0.1
    estimator.predictor.getElasticNetParam shouldBe 0.1
    estimator.predictor.getMaxIter shouldBe 20
  }
} 
Example 194
Source File: OpXGBoostClassifierTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import ml.dmlc.xgboost4j.scala.spark.{OpXGBoostQuietLogging, XGBoostClassificationModel, XGBoostClassifier}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpXGBoostClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[XGBoostClassificationModel],
  OpPredictorWrapper[XGBoostClassifier, XGBoostClassificationModel]]
  with PredictionEquality with OpXGBoostQuietLogging {

  override def specName: String = Spec[OpXGBoostClassifier]

  val rawData = Seq(
    1.0 -> Vectors.dense(12.0, 4.3, 1.3),
    0.0 -> Vectors.dense(0.0, 0.3, 0.1),
    0.0 -> Vectors.dense(1.0, 3.9, 4.3),
    1.0 -> Vectors.dense(10.0, 1.3, 0.9),
    1.0 -> Vectors.dense(15.0, 4.7, 1.3),
    0.0 -> Vectors.dense(0.5, 0.9, 10.1),
    1.0 -> Vectors.dense(11.5, 2.3, 1.3),
    0.0 -> Vectors.dense(0.1, 3.3, 0.1)
  ).map { case (l, v) => l.toRealNN -> v.toOPVector }

  val (inputData, label, features) = TestFeatureBuilder("label", "features", rawData)

  val estimator = new OpXGBoostClassifier().setInput(label.copy(isResponse = true), features)
  estimator.setSilent(1)

  val expectedResult = Seq(
    Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
    Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
    Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
    Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
    Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
    Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284)),
    Prediction(1.0, Array(-0.6200000047683716, 0.6200000047683716), Array(0.3799999952316284, 0.6200000047683716)),
    Prediction(0.0, Array(-0.3799999952316284, 0.3799999952316284), Array(0.6200000047683716, 0.3799999952316284))
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator.setAlpha(0.872).setEta(0.99912)
    estimator.fit(inputData)
    estimator.predictor.getAlpha shouldBe 0.872
    estimator.predictor.getEta shouldBe 0.99912
  }
} 
Example 195
Source File: OpNaiveBayesTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpNaiveBayesTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[NaiveBayesModel],
  OpPredictorWrapper[NaiveBayes, NaiveBayesModel]] with PredictionEquality {

  override def specName: String = Spec[OpNaiveBayes]

  val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features",
    Seq[(RealNN, OPVector)](
      1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector,
      0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector,
      1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector,
      1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector,
      1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector
    )
  )
  val feature1 = rawFeature1.copy(isResponse = true)
  val estimator = new OpNaiveBayes().setInput(feature1, feature2)

  val expectedResult = Seq(
    Prediction(1.0, Array(-34.41, -14.85), Array(0.0, 1.0)),
    Prediction(0.0, Array(-1.07, -1.42), Array(0.58, 0.41)),
    Prediction(0.0, Array(-9.70, -17.99), Array(1.0, 0.0)),
    Prediction(1.0, Array(-26.22, -8.33), Array(0.0, 1.0)),
    Prediction(1.0, Array(-41.93, -16.49), Array(0.0, 1.0)),
    Prediction(0.0, Array(-8.60, -27.31), Array(1.0, 0.0)),
    Prediction(1.0, Array(-31.07, -11.44), Array(0.0, 1.0)),
    Prediction(0.0, Array(-4.54, -6.32), Array(0.85, 0.14))
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator.setSmoothing(2)
    estimator.fit(inputData)
    estimator.predictor.getSmoothing shouldBe 2
  }
} 
Example 196
Source File: OpMultilayerPerceptronClassifierTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpMultilayerPerceptronClassifierTest extends OpEstimatorSpec[Prediction,
  OpPredictorWrapperModel[MultilayerPerceptronClassificationModel],
  OpPredictorWrapper[MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel]] with PredictionEquality {

  override def specName: String = Spec[OpMultilayerPerceptronClassifier]

  val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features",
    Seq[(RealNN, OPVector)](
      1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector,
      0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector,
      1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector,
      1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector,
      1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector
    )
  )
  val feature1 = rawFeature1.copy(isResponse = true)
  val estimator = new OpMultilayerPerceptronClassifier()
    .setInput(feature1, feature2)
    .setLayers(Array(3, 5, 4, 2))

  val expectedResult = Seq(
    Prediction(1.0, Array(-9.655814651428148, 9.202335441336952), Array(6.456683124562021E-9, 0.9999999935433168)),
    Prediction(0.0, Array(9.475612761543069, -10.617525149157993), Array(0.9999999981221492, 1.877850786773977E-9)),
    Prediction(0.0, Array(9.715293827870028, -10.885255922155942), Array(0.9999999988694366, 1.130563392364822E-9)),
    Prediction(1.0, Array(-9.66776357765489, 9.215079716735316), Array(6.299199338896916E-9, 0.9999999937008006)),
    Prediction(1.0, Array(-9.668041712561456, 9.215387575592239), Array(6.2955091287182745E-9, 0.9999999937044908)),
    Prediction(0.0, Array(9.692904797559496, -10.860273756796797), Array(0.9999999988145918, 1.1854083109077814E-9)),
    Prediction(1.0, Array(-9.667687253240183, 9.214995747770411), Array(6.300209139771467E-9, 0.9999999936997908)),
    Prediction(0.0, Array(9.703097414537668, -10.872171694864653), Array(0.9999999988404908, 1.1595091005698914E-9))
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator.setMaxIter(50).setBlockSize(2).setSeed(42)
    estimator.fit(inputData)
    estimator.predictor.getMaxIter shouldBe 50
    estimator.predictor.getBlockSize shouldBe 2
    estimator.predictor.getSeed shouldBe 42
  }
} 
Example 197
Source File: OpDecisionTreeClassifierTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpDecisionTreeClassifierTest extends OpEstimatorSpec[Prediction,
  OpPredictorWrapperModel[DecisionTreeClassificationModel],
  OpPredictorWrapper[DecisionTreeClassifier, DecisionTreeClassificationModel]] with PredictionEquality {

  override def specName: String = Spec[OpDecisionTreeClassifier]

  val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features",
    Seq[(RealNN, OPVector)](
      1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector,
      0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector,
      1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector,
      1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector,
      1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector
    )
  )
  val feature1 = rawFeature1.copy(isResponse = true)
  val estimator = new OpDecisionTreeClassifier().setInput(feature1, feature2)

  val expectedResult = Seq(
    Prediction(1.0, Array(0.0, 4.0), Array(0.0, 1.0)),
    Prediction(0.0, Array(4.0, 0.0), Array(1.0, 0.0)),
    Prediction(0.0, Array(4.0, 0.0), Array(1.0, 0.0)),
    Prediction(1.0, Array(0.0, 4.0), Array(0.0, 1.0)),
    Prediction(1.0, Array(0.0, 4.0), Array(0.0, 1.0)),
    Prediction(0.0, Array(4.0, 0.0), Array(1.0, 0.0)),
    Prediction(1.0, Array(0.0, 4.0), Array(0.0, 1.0)),
    Prediction(0.0, Array(4.0, 0.0), Array(1.0, 0.0))
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxDepth(6)
      .setMaxBins(2)
      .setMinInstancesPerNode(2)
      .setMinInfoGain(0.1)
    estimator.fit(inputData)

    estimator.predictor.getMaxDepth shouldBe 6
    estimator.predictor.getMaxBins shouldBe 2
    estimator.predictor.getMinInstancesPerNode shouldBe 2
    estimator.predictor.getMinInfoGain shouldBe 0.1
  }
} 
Example 198
Source File: OpRandomForestClassifierTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class OpRandomForestClassifierTest extends
  OpEstimatorSpec[Prediction, OpPredictorWrapperModel[RandomForestClassificationModel],
    OpPredictorWrapper[RandomForestClassifier, RandomForestClassificationModel]] with PredictionEquality {

  override def specName: String = Spec[OpRandomForestClassifier]

  lazy val (inputData, rawLabelMulti, featuresMulti) =
    TestFeatureBuilder[RealNN, OPVector]("labelMulti", "featuresMulti",
      Seq(
        (1.0.toRealNN, Vectors.dense(12.0, 4.3, 1.3).toOPVector),
        (0.0.toRealNN, Vectors.dense(0.0, 0.3, 0.1).toOPVector),
        (2.0.toRealNN, Vectors.dense(1.0, 3.9, 4.3).toOPVector),
        (2.0.toRealNN, Vectors.dense(10.0, 1.3, 0.9).toOPVector),
        (1.0.toRealNN, Vectors.dense(15.0, 4.7, 1.3).toOPVector),
        (0.0.toRealNN, Vectors.dense(0.5, 0.9, 10.1).toOPVector),
        (1.0.toRealNN, Vectors.dense(11.5, 2.3, 1.3).toOPVector),
        (0.0.toRealNN, Vectors.dense(0.1, 3.3, 0.1).toOPVector),
        (2.0.toRealNN, Vectors.dense(1.0, 4.0, 4.5).toOPVector),
        (2.0.toRealNN, Vectors.dense(10.0, 1.5, 1.0).toOPVector)
      )
    )

  val labelMulti = rawLabelMulti.copy(isResponse = true)

  val estimator = new OpRandomForestClassifier().setInput(labelMulti, featuresMulti)

  val expectedResult = Seq(
    Prediction(1.0, Array(0.0, 17.0, 3.0), Array(0.0, 0.85, 0.15)),
    Prediction(0.0, Array(19.0, 0.0, 1.0), Array(0.95, 0.0, 0.05)),
    Prediction(2.0, Array(0.0, 1.0, 19.0), Array(0.0, 0.05, 0.95)),
    Prediction(2.0, Array(1.0, 2.0, 17.0), Array(0.05, 0.1, 0.85)),
    Prediction(1.0, Array(0.0, 17.0, 3.0), Array(0.0, 0.85, 0.15)),
    Prediction(0.0, Array(16.0, 0.0, 4.0), Array(0.8, 0.0, 0.2)),
    Prediction(1.0, Array(1.0, 17.0, 2.0), Array(0.05, 0.85, 0.1)),
    Prediction(0.0, Array(17.0, 0.0, 3.0), Array(0.85, 0.0, 0.15)),
    Prediction(2.0, Array(2.0, 1.0, 17.0), Array(0.1, 0.05, 0.85)),
    Prediction(2.0, Array(1.0, 2.0, 17.0), Array(0.05, 0.1, 0.85))
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxDepth(10)
      .setImpurity(Impurity.Gini.sparkName)
      .setMaxBins(33)
      .setMinInstancesPerNode(2)
      .setMinInfoGain(0.2)
      .setSubsamplingRate(0.9)
      .setNumTrees(21)
      .setSeed(2L)
    estimator.fit(inputData)

    estimator.predictor.getMaxDepth shouldBe 10
    estimator.predictor.getMaxBins shouldBe 33
    estimator.predictor.getImpurity shouldBe Impurity.Gini.sparkName
    estimator.predictor.getMinInstancesPerNode shouldBe 2
    estimator.predictor.getMinInfoGain shouldBe 0.2
    estimator.predictor.getSubsamplingRate shouldBe 0.9
    estimator.predictor.getNumTrees shouldBe 21
    estimator.predictor.getSeed shouldBe 2L
  }

} 
Example 199
Source File: OpGBTClassifierTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpGBTClassifierTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[GBTClassificationModel],
  OpPredictorWrapper[GBTClassifier, GBTClassificationModel]] with PredictionEquality {

  override def specName: String = Spec[OpGBTClassifier]

  val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features",
    Seq[(RealNN, OPVector)](
      1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector,
      0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector,
      1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector,
      1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector,
      1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector
    )
  )
  val feature1 = rawFeature1.copy(isResponse = true)
  val estimator = new OpGBTClassifier().setInput(feature1, feature2)

  val expectedResult = Seq(
    Prediction(1.0, Array(-1.54, 1.54), Array(0.04, 0.95)),
    Prediction(0.0, Array(1.54, -1.54), Array(0.95, 0.04)),
    Prediction(0.0, Array(1.54, -1.54), Array(0.95, 0.04)),
    Prediction(1.0, Array(-1.54, 1.54), Array(0.04, 0.95)),
    Prediction(1.0, Array(-1.54, 1.54), Array(0.04, 0.95)),
    Prediction(0.0, Array(1.54, -1.54), Array(0.95, 0.04)),
    Prediction(1.0, Array(-1.54, 1.54), Array(0.04, 0.95)),
    Prediction(0.0, Array(1.54, -1.54), Array(0.95, 0.04))
  )


  it should "allow the user to set the desired spark parameters" in {
    estimator
      .setMaxIter(10)
      .setMaxDepth(6)
      .setMaxBins(2)
      .setMinInstancesPerNode(2)
      .setMinInfoGain(0.1)
    estimator.fit(inputData)

    estimator.predictor.getMaxIter shouldBe 10
    estimator.predictor.getMaxDepth shouldBe 6
    estimator.predictor.getMaxBins shouldBe 2
    estimator.predictor.getMinInstancesPerNode shouldBe 2
    estimator.predictor.getMinInfoGain shouldBe 0.1

  }
} 
Example 200
Source File: PredictionDeIndexerTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.preparators


import com.salesforce.op.utils.spark.RichDataset._
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.UnaryLambdaTransformer
import com.salesforce.op.stages.impl.feature.OpStringIndexerNoFilter
import com.salesforce.op.test.{TestFeatureBuilder, TestSparkContext}
import org.junit.runner.RunWith
import org.scalatest.FlatSpec
import org.scalatest.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class PredictionDeIndexerTest extends FlatSpec with TestSparkContext {

  val data = Seq(("a", 0.0), ("b", 1.0), ("c", 2.0)).map { case (txt, num) => (txt.toText, num.toRealNN) }
  val (ds, txtF, numF) = TestFeatureBuilder(data)

  val response = txtF.indexed()
  val indexedData = response.originStage.asInstanceOf[OpStringIndexerNoFilter[_]].fit(ds).transform(ds)

  val permutation = new UnaryLambdaTransformer[RealNN, RealNN](
    operationName = "modulo",
    transformFn = v => ((v.value.get + 1).toInt % 3).toRealNN
  ).setInput(response)
  val pred = permutation.getOutput()
  val permutedData = permutation.transform(indexedData)

  val expected = Array("b", "c", "a").map(_.toText)

  Spec[PredictionDeIndexer] should "deindexed the feature correctly" in {
    val predDeIndexer = new PredictionDeIndexer().setInput(response, pred)
    val deIndexed = predDeIndexer.getOutput()

    val results = predDeIndexer.fit(permutedData).transform(permutedData).collect(deIndexed)
    results shouldBe expected
  }


  it should "throw a nice error when there is no metadata" in {
    val predDeIndexer = new PredictionDeIndexer().setInput(numF, pred)
    the[Error] thrownBy {
      predDeIndexer.fit(permutedData).transform(permutedData)
    } should have message
      s"The feature ${numF.name} does not contain any label/index mapping in its metadata"
  }
}