org.apache.spark.mllib.classification.LogisticRegressionModel Java Examples

The following examples show how to use org.apache.spark.mllib.classification.LogisticRegressionModel. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: LogisticRegressionModelInfoAdapter.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel) {
    final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
    logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
    logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
    logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
    logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
    logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add("features");
    logisticRegressionModelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add("prediction");
    outputKeys.add("probability");
    logisticRegressionModelInfo.setOutputKeys(outputKeys);

    return logisticRegressionModelInfo;
}
 
Example #2
Source File: LogisticRegressionExporterTest.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Test
public void shouldExportAndImportCorrectly() {
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), 0.01);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), 0.01);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), 0.01);
    assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), 0.01);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], 0.01);

}
 
Example #3
Source File: LogisticRegressionModelInfoAdapter.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) {
    final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
    logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
    logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
    logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
    logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
    logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add("features");
    logisticRegressionModelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add("prediction");
    outputKeys.add("probability");
    logisticRegressionModelInfo.setOutputKeys(outputKeys);

    return logisticRegressionModelInfo;
}
 
Example #4
Source File: LogisticRegressionExporterTest.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Test
public void shouldExportAndImportCorrectly() {
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(data.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, null);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON);
    assertEquals((double) lrmodel.getThreshold().get(), importedModel.getThreshold(), EPSILON);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON);

}
 
Example #5
Source File: JavaLogisticRegressionWithLBFGSExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithLBFGSExample");
  SparkContext sc = new SparkContext(conf);
  // $example on$
  String path = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

  // Split initial RDD into two... [60% training data, 40% testing data].
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L);
  JavaRDD<LabeledPoint> training = splits[0].cache();
  JavaRDD<LabeledPoint> test = splits[1];

  // Run training algorithm to build the model.
  final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
    .setNumClasses(10)
    .run(training.rdd());

  // Compute raw scores on the test set.
  JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
    new Function<LabeledPoint, Tuple2<Object, Object>>() {
      public Tuple2<Object, Object> call(LabeledPoint p) {
        Double prediction = model.predict(p.features());
        return new Tuple2<Object, Object>(prediction, p.label());
      }
    }
  );

  // Get evaluation metrics.
  MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
  double accuracy = metrics.accuracy();
  System.out.println("Accuracy = " + accuracy);

  // Save and load model
  model.save(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel");
  LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
    "target/tmp/javaLogisticRegressionWithLBFGSModel");
  // $example off$

  sc.stop();
}
 
Example #6
Source File: LogisticRegressionBridgeTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = trainingData.collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, 0.01);
    }
}
 
Example #7
Source File: LogisticRegressionBridgeTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";
    JavaRDD<LabeledPoint> trainingData = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegressionWithSGD().run(trainingData.rdd());

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, null);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = trainingData.collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, EPSILON);
    }
}
 
Example #8
Source File: JavaMulticlassClassificationMetricsExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkConf conf = new SparkConf().setAppName("Multi class Classification Metrics Example");
  SparkContext sc = new SparkContext(conf);
  // $example on$
  String path = "data/mllib/sample_multiclass_classification_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

  // Split initial RDD into two... [60% training data, 40% testing data].
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.6, 0.4}, 11L);
  JavaRDD<LabeledPoint> training = splits[0].cache();
  JavaRDD<LabeledPoint> test = splits[1];

  // Run training algorithm to build the model.
  final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
    .setNumClasses(3)
    .run(training.rdd());

  // Compute raw scores on the test set.
  JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
    new Function<LabeledPoint, Tuple2<Object, Object>>() {
      public Tuple2<Object, Object> call(LabeledPoint p) {
        Double prediction = model.predict(p.features());
        return new Tuple2<Object, Object>(prediction, p.label());
      }
    }
  );

  // Get evaluation metrics.
  MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());

  // Confusion matrix
  Matrix confusion = metrics.confusionMatrix();
  System.out.println("Confusion matrix: \n" + confusion);

  // Overall statistics
  System.out.println("Accuracy = " + metrics.accuracy());

  // Stats by labels
  for (int i = 0; i < metrics.labels().length; i++) {
    System.out.format("Class %f precision = %f\n", metrics.labels()[i],metrics.precision(
      metrics.labels()[i]));
    System.out.format("Class %f recall = %f\n", metrics.labels()[i], metrics.recall(
      metrics.labels()[i]));
    System.out.format("Class %f F1 score = %f\n", metrics.labels()[i], metrics.fMeasure(
      metrics.labels()[i]));
  }

  //Weighted stats
  System.out.format("Weighted precision = %f\n", metrics.weightedPrecision());
  System.out.format("Weighted recall = %f\n", metrics.weightedRecall());
  System.out.format("Weighted F1 score = %f\n", metrics.weightedFMeasure());
  System.out.format("Weighted false positive rate = %f\n", metrics.weightedFalsePositiveRate());

  // Save and load model
  model.save(sc, "target/tmp/LogisticRegressionModel");
  LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
    "target/tmp/LogisticRegressionModel");
  // $example off$
}
 
Example #9
Source File: JavaBinaryClassificationMetricsExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkConf conf = new SparkConf().setAppName("Java Binary Classification Metrics Example");
  SparkContext sc = new SparkContext(conf);
  // $example on$
  String path = "data/mllib/sample_binary_classification_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

  // Split initial RDD into two... [60% training data, 40% testing data].
  JavaRDD<LabeledPoint>[] splits =
    data.randomSplit(new double[]{0.6, 0.4}, 11L);
  JavaRDD<LabeledPoint> training = splits[0].cache();
  JavaRDD<LabeledPoint> test = splits[1];

  // Run training algorithm to build the model.
  final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
    .setNumClasses(2)
    .run(training.rdd());

  // Clear the prediction threshold so the model will return probabilities
  model.clearThreshold();

  // Compute raw scores on the test set.
  JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
    new Function<LabeledPoint, Tuple2<Object, Object>>() {
      @Override
      public Tuple2<Object, Object> call(LabeledPoint p) {
        Double prediction = model.predict(p.features());
        return new Tuple2<Object, Object>(prediction, p.label());
      }
    }
  );

  // Get evaluation metrics.
  BinaryClassificationMetrics metrics =
    new BinaryClassificationMetrics(predictionAndLabels.rdd());

  // Precision by threshold
  JavaRDD<Tuple2<Object, Object>> precision = metrics.precisionByThreshold().toJavaRDD();
  System.out.println("Precision by threshold: " + precision.collect());

  // Recall by threshold
  JavaRDD<Tuple2<Object, Object>> recall = metrics.recallByThreshold().toJavaRDD();
  System.out.println("Recall by threshold: " + recall.collect());

  // F Score by threshold
  JavaRDD<Tuple2<Object, Object>> f1Score = metrics.fMeasureByThreshold().toJavaRDD();
  System.out.println("F1 Score by threshold: " + f1Score.collect());

  JavaRDD<Tuple2<Object, Object>> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD();
  System.out.println("F2 Score by threshold: " + f2Score.collect());

  // Precision-recall curve
  JavaRDD<Tuple2<Object, Object>> prc = metrics.pr().toJavaRDD();
  System.out.println("Precision-recall curve: " + prc.collect());

  // Thresholds
  JavaRDD<Double> thresholds = precision.map(
    new Function<Tuple2<Object, Object>, Double>() {
      @Override
      public Double call(Tuple2<Object, Object> t) {
        return new Double(t._1().toString());
      }
    }
  );

  // ROC Curve
  JavaRDD<Tuple2<Object, Object>> roc = metrics.roc().toJavaRDD();
  System.out.println("ROC curve: " + roc.collect());

  // AUPRC
  System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR());

  // AUROC
  System.out.println("Area under ROC = " + metrics.areaUnderROC());

  // Save and load model
  model.save(sc, "target/tmp/LogisticRegressionModel");
  LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel");
  // $example off$
}
 
Example #10
Source File: LogisticRegressionModelInfoAdapter.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Override
public Class<LogisticRegressionModel> getSource() {
    return LogisticRegressionModel.class;
}
 
Example #11
Source File: LogisticRegressionModelInfoAdapter.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Override
public Class<LogisticRegressionModel> getSource() {
    return LogisticRegressionModel.class;
}