Java Code Examples for org.apache.spark.sql.Dataset#withColumn()

The following examples show how to use org.apache.spark.sql.Dataset#withColumn() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TypeCastStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
private Dataset castColumn(Dataset<Row> dataset, String columnToCast, String castColumnName, DataType newDataType, String parseFormat) {

        Dataset<Row> newDataset = dataset;

        if(newDataType.equals(DataTypes.DateType)) {
            if(parseFormat != null && !parseFormat.equals("")) {
                // parse format given in config, so use it
                newDataset = dataset.withColumn(castColumnName, when(callUDF("isalong", dataset.col(columnToCast)), to_date(from_unixtime(callUDF("timestampstringtolong", dataset.col(columnToCast))), parseFormat)).otherwise(to_date(dataset.col(columnToCast), parseFormat)));
            } else {
                newDataset = dataset.withColumn(castColumnName, when(callUDF("isalong", dataset.col(columnToCast)), to_date(from_unixtime(callUDF("timestampstringtolong", dataset.col(columnToCast))))).otherwise(to_date(dataset.col(columnToCast))));
            }
        } else if(newDataType.equals(DataTypes.TimestampType)) {
            if(parseFormat != null && !parseFormat.equals("")) {
                // parse format given in config, so use it
                newDataset = dataset.withColumn(castColumnName, when(callUDF("isalong", dataset.col(columnToCast)), to_timestamp(from_unixtime(callUDF("timestampstringtolong", dataset.col(columnToCast))), parseFormat)).otherwise(to_timestamp(dataset.col(columnToCast), parseFormat)));
            } else {
                newDataset = dataset.withColumn(castColumnName, when(callUDF("isalong", dataset.col(columnToCast)), to_timestamp(from_unixtime(callUDF("timestampstringtolong", dataset.col(columnToCast))))).otherwise(to_timestamp(dataset.col(columnToCast))));
            }
        } else {
            newDataset = dataset.withColumn(castColumnName, dataset.col(columnToCast).cast(newDataType));
        }

        return newDataset;
    }
 
Example 2
Source File: VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java    From iceberg with Apache License 2.0 5 votes vote down vote up
private static Dataset<Row> withFloatColumnDictEncoded(Dataset<Row> df) {
  return df.withColumn(
      "floatCol",
      when(modColumn(9, 0), lit(0.0f))
          .when(modColumn(9, 1), lit(1.0f))
          .when(modColumn(9, 2), lit(2.0f))
          .when(modColumn(9, 3), lit(3.0f))
          .when(modColumn(9, 4), lit(4.0f))
          .when(modColumn(9, 5), lit(5.0f))
          .when(modColumn(9, 6), lit(6.0f))
          .when(modColumn(9, 7), lit(7.0f))
          .when(modColumn(9, 8), lit(8.0f)));
}
 
Example 3
Source File: AdvancedSearchDataset.java    From mmtf-spark with Apache License 2.0 5 votes vote down vote up
/**
 * Runs an RCSB PDB Advanced Search web service using an XML query description.
 * The returned dataset contains the following fields dependent on the query type:
 * <pre> 
 *   structureId, e.g., 1STP
 *   structureChainId, e.g., 4HHB.A
 *   ligandId, e.g., HEM
 * </pre>
 *   
 * @param xmlQuery RCSB PDB Advanced Query XML
 * @return dataset of ids
 * @throws IOException
 */
public static Dataset<Row> getDataset(String xmlQuery) throws IOException {
    // run advanced query
    List<String> results = AdvancedQueryService.postQuery(xmlQuery);

    // convert list of lists to a dataframe
    SparkSession spark = SparkSession.builder().getOrCreate();

    // handle 3 types of results based on length of string:
    //   structureId: 4 (e.g., 4HHB)
    //   structureEntityId: > 4 (e.g., 4HHB:1)
    //   entityId: < 4 (e.g., HEM)
    Dataset<Row> ds = null;
    if (results.size() > 0) {
        if (results.get(0).length() > 4) {
            ds = spark.createDataset(results, Encoders.STRING()).toDF("structureEntityId");
        
            // if results contain an entity id, e.g., 101M:1, then map entityId to structureChainId
            ds = ds.withColumn("structureId", substring_index(col("structureEntityId"), ":", 1));
            ds = ds.withColumn("entityId", substring_index(col("structureEntityId"), ":", -1));
          
            Dataset<Row> mapping = getEntityToChainId();
            ds = ds.join(mapping, ds.col("structureId").equalTo(mapping.col("structureId")).and(ds.col("entityId").equalTo(mapping.col("entity_id"))));
        
            ds = ds.select("structureChainId");
        } else if (results.get(0).length() < 4) {
            ds = spark.createDataset(results, Encoders.STRING()).toDF("ligandId");
        } else {
            ds = spark.createDataset(results, Encoders.STRING()).toDF("structureId");
        }
    }

    return ds;
}
 
Example 4
Source File: PdbjMineDataset.java    From mmtf-spark with Apache License 2.0 5 votes vote down vote up
/**
 * Fetches data using the PDBj Mine 2 SQL service
 * 
 * @param sqlQuery
 *            query in SQL format
 * @throws IOException
 */
public static Dataset<Row> getDataset(String sqlQuery) throws IOException {
	String encodedSQL = URLEncoder.encode(sqlQuery, "UTF-8");

	URL u = new URL(SERVICELOCATION + "?format=csv&q=" + encodedSQL);
	InputStream in = u.openStream();

	// save as a temporary CSV file
	Path tempFile = Files.createTempFile(null, ".csv");
	Files.copy(in, tempFile, StandardCopyOption.REPLACE_EXISTING);
	in.close();

	SparkSession spark = SparkSession.builder().getOrCreate();

	// load temporary CSV file into Spark dataset
	Dataset<Row> ds = spark.read().format("csv").option("header", "true").option("inferSchema", "true")
			// .option("parserLib", "UNIVOCITY")
			.load(tempFile.toString());

	// rename/concatenate columns to assign
	// consistent primary keys to datasets
	List<String> columns = Arrays.asList(ds.columns());

	if (columns.contains("pdbid")) {
		// this project uses upper case pdbids
		ds = ds.withColumn("pdbid", upper(col("pdbid")));

		if (columns.contains("chain")) {
			ds = ds.withColumn("structureChainId", concat(col("pdbid"), lit("."), col("chain")));
			ds = ds.drop("pdbid", "chain");
		} else {
			ds = ds.withColumnRenamed("pdbid", "structureId");
		}
	}

	return ds;
}
 
Example 5
Source File: SimplePredictionFromTextFile.java    From net.jgp.labs.spark with Apache License 2.0 5 votes vote down vote up
private void start() {
  SparkSession spark = SparkSession.builder().appName(
      "Simple prediction from Text File").master("local").getOrCreate();

  spark.udf().register("vectorBuilder", new VectorBuilder(), new VectorUDT());

  String filename = "data/tuple-data-file.csv";
  StructType schema = new StructType(
      new StructField[] { new StructField("_c0", DataTypes.DoubleType, false,
          Metadata.empty()),
          new StructField("_c1", DataTypes.DoubleType, false, Metadata
              .empty()),
          new StructField("features", new VectorUDT(), true, Metadata
              .empty()), });

  Dataset<Row> df = spark.read().format("csv").schema(schema).option("header",
      "false")
      .load(filename);
  df = df.withColumn("valuefeatures", df.col("_c0")).drop("_c0");
  df = df.withColumn("label", df.col("_c1")).drop("_c1");
  df.printSchema();

  df = df.withColumn("features", callUDF("vectorBuilder", df.col(
      "valuefeatures")));
  df.printSchema();
  df.show();

  LinearRegression lr = new LinearRegression().setMaxIter(20);// .setRegParam(1).setElasticNetParam(1);

  // Fit the model to the data.
  LinearRegressionModel model = lr.fit(df);

  // Given a dataset, predict each point's label, and show the results.
  model.transform(df).show();

  LinearRegressionTrainingSummary trainingSummary = model.summary();
  System.out.println("numIterations: " + trainingSummary.totalIterations());
  System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary
      .objectiveHistory()));
  trainingSummary.residuals().show();
  System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
  System.out.println("r2: " + trainingSummary.r2());

  double intercept = model.intercept();
  System.out.println("Interesection: " + intercept);
  double regParam = model.getRegParam();
  System.out.println("Regression parameter: " + regParam);
  double tol = model.getTol();
  System.out.println("Tol: " + tol);
  Double feature = 7.0;
  Vector features = Vectors.dense(feature);
  double p = model.predict(features);

  System.out.println("Prediction for feature " + feature + " is " + p);
  System.out.println(8 * regParam + intercept);
}
 
Example 6
Source File: TestHoodieDeltaStreamer.java    From hudi with Apache License 2.0 5 votes vote down vote up
@Override
public Dataset<Row> apply(JavaSparkContext jsc, SparkSession sparkSession, Dataset<Row> rowDataset,
                          TypedProperties properties) {
  rowDataset.sqlContext().udf().register("distance_udf", new DistanceUDF(), DataTypes.DoubleType);
  return rowDataset.withColumn("haversine_distance", functions.callUDF("distance_udf", functions.col("begin_lat"),
      functions.col("end_lat"), functions.col("begin_lon"), functions.col("end_lat")));
}
 
Example 7
Source File: ColumnHashStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataSet, Map<String, Object> parameters, SparkRunnerConfig config) {

    //check if all variables that should be hashed actually exist, otherwise log a warning
    List<String> existingColumns = new ArrayList<>(Arrays.asList(dataSet.columns()));

    Configuration configuration = ConfigurationUtils.getInstance().getConfiguration(config);
    if(configuration != null) {
        PreprocessingConfiguration preprocessingConfiguration = configuration.getPreprocessingConfiguration();
        if(preprocessingConfiguration != null) {
            for(ColumnHashConfiguration chc : preprocessingConfiguration.getColumnHashConfiguration()) {
                if(chc.isHashColumn()) {
                    if(!existingColumns.contains(chc.getColumnName())) {
                        // log the fact that a column that should be hashed does not exist
                        BpmnaiLogger.getInstance().writeWarn("The column '" + chc.getColumnName() + "' is configured to be hashed, but does not exist in the data.");
                    } else {
                        dataSet = dataSet.withColumn(chc.getColumnName(), sha1(dataSet.col(chc.getColumnName())));
                        BpmnaiLogger.getInstance().writeInfo("The column '" + chc.getColumnName() + "' is being hashed.");
                    }
                }

            }
        }
    }

    if(config.isWriteStepResultsIntoFile()) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataSet, "column_hash_step", config);
    }

    return dataSet;
}
 
Example 8
Source File: DetermineProcessVariablesStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
private Dataset<Row> doVariableNameMapping(Dataset<Row> dataset, boolean writeStepResultIntoFile, SparkRunnerConfig config) {
    Map<String, String> variableNameMappings = new HashMap<>();

    // getting variable name mappings from configuration
    Configuration configuration = ConfigurationUtils.getInstance().getConfiguration(config);
    if(configuration != null) {
        PreprocessingConfiguration preprocessingConfiguration = configuration.getPreprocessingConfiguration();
        if(preprocessingConfiguration != null) {
            for(VariableNameMapping vm : preprocessingConfiguration.getVariableNameMappings()) {
                if(!vm.getOldName().equals("") && !vm.getNewName().equals("")) {
                    variableNameMappings.put(vm.getOldName(), vm.getNewName());
                } else {
                    BpmnaiLogger.getInstance().writeWarn("Ignoring variable name mapping '" + vm.getOldName() + "' -> '" + vm.getNewName() + "'.");
                }
            }
        }
    }

    // rename all variables
    for(String oldName : variableNameMappings.keySet()) {
        String newName = variableNameMappings.get(oldName);

        BpmnaiLogger.getInstance().writeInfo("Renaming variable '" + oldName + "' to '" + newName + "' as per user configuration.");

        dataset = dataset.withColumn(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME,
                when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME).equalTo(oldName), lit(newName))
                        .otherwise(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME)));
    }

    if(writeStepResultIntoFile) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataset, "variable_name_mapping", config);
    }

    return dataset;
}
 
Example 9
Source File: CheckApprovedStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataset, boolean writeStepResultIntoFile, String dataLevel,
              Map<String, Object> parameters) {

if (parameters == null) {
SparkImporterLogger.getInstance().writeWarn("No parameters found for the CheckEndtimeStep");
return dataset;
}			
	
String colName = (String) parameters.get("column");
		
dataset =  dataset.withColumn("approved2", functions.when(dataset.col(colName).equalTo("true"), "OK").otherwise("NOT OK"));
			
return dataset;
}
 
Example 10
Source File: VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java    From iceberg with Apache License 2.0 5 votes vote down vote up
private static Dataset<Row> withStringColumnDictEncoded(Dataset<Row> df) {
  return df.withColumn(
      "stringCol",
      when(pmod(col("longCol"), lit(9)).equalTo(lit(0)), lit("0"))
          .when(modColumn(9, 1), lit("1"))
          .when(modColumn(9, 2), lit("2"))
          .when(modColumn(9, 3), lit("3"))
          .when(modColumn(9, 4), lit("4"))
          .when(modColumn(9, 5), lit("5"))
          .when(modColumn(9, 6), lit("6"))
          .when(modColumn(9, 7), lit("7"))
          .when(modColumn(9, 8), lit("8")));
}
 
Example 11
Source File: DataframeCheckpointApp.java    From net.jgp.labs.spark with Apache License 2.0 5 votes vote down vote up
private void start() {
  SparkConf conf = new SparkConf()
      .setAppName("Checkpoint")
      .setMaster("local[*]");
  SparkContext sparkContext = new SparkContext(conf);

  // We need to specify where Spark will save the checkpoint file. It can
  // be
  // an HDFS location.
  sparkContext.setCheckpointDir("/tmp");
  SparkSession spark = SparkSession.builder()
      .appName("Checkpoint")
      .master("local[*]")
      .getOrCreate();

  String filename = "data/tuple-data-file.csv";
  Dataset<Row> df1 =
      spark.read().format("csv").option("inferSchema", "true")
          .option("header", "false")
          .load(filename);
  System.out.println("DF #1 - step #1: simple dump of the dataframe");
  df1.show();

  System.out.println("DF #2 - step #2: same as DF #1 - step #1");
  Dataset<Row> df2 = df1.checkpoint(false);
  df2.show();

  df1 = df1.withColumn("x", df1.col("_c0"));
  System.out.println(
      "DF #1 - step #2: new column x, which is the same as _c0");
  df1.show();

  System.out.println("DF #2 - step #2: no operation was done on df2");
  df2.show();
}
 
Example 12
Source File: VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java    From iceberg with Apache License 2.0 5 votes vote down vote up
private static Dataset<Row> withDateColumnDictEncoded(Dataset<Row> df) {
  return df.withColumn(
      "dateCol",
      when(modColumn(9, 0), to_date(lit("04/12/2019"), "MM/dd/yyyy"))
          .when(modColumn(9, 1), to_date(lit("04/13/2019"), "MM/dd/yyyy"))
          .when(modColumn(9, 2), to_date(lit("04/14/2019"), "MM/dd/yyyy"))
          .when(modColumn(9, 3), to_date(lit("04/15/2019"), "MM/dd/yyyy"))
          .when(modColumn(9, 4), to_date(lit("04/16/2019"), "MM/dd/yyyy"))
          .when(modColumn(9, 5), to_date(lit("04/17/2019"), "MM/dd/yyyy"))
          .when(modColumn(9, 6), to_date(lit("04/18/2019"), "MM/dd/yyyy"))
          .when(modColumn(9, 7), to_date(lit("04/19/2019"), "MM/dd/yyyy"))
          .when(modColumn(9, 8), to_date(lit("04/20/2019"), "MM/dd/yyyy")));
}
 
Example 13
Source File: VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java    From iceberg with Apache License 2.0 5 votes vote down vote up
private static Dataset<Row> withDecimalColumnDictEncoded(Dataset<Row> df) {
  Types.DecimalType type = Types.DecimalType.of(20, 5);
  return df.withColumn(
      "decimalCol",
      when(modColumn(9, 0), bigDecimal(type, 0))
          .when(modColumn(9, 1), bigDecimal(type, 1))
          .when(modColumn(9, 2), bigDecimal(type, 2))
          .when(modColumn(9, 3), bigDecimal(type, 3))
          .when(modColumn(9, 4), bigDecimal(type, 4))
          .when(modColumn(9, 5), bigDecimal(type, 5))
          .when(modColumn(9, 6), bigDecimal(type, 6))
          .when(modColumn(9, 7), bigDecimal(type, 7))
          .when(modColumn(9, 8), bigDecimal(type, 8)));
}
 
Example 14
Source File: AWSDmsTransformer.java    From hudi with Apache License 2.0 5 votes vote down vote up
@Override
public Dataset<Row> apply(JavaSparkContext jsc, SparkSession sparkSession, Dataset<Row> rowDataset,
    TypedProperties properties) {
  Option<String> opColumnOpt = Option.fromJavaOptional(
      Arrays.stream(rowDataset.columns()).filter(c -> c.equals(AWSDmsAvroPayload.OP_FIELD)).findFirst());
  if (opColumnOpt.isPresent()) {
    return rowDataset;
  } else {
    return rowDataset.withColumn(AWSDmsAvroPayload.OP_FIELD, lit(""));
  }
}
 
Example 15
Source File: VectorizedReadDictionaryEncodedFlatParquetDataBenchmark.java    From iceberg with Apache License 2.0 5 votes vote down vote up
private static Dataset<Row> withIntColumnDictEncoded(Dataset<Row> df) {
  return df.withColumn(
      "intCol",
      when(modColumn(9, 0), lit(0))
          .when(modColumn(9, 1), lit(1))
          .when(modColumn(9, 2), lit(2))
          .when(modColumn(9, 3), lit(3))
          .when(modColumn(9, 4), lit(4))
          .when(modColumn(9, 5), lit(5))
          .when(modColumn(9, 6), lit(6))
          .when(modColumn(9, 7), lit(7))
          .when(modColumn(9, 8), lit(8)));
}
 
Example 16
Source File: TypeCastStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataset, Map<String, Object> parameters, SparkRunnerConfig config) {

    // get variables
    Map<String, String> varMap = (Map<String, String>) SparkBroadcastHelper.getInstance().getBroadcastVariable(SparkBroadcastHelper.BROADCAST_VARIABLE.PROCESS_VARIABLES_ESCALATED);

    List<StructField> datasetFields = Arrays.asList(dataset.schema().fields());

    List<ColumnConfiguration> columnConfigurations = null;
    List<VariableConfiguration> variableConfigurations = null;

    Configuration configuration = ConfigurationUtils.getInstance().getConfiguration(config);
    if(configuration != null) {
        PreprocessingConfiguration preprocessingConfiguration = configuration.getPreprocessingConfiguration();
        columnConfigurations = preprocessingConfiguration.getColumnConfiguration();
        variableConfigurations = preprocessingConfiguration.getVariableConfiguration();
    }

    Map<String, ColumnConfiguration> columnTypeConfigMap = new HashMap<>();
    Map<String, VariableConfiguration> variableTypeConfigMap = new HashMap<>();

    if(columnConfigurations != null) {
        for(ColumnConfiguration cc : columnConfigurations) {
            columnTypeConfigMap.put(cc.getColumnName(), cc);
        }
    }

    if(variableConfigurations != null) {
        for(VariableConfiguration vc : variableConfigurations) {
            variableTypeConfigMap.put(vc.getVariableName(), vc);
        }
    }

    for(String column : dataset.columns()) {

        // skip revision columns as they are handled for each variable column
        if(column.endsWith("_rev")) {
            continue;
        }

        DataType newDataType = null;
        boolean isVariableColumn  = false;
        String configurationDataType = null;
        String configurationParseFormat = null;

        if(variableTypeConfigMap.keySet().contains(column)) {
            // was initially a variable
            configurationDataType = variableTypeConfigMap.get(column).getVariableType();
            configurationParseFormat = variableTypeConfigMap.get(column).getParseFormat();
            if (config.getPipelineMode().equals(BpmnaiVariables.PIPELINE_MODE_LEARN)) {
                isVariableColumn = varMap.keySet().contains(column);
            } else {
                isVariableColumn = true;
            }
        } else if(columnTypeConfigMap.keySet().contains(column)){
            // was initially a column
            configurationDataType = columnTypeConfigMap.get(column).getColumnType();
            configurationParseFormat = columnTypeConfigMap.get(column).getParseFormat();
        }

        newDataType = mapDataType(datasetFields, column, configurationDataType);

        // only check for cast errors if dev feature is enabled and if a change in the datatype has been done
        if(config.isDevTypeCastCheckEnabled() && !newDataType.equals(getCurrentDataType(datasetFields, column))) {
            // add a column with casted value to be able to check the cast results
            dataset = castColumn(dataset, column, column+"_casted", newDataType, configurationParseFormat);

            // add a column for cast results and write CAST_ERROR? in it if there might be a cast error
            dataset = dataset.withColumn(column+"_castresult",
                    when(dataset.col(column).isNotNull().and(dataset.col(column).notEqual(lit(""))),
                            when(dataset.col(column+"_casted").isNull(), lit("CAST_ERROR?"))
                                    .otherwise(lit(""))
                    ).otherwise(lit(""))
            );
            dataset.cache();

            // check for cast errors and write warning to application log
            if(dataset.filter(column+"_castresult == 'CAST_ERROR?'").count() > 0) {
                BpmnaiLogger.getInstance().writeWarn("Column '" + column + "' seems to have cast errors. Please check the data type (is defined as '" + configurationDataType + "')");
            } else {
                // drop help columns as there are no cast errors for this column and rename casted column to actual column name
                dataset = dataset.drop(column, column+"_castresult").withColumnRenamed(column+"_casted", column);
            }
        } else {
            // cast without checking the cast result, entries are null is spark can't cast it
            dataset = castColumn(dataset, column, column, newDataType, configurationParseFormat);
        }

        // cast revision columns for former variables, revisions columns only exist on process level
        if(config.getDataLevel().equals(BpmnaiVariables.DATA_LEVEL_PROCESS) && config.isRevCountEnabled() && isVariableColumn) {
            dataset = dataset.withColumn(column+"_rev", dataset.col(column+"_rev").cast("integer"));
        }
    }

    if(config.isWriteStepResultsIntoFile()) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataset, "type_cast_columns", config);
    }

    //return preprocessed data
    return dataset;
}
 
Example 17
Source File: AddVariableColumnsStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
private Dataset<Row> doAddVariableColumns(Dataset<Row> dataset, boolean writeStepResultIntoFile, String dataLevel, SparkRunnerConfig config) {
    Map<String, String> varMap = (Map<String, String>) SparkBroadcastHelper.getInstance().getBroadcastVariable(SparkBroadcastHelper.BROADCAST_VARIABLE.PROCESS_VARIABLES_ESCALATED);
    Set<String> variables = varMap.keySet();

    for(String v : variables) {
        dataset = dataset.withColumn(v, when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME).equalTo(v),
                when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).equalTo("string"), dataset.col(BpmnaiVariables.VAR_TEXT))
                        .when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).equalTo("null"), dataset.col(BpmnaiVariables.VAR_TEXT))
                        .when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).equalTo("boolean"), dataset.col(BpmnaiVariables.VAR_LONG))
                        .when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).equalTo("integer"), dataset.col(BpmnaiVariables.VAR_LONG))
                        .when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).equalTo("long"), dataset.col(BpmnaiVariables.VAR_LONG))
                        .when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).equalTo("double"), dataset.col(BpmnaiVariables.VAR_DOUBLE))
                        .when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).equalTo("date"), dataset.col(BpmnaiVariables.VAR_LONG))
                        .otherwise(dataset.col(BpmnaiVariables.VAR_TEXT2)))
                .otherwise(null));

        //rev count is only relevant on process level
        if(dataLevel.equals(BpmnaiVariables.DATA_LEVEL_PROCESS) && config.isRevCountEnabled()) {
            dataset = dataset.withColumn(v+"_rev",
                    when(dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME).equalTo(v), dataset.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_REVISION))
                            .otherwise("0"));
        }
    }

    //drop unnecesssary columns
    dataset = dataset.drop(
            BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE,
            BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_REVISION,
            BpmnaiVariables.VAR_DOUBLE,
            BpmnaiVariables.VAR_LONG,
            BpmnaiVariables.VAR_TEXT,
            BpmnaiVariables.VAR_TEXT2);

    if(!config.isDevProcessStateColumnWorkaroundEnabled()) {
        dataset = dataset.drop(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME);
    }

    if(writeStepResultIntoFile) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataset, "add_var_columns", config);
    }

    //return preprocessed data
    return dataset;
}
 
Example 18
Source File: DataFilterOnActivityStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
/**
 * @param dataSet the incoming dataset for this processing step
 * @param parameters
 * @return the filtered DataSet
 */
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataSet, Map<String, Object> parameters, SparkRunnerConfig config) {
    // any parameters set?
    if (parameters == null || parameters.size() == 0) {
        BpmnaiLogger.getInstance().writeWarn("No parameters found for the DataFilterOnActivityStep");
        return dataSet;
    }

    // get query parameter
    String query = (String) parameters.get("query");
    BpmnaiLogger.getInstance().writeInfo("Filtering data with activity instance filter query: " + query + ".");

    // save size of initial dataset for log
    dataSet.cache();
    Long initialDSCount = dataSet.count();

    // repartition by process instance and order by start_time for this operation
    dataSet = dataSet.repartition(dataSet.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID)).sortWithinPartitions(BpmnaiVariables.VAR_START_TIME);

    // we temporarily store variable updates (rows with a var type set) separately.
    Dataset<Row> variables = dataSet.filter(col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).isNotNull());
    //find first occurrence of activity instance
    final Dataset<Row> dsTmp = dataSet.filter(dataSet.col(BpmnaiVariables.VAR_ACT_ID).equalTo(query)).filter(dataSet.col(BpmnaiVariables.VAR_END_TIME).isNull()); //TODO: ENSURING THAT THIS ISN'T A VARIABLE ROW

    // now we look for the first occurrence of the activity id contained in "query". The result comprises of a dataset of corresponding activity instances.
    final Dataset<Row> dsActivityInstances = dataSet.filter(dataSet.col(BpmnaiVariables.VAR_ACT_ID).like(query)).filter(dataSet.col(BpmnaiVariables.VAR_END_TIME).isNull()); //TODO: ENSURING THAT THIS ISN'T A VARIABLE ROW

    // we slim the resulting dataset down: only the activity instances process id and the instances start time are relevant.
    List<Row> activityRows = dsActivityInstances.select(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID, BpmnaiVariables.VAR_START_TIME).collectAsList();
    Map<String, String> activities = activityRows.stream().collect(Collectors.toMap(
            r -> r.getAs(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID), r -> r.getAs(BpmnaiVariables.VAR_START_TIME)));
    // broadcasting the PID - Start time Map to use it in a user defined function
    SparkBroadcastHelper.getInstance().broadcastVariable(SparkBroadcastHelper.BROADCAST_VARIABLE.PROCESS_INSTANCE_TIMESTAMP_MAP, activities);

    // now we have to select for each process instance in our inital dataset all events that happend before the first occurence of our selected activity.
    // We first narrow it down to the process instances in question
    Dataset<Row> selectedProcesses = dataSet.filter(col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID).isin(activities.keySet().toArray()));
    // Then, we mark all events that should be removed
    Dataset<Row> activityDataSet = selectedProcesses.withColumn("data_filter_on_activity",
            callUDF("activityBeforeTimestamp",
                    selectedProcesses.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID),
                    selectedProcesses.col(BpmnaiVariables.VAR_START_TIME)));
    // And we keep the rest
    activityDataSet = activityDataSet.filter(col("data_filter_on_activity").like("TRUE"));
    // Clean up
    activityDataSet = activityDataSet.drop("data_filter_on_activity");

    // However, we lost all variable updates in this approach, so now we add the variables in question to the dataset
    // first, we narrow it down to keep only variables that have a corresponding activity instance
    activityDataSet = activityDataSet.withColumnRenamed(BpmnaiVariables.VAR_ACT_INST_ID, BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT");

    variables = variables.join(activityDataSet.select(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT").distinct(), variables.col(BpmnaiVariables.VAR_ACT_INST_ID).equalTo(activityDataSet.col(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT")),"inner");

    activityDataSet = activityDataSet.withColumnRenamed(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT", BpmnaiVariables.VAR_ACT_INST_ID);
    variables = variables.drop(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT");
    dataSet = activityDataSet.union(variables);

    dataSet.cache();
    BpmnaiLogger.getInstance().writeInfo("DataFilterOnActivityStep: The filtered DataSet contains "+dataSet.count()+" rows, (before: "+ initialDSCount+" rows)");

    if (config.isWriteStepResultsIntoFile()) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataSet, "data_filter_on_activity_step", config);
    }

    return dataSet;


}
 
Example 19
Source File: PdbToUniProt.java    From mmtf-spark with Apache License 2.0 4 votes vote down vote up
/**
 * Returns an up-to-date dataset of PDB to UniProt 
 * residue-level mappings for a list of ids.
 * Valid ids are either a list of pdbIds (e.g. 1XYZ) or pdbId.chainId (e.g., 1XYZ.A).
 * This method reads a cached file and downloads updates.
 * 
 * @param ids list of pdbIds or pdbId.chainIds
 * @return dataset of PDB to UniProt residue-level mappings
 * @throws IOException
 */
public static Dataset<Row> getResidueMappings(List<String> ids) throws IOException {
    SparkSession spark = SparkSession.builder().getOrCreate();
    
    boolean withChainId = ids.size() > 0 && ids.get(0).length() > 4;
    
    // create dataset of ids
    Dataset<Row> df = spark.createDataset(ids, Encoders.STRING()).toDF("id");
    // get cached mappings
    Dataset<Row> mapping = getCachedResidueMappings();  
    
    // dataset for non-cached mappings
    Dataset<Row> notCached = null;
    // dataset with PDB Ids to be downloaded
    Dataset<Row> toDownload = null; 
    
    if (withChainId) {
        // get subset of requested ids from cached dataset
        mapping = mapping.join(df, mapping.col("structureChainId").equalTo(df.col("id"))).drop("id");
        // get ids that are not in the cached dataset
        notCached = df.join(mapping, df.col("id").equalTo(mapping.col("structureChainId")), "left_anti").cache(); 
        // create dataset of PDB Ids to be downloaded
        toDownload = notCached.withColumn("id", col("id").substr(0, 4)).distinct().cache();
    } else {
        // get subset of requested ids from cached dataset
        mapping = mapping.withColumn("pdbId", col("structureChainId").substr(0, 4));
        mapping = mapping.join(df, mapping.col("pdbId").equalTo(df.col("id"))).drop("id");
        // create dataset of PDB Ids to be downloaded
        toDownload = df.join(mapping, df.col("id").equalTo(mapping.col("pdbId")), "left_anti").distinct().cache();
        mapping = mapping.drop("pdbId");
    }
    
    toDownload = toDownload.distinct().cache();
        
    // download data that are not in the cache
    if (toDownload.count() > 0) {
        Dataset<Row> unpData = getChainMappings().select("structureId").distinct();
        toDownload = toDownload.join(unpData, toDownload.col("id").equalTo(unpData.col("structureId"))).drop("structureId").cache();
        System.out.println("Downloading mapping for " + toDownload.count() + " PDB structures.");
        Dataset<Row> downloadedData = downloadData(toDownload);
  
        // since data are downloaded for all chains in structure, make sure to only include the requested chains.
        if (withChainId) {
            downloadedData = downloadedData.join(notCached, downloadedData.col("structureChainId").equalTo(notCached.col("id"))).drop("id");
        }
        mapping = mapping.union(downloadedData);
    }
    
    return mapping;
}
 
Example 20
Source File: ProteinFoldDatasetCreator.java    From mmtf-spark with Apache License 2.0 3 votes vote down vote up
/**
 * Adds a column "foldType" with three major secondary structure classes: 
 * "alpha", "beta", "alpha+beta", and "other" based upon the fraction of alpha/beta content.
 * 
 * The simplified syntax used in this method relies on two static imports:
 * import static org.apache.spark.sql.functions.when;
    * import static org.apache.spark.sql.functions.col;
    * 
 * @param data input dataset with alpha, beta composition
 * @param minThreshold below this threshold, the secondary structure type is ignored
 * @param maxThreshold above this threshold, the secondary structure type is assigned
 * @return
 */
public static Dataset<Row> addProteinFoldType(Dataset<Row> data, double minThreshold, double maxThreshold) {
	return data.withColumn("foldType",
			when(col("alpha").gt(maxThreshold).and(col("beta").lt(minThreshold)), "alpha")
			.when(col("beta").gt(maxThreshold).and(col("alpha").lt(minThreshold)), "beta")
			.when(col("alpha").gt(maxThreshold).and(col("beta").gt(maxThreshold)), "alpha+beta")
			.otherwise("other")
			);
}