package com.alibaba.alink.operator.common.dataproc;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.TableSchema;
import org.apache.flink.types.Row;

import com.alibaba.alink.common.mapper.ModelMapper;
import com.alibaba.alink.common.model.RichModelDataConverter;
import com.alibaba.alink.common.utils.OutputColsHelper;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.params.dataproc.SrtPredictMapperParams;

import java.util.List;

/**
 * This mapper changes a row values to range [-1, 1] by dividing through the maximum absolute value of each feature.
 */
public class MaxAbsScalerModelMapper extends ModelMapper {
    private int[] selectedColIndices;
    private double[] maxAbs;
    private OutputColsHelper predictResultColsHelper;

    /**
     * Constructor.
     * @param modelSchema the model schema.
     * @param dataSchema  the data schema.
     * @param params      the params.
     */
    public MaxAbsScalerModelMapper(TableSchema modelSchema, TableSchema dataSchema, Params params) {
        super(modelSchema, dataSchema, params);
        String[] selectedColNames = RichModelDataConverter.extractSelectedColNames(modelSchema);
        TypeInformation[] selectedColTypes = RichModelDataConverter.extractSelectedColTypes(modelSchema);
        this.selectedColIndices = TableUtil.findColIndicesWithAssert(dataSchema, selectedColNames);

        String[] outputColNames = params.get(SrtPredictMapperParams.OUTPUT_COLS);
        if (outputColNames == null) {
            outputColNames = selectedColNames;
        }

        this.predictResultColsHelper = new OutputColsHelper(dataSchema,
            outputColNames, selectedColTypes, null);
    }

    /**
     * Load model from the list of Row type data.
     *
     * @param modelRows the list of Row type data.
     */
    @Override
    public void loadModel(List<Row> modelRows) {
        MaxAbsScalerModelDataConverter converter = new MaxAbsScalerModelDataConverter();
        maxAbs = converter.load(modelRows);
    }

    /**
     * Get the table schema(includes column names and types) of the calculation result.
     *
     * @return the table schema of output Row type data.
     */
    @Override
    public TableSchema getOutputSchema() {
        return this.predictResultColsHelper.getResultSchema();
    }

    /**
     * Map operation method.
     *
     * @param row the input Row type data.
     * @return one Row type data.
     * @throws Exception This method may throw exceptions. Throwing
     *                   an exception will cause the operation to fail.
     */
    @Override
    public Row map(Row row) throws Exception {
        if (null == row) {
            return null;
        }
        Row r = new Row(selectedColIndices.length);
        for (int i = 0; i < this.selectedColIndices.length; i++) {
            Object obj = row.getField(this.selectedColIndices[i]);
            if (null != obj) {
                double d;
                if (obj instanceof Number) {
                    d = ((Number) obj).doubleValue();
                } else {
                    d = Double.parseDouble(obj.toString());
                }
                r.setField(i, ScalerUtil.maxAbsScaler(this.maxAbs[i], d));
            }
        }
        return this.predictResultColsHelper.getResultRow(row, r);
    }
}