Java Code Examples for org.apache.spark.sql.DataFrame#persist()

The following examples show how to use org.apache.spark.sql.DataFrame#persist() . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: TrainingSparkRunner.java    From ambiverse-nlu with Apache License 2.0 4 votes vote down vote up
/**
 * Train classification model for documents by doing cross validation and hyper parameter optimization at the same time.
 * The produced model contains the best model and statistics about the runs, which are later saved from the caller method.
 *
 * @param jsc
 * @param sqlContext
 * @param documents
 * @param trainingSettings
 * @return
 * @throws ResourceInitializationException
 * @throws IOException
 */
public CrossValidatorModel crossValidate(JavaSparkContext jsc, SQLContext sqlContext, JavaRDD<SCAS> documents, TrainingSettings trainingSettings) throws ResourceInitializationException, IOException {

    FeatureExtractorSpark fesr = FeatureExtractionFactory.createFeatureExtractorSparkRunner(trainingSettings);

    //Extract features for each document as LabelPoints
    DataFrame trainData = fesr.extract(jsc, documents, sqlContext);

    //Save the data for future use, instead of recomputing it all the time
    trainData.persist(StorageLevel.MEMORY_AND_DISK_SER_2());

    //DataFrame trainData = sqlContext.createDataFrame(labeledPoints, LabeledPoint.class);

    //Wrap the classification model base on the training settings
    SparkClassificationModel model = new SparkClassificationModel(trainingSettings.getClassificationMethod());

    //Train the be best model using CrossValidator
    CrossValidatorModel cvModel = model.crossValidate(trainData, trainingSettings);

    return cvModel;
}
 
Example 2
Source File: TrainingSparkRunner.java    From ambiverse-nlu with Apache License 2.0 4 votes vote down vote up
/**
 * Train a specific model only without doing any cross validation or hyper parameter optimization.
 * The chosen hyper parameters should be set in the trainingSettings object. Set the corresponding map of the hyper paramets,
 * not the single parameters.
 *
 * @param jsc
 * @param sqlContext
 * @param documents
 * @param trainingSettings
 * @return
 * @throws ResourceInitializationException
 * @throws IOException
 */
public Model train(JavaSparkContext jsc, SQLContext sqlContext, JavaRDD<SCAS> documents, TrainingSettings trainingSettings) throws ResourceInitializationException, IOException {

    FeatureExtractorSpark fesr = FeatureExtractionFactory.createFeatureExtractorSparkRunner(trainingSettings);

    //Extract features for each document as LabelPoints
    DataFrame trainData = fesr.extract(jsc, documents, sqlContext);

    //Save the data for future use, instead of recomputing it all the time
    trainData.persist(StorageLevel.MEMORY_AND_DISK_SER_2());

    //DataFrame trainData = sqlContext.createDataFrame(labeledPoints, LabeledPoint.class);

    //Wrap the classification model base on the training settings
    SparkClassificationModel model = new SparkClassificationModel(trainingSettings.getClassificationMethod());

    Model resultModel = model.train(trainData, trainingSettings);

    return resultModel;
}