package com.alibaba.alink.operator.batch.dataproc; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.common.utils.DataSetConversionUtil; import com.alibaba.alink.common.utils.RowUtil; import com.alibaba.alink.params.dataproc.AppendIdParams; import org.apache.commons.lang3.ArrayUtils; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.functions.RuntimeContext; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.utils.DataSetUtils; import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; /** * Append an id column to BatchOperator. the id can be DENSE or UNIQUE * * @see DataSetUtils#zipWithIndex * @see DataSetUtils#zipWithUniqueId */ public final class AppendIdBatchOp extends BatchOperator<AppendIdBatchOp> implements AppendIdParams <AppendIdBatchOp> { public final static String appendIdColName = "append_id"; public final static TypeInformation appendIdColType = BasicTypeInfo.LONG_TYPE_INFO; public AppendIdBatchOp() { super(null); } public AppendIdBatchOp(Params params) { super(params); } public static Table appendId(DataSet <Row> dataSet, TableSchema schema, Long sessionId) { return AppendIdBatchOp.appendId( dataSet, schema, AppendIdBatchOp.appendIdColName, AppendType.DENSE, sessionId); } public static Table appendId( DataSet <Row> dataSet, TableSchema schema, String appendIdColName, AppendType appendType, Long sessionId) { String[] rawColNames = schema.getFieldNames(); TypeInformation[] rawColTypes = schema.getFieldTypes(); String[] colNames = ArrayUtils.add(rawColNames, appendIdColName); TypeInformation[] colTypes = ArrayUtils.add(rawColTypes, appendIdColType); DataSet <Row> ret = null; switch (appendType) { case DENSE: ret = DataSetUtils.zipWithIndex(dataSet) .map(new TransTupleToRowMapper()); break; case UNIQUE: ret = DataSetUtils.zipWithUniqueId(dataSet) .map(new TransTupleToRowMapper()); ret = dataSet.map(new AppendIdMapper()); break; default: throw new IllegalArgumentException("Error append type."); } return DataSetConversionUtil.toTable(sessionId, ret, colNames, colTypes); } @Override public AppendIdBatchOp linkFrom(BatchOperator<?>... inputs) { checkOpSize(1, inputs); this.setOutputTable(appendId( inputs[0].getDataSet(), inputs[0].getSchema(), getIdCol(), getAppendType(), getMLEnvironmentId() )); return this; } public static class AppendIdMapper extends RichMapFunction <Row, Row> { private long parallelism; private long counter; @Override public void open(Configuration parameters) throws Exception { RuntimeContext ctx = getRuntimeContext(); parallelism = ctx.getNumberOfParallelSubtasks(); counter = ctx.getIndexOfThisSubtask(); } @Override public Row map(Row value) throws Exception { Row ret = RowUtil.merge(value, Long.valueOf(counter)); counter += parallelism; return ret; } } public static class TransTupleToRowMapper implements MapFunction <Tuple2 <Long, Row>, Row> { @Override public Row map(Tuple2 <Long, Row> value) throws Exception { Row ret = RowUtil.merge(value.f1, value.f0); return ret; } } }