org.apache.spark.ml.tuning.CrossValidatorModel Java Examples

The following examples show how to use org.apache.spark.ml.tuning.CrossValidatorModel. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: SparkClassificationModel.java    From ambiverse-nlu with Apache License 2.0 5 votes vote down vote up
/**
 *
 * @param trainData
 * @param trainingSettings
 * @return
 */
public CrossValidatorModel crossValidate(DataFrame trainData, TrainingSettings trainingSettings) {

    //First create the pipeline and the ParamGrid
    createPipeline(trainData, trainingSettings);

    // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
    // This will allow us to jointly choose parameters for all Pipeline stages.
    // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
    CrossValidator cv = new CrossValidator()
            .setEstimator(pipeline)
            .setEvaluator(evaluator)
            .setEstimatorParamMaps(paramGrid)
            .setNumFolds(trainingSettings.getNumFolds());

    if(classificationMethod.equals(TrainingSettings.ClassificationMethod.LOG_REG)) {
        long numPositive = trainData.filter(col("label").equalTo("1.0")).count();
        long datasetSize = trainData.count();
        double balancingRatio = (double)(datasetSize - numPositive) / datasetSize;

        trainData = trainData
                .withColumn("classWeightCol",
                        when(col("label").equalTo("1.0"), 1* balancingRatio)
                                .otherwise((1 * (1.0 - balancingRatio))));
    }
    // Run cross-validation, and choose the best set of parameters.
    bestModel = cv.fit(trainData);
    System.out.println("IS LARGER BETTER ?"+bestModel.getEvaluator().isLargerBetter());
    return bestModel;
}
 
Example #2
Source File: SparkClassificationModel.java    From ambiverse-nlu with Apache License 2.0 5 votes vote down vote up
public static void debugOutputModel(CrossValidatorModel model, TrainingSettings trainingSettings, String output) throws IOException {
    FileSystem fs = FileSystem.get(new Configuration());
    Path statsPath = new Path(output+"debug_"+trainingSettings.getClassificationMethod()+".txt");
    fs.delete(statsPath, true);

    FSDataOutputStream fsdos = fs.create(statsPath);
    PipelineModel pipelineModel = (PipelineModel) model.bestModel();
    switch (trainingSettings.getClassificationMethod()) {
        case RANDOM_FOREST:
            for(int i=0; i< pipelineModel.stages().length; i++) {
                if (pipelineModel.stages()[i] instanceof RandomForestClassificationModel) {
                    RandomForestClassificationModel rfModel = (RandomForestClassificationModel) (pipelineModel.stages()[i]);
                    IOUtils.write(rfModel.toDebugString(), fsdos);
                    logger.info(rfModel.toDebugString());
                }
            }
            break;
        case LOG_REG:
            for(int i=0; i< pipelineModel.stages().length; i++) {
                if (pipelineModel.stages()[i] instanceof LogisticRegressionModel) {
                    LogisticRegressionModel lgModel = (LogisticRegressionModel) (pipelineModel.stages()[i]);
                    IOUtils.write(lgModel.toString(), fsdos);
                    logger.info(lgModel.toString());
                }
            }
            break;
    }
    fsdos.flush();
    IOUtils.closeQuietly(fsdos);
}
 
Example #3
Source File: EntitySalienceTrainingSparkRunner.java    From ambiverse-nlu with Apache License 2.0 4 votes vote down vote up
@Override
    protected int run() throws Exception {

        SparkConf sparkConf = new SparkConf()
                .setAppName("EntitySalienceTrainingSparkRunner")
                .set("spark.hadoop.validateOutputSpecs", "false")
                .set("spark.yarn.executor.memoryOverhead", "3072")
                .set("spark.rdd.compress", "true")
                .set("spark.core.connection.ack.wait.timeout", "600")
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
                //.set("spark.kryo.registrationRequired", "true")
                .registerKryoClasses(new Class[] {SCAS.class, LabeledPoint.class, SparseVector.class, int[].class, double[].class,
                        InternalRow[].class, GenericInternalRow.class, Object[].class, GenericArrayData.class,
                        VectorIndexer.class})
                ;//.setMaster("local[4]"); //Remove this if you run it on the server.

        TrainingSettings trainingSettings = new TrainingSettings();

        if(folds != null) {
            trainingSettings.setNumFolds(folds);
        }
        if(method != null) {
            trainingSettings.setClassificationMethod(TrainingSettings.ClassificationMethod.valueOf(method));
        }
        if(defaultConf != null) {
            trainingSettings.setAidaDefaultConf(defaultConf);
        }

        if(scalingFactor != null) {
            trainingSettings.setPositiveInstanceScalingFactor(scalingFactor);
        }

        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        int totalCores = Integer.parseInt(sc.getConf().get("spark.executor.instances"))
                * Integer.parseInt(sc.getConf().get("spark.executor.cores"));

//        int totalCores = 4;
////        trainingSettings.setFeatureExtractor(TrainingSettings.FeatureExtractor.ANNOTATE_AND_ENTITY_SALIENCE);
////        trainingSettings.setAidaDefaultConf("db");
//        //trainingSettings.setClassificationMethod(TrainingSettings.ClassificationMethod.LOG_REG);
//        trainingSettings.setPositiveInstanceScalingFactor(1);

        //Add the cache files to each node only if annotation is required.
        //The input documents could already be annotated, and in this case no caches are needed.
        if(trainingSettings.getFeatureExtractor().equals(TrainingSettings.FeatureExtractor.ANNOTATE_AND_ENTITY_SALIENCE)) {
            sc.addFile(trainingSettings.getBigramCountCache());
            sc.addFile(trainingSettings.getKeywordCountCache());
            sc.addFile(trainingSettings.getWordContractionsCache());
            sc.addFile(trainingSettings.getWordExpansionsCache());
            if (trainingSettings.getAidaDefaultConf().equals("db")) {
                sc.addFile(trainingSettings.getDatabaseAida());
            } else {
                sc.addFile(trainingSettings.getCassandraConfig());
            }
        }

        SQLContext sqlContext = new SQLContext(sc);


        FileSystem fs = FileSystem.get(new Configuration());

        int partitionNumber = 3 * totalCores;
        if(partitions != null) {
            partitionNumber = partitions;
        }

        //Read training documents serialized as SCAS
        JavaRDD<SCAS> documents = sc.sequenceFile(input, Text.class, SCAS.class, partitionNumber).values();

        //Instanciate a training spark runner
        TrainingSparkRunner trainingSparkRunner = new TrainingSparkRunner();

        //Train a model
        CrossValidatorModel model = trainingSparkRunner.crossValidate(sc, sqlContext, documents, trainingSettings);


        //Create the model path
        String modelPath = output+"/"+sc.getConf().getAppId()+"/model_"+trainingSettings.getClassificationMethod();

        //Delete the old model if there is one
        fs.delete(new Path(modelPath), true);

        //Save the new model model
        List<Model> models = new ArrayList<>();
        models.add(model.bestModel());
        sc.parallelize(models, 1).saveAsObjectFile(modelPath);

        //Save the model stats
        SparkClassificationModel.saveStats(model, trainingSettings, output+"/"+sc.getConf().getAppId()+"/");


        return 0;
    }
 
Example #4
Source File: SparkClassificationModel.java    From ambiverse-nlu with Apache License 2.0 4 votes vote down vote up
public static void saveStats(CrossValidatorModel model, TrainingSettings trainingSettings, String output) throws IOException {
    double[] avgMetrics = model.avgMetrics();
    double bestMetric = 0;
    int bestIndex=0;

    for(int i=0; i<avgMetrics.length; i++) {
        if(avgMetrics[i] > bestMetric) {
            bestMetric = avgMetrics[i];
            bestIndex = i;
        }
    }


    FileSystem fs = FileSystem.get(new Configuration());
    Path statsPath = new Path(output+"stats_"+trainingSettings.getClassificationMethod()+".txt");
    fs.delete(statsPath, true);

    FSDataOutputStream fsdos = fs.create(statsPath);

    String avgLine="Average cross-validation metrics: "+ Arrays.toString(model.avgMetrics());
    String bestMetricLine="\nBest cross-validation metric ["+trainingSettings.getMetricName()+"]: "+bestMetric;
    String bestSetParamLine= "\nBest set of parameters: "+model.getEstimatorParamMaps()[bestIndex];

    logger.info(avgLine);
    logger.info(bestMetricLine);
    logger.info(bestSetParamLine);


    IOUtils.write(avgLine, fsdos);
    IOUtils.write(bestMetricLine, fsdos);
    IOUtils.write(bestSetParamLine, fsdos);

    PipelineModel pipelineModel = (PipelineModel) model.bestModel();
    for(Transformer t : pipelineModel.stages()) {
        if(t instanceof ClassificationModel) {
            IOUtils.write("\n"+((Model) t).parent().extractParamMap().toString(), fsdos);
            logger.info(((Model) t).parent().extractParamMap().toString());
        }
    }

    fsdos.flush();
    IOUtils.closeQuietly(fsdos);

    debugOutputModel(model,trainingSettings, output);
}
 
Example #5
Source File: TrainingSparkRunner.java    From ambiverse-nlu with Apache License 2.0 4 votes vote down vote up
/**
 * Train classification model for documents by doing cross validation and hyper parameter optimization at the same time.
 * The produced model contains the best model and statistics about the runs, which are later saved from the caller method.
 *
 * @param jsc
 * @param sqlContext
 * @param documents
 * @param trainingSettings
 * @return
 * @throws ResourceInitializationException
 * @throws IOException
 */
public CrossValidatorModel crossValidate(JavaSparkContext jsc, SQLContext sqlContext, JavaRDD<SCAS> documents, TrainingSettings trainingSettings) throws ResourceInitializationException, IOException {

    FeatureExtractorSpark fesr = FeatureExtractionFactory.createFeatureExtractorSparkRunner(trainingSettings);

    //Extract features for each document as LabelPoints
    DataFrame trainData = fesr.extract(jsc, documents, sqlContext);

    //Save the data for future use, instead of recomputing it all the time
    trainData.persist(StorageLevel.MEMORY_AND_DISK_SER_2());

    //DataFrame trainData = sqlContext.createDataFrame(labeledPoints, LabeledPoint.class);

    //Wrap the classification model base on the training settings
    SparkClassificationModel model = new SparkClassificationModel(trainingSettings.getClassificationMethod());

    //Train the be best model using CrossValidator
    CrossValidatorModel cvModel = model.crossValidate(trainData, trainingSettings);

    return cvModel;
}
 
Example #6
Source File: JavaModelSelectionViaCrossValidationExample.java    From SparkDemo with MIT License 4 votes vote down vote up
public static void main(String[] args) {
  SparkSession spark = SparkSession
    .builder()
    .appName("JavaModelSelectionViaCrossValidationExample")
    .getOrCreate();

  // $example on$
  // Prepare training documents, which are labeled.
  Dataset<Row> training = spark.createDataFrame(Arrays.asList(
    new JavaLabeledDocument(0L, "a b c d e spark", 1.0),
    new JavaLabeledDocument(1L, "b d", 0.0),
    new JavaLabeledDocument(2L,"spark f g h", 1.0),
    new JavaLabeledDocument(3L, "hadoop mapreduce", 0.0),
    new JavaLabeledDocument(4L, "b spark who", 1.0),
    new JavaLabeledDocument(5L, "g d a y", 0.0),
    new JavaLabeledDocument(6L, "spark fly", 1.0),
    new JavaLabeledDocument(7L, "was mapreduce", 0.0),
    new JavaLabeledDocument(8L, "e spark program", 1.0),
    new JavaLabeledDocument(9L, "a e c l", 0.0),
    new JavaLabeledDocument(10L, "spark compile", 1.0),
    new JavaLabeledDocument(11L, "hadoop software", 0.0)
  ), JavaLabeledDocument.class);

  // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
  Tokenizer tokenizer = new Tokenizer()
    .setInputCol("text")
    .setOutputCol("words");
  HashingTF hashingTF = new HashingTF()
    .setNumFeatures(1000)
    .setInputCol(tokenizer.getOutputCol())
    .setOutputCol("features");
  LogisticRegression lr = new LogisticRegression()
    .setMaxIter(10)
    .setRegParam(0.01);
  Pipeline pipeline = new Pipeline()
    .setStages(new PipelineStage[] {tokenizer, hashingTF, lr});

  // We use a ParamGridBuilder to construct a grid of parameters to search over.
  // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
  // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
  ParamMap[] paramGrid = new ParamGridBuilder()
    .addGrid(hashingTF.numFeatures(), new int[] {10, 100, 1000})
    .addGrid(lr.regParam(), new double[] {0.1, 0.01})
    .build();

  // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
  // This will allow us to jointly choose parameters for all Pipeline stages.
  // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
  // Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
  // is areaUnderROC.
  CrossValidator cv = new CrossValidator()
    .setEstimator(pipeline)
    .setEvaluator(new BinaryClassificationEvaluator())
    .setEstimatorParamMaps(paramGrid).setNumFolds(2);  // Use 3+ in practice

  // Run cross-validation, and choose the best set of parameters.
  CrossValidatorModel cvModel = cv.fit(training);

  // Prepare test documents, which are unlabeled.
  Dataset<Row> test = spark.createDataFrame(Arrays.asList(
    new JavaDocument(4L, "spark i j k"),
    new JavaDocument(5L, "l m n"),
    new JavaDocument(6L, "mapreduce spark"),
    new JavaDocument(7L, "apache hadoop")
  ), JavaDocument.class);

  // Make predictions on test documents. cvModel uses the best model found (lrModel).
  Dataset<Row> predictions = cvModel.transform(test);
  for (Row r : predictions.select("id", "text", "probability", "prediction").collectAsList()) {
    System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
      + ", prediction=" + r.get(3));
  }
  // $example off$

  spark.stop();
}