org.apache.spark.ml.evaluation.BinaryClassificationEvaluator Java Examples

The following examples show how to use org.apache.spark.ml.evaluation.BinaryClassificationEvaluator. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SparkClassificationModel.java    From ambiverse-nlu with Apache License 2.0 6 votes vote down vote up
private Evaluator getEvaluator(TrainingSettings trainingSettings, Predictor predictor) {

        Evaluator evaluator = null;
        if(predictor instanceof RandomForestClassifier
               || predictor instanceof GBTClassifier
               || predictor instanceof DecisionTreeClassifier) {

            evaluator = new MulticlassClassificationEvaluator()
                    .setLabelCol("indexedLabel")
                    .setPredictionCol("prediction")
                    .setMetricName(trainingSettings.getMetricName());
        }

        if(predictor instanceof LogisticRegression) {
            evaluator = new BinaryClassificationEvaluator().setMetricName(trainingSettings.getMetricName());
        }

        return evaluator;
    }
 
Example #2
Source File: JavaModelSelectionViaCrossValidationExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaModelSelectionViaCrossValidationExample")
    .getOrCreate();

  // $example on$
  // Prepare training documents, which are labeled.
  Dataset<Row> training = spark.createDataFrame(Arrays.asList(
    new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
    new JavaLabeledDocument(1L, "b d", 0.0),
    new JavaLabeledDocument(2L,"spark f g h", 1.0),
    new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0),
    new JavaLabeledDocument(4L, "b spark who", 1.0),
    new JavaLabeledDocument(5L, "g d a y", 0.0),
    new JavaLabeledDocument(6L, "spark fly", 1.0),
    new JavaLabeledDocument(7L, "was mapreduce", 0.0),
    new JavaLabeledDocument(8L, "e spark program", 1.0),
    new JavaLabeledDocument(9L, "a e c l", 0.0),
    new JavaLabeledDocument(10L, "spark compile", 1.0),
    new JavaLabeledDocument(11L, "hadoop software", 0.0)
  ), JavaLabeledDocument.class);

  // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
  Tokenizer tokenizer = new Tokenizer()
    .setInputCol("text")
    .setOutputCol("words");
  HashingTF hashingTF = new HashingTF()
    .setNumFeatures(1000)
    .setInputCol(tokenizer.getOutputCol())
    .setOutputCol("features");
  LogisticRegression lr = new LogisticRegression()
    .setMaxIter(10)
    .setRegParam(0.01);
  Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});

  // We use a ParamGridBuilder to construct a grid of parameters to search over.
  // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
  // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
  ParamMap[] paramGrid = new ParamGridBuilder()
    .addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000})
    .addGrid(lr.regParam(), new double[] {0.1, 0.01})
    .build();

  // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
  // This will allow us to jointly choose parameters for all Pipeline stages.
  // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
  // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
  // is areaUnderROC.
  CrossValidator cv = new CrossValidator()
    .setEstimator(pipeline)
    .setEvaluator(new BinaryClassificationEvaluator())
    .setEstimatorParamMaps(paramGrid).setNumFolds(2);  // Use 3+ in practice

  // Run cross-validation, and choose the best set of parameters.
  CrossValidatorModel cvModel = cv.fit(training);

  // Prepare test documents, which are unlabeled.
  Dataset<Row> test = spark.createDataFrame(Arrays.asList(
    new JavaDocument(4L, "spark i j k"),
    new JavaDocument(5L, "l m n"),
    new JavaDocument(6L, "mapreduce spark"),
    new JavaDocument(7L, "apache hadoop")
  ), JavaDocument.class);

  // Make predictions on test documents. cvModel uses the best model found (lrModel).
  Dataset<Row> predictions = cvModel.transform(test);
  for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) {
    System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
      + ", prediction=" + r.get(3));
  }
  // $example off$

  spark.stop();
}