package edu.brown.cs.burlap.experiencereplay; import burlap.debugtools.RandomFactory; import burlap.mdp.auxiliary.StateMapping; import burlap.mdp.core.state.State; import burlap.mdp.singleagent.environment.EnvironmentOutcome; import edu.brown.cs.burlap.ALEState; import edu.brown.cs.burlap.action.ActionSet; import edu.brown.cs.burlap.preprocess.PreProcessor; import edu.brown.cs.burlap.vfa.StateVectorizor; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.FloatPointer; import java.io.*; import java.util.ArrayList; import java.util.List; import java.util.Random; import static org.bytedeco.javacpp.opencv_core.Mat; /** * A fixed size memory made specifically for storing vectorized states that overlap with previous states. * For example, if the state is 4 frames from an Atari game. * * @author Melrose Roderick. */ public class FrameExperienceMemory implements SavableExperienceMemory, StateVectorizor, StateMapping, Serializable { public transient BytePointer frameMemory; public transient PreProcessor preProcessor; public transient ActionSet actionSet; public FrameHistory currentFrameHistory; public int next = 0; public FrameExperience[] experiences; public int size = 0; public boolean alwaysIncludeMostRecent; int maxHistoryLength; // the history size public FrameExperienceMemory(int size, int maxHistoryLength, PreProcessor preProcessor, ActionSet actionSet) { this(size, maxHistoryLength, preProcessor, actionSet, false); } public FrameExperienceMemory(int size, int maxHistoryLength, PreProcessor preProcessor, ActionSet actionSet, boolean alwaysIncludeMostRecent) { if(size < 1){ throw new RuntimeException("FixedSizeMemory requires memory size > 0; was request size of " + size); } this.alwaysIncludeMostRecent = alwaysIncludeMostRecent; this.experiences = new FrameExperience[size]; this.currentFrameHistory = new FrameHistory(0, 0); this.maxHistoryLength = maxHistoryLength; this.preProcessor = preProcessor; this.actionSet = actionSet; long outputSize = preProcessor.outputSize(); // Create the frame history data size to be totalHistorySize + a padding on both sides of n - 1 long paddingSize = (this.maxHistoryLength - 1) * outputSize; frameMemory = (new BytePointer(size * outputSize + 2 * paddingSize)).zero(); } @Override public void vectorizeState(State state, FloatPointer input) { FrameHistory frameHistory = (FrameHistory) state; long frameSize = preProcessor.outputSize(); long index = frameHistory.index; int historyLength = frameHistory.historyLength; long pos = input.position(); input.limit(pos + maxHistoryLength * frameSize); // Fill unused frames with 0s if (historyLength < maxHistoryLength) { if (historyLength > 0) { input.limit(pos + (maxHistoryLength - historyLength)*frameSize).zero(); input.limit(pos + maxHistoryLength * frameSize); } else { input.zero(); return; } } // Convert compressed frameHistory data to CNN input preProcessor.convertDataToInput( this.frameMemory.position(index - (historyLength - 1)*frameSize), input.position(pos + (maxHistoryLength - historyLength)*frameSize), historyLength); input.position(pos); } @Override /** Assumes the input state is the most recently added state to the history **/ public State mapState(State s) { return currentFrameHistory; } @Override public void addExperience(EnvironmentOutcome eo) { // If this is the first frame of the episode, add the o frame. if (currentFrameHistory.historyLength == 0) { currentFrameHistory = addFrame(((ALEState)eo.o).getScreen()); } // If this is experience ends in a terminal state, // the terminal frame will never be used so don't add it. FrameHistory op; if (eo.terminated) { op = new FrameHistory(currentFrameHistory.index, 0); } else { op = addFrame(((ALEState)eo.op).getScreen()); } experiences[next] = new FrameExperience(currentFrameHistory, actionSet.map(eo.a), op, eo.r, eo.terminated); next = (next+1) % experiences.length; size = Math.min(size+1, experiences.length); currentFrameHistory = op; } protected FrameHistory addFrame(Mat screenMat) { long outputSize = preProcessor.outputSize(); long frameHistoryDataSize = frameMemory.capacity(); long paddingSize = (maxHistoryLength - 1) * outputSize; // Find new index long newIndex = currentFrameHistory.index + outputSize; if (newIndex >= frameHistoryDataSize) { // Copy the buffer to the start of the history BytePointer frameHistoryCopy = new BytePointer(frameMemory); frameHistoryCopy.limit(frameHistoryCopy.capacity()); frameMemory.position(0).limit(paddingSize).put(frameHistoryCopy.position(frameHistoryDataSize - paddingSize)); frameMemory.limit(frameMemory.capacity()); newIndex = paddingSize; } // Increment length if smaller than n int newHistoryLength = currentFrameHistory.historyLength >= maxHistoryLength ? maxHistoryLength : currentFrameHistory.historyLength + 1; // Process the new screen and place it in the memory preProcessor.convertScreenToData(screenMat, frameMemory.position(newIndex)); // Create new frame return new FrameHistory(newIndex, newHistoryLength); } @Override public List<EnvironmentOutcome> sampleExperiences(int n) { List<FrameExperience> samples = sampleFrameExperiences(n); List<EnvironmentOutcome> sampleOutcomes = new ArrayList<>(samples.size()); for (FrameExperience exp : samples) { sampleOutcomes.add(new EnvironmentOutcome(exp.o, actionSet.get(exp.a), exp.op, exp.r, exp.terminated)); } return sampleOutcomes; } public List<FrameExperience> sampleFrameExperiences(int n) { List<FrameExperience> samples; if(this.size == 0){ return new ArrayList<>(); } if(this.alwaysIncludeMostRecent){ n--; } if(this.size < n){ samples = new ArrayList<>(this.size); for(int i = 0; i < this.size; i++){ samples.add(this.experiences[i]); } return samples; } else{ samples = new ArrayList<>(Math.max(n, 1)); Random r = RandomFactory.getMapped(0); for(int i = 0; i < n; i++) { int sind = r.nextInt(this.size); samples.add(this.experiences[sind]); } } if(this.alwaysIncludeMostRecent){ FrameExperience eo; if(next > 0) { eo = this.experiences[next - 1]; } else if(size > 0){ eo = this.experiences[this.experiences.length-1]; } else{ throw new RuntimeException("FixedSizeMemory getting most recent fails because memory is size 0."); } samples.add(eo); } return samples; } @Override public void resetMemory() { this.size = 0; this.next = 0; this.currentFrameHistory = new FrameHistory(0, 0); } @Override public void saveMemory(String filePrefix) { String frameHistoryFilename = filePrefix + ".framehist"; String frameExperienceFilename = filePrefix + ".ser"; try (ObjectOutputStream objOut = new ObjectOutputStream(new FileOutputStream(frameExperienceFilename)); FileOutputStream historyOut = new FileOutputStream(frameHistoryFilename)) { objOut.writeObject(this); // write frame history long pos = 0; byte[] buffer = new byte[10000000]; int numRead; while (pos < frameMemory.limit()) { numRead = (int)Math.min(buffer.length, frameMemory.limit() - pos); frameMemory.position(pos).get(buffer, 0, numRead); pos += numRead; historyOut.write(buffer, 0, numRead); } } catch (IOException e) { System.out.println("Unable to save experience memory"); e.printStackTrace(); return; } } @Override public void loadMemory(String filePrefix) { String frameHistoryFilename = filePrefix + ".framehist"; String frameExperienceFilename = filePrefix + ".ser"; try (ObjectInputStream objIn = new ObjectInputStream(new FileInputStream(frameExperienceFilename)); FileInputStream historyIn = new FileInputStream(frameHistoryFilename)) { // load object FrameExperienceMemory experienceMemory = (FrameExperienceMemory) objIn.readObject(); this.currentFrameHistory = experienceMemory.currentFrameHistory; this.next = experienceMemory.next; this.size = experienceMemory.size; this.experiences = experienceMemory.experiences; // load frame history long pos = 0; byte[] buffer = new byte[10000000]; int numRead; while ((numRead = historyIn.read(buffer)) != -1) { this.frameMemory.position(pos).put(buffer, 0, numRead); pos += numRead; } } catch (FileNotFoundException e) { e.printStackTrace(); } catch (IOException e) { e.printStackTrace(); } catch (ClassNotFoundException e) { e.printStackTrace(); } } }