package com.flipkart.fdp.ml.adapter; import com.flipkart.fdp.ml.export.ModelExporter; import com.flipkart.fdp.ml.importer.ModelImporter; import com.flipkart.fdp.ml.modelinfo.StringIndexerModelInfo; import com.flipkart.fdp.ml.transformer.Transformer; import org.apache.spark.ml.feature.StringIndexer; import org.apache.spark.ml.feature.StringIndexerModel; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.junit.Test; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.apache.spark.sql.types.DataTypes.*; import static org.junit.Assert.assertEquals; /** * Created by akshay.us on 3/2/16. */ public class StringIndexerBridgeTest extends SparkTestBase { @Test public void testStringIndexer() { //prepare data StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); List<Row> trainingData = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); DataFrame dataset = sqlContext.createDataFrame(trainingData, schema); //train model in spark StringIndexerModel model = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex").fit(dataset); //Export this model byte[] exportedModel = ModelExporter.export(model, dataset); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions Row[] sparkOutput = model.transform(dataset).orderBy("id").select("id", "label", "labelIndex").collect(); for (Row row : sparkOutput) { Map<String, Object> data = new HashMap<String, Object>(); data.put(model.getInputCol(), (String) row.get(1)); transformer.transform(data); double indexerOutput = (double) data.get(model.getOutputCol()); assertEquals(indexerOutput, (double) row.get(2), EPSILON); } } @Test public void testStringIndexerForDoubleColumn() { //prepare data StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", DoubleType, false) }); List<Row> trainingData = Arrays.asList( cr(0, 1.0), cr(1, 2.0), cr(2, 3.0), cr(3, 1.0), cr(4, 1.0), cr(5, 3.0)); DataFrame dataset = sqlContext.createDataFrame(trainingData, schema); //train model in spark StringIndexerModel model = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex").fit(dataset); //Export this model byte[] exportedModel = ModelExporter.export(model, dataset); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions Row[] sparkOutput = model.transform(dataset).orderBy("id").select("id", "label", "labelIndex").collect(); for (Row row : sparkOutput) { Map<String, Object> data = new HashMap<String, Object>(); data.put(model.getInputCol(), row.getDouble(1)); transformer.transform(data); double indexerOutput = (double) data.get(model.getOutputCol()); assertEquals(indexerOutput, row.getDouble(2), EPSILON); } } @Test(expected=RuntimeException.class) public void testStringIndexerForUnseenValues() { //prepare data StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", DoubleType, false) }); List<Row> trainingData = Arrays.asList( cr(0, 1.0), cr(1, 2.0), cr(2, 3.0), cr(3, 1.0), cr(4, 1.0), cr(5, 3.0)); DataFrame dataset = sqlContext.createDataFrame(trainingData, schema); //train model in spark StringIndexerModel model = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex").fit(dataset); //Export this model byte[] exportedModel = ModelExporter.export(model, dataset); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //unseen value Map<String, Object> data = new HashMap<String, Object>(); data.put(model.getInputCol(), 7.0); transformer.transform(data); } @Test public void testStringIndexerForHandlingUnseenValues() { //prepare data StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", DoubleType, false) }); List<Row> trainingData = Arrays.asList( cr(0, 1.0), cr(1, 2.0), cr(2, 3.0), cr(3, 1.0), cr(4, 1.0), cr(5, 3.0)); DataFrame dataset = sqlContext.createDataFrame(trainingData, schema); //train model in spark StringIndexerModel model = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex").fit(dataset); //Export this model byte[] exportedModel = ModelExporter.export(model, dataset); StringIndexerModelInfo stringIndexerModelInfo = (StringIndexerModelInfo)ModelImporter.importModelInfo(exportedModel); stringIndexerModelInfo.setFailOnUnseenValues(false); //Import and get Transformer Transformer transformer = stringIndexerModelInfo.getTransformer(); //unseen value Map<String, Object> data = new HashMap<String, Object>(); data.put(model.getInputCol(), 7.0); transformer.transform(data); double indexerOutput = (double) data.get(model.getOutputCol()); assertEquals(indexerOutput, 3.0, EPSILON); //unseen value data.put(model.getInputCol(), 9.0); transformer.transform(data); indexerOutput = (double) data.get(model.getOutputCol()); assertEquals(indexerOutput, 3.0, EPSILON); //unseen value data.put(model.getInputCol(), 0.0); transformer.transform(data); indexerOutput = (double) data.get(model.getOutputCol()); assertEquals(indexerOutput, 3.0, EPSILON); //seen value data.put(model.getInputCol(), 2.0); transformer.transform(data); indexerOutput = (double) data.get(model.getOutputCol()); assertEquals(indexerOutput, stringIndexerModelInfo.getLabelToIndex().get("2.0"), EPSILON); } }