Java Code Examples for org.dmg.pmml.MiningFunction#CLASSIFICATION

The following examples show how to use org.dmg.pmml.MiningFunction#CLASSIFICATION . You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may check out the related API usage on the sidebar.
Example 1
Source File: GLMConverter.java    From jpmml-r with GNU Affero General Public License v3.0 6 votes vote down vote up
static
private MiningFunction getMiningFunction(String family){
	GeneralRegressionModel.Distribution distribution = parseFamily(family);

	switch(distribution){
		case BINOMIAL:
			return MiningFunction.CLASSIFICATION;
		case NORMAL:
		case GAMMA:
		case IGAUSS:
		case POISSON:
			return MiningFunction.REGRESSION;
		default:
			throw new IllegalArgumentException();
	}
}
 
Example 2
Source File: GeneralizedLinearRegressionModelConverter.java    From jpmml-sparkml with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public MiningFunction getMiningFunction(){
	GeneralizedLinearRegressionModel model = getTransformer();

	String family = model.getFamily();
	switch(family){
		case "binomial":
			return MiningFunction.CLASSIFICATION;
		default:
			return MiningFunction.REGRESSION;
	}
}
 
Example 3
Source File: BaseEstimator.java    From jpmml-sklearn with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public MiningFunction getMiningFunction(){
	String estimatorType = getEstimatorType();

	switch(estimatorType){
		case "classifier":
			return MiningFunction.CLASSIFICATION;
		case "regressor":
			return MiningFunction.REGRESSION;
		default:
			throw new IllegalArgumentException(estimatorType);
	}
}
 
Example 4
Source File: RuleSetClassifier.java    From jpmml-sklearn with GNU Affero General Public License v3.0 5 votes vote down vote up
@Override
public RuleSetModel encodeModel(Schema schema){
	String defaultScore = getDefaultScore();
	List<Object[]> rules = getRules();

	Label label = schema.getLabel();
	List<? extends Feature> features = schema.getFeatures();

	RuleSelectionMethod ruleSelectionMethod = new RuleSelectionMethod(RuleSelectionMethod.Criterion.FIRST_HIT);

	RuleSet ruleSet = new RuleSet()
		.addRuleSelectionMethods(ruleSelectionMethod);

	if(defaultScore != null){
		ruleSet
			.setDefaultConfidence(1d)
			.setDefaultScore(defaultScore);
	}

	Scope scope = new DataFrameScope(FieldName.create("X"), features);

	for(Object[] rule : rules){
		String predicate = TupleUtil.extractElement(rule, 0, String.class);
		String score = TupleUtil.extractElement(rule, 1, String.class);

		Predicate pmmlPredicate = PredicateTranslator.translate(predicate, scope);

		SimpleRule simpleRule = new SimpleRule(score, pmmlPredicate);

		ruleSet.addRules(simpleRule);
	}

	RuleSetModel ruleSetModel = new RuleSetModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(label), ruleSet);

	return ruleSetModel;
}
 
Example 5
Source File: AppPMMLUtilsTest.java    From oryx with Apache License 2.0 5 votes vote down vote up
private static PMML buildDummyModel() {
  Node node = new CountingLeafNode().setRecordCount(123.0);
  TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, null, node);
  PMML pmml = PMMLUtils.buildSkeletonPMML();
  pmml.addModels(treeModel);
  return pmml;
}
 
Example 6
Source File: PMMLUtilsTest.java    From oryx with Apache License 2.0 5 votes vote down vote up
public static PMML buildDummyModel() {
  Node node = new CountingLeafNode().setRecordCount(123.0);
  TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, null, node);
  PMML pmml = PMMLUtils.buildSkeletonPMML();
  pmml.addModels(treeModel);
  return pmml;
}
 
Example 7
Source File: ClassificationModelConverter.java    From jpmml-sparkml with GNU Affero General Public License v3.0 4 votes vote down vote up
@Override
public MiningFunction getMiningFunction(){
	return MiningFunction.CLASSIFICATION;
}
 
Example 8
Source File: Classifier.java    From jpmml-sklearn with GNU Affero General Public License v3.0 4 votes vote down vote up
@Override
public MiningFunction getMiningFunction(){
	return MiningFunction.CLASSIFICATION;
}
 
Example 9
Source File: RDFPMMLUtilsTest.java    From oryx with Apache License 2.0 4 votes vote down vote up
private static PMML buildDummyClassificationModel(int numTrees) {
  PMML pmml = PMMLUtils.buildSkeletonPMML();

  List<DataField> dataFields = new ArrayList<>();
  DataField predictor =
      new DataField(FieldName.create("color"), OpType.CATEGORICAL, DataType.STRING);
  predictor.addValues(new Value("yellow"), new Value("red"));
  dataFields.add(predictor);
  DataField target =
      new DataField(FieldName.create("fruit"), OpType.CATEGORICAL, DataType.STRING);
  target.addValues(new Value("banana"), new Value("apple"));
  dataFields.add(target);
  DataDictionary dataDictionary =
      new DataDictionary(dataFields).setNumberOfFields(dataFields.size());
  pmml.setDataDictionary(dataDictionary);

  List<MiningField> miningFields = new ArrayList<>();
  MiningField predictorMF = new MiningField(FieldName.create("color"))
      .setOpType(OpType.CATEGORICAL)
      .setUsageType(MiningField.UsageType.ACTIVE)
      .setImportance(0.5);
  miningFields.add(predictorMF);
  MiningField targetMF = new MiningField(FieldName.create("fruit"))
      .setOpType(OpType.CATEGORICAL)
      .setUsageType(MiningField.UsageType.PREDICTED);
  miningFields.add(targetMF);
  MiningSchema miningSchema = new MiningSchema(miningFields);

  double dummyCount = 2.0;
  Node rootNode =
    new ComplexNode().setId("r").setRecordCount(dummyCount).setPredicate(new True());

  double halfCount = dummyCount / 2;

  Node left = new ComplexNode().setId("r-").setRecordCount(halfCount).setPredicate(new True());
  left.addScoreDistributions(new ScoreDistribution("apple", halfCount));
  Node right = new ComplexNode().setId("r+").setRecordCount(halfCount)
      .setPredicate(new SimpleSetPredicate(FieldName.create("color"),
                                           SimpleSetPredicate.BooleanOperator.IS_NOT_IN,
                                           new Array(Array.Type.STRING, "red")));
  right.addScoreDistributions(new ScoreDistribution("banana", halfCount));

  rootNode.addNodes(right, left);

  TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, miningSchema, rootNode)
      .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT)
      .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD);

  if (numTrees > 1) {
    MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, miningSchema);
    List<Segment> segments = new ArrayList<>();
    for (int i = 0; i < numTrees; i++) {
      segments.add(new Segment()
          .setId(Integer.toString(i))
          .setPredicate(new True())
          .setModel(treeModel)
          .setWeight(1.0));
    }
    miningModel.setSegmentation(
        new Segmentation(Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE, segments));
    pmml.addModels(miningModel);
  } else {
    pmml.addModels(treeModel);
  }

  return pmml;
}
 
Example 10
Source File: BinaryTreeConverter.java    From jpmml-r with GNU Affero General Public License v3.0 3 votes vote down vote up
private void encodeResponse(S4Object responses, RExpEncoder encoder){
	RGenericVector variables = responses.getGenericAttribute("variables");
	RBooleanVector is_nominal = responses.getBooleanAttribute("is_nominal");
	RGenericVector levels = responses.getGenericAttribute("levels");

	RStringVector variableNames = variables.names();

	String variableName = variableNames.asScalar();

	DataField dataField;

	Boolean categorical = is_nominal.getElement(variableName);
	if((Boolean.TRUE).equals(categorical)){
		this.miningFunction = MiningFunction.CLASSIFICATION;

		RExp targetVariable = variables.getElement(variableName);

		RStringVector targetVariableClass = RExpUtil.getClassNames(targetVariable);

		RStringVector targetCategories = levels.getStringElement(variableName);

		dataField = encoder.createDataField(FieldName.create(variableName), OpType.CATEGORICAL, RExpUtil.getDataType(targetVariableClass.asScalar()), targetCategories.getValues());
	} else

	if((Boolean.FALSE).equals(categorical)){
		this.miningFunction = MiningFunction.REGRESSION;

		dataField = encoder.createDataField(FieldName.create(variableName), OpType.CONTINUOUS, DataType.DOUBLE);
	} else

	{
		throw new IllegalArgumentException();
	}

	encoder.setLabel(dataField);
}
 
Example 11
Source File: TreePathFinderTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 3 votes vote down vote up
@Test
public void find(){
	Node node1a = new BranchNode();

	Node node2a = new LeafNode();
	Node node2b = new BranchNode();
	Node node2c = new BranchNode();

	node1a.addNodes(node2a, node2b, node2c);

	Node node3a = new BranchNode();
	Node node3b = new LeafNode();

	node2b.addNodes(node3a);
	node2c.addNodes(node3b);

	Node node4a = new LeafNode();

	node3a.addNodes(node4a);

	TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, new MiningSchema(), node1a);

	TreePathFinder finder = new TreePathFinder();
	finder.applyTo(treeModel);

	Map<Node, List<Node>> paths = finder.getPaths();

	assertEquals(3, paths.size());

	assertEquals(Arrays.asList(node1a, node2a), paths.get(node2a));
	assertEquals(Arrays.asList(node1a, node2b, node3a, node4a), paths.get(node4a));
	assertEquals(Arrays.asList(node1a, node2c, node3b), paths.get(node3b));
}
 
Example 12
Source File: ArrayListTransformerTest.java    From jpmml-model with BSD 3-Clause "New" or "Revised" License 3 votes vote down vote up
@Test
public void transform(){
	Node node1a = new BranchNode();

	Node node2a = new BranchNode();
	Node node2b = new LeafNode();

	node1a.addNodes(node2a, node2b);

	Array array = new ComplexArray()
		.setType(Array.Type.INT)
		.setValue(Arrays.asList(-1, 1));

	Predicate predicate = new SimpleSetPredicate(FieldName.create("x"), SimpleSetPredicate.BooleanOperator.IS_IN, array);

	Node node3a = new LeafNode(null, predicate);

	node2a.addNodes(node3a);

	assertTrue(node1a.getNodes() instanceof ArrayList);
	assertTrue(node2a.getNodes() instanceof ArrayList);

	Object value = array.getValue();

	assertTrue(value instanceof ArrayList);
	assertTrue(value instanceof ComplexValue);

	TreeModel treeModel = new TreeModel(MiningFunction.CLASSIFICATION, new MiningSchema(), node1a);

	ArrayListTransformer transformer = new ArrayListTransformer();
	transformer.applyTo(treeModel);

	assertTrue(node1a.getNodes() instanceof DoubletonList);
	assertTrue(node2a.getNodes() instanceof SingletonList);

	value = array.getValue();

	assertTrue(value instanceof ArrayList);
	assertTrue(value instanceof ComplexValue);
}