package com.alibaba.alink.operator.common.optim.subfunc; import com.alibaba.alink.common.linalg.DenseVector; import com.alibaba.alink.common.model.ModelParamName; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; /** * Parse model rows to dense vector. * */ public class ParseRowModel extends RichMapPartitionFunction<Row, Tuple2<DenseVector, double[]>> { @Override public void mapPartition(Iterable<Row> iterable, Collector<Tuple2<DenseVector, double[]>> collector) throws Exception { DenseVector coefVector = null; double[] lossCurve = null; int taskId = getRuntimeContext().getIndexOfThisSubtask(); if (taskId == 0) { for (Row row : iterable) { Params params = Params.fromJson((String)row.getField(0)); coefVector = params.get(ModelParamName.COEF); lossCurve = params.get(ModelParamName.LOSS_CURVE); } if (coefVector != null) { collector.collect(Tuple2.of(coefVector, lossCurve)); } } } }