package com.alibaba.alink.common.model;

import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.api.misc.param.Params;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;

import java.util.*;

/**
 * A utility class for converting model data to a collection of rows.
 */
class ModelConverterUtils {

    /**
     * The size of a string segment. When serializing model data to a table,
     * the string data will be sliced to segments with size no larger than "SEGMENT_SIZE".
     */
    static final int SEGMENT_SIZE = 32 * 1024;

    /**
     * Maximum number of slices a string can split to.
     */
    static final long MAX_NUM_SLICES = 1024L * 1024L;

    /**
     * Append model meta data to the collection of rows.
     *
     * @param meta      The model meta data.
     * @param collector The collector of model rows.
     * @param numFields Number of fields of the model table.
     */
    static void appendMetaRow(Params meta, Collector<Row> collector, final int numFields) {
        if (meta != null) {
            appendStringData(meta.toJson(), collector, numFields, 0);
        }
    }

    /**
     * Append a list of strings to the collection of rows.
     * <p>
     * Each of these string will be sliced to segments of size "SEGMENT_SIZE".
     *
     * @param data      The model data serialized to a list of strings.
     * @param collector The collector of model rows.
     * @param numFields Number of fields of the model table.
     */
    static void appendDataRows(Iterable<String> data, Collector<Row> collector, final int numFields) {
        if (data != null) {
            int index = 0;
            for (String s : data) {
                appendStringData(s, collector, numFields, index + 1);
                index++;
            }
        }
    }

    /**
     * Append a list of additional data to the collection of rows.
     *
     * @param auxData   The additional model data.
     * @param collector The collector of model rows.
     * @param numFields Number of fields of the model table.
     * @param <T>       The type of additional data.
     */
    static <T> void appendAuxiliaryData(Iterable<T> auxData, Collector<Row> collector, final int numFields) {
        if (auxData == null) {
            return;
        }

        final int numAdditionalFields = numFields - 2;
        int sliceIndex = 0;

        for (T data : auxData) {
            int stringIndex = Integer.MAX_VALUE;
            long modelId = getModelId(stringIndex, sliceIndex);
            Row row = new Row(numFields);
            row.setField(0, modelId);
            if (data instanceof Row) {
                Row r = (Row) data;
                for (int j = 0; j < numAdditionalFields; j++) {
                    row.setField(2 + j, r.getField(j));
                }
            } else {
                row.setField(2, data);
            }
            collector.collect(row);
            sliceIndex++;
        }
    }

    /**
     * Extract from a collection of rows the model meta and model data.
     *
     * @param rows Model rows.
     * @return A tuple of model meta and serialized model data.
     */
    static Tuple2<Params, Iterable<String>> extractModelMetaAndData(List<Row> rows) {
        Integer[] order = orderModelRows(rows);

        // extract meta
        List<String> metaSegments = new ArrayList<>();
        for (int i = 0; i < order.length; i++) {
            long id = (Long) rows.get(order[i]).getField(0);
            int currStringId = getStringIndex(id);
            if (currStringId == 0) {
                metaSegments.add((String) rows.get(order[i]).getField(1));
            } else {
                break;
            }
        }
        String metaStr = mergeString(metaSegments);

        return Tuple2.of(Params.fromJson(metaStr), new StringDataIterable(rows, order));
    }

    /**
     * Extract the additional data from a collection of rows.
     *
     * @param rows    Model rows.
     * @param isLabel Whether the additional data is label data.
     * @param <T>     The type of additional data.
     * @return The list of additional data.
     */
    static <T> Iterable<T> extractAuxiliaryData(List<Row> rows, boolean isLabel) {
        Integer[] order = orderModelRows(rows);
        return new AuxiliaryDataIterable<T>(rows, order, isLabel);
    }

    private static class StringDataIterator implements Iterator<String> {
        List<Row> modelRows;
        Integer[] order;
        String curr;
        int listPos = 0;

        public StringDataIterator(List<Row> modelRows, Integer[] order) {
            this.modelRows = modelRows;
            this.order = order;
            if (getNextValue() == 0) { // skip meta data
                getNextValue();
            }
        }

        @Override
        public boolean hasNext() {
            return curr != null;
        }

        @Override
        public String next() {
            if (!hasNext()) {
                throw new RuntimeException("Iterator do not has next value.");
            }
            String ret = curr;
            getNextValue();
            return ret;
        }

        private int getNextValue() {
            List<String> segments = new ArrayList<>();
            int lastStringId = -1;
            while (true) {
                if (listPos >= order.length || modelRows.get(order[listPos]).getField(1) == null) {
                    break;
                }
                long id = (Long) modelRows.get(order[listPos]).getField(0);
                String segment = (String) modelRows.get(order[listPos]).getField(1);

                int stringId = getStringIndex(id);
                if (lastStringId == -1) {
                    lastStringId = stringId;
                }
                if (stringId != lastStringId) {
                    break;
                } else {
                    segments.add(segment);
                    listPos++;
                }
            }
            if (segments.size() > 0) {
                curr = mergeString(segments);
                return lastStringId;
            } else {
                curr = null;
                return -1;
            }
        }
    }

    private static class StringDataIterable implements Iterable<String> {
        StringDataIterator iterator;

        public StringDataIterable(List<Row> modelRows, Integer[] order) {
            this.iterator = new StringDataIterator(modelRows, order);
        }

        @Override
        public Iterator<String> iterator() {
            return iterator;
        }
    }

    private static class AuxiliaryDataIterator<T> implements Iterator<T> {
        List<Row> modelRows;
        Integer[] order;
        boolean isLabel;
        int listPos = 0;

        public AuxiliaryDataIterator(List<Row> modelRows, Integer[] order, boolean isLabel) {
            this.modelRows = modelRows;
            this.order = order;
            this.isLabel = isLabel;

            for (; listPos < order.length; listPos++) {
                long id = (Long) modelRows.get(order[listPos]).getField(0);
                if (getStringIndex(id) == Integer.MAX_VALUE) {
                    break;
                }
            }
        }

        @Override
        public boolean hasNext() {
            return listPos < order.length;
        }

        @Override
        public T next() {
            if (!hasNext()) {
                throw new RuntimeException("The iterator do not have next value.");
            }
            Object ret;
            Row modelRow = modelRows.get(order[listPos]);
            if (isLabel) {
                ret = modelRow.getField(2);
            } else {
                Row sub = new Row(modelRow.getArity() - 2);
                for (int j = 0; j < sub.getArity(); j++) {
                    sub.setField(j, modelRow.getField(2 + j));
                }
                ret = sub;
            }
            listPos++;
            return (T) ret;
        }
    }

    private static class AuxiliaryDataIterable<T> implements Iterable<T> {
        AuxiliaryDataIterator<T> iterator;

        public AuxiliaryDataIterable(List<Row> modelRows, Integer[] order, boolean isLabel) {
            this.iterator = new AuxiliaryDataIterator<T>(modelRows, order, isLabel);
        }

        @Override
        public Iterator<T> iterator() {
            return iterator;
        }
    }

    private static void appendStringData(String data, Collector<Row> collector,
                                         final int numFields, int pos) {
        StringSlicer slicer = new StringSlicer(data, SEGMENT_SIZE);
        int i = 0;
        while (slicer.hasNextSegment()) {
            long modelId = getModelId(pos, i);
            Row row = new Row(numFields);
            row.setField(0, modelId);
            row.setField(1, slicer.nextSegment());
            collector.collect(row);
            i++;
        }
    }

    private static long getModelId(int stringIndex, int sliceIndex) {
        return MAX_NUM_SLICES * stringIndex + sliceIndex;
    }

    private static int getStringIndex(long modelId) {
        return (int) ((modelId) / MAX_NUM_SLICES);
    }

    private static Integer[] orderModelRows(List<Row> rows) {
        Integer[] order = new Integer[rows.size()];
        for (int i = 0; i < order.length; i++) {
            order[i] = i;
        }
        Arrays.sort(order, new Comparator<Integer>() {
            @Override
            public int compare(Integer o1, Integer o2) {
                return Long.compare((Long) rows.get(o1).getField(0), (Long) rows.get(o2).getField(0));
            }
        });
        return order;
    }

    private static class StringSlicer {
        private int segmentSize;
        private String str;
        private int pos;
        private int len;

        public StringSlicer(String str, int segmentSize) {
            this.segmentSize = segmentSize;
            this.str = str;
            this.pos = 0;
            this.len = str == null ? 0 : str.length();
        }

        public boolean hasNextSegment() {
            return pos < len;
        }

        public String nextSegment() {
            String segment = str.substring(pos, Math.min(pos + segmentSize, len));
            pos += segment.length();
            return segment;
        }
    }

    private static String mergeString(List<String> strings) {
        if (strings.size() == 1) { // this is the most cases.
            return strings.get(0);
        }
        StringBuilder sbd = new StringBuilder();
        strings.forEach(sbd::append);
        return sbd.toString();
    }
}