package com.alibaba.alink.operator.batch.recommendation;

import com.alibaba.alink.common.utils.DataSetConversionUtil;
import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.batch.sql.JoinBatchOp;
import com.alibaba.alink.operator.common.recommendation.AlsModelDataConverter;
import com.alibaba.alink.operator.common.recommendation.AlsPredict;
import com.alibaba.alink.params.recommendation.AlsPredictParams;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

/**
 * Make predictions based on model trained from AlsTrainBatchOp.
 * <p>
 * There are two types of predictions:
 * 1) rating prediction: given user and item, predict the rating.
 * 2) recommend prediction: given a list of users, recommend k items for each users.
 */
@SuppressWarnings("unchecked")
public final class AlsPredictBatchOp
    extends BatchOperator<AlsPredictBatchOp>
    implements AlsPredictParams<AlsPredictBatchOp> {

    public AlsPredictBatchOp() {
        this(new Params());
    }

    public AlsPredictBatchOp(Params params) {
        super(params);
    }

    private static DataSet<Tuple2<Long, float[]>> getFactors(BatchOperator<?> model, final int identity) {
        return model.getDataSet()
            .flatMap(new FlatMapFunction<Row, Tuple2<Long, float[]>>() {
                @Override
                public void flatMap(Row value, Collector<Tuple2<Long, float[]>> out) throws Exception {
                    int w = AlsModelDataConverter.getIsUser(value) ? 0 : 1;
                    if (w != identity) {
                        return;
                    }

                    long idx = AlsModelDataConverter.getVertexId(value);
                    float[] factors = AlsModelDataConverter.getFactors(value);
                    out.collect(Tuple2.of(idx, factors));
                }
            });
    }

    @Override
    public AlsPredictBatchOp linkFrom(BatchOperator<?>... inputs) {
        checkOpSize(2, inputs);

        BatchOperator model = inputs[0];
        BatchOperator data = inputs[1];

        this.setOutputTable(rate(model, data));
        return this;
    }

    /**
     * Recommend items for users.
     *
     * @param model The model trained from AlsTrainBatchOp.
     * @param data  The prediction data which contains users.
     * @return Recommended items for each users.
     */
//    public Table recommendForUsers(BatchOperator model, BatchOperator data) {
//        String userColName = getUserCol();
//        String predResultColName = getPredictionCol();
//        int topk = getTopK();
//
//        data = data.select("`" + userColName + "`");
//        DataSet<Tuple1<Long>> users = data.getDataSet()
//            .map(new MapFunction<Row, Tuple1<Long>>() {
//                @Override
//                public Tuple1<Long> map(Row value) throws Exception {
//                    return Tuple1.of(((Number) value.getField(0)).longValue());
//                }
//            });
//
//        DataSet<Tuple2<Long, float[]>> userFactors = getFactors(model, 0);
//        DataSet<Tuple2<Long, float[]>> itemFactors = getFactors(model, 1);
//
//        DataSet<Tuple2<Long, String>> recommend = AlsPredict.recommendForUsers(userFactors, itemFactors, users, topk);
//
//        DataSet<Row> output = recommend
//            .map(new MapFunction<Tuple2<Long, String>, Row>() {
//                @Override
//                public Row map(Tuple2<Long, String> value) throws Exception {
//                    return Row.of(value.f0, value.f1);
//                }
//            });
//        setOutput(output, new String[]{userColName, predResultColName},
//            new TypeInformation[]{Types.LONG, Types.STRING});
//
//        return this.getOutput();
//    }

    /**
     * Recommend users for items
     *
     * @param model The model trained from AlsTrainBatchOp.
     * @param data  The prediction data which contains items.
     * @return Recommended users for each items.
     */
//    public Table recommendForItems(BatchOperator model, BatchOperator data) {
//        String itemColName = getItemCol();
//        String predResultColName = getPredictionCol();
//        int topk = getTopK();
//
//        data = data.select("`" + itemColName + "`");
//        DataSet<Tuple1<Long>> items = data.getDataSet()
//            .map(new MapFunction<Row, Tuple1<Long>>() {
//                @Override
//                public Tuple1<Long> map(Row value) throws Exception {
//                    return Tuple1.of(((Number) value.getField(0)).longValue());
//                }
//            });
//
//        DataSet<Tuple2<Long, float[]>> userFactors = getFactors(model, 0);
//        DataSet<Tuple2<Long, float[]>> itemFactors = getFactors(model, 1);
//
//        DataSet<Tuple2<Long, String>> recommend = AlsPredict.recommendForUsers(itemFactors, userFactors, items, topk);
//
//        DataSet<Row> output = recommend
//            .map(new MapFunction<Tuple2<Long, String>, Row>() {
//                @Override
//                public Row map(Tuple2<Long, String> value) throws Exception {
//                    return Row.of(value.f0, value.f1);
//                }
//            });
//        setOutput(output, new String[]{itemColName, predResultColName},
//            new TypeInformation[]{Types.LONG, Types.STRING});
//
//        return this.getOutput();
//    }

    /**
     * Predict ratings give a user and an item.
     *
     * @param model The model trained from AlsTrainBatchOp.
     * @param data  The prediction data which contains user-item pairs.
     * @return The predicted rating given by the user to the item.
     */
    public Table rate(BatchOperator model, BatchOperator data) {
        String userColName = getUserCol();
        String itemColName = getItemCol();
        String predResultColName = getPredictionCol();
        final int userColIdx = TableUtil.findColIndexWithAssertAndHint(data.getColNames(), userColName);
        final int itemColIdx = TableUtil.findColIndexWithAssertAndHint(data.getColNames(), itemColName);

        DataSet<Tuple2<Long, float[]>> userFactors = getFactors(model, 0);
        DataSet<Tuple2<Long, float[]>> itemFactors = getFactors(model, 1);

        DataSet<Tuple2<Long, Long>> ui = data
            .getDataSet()
            .map(new MapFunction<Row, Tuple2<Long, Long>>() {
                @Override
                public Tuple2<Long, Long> map(Row value) throws Exception {
                    return new Tuple2<>(((Number) value.getField(userColIdx)).longValue(),
                        ((Number) value.getField(itemColIdx)).longValue());
                }
            });

        DataSet<Tuple3<Long, Long, Double>> rating = AlsPredict.rate(userFactors, itemFactors, ui);

        DataSet<Row> output = rating
            .map(new MapFunction<Tuple3<Long, Long, Double>, Row>() {
                @Override
                public Row map(Tuple3<Long, Long, Double> value) throws Exception {
                    return Row.of(value.f0, value.f1, value.f2);
                }
            });

        Table ratingTable = DataSetConversionUtil.toTable(getMLEnvironmentId(), output,
            new String[]{userColName, itemColName, predResultColName},
            new TypeInformation[]{Types.LONG, Types.LONG, Types.DOUBLE});

        return new JoinBatchOp(
            String.format("a.`%s`=b.`%s` and a.`%s`=b.`%s`", userColName, userColName, itemColName, itemColName),
            String.format("a.*,b.`%s`", predResultColName))
            .setMLEnvironmentId(getMLEnvironmentId())
            .linkFrom(data, BatchOperator.fromTable(ratingTable).setMLEnvironmentId(getMLEnvironmentId())).getOutputTable();
    }
}