Java Code Examples for org.apache.spark.sql.Dataset#filter()

The following examples show how to use org.apache.spark.sql.Dataset#filter() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: DataFilterStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataset, Map<String, Object> parameters, SparkRunnerConfig config) {

	if (parameters == null || parameters.size() == 0) {
        BpmnaiLogger.getInstance().writeWarn("No parameters found for the DataFilterStep");
        return dataset;
    }
	
	String query = (String) parameters.get("query");               
    BpmnaiLogger.getInstance().writeInfo("Filtering data with filter query: " + query + ".");
    dataset = dataset.filter(query);

    dataset.cache();
    if(dataset.count() == 0) {
        BpmnaiLogger.getInstance().writeInfo("Filtering resulted in zero lines of data. Aborting. Please check your filter query.");
        System.exit(1);
    }
           
    return dataset;
}
 
Example 2
Source File: WaterInteractions.java    From mmtf-spark with Apache License 2.0 6 votes vote down vote up
/**
 * Remove rows where the water interaction does not include at least one organic ligand (LGO) 
 * and one protein residue (PRO).
 * 
 * TODO need to handle cases of maxInteractions > 4
 * @param data
 * @param maxInteractions
 * @return
 */
private static Dataset<Row> filterBridgingWaterInteractions(Dataset<Row> data, String maxInteractions) {    
	if (maxInteractions.compareTo("4") > 0) {
		throw new IllegalArgumentException("maxInteractions > 4 are not supported, yet");
	}
	
    if (maxInteractions.equals("2")) {
        data = data.filter(col("type1").equalTo("LGO").or(col("type2").equalTo("LGO")));
        data = data.filter(col("type1").equalTo("PRO").or(col("type2").equalTo("PRO")));
    } else if (maxInteractions.equals("3")) {
        data = data.filter(col("type1").equalTo("LGO").or(col("type2").equalTo("LGO"))
                .or(col("type3").equalTo("LGO")));
        data = data.filter(col("type1").equalTo("PRO").or(col("type2").equalTo("PRO"))
                .or(col("type3").equalTo("PRO")));
    } else if (maxInteractions.equals("4")) {
        data = data.filter(col("type1").equalTo("LGO").or(col("type2").equalTo("LGO"))
                .or(col("type3").equalTo("LGO")).or(col("type4").equalTo("LGO")));
        data = data.filter(col("type1").equalTo("PRO").or(col("type2").equalTo("PRO"))
                .or(col("type3").equalTo("PRO")).or(col("type4").equalTo("PRO")));
    }
    return data;
}
 
Example 3
Source File: PdbDrugBankMapping.java    From mmtf-spark with Apache License 2.0 6 votes vote down vote up
public static void main(String[] args) throws IOException {
    SparkSession spark = SparkSession.builder().master("local[*]").appName(PdbDrugBankMapping.class.getSimpleName())
            .getOrCreate();

    // download open DrugBank dataset
    Dataset<Row> drugBank = DrugBankDataset.getOpenDrugLinks();
    
    // find some tryrosine kinase inhibitors with generic name stem: "tinib"
    drugBank = drugBank.filter("Commonname LIKE '%tinib'");
    
    // get PDB ligand annotations
    Dataset<Row> ligands = CustomReportService.getDataset("ligandId","ligandMolecularWeight","ligandFormula","ligandSmiles","InChIKey");

    // join ligand dataset with DrugBank info by InChIKey
    ligands = ligands.join(drugBank, ligands.col("InChIKey").equalTo(drugBank.col("StandardInChIKey")));
   
    // show one example per drug molecule
    ligands = ligands.dropDuplicates("Commonname");
    ligands.select("structureChainId", "ligandId", "DrugBankID", "Commonname", "ligandMolecularWeight","ligandFormula", "InChIKey", "ligandSmiles")
    .sort("Commonname").show(50);

    spark.close(); 
}
 
Example 4
Source File: DrugBankDemo.java    From mmtf-spark with Apache License 2.0 6 votes vote down vote up
public static void main(String[] args) throws IOException {
    SparkSession spark = SparkSession.builder().master("local[*]").appName(DrugBankDemo.class.getSimpleName())
            .getOrCreate();

    // download open DrugBank dataset
    Dataset<Row> openDrugLinks = DrugBankDataset.getOpenDrugLinks();

    // find all drugs with an InChIKey
    openDrugLinks = openDrugLinks.filter("StandardInChIKey IS NOT NULL");

    // show some sample data
    openDrugLinks.select("DrugBankID", "Commonname", "CAS", "StandardInChIKey").show();

    // The DrugBank password protected datasets contain more information.
    // You need to create a DrugBank account and supply username/password
    // to access these datasets.

    // Download DrugBank dataset for approved drugs
    // String username = args[0];
    // String password = args[1];
    // Dataset<Row> drugLinks =
    // DrugBankDataset.getDrugLinks(DrugGroup.APPROVED, username, password);
    // drugLinks.show();

    spark.close(); 
}
 
Example 5
Source File: HoodieClientTestUtils.java    From hudi with Apache License 2.0 6 votes vote down vote up
/**
 * Obtain all new data written into the Hoodie table since the given timestamp.
 */
public static Dataset<Row> readSince(String basePath, SQLContext sqlContext,
                                     HoodieTimeline commitTimeline, String lastCommitTime) {
  List<HoodieInstant> commitsToReturn =
      commitTimeline.findInstantsAfter(lastCommitTime, Integer.MAX_VALUE).getInstants().collect(Collectors.toList());
  try {
    // Go over the commit metadata, and obtain the new files that need to be read.
    HashMap<String, String> fileIdToFullPath = getLatestFileIDsToFullPath(basePath, commitTimeline, commitsToReturn);
    String[] paths = fileIdToFullPath.values().toArray(new String[fileIdToFullPath.size()]);
    Dataset<Row> rows = null;
    if (paths[0].endsWith(HoodieFileFormat.PARQUET.getFileExtension())) {
      rows = sqlContext.read().parquet(paths);
    }

    return rows.filter(String.format("%s >'%s'", HoodieRecord.COMMIT_TIME_METADATA_FIELD, lastCommitTime));
  } catch (IOException e) {
    throw new HoodieException("Error pulling data incrementally from commitTimestamp :" + lastCommitTime, e);
  }
}
 
Example 6
Source File: ParDoTranslatorBatch.java    From beam with Apache License 2.0 6 votes vote down vote up
private void pruneOutputFilteredByTag(
    TranslationContext context,
    Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> allOutputs,
    Map.Entry<TupleTag<?>, PValue> output,
    Coder<? extends BoundedWindow> windowCoder) {
  Dataset<Tuple2<TupleTag<?>, WindowedValue<?>>> filteredDataset =
      allOutputs.filter(new DoFnFilterFunction(output.getKey()));
  Coder<WindowedValue<?>> windowedValueCoder =
      (Coder<WindowedValue<?>>)
          (Coder<?>)
              WindowedValue.getFullCoder(
                  ((PCollection<OutputT>) output.getValue()).getCoder(), windowCoder);
  Dataset<WindowedValue<?>> outputDataset =
      filteredDataset.map(
          (MapFunction<Tuple2<TupleTag<?>, WindowedValue<?>>, WindowedValue<?>>)
              value -> value._2,
          EncoderHelpers.fromBeamCoder(windowedValueCoder));
  context.putDatasetWildcard(output.getValue(), outputDataset);
}
 
Example 7
Source File: CustomReportDemo.java    From mmtf-spark with Apache License 2.0 5 votes vote down vote up
/**
 * @param args no input arguments
 * @throws IOException if custom report web service fails
 */
public static void main(String[] args) throws IOException {    
    long start = System.nanoTime();
    
    SparkConf conf = new SparkConf().setMaster("local[1]").setAppName(CustomReportDemo.class.getSimpleName());
    JavaSparkContext sc = new JavaSparkContext(conf);
   
    // retrieve PDB annotation: Binding affinities (Ki, Kd), 
    // group name of the ligand (hetId), and the 
    // Enzyme Classification number (ecNo)
    Dataset<Row> ds = CustomReportService.getDataset("Ki","Kd","hetId","ecNo");
    
    // show the schema of this dataset
    ds.printSchema();
        
    // select structures that either have a Ki or Kd value(s) and
    // are protein-serine/threonine kinases (EC 2.7.1.*):
    
    // A. by using dataset operations
    ds = ds.filter("(Ki IS NOT NULL OR Kd IS NOT NULL) AND ecNo LIKE '2.7.11.%'");
    ds.show(10);
     
    // B. by creating a temporary query and running SQL
    ds.createOrReplaceTempView("table");
    ds.sparkSession().sql("SELECT * from table WHERE (Ki IS NOT NULL OR Kd IS NOT NULL) AND ecNo LIKE '2.7.11.%'");
    ds.show(10);
    
    long end = System.nanoTime();
    
    System.out.println("Time:     " + (end-start)/1E9 + "sec.");
    
    sc.close();
}
 
Example 8
Source File: PdbMetadataDemo.java    From mmtf-spark with Apache License 2.0 5 votes vote down vote up
public static void main(String[] args) throws IOException {
 SparkSession spark = SparkSession.builder().master("local[*]").appName(PdbMetadataDemo.class.getSimpleName())
            .getOrCreate();

 // query the following fields from the _citation category using PDBj's Mine2 web service:
 // journal_abbrev, pdbx_database_id_PubMed, year.   
 // Note, mixed case column names must be quoted and escaped with \".
 String sqlQuery = "SELECT pdbid, journal_abbrev, \"pdbx_database_id_PubMed\", year from citation WHERE id = 'primary'";
 Dataset<Row>ds = PdbjMineDataset.getDataset(sqlQuery);
 
 System.out.println("First 10 results from query: " + sqlQuery);
 ds.show(10, false);
  
 // filter out unpublished entries (they contain the word "published" in various upper/lower case combinations)
 ds = ds.filter("UPPER(journal_abbrev) NOT LIKE '%PUBLISHED%'");
 
 // print the top 10 journals
 System.out.println("Top 10 journals that publish PDB structures:");
 ds.groupBy("journal_abbrev").count().sort(col("count").desc()).show(10, false);
	
 // filter out entries without a PubMed Id (is -1 if PubMed Id is not available)
 ds = ds.filter("pdbx_database_id_PubMed > 0");
 System.out.println("Entries with PubMed Ids: " + ds.count());
 
 // show growth of papers in PubMed
 System.out.println("PubMed Ids per year: ");
 ds.groupBy("year").count().sort(col("year").desc()).show(10, false);

 spark.close();
}
 
Example 9
Source File: SparkMLScoringOnline.java    From -Data-Stream-Development-with-Apache-Spark-Kafka-and-Spring-Boot with MIT License 4 votes vote down vote up
public static void main(String[] args) throws InterruptedException, StreamingQueryException {
 
      System.setProperty("hadoop.home.dir", HADOOP_HOME_DIR_VALUE);

      // * the schema can be written on disk, and read from disk
      // * the schema is not mandatory to be complete, it can contain only the needed fields    
      StructType RSVP_SCHEMA = new StructType()                                
              .add("event",
                      new StructType()
                              .add("event_id", StringType, true)
                              .add("event_name", StringType, true)
                              .add("event_url", StringType, true)
                              .add("time", LongType, true))
              .add("group",
                      new StructType()
                              .add("group_city", StringType, true)
                              .add("group_country", StringType, true)
                              .add("group_id", LongType, true)
                              .add("group_lat", DoubleType, true)
                              .add("group_lon", DoubleType, true)
                              .add("group_name", StringType, true)
                              .add("group_state", StringType, true)
                              .add("group_topics", DataTypes.createArrayType(
                                      new StructType()
                                              .add("topicName", StringType, true)
                                              .add("urlkey", StringType, true)), true)
                              .add("group_urlname", StringType, true))
              .add("guests", LongType, true)
              .add("member",
                      new StructType()
                              .add("member_id", LongType, true)
                              .add("member_name", StringType, true)                                
                              .add("photo", StringType, true))
              .add("mtime", LongType, true)
              .add("response", StringType, true)
              .add("rsvp_id", LongType, true)
              .add("venue",
                      new StructType()
                              .add("lat", DoubleType, true)
                              .add("lon", DoubleType, true)
                              .add("venue_id", LongType, true)
                              .add("venue_name", StringType, true))
              .add("visibility", StringType, true);

      final SparkConf conf = new SparkConf()
              .setMaster(RUN_LOCAL_WITH_AVAILABLE_CORES)
              .setAppName(APPLICATION_NAME)
              .set("spark.sql.caseSensitive", CASE_SENSITIVE);

      SparkSession spark = SparkSession
              .builder()
              .config(conf)
              .getOrCreate();

      PipelineModel pipelineModel = PipelineModel.load(MODEL_FOLDER_PATH);
     
      Dataset<Row> meetupStream = spark.readStream()
              .format(KAFKA_FORMAT)
              .option("kafka.bootstrap.servers", KAFKA_BROKERS)
              .option("subscribe", KAFKA_TOPIC)
              .load();

      Dataset<Row> gatheredDF = meetupStream.select(
    (from_json(col("value").cast("string"), RSVP_SCHEMA))
	        .alias("rsvp"))
	.alias("meetup")
          .select("meetup.*");
		
      Dataset<Row> filteredDF = gatheredDF.filter(e -> !e.anyNull());  

      Dataset<Row> preparedDF = filteredDF.select(
        col("rsvp.group.group_city"),
        col("rsvp.group.group_lat"), col("rsvp.group.group_lon"), 
		col("rsvp.response")
);
		                
      preparedDF.printSchema();
   
      Dataset<Row> predictionDF = pipelineModel.transform(preparedDF);
      
      StreamingQuery query = predictionDF.writeStream()                
              .format(JSON_FORMAT)
              .option("path", RESULT_FOLDER_PATH)
              .option("checkpointLocation", CHECKPOINT_LOCATION)
              .trigger(Trigger.ProcessingTime(QUERY_INTERVAL_SECONDS))
              .option("truncate", false)
              .start();

      query.awaitTermination();
  }
 
Example 10
Source File: DetermineProcessVariablesStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
private Dataset<Row> doFilterVariables(Dataset<Row> dataset, boolean writeStepResultIntoFile, SparkRunnerConfig config) {
    List<String> variablesToFilter = new ArrayList<>();

    Configuration configuration = ConfigurationUtils.getInstance().getConfiguration(config);
    if(configuration != null) {
        PreprocessingConfiguration preprocessingConfiguration = configuration.getPreprocessingConfiguration();
        if(preprocessingConfiguration != null) {
            for(VariableConfiguration vc : preprocessingConfiguration.getVariableConfiguration()) {
                if(!vc.isUseVariable()) {
                    variablesToFilter.add(vc.getVariableName());
                    BpmnaiLogger.getInstance().writeInfo("The variable '" + vc.getVariableName() + "' will be filtered out. Comment: " + vc.getComment());
                }
            }
        }

    }

    //check if all variables that should be filtered actually exist, otherwise log a warning
    List<Row> existingVariablesRows = dataset.select(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME).distinct().collectAsList();
    List<String> existingVariables = existingVariablesRows
            .stream()
            .map(r -> r.getString(0)).collect(Collectors.toList());

    variablesToFilter
            .stream()
            .forEach(new Consumer<String>() {
                @Override
                public void accept(String s) {
                    if(!existingVariables.contains(s)) {
                        // log the fact that a variable that should be filtered does not exist
                        BpmnaiLogger.getInstance().writeWarn("The variable '" + s + "' is configured to be filtered, but does not exist in the data.");
                    }
                }
            });

    dataset = dataset.filter((FilterFunction<Row>) row -> {
        // keep the row if the variable name column does not contain a value that should be filtered
        String variable = row.getAs(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME);

        //TODO: cleanup
        boolean keep = !variablesToFilter.contains(variable);
        if(variable != null && variable.startsWith("_CORRELATION_ID_")) {
            keep = false;
        }

        return keep;
    });

    if(writeStepResultIntoFile) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataset, "variable_filter", config);
    }

    return dataset;
}
 
Example 11
Source File: DataFilterOnActivityStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
/**
 * @param dataSet the incoming dataset for this processing step
 * @param parameters
 * @return the filtered DataSet
 */
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataSet, Map<String, Object> parameters, SparkRunnerConfig config) {
    // any parameters set?
    if (parameters == null || parameters.size() == 0) {
        BpmnaiLogger.getInstance().writeWarn("No parameters found for the DataFilterOnActivityStep");
        return dataSet;
    }

    // get query parameter
    String query = (String) parameters.get("query");
    BpmnaiLogger.getInstance().writeInfo("Filtering data with activity instance filter query: " + query + ".");

    // save size of initial dataset for log
    dataSet.cache();
    Long initialDSCount = dataSet.count();

    // repartition by process instance and order by start_time for this operation
    dataSet = dataSet.repartition(dataSet.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID)).sortWithinPartitions(BpmnaiVariables.VAR_START_TIME);

    // we temporarily store variable updates (rows with a var type set) separately.
    Dataset<Row> variables = dataSet.filter(col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).isNotNull());
    //find first occurrence of activity instance
    final Dataset<Row> dsTmp = dataSet.filter(dataSet.col(BpmnaiVariables.VAR_ACT_ID).equalTo(query)).filter(dataSet.col(BpmnaiVariables.VAR_END_TIME).isNull()); //TODO: ENSURING THAT THIS ISN'T A VARIABLE ROW

    // now we look for the first occurrence of the activity id contained in "query". The result comprises of a dataset of corresponding activity instances.
    final Dataset<Row> dsActivityInstances = dataSet.filter(dataSet.col(BpmnaiVariables.VAR_ACT_ID).like(query)).filter(dataSet.col(BpmnaiVariables.VAR_END_TIME).isNull()); //TODO: ENSURING THAT THIS ISN'T A VARIABLE ROW

    // we slim the resulting dataset down: only the activity instances process id and the instances start time are relevant.
    List<Row> activityRows = dsActivityInstances.select(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID, BpmnaiVariables.VAR_START_TIME).collectAsList();
    Map<String, String> activities = activityRows.stream().collect(Collectors.toMap(
            r -> r.getAs(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID), r -> r.getAs(BpmnaiVariables.VAR_START_TIME)));
    // broadcasting the PID - Start time Map to use it in a user defined function
    SparkBroadcastHelper.getInstance().broadcastVariable(SparkBroadcastHelper.BROADCAST_VARIABLE.PROCESS_INSTANCE_TIMESTAMP_MAP, activities);

    // now we have to select for each process instance in our inital dataset all events that happend before the first occurence of our selected activity.
    // We first narrow it down to the process instances in question
    Dataset<Row> selectedProcesses = dataSet.filter(col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID).isin(activities.keySet().toArray()));
    // Then, we mark all events that should be removed
    Dataset<Row> activityDataSet = selectedProcesses.withColumn("data_filter_on_activity",
            callUDF("activityBeforeTimestamp",
                    selectedProcesses.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID),
                    selectedProcesses.col(BpmnaiVariables.VAR_START_TIME)));
    // And we keep the rest
    activityDataSet = activityDataSet.filter(col("data_filter_on_activity").like("TRUE"));
    // Clean up
    activityDataSet = activityDataSet.drop("data_filter_on_activity");

    // However, we lost all variable updates in this approach, so now we add the variables in question to the dataset
    // first, we narrow it down to keep only variables that have a corresponding activity instance
    activityDataSet = activityDataSet.withColumnRenamed(BpmnaiVariables.VAR_ACT_INST_ID, BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT");

    variables = variables.join(activityDataSet.select(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT").distinct(), variables.col(BpmnaiVariables.VAR_ACT_INST_ID).equalTo(activityDataSet.col(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT")),"inner");

    activityDataSet = activityDataSet.withColumnRenamed(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT", BpmnaiVariables.VAR_ACT_INST_ID);
    variables = variables.drop(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT");
    dataSet = activityDataSet.union(variables);

    dataSet.cache();
    BpmnaiLogger.getInstance().writeInfo("DataFilterOnActivityStep: The filtered DataSet contains "+dataSet.count()+" rows, (before: "+ initialDSCount+" rows)");

    if (config.isWriteStepResultsIntoFile()) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataSet, "data_filter_on_activity_step", config);
    }

    return dataSet;


}
 
Example 12
Source File: AtpInteractionAnalysis.java    From mmtf-spark with Apache License 2.0 4 votes vote down vote up
/**
 * @param args input arguments
 * @throws IOException
 */
public static void main(String[] args) throws IOException {

	String path = MmtfReader.getMmtfFullPath();
     
    long start = System.nanoTime();
    
    SparkConf conf = new SparkConf().setMaster("local[*]").setAppName(AtpInteractionAnalysis.class.getSimpleName());
    JavaSparkContext sc = new JavaSparkContext(conf);
    
    // read PDB in MMTF format
    JavaPairRDD<String, StructureDataInterface> pdb = MmtfReader.readSequenceFile(path, sc);
   
    // filter by sequence identity subset
    int sequenceIdentity = 20;
    double resolution = 2.0;
    pdb = pdb.filter(new Pisces(sequenceIdentity, resolution));
    
    // find ATP interactions within 3 Angstroms
    GroupInteractionExtractor finder = new GroupInteractionExtractor("ATP", 3);
    Dataset<Row> interactions = finder.getDataset(pdb).cache();
    
    // TODO add a line to only analyze interactions 
    // with the oxygens in the terminal phosphate group of ATP
    // (O1G, O2G, O3G)
    // Tip: Google SQL LIKE
    interactions = interactions.filter("atom1 LIKE('O%G')");
    
    // show the data schema of the dataset and some data
       interactions.printSchema();
       interactions.show(20);
       
       long n = interactions.count();
       System.out.println("# interactions: " + n);
       
       System.out.println("Top interacting groups");

       Dataset<Row> topGroups = interactions
       		.groupBy("residue2")
       		.count();
       
       topGroups
       .sort(col("count").desc()) // sort descending by count
       .show(10);
       
       System.out.println("Top interacting group/atoms types");

       Dataset<Row> topGroupsAndAtoms = interactions
       		.groupBy("residue2","atom2")
       		.count();

       topGroupsAndAtoms
       .withColumn("frequency", col("count").divide(n)) // add column with frequency of occurrence
       .sort(col("frequency").desc()) // sort descending
       .show(10);

       long end = System.nanoTime();
    
    System.out.println("Time:     " + (end-start)/1E9 + "sec.");
    
    sc.close();
}
 
Example 13
Source File: Basic.java    From learning-spark-with-java with MIT License 4 votes vote down vote up
public static void main(String[] args) {
    SparkSession spark = SparkSession
        .builder()
        .appName("Dataset-Basic")
        .master("local[4]")
        .getOrCreate();

    List<Integer> data = Arrays.asList(10, 11, 12, 13, 14, 15);
    Dataset<Integer> ds = spark.createDataset(data, Encoders.INT());

    System.out.println("*** only one column, and it always has the same name");
    ds.printSchema();

    ds.show();

    System.out.println("*** values > 12");

    // the harder way to filter
    Dataset<Integer> ds2 = ds.filter((Integer value) -> value > 12);

    ds.show();

    List<Tuple3<Integer, String, String>> tuples =
        Arrays.asList(
            new Tuple3<>(1, "one", "un"),
            new Tuple3<>(2, "two", "deux"),
            new Tuple3<>(3, "three", "trois"));

    Encoder<Tuple3<Integer, String, String>> encoder =
        Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.STRING());

    Dataset<Tuple3<Integer, String, String>> tupleDS =
        spark.createDataset(tuples, encoder);

    System.out.println("*** Tuple Dataset types");
    tupleDS.printSchema();

    // the tuple columns have unfriendly names, but you can use them to query
    System.out.println("*** filter by one column and fetch another");
    tupleDS.where(col("_1").gt(2)).select(col("_2"), col("_3")).show();

    spark.stop();
}
 
Example 14
Source File: InListDeriver.java    From envelope with Apache License 2.0 4 votes vote down vote up
@Override
public Dataset<Row> derive(Map<String, Dataset<Row>> dependencies) {

  Dataset<Row> target = getStepDataFrame(dependencies);
  if (target.columns().length < 1) {
    throw new RuntimeException("Targeted step, '" + stepName + ",' has no columns");
  }

  try {
    String targetField = fieldName == null ? target.columns()[0] : fieldName;
    Column targetColumn = target.col(targetField);

    LOGGER.debug("Targeting '{}[{}]'", stepName, targetField);

    // If the IN list is inline, there is no batch
    if (inList != null) {
      LOGGER.debug("IN list is inline");
      return target.filter(targetColumn.isin(inList.toArray()));
    }

    // Otherwise, collect the values from the reference, executed within the batch
    else {
      LOGGER.trace("IN list is a reference");
      Dataset<Row> reference = dependencies.get(refStepName);
      String referenceField = refFieldName == null ? reference.columns()[0] : refFieldName;

      LOGGER.debug("Referencing using {}[{}]", refStepName, referenceField);
      Column referenceColumn = reference.col(referenceField);

      Iterator<Row> referenceIterator = reference.select(referenceColumn).distinct().toLocalIterator();
      this.inList = new ArrayList<>();
      long counter = 0;

      // Set up the batch collector
      JavaRDD<Row> unionRDD = new JavaSparkContext(Contexts.getSparkSession().sparkContext()).emptyRDD();
      Dataset<Row> union = Contexts.getSparkSession().createDataFrame(unionRDD, target.schema());

      while (referenceIterator.hasNext()) {
        // Flush the batch
        if (counter == batchSize) {
          LOGGER.trace("Flushing batch");
          union = union.union(target.filter(targetColumn.isin(inList.toArray())));
          inList.clear();
          counter = 0L;
        }

        // Gather the elements of the IN list from the reference
        inList.add(referenceIterator.next().get(0));
        counter++;
      }

      // If the selection is under the batch threshold
      if (union.rdd().isEmpty()) {
        return target.filter(targetColumn.isin(inList.toArray()));
      }

      // Flush any remaining IN list values
      else {
        return union.union(target.filter(targetColumn.isin(inList.toArray())));
      }
    }
  } catch (Throwable ae) {
    throw new RuntimeException("Error executing IN list filtering", ae);
  }

}