Java Code Examples for org.apache.spark.api.java.JavaRDD#randomSplit()

The following examples show how to use org.apache.spark.api.java.JavaRDD#randomSplit() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: SparkExport.java    From DataVec with Apache License 2.0 6 votes vote down vote up
public static void exportCSVLocal(String outputDir, String baseFileName, int numFiles, String delimiter,
                String quote, JavaRDD<List<Writable>> data, int rngSeed) throws Exception {

    JavaRDD<String> lines = data.map(new WritablesToStringFunction(delimiter, quote));
    double[] split = new double[numFiles];
    for (int i = 0; i < split.length; i++)
        split[i] = 1.0 / numFiles;
    JavaRDD<String>[] splitData = lines.randomSplit(split);

    int count = 0;
    Random r = new Random(rngSeed);
    for (JavaRDD<String> subset : splitData) {
        String path = FilenameUtils.concat(outputDir, baseFileName + (count++) + ".csv");
        List<String> linesList = subset.collect();
        if (!(linesList instanceof ArrayList))
            linesList = new ArrayList<>(linesList);
        Collections.shuffle(linesList, r);
        FileUtils.writeLines(new File(path), linesList);
    }
}
 
Example 2
Source File: SparkExport.java    From DataVec with Apache License 2.0 6 votes vote down vote up
public static void exportCSVLocal(String outputDir, String baseFileName, int numFiles, String delimiter,
                String quote, JavaRDD<List<Writable>> data) throws Exception {

    JavaRDD<String> lines = data.map(new WritablesToStringFunction(delimiter, quote));
    double[] split = new double[numFiles];
    for (int i = 0; i < split.length; i++)
        split[i] = 1.0 / numFiles;
    JavaRDD<String>[] splitData = lines.randomSplit(split);

    int count = 0;
    for (JavaRDD<String> subset : splitData) {
        String path = FilenameUtils.concat(outputDir, baseFileName + (count++) + ".csv");
        //            subset.saveAsTextFile(path);
        List<String> linesList = subset.collect();
        FileUtils.writeLines(new File(path), linesList);
    }
}
 
Example 3
Source File: SparkExport.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static void exportCSVLocal(String outputDir, String baseFileName, int numFiles, String delimiter,
                String quote, JavaRDD<List<Writable>> data, int rngSeed) throws Exception {

    JavaRDD<String> lines = data.map(new WritablesToStringFunction(delimiter, quote));
    double[] split = new double[numFiles];
    for (int i = 0; i < split.length; i++)
        split[i] = 1.0 / numFiles;
    JavaRDD<String>[] splitData = lines.randomSplit(split);

    int count = 0;
    Random r = new Random(rngSeed);
    for (JavaRDD<String> subset : splitData) {
        String path = FilenameUtils.concat(outputDir, baseFileName + (count++) + ".csv");
        List<String> linesList = subset.collect();
        if (!(linesList instanceof ArrayList))
            linesList = new ArrayList<>(linesList);
        Collections.shuffle(linesList, r);
        FileUtils.writeLines(new File(path), linesList);
    }
}
 
Example 4
Source File: SparkExport.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static void exportCSVLocal(String outputDir, String baseFileName, int numFiles, String delimiter,
                String quote, JavaRDD<List<Writable>> data) throws Exception {

    JavaRDD<String> lines = data.map(new WritablesToStringFunction(delimiter, quote));
    double[] split = new double[numFiles];
    for (int i = 0; i < split.length; i++)
        split[i] = 1.0 / numFiles;
    JavaRDD<String>[] splitData = lines.randomSplit(split);

    int count = 0;
    for (JavaRDD<String> subset : splitData) {
        String path = FilenameUtils.concat(outputDir, baseFileName + (count++) + ".csv");
        //            subset.saveAsTextFile(path);
        List<String> linesList = subset.collect();
        FileUtils.writeLines(new File(path), linesList);
    }
}
 
Example 5
Source File: JavaNaiveBayesExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkConf sparkConf = new SparkConf().setAppName("JavaNaiveBayesExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);
  // $example on$
  String path = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> inputData = MLUtils.loadLibSVMFile(jsc.sc(), path).toJavaRDD();
  JavaRDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{0.6, 0.4});
  JavaRDD<LabeledPoint> training = tmp[0]; // training set
  JavaRDD<LabeledPoint> test = tmp[1]; // test set
  final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
  JavaPairRDD<Double, Double> predictionAndLabel =
    test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  double accuracy = predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
    @Override
    public Boolean call(Tuple2<Double, Double> pl) {
      return pl._1().equals(pl._2());
    }
  }).count() / (double) test.count();

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myNaiveBayesModel");
  NaiveBayesModel sameModel = NaiveBayesModel.load(jsc.sc(), "target/tmp/myNaiveBayesModel");
  // $example off$

  jsc.stop();
}
 
Example 6
Source File: JavaLogisticRegressionWithLBFGSExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {
  SparkConf conf = new SparkConf().setAppName("JavaLogisticRegressionWithLBFGSExample");
  SparkContext sc = new SparkContext(conf);
  // $example on$
  String path = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

  // Split initial RDD into two... [60% training data, 40% testing data].
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[] {0.6, 0.4}, 11L);
  JavaRDD<LabeledPoint> training = splits[0].cache();
  JavaRDD<LabeledPoint> test = splits[1];

  // Run training algorithm to build the model.
  final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
    .setNumClasses(10)
    .run(training.rdd());

  // Compute raw scores on the test set.
  JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
    new Function<LabeledPoint, Tuple2<Object, Object>>() {
      public Tuple2<Object, Object> call(LabeledPoint p) {
        Double prediction = model.predict(p.features());
        return new Tuple2<Object, Object>(prediction, p.label());
      }
    }
  );

  // Get evaluation metrics.
  MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());
  double accuracy = metrics.accuracy();
  System.out.println("Accuracy = " + accuracy);

  // Save and load model
  model.save(sc, "target/tmp/javaLogisticRegressionWithLBFGSModel");
  LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc,
    "target/tmp/javaLogisticRegressionWithLBFGSModel");
  // $example off$

  sc.stop();
}
 
Example 7
Source File: RandomForestMlib.java    From Java-Data-Science-Cookbook with MIT License 4 votes vote down vote up
public static void main(String args[]){

		SparkConf configuration = new SparkConf().setMaster("local[4]").setAppName("Any");
		JavaSparkContext sc = new JavaSparkContext(configuration);

		// Load and parse the data file.
		String input = "data/rf-data.txt";
		JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc.sc(), input).toJavaRDD();
		// Split the data into training and test sets (30% held out for testing)
		JavaRDD<LabeledPoint>[] dataSplits = data.randomSplit(new double[]{0.7, 0.3});
		JavaRDD<LabeledPoint> trainingData = dataSplits[0];
		JavaRDD<LabeledPoint> testData = dataSplits[1];

		// Train a RandomForest model.
		Integer numClasses = 2;
		HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();//  Empty categoricalFeaturesInfo indicates all features are continuous.
		Integer numTrees = 3; // Use more in practice.
		String featureSubsetStrategy = "auto"; // Let the algorithm choose.
		String impurity = "gini";
		Integer maxDepth = 5;
		Integer maxBins = 32;
		Integer seed = 12345;

		final RandomForestModel rfModel = RandomForest.trainClassifier(trainingData, numClasses,
				categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
				seed);

		// Evaluate model on test instances and compute test error
		JavaPairRDD<Double, Double> label =
				testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
					public Tuple2<Double, Double> call(LabeledPoint p) {
						return new Tuple2<Double, Double>(rfModel.predict(p.features()), p.label());
					}
				});

		Double testError =
				1.0 * label.filter(new Function<Tuple2<Double, Double>, Boolean>() {
					public Boolean call(Tuple2<Double, Double> pl) {
						return !pl._1().equals(pl._2());
					}
				}).count() / testData.count();

		System.out.println("Test Error: " + testError);
		System.out.println("Learned classification forest model:\n" + rfModel.toDebugString());
	}
 
Example 8
Source File: JavaRandomForestRegressionExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  // $example on$
  SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestRegressionExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);
  // Load and parse the data file.
  String datapath = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
  // Split the data into training and test sets (30% held out for testing)
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1];

  // Set parameters.
  // Empty categoricalFeaturesInfo indicates all features are continuous.
  Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
  Integer numTrees = 3; // Use more in practice.
  String featureSubsetStrategy = "auto"; // Let the algorithm choose.
  String impurity = "variance";
  Integer maxDepth = 4;
  Integer maxBins = 32;
  Integer seed = 12345;
  // Train a RandomForest model.
  final RandomForestModel model = RandomForest.trainRegressor(trainingData,
    categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed);

  // Evaluate model on test instances and compute test error
  JavaPairRDD<Double, Double> predictionAndLabel =
    testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  Double testMSE =
    predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
      @Override
      public Double call(Tuple2<Double, Double> pl) {
        Double diff = pl._1() - pl._2();
        return diff * diff;
      }
    }).reduce(new Function2<Double, Double, Double>() {
      @Override
      public Double call(Double a, Double b) {
        return a + b;
      }
    }) / testData.count();
  System.out.println("Test Mean Squared Error: " + testMSE);
  System.out.println("Learned regression forest model:\n" + model.toDebugString());

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myRandomForestRegressionModel");
  RandomForestModel sameModel = RandomForestModel.load(jsc.sc(),
    "target/tmp/myRandomForestRegressionModel");
  // $example off$

  jsc.stop();
}
 
Example 9
Source File: JavaRandomForestClassificationExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  // $example on$
  SparkConf sparkConf = new SparkConf().setAppName("JavaRandomForestClassificationExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);
  // Load and parse the data file.
  String datapath = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
  // Split the data into training and test sets (30% held out for testing)
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1];

  // Train a RandomForest model.
  // Empty categoricalFeaturesInfo indicates all features are continuous.
  Integer numClasses = 2;
  HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
  Integer numTrees = 3; // Use more in practice.
  String featureSubsetStrategy = "auto"; // Let the algorithm choose.
  String impurity = "gini";
  Integer maxDepth = 5;
  Integer maxBins = 32;
  Integer seed = 12345;

  final RandomForestModel model = RandomForest.trainClassifier(trainingData, numClasses,
    categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins,
    seed);

  // Evaluate model on test instances and compute test error
  JavaPairRDD<Double, Double> predictionAndLabel =
    testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  Double testErr =
    1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
      @Override
      public Boolean call(Tuple2<Double, Double> pl) {
        return !pl._1().equals(pl._2());
      }
    }).count() / testData.count();
  System.out.println("Test Error: " + testErr);
  System.out.println("Learned classification forest model:\n" + model.toDebugString());

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myRandomForestClassificationModel");
  RandomForestModel sameModel = RandomForestModel.load(jsc.sc(),
    "target/tmp/myRandomForestClassificationModel");
  // $example off$

  jsc.stop();
}
 
Example 10
Source File: JavaGradientBoostingClassificationExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  // $example on$
  SparkConf sparkConf = new SparkConf()
    .setAppName("JavaGradientBoostedTreesClassificationExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);

  // Load and parse the data file.
  String datapath = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
  // Split the data into training and test sets (30% held out for testing)
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1];

  // Train a GradientBoostedTrees model.
  // The defaultParams for Classification use LogLoss by default.
  BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Classification");
  boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.
  boostingStrategy.getTreeStrategy().setNumClasses(2);
  boostingStrategy.getTreeStrategy().setMaxDepth(5);
  // Empty categoricalFeaturesInfo indicates all features are continuous.
  Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
  boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);

  final GradientBoostedTreesModel model =
    GradientBoostedTrees.train(trainingData, boostingStrategy);

  // Evaluate model on test instances and compute test error
  JavaPairRDD<Double, Double> predictionAndLabel =
    testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  Double testErr =
    1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
      @Override
      public Boolean call(Tuple2<Double, Double> pl) {
        return !pl._1().equals(pl._2());
      }
    }).count() / testData.count();
  System.out.println("Test Error: " + testErr);
  System.out.println("Learned classification GBT model:\n" + model.toDebugString());

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myGradientBoostingClassificationModel");
  GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(),
    "target/tmp/myGradientBoostingClassificationModel");
  // $example off$

  jsc.stop();
}
 
Example 11
Source File: JavaGradientBoostingRegressionExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  // $example on$
  SparkConf sparkConf = new SparkConf()
    .setAppName("JavaGradientBoostedTreesRegressionExample");
  JavaSparkContext jsc = new JavaSparkContext(sparkConf);
  // Load and parse the data file.
  String datapath = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
  // Split the data into training and test sets (30% held out for testing)
  JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
  JavaRDD<LabeledPoint> trainingData = splits[0];
  JavaRDD<LabeledPoint> testData = splits[1];

  // Train a GradientBoostedTrees model.
  // The defaultParams for Regression use SquaredError by default.
  BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams("Regression");
  boostingStrategy.setNumIterations(3); // Note: Use more iterations in practice.
  boostingStrategy.getTreeStrategy().setMaxDepth(5);
  // Empty categoricalFeaturesInfo indicates all features are continuous.
  Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
  boostingStrategy.treeStrategy().setCategoricalFeaturesInfo(categoricalFeaturesInfo);

  final GradientBoostedTreesModel model =
    GradientBoostedTrees.train(trainingData, boostingStrategy);

  // Evaluate model on test instances and compute test error
  JavaPairRDD<Double, Double> predictionAndLabel =
    testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
  Double testMSE =
    predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
      @Override
      public Double call(Tuple2<Double, Double> pl) {
        Double diff = pl._1() - pl._2();
        return diff * diff;
      }
    }).reduce(new Function2<Double, Double, Double>() {
      @Override
      public Double call(Double a, Double b) {
        return a + b;
      }
    }) / data.count();
  System.out.println("Test Mean Squared Error: " + testMSE);
  System.out.println("Learned regression GBT model:\n" + model.toDebugString());

  // Save and load model
  model.save(jsc.sc(), "target/tmp/myGradientBoostingRegressionModel");
  GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(jsc.sc(),
    "target/tmp/myGradientBoostingRegressionModel");
  // $example off$

  jsc.stop();
}
 
Example 12
Source File: JavaDecisionTreeClassificationExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {

    // $example on$
    SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeClassificationExample");
    JavaSparkContext jsc = new JavaSparkContext(sparkConf);

    // Load and parse the data file.
    String datapath = "data/mllib/sample_libsvm_data.txt";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
    // Split the data into training and test sets (30% held out for testing)
    JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
    JavaRDD<LabeledPoint> trainingData = splits[0];
    JavaRDD<LabeledPoint> testData = splits[1];

    // Set parameters.
    //  Empty categoricalFeaturesInfo indicates all features are continuous.
    Integer numClasses = 2;
    Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
    String impurity = "gini";
    Integer maxDepth = 5;
    Integer maxBins = 32;

    // Train a DecisionTree model for classification.
    final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
      categoricalFeaturesInfo, impurity, maxDepth, maxBins);

    // Evaluate model on test instances and compute test error
    JavaPairRDD<Double, Double> predictionAndLabel =
      testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
        @Override
        public Tuple2<Double, Double> call(LabeledPoint p) {
          return new Tuple2<>(model.predict(p.features()), p.label());
        }
      });
    Double testErr =
      1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
        @Override
        public Boolean call(Tuple2<Double, Double> pl) {
          return !pl._1().equals(pl._2());
        }
      }).count() / testData.count();

    System.out.println("Test Error: " + testErr);
    System.out.println("Learned classification tree model:\n" + model.toDebugString());

    // Save and load model
    model.save(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
    DecisionTreeModel sameModel = DecisionTreeModel
      .load(jsc.sc(), "target/tmp/myDecisionTreeClassificationModel");
    // $example off$
  }
 
Example 13
Source File: JavaDecisionTreeRegressionExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {

    // $example on$
    SparkConf sparkConf = new SparkConf().setAppName("JavaDecisionTreeRegressionExample");
    JavaSparkContext jsc = new JavaSparkContext(sparkConf);

    // Load and parse the data file.
    String datapath = "data/mllib/sample_libsvm_data.txt";
    JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(jsc.sc(), datapath).toJavaRDD();
    // Split the data into training and test sets (30% held out for testing)
    JavaRDD<LabeledPoint>[] splits = data.randomSplit(new double[]{0.7, 0.3});
    JavaRDD<LabeledPoint> trainingData = splits[0];
    JavaRDD<LabeledPoint> testData = splits[1];

    // Set parameters.
    // Empty categoricalFeaturesInfo indicates all features are continuous.
    Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
    String impurity = "variance";
    Integer maxDepth = 5;
    Integer maxBins = 32;

    // Train a DecisionTree model.
    final DecisionTreeModel model = DecisionTree.trainRegressor(trainingData,
      categoricalFeaturesInfo, impurity, maxDepth, maxBins);

    // Evaluate model on test instances and compute test error
    JavaPairRDD<Double, Double> predictionAndLabel =
      testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
      @Override
      public Tuple2<Double, Double> call(LabeledPoint p) {
        return new Tuple2<>(model.predict(p.features()), p.label());
      }
    });
    Double testMSE =
      predictionAndLabel.map(new Function<Tuple2<Double, Double>, Double>() {
        @Override
        public Double call(Tuple2<Double, Double> pl) {
          Double diff = pl._1() - pl._2();
          return diff * diff;
        }
      }).reduce(new Function2<Double, Double, Double>() {
        @Override
        public Double call(Double a, Double b) {
          return a + b;
        }
      }) / data.count();
    System.out.println("Test Mean Squared Error: " + testMSE);
    System.out.println("Learned regression tree model:\n" + model.toDebugString());

    // Save and load model
    model.save(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
    DecisionTreeModel sameModel = DecisionTreeModel
      .load(jsc.sc(), "target/tmp/myDecisionTreeRegressionModel");
    // $example off$
  }
 
Example 14
Source File: SparkUtils.java    From DataVec with Apache License 2.0 4 votes vote down vote up
public static <T> List<JavaRDD<T>> splitData(SplitStrategy splitStrategy, JavaRDD<T> data, long seed) {

        if (splitStrategy instanceof RandomSplit) {

            RandomSplit rs = (RandomSplit) splitStrategy;

            double fractionTrain = rs.getFractionTrain();

            double[] splits = new double[] {fractionTrain, 1.0 - fractionTrain};

            JavaRDD<T>[] split = data.randomSplit(splits, seed);
            List<JavaRDD<T>> list = new ArrayList<>(2);
            Collections.addAll(list, split);

            return list;

        } else {
            throw new RuntimeException("Not yet implemented");
        }
    }
 
Example 15
Source File: SparkUtils.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
public static <T> List<JavaRDD<T>> splitData(SplitStrategy splitStrategy, JavaRDD<T> data, long seed) {

        if (splitStrategy instanceof RandomSplit) {

            RandomSplit rs = (RandomSplit) splitStrategy;

            double fractionTrain = rs.getFractionTrain();

            double[] splits = new double[] {fractionTrain, 1.0 - fractionTrain};

            JavaRDD<T>[] split = data.randomSplit(splits, seed);
            List<JavaRDD<T>> list = new ArrayList<>(2);
            Collections.addAll(list, split);

            return list;

        } else {
            throw new RuntimeException("Not yet implemented");
        }
    }