Java Code Examples for org.deeplearning4j.spark.util.SparkUtils#repartitionEqually()

The following examples show how to use org.deeplearning4j.spark.util.SparkUtils#repartitionEqually() . These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
protected void doIteration(SparkDl4jMultiLayer network, JavaRDD<DataSet> split, int splitNum, int numSplits) {
    log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers",
                    splitNum, numSplits, batchSizePerWorker, thresholdAlgorithm, numWorkers);

    if (collectTrainingStats)
        stats.logMapPartitionsStart();

    JavaRDD<DataSet> splitData = split;

    if (collectTrainingStats)
        stats.logRepartitionStart();

    if(repartitioner != null){
        log.info("Repartitioning training data using repartitioner: {}", repartitioner);
        int minPerWorker = Math.max(1, batchSizePerWorker/rddDataSetNumExamples);
        splitData = repartitioner.repartition(splitData, minPerWorker, numWorkers);
    } else {
        log.info("Repartitioning training data using SparkUtils repartitioner");
        splitData = SparkUtils.repartitionEqually(splitData, repartition, numWorkers);
    }
    int nPartitions = splitData.partitions().size();

    if (collectTrainingStats && repartition != Repartition.Never)
        stats.logRepartitionEnd();


    FlatMapFunction<Iterator<DataSet>, SharedTrainingResult> function =
                    new SharedFlatMapDataSet<>(getWorkerInstance(network));

    JavaRDD<SharedTrainingResult> result = splitData.mapPartitions(function);

    processResults(network, null, result);

    if (collectTrainingStats)
        stats.logMapPartitionsEnd(nPartitions);
}
 
Example 2
protected void doIterationMDS(SparkComputationGraph network, JavaRDD<MultiDataSet> split, int splitNum,
                int numSplits) {
    log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers",
                    splitNum, numSplits, batchSizePerWorker, thresholdAlgorithm, numWorkers);

    if (collectTrainingStats)
        stats.logMapPartitionsStart();

    JavaRDD<MultiDataSet> splitData = split;

    if (collectTrainingStats)
        stats.logRepartitionStart();

    if(repartitioner != null){
        log.info("Repartitioning training data using repartitioner: {}", repartitioner);
        int minPerWorker = Math.max(1, batchSizePerWorker/rddDataSetNumExamples);
        splitData = repartitioner.repartition(splitData, minPerWorker, numWorkers);
    } else {
        log.info("Repartitioning training data using SparkUtils repartitioner");
        splitData = SparkUtils.repartitionEqually(splitData, repartition, numWorkers);
    }
    int nPartitions = splitData.partitions().size();

    if (collectTrainingStats && repartition != Repartition.Never)
        stats.logRepartitionEnd();


    FlatMapFunction<Iterator<MultiDataSet>, SharedTrainingResult> function =
                    new SharedFlatMapMultiDataSet<>(getWorkerInstance(network));

    JavaRDD<SharedTrainingResult> result = splitData.mapPartitions(function);

    processResults(null, network, result);

    if (collectTrainingStats)
        stats.logMapPartitionsEnd(nPartitions);
}
 
Example 3
protected void doIteration(SparkComputationGraph network, JavaRDD<DataSet> data, int splitNum, int numSplits) {
    log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers",
                    splitNum, numSplits, batchSizePerWorker, thresholdAlgorithm, numWorkers);

    if (collectTrainingStats)
        stats.logMapPartitionsStart();

    if (collectTrainingStats)
        stats.logRepartitionStart();

    if(repartitioner != null){
        log.info("Repartitioning training data using repartitioner: {}", repartitioner);
        int minPerWorker = Math.max(1, batchSizePerWorker/rddDataSetNumExamples);
        data = repartitioner.repartition(data, minPerWorker, numWorkers);
    } else {
        log.info("Repartitioning training data using SparkUtils repartitioner");
        data = SparkUtils.repartitionEqually(data, repartition, numWorkers);
    }
    int nPartitions = data.partitions().size();

    if (collectTrainingStats && repartition != Repartition.Never)
        stats.logRepartitionEnd();


    FlatMapFunction<Iterator<DataSet>, SharedTrainingResult> function =
                    new SharedFlatMapDataSet<>(getWorkerInstance(network));

    JavaRDD<SharedTrainingResult> result = data.mapPartitions(function);

    processResults(null, network, result);

    if (collectTrainingStats)
        stats.logMapPartitionsEnd(nPartitions);
}
 
Example 4
protected void doIterationPaths(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<String> data,
                int splitNum, int numSplits, DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, int dataSetObjectNumExamples) {
    if (network == null && graph == null)
        throw new DL4JInvalidConfigException("Both MLN & CompGraph are NULL");

    log.info("Starting training of split {} of {}. workerMiniBatchSize={}, thresholdAlgorithm={}, Configured for {} workers",
                    splitNum, numSplits, batchSizePerWorker, thresholdAlgorithm, numWorkers);

    if (collectTrainingStats)
        stats.logMapPartitionsStart();

    if (collectTrainingStats)
        stats.logRepartitionStart();

    if(repartitioner != null){
        log.info("Repartitioning training data using repartitioner: {}", repartitioner);
        int minPerWorker = Math.max(1, batchSizePerWorker/dataSetObjectNumExamples);
        data = repartitioner.repartition(data, minPerWorker, numWorkers);
    } else {
        log.info("Repartitioning training data using SparkUtils repartitioner");
        data = SparkUtils.repartitionEqually(data, repartition, numWorkers);
    }

    int nPartitions = data.partitions().size();
    if (collectTrainingStats && repartition != Repartition.Never)
        stats.logRepartitionEnd();

    JavaSparkContext sc = (network != null ? network.getSparkContext() : graph.getSparkContext());
    FlatMapFunction<Iterator<String>, SharedTrainingResult> function;
    if(dsLoader != null){
        function = new SharedFlatMapPaths<>(
                network != null ? getWorkerInstance(network) : getWorkerInstance(graph), dsLoader, BroadcastHadoopConfigHolder.get(sc));
    } else {
        function = new SharedFlatMapPathsMDS<>(
                network != null ? getWorkerInstance(network) : getWorkerInstance(graph), mdsLoader, BroadcastHadoopConfigHolder.get(sc));
    }


    JavaRDD<SharedTrainingResult> result = data.mapPartitions(function);

    processResults(network, graph, result);

    if (collectTrainingStats)
        stats.logMapPartitionsEnd(nPartitions);
}