package edu.brown.cs.burlap.experiencereplay;

import burlap.mdp.core.action.Action;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import edu.brown.cs.burlap.ALEState;
import edu.brown.cs.burlap.action.ActionSet;
import edu.brown.cs.burlap.preprocess.PreProcessor;
import edu.brown.cs.burlap.vfa.StateVectorizor;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.opencv_core;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.util.ArrayList;
import java.util.List;

import static org.bytedeco.javacpp.opencv_core.CV_32F;
import static org.bytedeco.javacpp.opencv_core.CV_8U;

/**
 * Unit tests for FrameExperienceMemory.
 *
 * @author Melrose Roderick.
 */
public class TestFrameExperienceMemory {
    FloatPointer input;


    @Before
    public void setup() {
        Loader.load(opencv_core.class);
    }

    @After
    public void teardown() {

    }

    @Test
    public void TestSmall() {
        BytePointer data0 = new BytePointer((byte)0, (byte)0);
        BytePointer data1 = new BytePointer((byte)0, (byte)1);
        BytePointer data2 = new BytePointer((byte)2, (byte)3);
        BytePointer data3 = new BytePointer((byte)4, (byte)5);
        BytePointer data4 = new BytePointer((byte)6, (byte)7);
        BytePointer data5 = new BytePointer((byte)8, (byte)9);
        BytePointer data6 = new BytePointer((byte)10, (byte)11);
        BytePointer data7 = new BytePointer((byte)12, (byte)13);

        opencv_core.Mat frame0 = new opencv_core.Mat(1, 2, CV_8U, data0);
        opencv_core.Mat frame1 = new opencv_core.Mat(1, 2, CV_8U, data1);
        opencv_core.Mat frame2 = new opencv_core.Mat(1, 2, CV_8U, data2);
        opencv_core.Mat frame3 = new opencv_core.Mat(1, 2, CV_8U, data3);
        opencv_core.Mat frame4 = new opencv_core.Mat(1, 2, CV_8U, data4);
        opencv_core.Mat frame5 = new opencv_core.Mat(1, 2, CV_8U, data5);
        opencv_core.Mat frame6 = new opencv_core.Mat(1, 2, CV_8U, data6);
        opencv_core.Mat frame7 = new opencv_core.Mat(1, 2, CV_8U, data7);

        ALEState aleState0 = new ALEState(frame0);
        ALEState aleState1 = new ALEState(frame1);
        ALEState aleState2 = new ALEState(frame2);
        ALEState aleState3 = new ALEState(frame3);
        ALEState aleState4 = new ALEState(frame4);
        ALEState aleState5 = new ALEState(frame5);
        ALEState aleState6 = new ALEState(frame6);
        ALEState aleState7 = new ALEState(frame7);


        input = new FloatPointer(2 * 2);

        ActionSet actionSet = new ActionSet(new String[]{"Action0"});
        Action action0 = actionSet.get(0);

        FrameExperienceMemory experienceMemory = new FrameExperienceMemory(5, 2, new TestPreprocessor(2), actionSet);
        FrameHistory state0 = experienceMemory.currentFrameHistory;
        experienceMemory.addExperience(new EnvironmentOutcome(aleState0, action0, aleState1, 0, false));
        FrameHistory state1 = experienceMemory.currentFrameHistory;
        experienceMemory.addExperience(new EnvironmentOutcome(aleState1, action0, aleState2, 0, false));
        FrameHistory state2 = experienceMemory.currentFrameHistory;
        experienceMemory.addExperience(new EnvironmentOutcome(aleState2, action0, aleState3, 0, false));
        FrameHistory state3 = experienceMemory.currentFrameHistory;
        experienceMemory.addExperience(new EnvironmentOutcome(aleState3, action0, aleState4, 0, false));
        FrameHistory state4 = experienceMemory.currentFrameHistory;

        compare(state0, experienceMemory, new BytePointer[]{data0, data0}, 2);
        compare(state1, experienceMemory, new BytePointer[]{data0, data1}, 2);
        compare(state2, experienceMemory, new BytePointer[]{data1, data2}, 2);
        compare(state3, experienceMemory, new BytePointer[]{data2, data3}, 2);
        compare(state4, experienceMemory, new BytePointer[]{data3, data4}, 2);

        experienceMemory.addExperience(new EnvironmentOutcome(aleState4, action0, aleState5, 0, false));
        FrameHistory state5 = experienceMemory.currentFrameHistory;
        experienceMemory.addExperience(new EnvironmentOutcome(aleState5, action0, aleState6, 0, false));
        FrameHistory state6 = experienceMemory.currentFrameHistory;
        experienceMemory.addExperience(new EnvironmentOutcome(aleState6, action0, aleState7, 0, false));
        FrameHistory state7 = experienceMemory.currentFrameHistory;

        compare(state3, experienceMemory, new BytePointer[]{data2, data3}, 2);
        compare(state4, experienceMemory, new BytePointer[]{data3, data4}, 2);
        compare(state5, experienceMemory, new BytePointer[]{data4, data5}, 2);
        compare(state6, experienceMemory, new BytePointer[]{data5, data6}, 2);
        compare(state7, experienceMemory, new BytePointer[]{data6, data7}, 2);
    }

    @Test
    public void TestRandom() {
        int replaySize = 50;
        int history = 4;
        int frameSize = 10;
        input = new FloatPointer(frameSize * history);

        ActionSet actionSet = new ActionSet(new String[]{"Action0"});
        Action action0 = actionSet.get(0);

        FrameExperienceMemory experienceMemory = new FrameExperienceMemory(replaySize, history, new TestPreprocessor(frameSize), actionSet);
        FrameHistory initialState = experienceMemory.currentFrameHistory;
        BytePointer data0 = new BytePointer(history);
        for (int f = 0; f < frameSize; f++) {
            data0.position(f).put((byte)0);
        }
        data0.position(0);
        List<BytePointer> dataList = new ArrayList<>();
        for (int h = 0; h < history; h++) {
            dataList.add(data0);
        }
        compare(initialState, experienceMemory, dataList.toArray(new BytePointer[history]), frameSize);

        List<List<BytePointer>> dataListList = new ArrayList<>();

        List<FrameHistory> states = new ArrayList<>();

        ALEState prevAleState = null;
        for (int n = 0; n < 100; n++) {
            for (int i = 0; i < replaySize; i++) {
                BytePointer data = new BytePointer(frameSize);
                for (int f = 0; f < frameSize; f++) {
                    byte d = (byte) (Math.random()*126.0);
                    data.position(f).put(d);
                }
                data.position(0);
                dataList.remove(0);
                dataList.add(data);

                opencv_core.Mat frame = new opencv_core.Mat(1, frameSize, CV_8U, data);
                ALEState aleState = new ALEState(frame);

                if (prevAleState != null) {
                    experienceMemory.addExperience(new EnvironmentOutcome(prevAleState, action0, aleState, 0, false));
                    FrameHistory state = experienceMemory.currentFrameHistory;

                    compare(state, experienceMemory, dataList.toArray(new BytePointer[history]), frameSize);

                    if (i < dataListList.size()) {
                        dataListList.set(i, new ArrayList<>(dataList));
                        states.set(i, state);
                    } else {
                        dataListList.add(new ArrayList<>(dataList));
                        states.add(state);
                    }

                    for (int k = 0; k < states.size(); k++) {
                        compare(states.get(k), experienceMemory, dataListList.get(k).toArray(new BytePointer[history]), frameSize);
                    }
                }

                prevAleState = aleState;
            }
        }
    }

    public void compare(FrameHistory state, StateVectorizor vectorizor, BytePointer[]dataArray, long outputSize) {
        vectorizor.vectorizeState(state, input);

        int i = 0;
        for (BytePointer data : dataArray) {
            for (int k = 0; k < outputSize; k++) {
                Assert.assertEquals(input.get(i), data.get(k), 1e-6);
                i++;
            }
        }
    }


    public class TestPreprocessor implements PreProcessor {
        int frameSize;

        public TestPreprocessor(int frameSize) {
            this.frameSize = frameSize;
        }

        @Override
        public void convertScreenToData(opencv_core.Mat screen, BytePointer data) {
            if (screen.data().address() == data.address()) {
                return;
            }

            BytePointer screenData = screen.data();
            data.limit(data.position() + frameSize).put(screen.data().limit(frameSize));
        }

        @Override
        public void convertDataToInput(BytePointer data, FloatPointer input, long size) {
            int dataSize = outputSize() * (int)size;

            opencv_core.Mat mat = new opencv_core.Mat(1, dataSize, CV_8U, data);
            opencv_core.Mat floatMat = new opencv_core.Mat(1, dataSize, CV_32F, (new BytePointer(input)).position(input.position() * input.sizeof()));

            mat.convertTo(floatMat, CV_32F, 1, 0);
        }

        @Override
        public int outputSize() {
            return frameSize;
        }
    }
}