Java Code Examples for org.apache.spark.api.java.JavaRDD#cache()

The following examples show how to use org.apache.spark.api.java.JavaRDD#cache() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: SparkCacheOperator.java    From rheem with Apache License 2.0 6 votes vote down vote up
@Override
public Tuple<Collection<ExecutionLineageNode>, Collection<ChannelInstance>> evaluate(
        ChannelInstance[] inputs,
        ChannelInstance[] outputs,
        SparkExecutor sparkExecutor,
        OptimizationContext.OperatorContext operatorContext) {
    RddChannel.Instance input = (RddChannel.Instance) inputs[0];
    final JavaRDD<Object> rdd = input.provideRdd();
    final JavaRDD<Object> cachedRdd = rdd.cache();
    cachedRdd.foreachPartition(iterator -> {
    });

    RddChannel.Instance output = (RddChannel.Instance) outputs[0];
    output.accept(cachedRdd, sparkExecutor);

    return ExecutionOperator.modelQuasiEagerExecution(inputs, outputs, operatorContext);
}
 
Example 2
Source File: AnalyzeSpark.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
public static DataAnalysis analyze(Schema schema, JavaRDD<List<Writable>> data, int maxHistogramBuckets) {
    data.cache();
    /*
     * TODO: Some care should be given to add histogramBuckets and histogramBucketCounts to this in the future
     */

    List<ColumnType> columnTypes = schema.getColumnTypes();
    List<AnalysisCounter> counters =
                    data.aggregate(null, new AnalysisAddFunction(schema), new AnalysisCombineFunction());

    double[][] minsMaxes = new double[counters.size()][2];
    List<ColumnAnalysis> list = DataVecAnalysisUtils.convertCounters(counters, minsMaxes, columnTypes);

    List<HistogramCounter> histogramCounters =
                    data.aggregate(null, new HistogramAddFunction(maxHistogramBuckets, schema, minsMaxes),
                                    new HistogramCombineFunction());

    DataVecAnalysisUtils.mergeCounters(list, histogramCounters);
    return new DataAnalysis(schema, list);
}
 
Example 3
Source File: AssemblyContigAlignmentsConfigPicker.java    From gatk with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
/**
 * Parses input alignments into custom {@link AlignmentInterval} format, and
 * performs a primitive filtering implemented in
 * {@link #notDiscardForBadMQ(AlignedContig)} that
 * gets rid of contigs with no good alignments.
 *
 * It's important to remember that this step doesn't select alignments,
 * but only parses alignments and either keeps the whole contig or drops it completely.
 */
private static JavaRDD<AlignedContig> convertRawAlignmentsToAlignedContigAndFilterByQuality(final JavaRDD<GATKRead> assemblyAlignments,
                                                                                            final SAMFileHeader header,
                                                                                            final Logger toolLogger) {
    assemblyAlignments.cache();
    toolLogger.info( "Processing " + assemblyAlignments.count() + " raw alignments from " +
                     assemblyAlignments.map(GATKRead::getName).distinct().count() + " contigs.");

    final JavaRDD<AlignedContig> parsedContigAlignments =
            new SvDiscoverFromLocalAssemblyContigAlignmentsSpark.SAMFormattedContigAlignmentParser(assemblyAlignments, header, false)
                    .getAlignedContigs()
                    .filter(AssemblyContigAlignmentsConfigPicker::notDiscardForBadMQ).cache();
    assemblyAlignments.unpersist();
    toolLogger.info( "Filtering on MQ left " + parsedContigAlignments.count() + " contigs.");
    return parsedContigAlignments;
}
 
Example 4
Source File: MLLibUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Convert an rdd of data set in to labeled point.
 * @param data the dataset to convert
 * @param preCache boolean pre-cache rdd before operation
 * @return an rdd of labeled point
 */
public static JavaRDD<LabeledPoint> fromDataSet(JavaRDD<DataSet> data, boolean preCache) {
    if (preCache && !data.getStorageLevel().useMemory()) {
        data.cache();
    }
    return data.map(new Function<DataSet, LabeledPoint>() {
        @Override
        public LabeledPoint call(DataSet dataSet) {
            return toLabeledPoint(dataSet);
        }
    });
}
 
Example 5
Source File: JavaGaussianMixtureExample.java    From SparkDemo with MIT License 5 votes vote down vote up
public static void main(String[] args) {

    SparkConf conf = new SparkConf().setAppName("JavaGaussianMixtureExample");
    JavaSparkContext jsc = new JavaSparkContext(conf);

    // $example on$
    // Load and parse data
    String path = "data/mllib/gmm_data.txt";
    JavaRDD<String> data = jsc.textFile(path);
    JavaRDD<Vector> parsedData = data.map(
      new Function<String, Vector>() {
        public Vector call(String s) {
          String[] sarray = s.trim().split(" ");
          double[] values = new double[sarray.length];
          for (int i = 0; i < sarray.length; i++) {
            values[i] = Double.parseDouble(sarray[i]);
          }
          return Vectors.dense(values);
        }
      }
    );
    parsedData.cache();

    // Cluster the data into two classes using GaussianMixture
    GaussianMixtureModel gmm = new GaussianMixture().setK(2).run(parsedData.rdd());

    // Save and load GaussianMixtureModel
    gmm.save(jsc.sc(), "target/org/apache/spark/JavaGaussianMixtureExample/GaussianMixtureModel");
    GaussianMixtureModel sameModel = GaussianMixtureModel.load(jsc.sc(),
      "target/org.apache.spark.JavaGaussianMixtureExample/GaussianMixtureModel");

    // Output the parameters of the mixture model
    for (int j = 0; j < gmm.k(); j++) {
      System.out.printf("weight=%f\nmu=%s\nsigma=\n%s\n",
        gmm.weights()[j], gmm.gaussians()[j].mu(), gmm.gaussians()[j].sigma());
    }
    // $example off$

    jsc.stop();
  }
 
Example 6
Source File: MLLibUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Converts JavaRDD labeled points to JavaRDD DataSets.
 * @param data JavaRDD LabeledPoints
 * @param numPossibleLabels number of possible labels
 * @param preCache boolean pre-cache rdd before operation
 * @return
 */
public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, final long numPossibleLabels,
                boolean preCache) {
    if (preCache && !data.getStorageLevel().useMemory()) {
        data.cache();
    }
    return data.map(new Function<LabeledPoint, DataSet>() {
        @Override
        public DataSet call(LabeledPoint lp) {
            return fromLabeledPoint(lp, numPossibleLabels);
        }
    });
}
 
Example 7
Source File: MLLibUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Converts a continuous JavaRDD LabeledPoint to a JavaRDD DataSet.
 * @param data JavaRdd LabeledPoint
 * @param preCache boolean pre-cache rdd before operation
 * @return
 */
public static JavaRDD<DataSet> fromContinuousLabeledPoint(JavaRDD<LabeledPoint> data, boolean preCache) {
    if (preCache && !data.getStorageLevel().useMemory()) {
        data.cache();
    }
    return data.map(new Function<LabeledPoint, DataSet>() {
        @Override
        public DataSet call(LabeledPoint lp) {
            return convertToDataset(lp);
        }
    });
}
 
Example 8
Source File: CollectMultipleMetricsSpark.java    From gatk with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
@Override
protected void runTool( final JavaSparkContext ctx ) {
    final JavaRDD<GATKRead> unFilteredReads = getUnfilteredReads();
    List<SparkCollectorProvider> collectorsToRun = getCollectorsToRun();
    if (collectorsToRun.size() > 1) {
        // if there is more than one collector to run, cache the
        // unfiltered RDD so we don't recompute it
        unFilteredReads.cache();
    }
    for (final SparkCollectorProvider provider : collectorsToRun) {
        MetricsCollectorSpark<? extends MetricsArgumentCollection> metricsCollector =
                provider.createCollector(
                    outputBaseName,
                    metricAccumulationLevel.accumulationLevels,
                    getDefaultHeaders(),
                    getHeaderForReads()
                );
        validateCollector(metricsCollector, collectorsToRun.get(collectorsToRun.indexOf(provider)).getClass().getName());

        // Execute the collector's lifecycle

        //Bypass the framework merging of command line filters and just apply the default
        //ones specified by the collector
        ReadFilter readFilter = ReadFilter.fromList(metricsCollector.getDefaultReadFilters(), getHeaderForReads());

        metricsCollector.collectMetrics(
                unFilteredReads.filter(r -> readFilter.test(r)),
                getHeaderForReads()
        );
        metricsCollector.saveMetrics(getReadSourceName());
    }
}
 
Example 9
Source File: SparkSharder.java    From gatk with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
private static <L extends Locatable, I extends Locatable, T> JavaRDD<T> joinOverlapping(JavaSparkContext ctx, JavaRDD<L> locatables, Class<L> locatableClass,
                                                                                        SAMSequenceDictionary sequenceDictionary, JavaRDD<I> intervals,
                                                                                        int maxLocatableLength, FlatMapFunction2<Iterator<L>, Iterator<I>, T> f) {

    List<PartitionLocatable<SimpleInterval>> partitionReadExtents = computePartitionReadExtents(locatables, sequenceDictionary, maxLocatableLength);
    List<SimpleInterval> firstLocatablesList = partitionReadExtents.stream().map(PartitionLocatable::getLocatable).collect(Collectors.toList());
    Broadcast<List<SimpleInterval>> firstLocatablesBroadcast = ctx.broadcast(firstLocatablesList);

    // For each interval find which partition it starts and ends in.
    // An interval is processed in the partition it starts in. However, we need to make sure that
    // subsequent partitions are coalesced if needed, so for each partition p find the latest subsequent
    // partition that is needed to read all of the intervals that start in p.
    OverlapDetector<PartitionLocatable<SimpleInterval>> overlapDetector = OverlapDetector.create(partitionReadExtents);
    Broadcast<OverlapDetector<PartitionLocatable<SimpleInterval>>> overlapDetectorBroadcast = ctx.broadcast(overlapDetector);
    JavaRDD<PartitionLocatable<I>> indexedIntervals = intervals.map(interval -> {
        int[] partitionIndexes = overlapDetectorBroadcast.getValue().getOverlaps(interval).stream()
                .mapToInt(PartitionLocatable::getPartitionIndex).toArray();
        if (partitionIndexes.length == 0) {
            final List<SimpleInterval> firstLocatables = firstLocatablesBroadcast.getValue();
            // interval does not overlap any partition - add it to the one after the interval start
            int i = Collections.binarySearch(firstLocatables, new SimpleInterval(interval), (o1, o2) -> IntervalUtils.compareLocatables(o1, o2, sequenceDictionary));
            if (i >= 0) {
                throw new IllegalStateException(); // TODO: no overlaps, yet start of interval matches a partition read extent start
            }
            int insertionPoint = -i - 1;
            if (insertionPoint == firstLocatables.size()) {
                insertionPoint = firstLocatables.size() - 1;
            }
            return new PartitionLocatable<>(insertionPoint, interval);
        }
        Arrays.sort(partitionIndexes);
        int startIndex = partitionIndexes[0];
        int endIndex = partitionIndexes[partitionIndexes.length - 1];
        return new PartitionLocatable<>(startIndex, endIndex, interval);
    });

    // Create an RDD of intervals with the same number of partitions as the locatables, and where each interval
    // is in its start partition. Within each partition, intervals are sorted by IntervalUtils#compareLocatables.
    JavaRDD<PartitionLocatable<I>> indexedIntervalsRepartitioned = indexedIntervals
            .mapToPair(interval ->
                    new Tuple2<>(interval, (Void) null))
            .repartitionAndSortWithinPartitions(new PartitionLocatablePartitioner(locatables.getNumPartitions()), new PartitionLocatableComparator<I>(sequenceDictionary))
            .keys();

    indexedIntervalsRepartitioned.cache(); // cache since we need to do two calculations on the intervals

    // Find the end partition index for each partition.
    Map<Integer, Integer> maxEndPartitionIndexesMap = indexedIntervalsRepartitioned.mapToPair((PairFunction<PartitionLocatable<I>, Integer, Integer>) partitionLocatable ->
            new Tuple2<>(partitionLocatable.getPartitionIndex(), partitionLocatable.getEndPartitionIndex()))
            .reduceByKey((Function2<Integer, Integer, Integer>) Math::max)
            .collectAsMap();
    List<Integer> maxEndPartitionIndexes = IntStream.range(0, locatables.getNumPartitions()).boxed().collect(Collectors.toList());
    maxEndPartitionIndexesMap.forEach((startIndex, endIndex) -> {
        if (endIndex > maxEndPartitionIndexes.get(startIndex)) {
            maxEndPartitionIndexes.set(startIndex, endIndex);
        }
    });

    JavaRDD<L> coalescedRdd = coalesce(locatables, locatableClass, new RangePartitionCoalescer(maxEndPartitionIndexes));

    // zipPartitions on coalesced locatable partitions and intervals, and apply the function f
    return coalescedRdd.zipPartitions(indexedIntervalsRepartitioned.map(PartitionLocatable::getLocatable), f);
}
 
Example 10
Source File: DataSparkFromRDD.java    From toolbox with Apache License 2.0 4 votes vote down vote up
public DataSparkFromRDD(JavaRDD<DataInstance> input, Attributes atts) {

        // FIXME: is this a good idea?
        amidstRDD = input.cache();
        attributes = atts;
    }
 
Example 11
Source File: ALSUpdate.java    From oryx with Apache License 2.0 4 votes vote down vote up
@Override
public PMML buildModel(JavaSparkContext sparkContext,
                       JavaRDD<String> trainData,
                       List<?> hyperParameters,
                       Path candidatePath) {
  int features = (Integer) hyperParameters.get(0);
  double lambda = (Double) hyperParameters.get(1);
  double alpha = (Double) hyperParameters.get(2);
  double epsilon = Double.NaN;
  if (logStrength) {
    epsilon = (Double) hyperParameters.get(3);
  }
  Preconditions.checkArgument(features > 0);
  Preconditions.checkArgument(lambda >= 0.0);
  Preconditions.checkArgument(alpha > 0.0);
  if (logStrength) {
    Preconditions.checkArgument(epsilon > 0.0);
  }

  JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN);
  parsedRDD.cache();

  Map<String,Integer> userIDIndexMap = buildIDIndexMapping(parsedRDD, true);
  Map<String,Integer> itemIDIndexMap = buildIDIndexMapping(parsedRDD, false);

  log.info("Broadcasting ID-index mappings for {} users, {} items",
           userIDIndexMap.size(), itemIDIndexMap.size());

  Broadcast<Map<String,Integer>> bUserIDToIndex = sparkContext.broadcast(userIDIndexMap);
  Broadcast<Map<String,Integer>> bItemIDToIndex = sparkContext.broadcast(itemIDIndexMap);

  JavaRDD<Rating> trainRatingData = parsedToRatingRDD(parsedRDD, bUserIDToIndex, bItemIDToIndex);
  trainRatingData = aggregateScores(trainRatingData, epsilon);
  ALS als = new ALS()
      .setRank(features)
      .setIterations(iterations)
      .setLambda(lambda)
      .setCheckpointInterval(5);
  if (implicit) {
    als = als.setImplicitPrefs(true).setAlpha(alpha);
  }

  RDD<Rating> trainingRatingDataRDD = trainRatingData.rdd();
  trainingRatingDataRDD.cache();
  MatrixFactorizationModel model = als.run(trainingRatingDataRDD);
  trainingRatingDataRDD.unpersist(false);

  bUserIDToIndex.unpersist();
  bItemIDToIndex.unpersist();

  parsedRDD.unpersist();

  Broadcast<Map<Integer,String>> bUserIndexToID = sparkContext.broadcast(invertMap(userIDIndexMap));
  Broadcast<Map<Integer,String>> bItemIndexToID = sparkContext.broadcast(invertMap(itemIDIndexMap));

  PMML pmml = mfModelToPMML(model,
                            features,
                            lambda,
                            alpha,
                            epsilon,
                            implicit,
                            logStrength,
                            candidatePath,
                            bUserIndexToID,
                            bItemIndexToID);
  unpersist(model);

  bUserIndexToID.unpersist();
  bItemIndexToID.unpersist();

  return pmml;
}
 
Example 12
Source File: ALSUpdate.java    From oryx with Apache License 2.0 4 votes vote down vote up
@Override
public double evaluate(JavaSparkContext sparkContext,
                       PMML model,
                       Path modelParentPath,
                       JavaRDD<String> testData,
                       JavaRDD<String> trainData) {

  JavaRDD<String[]> parsedTestRDD = testData.map(MLFunctions.PARSE_FN);
  parsedTestRDD.cache();

  Map<String,Integer> userIDToIndex = buildIDIndexOneWayMap(model, parsedTestRDD, true);
  Map<String,Integer> itemIDToIndex = buildIDIndexOneWayMap(model, parsedTestRDD, false);

  log.info("Broadcasting ID-index mappings for {} users, {} items",
           userIDToIndex.size(), itemIDToIndex.size());

  Broadcast<Map<String,Integer>> bUserIDToIndex = sparkContext.broadcast(userIDToIndex);
  Broadcast<Map<String,Integer>> bItemIDToIndex = sparkContext.broadcast(itemIDToIndex);

  JavaRDD<Rating> testRatingData = parsedToRatingRDD(parsedTestRDD, bUserIDToIndex, bItemIDToIndex);
  double epsilon = Double.NaN;
  if (logStrength) {
    epsilon = Double.parseDouble(AppPMMLUtils.getExtensionValue(model, "epsilon"));
  }
  testRatingData = aggregateScores(testRatingData, epsilon);

  MatrixFactorizationModel mfModel =
      pmmlToMFModel(sparkContext, model, modelParentPath, bUserIDToIndex, bItemIDToIndex);

  parsedTestRDD.unpersist();

  double eval;
  if (implicit) {
    double auc = Evaluation.areaUnderCurve(sparkContext, mfModel, testRatingData);
    log.info("AUC: {}", auc);
    eval = auc;
  } else {
    double rmse = Evaluation.rmse(mfModel, testRatingData);
    log.info("RMSE: {}", rmse);
    eval = -rmse;
  }
  unpersist(mfModel);

  bUserIDToIndex.unpersist();
  bItemIDToIndex.unpersist();

  return eval;
}
 
Example 13
Source File: MLUpdate.java    From oryx with Apache License 2.0 4 votes vote down vote up
@Override
public void runUpdate(JavaSparkContext sparkContext,
                      long timestamp,
                      JavaPairRDD<Object,M> newKeyMessageData,
                      JavaPairRDD<Object,M> pastKeyMessageData,
                      String modelDirString,
                      TopicProducer<String,String> modelUpdateTopic)
    throws IOException, InterruptedException {

  Objects.requireNonNull(newKeyMessageData);

  JavaRDD<M> newData = newKeyMessageData.values();
  JavaRDD<M> pastData = pastKeyMessageData == null ? null : pastKeyMessageData.values();

  if (newData != null) {
    newData.cache();
    // This forces caching of the RDD. This shouldn't be necessary but we see some freezes
    // when many workers try to materialize the RDDs at once. Hence the workaround.
    newData.foreachPartition(p -> {});
  }
  if (pastData != null) {
    pastData.cache();
    pastData.foreachPartition(p -> {});
  }

  List<List<?>> hyperParameterCombos = HyperParams.chooseHyperParameterCombos(
      getHyperParameterValues(), hyperParamSearch, candidates);

  Path modelDir = new Path(modelDirString);
  Path tempModelPath = new Path(modelDir, ".temporary");
  Path candidatesPath = new Path(tempModelPath, Long.toString(System.currentTimeMillis()));

  FileSystem fs = FileSystem.get(modelDir.toUri(), sparkContext.hadoopConfiguration());
  fs.mkdirs(candidatesPath);

  Path bestCandidatePath = findBestCandidatePath(
      sparkContext, newData, pastData, hyperParameterCombos, candidatesPath);

  Path finalPath = new Path(modelDir, Long.toString(System.currentTimeMillis()));
  if (bestCandidatePath == null) {
    log.info("Unable to build any model");
  } else {
    // Move best model into place
    fs.rename(bestCandidatePath, finalPath);
  }
  // Then delete everything else
  fs.delete(candidatesPath, true);

  if (modelUpdateTopic == null) {
    log.info("No update topic configured, not publishing models to a topic");
  } else {
    // Push PMML model onto update topic, if it exists
    Path bestModelPath = new Path(finalPath, MODEL_FILE_NAME);
    if (fs.exists(bestModelPath)) {
      FileStatus bestModelPathFS = fs.getFileStatus(bestModelPath);
      PMML bestModel = null;
      boolean modelNeededForUpdates = canPublishAdditionalModelData();
      boolean modelNotTooLarge = bestModelPathFS.getLen() <= maxMessageSize;
      if (modelNeededForUpdates || modelNotTooLarge) {
        // Either the model is required for publishAdditionalModelData, or required because it's going to
        // be serialized to Kafka
        try (InputStream in = fs.open(bestModelPath)) {
          bestModel = PMMLUtils.read(in);
        }
      }

      if (modelNotTooLarge) {
        modelUpdateTopic.send("MODEL", PMMLUtils.toString(bestModel));
      } else {
        modelUpdateTopic.send("MODEL-REF", fs.makeQualified(bestModelPath).toString());
      }

      if (modelNeededForUpdates) {
        publishAdditionalModelData(
            sparkContext, bestModel, newData, pastData, finalPath, modelUpdateTopic);
      }
    }
  }

  if (newData != null) {
    newData.unpersist();
  }
  if (pastData != null) {
    pastData.unpersist();
  }
}
 
Example 14
Source File: AnalyzeSpark.java    From DataVec with Apache License 2.0 4 votes vote down vote up
/**
 *
 * @param schema
 * @param data
 * @return
 */
public static DataQualityAnalysis analyzeQuality(final Schema schema, final JavaRDD<List<Writable>> data) {
    data.cache();
    int nColumns = schema.numColumns();


    List<ColumnType> columnTypes = schema.getColumnTypes();
    List<QualityAnalysisState> states = data.aggregate(null, new QualityAnalysisAddFunction(schema),
                    new QualityAnalysisCombineFunction());

    List<ColumnQuality> list = new ArrayList<>(nColumns);

    for (QualityAnalysisState qualityState : states) {
        list.add(qualityState.getColumnQuality());
    }

    return new DataQualityAnalysis(schema, list);

}
 
Example 15
Source File: JavaLinearRegressionWithSGDExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkConf conf = new SparkConf().setAppName("JavaLinearRegressionWithSGDExample");
  JavaSparkContext sc = new JavaSparkContext(conf);

  // $example on$
  // Load and parse the data
  String path = "data/mllib/ridge-data/lpsa.data";
  JavaRDD<String> data = sc.textFile(path);
  JavaRDD<LabeledPoint> parsedData = data.map(
    new Function<String, LabeledPoint>() {
      public LabeledPoint call(String line) {
        String[] parts = line.split(",");
        String[] features = parts[1].split(" ");
        double[] v = new double[features.length];
        for (int i = 0; i < features.length - 1; i++) {
          v[i] = Double.parseDouble(features[i]);
        }
        return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v));
      }
    }
  );
  parsedData.cache();

  // Building the model
  int numIterations = 100;
  double stepSize = 0.00000001;
  final LinearRegressionModel model =
    LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize);

  // Evaluate model on training examples and compute training error
  JavaRDD<Tuple2<Double, Double>> valuesAndPreds = parsedData.map(
    new Function<LabeledPoint, Tuple2<Double, Double>>() {
      public Tuple2<Double, Double> call(LabeledPoint point) {
        double prediction = model.predict(point.features());
        return new Tuple2<>(prediction, point.label());
      }
    }
  );
  double MSE = new JavaDoubleRDD(valuesAndPreds.map(
    new Function<Tuple2<Double, Double>, Object>() {
      public Object call(Tuple2<Double, Double> pair) {
        return Math.pow(pair._1() - pair._2(), 2.0);
      }
    }
  ).rdd()).mean();
  System.out.println("training Mean Squared Error = " + MSE);

  // Save and load model
  model.save(sc.sc(), "target/tmp/javaLinearRegressionWithSGDModel");
  LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(),
    "target/tmp/javaLinearRegressionWithSGDModel");
  // $example off$

  sc.stop();
}
 
Example 16
Source File: JavaKMeansExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {

    SparkConf conf = new SparkConf().setAppName("JavaKMeansExample");
    JavaSparkContext jsc = new JavaSparkContext(conf);

    // $example on$
    // Load and parse data
    String path = "data/mllib/kmeans_data.txt";
    JavaRDD<String> data = jsc.textFile(path);
    JavaRDD<Vector> parsedData = data.map(
      new Function<String, Vector>() {
        public Vector call(String s) {
          String[] sarray = s.split(" ");
          double[] values = new double[sarray.length];
          for (int i = 0; i < sarray.length; i++) {
            values[i] = Double.parseDouble(sarray[i]);
          }
          return Vectors.dense(values);
        }
      }
    );
    parsedData.cache();

    // Cluster the data into two classes using KMeans
    int numClusters = 2;
    int numIterations = 20;
    KMeansModel clusters = KMeans.train(parsedData.rdd(), numClusters, numIterations);

    System.out.println("Cluster centers:");
    for (Vector center: clusters.clusterCenters()) {
      System.out.println(" " + center);
    }
    double cost = clusters.computeCost(parsedData.rdd());
    System.out.println("Cost: " + cost);

    // Evaluate clustering by computing Within Set Sum of Squared Errors
    double WSSSE = clusters.computeCost(parsedData.rdd());
    System.out.println("Within Set Sum of Squared Errors = " + WSSSE);

    // Save and load model
    clusters.save(jsc.sc(), "target/org/apache/spark/JavaKMeansExample/KMeansModel");
    KMeansModel sameModel = KMeansModel.load(jsc.sc(),
      "target/org/apache/spark/JavaKMeansExample/KMeansModel");
    // $example off$

    jsc.stop();
  }
 
Example 17
Source File: JavaSVMWithSGDExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkConf conf = new SparkConf().setAppName("JavaSVMWithSGDExample");
  SparkContext sc = new SparkContext(conf);
  // $example on$
  String path = "data/mllib/sample_libsvm_data.txt";
  JavaRDD<LabeledPoint> data = MLUtils.loadLibSVMFile(sc, path).toJavaRDD();

  // Split initial RDD into two... [60% training data, 40% testing data].
  JavaRDD<LabeledPoint> training = data.sample(false, 0.6, 11L);
  training.cache();
  JavaRDD<LabeledPoint> test = data.subtract(training);

  // Run training algorithm to build the model.
  int numIterations = 100;
  final SVMModel model = SVMWithSGD.train(training.rdd(), numIterations);

  // Clear the default threshold.
  model.clearThreshold();

  // Compute raw scores on the test set.
  JavaRDD<Tuple2<Object, Object>> scoreAndLabels = test.map(
    new Function<LabeledPoint, Tuple2<Object, Object>>() {
      public Tuple2<Object, Object> call(LabeledPoint p) {
        Double score = model.predict(p.features());
        return new Tuple2<Object, Object>(score, p.label());
      }
    }
  );

  // Get evaluation metrics.
  BinaryClassificationMetrics metrics =
    new BinaryClassificationMetrics(JavaRDD.toRDD(scoreAndLabels));
  double auROC = metrics.areaUnderROC();

  System.out.println("Area under ROC = " + auROC);

  // Save and load model
  model.save(sc, "target/tmp/javaSVMWithSGDModel");
  SVMModel sameModel = SVMModel.load(sc, "target/tmp/javaSVMWithSGDModel");
  // $example off$

  sc.stop();
}