```# Experience Replay
# Following paper: Playing Atari with Deep Reinforcement Learning
#     https://www.cs.toronto.edu/~vmnih/docs/dqn.pdf
#
# ---
# @author Yiren Lu
# @email luyiren [at] seas [dot] upenn [dot] edu
#

import numpy as np
import random
from collections import namedtuple

Step = namedtuple('Step','cur_step action next_step reward done')

class ExpReplay():
"""Experience replay"""

def __init__(self, mem_size, start_mem=None, state_size=[84, 84], kth=4, drop_rate=0.2, batch_size=32):
# k = -1 for sending raw state
self.state_size = state_size
self.drop_rate = drop_rate
self.mem_size = mem_size
self.start_mem = start_mem
if start_mem == None:
self.start_mem = mem_size/20
self.kth = kth
self.batch_size = batch_size
self.mem = []
self.total_steps = 0

"""
Store episode to memory and check if it reaches the mem_size.
If so, drop [self.drop_rate] of the oldest memory

args
step      namedtuple Step, where step.cur_step and step.next_step are of size {state_size}
"""
self.mem.append(step)
self.total_steps = self.total_steps + 1
while len(self.mem) > self.mem_size:
self.mem = self.mem[int(len(self.mem)*self.drop_rate):]

def get_last_state(self):
if len(self.mem) > abs(self.kth):
if self.kth == -1:
return self.mem[-1].cur_step
if len(self.state_size) == 1:
return [s.cur_step for s in self.mem[-abs(self.kth):]]
last_state = np.stack([s.cur_step for s in self.mem[-abs(self.kth):]], axis=len(self.state_size))
return np.stack([s.cur_step for s in self.mem[-abs(self.kth):]], axis=len(self.state_size))
return []

def sample(self, num=None):
"""Randomly draw [num] samples"""
if num == None:
num = self.batch_size
if len(self.mem) < self.start_mem:
return []
sampled_idx = random.sample(range(abs(self.kth),len(self.mem)), num)
samples = []
for idx in sampled_idx:
steps = self.mem[idx-abs(self.kth):idx]
cur_state = np.stack([s.cur_step for s in steps], axis=len(self.state_size))
next_state = np.stack([s.next_step for s in steps], axis=len(self.state_size))
# handle special cases
if self.kth == -1:
cur_state = steps[0].cur_step
next_state = steps[0].next_step
elif len(self.state_size) == 1:
cur_state = [steps[0].cur_step]
next_state = [steps[0].next_step]
reward = steps[-1].reward
action = steps[-1].action
done = steps[-1].done
samples.append(Step(cur_step=cur_state, action=action, next_step=next_state, reward=reward, done=done))
return samples

```