/* * Copyright (c) 2015 Villu Ruusmann * * This file is part of JPMML-R * * JPMML-R is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-R is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-R. If not, see <http://www.gnu.org/licenses/>. */ package org.jpmml.rexp; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; import org.dmg.pmml.FieldName; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.OpType; import org.dmg.pmml.Output; import org.dmg.pmml.Predicate; import org.dmg.pmml.ScoreDistribution; import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.True; import org.dmg.pmml.tree.BranchNode; import org.dmg.pmml.tree.ClassifierNode; import org.dmg.pmml.tree.LeafNode; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.Feature; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.Schema; import org.jpmml.converter.SchemaUtil; public class BinaryTreeConverter extends TreeModelConverter<S4Object> { private MiningFunction miningFunction = null; private Map<FieldName, Integer> featureIndexes = new LinkedHashMap<>(); public BinaryTreeConverter(S4Object binaryTree){ super(binaryTree); } @Override public void encodeSchema(RExpEncoder encoder){ S4Object binaryTree = getObject(); S4Object responses = (S4Object)binaryTree.getAttribute("responses"); RGenericVector tree = binaryTree.getGenericAttribute("tree"); encodeResponse(responses, encoder); encodeVariableList(tree, encoder); } @Override public TreeModel encodeModel(Schema schema){ S4Object binaryTree = getObject(); RGenericVector tree = binaryTree.getGenericAttribute("tree"); Output output; switch(this.miningFunction){ case REGRESSION: output = new Output(); break; case CLASSIFICATION: CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); output = ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel); break; default: throw new IllegalArgumentException(); } output.addOutputFields(ModelUtil.createEntityIdField(FieldName.create("nodeId"))); TreeModel treeModel = encodeTreeModel(tree, schema) .setOutput(output); return treeModel; } private void encodeResponse(S4Object responses, RExpEncoder encoder){ RGenericVector variables = responses.getGenericAttribute("variables"); RBooleanVector is_nominal = responses.getBooleanAttribute("is_nominal"); RGenericVector levels = responses.getGenericAttribute("levels"); RStringVector variableNames = variables.names(); String variableName = variableNames.asScalar(); DataField dataField; Boolean categorical = is_nominal.getElement(variableName); if((Boolean.TRUE).equals(categorical)){ this.miningFunction = MiningFunction.CLASSIFICATION; RExp targetVariable = variables.getElement(variableName); RStringVector targetVariableClass = RExpUtil.getClassNames(targetVariable); RStringVector targetCategories = levels.getStringElement(variableName); dataField = encoder.createDataField(FieldName.create(variableName), OpType.CATEGORICAL, RExpUtil.getDataType(targetVariableClass.asScalar()), targetCategories.getValues()); } else if((Boolean.FALSE).equals(categorical)){ this.miningFunction = MiningFunction.REGRESSION; dataField = encoder.createDataField(FieldName.create(variableName), OpType.CONTINUOUS, DataType.DOUBLE); } else { throw new IllegalArgumentException(); } encoder.setLabel(dataField); } private void encodeVariableList(RGenericVector tree, RExpEncoder encoder){ RBooleanVector terminal = tree.getBooleanElement("terminal"); RGenericVector psplit = tree.getGenericElement("psplit"); RGenericVector left = tree.getGenericElement("left"); RGenericVector right = tree.getGenericElement("right"); if((Boolean.TRUE).equals(terminal.asScalar())){ return; } RNumberVector<?> splitpoint = psplit.getNumericElement("splitpoint"); RStringVector variableName = psplit.getStringElement("variableName"); FieldName name = FieldName.create(variableName.asScalar()); DataField dataField = encoder.getDataField(name); if(dataField == null){ if(splitpoint instanceof RIntegerVector){ RStringVector levels = splitpoint.getStringAttribute("levels"); dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues()); } else if(splitpoint instanceof RDoubleVector){ dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE); } else { throw new IllegalArgumentException(); } encoder.addFeature(dataField); this.featureIndexes.put(name, this.featureIndexes.size()); } encodeVariableList(left, encoder); encodeVariableList(right, encoder); } private TreeModel encodeTreeModel(RGenericVector tree, Schema schema){ Node root = encodeNode(True.INSTANCE, tree, schema); TreeModel treeModel = new TreeModel(this.miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); return treeModel; } private Node encodeNode(Predicate predicate, RGenericVector tree, Schema schema){ RIntegerVector nodeId = tree.getIntegerElement("nodeID"); RBooleanVector terminal = tree.getBooleanElement("terminal"); RGenericVector psplit = tree.getGenericElement("psplit"); RGenericVector ssplits = tree.getGenericElement("ssplits"); RDoubleVector prediction = tree.getDoubleElement("prediction"); RGenericVector left = tree.getGenericElement("left"); RGenericVector right = tree.getGenericElement("right"); Integer id = nodeId.asScalar(); if((Boolean.TRUE).equals(terminal.asScalar())){ Node result = new LeafNode(null, predicate) .setId(id); return encodeScore(result, prediction, schema); } RNumberVector<?> splitpoint = psplit.getNumericElement("splitpoint"); RStringVector variableName = psplit.getStringElement("variableName"); if(ssplits.size() > 0){ throw new IllegalArgumentException(); } Predicate leftPredicate; Predicate rightPredicate; FieldName name = FieldName.create(variableName.asScalar()); Integer index = this.featureIndexes.get(name); if(index == null){ throw new IllegalArgumentException(); } Feature feature = schema.getFeature(index); if(feature instanceof CategoricalFeature){ CategoricalFeature categoricalFeature = (CategoricalFeature)feature; List<?> values = categoricalFeature.getValues(); List<Integer> splitValues = (List<Integer>)splitpoint.getValues(); leftPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, true)); rightPredicate = createSimpleSetPredicate(categoricalFeature, selectValues(values, splitValues, false)); } else { ContinuousFeature continuousFeature = feature.toContinuousFeature(); Number value = splitpoint.asScalar(); leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, value); rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, value); } Node leftChild = encodeNode(leftPredicate, left, schema); Node rightChild = encodeNode(rightPredicate, right, schema); Node result = new BranchNode(null, predicate) .setId(id) .addNodes(leftChild, rightChild); return result; } private Node encodeScore(Node node, RDoubleVector probabilities, Schema schema){ switch(this.miningFunction){ case REGRESSION: return encodeRegressionScore(node, probabilities); case CLASSIFICATION: return encodeClassificationScore(node, probabilities, schema); default: throw new IllegalArgumentException(); } } static private <E> List<E> selectValues(List<E> values, List<Integer> splits, boolean left){ if(values.size() != splits.size()){ throw new IllegalArgumentException(); } List<E> result = new ArrayList<>(); for(int i = 0; i < values.size(); i++){ E value = values.get(i); Integer split = splits.get(i); boolean append; if(left){ append = (split == 1); } else { append = (split == 0); } // End if if(append){ result.add(value); } } return result; } static private Node encodeRegressionScore(Node node, RDoubleVector probabilities){ if(probabilities.size() != 1){ throw new IllegalArgumentException(); } Double probability = probabilities.asScalar(); node.setScore(probability); return node; } static private Node encodeClassificationScore(Node node, RDoubleVector probabilities, Schema schema){ CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); SchemaUtil.checkSize(probabilities.size(), categoricalLabel); node = new ClassifierNode(node); List<ScoreDistribution> scoreDistributions = node.getScoreDistributions(); Double maxProbability = null; for(int i = 0; i < categoricalLabel.size(); i++){ Object value = categoricalLabel.getValue(i); Double probability = probabilities.getValue(i); if(maxProbability == null || (maxProbability).compareTo(probability) < 0){ node.setScore(value); maxProbability = probability; } ScoreDistribution scoreDistribution = new ScoreDistribution(value, probability); scoreDistributions.add(scoreDistribution); } return node; } }