Java Code Examples for org.apache.spark.broadcast.Broadcast

The following examples show how to use org.apache.spark.broadcast.Broadcast. These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source Project: sparkResearch   Source File: BroadCastParam.java    License: Apache License 2.0 6 votes vote down vote up
/**
 * 广播变量测试
 * @param args
 */
public static void main(String[] args) {
    SparkSession sparkSession = SparkSession.builder()
            .master("local[4]").appName("AttackFind").getOrCreate();
    //初始化sparkContext
    JavaSparkContext javaSparkContext = JavaSparkContext.fromSparkContext(sparkSession.sparkContext());
    //在这里假定一份广播变量
    //因为我们之前说过,广播变量只可读
    final List<String> broadcastList = Arrays.asList("190099HJLL","98392QUEYY","561788LLKK");
    //设置广播变量,把broadcast广播出去
    final Broadcast<List<String>> broadcast = javaSparkContext.broadcast(broadcastList);
    //定义数据
    JavaPairRDD<String,String> pairRDD = javaSparkContext.parallelizePairs(Arrays.asList(new Tuple2<>("000", "000")));
    JavaPairRDD<String,String> resultPairRDD = pairRDD.filter((Function<Tuple2<String, String>, Boolean>) v1 -> broadcast.value().contains(v1._2));
    resultPairRDD.foreach((VoidFunction<Tuple2<String, String>>) System.out::println);
}
 
Example 2
@Override
protected void runTool(final JavaSparkContext ctx) {
    String referenceFileName = addReferenceFilesForSpark(ctx, referenceArguments.getReferencePath());
    List<String> localKnownSitesFilePaths = addVCFsForSpark(ctx, knownVariants);

    //Should this get the getUnfilteredReads? getReads will merge default and command line filters.
    //but the code below uses other filters for other parts of the pipeline that do not honor
    //the commandline.
    final JavaRDD<GATKRead> initialReads = getReads();

    // The initial reads have already had the WellformedReadFilter applied to them, which
    // is all the filtering that ApplyBQSR wants. BQSR itself wants additional filtering
    // performed, so we do that here.
    //NOTE: this filter doesn't honor enabled/disabled commandline filters
    final ReadFilter bqsrReadFilter = ReadFilter.fromList(BaseRecalibrator.getBQSRSpecificReadFilterList(), getHeaderForReads());
    final JavaRDD<GATKRead> filteredReadsForBQSR = initialReads.filter(read -> bqsrReadFilter.test(read));

    JavaPairRDD<GATKRead, Iterable<GATKVariant>> readsWithVariants = JoinReadsWithVariants.join(filteredReadsForBQSR, localKnownSitesFilePaths);
    //note: we use the reference dictionary from the reads themselves.
    final RecalibrationReport bqsrReport = BaseRecalibratorSparkFn.apply(readsWithVariants, getHeaderForReads(), referenceFileName, bqsrArgs);

    final Broadcast<RecalibrationReport> reportBroadcast = ctx.broadcast(bqsrReport);
    final JavaRDD<GATKRead> finalReads = ApplyBQSRSparkFn.apply(initialReads, reportBroadcast, getHeaderForReads(), applyBqsrArgs.toApplyBQSRArgumentCollection(bqsrArgs));

    writeReads(ctx, output, finalReads);
}
 
Example 3
Source Project: deeplearning4j   Source File: ExecuteWorkerPathFlatMap.java    License: Apache License 2.0 6 votes vote down vote up
public ExecuteWorkerPathFlatMap(TrainingWorker<R> worker, DataSetLoader dataSetLoader, Broadcast<SerializableHadoopConfig> hadoopConfig) {
    this.workerFlatMap = new ExecuteWorkerFlatMap<>(worker);
    this.dataSetLoader = dataSetLoader;
    this.hadoopConfig = hadoopConfig;

    //How many dataset objects of size 'dataSetObjectNumExamples' should we load?
    //Only pass on the required number, not all of them (to avoid async preloading data that won't be used)
    //Most of the time we'll get exactly the number we want, but this isn't guaranteed all the time for all
    // splitting strategies
    WorkerConfiguration conf = worker.getDataConfiguration();
    int dataSetObjectNumExamples = conf.getDataSetObjectSizeExamples();
    int workerMinibatchSize = conf.getBatchSizePerWorker();
    int maxMinibatches = (conf.getMaxBatchesPerWorker() > 0 ? conf.getMaxBatchesPerWorker() : Integer.MAX_VALUE);

    if (maxMinibatches == Integer.MAX_VALUE) {
        maxDataSetObjects = Integer.MAX_VALUE;
    } else {
        //Required: total number of examples / examples per dataset object
        maxDataSetObjects =
                        (int) Math.ceil(maxMinibatches * workerMinibatchSize / ((double) dataSetObjectNumExamples));
    }
}
 
Example 4
Source Project: deeplearning4j   Source File: TextPipelineTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test @Ignore   //AB 2020/04/20 https://github.com/eclipse/deeplearning4j/issues/8849
public void testCountCumSum() throws Exception {
    JavaSparkContext sc = getContext();
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vec.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    pipeline.buildVocabCache();
    pipeline.buildVocabWordListRDD();
    JavaRDD<AtomicLong> sentenceCountRDD = pipeline.getSentenceCountRDD();

    CountCumSum countCumSum = new CountCumSum(sentenceCountRDD);
    JavaRDD<Long> sentenceCountCumSumRDD = countCumSum.buildCumSum();
    List<Long> sentenceCountCumSumList = sentenceCountCumSumRDD.collect();
    assertTrue(sentenceCountCumSumList.get(0) == 6L);
    assertTrue(sentenceCountCumSumList.get(1) == 9L);

    sc.stop();
}
 
Example 5
Source Project: geowave   Source File: GeoWaveRDDLoader.java    License: Apache License 2.0 6 votes vote down vote up
public static GeoWaveIndexedRDD loadIndexedRDD(
    final SparkContext sc,
    final DataStorePluginOptions storeOptions,
    final RDDOptions rddOpts,
    final NumericIndexStrategy indexStrategy) throws IOException {
  final GeoWaveRDD wrappedRDD = GeoWaveRDDLoader.loadRDD(sc, storeOptions, rddOpts);
  // Index strategy can be expensive so we will broadcast it and store it
  Broadcast<NumericIndexStrategy> broadcastStrategy = null;
  if (indexStrategy != null) {
    broadcastStrategy =
        (Broadcast<NumericIndexStrategy>) RDDUtils.broadcastIndexStrategy(sc, indexStrategy);
  }

  final GeoWaveIndexedRDD returnRDD = new GeoWaveIndexedRDD(wrappedRDD, broadcastStrategy);
  return returnRDD;
}
 
Example 6
Source Project: oryx   Source File: ALSUpdate.java    License: Apache License 2.0 6 votes vote down vote up
private static RDD<Tuple2<Object,double[]>> readAndConvertFeatureRDD(
    JavaPairRDD<String,float[]> javaRDD,
    Broadcast<? extends Map<String,Integer>> bIdToIndex) {

  RDD<Tuple2<Integer,double[]>> scalaRDD = javaRDD.mapToPair(t ->
      new Tuple2<>(bIdToIndex.value().get(t._1()), t._2())
  ).mapValues(f -> {
      double[] d = new double[f.length];
      for (int i = 0; i < d.length; i++) {
        d[i] = f[i];
      }
      return d;
    }
  ).rdd();

  // This mimics the persistence level establish by ALS training methods
  scalaRDD.persist(StorageLevel.MEMORY_AND_DISK());

  @SuppressWarnings("unchecked")
  RDD<Tuple2<Object,double[]>> objKeyRDD = (RDD<Tuple2<Object,double[]>>) (RDD<?>) scalaRDD;
  return objKeyRDD;
}
 
Example 7
Source Project: deeplearning4j   Source File: TextPipelineTest.java    License: Apache License 2.0 6 votes vote down vote up
@Test
public void testWordFreqAccNotIdentifyingStopWords() throws Exception {

    JavaSparkContext sc = getContext();
    //  word2vec.setRemoveStop(false);
    JavaRDD<String> corpusRDD = getCorpusRDD(sc);
    Broadcast<Map<String, Object>> broadcastTokenizerVarMap = sc.broadcast(word2vecNoStop.getTokenizerVarMap());

    TextPipeline pipeline = new TextPipeline(corpusRDD, broadcastTokenizerVarMap);
    JavaRDD<List<String>> tokenizedRDD = pipeline.tokenize();
    pipeline.updateAndReturnAccumulatorVal(tokenizedRDD);

    Counter<String> wordFreqCounter = pipeline.getWordFreqAcc().value();
    assertEquals(wordFreqCounter.getCount("is"), 1, 0);
    assertEquals(wordFreqCounter.getCount("this"), 1, 0);
    assertEquals(wordFreqCounter.getCount("are"), 1, 0);
    assertEquals(wordFreqCounter.getCount("a"), 1, 0);
    assertEquals(wordFreqCounter.getCount("strange"), 2, 0);
    assertEquals(wordFreqCounter.getCount("flowers"), 1, 0);
    assertEquals(wordFreqCounter.getCount("world"), 1, 0);
    assertEquals(wordFreqCounter.getCount("red"), 1, 0);

    sc.stop();
}
 
Example 8
Source Project: systemds   Source File: RemoteParForSpark.java    License: Apache License 2.0 6 votes vote down vote up
@SuppressWarnings("unchecked")
private static Map<String, Broadcast<CacheBlock>> broadcastInputs(SparkExecutionContext sec, ArrayList<ParForStatementBlock.ResultVar> resultVars) {
	LocalVariableMap inputs = sec.getVariables();
	// exclude the result variables
	// TODO use optimizer-picked list of amenable objects (e.g., size constraints)
	Set<String> retVars = resultVars.stream()
		.map(v -> v._name).collect(Collectors.toSet());
	Set<String> brVars = inputs.keySet().stream()
		.filter(v -> !retVars.contains(v)).collect(Collectors.toSet());
	
	// construct broadcast objects
	Map<String, Broadcast<CacheBlock>> result = new HashMap<>();
	for (String key : brVars) {
		Data var = sec.getVariable(key);
		if ((var instanceof ScalarObject) || (var instanceof MatrixObject && ((MatrixObject) var).isPartitioned()))
			continue;
		result.put(key, sec.broadcastVariable((CacheableData<CacheBlock>) var));
	}
	return result;
}
 
Example 9
static SVIntervalTree<SVInterval> findGenomewideHighCoverageIntervalsToIgnore(final FindBreakpointEvidenceSparkArgumentCollection params,
                                                                              final ReadMetadata readMetadata,
                                                                              final JavaSparkContext ctx,
                                                                              final SAMFileHeader header,
                                                                              final JavaRDD<GATKRead> unfilteredReads,
                                                                              final SVReadFilter filter,
                                                                              final Logger logger,
                                                                              final Broadcast<ReadMetadata> broadcastMetadata) {
    final int capacity = header.getSequenceDictionary().getSequences().stream()
            .mapToInt(seqRec -> (seqRec.getSequenceLength() + DEPTH_WINDOW_SIZE - 1)/DEPTH_WINDOW_SIZE).sum();
    final List<SVInterval> depthIntervals = new ArrayList<>(capacity);
    for (final SAMSequenceRecord sequenceRecord : header.getSequenceDictionary().getSequences()) {
        final int contigID = readMetadata.getContigID(sequenceRecord.getSequenceName());
        final int contigLength = sequenceRecord.getSequenceLength();
        for (int i = 1; i < contigLength; i = i + DEPTH_WINDOW_SIZE) {
            depthIntervals.add(new SVInterval(contigID, i, Math.min(contigLength, i + DEPTH_WINDOW_SIZE)));
        }
    }

    final List<SVInterval> highCoverageSubintervals = findHighCoverageSubintervalsAndLog(
            params, ctx, broadcastMetadata, depthIntervals, unfilteredReads, filter, logger);
    final SVIntervalTree<SVInterval> highCoverageSubintervalTree = new SVIntervalTree<>();
    highCoverageSubintervals.forEach(i -> highCoverageSubintervalTree.put(i, i));

    return highCoverageSubintervalTree;
}
 
Example 10
public SvDiscoveryInputMetaData(final JavaSparkContext ctx,
                                final DiscoverVariantsFromContigAlignmentsSparkArgumentCollection discoverStageArgs,
                                final String nonCanonicalChromosomeNamesFile,
                                final String outputPath,
                                final ReadMetadata readMetadata,
                                final List<SVInterval> assembledIntervals,
                                final PairedStrandedIntervalTree<EvidenceTargetLink> evidenceTargetLinks,
                                final Broadcast<SVIntervalTree<VariantContext>> cnvCallsBroadcast,
                                final SAMFileHeader headerForReads,
                                final ReferenceMultiSparkSource reference,
                                final Set<VCFHeaderLine> defaultToolVCFHeaderLines,
                                final Logger toolLogger) {

    final SAMSequenceDictionary sequenceDictionary = headerForReads.getSequenceDictionary();
    final Broadcast<Set<String>> canonicalChromosomesBroadcast =
            ctx.broadcast(SVUtils.getCanonicalChromosomes(nonCanonicalChromosomeNamesFile, sequenceDictionary));
    final String sampleId = SVUtils.getSampleId(headerForReads);

    this.referenceData = new ReferenceData(canonicalChromosomesBroadcast, ctx.broadcast(reference), ctx.broadcast(sequenceDictionary));
    this.sampleSpecificData = new SampleSpecificData(sampleId, cnvCallsBroadcast, assembledIntervals, evidenceTargetLinks, readMetadata, ctx.broadcast(headerForReads));
    this.discoverStageArgs = discoverStageArgs;
    this.outputPath = outputPath;
    this.defaultToolVCFHeaderLines = defaultToolVCFHeaderLines;
    this.toolLogger = toolLogger;
}
 
Example 11
private static FlatMapFunction<Shard<VariantContext>, VariantWalkerContext> getVariantsFunction(
        final String referenceFileName,
        final Broadcast<FeatureManager> bFeatureManager) {
    return (FlatMapFunction<Shard<VariantContext>, VariantWalkerContext>) shard -> {
        ReferenceDataSource reference = referenceFileName == null ? null : new ReferenceFileSource(IOUtils.getPath(SparkFiles.get(referenceFileName)));
        FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();

        return StreamSupport.stream(shard.spliterator(), false)
                .filter(v -> v.getStart() >= shard.getStart() && v.getStart() <= shard.getEnd()) // only include variants that start in the shard
                .map(v -> {
                    final SimpleInterval variantInterval = new SimpleInterval(v);
                    return new VariantWalkerContext(v,
                            new ReadsContext(), // empty
                            new ReferenceContext(reference, variantInterval),
                            new FeatureContext(features, variantInterval));
                }).iterator();
    };
}
 
Example 12
Source Project: beam   Source File: SparkBatchPortablePipelineTranslator.java    License: Apache License 2.0 6 votes vote down vote up
/**
 * Broadcast the side inputs of an executable stage. *This can be expensive.*
 *
 * @return Map from PCollection ID to Spark broadcast variable and coder to decode its contents.
 */
private static <SideInputT>
    ImmutableMap<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
        broadcastSideInputs(
            RunnerApi.ExecutableStagePayload stagePayload, SparkTranslationContext context) {
  Map<String, Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>>>
      broadcastVariables = new HashMap<>();
  for (SideInputId sideInputId : stagePayload.getSideInputsList()) {
    RunnerApi.Components stagePayloadComponents = stagePayload.getComponents();
    String collectionId =
        stagePayloadComponents
            .getTransformsOrThrow(sideInputId.getTransformId())
            .getInputsOrThrow(sideInputId.getLocalName());
    if (broadcastVariables.containsKey(collectionId)) {
      // This PCollection has already been broadcast.
      continue;
    }
    Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<SideInputT>> tuple2 =
        broadcastSideInput(collectionId, stagePayloadComponents, context);
    broadcastVariables.put(collectionId, tuple2);
  }
  return ImmutableMap.copyOf(broadcastVariables);
}
 
Example 13
Source Project: GeoTriples   Source File: SparkMaster.java    License: Apache License 2.0 6 votes vote down vote up
/**
 * Convert the input Dataset into RDF triples and store the results.
 * The conversion is taking place per Partitions using the mapPartition Spark transformation.
 * @param mapping_list list of TripleMaps
 */
private void convert_partition(ArrayList<TriplesMap> mapping_list){
    SparkContext sc = SparkContext.getOrCreate();

    Pair<ArrayList<TriplesMap>, List<String>> transformation_info = new Pair<>(mapping_list, Arrays.asList(reader.getHeaders()));
    ClassTag<Pair<ArrayList<TriplesMap>, List<String>>> classTag_pair = scala.reflect.ClassTag$.MODULE$.apply(Pair.class);
    Broadcast<Pair<ArrayList<TriplesMap>, List<String>>> bd_info = sc.broadcast(transformation_info, classTag_pair);

    rowRDD
        .mapPartitions(
        (Iterator<Row> rows_iter) -> {
            ArrayList<TriplesMap> p_mapping_list = bd_info.value().getKey();
            List<String> p_header = bd_info.value().getValue();
            RML_Converter rml_converter = new RML_Converter(p_mapping_list, p_header);
            rml_converter.start();
            rml_converter.registerFunctions();
            Iterator<String> triples = rml_converter.convertPartition(rows_iter);

            rml_converter.stop();
            return triples;
        })
        .saveAsTextFile(outputDir);
}
 
Example 14
Source Project: iceberg   Source File: RewriteManifestsAction.java    License: Apache License 2.0 6 votes vote down vote up
private List<ManifestFile> writeManifestsForPartitionedTable(
    Dataset<Row> manifestEntryDF, int numManifests,
    int targetNumManifestEntries) {

  Broadcast<FileIO> io = sparkContext.broadcast(fileIO);
  StructType sparkType = (StructType) manifestEntryDF.schema().apply("data_file").dataType();

  // we allow the actual size of manifests to be 10% higher if the estimation is not precise enough
  long maxNumManifestEntries = (long) (1.1 * targetNumManifestEntries);

  return withReusableDS(manifestEntryDF, df -> {
    Column partitionColumn = df.col("data_file.partition");
    return df.repartitionByRange(numManifests, partitionColumn)
        .sortWithinPartitions(partitionColumn)
        .mapPartitions(
            toManifests(io, maxNumManifestEntries, stagingLocation, formatVersion, spec, sparkType),
            manifestEncoder
        )
        .collectAsList();
  });
}
 
Example 15
/**
 * Grab template names for all reads that contain kmers associated with a given breakpoint.
 */
@VisibleForTesting static List<QNameAndInterval> getAssemblyQNames(
        final FindBreakpointEvidenceSparkArgumentCollection params,
        final JavaSparkContext ctx,
        final HopscotchUniqueMultiMap<SVKmer, Integer, KmerAndInterval> kmerMultiMap,
        final JavaRDD<GATKRead> unfilteredReads,
        final SVReadFilter filter ) {
    final Broadcast<HopscotchUniqueMultiMap<SVKmer, Integer, KmerAndInterval>> broadcastKmersAndIntervals =
            ctx.broadcast(kmerMultiMap);

    final int kSize = params.kSize;
    final List<QNameAndInterval> qNamesAndIntervals =
        unfilteredReads
            .filter(filter::notJunk)
            .filter(filter::isPrimaryLine)
            .mapPartitions(readItr ->
                    new FlatMapGluer<>(new QNameIntervalFinder(kSize,broadcastKmersAndIntervals.getValue()), readItr))
            .collect();

    SparkUtils.destroyBroadcast(broadcastKmersAndIntervals, "cleaned kmers and intervals");

    return qNamesAndIntervals;
}
 
Example 16
/**
 * Filters input assembly contigs that are not strong enough to support an event,
 * then delegates to {@link BreakpointsInference} to infer the reference locations
 * that bound the bi-path bubble in the graph caused by the event,
 * as well as the alternative path encoded in the contig sequence.
 */
private static JavaPairRDD<SimpleNovelAdjacencyAndChimericAlignmentEvidence, List<SvType>>
inferTypeFromSingleContigSimpleChimera(final JavaRDD<AssemblyContigWithFineTunedAlignments> assemblyContigs,
                                       final SvDiscoveryInputMetaData svDiscoveryInputMetaData) {

    final Broadcast<SAMSequenceDictionary> referenceSequenceDictionaryBroadcast = svDiscoveryInputMetaData.getReferenceData().getReferenceSequenceDictionaryBroadcast();
    final Broadcast<ReferenceMultiSparkSource> referenceBroadcast = svDiscoveryInputMetaData.getReferenceData().getReferenceBroadcast();

    return
            assemblyContigs
                    .filter(tig -> SimpleChimera
                            .splitPairStrongEnoughEvidenceForCA(tig.getHeadAlignment(), tig.getTailAlignment(),
                                    MORE_RELAXED_ALIGNMENT_MIN_MQ, MORE_RELAXED_ALIGNMENT_MIN_LENGTH))

                    .mapToPair(tig -> getNovelAdjacencyAndEvidence(tig, referenceSequenceDictionaryBroadcast.getValue()))

                    .groupByKey()       // group the same novel adjacency produced by different contigs together

                    .mapToPair(noveltyAndEvidence -> inferType(noveltyAndEvidence, referenceSequenceDictionaryBroadcast, referenceBroadcast));
}
 
Example 17
Source Project: iceberg   Source File: RemoveOrphanFilesAction.java    License: Apache License 2.0 6 votes vote down vote up
private Dataset<Row> buildActualFileDF() {
  List<String> subDirs = Lists.newArrayList();
  List<String> matchingFiles = Lists.newArrayList();

  Predicate<FileStatus> predicate = file -> file.getModificationTime() < olderThanTimestamp;

  // list at most 3 levels and only dirs that have less than 10 direct sub dirs on the driver
  listDirRecursively(location, predicate, hadoopConf.value(), 3, 10, subDirs, matchingFiles);

  JavaRDD<String> matchingFileRDD = sparkContext.parallelize(matchingFiles, 1);

  if (subDirs.isEmpty()) {
    return spark.createDataset(matchingFileRDD.rdd(), Encoders.STRING()).toDF("file_path");
  }

  int parallelism = Math.min(subDirs.size(), partitionDiscoveryParallelism);
  JavaRDD<String> subDirRDD = sparkContext.parallelize(subDirs, parallelism);

  Broadcast<SerializableConfiguration> conf = sparkContext.broadcast(hadoopConf);
  JavaRDD<String> matchingLeafFileRDD = subDirRDD.mapPartitions(listDirsRecursively(conf, olderThanTimestamp));

  JavaRDD<String> completeMatchingFileRDD = matchingFileRDD.union(matchingLeafFileRDD);
  return spark.createDataset(completeMatchingFileRDD.rdd(), Encoders.STRING()).toDF("file_path");
}
 
Example 18
Source Project: iceberg   Source File: SparkBatchWrite.java    License: Apache License 2.0 6 votes vote down vote up
SparkBatchWrite(Table table, Broadcast<FileIO> io, Broadcast<EncryptionManager> encryptionManager,
                CaseInsensitiveStringMap options, boolean overwriteDynamic, boolean overwriteByFilter,
                Expression overwriteExpr, String applicationId, String wapId, Schema writeSchema,
                StructType dsSchema) {
  this.table = table;
  this.format = getFileFormat(table.properties(), options);
  this.io = io;
  this.encryptionManager = encryptionManager;
  this.overwriteDynamic = overwriteDynamic;
  this.overwriteByFilter = overwriteByFilter;
  this.overwriteExpr = overwriteExpr;
  this.applicationId = applicationId;
  this.wapId = wapId;
  this.genieId = options.get("genie-id");
  this.writeSchema = writeSchema;
  this.dsSchema = dsSchema;

  long tableTargetFileSize = PropertyUtil.propertyAsLong(
      table.properties(), WRITE_TARGET_FILE_SIZE_BYTES, WRITE_TARGET_FILE_SIZE_BYTES_DEFAULT);
  this.targetFileSize = options.getLong("target-file-size-bytes", tableTargetFileSize);
}
 
Example 19
@Override
protected void runTool(final JavaSparkContext ctx) {

    validateParams();

    final Broadcast<SVIntervalTree<VariantContext>> cnvCallsBroadcast =
            StructuralVariationDiscoveryPipelineSpark.broadcastCNVCalls(ctx, getHeaderForReads(),
                    discoverStageArgs.cnvCallsFile);
    final String outputPrefixWithSampleName = getOutputPrefix();
    final SvDiscoveryInputMetaData svDiscoveryInputMetaData =
            new SvDiscoveryInputMetaData(ctx, discoverStageArgs, nonCanonicalChromosomeNamesFile, outputPrefixWithSampleName,
                    null, null, null,
                    cnvCallsBroadcast,
                    getHeaderForReads(), getReference(), getDefaultToolVCFHeaderLines(), localLogger);
    final JavaRDD<GATKRead> assemblyRawAlignments = getReads();

    final AssemblyContigsClassifiedByAlignmentSignatures contigsByPossibleRawTypes =
            preprocess(svDiscoveryInputMetaData, assemblyRawAlignments);

    final List<VariantContext> variants =
            dispatchJobs(ctx, contigsByPossibleRawTypes, svDiscoveryInputMetaData, assemblyRawAlignments, writeSAMFiles);
    contigsByPossibleRawTypes.unpersist();

    filterAndWriteMergedVCF(outputPrefixWithSampleName, variants, svDiscoveryInputMetaData);
}
 
Example 20
Source Project: iceberg   Source File: SparkWriteBuilder.java    License: Apache License 2.0 6 votes vote down vote up
@Override
public StreamingWrite buildForStreaming() {
  // Validate
  Schema writeSchema = SparkSchemaUtil.convert(table.schema(), dsSchema);
  TypeUtil.validateWriteSchema(table.schema(), writeSchema,
      checkNullability(spark, options), checkOrdering(spark, options));
  SparkUtil.validatePartitionTransforms(table.spec());

  // Change to streaming write if it is just append
  Preconditions.checkState(!overwriteDynamic,
      "Unsupported streaming operation: dynamic partition overwrite");
  Preconditions.checkState(!overwriteByFilter || overwriteExpr == Expressions.alwaysTrue(),
      "Unsupported streaming operation: overwrite by filter: %s", overwriteExpr);

  // Get application id
  String appId = spark.sparkContext().applicationId();

  // Get write-audit-publish id
  String wapId = spark.conf().get("spark.wap.id", null);

  Broadcast<FileIO> io = lazySparkContext().broadcast(SparkUtil.serializableFileIO(table));
  Broadcast<EncryptionManager> encryptionManager = lazySparkContext().broadcast(table.encryption());

  return new SparkStreamingWrite(
      table, io, encryptionManager, options, overwriteByFilter, writeQueryId, appId, wapId, writeSchema, dsSchema);
}
 
Example 21
@VisibleForTesting
static VariantContextBuilder annotateWithExternalCNVCalls(final String recordContig, final int pos, final int end,
                                                          final VariantContextBuilder inputBuilder,
                                                          final Broadcast<SAMSequenceDictionary> broadcastSequenceDictionary,
                                                          final Broadcast<SVIntervalTree<VariantContext>> broadcastCNVCalls,
                                                          final String sampleId) {
    if (broadcastCNVCalls == null)
        return inputBuilder;
    final SVInterval variantInterval = new SVInterval(broadcastSequenceDictionary.getValue().getSequenceIndex(recordContig), pos, end);
    final SVIntervalTree<VariantContext> cnvCallTree = broadcastCNVCalls.getValue();
    final String cnvCallAnnotation =
            Utils.stream(cnvCallTree.overlappers(variantInterval))
                    .map(overlapper -> formatExternalCNVCallAnnotation(overlapper.getValue(), sampleId))
                    .collect(Collectors.joining(VCFConstants.INFO_FIELD_ARRAY_SEPARATOR));
    if (!cnvCallAnnotation.isEmpty()) {
        return inputBuilder.attribute(GATKSVVCFConstants.EXTERNAL_CNV_CALLS, cnvCallAnnotation);
    } else
        return inputBuilder;
}
 
Example 22
private static FlatMapFunction<Iterator<AssemblyRegionWalkerContext>, VariantContext> assemblyFunction(final SAMFileHeader header,
                                                                                                       final String referenceFileName,
                                                                                                       final Broadcast<HaplotypeCallerArgumentCollection> hcArgsBroadcast,
                                                                                                       final Broadcast<AssemblyRegionArgumentCollection> assemblyRegionArgsBroadcast,
                                                                                                       final Broadcast<VariantAnnotatorEngine> annotatorEngineBroadcast) {
    return (FlatMapFunction<Iterator<AssemblyRegionWalkerContext>, VariantContext>) contexts -> {
        // HaplotypeCallerEngine isn't serializable but is expensive to instantiate, so construct and reuse one for every partition
        final ReferenceSequenceFile taskReferenceSequenceFile = taskReferenceSequenceFile(referenceFileName);
        final HaplotypeCallerEngine hcEngine = new HaplotypeCallerEngine(hcArgsBroadcast.value(), assemblyRegionArgsBroadcast.value(), false, false, header, taskReferenceSequenceFile, annotatorEngineBroadcast.getValue());
        Iterator<Iterator<VariantContext>> iterators = Utils.stream(contexts).map(context -> {
            AssemblyRegion region = context.getAssemblyRegion();
            FeatureContext featureContext = context.getFeatureContext();
            return hcEngine.callRegion(region, featureContext, context.getReferenceContext()).iterator();
        }).iterator();

        return Iterators.concat(iterators);
    };
}
 
Example 23
Source Project: flight-spark-source   Source File: DefaultSource.java    License: Apache License 2.0 5 votes vote down vote up
public DataSourceReader createReader(DataSourceOptions dataSourceOptions) {
  Location defaultLocation = Location.forGrpcInsecure(
    dataSourceOptions.get("host").orElse("localhost"),
    dataSourceOptions.getInt("port", 47470)
  );
  String sql = dataSourceOptions.get("path").orElse("");
  FlightDataSourceReader.FactoryOptions options = new FlightDataSourceReader.FactoryOptions(
    defaultLocation,
    sql,
    dataSourceOptions.get("username").orElse("anonymous"),
    dataSourceOptions.get("password").orElse(null),
    dataSourceOptions.getBoolean("parallel", false), null);
  Broadcast<FlightDataSourceReader.FactoryOptions> bOptions = lazySparkContext().broadcast(options);
  return new FlightDataSourceReader(bOptions);
}
 
Example 24
Source Project: geowave   Source File: RDDUtils.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * Translate a set of objects in a JavaRDD to a provided type and push to GeoWave
 *
 * @throws IOException
 */
private static void writeToGeoWave(
    final SparkContext sc,
    final Index index,
    final DataStorePluginOptions outputStoreOptions,
    final DataTypeAdapter adapter,
    final JavaRDD<SimpleFeature> inputRDD) throws IOException {

  // setup the configuration and the output format
  final Configuration conf = new org.apache.hadoop.conf.Configuration(sc.hadoopConfiguration());

  GeoWaveOutputFormat.setStoreOptions(conf, outputStoreOptions);
  GeoWaveOutputFormat.addIndex(conf, index);
  GeoWaveOutputFormat.addDataAdapter(conf, adapter);

  // create the job
  final Job job = new Job(conf);
  job.setOutputKeyClass(GeoWaveOutputKey.class);
  job.setOutputValueClass(SimpleFeature.class);
  job.setOutputFormatClass(GeoWaveOutputFormat.class);

  // broadcast string names
  final ClassTag<String> stringTag = scala.reflect.ClassTag$.MODULE$.apply(String.class);
  final Broadcast<String> typeName = sc.broadcast(adapter.getTypeName(), stringTag);
  final Broadcast<String> indexName = sc.broadcast(index.getName(), stringTag);

  // map to a pair containing the output key and the output value
  inputRDD.mapToPair(
      feat -> new Tuple2<>(
          new GeoWaveOutputKey(typeName.value(), indexName.value()),
          feat)).saveAsNewAPIHadoopDataset(job.getConfiguration());
}
 
Example 25
/**
 * Method to get the vehicles which are in radius of POI and their distance from POI.
 * 
 * @param nonFilteredIotDataStream original IoT data stream
 * @param broadcastPOIValues variable containing POI coordinates, route and vehicle types to monitor.
 */
public void processPOIData(JavaDStream<IoTData> nonFilteredIotDataStream,Broadcast<Tuple3<POIData, String, String>> broadcastPOIValues) {
	 
	// Filter by routeId,vehicleType and in POI range
	JavaDStream<IoTData> iotDataStreamFiltered = nonFilteredIotDataStream
			.filter(iot -> (iot.getRouteId().equals(broadcastPOIValues.value()._2())
					&& iot.getVehicleType().contains(broadcastPOIValues.value()._3())
					&& GeoDistanceCalculator.isInPOIRadius(Double.valueOf(iot.getLatitude()),
							Double.valueOf(iot.getLongitude()), broadcastPOIValues.value()._1().getLatitude(),
							broadcastPOIValues.value()._1().getLongitude(),
							broadcastPOIValues.value()._1().getRadius())));

	// pair with poi
	JavaPairDStream<IoTData, POIData> poiDStreamPair = iotDataStreamFiltered
			.mapToPair(iot -> new Tuple2<>(iot, broadcastPOIValues.value()._1()));

	// Transform to dstream of POITrafficData
	JavaDStream<POITrafficData> trafficDStream = poiDStreamPair.map(poiTrafficDataFunc);

	// Map Cassandra table column
	Map<String, String> columnNameMappings = new HashMap<String, String>();
	columnNameMappings.put("vehicleId", "vehicleid");
	columnNameMappings.put("distance", "distance");
	columnNameMappings.put("vehicleType", "vehicletype");
	columnNameMappings.put("timeStamp", "timestamp");

	// call CassandraStreamingJavaUtil function to save in DB
	javaFunctions(trafficDStream)
			.writerBuilder("traffickeyspace", "poi_traffic",CassandraJavaUtil.mapToRow(POITrafficData.class, columnNameMappings))
			.withConstantTTL(120)//keeping data for 2 minutes
			.saveToCassandra();
}
 
Example 26
/**
 * Given novel adjacency and inferred variant types that should be linked together,
 * produce annotated, and linked VCF records.
 */
public static List<VariantContext> produceLinkedAssemblyBasedVariants(final Tuple2<SvType, SvType> linkedVariants,
                                                                      final SimpleNovelAdjacencyAndChimericAlignmentEvidence simpleNovelAdjacencyAndChimericAlignmentEvidence,
                                                                      final Broadcast<ReferenceMultiSparkSource> broadcastReference,
                                                                      final Broadcast<SAMSequenceDictionary> broadcastSequenceDictionary,
                                                                      final Broadcast<SVIntervalTree<VariantContext>> broadcastCNVCalls,
                                                                      final String sampleId,
                                                                      final String linkKey) {

    final VariantContext firstVar = produceAnnotatedVcFromAssemblyEvidence(linkedVariants._1, simpleNovelAdjacencyAndChimericAlignmentEvidence,
            broadcastReference, broadcastSequenceDictionary, broadcastCNVCalls, sampleId).make();
    final VariantContext secondVar = produceAnnotatedVcFromAssemblyEvidence(linkedVariants._2, simpleNovelAdjacencyAndChimericAlignmentEvidence,
            broadcastReference, broadcastSequenceDictionary, broadcastCNVCalls, sampleId).make();

    final VariantContextBuilder builder1 = new VariantContextBuilder(firstVar);
    builder1.attribute(linkKey, secondVar.getID());

    final VariantContextBuilder builder2 = new VariantContextBuilder(secondVar);
    builder2.attribute(linkKey, firstVar.getID());

    // manually remove inserted sequence information from RPL event-produced DEL, when it can be linked with an INS
    if (linkedVariants._1 instanceof SimpleSVType.Deletion)
        return Arrays.asList(builder1.rmAttribute(GATKSVVCFConstants.INSERTED_SEQUENCE)
                                     .rmAttribute(GATKSVVCFConstants.INSERTED_SEQUENCE_LENGTH)
                                     .rmAttribute(GATKSVVCFConstants.SEQ_ALT_HAPLOTYPE)
                                     .make(),
                             builder2.make());
    else if (linkedVariants._2 instanceof SimpleSVType.Deletion) {
        return Arrays.asList(builder1.make(),
                             builder2.rmAttribute(GATKSVVCFConstants.INSERTED_SEQUENCE)
                                     .rmAttribute(GATKSVVCFConstants.INSERTED_SEQUENCE_LENGTH)
                                     .rmAttribute(GATKSVVCFConstants.SEQ_ALT_HAPLOTYPE)
                                     .make());
    } else
        return Arrays.asList(builder1.make(), builder2.make());
}
 
Example 27
Source Project: systemds   Source File: SparkExecutionContext.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * This call destroys a broadcast variable at all executors and the driver.
 * Hence, it is intended to be used on rmvar only. Depending on the
 * ASYNCHRONOUS_VAR_DESTROY configuration, this is asynchronous or not.
 *
 * @param bvar broadcast variable
 */
public static void cleanupBroadcastVariable(Broadcast<?> bvar)
{
	//In comparison to 'unpersist' (which would only delete the broadcast
	//from the executors), this call also deletes related data from the driver.
	if( bvar.isValid() ) {
		bvar.destroy( !ASYNCHRONOUS_VAR_DESTROY );
	}
}
 
Example 28
Source Project: beam   Source File: SparkBatchPortablePipelineTranslator.java    License: Apache License 2.0 5 votes vote down vote up
/**
 * Collect and serialize the data and then broadcast the result. *This can be expensive.*
 *
 * @return Spark broadcast variable and coder to decode its contents
 */
private static <T> Tuple2<Broadcast<List<byte[]>>, WindowedValueCoder<T>> broadcastSideInput(
    String collectionId, RunnerApi.Components components, SparkTranslationContext context) {
  @SuppressWarnings("unchecked")
  BoundedDataset<T> dataset = (BoundedDataset<T>) context.popDataset(collectionId);
  WindowedValueCoder<T> coder = getWindowedValueCoder(collectionId, components);
  List<byte[]> bytes = dataset.getBytes(coder);
  Broadcast<List<byte[]>> broadcast = context.getSparkContext().broadcast(bytes);
  return new Tuple2<>(broadcast, coder);
}
 
Example 29
@Override
protected Broadcast<Supplier<AssemblyRegionEvaluator>> assemblyRegionEvaluatorSupplierBroadcast(final JavaSparkContext ctx) {
    final Path referencePath = IOUtils.getPath(referenceArguments.getReferenceFileName());
    final String referenceFileName = referencePath.getFileName().toString();
    final String pathOnExecutor = SparkFiles.get(referenceFileName);
    final ReferenceSequenceFile taskReferenceSequenceFile = new CachingIndexedFastaSequenceFile(IOUtils.getPath(pathOnExecutor));
    final Collection<Annotation> annotations = makeVariantAnnotations();
    final VariantAnnotatorEngine annotatorEngine = new VariantAnnotatorEngine(annotations,  hcArgs.dbsnp.dbsnp, hcArgs.comps, hcArgs.emitReferenceConfidence != ReferenceConfidenceMode.NONE, false);
    return assemblyRegionEvaluatorSupplierBroadcastFunction(ctx, hcArgs, assemblyRegionArgs, getHeaderForReads(), taskReferenceSequenceFile, annotatorEngine);
}
 
Example 30
Source Project: gatk   Source File: PSScorerTest.java    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
@Test(dataProvider = "mapUnpaired", groups = "spark")
public void testMapGroupedReadsToTaxUnpaired(final int readLength, final List<Integer> NM, final List<Integer> clip,
                                             final List<Integer> insert, final List<Integer> delete,
                                             final List<String> contig, final List<Integer> truthTax) {

    if (!(NM.size() == clip.size() && NM.size() == insert.size() && NM.size() == delete.size() && NM.size() == contig.size())) {
        throw new TestException("Input lists for read must be of uniform length");
    }

    final JavaSparkContext ctx = SparkContextFactory.getTestSparkContext();
    final Broadcast<PSTaxonomyDatabase> taxonomyDatabaseBroadcast = ctx.broadcast(taxonomyDatabase);

    //Test with alternate alignments assigned to the XA tag
    final List<Iterable<GATKRead>> readListXA = new ArrayList<>();
    readListXA.add(generateUnpairedRead(readLength, NM, clip, insert, delete, contig, "XA"));
    final JavaRDD<Iterable<GATKRead>> pairsXA = ctx.parallelize(readListXA);
    final JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> resultXA = PSScorer.mapGroupedReadsToTax(pairsXA,
            MIN_IDENT, IDENT_MARGIN, taxonomyDatabaseBroadcast);
    final PSPathogenAlignmentHit infoXA = resultXA.first()._2;

    Assert.assertNotNull(infoXA);
    Assert.assertEquals(infoXA.taxIDs.size(), truthTax.size());
    Assert.assertTrue(infoXA.taxIDs.containsAll(truthTax));
    Assert.assertEquals(infoXA.numMates, 1);

    //Test SA tag
    final List<Iterable<GATKRead>> readListSA = new ArrayList<>();
    readListSA.add(generateUnpairedRead(readLength, NM, clip, insert, delete, contig, "SA"));
    final JavaRDD<Iterable<GATKRead>> pairsSA = ctx.parallelize(readListSA);
    final JavaRDD<Tuple2<Iterable<GATKRead>, PSPathogenAlignmentHit>> resultSA = PSScorer.mapGroupedReadsToTax(pairsSA,
            MIN_IDENT, IDENT_MARGIN, taxonomyDatabaseBroadcast);
    final PSPathogenAlignmentHit infoSA = resultSA.first()._2;

    Assert.assertNotNull(infoSA);
    Assert.assertEquals(infoSA.taxIDs.size(), truthTax.size());
    Assert.assertTrue(infoSA.taxIDs.containsAll(truthTax));
    Assert.assertEquals(infoSA.numMates, 1);
}