package com.alibaba.alink.operator.batch.dataproc; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.params.dataproc.SplitParams; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.utils.DataSetUtils; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.Table; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; import org.apache.flink.util.Preconditions; import java.util.*; /** * Split a dataset into two parts. */ public final class SplitBatchOp extends BatchOperator<SplitBatchOp> implements SplitParams<SplitBatchOp> { public SplitBatchOp() { this(new Params()); } public SplitBatchOp(Params params) { super(params); } public SplitBatchOp(double fraction) { this(new Params().set(FRACTION, fraction)); } @Override public SplitBatchOp linkFrom(BatchOperator<?>... inputs) { BatchOperator<?> in = checkAndGetFirst(inputs); final double fraction = getFraction(); if (fraction < 0. || fraction > 1.0) { throw new RuntimeException("invalid fraction " + fraction); } DataSet<Row> rows = in.getDataSet(); DataSet<Tuple2<Integer, Long>> countsPerPartition = DataSetUtils.countElementsPerPartition(rows); DataSet<long[]> numPickedPerPartition = countsPerPartition .mapPartition(new CountInPartition(fraction)) .setParallelism(1) .name("decide_count_of_each_partition"); DataSet<Row> out = rows .mapPartition(new PickInPartition()) .withBroadcastSet(numPickedPerPartition, "counts") .name("pick_in_each_partition"); this.setOutput(out, in.getSchema()); this.setSideOutputTables(new Table[]{in.getOutputTable().minusAll(this.getOutputTable())}); return this; } /** * Randomly decide the number of elements to select in each task */ private static class CountInPartition extends RichMapPartitionFunction<Tuple2<Integer, Long>, long[]> { private double fraction; public CountInPartition(double fraction) { this.fraction = fraction; } @Override public void mapPartition(Iterable<Tuple2<Integer, Long>> values, Collector<long[]> out) throws Exception { Preconditions.checkArgument(getRuntimeContext().getIndexOfThisSubtask() == 0); long totCount = 0L; List<Tuple2<Integer, Long>> buffer = new ArrayList<>(); for (Tuple2<Integer, Long> value : values) { totCount += value.f1; buffer.add(value); } int npart = buffer.size(); // num tasks long[] eachCount = new long[npart]; long numTarget = Math.round((totCount * fraction)); long[] eachSelect = new long[npart]; for (Tuple2<Integer, Long> value : buffer) { eachCount[value.f0] = value.f1; } long totSelect = 0L; for (int i = 0; i < npart; i++) { eachSelect[i] = Math.round(Math.floor(eachCount[i] * fraction)); totSelect += eachSelect[i]; } if (totSelect < numTarget) { long remain = numTarget - totSelect; remain = Math.min(remain, totCount - totSelect); if (remain == totCount - totSelect) { for (int i = 0; i < npart; i++) { eachSelect[i] = eachCount[i]; } } else { // select 'remain' out of 'npart' List<Integer> shuffle = new ArrayList<>(npart); while (remain > 0) { for (int i = 0; i < npart; i++) { shuffle.add(i); } Collections.shuffle(shuffle, new Random()); for (int i = 0; i < Math.min(remain, npart); i++) { int taskId = shuffle.get(i); while (eachSelect[taskId] >= eachCount[taskId]) { taskId = (taskId + 1) % npart; } eachSelect[taskId]++; } remain -= npart; } } } long[] statistics = new long[npart * 2]; for (int i = 0; i < npart; i++) { statistics[i] = eachCount[i]; statistics[i + npart] = eachSelect[i]; } out.collect(statistics); } } /** * Randomly pick elements in each task */ private static class PickInPartition extends RichMapPartitionFunction<Row, Row> { @Override public void mapPartition(Iterable<Row> values, Collector<Row> out) throws Exception { int npart = getRuntimeContext().getNumberOfParallelSubtasks(); List<long[]> bc = getRuntimeContext().getBroadcastVariable("counts"); long[] eachCount = Arrays.copyOfRange(bc.get(0), 0, npart); long[] eachSelect = Arrays.copyOfRange(bc.get(0), npart, npart * 2); if (bc.get(0).length / 2 != getRuntimeContext().getNumberOfParallelSubtasks()) { throw new RuntimeException("parallelism has changed"); } int taskId = getRuntimeContext().getIndexOfThisSubtask(); // emit the selected int[] selected = null; int iRow = 0; int numEmits = 0; for (Row row : values) { if (0 == iRow) { long count = eachCount[taskId]; long select = eachSelect[taskId]; List<Integer> shuffle = new ArrayList<>((int) count); for (int i = 0; i < count; i++) { shuffle.add(i); } Collections.shuffle(shuffle, new Random(taskId)); selected = new int[(int) select]; for (int i = 0; i < select; i++) { selected[i] = shuffle.get(i); } Arrays.sort(selected); } if (numEmits < selected.length && iRow == selected[numEmits]) { out.collect(row); numEmits++; } iRow++; } } } }