Java Code Examples for org.apache.spark.api.java.JavaRDD#persist()

The following examples show how to use org.apache.spark.api.java.JavaRDD#persist() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: PathSeqBwaSpark.java    From gatk with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
/**
 * Writes RDD of reads to path. Note writeReads() is not used because there are separate paired/unpaired outputs.
 * Header sequence dictionary is reduced to only those that were aligned to.
 */
private void writeBam(final JavaRDD<GATKRead> reads, final String inputBamPath, final boolean isPaired,
                      final JavaSparkContext ctx, SAMFileHeader header) {

    //Only retain header sequences that were aligned to.
    //This invokes an action and therefore the reads must be cached.
    reads.persist(StorageLevel.MEMORY_AND_DISK_SER());
    header = PSBwaUtils.removeUnmappedHeaderSequences(header, reads, logger);

    final String outputPath = isPaired ? outputPaired : outputUnpaired;
    try {
        ReadsSparkSink.writeReads(ctx, outputPath, null, reads, header,
                shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE,
                PSUtils.pathseqGetRecommendedNumReducers(inputBamPath, numReducers, getTargetPartitionSize()), shardedPartsDir, true, splittingIndexGranularity);
    } catch (final IOException e) {
        throw new UserException.CouldNotCreateOutputFile(outputPath, "Writing failed", e);
    }
}
 
Example 2
Source File: ParameterAveragingTrainingMaster.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void executeTrainingDirect(SparkComputationGraph graph, JavaRDD<MultiDataSet> trainingData) {
    if (collectTrainingStats)
        stats.logFitStart();
    //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
    // number of minibatches between averaging
    //But to do that, we need to know: (a) the number of examples, and (b) the number of workers
    if (storageLevel != null)
        trainingData.persist(storageLevel);

    long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingData);

    JavaRDD<MultiDataSet>[] splits =
                    getSplitRDDs(trainingData, (int) totalDataSetObjectCount, rddDataSetNumExamples);

    int splitNum = 1;
    for (JavaRDD<MultiDataSet> split : splits) {
        doIteration(graph, split, splitNum++, splits.length);
    }

    if (collectTrainingStats)
        stats.logFitEnd((int) totalDataSetObjectCount);
}
 
Example 3
Source File: ParameterAveragingTrainingMaster.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void executeTrainingPathsHelper(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<String> trainingDataPaths,
                                          DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, int dataSetObjectsNumExamples) {
    if (numWorkers == null)
        numWorkers = network.getSparkContext().defaultParallelism();

    if (collectTrainingStats)
        stats.logFitStart();
    if (storageLevelStreams != null)
        trainingDataPaths.persist(storageLevelStreams);

    long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingDataPaths);
    JavaRDD<String>[] splits =
                    getSplitRDDs(trainingDataPaths, (int) totalDataSetObjectCount, dataSetObjectsNumExamples);

    int splitNum = 1;
    for (JavaRDD<String> split : splits) {
        doIterationPaths(network, graph, split, splitNum++, splits.length, dataSetObjectsNumExamples, dsLoader, mdsLoader);
    }

    if (collectTrainingStats)
        stats.logFitEnd((int) totalDataSetObjectCount);
}
 
Example 4
Source File: ParameterAveragingTrainingMaster.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void executeTrainingDirect(SparkDl4jMultiLayer network, JavaRDD<DataSet> trainingData) {
    if (collectTrainingStats)
        stats.logFitStart();
    //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
    // number of minibatches between averagings
    //But to do that, wee need to know: (a) the number of examples, and (b) the number of workers
    if (storageLevel != null)
        trainingData.persist(storageLevel);

    long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingData);
    JavaRDD<DataSet>[] splits = getSplitRDDs(trainingData, (int) totalDataSetObjectCount, rddDataSetNumExamples);

    int splitNum = 1;
    for (JavaRDD<DataSet> split : splits) {
        doIteration(network, split, splitNum++, splits.length);
    }

    if (collectTrainingStats)
        stats.logFitEnd((int) totalDataSetObjectCount);
}
 
Example 5
Source File: SharedTrainingMaster.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void executeTrainingPathsHelper(SparkDl4jMultiLayer network, SparkComputationGraph graph, JavaRDD<String> trainingDataPaths,
                                          DataSetLoader dsLoader, MultiDataSetLoader mdsLoader, int dataSetObjectsNumExamples) {

    if (numWorkers == null) {
        if(network != null){
            numWorkers = network.getSparkContext().defaultParallelism();
        } else {
            numWorkers = graph.getSparkContext().defaultParallelism();
        }
    }

    if (collectTrainingStats)
        stats.logFitStart();

    if (storageLevelStreams != null)
        trainingDataPaths.persist(storageLevelStreams);

    long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingDataPaths);

    doIterationPaths(network, graph, trainingDataPaths, 1, 1, dsLoader, mdsLoader, dataSetObjectsNumExamples);

    if (collectTrainingStats)
        stats.logFitEnd((int) totalDataSetObjectCount);
}
 
Example 6
Source File: SharedTrainingMaster.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void executeTrainingDirect(SparkComputationGraph network, JavaRDD<DataSet> trainingData) {
    if (collectTrainingStats)
        stats.logFitStart();

    //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
    // number of minibatches between averagings
    //But to do that, wee need to know: (a) the number of examples, and (b) the number of workers
    if (storageLevel != null)
        trainingData.persist(storageLevel);

    long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingData);

    // since this is real distributed training, we don't need to split data
    doIteration(network, trainingData, 1, 1);

    if (collectTrainingStats)
        stats.logFitEnd((int) totalDataSetObjectCount);
}
 
Example 7
Source File: SharedTrainingMaster.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void executeTrainingDirectMDS(SparkComputationGraph network, JavaRDD<MultiDataSet> trainingData) {
    if (collectTrainingStats)
        stats.logFitStart();

    //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
    // number of minibatches between averagings
    //But to do that, wee need to know: (a) the number of examples, and (b) the number of workers
    if (storageLevel != null)
        trainingData.persist(storageLevel);

    long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingData);

    // since this is real distributed training, we don't need to split data
    doIterationMDS(network, trainingData, 1, 1);

    if (collectTrainingStats)
        stats.logFitEnd((int) totalDataSetObjectCount);
}
 
Example 8
Source File: SharedTrainingMaster.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
protected void executeTrainingDirect(SparkDl4jMultiLayer network, JavaRDD<DataSet> trainingData) {
    if (collectTrainingStats)
        stats.logFitStart();

    //For "vanilla" parameter averaging training, we need to split the full data set into batches of size N, such that we can process the specified
    // number of minibatches between averagings
    //But to do that, wee need to know: (a) the number of examples, and (b) the number of workers
    if (storageLevel != null)
        trainingData.persist(storageLevel);

    long totalDataSetObjectCount = getTotalDataSetObjectCount(trainingData);

    // since this is real distributed training, we don't need to split data
    doIteration(network, trainingData, 1, 1);

    if (collectTrainingStats)
        stats.logFitEnd((int) totalDataSetObjectCount);
}
 
Example 9
Source File: HoodieSimpleIndex.java    From hudi with Apache License 2.0 6 votes vote down vote up
/**
 * Tags records location for incoming records.
 *
 * @param inputRecordRDD {@link JavaRDD} of incoming records
 * @param jsc            instance of {@link JavaSparkContext} to use
 * @param hoodieTable    instance of {@link HoodieTable} to use
 * @return {@link JavaRDD} of records with record locations set
 */
protected JavaRDD<HoodieRecord<T>> tagLocationInternal(JavaRDD<HoodieRecord<T>> inputRecordRDD, JavaSparkContext jsc,
                                                       HoodieTable<T> hoodieTable) {
  if (config.getSimpleIndexUseCaching()) {
    inputRecordRDD.persist(SparkConfigUtils.getSimpleIndexInputStorageLevel(config.getProps()));
  }

  JavaPairRDD<HoodieKey, HoodieRecord<T>> keyedInputRecordRDD = inputRecordRDD.mapToPair(record -> new Tuple2<>(record.getKey(), record));
  JavaPairRDD<HoodieKey, HoodieRecordLocation> existingLocationsOnTable = fetchRecordLocationsForAffectedPartitions(keyedInputRecordRDD.keys(), jsc, hoodieTable,
      config.getSimpleIndexParallelism());

  JavaRDD<HoodieRecord<T>> taggedRecordRDD = keyedInputRecordRDD.leftOuterJoin(existingLocationsOnTable)
      .map(entry -> {
        final HoodieRecord<T> untaggedRecord = entry._2._1;
        final Option<HoodieRecordLocation> location = Option.ofNullable(entry._2._2.orNull());
        return HoodieIndexUtils.getTaggedRecord(untaggedRecord, location);
      });

  if (config.getSimpleIndexUseCaching()) {
    inputRecordRDD.unpersist();
  }
  return taggedRecordRDD;
}
 
Example 10
Source File: RunCompactionActionExecutor.java    From hudi with Apache License 2.0 5 votes vote down vote up
@Override
public HoodieWriteMetadata execute() {
  HoodieInstant instant = HoodieTimeline.getCompactionRequestedInstant(instantTime);
  HoodieTimeline pendingCompactionTimeline = table.getActiveTimeline().filterPendingCompactionTimeline();
  if (!pendingCompactionTimeline.containsInstant(instant)) {
    throw new IllegalStateException(
        "No Compaction request available at " + instantTime + " to run compaction");
  }

  HoodieWriteMetadata compactionMetadata = new HoodieWriteMetadata();
  try {
    HoodieActiveTimeline timeline = table.getActiveTimeline();
    HoodieCompactionPlan compactionPlan =
        CompactionUtils.getCompactionPlan(table.getMetaClient(), instantTime);
    // Mark instant as compaction inflight
    timeline.transitionCompactionRequestedToInflight(instant);
    table.getMetaClient().reloadActiveTimeline();

    HoodieMergeOnReadTableCompactor compactor = new HoodieMergeOnReadTableCompactor();
    JavaRDD<WriteStatus> statuses = compactor.compact(jsc, compactionPlan, table, config, instantTime);

    statuses.persist(SparkConfigUtils.getWriteStatusStorageLevel(config.getProps()));
    List<HoodieWriteStat> updateStatusMap = statuses.map(WriteStatus::getStat).collect();
    HoodieCommitMetadata metadata = new HoodieCommitMetadata(true);
    for (HoodieWriteStat stat : updateStatusMap) {
      metadata.addWriteStat(stat.getPartitionPath(), stat);
    }
    metadata.addMetadata(HoodieCommitMetadata.SCHEMA_KEY, config.getSchema());

    compactionMetadata.setWriteStatuses(statuses);
    compactionMetadata.setCommitted(false);
    compactionMetadata.setCommitMetadata(Option.of(metadata));
  } catch (IOException e) {
    throw new HoodieCompactionException("Could not compact " + config.getBasePath(), e);
  }

  return compactionMetadata;
}
 
Example 11
Source File: BaseCommitActionExecutor.java    From hudi with Apache License 2.0 5 votes vote down vote up
protected void updateIndexAndCommitIfNeeded(JavaRDD<WriteStatus> writeStatusRDD, HoodieWriteMetadata result) {
  // cache writeStatusRDD before updating index, so that all actions before this are not triggered again for future
  // RDD actions that are performed after updating the index.
  writeStatusRDD = writeStatusRDD.persist(SparkConfigUtils.getWriteStatusStorageLevel(config.getProps()));
  Instant indexStartTime = Instant.now();
  // Update the index back
  JavaRDD<WriteStatus> statuses = ((HoodieTable<T>)table).getIndex().updateLocation(writeStatusRDD, jsc,
      (HoodieTable<T>)table);
  result.setIndexUpdateDuration(Duration.between(indexStartTime, Instant.now()));
  result.setWriteStatuses(statuses);
  commitOnAutoCommit(result);
}
 
Example 12
Source File: BaseCommitActionExecutor.java    From hudi with Apache License 2.0 5 votes vote down vote up
public HoodieWriteMetadata execute(JavaRDD<HoodieRecord<T>> inputRecordsRDD) {
  HoodieWriteMetadata result = new HoodieWriteMetadata();
  // Cache the tagged records, so we don't end up computing both
  // TODO: Consistent contract in HoodieWriteClient regarding preppedRecord storage level handling
  if (inputRecordsRDD.getStorageLevel() == StorageLevel.NONE()) {
    inputRecordsRDD.persist(StorageLevel.MEMORY_AND_DISK_SER());
  } else {
    LOG.info("RDD PreppedRecords was persisted at: " + inputRecordsRDD.getStorageLevel());
  }

  WorkloadProfile profile = null;
  if (isWorkloadProfileNeeded()) {
    profile = new WorkloadProfile(inputRecordsRDD);
    LOG.info("Workload profile :" + profile);
    saveWorkloadProfileMetadataToInflight(profile, instantTime);
  }

  // partition using the insert partitioner
  final Partitioner partitioner = getPartitioner(profile);
  JavaRDD<HoodieRecord<T>> partitionedRecords = partition(inputRecordsRDD, partitioner);
  JavaRDD<WriteStatus> writeStatusRDD = partitionedRecords.mapPartitionsWithIndex((partition, recordItr) -> {
    if (WriteOperationType.isChangingRecords(operationType)) {
      return handleUpsertPartition(instantTime, partition, recordItr, partitioner);
    } else {
      return handleInsertPartition(instantTime, partition, recordItr, partitioner);
    }
  }, true).flatMap(List::iterator);

  updateIndexAndCommitIfNeeded(writeStatusRDD, result);
  return result;
}
 
Example 13
Source File: HoodieBloomIndex.java    From hudi with Apache License 2.0 5 votes vote down vote up
@Override
public JavaRDD<HoodieRecord<T>> tagLocation(JavaRDD<HoodieRecord<T>> recordRDD, JavaSparkContext jsc,
                                            HoodieTable<T> hoodieTable) {

  // Step 0: cache the input record RDD
  if (config.getBloomIndexUseCaching()) {
    recordRDD.persist(SparkConfigUtils.getBloomIndexInputStorageLevel(config.getProps()));
  }

  // Step 1: Extract out thinner JavaPairRDD of (partitionPath, recordKey)
  JavaPairRDD<String, String> partitionRecordKeyPairRDD =
      recordRDD.mapToPair(record -> new Tuple2<>(record.getPartitionPath(), record.getRecordKey()));

  // Lookup indexes for all the partition/recordkey pair
  JavaPairRDD<HoodieKey, HoodieRecordLocation> keyFilenamePairRDD =
      lookupIndex(partitionRecordKeyPairRDD, jsc, hoodieTable);

  // Cache the result, for subsequent stages.
  if (config.getBloomIndexUseCaching()) {
    keyFilenamePairRDD.persist(StorageLevel.MEMORY_AND_DISK_SER());
  }
  if (LOG.isDebugEnabled()) {
    long totalTaggedRecords = keyFilenamePairRDD.count();
    LOG.debug("Number of update records (ones tagged with a fileID): " + totalTaggedRecords);
  }

  // Step 4: Tag the incoming records, as inserts or updates, by joining with existing record keys
  // Cost: 4 sec.
  JavaRDD<HoodieRecord<T>> taggedRecordRDD = tagLocationBacktoRecords(keyFilenamePairRDD, recordRDD);

  if (config.getBloomIndexUseCaching()) {
    recordRDD.unpersist(); // unpersist the input Record RDD
    keyFilenamePairRDD.unpersist();
  }
  return taggedRecordRDD;
}
 
Example 14
Source File: HBaseIndex.java    From hudi with Apache License 2.0 5 votes vote down vote up
@Override
public JavaRDD<WriteStatus> updateLocation(JavaRDD<WriteStatus> writeStatusRDD, JavaSparkContext jsc,
    HoodieTable<T> hoodieTable) {
  final HBaseIndexQPSResourceAllocator hBaseIndexQPSResourceAllocator = createQPSResourceAllocator(this.config);
  setPutBatchSize(writeStatusRDD, hBaseIndexQPSResourceAllocator, jsc);
  LOG.info("multiPutBatchSize: before hbase puts" + multiPutBatchSize);
  JavaRDD<WriteStatus> writeStatusJavaRDD = writeStatusRDD.mapPartitionsWithIndex(updateLocationFunction(), true);
  // caching the index updated status RDD
  writeStatusJavaRDD = writeStatusJavaRDD.persist(SparkConfigUtils.getWriteStatusStorageLevel(config.getProps()));
  return writeStatusJavaRDD;
}
 
Example 15
Source File: MpBoostLearner.java    From sparkboost with Apache License 2.0 4 votes vote down vote up
/**
 * Build a new classifier by analyzing the training data available in the
 * specified documents set.
 *
 * @param docs The set of documents used as training data.
 * @return A new MP-Boost classifier.
 */
public BoostClassifier buildModel(JavaRDD<MultilabelPoint> docs) {
    if (docs == null)
        throw new NullPointerException("The set of training documents is 'null'");

    // Repartition documents.
    Logging.l().info("Load initial data and generating internal data representations...");
    docs = docs.repartition(getNumDocumentsPartitions());
    docs = docs.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Docs: num partitions " + docs.partitions().size());

    int numDocs = DataUtils.getNumDocuments(docs);
    int numLabels = DataUtils.getNumLabels(docs);
    JavaRDD<DataUtils.LabelDocuments> labelDocuments = DataUtils.getLabelDocuments(docs);

    // Repartition labels.
    labelDocuments = labelDocuments.repartition(getNumLabelsPartitions());
    labelDocuments = labelDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Labels: num partitions " + labelDocuments.partitions().size());

    // Repartition features.
    JavaRDD<DataUtils.FeatureDocuments> featureDocuments = DataUtils.getFeatureDocuments(docs);
    featureDocuments = featureDocuments.repartition(getNumFeaturesPartitions());
    featureDocuments = featureDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Features: num partitions " + featureDocuments.partitions().size());


    Logging.l().info("Ok, done!");

    WeakHypothesis[] computedWH = new WeakHypothesis[numIterations];
    double[][] localDM = initDistributionMatrix(numLabels, numDocs);
    for (int i = 0; i < numIterations; i++) {

        // Generate new weak hypothesis.
        WeakHypothesis localWH = learnWeakHypothesis(localDM, labelDocuments, featureDocuments);

        // Update distribution matrix with the new hypothesis.
        updateDistributionMatrix(sc, docs, localDM, localWH);

        // Save current generated weak hypothesis.
        computedWH[i] = localWH;

        Logging.l().info("Completed iteration " + (i + 1));
    }

    Logging.l().info("Model built!");

    return new BoostClassifier(computedWH);
}
 
Example 16
Source File: PersistExample.java    From Apache-Spark-2x-for-Java-Developers with MIT License 4 votes vote down vote up
/**
	 * @param args
	 */
	public static void main(String[] args) {
		//C:\Users\sumit.kumar\Downloads\bin\warehouse
		//System.setProperty("hadoop.home.dir", "C:\\Users\\sumit.kumar\\Downloads");
		String logFile = "src/main/resources/Apology_by_Plato.txt"; // Should be some file on your system
		Logger rootLogger = LogManager.getRootLogger();
		rootLogger.setLevel(Level.WARN);
		 SparkConf conf = new SparkConf().setMaster("local").setAppName("ActionExamples").set("spark.hadoop.validateOutputSpecs", "false");
			JavaSparkContext sparkContext = new JavaSparkContext(conf);
		    JavaRDD<Integer> rdd = sparkContext.parallelize(Arrays.asList(1, 2, 3,4,5),3).cache();	
		    JavaRDD<Integer> evenRDD= rdd.filter(new org.apache.spark.api.java.function.Function<Integer, Boolean>() {
			@Override
			public Boolean call(Integer v1) throws Exception {
			  return ((v1%2)==0)?true:false;
				}
			});
		    
		    evenRDD.persist(StorageLevel.MEMORY_AND_DISK());
		    evenRDD.foreach(new VoidFunction<Integer>() {
			@Override
			public void call(Integer t) throws Exception {
			System.out.println("The value of RDD are :"+t);
			 }
			});
		   //unpersisting the RDD 
		   evenRDD.unpersist();
		   rdd.unpersist();
		   
		   /* JavaRDD<String> lines = spark.read().textFile(logFile).javaRDD().cache();
		    System.out.println("DEBUG: \n"+ lines.toDebugString());
		   long word= lines.count();
		   JavaRDD<String> distinctLines=lines.distinct();
		   System.out.println("DEBUG: \n"+ distinctLines.toDebugString());
		   JavaRDD<String> finalRdd=lines.subtract(distinctLines);
		    
		   
		   System.out.println("DEBUG: \n"+ finalRdd.toDebugString());
		   System.out.println("The count is "+word);
		   System.out.println("The count is "+distinctLines.count());
		   System.out.println("The count is "+finalRdd.count());
		   
		   finalRdd.foreach(new VoidFunction<String>() {
			
			@Override
			public void call(String t) throws Exception {
				// TODO Auto-generated method stub
				System.out.println(t);
			}
		});
*/	    /*SparkConf conf = new SparkConf().setAppName("Simple Application");
	    JavaSparkContext sc = new JavaSparkContext(conf);
	    StorageLevel newLevel;
		JavaRDD<String> logData = sc.textFile(logFile).cache();

	    long numAs = logData.filter(new Function(logFile, logFile, logFile, logFile, false) {
	      public Boolean call(String s) { return s.contains("a"); }
	    }).count();

	    long numBs = logData.filter(new Function(logFile, logFile, logFile, logFile, false) {
	      public Boolean call(String s) { return s.contains("b"); }
	    }).count();

	    System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs);
	    
	    sc.stop();*/

	}
 
Example 17
Source File: PathSeqPipelineSpark.java    From gatk with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
@Override
protected void runTool(final JavaSparkContext ctx) {

    filterArgs.doReadFilterArgumentWarnings(getCommandLineParser().getPluginDescriptor(GATKReadFilterPluginDescriptor.class), logger);
    SAMFileHeader header = PSUtils.checkAndClearHeaderSequences(getHeaderForReads(), filterArgs, logger);

    //Do not allow use of numReducers
    if (numReducers > 0) {
        throw new UserException.BadInput("Use --readsPerPartitionOutput instead of --num-reducers.");
    }

    //Filter
    final Tuple2<JavaRDD<GATKRead>, JavaRDD<GATKRead>> filterResult;
    final PSFilter filter = new PSFilter(ctx, filterArgs, header);
    try (final PSFilterLogger filterLogger = filterArgs.filterMetricsFileUri != null ? new PSFilterFileLogger(getMetricsFile(), filterArgs.filterMetricsFileUri) : new PSFilterEmptyLogger()) {
        final JavaRDD<GATKRead> inputReads = getReads();
        filterResult = filter.doFilter(inputReads, filterLogger);
    }
    JavaRDD<GATKRead> pairedReads = filterResult._1;
    JavaRDD<GATKRead> unpairedReads = filterResult._2;

    //Counting forces an action on the RDDs to guarantee we're done with the Bwa image and kmer filter
    final long numPairedReads = pairedReads.count();
    final long numUnpairedReads = unpairedReads.count();
    final long numTotalReads = numPairedReads + numUnpairedReads;

    //Closes Bwa image, kmer filter, and metrics file if used
    //Note the host Bwa image before must be unloaded before trying to load the pathogen image
    filter.close();

    //Rebalance partitions using the counts
    final int numPairedPartitions = 1 + (int) (numPairedReads / readsPerPartition);
    final int numUnpairedPartitions = 1 + (int) (numUnpairedReads / readsPerPartition);
    pairedReads = repartitionPairedReads(pairedReads, numPairedPartitions, numPairedReads);
    unpairedReads = unpairedReads.repartition(numUnpairedPartitions);

    //Bwa pathogen alignment
    final PSBwaAlignerSpark aligner = new PSBwaAlignerSpark(ctx, bwaArgs);
    PSBwaUtils.addReferenceSequencesToHeader(header, bwaArgs.microbeDictionary);
    final Broadcast<SAMFileHeader> headerBroadcast = ctx.broadcast(header);
    JavaRDD<GATKRead> alignedPairedReads = aligner.doBwaAlignment(pairedReads, true, headerBroadcast);
    JavaRDD<GATKRead> alignedUnpairedReads = aligner.doBwaAlignment(unpairedReads, false, headerBroadcast);

    //Cache this expensive result. Note serialization significantly reduces memory consumption.
    alignedPairedReads.persist(StorageLevel.MEMORY_AND_DISK_SER());
    alignedUnpairedReads.persist(StorageLevel.MEMORY_AND_DISK_SER());

    //Score pathogens
    final PSScorer scorer = new PSScorer(scoreArgs);
    final JavaRDD<GATKRead> readsFinal = scorer.scoreReads(ctx, alignedPairedReads, alignedUnpairedReads, header);

    //Clean up header
    header = PSBwaUtils.removeUnmappedHeaderSequences(header, readsFinal, logger);

    //Log read counts
    if (scoreArgs.scoreMetricsFileUri != null) {
        try (final PSScoreLogger scoreLogger = new PSScoreFileLogger(getMetricsFile(), scoreArgs.scoreMetricsFileUri)) {
            scoreLogger.logReadCounts(readsFinal);
        }
    }

    //Write reads to BAM, if specified
    if (outputPath != null) {
        try {
            //Reduce number of partitions since we previously went to ~5K reads per partition, which
            // is far too small for sharded output.
            final int numPartitions = Math.max(1, (int) (numTotalReads / readsPerPartitionOutput));
            final JavaRDD<GATKRead> readsFinalRepartitioned = readsFinal.coalesce(numPartitions, false);
            ReadsSparkSink.writeReads(ctx, outputPath, null, readsFinalRepartitioned, header,
                    shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, numPartitions, shardedPartsDir, true, splittingIndexGranularity);
        } catch (final IOException e) {
            throw new UserException.CouldNotCreateOutputFile(outputPath, "writing failed", e);
        }
    }
    aligner.close();
}
 
Example 18
Source File: AdaBoostMHLearner.java    From sparkboost with Apache License 2.0 4 votes vote down vote up
/**
 * Build a new classifier by analyzing the training data available in the
 * specified documents set.
 *
 * @param docs The set of documents used as training data.
 * @return A new AdaBoost.MH classifier.
 */
public BoostClassifier buildModel(JavaRDD<MultilabelPoint> docs) {
    if (docs == null)
        throw new NullPointerException("The set of input documents is 'null'");


    // Repartition documents.
    Logging.l().info("Load initial data and generating internal data representations...");
    docs = docs.repartition(getNumDocumentsPartitions());
    docs = docs.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Docs: num partitions " + docs.partitions().size());

    int numDocs = DataUtils.getNumDocuments(docs);
    int numLabels = DataUtils.getNumLabels(docs);
    JavaRDD<DataUtils.LabelDocuments> labelDocuments = DataUtils.getLabelDocuments(docs);

    // Repartition labels.
    labelDocuments = labelDocuments.repartition(getNumLabelsPartitions());
    labelDocuments = labelDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Labels: num partitions " + labelDocuments.partitions().size());

    // Repartition features.
    JavaRDD<DataUtils.FeatureDocuments> featureDocuments = DataUtils.getFeatureDocuments(docs);
    featureDocuments = featureDocuments.repartition(getNumFeaturesPartitions());
    featureDocuments = featureDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Features: num partitions " + featureDocuments.partitions().size());
    Logging.l().info("Ok, done!");

    WeakHypothesis[] computedWH = new WeakHypothesis[numIterations];
    double[][] localDM = initDistributionMatrix(numLabels, numDocs);
    for (int i = 0; i < numIterations; i++) {

        // Generate new weak hypothesis.
        WeakHypothesis localWH = learnWeakHypothesis(localDM, labelDocuments, featureDocuments);

        // Update distribution matrix with the new hypothesis.
        updateDistributionMatrix(sc, docs, localDM, localWH);

        // Save current generated weak hypothesis.
        computedWH[i] = localWH;

        Logging.l().info("Completed iteration " + (i + 1));
    }

    Logging.l().info("Model built!");

    return new BoostClassifier(computedWH);
}