org.apache.spark.sql.DataFrame Java Examples
The following examples show how to use
org.apache.spark.sql.DataFrame.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SparkDataSourceManager.java From DDF with Apache License 2.0 | 8 votes |
@Override public DDF loadFromJDBC(JDBCDataSourceDescriptor dataSource) throws DDFException { SparkDDFManager sparkDDFManager = (SparkDDFManager)mDDFManager; HiveContext sqlContext = sparkDDFManager.getHiveContext(); JDBCDataSourceCredentials cred = (JDBCDataSourceCredentials)dataSource.getDataSourceCredentials(); String fullURL = dataSource.getDataSourceUri().getUri().toString(); if (cred.getUsername() != null && !cred.getUsername().equals("")) { fullURL += String.format("?user=%s&password=%s", cred.getUsername(), cred.getPassword()); } Map<String, String> options = new HashMap<String, String>(); options.put("url", fullURL); options.put("dbtable", dataSource.getDbTable()); DataFrame df = sqlContext.load("jdbc", options); DDF ddf = sparkDDFManager.newDDF(sparkDDFManager, df, new Class<?>[]{DataFrame.class}, null, SparkUtils.schemaFromDataFrame(df)); // TODO? ddf.getRepresentationHandler().get(RDD.class, Row.class); ddf.getMetaDataHandler().setDataSourceDescriptor(dataSource); return ddf; }
Example #2
Source File: FillNAValuesTransformerBridgeTest.java From spark-transformers with Apache License 2.0 | 6 votes |
@Test public void shouldBehaveExactlyAsSparkNAFillerForAllSupportedDataTypes() { DataFrame df = getDataFrame(); DataFrame df1 = df.na().fill( getFillNAMap() ); FillNAValuesTransformer fillNAValuesTransformer = new FillNAValuesTransformer(); fillNAValuesTransformer.setNAValueMap( getFillNAMap() ); DataFrame df2 = fillNAValuesTransformer.transform(df); Row[] data1 = df1.orderBy("id").select("id", "a", "b", "c", "d").collect(); Row[] data2 = df2.orderBy("id").select("id", "a", "b", "c", "d").collect(); for( int i =0; i < data1.length; i++) { for( int j=1; j<=4; j++) { assertEquals(data1[i].get(j), data2[i].get(j)); } } }
Example #3
Source File: ProbabilityTransformModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 6 votes |
@Override public ProbabilityTransformModelInfo getModelInfo(final ProbabilityTransformModel from, DataFrame df) { ProbabilityTransformModelInfo modelInfo = new ProbabilityTransformModelInfo(); modelInfo.setActualClickProportion(from.getActualClickProportion()); modelInfo.setUnderSampledClickProportion(from.getUnderSampledClickProportion()); modelInfo.setProbIndex(from.getProbIndex()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(from.getInputCol()); modelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(from.getOutputCol()); modelInfo.setOutputKeys(outputKeys); return modelInfo; }
Example #4
Source File: IfZeroVectorBridgeTest.java From spark-transformers with Apache License 2.0 | 6 votes |
@Test public void testIfZeroVectorSparse() { IfZeroVector sparkModel = new IfZeroVector() .setInputCol("vectorized_count") .setOutputCol("product_title_filtered") .setThenSetValue("others") .setElseSetCol("product_title"); System.out.println(sparseOrderDF.schema()); DataFrame transformed = sparkModel.transform(sparseOrderDF).orderBy("product_title"); System.out.println(transformed.schema()); //compare predictions Row[] sparkOutput = transformed.select("product_title_filtered").collect(); assertEquals("others", sparkOutput[0].get(0)); assertEquals("Nike Airmax 2015", sparkOutput[1].get(0)); assertEquals("Xiaomi Redmi Note", sparkOutput[2].get(0)); }
Example #5
Source File: Tagger.java From vn.vitk with GNU General Public License v3.0 | 6 votes |
void testRandomSplit(String inputFileName, int numFeatures, String modelFileName) { CMMParams params = new CMMParams() .setMaxIter(600) .setRegParam(1E-6) .setMarkovOrder(2) .setNumFeatures(numFeatures); JavaRDD<String> lines = jsc.textFile(inputFileName); DataFrame dataset = createDataFrame(lines.collect()); DataFrame[] splits = dataset.randomSplit(new double[]{0.9, 0.1}); DataFrame trainingData = splits[0]; System.out.println("Number of training sequences = " + trainingData.count()); DataFrame testData = splits[1]; System.out.println("Number of test sequences = " + testData.count()); // train and save a model on the training data cmmModel = train(trainingData, modelFileName, params); // test the model on the test data System.out.println("Test accuracy:"); evaluate(testData); // test the model on the training data System.out.println("Training accuracy:"); evaluate(trainingData); }
Example #6
Source File: LogisticRegressionModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 6 votes |
@Override public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) { final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo(); logisticRegressionModelInfo.setWeights(sparkLRModel.weights().toArray()); logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept()); logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses()); logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures()); logisticRegressionModelInfo.setThreshold((double) sparkLRModel.getThreshold().get()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add("features"); logisticRegressionModelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add("prediction"); outputKeys.add("probability"); logisticRegressionModelInfo.setOutputKeys(outputKeys); return logisticRegressionModelInfo; }
Example #7
Source File: JavaStocks.java From spark-ts-examples with Apache License 2.0 | 6 votes |
private static DataFrame loadObservations(JavaSparkContext sparkContext, SQLContext sqlContext, String path) { JavaRDD<Row> rowRdd = sparkContext.textFile(path).map((String line) -> { String[] tokens = line.split("\t"); ZonedDateTime dt = ZonedDateTime.of(Integer.parseInt(tokens[0]), Integer.parseInt(tokens[1]), Integer.parseInt(tokens[1]), 0, 0, 0, 0, ZoneId.systemDefault()); String symbol = tokens[3]; double price = Double.parseDouble(tokens[5]); return RowFactory.create(Timestamp.from(dt.toInstant()), symbol, price); }); List<StructField> fields = new ArrayList(); fields.add(DataTypes.createStructField("timestamp", DataTypes.TimestampType, true)); fields.add(DataTypes.createStructField("symbol", DataTypes.StringType, true)); fields.add(DataTypes.createStructField("price", DataTypes.DoubleType, true)); StructType schema = DataTypes.createStructType(fields); return sqlContext.createDataFrame(rowRdd, schema); }
Example #8
Source File: Tagger.java From vn.vitk with GNU General Public License v3.0 | 6 votes |
/** * Tags a list of sequences and returns a list of tag sequences. * @param sentences * @return a list of tagged sequences. */ public List<String> tag(List<String> sentences) { List<Row> rows = new LinkedList<Row>(); for (String sentence : sentences) { rows.add(RowFactory.create(sentence)); } StructType schema = new StructType(new StructField[]{ new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); SQLContext sqlContext = new SQLContext(jsc); DataFrame input = sqlContext.createDataFrame(rows, schema); if (cmmModel != null) { DataFrame output = cmmModel.transform(input).repartition(1); return output.javaRDD().map(new RowToStringFunction(1)).collect(); } else { System.err.println("Tagging model is null. You need to create or load a model first."); return null; } }
Example #9
Source File: LogisticRegressionModelInfoAdapter1.java From spark-transformers with Apache License 2.0 | 6 votes |
@Override public LogisticRegressionModelInfo getModelInfo(final LogisticRegressionModel sparkLRModel, DataFrame df) { final LogisticRegressionModelInfo logisticRegressionModelInfo = new LogisticRegressionModelInfo(); logisticRegressionModelInfo.setWeights(sparkLRModel.coefficients().toArray()); logisticRegressionModelInfo.setIntercept(sparkLRModel.intercept()); logisticRegressionModelInfo.setNumClasses(sparkLRModel.numClasses()); logisticRegressionModelInfo.setNumFeatures(sparkLRModel.numFeatures()); logisticRegressionModelInfo.setThreshold(sparkLRModel.getThreshold()); logisticRegressionModelInfo.setProbabilityKey(sparkLRModel.getProbabilityCol()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(sparkLRModel.getFeaturesCol()); logisticRegressionModelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(sparkLRModel.getPredictionCol()); outputKeys.add(sparkLRModel.getProbabilityCol()); logisticRegressionModelInfo.setOutputKeys(outputKeys); return logisticRegressionModelInfo; }
Example #10
Source File: CustomOneHotEncoderModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 6 votes |
@Override public OneHotEncoderModelInfo getModelInfo(final CustomOneHotEncoderModel from, DataFrame df) { OneHotEncoderModelInfo modelInfo = new OneHotEncoderModelInfo(); modelInfo.setNumTypes(from.vectorSize()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(from.getInputCol()); modelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(from.getOutputCol()); modelInfo.setOutputKeys(outputKeys); return modelInfo; }
Example #11
Source File: IfZeroVectorModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 6 votes |
@Override public IfZeroVectorModelInfo getModelInfo(final IfZeroVector from, DataFrame df) { IfZeroVectorModelInfo modelInfo = new IfZeroVectorModelInfo(); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(from.getInputCol()); modelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(from.getOutputCol()); modelInfo.setOutputKeys(outputKeys); modelInfo.setThenSetValue(from.getThenSetValue()); modelInfo.setElseSetCol(from.getElseSetCol()); return modelInfo; }
Example #12
Source File: CountVectorizerModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 6 votes |
@Override public CountVectorizerModelInfo getModelInfo(final CountVectorizerModel from, final DataFrame df) { final CountVectorizerModelInfo modelInfo = new CountVectorizerModelInfo(); modelInfo.setMinTF(from.getMinTF()); modelInfo.setVocabulary(from.vocabulary()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(from.getInputCol()); modelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(from.getOutputCol()); modelInfo.setOutputKeys(outputKeys); return modelInfo; }
Example #13
Source File: BucketizerModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 5 votes |
@Override public BucketizerModelInfo getModelInfo(final Bucketizer from, final DataFrame df) { final BucketizerModelInfo modelInfo = new BucketizerModelInfo(); modelInfo.setSplits(from.getSplits()); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(from.getInputCol()); modelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(from.getOutputCol()); modelInfo.setOutputKeys(outputKeys); return modelInfo; }
Example #14
Source File: AbstractJavaEsSparkSQLTest.java From elasticsearch-hadoop with Apache License 2.0 | 5 votes |
@Test public void testEsdataFrame1WriteWithId() throws Exception { DataFrame dataFrame = artistsAsDataFrame(); String target = resource("sparksql-test-scala-basic-write-id-mapping", "data", version); String docEndpoint = docEndpoint("sparksql-test-scala-basic-write-id-mapping", "data", version); JavaEsSparkSQL.saveToEs(dataFrame, target, ImmutableMap.of(ES_MAPPING_ID, "id")); assertTrue(RestUtils.exists(target)); assertThat(RestUtils.get(target + "/_search?"), containsString("345")); assertThat(RestUtils.exists(docEndpoint + "/1"), is(true)); }
Example #15
Source File: RandomForestRegressionModelInfoAdapterBridgeTest.java From spark-transformers with Apache License 2.0 | 5 votes |
@Test public void testRandomForestRegression() { // Load the data stored in LIBSVM format as a DataFrame. DataFrame data = sqlContext.read().format("libsvm").load("src/test/resources/regression_test.libsvm"); // Split the data into training and test sets (30% held out for testing) DataFrame[] splits = data.randomSplit(new double[]{0.7, 0.3}); DataFrame trainingData = splits[0]; DataFrame testData = splits[1]; // Train a RandomForest model. RandomForestRegressionModel regressionModel = new RandomForestRegressor() .setFeaturesCol("features").fit(trainingData); byte[] exportedModel = ModelExporter.export(regressionModel, null); Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); Row[] sparkOutput = regressionModel.transform(testData).select("features", "prediction").collect(); //compare predictions for (Row row : sparkOutput) { Vector v = (Vector) row.get(0); double actual = row.getDouble(1); Map<String, Object> inputData = new HashMap<String, Object>(); inputData.put(transformer.getInputKeys().iterator().next(), v.toArray()); transformer.transform(inputData); double predicted = (double) inputData.get(transformer.getOutputKeys().iterator().next()); System.out.println(actual + ", " + predicted); assertEquals(actual, predicted, EPSILON); } }
Example #16
Source File: CMMModel.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
@Override public CMMModel load(String path) { org.apache.spark.ml.util.DefaultParamsReader.Metadata metadata = DefaultParamsReader.loadMetadata(path, sc(), CMMModel.class.getName()); String pipelinePath = new Path(path, "pipelineModel").toString(); PipelineModel pipelineModel = PipelineModel.load(pipelinePath); String dataPath = new Path(path, "data").toString(); DataFrame df = sqlContext().read().format("parquet").load(dataPath); Row row = df.select("markovOrder", "weights", "tagDictionary").head(); // load the Markov order MarkovOrder order = MarkovOrder.values()[row.getInt(0)-1]; // load the weight vector Vector w = row.getAs(1); // load the tag dictionary @SuppressWarnings("unchecked") scala.collection.immutable.HashMap<String, WrappedArray<Integer>> td = (scala.collection.immutable.HashMap<String, WrappedArray<Integer>>)row.get(2); Map<String, Set<Integer>> tagDict = new HashMap<String, Set<Integer>>(); Iterator<Tuple2<String, WrappedArray<Integer>>> iterator = td.iterator(); while (iterator.hasNext()) { Tuple2<String, WrappedArray<Integer>> tuple = iterator.next(); Set<Integer> labels = new HashSet<Integer>(); scala.collection.immutable.List<Integer> list = tuple._2().toList(); for (int i = 0; i < list.size(); i++) labels.add(list.apply(i)); tagDict.put(tuple._1(), labels); } // build a CMM model CMMModel model = new CMMModel(pipelineModel, w, order, tagDict); DefaultParamsReader.getAndSetParams(model, metadata); return model; }
Example #17
Source File: LogisticRegression1ExporterTest.java From spark-transformers with Apache License 2.0 | 5 votes |
@Test public void shouldExportAndImportCorrectly() { //prepare data String datapath = "src/test/resources/binary_classification_test.libsvm"; DataFrame trainingData = sqlContext.read().format("libsvm").load(datapath); //Train model in spark LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData); //Export this model byte[] exportedModel = ModelExporter.export(lrmodel, trainingData); //Import it back LogisticRegressionModelInfo importedModel = (LogisticRegressionModelInfo) ModelImporter.importModelInfo(exportedModel); //check if they are exactly equal with respect to their fields //it maybe edge cases eg. order of elements in the list is changed assertEquals(lrmodel.intercept(), importedModel.getIntercept(), EPSILON); assertEquals(lrmodel.numClasses(), importedModel.getNumClasses(), EPSILON); assertEquals(lrmodel.numFeatures(), importedModel.getNumFeatures(), EPSILON); assertEquals(lrmodel.getThreshold(), importedModel.getThreshold(), EPSILON); for (int i = 0; i < importedModel.getNumFeatures(); i++) assertEquals(lrmodel.weights().toArray()[i], importedModel.getWeights()[i], EPSILON); assertEquals(lrmodel.getFeaturesCol(), importedModel.getInputKeys().iterator().next()); assertEquals(lrmodel.getPredictionCol(), importedModel.getOutputKeys().iterator().next()); }
Example #18
Source File: TransitionClassifier.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
/** * Converts a list of dependency graphs to a data frame. * @param jsc * @param graphs * @param featureFrame * @return a data frame */ private DataFrame toDataFrame(JavaSparkContext jsc, List<DependencyGraph> graphs, FeatureFrame featureFrame) { List<ParsingContext> list = new ArrayList<ParsingContext>(); for (DependencyGraph graph : graphs) { List<ParsingContext> xy = TransitionDecoder.decode(graph, featureFrame); list.addAll(xy); } JavaRDD<ParsingContext> javaRDD = jsc.parallelize(list); return sqlContext.createDataFrame(javaRDD, ParsingContext.class); }
Example #19
Source File: DataSparkFromRDD.java From toolbox with Apache License 2.0 | 5 votes |
@Override public DataFrame getDataFrame(SQLContext sql) { // Obtain the schema StructType schema = SchemaConverter.getSchema(attributes); // Transform the RDD JavaRDD<Row> rowRDD = DataFrameOps.toRowRDD(amidstRDD, attributes); // Create the DataFrame return sql.createDataFrame(rowRDD, schema); }
Example #20
Source File: Tagger.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
/** * Tags a data frame containing a column named 'sentence'. * @param input * @param outputFileName * @param outputFormat */ public void tag(DataFrame input, String outputFileName, OutputFormat outputFormat) { long tic = System.currentTimeMillis(); long duration = 0; if (cmmModel != null) { DataFrame output = cmmModel.transform(input).repartition(1); duration = System.currentTimeMillis() - tic; switch (outputFormat) { case JSON: output.write().json(outputFileName); break; case PARQUET: output.write().parquet(outputFileName); break; case TEXT: toTaggedSentence(output).repartition(1).saveAsTextFile(outputFileName); // output.select("prediction").write().text(outputFileName); break; } } else { System.err.println("Tagging model is null. You need to create or load a model first."); } if (verbose) { long n = input.count(); System.out.println(" Number of sentences = " + n); System.out.println(" Total tagging time = " + duration + " milliseconds."); System.out.println("Average tagging time = " + ((float)duration) / n + " milliseconds."); } }
Example #21
Source File: Log1PScalerModelInfoAdapter.java From spark-transformers with Apache License 2.0 | 5 votes |
@Override public Log1PScalerModelInfo getModelInfo(final Log1PScaler from, DataFrame df) { Log1PScalerModelInfo modelInfo = new Log1PScalerModelInfo(); Set<String> inputKeys = new LinkedHashSet<String>(); inputKeys.add(from.getInputCol()); modelInfo.setInputKeys(inputKeys); Set<String> outputKeys = new LinkedHashSet<String>(); outputKeys.add(from.getOutputCol()); modelInfo.setOutputKeys(outputKeys); return modelInfo; }
Example #22
Source File: Tagger.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
/** * Creates a data frame from a list of tagged sentences. * @param taggedSentences * @return a data frame of two columns: "sentence" and "partOfSpeech". */ public DataFrame createDataFrame(List<String> taggedSentences) { List<String> wordSequences = new LinkedList<String>(); List<String> tagSequences = new LinkedList<String>(); for (String taggedSentence : taggedSentences) { StringBuilder wordBuf = new StringBuilder(); StringBuilder tagBuf = new StringBuilder(); String[] tokens = taggedSentence.split("\\s+"); for (String token : tokens) { String[] parts = token.split("/"); if (parts.length == 2) { wordBuf.append(parts[0]); wordBuf.append(' '); tagBuf.append(parts[1]); tagBuf.append(' '); } else { // this token is "///" wordBuf.append('/'); wordBuf.append(' '); tagBuf.append('/'); tagBuf.append(' '); } } wordSequences.add(wordBuf.toString().trim()); tagSequences.add(tagBuf.toString().trim()); } if (verbose) { System.out.println("Number of sentences = " + wordSequences.size()); } List<Row> rows = new LinkedList<Row>(); for (int i = 0; i < wordSequences.size(); i++) { rows.add(RowFactory.create(wordSequences.get(i), tagSequences.get(i))); } JavaRDD<Row> jrdd = jsc.parallelize(rows); StructType schema = new StructType(new StructField[]{ new StructField("sentence", DataTypes.StringType, false, Metadata.empty()), new StructField("partOfSpeech", DataTypes.StringType, false, Metadata.empty()) }); return new SQLContext(jsc).createDataFrame(jrdd, schema); }
Example #23
Source File: StringIndexerBridgeTest.java From spark-transformers with Apache License 2.0 | 5 votes |
@Test public void testStringIndexer() { //prepare data StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); List<Row> trainingData = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); DataFrame dataset = sqlContext.createDataFrame(trainingData, schema); //train model in spark StringIndexerModel model = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex").fit(dataset); //Export this model byte[] exportedModel = ModelExporter.export(model, dataset); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //compare predictions Row[] sparkOutput = model.transform(dataset).orderBy("id").select("id", "label", "labelIndex").collect(); for (Row row : sparkOutput) { Map<String, Object> data = new HashMap<String, Object>(); data.put(model.getInputCol(), (String) row.get(1)); transformer.transform(data); double indexerOutput = (double) data.get(model.getOutputCol()); assertEquals(indexerOutput, (double) row.get(2), EPSILON); } }
Example #24
Source File: Tagger.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
/** * Trains a tagger with data specified in a data frame. The data frame has * two columns, one column "sentence" contains a word sequence, and the other column "partOfSpeech" * contains the corresponding tag sequence. Each row of the data frame specifies a tagged sequence * in the training set. * @param dataset * @param modelFileName * @param params * @return a {@link CMMModel} */ public CMMModel train(DataFrame dataset, String modelFileName, CMMParams params) { CMM cmm = new CMM(params).setVerbose(verbose); cmmModel = cmm.fit(dataset); try { cmmModel.write().overwrite().save(modelFileName); } catch (IOException e) { e.printStackTrace(); } return cmmModel; }
Example #25
Source File: AbstractJavaEsSparkSQLTest.java From elasticsearch-hadoop with Apache License 2.0 | 5 votes |
@Test public void testEsSchemaRDD1WriteWithMappingExclude() throws Exception { DataFrame dataFrame = artistsAsDataFrame(); String target = resource("sparksql-test-scala-basic-write-exclude-mapping", "data", version); JavaEsSparkSQL.saveToEs(dataFrame, target,ImmutableMap.of(ES_MAPPING_EXCLUDE, "url")); assertTrue(RestUtils.exists(target)); assertThat(RestUtils.get(target + "/_search?"), not(containsString("url"))); }
Example #26
Source File: InstanceRelationWriterTest.java From rdf2x with Apache License 2.0 | 5 votes |
private DataFrame getTestRelations() { List<Row> rows = new ArrayList<>(); rows.add(RowFactory.create( uriIndex.getIndex("http://example.com/knows"), uriIndex.getIndex("http://example.com/a"), 1L, uriIndex.getIndex("http://example.com/b"), 3L )); rows.add(RowFactory.create( uriIndex.getIndex("http://example.com/likes"), uriIndex.getIndex("http://example.com/a"), 2L, uriIndex.getIndex("http://example.com/b"), 3L )); return sql.createDataFrame(rows, new StructType() .add("predicateIndex", DataTypes.IntegerType, false) .add("fromTypeIndex", DataTypes.IntegerType, false) .add("fromID", DataTypes.LongType, false) .add("toTypeIndex", DataTypes.IntegerType, false) .add("toID", DataTypes.LongType, false) ); }
Example #27
Source File: EntitySalienceFeatureExtractorSpark.java From ambiverse-nlu with Apache License 2.0 | 5 votes |
/** * Extract a DataFrame ready for training or testing. * @param jsc * @param documents * @param sqlContext * @return * @throws ResourceInitializationException */ public DataFrame extract(JavaSparkContext jsc, JavaRDD<SCAS> documents, SQLContext sqlContext) throws ResourceInitializationException { Accumulator<Integer> TOTAL_DOCS = jsc.accumulator(0, "TOTAL_DOCS"); Accumulator<Integer> SALIENT_ENTITY_INSTANCES = jsc.accumulator(0, "SALIENT_ENTITY_INSTANCES"); Accumulator<Integer> NON_SALIENT_ENTITY_INSTANCES = jsc.accumulator(0, "NON_SALIENT_ENTITY_INSTANCES"); TrainingSettings trainingSettings = getTrainingSettings(); FeatureExtractor fe = new NYTEntitySalienceFeatureExtractor(); final int featureVectorSize = FeatureSetFactory.createFeatureSet(TrainingSettings.FeatureExtractor.ENTITY_SALIENCE).getFeatureVectorSize(); JavaRDD<TrainingInstance> trainingInstances = documents.flatMap(s -> { TOTAL_DOCS.add(1); return fe.getTrainingInstances(s.getJCas(), trainingSettings.getFeatureExtractor(), trainingSettings.getPositiveInstanceScalingFactor()); }); StructType schema = new StructType(new StructField[]{ new StructField("docId", DataTypes.StringType, false, Metadata.empty() ), new StructField("entityId", DataTypes.StringType, false, Metadata.empty() ), new StructField("label", DataTypes.DoubleType, false, Metadata.empty() ), new StructField("features", new VectorUDT(), false, Metadata.empty()) }); JavaRDD<Row> withFeatures = trainingInstances.map(ti -> { if (ti.getLabel() == 1.0) { SALIENT_ENTITY_INSTANCES.add(1); } else { NON_SALIENT_ENTITY_INSTANCES.add(1); } Vector vei = FeatureValueInstanceUtils.convertToSparkMLVector(ti, featureVectorSize); return RowFactory.create(ti.getDocId(), ti.getEntityId(), ti.getLabel(), vei); }); return sqlContext.createDataFrame(withFeatures, schema); }
Example #28
Source File: InstanceRelationWriterTest.java From rdf2x with Apache License 2.0 | 5 votes |
@Test public void testWriteRelationTablesWithoutPredicateIndex() throws IOException { InstanceRelationWriter writer = new InstanceRelationWriter(config .setStorePredicate(false), jsc(), persistor, rdfSchema); writer.writeRelationTables(getTestRelationSchema(), getTestRelations()); List<Row> rows = new ArrayList<>(); rows.add(RowFactory.create(1L, 3L)); rows.add(RowFactory.create(2L, 3L)); DataFrame result = this.result.values().iterator().next(); assertEquals("Expected schema of A_B was extracted", getExpectedSchemaOfAB(false, false), result.schema()); assertRDDEquals("Expected rows of A_B were extracted", jsc().parallelize(rows), result.toJavaRDD()); }
Example #29
Source File: LogisticRegression1BridgeTest.java From spark-transformers with Apache License 2.0 | 5 votes |
@Test public void testLogisticRegression() { //prepare data String datapath = "src/test/resources/binary_classification_test.libsvm"; DataFrame trainingData = sqlContext.read().format("libsvm").load(datapath); //Train model in spark LogisticRegressionModel lrmodel = new LogisticRegression().fit(trainingData); //Export this model byte[] exportedModel = ModelExporter.export(lrmodel, trainingData); //Import and get Transformer Transformer transformer = ModelImporter.importAndGetTransformer(exportedModel); //validate predictions List<LabeledPoint> testPoints = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().collect(); for (LabeledPoint i : testPoints) { Vector v = i.features(); double actual = lrmodel.predict(v); Map<String, Object> data = new HashMap<String, Object>(); data.put("features", v.toArray()); transformer.transform(data); double predicted = (double) data.get("prediction"); assertEquals(actual, predicted, EPSILON); } }
Example #30
Source File: Tagger.java From vn.vitk with GNU General Public License v3.0 | 5 votes |
/** * Tags a distributed list of sentences and writes the result to an output file with * a desired output format. * @param sentences * @param outputFileName * @param outputFormat */ public void tag(JavaRDD<Row> sentences, String outputFileName, OutputFormat outputFormat) { StructType schema = new StructType(new StructField[]{ new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); SQLContext sqlContext = new SQLContext(jsc); DataFrame input = sqlContext.createDataFrame(sentences, schema); tag(input, outputFileName, outputFormat); }