package com.alibaba.alink.pipeline; import com.alibaba.alink.operator.batch.BatchOperator; import com.alibaba.alink.operator.batch.source.TableSourceBatchOp; import com.alibaba.alink.operator.stream.StreamOperator; import com.alibaba.alink.operator.stream.source.TableSourceStreamOp; import org.apache.flink.ml.api.core.Estimator; import org.apache.flink.ml.api.misc.param.Params; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.TableEnvironment; import org.apache.flink.table.api.java.StreamTableEnvironment; import org.apache.flink.util.Preconditions; /** * The base class for estimator implementations. * * @param <E> A subclass of the {@link EstimatorBase}, used by * {@link org.apache.flink.ml.api.misc.param.WithParams} * @param <M> class type of the {@link ModelBase} this Estimator produces. */ public abstract class EstimatorBase<E extends EstimatorBase<E, M>, M extends ModelBase<M>> extends PipelineStageBase<E> implements Estimator<E, M> { public EstimatorBase() { super(); } public EstimatorBase(Params params) { super(params); } @Override public M fit(TableEnvironment tEnv, Table input) { Preconditions.checkArgument(input != null, "Input CAN NOT BE null!"); Preconditions.checkArgument( tableEnvOf(input) == tEnv, "The input table is not in the specified table environment."); return fit(input); } /** * Train and produce a {@link ModelBase} which fits the records in the given {@link Table}. * * @param input the table with records to train the Model. * @return a model trained to fit on the given Table. */ public M fit(Table input) { Preconditions.checkArgument(input != null, "Input CAN NOT BE null!"); if (tableEnvOf(input) instanceof StreamTableEnvironment) { TableSourceStreamOp source = new TableSourceStreamOp(input); if(this.params.contains(ML_ENVIRONMENT_ID)){ source.setMLEnvironmentId(this.params.get(ML_ENVIRONMENT_ID)); } return fit(source); } else { TableSourceBatchOp source = new TableSourceBatchOp(input); if(this.params.contains(ML_ENVIRONMENT_ID)){ source.setMLEnvironmentId(this.params.get(ML_ENVIRONMENT_ID)); } return fit(source); } } /** * Train and produce a {@link ModelBase} which fits the records from the given {@link BatchOperator}. * * @param input the table with records to train the Model. * @return a model trained to fit on the given Table. */ public abstract M fit(BatchOperator input); /** * Online learning and produce {@link ModelBase} series which fit the streaming records from the given {@link * StreamOperator}. * * @param input the StreamOperator with streaming records to online train the Model series. * @return the model series trained to fit on the streaming data from given StreamOperator. */ public M fit(StreamOperator input) { throw new UnsupportedOperationException("NOT supported yet!"); } }