package org.lightgbm.predict4j.v2;

import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;

import org.apache.commons.io.IOUtils;
import org.lightgbm.predict4j.SparseVector;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * @author lyg5623
 */
public abstract class Boosting implements Serializable {
    private static final Logger logger = LoggerFactory.getLogger(Boosting.class);
    private static final long serialVersionUID = -3370589073161617590L;

    static Boosting createBoosting(String filename) throws FileNotFoundException, IOException {
        String type = getBoostingTypeFromModelFile(filename);
        Boosting boosting = null;
        if (type.equals("tree")) {
            boosting = new GBDT();
        } else {
            logger.error("unknown submodel type in model file " + filename);
        }
        loadFileToBoosting(boosting, filename);
        return boosting;
    }

    static Boosting createBoosting(String type, String filename) throws FileNotFoundException, IOException {
        if (filename == null || filename.length() == 0) {
            if (type.equals("gbdt")) {
                return new GBDT();
            } else if (type.equals("dart")) {
                return new DART();
            } else if (type.equals("goss")) {
                return new GOSS();
            } else {
                return null;
            }
        } else {
            Boosting boosting = null;
            String type_in_file = getBoostingTypeFromModelFile(filename);
            if (type_in_file.equals("tree")) {
                if (type.equals("gbdt")) {
                    boosting = new GBDT();
                } else if (type.equals("dart")) {
                    boosting = new DART();
                } else if (type.equals("goss")) {
                    boosting = new GOSS();
                } else {
                    logger.error("unknown boosting type " + type);
                }
                loadFileToBoosting(boosting, filename);
            } else {
                logger.error("unknown submodel type in model file " + filename);
            }
            return boosting;
        }
    }

    static boolean loadFileToBoosting(Boosting boosting, String filename) throws FileNotFoundException, IOException {
        if (boosting != null) {
            StringBuilder sb = new StringBuilder();
            List<String> lines = IOUtils.readLines(new FileInputStream(filename));
            for (String line : lines) {
                sb.append(line).append("\n");
            }
            if (!boosting.loadModelFromString(sb.toString()))
                return false;
        }

        return true;
    }

    static String getBoostingTypeFromModelFile(String filename) throws FileNotFoundException, IOException {
        List<String> lines = IOUtils.readLines(new FileInputStream(filename));
        return lines.get(0);
    }

    abstract boolean loadModelFromString(String modelStr);

    abstract boolean needAccuratePrediction();

    abstract int numberOfClasses();

    abstract void initPredict(int num_iteration);

    abstract int numPredictOneRow(int num_iteration, boolean is_pred_leaf);

    abstract int getCurrentIteration();

    abstract int maxFeatureIdx();

    abstract List<Double> predictLeafIndex(SparseVector vector);

    abstract List<Double> predictRaw(SparseVector vector, PredictionEarlyStopInstance early_stop);

    abstract List<Double> predict(SparseVector vector, PredictionEarlyStopInstance early_stop);
}