package com.flipkart.fdp.ml.adapter; import com.flipkart.fdp.ml.export.ModelExporter; import com.flipkart.fdp.ml.importer.ModelImporter; import com.flipkart.fdp.ml.transformer.Transformer; import org.apache.commons.lang.ArrayUtils; import org.apache.spark.ml.feature.RegexTokenizer; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.junit.Test; import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import static org.apache.spark.sql.types.DataTypes.*; /** * Created by akshay.us on 3/14/16. */ public class RegexTokenizerBridgeTest extends SparkTestBase { @Test public void testRegexTokenizer() { //prepare data StructType schema = createStructType(new StructField[]{ createStructField("rawText", StringType, false), }); List<Row> trainingData = Arrays.asList( cr("Test of tok."), cr("Te,st. punct") ); DataFrame dataset = sqlContext.createDataFrame(trainingData, schema); //train model in spark RegexTokenizer sparkModel = new RegexTokenizer() .setInputCol("rawText") .setOutputCol("tokens") .setPattern("\\s") .setGaps(true) .setToLowercase(false) .setMinTokenLength(3); //Export this model byte[] exportedModel = ModelExporter.export(sparkModel, dataset); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); Row[] pairs = sparkModel.transform(dataset).select("rawText", "tokens").collect(); for (Row row : pairs) { Map<String, Object> data = new HashMap<String, Object>(); data.put(sparkModel.getInputCol(), row.getString(0)); transformer.transform(data); String[] output = (String[]) data.get(sparkModel.getOutputCol()); Object sparkOp = row.get(1); System.out.println(ArrayUtils.toString(output)); System.out.println(row.get(1)); } } }