package com.alibaba.alink.operator.batch.feature;

import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.GroupReduceFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.functions.RichReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.utils.DataSetUtils;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

import com.alibaba.alink.common.utils.TableUtil;
import com.alibaba.alink.operator.batch.BatchOperator;
import com.alibaba.alink.operator.common.dataproc.SortUtils;
import com.alibaba.alink.operator.common.dataproc.SortUtilsNext;
import com.alibaba.alink.operator.common.feature.QuantileDiscretizerModelDataConverter;
import com.alibaba.alink.operator.common.feature.binning.BinTypes;
import com.alibaba.alink.operator.common.feature.binning.FeatureBorder;
import com.alibaba.alink.operator.common.feature.quantile.PairComparable;
import com.alibaba.alink.operator.common.tree.Preprocessing;
import com.alibaba.alink.params.feature.QuantileDiscretizerTrainParams;
import com.alibaba.alink.params.statistics.HasRoundMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.StreamSupport;

/**
 * Fit a quantile discretizer model.
 */
public final class QuantileDiscretizerTrainBatchOp extends BatchOperator<QuantileDiscretizerTrainBatchOp>
	implements QuantileDiscretizerTrainParams<QuantileDiscretizerTrainBatchOp> {

	private static final Logger LOG = LoggerFactory.getLogger(QuantileDiscretizerTrainBatchOp.class);

	public QuantileDiscretizerTrainBatchOp() {
	}

	public QuantileDiscretizerTrainBatchOp(Params params) {
		super(params);
	}

	public static DataSet<Row> quantile(
		DataSet<Row> input,
		final int[] quantileNum,
		final HasRoundMode.RoundMode roundMode,
		final boolean zeroAsMissing) {
		/* instance count of dataset */
		DataSet<Long> cnt = DataSetUtils
			.countElementsPerPartition(input)
			.sum(1)
			.map(new MapFunction<Tuple2<Integer, Long>, Long>() {
				@Override
				public Long map(Tuple2<Integer, Long> value) throws Exception {
					return value.f1;
				}
			});

		/* missing count of columns */
		DataSet<Tuple2<Integer, Long>> missingCount = input
			.mapPartition(new RichMapPartitionFunction<Row, Tuple2<Integer, Long>>() {
				@Override
				public void open(Configuration parameters) throws Exception {
					super.open(parameters);
					LOG.info("{} open.", getRuntimeContext().getTaskName());
				}

				@Override
				public void close() throws Exception {
					super.close();
					LOG.info("{} close.", getRuntimeContext().getTaskName());
				}

				@Override
				public void mapPartition(Iterable<Row> values, Collector<Tuple2<Integer, Long>> out)
					throws Exception {
					StreamSupport.stream(values.spliterator(), false)
						.flatMap(x -> {
							long[] counts = new long[x.getArity()];

							Arrays.fill(counts, 0L);

							for (int i = 0; i < x.getArity(); ++i) {
								if (x.getField(i) == null
								|| (zeroAsMissing && ((Number) x.getField(i)).doubleValue() == 0.0)
								|| Double.isNaN(((Number)x.getField(i)).doubleValue())) {
									counts[i]++;
								}
							}

							return IntStream.range(0, x.getArity())
								.mapToObj(y -> Tuple2.of(y, counts[y]));
						})
						.collect(Collectors.groupingBy(
							x -> x.f0,
							Collectors.mapping(x -> x.f1, Collectors.reducing((a, b) -> a + b))
							)
						)
						.entrySet()
						.stream()
						.map(x -> Tuple2.of(x.getKey(), x.getValue().get()))
						.forEach(out::collect);
				}
			})
			.groupBy(0)
			.reduce(new RichReduceFunction<Tuple2<Integer, Long>>() {
				@Override
				public void open(Configuration parameters) throws Exception {
					super.open(parameters);
					LOG.info("{} open.", getRuntimeContext().getTaskName());
				}

				@Override
				public void close() throws Exception {
					super.close();
					LOG.info("{} close.", getRuntimeContext().getTaskName());
				}

				@Override
				public Tuple2<Integer, Long> reduce(Tuple2<Integer, Long> value1, Tuple2<Integer, Long> value2)
					throws Exception {
					return Tuple2.of(value1.f0, value1.f1 + value2.f1);
				}
			});

		/* flatten dataset to 1d */
		DataSet<PairComparable> flatten = input
			.mapPartition(new RichMapPartitionFunction<Row, PairComparable>() {

				PairComparable pairBuff;

				@Override
				public void open(Configuration parameters) throws Exception {
					super.open(parameters);
					LOG.info("{} open.", getRuntimeContext().getTaskName());
					pairBuff = new PairComparable();
				}

				@Override
				public void close() throws Exception {
					super.close();
					LOG.info("{} close.", getRuntimeContext().getTaskName());
				}

				@Override
				public void mapPartition(Iterable<Row> values, Collector<PairComparable> out) {
					for (Row value : values) {
						for (int i = 0; i < value.getArity(); ++i) {
							pairBuff.first = i;
							if (value.getField(i) == null
								|| (zeroAsMissing && ((Number) value.getField(i)).doubleValue() == 0.0)
								|| Double.isNaN(((Number)value.getField(i)).doubleValue())) {
								pairBuff.second = null;
							} else {
								pairBuff.second = (Number) value.getField(i);
							}
							out.collect(pairBuff);
						}
					}
				}
			});

		/* sort data */
		Tuple2<DataSet<PairComparable>, DataSet<Tuple2<Integer, Long>>> sortedData
			= SortUtilsNext.pSort(flatten);

		/* calculate quantile */
		return sortedData.f0
			.mapPartition(new MultiQuantile(quantileNum, roundMode))
			.withBroadcastSet(sortedData.f1, "counts")
			.withBroadcastSet(cnt, "totalCnt")
			.withBroadcastSet(missingCount, "missingCounts")
			.groupBy(0)
			.reduceGroup(new RichGroupReduceFunction<Tuple2<Integer, Number>, Row>() {
				@Override
				public void open(Configuration parameters) throws Exception {
					super.open(parameters);
					LOG.info("{} open.", getRuntimeContext().getTaskName());
				}

				@Override
				public void close() throws Exception {
					super.close();
					LOG.info("{} close.", getRuntimeContext().getTaskName());
				}

				@Override
				public void reduce(Iterable<Tuple2<Integer, Number>> values, Collector<Row> out) throws Exception {
					TreeSet<Number> set = new TreeSet<>(new Comparator<Number>() {
						@Override
						public int compare(Number o1, Number o2) {
							return SortUtils.OBJECT_COMPARATOR.compare(o1, o2);
						}
					});

					int id = -1;
					for (Tuple2<Integer, Number> val : values) {
						id = val.f0;
						set.add(val.f1);
					}

					out.collect(Row.of(id, set.toArray(new Number[0])));
				}
			});
	}

	@Override
	public QuantileDiscretizerTrainBatchOp linkFrom(BatchOperator<?>... inputs) {
		BatchOperator<?> in = checkAndGetFirst(inputs);
		if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS) && getParams().contains(
			QuantileDiscretizerTrainParams.NUM_BUCKETS_ARRAY)) {
			throw new RuntimeException("It can not set num_buckets and num_buckets_array at the same time.");
		}

		String[] quantileColNames =
			getSelectedCols();

		int[] quantileNum = null;

		if (getParams().contains(QuantileDiscretizerTrainParams.NUM_BUCKETS)) {
			quantileNum = new int[quantileColNames.length];
			Arrays.fill(quantileNum, getNumBuckets());
		} else {
			quantileNum = Arrays.stream(getNumBucketsArray()).mapToInt(Integer::intValue).toArray();
		}

		/* filter the selected column from input */
		DataSet<Row> input = Preprocessing.select(in, quantileColNames).getDataSet();

		DataSet<Row> quantile = quantile(
			input, quantileNum,
			getParams().get(HasRoundMode.ROUND_MODE),
			getParams().get(Preprocessing.ZERO_AS_MISSING)
		);

		quantile = quantile.reduceGroup(
			new SerializeModel(
				getParams(),
				quantileColNames,
				TableUtil.findColTypesWithAssertAndHint(in.getSchema(), quantileColNames),
				BinTypes.BinDivideType.QUANTILE
			)
		);

		/* set output */
		setOutput(quantile, new QuantileDiscretizerModelDataConverter().getModelSchema());

		return this;
	}

	public static class MultiQuantile
		extends RichMapPartitionFunction<PairComparable, Tuple2<Integer, Number>> {
		private List<Tuple2<Integer, Long>> counts;
		private List<Tuple2<Integer, Long>> missingCounts;
		private long totalCnt = 0;
		private int[] quantileNum;
		private HasRoundMode.RoundMode roundType;
		private int taskId;

		public MultiQuantile(int[] quantileNum, HasRoundMode.RoundMode roundType) {
			this.quantileNum = quantileNum;
			this.roundType = roundType;
		}

		@Override
		public void open(Configuration parameters) throws Exception {
			this.counts = getRuntimeContext().getBroadcastVariableWithInitializer(
				"counts",
				new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
					@Override
					public List<Tuple2<Integer, Long>> initializeBroadcastVariable(
						Iterable<Tuple2<Integer, Long>> data) {
						ArrayList<Tuple2<Integer, Long>> sortedData = new ArrayList<>();
						for (Tuple2<Integer, Long> datum : data) {
							sortedData.add(datum);
						}

						sortedData.sort(Comparator.comparing(o -> o.f0));

						return sortedData;
					}
				});

			this.totalCnt = getRuntimeContext().getBroadcastVariableWithInitializer("totalCnt",
				new BroadcastVariableInitializer<Long, Long>() {
					@Override
					public Long initializeBroadcastVariable(Iterable<Long> data) {
						return data.iterator().next();
					}
				});

			this.missingCounts = getRuntimeContext().getBroadcastVariableWithInitializer(
				"missingCounts",
				new BroadcastVariableInitializer<Tuple2<Integer, Long>, List<Tuple2<Integer, Long>>>() {
					@Override
					public List<Tuple2<Integer, Long>> initializeBroadcastVariable(
						Iterable<Tuple2<Integer, Long>> data) {
						return StreamSupport.stream(data.spliterator(), false)
							.sorted(Comparator.comparing(o -> o.f0))
							.collect(Collectors.toList());
					}
				}
			);

			taskId = getRuntimeContext().getIndexOfThisSubtask();
			LOG.info("{} open.", getRuntimeContext().getTaskName());
		}

		@Override
		public void close() throws Exception {
			super.close();
			LOG.info("{} close.", getRuntimeContext().getTaskName());
		}

		@Override
		public void mapPartition(Iterable<PairComparable> values, Collector<Tuple2<Integer, Number>> out) throws Exception {

			long start = 0;
			long end;

			int curListIndex = -1;
			int size = counts.size();

			for (int i = 0; i < size; ++i) {
				int curId = counts.get(i).f0;

				if (curId == taskId) {
					curListIndex = i;
					break;
				}

				if (curId > taskId) {
					throw new RuntimeException("Error curId: " + curId
						+ ". id: " + taskId);
				}

				start += counts.get(i).f1;
			}

			end = start + counts.get(curListIndex).f1;

			ArrayList<PairComparable> allRows = new ArrayList<>((int) (end - start));

			for (PairComparable value : values) {
				allRows.add(value);
			}

			if (allRows.isEmpty()) {
				return;
			}

			if (allRows.size() != end - start) {
				throw new Exception("Error start end."
					+ " start: " + start
					+ ". end: " + end
					+ ". size: " + allRows.size());
			}

			LOG.info("taskId: {}, size: {}", getRuntimeContext().getIndexOfThisSubtask(), allRows.size());

			allRows.sort(Comparator.naturalOrder());

			size = (int) ((end - 1) / totalCnt - start / totalCnt) + 1;

			int localStart = 0;

			for (int i = 0; i < size; ++i) {
				int fIdx = (int) (start / totalCnt + i);
				int subStart = 0;
				int subEnd = (int) totalCnt;

				if (i == 0) {
					subStart = (int) (start % totalCnt);
				}

				if (i == size - 1) {
					subEnd = (int) (end % totalCnt == 0 ? totalCnt : end % totalCnt);
				}

				if (totalCnt - missingCounts.get(fIdx).f1 == 0) {
					localStart += subEnd - subStart;
					continue;
				}

				QIndex qIndex = new QIndex(
					totalCnt - missingCounts.get(fIdx).f1, quantileNum[fIdx], roundType);

				for (int j = 1; j < quantileNum[fIdx]; ++j) {
					long index = qIndex.genIndex(j);

					if (index >= subStart && index < subEnd) {
						PairComparable pairComparable = allRows.get(
							(int) (index + localStart - subStart));
						out.collect(Tuple2.of(pairComparable.first, pairComparable.second));
					}
				}

				localStart += subEnd - subStart;
			}
		}
	}

	public static class SerializeModel implements GroupReduceFunction<Row, Row> {
		private Params meta;
		private String[] colNames;
		private TypeInformation<?>[] colTypes;
		private BinTypes.BinDivideType binDivideType;

		public SerializeModel(Params meta, String[] colNames, TypeInformation<?>[] colTypes, BinTypes.BinDivideType binDivideType) {
			this.meta = meta;
			this.colNames = colNames;
			this.colTypes = colTypes;
			this.binDivideType = binDivideType;
		}

		@Override
		public void reduce(Iterable<Row> values, Collector<Row> out) throws Exception {
			Map<String, FeatureBorder> m = new HashMap<>();
			for (Row val : values) {
				int index = (int) val.getField(0);
				Number[] splits = (Number[]) val.getField(1);
				m.put(
					colNames[index],
					QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
						colNames[index],
						colTypes[index],
						splits,
						meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),
						binDivideType
					)
				);
			}

			for (int i = 0; i < colNames.length; ++i) {
				if (m.containsKey(colNames[i])) {
					continue;
				}

				m.put(
					colNames[i],
					QuantileDiscretizerModelDataConverter.arraySplit2FeatureBorder(
						colNames[i],
						colTypes[i],
						null,
						meta.get(QuantileDiscretizerTrainParams.LEFT_OPEN),
						binDivideType
					)
				);
			}

			QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(m, meta);

			model.save(model, out);
		}
	}

	public static class QIndex {
		private double totalCount;
		private double q1;
		private HasRoundMode.RoundMode roundMode;

		public QIndex(double totalCount, int quantileNum, HasRoundMode.RoundMode type) {
			this.totalCount = totalCount;
			this.q1 = 1.0 / (double) quantileNum;
			this.roundMode = type;
		}

		public long genIndex(int k) {
			return roundMode.calc(this.q1 * (this.totalCount - 1.0) * (double) k);
		}
	}

	public static DataSet<FeatureBorder> transformModelToFeatureBorder(DataSet<Row> modelDataSet) {
		return modelDataSet
			.reduceGroup(
				new GroupReduceFunction<Row, FeatureBorder>() {
					@Override
					public void reduce(Iterable<Row> values, Collector<FeatureBorder> out) throws Exception {
						List<Row> list = new ArrayList<>();
						values.forEach(list::add);
						QuantileDiscretizerModelDataConverter model
							= new QuantileDiscretizerModelDataConverter().load(list);
						for (Map.Entry<String, FeatureBorder> entry : model.data.entrySet()) {
							out.collect(entry.getValue());
						}
					}
				}
			);
	}

	public static DataSet<Row> transformFeatureBorderToModel(DataSet<FeatureBorder> featureBorderDataSet) {
		return featureBorderDataSet.mapPartition(new MapPartitionFunction<FeatureBorder, Row>() {
			@Override
			public void mapPartition(Iterable<FeatureBorder> values, Collector<Row> out) throws Exception {
				transformFeatureBorderToModel(values, out);
			}
		}).setParallelism(1);
	}

	public static void transformFeatureBorderToModel(Iterable<FeatureBorder> values, Collector<Row> out) {
		List<String> colNames = new ArrayList<>();
		Map<String, FeatureBorder> m = new HashMap<>();
		for (FeatureBorder featureBorder : values) {
			m.put(featureBorder.featureName, featureBorder);
			colNames.add(featureBorder.featureName);
		}

		Params meta = new Params()
			.set(QuantileDiscretizerTrainParams.SELECTED_COLS, colNames.toArray(new String[0]));
		QuantileDiscretizerModelDataConverter model = new QuantileDiscretizerModelDataConverter(m, meta);
		model.save(model, out);
	}
}