package com.flipkart.fdp.ml.adapter; import com.flipkart.fdp.ml.export.ModelExporter; import com.flipkart.fdp.ml.importer.ModelImporter; import com.flipkart.fdp.ml.transformer.Transformer; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.ml.feature.VectorBinarizer; import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.SparseVector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.junit.Test; import java.util.Arrays; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertArrayEquals; /** * Created by karan.verma on 09/11/16. */ public class VectorBinarizerBridgeTest extends SparkTestBase{ @Test(expected=IllegalArgumentException.class) public void testVectorBinarizerNegativeThresholdValue() { // prepare data VectorBinarizer vectorBinarizer = new VectorBinarizer() .setInputCol("vector1") .setOutputCol("binarized") .setThreshold(-1d); } @Test public void testVectorBinarizerDense() { // prepare data JavaRDD<Row> jrdd = sc.parallelize(Arrays.asList( RowFactory.create(0d, 1d, new DenseVector(new double[]{-2d, -3d, -4d, -1d, 6d, -7d, 8d, 0d, 0d, 0d, 0d, 0d})), RowFactory.create(1d, 2d, new DenseVector(new double[]{4d, -5d, 6d, 7d, -8d, 9d, -10d, 0d, 0d, 0d, 0d, 0d})), RowFactory.create(2d, 3d, new DenseVector(new double[]{-5d, 6d, -8d, 9d, 10d, 11d, 12d, 0d, 0d, 0d, 0d, 0d})) )); StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), new StructField("value1", DataTypes.DoubleType, false, Metadata.empty()), new StructField("vector1", new VectorUDT(), false, Metadata.empty()) }); DataFrame df = sqlContext.createDataFrame(jrdd, schema); VectorBinarizer vectorBinarizer = new VectorBinarizer() .setInputCol("vector1") .setOutputCol("binarized") .setThreshold(2d); //Export this model byte[] exportedModel = ModelExporter.export(vectorBinarizer, df); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions Row[] sparkOutput = vectorBinarizer.transform(df).orderBy("id").select("id", "value1", "vector1", "binarized").collect(); for (Row row : sparkOutput) { Map<String, Object> data = new HashMap<>(); data.put(vectorBinarizer.getInputCol(), ((DenseVector) row.get(2)).toArray()); transformer.transform(data); double[] output = (double[]) data.get(vectorBinarizer.getOutputCol()); assertArrayEquals(output, ((DenseVector) row.get(3)).toArray(), 0d); } } @Test public void testVectorBinarizerSparse() { // prepare data int[] sparseArray1 = {5, 6, 11, 4, 7, 9, 8, 14, 13}; double[] sparseArray1Values = {-5d, 7d, 1d, -2d, -4d, -1d, 31d, -1d, -3d}; int[] sparseArray2 = {2, 6, 1}; double[] sparseArray2Values = {1d, 11d, 2d}; int[] sparseArray3 = {4, 6, 1}; double[] sparseArray3Values = {52d, 71d, 11d}; int[] sparseArray4 = {4, 1, 2}; double[] sparseArray4Values = {17d, 7d, 9d}; JavaRDD<Row> jrdd = sc.parallelize(Arrays.asList( RowFactory.create(3d, 4d, new SparseVector(20, sparseArray1, sparseArray1Values)), RowFactory.create(4d, 5d, new SparseVector(20, sparseArray2, sparseArray2Values)), RowFactory.create(5d, 5d, new SparseVector(20, sparseArray3, sparseArray3Values)), RowFactory.create(6d, 5d, new SparseVector(20, sparseArray4, sparseArray4Values)) )); StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.DoubleType, false, Metadata.empty()), new StructField("value1", DataTypes.DoubleType, false, Metadata.empty()), new StructField("vector1", new VectorUDT(), false, Metadata.empty()) }); DataFrame df = sqlContext.createDataFrame(jrdd, schema); VectorBinarizer vectorBinarizer = new VectorBinarizer() .setInputCol("vector1") .setOutputCol("binarized"); //Export this model byte[] exportedModel = ModelExporter.export(vectorBinarizer, null); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions Row[] sparkOutput = vectorBinarizer.transform(df).orderBy("id").select("id", "value1", "vector1", "binarized").collect(); for (Row row : sparkOutput) { Map<String, Object> data = new HashMap<>(); data.put(vectorBinarizer.getInputCol(), ((SparseVector) row.get(2)).toArray()); transformer.transform(data); double[] output = (double[]) data.get(vectorBinarizer.getOutputCol()); assertArrayEquals(output, ((SparseVector)row.get(3)).toArray(), 0d); } } }