package com.flipkart.fdp.ml.adapter; import com.flipkart.fdp.ml.export.ModelExporter; import com.flipkart.fdp.ml.importer.ModelImporter; import com.flipkart.fdp.ml.transformer.Transformer; import org.apache.spark.ml.Pipeline; import org.apache.spark.ml.PipelineModel; import org.apache.spark.ml.PipelineStage; import org.apache.spark.ml.classification.RandomForestClassificationModel; import org.apache.spark.ml.classification.RandomForestClassifier; import org.apache.spark.ml.feature.StringIndexer; import org.apache.spark.ml.feature.StringIndexerModel; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.junit.Test; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; /** * Created by akshay.us on 8/29/16. */ public class RandomForestClassificationModelInfoAdapterBridgeTest extends SparkTestBase { @Test public void testRandomForestClassification() { // Load the data stored in LIBSVM format as a DataFrame. DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm"); StringIndexerModel stringIndexerModel = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") .fit(data); data = stringIndexerModel.transform(data); // Split the data into training and test sets (30% held out for testing) DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); DataFrame trainingData = splits[0]; DataFrame testData = splits[1]; // Train a RandomForest model. RandomForestClassificationModel classificationModel = new RandomForestClassifier() .setLabelCol("labelIndex") .setFeaturesCol("features") .setPredictionCol("prediction") .setRawPredictionCol("rawPrediction") .setProbabilityCol("probability") .fit(trainingData); byte[] exportedModel = ModelExporter.export(classificationModel, null); Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); Row[] sparkOutput = classificationModel.transform(testData).select("features", "prediction", "rawPrediction", "probability").collect(); //compare predictions for (Row row : sparkOutput) { Vector v = (Vector) row.get(0); double actual = row.getDouble(1); double [] actualProbability = ((Vector) row.get(3)).toArray(); double[] actualRaw = ((Vector) row.get(2)).toArray(); Map<String, Object> inputData = new HashMap<String, Object>(); inputData.put(transformer.getInputKeys().iterator().next(), v.toArray()); transformer.transform(inputData); double predicted = (double) inputData.get("prediction"); double[] probability = (double[]) inputData.get("probability"); double[] rawPrediction = (double[]) inputData.get("rawPrediction"); assertEquals(actual, predicted, EPSILON); assertArrayEquals(actualProbability, probability, EPSILON); assertArrayEquals(actualRaw, rawPrediction, EPSILON); } } @Test public void testRandomForestClassificationWithPipeline() { // Load the data stored in LIBSVM format as a DataFrame. DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/classification_test.libsvm"); // Split the data into training and test sets (30% held out for testing) DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); DataFrame trainingData = splits[0]; DataFrame testData = splits[1]; StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex"); // Train a DecisionTree model. RandomForestClassifier classifier = new RandomForestClassifier() .setLabelCol("labelIndex") .setFeaturesCol("features") .setPredictionCol("prediction") .setRawPredictionCol("rawPrediction") .setProbabilityCol("probability"); Pipeline pipeline = new Pipeline() .setStages(new PipelineStage[]{indexer, classifier}); // Train model. This also runs the indexer. PipelineModel sparkPipeline = pipeline.fit(trainingData); //Export this model byte[] exportedModel = ModelExporter.export(sparkPipeline, null); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); Row[] sparkOutput = sparkPipeline.transform(testData).select("label", "features", "prediction", "rawPrediction", "probability").collect(); //compare predictions for (Row row : sparkOutput) { Vector v = (Vector) row.get(1); double actual = row.getDouble(2); double [] actualProbability = ((Vector) row.get(4)).toArray(); double[] actualRaw = ((Vector) row.get(3)).toArray(); Map<String, Object> inputData = new HashMap<String, Object>(); inputData.put("features", v.toArray()); inputData.put("label", row.get(0).toString()); transformer.transform(inputData); double predicted = (double) inputData.get("prediction"); double[] probability = (double[]) inputData.get("probability"); double[] rawPrediction = (double[]) inputData.get("rawPrediction"); assertEquals(actual, predicted, EPSILON); assertArrayEquals(actualProbability, probability, EPSILON); assertArrayEquals(actualRaw, rawPrediction, EPSILON); } } }