Java Code Examples for org.apache.spark.sql.types.Metadata#empty()

The following examples show how to use org.apache.spark.sql.types.Metadata#empty() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestRangeRowRule.java    From envelope with Apache License 2.0 6 votes vote down vote up
@Test
public void testIgnoreNulls() {
  StructType schema = new StructType(new StructField[] {
      new StructField("name", DataTypes.StringType, false, Metadata.empty()),
      new StructField("nickname", DataTypes.StringType, false, Metadata.empty()),
      new StructField("age", DataTypes.IntegerType, false, Metadata.empty()),
      new StructField("candycrushscore", DataTypes.createDecimalType(), false, Metadata.empty())
  });

  Map<String, Object> configMap = new HashMap<>();
  configMap.put(RangeRowRule.FIELDS_CONFIG, Lists.newArrayList("age"));
  configMap.put(RangeRowRule.FIELD_TYPE_CONFIG, "int");
  configMap.put(RangeRowRule.RANGE_CONFIG, Lists.newArrayList(0,105));
  configMap.put(RangeRowRule.IGNORE_NULLS_CONFIG, true);
  Config config = ConfigFactory.parseMap(configMap);

  RangeRowRule rule = new RangeRowRule();
  assertNoValidationFailures(rule, config);
  rule.configure(config);
  rule.configureName("agerange");

  Row row1 = new RowWithSchema(schema, "Ian", "Ian", null, new BigDecimal("0.00"));
  assertTrue("Row should pass rule", rule.check(row1));
}
 
Example 2
Source File: TestRangeRowRule.java    From envelope with Apache License 2.0 6 votes vote down vote up
@Test
public void testRangeDataTypes() throws Exception {
  Config config = ConfigUtils.configFromResource("/dq/dq-range-rules.conf").getConfig("steps");
  StructType schema = new StructType(new StructField[] {
    new StructField("fa", DataTypes.LongType, false, Metadata.empty()),
    new StructField("fi", DataTypes.IntegerType, false, Metadata.empty()),
    new StructField("fl", DataTypes.LongType, false, Metadata.empty()),
    new StructField("ff", DataTypes.FloatType, false, Metadata.empty()),
    new StructField("fe", DataTypes.DoubleType, false, Metadata.empty()),
    new StructField("fd", DataTypes.createDecimalType(), false, Metadata.empty())
  });
  Row row = new RowWithSchema(schema, new Long(2), 2, new Long(2), new Float(2.0), 2.0, new BigDecimal("2.0"));
    
  ConfigObject rro =  config.getObject("dq1.deriver.rules") ;
  for ( String rulename : rro.keySet() ) {
    Config rrc = rro.toConfig().getConfig(rulename);
    RangeRowRule rrr = new RangeRowRule() ;
    rrr.configure(rrc);
    rrr.configureName(rulename);
    assertTrue("Row should pass rule " + rulename, rrr.check(row));
  }
}
 
Example 3
Source File: Tagger.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Tags a distributed list of sentences and writes the result to an output file with 
 * a desired output format.
 * @param sentences
 * @param outputFileName
 * @param outputFormat
 */
public void tag(JavaRDD<Row> sentences, String outputFileName, OutputFormat outputFormat) {
	StructType schema = new StructType(new StructField[]{
		new StructField("sentence", DataTypes.StringType, false, Metadata.empty())	
	});
	SQLContext sqlContext = new SQLContext(jsc);
	DataFrame input = sqlContext.createDataFrame(sentences, schema);
	tag(input, outputFileName, outputFormat);
}
 
Example 4
Source File: JavaAFTSurvivalRegressionExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaAFTSurvivalRegressionExample")
    .getOrCreate();

  // $example on$
  List<Row> data = Arrays.asList(
    RowFactory.create(1.218, 1.0, Vectors.dense(1.560, -0.605)),
    RowFactory.create(2.949, 0.0, Vectors.dense(0.346, 2.158)),
    RowFactory.create(3.627, 0.0, Vectors.dense(1.380, 0.231)),
    RowFactory.create(0.273, 1.0, Vectors.dense(0.520, 1.151)),
    RowFactory.create(4.199, 0.0, Vectors.dense(0.795, -0.226))
  );
  StructType schema = new StructType(new StructField[]{
    new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
    new StructField("censor", DataTypes.DoubleType, false, Metadata.empty()),
    new StructField("features", new VectorUDT(), false, Metadata.empty())
  });
  Dataset<Row> training = spark.createDataFrame(data, schema);
  double[] quantileProbabilities = new double[]{0.3, 0.6};
  AFTSurvivalRegression aft = new AFTSurvivalRegression()
    .setQuantileProbabilities(quantileProbabilities)
    .setQuantilesCol("quantiles");

  AFTSurvivalRegressionModel model = aft.fit(training);

  // Print the coefficients, intercept and scale parameter for AFT survival regression
  System.out.println("Coefficients: " + model.coefficients());
  System.out.println("Intercept: " + model.intercept());
  System.out.println("Scale: " + model.scale());
  model.transform(training).show(false);
  // $example off$

  spark.stop();
}
 
Example 5
Source File: Tagger.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Tags a list of sequences and writes the result to an output file with a
 * desired output format.
 * 
 * @param sentences
 * @param outputFileName
 * @param outputFormat
 */
public void tag(List<String> sentences, String outputFileName, OutputFormat outputFormat) {
	List<Row> rows = new LinkedList<Row>();
	for (String sentence : sentences) {
		rows.add(RowFactory.create(sentence));
	}
	StructType schema = new StructType(new StructField[]{
		new StructField("sentence", DataTypes.StringType, false, Metadata.empty())	
	});
	SQLContext sqlContext = new SQLContext(jsc);
	DataFrame input = sqlContext.createDataFrame(rows, schema);
	tag(input, outputFileName, outputFormat);
}
 
Example 6
Source File: JavaTfIdfExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaTfIdfExample")
    .getOrCreate();

  // $example on$
  List<Row> data = Arrays.asList(
    RowFactory.create(0.0, "Hi I heard about Spark"),
    RowFactory.create(0.0, "I wish Java could use case classes"),
    RowFactory.create(1.0, "Logistic regression models are neat")
  );
  StructType schema = new StructType(new StructField[]{
    new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
    new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
  });
  Dataset<Row> sentenceData = spark.createDataFrame(data, schema);

  Tokenizer tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words");
  Dataset<Row> wordsData = tokenizer.transform(sentenceData);

  int numFeatures = 20;
  HashingTF hashingTF = new HashingTF()
    .setInputCol("words")
    .setOutputCol("rawFeatures")
    .setNumFeatures(numFeatures);

  Dataset<Row> featurizedData = hashingTF.transform(wordsData);
  // alternatively, CountVectorizer can also be used to get term frequency vectors

  IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
  IDFModel idfModel = idf.fit(featurizedData);

  Dataset<Row> rescaledData = idfModel.transform(featurizedData);
  rescaledData.select("label", "features").show();
  // $example off$

  spark.stop();
}
 
Example 7
Source File: JavaNormalizerExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaNormalizerExample")
    .getOrCreate();

  // $example on$
  List<Row> data = Arrays.asList(
      RowFactory.create(0, Vectors.dense(1.0, 0.1, -8.0)),
      RowFactory.create(1, Vectors.dense(2.0, 1.0, -4.0)),
      RowFactory.create(2, Vectors.dense(4.0, 10.0, 8.0))
  );
  StructType schema = new StructType(new StructField[]{
      new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
      new StructField("features", new VectorUDT(), false, Metadata.empty())
  });
  Dataset<Row> dataFrame = spark.createDataFrame(data, schema);

  // Normalize each Vector using $L^1$ norm.
  Normalizer normalizer = new Normalizer()
    .setInputCol("features")
    .setOutputCol("normFeatures")
    .setP(1.0);

  Dataset<Row> l1NormData = normalizer.transform(dataFrame);
  l1NormData.show();

  // Normalize each Vector using $L^\infty$ norm.
  Dataset<Row> lInfNormData =
    normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
  lInfNormData.show();
  // $example off$

  spark.stop();
}
 
Example 8
Source File: SimplePredictionFromTextFile.java    From net.jgp.labs.spark with Apache License 2.0 5 votes vote down vote up
private void start() {
  SparkSession spark = SparkSession.builder().appName(
      "Simple prediction from Text File").master("local").getOrCreate();

  spark.udf().register("vectorBuilder", new VectorBuilder(), new VectorUDT());

  String filename = "data/tuple-data-file.csv";
  StructType schema = new StructType(
      new StructField[] { new StructField("_c0", DataTypes.DoubleType, false,
          Metadata.empty()),
          new StructField("_c1", DataTypes.DoubleType, false, Metadata
              .empty()),
          new StructField("features", new VectorUDT(), true, Metadata
              .empty()), });

  Dataset<Row> df = spark.read().format("csv").schema(schema).option("header",
      "false")
      .load(filename);
  df = df.withColumn("valuefeatures", df.col("_c0")).drop("_c0");
  df = df.withColumn("label", df.col("_c1")).drop("_c1");
  df.printSchema();

  df = df.withColumn("features", callUDF("vectorBuilder", df.col(
      "valuefeatures")));
  df.printSchema();
  df.show();

  LinearRegression lr = new LinearRegression().setMaxIter(20);// .setRegParam(1).setElasticNetParam(1);

  // Fit the model to the data.
  LinearRegressionModel model = lr.fit(df);

  // Given a dataset, predict each point's label, and show the results.
  model.transform(df).show();

  LinearRegressionTrainingSummary trainingSummary = model.summary();
  System.out.println("numIterations: " + trainingSummary.totalIterations());
  System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary
      .objectiveHistory()));
  trainingSummary.residuals().show();
  System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
  System.out.println("r2: " + trainingSummary.r2());

  double intercept = model.intercept();
  System.out.println("Interesection: " + intercept);
  double regParam = model.getRegParam();
  System.out.println("Regression parameter: " + regParam);
  double tol = model.getTol();
  System.out.println("Tol: " + tol);
  Double feature = 7.0;
  Vector features = Vectors.dense(feature);
  double p = model.predict(features);

  System.out.println("Prediction for feature " + feature + " is " + p);
  System.out.println(8 * regParam + intercept);
}
 
Example 9
Source File: FirstPrediction.java    From net.jgp.labs.spark with Apache License 2.0 5 votes vote down vote up
private void start() {
  SparkSession spark = SparkSession.builder().appName("First Prediction")
      .master("local").getOrCreate();

  StructType schema = new StructType(
      new StructField[] { new StructField("label", DataTypes.DoubleType,
          false, Metadata.empty()),
          new StructField("features", new VectorUDT(), false, Metadata
              .empty()), });

  // TODO this example is not working yet
}
 
Example 10
Source File: EntitySalienceFeatureExtractorSpark.java    From ambiverse-nlu with Apache License 2.0 5 votes vote down vote up
/**
 * Extract a DataFrame ready for training or testing.
 * @param jsc
 * @param documents
 * @param sqlContext
 * @return
 * @throws ResourceInitializationException
 */
public DataFrame extract(JavaSparkContext jsc, JavaRDD<SCAS> documents, SQLContext sqlContext) throws ResourceInitializationException {
    Accumulator<Integer> TOTAL_DOCS = jsc.accumulator(0, "TOTAL_DOCS");
    Accumulator<Integer> SALIENT_ENTITY_INSTANCES = jsc.accumulator(0, "SALIENT_ENTITY_INSTANCES");
    Accumulator<Integer> NON_SALIENT_ENTITY_INSTANCES = jsc.accumulator(0, "NON_SALIENT_ENTITY_INSTANCES");

    TrainingSettings trainingSettings = getTrainingSettings();

    FeatureExtractor fe = new NYTEntitySalienceFeatureExtractor();
    final int featureVectorSize = FeatureSetFactory.createFeatureSet(TrainingSettings.FeatureExtractor.ENTITY_SALIENCE).getFeatureVectorSize();

    JavaRDD<TrainingInstance> trainingInstances =
            documents.flatMap(s -> {
                TOTAL_DOCS.add(1);
                return fe.getTrainingInstances(s.getJCas(),
                        trainingSettings.getFeatureExtractor(),
                        trainingSettings.getPositiveInstanceScalingFactor());
            });

    StructType schema = new StructType(new StructField[]{
            new StructField("docId", DataTypes.StringType, false, Metadata.empty() ),
            new StructField("entityId", DataTypes.StringType, false, Metadata.empty() ),
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty() ),
            new StructField("features", new VectorUDT(), false, Metadata.empty())
    });

    JavaRDD<Row> withFeatures = trainingInstances.map(ti -> {
        if (ti.getLabel() == 1.0) {
            SALIENT_ENTITY_INSTANCES.add(1);
        } else {
            NON_SALIENT_ENTITY_INSTANCES.add(1);
        }
        Vector vei = FeatureValueInstanceUtils.convertToSparkMLVector(ti, featureVectorSize);
        return RowFactory.create(ti.getDocId(), ti.getEntityId(), ti.getLabel(), vei);
    });

    return sqlContext.createDataFrame(withFeatures, schema);
}
 
Example 11
Source File: VectorBinarizerBridgeTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void testVectorBinarizerDense() {
    // prepare data

    JavaRDD<Row> jrdd = sc.parallelize(Arrays.asList(
            RowFactory.create(0d, 1d, new DenseVector(new double[]{-2d, -3d, -4d, -1d, 6d, -7d, 8d, 0d, 0d, 0d, 0d, 0d})),
            RowFactory.create(1d, 2d, new DenseVector(new double[]{4d, -5d, 6d, 7d, -8d, 9d, -10d, 0d, 0d, 0d, 0d, 0d})),
            RowFactory.create(2d, 3d, new DenseVector(new double[]{-5d, 6d, -8d, 9d, 10d, 11d, 12d, 0d, 0d, 0d, 0d, 0d}))
    ));

    StructType schema = new StructType(new StructField[]{
            new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("value1", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("vector1", new VectorUDT(), false, Metadata.empty())
    });

    DataFrame df = sqlContext.createDataFrame(jrdd, schema);
    VectorBinarizer vectorBinarizer = new VectorBinarizer()
            .setInputCol("vector1")
            .setOutputCol("binarized")
            .setThreshold(2d);


    //Export this model
    byte[] exportedModel = ModelExporter.export(vectorBinarizer, df);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
    //compare predictions
    Row[] sparkOutput = vectorBinarizer.transform(df).orderBy("id").select("id", "value1", "vector1", "binarized").collect();
    for (Row row : sparkOutput) {

        Map<String, Object> data = new HashMap<>();
        data.put(vectorBinarizer.getInputCol(), ((DenseVector) row.get(2)).toArray());
        transformer.transform(data);
        double[] output = (double[]) data.get(vectorBinarizer.getOutputCol());
        assertArrayEquals(output, ((DenseVector) row.get(3)).toArray(), 0d);
    }
}
 
Example 12
Source File: JavaPolynomialExpansionExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaPolynomialExpansionExample")
    .getOrCreate();

  // $example on$
  PolynomialExpansion polyExpansion = new PolynomialExpansion()
    .setInputCol("features")
    .setOutputCol("polyFeatures")
    .setDegree(3);

  List<Row> data = Arrays.asList(
    RowFactory.create(Vectors.dense(2.0, 1.0)),
    RowFactory.create(Vectors.dense(0.0, 0.0)),
    RowFactory.create(Vectors.dense(3.0, -1.0))
  );
  StructType schema = new StructType(new StructField[]{
    new StructField("features", new VectorUDT(), false, Metadata.empty()),
  });
  Dataset<Row> df = spark.createDataFrame(data, schema);

  Dataset<Row> polyDF = polyExpansion.transform(df);
  polyDF.show(false);
  // $example off$

  spark.stop();
}
 
Example 13
Source File: TestSparkSchema.java    From iceberg with Apache License 2.0 5 votes vote down vote up
@Test
public void testSparkReadSchemaCombinedWithProjection() throws IOException {
  String tableLocation = temp.newFolder("iceberg-table").toString();

  HadoopTables tables = new HadoopTables(CONF);
  PartitionSpec spec = PartitionSpec.unpartitioned();
  tables.create(SCHEMA, spec, null, tableLocation);

  List<SimpleRecord> expectedRecords = Lists.newArrayList(
      new SimpleRecord(1, "a")
  );
  Dataset<Row> originalDf = spark.createDataFrame(expectedRecords, SimpleRecord.class);
  originalDf.select("id", "data").write()
      .format("iceberg")
      .mode("append")
      .save(tableLocation);

  StructType sparkReadSchema =
      new StructType(
          new StructField[] {
              new StructField("id", DataTypes.IntegerType, true, Metadata.empty()),
              new StructField("data", DataTypes.StringType, true, Metadata.empty())
          }
      );

  Dataset<Row> resultDf = spark.read()
      .schema(sparkReadSchema)
      .format("iceberg")
      .load(tableLocation)
      .select("id");

  Row[] results = (Row[]) resultDf.collect();

  Assert.assertEquals("Result size matches", 1, results.length);
  Assert.assertEquals("Row length matches with sparkReadSchema", 1, results[0].length());
  Assert.assertEquals("Row content matches data", 1, results[0].getInt(0));
}
 
Example 14
Source File: CMMModel.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
@Override
public DataFrame transform(DataFrame dataset) {
	JavaRDD<Row> output = dataset.javaRDD().map(new DecodeFunction());
	StructType schema = new StructType(new StructField[]{
		new StructField("sentence", DataTypes.StringType, false, Metadata.empty()),
		new StructField("prediction", DataTypes.StringType, false, Metadata.empty())
	});
	return dataset.sqlContext().createDataFrame(output, schema);
}
 
Example 15
Source File: AverageUDAF.java    From Apache-Spark-2x-for-Java-Developers with MIT License 4 votes vote down vote up
@Override
public StructType inputSchema() {
	return new StructType(new StructField[] { new StructField("counter", DataTypes.DoubleType, true, Metadata.empty())});
}
 
Example 16
Source File: StringSanitizerBridgeTest.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Test
public void testStringSanitizer() {

	//prepare data
	JavaRDD<Row> rdd = jsc.parallelize(Arrays.asList(
			RowFactory.create(1, "Jyoti complex near Sananda clothes store; English Bazar; Malda;WB;India,"),
			RowFactory.create(2, "hallalli vinayaka tent road c/o B K vishwanath Mandya"),
			RowFactory.create(3, "M.sathish S/o devudu Lakshmi opticals Gokavaram bus stand Rajhamundry 9494954476")
	));

	StructType schema = new StructType(new StructField[]{
			new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
			new StructField("rawText", DataTypes.StringType, false, Metadata.empty())
	});
	Dataset<Row> dataset = spark.createDataFrame(rdd, schema);
	dataset.show();

	//train model in spark
	StringSanitizer sparkModel = new StringSanitizer()
			.setInputCol("rawText")
			.setOutputCol("token");

	//Export this model
	byte[] exportedModel = ModelExporter.export(sparkModel);

	//Import and get Transformer
	Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

	List<Row> pairs = sparkModel.transform(dataset).select("rawText", "token").collectAsList();

	for (Row row : pairs) {
		Map<String, Object> data = new HashMap<String, Object>();
		data.put(sparkModel.getInputCol(), row.getString(0));
		transformer.transform(data);

		String[] actual = (String[]) data.get(sparkModel.getOutputCol());

		List<String> actualList = Arrays.asList(actual);
		List<String> expected = row.getList(1);

		assertTrue("both should be same", actualList.equals(expected));
	}
}
 
Example 17
Source File: AttributeReference.java    From indexr with Apache License 2.0 4 votes vote down vote up
public AttributeReference(String name, DataType dataType) {
    this.name = name;
    this.dataType = dataType;
    this.metadata = Metadata.empty;
    this.exprId = NamedExpression.newExprId();
}
 
Example 18
Source File: JavaIndexToStringExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaIndexToStringExample")
    .getOrCreate();

  // $example on$
  List<Row> data = Arrays.asList(
    RowFactory.create(0, "a"),
    RowFactory.create(1, "b"),
    RowFactory.create(2, "c"),
    RowFactory.create(3, "a"),
    RowFactory.create(4, "a"),
    RowFactory.create(5, "c")
  );
  StructType schema = new StructType(new StructField[]{
    new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
    new StructField("category", DataTypes.StringType, false, Metadata.empty())
  });
  Dataset<Row> df = spark.createDataFrame(data, schema);

  StringIndexerModel indexer = new StringIndexer()
    .setInputCol("category")
    .setOutputCol("categoryIndex")
    .fit(df);
  Dataset<Row> indexed = indexer.transform(df);

  System.out.println("Transformed string column '" + indexer.getInputCol() + "' " +
      "to indexed column '" + indexer.getOutputCol() + "'");
  indexed.show();

  StructField inputColSchema = indexed.schema().apply(indexer.getOutputCol());
  System.out.println("StringIndexer will store labels in output column metadata: " +
      Attribute.fromStructField(inputColSchema).toString() + "\n");

  IndexToString converter = new IndexToString()
    .setInputCol("categoryIndex")
    .setOutputCol("originalCategory");
  Dataset<Row> converted = converter.transform(indexed);

  System.out.println("Transformed indexed column '" + converter.getInputCol() + "' back to " +
      "original string column '" + converter.getOutputCol() + "' using labels in metadata");
  converted.select("id", "categoryIndex", "originalCategory").show();

  // $example off$
  spark.stop();
}
 
Example 19
Source File: VectorAssemblerBridgeTest.java    From spark-transformers with Apache License 2.0 4 votes vote down vote up
@Test
public void testVectorAssembler() {
    // prepare data

    JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
            RowFactory.create(0d, 1d, new DenseVector(new double[]{2d, 3d})),
            RowFactory.create(1d, 2d, new DenseVector(new double[]{3d, 4d})),
            RowFactory.create(2d, 3d, new DenseVector(new double[]{4d, 5d})),
            RowFactory.create(3d, 4d, new DenseVector(new double[]{5d, 6d})),
            RowFactory.create(4d, 5d, new DenseVector(new double[]{6d, 7d}))
    ));

    StructType schema = new StructType(new StructField[]{
            new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("value1", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("vector1", new VectorUDT(), false, Metadata.empty())
    });

    Dataset<Row> df = spark.createDataFrame(jrdd, schema);
    VectorAssembler vectorAssembler = new VectorAssembler()
            .setInputCols(new String[]{"value1", "vector1"})
            .setOutputCol("feature");


    //Export this model
    byte[] exportedModel = ModelExporter.export(vectorAssembler);

    String exportedModelJson = new String(exportedModel);
    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
    //compare predictions
    List<Row> sparkOutput = vectorAssembler.transform(df).orderBy("id").select("id", "value1", "vector1", "feature").collectAsList();
    for (Row row : sparkOutput) {

        Map<String, Object> data = new HashMap<>();
        data.put(vectorAssembler.getInputCols()[0], row.get(1));
        data.put(vectorAssembler.getInputCols()[1], ((DenseVector) row.get(2)).toArray());
        transformer.transform(data);
        double[] output = (double[]) data.get(vectorAssembler.getOutputCol());
        assertArrayEquals(output, ((DenseVector) row.get(3)).toArray(), 0d);
    }
}
 
Example 20
Source File: JavaBucketedRandomProjectionLSHExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaBucketedRandomProjectionLSHExample")
    .getOrCreate();

  // $example on$
  List<Row> dataA = Arrays.asList(
    RowFactory.create(0, Vectors.dense(1.0, 1.0)),
    RowFactory.create(1, Vectors.dense(1.0, -1.0)),
    RowFactory.create(2, Vectors.dense(-1.0, -1.0)),
    RowFactory.create(3, Vectors.dense(-1.0, 1.0))
  );

  List<Row> dataB = Arrays.asList(
      RowFactory.create(4, Vectors.dense(1.0, 0.0)),
      RowFactory.create(5, Vectors.dense(-1.0, 0.0)),
      RowFactory.create(6, Vectors.dense(0.0, 1.0)),
      RowFactory.create(7, Vectors.dense(0.0, -1.0))
  );

  StructType schema = new StructType(new StructField[]{
    new StructField("id", DataTypes.IntegerType, false, Metadata.empty()),
    new StructField("keys", new VectorUDT(), false, Metadata.empty())
  });
  Dataset<Row> dfA = spark.createDataFrame(dataA, schema);
  Dataset<Row> dfB = spark.createDataFrame(dataB, schema);

  Vector key = Vectors.dense(1.0, 0.0);

  BucketedRandomProjectionLSH mh = new BucketedRandomProjectionLSH()
    .setBucketLength(2.0)
    .setNumHashTables(3)
    .setInputCol("keys")
    .setOutputCol("values");

  BucketedRandomProjectionLSHModel model = mh.fit(dfA);

  // Feature Transformation
  model.transform(dfA).show();
  // Cache the transformed columns
  Dataset<Row> transformedA = model.transform(dfA).cache();
  Dataset<Row> transformedB = model.transform(dfB).cache();

  // Approximate similarity join
  model.approxSimilarityJoin(dfA, dfB, 1.5).show();
  model.approxSimilarityJoin(transformedA, transformedB, 1.5).show();
  // Self Join
  model.approxSimilarityJoin(dfA, dfA, 2.5).filter("datasetA.id < datasetB.id").show();

  // Approximate nearest neighbor search
  model.approxNearestNeighbors(dfA, key, 2).show();
  model.approxNearestNeighbors(transformedA, key, 2).show();
  // $example off$

  spark.stop();
}