org.apache.spark.ml.classification.NaiveBayes Scala Examples

The following examples show how to use org.apache.spark.ml.classification.NaiveBayes. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: NaiveBayesTraining.scala    From piflow   with BSD 2-Clause "Simplified" License 5 votes vote down vote up
package cn.piflow.bundle.ml_classification

import cn.piflow.conf.bean.PropertyDescriptor
import cn.piflow.conf.util.{ImageUtil, MapUtil}
import cn.piflow.conf.{ConfigurableStop, Port, StopGroup}
import cn.piflow.{JobContext, JobInputStream, JobOutputStream, ProcessContext}
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.sql.SparkSession

class NaiveBayesTraining extends ConfigurableStop{
  val authorEmail: String = "[email protected]"
  val description: String = "Train a NaiveBayes model"
  val inportList: List[String] = List(Port.DefaultPort)
  val outportList: List[String] = List(Port.DefaultPort)
  var training_data_path:String =_
  var smoothing_value:String=_
  var model_save_path:String=_


  def perform(in: JobInputStream, out: JobOutputStream, pec: JobContext): Unit = {
    val spark = pec.get[SparkSession]()

    //load data stored in libsvm format as a dataframe
    val data=spark.read.format("libsvm").load(training_data_path)

    //get smoothing factor
    var smoothing_factor:Double=0
    if(smoothing_value!=""){
      smoothing_factor=smoothing_value.toDouble
    }

    //training a NaiveBayes model
    val model=new NaiveBayes().setSmoothing(smoothing_factor).fit(data)

    //model persistence
    model.save(model_save_path)

    import spark.implicits._
    val dfOut=Seq(model_save_path).toDF
    dfOut.show()
    out.write(dfOut)

  }

  def initialize(ctx: ProcessContext): Unit = {

  }


  def setProperties(map: Map[String, Any]): Unit = {
    training_data_path=MapUtil.get(map,key="training_data_path").asInstanceOf[String]
    smoothing_value=MapUtil.get(map,key="smoothing_value").asInstanceOf[String]
    model_save_path=MapUtil.get(map,key="model_save_path").asInstanceOf[String]
  }

  override def getPropertyDescriptor(): List[PropertyDescriptor] = {
    var descriptor : List[PropertyDescriptor] = List()
    val training_data_path = new PropertyDescriptor().name("training_data_path").displayName("TRAINING_DATA_PATH").defaultValue("").required(true)
    val smoothing_value = new PropertyDescriptor().name("smoothing_value").displayName("SMOOTHING_FACTOR").defaultValue("0").required(false)
    val model_save_path = new PropertyDescriptor().name("model_save_path").displayName("MODEL_SAVE_PATH").defaultValue("").required(true)
    descriptor = training_data_path :: descriptor
    descriptor = smoothing_value :: descriptor
    descriptor = model_save_path :: descriptor
    descriptor
  }

  override def getIcon(): Array[Byte] = {
    ImageUtil.getImage("icon/ml_classification/NavieBayesTraining.png")
  }

  override def getGroup(): List[String] = {
    List(StopGroup.MLGroup.toString)
  }

} 
Example 2
Source File: NaiveBayesClassifierParitySpec.scala    From mleap   with Apache License 2.0 5 votes vote down vote up
package org.apache.spark.ml.parity.classification

import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.{Pipeline, Transformer}
import org.apache.spark.ml.parity.SparkParityBase
import org.apache.spark.sql._


class NaiveBayesClassifierParitySpec extends SparkParityBase {
  override val dataset: DataFrame = baseDataset.select("fico_score_group_fnl", "approved")
  override val sparkTransformer: Transformer = new Pipeline().setStages(Array(new StringIndexer().
    setInputCol("fico_score_group_fnl").
    setOutputCol("fico_index"),
    new VectorAssembler().
      setInputCols(Array("fico_index")).
      setOutputCol("features"),
    new StringIndexer().
      setInputCol("approved").
      setOutputCol("label"),
    new NaiveBayes(uid = "nb").
      setModelType("multinomial").
      setThresholds(Array(0.4)).
      setFeaturesCol("features").
      setLabelCol("label"))).fit(dataset)

  override val unserializedParams = Set("stringOrderType", "labelCol", "smoothing")
} 
Example 3
package org.sparksamples.classification.stumbleupon

import org.apache.log4j.Logger
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.sql.DataFrame

import scala.collection.mutable


object NaiveBayesPipeline {
  @transient lazy val logger = Logger.getLogger(getClass.getName)

  def naiveBayesPipeline(vectorAssembler: VectorAssembler, dataFrame: DataFrame) = {
    val Array(training, test) = dataFrame.randomSplit(Array(0.9, 0.1), seed = 12345)

    // Set up Pipeline
    val stages = new mutable.ArrayBuffer[PipelineStage]()

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
    stages += labelIndexer

    val nb = new NaiveBayes()

    stages += vectorAssembler
    stages += nb
    val pipeline = new Pipeline().setStages(stages.toArray)

    // Fit the Pipeline
    val startTime = System.nanoTime()
    //val model = pipeline.fit(training)
    val model = pipeline.fit(dataFrame)
    val elapsedTime = (System.nanoTime() - startTime) / 1e9
    println(s"Training time: $elapsedTime seconds")

    //val holdout = model.transform(test).select("prediction","label")
    val holdout = model.transform(dataFrame).select("prediction","label")

    // Select (prediction, true label) and compute test error
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val mAccuracy = evaluator.evaluate(holdout)
    println("Test set accuracy = " + mAccuracy)
  }
} 
Example 4
package org.sparksamples.classification.stumbleupon

import org.apache.log4j.Logger
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.mllib.evaluation.{MulticlassMetrics, RegressionMetrics}
import org.apache.spark.sql.DataFrame

import scala.collection.mutable


object NaiveBayesPipeline {
  @transient lazy val logger = Logger.getLogger(getClass.getName)

  def naiveBayesPipeline(vectorAssembler: VectorAssembler, dataFrame: DataFrame) = {
    val Array(training, test) = dataFrame.randomSplit(Array(0.9, 0.1), seed = 12345)

    // Set up Pipeline
    val stages = new mutable.ArrayBuffer[PipelineStage]()

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
    stages += labelIndexer

    val nb = new NaiveBayes()

    stages += vectorAssembler
    stages += nb
    val pipeline = new Pipeline().setStages(stages.toArray)

    // Fit the Pipeline
    val startTime = System.nanoTime()
    //val model = pipeline.fit(training)
    val model = pipeline.fit(dataFrame)
    val elapsedTime = (System.nanoTime() - startTime) / 1e9
    println(s"Training time: $elapsedTime seconds")

    //val holdout = model.transform(test).select("prediction","label")
    val holdout = model.transform(dataFrame).select("prediction","label")

    // have to do a type conversion for RegressionMetrics
    val rm = new RegressionMetrics(holdout.rdd.map(x => (x(0).asInstanceOf[Double], x(1).asInstanceOf[Double])))

    logger.info("Test Metrics")
    logger.info("Test Explained Variance:")
    logger.info(rm.explainedVariance)
    logger.info("Test R^2 Coef:")
    logger.info(rm.r2)
    logger.info("Test MSE:")
    logger.info(rm.meanSquaredError)
    logger.info("Test RMSE:")
    logger.info(rm.rootMeanSquaredError)

    val predictions = model.transform(test).select("prediction").rdd.map(_.getDouble(0))
    val labels = model.transform(test).select("label").rdd.map(_.getDouble(0))
    val accuracy = new MulticlassMetrics(predictions.zip(labels)).precision
    println(s"  Accuracy : $accuracy")

    holdout.rdd.map(x => x(0).asInstanceOf[Double]).repartition(1).saveAsTextFile("/home/ubuntu/work/ml-resources/spark-ml/results/NB.xls")

    savePredictions(holdout, test, rm, "/home/ubuntu/work/ml-resources/spark-ml/results/NaiveBayes.csv")
  }

  def savePredictions(predictions:DataFrame, testRaw:DataFrame, regressionMetrics: RegressionMetrics, filePath:String) = {
    predictions
      .coalesce(1)
      .write.format("com.databricks.spark.csv")
      .option("header", "true")
      .save(filePath)
  }
} 
Example 5
package org.stumbleuponclassifier

import org.apache.log4j.Logger
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.ml.{Pipeline, PipelineStage}
import org.apache.spark.sql.DataFrame

import scala.collection.mutable


object NaiveBayesPipeline {
  @transient lazy val logger = Logger.getLogger(getClass.getName)

  def naiveBayesPipeline(vectorAssembler: VectorAssembler, dataFrame: DataFrame) = {
    val Array(training, test) = dataFrame.randomSplit(Array(0.9, 0.1), seed = 12345)

    // Set up Pipeline
    val stages = new mutable.ArrayBuffer[PipelineStage]()

    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
    stages += labelIndexer

    val nb = new NaiveBayes()

    stages += vectorAssembler
    stages += nb
    val pipeline = new Pipeline().setStages(stages.toArray)

    // Fit the Pipeline
    val startTime = System.nanoTime()
    //val model = pipeline.fit(training)
    val model = pipeline.fit(dataFrame)
    val elapsedTime = (System.nanoTime() - startTime) / 1e9
    println(s"Training time: $elapsedTime seconds")

    //val holdout = model.transform(test).select("prediction","label")
    val holdout = model.transform(dataFrame).select("prediction","label")

    // Select (prediction, true label) and compute test error
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val mAccuracy = evaluator.evaluate(holdout)
    println("Test set accuracy = " + mAccuracy)
  }
} 
Example 6
package org.apache.spark.examples.ml

import org.apache.spark.SparkConf
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator

import org.apache.spark.sql.SparkSession

object DocumentClassificationLibSVM {
  def main(args: Array[String]): Unit = {

    val spConfig = (new SparkConf).setMaster("local").setAppName("SparkApp")
    val spark = SparkSession
      .builder()
      .appName("SparkRatingData").config(spConfig)
      .getOrCreate()

    val data = spark.read.format("libsvm").load("./output/20news-by-date-train-libsvm/part-combined")

    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1L)

    // Train a NaiveBayes model.
    val model = new NaiveBayes()
      .fit(trainingData)
    val predictions = model.transform(testData)
    predictions.show()

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictions)
    println("Test set accuracy = " + accuracy)
    spark.stop()
  }
} 
Example 7
Source File: ClassifiersImpl.scala    From spark_training   with Apache License 2.0 5 votes vote down vote up
package com.malaska.spark.training.machinelearning.common

import org.apache.spark.ml.classification.{DecisionTreeClassifier, GBTClassifier, LogisticRegression, NaiveBayes}
import org.apache.spark.ml.evaluation.{MulticlassClassificationEvaluator, RegressionEvaluator}
import org.apache.spark.ml.regression.RandomForestRegressor
import org.apache.spark.sql._

object ClassifiersImpl {
  def logisticRegression(trainingLabeledPointDf: DataFrame,
                         testPercentage:Double): Unit = {
    val mlr = new LogisticRegression()
      .setMaxIter(10)
      .setRegParam(0.3)
      .setElasticNetParam(0.8)

    val splits = trainingLabeledPointDf.randomSplit(Array(testPercentage, 1-testPercentage))

    val model = mlr.fit(splits(0))

    val trainTransformed = model.transform(splits(1))

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(trainTransformed)
    println("Test set accuracy of logisticRegression = " + accuracy)

    //println(model)
  }

  def gbtClassifer(trainingLabeledPointDf: DataFrame,
                   testPercentage:Double): Unit = {
    val gbt = new GBTClassifier()

    val splits = trainingLabeledPointDf.randomSplit(Array(testPercentage, 1-testPercentage))

    val model = gbt.fit(splits(0))

    val trainTransformed = model.transform(splits(1))

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(trainTransformed)
    println("Test set accuracy of gbtClassifier = " + accuracy)

    //println(model)
    //println(model.toDebugString)
  }

  def randomForestRegressor(trainingLabeledPointDf: DataFrame,
                            impurity:String,
                            maxDepth:Int,
                            maxBins:Int,
                            testPercentage:Double): Unit = {
    val rf = new RandomForestRegressor()

    rf.setImpurity(impurity)
    rf.setMaxDepth(maxDepth)
    rf.setMaxBins(maxBins)

    val splits = trainingLabeledPointDf.randomSplit(Array(testPercentage, 1-testPercentage))

    val model = rf.fit(splits(0))
    val trainTransformed = model.transform(splits(1))

    

    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(trainTransformed)
    println("Test set accuracy of NaiveBayer = " + accuracy)
  }
} 
Example 8
Source File: OpNaiveBayes.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.UID
import com.salesforce.op.features.types.{OPVector, Prediction, RealNN}
import com.salesforce.op.stages.impl.CheckIsResponseValues
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpProbabilisticClassifierModel}
import com.salesforce.op.utils.reflection.ReflectionUtils.reflectMethod
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel, OpNaiveBayesParams}

import scala.reflect.runtime.universe.TypeTag


class OpNaiveBayesModel
(
  sparkModel: NaiveBayesModel,
  uid: String = UID[OpNaiveBayesModel],
  operationName: String = classOf[NaiveBayes].getSimpleName
)(
  implicit tti1: TypeTag[RealNN],
  tti2: TypeTag[OPVector],
  tto: TypeTag[Prediction],
  ttov: TypeTag[Prediction#Value]
) extends OpProbabilisticClassifierModel[NaiveBayesModel](
  sparkModel = sparkModel, uid = uid, operationName = operationName
) {
  @transient lazy val predictRawMirror = reflectMethod(getSparkMlStage().get, "predictRaw")
  @transient lazy val raw2probabilityMirror = reflectMethod(getSparkMlStage().get, "raw2probability")
  @transient lazy val probability2predictionMirror =
    reflectMethod(getSparkMlStage().get, "probability2prediction")
} 
Example 9
Source File: OpNaiveBayesTest.scala    From TransmogrifAI   with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
package com.salesforce.op.stages.impl.classification

import com.salesforce.op.features.types._
import com.salesforce.op.stages.impl.PredictionEquality
import com.salesforce.op.stages.sparkwrappers.specific.{OpPredictorWrapper, OpPredictorWrapperModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.ml.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.ml.linalg.Vectors
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class OpNaiveBayesTest extends OpEstimatorSpec[Prediction, OpPredictorWrapperModel[NaiveBayesModel],
  OpPredictorWrapper[NaiveBayes, NaiveBayesModel]] with PredictionEquality {

  override def specName: String = Spec[OpNaiveBayes]

  val (inputData, rawFeature1, feature2) = TestFeatureBuilder("label", "features",
    Seq[(RealNN, OPVector)](
      1.0.toRealNN -> Vectors.dense(12.0, 4.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.0, 0.3, 0.1).toOPVector,
      0.0.toRealNN -> Vectors.dense(1.0, 3.9, 4.3).toOPVector,
      1.0.toRealNN -> Vectors.dense(10.0, 1.3, 0.9).toOPVector,
      1.0.toRealNN -> Vectors.dense(15.0, 4.7, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.5, 0.9, 10.1).toOPVector,
      1.0.toRealNN -> Vectors.dense(11.5, 2.3, 1.3).toOPVector,
      0.0.toRealNN -> Vectors.dense(0.1, 3.3, 0.1).toOPVector
    )
  )
  val feature1 = rawFeature1.copy(isResponse = true)
  val estimator = new OpNaiveBayes().setInput(feature1, feature2)

  val expectedResult = Seq(
    Prediction(1.0, Array(-34.41, -14.85), Array(0.0, 1.0)),
    Prediction(0.0, Array(-1.07, -1.42), Array(0.58, 0.41)),
    Prediction(0.0, Array(-9.70, -17.99), Array(1.0, 0.0)),
    Prediction(1.0, Array(-26.22, -8.33), Array(0.0, 1.0)),
    Prediction(1.0, Array(-41.93, -16.49), Array(0.0, 1.0)),
    Prediction(0.0, Array(-8.60, -27.31), Array(1.0, 0.0)),
    Prediction(1.0, Array(-31.07, -11.44), Array(0.0, 1.0)),
    Prediction(0.0, Array(-4.54, -6.32), Array(0.85, 0.14))
  )

  it should "allow the user to set the desired spark parameters" in {
    estimator.setSmoothing(2)
    estimator.fit(inputData)
    estimator.predictor.getSmoothing shouldBe 2
  }
} 
Example 10
Source File: NaiveBayesExample.scala    From spark1.52   with Apache License 2.0 5 votes vote down vote up
// scalastyle:off println
package org.apache.spark.examples.ml

// $example on$
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.{SQLContext, DataFrame}
// $example off$


    predictions.show(5)

    // Select (prediction, true label) and compute test error
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("label")//标签列名
      .setPredictionCol("prediction")//预测结果列名
      .setMetricName("precision")//准确率
    //Accuracy: 1.0
    val accuracy = evaluator.evaluate(predictions)
    println("Accuracy: " + accuracy)
    // $example off$

    sc.stop()
  }
}
// scalastyle:on println 
Example 11
Source File: NaiveBayes.scala    From Scala-and-Spark-for-Big-Data-Analytics   with MIT License 5 votes vote down vote up
package com.chapter12.NaiveBayes

import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}



object NaiveBayesExample {
  def main(args: Array[String]): Unit = {    
    // Create the Spark session 
    val spark = SparkSession
      .builder
      .master("local[*]")
      .config("spark.sql.warehouse.dir", "E:/Exp/")
      .appName(s"OneVsRestExample")
      .getOrCreate()

    // Load the data stored in LIBSVM format as a DataFrame.
    val data = spark.read.format("libsvm").load("C:/Users/rezkar/Downloads/spark-2.1.0-bin-hadoop2.7/data/sample.data")

    // Split the data into training and test sets (30% held out for testing)
    val Array(trainingData, validationData) = data.randomSplit(Array(0.75, 0.25), seed = 12345L)

    // Train a NaiveBayes model.
    val nb = new NaiveBayes().setSmoothing(0.00001)        
    val model = nb.fit(trainingData)

    // Select example rows to display.
    val predictions = model.transform(validationData)
    predictions.show()

    // Select (prediction, true label) and compute test error obtain evaluator and compute the classification performnce metrics like accuracy, precision, recall and f1 measure. 
    val evaluator = new BinaryClassificationEvaluator().setLabelCol("label").setMetricName("areaUnderROC")
    val evaluator1 = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("accuracy")
    val evaluator2 = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("weightedPrecision")
    val evaluator3 = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("weightedRecall")
    val evaluator4 = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("f1")

    // compute the classification accuracy, precision, recall, f1 measure and error on test data.
    val areaUnderROC = evaluator.evaluate(predictions)
    val accuracy = evaluator1.evaluate(predictions)
    val precision = evaluator2.evaluate(predictions)
    val recall = evaluator3.evaluate(predictions)
    val f1 = evaluator4.evaluate(predictions)
    
    // Print the performance metrics
    println("areaUnderROC = " + areaUnderROC)
    println("Accuracy = " + accuracy)
    println("Precision = " + precision)
    println("Recall = " + recall)
    println("F1 = " + f1)
    println(s"Test Error = ${1 - accuracy}")
    
    data.show(20)

    spark.stop()
  }
} 
Example 12
Source File: NaiveBayesSuite.scala    From aardpfark   with Apache License 2.0 5 votes vote down vote up
package com.ibm.aardpfark.spark.ml.classification

import com.ibm.aardpfark.pfa.ProbClassifierResult

import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.linalg.{Vector, Vectors}

import org.apache.spark.sql.Row
import org.apache.spark.sql.functions._

class NaiveBayesSuite extends SparkClassifierPFASuiteBase[ProbClassifierResult] {

  import spark.implicits._

  val inputPath = "data/sample_multiclass_classification_data.txt"
  val dataset = spark.read.format("libsvm").load(inputPath)
  val multinomialData = dataset
    .as[(Double, Vector)]
    .map { case (label, vector) =>
      val nonZeroVector = Vectors.dense(vector.toArray.map(math.max(0.0, _)))
      (label, nonZeroVector)
    }.toDF("label", "features")

  val multinomialDataBinary = multinomialData.select(
    when(col("label") >= 1, 1.0).otherwise(0.0).alias("label"), col("features")
  )

  val bernoulliData = dataset
    .as[(Double, Vector)]
    .map { case (label, vector) =>
      val binaryData = vector.toArray.map {
        case e if e > 0.0 =>
          1.0
        case e if e <= 0.0 =>
          0.0
      }
      (label, Vectors.dense(binaryData))
    }.toDF("label", "features")

  val bernoulliDataBinary = bernoulliData.select(
    when(col("label") >= 1, 1.0).otherwise(0.0).alias("label"), col("features")
  )

  val clf = new NaiveBayes()

  override val sparkTransformer = clf.fit(multinomialData)
  val result = sparkTransformer.transform(multinomialData)
  override val input = withColumnAsArray(result, clf.getFeaturesCol).toJSON.collect()
  override val expectedOutput = result.select(clf.getPredictionCol, clf.getRawPredictionCol, clf.getProbabilityCol).map {
    case Row(p: Double, raw: Vector, pr: Vector) => (p, raw.toArray, pr.toArray)
  }.toDF(clf.getPredictionCol, clf.getRawPredictionCol, clf.getProbabilityCol).toJSON.collect()

  // Additional tests
  test("Multinomial model binary classification") {
    val sparkTransformer = clf.fit(multinomialDataBinary)
    val result = sparkTransformer.transform(multinomialDataBinary)
    val input = withColumnAsArray(result, clf.getFeaturesCol).toJSON.collect()
    val expectedOutput = result.select(clf.getPredictionCol, clf.getRawPredictionCol, clf.getProbabilityCol).map {
      case Row(p: Double, raw: Vector, pr: Vector) => (p, raw.toArray, pr.toArray)
    }.toDF(clf.getPredictionCol, clf.getRawPredictionCol, clf.getProbabilityCol).toJSON.collect()
    parityTest(sparkTransformer, input, expectedOutput)
  }

  test("Bernoulli model") {
    val sparkTransformer = clf.setModelType("bernoulli").fit(bernoulliData)
    val result = sparkTransformer.transform(bernoulliData)
    val input = withColumnAsArray(result, clf.getFeaturesCol).toJSON.collect()
    val expectedOutput = result.select(clf.getPredictionCol, clf.getRawPredictionCol, clf.getProbabilityCol).map {
      case Row(p: Double, raw: Vector, pr: Vector) => (p, raw.toArray, pr.toArray)
    }.toDF(clf.getPredictionCol, clf.getRawPredictionCol, clf.getProbabilityCol).toJSON.collect()
    parityTest(sparkTransformer, input, expectedOutput)
  }

  test("Bernoulli model binary classification") {
    val sparkTransformer = clf.setModelType("bernoulli").fit(bernoulliDataBinary)
    val result = sparkTransformer.transform(bernoulliDataBinary)
    val input = withColumnAsArray(result, clf.getFeaturesCol).toJSON.collect()
    val expectedOutput = result.select(clf.getPredictionCol, clf.getRawPredictionCol, clf.getProbabilityCol).map {
      case Row(p: Double, raw: Vector, pr: Vector) => (p, raw.toArray, pr.toArray)
    }.toDF(clf.getPredictionCol, clf.getRawPredictionCol, clf.getProbabilityCol).toJSON.collect()
    parityTest(sparkTransformer, input, expectedOutput)
  }

}