org.apache.spark.mllib.evaluation.BinaryClassificationMetrics Java Examples

The following examples show how to use org.apache.spark.mllib.evaluation.BinaryClassificationMetrics. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TrainingSparkRunner.java    From ambiverse-nlu with Apache License 2.0 4 votes vote down vote up
private void binaryEvaluation(DataFrame predictions, String output, TrainingSettings trainingSettings) throws IOException {

        FileSystem fs = FileSystem.get(new Configuration());
        Path evalPath = new Path(output+"binary_evaluation_"+trainingSettings.getClassificationMethod()+".txt");
        fs.delete(evalPath, true);
        FSDataOutputStream fsdos = fs.create(evalPath);

        BinaryClassificationMetrics metrics = new BinaryClassificationMetrics(predictions
                .select("rawPrediction", "label")
                .javaRDD()
                .map((Row row) -> {
                    Vector vector = row.getAs("rawPrediction");
                    Double label = row.getAs("label");
                    return new Tuple2<Object, Object>(vector.apply(1), label);
                }).rdd());


        // Precision by threshold
        JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
        IOUtils.write("\nPrecision by threshold: " + precision.collect(), fsdos);

        // Recall by threshold
        JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();
        IOUtils.write("\nRecall by threshold: " + recall.collect(), fsdos);

        // F Score by threshold
        JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
        IOUtils.write("\nF1 Score by threshold: " + f1Score.collect(), fsdos);

        JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
        IOUtils.write("\nF2 Score by threshold: " + f2Score.collect(), fsdos);

        // Precision-recall curve
        JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD();
        IOUtils.write("\nPrecision-recall curve: " + prc.collect(), fsdos);

        // Thresholds
        JavaRDD<Double> thresholds = precision.map(t -> new Double(t._1().toString()));

        // ROC Curve
        JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD();
        IOUtils.write("\nROC curve: " + roc.collect(), fsdos);

        // AUPRC
        IOUtils.write("\nArea under precision-recall curve = " + metrics.areaUnderPR(), fsdos);

        // AUROC
        IOUtils.write("\nArea under ROC = " + metrics.areaUnderROC(), fsdos);

        fsdos.flush();
        IOUtils.closeQuietly(fsdos);
    }
 
Example #2
Source File: SparkMultiClassClassifier.java    From mmtf-spark with Apache License 2.0 4 votes vote down vote up
/**
 * Dataset must at least contain the following two columns:
 * label: the class labels
 * features: feature vector
 * @param data
 * @return map with metrics
 */
public Map<String,String> fit(Dataset<Row> data) {
	int classCount = (int)data.select(label).distinct().count();

	StringIndexerModel labelIndexer = new StringIndexer()
	  .setInputCol(label)
	  .setOutputCol("indexedLabel")
	  .fit(data);

	// Split the data into training and test sets (30% held out for testing)
	Dataset<Row>[] splits = data.randomSplit(new double[] {1.0-testFraction, testFraction}, seed);
	Dataset<Row> trainingData = splits[0];
	Dataset<Row> testData = splits[1];
	
	String[] labels = labelIndexer.labels();
	
	System.out.println();
	System.out.println("Class\tTrain\tTest");
	for (String l: labels) {
		System.out.println(l + "\t" + trainingData.select(label).filter(label + " = '" + l + "'").count()
				+ "\t" 
				+ testData.select(label).filter(label + " = '" + l + "'").count());
	}
	
	// Set input columns
	predictor
	.setLabelCol("indexedLabel")
	.setFeaturesCol("features");

	// Convert indexed labels back to original labels.
	IndexToString labelConverter = new IndexToString()
	  .setInputCol("prediction")
	  .setOutputCol("predictedLabel")
	  .setLabels(labelIndexer.labels());

	// Chain indexers and forest in a Pipeline
	Pipeline pipeline = new Pipeline()
	  .setStages(new PipelineStage[] {labelIndexer, predictor, labelConverter});

	// Train model. This also runs the indexers.
	PipelineModel model = pipeline.fit(trainingData);

	// Make predictions.
	Dataset<Row> predictions = model.transform(testData).cache();
	
	// Display some sample predictions
	System.out.println();
	System.out.println("Sample predictions: " + predictor.getClass().getSimpleName());

	predictions.sample(false, 0.1, seed).show(25);	

	predictions = predictions.withColumnRenamed(label, "stringLabel");
	predictions = predictions.withColumnRenamed("indexedLabel", label);
	
	// collect metrics
	Dataset<Row> pred = predictions.select("prediction",label);
       Map<String,String> metrics = new LinkedHashMap<>();       
       metrics.put("Method", predictor.getClass().getSimpleName());
       
       if (classCount == 2) {
       	    BinaryClassificationMetrics b = new BinaryClassificationMetrics(pred);
         	metrics.put("AUC", Float.toString((float)b.areaUnderROC()));
       }
    
       MulticlassMetrics m = new MulticlassMetrics(pred); 
       metrics.put("F", Float.toString((float)m.weightedFMeasure()));
       metrics.put("Accuracy", Float.toString((float)m.accuracy()));
       metrics.put("Precision", Float.toString((float)m.weightedPrecision()));
       metrics.put("Recall", Float.toString((float)m.weightedRecall()));
       metrics.put("False Positive Rate", Float.toString((float)m.weightedFalsePositiveRate()));
       metrics.put("True Positive Rate", Float.toString((float)m.weightedTruePositiveRate()));
       metrics.put("", "\nConfusion Matrix\n" 
           + Arrays.toString(labels) +"\n" 
       		+ m.confusionMatrix().toString());
       
       return metrics;
}
 
Example #3
Source File: JavaSVMWithSGDExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkConf conf = new SparkConf().setAppName("JavaSVMWithSGDExample");
  SparkContext sc = new SparkContext(conf);
  // $example on$
  String path = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

  // Split initial RDD into two... [60% training data, 40% testing data].
  JavaRDD<LabeledPoint> training = data.sample(false, 0.6, 11L);
  training.cache();
  JavaRDD<LabeledPoint> test = data.subtract(training);

  // Run training algorithm to build the model.
  int numIterations = 100;
  final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations);

  // Clear the default threshold.
  model.clearThreshold();

  // Compute raw scores on the test set.
  JavaRDD<Tuple2<Object, Object>> scoreAndLabels = test.map(
    new Function<LabeledPoint, Tuple2<Object, Object>>() {
      public Tuple2<Object, Object> call(LabeledPoint p) {
        Double score = model.predict(p.features());
        return new Tuple2<Object, Object>(score, p.label());
      }
    }
  );

  // Get evaluation metrics.
  BinaryClassificationMetrics metrics =
    new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels));
  double auROC = metrics.areaUnderROC();

  System.out.println("Area under ROC = " + auROC);

  // Save and load model
  model.save(sc, "target/tmp/javaSVMWithSGDModel");
  SVMModel sameModel = SVMModel.load(sc, "target/tmp/javaSVMWithSGDModel");
  // $example off$

  sc.stop();
}
 
Example #4
Source File: JavaBinaryClassificationMetricsExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example");
  SparkContext sc = new SparkContext(conf);
  // $example on$
  String path = "data/mllib/sample_binary_classification_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

  // Split initial RDD into two... [60% training data, 40% testing data].
  JavaRDD<LabeledPoint>[] splits =
    data.randomSplit(new double[]{0.6, 0.4}, 11L);
  JavaRDD<LabeledPoint> training = splits[0].cache();
  JavaRDD<LabeledPoint> test = splits[1];

  // Run training algorithm to build the model.
  final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
    .setNumClasses(2)
    .run(training.rdd());

  // Clear the prediction threshold so the model will return probabilities
  model.clearThreshold();

  // Compute raw scores on the test set.
  JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
    new Function<LabeledPoint, Tuple2<Object, Object>>() {
      @Override
      public Tuple2<Object, Object> call(LabeledPoint p) {
        Double prediction = model.predict(p.features());
        return new Tuple2<Object, Object>(prediction, p.label());
      }
    }
  );

  // Get evaluation metrics.
  BinaryClassificationMetrics metrics =
    new BinaryClassificationMetrics(predictionAndLabels.rdd());

  // Precision by threshold
  JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
  System.out.println("Precision by threshold: " + precision.collect());

  // Recall by threshold
  JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();
  System.out.println("Recall by threshold: " + recall.collect());

  // F Score by threshold
  JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
  System.out.println("F1 Score by threshold: " + f1Score.collect());

  JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
  System.out.println("F2 Score by threshold: " + f2Score.collect());

  // Precision-recall curve
  JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD();
  System.out.println("Precision-recall curve: " + prc.collect());

  // Thresholds
  JavaRDD<Double> thresholds = precision.map(
    new Function<Tuple2<Object, Object>, Double>() {
      @Override
      public Double call(Tuple2<Object, Object> t) {
        return new Double(t._1().toString());
      }
    }
  );

  // ROC Curve
  JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD();
  System.out.println("ROC curve: " + roc.collect());

  // AUPRC
  System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());

  // AUROC
  System.out.println("Area under ROC = " + metrics.areaUnderROC());

  // Save and load model
  model.save(sc, "target/tmp/LogisticRegressionModel");
  LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel");
  // $example off$
}