/**
*
* Copyright (c) 2017 ytk-learn https://github.com/yuantiku
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:

* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.

* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

package com.fenbi.ytklearn.dataflow;

import com.fenbi.mp4j.comm.ThreadCommSlave;
import com.fenbi.mp4j.exception.Mp4jException;
import com.fenbi.ytklearn.data.Constants;
import com.fenbi.ytklearn.exception.YtkLearnException;
import com.fenbi.ytklearn.loss.ILossFunction;
import com.fenbi.ytklearn.utils.CheckUtils;
import lombok.Data;
import org.apache.commons.lang.ArrayUtils;
import org.python.core.PyFunction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;

/**
 * used for data-parallel tree maker
 * @author wufan
 * @author xialong
 */

@Data
public class GBDTCoreData extends CoreData {
    public final int DENSE_MAX_1D_SAMPLE_CNT;
    public final int DENSE_MAX_1D_LEN;

    public final int maxFeatureDim;
    public int usefulFeatureDim;

    public final int numTreeInGroup;

    public ILossFunction obj;
    private float baseScore;
    private boolean sampleDepdtBasePrediction;

    // set outside
    public Map<Integer, String> fIndex2NameMap;
    // set outside, compute once, used in thread data
    public int sampleNum;
    public double weightSum;

    // === used for train phase ===
    public int lastPredRound;

    public float[] TMP_INIT_SCORE;
    public float[][] initScore;

    // prediction buffer, save raw prediction(score before loss function) in last iter
    public float[][] score;
    // first order and second order gradient  eg: grad1, hess1, grad2, hess2,
    public float[][] gradPairs;

    // len(xColRange) = 2, [startXColIndex, endXColIndex)
    public int[] xColRange;
    // global feature range for each worker(thread)
    public int[][] globalFeatureAssignFrom;
    public int[][] globalFeatureAssignTo;

    // === used in local(feature-parallel) version ===
    // len(xRowRange) = 2, [startXRowIndex, endXRowIndex)
    public int[] xRowRange;
    public int[][] globalSampleAssignFrom;
    public int[][] globalSampleAssignTo;

    public int[][] globalGradAssignFrom;
    public int[][] globalGradAssignTo;
    public GBDTCoreData allData;
    // feature cols for thread data, only used in local version
    public FeatureColData xT;


    public GBDTCoreData(ThreadCommSlave comm,
                        DataFlow.CoreParams coreParams,
                        IFeatureMap featureMap,
                        PyFunction pyTransformFunc,
                        boolean needPyTransfor,
                        int maxFeatureDim,
                        int numTreeInGroup,
                        ILossFunction obj,
                        float baseScore,
                        boolean sampleDepdtBasePrediction) {
        super(comm, coreParams, featureMap, pyTransformFunc, needPyTransfor);
        this.initScore = new float[MAX_2D_LEN][];
        this.score = new float[MAX_2D_LEN][];

        this.DENSE_MAX_1D_SAMPLE_CNT = MAX_1D_LEN / maxFeatureDim;
        this.DENSE_MAX_1D_LEN = DENSE_MAX_1D_SAMPLE_CNT * maxFeatureDim;
        this.lastPredRound = 0;

        this.maxFeatureDim = maxFeatureDim;
        this.numTreeInGroup = numTreeInGroup;
        this.obj = obj;
        this.baseScore = baseScore;
        this.sampleDepdtBasePrediction = sampleDepdtBasePrediction;
    }

    // for feature parellel
    public static GBDTCoreData mergeThreadFeatData(CoreData[] threadCorData) {
        if (threadCorData == null || threadCorData.length == 0) {
            return null;
        }
        int total2D = 0;
        for (int t = 0; t < threadCorData.length; t++) {
            total2D += threadCorData[t].cursor2d;
        }
        GBDTCoreData tdata0 = (GBDTCoreData)threadCorData[0];
        GBDTCoreData allData = new GBDTCoreData(tdata0.comm, tdata0.coreParams, tdata0.featureMap, tdata0.pyTransformFunc, tdata0.needPyTransform,
                tdata0.maxFeatureDim, tdata0.numTreeInGroup, tdata0.obj, tdata0.baseScore, tdata0.sampleDepdtBasePrediction);
        allData.sampleNum = (int)tdata0.gRealNum;
        allData.cursor2d = total2D;
        allData.x = new int[total2D][];
        allData.realNum = new int[total2D];
        int index2D = 0;
        for (CoreData data: threadCorData) {
            for (int i = 0; i < data.cursor2d; i++) {
                allData.x[index2D] = data.x[i];
                allData.realNum[index2D] = data.realNum[i];
                index2D++;
            }
        }
        return allData;
    }

    //local(feature-parallel) version: reverse features, get features for thread data
    public void createFeatureColData() {
        xT = new FeatureColData(allData, xColRange);
    }


    @Override
    public void initAssistData() {
        LINE_LIST = new ArrayList<>();
        LINE_LIST.add("temp");

        TMP_Y = new float[DENSE_MAX_1D_SAMPLE_CNT * numTreeInGroup];
        TMP_WEIGHT = new float[DENSE_MAX_1D_SAMPLE_CNT];
        TMP_INIT_SCORE = new float[DENSE_MAX_1D_SAMPLE_CNT * numTreeInGroup];
    }

    public int getFeatureVal(int sid, int fid) {
        int index2D = sid / DENSE_MAX_1D_SAMPLE_CNT;
        int index1D = sid % DENSE_MAX_1D_SAMPLE_CNT;
        return x[index2D][index1D * maxFeatureDim + fid];
    }

    public void setFeatureVal(int sid, int fid, int val) {
        int index2D = sid / DENSE_MAX_1D_SAMPLE_CNT;
        int index1D = sid % DENSE_MAX_1D_SAMPLE_CNT;
        x[index2D][index1D * maxFeatureDim + fid] = val;
    }

    public boolean isEqualReplaceFeaVal(int sid, int fid, int refVal, int newVal) {
        int index2D = sid / DENSE_MAX_1D_SAMPLE_CNT;
        int index1D = (sid % DENSE_MAX_1D_SAMPLE_CNT) * maxFeatureDim + fid;
        if (x[index2D][index1D] == refVal) {
            x[index2D][index1D] = newVal;
            return true;
        } else {
            return false;
        }
    }

    public float getSampleWeight(int sid) {
        int index2D = sid / DENSE_MAX_1D_SAMPLE_CNT;
        int index1D = sid % DENSE_MAX_1D_SAMPLE_CNT;
        return weight[index2D][index1D];
    }

    // used in train phase, alloc space for gradPairs
    public void initGradPairs() {
        gradPairs = new float[cursor2d][];
        for (int i = 0; i < cursor2d; i++) {
            gradPairs[i] = new float[realNum[i] * 2 * numTreeInGroup];
            for (int j = 0; j < gradPairs[i].length; j++) {
                gradPairs[i][j] = 0.f;
            }
        }
    }

    @Override
    protected String[] trainDataSplit(String line) {
        String[] info = line.trim().split(coreParams.x_delim);
        CheckUtils.check(info.length >=3, "[GBDT] data format error! line:%s", line);
        return info;
    }

    @Override
    protected void updateY() {
        int yidx = count * numTreeInGroup;
        for (int i = 0; i < numTreeInGroup; i++) {
            TMP_Y[yidx + i] = label[i];
        }
    }

    @Override
    protected boolean yExtract(String line, String[] info) throws Exception {

        //regression & binary classification & ...
        if (numTreeInGroup == 1) {
            label[0] = Float.parseFloat(info[1]);
            CheckUtils.check(obj.checkLabel(label[0]), "[GBDT] label error, line: %s", line);
            labelIdx = (int) label[0];

        } else { // multiclass softmax
            String[] linfo = info[1].split(coreParams.y_delim);
            CheckUtils.check(linfo.length == numTreeInGroup || linfo.length == 1, "[GBDT] label num must equal %d or 1, line: %s", numTreeInGroup, line);

            if (linfo.length == 1) {
                for (int i = 0; i < numTreeInGroup; i++) {
                    label[i] = 0;
                }
                int clazz = Integer.parseInt(linfo[0]);
                if (clazz >= numTreeInGroup) {
                    throw new YtkLearnException("multi classification label must in range [0,K-1]!\n" + line);
                }
                label[clazz] = 1.0f;
            } else {
                for (int i = 0; i < numTreeInGroup; i++) {
                    label[i] = Float.parseFloat(linfo[i]);
                }
            }

            CheckUtils.check(obj.checkLabel(label), "[GBDT] all label sum must equal 1.0, line: %s", line);
            if (coreParams.needYStat) {
                labelIdx = -1;
                for (int i = 0; i < numTreeInGroup; i++) {
                    if (label[i] == 1.0) {
                        labelIdx = i;
                    }
                }
            }

        }

        if (!coreParams.needYSampling) {
            return true;
        }

        CheckUtils.check(labelIdx != -1, "[GBDT] label error! line: %s", line);
        float rate = coreParams.ySampling[labelIdx];
        if (rate <= 1.0f) {
            wei *= (1.0f / rate);
        } else {
            wei *= rate;
        }
        return rand.nextFloat() <= rate;
    }

    @Override
    protected boolean exceed1DRange() {
        return xindex >= DENSE_MAX_1D_LEN;
    }

    @Override
    protected void exceed1DHandle() throws Mp4jException {
        int localnum = realNum[cursor2d];

        if (localnum != count) {
            LOG_UTILS.verboseInfo(loadingPrefix + "----error! localnum:" + localnum + ", count:" + count, false);
        }

        weight[cursor2d] = new float[localnum];
        System.arraycopy(TMP_WEIGHT, 0, weight[cursor2d], 0, localnum);

        y[cursor2d] = new float[localnum * numTreeInGroup];
        System.arraycopy(TMP_Y, 0, y[cursor2d], 0, localnum * numTreeInGroup);

        initScore[cursor2d] = new float[localnum * numTreeInGroup];
        System.arraycopy(TMP_INIT_SCORE, 0, initScore[cursor2d], 0, localnum * numTreeInGroup);

        score[cursor2d] = new float[localnum * numTreeInGroup];
        predict[cursor2d] = new float[localnum * numTreeInGroup];

        xindex = 0;
        count = 0;
        cursor2d++;
    }

    @Override
    protected void alloc1D() {
        if (x[cursor2d] == null) {
            x[cursor2d] = new int[DENSE_MAX_1D_LEN];
            for (int i = 0; i < DENSE_MAX_1D_LEN; i++) {
                x[cursor2d][i] = Constants.INT_MISSING_VALUE;
            }
        }
    }

    @Override
    protected void exceed2DHandle() {
        int new_len = x.length * 2;
        int[][] new_x = new int[new_len][];
        float[][] new_y = new float[new_len][];
        float[][] new_weight = new float[new_len][];
        int[] new_realNum = new int[new_len];
        double[] new_weightNum = new double[new_len];
        float[][] new_margin = new float[new_len][];
        float[][] new_score = new float[new_len][];
        float[][] new_predict = new float[new_len][];

        for (int i = 0; i < x.length; i++) {
            new_x[i] = x[i];
            new_y[i] = y[i];
            new_weight[i] = weight[i];
            new_realNum[i] = realNum[i];
            new_weightNum[i] = weightNum[i];
            new_margin[i] = initScore[i];
            new_score[i] = score[i];
            new_predict[i] = predict[i];
        }

        x = new_x;
        y = new_y;
        weight = new_weight;
        realNum = new_realNum;
        weightNum = new_weightNum;
        initScore = new_margin;
        score = new_score;
        predict = new_predict;
    }

    @Override
    protected void updateXidx() throws Mp4jException {
        TMP_WEIGHT[count] = wei;
    }

    @Override
    protected boolean updateX(String line, String[] info) throws Mp4jException {

        try {
            Map<String, Float> fnvMap = line2FeatureMap(line, info);
            for (Map.Entry<String, Float> fnvMapEntry : fnvMap.entrySet()) {
                String fn = fnvMapEntry.getKey();
                Float fv = fnvMapEntry.getValue();

                if (loadingTrainData && !coreParams.need_dict) {
                    XLong xcnt = featureXCntMap.get(fn);
                    if (xcnt == null) {
                        xcnt = new XLong();
                        xcnt.val = 1;
                        featureXCntMap.put(fn, xcnt);
                    } else {
                        xcnt.val ++;
                    }
                }

                Integer findex = featureMap.getIndex(fn);
                if (findex == null) {
                    continue;
                }
                if (loadingTrainData) {
                    findex += biasDelta;
                }

                CheckUtils.check(findex < maxFeatureDim, "[GBDT] max_feature_dim(%d) smaller than real feature number in data set, local feature index is %d, sample:%s",
                        maxFeatureDim, findex, line);
                x[cursor2d][xindex + findex] = Float.floatToRawIntBits(fv);
            }

            xindex += maxFeatureDim;
            if (lineCnt < 10) {
                LOG_UTILS.verboseInfo("weight:" + wei + ", label:" + ArrayUtils.toString(label) + ", features:" + fnvMap, false);
            }

        } catch (Exception e) {
            errorNum++;
            LOG_UTILS.error(loadingPrefix + "[ERROR] error format:" + line +
                    ", local error total num:" + errorNum +
                    ", max error tol:" + maxErrorTolNum +
                    ", has read lines:" + lineCnt);
            if (errorNum > maxErrorTolNum) {
                LOG_UTILS.error("[ERROR] train error num:" + errorNum +
                        " > " + "max tol:" + maxErrorTolNum);
                throw e;
            }
            return false;
        }

        return true;
    }

    protected void lastSampleHandle() throws Mp4jException {
        int localnum = realNum[cursor2d];

        weight[cursor2d] = new float[localnum];
        System.arraycopy(TMP_WEIGHT, 0, weight[cursor2d], 0, localnum);

        y[cursor2d] = new float[localnum * numTreeInGroup];
        System.arraycopy(TMP_Y, 0, y[cursor2d], 0, localnum * numTreeInGroup);

        initScore[cursor2d] = new float[localnum * numTreeInGroup];
        System.arraycopy(TMP_INIT_SCORE, 0, initScore[cursor2d], 0, localnum * numTreeInGroup);

        score[cursor2d] = new float[localnum * numTreeInGroup];
        predict[cursor2d] = new float[localnum * numTreeInGroup];

        LOG_UTILS.verboseInfo(loadingPrefix + "finished read data, cursor2d:" + cursor2d +
                ", real num:" + ArrayUtils.toString(Arrays.copyOfRange(realNum, 0, cursor2d + 1)) +
                ", weight sum:" + ArrayUtils.toString(Arrays.copyOfRange(weightNum, 0, cursor2d + 1)), false);
    }

    @Override
    protected void otherHandle(String line, String[] info) {
        if (numTreeInGroup == 1) {
            TMP_INIT_SCORE[count] = baseScore;
            if (sampleDepdtBasePrediction) {
                TMP_INIT_SCORE[count] += (float)obj.pred2Score(Float.parseFloat(info[3]));
            }
        } else { // multi-class
            String[] linfo = null;
            if (sampleDepdtBasePrediction) {
                linfo = info[3].split(coreParams.y_delim);
                CheckUtils.check(linfo.length == numTreeInGroup,
                        "[GBDT] sample dependent score num must equal %d, %s", numTreeInGroup, line);
            }

            int offset = count * numTreeInGroup;
            for (int i = 0; i < numTreeInGroup; i++) {
                TMP_INIT_SCORE[offset + i] = baseScore;
                if (sampleDepdtBasePrediction) {
                    TMP_INIT_SCORE[offset + i] += (float)obj.pred2Score(Float.parseFloat(linfo[i]));
                }
            }
        }
    }

}