/* * Copyright (c) 2015 Villu Ruusmann * * This file is part of JPMML-R * * JPMML-R is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * JPMML-R is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with JPMML-R. If not, see <http://www.gnu.org/licenses/>. */ package org.jpmml.rexp; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import org.dmg.pmml.DataField; import org.dmg.pmml.DataType; import org.dmg.pmml.FieldName; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.Model; import org.dmg.pmml.OpType; import org.dmg.pmml.Predicate; import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.True; import org.dmg.pmml.mining.MiningModel; import org.dmg.pmml.mining.Segmentation; import org.dmg.pmml.regression.RegressionModel; import org.dmg.pmml.tree.BranchNode; import org.dmg.pmml.tree.LeafNode; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.jpmml.converter.CMatrixUtil; import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.CategoryManager; import org.jpmml.converter.ContinuousFeature; import org.jpmml.converter.ContinuousLabel; import org.jpmml.converter.Feature; import org.jpmml.converter.FlagManager; import org.jpmml.converter.ModelUtil; import org.jpmml.converter.Schema; import org.jpmml.converter.ValueUtil; import org.jpmml.converter.mining.MiningModelUtil; public class GBMConverter extends TreeModelConverter<RGenericVector> { public GBMConverter(RGenericVector gbm){ super(gbm); } @Override public void encodeSchema(RExpEncoder encoder){ RGenericVector gbm = getObject(); RGenericVector distribution = gbm.getGenericElement("distribution"); RStringVector response_name = gbm.getStringElement("response.name", false); RGenericVector var_levels = gbm.getGenericElement("var.levels"); RStringVector var_names = gbm.getStringElement("var.names"); RNumberVector<?> var_type = gbm.getNumericElement("var.type"); RStringVector classes = gbm.getStringElement("classes", false); RStringVector distributionName = distribution.getStringElement("name"); { FieldName responseName; if(response_name != null){ responseName = FieldName.create(response_name.asScalar()); } else { responseName = FieldName.create("y"); } DataField dataField; switch(distributionName.asScalar()){ case "gaussian": dataField = encoder.createDataField(responseName, OpType.CONTINUOUS, DataType.DOUBLE); break; case "adaboost": case "bernoulli": dataField = encoder.createDataField(responseName, OpType.CATEGORICAL, DataType.INTEGER, GBMConverter.BINARY_CLASSES); break; case "multinomial": dataField = encoder.createDataField(responseName, OpType.CATEGORICAL, DataType.STRING, classes.getValues()); break; default: throw new IllegalArgumentException(); } encoder.setLabel(dataField); } for(int i = 0; i < var_names.size(); i++){ FieldName varName = FieldName.create(var_names.getValue(i)); DataField dataField; boolean categorical = (ValueUtil.asInt(var_type.getValue(i)) > 0); if(categorical){ RStringVector var_level = (RStringVector)var_levels.getValue(i); dataField = encoder.createDataField(varName, OpType.CATEGORICAL, DataType.STRING, var_level.getValues()); } else { dataField = encoder.createDataField(varName, OpType.CONTINUOUS, DataType.DOUBLE); } encoder.addFeature(dataField); } } @Override public MiningModel encodeModel(Schema schema){ RGenericVector gbm = getObject(); RDoubleVector initF = gbm.getDoubleElement("initF"); RGenericVector trees = gbm.getGenericElement("trees"); RGenericVector c_splits = gbm.getGenericElement("c.splits"); RGenericVector distribution = gbm.getGenericElement("distribution"); RStringVector distributionName = distribution.getStringElement("name"); Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE); List<TreeModel> treeModels = new ArrayList<>(); for(int i = 0; i < trees.size(); i++){ RGenericVector tree = (RGenericVector)trees.getValue(i); TreeModel treeModel = encodeTreeModel(MiningFunction.REGRESSION, tree, c_splits, segmentSchema); treeModels.add(treeModel); } MiningModel miningModel = encodeMiningModel(distributionName, treeModels, initF.asScalar(), schema); return miningModel; } private MiningModel encodeMiningModel(RStringVector distributionName, List<TreeModel> treeModels, Double initF, Schema schema){ switch(distributionName.asScalar()){ case "gaussian": return encodeRegression(treeModels, initF, schema); case "adaboost": return encodeBinaryClassification(treeModels, initF, -2d, schema); case "bernoulli": return encodeBinaryClassification(treeModels, initF, -1d, schema); case "multinomial": return encodeMultinomialClassification(treeModels, initF, schema); default: throw new IllegalArgumentException(); } } private MiningModel encodeRegression(List<TreeModel> treeModels, Double initF, Schema schema){ MiningModel miningModel = createMiningModel(treeModels, initF, schema); return miningModel; } private MiningModel encodeBinaryClassification(List<TreeModel> treeModels, Double initF, double coefficient, Schema schema){ Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE); MiningModel miningModel = createMiningModel(treeModels, initF, segmentSchema) .setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue"), OpType.CONTINUOUS, DataType.DOUBLE)); return MiningModelUtil.createBinaryLogisticClassification(miningModel, -coefficient, 0d, RegressionModel.NormalizationMethod.LOGIT, true, schema); } private MiningModel encodeMultinomialClassification(List<TreeModel> treeModels, Double initF, Schema schema){ CategoricalLabel categoricalLabel = (CategoricalLabel)schema.getLabel(); Schema segmentSchema = schema.toAnonymousRegressorSchema(DataType.DOUBLE); List<Model> miningModels = new ArrayList<>(); for(int i = 0, columns = categoricalLabel.size(), rows = (treeModels.size() / columns); i < columns; i++){ MiningModel miningModel = createMiningModel(CMatrixUtil.getColumn(treeModels, rows, columns, i), initF, segmentSchema) .setOutput(ModelUtil.createPredictedOutput(FieldName.create("gbmValue(" + categoricalLabel.getValue(i) + ")"), OpType.CONTINUOUS, DataType.DOUBLE)); miningModels.add(miningModel); } return MiningModelUtil.createClassification(miningModels, RegressionModel.NormalizationMethod.SOFTMAX, true, schema); } private TreeModel encodeTreeModel(MiningFunction miningFunction, RGenericVector tree, RGenericVector c_splits, Schema schema){ Node root = encodeNode(True.INSTANCE, 0, tree, c_splits, new FlagManager(), new CategoryManager(), schema); TreeModel treeModel = new TreeModel(miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.MULTI_SPLIT); return treeModel; } private Node encodeNode(Predicate predicate, int i, RGenericVector tree, RGenericVector c_splits, FlagManager flagManager, CategoryManager categoryManager, Schema schema){ RIntegerVector splitVar = (RIntegerVector)tree.getValue(0); RDoubleVector splitCodePred = (RDoubleVector)tree.getValue(1); RIntegerVector leftNode = (RIntegerVector)tree.getValue(2); RIntegerVector rightNode = (RIntegerVector)tree.getValue(3); RIntegerVector missingNode = (RIntegerVector)tree.getValue(4); RDoubleVector prediction = (RDoubleVector)tree.getValue(7); Integer id = Integer.valueOf(i + 1); Integer var = splitVar.getValue(i); if(var == -1){ Double value = prediction.getValue(i); Node result = new LeafNode(value, predicate) .setId(id); return result; } Boolean isMissing; FlagManager missingFlagManager = flagManager; FlagManager nonMissingFlagManager = flagManager; Predicate missingPredicate; Feature feature = schema.getFeature(var); { FieldName name = feature.getName(); isMissing = flagManager.getValue(name); if(isMissing == null){ missingFlagManager = missingFlagManager.fork(name, Boolean.TRUE); nonMissingFlagManager = nonMissingFlagManager.fork(name, Boolean.FALSE); } missingPredicate = createSimplePredicate(feature, SimplePredicate.Operator.IS_MISSING, null); } CategoryManager leftCategoryManager = categoryManager; CategoryManager rightCategoryManager = categoryManager; Predicate leftPredicate; Predicate rightPredicate; Double split = splitCodePred.getValue(i); if(feature instanceof CategoricalFeature){ CategoricalFeature categoricalFeature = (CategoricalFeature)feature; FieldName name = categoricalFeature.getName(); List<?> values = categoricalFeature.getValues(); int index = ValueUtil.asInt(split); RIntegerVector c_split = (RIntegerVector)c_splits.getValue(index); List<Integer> splitValues = c_split.getValues(); java.util.function.Predicate<Object> valueFilter = categoryManager.getValueFilter(name); List<Object> leftValues = selectValues(values, valueFilter, splitValues, true); List<Object> rightValues = selectValues(values, valueFilter, splitValues, false); leftCategoryManager = leftCategoryManager.fork(name, leftValues); rightCategoryManager = rightCategoryManager.fork(name, rightValues); leftPredicate = createSimpleSetPredicate(categoricalFeature, leftValues); rightPredicate = createSimpleSetPredicate(categoricalFeature, rightValues); } else { ContinuousFeature continuousFeature = feature.toContinuousFeature(); leftPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.LESS_THAN, split); rightPredicate = createSimplePredicate(continuousFeature, SimplePredicate.Operator.GREATER_OR_EQUAL, split); } Node result = new BranchNode(null, predicate) .setId(id); List<Node> nodes = result.getNodes(); Integer missing = missingNode.getValue(i); if(missing != -1 && (isMissing == null || isMissing)){ Node missingChild = encodeNode(missingPredicate, missing, tree, c_splits, missingFlagManager, categoryManager, schema); nodes.add(missingChild); } Integer left = leftNode.getValue(i); if(left != -1 && (isMissing == null || !isMissing)){ Node leftChild = encodeNode(leftPredicate, left, tree, c_splits, nonMissingFlagManager, leftCategoryManager, schema); nodes.add(leftChild); } Integer right = rightNode.getValue(i); if(right != -1 && (isMissing == null || !isMissing)){ Node rightChild = encodeNode(rightPredicate, right, tree, c_splits, nonMissingFlagManager, rightCategoryManager, schema); nodes.add(rightChild); } return result; } static private MiningModel createMiningModel(List<TreeModel> treeModels, Double initF, Schema schema){ ContinuousLabel continuousLabel = (ContinuousLabel)schema.getLabel(); MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel)) .setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.SUM, treeModels)) .setTargets(ModelUtil.createRescaleTargets(null, initF, continuousLabel)); return miningModel; } static private List<Object> selectValues(List<?> values, java.util.function.Predicate<Object> valueFilter, List<Integer> splitValues, boolean left){ if(values.size() != splitValues.size()){ throw new IllegalArgumentException(); } List<Object> result = new ArrayList<>(); for(int i = 0; i < values.size(); i++){ Object value = values.get(i); Integer splitValue = splitValues.get(i); boolean append; if(left){ append = (splitValue == -1); } else { append = (splitValue == 1); } // End if if(append && valueFilter.test(value)){ result.add(value); } } return result; } private static final List<Integer> BINARY_CLASSES = Arrays.asList(0, 1); }