Java Code Examples for org.dmg.pmml.mining.Segmentation#MultipleModelMethod

The following examples show how to use org.dmg.pmml.mining.Segmentation#MultipleModelMethod . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: MiningModelUtil.java    From jpmml-evaluator with GNU Affero General Public License v3.0 6 votes vote down vote up
static
public SegmentResult asSegmentResult(Segmentation.MultipleModelMethod multipleModelMethod, Map<FieldName, ?> predictions){

	switch(multipleModelMethod){
		case SELECT_FIRST:
		case SELECT_ALL:
		case MODEL_CHAIN:
			{
				if(predictions instanceof SegmentResult){
					SegmentResult segmentResult = (SegmentResult)predictions;

					return segmentResult;
				}
			}
			break;
		default:
			break;
	}

	return null;
}
 
Example 2
Source File: MiningModelEvaluator.java    From jpmml-evaluator with GNU Affero General Public License v3.0 6 votes vote down vote up
private List<OutputField> createNestedOutputFields(){
	MiningModel miningModel = getModel();

	Segmentation segmentation = miningModel.getSegmentation();

	List<Segment> segments = segmentation.getSegments();

	Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
	switch(multipleModelMethod){
		case SELECT_FIRST:
			return createNestedOutputFields(getActiveHead(segments));
		case SELECT_ALL:
			// Ignored
			break;
		case MODEL_CHAIN:
			return createNestedOutputFields(getActiveTail(segments));
		default:
			break;
	}

	return Collections.emptyList();
}
 
Example 3
Source File: TargetCategoryParser.java    From jpmml-evaluator with GNU Affero General Public License v3.0 6 votes vote down vote up
private void processMiningModel(MiningModel miningModel){
	Segmentation segmentation = miningModel.getSegmentation();

	if(segmentation != null){
		Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();

		switch(multipleModelMethod){
			case SELECT_FIRST:
			case SELECT_ALL:
			case MODEL_CHAIN:
				{
					this.targetDataTypes.push(Collections.singletonMap(Evaluator.DEFAULT_TARGET_NAME, null));

					this.dataType = null;

					return;
				}
			default:
				break;
		}
	}

	processModel(miningModel);
}
 
Example 4
Source File: BaggingClassifier.java    From jpmml-sklearn with GNU Affero General Public License v3.0 6 votes vote down vote up
@Override
public MiningModel encodeModel(Schema schema){
	List<? extends Classifier> estimators = getEstimators();
	List<List<Integer>> estimatorsFeatures = getEstimatorsFeatures();

	Segmentation.MultipleModelMethod multipleModelMethod = Segmentation.MultipleModelMethod.AVERAGE;

	for(Classifier estimator : estimators){

		if(!estimator.hasProbabilityDistribution()){
			multipleModelMethod = Segmentation.MultipleModelMethod.MAJORITY_VOTE;

			break;
		}
	}

	MiningModel miningModel = BaggingUtil.encodeBagging(estimators, estimatorsFeatures, multipleModelMethod, MiningFunction.CLASSIFICATION, schema)
		.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel)schema.getLabel()));

	return miningModel;
}
 
Example 5
Source File: VotingClassifier.java    From jpmml-sklearn with GNU Affero General Public License v3.0 5 votes vote down vote up
static
private Segmentation.MultipleModelMethod parseVoting(String voting, boolean weighted){

	switch(voting){
		case "hard":
			return (weighted ? Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE : Segmentation.MultipleModelMethod.MAJORITY_VOTE);
		case "soft":
			return (weighted ? Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE : Segmentation.MultipleModelMethod.AVERAGE);
		default:
			throw new IllegalArgumentException(voting);
	}
}
 
Example 6
Source File: MiningModelEvaluator.java    From jpmml-evaluator with GNU Affero General Public License v3.0 5 votes vote down vote up
private Map<FieldName, ?> getSegmentationResult(Set<Segmentation.MultipleModelMethod> multipleModelMethods, List<SegmentResult> segmentResults){
	MiningModel miningModel = getModel();

	Segmentation segmentation = miningModel.getSegmentation();

	Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
	switch(multipleModelMethod){
		case SELECT_FIRST:
			if(!segmentResults.isEmpty()){
				return segmentResults.get(0);
			}
			break;
		case SELECT_ALL:
			return selectAll(segmentResults);
		case MODEL_CHAIN:
			if(!segmentResults.isEmpty()){
				return segmentResults.get(segmentResults.size() - 1);
			}
			break;
		default:
			if(!(multipleModelMethods).contains(multipleModelMethod)){
				throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
			}
			break;
	}

	// "If no segments have predicates that evaluate to true, then the result is a missing value"
	if(segmentResults.isEmpty()){
		return Collections.singletonMap(getTargetName(), null);
	}

	return null;
}
 
Example 7
Source File: FieldResolver.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
static
private List<Output> getEarlierOutputs(Segmentation segmentation, Segment targetSegment){
	List<Output> result = new ArrayList<>();

	Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
	switch(multipleModelMethod){
		case MODEL_CHAIN:
			break;
		default:
			return Collections.emptyList();
	}

	List<Segment> segments = segmentation.getSegments();
	for(Segment segment : segments){
		Model model = segment.getModel();

		if(targetSegment != null && (targetSegment).equals(segment)){
			break;
		}

		Output output = model.getOutput();
		if(output != null){
			result.add(output);
		}
	}

	return result;
}
 
Example 8
Source File: MiningModelUtil.java    From pyramid with Apache License 2.0 5 votes vote down vote up
static
public Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models, List<? extends Number> weights){

    if((weights != null) && (models.size() != weights.size())){
        throw new IllegalArgumentException();
    }

    List<Segment> segments = new ArrayList<>();

    for(int i = 0; i < models.size(); i++){
        Model model = models.get(i);
        Number weight = (weights != null ? weights.get(i) : null);

        Segment segment = new Segment()
                .setId(String.valueOf(i + 1))
                .setPredicate(new True())
                .setModel(model);

        if(weight != null && !ValueUtil.isOne(weight)){
            segment.setWeight(ValueUtil.asDouble(weight));
        }

        segments.add(segment);
    }

    return new Segmentation(multipleModelMethod, segments);
}
 
Example 9
Source File: ForestUtil.java    From jpmml-sklearn with GNU Affero General Public License v3.0 5 votes vote down vote up
static
public <E extends Estimator & HasEstimatorEnsemble<T> & HasTreeOptions, T extends Estimator & HasTree> MiningModel encodeBaseForest(E estimator, Segmentation.MultipleModelMethod multipleModelMethod, MiningFunction miningFunction, Schema schema){
	List<TreeModel> treeModels = TreeUtil.encodeTreeModelEnsemble(estimator, miningFunction, schema);

	MiningModel miningModel = new MiningModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()))
		.setSegmentation(MiningModelUtil.createSegmentation(multipleModelMethod, treeModels));

	return TreeUtil.transform(estimator, miningModel);
}
 
Example 10
Source File: RDFUpdate.java    From oryx with Apache License 2.0 4 votes vote down vote up
private PMML rdfModelToPMML(RandomForestModel rfModel,
                            CategoricalValueEncodings categoricalValueEncodings,
                            int maxDepth,
                            int maxSplitCandidates,
                            String impurity,
                            List<? extends IntLongMap> nodeIDCounts,
                            IntLongMap predictorIndexCounts) {

  boolean classificationTask = rfModel.algo().equals(Algo.Classification());
  Preconditions.checkState(classificationTask == inputSchema.isClassification());

  DecisionTreeModel[] trees = rfModel.trees();

  Model model;
  if (trees.length == 1) {
    model = toTreeModel(trees[0], categoricalValueEncodings, nodeIDCounts.get(0));
  } else {
    MiningModel miningModel = new MiningModel();
    model = miningModel;
    Segmentation.MultipleModelMethod multipleModelMethodType = classificationTask ?
        Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE :
        Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE;
    List<Segment> segments = new ArrayList<>(trees.length);
    for (int treeID = 0; treeID < trees.length; treeID++) {
      TreeModel treeModel =
          toTreeModel(trees[treeID], categoricalValueEncodings, nodeIDCounts.get(treeID));
      segments.add(new Segment()
           .setId(Integer.toString(treeID))
           .setPredicate(new True())
           .setModel(treeModel)
           .setWeight(1.0)); // No weights in MLlib impl now
    }
    miningModel.setSegmentation(new Segmentation(multipleModelMethodType, segments));
  }

  model.setMiningFunction(classificationTask ?
                          MiningFunction.CLASSIFICATION :
                          MiningFunction.REGRESSION);

  double[] importances = countsToImportances(predictorIndexCounts);
  model.setMiningSchema(AppPMMLUtils.buildMiningSchema(inputSchema, importances));
  DataDictionary dictionary =
      AppPMMLUtils.buildDataDictionary(inputSchema, categoricalValueEncodings);

  PMML pmml = PMMLUtils.buildSkeletonPMML();
  pmml.setDataDictionary(dictionary);
  pmml.addModels(model);

  AppPMMLUtils.addExtension(pmml, "maxDepth", maxDepth);
  AppPMMLUtils.addExtension(pmml, "maxSplitCandidates", maxSplitCandidates);
  AppPMMLUtils.addExtension(pmml, "impurity", impurity);

  return pmml;
}
 
Example 11
Source File: MiningModelUtil.java    From pyramid with Apache License 2.0 4 votes vote down vote up
static
public Segmentation createSegmentation(Segmentation.MultipleModelMethod multipleModelMethod, List<? extends Model> models){
    return createSegmentation(multipleModelMethod, models, null);
}
 
Example 12
Source File: BaggingUtil.java    From jpmml-sklearn with GNU Affero General Public License v3.0 4 votes vote down vote up
static
public <E extends Estimator> MiningModel encodeBagging(List<E> estimators, List<List<Integer>> estimatorsFeatures, Segmentation.MultipleModelMethod multipleModelMethod, MiningFunction miningFunction, Schema schema){
	Schema segmentSchema = schema.toAnonymousSchema();

	List<Model> models = new ArrayList<>();

	for(int i = 0; i < estimators.size(); i++){
		E estimator = estimators.get(i);
		List<Integer> estimatorFeatures = estimatorsFeatures.get(i);

		Schema estimatorSchema = segmentSchema.toSubSchema(Ints.toArray(estimatorFeatures));

		Model model = estimator.encodeModel(estimatorSchema);

		models.add(model);
	}

	MiningModel miningModel = new MiningModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()))
		.setSegmentation(MiningModelUtil.createSegmentation(multipleModelMethod, models));

	return miningModel;
}
 
Example 13
Source File: MiningModelEvaluator.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
public MiningModelEvaluator(PMML pmml, MiningModel miningModel){
	super(pmml, miningModel);

	if(miningModel.hasEmbeddedModels()){
		List<EmbeddedModel> embeddedModels = miningModel.getEmbeddedModels();

		EmbeddedModel embeddedModel = Iterables.getFirst(embeddedModels, null);

		throw new UnsupportedElementException(embeddedModel);
	}

	Segmentation segmentation = miningModel.getSegmentation();
	if(segmentation == null){
		throw new MissingElementException(miningModel, PMMLElements.MININGMODEL_SEGMENTATION);
	}

	Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
	if(multipleModelMethod == null){
		throw new MissingAttributeException(segmentation, PMMLAttributes.SEGMENTATION_MULTIPLEMODELMETHOD);
	} // End if

	if(!segmentation.hasSegments()){
		throw new MissingElementException(segmentation, PMMLElements.SEGMENTATION_SEGMENTS);
	}

	List<Segment> segments = segmentation.getSegments();
	for(Segment segment : segments){
		VariableWeight variableWeight = segment.getVariableWeight();

		if(variableWeight != null){
			throw new UnsupportedElementException(variableWeight);
		}
	}

	LocalTransformations localTransformations = segmentation.getLocalTransformations();
	if(localTransformations != null){
		throw new UnsupportedElementException(localTransformations);
	}

	Output output = miningModel.getOutput();
	if(output != null && output.hasOutputFields()){
		this.segmentResultFeatures = CacheUtil.getValue(output, MiningModelEvaluator.segmentResultFeaturesCache);
	}
}
 
Example 14
Source File: MiningModelUtil.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
static
public <V extends Number> ValueMap<Object, V> aggregateProbabilities(ValueFactory<V> valueFactory, Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, Number missingThreshold, List<?> categories, List<SegmentResult> segmentResults){
	ProbabilityAggregator<V> aggregator;

	switch(multipleModelMethod){
		case AVERAGE:
			aggregator = new ProbabilityAggregator.Average<>(valueFactory);
			break;
		case WEIGHTED_AVERAGE:
			aggregator = new ProbabilityAggregator.WeightedAverage<>(valueFactory);
			break;
		case MEDIAN:
			aggregator = new ProbabilityAggregator.Median<>(valueFactory, segmentResults.size());
			break;
		case MAX:
			aggregator = new ProbabilityAggregator.Max<>(valueFactory, segmentResults.size());
			break;
		default:
			throw new IllegalArgumentException();
	}

	Fraction<V> missingFraction = null;

	segmentResults:
	for(SegmentResult segmentResult : segmentResults){
		Object targetValue = segmentResult.getTargetValue();

		if(targetValue == null){

			switch(missingPredictionTreatment){
				case RETURN_MISSING:
					return null;
				case SKIP_SEGMENT:
					if(missingFraction == null){
						missingFraction = new Fraction<>(valueFactory, segmentResults);
					} // End if

					if(missingFraction.update(segmentResult, missingThreshold)){
						return null;
					}

					continue segmentResults;
				case CONTINUE:
					return null;
				default:
					throw new IllegalArgumentException();
			}
		}

		HasProbability hasProbability;

		try {
			hasProbability = TypeUtil.cast(HasProbability.class, targetValue);
		} catch(TypeCheckException tce){
			throw tce.ensureContext(segmentResult.getSegment());
		}

		switch(multipleModelMethod){
			case AVERAGE:
			case MEDIAN:
			case MAX:
				aggregator.add(hasProbability);
				break;
			case WEIGHTED_AVERAGE:
				Number weight = segmentResult.getWeight();

				aggregator.add(hasProbability, weight);
				break;
			default:
				throw new IllegalArgumentException();
		}
	}

	switch(multipleModelMethod){
		case AVERAGE:
			return aggregator.averageMap();
		case WEIGHTED_AVERAGE:
			return aggregator.weightedAverageMap();
		case MEDIAN:
			return aggregator.medianMap(categories);
		case MAX:
			return aggregator.maxMap(categories);
		default:
			throw new IllegalArgumentException();
	}
}
 
Example 15
Source File: MiningModelEvaluator.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
@Override
protected <V extends Number> Map<FieldName, ?> evaluateRegression(ValueFactory<V> valueFactory, EvaluationContext context){
	MiningModel miningModel = getModel();

	List<SegmentResult> segmentResults = evaluateSegmentation((MiningModelEvaluationContext)context);

	Map<FieldName, ?> predictions = getSegmentationResult(REGRESSION_METHODS, segmentResults);
	if(predictions != null){
		return predictions;
	}

	TargetField targetField = getTargetField();

	Segmentation segmentation = miningModel.getSegmentation();

	Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
	Segmentation.MissingPredictionTreatment missingPredictionTreatment = segmentation.getMissingPredictionTreatment();
	Number missingThreshold = segmentation.getMissingThreshold();
	if(missingThreshold.doubleValue() < 0d || missingThreshold.doubleValue() > 1d){
		throw new InvalidAttributeException(segmentation, PMMLAttributes.SEGMENTATION_MISSINGTHRESHOLD, missingThreshold);
	}

	Value<V> value;

	switch(multipleModelMethod){
		case AVERAGE:
		case WEIGHTED_AVERAGE:
		case MEDIAN:
		case WEIGHTED_MEDIAN:
		case SUM:
		case WEIGHTED_SUM:
			value = MiningModelUtil.aggregateValues(valueFactory, multipleModelMethod, missingPredictionTreatment, missingThreshold, segmentResults);
			if(value == null){
				return TargetUtil.evaluateRegressionDefault(valueFactory, targetField);
			}
			break;
		case MAJORITY_VOTE:
		case WEIGHTED_MAJORITY_VOTE:
		case MAX:
		case SELECT_FIRST:
		case SELECT_ALL:
		case MODEL_CHAIN:
			throw new InvalidAttributeException(segmentation, multipleModelMethod);
		default:
			throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
	}

	value = TargetUtil.evaluateRegressionInternal(targetField, value);

	Regression<V> result = new MiningScore<V>(value){

		@Override
		public Collection<? extends SegmentResult> getSegmentResults(){
			return segmentResults;
		}
	};

	return TargetUtil.evaluateRegression(targetField, result);
}
 
Example 16
Source File: MiningModelEvaluator.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
@Override
protected <V extends Number> Map<FieldName, ?> evaluateClustering(ValueFactory<V> valueFactory, EvaluationContext context){
	MiningModel miningModel = getModel();

	List<SegmentResult> segmentResults = evaluateSegmentation((MiningModelEvaluationContext)context);

	Map<FieldName, ?> predictions = getSegmentationResult(CLUSTERING_METHODS, segmentResults);
	if(predictions != null){
		return predictions;
	}

	Segmentation segmentation = miningModel.getSegmentation();

	Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
	Segmentation.MissingPredictionTreatment missingPredictionTreatment = segmentation.getMissingPredictionTreatment();
	Number missingThreshold = segmentation.getMissingThreshold();
	if(missingThreshold.doubleValue() < 0d || missingThreshold.doubleValue() > 1d){
		throw new InvalidAttributeException(segmentation, PMMLAttributes.SEGMENTATION_MISSINGTHRESHOLD, missingThreshold);
	}

	MiningVoteDistribution<V> result;

	switch(multipleModelMethod){
		case MAJORITY_VOTE:
		case WEIGHTED_MAJORITY_VOTE:
			{
				ValueMap<Object, V> values = MiningModelUtil.aggregateVotes(valueFactory, multipleModelMethod, missingPredictionTreatment, missingThreshold, segmentResults);
				if(values == null){
					return Collections.singletonMap(getTargetName(), null);
				}

				result = new MiningVoteDistribution<V>(values){

					@Override
					public Collection<? extends SegmentResult> getSegmentResults(){
						return segmentResults;
					}
				};
			}
			break;
		case AVERAGE:
		case WEIGHTED_AVERAGE:
		case MEDIAN:
		case WEIGHTED_MEDIAN:
		case MAX:
		case SUM:
		case WEIGHTED_SUM:
		case SELECT_FIRST:
		case SELECT_ALL:
		case MODEL_CHAIN:
			throw new InvalidAttributeException(segmentation, multipleModelMethod);
		default:
			throw new UnsupportedAttributeException(segmentation, multipleModelMethod);
	}

	result.computeResult(DataType.STRING);

	return Collections.singletonMap(getTargetName(), result);
}
 
Example 17
Source File: MiningModelUtil.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
static
public <V extends Number> ValueMap<Object, V> aggregateVotes(ValueFactory<V> valueFactory, Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, Number missingThreshold, List<SegmentResult> segmentResults){
	VoteAggregator<Object, V> aggregator = new VoteAggregator<>(valueFactory);

	Fraction<V> missingFraction = null;

	segmentResults:
	for(SegmentResult segmentResult : segmentResults){
		Object targetValue = EvaluatorUtil.decode(segmentResult.getTargetValue());

		if(targetValue == null){

			switch(missingPredictionTreatment){
				case RETURN_MISSING:
					return null;
				case SKIP_SEGMENT:
				case CONTINUE:
					if(missingFraction == null){
						missingFraction = new Fraction<>(valueFactory, segmentResults);
					} // End if

					if(missingFraction.update(segmentResult, missingThreshold)){
						return null;
					}
					break;
				default:
					throw new IllegalArgumentException();
			} // End switch

			switch(missingPredictionTreatment){
				case SKIP_SEGMENT:
					continue segmentResults;
				case CONTINUE:
					break;
				default:
					throw new IllegalArgumentException();
			}
		}

		switch(multipleModelMethod){
			case MAJORITY_VOTE:
				aggregator.add(targetValue);
				break;
			case WEIGHTED_MAJORITY_VOTE:
				Number weight = segmentResult.getWeight();

				aggregator.add(targetValue, weight);
				break;
			default:
				throw new IllegalArgumentException();
		}
	}

	ValueMap<Object, V> result = aggregator.sumMap();

	switch(missingPredictionTreatment){
		case CONTINUE:
			// Remove the "missing" pseudo-category
			Value<V> missingVoteSum = result.remove(null);

			if(missingVoteSum != null){
				Collection<Value<V>> voteSums = result.values();

				// "The missing result is returned if it gets the most (possibly weighted) votes"
				if(!voteSums.isEmpty() && (missingVoteSum).compareTo(Collections.max(voteSums)) > 0){
					return null;
				}
			}
			break;
		default:
			break;
	}

	return result;
}
 
Example 18
Source File: MiningModelEvaluator.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
private ModelEvaluator<?> createSegmentModelEvaluator(String segmentId, Model model){
	MiningModel miningModel = getModel();

	MiningFunction miningFunction = miningModel.getMiningFunction();

	Segmentation segmentation = miningModel.getSegmentation();

	Configuration configuration = ensureConfiguration();

	ModelEvaluatorFactory modelEvaluatorFactory = configuration.getModelEvaluatorFactory();

	ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory.newModelEvaluator(getPMML(), model);

	Segmentation.MultipleModelMethod multipleModelMethod = segmentation.getMultipleModelMethod();
	switch(multipleModelMethod){
		case SELECT_FIRST:
		case SELECT_ALL:
		case MODEL_CHAIN:
			{
				Set<ResultFeature> resultFeatures = getResultFeatures();

				if(!resultFeatures.isEmpty()){
					modelEvaluator.addResultFeatures(resultFeatures);
				}
			}
			// Falls through
		default:
			{
				Set<ResultFeature> segmentResultFeatures = getSegmentResultFeatures(segmentId);

				if(segmentResultFeatures != null && !segmentResultFeatures.isEmpty()){
					modelEvaluator.addResultFeatures(segmentResultFeatures);
				}
			}
			break;
	}

	MiningFunction segmentMiningFunction = model.getMiningFunction();

	if((MiningFunction.CLASSIFICATION).equals(miningFunction) && (MiningFunction.CLASSIFICATION).equals(segmentMiningFunction)){
		List<TargetField> targetFields = getTargetFields();
		List<TargetField> segmentTargetFields = modelEvaluator.getTargetFields();

		if(targetFields.size() == 1 && segmentTargetFields.size() == 1){
			TargetField targetField = targetFields.get(0);
			TargetField segmentTargetField = segmentTargetFields.get(0);

			if(segmentTargetField instanceof DefaultTargetField){
				DefaultTargetField defaultTargetField = (DefaultTargetField)segmentTargetField;

				modelEvaluator.setDefaultDataField(new DataField(Evaluator.DEFAULT_TARGET_NAME, OpType.CATEGORICAL, targetField.getDataType()));
			}
		}
	}

	modelEvaluator.configure(configuration);

	return modelEvaluator;
}
 
Example 19
Source File: VotingClassifier.java    From jpmml-sklearn with GNU Affero General Public License v3.0 4 votes vote down vote up
@Override
public Model encodeModel(Schema schema){
	List<? extends Classifier> estimators = getEstimators();
	List<? extends Number> weights = getWeights();

	CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel();

	List<Model> models = new ArrayList<>();

	for(Classifier estimator : estimators){
		Model model = estimator.encodeModel(schema);

		models.add(model);
	}

	String voting = getVoting();

	Segmentation.MultipleModelMethod multipleModelMethod = parseVoting(voting, (weights != null && weights.size() > 0));

	MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel))
		.setSegmentation(MiningModelUtil.createSegmentation(multipleModelMethod, models, weights))
		.setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));

	return miningModel;
}
 
Example 20
Source File: MiningModelUtil.java    From jpmml-evaluator with GNU Affero General Public License v3.0 4 votes vote down vote up
static
public <V extends Number> Value<V> aggregateValues(ValueFactory<V> valueFactory, Segmentation.MultipleModelMethod multipleModelMethod, Segmentation.MissingPredictionTreatment missingPredictionTreatment, Number missingThreshold, List<SegmentResult> segmentResults){
	ValueAggregator<V> aggregator;

	switch(multipleModelMethod){
		case AVERAGE:
		case SUM:
			aggregator = new ValueAggregator.UnivariateStatistic<>(valueFactory);
			break;
		case MEDIAN:
			aggregator = new ValueAggregator.Median<>(valueFactory, segmentResults.size());
			break;
		case WEIGHTED_AVERAGE:
		case WEIGHTED_SUM:
			aggregator = new ValueAggregator.WeightedUnivariateStatistic<>(valueFactory);
			break;
		case WEIGHTED_MEDIAN:
			aggregator = new ValueAggregator.WeightedMedian<>(valueFactory, segmentResults.size());
			break;
		default:
			throw new IllegalArgumentException();
	}

	Fraction<V> missingFraction = null;

	segmentResults:
	for(SegmentResult segmentResult : segmentResults){
		Object targetValue = EvaluatorUtil.decode(segmentResult.getTargetValue());

		if(targetValue == null){

			switch(missingPredictionTreatment){
				case RETURN_MISSING:
					return null;
				case SKIP_SEGMENT:
					if(missingFraction == null){
						missingFraction = new Fraction<>(valueFactory, segmentResults);
					} // End if

					if(missingFraction.update(segmentResult, missingThreshold)){
						return null;
					}

					continue segmentResults;
				case CONTINUE:
					return null;
				default:
					throw new IllegalArgumentException();
			}
		}

		Number value;

		try {
			if(targetValue instanceof Number){
				value = (Number)targetValue;
			} else

			{
				value = (Number)TypeUtil.cast(DataType.DOUBLE, targetValue);
			}
		} catch(TypeCheckException tce){
			throw tce.ensureContext(segmentResult.getSegment());
		}

		switch(multipleModelMethod){
			case AVERAGE:
			case SUM:
			case MEDIAN:
				aggregator.add(value);
				break;
			case WEIGHTED_AVERAGE:
			case WEIGHTED_SUM:
			case WEIGHTED_MEDIAN:
				Number weight = segmentResult.getWeight();

				aggregator.add(value, weight);
				break;
			default:
				throw new IllegalArgumentException();
		}
	}

	switch(multipleModelMethod){
		case AVERAGE:
			return aggregator.average();
		case WEIGHTED_AVERAGE:
			return aggregator.weightedAverage();
		case SUM:
			return aggregator.sum();
		case WEIGHTED_SUM:
			return aggregator.weightedSum();
		case MEDIAN:
			return aggregator.median();
		case WEIGHTED_MEDIAN:
			return aggregator.weightedMedian();
		default:
			throw new IllegalArgumentException();
	}
}