package com.flipkart.fdp.ml.adapter; import com.flipkart.fdp.ml.export.ModelExporter; import com.flipkart.fdp.ml.importer.ModelImporter; import com.flipkart.fdp.ml.transformer.Transformer; import org.apache.spark.api.java.JavaRDD; import com.flipkart.transformer.ml.CommonAddressFeatures; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.types.*; import org.junit.Test; import java.util.*; import static org.junit.Assert.assertEquals; public class CommonAddressFeaturesBridgeTest extends SparkTestBase { @Test public void testCommonAddressFeatures() { final String[] addressLine1 = new String[]{ "Jyoti complex near Sananda clothes store; English Bazar; Malda;WB;India" , "hallalli vinayaka tent road c/o B K vishwanath Mandya", "M.sathish S/o devudu Lakshmi opticals Gokavaram bus stand Rajhamundry 9494954476" }; final String[] addressLine2 = new String[]{ "", "harishchandra circle", "Near Lilly's Textile" }; final String[] mergeAddress = new String[]{ addressLine1[0] + " " + addressLine2[0], addressLine1[1] + " " + addressLine2[1], addressLine1[2] + " " + addressLine2[2] }; final List<String[]> sanitizedAddress = new ArrayList<>(); sanitizedAddress.add(mergeAddress[0].split(" ")); sanitizedAddress.add(mergeAddress[1].split(" ")); sanitizedAddress.add(mergeAddress[2].split(" ")); //prepare data JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList( RowFactory.create(1, mergeAddress[0], sanitizedAddress.get(0)), RowFactory.create(1, mergeAddress[1], sanitizedAddress.get(1)), RowFactory.create(1, mergeAddress[2], sanitizedAddress.get(2)) )); StructType schema = new StructType(new StructField[]{ new StructField("id", DataTypes.IntegerType, false, Metadata.empty()), new StructField("mergedAddress", DataTypes.StringType, false, Metadata.empty()), new StructField("sanitizedAddress", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); Dataset<Row> dataset = spark.createDataFrame(rdd, schema); dataset.show(); CommonAddressFeatures sparkModel = new CommonAddressFeatures() .setInputCol("sanitizedAddress") .setRawInputCol("mergedAddress"); byte[] exportedModel = ModelExporter.export(sparkModel); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions Dataset<Row> rowDataset = sparkModel.transform(dataset).orderBy("id").select("mergedAddress", "sanitizedAddress", "numWords", "numCommas", "numericPresent", "addressLength", "favouredStart", "unfavouredStart"); rowDataset.show(); assertCorrectness(rowDataset, transformer); } private void assertCorrectness(Dataset<Row> rowDataset, Transformer transformer) { List<Row> sparkOutput = rowDataset.collectAsList(); for (Row row : sparkOutput) { Map<String, Object> data = new HashMap<>(); data.put("mergedAddress", row.get(0)); List<Object> list = row.getList(1); String[] sanitizedAddress = new String[list.size()]; for (int j = 0; j < sanitizedAddress.length; j++) { sanitizedAddress[j] = (String) list.get(j); } data.put("sanitizedAddress", sanitizedAddress); transformer.transform(data); assertEquals("number of words should be equals", row.get(2), data.get("numWords")); assertEquals("number of commas should be equals", row.get(3), data.get("numCommas")); assertEquals("numericPresent should be equals", row.get(4), data.get("numericPresent")); assertEquals("addressLength should be equals", row.get(5), data.get("addressLength")); assertEquals("favouredStart should be equals", row.get(6), data.get("favouredStart")); assertEquals("unfavouredStart should be equals", row.get(7), data.get("unfavouredStart")); } } }