org.apache.spark.ml.Model Java Examples

The following examples show how to use org.apache.spark.ml.Model. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example #1
Source File: TreeModelUtil.java    From jpmml-sparkml with GNU Affero General Public License v3.0 6 votes vote down vote up
static
public <C extends ModelConverter<? extends M> & HasTreeOptions, M extends Model<M> & TreeEnsembleModel<T>, T extends Model<T> & DecisionTreeModel> List<TreeModel> encodeDecisionTreeEnsemble(C converter, PredicateManager predicateManager, Schema schema){
	M model = converter.getTransformer();

	Schema segmentSchema = schema.toAnonymousSchema();

	List<TreeModel> treeModels = new ArrayList<>();

	T[] trees = model.trees();
	for(T tree : trees){
		TreeModel treeModel = encodeDecisionTree(converter, tree, predicateManager, segmentSchema);

		treeModels.add(treeModel);
	}

	return treeModels;
}
 
Example #2
Source File: TreeModelUtil.java    From jpmml-sparkml with GNU Affero General Public License v3.0 5 votes vote down vote up
static
private <M extends Model<M> & DecisionTreeModel> TreeModel encodeTreeModel(M model, PredicateManager predicateManager, MiningFunction miningFunction, ScoreEncoder scoreEncoder, Schema schema){
	Node root = encodeNode(True.INSTANCE, model.rootNode(), predicateManager, new CategoryManager(), scoreEncoder, schema);

	TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root)
		.setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT);

	return treeModel;
}
 
Example #3
Source File: EntitySalienceTrainingSparkRunner.java    From ambiverse-nlu with Apache License 2.0 4 votes vote down vote up
@Override
    protected int run() throws Exception {

        SparkConf sparkConf = new SparkConf()
                .setAppName("EntitySalienceTrainingSparkRunner")
                .set("spark.hadoop.validateOutputSpecs", "false")
                .set("spark.yarn.executor.memoryOverhead", "3072")
                .set("spark.rdd.compress", "true")
                .set("spark.core.connection.ack.wait.timeout", "600")
                .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
                //.set("spark.kryo.registrationRequired", "true")
                .registerKryoClasses(new Class[] {SCAS.class, LabeledPoint.class, SparseVector.class, int[].class, double[].class,
                        InternalRow[].class, GenericInternalRow.class, Object[].class, GenericArrayData.class,
                        VectorIndexer.class})
                ;//.setMaster("local[4]"); //Remove this if you run it on the server.

        TrainingSettings trainingSettings = new TrainingSettings();

        if(folds != null) {
            trainingSettings.setNumFolds(folds);
        }
        if(method != null) {
            trainingSettings.setClassificationMethod(TrainingSettings.ClassificationMethod.valueOf(method));
        }
        if(defaultConf != null) {
            trainingSettings.setAidaDefaultConf(defaultConf);
        }

        if(scalingFactor != null) {
            trainingSettings.setPositiveInstanceScalingFactor(scalingFactor);
        }

        JavaSparkContext sc = new JavaSparkContext(sparkConf);
        int totalCores = Integer.parseInt(sc.getConf().get("spark.executor.instances"))
                * Integer.parseInt(sc.getConf().get("spark.executor.cores"));

//        int totalCores = 4;
////        trainingSettings.setFeatureExtractor(TrainingSettings.FeatureExtractor.ANNOTATE_AND_ENTITY_SALIENCE);
////        trainingSettings.setAidaDefaultConf("db");
//        //trainingSettings.setClassificationMethod(TrainingSettings.ClassificationMethod.LOG_REG);
//        trainingSettings.setPositiveInstanceScalingFactor(1);

        //Add the cache files to each node only if annotation is required.
        //The input documents could already be annotated, and in this case no caches are needed.
        if(trainingSettings.getFeatureExtractor().equals(TrainingSettings.FeatureExtractor.ANNOTATE_AND_ENTITY_SALIENCE)) {
            sc.addFile(trainingSettings.getBigramCountCache());
            sc.addFile(trainingSettings.getKeywordCountCache());
            sc.addFile(trainingSettings.getWordContractionsCache());
            sc.addFile(trainingSettings.getWordExpansionsCache());
            if (trainingSettings.getAidaDefaultConf().equals("db")) {
                sc.addFile(trainingSettings.getDatabaseAida());
            } else {
                sc.addFile(trainingSettings.getCassandraConfig());
            }
        }

        SQLContext sqlContext = new SQLContext(sc);


        FileSystem fs = FileSystem.get(new Configuration());

        int partitionNumber = 3 * totalCores;
        if(partitions != null) {
            partitionNumber = partitions;
        }

        //Read training documents serialized as SCAS
        JavaRDD<SCAS> documents = sc.sequenceFile(input, Text.class, SCAS.class, partitionNumber).values();

        //Instanciate a training spark runner
        TrainingSparkRunner trainingSparkRunner = new TrainingSparkRunner();

        //Train a model
        CrossValidatorModel model = trainingSparkRunner.crossValidate(sc, sqlContext, documents, trainingSettings);


        //Create the model path
        String modelPath = output+"/"+sc.getConf().getAppId()+"/model_"+trainingSettings.getClassificationMethod();

        //Delete the old model if there is one
        fs.delete(new Path(modelPath), true);

        //Save the new model model
        List<Model> models = new ArrayList<>();
        models.add(model.bestModel());
        sc.parallelize(models, 1).saveAsObjectFile(modelPath);

        //Save the model stats
        SparkClassificationModel.saveStats(model, trainingSettings, output+"/"+sc.getConf().getAppId()+"/");


        return 0;
    }
 
Example #4
Source File: TrainingSparkRunner.java    From ambiverse-nlu with Apache License 2.0 4 votes vote down vote up
/**
 * Train a specific model only without doing any cross validation or hyper parameter optimization.
 * The chosen hyper parameters should be set in the trainingSettings object. Set the corresponding map of the hyper paramets,
 * not the single parameters.
 *
 * @param jsc
 * @param sqlContext
 * @param documents
 * @param trainingSettings
 * @return
 * @throws ResourceInitializationException
 * @throws IOException
 */
public Model train(JavaSparkContext jsc, SQLContext sqlContext, JavaRDD<SCAS> documents, TrainingSettings trainingSettings) throws ResourceInitializationException, IOException {

    FeatureExtractorSpark fesr = FeatureExtractionFactory.createFeatureExtractorSparkRunner(trainingSettings);

    //Extract features for each document as LabelPoints
    DataFrame trainData = fesr.extract(jsc, documents, sqlContext);

    //Save the data for future use, instead of recomputing it all the time
    trainData.persist(StorageLevel.MEMORY_AND_DISK_SER_2());

    //DataFrame trainData = sqlContext.createDataFrame(labeledPoints, LabeledPoint.class);

    //Wrap the classification model base on the training settings
    SparkClassificationModel model = new SparkClassificationModel(trainingSettings.getClassificationMethod());

    Model resultModel = model.train(trainData, trainingSettings);

    return resultModel;
}
 
Example #5
Source File: ClusteringModelConverter.java    From jpmml-sparkml with GNU Affero General Public License v3.0 4 votes vote down vote up
@Override
public List<OutputField> registerOutputFields(Label label, org.dmg.pmml.Model pmmlModel, SparkMLEncoder encoder){
	T model = getTransformer();

	List<Integer> clusters = LabelUtil.createTargetCategories(getNumberOfClusters());

	String predictionCol = model.getPredictionCol();

	OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), OpType.CATEGORICAL, DataType.STRING)
		.setFinalResult(false);

	DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, true);

	OutputField predictedOutputField = new OutputField(FieldName.create(predictionCol), OpType.CATEGORICAL, DataType.INTEGER)
		.setResultFeature(ResultFeature.TRANSFORMED_VALUE)
		.setExpression(new FieldRef(pmmlPredictedField.getName()));

	DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, true);

	encoder.putOnlyFeature(predictionCol, new IndexFeature(encoder, predictedField, clusters));

	return Collections.emptyList();
}
 
Example #6
Source File: TreeModelUtil.java    From jpmml-sparkml with GNU Affero General Public License v3.0 4 votes vote down vote up
static
public <C extends ModelConverter<? extends M> & HasTreeOptions, M extends Model<M> & DecisionTreeModel> TreeModel encodeDecisionTree(C converter, Schema schema){
	PredicateManager predicateManager = new PredicateManager();

	return encodeDecisionTree(converter, predicateManager, schema);
}
 
Example #7
Source File: TreeModelUtil.java    From jpmml-sparkml with GNU Affero General Public License v3.0 4 votes vote down vote up
static
public <C extends ModelConverter<? extends M> & HasTreeOptions, M extends Model<M> & DecisionTreeModel> TreeModel encodeDecisionTree(C converter, PredicateManager predicateManager, Schema schema){
	return encodeDecisionTree(converter, converter.getTransformer(), predicateManager, schema);
}
 
Example #8
Source File: TreeModelUtil.java    From jpmml-sparkml with GNU Affero General Public License v3.0 4 votes vote down vote up
static
public <C extends ModelConverter<? extends M> & HasTreeOptions, M extends Model<M> & TreeEnsembleModel<T>, T extends Model<T> & DecisionTreeModel> List<TreeModel> encodeDecisionTreeEnsemble(C converter, Schema schema){
	PredicateManager predicateManager = new PredicateManager();

	return encodeDecisionTreeEnsemble(converter, predicateManager, schema);
}
 
Example #9
Source File: ModelConverter.java    From jpmml-sparkml with GNU Affero General Public License v3.0 4 votes vote down vote up
abstract
public org.dmg.pmml.Model encodeModel(Schema schema);
 
Example #10
Source File: ModelConverter.java    From jpmml-sparkml with GNU Affero General Public License v3.0 4 votes vote down vote up
public List<OutputField> registerOutputFields(Label label, org.dmg.pmml.Model model, SparkMLEncoder encoder){
	return null;
}
 
Example #11
Source File: ModelConverter.java    From jpmml-sparkml with GNU Affero General Public License v3.0 4 votes vote down vote up
public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder){
	Schema schema = encodeSchema(encoder);

	Label label = schema.getLabel();

	org.dmg.pmml.Model model = encodeModel(schema);

	List<OutputField> sparkOutputFields = registerOutputFields(label, model, encoder);
	if(sparkOutputFields != null && sparkOutputFields.size() > 0){
		org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);

		Output output = ModelUtil.ensureOutput(finalModel);

		List<OutputField> outputFields = output.getOutputFields();

		outputFields.addAll(sparkOutputFields);
	}

	return model;
}