Java Code Examples for org.apache.spark.api.java.JavaRDD#treeAggregate()

The following examples show how to use org.apache.spark.api.java.JavaRDD#treeAggregate() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: BaseRecalibratorSparkFn.java    From gatk with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
/**
 * Run the {@link BaseRecalibrationEngine} on reads and overlapping variants.
 * @param readsWithVariants the RDD of reads with overlapping variants
 * @param header the reads header
 * @param referenceFileName the name of the reference file added via {@code SparkContext#addFile()}
 * @param recalArgs arguments to use during recalibration
 * @return the recalibration report object
 */
public static RecalibrationReport apply(final JavaPairRDD<GATKRead, Iterable<GATKVariant>> readsWithVariants, final SAMFileHeader header, final String referenceFileName, final RecalibrationArgumentCollection recalArgs) {
    JavaRDD<RecalibrationTables> unmergedTables = readsWithVariants.mapPartitions(readsWithVariantsIterator -> {
        String pathOnExecutor = SparkFiles.get(referenceFileName);
        ReferenceDataSource referenceDataSource = new ReferenceFileSource(IOUtils.getPath(pathOnExecutor));
        final BaseRecalibrationEngine bqsr = new BaseRecalibrationEngine(recalArgs, header);
        bqsr.logCovariatesUsed();
        Utils.stream(readsWithVariantsIterator).forEach(t -> bqsr.processRead(t._1, referenceDataSource, t._2));
        return Iterators.singletonIterator(bqsr.getRecalibrationTables());
    });

    final RecalibrationTables emptyRecalibrationTable = new RecalibrationTables(new StandardCovariateList(recalArgs, header));
    final RecalibrationTables combinedTables = unmergedTables.treeAggregate(emptyRecalibrationTable,
            RecalibrationTables::inPlaceCombine,
            RecalibrationTables::inPlaceCombine,
            Math.max(1, (int)(Math.log(unmergedTables.partitions().size()) / Math.log(2))));

    BaseRecalibrationEngine.finalizeRecalibrationTables(combinedTables);

    final QuantizationInfo quantizationInfo = new QuantizationInfo(combinedTables, recalArgs.QUANTIZING_LEVELS);

    final StandardCovariateList covariates = new StandardCovariateList(recalArgs, header);
    return RecalUtils.createRecalibrationReport(recalArgs.generateReportTable(covariates.covariateNames()), quantizationInfo.generateReportTable(), RecalUtils.generateReportTables(combinedTables, covariates));
}
 
Example 2
Source File: SparkDl4jMultiLayer.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected IEvaluation[] doEvaluation(JavaRDD<String> data, int evalNumWorkers, int evalBatchSize, DataSetLoader loader, MultiDataSetLoader mdsLoader, IEvaluation... emptyEvaluations){
    Configuration config = sc.hadoopConfiguration();
    IEvaluateMDSPathsFlatMapFunction evalFn = new IEvaluateMDSPathsFlatMapFunction(sc.broadcast(conf.toJson()),
            SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, loader, mdsLoader,
            BroadcastHadoopConfigHolder.get(sc), emptyEvaluations);
    Preconditions.checkArgument(evalNumWorkers > 0, "Invalid number of evaulation workers: require at least 1 - got %s", evalNumWorkers);
    JavaRDD<IEvaluation[]> evaluations = data.mapPartitions(evalFn);
    return evaluations.treeAggregate(null, new IEvaluateAggregateFunction<>(), new IEvaluateAggregateFunction<>());
}
 
Example 3
Source File: SparkComputationGraph.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
public <T extends IEvaluation> T[] doEvaluationMDS(JavaRDD<MultiDataSet> data, int evalNumWorkers, int evalBatchSize, T... emptyEvaluations) {
    Preconditions.checkArgument(evalNumWorkers > 0, "Invalid number of evaulation workers: require at least 1 - got %s", evalNumWorkers);
    IEvaluateMDSFlatMapFunction<T> evalFn = new IEvaluateMDSFlatMapFunction<>(sc.broadcast(conf.toJson()),
                    SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, emptyEvaluations);
    JavaRDD<T[]> evaluations = data.mapPartitions(evalFn);
    return evaluations.treeAggregate(null, new IEvaluateAggregateFunction<T>(),
                    new IEvaluateAggregateFunction<T>());
}
 
Example 4
Source File: SparkComputationGraph.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
protected IEvaluation[] doEvaluation(JavaRDD<String> data, int evalNumWorkers, int evalBatchSize, DataSetLoader loader, MultiDataSetLoader mdsLoader, IEvaluation... emptyEvaluations){
    IEvaluateMDSPathsFlatMapFunction evalFn = new IEvaluateMDSPathsFlatMapFunction(sc.broadcast(conf.toJson()),
            SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, loader, mdsLoader,
            BroadcastHadoopConfigHolder.get(sc), emptyEvaluations);
    Preconditions.checkArgument(evalNumWorkers > 0, "Invalid number of evaulation workers: require at least 1 - got %s", evalNumWorkers);
    JavaRDD<IEvaluation[]> evaluations = data.mapPartitions(evalFn);
    return evaluations.treeAggregate(null, new IEvaluateAggregateFunction<>(), new IEvaluateAggregateFunction<>());
}
 
Example 5
Source File: SharedTrainingMaster.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
protected void processResults(SparkDl4jMultiLayer network, SparkComputationGraph graph,
                JavaRDD<SharedTrainingResult> results) {
    Preconditions.checkState(network != null || graph != null, "Both MLN & CG are null");
    Preconditions.checkState(setupDone, "Setup was not completed before trying to process results");



    if (collectTrainingStats)
        stats.logAggregateStartTime();

    SharedTrainingAccumulationTuple finalResult = results.treeAggregate(null, new SharedTrainingAggregateFunction(),
                    new SharedTrainingAccumulationFunction(), 4);
    SparkTrainingStats aggregatedStats = finalResult.getSparkTrainingStats();
    if (collectTrainingStats)
        stats.logAggregationEndTime();

    //finalizeTraining has to be *after* training has completed, otherwise the RDD (via tree aggregate)
    finalizeTraining();


    if (collectTrainingStats)
        stats.logProcessParamsUpdaterStart();

    if (finalResult.getUpdaterStateArray() != null) {

        if (finalResult.getAggregationsCount() > 1) {
            finalResult.getUpdaterStateArray().divi(finalResult.getAggregationsCount());
        }

        if (network != null) {
            if (network.getNetwork().getUpdater() != null
                            && network.getNetwork().getUpdater().getStateViewArray() != null)
                network.getNetwork().getUpdater().getStateViewArray().assign(finalResult.getUpdaterStateArray());
        } else {
            if (graph.getNetwork().getUpdater() != null
                            && graph.getNetwork().getUpdater().getStateViewArray() != null)
                graph.getNetwork().getUpdater().getStateViewArray().assign(finalResult.getUpdaterStateArray());
        }
    }


    double score = finalResult.getScoreSum() / Math.max(1, finalResult.getAggregationsCount());

    if (network != null) {
        network.getNetwork().setScore(score);
    } else {
        graph.getNetwork().setScore(score);
    }

    if (collectTrainingStats)
        stats.logProcessParamsUpdaterEnd();


    if (collectTrainingStats) {
        stats.logProcessParamsUpdaterEnd();
        stats.addWorkerStats(aggregatedStats);
    }

    if (statsStorage != null) {
        Collection<StorageMetaData> meta = finalResult.getListenerMetaData();
        if (meta != null && !meta.isEmpty()) {
            statsStorage.putStorageMetaData(meta);
        }

        Collection<Persistable> staticInfo = finalResult.getListenerStaticInfo();
        if (staticInfo != null && !staticInfo.isEmpty()) {
            statsStorage.putStaticInfo(staticInfo);
        }

        Collection<Persistable> updates = finalResult.getListenerUpdates();
        if (updates != null && !updates.isEmpty()) {
            statsStorage.putUpdate(updates);
        }
    }

    if (logMinibatchesPerWorker){
        if(finalResult.getMinibatchesPerExecutor() != null){
            List<String> l = new ArrayList<>(finalResult.getMinibatchesPerExecutor().keySet());
            Collections.sort(l);
            Map<String,Integer> linkedMap = new LinkedHashMap<>();
            for(String s : l){
                linkedMap.put(s, finalResult.getMinibatchesPerExecutor().get(s));
            }
            log.info("Number of minibatches processed per JVM/executor: {}", linkedMap);
        }
    }

    if(finalResult.getThresholdAlgorithmReducer() != null){
        //Store the final threshold algorithm after aggregation
        //Some threshold algorithms contain state/history, used to adapt the threshold algorithm
        //The idea is we want to keep this history/state for next epoch, rather than simply throwing it away
        // and starting the threshold adaption process from scratch on each epoch
        ThresholdAlgorithm ta = finalResult.getThresholdAlgorithmReducer().getFinalResult();
        this.thresholdAlgorithm = ta;
    }

    Nd4j.getExecutioner().commit();
}
 
Example 6
Source File: SparkDl4jMultiLayer.java    From deeplearning4j with Apache License 2.0 3 votes vote down vote up
/**
 * Perform distributed evaluation of any type of {@link IEvaluation} - or multiple IEvaluation instances.
 * Distributed equivalent of {@link MultiLayerNetwork#doEvaluation(DataSetIterator, IEvaluation[])}
 *
 * @param data             Data to evaluate on
 * @param emptyEvaluations Empty evaluation instances. Starting point (serialized/duplicated, then merged)
 * @param evalNumWorkers   Number of workers (copies of the MultiLayerNetwork) model to use. Generally this should
 *                         be smaller than the number of threads - 2 to 4 is often good enough. If using CUDA GPUs,
 *                         this should ideally be set to the number of GPUs on each node (i.e., 1 for a single GPU node)
 * @param evalBatchSize    Evaluation batch size
 * @param <T>              Type of evaluation instance to return
 * @return IEvaluation instances
 */
public <T extends IEvaluation> T[] doEvaluation(JavaRDD<DataSet> data, int evalNumWorkers, int evalBatchSize, T... emptyEvaluations) {
    IEvaluateFlatMapFunction<T> evalFn = new IEvaluateFlatMapFunction<>(false, sc.broadcast(conf.toJson()),
                    SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, emptyEvaluations);
    JavaRDD<T[]> evaluations = data.mapPartitions(evalFn);
    return evaluations.treeAggregate(null, new IEvaluateAggregateFunction<T>(), new IEvaluationReduceFunction<T>());
}
 
Example 7
Source File: SparkComputationGraph.java    From deeplearning4j with Apache License 2.0 3 votes vote down vote up
/**
 * Perform distributed evaluation on a <i>single output</i> ComputationGraph form DataSet objects using Spark.
 * Can be used to perform multiple evaluations on this single output (for example, {@link Evaluation} and
 * {@link ROC}) at the same time.<br>
 *
 * @param data             Data to evaluatie
 * @param evalNumWorkers   Number of worker threads (per machine) to use for evaluation. May want tis to be less than
 *                         the number of Spark threads per machine/JVM to reduce memory requirements
 * @param evalBatchSize    Minibatch size for evaluation
 * @param emptyEvaluations Evaluations to perform
 * @return                 Evaluations
 */
public <T extends IEvaluation> T[] doEvaluation(JavaRDD<DataSet> data, int evalNumWorkers, int evalBatchSize, T... emptyEvaluations) {
    IEvaluateFlatMapFunction<T> evalFn = new IEvaluateFlatMapFunction<>(true, sc.broadcast(conf.toJson()),
            SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, emptyEvaluations);
    JavaRDD<T[]> evaluations = data.mapPartitions(evalFn);
    return evaluations.treeAggregate(null, new IEvaluateAggregateFunction<T>(),
                    new IEvaluateAggregateFunction<T>());
}