org.apache.spark.sql.DataFrame Java Examples

The following examples show how to use org.apache.spark.sql.DataFrame. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SparkDataSourceManager.java    From DDF with Apache License 2.0 8 votes vote down vote up
@Override
public DDF loadFromJDBC(JDBCDataSourceDescriptor dataSource) throws DDFException {
    SparkDDFManager sparkDDFManager = (SparkDDFManager)mDDFManager;
    HiveContext sqlContext = sparkDDFManager.getHiveContext();

    JDBCDataSourceCredentials cred = (JDBCDataSourceCredentials)dataSource.getDataSourceCredentials();
    String fullURL = dataSource.getDataSourceUri().getUri().toString();
    if (cred.getUsername() != null &&  !cred.getUsername().equals("")) {
        fullURL += String.format("?user=%s&password=%s", cred.getUsername(), cred.getPassword());
    }

    Map<String, String> options = new HashMap<String, String>();
    options.put("url", fullURL);
    options.put("dbtable", dataSource.getDbTable());
    DataFrame df = sqlContext.load("jdbc", options);

    DDF ddf = sparkDDFManager.newDDF(sparkDDFManager, df, new Class<?>[]{DataFrame.class},
        null, SparkUtils.schemaFromDataFrame(df));
    // TODO?
    ddf.getRepresentationHandler().get(RDD.class, Row.class);
    ddf.getMetaDataHandler().setDataSourceDescriptor(dataSource);
    return ddf;
}
 
Example #2
Source File: FillNAValuesTransformerBridgeTest.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Test
public void shouldBehaveExactlyAsSparkNAFillerForAllSupportedDataTypes() {

    DataFrame df = getDataFrame();
    DataFrame df1 = df.na().fill( getFillNAMap() );

    FillNAValuesTransformer fillNAValuesTransformer = new FillNAValuesTransformer();
    fillNAValuesTransformer.setNAValueMap( getFillNAMap() );
    DataFrame df2 = fillNAValuesTransformer.transform(df);

    Row[] data1 = df1.orderBy("id").select("id", "a", "b", "c", "d").collect();
    Row[] data2 = df2.orderBy("id").select("id", "a", "b", "c", "d").collect();

    for( int i =0; i < data1.length; i++) {
        for( int j=1; j<=4; j++) {
            assertEquals(data1[i].get(j), data2[i].get(j));
        }
    }
}
 
Example #3
Source File: ProbabilityTransformModelInfoAdapter.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Override
public ProbabilityTransformModelInfo getModelInfo(final ProbabilityTransformModel from, DataFrame df) {
    ProbabilityTransformModelInfo modelInfo = new ProbabilityTransformModelInfo();

    modelInfo.setActualClickProportion(from.getActualClickProportion());
    modelInfo.setUnderSampledClickProportion(from.getUnderSampledClickProportion());
    modelInfo.setProbIndex(from.getProbIndex());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(from.getInputCol());
    modelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(from.getOutputCol());
    modelInfo.setOutputKeys(outputKeys);
    return modelInfo;
}
 
Example #4
Source File: IfZeroVectorBridgeTest.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Test
public void testIfZeroVectorSparse() {
    IfZeroVector sparkModel = new IfZeroVector()
            .setInputCol("vectorized_count")
            .setOutputCol("product_title_filtered")
            .setThenSetValue("others")
            .setElseSetCol("product_title");
    System.out.println(sparseOrderDF.schema());
    DataFrame transformed = sparkModel.transform(sparseOrderDF).orderBy("product_title");
    System.out.println(transformed.schema());
    //compare predictions
    Row[] sparkOutput = transformed.select("product_title_filtered").collect();
    assertEquals("others", sparkOutput[0].get(0));
    assertEquals("Nike Airmax 2015", sparkOutput[1].get(0));
    assertEquals("Xiaomi Redmi Note", sparkOutput[2].get(0));
}
 
Example #5
Source File: Tagger.java    From vn.vitk with GNU General Public License v3.0 6 votes vote down vote up
void testRandomSplit(String inputFileName, int numFeatures, String modelFileName) {
	CMMParams params = new CMMParams()
		.setMaxIter(600)
		.setRegParam(1E-6)
		.setMarkovOrder(2)
		.setNumFeatures(numFeatures);
	
	JavaRDD<String> lines = jsc.textFile(inputFileName);
	DataFrame dataset = createDataFrame(lines.collect());
	DataFrame[] splits = dataset.randomSplit(new double[]{0.9, 0.1}); 
	DataFrame trainingData = splits[0];
	System.out.println("Number of training sequences = " + trainingData.count());
	DataFrame testData = splits[1];
	System.out.println("Number of test sequences = " + testData.count());
	// train and save a model on the training data
	cmmModel = train(trainingData, modelFileName, params);
	// test the model on the test data
	System.out.println("Test accuracy:");
	evaluate(testData); 
	// test the model on the training data
	System.out.println("Training accuracy:");
	evaluate(trainingData);
}
 
Example #6
Source File: LogisticRegressionModelInfoAdapter.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) {
    final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
    logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray());
    logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
    logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
    logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
    logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add("features");
    logisticRegressionModelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add("prediction");
    outputKeys.add("probability");
    logisticRegressionModelInfo.setOutputKeys(outputKeys);

    return logisticRegressionModelInfo;
}
 
Example #7
Source File: JavaStocks.java    From spark-ts-examples with Apache License 2.0 6 votes vote down vote up
private static DataFrame loadObservations(JavaSparkContext sparkContext, SQLContext sqlContext,
    String path) {
  JavaRDD<Row> rowRdd = sparkContext.textFile(path).map((String line) -> {
      String[] tokens = line.split("\t");
      ZonedDateTime dt = ZonedDateTime.of(Integer.parseInt(tokens[0]),
          Integer.parseInt(tokens[1]), Integer.parseInt(tokens[1]), 0, 0, 0, 0,
          ZoneId.systemDefault());
      String symbol = tokens[3];
      double price = Double.parseDouble(tokens[5]);
      return RowFactory.create(Timestamp.from(dt.toInstant()), symbol, price);
  });
  List<StructField> fields = new ArrayList();
  fields.add(DataTypes.createStructField("timestamp", DataTypes.TimestampType, true));
  fields.add(DataTypes.createStructField("symbol", DataTypes.StringType, true));
  fields.add(DataTypes.createStructField("price", DataTypes.DoubleType, true));
  StructType schema = DataTypes.createStructType(fields);
  return sqlContext.createDataFrame(rowRdd, schema);
}
 
Example #8
Source File: Tagger.java    From vn.vitk with GNU General Public License v3.0 6 votes vote down vote up
/**
 * Tags a list of sequences and returns a list of tag sequences.
 * @param sentences
 * @return a list of tagged sequences.
 */
public List<String> tag(List<String> sentences) {
	List<Row> rows = new LinkedList<Row>();
	for (String sentence : sentences) {
		rows.add(RowFactory.create(sentence));
	}
	StructType schema = new StructType(new StructField[]{
		new StructField("sentence", DataTypes.StringType, false, Metadata.empty())	
	});
	SQLContext sqlContext = new SQLContext(jsc);
	DataFrame input = sqlContext.createDataFrame(rows, schema);
	if (cmmModel != null) {
		DataFrame output = cmmModel.transform(input).repartition(1);
		return output.javaRDD().map(new RowToStringFunction(1)).collect();
	} else {
		System.err.println("Tagging model is null. You need to create or load a model first.");
		return null;
	}
}
 
Example #9
Source File: LogisticRegressionModelInfoAdapter1.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Override
public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) {
    final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo();
    logisticRegressionModelInfo.setWeights(sparkLRModel.coefficients().toArray());
    logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept());
    logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses());
    logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures());
    logisticRegressionModelInfo.setThreshold(sparkLRModel.getThreshold());
    logisticRegressionModelInfo.setProbabilityKey(sparkLRModel.getProbabilityCol());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(sparkLRModel.getFeaturesCol());
    logisticRegressionModelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(sparkLRModel.getPredictionCol());
    outputKeys.add(sparkLRModel.getProbabilityCol());
    logisticRegressionModelInfo.setOutputKeys(outputKeys);

    return logisticRegressionModelInfo;
}
 
Example #10
Source File: CustomOneHotEncoderModelInfoAdapter.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Override
public OneHotEncoderModelInfo getModelInfo(final CustomOneHotEncoderModel from, DataFrame df) {
    OneHotEncoderModelInfo modelInfo = new OneHotEncoderModelInfo();

    modelInfo.setNumTypes(from.vectorSize());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(from.getInputCol());
    modelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(from.getOutputCol());
    modelInfo.setOutputKeys(outputKeys);

    return modelInfo;
}
 
Example #11
Source File: IfZeroVectorModelInfoAdapter.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Override
public IfZeroVectorModelInfo getModelInfo(final IfZeroVector from, DataFrame df) {
    IfZeroVectorModelInfo modelInfo = new IfZeroVectorModelInfo();

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(from.getInputCol());
    modelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(from.getOutputCol());
    modelInfo.setOutputKeys(outputKeys);

    modelInfo.setThenSetValue(from.getThenSetValue());
    modelInfo.setElseSetCol(from.getElseSetCol());

    return modelInfo;
}
 
Example #12
Source File: CountVectorizerModelInfoAdapter.java    From spark-transformers with Apache License 2.0 6 votes vote down vote up
@Override
public CountVectorizerModelInfo getModelInfo(final CountVectorizerModel from, final DataFrame df) {
    final CountVectorizerModelInfo modelInfo = new CountVectorizerModelInfo();
    modelInfo.setMinTF(from.getMinTF());
    modelInfo.setVocabulary(from.vocabulary());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(from.getInputCol());
    modelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(from.getOutputCol());
    modelInfo.setOutputKeys(outputKeys);

    return modelInfo;
}
 
Example #13
Source File: BucketizerModelInfoAdapter.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Override
public BucketizerModelInfo getModelInfo(final Bucketizer from, final DataFrame df) {
    final BucketizerModelInfo modelInfo = new BucketizerModelInfo();
    modelInfo.setSplits(from.getSplits());

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(from.getInputCol());
    modelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(from.getOutputCol());
    modelInfo.setOutputKeys(outputKeys);
    return modelInfo;
}
 
Example #14
Source File: AbstractJavaEsSparkSQLTest.java    From elasticsearch-hadoop with Apache License 2.0 5 votes vote down vote up
@Test
public void testEsdataFrame1WriteWithId() throws Exception {
	DataFrame dataFrame = artistsAsDataFrame();

	String target = resource("sparksql-test-scala-basic-write-id-mapping", "data", version);
	String docEndpoint = docEndpoint("sparksql-test-scala-basic-write-id-mapping", "data", version);

	JavaEsSparkSQL.saveToEs(dataFrame, target,
			ImmutableMap.of(ES_MAPPING_ID, "id"));
	assertTrue(RestUtils.exists(target));
	assertThat(RestUtils.get(target + "/_search?"), containsString("345"));
	assertThat(RestUtils.exists(docEndpoint + "/1"), is(true));
}
 
Example #15
Source File: RandomForestRegressionModelInfoAdapterBridgeTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void testRandomForestRegression() {
    // Load the data stored in LIBSVM format as a DataFrame.
    DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm");

    // Split the data into training and test sets (30% held out for testing)
    DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3});
    DataFrame trainingData = splits[0];
    DataFrame testData = splits[1];

    // Train a RandomForest model.
    RandomForestRegressionModel regressionModel = new RandomForestRegressor()
            .setFeaturesCol("features").fit(trainingData);

    byte[] exportedModel = ModelExporter.export(regressionModel, null);

    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    Row[] sparkOutput = regressionModel.transform(testData).select("features", "prediction").collect();

    //compare predictions
    for (Row row : sparkOutput) {
        Vector v = (Vector) row.get(0);
        double actual = row.getDouble(1);

        Map<String, Object> inputData = new HashMap<String, Object>();
        inputData.put(transformer.getInputKeys().iterator().next(), v.toArray());
        transformer.transform(inputData);
        double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next());

        System.out.println(actual + ", " + predicted);
        assertEquals(actual, predicted, EPSILON);
    }
}
 
Example #16
Source File: CMMModel.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
@Override
public CMMModel load(String path) {
	org.apache.spark.ml.util.DefaultParamsReader.Metadata metadata = DefaultParamsReader.loadMetadata(path, sc(), CMMModel.class.getName());
	String pipelinePath = new Path(path, "pipelineModel").toString();
	PipelineModel pipelineModel = PipelineModel.load(pipelinePath);
	String dataPath = new Path(path, "data").toString();
	DataFrame df = sqlContext().read().format("parquet").load(dataPath);
	Row row = df.select("markovOrder", "weights", "tagDictionary").head();
	// load the Markov order
	MarkovOrder order = MarkovOrder.values()[row.getInt(0)-1];
	// load the weight vector
	Vector w = row.getAs(1);
	// load the tag dictionary
	@SuppressWarnings("unchecked")
	scala.collection.immutable.HashMap<String, WrappedArray<Integer>> td = (scala.collection.immutable.HashMap<String, WrappedArray<Integer>>)row.get(2);
	Map<String, Set<Integer>> tagDict = new HashMap<String, Set<Integer>>();
	Iterator<Tuple2<String, WrappedArray<Integer>>> iterator = td.iterator();
	while (iterator.hasNext()) {
		Tuple2<String, WrappedArray<Integer>> tuple = iterator.next();
		Set<Integer> labels = new HashSet<Integer>();
		scala.collection.immutable.List<Integer> list = tuple._2().toList();
		for (int i = 0; i < list.size(); i++)
			labels.add(list.apply(i));
		tagDict.put(tuple._1(), labels);
	}
	// build a CMM model
	CMMModel model = new CMMModel(pipelineModel, w, order, tagDict);
	DefaultParamsReader.getAndSetParams(model, metadata);
	return model;
}
 
Example #17
Source File: LogisticRegression1ExporterTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void shouldExportAndImportCorrectly() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";

    DataFrame trainingData = sqlContext.read().format("libsvm").load(datapath);

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, trainingData);

    //Import it back
    LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel);

    //check if they are exactly equal with respect to their fields
    //it maybe edge cases eg. order of elements in the list is changed
    assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON);
    assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON);
    assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON);
    assertEquals(lrmodel.getThreshold(), importedModel.getThreshold(), EPSILON);
    for (int i = 0; i < importedModel.getNumFeatures(); i++)
        assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON);

    assertEquals(lrmodel.getFeaturesCol(), importedModel.getInputKeys().iterator().next());
    assertEquals(lrmodel.getPredictionCol(), importedModel.getOutputKeys().iterator().next());
}
 
Example #18
Source File: TransitionClassifier.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Converts a list of dependency graphs to a data frame.
 * @param jsc 
 * @param graphs
 * @param featureFrame
 * @return a data frame
 */
private DataFrame toDataFrame(JavaSparkContext jsc, List<DependencyGraph> graphs, FeatureFrame featureFrame) {
	List<ParsingContext> list = new ArrayList<ParsingContext>();
	for (DependencyGraph graph : graphs) {
		List<ParsingContext> xy = TransitionDecoder.decode(graph, featureFrame);
		list.addAll(xy);
	}
	JavaRDD<ParsingContext> javaRDD = jsc.parallelize(list);
	return sqlContext.createDataFrame(javaRDD, ParsingContext.class);
}
 
Example #19
Source File: DataSparkFromRDD.java    From toolbox with Apache License 2.0 5 votes vote down vote up
@Override
public DataFrame getDataFrame(SQLContext sql) {

    // Obtain the schema
    StructType schema = SchemaConverter.getSchema(attributes);

    // Transform the RDD
    JavaRDD<Row> rowRDD = DataFrameOps.toRowRDD(amidstRDD, attributes);

    // Create the DataFrame
    return sql.createDataFrame(rowRDD, schema);
}
 
Example #20
Source File: Tagger.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
/**
	 * Tags a data frame containing a column named 'sentence'.
	 * @param input
	 * @param outputFileName
	 * @param outputFormat
	 */
	public void tag(DataFrame input, String outputFileName, OutputFormat outputFormat) {
		long tic = System.currentTimeMillis();
		long duration = 0;
		if (cmmModel != null) {
			DataFrame output = cmmModel.transform(input).repartition(1);
			duration = System.currentTimeMillis() - tic;
			switch (outputFormat) {
			case JSON:
				output.write().json(outputFileName);
				break;
			case PARQUET:
				output.write().parquet(outputFileName);
				break;
			case TEXT:
				toTaggedSentence(output).repartition(1).saveAsTextFile(outputFileName);
//				output.select("prediction").write().text(outputFileName);
				break;
			}
		} else {
			System.err.println("Tagging model is null. You need to create or load a model first.");
		}
		if (verbose) {
			long n = input.count();
			System.out.println(" Number of sentences = " + n);
			System.out.println("  Total tagging time = " + duration + " milliseconds.");
			System.out.println("Average tagging time = " + ((float)duration) / n + " milliseconds.");
		}
	}
 
Example #21
Source File: Log1PScalerModelInfoAdapter.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Override
public Log1PScalerModelInfo getModelInfo(final Log1PScaler from, DataFrame df) {
    Log1PScalerModelInfo modelInfo = new Log1PScalerModelInfo();

    Set<String> inputKeys = new LinkedHashSet<String>();
    inputKeys.add(from.getInputCol());
    modelInfo.setInputKeys(inputKeys);

    Set<String> outputKeys = new LinkedHashSet<String>();
    outputKeys.add(from.getOutputCol());
    modelInfo.setOutputKeys(outputKeys);

    return modelInfo;
}
 
Example #22
Source File: Tagger.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Creates a data frame from a list of tagged sentences.
 * @param taggedSentences
 * @return a data frame of two columns: "sentence" and "partOfSpeech".
 */
public DataFrame createDataFrame(List<String> taggedSentences) {
	List<String> wordSequences = new LinkedList<String>();
	List<String> tagSequences = new LinkedList<String>();
	for (String taggedSentence : taggedSentences) {
		StringBuilder wordBuf = new StringBuilder();
		StringBuilder tagBuf = new StringBuilder();
		String[] tokens = taggedSentence.split("\\s+");
		for (String token : tokens) {
			String[] parts = token.split("/");
			if (parts.length == 2) {
				wordBuf.append(parts[0]);
				wordBuf.append(' ');
				tagBuf.append(parts[1]);
				tagBuf.append(' ');
			} else { // this token is "///"  
				wordBuf.append('/');
				wordBuf.append(' ');
				tagBuf.append('/');
				tagBuf.append(' ');
			}
		}
		wordSequences.add(wordBuf.toString().trim());
		tagSequences.add(tagBuf.toString().trim());
	}
	if (verbose) {
		System.out.println("Number of sentences = " + wordSequences.size());
	}
	List<Row> rows = new LinkedList<Row>();
	for (int i = 0; i < wordSequences.size(); i++) {
		rows.add(RowFactory.create(wordSequences.get(i), tagSequences.get(i)));
	}
	JavaRDD<Row> jrdd = jsc.parallelize(rows);
	StructType schema = new StructType(new StructField[]{
			new StructField("sentence", DataTypes.StringType, false, Metadata.empty()),
			new StructField("partOfSpeech", DataTypes.StringType, false, Metadata.empty())
		});
		
	return new SQLContext(jsc).createDataFrame(jrdd, schema);
}
 
Example #23
Source File: StringIndexerBridgeTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void testStringIndexer() {

    //prepare data
    StructType schema = createStructType(new StructField[]{
            createStructField("id", IntegerType, false),
            createStructField("label", StringType, false)
    });
    List<Row> trainingData = Arrays.asList(
            cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
    DataFrame dataset = sqlContext.createDataFrame(trainingData, schema);

    //train model in spark
    StringIndexerModel model = new StringIndexer()
            .setInputCol("label")
            .setOutputCol("labelIndex").fit(dataset);

    //Export this model
    byte[] exportedModel = ModelExporter.export(model, dataset);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //compare predictions
    Row[] sparkOutput = model.transform(dataset).orderBy("id").select("id", "label", "labelIndex").collect();
    for (Row row : sparkOutput) {

        Map<String, Object> data = new HashMap<String, Object>();
        data.put(model.getInputCol(), (String) row.get(1));
        transformer.transform(data);
        double indexerOutput = (double) data.get(model.getOutputCol());

        assertEquals(indexerOutput, (double) row.get(2), EPSILON);
    }

}
 
Example #24
Source File: Tagger.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Trains a tagger with data specified in a data frame. The data frame has 
 * two columns, one column "sentence" contains a word sequence, and the other column "partOfSpeech" 
 * contains the corresponding tag sequence. Each row of the data frame specifies a tagged sequence
 * in the training set.
 * @param dataset
 * @param modelFileName
 * @param params
 * @return a {@link CMMModel}
 */
public CMMModel train(DataFrame dataset, String modelFileName, CMMParams params) {
	CMM cmm = new CMM(params).setVerbose(verbose);
	cmmModel = cmm.fit(dataset);
	try {
		cmmModel.write().overwrite().save(modelFileName);
	} catch (IOException e) {
		e.printStackTrace();
	}
	return cmmModel;
}
 
Example #25
Source File: AbstractJavaEsSparkSQLTest.java    From elasticsearch-hadoop with Apache License 2.0 5 votes vote down vote up
@Test
public void testEsSchemaRDD1WriteWithMappingExclude() throws Exception {
	DataFrame dataFrame = artistsAsDataFrame();

    String target = resource("sparksql-test-scala-basic-write-exclude-mapping", "data", version);
    JavaEsSparkSQL.saveToEs(dataFrame, target,ImmutableMap.of(ES_MAPPING_EXCLUDE, "url"));
    assertTrue(RestUtils.exists(target));
    assertThat(RestUtils.get(target + "/_search?"), not(containsString("url")));
}
 
Example #26
Source File: InstanceRelationWriterTest.java    From rdf2x with Apache License 2.0 5 votes vote down vote up
private DataFrame getTestRelations() {
    List<Row> rows = new ArrayList<>();

    rows.add(RowFactory.create(
            uriIndex.getIndex("http://example.com/knows"),
            uriIndex.getIndex("http://example.com/a"),
            1L,
            uriIndex.getIndex("http://example.com/b"),
            3L
    ));

    rows.add(RowFactory.create(
            uriIndex.getIndex("http://example.com/likes"),
            uriIndex.getIndex("http://example.com/a"),
            2L,
            uriIndex.getIndex("http://example.com/b"),
            3L
    ));

    return sql.createDataFrame(rows, new StructType()
            .add("predicateIndex", DataTypes.IntegerType, false)
            .add("fromTypeIndex", DataTypes.IntegerType, false)
            .add("fromID", DataTypes.LongType, false)
            .add("toTypeIndex", DataTypes.IntegerType, false)
            .add("toID", DataTypes.LongType, false)
    );
}
 
Example #27
Source File: EntitySalienceFeatureExtractorSpark.java    From ambiverse-nlu with Apache License 2.0 5 votes vote down vote up
/**
 * Extract a DataFrame ready for training or testing.
 * @param jsc
 * @param documents
 * @param sqlContext
 * @return
 * @throws ResourceInitializationException
 */
public DataFrame extract(JavaSparkContext jsc, JavaRDD<SCAS> documents, SQLContext sqlContext) throws ResourceInitializationException {
    Accumulator<Integer> TOTAL_DOCS = jsc.accumulator(0, "TOTAL_DOCS");
    Accumulator<Integer> SALIENT_ENTITY_INSTANCES = jsc.accumulator(0, "SALIENT_ENTITY_INSTANCES");
    Accumulator<Integer> NON_SALIENT_ENTITY_INSTANCES = jsc.accumulator(0, "NON_SALIENT_ENTITY_INSTANCES");

    TrainingSettings trainingSettings = getTrainingSettings();

    FeatureExtractor fe = new NYTEntitySalienceFeatureExtractor();
    final int featureVectorSize = FeatureSetFactory.createFeatureSet(TrainingSettings.FeatureExtractor.ENTITY_SALIENCE).getFeatureVectorSize();

    JavaRDD<TrainingInstance> trainingInstances =
            documents.flatMap(s -> {
                TOTAL_DOCS.add(1);
                return fe.getTrainingInstances(s.getJCas(),
                        trainingSettings.getFeatureExtractor(),
                        trainingSettings.getPositiveInstanceScalingFactor());
            });

    StructType schema = new StructType(new StructField[]{
            new StructField("docId", DataTypes.StringType, false, Metadata.empty() ),
            new StructField("entityId", DataTypes.StringType, false, Metadata.empty() ),
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty() ),
            new StructField("features", new VectorUDT(), false, Metadata.empty())
    });

    JavaRDD<Row> withFeatures = trainingInstances.map(ti -> {
        if (ti.getLabel() == 1.0) {
            SALIENT_ENTITY_INSTANCES.add(1);
        } else {
            NON_SALIENT_ENTITY_INSTANCES.add(1);
        }
        Vector vei = FeatureValueInstanceUtils.convertToSparkMLVector(ti, featureVectorSize);
        return RowFactory.create(ti.getDocId(), ti.getEntityId(), ti.getLabel(), vei);
    });

    return sqlContext.createDataFrame(withFeatures, schema);
}
 
Example #28
Source File: InstanceRelationWriterTest.java    From rdf2x with Apache License 2.0 5 votes vote down vote up
@Test
public void testWriteRelationTablesWithoutPredicateIndex() throws IOException {
    InstanceRelationWriter writer = new InstanceRelationWriter(config
            .setStorePredicate(false), jsc(), persistor, rdfSchema);
    writer.writeRelationTables(getTestRelationSchema(), getTestRelations());

    List<Row> rows = new ArrayList<>();
    rows.add(RowFactory.create(1L, 3L));
    rows.add(RowFactory.create(2L, 3L));

    DataFrame result = this.result.values().iterator().next();
    assertEquals("Expected schema of A_B was extracted", getExpectedSchemaOfAB(false, false), result.schema());
    assertRDDEquals("Expected rows of A_B were extracted", jsc().parallelize(rows), result.toJavaRDD());
}
 
Example #29
Source File: LogisticRegression1BridgeTest.java    From spark-transformers with Apache License 2.0 5 votes vote down vote up
@Test
public void testLogisticRegression() {
    //prepare data
    String datapath = "src/test/resources/binary_classification_test.libsvm";

    DataFrame trainingData = sqlContext.read().format("libsvm").load(datapath);

    //Train model in spark
    LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData);

    //Export this model
    byte[] exportedModel = ModelExporter.export(lrmodel, trainingData);

    //Import and get Transformer
    Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel);

    //validate predictions
    List<LabeledPoint> testPoints = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().collect();
    for (LabeledPoint i : testPoints) {
        Vector v = i.features();
        double actual = lrmodel.predict(v);

        Map<String, Object> data = new HashMap<String, Object>();
        data.put("features", v.toArray());
        transformer.transform(data);
        double predicted = (double) data.get("prediction");

        assertEquals(actual, predicted, EPSILON);
    }
}
 
Example #30
Source File: Tagger.java    From vn.vitk with GNU General Public License v3.0 5 votes vote down vote up
/**
 * Tags a distributed list of sentences and writes the result to an output file with 
 * a desired output format.
 * @param sentences
 * @param outputFileName
 * @param outputFormat
 */
public void tag(JavaRDD<Row> sentences, String outputFileName, OutputFormat outputFormat) {
	StructType schema = new StructType(new StructField[]{
		new StructField("sentence", DataTypes.StringType, false, Metadata.empty())	
	});
	SQLContext sqlContext = new SQLContext(jsc);
	DataFrame input = sqlContext.createDataFrame(sentences, schema);
	tag(input, outputFileName, outputFormat);
}