package com.alibaba.alink.common.utils; import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.ReduceFunction; import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.DataSet; import org.apache.flink.util.Collector; /** * Utils for handling dataset. */ public class DataSetUtil { /** * Count number of records in the dataset. * * @return a dataset of one record, recording the number of records of [[dataset]] */ public static <T> DataSet<Long> count(DataSet<T> dataSet) { return dataSet .mapPartition(new MapPartitionFunction<T, Long>() { @Override public void mapPartition(Iterable<T> values, Collector<Long> out) throws Exception { long cnt = 0L; for (T v : values) { cnt++; } out.collect(cnt); } }) .name("count_dataset") .returns(Types.LONG) .reduce(new ReduceFunction<Long>() { @Override public Long reduce(Long value1, Long value2) throws Exception { return value1 + value2; } }); } /** * Returns an empty dataset of the same type as [[dataSet]]. */ public static <T> DataSet<T> empty(DataSet<T> dataSet) { return dataSet .mapPartition(new MapPartitionFunction<T, T>() { @Override public void mapPartition(Iterable<T> values, Collector<T> out) throws Exception { } }) .returns(dataSet.getType()); } }