import random class Pool(): def __init__(self, size): self.size = size self.data = [None] * size self.idx = 0 self.sum_len = 0 self.total = 0 def put(self, x): if(self.total >= self.size): old_x = self.data[self.idx] self.sum_len -= len(old_x[0]) self.sum_len += len(x[0]) self.data[self.idx] = x self.idx = (self.idx + 1) % self.size self.total += 1 ''' Sample a batch of #size episodes. ''' def sample(self, size): return random.choices(self.data, k=size) ''' Samples a batch of episodes. Size of total steps is close to #size.''' def sample_steps(self, size): avg_len = self.sum_len / self.size eps_to_fetch = int(size / avg_len) return random.choices(self.data, k=eps_to_fetch)