/* * Copyright (C) 2017 Radicalbit * * This file is part of flink-JPMML * * flink-JPMML is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * flink-JPMML is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with flink-JPMML. If not, see <http://www.gnu.org/licenses/>. */ package io.radicalbit.flink.pmml.scala import io.radicalbit.flink.pmml.scala.api.reader.ModelReader import io.radicalbit.flink.pmml.scala.models.prediction.{Prediction, Score, Target} import io.radicalbit.flink.pmml.scala.utils.PmmlLoaderKit import io.radicalbit.flink.streaming.spec.core.{FlinkPipelineTestKit, FlinkTestKitCompanion} import org.apache.flink.api.scala.ClosureCleaner import org.apache.flink.ml.math.{DenseVector, SparseVector, Vector} import org.apache.flink.runtime.client.JobExecutionException import org.apache.flink.streaming.api.scala._ import org.scalatest.{Matchers, WordSpecLike} object QuickDataStreamSpec extends FlinkTestKitCompanion[(Prediction, Vector)] class QuickDataStreamSpec extends FlinkPipelineTestKit[Vector, (Prediction, Vector)] with WordSpecLike with Matchers with PmmlLoaderKit { private implicit val companion = QuickDataStreamSpec private val defaultInput: Vector = DenseVector(1.0, 1.0, 1.0, 1.0) private val defaultSparseInput: Vector = SparseVector(4, Array(0, 1, 2, 3), Array(1.0, 1.0, 1.0, 1.0)) private val defaultPrediction = (Prediction(Score(3.0)), defaultInput) private val sparsePrediction = (Prediction(Score(3.0)), defaultSparseInput) private val emptyPrediction = (Prediction(Target.empty), defaultInput) private def pipelineBuilder(source: Option[String]) = { val reader = ModelReader(source getOrElse getPMMLSource(Source.KmeansPmml)) (in: DataStream[Vector]) => in.quickEvaluate(reader) } "QuickDataStream" should { "quick DataStream should be serializable" in { noException should be thrownBy ClosureCleaner.clean(pipelineBuilder(None), checkSerializable = true) } "return correct output sequence on heterogeneous input" in { val in: Seq[Vector] = Seq(defaultInput, defaultSparseInput) val out = Seq(defaultPrediction, sparsePrediction) executePipeline(in)(pipelineBuilder(None)) shouldBe out } "compute quick prediction with any dense input vector" in { val in: Seq[Vector] = Seq(defaultInput) val out = Seq(defaultPrediction) executePipeline(in)(pipelineBuilder(None)) shouldBe out } "compute quick predictions with any sparse input vector" in { val in: Seq[Vector] = Seq(defaultSparseInput) val out = Seq(sparsePrediction) executePipeline(in)(pipelineBuilder(None)) shouldBe out } "throw JobExecutionException if the model path cannot be loaded" in { val invalidSource = Source.NotExistingPath an[JobExecutionException] should be thrownBy { executePipeline(Seq(defaultInput))(pipelineBuilder(Some(invalidSource))) shouldBe Seq(defaultPrediction) } } "Emit empty prediction if the input is not valid" in { val shortInput: Vector = SparseVector(2, Array(0, 3), Array(1.0, 1.0)) executePipeline(Seq(shortInput))(pipelineBuilder(None)) shouldBe Seq((Prediction(Target.empty), shortInput)) } "Emit empty prediction if the model is not valid" in { val invalidModelSource = getPMMLSource(Source.KmeansPmmlEmpty) an[JobExecutionException] should be thrownBy { executePipeline(Seq(defaultInput))(pipelineBuilder(Some(invalidModelSource))) shouldBe Seq(emptyPrediction) } } } }