# Experience Replay # Following paper: Playing Atari with Deep Reinforcement Learning # https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf # # --- # @author Yiren Lu # @email luyiren [at] seas [dot] upenn [dot] edu # # MIT License import numpy as np import random from collections import namedtuple Step = namedtuple('Step','cur_step action next_step reward done') class ExpReplay(): """Experience replay""" def __init__(self, mem_size, start_mem=None, state_size=[84, 84], kth=4, drop_rate=0.2, batch_size=32): # k = -1 for sending raw state self.state_size = state_size self.drop_rate = drop_rate self.mem_size = mem_size self.start_mem = start_mem if start_mem == None: self.start_mem = mem_size/20 self.kth = kth self.batch_size = batch_size self.mem = [] self.total_steps = 0 def add_step(self, step): """ Store episode to memory and check if it reaches the mem_size. If so, drop [self.drop_rate] of the oldest memory args step namedtuple Step, where step.cur_step and step.next_step are of size {state_size} """ self.mem.append(step) self.total_steps = self.total_steps + 1 while len(self.mem) > self.mem_size: self.mem = self.mem[int(len(self.mem)*self.drop_rate):] def get_last_state(self): if len(self.mem) > abs(self.kth): if self.kth == -1: return self.mem[-1].cur_step if len(self.state_size) == 1: return [s.cur_step for s in self.mem[-abs(self.kth):]] last_state = np.stack([s.cur_step for s in self.mem[-abs(self.kth):]], axis=len(self.state_size)) return np.stack([s.cur_step for s in self.mem[-abs(self.kth):]], axis=len(self.state_size)) return [] def sample(self, num=None): """Randomly draw [num] samples""" if num == None: num = self.batch_size if len(self.mem) < self.start_mem: return [] sampled_idx = random.sample(range(abs(self.kth),len(self.mem)), num) samples = [] for idx in sampled_idx: steps = self.mem[idx-abs(self.kth):idx] cur_state = np.stack([s.cur_step for s in steps], axis=len(self.state_size)) next_state = np.stack([s.next_step for s in steps], axis=len(self.state_size)) # handle special cases if self.kth == -1: cur_state = steps[0].cur_step next_state = steps[0].next_step elif len(self.state_size) == 1: cur_state = [steps[0].cur_step] next_state = [steps[0].next_step] reward = steps[-1].reward action = steps[-1].action done = steps[-1].done samples.append(Step(cur_step=cur_state, action=action, next_step=next_state, reward=reward, done=done)) return samples