package com.alibaba.alink.operator.common.tree; import com.alibaba.alink.common.comqueue.ComContext; import com.alibaba.alink.common.comqueue.CompareCriterionFunction; import com.alibaba.alink.common.comqueue.CompleteResultFunction; import com.alibaba.alink.common.comqueue.IterativeComQueue; import com.alibaba.alink.common.comqueue.communication.AllReduce; import com.alibaba.alink.common.model.ModelParamName; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.io.types.FlinkTypeConverter; import com.alibaba.alink.operator.common.tree.paralleltree.TreeInitObj; import com.alibaba.alink.operator.common.tree.paralleltree.TreeObj; import com.alibaba.alink.operator.common.tree.paralleltree.TreeSplit; import com.alibaba.alink.operator.common.tree.paralleltree.TreeStat; import com.alibaba.alink.operator.common.tree.seriestree.DecisionTree; import com.alibaba.alink.operator.common.tree.seriestree.DenseData; import com.alibaba.alink.params.shared.colname.HasCategoricalCols; import com.alibaba.alink.params.shared.colname.HasFeatureCols; import com.alibaba.alink.params.shared.colname.HasLabelCol; import com.alibaba.alink.params.shared.tree.*; import org.apache.flink.api.common.functions.*; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.operators.MapPartitionOperator; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.utils.DataSetUtils; import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.*; import java.util.stream.Collectors; import java.util.stream.StreamSupport; /** * Base class for fitting random forest and decision tree model. * The random forest use the bagging to prevent the overfitting. * * <p>In the operator, we implement three type of decision tree to * increase diversity of the forest. * <ul> * <tr>id3</tr> * <tr>cart</tr> * <tr>c4.5</tr> * </ul> * and the criteria is * <ul> * <tr>information</tr> * <tr>gini</tr> * <tr>information ratio</tr> * <tr>mse</tr> * </ul> * * @see <a href="https://en.wikipedia.org/wiki/Random_forest">Random_forest</a> * * @param <T> */ public abstract class BaseRandomForestTrainBatchOp<T extends BaseRandomForestTrainBatchOp<T>> extends BatchOperator<T> { protected DataSet<Object[]> labels; protected BatchOperator<?> stringIndexerModel; protected BaseRandomForestTrainBatchOp(Params params) { super(params); } @Override public T linkFrom(BatchOperator<?>... inputs) { BatchOperator<?> in = checkAndGetFirst(inputs); if (Criteria.isRegression(getParams().get(TreeUtil.TREE_TYPE))) { getParams().set(ModelParamName.LABEL_TYPE, FlinkTypeConverter.getTypeString(Types.DOUBLE)); } else { getParams().set( ModelParamName.LABEL_TYPE, FlinkTypeConverter.getTypeString( TableUtil.findColTypeWithAssertAndHint(in.getSchema(), getParams().get(HasLabelCol.LABEL_COL)) ) ); } getParams().set(ModelParamName.FEATURE_TYPES, FlinkTypeConverter.getTypeString( TableUtil.findColTypesWithAssertAndHint(in.getSchema(), getParams().get(HasFeatureCols.FEATURE_COLS)) ) ); in = in.select(TreeUtil.trainColNames(getParams())); set( HasCategoricalCols.CATEGORICAL_COLS, TableUtil.getCategoricalCols( in.getSchema(), getParams().get(HasFeatureCols.FEATURE_COLS), getParams().contains(HasCategoricalCols.CATEGORICAL_COLS) ? getParams().get(HasCategoricalCols.CATEGORICAL_COLS) : null ) ); labels = Preprocessing.generateLabels( in, getParams(), Criteria.isRegression(getParams().get(TreeUtil.TREE_TYPE)) ); in = Preprocessing.castLabel( in, getParams(), labels, Criteria.isRegression(getParams().get(TreeUtil.TREE_TYPE)) ); stringIndexerModel = Preprocessing.generateStringIndexerModel(in, getParams()); in = Preprocessing.castWeightCol( Preprocessing.castContinuousCols( Preprocessing.castCategoricalCols( in, stringIndexerModel, getParams() ), getParams() ), getParams() ); DataSet<Row> model; if (getParams().get(HasCreateTreeMode.CREATE_TREE_MODE).toUpperCase().equals("PARALLEL")) { model = parallelTrain(in); } else { model = seriesTrain(in); } setOutput(model, new TreeModelDataConverter( FlinkTypeConverter.getFlinkType(getParams().get(ModelParamName.LABEL_TYPE_NAME)) ).getModelSchema() ); return (T) this; } private DataSet<Row> parallelTrain(BatchOperator<?> in) { BatchOperator<?> quantileModel = Preprocessing.generateQuantileDiscretizerModel(in, getParams()); DataSet<Row> trainingDataSet = Preprocessing .castToQuantile(in, quantileModel, getParams()) .getDataSet() // check null value in training dataset and throw exception when there are null values. .map(new CheckNullValue(in.getColNames())); final Params meta = getParams().clone(); return new IterativeComQueue().setMaxIter(Integer.MAX_VALUE) .initWithPartitionedData("treeInput", trainingDataSet) .initWithBroadcastData("quantileModel", quantileModel.getDataSet()) .initWithBroadcastData("stringIndexerModel", stringIndexerModel.getDataSet()) .initWithBroadcastData("labels", labels) .add(new TreeInitObj(meta)) .add(new TreeStat()) .add(new AllReduce("allReduce", "allReduceCnt")) .add(new TreeSplit()) .setCompareCriterionOfNode0(new Criterion()) .closeWith(new SerializeModelCompleteResultFunction(meta)) .exec(); } private static class SerializeModelCompleteResultFunction extends CompleteResultFunction { private final Params meta; SerializeModelCompleteResultFunction(Params meta) { this.meta = meta; } @Override public List<Row> calc(ComContext context) { if (context.getTaskId() != 0) { return null; } TreeObj treeObj = context.getObj("treeObj"); List<Row> stringIndexerModel = context.getObj("stringIndexerModel"); List<Object[]> labelsList = context.getObj("labels"); List<Row> model = TreeModelDataConverter.saveModelWithData( treeObj.getRoots(), meta, stringIndexerModel, labelsList == null || labelsList.isEmpty() ? null : labelsList.get(0) ); return model; } } private static class Criterion extends CompareCriterionFunction { @Override public boolean calc(ComContext context) { TreeObj treeObj = context.getObj("treeObj"); return treeObj.terminationCriterion(); } } private static class CheckNullValue implements MapFunction<Row, Row> { private String[] cols; public CheckNullValue(String[] cols) { this.cols = cols; } @Override public Row map(Row value) throws Exception { for (int i = 0; i < value.getArity(); ++i) { if (value.getField(i) == null) { throw new IllegalArgumentException("There should not be null value in training dataset. col: " + cols[i] + ", " + "Maybe you can use {@code Imputer} to fill the missing values"); } } return value; } } private DataSet<Row> seriesTrain(BatchOperator<?> in) { DataSet<Row> trainDataSet = in.getDataSet(); MapPartitionOperator<Row, Tuple2<Integer, Row>> sampled = trainDataSet .mapPartition(new SampleData( get(HasSeed.SEED), get(HasSubsamplingRatio.SUBSAMPLING_RATIO), get(HasNumTreesDefaltAs10.NUM_TREES) ) ); if (getParams().get(HasSubsamplingRatio.SUBSAMPLING_RATIO) > 1.0) { DataSet<Long> cnt = DataSetUtils .countElementsPerPartition(trainDataSet) .sum(1) .map(new MapFunction<Tuple2<Integer, Long>, Long>() { @Override public Long map(Tuple2<Integer, Long> value) throws Exception { return value.f1; } }); sampled = sampled.withBroadcastSet(cnt, "totalCnt"); } DataSet<Integer> labelSize = labels.map(new MapFunction<Object[], Integer>() { @Override public Integer map(Object[] objects) throws Exception { return objects.length; } }); DataSet<Tuple2<Integer, String>> pModel = sampled .groupBy(0) .withPartitioner(new AvgPartition()) .reduceGroup(new SeriesTrainFunction(getParams())) .withBroadcastSet(stringIndexerModel.getDataSet(), "stringIndexerModel") .withBroadcastSet(labelSize, "labelSize"); return pModel .reduceGroup(new SerializeModel(getParams())) .withBroadcastSet(stringIndexerModel.getDataSet(), "stringIndexerModel") .withBroadcastSet(labels, "labels"); } private static class SeriesTrainFunction extends RichGroupReduceFunction<Tuple2<Integer, Row>, Tuple2<Integer, String>> { private static final Logger LOG = LoggerFactory.getLogger(SeriesTrainFunction.class); private Map<String, Integer> categoricalColsSize; private Params params; public SeriesTrainFunction(Params params) { this.params = params; } @Override public void open(Configuration parameters) throws Exception { categoricalColsSize = getRuntimeContext() .getBroadcastVariableWithInitializer( "stringIndexerModel", new BroadcastVariableInitializer<Row, Map<String, Integer>>() { @Override public Map<String, Integer> initializeBroadcastVariable(Iterable<Row> iterable) { List<Row> stringIndexerSerialized = new ArrayList<>(); for (Row row : iterable) { stringIndexerSerialized.add(row); } return TreeUtil.extractCategoricalColsSize( stringIndexerSerialized, params.get(HasCategoricalCols.CATEGORICAL_COLS)); } }); if (!Criteria.isRegression(params.get(TreeUtil.TREE_TYPE))) { categoricalColsSize.put( params.get(HasLabelCol.LABEL_COL), getRuntimeContext() .getBroadcastVariableWithInitializer( "labelSize", new BroadcastVariableInitializer<Integer, Integer>() { @Override public Integer initializeBroadcastVariable(Iterable<Integer> iterable) { return iterable.iterator().next(); } } ) ); } } @Override public void reduce(Iterable<Tuple2<Integer, Row>> values, Collector<Tuple2<Integer, String>> out) throws Exception { LOG.info("start the random forests training"); List<Row> dataCache = new ArrayList<>(); int treeId = 0; for (Tuple2<Integer, Row> value : values) { treeId = value.f0; dataCache.add(value.f1); } // create dense data. DenseData data = new DenseData( dataCache.size(), TreeUtil.getFeatureMeta(params.get(HasFeatureCols.FEATURE_COLS), categoricalColsSize), TreeUtil.getLabelMeta( params.get(HasLabelCol.LABEL_COL), params.get(HasFeatureCols.FEATURE_COLS).length, categoricalColsSize ) ); // read instance to data. data.readFromInstances(dataCache); // rewrite gain for this tree. params.set(Criteria.Gain.GAIN, getGainFromParams(params, treeId)); // fit the decision tree. Node root = new DecisionTree(data, params).fit(); // serialize the tree. for (String serialized : TreeModelDataConverter.serializeTree(root)) { out.collect(Tuple2.of(treeId, serialized)); } LOG.info("end the random forests training"); } } public static class AvgPartition implements Partitioner<Integer> { @Override public int partition(Integer key, int numPartitions) { return key % numPartitions; } } private static class SerializeModel extends RichGroupReduceFunction<Tuple2<Integer, String>, Row> { private Params params; private transient List<Row> stringIndexerModelSerialized; private transient Object[] labels; public SerializeModel(Params params) { this.params = params; } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); stringIndexerModelSerialized = getRuntimeContext().getBroadcastVariable("stringIndexerModel"); labels = getRuntimeContext().getBroadcastVariableWithInitializer("labels", new BroadcastVariableInitializer<Object[], Object[]>() { @Override public Object[] initializeBroadcastVariable(Iterable<Object[]> data) { Iterator<Object[]> iter = data.iterator(); if (iter.hasNext()) { return iter.next(); } else { return null; } } }); } @Override public void reduce(Iterable<Tuple2<Integer, String>> values, Collector<Row> out) throws Exception { TreeModelDataConverter.saveModelWithData( StreamSupport.stream(values.spliterator(), false) .collect(Collectors.groupingBy(x -> x.f0, Collectors.mapping(x -> x.f1, Collectors.toList()))) .entrySet() .stream() .sorted((x, y) -> x.getKey().compareTo(y.getKey())) .map(x -> TreeModelDataConverter.deserializeTree(x.getValue())) .collect(Collectors.toList()), params, stringIndexerModelSerialized, labels ).forEach(out::collect); } } private static Criteria.Gain getGainFromParams(Params params, int treeId) { TreeUtil.TreeType treeType = params.get(TreeUtil.TREE_TYPE); switch (treeType) { case AVG: return getAvgGain(params.get(HasNumTreesDefaltAs10.NUM_TREES), treeId); case PARTITION: return getIntervalGain(params.get(HasTreePartition.TREE_PARTITION), treeId); case MSE: return Criteria.Gain.MSE; case GINI: return Criteria.Gain.GINI; case INFOGAIN: return Criteria.Gain.INFOGAIN; case INFOGAINRATIO: return Criteria.Gain.INFOGAINRATIO; default: throw new IllegalArgumentException("Could not parse the gain type from params. type: " + treeType); } } private static Criteria.Gain getIntervalGain(String treeType, int id) { String[] intervalStrs = treeType.split(","); Preconditions.checkState(intervalStrs.length == 2, "Error format of treeType: " + treeType); return getIntervalGain( Integer.parseInt(intervalStrs[0]), Integer.parseInt(intervalStrs[1]), id); } private static Criteria.Gain getIntervalGain(int startGini, int startInfoGainRatio, int id) { if (id < startGini) { return Criteria.Gain.INFOGAIN; } else if (id < startInfoGainRatio) { return Criteria.Gain.GINI; } else { return Criteria.Gain.INFOGAINRATIO; } } private static Criteria.Gain getAvgGain(int treeNum, int id) { int div = treeNum / 3; int mod = treeNum % 3; int startGini = mod < 1 ? div : div + 1; int startInfoGainRatio = mod < 2 ? startGini + div : startGini + div + 1; return getIntervalGain(startGini, startInfoGainRatio, id); } public static class SampleData extends RichMapPartitionFunction<Row, Tuple2<Integer, Row>> { private long seed; private double factor; private int treeNum; public SampleData(long seed, double factor, int treeNum) { this.seed = seed; this.factor = factor; this.treeNum = treeNum; } @Override public void open(Configuration parameters) throws Exception { if (factor > 1.0) { factor = Math.min(factor / getRuntimeContext() .getBroadcastVariableWithInitializer("totalCnt", new BroadcastVariableInitializer<Long, Double>() { @Override public Double initializeBroadcastVariable(Iterable<Long> data) { for (Long cnt : data) { return cnt.doubleValue(); } throw new RuntimeException( "Can not find total sample count of sample in training dataset if factor > 1.0" ); } }), 1.0); } } @Override public void mapPartition( Iterable<Row> values, Collector<Tuple2<Integer, Row>> out) throws Exception { Random rand = new Random(this.seed + getRuntimeContext().getIndexOfThisSubtask()); for (Row row : values) { for (int i = 0; i < treeNum; ++i) { double randNum = rand.nextDouble(); if (randNum < factor) { out.collect(new Tuple2<>(i, row)); } } } } } }