org.apache.spark.mllib.util.DataValidators Scala Examples

The following examples show how to use org.apache.spark.mllib.util.DataValidators. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.
Example 1
Source File: LogisticRegressionModel.scala    From keystone   with Apache License 2.0 5 votes vote down vote up
package keystoneml.nodes.learning

import breeze.linalg.Vector
import org.apache.spark.mllib.classification.{LogisticRegressionModel => MLlibLRM}
import org.apache.spark.mllib.linalg.{Vector => MLlibVector}
import org.apache.spark.mllib.optimization.{SquaredL2Updater, LogisticGradient, LBFGS}
import org.apache.spark.mllib.regression.{GeneralizedLinearAlgorithm, LabeledPoint}
import org.apache.spark.mllib.util.DataValidators
import org.apache.spark.rdd.RDD
import keystoneml.utils.MLlibUtils.breezeVectorToMLlib
import keystoneml.workflow.{LabelEstimator, Transformer}

import scala.reflect.ClassTag


  private[this] class LogisticRegressionWithLBFGS(numClasses: Int, numFeaturesValue: Int)
      extends GeneralizedLinearAlgorithm[MLlibLRM] with Serializable {

    this.numFeatures = numFeaturesValue
    override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater)

    override protected val validators = List(multiLabelValidator)

    require(numClasses > 1)
    numOfLinearPredictor = numClasses - 1
    if (numClasses > 2) {
      optimizer.setGradient(new LogisticGradient(numClasses))
    }

    private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
      if (numOfLinearPredictor > 1) {
        DataValidators.multiLabelValidator(numOfLinearPredictor + 1)(data)
      } else {
        DataValidators.binaryLabelValidator(data)
      }
    }

    override protected def createModel(weights: MLlibVector, intercept: Double) = {
      if (numOfLinearPredictor == 1) {
        new MLlibLRM(weights, intercept)
      } else {
        new MLlibLRM(weights, intercept, numFeatures, numOfLinearPredictor + 1)
      }
    }
  }

  override def fit(in: RDD[T], labels: RDD[Int]): LogisticRegressionModel[T] = {
    val labeledPoints = labels.zip(in).map(x => LabeledPoint(x._1, breezeVectorToMLlib(x._2)))
    val trainer = new LogisticRegressionWithLBFGS(numClasses, numFeatures)
    trainer.setValidateData(false).optimizer.setNumIterations(numIters).setRegParam(regParam)
    val model = trainer.run(labeledPoints)

    new LogisticRegressionModel(model)
  }
}