package com.alibaba.alink.operator.common.dataproc; import org.apache.flink.api.common.functions.BroadcastVariableInitializer; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.typeutils.TupleTypeInfo; import org.apache.flink.api.java.utils.DataSetUtils; import org.apache.flink.configuration.Configuration; import org.apache.flink.util.Collector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; import java.util.Comparator; import java.util.HashSet; import java.util.List; import java.util.Set; public final class SortUtilsNext { private static final Logger LOG = LoggerFactory.getLogger(SortUtilsNext.class); private final static int SPLIT_POINT_SIZE = 1000; /** * <p> * reference: Yang, X. (2014). Chong gou da shu ju tong ji (1st ed., pp. 25-29). * <p> * Note: This algorithm is improved on the base of the parallel * sorting by regular sampling(PSRS). * * @param input input dataset * @return f0: dataset which is indexed by partition id, f1: dataset which has partition id and count. */ public static <T extends Comparable<T>> Tuple2<DataSet<T>, DataSet<Tuple2<Integer, Long>>> pSort(DataSet<T> input) { DataSet<Tuple2<Integer, Long>> cnt = DataSetUtils.countElementsPerPartition(input); DataSet<T> sorted = input .mapPartition(new RichMapPartitionFunction<T, T>() { int taskId; int cnt; @Override public void open(Configuration parameters) throws Exception { super.open(parameters); taskId = getRuntimeContext().getIndexOfThisSubtask(); LOG.info("{} open.", getRuntimeContext().getTaskName()); List<Tuple2<Integer, Long>> cntVar = getRuntimeContext().getBroadcastVariable("cnt"); for (Tuple2<Integer, Long> var : cntVar) { if (var.f0 == taskId) { cnt = var.f1.intValue(); break; } } } @Override public void close() throws Exception { super.close(); LOG.info("{} close.", getRuntimeContext().getTaskName()); } @Override public void mapPartition(Iterable<T> values, Collector<T> out) throws Exception { ArrayList<T> all = new ArrayList<>(cnt); for (T val : values) { all.add(val); } all.sort(Comparator.naturalOrder()); for (T val : all) { out.collect(val); } } }) .withBroadcastSet(cnt, "cnt") .returns(input.getType()); DataSet<Tuple2<Object, Integer>> splitPoints = sorted .mapPartition(new SampleSplitPoint<>()) .withBroadcastSet(cnt, "cnt") .reduceGroup(new SplitPointReducer()); DataSet<Tuple2<Integer, T>> splitData = sorted .mapPartition(new SplitData<>()) .withBroadcastSet(splitPoints, "splitPoints") .returns(new TupleTypeInfo<>(Types.INT, input.getType())); DataSet<T> partitioned = splitData .partitionCustom(new Partitioner<Integer>() { @Override public int partition(Integer key, int numPartitions) { return key % numPartitions; } }, 0) .map(new MapFunction<Tuple2<Integer, T>, T>() { @Override public T map(Tuple2<Integer, T> value) throws Exception { return value.f1; } }) .returns(input.getType()); DataSet<Tuple2<Integer, Long>> partitionedCnt = DataSetUtils.countElementsPerPartition(partitioned); return Tuple2.of(partitioned, partitionedCnt); } private static long genSampleIndex(long splitPointIdx, long count, long splitPointSize) { splitPointIdx++; splitPointSize++; long div = count / splitPointSize; long mod = count % splitPointSize; return div * splitPointIdx + (Math.min(mod, splitPointIdx)) - 1; } /** * */ public final static class SampleSplitPoint<T> extends RichMapPartitionFunction<T, Tuple2<Object, Integer>> { private int taskId; private int cnt; public SampleSplitPoint() { } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); this.taskId = getRuntimeContext().getIndexOfThisSubtask(); LOG.info("{} open.", getRuntimeContext().getTaskName()); List<Tuple2<Integer, Long>> allCnt = getRuntimeContext().getBroadcastVariable("cnt"); for (Tuple2<Integer, Long> localCnt : allCnt) { if (localCnt.f0 == taskId) { cnt = localCnt.f1.intValue(); break; } } } @Override public void close() throws Exception { super.close(); LOG.info("{} close.", getRuntimeContext().getTaskName()); } @Override public void mapPartition(Iterable<T> values, Collector<Tuple2<Object, Integer>> out) throws Exception { if (cnt <= 0) { out.collect(new Tuple2( getRuntimeContext().getNumberOfParallelSubtasks(), -taskId - 1)); return; } int localSplitPointSize = Math.min(SPLIT_POINT_SIZE, cnt - 1); ArrayList<Tuple2<Object, Integer>> splitPoints = new ArrayList<>(localSplitPointSize); int dataIndex = 0; int splitPointIdx = 0; int splitPointInDataIndex = (int) genSampleIndex(splitPointIdx, cnt, localSplitPointSize); for (T val : values) { if (dataIndex == splitPointInDataIndex) { out.collect(Tuple2.of(val, taskId)); splitPointInDataIndex = (int) genSampleIndex(++splitPointIdx, cnt, localSplitPointSize); if (splitPointInDataIndex >= cnt) { break; } } dataIndex++; } out.collect(new Tuple2( getRuntimeContext().getNumberOfParallelSubtasks(), -taskId - 1)); } } public static class SplitPointReducer extends RichGroupReduceFunction<Tuple2<Object, Integer>, Tuple2<Object, Integer>> { public SplitPointReducer() { } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); LOG.info("{} open.", getRuntimeContext().getTaskName()); } @Override public void close() throws Exception { super.close(); LOG.info("{} close.", getRuntimeContext().getTaskName()); } @Override public void reduce( Iterable<Tuple2<Object, Integer>> values, Collector<Tuple2<Object, Integer>> out) throws Exception { ArrayList<Tuple2<Object, Integer>> all = new ArrayList<>(); int instanceCount = -1; for (Tuple2<Object, Integer> value : values) { if (value.f1 < 0) { instanceCount = (int) value.f0; continue; } all.add(Tuple2.of(value.f0, value.f1)); } if (all.isEmpty()) { return; } int count = all.size(); all.sort(new SortUtils.PairComparator()); Set<Tuple2<Object, Integer>> split = new HashSet<>(); int splitPointSize = instanceCount - 1; for (int i = 0; i < splitPointSize; ++i) { int index = (int) genSampleIndex(i, count, splitPointSize); if (index >= count) { throw new Exception("Index error. index: " + index + ". totalCount: " + count); } split.add(all.get(index)); } for (Tuple2<Object, Integer> sSplit : split) { out.collect(sSplit); } } } public final static class SplitData<T> extends RichMapPartitionFunction<T, Tuple2<Integer, T>> { private int taskId; private List<Tuple2<Object, Integer>> splitPoints; private Tuple2<Integer, T> outBuff; public SplitData() { } @Override public void close() throws Exception { super.close(); LOG.info("{} close.", getRuntimeContext().getTaskName()); } @Override public void open(Configuration parameters) throws Exception { super.open(parameters); RuntimeContext ctx = getRuntimeContext(); this.taskId = ctx.getIndexOfThisSubtask(); this.splitPoints = ctx.getBroadcastVariableWithInitializer( "splitPoints", new BroadcastVariableInitializer<Tuple2<Object, Integer>, List<Tuple2<Object, Integer>>>() { @Override public List<Tuple2<Object, Integer>> initializeBroadcastVariable( Iterable<Tuple2<Object, Integer>> data) { // sort the list by task id to calculate the correct offset List<Tuple2<Object, Integer>> sortedData = new ArrayList<>(); for (Tuple2<Object, Integer> datum : data) { sortedData.add(datum); } sortedData.sort(new SortUtils.PairComparator()); return sortedData; } }); outBuff = new Tuple2<>(); LOG.info("{} open.", getRuntimeContext().getTaskName()); } /** * use binary search to partition data into sorted subsets * notice: data within each subset will not be sorted. */ @Override public void mapPartition(Iterable<T> values, Collector<Tuple2<Integer, T>> out) throws Exception { if (splitPoints.isEmpty()) { for (T val : values) { outBuff.setFields(0, val); out.collect(outBuff); } return; } int splitSize = splitPoints.size(); int curIndex = 0; SortUtils.PairComparator pairComparator = new SortUtils.PairComparator(); Tuple2<Object, Integer> curTuple = Tuple2.of(null, taskId); for (T val : values) { if (curIndex < splitSize) { curTuple.f0 = val; int code = pairComparator.compare( curTuple, splitPoints.get(curIndex) ); if (code > 0) { ++curIndex; while (curIndex < splitSize) { code = pairComparator.compare( curTuple, splitPoints.get(curIndex)); if (code <= 0) { break; } ++curIndex; } } } outBuff.setFields(curIndex, val); out.collect(outBuff); } } } }