package com.flipkart.fdp.ml.adapter; import com.flipkart.fdp.ml.export.ModelExporter; import com.flipkart.fdp.ml.importer.ModelImporter; import com.flipkart.fdp.ml.transformer.Transformer; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.MinMaxScaler; import org.apache.spark.ml.feature.MinMaxScalerModel; import org.apache.spark.ml.linalg.Vector; import org.apache.spark.ml.linalg.VectorUDT; import org.apache.spark.ml.linalg.Vectors; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.junit.Test; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.junit.Assert.assertArrayEquals; public class MinMaxScalerBridgeTest extends SparkTestBase { private final double data[][] = {{1, 0, Long.MIN_VALUE}, {2, 0, 0}, {3, 0, Long.MAX_VALUE}, {1.0, 0, 0}}; private final double expected[][] = {{-5, 0, -5}, {0, 0, 0}, {5, 0, 5}, {-2.5, 0, 0}}; @Test public void testMinMaxScaler() { //prepare data JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList( RowFactory.create(1.0, Vectors.dense(data[0])), RowFactory.create(2.0, Vectors.dense(data[1])), RowFactory.create(3.0, Vectors.dense(data[2])), RowFactory.create(4.0, Vectors.dense(data[3])) )); StructType schema = new StructType(new StructField[]{ new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", new VectorUDT(), false, Metadata.empty()) }); Dataset<Row> df = spark.createDataFrame(jrdd, schema); //train model in spark MinMaxScalerModel sparkModel = new MinMaxScaler() .setInputCol("features") .setOutputCol("scaled") .setMin(-5) .setMax(5) .fit(df); //Export model, import it back and get transformer byte[] exportedModel = ModelExporter.export(sparkModel); final Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions List<Row> sparkOutput = sparkModel.transform(df).orderBy("label").select("features", "scaled").collectAsList(); assertCorrectness(sparkOutput, expected, transformer); } private void assertCorrectness(List<Row> sparkOutput, double[][] expected, Transformer transformer) { for (int i = 0; i < 3; i++) { double[] input = ((Vector) sparkOutput.get(i).get(0)).toArray(); Map<String, Object> data = new HashMap<String, Object>(); data.put("features", input); transformer.transform(data); double[] transformedOp = (double[]) data.get("scaled"); double[] sparkOp = ((Vector) sparkOutput.get(i).get(1)).toArray(); assertArrayEquals(transformedOp, sparkOp, 0.01); assertArrayEquals(transformedOp, expected[i], 0.01); } } }