/* * Copyright (c) 2014, Cloudera, Inc. All Rights Reserved. * * Cloudera, Inc. licenses this file to you under the Apache License, * Version 2.0 (the "License"). You may not use this file except in * compliance with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR * CONDITIONS OF ANY KIND, either express or implied. See the License for * the specific language governing permissions and limitations under the * License. */ package com.cloudera.oryx.app.batch.mllib.rdf; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Queue; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.IntStream; import com.google.common.base.Preconditions; import com.typesafe.config.Config; import org.apache.hadoop.fs.Path; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.RandomForest; import org.apache.spark.mllib.tree.configuration.Algo; import org.apache.spark.mllib.tree.configuration.FeatureType; import org.apache.spark.mllib.tree.model.DecisionTreeModel; import org.apache.spark.mllib.tree.model.Predict; import org.apache.spark.mllib.tree.model.RandomForestModel; import org.apache.spark.mllib.tree.model.Split; import org.dmg.pmml.Array; import org.dmg.pmml.DataDictionary; import org.dmg.pmml.FieldName; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.Model; import org.dmg.pmml.PMML; import org.dmg.pmml.Predicate; import org.dmg.pmml.ScoreDistribution; import org.dmg.pmml.SimplePredicate; import org.dmg.pmml.SimpleSetPredicate; import org.dmg.pmml.True; import org.dmg.pmml.mining.MiningModel; import org.dmg.pmml.mining.Segment; import org.dmg.pmml.mining.Segmentation; import org.dmg.pmml.tree.ComplexNode; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; import org.eclipse.collections.api.map.primitive.IntLongMap; import org.eclipse.collections.impl.map.mutable.primitive.IntLongHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.collection.JavaConversions; import com.cloudera.oryx.app.classreg.example.Example; import com.cloudera.oryx.app.classreg.example.ExampleUtils; import com.cloudera.oryx.app.common.fn.MLFunctions; import com.cloudera.oryx.app.pmml.AppPMMLUtils; import com.cloudera.oryx.app.rdf.RDFPMMLUtils; import com.cloudera.oryx.app.rdf.tree.DecisionForest; import com.cloudera.oryx.app.schema.CategoricalValueEncodings; import com.cloudera.oryx.app.schema.InputSchema; import com.cloudera.oryx.common.collection.Pair; import com.cloudera.oryx.common.pmml.PMMLUtils; import com.cloudera.oryx.common.random.RandomManager; import com.cloudera.oryx.common.text.TextUtils; import com.cloudera.oryx.ml.MLUpdate; import com.cloudera.oryx.ml.param.HyperParamValues; import com.cloudera.oryx.ml.param.HyperParams; /** * Update function that builds and evaluates random decision forest models in the Batch Layer. */ public final class RDFUpdate extends MLUpdate<String> { private static final Logger log = LoggerFactory.getLogger(RDFUpdate.class); private final int numTrees; private final List<HyperParamValues<?>> hyperParamValues; private final InputSchema inputSchema; public RDFUpdate(Config config) { super(config); numTrees = config.getInt("oryx.rdf.num-trees"); Preconditions.checkArgument(numTrees >= 1); hyperParamValues = Arrays.asList( HyperParams.fromConfig(config, "oryx.rdf.hyperparams.max-split-candidates"), HyperParams.fromConfig(config, "oryx.rdf.hyperparams.max-depth"), HyperParams.fromConfig(config, "oryx.rdf.hyperparams.impurity")); inputSchema = new InputSchema(config); Preconditions.checkArgument(inputSchema.hasTarget()); } @Override public List<HyperParamValues<?>> getHyperParameterValues() { return hyperParamValues; } @Override public PMML buildModel(JavaSparkContext sparkContext, JavaRDD<String> trainData, List<?> hyperParameters, Path candidatePath) { int maxSplitCandidates = (Integer) hyperParameters.get(0); int maxDepth = (Integer) hyperParameters.get(1); String impurity = (String) hyperParameters.get(2); Preconditions.checkArgument(maxSplitCandidates >= 2, "max-split-candidates must be at least 2"); Preconditions.checkArgument(maxDepth > 0, "max-depth must be at least 1"); JavaRDD<String[]> parsedRDD = trainData.map(MLFunctions.PARSE_FN); CategoricalValueEncodings categoricalValueEncodings = new CategoricalValueEncodings(getDistinctValues(parsedRDD)); JavaRDD<LabeledPoint> trainPointData = parseToLabeledPointRDD(parsedRDD, categoricalValueEncodings); Map<Integer,Integer> categoryInfo = categoricalValueEncodings.getCategoryCounts(); categoryInfo.remove(inputSchema.getTargetFeatureIndex()); // Don't specify target count // Need to translate indices to predictor indices Map<Integer,Integer> categoryInfoByPredictor = new HashMap<>(categoryInfo.size()); categoryInfo.forEach((k, v) -> categoryInfoByPredictor.put(inputSchema.featureToPredictorIndex(k), v)); int seed = RandomManager.getRandom().nextInt(); RandomForestModel model; if (inputSchema.isClassification()) { int numTargetClasses = categoricalValueEncodings.getValueCount(inputSchema.getTargetFeatureIndex()); model = RandomForest.trainClassifier(trainPointData, numTargetClasses, categoryInfoByPredictor, numTrees, "auto", impurity, maxDepth, maxSplitCandidates, seed); } else { model = RandomForest.trainRegressor(trainPointData, categoryInfoByPredictor, numTrees, "auto", impurity, maxDepth, maxSplitCandidates, seed); } List<IntLongHashMap> treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model); IntLongHashMap predictorIndexCounts = predictorExampleCounts(trainPointData, model); return rdfModelToPMML(model, categoricalValueEncodings, maxDepth, maxSplitCandidates, impurity, treeNodeIDCounts, predictorIndexCounts); } @Override public double evaluate(JavaSparkContext sparkContext, PMML model, Path modelParentPath, JavaRDD<String> testData, JavaRDD<String> trainData) { RDFPMMLUtils.validatePMMLVsSchema(model, inputSchema); Pair<DecisionForest,CategoricalValueEncodings> forestAndEncoding = RDFPMMLUtils.read(model); DecisionForest forest = forestAndEncoding.getFirst(); CategoricalValueEncodings valueEncodings = forestAndEncoding.getSecond(); InputSchema inputSchema = this.inputSchema; JavaRDD<Example> examplesRDD = testData.map(MLFunctions.PARSE_FN). map(data -> ExampleUtils.dataToExample(data, inputSchema, valueEncodings)); double eval; if (inputSchema.isClassification()) { double accuracy = Evaluation.accuracy(forest, examplesRDD); log.info("Accuracy: {}", accuracy); eval = accuracy; } else { double rmse = Evaluation.rmse(forest, examplesRDD); log.info("RMSE: {}", rmse); eval = -rmse; } return eval; } private Map<Integer,Collection<String>> getDistinctValues(JavaRDD<String[]> parsedRDD) { int[] categoricalIndices = IntStream.range(0, inputSchema.getNumFeatures()). filter(inputSchema::isCategorical).toArray(); return parsedRDD.mapPartitions(data -> { Map<Integer,Collection<String>> categoryValues = new HashMap<>(); for (int i : categoricalIndices) { categoryValues.put(i, new HashSet<>()); } data.forEachRemaining(datum -> categoryValues.forEach((category, values) -> values.add(datum[category])) ); return Collections.singleton(categoryValues).iterator(); }).reduce((v1, v2) -> { // Assumes both have the same key set v1.forEach((category, values) -> values.addAll(v2.get(category))); return v1; }); } private JavaRDD<LabeledPoint> parseToLabeledPointRDD( JavaRDD<String[]> parsedRDD, CategoricalValueEncodings categoricalValueEncodings) { return parsedRDD.map(data -> { try { double[] features = new double[inputSchema.getNumPredictors()]; double target = Double.NaN; for (int featureIndex = 0; featureIndex < data.length; featureIndex++) { double encoded; if (inputSchema.isNumeric(featureIndex)) { encoded = Double.parseDouble(data[featureIndex]); } else if (inputSchema.isCategorical(featureIndex)) { Map<String,Integer> valueEncoding = categoricalValueEncodings.getValueEncodingMap(featureIndex); encoded = valueEncoding.get(data[featureIndex]); } else { continue; } if (inputSchema.isTarget(featureIndex)) { target = encoded; } else { features[inputSchema.featureToPredictorIndex(featureIndex)] = encoded; } } Preconditions.checkState(!Double.isNaN(target)); return new LabeledPoint(target, Vectors.dense(features)); } catch (NumberFormatException | ArrayIndexOutOfBoundsException e) { log.warn("Bad input: {}", Arrays.toString(data)); throw e; } }); } /** * @param trainPointData data to run down trees * @param model random decision forest model to count on * @return maps of node IDs to the count of training examples that reached that node, one * per tree in the model * @see #predictorExampleCounts(JavaRDD,RandomForestModel) */ private static List<IntLongHashMap> treeNodeExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData, RandomForestModel model) { return trainPointData.mapPartitions(data -> { DecisionTreeModel[] trees = model.trees(); List<IntLongHashMap> treeNodeIDCounts = IntStream.range(0, trees.length). mapToObj(i -> new IntLongHashMap()).collect(Collectors.toList()); data.forEachRemaining(datum -> { double[] featureVector = datum.features().toArray(); for (int i = 0; i < trees.length; i++) { DecisionTreeModel tree = trees[i]; IntLongHashMap nodeIDCount = treeNodeIDCounts.get(i); org.apache.spark.mllib.tree.model.Node node = tree.topNode(); // This logic cloned from Node.predict: while (!node.isLeaf()) { // Count node ID nodeIDCount.addToValue(node.id(), 1); Split split = node.split().get(); int featureIndex = split.feature(); node = nextNode(featureVector, node, split, featureIndex); } nodeIDCount.addToValue(node.id(), 1); } }); return Collections.singleton(treeNodeIDCounts).iterator(); } ).reduce((a, b) -> { Preconditions.checkArgument(a.size() == b.size()); for (int i = 0; i < a.size(); i++) { merge(a.get(i), b.get(i)); } return a; }); } /** * @param trainPointData data to run down trees * @param model random decision forest model to count on * @return map of predictor index to the number of training examples that reached a * node whose decision is based on that feature. The index is among predictors, not all * features, since there are fewer predictors than features. That is, the index will * match the one used in the {@link RandomForestModel}. */ private static IntLongHashMap predictorExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData, RandomForestModel model) { return trainPointData.mapPartitions(data -> { IntLongHashMap featureIndexCount = new IntLongHashMap(); data.forEachRemaining(datum -> { double[] featureVector = datum.features().toArray(); for (DecisionTreeModel tree : model.trees()) { org.apache.spark.mllib.tree.model.Node node = tree.topNode(); // This logic cloned from Node.predict: while (!node.isLeaf()) { Split split = node.split().get(); int featureIndex = split.feature(); // Count feature featureIndexCount.addToValue(featureIndex, 1); node = nextNode(featureVector, node, split, featureIndex); } } }); return Collections.singleton(featureIndexCount).iterator(); }).reduce(RDFUpdate::merge); } private static org.apache.spark.mllib.tree.model.Node nextNode( double[] featureVector, org.apache.spark.mllib.tree.model.Node node, Split split, int featureIndex) { double featureValue = featureVector[featureIndex]; if (split.featureType().equals(FeatureType.Continuous())) { if (featureValue <= split.threshold()) { return node.leftNode().get(); } else { return node.rightNode().get(); } } else { if (split.categories().contains(featureValue)) { return node.leftNode().get(); } else { return node.rightNode().get(); } } } private static IntLongHashMap merge(IntLongHashMap a, IntLongHashMap b) { if (b.size() > a.size()) { return merge(b, a); } b.forEachKeyValue(a::addToValue); return a; } private PMML rdfModelToPMML(RandomForestModel rfModel, CategoricalValueEncodings categoricalValueEncodings, int maxDepth, int maxSplitCandidates, String impurity, List<? extends IntLongMap> nodeIDCounts, IntLongMap predictorIndexCounts) { boolean classificationTask = rfModel.algo().equals(Algo.Classification()); Preconditions.checkState(classificationTask == inputSchema.isClassification()); DecisionTreeModel[] trees = rfModel.trees(); Model model; if (trees.length == 1) { model = toTreeModel(trees[0], categoricalValueEncodings, nodeIDCounts.get(0)); } else { MiningModel miningModel = new MiningModel(); model = miningModel; Segmentation.MultipleModelMethod multipleModelMethodType = classificationTask ? Segmentation.MultipleModelMethod.WEIGHTED_MAJORITY_VOTE : Segmentation.MultipleModelMethod.WEIGHTED_AVERAGE; List<Segment> segments = new ArrayList<>(trees.length); for (int treeID = 0; treeID < trees.length; treeID++) { TreeModel treeModel = toTreeModel(trees[treeID], categoricalValueEncodings, nodeIDCounts.get(treeID)); segments.add(new Segment() .setId(Integer.toString(treeID)) .setPredicate(new True()) .setModel(treeModel) .setWeight(1.0)); // No weights in MLlib impl now } miningModel.setSegmentation(new Segmentation(multipleModelMethodType, segments)); } model.setMiningFunction(classificationTask ? MiningFunction.CLASSIFICATION : MiningFunction.REGRESSION); double[] importances = countsToImportances(predictorIndexCounts); model.setMiningSchema(AppPMMLUtils.buildMiningSchema(inputSchema, importances)); DataDictionary dictionary = AppPMMLUtils.buildDataDictionary(inputSchema, categoricalValueEncodings); PMML pmml = PMMLUtils.buildSkeletonPMML(); pmml.setDataDictionary(dictionary); pmml.addModels(model); AppPMMLUtils.addExtension(pmml, "maxDepth", maxDepth); AppPMMLUtils.addExtension(pmml, "maxSplitCandidates", maxSplitCandidates); AppPMMLUtils.addExtension(pmml, "impurity", impurity); return pmml; } private TreeModel toTreeModel(DecisionTreeModel dtModel, CategoricalValueEncodings categoricalValueEncodings, IntLongMap nodeIDCounts) { boolean classificationTask = dtModel.algo().equals(Algo.Classification()); Preconditions.checkState(classificationTask == inputSchema.isClassification()); Node root = new ComplexNode(); root.setId("r"); Queue<Node> modelNodes = new ArrayDeque<>(); modelNodes.add(root); Queue<Pair<org.apache.spark.mllib.tree.model.Node,Split>> treeNodes = new ArrayDeque<>(); treeNodes.add(new Pair<>(dtModel.topNode(), null)); while (!treeNodes.isEmpty()) { Pair<org.apache.spark.mllib.tree.model.Node,Split> treeNodePredicate = treeNodes.remove(); Node modelNode = modelNodes.remove(); // This is the decision that got us here from the parent, if any; // not the predicate at this node Predicate predicate = buildPredicate(treeNodePredicate.getSecond(), categoricalValueEncodings); modelNode.setPredicate(predicate); org.apache.spark.mllib.tree.model.Node treeNode = treeNodePredicate.getFirst(); long nodeCount = nodeIDCounts.get(treeNode.id()); modelNode.setRecordCount((double) nodeCount); if (treeNode.isLeaf()) { Predict prediction = treeNode.predict(); int targetEncodedValue = (int) prediction.predict(); if (classificationTask) { Map<Integer,String> targetEncodingToValue = categoricalValueEncodings.getEncodingValueMap(inputSchema.getTargetFeatureIndex()); double predictedProbability = prediction.prob(); Preconditions.checkState(predictedProbability >= 0.0 && predictedProbability <= 1.0); // Not sure how nodeCount == 0 can happen but it does in the MLlib model long effectiveNodeCount = Math.max(1, nodeCount); // Problem: MLlib only gives a predicted class and its probability, and no distribution // over the rest. Infer that the rest of the probability is evenly distributed. double restProbability = (1.0 - predictedProbability) / (targetEncodingToValue.size() - 1); targetEncodingToValue.forEach((encodedValue, value) -> { double probability = encodedValue == targetEncodedValue ? predictedProbability : restProbability; // Yes, recordCount may be fractional; it's a relative indicator double recordCount = probability * effectiveNodeCount; if (recordCount > 0.0) { ScoreDistribution distribution = new ScoreDistribution(value, recordCount); // Not "confident" enough in the "probability" to call it one distribution.setConfidence(probability); modelNode.addScoreDistributions(distribution); } }); } else { modelNode.setScore(Double.toString(targetEncodedValue)); } } else { Split split = treeNode.split().get(); Node positiveModelNode = new ComplexNode().setId(modelNode.getId() + "+"); Node negativeModelNode = new ComplexNode().setId(modelNode.getId() + "-"); modelNode.addNodes(positiveModelNode, negativeModelNode); org.apache.spark.mllib.tree.model.Node rightTreeNode = treeNode.rightNode().get(); org.apache.spark.mllib.tree.model.Node leftTreeNode = treeNode.leftNode().get(); boolean defaultRight = nodeIDCounts.get(rightTreeNode.id()) > nodeIDCounts.get(leftTreeNode.id()); modelNode.setDefaultChild(defaultRight ? positiveModelNode.getId() : negativeModelNode.getId()); // Right node is "positive", so carries the predicate. It must evaluate first // and therefore come first in the tree modelNodes.add(positiveModelNode); modelNodes.add(negativeModelNode); treeNodes.add(new Pair<>(rightTreeNode, split)); treeNodes.add(new Pair<>(leftTreeNode, null)); } } return new TreeModel() .setNode(root) .setSplitCharacteristic(TreeModel.SplitCharacteristic.BINARY_SPLIT) .setMissingValueStrategy(TreeModel.MissingValueStrategy.DEFAULT_CHILD); } private Predicate buildPredicate(Split split, CategoricalValueEncodings categoricalValueEncodings) { if (split == null) { // Left child always applies, but is evaluated second return new True(); } int featureIndex = inputSchema.predictorToFeatureIndex(split.feature()); FieldName fieldName = FieldName.create(inputSchema.getFeatureNames().get(featureIndex)); if (split.featureType().equals(FeatureType.Categorical())) { // Note that categories in MLlib model select the *left* child but the // convention here will be that the predicate selects the *right* child // So the predicate will evaluate "not in" this set // More ugly casting @SuppressWarnings("unchecked") Collection<Double> javaCategories = (Collection<Double>) (Collection<?>) JavaConversions.seqAsJavaList(split.categories()); Set<Integer> negativeEncodings = javaCategories.stream().map(Double::intValue).collect(Collectors.toSet()); Map<Integer,String> encodingToValue = categoricalValueEncodings.getEncodingValueMap(featureIndex); List<String> negativeValues = negativeEncodings.stream().map(encodingToValue::get).collect(Collectors.toList()); String joinedValues = TextUtils.joinPMMLDelimited(negativeValues); return new SimpleSetPredicate(fieldName, SimpleSetPredicate.BooleanOperator.IS_NOT_IN, new Array(Array.Type.STRING, joinedValues)); } else { // For MLlib, left means <= threshold, so right means > return new SimplePredicate(fieldName, SimplePredicate.Operator.GREATER_THAN, Double.toString(split.threshold())); } } private double[] countsToImportances(IntLongMap predictorIndexCounts) { double[] importances = new double[inputSchema.getNumPredictors()]; long total = predictorIndexCounts.sum(); predictorIndexCounts.forEachKeyValue( (k, count) -> importances[k] = total == 0 ? 0.0 : (double) count / total); return importances; } }