package com.alibaba.alink.operator.stream.onlinelearning; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; import com.alibaba.alink.common.io.directreader.DataBridge; import com.alibaba.alink.common.io.directreader.DirectReader; import com.alibaba.alink.common.utils.DataStreamConversionUtil; import com.alibaba.alink.common.utils.TableUtil; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.common.linear.LinearModelData; import com.alibaba.alink.operator.common.linear.LinearModelDataConverter; import com.alibaba.alink.operator.common.linear.LinearModelMapper; import com.alibaba.alink.operator.stream.StreamOperator; import com.alibaba.alink.params.onlinelearning.FtrlPredictParams; import org.apache.flink.api.common.functions.FlatMapFunction; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.Partitioner; import org.apache.flink.api.common.functions.RichFlatMapFunction; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.configuration.Configuration; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.functions.co.RichCoFlatMapFunction; import org.apache.flink.table.api.TableSchema; import org.apache.flink.types.Row; import org.apache.flink.util.Collector; /** * Ftrl predictor receive two stream : model stream and data stream. It using updated model by model stream real-time, * and using the newest model predict data stream. */ public final class FtrlPredictStreamOp extends StreamOperator<FtrlPredictStreamOp> implements FtrlPredictParams<FtrlPredictStreamOp> { private DataBridge dataBridge = null; public FtrlPredictStreamOp(BatchOperator model) { super(new Params()); if (model != null) { dataBridge = DirectReader.collect(model); } else { throw new IllegalArgumentException("Ftrl algo: initial model is null. Please set a valid initial model."); } } public FtrlPredictStreamOp(BatchOperator model, Params params) { super(params); if (model != null) { dataBridge = DirectReader.collect(model); } else { throw new IllegalArgumentException("Ftrl algo: initial model is null. Please set a valid initial model."); } } @Override public FtrlPredictStreamOp linkFrom(StreamOperator<?>... inputs) { checkOpSize(2, inputs); try { DataStream<LinearModelData> modelstr = inputs[0].getDataStream() .flatMap(new RichFlatMapFunction<Row, Tuple2<Integer, Row>>() { @Override public void flatMap(Row row, Collector<Tuple2<Integer, Row>> out) throws Exception { int numTasks = getRuntimeContext().getNumberOfParallelSubtasks(); for (int i = 0; i < numTasks; ++i) { out.collect(Tuple2.of(i, row)); } } }).partitionCustom(new Partitioner<Integer>() { @Override public int partition(Integer key, int numPartitions) { return key; } }, 0).map(new MapFunction<Tuple2<Integer, Row>, Row>() { @Override public Row map(Tuple2<Integer, Row> value) throws Exception { return value.f1; } }) .flatMap(new CollectModel()); TypeInformation[] types = new TypeInformation[3]; String[] names = new String[3]; for (int i = 0; i < 3; ++i) { names[i] = inputs[0].getSchema().getFieldNames()[i + 2]; types[i] = inputs[0].getSchema().getFieldTypes()[i + 2]; } TableSchema modelSchema = new TableSchema(names, types); /* predict samples */ DataStream<Row> prediction = inputs[1].getDataStream() .connect(modelstr) .flatMap(new PredictProcess(TableUtil.toSchemaJson(modelSchema), TableUtil.toSchemaJson(inputs[1].getSchema()), this.getParams(), dataBridge)); this.setOutputTable(DataStreamConversionUtil.toTable(getMLEnvironmentId(), prediction, new LinearModelMapper(modelSchema, inputs[1].getSchema(), getParams()).getOutputSchema())); } catch (Exception ex) { ex.printStackTrace(); throw new RuntimeException(ex.toString()); } return this; } public static class CollectModel implements FlatMapFunction<Row, LinearModelData> { private Map<Long, List<Row>> buffers = new HashMap<>(0); @Override public void flatMap(Row inRow, Collector<LinearModelData> out) throws Exception { long id = (long)inRow.getField(0); Long nTab = (long)inRow.getField(1); Row row = new Row(inRow.getArity() - 2); for (int i = 0; i < row.getArity(); ++i) { row.setField(i, inRow.getField(i + 2)); } if (buffers.containsKey(id) && buffers.get(id).size() == nTab.intValue() - 1) { buffers.get(id).add(row); LinearModelData ret = new LinearModelDataConverter().load(buffers.get(id)); buffers.get(id).clear(); System.out.println("collect model : " + id); out.collect(ret); } else { if (buffers.containsKey(id)) { buffers.get(id).add(row); } else { List<Row> buffer = new ArrayList<>(0); buffer.add(row); buffers.put(id, buffer); } } } } public static class PredictProcess extends RichCoFlatMapFunction<Row, LinearModelData, Row> { private LinearModelMapper predictor = null; private String modelSchemaJson; private String dataSchemaJson; private Params params; private int iter = 0; private DataBridge dataBridge; public PredictProcess(String modelSchemaJson, String dataSchemaJson, Params params, DataBridge dataBridge) { this.dataBridge = dataBridge; this.modelSchemaJson = modelSchemaJson; this.dataSchemaJson = dataSchemaJson; this.params = params; } @Override public void open(Configuration parameters) throws Exception { this.predictor = new LinearModelMapper(TableUtil.fromSchemaJson(modelSchemaJson), TableUtil.fromSchemaJson(dataSchemaJson), this.params); if (dataBridge != null) { // read init model List<Row> modelRows = DirectReader.directRead(dataBridge); LinearModelData model = new LinearModelDataConverter().load(modelRows); this.predictor.loadModel(model); } } @Override public void flatMap1(Row row, Collector<Row> collector) throws Exception { collector.collect(this.predictor.map(row)); } @Override public void flatMap2(LinearModelData linearModel, Collector<Row> collector) throws Exception { this.predictor.loadModel(linearModel); System.out.println(getRuntimeContext().getIndexOfThisSubtask() + " load model : " + iter++); } } }