package edu.brown.cs.burlap.examples;

import burlap.behavior.policy.EpsilonGreedy;
import burlap.behavior.policy.SolverDerivedPolicy;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.Environment;
import edu.brown.cs.burlap.ALEDomainGenerator;
import edu.brown.cs.burlap.ALEEnvironment;
import edu.brown.cs.burlap.action.ActionSet;
import edu.brown.cs.burlap.experiencereplay.FrameExperienceMemory;
import edu.brown.cs.burlap.gui.ALEVisualExplorer;
import edu.brown.cs.burlap.gui.ALEVisualizer;
import edu.brown.cs.burlap.io.PoolingMethod;
import edu.brown.cs.burlap.learners.DeepQLearner;
import edu.brown.cs.burlap.policies.AnnealedEpsilonGreedy;
import edu.brown.cs.burlap.preprocess.ALEPreProcessor;
import edu.brown.cs.burlap.testing.DeepQTester;
import edu.brown.cs.burlap.vfa.DQN;
import org.bytedeco.javacpp.Loader;

import static org.bytedeco.javacpp.caffe.Caffe;

/**
 * A burlap_caffe example on the Atari domain.
 *
 * @author Melrose Roderick.
 */
public class AtariDQN extends TrainingHelper {

    // TODO: set to true if you download our version of ALE and want to replicate the Deepmind results
    static final boolean TERMINATE_ON_END_LIFE = false;

    protected FrameExperienceMemory trainingMemory;
    protected FrameExperienceMemory testMemory;

    public AtariDQN(DeepQLearner learner, DeepQTester tester, DQN vfa, ActionSet actionSet, Environment env,
                    FrameExperienceMemory trainingMemory,
                    FrameExperienceMemory testMemory) {
        super(learner, tester, vfa, actionSet, env);

        this.trainingMemory = trainingMemory;
        this.testMemory = testMemory;
    }

    @Override
    public void prepareForTraining() {
        if (TERMINATE_ON_END_LIFE) {
            ((ALEEnvironment) env).setTerminateOnEndLife(true);
        }

        vfa.stateConverter = trainingMemory;
    }

    @Override
    public void prepareForTesting() {
        if (TERMINATE_ON_END_LIFE) {
            ((ALEEnvironment) env).setTerminateOnEndLife(false);
        }

        vfa.stateConverter = testMemory;
    }

    public static void main(String[] args) {

        // Learning constants defined in the DeepMind Nature paper
        // (http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
        int experienceMemoryLength = 1000000;
        int maxHistoryLength = 4;
        int staleUpdateFreq = 10000;
        double gamma = 0.99;
        int frameSkip = 4;
        int updateFreq = 4;
        double rewardClip = 1.0;
        float gradientClip = 1.0f;
        double epsilonStart = 1;
        double epsilonEnd = 0.1;
        int epsilonAnnealDuration = 1000000;
        int replayStartSize = 50000;
        int noopMax = 30;
        int totalTrainingSteps = 50000000;
        double testEpsilon = 0.05;

        // Testing and recording constants
        int testInterval = 250000;
        int totalTestSteps = 125000;
        int maxEpisodeSteps = 100000;
        int snapshotInterval = 1000000;
        String snapshotPrefix = "snapshots/experiment1";
        String resultsDirectory = "results/experiment1";

        // ALE Paths
        // TODO: Set to appropriate paths for your machine
        String alePath = "/path/to/atari/executable";
        String romPath = "/path/to/atari/rom/file";

        // Caffe solver file
        String solverFile = "example_models/atari_dqn_solver.prototxt";

        // Load Caffe
        Loader.load(Caffe.class);

        // Create the domain
        ALEDomainGenerator domGen = new ALEDomainGenerator(ALEDomainGenerator.saActionSet());
        SADomain domain = domGen.generateDomain();

        // Create the ALEEnvironment and visualizer
        ALEEnvironment env = new ALEEnvironment(alePath, romPath, frameSkip, PoolingMethod.POOLING_METHOD_MAX);
        env.setRandomNoopMax(noopMax);
        ALEVisualExplorer exp = new ALEVisualExplorer(domain, env, ALEVisualizer.create());
        exp.initGUI();
        exp.startLiveStatePolling(1000/60);

        // Setup the ActionSet from the ALEDomain to use the ALEActions
        ActionSet actionSet = new ActionSet(domain);

        // Setup the training and test memory
        FrameExperienceMemory trainingExperienceMemory =
                new FrameExperienceMemory(experienceMemoryLength, maxHistoryLength, new ALEPreProcessor(), actionSet);
        // The size of the test memory is arbitrary but should be significantly greater than 1 to minimize copying
        FrameExperienceMemory testExperienceMemory =
                new FrameExperienceMemory(10000, maxHistoryLength, new ALEPreProcessor(), actionSet);


        // Initialize the DQN with the solver file.
        // NOTE: this Caffe architecture is made for 3 actions (the number of actions in Pong)
        DQN dqn = new DQN(solverFile, actionSet, trainingExperienceMemory, gamma);
        dqn.setRewardClip(rewardClip);
        dqn.setGradientClip(gradientClip);

        // Create the policies
        SolverDerivedPolicy learningPolicy =
                new AnnealedEpsilonGreedy(dqn, epsilonStart, epsilonEnd, epsilonAnnealDuration);
        SolverDerivedPolicy testPolicy = new EpsilonGreedy(dqn, testEpsilon);

        // Setup the learner
        DeepQLearner deepQLearner = new DeepQLearner(domain, gamma, replayStartSize, learningPolicy, dqn, trainingExperienceMemory);
        deepQLearner.setExperienceReplay(trainingExperienceMemory, dqn.batchSize);
        deepQLearner.useStaleTarget(staleUpdateFreq);
        deepQLearner.setUpdateFreq(updateFreq);

        // Setup the tester
        DeepQTester tester = new DeepQTester(testPolicy, testExperienceMemory, testExperienceMemory);

        // Setup helper
        TrainingHelper helper =
                new AtariDQN(deepQLearner, tester, dqn, actionSet, env, trainingExperienceMemory, testExperienceMemory);
        helper.setTotalTrainingSteps(totalTrainingSteps);
        helper.setTestInterval(testInterval);
        helper.setTotalTestSteps(totalTestSteps);
        helper.setMaxEpisodeSteps(maxEpisodeSteps);
        helper.enableSnapshots(snapshotPrefix, snapshotInterval);
        helper.recordResultsTo(resultsDirectory);
        //helper.verbose = true;

        // Uncomment this line to load learning state if resuming
        //helper.loadLearningState(snapshotDirectory);

        // Run helper
        helper.run();
    }
}