Java Code Examples for org.apache.spark.sql.Dataset#randomSplit()

Example 1
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  // Load and parse the data file, converting it to a DataFrame.
  Dataset<Row> data ="libsvm").load("data/mllib/sample_libsvm_data.txt");

  // Index labels, adding metadata to the label column.
  // Fit on whole dataset to include all labels in index.
  StringIndexerModel labelIndexer = new StringIndexer()
  // Automatically identify categorical features, and index them.
  // Set maxCategories so features with > 4 distinct values are treated as continuous.
  VectorIndexerModel featureIndexer = new VectorIndexer()

  // Split the data into training and test sets (30% held out for testing)
  Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
  Dataset<Row> trainingData = splits[0];
  Dataset<Row> testData = splits[1];

  // Train a RandomForest model.
  RandomForestClassifier rf = new RandomForestClassifier()

  // Convert indexed labels back to original labels.
  IndexToString labelConverter = new IndexToString()

  // Chain indexers and forest in a Pipeline
  Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[] {labelIndexer, featureIndexer, rf, labelConverter});

  // Train model. This also runs the indexers.
  PipelineModel model =;

  // Make predictions.
  Dataset<Row> predictions = model.transform(testData);

  // Select example rows to display."predictedLabel", "label", "features").show(5);

  // Select (prediction, true label) and compute test error
  MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
  double accuracy = evaluator.evaluate(predictions);
  System.out.println("Test Error = " + (1.0 - accuracy));

  RandomForestClassificationModel rfModel = (RandomForestClassificationModel)(model.stages()[2]);
  System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
  // $example off$

Example 2
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  JavaRDD<Rating> ratingsRDD = spark
    .read().textFile(Constant.LOCAL_FILE_PREX +"data/mllib/als/sample_movielens_ratings.txt").javaRDD()
    .map(new Function<String, Rating>() {
      public Rating call(String str) {
        return Rating.parseRating(str);
  Dataset<Row> ratings = spark.createDataFrame(ratingsRDD, Rating.class);
  Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2});
  Dataset<Row> training = splits[0];
  Dataset<Row> test = splits[1];

  // Build the recommendation model using ALS on the training data
  ALS als = new ALS()
  ALSModel model =;

  // Evaluate the model by computing the RMSE on the test data
  Dataset<Row> predictions = model.transform(test);

  RegressionEvaluator evaluator = new RegressionEvaluator()
  Double rmse = evaluator.evaluate(predictions);
  System.out.println("Root-mean-square error = " + rmse);
  // $example off$
Example 3
Source File:    From spark-transformers with Apache License 2.0
public void testGradientBoostClassification() {
	// Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/binary_classification_test.libsvm";

	Dataset<Row> data ="libsvm").load(datapath);

	// Split the data into training and test sets (30% held out for testing)
	Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
	Dataset<Row> trainingData = splits[0];
	Dataset<Row> testData = splits[1];

	// Train a RandomForest model.
	GBTClassificationModel classificationModel = new GBTClassifier().fit(trainingData);

	byte[] exportedModel = ModelExporter.export(classificationModel);

	Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

	List<Row> sparkOutput =
	        classificationModel.transform(testData).select("features", "prediction","label").collectAsList();

	// compare predictions
	for (Row row : sparkOutput) {
		Map<String, Object> data_ = new HashMap<>();
		data_.put("features", ((SparseVector) row.get(0)).toArray());
		data_.put("label", (row.get(2)).toString());
		System.out.println(data_.get("prediction")+" ,"+row.get(1));
		assertEquals((double) data_.get("prediction"), (double) row.get(1), EPSILON);

Example 4
Source File:    From -Data-Stream-Development-with-Apache-Spark-Kafka-and-Spring-Boot with MIT License
public static void main(String[] args) throws InterruptedException, StreamingQueryException {

                System.setProperty("hadoop.home.dir", HADOOP_HOME_DIR_VALUE);

                // * the schema can be written on disk, and read from disk
                // * the schema is not mandatory to be complete, it can contain only the needed fields    
                StructType HOUSES_SCHEMA = 
                       new StructType()
                           .add("House", LongType, true)
                           .add("Taxes", LongType, true)
                           .add("Bedrooms", LongType, true)
                           .add("Baths", FloatType, true)
                           .add("Quadrant", LongType, true)
                           .add("NW", StringType, true)
                           .add("Price($)", LongType, false)
                           .add("Size(sqft)", LongType, false)
                           .add("lot", LongType, true);

                final SparkConf conf = new SparkConf()
                    .set("spark.sql.caseSensitive", CASE_SENSITIVE);

                SparkSession sparkSession = SparkSession.builder()

                Dataset<Row> housesDF =
                // Gathering Data				
                Dataset<Row> gatheredDF ="Taxes"), 
                    col("Bedrooms"), col("Baths"),
                    col("Size(sqft)"), col("Price($)"));
                // Data Preparation  
                Dataset<Row> labelDF = gatheredDF.withColumnRenamed("Price($)", "label");
                Imputer imputer = new Imputer()
                    // .setMissingValue(1.0d)
                    .setInputCols(new String[] { "Baths" })
                    .setOutputCols(new String[] { "~Baths~" });

                VectorAssembler assembler = new VectorAssembler()
                    .setInputCols(new String[] { "Taxes", "Bedrooms", "~Baths~", "Size(sqft)" })
                // Choosing a Model               
                LinearRegression linearRegression = new LinearRegression();

                Pipeline pipeline = new Pipeline()
                                .setStages(new PipelineStage[] {
                                    imputer, assembler, linearRegression 

                // Training The Data
                Dataset<Row>[] splitDF = labelDF.randomSplit(new double[] { 0.8, 0.2 });

                Dataset<Row> trainDF = splitDF[0];
                Dataset<Row> evaluationDF = splitDF[1];

                PipelineModel pipelineModel =;
                // Evaluation 
                Dataset<Row> predictionsDF = pipelineModel.transform(evaluationDF);


                Dataset<Row> forEvaluationDF ="label"), 

                RegressionEvaluator evaluteR2 = new RegressionEvaluator().setMetricName("r2");
                RegressionEvaluator evaluteRMSE = new RegressionEvaluator().setMetricName("rmse");

                double r2 = evaluteR2.evaluate(forEvaluationDF);
                double rmse = evaluteRMSE.evaluate(forEvaluationDF);

      "R2 =" + r2);
      "RMSE =" + rmse);
Example 5
Source File:    From spark-transformers with Apache License 2.0
public void testDecisionTreeClassificationPrediction() {
    // Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/classification_test.libsvm";
	Dataset<Row> data ="libsvm").load(datapath);

    // Split the data into training and test sets (30% held out for testing)
    Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeClassificationModel classifierModel = new DecisionTreeClassifier().fit(trainingData);
    List<Row> output = classifierModel.transform(testData).select("features", "prediction","rawPrediction").collectAsList();
    byte[] exportedModel = ModelExporter.export(classifierModel);

    DecisionTreeTransformer transformer = (DecisionTreeTransformer) ModelImporter.importAndGetTransformer(exportedModel);

    //compare predictions
    for (Row row : output) {
    	Map<String, Object> data_ = new HashMap<>();
    	double [] actualRawPrediction = ((DenseVector) row.get(2)).toArray();
        data_.put("features", ((SparseVector) row.get(0)).toArray());
        assertEquals((double)data_.get("prediction"), (double)row.get(1), EPSILON);
        assertArrayEquals((double[]) data_.get("rawPrediction"), actualRawPrediction, EPSILON);
Example 6
Source File:    From spark-transformers with Apache License 2.0
public void testDecisionTreeRegressionPrediction() {
    // Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/regression_test.libsvm";
	Dataset<Row> data ="libsvm").load(datapath);

    // Split the data into training and test sets (30% held out for testing)
    Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];

    // Train a DecisionTree model.
    DecisionTreeRegressionModel regressionModel = new DecisionTreeRegressor().fit(trainingData);
    List<Row> output = regressionModel.transform(testData).select("features", "prediction").collectAsList();
    byte[] exportedModel = ModelExporter.export(regressionModel);

    DecisionTreeTransformer transformer = (DecisionTreeTransformer) ModelImporter.importAndGetTransformer(exportedModel);

    //compare predictions
    for (Row row : output) {
    	Map<String, Object> data_ = new HashMap<>();
        data_.put("features", ((SparseVector) row.get(0)).toArray());
        assertEquals((double)data_.get("prediction"), (double)row.get(1), EPSILON);
Example 7
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  Dataset<Row> data ="libsvm")

  // Prepare training and test data.
  Dataset<Row>[] splits = data.randomSplit(new double[] {0.9, 0.1}, 12345);
  Dataset<Row> training = splits[0];
  Dataset<Row> test = splits[1];

  LinearRegression lr = new LinearRegression();

  // We use a ParamGridBuilder to construct a grid of parameters to search over.
  // TrainValidationSplit will try all combinations of values and determine best model using
  // the evaluator.
  ParamMap[] paramGrid = new ParamGridBuilder()
    .addGrid(lr.regParam(), new double[] {0.1, 0.01})
    .addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})

  // In this case the estimator is simply the linear regression.
  // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
  TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
    .setEvaluator(new RegressionEvaluator())
    .setTrainRatio(0.8);  // 80% for training and the remaining 20% for validation

  // Run train validation split, and choose the best set of parameters.
  TrainValidationSplitModel model =;

  // Make predictions on test data. model is the model with combination of parameters
  // that performed best.
    .select("features", "label", "prediction")
  // $example off$

Example 8
Source File:    From mmtf-spark with Apache License 2.0
 * Dataset must at least contain the following two columns:
 * label: the class labels
 * features: feature vector
 * @param data
 * @return map with metrics
public Map<String,String> fit(Dataset<Row> data) {
	int classCount = (int);

	StringIndexerModel labelIndexer = new StringIndexer()

	// Split the data into training and test sets (30% held out for testing)
	Dataset<Row>[] splits = data.randomSplit(new double[] {1.0-testFraction, testFraction}, seed);
	Dataset<Row> trainingData = splits[0];
	Dataset<Row> testData = splits[1];
	String[] labels = labelIndexer.labels();
	for (String l: labels) {
		System.out.println(l + "\t" + + " = '" + l + "'").count()
				+ "\t" 
				+ + " = '" + l + "'").count());
	// Set input columns

	// Convert indexed labels back to original labels.
	IndexToString labelConverter = new IndexToString()

	// Chain indexers and forest in a Pipeline
	Pipeline pipeline = new Pipeline()
	  .setStages(new PipelineStage[] {labelIndexer, predictor, labelConverter});

	// Train model. This also runs the indexers.
	PipelineModel model =;

	// Make predictions.
	Dataset<Row> predictions = model.transform(testData).cache();
	// Display some sample predictions
	System.out.println("Sample predictions: " + predictor.getClass().getSimpleName());

	predictions.sample(false, 0.1, seed).show(25);	

	predictions = predictions.withColumnRenamed(label, "stringLabel");
	predictions = predictions.withColumnRenamed("indexedLabel", label);
	// collect metrics
	Dataset<Row> pred ="prediction",label);
       Map<String,String> metrics = new LinkedHashMap<>();       
       metrics.put("Method", predictor.getClass().getSimpleName());
       if (classCount == 2) {
       	    BinaryClassificationMetrics b = new BinaryClassificationMetrics(pred);
         	metrics.put("AUC", Float.toString((float)b.areaUnderROC()));
       MulticlassMetrics m = new MulticlassMetrics(pred); 
       metrics.put("F", Float.toString((float)m.weightedFMeasure()));
       metrics.put("Accuracy", Float.toString((float)m.accuracy()));
       metrics.put("Precision", Float.toString((float)m.weightedPrecision()));
       metrics.put("Recall", Float.toString((float)m.weightedRecall()));
       metrics.put("False Positive Rate", Float.toString((float)m.weightedFalsePositiveRate()));
       metrics.put("True Positive Rate", Float.toString((float)m.weightedTruePositiveRate()));
       metrics.put("", "\nConfusion Matrix\n" 
           + Arrays.toString(labels) +"\n" 
       		+ m.confusionMatrix().toString());
       return metrics;
Example 9
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  // Load and parse the data file, converting it to a DataFrame.
  Dataset<Row> data ="libsvm").load("data/mllib/sample_libsvm_data.txt");

  // Automatically identify categorical features, and index them.
  // Set maxCategories so features with > 4 distinct values are treated as continuous.
  VectorIndexerModel featureIndexer = new VectorIndexer()

  // Split the data into training and test sets (30% held out for testing).
  Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
  Dataset<Row> trainingData = splits[0];
  Dataset<Row> testData = splits[1];

  // Train a GBT model.
  GBTRegressor gbt = new GBTRegressor()

  // Chain indexer and GBT in a Pipeline.
  Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] {featureIndexer, gbt});

  // Train model. This also runs the indexer.
  PipelineModel model =;

  // Make predictions.
  Dataset<Row> predictions = model.transform(testData);

  // Select example rows to display."prediction", "label", "features").show(5);

  // Select (prediction, true label) and compute test error.
  RegressionEvaluator evaluator = new RegressionEvaluator()
  double rmse = evaluator.evaluate(predictions);
  System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

  GBTRegressionModel gbtModel = (GBTRegressionModel)(model.stages()[1]);
  System.out.println("Learned regression GBT model:\n" + gbtModel.toDebugString());
  // $example off$

Example 10
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  // Load training data
  String path = "data/mllib/sample_multiclass_classification_data.txt";
  Dataset<Row> dataFrame ="libsvm").load(path);

  // Split the data into train and test
  Dataset<Row>[] splits = dataFrame.randomSplit(new double[]{0.6, 0.4}, 1234L);
  Dataset<Row> train = splits[0];
  Dataset<Row> test = splits[1];

  // specify layers for the neural network:
  // input layer of size 4 (features), two intermediate of size 5 and 4
  // and output of size 3 (classes)
  int[] layers = new int[] {4, 5, 4, 3};

  // create the trainer and set its parameters
  MultilayerPerceptronClassifier trainer = new MultilayerPerceptronClassifier()

  // train the model
  MultilayerPerceptronClassificationModel model =;

  // compute accuracy on the test set
  Dataset<Row> result = model.transform(test);
  Dataset<Row> predictionAndLabels ="prediction", "label");
  MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()

  System.out.println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels));
  // $example off$

Example 11
Source File:    From mmtf-spark with Apache License 2.0
 * Dataset must at least contain the following two columns:
 * label: the class labels
 * features: feature vector
 * @param data
 * @return map with metrics
public Map<String,String> fit(Dataset<Row> data) {

	// Split the data into training and test sets (30% held out for testing)
	Dataset<Row>[] splits = data.randomSplit(new double[] {1.0-testFraction, testFraction}, seed);
	Dataset<Row> trainingData = splits[0];
	Dataset<Row> testData = splits[1];

	// Train a RandomForest model.

	// Chain indexer and forest in a Pipeline
	Pipeline pipeline = new Pipeline()
	  .setStages(new PipelineStage[] {predictor});

	// Train model. This also runs the indexer.
	PipelineModel model =;

	// Make predictions.
	Dataset<Row> predictions = model.transform(testData);

	// Display some sample predictions
	System.out.println("Sample predictions: " + predictor.getClass().getSimpleName());
	String primaryKey = predictions.columns()[0];, label, "prediction").sample(false, 0.1, seed).show(50);
	Map<String,String> metrics = new LinkedHashMap<>();
    metrics.put("Method", predictor.getClass().getSimpleName());
    // Select (prediction, true label) and compute test error
    RegressionEvaluator evaluator = new RegressionEvaluator()
    metrics.put("rmse", Double.toString(evaluator.evaluate(predictions)));

	return metrics;
Example 12
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  // Load and parse the data file, converting it to a DataFrame.
  Dataset<Row> data ="libsvm").load("data/mllib/sample_libsvm_data.txt");

  // Automatically identify categorical features, and index them.
  // Set maxCategories so features with > 4 distinct values are treated as continuous.
  VectorIndexerModel featureIndexer = new VectorIndexer()

  // Split the data into training and test sets (30% held out for testing)
  Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
  Dataset<Row> trainingData = splits[0];
  Dataset<Row> testData = splits[1];

  // Train a RandomForest model.
  RandomForestRegressor rf = new RandomForestRegressor()

  // Chain indexer and forest in a Pipeline
  Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[] {featureIndexer, rf});

  // Train model. This also runs the indexer.
  PipelineModel model =;

  // Make predictions.
  Dataset<Row> predictions = model.transform(testData);

  // Select example rows to display."prediction", "label", "features").show(5);

  // Select (prediction, true label) and compute test error
  RegressionEvaluator evaluator = new RegressionEvaluator()
  double rmse = evaluator.evaluate(predictions);
  System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

  RandomForestRegressionModel rfModel = (RandomForestRegressionModel)(model.stages()[1]);
  System.out.println("Learned regression forest model:\n" + rfModel.toDebugString());
  // $example off$

Example 13
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  // Load and parse the data file, converting it to a DataFrame.
  Dataset<Row> data = spark

  // Index labels, adding metadata to the label column.
  // Fit on whole dataset to include all labels in index.
  StringIndexerModel labelIndexer = new StringIndexer()
  // Automatically identify categorical features, and index them.
  // Set maxCategories so features with > 4 distinct values are treated as continuous.
  VectorIndexerModel featureIndexer = new VectorIndexer()

  // Split the data into training and test sets (30% held out for testing)
  Dataset<Row>[] splits = data.randomSplit(new double[] {0.7, 0.3});
  Dataset<Row> trainingData = splits[0];
  Dataset<Row> testData = splits[1];

  // Train a GBT model.
  GBTClassifier gbt = new GBTClassifier()

  // Convert indexed labels back to original labels.
  IndexToString labelConverter = new IndexToString()

  // Chain indexers and GBT in a Pipeline.
  Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[] {labelIndexer, featureIndexer, gbt, labelConverter});

  // Train model. This also runs the indexers.
  PipelineModel model =;

  // Make predictions.
  Dataset<Row> predictions = model.transform(testData);

  // Select example rows to display."predictedLabel", "label", "features").show(5);

  // Select (prediction, true label) and compute test error.
  MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
  double accuracy = evaluator.evaluate(predictions);
  System.out.println("Test Error = " + (1.0 - accuracy));

  GBTClassificationModel gbtModel = (GBTClassificationModel)(model.stages()[2]);
  System.out.println("Learned classification GBT model:\n" + gbtModel.toDebugString());
  // $example off$

Example 14
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  // Load the data stored in LIBSVM format as a DataFrame.
  Dataset<Row> data = spark

  // Index labels, adding metadata to the label column.
  // Fit on whole dataset to include all labels in index.
  StringIndexerModel labelIndexer = new StringIndexer()

  // Automatically identify categorical features, and index them.
  VectorIndexerModel featureIndexer = new VectorIndexer()
    .setMaxCategories(4) // features with > 4 distinct values are treated as continuous.

  // Split the data into training and test sets (30% held out for testing).
  Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  Dataset<Row> trainingData = splits[0];
  Dataset<Row> testData = splits[1];

  // Train a DecisionTree model.
  DecisionTreeClassifier dt = new DecisionTreeClassifier()

  // Convert indexed labels back to original labels.
  IndexToString labelConverter = new IndexToString()

  // Chain indexers and tree in a Pipeline.
  Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[]{labelIndexer, featureIndexer, dt, labelConverter});

  // Train model. This also runs the indexers.
  PipelineModel model =;

  // Make predictions.
  Dataset<Row> predictions = model.transform(testData);

  // Select example rows to display."predictedLabel", "label", "features").show(5);

  // Select (prediction, true label) and compute test error.
  MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()
  double accuracy = evaluator.evaluate(predictions);
  System.out.println("Test Error = " + (1.0 - accuracy));

  DecisionTreeClassificationModel treeModel =
    (DecisionTreeClassificationModel) (model.stages()[2]);
  System.out.println("Learned classification tree model:\n" + treeModel.toDebugString());
  // $example off$

Example 15
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession
  // $example on$
  // Load the data stored in LIBSVM format as a DataFrame.
  Dataset<Row> data ="libsvm")

  // Automatically identify categorical features, and index them.
  // Set maxCategories so features with > 4 distinct values are treated as continuous.
  VectorIndexerModel featureIndexer = new VectorIndexer()

  // Split the data into training and test sets (30% held out for testing).
  Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  Dataset<Row> trainingData = splits[0];
  Dataset<Row> testData = splits[1];

  // Train a DecisionTree model.
  DecisionTreeRegressor dt = new DecisionTreeRegressor()

  // Chain indexer and tree in a Pipeline.
  Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[]{featureIndexer, dt});

  // Train model. This also runs the indexer.
  PipelineModel model =;

  // Make predictions.
  Dataset<Row> predictions = model.transform(testData);

  // Select example rows to display."label", "features").show(5);

  // Select (prediction, true label) and compute test error.
  RegressionEvaluator evaluator = new RegressionEvaluator()
  double rmse = evaluator.evaluate(predictions);
  System.out.println("Root Mean Squared Error (RMSE) on test data = " + rmse);

  DecisionTreeRegressionModel treeModel =
    (DecisionTreeRegressionModel) (model.stages()[1]);
  System.out.println("Learned regression tree model:\n" + treeModel.toDebugString());
  // $example off$

Example 16
Source File:    From SparkDemo with MIT License
public static void main(String[] args) {
  SparkSession spark = SparkSession

  // $example on$
  // load data file.
  Dataset<Row> inputData ="libsvm")

  // generate the train/test split.
  Dataset<Row>[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
  Dataset<Row> train = tmp[0];
  Dataset<Row> test = tmp[1];

  // configure the base classifier.
  LogisticRegression classifier = new LogisticRegression()

  // instantiate the One Vs Rest Classifier.
  OneVsRest ovr = new OneVsRest().setClassifier(classifier);

  // train the multiclass model.
  OneVsRestModel ovrModel =;

  // score the model on test data.
  Dataset<Row> predictions = ovrModel.transform(test)
    .select("prediction", "label");

  // obtain evaluator.
  MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator()

  // compute the classification error on test data.
  double accuracy = evaluator.evaluate(predictions);
  System.out.println("Test Error = " + (1 - accuracy));
  // $example off$

Example 17
Source File:    From spark-transformers with Apache License 2.0
  public void testDecisionTreeRegressionPrediction() {
      // Load the data stored in LIBSVM format as a DataFrame.
  	String datapath = "src/test/resources/regression_test.libsvm";
  	Dataset<Row> data ="libsvm").load(datapath);

      // Split the data into training and test sets (30% held out for testing)
      Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
      Dataset<Row> trainingData = splits[0];
      Dataset<Row> testData = splits[1];

      StringIndexer indexer = new StringIndexer()
DecisionTreeRegressor regressionModel =
        new DecisionTreeRegressor().setLabelCol("labelIndex").setFeaturesCol("features");

Pipeline pipeline = new Pipeline()
              .setStages(new PipelineStage[]{indexer, regressionModel});

PipelineModel sparkPipeline =;

      byte[] exportedModel = ModelExporter.export(sparkPipeline);

      Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);
      List<Row> output = sparkPipeline.transform(testData).select("features", "prediction", "label").collectAsList();

      //compare predictions
      for (Row row : output) {
      	Map<String, Object> data_ = new HashMap<>();
          data_.put("features", ((SparseVector) row.get(0)).toArray());
          data_.put("label", (row.get(2)).toString());
          assertEquals((double)data_.get("prediction"), (double)row.get(1), EPSILON);
Example 18
Source File:    From spark-transformers with Apache License 2.0
public void testGradientBoostClassification() {
	// Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/binary_classification_test.libsvm";

	Dataset<Row> data ="libsvm").load(datapath);
	StringIndexer indexer = new StringIndexer()
	// Split the data into training and test sets (30% held out for testing)
	Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});
	Dataset<Row> trainingData = splits[0];
	Dataset<Row> testData = splits[1];

	// Train a RandomForest model.
	GBTClassifier classificationModel = new GBTClassifier().setLabelCol("labelIndex")

        Pipeline pipeline = new Pipeline()
                .setStages(new PipelineStage[]{indexer, classificationModel});

	 PipelineModel sparkPipeline =;

	// Export this model
	byte[] exportedModel = ModelExporter.export(sparkPipeline);

	// Import and get Transformer
	Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

	List<Row> sparkOutput = sparkPipeline.transform(testData).select("features", "prediction", "label").collectAsList();
	// compare predictions
	for (Row row : sparkOutput) {
		Map<String, Object> data_ = new HashMap<>();
		data_.put("features", ((SparseVector) row.get(0)).toArray());
		data_.put("label", (row.get(2)).toString());
		System.out.println(data_.get("prediction")+" ,"+row.get(1));
		assertEquals((double) data_.get("prediction"), (double) row.get(1), EPSILON);

Example 19
Source File:    From spark-transformers with Apache License 2.0
public void testDecisionTreeClassificationWithPipeline() {

    // Load the data stored in LIBSVM format as a DataFrame.
	String datapath = "src/test/resources/classification_test.libsvm";
	Dataset<Row> data ="libsvm").load(datapath);

    // Split the data into training and test sets (30% held out for testing)
    Dataset<Row>[] splits = data.randomSplit(new double[]{0.7, 0.3});        

    Dataset<Row> trainingData = splits[0];
    Dataset<Row> testData = splits[1];

    StringIndexer indexer = new StringIndexer()

    // Train a DecisionTree model.
    DecisionTreeClassifier classificationModel = new DecisionTreeClassifier()

    Pipeline pipeline = new Pipeline()
            .setStages(new PipelineStage[]{indexer, classificationModel});

    // Train model.  This also runs the indexer.
    PipelineModel sparkPipeline =;

    //Export this model
    byte[] exportedModel = ModelExporter.export(sparkPipeline);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    List<Row> output = sparkPipeline.transform(testData).select("features", "label","prediction","rawPrediction").collectAsList();

    //compare predictions
    for (Row row : output) {
    	Map<String, Object> data_ = new HashMap<>();
    	double [] actualRawPrediction = ((DenseVector) row.get(3)).toArray();
        data_.put("features", ((SparseVector) row.get(0)).toArray());
        data_.put("label", (row.get(1)).toString());
        assertEquals((double)data_.get("prediction"), (double)row.get(2), EPSILON);
        assertArrayEquals((double[]) data_.get("rawPrediction"), actualRawPrediction, EPSILON);
Example 20
Source File:    From Spark_ALS with MIT License
public static void main(String[] args) {
    SparkConf conf = new SparkConf().setAppName("JavaALSExample").setMaster("local");
    JavaSparkContext jsc = new JavaSparkContext(conf);
    SQLContext sqlContext = new SQLContext(jsc);

    JavaRDD<Rating> ratingsRDD = jsc.textFile("data/sample_movielens_ratings.txt")
            .map(new Function<String, Rating>() {
                public Rating call(String str) {
                    return Rating.parseRating(str);
    Dataset<Row> ratings = sqlContext.createDataFrame(ratingsRDD, Rating.class);
    Dataset<Row>[] splits = ratings.randomSplit(new double[]{0.8, 0.2}); // //对数据进行分割,80%为训练样例,剩下的为测试样例。
    Dataset<Row> training = splits[0];
    Dataset<Row> test = splits[1];

    // Build the recommendation model using ALS on the training data
    ALS als = new ALS().setMaxIter(5) // 设置迭代次数
            .setRegParam(0.01) // //正则化参数,使每次迭代平滑一些,此数据集取0.1好像错误率低一些。
    ALSModel model =; // //调用算法开始训练

    Dataset<Row> itemFactors = model.itemFactors();;
    Dataset<Row> userFactors = model.userFactors();;

    // Evaluate the model by computing the RMSE on the test data
    Dataset<Row> rawPredictions = model.transform(test); //对测试数据进行预测
    Dataset<Row> predictions = rawPredictions
            .withColumn("rating", rawPredictions.col("rating").cast(DataTypes.DoubleType))
            .withColumn("prediction", rawPredictions.col("prediction").cast(DataTypes.DoubleType));

    RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating")
    Double rmse = evaluator.evaluate(predictions);"Root-mean-square error = {} ", rmse);
