package edu.brown.cs.burlap.examples;

import burlap.behavior.singleagent.Episode;
import burlap.mdp.singleagent.environment.Environment;
import edu.brown.cs.burlap.action.ActionSet;
import edu.brown.cs.burlap.learners.DeepQLearner;
import edu.brown.cs.burlap.testing.Tester;
import edu.brown.cs.burlap.vfa.DQN;

import java.io.*;

/**
 * A class to coordinate all the steps for training and testing a DQN on a given Domain.
 *
 * @author Melrose Roderick.
 */
public class TrainingHelper {

    protected DeepQLearner learner;
    protected Tester tester;
    protected DQN vfa;

    protected Environment env;

    protected ActionSet actionSet;

    protected int maxEpisodeSteps = -1;
    protected int totalTrainingSteps = 10000000;

    protected int testInterval = 100000;
    protected int totalTestSteps = 125000;

    protected String snapshotPrefix;
    protected int snapshotInterval = -1;

    protected int stepCounter;
    protected int episodeCounter;

    protected double highestAverageReward = Double.NEGATIVE_INFINITY;
    protected PrintStream testOutput;
    protected String resultsPrefix;

    /** If true, prints out episode information at the end of every episode */
    public boolean verbose = false;


    public TrainingHelper(DeepQLearner learner, Tester tester, DQN vfa, ActionSet actionSet, Environment env) {
        this.learner = learner;
        this.vfa = vfa;
        this.tester = tester;
        this.env = env;
        this.actionSet = actionSet;

        this.stepCounter = 0;
        this.episodeCounter = 0;
    }

    public void prepareForTraining() {}
    public void prepareForTesting() {}

    public void setTotalTrainingSteps(int n) {
        totalTrainingSteps = n;
    }

    public void setTotalTestSteps(int n) {
        totalTestSteps = n;
    }

    public void setTestInterval(int i) {
        testInterval = i;
    }

    public void setMaxEpisodeSteps(int f) {
        maxEpisodeSteps = f;
    }

    public void enableSnapshots(String snapshotPrefix, int snapshotInterval) {
        File dir = new File(snapshotPrefix);
        File parent = dir.getParentFile();
        if (!parent.exists() && !parent.mkdirs()) {
            throw new RuntimeException(String.format("Could not create the directory: %s", snapshotPrefix));
        }

        this.snapshotPrefix = snapshotPrefix;
        this.snapshotInterval = snapshotInterval;
    }

    public void recordResultsTo(String resultsPrefix) {
        File dir = new File(resultsPrefix);
        if (!dir.exists() && !dir.mkdirs()) {
            throw new RuntimeException(String.format("Could not create the directory: %s", resultsPrefix));
        }

        this.resultsPrefix = resultsPrefix;

        try {
            String fileName = new File(resultsPrefix, "testResults").toString();
            testOutput = new PrintStream(new BufferedOutputStream(new FileOutputStream(fileName)));
        } catch (FileNotFoundException e) {
            e.printStackTrace();

            throw new RuntimeException(String.format("Can't open %s", resultsPrefix));
        }
    }

    public void run() {

        int testCountDown = testInterval;
        int snapshotCountDown = snapshotInterval;

        long trainingStart = System.currentTimeMillis();
        int trainingSteps = 0;
        while (stepCounter < totalTrainingSteps) {
            long epStartTime = 0;
            if (verbose) {
                System.out.println(String.format("Training Episode %d at step %d", episodeCounter, stepCounter));
                epStartTime = System.currentTimeMillis();
            }

            // Set variables needed for training
            prepareForTraining();
            env.resetEnvironment();

            // run learning episode
            Episode ea = learner.runLearningEpisode(env, Math.min(totalTrainingSteps - stepCounter, maxEpisodeSteps));

            // add up episode reward
            double totalReward = 0;
            for (double r : ea.rewardSequence) {
                totalReward += r;
            }

            if (verbose) {
                // output episode data
                long epEndTime = System.currentTimeMillis();
                double timeInterval = (epEndTime - epStartTime)/1000.0;

                System.out.println(String.format("Episode reward: %.2f -- %.1f steps/sec", totalReward, ea.numTimeSteps()/timeInterval));
                System.out.println();
            }

            // take snapshot every snapshotCountDown steps
            stepCounter += ea.numTimeSteps();
            trainingSteps += ea.numTimeSteps();
            episodeCounter++;
            if (snapshotPrefix != null) {
                snapshotCountDown -= ea.numTimeSteps();
                if (snapshotCountDown <= 0) {
                    saveLearningState(snapshotPrefix);
                    snapshotCountDown += snapshotInterval;
                }
            }

            // take test set every testCountDown steps
            testCountDown -= ea.numTimeSteps();
            if (testCountDown <= 0) {
                double trainingTimeInterval = (System.currentTimeMillis() - trainingStart)/1000.0;

                // run test set
                runTestSet();
                testCountDown += testInterval;

                // output training rate
                System.out.printf("Training rate: %.1f steps/sec\n\n",
                        testInterval/trainingTimeInterval);

                // restart training timer
                trainingStart = System.currentTimeMillis();
            }
        }

        if (testOutput != null) {
            testOutput.printf("Final best: %.2f\n", highestAverageReward);
            testOutput.flush();
        }

        System.out.println("Done Training!");
    }

    public void runTestSet() {

        long testStart = System.currentTimeMillis();
        int numSteps = 0;
        int numEpisodes = 0;

        // Change any learning variables to test values (i.e. experience memory)
        prepareForTesting();

        // Run the test policy on test episodes
        System.out.println("Running Test Set...");
        double totalTestReward = 0;
        while (true) {
            env.resetEnvironment();
            Episode e = tester.runTestEpisode(env, Math.min(maxEpisodeSteps, totalTestSteps - numSteps));

            double totalReward = 0;
            for (double reward : e.rewardSequence) {
                totalReward += reward;
            }

            if (verbose) {
                System.out.println(String.format("%d: Reward = %.2f, Steps = %d", numEpisodes, totalReward, numSteps));
            }

            numSteps += e.numTimeSteps();
            if (numSteps >= totalTestSteps) {
                if (numEpisodes == 0) {
                    totalTestReward = totalReward;
                    numEpisodes = 1;
                }
                break;
            }

            totalTestReward += totalReward;
            numEpisodes += 1;
        }

        double averageReward = totalTestReward/numEpisodes;
        if (averageReward > highestAverageReward) {
            if (resultsPrefix != null) {
                vfa.snapshot(new File(resultsPrefix, "best_net.caffemodel").toString(),  null);
            }
            highestAverageReward = averageReward;
        }

        double testTimeInterval = (System.currentTimeMillis() - testStart)/1000.0;
        System.out.printf("Average Test Reward: %.2f -- highest: %.2f, Test rate: %.1f\n\n", averageReward, highestAverageReward, numSteps/testTimeInterval);

        if (testOutput != null) {
            testOutput.printf("Frame %d: %.2f\n", stepCounter, averageReward);
            testOutput.flush();
        }
    }

    public void saveLearningState(String filePrefix) {
        System.out.print("Saving learning snapshot... ");
        learner.saveLearningState(filePrefix);
        System.out.println("Done");
    }

    public void loadLearningState(String filePrefix) {
        learner.loadLearningState(filePrefix);
    }
}