Java Code Examples for org.apache.spark.sql.Dataset#withColumnRenamed()

The following examples show how to use org.apache.spark.sql.Dataset#withColumnRenamed() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: WriteToDiscStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataset, Map<String, Object> parameters, SparkRunnerConfig config) {
	
    // remove spaces from column names as parquet does not support them
    for(String columnName : dataset.columns()) {
        if(columnName.contains(" ")) {
            String newColumnName = columnName.replace(' ', '_');
            dataset = dataset.withColumnRenamed(columnName, newColumnName);
        }
    }

    dataset.cache();
    BpmnaiUtils.getInstance().writeDatasetToParquet(dataset, "result", config);

    if(config.isGenerateResultPreview()) {
        dataset.limit(config.getResultPreviewLineCount()).write().mode(SaveMode.Overwrite).saveAsTable(BpmnaiVariables.RESULT_PREVIEW_TEMP_TABLE);
    }

    return dataset;
}
 
Example 2
Source File: ParseJSONDeriver.java    From envelope with Apache License 2.0 6 votes vote down vote up
@Override
public Dataset<Row> derive(Map<String, Dataset<Row>> dependencies) {
  String parsedStructTemporaryFieldName = "__parsed_json";

  Dataset<Row> dependency = dependencies.get(stepName);

  Dataset<Row> parsed = dependency.select(
      functions.from_json(new Column(fieldName), schema, options).as(parsedStructTemporaryFieldName));

  if (asStruct) {
    return parsed.withColumnRenamed(parsedStructTemporaryFieldName, structFieldName);
  }
  else {
    for (StructField parsedField : schema.fields()) {
      parsed = parsed.withColumn(
          parsedField.name(), new Column(parsedStructTemporaryFieldName + "." + parsedField.name()));
    }

    return parsed.drop(parsedStructTemporaryFieldName);
  }
}
 
Example 3
Source File: SparkMLHouses.java    From -Data-Stream-Development-with-Apache-Spark-Kafka-and-Spring-Boot with MIT License 5 votes vote down vote up
public static void main(String[] args) throws InterruptedException, StreamingQueryException {

                System.setProperty("hadoop.home.dir", HADOOP_HOME_DIR_VALUE);

                // * the schema can be written on disk, and read from disk
                // * the schema is not mandatory to be complete, it can contain only the needed fields    
                StructType HOUSES_SCHEMA = 
                       new StructType()
                           .add("House", LongType, true)
                           .add("Taxes", LongType, true)
                           .add("Bedrooms", LongType, true)
                           .add("Baths", FloatType, true)
                           .add("Quadrant", LongType, true)
                           .add("NW", StringType, true)
                           .add("Price($)", LongType, false)
                           .add("Size(sqft)", LongType, false)
                           .add("lot", LongType, true);

                final SparkConf conf = new SparkConf()
                    .setMaster(RUN_LOCAL_WITH_AVAILABLE_CORES)
                    .setAppName(APPLICATION_NAME)
                    .set("spark.sql.caseSensitive", CASE_SENSITIVE);

                SparkSession sparkSession = SparkSession.builder()
                    .config(conf)
                    .getOrCreate();

                Dataset<Row> housesDF = sparkSession.read()
                     .schema(HOUSES_SCHEMA)
                     .json(HOUSES_FILE_PATH);
             
                // Gathering Data				
                Dataset<Row> gatheredDF = housesDF.select(col("Taxes"), 
                    col("Bedrooms"), col("Baths"),
                    col("Size(sqft)"), col("Price($)"));
                
                // Data Preparation  
                Dataset<Row> labelDF = gatheredDF.withColumnRenamed("Price($)", "label");
                
                Imputer imputer = new Imputer()
                    // .setMissingValue(1.0d)
                    .setInputCols(new String[] { "Baths" })
                    .setOutputCols(new String[] { "~Baths~" });

                VectorAssembler assembler = new VectorAssembler()
                    .setInputCols(new String[] { "Taxes", "Bedrooms", "~Baths~", "Size(sqft)" })
                    .setOutputCol("features");
                
                // Choosing a Model               
                LinearRegression linearRegression = new LinearRegression();
                linearRegression.setMaxIter(1000);

                Pipeline pipeline = new Pipeline()
                                .setStages(new PipelineStage[] {
                                    imputer, assembler, linearRegression 
                                });

                // Training The Data
                Dataset<Row>[] splitDF = labelDF.randomSplit(new double[] { 0.8, 0.2 });

                Dataset<Row> trainDF = splitDF[0];
                Dataset<Row> evaluationDF = splitDF[1];

                PipelineModel pipelineModel = pipeline.fit(trainDF);
                
                // Evaluation 
                Dataset<Row> predictionsDF = pipelineModel.transform(evaluationDF);

                predictionsDF.show(false);

                Dataset<Row> forEvaluationDF = predictionsDF.select(col("label"), 
                    col("prediction"));

                RegressionEvaluator evaluteR2 = new RegressionEvaluator().setMetricName("r2");
                RegressionEvaluator evaluteRMSE = new RegressionEvaluator().setMetricName("rmse");

                double r2 = evaluteR2.evaluate(forEvaluationDF);
                double rmse = evaluteRMSE.evaluate(forEvaluationDF);

                logger.info("---------------------------");
                logger.info("R2 =" + r2);
                logger.info("RMSE =" + rmse);
                logger.info("---------------------------");
        }
 
Example 4
Source File: PdbjMineDataset.java    From mmtf-spark with Apache License 2.0 5 votes vote down vote up
/**
 * Fetches data using the PDBj Mine 2 SQL service
 * 
 * @param sqlQuery
 *            query in SQL format
 * @throws IOException
 */
public static Dataset<Row> getDataset(String sqlQuery) throws IOException {
	String encodedSQL = URLEncoder.encode(sqlQuery, "UTF-8");

	URL u = new URL(SERVICELOCATION + "?format=csv&q=" + encodedSQL);
	InputStream in = u.openStream();

	// save as a temporary CSV file
	Path tempFile = Files.createTempFile(null, ".csv");
	Files.copy(in, tempFile, StandardCopyOption.REPLACE_EXISTING);
	in.close();

	SparkSession spark = SparkSession.builder().getOrCreate();

	// load temporary CSV file into Spark dataset
	Dataset<Row> ds = spark.read().format("csv").option("header", "true").option("inferSchema", "true")
			// .option("parserLib", "UNIVOCITY")
			.load(tempFile.toString());

	// rename/concatenate columns to assign
	// consistent primary keys to datasets
	List<String> columns = Arrays.asList(ds.columns());

	if (columns.contains("pdbid")) {
		// this project uses upper case pdbids
		ds = ds.withColumn("pdbid", upper(col("pdbid")));

		if (columns.contains("chain")) {
			ds = ds.withColumn("structureChainId", concat(col("pdbid"), lit("."), col("chain")));
			ds = ds.drop("pdbid", "chain");
		} else {
			ds = ds.withColumnRenamed("pdbid", "structureId");
		}
	}

	return ds;
}
 
Example 5
Source File: DrugBankDataset.java    From mmtf-spark with Apache License 2.0 5 votes vote down vote up
/**
 * Removes spaces from column names to ensure compatibility with parquet
 * files.
 *
 * @param original
 *            dataset
 * @return dataset with columns renamed
 */
private static Dataset<Row> removeSpacesFromColumnNames(Dataset<Row> original) {

    for (String existingName : original.columns()) {
        String newName = existingName.replaceAll(" ", "");
        original = original.withColumnRenamed(existingName, newName);
    }

    return original;
}
 
Example 6
Source File: AggregateActivityInstancesStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataset, Map<String, Object> parameters, SparkRunnerConfig config) {

    //apply first and processState aggregator
    Map<String, String> aggregationMap = new HashMap<>();
    for(String column : dataset.columns()) {
        if(column.equals(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID)) {
            continue;
        } else if(column.equals(BpmnaiVariables.VAR_DURATION) || column.endsWith("_rev")) {
            aggregationMap.put(column, "max");
        } else if(column.equals(BpmnaiVariables.VAR_STATE)) {
            aggregationMap.put(column, "ProcessState");
        } else if(column.equals(BpmnaiVariables.VAR_ACT_INST_ID)) {
            //ignore it, as we aggregate by it
            continue;
        } else {
            aggregationMap.put(column, "AllButEmptyString");
        }
    }

    //first aggregation
    //activity level, take only processInstance and activityInstance rows
    dataset = dataset
            .filter(dataset.col(BpmnaiVariables.VAR_DATA_SOURCE).notEqual(BpmnaiVariables.EVENT_PROCESS_INSTANCE))
            .groupBy(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID, BpmnaiVariables.VAR_ACT_INST_ID)
            .agg(aggregationMap);

    //rename back columns after aggregation
    String pattern = "(max|allbutemptystring|processstate)\\((.+)\\)";
    Pattern r = Pattern.compile(pattern);

    for(String columnName : dataset.columns()) {
        Matcher m = r.matcher(columnName);
        if(m.find()) {
            String newColumnName = m.group(2);
            dataset = dataset.withColumnRenamed(columnName, newColumnName);
        }
    }


    //in case we add the CSV we have a name column in the first dataset of the join so we call drop again to make sure it is gone
    dataset = dataset.drop(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME);
    dataset = dataset.drop(BpmnaiVariables.VAR_DATA_SOURCE);

    dataset = dataset.sort(BpmnaiVariables.VAR_START_TIME);

    dataset.cache();
    BpmnaiLogger.getInstance().writeInfo("Found " + dataset.count() + " activity instances.");

    if(config.isWriteStepResultsIntoFile()) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataset, "agg_of_activity_instances", config);
    }

    //return preprocessed data
    return dataset;
}
 
Example 7
Source File: DataFilterOnActivityStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
/**
 * @param dataSet the incoming dataset for this processing step
 * @param parameters
 * @return the filtered DataSet
 */
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataSet, Map<String, Object> parameters, SparkRunnerConfig config) {
    // any parameters set?
    if (parameters == null || parameters.size() == 0) {
        BpmnaiLogger.getInstance().writeWarn("No parameters found for the DataFilterOnActivityStep");
        return dataSet;
    }

    // get query parameter
    String query = (String) parameters.get("query");
    BpmnaiLogger.getInstance().writeInfo("Filtering data with activity instance filter query: " + query + ".");

    // save size of initial dataset for log
    dataSet.cache();
    Long initialDSCount = dataSet.count();

    // repartition by process instance and order by start_time for this operation
    dataSet = dataSet.repartition(dataSet.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID)).sortWithinPartitions(BpmnaiVariables.VAR_START_TIME);

    // we temporarily store variable updates (rows with a var type set) separately.
    Dataset<Row> variables = dataSet.filter(col(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_TYPE).isNotNull());
    //find first occurrence of activity instance
    final Dataset<Row> dsTmp = dataSet.filter(dataSet.col(BpmnaiVariables.VAR_ACT_ID).equalTo(query)).filter(dataSet.col(BpmnaiVariables.VAR_END_TIME).isNull()); //TODO: ENSURING THAT THIS ISN'T A VARIABLE ROW

    // now we look for the first occurrence of the activity id contained in "query". The result comprises of a dataset of corresponding activity instances.
    final Dataset<Row> dsActivityInstances = dataSet.filter(dataSet.col(BpmnaiVariables.VAR_ACT_ID).like(query)).filter(dataSet.col(BpmnaiVariables.VAR_END_TIME).isNull()); //TODO: ENSURING THAT THIS ISN'T A VARIABLE ROW

    // we slim the resulting dataset down: only the activity instances process id and the instances start time are relevant.
    List<Row> activityRows = dsActivityInstances.select(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID, BpmnaiVariables.VAR_START_TIME).collectAsList();
    Map<String, String> activities = activityRows.stream().collect(Collectors.toMap(
            r -> r.getAs(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID), r -> r.getAs(BpmnaiVariables.VAR_START_TIME)));
    // broadcasting the PID - Start time Map to use it in a user defined function
    SparkBroadcastHelper.getInstance().broadcastVariable(SparkBroadcastHelper.BROADCAST_VARIABLE.PROCESS_INSTANCE_TIMESTAMP_MAP, activities);

    // now we have to select for each process instance in our inital dataset all events that happend before the first occurence of our selected activity.
    // We first narrow it down to the process instances in question
    Dataset<Row> selectedProcesses = dataSet.filter(col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID).isin(activities.keySet().toArray()));
    // Then, we mark all events that should be removed
    Dataset<Row> activityDataSet = selectedProcesses.withColumn("data_filter_on_activity",
            callUDF("activityBeforeTimestamp",
                    selectedProcesses.col(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID),
                    selectedProcesses.col(BpmnaiVariables.VAR_START_TIME)));
    // And we keep the rest
    activityDataSet = activityDataSet.filter(col("data_filter_on_activity").like("TRUE"));
    // Clean up
    activityDataSet = activityDataSet.drop("data_filter_on_activity");

    // However, we lost all variable updates in this approach, so now we add the variables in question to the dataset
    // first, we narrow it down to keep only variables that have a corresponding activity instance
    activityDataSet = activityDataSet.withColumnRenamed(BpmnaiVariables.VAR_ACT_INST_ID, BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT");

    variables = variables.join(activityDataSet.select(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT").distinct(), variables.col(BpmnaiVariables.VAR_ACT_INST_ID).equalTo(activityDataSet.col(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT")),"inner");

    activityDataSet = activityDataSet.withColumnRenamed(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT", BpmnaiVariables.VAR_ACT_INST_ID);
    variables = variables.drop(BpmnaiVariables.VAR_ACT_INST_ID+"_RIGHT");
    dataSet = activityDataSet.union(variables);

    dataSet.cache();
    BpmnaiLogger.getInstance().writeInfo("DataFilterOnActivityStep: The filtered DataSet contains "+dataSet.count()+" rows, (before: "+ initialDSCount+" rows)");

    if (config.isWriteStepResultsIntoFile()) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataSet, "data_filter_on_activity_step", config);
    }

    return dataSet;


}
 
Example 8
Source File: SparkMultiClassClassifier.java    From mmtf-spark with Apache License 2.0 4 votes vote down vote up
/**
 * Dataset must at least contain the following two columns:
 * label: the class labels
 * features: feature vector
 * @param data
 * @return map with metrics
 */
public Map<String,String> fit(Dataset<Row> data) {
	int classCount = (int)data.select(label).distinct().count();

	StringIndexerModel labelIndexer = new StringIndexer()
	  .setInputCol(label)
	  .setOutputCol("indexedLabel")
	  .fit(data);

	// Split the data into training and test sets (30% held out for testing)
	Dataset<Row>[] splits = data.randomSplit(new double[] {1.0-testFraction, testFraction}, seed);
	Dataset<Row> trainingData = splits[0];
	Dataset<Row> testData = splits[1];
	
	String[] labels = labelIndexer.labels();
	
	System.out.println();
	System.out.println("Class\tTrain\tTest");
	for (String l: labels) {
		System.out.println(l + "\t" + trainingData.select(label).filter(label + " = '" + l + "'").count()
				+ "\t" 
				+ testData.select(label).filter(label + " = '" + l + "'").count());
	}
	
	// Set input columns
	predictor
	.setLabelCol("indexedLabel")
	.setFeaturesCol("features");

	// Convert indexed labels back to original labels.
	IndexToString labelConverter = new IndexToString()
	  .setInputCol("prediction")
	  .setOutputCol("predictedLabel")
	  .setLabels(labelIndexer.labels());

	// Chain indexers and forest in a Pipeline
	Pipeline pipeline = new Pipeline()
	  .setStages(new PipelineStage[] {labelIndexer, predictor, labelConverter});

	// Train model. This also runs the indexers.
	PipelineModel model = pipeline.fit(trainingData);

	// Make predictions.
	Dataset<Row> predictions = model.transform(testData).cache();
	
	// Display some sample predictions
	System.out.println();
	System.out.println("Sample predictions: " + predictor.getClass().getSimpleName());

	predictions.sample(false, 0.1, seed).show(25);	

	predictions = predictions.withColumnRenamed(label, "stringLabel");
	predictions = predictions.withColumnRenamed("indexedLabel", label);
	
	// collect metrics
	Dataset<Row> pred = predictions.select("prediction",label);
       Map<String,String> metrics = new LinkedHashMap<>();       
       metrics.put("Method", predictor.getClass().getSimpleName());
       
       if (classCount == 2) {
       	    BinaryClassificationMetrics b = new BinaryClassificationMetrics(pred);
         	metrics.put("AUC", Float.toString((float)b.areaUnderROC()));
       }
    
       MulticlassMetrics m = new MulticlassMetrics(pred); 
       metrics.put("F", Float.toString((float)m.weightedFMeasure()));
       metrics.put("Accuracy", Float.toString((float)m.accuracy()));
       metrics.put("Precision", Float.toString((float)m.weightedPrecision()));
       metrics.put("Recall", Float.toString((float)m.weightedRecall()));
       metrics.put("False Positive Rate", Float.toString((float)m.weightedFalsePositiveRate()));
       metrics.put("True Positive Rate", Float.toString((float)m.weightedTruePositiveRate()));
       metrics.put("", "\nConfusion Matrix\n" 
           + Arrays.toString(labels) +"\n" 
       		+ m.confusionMatrix().toString());
       
       return metrics;
}