Java Code Examples for org.apache.spark.api.java.JavaRDD#repartition()

The following examples show how to use org.apache.spark.api.java.JavaRDD#repartition() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TestRepartitioning.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testRepartitioningApprox() {
        List<String> list = new ArrayList<>();
        for (int i = 0; i < 1000; i++) {
            list.add(String.valueOf(i));
        }

        JavaRDD<String> rdd = sc.parallelize(list);
        rdd = rdd.repartition(200);

        JavaRDD<String> rdd2 = SparkUtils.repartitionApproximateBalance(rdd, Repartition.Always, 10);
        assertFalse(rdd == rdd2); //Should be different objects due to repartitioning

        assertEquals(10, rdd2.partitions().size());

        for (int i = 0; i < 10; i++) {
            List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
//            System.out.println("Partition " + i + " size: " + partition.size());
            assertTrue(partition.size() >= 90 && partition.size() <= 110);
        }
    }
 
Example 2
Source File: TestRepartitioning.java    From deeplearning4j with Apache License 2.0 6 votes vote down vote up
@Test
    public void testRepartitioning() {
        List<String> list = new ArrayList<>();
        for (int i = 0; i < 1000; i++) {
            list.add(String.valueOf(i));
        }

        JavaRDD<String> rdd = sc.parallelize(list);
        rdd = rdd.repartition(200);

        JavaRDD<String> rdd2 = SparkUtils.repartitionBalanceIfRequired(rdd, Repartition.Always, 100, 10);
        assertFalse(rdd == rdd2); //Should be different objects due to repartitioning

        assertEquals(10, rdd2.partitions().size());
        for (int i = 0; i < 10; i++) {
            List<String> partition = rdd2.collectPartitions(new int[] {i})[0];
//            System.out.println("Partition " + i + " size: " + partition.size());
            assertEquals(100, partition.size()); //Should be exactly 100, for the util method (but NOT spark .repartition)
        }
    }
 
Example 3
Source File: ReadsSparkSinkUnitTest.java    From gatk with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
@Test(dataProvider = "loadReadsBAM", groups = "spark")
public void readsSinkShardedTest(String inputBam, String outputFileName, String referenceFile, String outputFileExtension, boolean writeBai, boolean writeSbi, long sbiGranularity) throws IOException {
    final GATKPath inputBamSpecifier = new GATKPath(inputBam);
    final File outputFile = createTempFile(outputFileName, outputFileExtension);
    JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

    final GATKPath referencePath = referenceFile == null ? null : new GATKPath(referenceFile);

    ReadsSparkSource readSource = new ReadsSparkSource(ctx);
    JavaRDD<GATKRead> rddParallelReads = readSource.getParallelReads(inputBamSpecifier, referencePath);
    rddParallelReads = rddParallelReads.repartition(2); // ensure that the output is in two shards
    SAMFileHeader header = readSource.getHeader(inputBamSpecifier, referencePath);

    ReadsSparkSink.writeReads(ctx, outputFile.getAbsolutePath(), referencePath, rddParallelReads, header, ReadsWriteFormat.SHARDED, 0, null, false, sbiGranularity);
    int shards = outputFile.listFiles((dir, name) -> !name.startsWith(".") && !name.startsWith("_")).length;
    Assert.assertEquals(shards, 2);
    // check that no local .crc files are created
    int crcs = outputFile.listFiles((dir, name) -> name.startsWith(".") && name.endsWith(".crc")).length;
    Assert.assertEquals(crcs, 0);

    JavaRDD<GATKRead> rddParallelReads2 = readSource.getParallelReads(new GATKPath(outputFile.getAbsolutePath()), referencePath);
    // reads are not globally sorted, so don't test that
    Assert.assertEquals(rddParallelReads.count(), rddParallelReads2.count());
}
 
Example 4
Source File: VariantsSparkSinkUnitTest.java    From gatk with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
private void assertSingleShardedWritingWorks(String vcf, String outputPath, boolean writeTabixIndex) throws IOException {
    JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();

    VariantsSparkSource variantsSparkSource = new VariantsSparkSource(ctx);
    JavaRDD<VariantContext> variants = variantsSparkSource.getParallelVariantContexts(vcf, null);
    if (variants.getNumPartitions() == 1) {
        variants = variants.repartition(3); // repartition to more than 1 partition
    }
    VCFHeader header = getHeader(vcf);

    VariantsSparkSink.writeVariants(ctx, outputPath, variants, header, writeTabixIndex);

    checkFileExtensionConsistentWithContents(outputPath, writeTabixIndex);

    JavaRDD<VariantContext> variants2 = variantsSparkSource.getParallelVariantContexts(outputPath, null);
    final List<VariantContext> writtenVariants = variants2.collect();

    VariantContextTestUtils.assertEqualVariants(readVariants(vcf), writtenVariants);
}
 
Example 5
Source File: SparkTrainWorker.java    From ytk-learn with MIT License 5 votes vote down vote up
public boolean sparkTrain(JavaRDD<String> rdd) {
    JavaRDD<String> repartition = rdd.repartition(slaveNum);
    JavaRDD<Boolean> partRDD = repartition.mapPartitionsWithIndex(trainFunc, true);
    List<Boolean> res = partRDD.collect();
    for (boolean result : res) {
        if (!result) {
            return false;
        }
    }
    return true;
}
 
Example 6
Source File: SqoopSparkDriver.java    From sqoop-on-spark with Apache License 2.0 5 votes vote down vote up
public static void execute(JobRequest request, SparkConf conf, JavaSparkContext sc)
    throws Exception {

  LOG.info("Executing sqoop spark job");

  long totalTime = System.currentTimeMillis();
  SparkPrefixContext driverContext = new SparkPrefixContext(request.getConf(),
      JobConstants.PREFIX_CONNECTOR_DRIVER_CONTEXT);

  int defaultExtractors = conf.getInt(DEFAULT_EXTRACTORS, 10);
  long numExtractors = (driverContext.getLong(JobConstants.JOB_ETL_EXTRACTOR_NUM,
      defaultExtractors));
  int numLoaders = conf.getInt(NUM_LOADERS, 1);

  List<Partition> sp = getPartitions(request, numExtractors);
  System.out.println(">>> Partition size:" + sp.size());

  JavaRDD<Partition> rdd = sc.parallelize(sp, sp.size());
  JavaRDD<List<IntermediateDataFormat<?>>> mapRDD = rdd.map(new SqoopExtractFunction(
      request));
  // if max loaders or num loaders is given reparition to adjust the max
  // loader parallelism
  if (numLoaders != numExtractors) {
    JavaRDD<List<IntermediateDataFormat<?>>> reParitionedRDD = mapRDD.repartition(numLoaders);
    System.out.println(">>> RePartition RDD size:" + reParitionedRDD.partitions().size());
    reParitionedRDD.mapPartitions(new SqoopLoadFunction(request)).collect();
  } else {
    System.out.println(">>> Mapped RDD size:" + mapRDD.partitions().size());
    mapRDD.mapPartitions(new SqoopLoadFunction(request)).collect();
  }

  System.out.println(">>> TOTAL time ms:" + (System.currentTimeMillis() - totalTime));

  LOG.info("Done EL in sqoop spark job, next call destroy apis");

}
 
Example 7
Source File: MLLibUtil.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
/**
 * Convert an rdd
 * of labeled point
 * based on the specified batch size
 * in to data set
 * @param data the data to convert
 * @param numPossibleLabels the number of possible labels
 * @param batchSize the batch size
 * @return the new rdd
 */
public static JavaRDD<DataSet> fromLabeledPoint(JavaRDD<LabeledPoint> data, final long numPossibleLabels,
                long batchSize) {

    JavaRDD<DataSet> mappedData = data.map(new Function<LabeledPoint, DataSet>() {
        @Override
        public DataSet call(LabeledPoint lp) {
            return fromLabeledPoint(lp, numPossibleLabels);
        }
    });

    return mappedData.repartition((int) (mappedData.count() / batchSize));
}
 
Example 8
Source File: MiniBatchTests.java    From deeplearning4j with Apache License 2.0 5 votes vote down vote up
@Test
public void testMiniBatches() throws Exception {
    log.info("Setting up Spark Context...");
    JavaRDD<String> lines = sc.textFile(new ClassPathResource("svmLight/iris_svmLight_0.txt")
                    .getTempFileFromArchive().toURI().toString()).cache();
    long count = lines.count();
    assertEquals(300, count);
    // gotta map this to a Matrix/INDArray
    RecordReader rr = new SVMLightRecordReader();
    Configuration c = new Configuration();
    c.set(SVMLightRecordReader.NUM_FEATURES, "5");
    rr.setConf(c);
    JavaRDD<DataSet> points = lines.map(new RecordReaderFunction(rr, 4, 3)).cache();
    count = points.count();
    assertEquals(300, count);

    List<DataSet> collect = points.collect();

    points = points.repartition(1);
    JavaRDD<DataSet> miniBatches = new RDDMiniBatches(10, points).miniBatchesJava();
    count = miniBatches.count();
    List<DataSet> list = miniBatches.collect();
    assertEquals(30, count);    //Expect exactly 30 from 1 partition... could be more for multiple input partitions

    lines.unpersist();
    points.unpersist();
    miniBatches.map(new DataSetAssertionFunction());
}
 
Example 9
Source File: TestSparkComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testEvaluationAndRocMDS() {
    for( int evalWorkers : new int[]{1, 4, 8}) {

        DataSetIterator iter = new IrisDataSetIterator(5, 150);

        //Make a 2-class version of iris:
        List<MultiDataSet> l = new ArrayList<>();
        iter.reset();
        while (iter.hasNext()) {
            DataSet ds = iter.next();
            INDArray newL = Nd4j.create(ds.getLabels().size(0), 2);
            newL.putColumn(0, ds.getLabels().getColumn(0));
            newL.putColumn(1, ds.getLabels().getColumn(1));
            newL.getColumn(1).addi(ds.getLabels().getColumn(2));

            MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(ds.getFeatures(), newL);
            l.add(mds);
        }

        MultiDataSetIterator mdsIter = new IteratorMultiDataSetIterator(l.iterator(), 5);

        ComputationGraph cg = getBasicNetIris2Class();

        IEvaluation[] es = cg.doEvaluation(mdsIter, new Evaluation(), new ROC(32));
        Evaluation e = (Evaluation) es[0];
        ROC roc = (ROC) es[1];


        SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null);
        scg.setDefaultEvaluationWorkers(evalWorkers);

        JavaRDD<MultiDataSet> rdd = sc.parallelize(l);
        rdd = rdd.repartition(20);

        IEvaluation[] es2 = scg.doEvaluationMDS(rdd, 5, new Evaluation(), new ROC(32));
        Evaluation e2 = (Evaluation) es2[0];
        ROC roc2 = (ROC) es2[1];


        assertEquals(e2.accuracy(), e.accuracy(), 1e-3);
        assertEquals(e2.f1(), e.f1(), 1e-3);
        assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3);
        assertEquals(e2.falseNegatives(), e.falseNegatives());
        assertEquals(e2.falsePositives(), e.falsePositives());
        assertEquals(e2.trueNegatives(), e.trueNegatives());
        assertEquals(e2.truePositives(), e.truePositives());
        assertEquals(e2.precision(), e.precision(), 1e-3);
        assertEquals(e2.recall(), e.recall(), 1e-3);
        assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix());

        assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5);
        assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5);
    }
}
 
Example 10
Source File: ElephasQuadParser.java    From rdf2x with Apache License 2.0 4 votes vote down vote up
@Override
public JavaRDD<Quad> parseQuads(String path) {

    Configuration conf = new Configuration();

    Integer batchSize = config.getBatchSize();
    conf.set(NLineInputFormat.LINES_PER_MAP, batchSize.toString());

    if (config.getErrorHandling() == ParseErrorHandling.Throw) {
        conf.set(RdfIOConstants.INPUT_IGNORE_BAD_TUPLES, "false");
    } else {
        conf.set(RdfIOConstants.INPUT_IGNORE_BAD_TUPLES, "true");
    }

    Boolean isLineBased = config.getLineBasedFormat();
    if (isLineBased == null) {
        isLineBased = guessIsLineBasedFormat(path);
    }
    JavaRDD<Quad> quads;
    Integer partitions = config.getRepartition();
    if (isLineBased) {
        log.info("Parsing RDF in parallel with batch size: {}", batchSize);
        quads = sc.newAPIHadoopFile(path,
                NQuadsInputFormat.class,
                LongWritable.class, // position
                QuadWritable.class, // value
                conf).values().map(QuadWritable::get);
    } else {
        // let Jena guess the format, load whole files
        log.info("Input format is not line based, parsing RDF by Master node only.");
        quads = sc.newAPIHadoopFile(path,
                TriplesOrQuadsInputFormat.class,
                LongWritable.class, // position
                QuadWritable.class, // value
                conf).values().map(QuadWritable::get);

        if (partitions == null) {
            log.warn("Reading non-line based formats by master node only, consider setting --parsing.repartition to redistribute work to other nodes.");
        }
    }
    if (partitions != null) {
        log.info("Distributing workload, repartitioning into {} partitions", partitions);
        quads = quads.repartition(partitions);
    }


    final List<String> acceptedLanguages = config.getAcceptedLanguages();
    // if only some languages are accepted
    if (!acceptedLanguages.isEmpty()) {
        // filter out literals of unsupported languages
        quads = quads.filter(quad ->
                !quad.getObject().isLiteral() ||
                        quad.getObject().getLiteralLanguage() == null ||
                        quad.getObject().getLiteralLanguage().isEmpty() ||
                        acceptedLanguages.contains(quad.getObject().getLiteralLanguage())
        );
    }

    return quads;
}
 
Example 11
Source File: TestSparkComputationGraph.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test(timeout = 60000L)
public void testEvaluationAndRoc() {
    for( int evalWorkers : new int[]{1, 4, 8}) {
        DataSetIterator iter = new IrisDataSetIterator(5, 150);

        //Make a 2-class version of iris:
        List<DataSet> l = new ArrayList<>();
        iter.reset();
        while (iter.hasNext()) {
            DataSet ds = iter.next();
            INDArray newL = Nd4j.create(ds.getLabels().size(0), 2);
            newL.putColumn(0, ds.getLabels().getColumn(0));
            newL.putColumn(1, ds.getLabels().getColumn(1));
            newL.getColumn(1).addi(ds.getLabels().getColumn(2));
            ds.setLabels(newL);
            l.add(ds);
        }

        iter = new ListDataSetIterator<>(l);

        ComputationGraph cg = getBasicNetIris2Class();

        Evaluation e = cg.evaluate(iter);
        ROC roc = cg.evaluateROC(iter, 32);


        SparkComputationGraph scg = new SparkComputationGraph(sc, cg, null);
        scg.setDefaultEvaluationWorkers(evalWorkers);


        JavaRDD<DataSet> rdd = sc.parallelize(l);
        rdd = rdd.repartition(20);

        Evaluation e2 = scg.evaluate(rdd);
        ROC roc2 = scg.evaluateROC(rdd);


        assertEquals(e2.accuracy(), e.accuracy(), 1e-3);
        assertEquals(e2.f1(), e.f1(), 1e-3);
        assertEquals(e2.getNumRowCounter(), e.getNumRowCounter(), 1e-3);
        assertEquals(e2.falseNegatives(), e.falseNegatives());
        assertEquals(e2.falsePositives(), e.falsePositives());
        assertEquals(e2.trueNegatives(), e.trueNegatives());
        assertEquals(e2.truePositives(), e.truePositives());
        assertEquals(e2.precision(), e.precision(), 1e-3);
        assertEquals(e2.recall(), e.recall(), 1e-3);
        assertEquals(e2.getConfusionMatrix(), e.getConfusionMatrix());

        assertEquals(roc.calculateAUC(), roc2.calculateAUC(), 1e-5);
        assertEquals(roc.calculateAUCPR(), roc2.calculateAUCPR(), 1e-5);
    }
}
 
Example 12
Source File: TestRepartitioning.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
public void testRepartitioning2() throws Exception {

    int[] ns;
    if(isIntegrationTests()){
        ns = new int[]{320, 321, 25600, 25601, 25615};
    } else {
        ns = new int[]{320, 2561};
    }

    for (int n : ns) {

        List<String> list = new ArrayList<>();
        for (int i = 0; i < n; i++) {
            list.add(String.valueOf(i));
        }

        JavaRDD<String> rdd = sc.parallelize(list);
        rdd.repartition(65);

        int totalDataSetObjectCount = n;
        int dataSetObjectsPerSplit = 8 * 4 * 10;
        int valuesPerPartition = 10;
        int nPartitions = 32;

        JavaRDD<String>[] splits = org.deeplearning4j.spark.util.SparkUtils.balancedRandomSplit(
                        totalDataSetObjectCount, dataSetObjectsPerSplit, rdd, new Random().nextLong());

        List<Integer> counts = new ArrayList<>();
        List<List<Tuple2<Integer, Integer>>> partitionCountList = new ArrayList<>();
        //            System.out.println("------------------------");
        //            System.out.println("Partitions Counts:");
        for (JavaRDD<String> split : splits) {
            JavaRDD<String> repartitioned = SparkUtils.repartition(split, Repartition.Always,
                            RepartitionStrategy.Balanced, valuesPerPartition, nPartitions);
            List<Tuple2<Integer, Integer>> partitionCounts = repartitioned
                            .mapPartitionsWithIndex(new CountPartitionsFunction<String>(), true).collect();
            //                System.out.println(partitionCounts);
            partitionCountList.add(partitionCounts);
            counts.add((int) split.count());
        }

        //            System.out.println(counts.size());
        //            System.out.println(counts);


        int expNumPartitionsWithMore = totalDataSetObjectCount % nPartitions;
        int actNumPartitionsWithMore = 0;
        for (List<Tuple2<Integer, Integer>> l : partitionCountList) {
            assertEquals(nPartitions, l.size());

            for (Tuple2<Integer, Integer> t2 : l) {
                int partitionSize = t2._2();
                if (partitionSize > valuesPerPartition)
                    actNumPartitionsWithMore++;
            }
        }

        assertEquals(expNumPartitionsWithMore, actNumPartitionsWithMore);
    }
}
 
Example 13
Source File: TestExport.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testBatchAndExportMultiDataSetsFunction() throws Exception {
        String baseDir = System.getProperty("java.io.tmpdir");
        baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExportMDS/");
        baseDir = baseDir.replaceAll("\\\\", "/");
        File f = new File(baseDir);
        if (f.exists())
            FileUtils.deleteDirectory(f);
        f.mkdir();
        f.deleteOnExit();
        int minibatchSize = 5;
        int nIn = 4;
        int nOut = 3;

        List<MultiDataSet> dataSets = new ArrayList<>();
        dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(10, nIn), Nd4j.create(10, nOut))); //Larger than minibatch size -> tests splitting
        for (int i = 0; i < 98; i++) {
            if (i % 2 == 0) {
                dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(5, nIn), Nd4j.create(5, nOut)));
            } else {
                dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
                dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
                dataSets.add(new org.nd4j.linalg.dataset.MultiDataSet(Nd4j.create(3, nIn), Nd4j.create(3, nOut)));
            }
        }

        Collections.shuffle(dataSets, new Random(12345));

        JavaRDD<MultiDataSet> rdd = sc.parallelize(dataSets);
        rdd = rdd.repartition(1); //For testing purposes (should get exactly 100 out, but maybe more with more partitions)


        JavaRDD<String> pathsRdd = rdd.mapPartitionsWithIndex(
                        new BatchAndExportMultiDataSetsFunction(minibatchSize, "file:///" + baseDir), true);

        List<String> paths = pathsRdd.collect();
        assertEquals(100, paths.size());

        File[] files = f.listFiles();
        assertNotNull(files);

        int count = 0;
        for (File file : files) {
            if (!file.getPath().endsWith(".bin"))
                continue;
//            System.out.println(file);
            MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
            ds.load(file);
            assertEquals(minibatchSize, ds.getFeatures(0).size(0));
            assertEquals(minibatchSize, ds.getLabels(0).size(0));

            count++;
        }

        assertEquals(100, count);

        FileUtils.deleteDirectory(f);
    }
 
Example 14
Source File: TestExport.java    From deeplearning4j with Apache License 2.0 4 votes vote down vote up
@Test
    public void testBatchAndExportDataSetsFunction() throws Exception {
        String baseDir = System.getProperty("java.io.tmpdir");
        baseDir = FilenameUtils.concat(baseDir, "dl4j_spark_testBatchAndExport/");
        baseDir = baseDir.replaceAll("\\\\", "/");
        File f = new File(baseDir);
        if (f.exists())
            FileUtils.deleteDirectory(f);
        f.mkdir();
        f.deleteOnExit();
        int minibatchSize = 5;
        int nIn = 4;
        int nOut = 3;

        List<DataSet> dataSets = new ArrayList<>();
        dataSets.add(new DataSet(Nd4j.create(10, nIn), Nd4j.create(10, nOut))); //Larger than minibatch size -> tests splitting
        for (int i = 0; i < 98; i++) {
            if (i % 2 == 0) {
                dataSets.add(new DataSet(Nd4j.create(5, nIn), Nd4j.create(5, nOut)));
            } else {
                dataSets.add(new DataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
                dataSets.add(new DataSet(Nd4j.create(1, nIn), Nd4j.create(1, nOut)));
                dataSets.add(new DataSet(Nd4j.create(3, nIn), Nd4j.create(3, nOut)));
            }
        }

        Collections.shuffle(dataSets, new Random(12345));

        JavaRDD<DataSet> rdd = sc.parallelize(dataSets);
        rdd = rdd.repartition(1); //For testing purposes (should get exactly 100 out, but maybe more with more partitions)


        JavaRDD<String> pathsRdd = rdd.mapPartitionsWithIndex(
                        new BatchAndExportDataSetsFunction(minibatchSize, "file:///" + baseDir), true);

        List<String> paths = pathsRdd.collect();
        assertEquals(100, paths.size());

        File[] files = f.listFiles();
        assertNotNull(files);

        int count = 0;
        for (File file : files) {
            if (!file.getPath().endsWith(".bin"))
                continue;
//            System.out.println(file);
            DataSet ds = new DataSet();
            ds.load(file);
            assertEquals(minibatchSize, ds.numExamples());

            count++;
        }

        assertEquals(100, count);

        FileUtils.deleteDirectory(f);
    }
 
Example 15
Source File: InsertSizeMetricsCollectorSparkUnitTest.java    From gatk with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
@Test(dataProvider="metricsfiles", groups={"R", "spark"})
public void test(
        final String fileName,
        final String referenceName,
        final boolean allLevels,
        final String expectedResultsFile) throws IOException {

    final GATKPath inputPathSpecifier = new GATKPath(new File(TEST_DATA_DIR, fileName).getAbsolutePath());
    final GATKPath referencePath = referenceName != null ?
            new GATKPath(referenceName) :
            null;

    final File outfile = GATKBaseTest.createTempFile("test", ".insert_size_metrics");

    JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    ReadsSparkSource readSource = new ReadsSparkSource(ctx, ValidationStringency.DEFAULT_STRINGENCY);

    SAMFileHeader samHeader = readSource.getHeader(inputPathSpecifier, referencePath);
    JavaRDD<GATKRead> rddParallelReads = readSource.getParallelReads(inputPathSpecifier, referencePath);

    InsertSizeMetricsArgumentCollection isArgs = new InsertSizeMetricsArgumentCollection();
    isArgs.output = outfile.getAbsolutePath();
    if (allLevels) {
        isArgs.metricAccumulationLevel.accumulationLevels = new HashSet<>();
        isArgs.metricAccumulationLevel.accumulationLevels.add(MetricAccumulationLevel.ALL_READS);
        isArgs.metricAccumulationLevel.accumulationLevels.add(MetricAccumulationLevel.SAMPLE);
        isArgs.metricAccumulationLevel.accumulationLevels.add(MetricAccumulationLevel.LIBRARY);
        isArgs.metricAccumulationLevel.accumulationLevels.add(MetricAccumulationLevel.READ_GROUP);
    }

    InsertSizeMetricsCollectorSpark isSpark = new InsertSizeMetricsCollectorSpark();
    isSpark.initialize(isArgs, samHeader, null);

    // Since we're bypassing the framework in order to force this test to run on multiple partitions, we
    // need to make the read filter manually since we don't have the plugin descriptor to do it for us; so
    // remove the (default) FirstOfPairReadFilter filter and add in the SECOND_IN_PAIR manually since thats
    // required for our tests to pass
    List<ReadFilter> readFilters = isSpark.getDefaultReadFilters();
    readFilters.stream().filter(
            f -> !f.getClass().getSimpleName().equals(ReadFilterLibrary.FirstOfPairReadFilter.class.getSimpleName()));
    ReadFilter rf = ReadFilter.fromList(readFilters, samHeader);

    // Force the input RDD to be split into two partitions to ensure that the
    // reduce/combiners run
    rddParallelReads = rddParallelReads.repartition(2);
    isSpark.collectMetrics(rddParallelReads.filter(r -> rf.test(r)), samHeader);

    isSpark.saveMetrics(fileName);

    IntegrationTestSpec.assertEqualTextFiles(
            outfile,
            new File(TEST_DATA_DIR, expectedResultsFile),
            "#"
    );
}
 
Example 16
Source File: PathSeqPipelineSpark.java    From gatk with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
@Override
protected void runTool(final JavaSparkContext ctx) {

    filterArgs.doReadFilterArgumentWarnings(getCommandLineParser().getPluginDescriptor(GATKReadFilterPluginDescriptor.class), logger);
    SAMFileHeader header = PSUtils.checkAndClearHeaderSequences(getHeaderForReads(), filterArgs, logger);

    //Do not allow use of numReducers
    if (numReducers > 0) {
        throw new UserException.BadInput("Use --readsPerPartitionOutput instead of --num-reducers.");
    }

    //Filter
    final Tuple2<JavaRDD<GATKRead>, JavaRDD<GATKRead>> filterResult;
    final PSFilter filter = new PSFilter(ctx, filterArgs, header);
    try (final PSFilterLogger filterLogger = filterArgs.filterMetricsFileUri != null ? new PSFilterFileLogger(getMetricsFile(), filterArgs.filterMetricsFileUri) : new PSFilterEmptyLogger()) {
        final JavaRDD<GATKRead> inputReads = getReads();
        filterResult = filter.doFilter(inputReads, filterLogger);
    }
    JavaRDD<GATKRead> pairedReads = filterResult._1;
    JavaRDD<GATKRead> unpairedReads = filterResult._2;

    //Counting forces an action on the RDDs to guarantee we're done with the Bwa image and kmer filter
    final long numPairedReads = pairedReads.count();
    final long numUnpairedReads = unpairedReads.count();
    final long numTotalReads = numPairedReads + numUnpairedReads;

    //Closes Bwa image, kmer filter, and metrics file if used
    //Note the host Bwa image before must be unloaded before trying to load the pathogen image
    filter.close();

    //Rebalance partitions using the counts
    final int numPairedPartitions = 1 + (int) (numPairedReads / readsPerPartition);
    final int numUnpairedPartitions = 1 + (int) (numUnpairedReads / readsPerPartition);
    pairedReads = repartitionPairedReads(pairedReads, numPairedPartitions, numPairedReads);
    unpairedReads = unpairedReads.repartition(numUnpairedPartitions);

    //Bwa pathogen alignment
    final PSBwaAlignerSpark aligner = new PSBwaAlignerSpark(ctx, bwaArgs);
    PSBwaUtils.addReferenceSequencesToHeader(header, bwaArgs.microbeDictionary);
    final Broadcast<SAMFileHeader> headerBroadcast = ctx.broadcast(header);
    JavaRDD<GATKRead> alignedPairedReads = aligner.doBwaAlignment(pairedReads, true, headerBroadcast);
    JavaRDD<GATKRead> alignedUnpairedReads = aligner.doBwaAlignment(unpairedReads, false, headerBroadcast);

    //Cache this expensive result. Note serialization significantly reduces memory consumption.
    alignedPairedReads.persist(StorageLevel.MEMORY_AND_DISK_SER());
    alignedUnpairedReads.persist(StorageLevel.MEMORY_AND_DISK_SER());

    //Score pathogens
    final PSScorer scorer = new PSScorer(scoreArgs);
    final JavaRDD<GATKRead> readsFinal = scorer.scoreReads(ctx, alignedPairedReads, alignedUnpairedReads, header);

    //Clean up header
    header = PSBwaUtils.removeUnmappedHeaderSequences(header, readsFinal, logger);

    //Log read counts
    if (scoreArgs.scoreMetricsFileUri != null) {
        try (final PSScoreLogger scoreLogger = new PSScoreFileLogger(getMetricsFile(), scoreArgs.scoreMetricsFileUri)) {
            scoreLogger.logReadCounts(readsFinal);
        }
    }

    //Write reads to BAM, if specified
    if (outputPath != null) {
        try {
            //Reduce number of partitions since we previously went to ~5K reads per partition, which
            // is far too small for sharded output.
            final int numPartitions = Math.max(1, (int) (numTotalReads / readsPerPartitionOutput));
            final JavaRDD<GATKRead> readsFinalRepartitioned = readsFinal.coalesce(numPartitions, false);
            ReadsSparkSink.writeReads(ctx, outputPath, null, readsFinalRepartitioned, header,
                    shardedOutput ? ReadsWriteFormat.SHARDED : ReadsWriteFormat.SINGLE, numPartitions, shardedPartsDir, true, splittingIndexGranularity);
        } catch (final IOException e) {
            throw new UserException.CouldNotCreateOutputFile(outputPath, "writing failed", e);
        }
    }
    aligner.close();
}
 
Example 17
Source File: FindAssemblyRegionsSpark.java    From gatk with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
/**
 * Get an RDD of assembly regions for the given reads and intervals using the <i>strict</i> algorithm (looks for
 * assembly regions in each contig in parallel).
 * @param ctx the Spark context
 * @param reads the coordinate-sorted reads
 * @param header the header for the reads
 * @param sequenceDictionary the sequence dictionary for the reads
 * @param referenceFileName the file name for the reference
 * @param features source of arbitrary features (may be null)
 * @param intervalShards the sharded intervals to find assembly regions for
 * @param assemblyRegionEvaluatorSupplierBroadcast evaluator used to determine whether a locus is active
 * @param shardingArgs the arguments for sharding reads
 * @param assemblyRegionArgs the arguments for finding assembly regions
 * @param shuffle whether to use a shuffle or not when sharding reads
 * @return an RDD of assembly regions
 */
public static JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegionsStrict(
        final JavaSparkContext ctx,
        final JavaRDD<GATKRead> reads,
        final SAMFileHeader header,
        final SAMSequenceDictionary sequenceDictionary,
        final String referenceFileName,
        final FeatureManager features,
        final List<ShardBoundary> intervalShards,
        final Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast,
        final AssemblyRegionReadShardArgumentCollection shardingArgs,
        final AssemblyRegionArgumentCollection assemblyRegionArgs,
        final boolean shuffle) {
    JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, sequenceDictionary, intervalShards, shardingArgs.readShardSize, shuffle);
    Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);

    // 1. Calculate activity for each locus in the desired intervals, in parallel.
    JavaRDD<ActivityProfileStateRange> activityProfileStates = shardedReads.mapPartitions(getActivityProfileStatesFunction(referenceFileName, bFeatureManager, header,
            assemblyRegionEvaluatorSupplierBroadcast, assemblyRegionArgs));

    // 2. Group by contig. We need to do this so we can perform the band pass filter over the whole contig, so we
    // produce assembly regions that are identical to those produced by AssemblyRegionWalker.
    // This step requires a shuffle, but the amount of data in the ActivityProfileStateRange should be small, so it
    // should not be prohibitive.
    JavaPairRDD<String, Iterable<ActivityProfileStateRange>> contigToGroupedStates = activityProfileStates
            .keyBy((Function<ActivityProfileStateRange, String>) range -> range.getContig())
            .groupByKey();

    // 3. Run the band pass filter to find AssemblyRegions. The filtering is fairly cheap, so should be fast
    // even though it has to scan a whole contig. Note that we *don't* fill in reads here, since after we have found
    // the assembly regions we want to do assembly using the full resources of the cluster. So if we have
    // very small assembly region objects, then we can repartition them for redistribution across the cluster,
    // at which points the reads can be filled in. (See next step.)
    JavaRDD<ReadlessAssemblyRegion> readlessAssemblyRegions = contigToGroupedStates
            .flatMap(getReadlessAssemblyRegionsFunction(header, assemblyRegionArgs));
    // repartition to distribute the data evenly across the cluster again
    readlessAssemblyRegions = readlessAssemblyRegions.repartition(readlessAssemblyRegions.getNumPartitions());

    // 4. Fill in the reads. Each shard is an assembly region, with its overlapping reads.
    JavaRDD<Shard<GATKRead>> assemblyRegionShardedReads = SparkSharder.shard(ctx, reads, GATKRead.class, header.getSequenceDictionary(), readlessAssemblyRegions, shardingArgs.readShardSize);

    // 5. Convert shards to assembly regions. Reads downsampling is done again here. Note it will only be
    // consistent with the downsampling done in step 1 when https://github.com/broadinstitute/gatk/issues/5437 is in.
    JavaRDD<AssemblyRegion> assemblyRegions = assemblyRegionShardedReads.mapPartitions((FlatMapFunction<Iterator<Shard<GATKRead>>, AssemblyRegion>) shardedReadIterator -> {
        final ReadsDownsampler readsDownsampler = assemblyRegionArgs.maxReadsPerAlignmentStart > 0 ?
                new PositionalDownsampler(assemblyRegionArgs.maxReadsPerAlignmentStart, header) : null;
        return Utils.stream(shardedReadIterator)
                .map(shardedRead -> toAssemblyRegion(shardedRead, header, readsDownsampler)).iterator();
    });

    // 6. Add reference and feature context.
    return assemblyRegions.mapPartitions(getAssemblyRegionWalkerContextFunction(referenceFileName, bFeatureManager));
}
 
Example 18
Source File: SparkReader.java    From GeoTriples with Apache License 2.0 4 votes vote down vote up
/**
 * Call the corresponding reader regarding the source of the input file
 *
 * @return a Spark's Dataset containing the data
 */
public JavaRDD<Row> read(String repartition){

    long startTime = System.currentTimeMillis();
    JavaRDD<Row> rowRDD = null;
    Dataset<Row> dt;
    try {
        switch (source) {
            case SHP:
                int p = StringUtils.isNumeric(repartition) ? Integer.parseInt(repartition) : 0;
                rowRDD = readSHP(p);
                break;
            case CSV:
                dt = readCSV();
                // insert a column with ID
                dt = dt.withColumn(Config.GEOTRIPLES_AUTO_ID, functions.monotonicallyIncreasingId());
                headers = dt.columns();
                rowRDD = dt.javaRDD();
                break;
            case TSV:
                dt = readTSV();
                // insert a column with ID
                dt = dt.withColumn(Config.GEOTRIPLES_AUTO_ID, functions.monotonicallyIncreasingId());
                headers = dt.columns();
                rowRDD = dt.javaRDD();
                break;
            case GEOJSON:
                dt = readGeoJSON();
                // insert a column with ID
                dt = dt.withColumn(Config.GEOTRIPLES_AUTO_ID, functions.monotonicallyIncreasingId());
                headers = dt.columns();
                rowRDD = dt.javaRDD();
                break;
            case KML:
                log.error("KML files are not Supported yet");
                break;
        }

        /*
             repartition the loaded dataset if it is specified by user.
             if "repartition" is set to "defualt" the number of partitions is calculated based on input's size
             else the number must be defined by the user
        */
        int partitions = rowRDD == null ? 0: rowRDD.getNumPartitions();
        log.info("The input data was read into " + partitions + " partitions");
        if (repartition != null && source != Source.SHP){
            int new_partitions = 0;
            if (repartition.equals("default")) {
                try {
                    Configuration conf = new Configuration();
                    FileSystem fs = FileSystem.get(conf);
                    for (String filename : filenames) {
                        Path input_path = new Path(filename);
                        double file_size = fs.getContentSummary(input_path).getLength();
                        new_partitions += Math.ceil(file_size / 120000000) + 1;
                    }
                }
                catch(IOException e){
                    e.printStackTrace();
                    System.exit(1);
                }
            }
            else if (StringUtils.isNumeric(repartition))
                new_partitions = Integer.parseInt(repartition);

            if(new_partitions > 0){
                if(partitions > new_partitions)
                    rowRDD = rowRDD.coalesce(new_partitions);
                else
                    rowRDD = rowRDD.repartition(new_partitions);
                log.info("Dataset was repartitioned into: " + new_partitions + " partitions");
            }
        }
    }
    catch (NullPointerException ex){
        log.error("Not Supported file format");
        ex.printStackTrace();
        System.exit(1);
    }
    log.info("Input dataset(s) was loaded in " + (System.currentTimeMillis() - startTime) + " msec");
    return rowRDD;
}
 
Example 19
Source File: MpBoostLearner.java    From sparkboost with Apache License 2.0 4 votes vote down vote up
/**
 * Build a new classifier by analyzing the training data available in the
 * specified documents set.
 *
 * @param docs The set of documents used as training data.
 * @return A new MP-Boost classifier.
 */
public BoostClassifier buildModel(JavaRDD<MultilabelPoint> docs) {
    if (docs == null)
        throw new NullPointerException("The set of training documents is 'null'");

    // Repartition documents.
    Logging.l().info("Load initial data and generating internal data representations...");
    docs = docs.repartition(getNumDocumentsPartitions());
    docs = docs.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Docs: num partitions " + docs.partitions().size());

    int numDocs = DataUtils.getNumDocuments(docs);
    int numLabels = DataUtils.getNumLabels(docs);
    JavaRDD<DataUtils.LabelDocuments> labelDocuments = DataUtils.getLabelDocuments(docs);

    // Repartition labels.
    labelDocuments = labelDocuments.repartition(getNumLabelsPartitions());
    labelDocuments = labelDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Labels: num partitions " + labelDocuments.partitions().size());

    // Repartition features.
    JavaRDD<DataUtils.FeatureDocuments> featureDocuments = DataUtils.getFeatureDocuments(docs);
    featureDocuments = featureDocuments.repartition(getNumFeaturesPartitions());
    featureDocuments = featureDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Features: num partitions " + featureDocuments.partitions().size());


    Logging.l().info("Ok, done!");

    WeakHypothesis[] computedWH = new WeakHypothesis[numIterations];
    double[][] localDM = initDistributionMatrix(numLabels, numDocs);
    for (int i = 0; i < numIterations; i++) {

        // Generate new weak hypothesis.
        WeakHypothesis localWH = learnWeakHypothesis(localDM, labelDocuments, featureDocuments);

        // Update distribution matrix with the new hypothesis.
        updateDistributionMatrix(sc, docs, localDM, localWH);

        // Save current generated weak hypothesis.
        computedWH[i] = localWH;

        Logging.l().info("Completed iteration " + (i + 1));
    }

    Logging.l().info("Model built!");

    return new BoostClassifier(computedWH);
}
 
Example 20
Source File: AdaBoostMHLearner.java    From sparkboost with Apache License 2.0 4 votes vote down vote up
/**
 * Build a new classifier by analyzing the training data available in the
 * specified documents set.
 *
 * @param docs The set of documents used as training data.
 * @return A new AdaBoost.MH classifier.
 */
public BoostClassifier buildModel(JavaRDD<MultilabelPoint> docs) {
    if (docs == null)
        throw new NullPointerException("The set of input documents is 'null'");


    // Repartition documents.
    Logging.l().info("Load initial data and generating internal data representations...");
    docs = docs.repartition(getNumDocumentsPartitions());
    docs = docs.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Docs: num partitions " + docs.partitions().size());

    int numDocs = DataUtils.getNumDocuments(docs);
    int numLabels = DataUtils.getNumLabels(docs);
    JavaRDD<DataUtils.LabelDocuments> labelDocuments = DataUtils.getLabelDocuments(docs);

    // Repartition labels.
    labelDocuments = labelDocuments.repartition(getNumLabelsPartitions());
    labelDocuments = labelDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Labels: num partitions " + labelDocuments.partitions().size());

    // Repartition features.
    JavaRDD<DataUtils.FeatureDocuments> featureDocuments = DataUtils.getFeatureDocuments(docs);
    featureDocuments = featureDocuments.repartition(getNumFeaturesPartitions());
    featureDocuments = featureDocuments.persist(StorageLevel.MEMORY_AND_DISK_SER());
    Logging.l().info("Features: num partitions " + featureDocuments.partitions().size());
    Logging.l().info("Ok, done!");

    WeakHypothesis[] computedWH = new WeakHypothesis[numIterations];
    double[][] localDM = initDistributionMatrix(numLabels, numDocs);
    for (int i = 0; i < numIterations; i++) {

        // Generate new weak hypothesis.
        WeakHypothesis localWH = learnWeakHypothesis(localDM, labelDocuments, featureDocuments);

        // Update distribution matrix with the new hypothesis.
        updateDistributionMatrix(sc, docs, localDM, localWH);

        // Save current generated weak hypothesis.
        computedWH[i] = localWH;

        Logging.l().info("Completed iteration " + (i + 1));
    }

    Logging.l().info("Model built!");

    return new BoostClassifier(computedWH);
}