/* * Copyright (c) 2014 Villu Ruusmann * * This file is part of JPMML-R * * JPMML-R is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-R is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-R. If not, see <http://www.gnu.org/licenses/>. */ package org.jpmml.rexp; import java.util.ArrayList; import java.util.List; import com.google.common.math.DoubleMath; import com.google.common.primitives.UnsignedLong; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; import org.dmg.pmml.FieldName; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.OpType; import org.dmg.pmml.Predicate; import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.True; import org.dmg.pmml.Visitor; import org.dmg.pmml.mining.MiningModel; import org.dmg.pmml.mining.Segmentation; import org.dmg.pmml.tree.BranchNode; import org.dmg.pmml.tree.LeafNode; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.BooleanFeature; import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.CategoryManager; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.Feature; import org.jpmml.converter.FortranMatrixUtil; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.Schema; import org.jpmml.converter.ValueUtil; import org.jpmml.converter.mining.MiningModelUtil; import org.jpmml.rexp.visitors.RandomForestCompactor; public class RandomForestConverter extends TreeModelConverter<RGenericVector> { private boolean compact = true; public RandomForestConverter(RGenericVector randomForest){ super(randomForest); this.compact = getOption("compact", Boolean.TRUE); } @Override public void encodeSchema(RExpEncoder encoder){ RGenericVector randomForest = getObject(); if(randomForest.hasElement("terms")){ encodeFormula(encoder); } else { encodeNonFormula(encoder); } } @Override public MiningModel encodeModel(Schema schema){ RGenericVector randomForest = getObject(); RStringVector type = randomForest.getStringElement("type"); RGenericVector forest = randomForest.getGenericElement("forest"); switch(type.asScalar()){ case "regression": return encodeRegression(forest, schema); case "classification": return encodeClassification(forest, schema); default: throw new IllegalArgumentException(); } } private void encodeFormula(RExpEncoder encoder){ RGenericVector randomForest = getObject(); RGenericVector forest = randomForest.getGenericElement("forest"); RNumberVector<?> y = randomForest.getNumericElement("y", false); RExp terms = randomForest.getElement("terms"); RNumberVector<?> ncat = forest.getNumericElement("ncat"); RGenericVector xlevels = forest.getGenericElement("xlevels"); FormulaContext context = new XLevelsFormulaContext(xlevels){ @Override public List<String> getCategories(String variable){ if(ncat != null && ncat.hasElement(variable)){ if((ncat.getElement(variable)).doubleValue() > 1d){ return super.getCategories(variable); } } return null; } }; Formula formula = FormulaUtil.createFormula(terms, context, encoder); if(y instanceof RIntegerVector){ FormulaUtil.setLabel(formula, terms, y, encoder); } else { FormulaUtil.setLabel(formula, terms, null, encoder); } FormulaUtil.addFeatures(formula, xlevels.names(), false, encoder); } private void encodeNonFormula(RExpEncoder encoder){ RGenericVector randomForest = getObject(); RGenericVector forest = randomForest.getGenericElement("forest"); RNumberVector<?> y = randomForest.getNumericElement("y", false); RStringVector xNames = randomForest.getStringElement("xNames", false); RNumberVector<?> ncat = forest.getNumericElement("ncat"); RGenericVector xlevels = forest.getGenericElement("xlevels"); if(xNames == null){ xNames = xlevels.names(); } { FieldName name = FieldName.create("_target"); DataField dataField; if(y instanceof RIntegerVector){ y = randomForest.getFactorElement("y"); dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, RExpUtil.getFactorLevels((RIntegerVector)y)); } else { dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE); } encoder.setLabel(dataField); } for(int i = 0; i < ncat.size(); i++){ FieldName name = FieldName.create(xNames.getValue(i)); DataField dataField; boolean categorical = ((ncat.getValue(i)).doubleValue() > 1d); if(categorical){ RStringVector levels = (RStringVector)xlevels.getValue(i); dataField = encoder.createDataField(name, OpType.CATEGORICAL, null, levels.getValues()); } else { dataField = encoder.createDataField(name, OpType.CONTINUOUS, DataType.DOUBLE); } encoder.addFeature(dataField); } } private MiningModel encodeRegression(RGenericVector forest, Schema schema){ RNumberVector<?> leftDaughter = forest.getNumericElement("leftDaughter"); RNumberVector<?> rightDaughter = forest.getNumericElement("rightDaughter"); RDoubleVector nodepred = forest.getDoubleElement("nodepred"); RNumberVector<?> bestvar = forest.getNumericElement("bestvar"); RDoubleVector xbestsplit = forest.getDoubleElement("xbestsplit"); RIntegerVector nrnodes = forest.getIntegerElement("nrnodes"); RNumberVector<?> ntree = forest.getNumericElement("ntree"); ScoreEncoder<Double> scoreEncoder = new ScoreEncoder<Double>(){ @Override public Double encode(Double value){ return value; } }; int rows = nrnodes.asScalar(); int columns = ValueUtil.asInt(ntree.asScalar()); Schema segmentSchema = schema.toAnonymousSchema(); List<TreeModel> treeModels = new ArrayList<>(); for(int i = 0; i < columns; i++){ TreeModel treeModel = encodeTreeModel( MiningFunction.REGRESSION, scoreEncoder, FortranMatrixUtil.getColumn(leftDaughter.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(rightDaughter.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema ); treeModels.add(treeModel); } MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())) .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.AVERAGE, treeModels)); return miningModel; } private MiningModel encodeClassification(RGenericVector forest, Schema schema){ RNumberVector<?> bestvar = forest.getNumericElement("bestvar"); RNumberVector<?> treemap = forest.getNumericElement("treemap"); RIntegerVector nodepred = forest.getIntegerElement("nodepred"); RDoubleVector xbestsplit = forest.getDoubleElement("xbestsplit"); RIntegerVector nrnodes = forest.getIntegerElement("nrnodes"); RDoubleVector ntree = forest.getDoubleElement("ntree"); int rows = nrnodes.asScalar(); int columns = ValueUtil.asInt(ntree.asScalar()); CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); ScoreEncoder<Integer> scoreEncoder = new ScoreEncoder<Integer>(){ @Override public Object encode(Integer value){ return categoricalLabel.getValue(value - 1); } }; Schema segmentSchema = schema.toAnonymousSchema(); List<TreeModel> treeModels = new ArrayList<>(); for(int i = 0; i < columns; i++){ List<? extends Number> daughters = FortranMatrixUtil.getColumn(treemap.getValues(), 2 * rows, columns, i); TreeModel treeModel = encodeTreeModel( MiningFunction.CLASSIFICATION, scoreEncoder, FortranMatrixUtil.getColumn(daughters, rows, 2, 0), FortranMatrixUtil.getColumn(daughters, rows, 2, 1), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema ); treeModels.add(treeModel); } MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)) .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels)) .setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel)); return miningModel; } private <P extends Number> TreeModel encodeTreeModel(MiningFunction miningFunction, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<P> nodepred, List<? extends Number> bestvar, List<Double> xbestsplit, Schema schema){ RGenericVector randomForest = getObject(); Node root = encodeNode(True.INSTANCE, 0, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, new CategoryManager(), schema); TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setMissingValueStrategy(TreeModel.MissingValueStrategy.NULL_PREDICTION) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT); if(this.compact){ Visitor visitor = new RandomForestCompactor(); visitor.applyTo(treeModel); } return treeModel; } private <P extends Number> Node encodeNode(Predicate predicate, int i, ScoreEncoder<P> scoreEncoder, List<? extends Number> leftDaughter, List<? extends Number> rightDaughter, List<? extends Number> bestvar, List<Double> xbestsplit, List<P> nodepred, CategoryManager categoryManager, Schema schema){ Integer id = Integer.valueOf(i + 1); int var = ValueUtil.asInt(bestvar.get(i)); if(var == 0){ P prediction = nodepred.get(i); Node result = new LeafNode(scoreEncoder.encode(prediction), predicate) .setId(id); return result; } CategoryManager leftCategoryManager = categoryManager; CategoryManager rightCategoryManager = categoryManager; Predicate leftPredicate; Predicate rightPredicate; Feature feature = schema.getFeature(var - 1); Double split = xbestsplit.get(i); if(feature instanceof BooleanFeature){ BooleanFeature booleanFeature = (BooleanFeature)feature; if(split != 0.5d){ throw new IllegalArgumentException(); } leftPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(0)); rightPredicate = createSimplePredicate(booleanFeature, SimplePredicate.Operator.EQUAL, booleanFeature.getValue(1)); } else if(feature instanceof CategoricalFeature){ CategoricalFeature categoricalFeature = (CategoricalFeature)feature; FieldName name = categoricalFeature.getName(); List<?> values = categoricalFeature.getValues(); java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name); List<Object> leftValues = selectValues(values, valueFilter, split, true); List<Object> rightValues = selectValues(values, valueFilter, split, false); leftCategoryManager = categoryManager.fork(name, leftValues); rightCategoryManager = categoryManager.fork(name, rightValues); leftPredicate = createSimpleSetPredicate(categoricalFeature, leftValues); rightPredicate = createSimpleSetPredicate(categoricalFeature, rightValues); } else { ContinuousFeature continuousFeature = feature.toContinuousFeature(); leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_OR_EQUAL, split); rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_THAN, split); } Node result = new BranchNode(null, predicate) .setId(id); List<Node> nodes = result.getNodes(); int left = ValueUtil.asInt(leftDaughter.get(i)); if(left != 0){ Node leftChild = encodeNode(leftPredicate, left - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, leftCategoryManager, schema); nodes.add(leftChild); } int right = ValueUtil.asInt(rightDaughter.get(i)); if(right != 0){ Node rightChild = encodeNode(rightPredicate, right - 1, scoreEncoder, leftDaughter, rightDaughter, bestvar, xbestsplit, nodepred, rightCategoryManager, schema); nodes.add(rightChild); } return result; } static List<Object> selectValues(List<?> values, java.util.function.Predicate<Object> valueFilter, Double split, boolean left){ UnsignedLong bits = toUnsignedLong(split.doubleValue()); List<Object> result = new ArrayList<>(); for(int i = 0; i < values.size(); i++){ Object value = values.get(i); boolean append; // Send "true" categories to the left if(left){ // Test if the least significant bit (LSB) is 1 append = (bits.mod(RandomForestConverter.TWO)).equals(UnsignedLong.ONE); } else // Send all other categories to the right { // Test if the LSB is 0 append = (bits.mod(RandomForestConverter.TWO)).equals(UnsignedLong.ZERO); } // End if if(append && valueFilter.test(value)){ result.add(value); } bits = bits.dividedBy(RandomForestConverter.TWO); } return result; } static UnsignedLong toUnsignedLong(double value){ if(!DoubleMath.isMathematicalInteger(value)){ throw new IllegalArgumentException(); } return UnsignedLong.fromLongBits((long)value); } static private interface ScoreEncoder<V extends Number> { Object encode(V value); } private static final UnsignedLong TWO = UnsignedLong.valueOf(2L); }