package regression;

import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.stat.regression.GLSMultipleLinearRegression;
import org.apache.commons.math3.stat.correlation.Covariance;
import org.neo4j.logging.Log;

//NOT CURRENTLY WORKING, DEVELOPMENT PAUSED//

/**
 * Created by Lauren on 6/5/18.
 */
public class GlsLRModel extends LRModel {
    private GLSMultipleLinearRegression R;
    private final List<List<Double>> data = new ArrayList<>();
    private final List<Double> response = new ArrayList<>();
    private int numVars;
    private int numObs;
    private double[] params;

    GlsLRModel(String model, int numVars, boolean intercept) {
        super(model, Framework.GLS);
        R = new GLSMultipleLinearRegression();
        R.setNoIntercept(!intercept);
        numObs = 0;
        this.numVars = numVars;
    }
    @Override
    protected long getNTest() {
        return 0;
    }

    @Override long getNTrain() {return numObs;}

    @Override GlsLRModel clearTest() {return this;}
    @Override GlsLRModel clearAll() {return this;}

    @Override
    int getNumVars() { return numVars; }

    @Override
    boolean hasConstant() {return !R.isNoIntercept();}

    @Override
    public void addTrain(List<Double> given, double expected, Log log) {
        if (given.size() != numVars) throw new IllegalArgumentException("incorrect number of variables in given.");
        data.add(given);
        response.add(expected);
        numObs += 1;
        this.state = State.training;
    }

    @Override
    public double predict(List<Double> given) {
        if (given.size() != numVars) throw new IllegalArgumentException("incorrect number of variables in given.");
        if (this.state == State.training) train();
        double result = 0;
        if (R.isNoIntercept()) {
            for (int i = 0; i < numVars; i++) result += params[i] * given.get(i);
        } else {
            result += params[0];
            for (int i = 0; i < numVars; i++) result += params[i + 1] * given.get(i);
        }
        return result;
    }

    @Override
    public Object data() {
        if (this.state == State.training) train();
        if (this.state == State.ready) return this.params;
        else throw new RuntimeException(this.name + "is not in a state for serialization.");
    }

    /*@Override
    public LR.StatResult stats() {
        return new LR.StatResult(this.getN(), this.numVars);
    }*/

    @Override
    public GlsLRModel train() {
        double[][] dataArray = new double[this.numObs][this.numVars];
        for (int i = 0; i < numObs; i++) {
            for (int j = 0; j < numVars; j++)
                dataArray[i][j] = data.get(i).get(j);
        }
        RealMatrix data = new BlockRealMatrix(dataArray);
        double[] obs = LR.doubleListToArray(response);
        Covariance c = new Covariance(data);
        RealMatrix m = c.getCovarianceMatrix();
        double[][] covariance = m.getData();
        R.newSampleData(obs, dataArray, covariance);
        params = R.estimateRegressionParameters();
        this.state = State.ready;
        List<Double> paramResult = new ArrayList<>();
        for (int i = 0; i < numVars; i++) {
            paramResult.add(params[i]);
        }
        return this;
    }

    @Override
    void addTest(List<Double> given, double expected, Log log) {

    }

    @Override
    GlsLRModel test() {
        return this;
    }

    @Override
    GlsLRModel copy(String string) {
        return this;
    }

    @Override
    void removeTest(List<Double> input, double output, Log log) {

    }

    @Override
    void removeTrain(List<Double> input, double output, Log log) {

    }
}