Java Code Examples for org.apache.spark.sql.Dataset#sort()

The following examples show how to use org.apache.spark.sql.Dataset#sort() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MLContextTest.java    From systemds with Apache License 2.0 6 votes vote down vote up
@Test
public void testOutputDataFrameFromMatrixDML() {
	System.out.println("MLContextTest - output DataFrame from matrix DML");

	String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
	Script script = dml(s).out("M");
	Dataset<Row> df = ml.execute(script).getMatrix("M").toDF();
	Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
	List<Row> list = sortedDF.collectAsList();
	Row row1 = list.get(0);
	Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
	Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
	Assert.assertEquals(2.0, row1.getDouble(2), 0.0);

	Row row2 = list.get(1);
	Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
	Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
	Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
}
 
Example 2
Source File: MLContextTest.java    From systemds with Apache License 2.0 6 votes vote down vote up
@Test
public void testOutputDataFrameDoublesWithIDColumnFromMatrixDML() {
	System.out.println("MLContextTest - output DataFrame of doubles with ID column from matrix DML");

	String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
	Script script = dml(s).out("M");
	Dataset<Row> df = ml.execute(script).getMatrix("M").toDFDoubleWithIDColumn();
	Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
	List<Row> list = sortedDF.collectAsList();

	Row row1 = list.get(0);
	Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
	Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
	Assert.assertEquals(2.0, row1.getDouble(2), 0.0);

	Row row2 = list.get(1);
	Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
	Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
	Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
}
 
Example 3
Source File: MutationToStructureDemo.java    From mmtf-spark with Apache License 2.0 6 votes vote down vote up
public static void main(String[] args) throws IOException {
    SparkSession spark = SparkSession.builder().master("local[*]").appName(MutationToStructureDemo.class.getSimpleName())
            .getOrCreate();

    // find missense mutations that map to UniProt ID P15056 (BRAF)
    // that are annotated as pathogenic or likely pathogenic in ClinVar.
    List<String> uniprotIds = Arrays.asList("P15056"); // BRAF: P15056
    String query = "clinvar.rcv.clinical_significance:pathogenic OR clinvar.rcv.clinical_significance:likely pathogenic";
    Dataset<Row> df = MyVariantDataset.getVariations(uniprotIds, query).cache();
    System.out.println("BRAF missense mutations: " + df.count());
    df.show();
    
    // extract the list of variant Ids
    List<String> variantIds = df.select("variationId").as(Encoders.STRING()).collectAsList();
    
    // map to PDB structures
    Dataset<Row> ds = G2SDataset.getPositionDataset(variantIds);
    ds = ds.sort("structureId","chainId","pdbPosition");
    ds.show();

    spark.close(); 
}
 
Example 4
Source File: MLContextTest.java    From systemds with Apache License 2.0 6 votes vote down vote up
@Test
public void testOutputDataFrameFromMatrixDML() {
	System.out.println("MLContextTest - output DataFrame from matrix DML");

	String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
	Script script = dml(s).out("M");
	Dataset<Row> df = ml.execute(script).getMatrix("M").toDF();
	Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
	List<Row> list = sortedDF.collectAsList();
	Row row1 = list.get(0);
	Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
	Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
	Assert.assertEquals(2.0, row1.getDouble(2), 0.0);

	Row row2 = list.get(1);
	Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
	Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
	Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
}
 
Example 5
Source File: MLContextTest.java    From systemds with Apache License 2.0 6 votes vote down vote up
@Test
public void testOutputDataFrameDoublesWithIDColumnFromMatrixDML() {
	System.out.println("MLContextTest - output DataFrame of doubles with ID column from matrix DML");

	String s = "M = matrix('1 2 3 4', rows=2, cols=2);";
	Script script = dml(s).out("M");
	Dataset<Row> df = ml.execute(script).getMatrix("M").toDFDoubleWithIDColumn();
	Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);
	List<Row> list = sortedDF.collectAsList();

	Row row1 = list.get(0);
	Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
	Assert.assertEquals(1.0, row1.getDouble(1), 0.0);
	Assert.assertEquals(2.0, row1.getDouble(2), 0.0);

	Row row2 = list.get(1);
	Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
	Assert.assertEquals(3.0, row2.getDouble(1), 0.0);
	Assert.assertEquals(4.0, row2.getDouble(2), 0.0);
}
 
Example 6
Source File: MLContextTest.java    From systemds with Apache License 2.0 5 votes vote down vote up
@Test
public void testOutputDataFrameOfVectorsDML() {
	System.out.println("MLContextTest - output DataFrame of vectors DML");

	String s = "m=matrix('1 2 3 4',rows=2,cols=2);";
	Script script = dml(s).out("m");
	MLResults results = ml.execute(script);
	Dataset<Row> df = results.getDataFrame("m", true);
	Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);

	// verify column types
	StructType schema = sortedDF.schema();
	StructField[] fields = schema.fields();
	StructField idColumn = fields[0];
	StructField vectorColumn = fields[1];
	Assert.assertTrue(idColumn.dataType() instanceof DoubleType);
	Assert.assertTrue(vectorColumn.dataType() instanceof VectorUDT);

	List<Row> list = sortedDF.collectAsList();

	Row row1 = list.get(0);
	Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
	Vector v1 = (DenseVector) row1.get(1);
	double[] arr1 = v1.toArray();
	Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, arr1, 0.0);

	Row row2 = list.get(1);
	Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
	Vector v2 = (DenseVector) row2.get(1);
	double[] arr2 = v2.toArray();
	Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, arr2, 0.0);
}
 
Example 7
Source File: MLContextTest.java    From systemds with Apache License 2.0 5 votes vote down vote up
@Test
public void testOutputDataFrameOfVectorsDML() {
	System.out.println("MLContextTest - output DataFrame of vectors DML");

	String s = "m=matrix('1 2 3 4',rows=2,cols=2);";
	Script script = dml(s).out("m");
	MLResults results = ml.execute(script);
	Dataset<Row> df = results.getDataFrame("m", true);
	Dataset<Row> sortedDF = df.sort(RDDConverterUtils.DF_ID_COLUMN);

	// verify column types
	StructType schema = sortedDF.schema();
	StructField[] fields = schema.fields();
	StructField idColumn = fields[0];
	StructField vectorColumn = fields[1];
	Assert.assertTrue(idColumn.dataType() instanceof DoubleType);
	Assert.assertTrue(vectorColumn.dataType() instanceof VectorUDT);

	List<Row> list = sortedDF.collectAsList();

	Row row1 = list.get(0);
	Assert.assertEquals(1.0, row1.getDouble(0), 0.0);
	Vector v1 = (DenseVector) row1.get(1);
	double[] arr1 = v1.toArray();
	Assert.assertArrayEquals(new double[] { 1.0, 2.0 }, arr1, 0.0);

	Row row2 = list.get(1);
	Assert.assertEquals(2.0, row2.getDouble(0), 0.0);
	Vector v2 = (DenseVector) row2.get(1);
	double[] arr2 = v2.toArray();
	Assert.assertArrayEquals(new double[] { 3.0, 4.0 }, arr2, 0.0);
}
 
Example 8
Source File: AggregateActivityInstancesStep.java    From bpmn.ai with BSD 3-Clause "New" or "Revised" License 4 votes vote down vote up
@Override
public Dataset<Row> runPreprocessingStep(Dataset<Row> dataset, Map<String, Object> parameters, SparkRunnerConfig config) {

    //apply first and processState aggregator
    Map<String, String> aggregationMap = new HashMap<>();
    for(String column : dataset.columns()) {
        if(column.equals(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID)) {
            continue;
        } else if(column.equals(BpmnaiVariables.VAR_DURATION) || column.endsWith("_rev")) {
            aggregationMap.put(column, "max");
        } else if(column.equals(BpmnaiVariables.VAR_STATE)) {
            aggregationMap.put(column, "ProcessState");
        } else if(column.equals(BpmnaiVariables.VAR_ACT_INST_ID)) {
            //ignore it, as we aggregate by it
            continue;
        } else {
            aggregationMap.put(column, "AllButEmptyString");
        }
    }

    //first aggregation
    //activity level, take only processInstance and activityInstance rows
    dataset = dataset
            .filter(dataset.col(BpmnaiVariables.VAR_DATA_SOURCE).notEqual(BpmnaiVariables.EVENT_PROCESS_INSTANCE))
            .groupBy(BpmnaiVariables.VAR_PROCESS_INSTANCE_ID, BpmnaiVariables.VAR_ACT_INST_ID)
            .agg(aggregationMap);

    //rename back columns after aggregation
    String pattern = "(max|allbutemptystring|processstate)\\((.+)\\)";
    Pattern r = Pattern.compile(pattern);

    for(String columnName : dataset.columns()) {
        Matcher m = r.matcher(columnName);
        if(m.find()) {
            String newColumnName = m.group(2);
            dataset = dataset.withColumnRenamed(columnName, newColumnName);
        }
    }


    //in case we add the CSV we have a name column in the first dataset of the join so we call drop again to make sure it is gone
    dataset = dataset.drop(BpmnaiVariables.VAR_PROCESS_INSTANCE_VARIABLE_NAME);
    dataset = dataset.drop(BpmnaiVariables.VAR_DATA_SOURCE);

    dataset = dataset.sort(BpmnaiVariables.VAR_START_TIME);

    dataset.cache();
    BpmnaiLogger.getInstance().writeInfo("Found " + dataset.count() + " activity instances.");

    if(config.isWriteStepResultsIntoFile()) {
        BpmnaiUtils.getInstance().writeDatasetToCSV(dataset, "agg_of_activity_instances", config);
    }

    //return preprocessed data
    return dataset;
}
 
Example 9
Source File: DemoQueryVsAll.java    From mmtf-spark with Apache License 2.0 4 votes vote down vote up
public static void main(String[] args) throws IOException {

        String path = MmtfReader.getMmtfReducedPath();

        SparkConf conf = new SparkConf().setMaster("local[*]").setAppName(DemoQueryVsAll.class.getSimpleName());
        JavaSparkContext sc = new JavaSparkContext(conf);

        long start = System.nanoTime();

        // default example
        List<String> queryIds = Arrays.asList("2W47");
        
        // use list of PDB IDs from the command line
        if (args.length > 0) {
            queryIds = Arrays.asList(args);
        }
        
        System.out.println("DemoQueryVsAll Query structures: " + queryIds);
        
        // download query structure
        JavaPairRDD<String, StructureDataInterface> query = MmtfReader.downloadReducedMmtfFiles(queryIds, sc)
                .flatMapToPair(new StructureToPolymerChains());

        // use a 25% random sample of the Pisces non-redundant set
        // at 20% sequence identity and a resolution better than 1.6 A.
        double fraction = 0.25;
        long seed = 11;
        JavaPairRDD<String, StructureDataInterface> target = MmtfReader.readSequenceFile(path, fraction, seed, sc)
                .flatMapToPair(new StructureToPolymerChains())
                .filter(new Pisces(20, 1.6))
                .sample(false, fraction, seed);

        // specialized algorithms
        // String alignmentAlgorithm = CeMain.algorithmName;
        // String alignmentAlgorithm = CeCPMain.algorithmName;
        // String alignmentAlgorithm = FatCatFlexible.algorithmName;

        // two standard algorithms
        // String alignmentAlgorithm = CeMain.algorithmName;
        String alignmentAlgorithm = FatCatRigid.algorithmName;

        // String alignmentAlgorithm = ExhaustiveAligner.alignmentAlgorithm;

        // calculate alignments
        Dataset<Row> alignments = StructureAligner.getQueryVsAllAlignments(query, target, alignmentAlgorithm).cache();

        // sort alignments by TM score
        alignments = alignments.sort(col("tm").desc());
        
        // show results
        int count = (int) alignments.count();
        alignments.show(count);
        System.out.println("Pairs: " + count);

        // save results to file
        alignments.write().mode("overwrite").option("compression", "gzip").parquet("alignments.parquet");
        
        long end = System.nanoTime();
        System.out.println("Time per alignment: " + TimeUnit.NANOSECONDS.toMillis((end - start) / count) + " msec.");
        System.out.println("Time: " + TimeUnit.NANOSECONDS.toSeconds(end - start) + " sec.");

        sc.close();
    }