import random import numpy as np from collections import namedtuple, deque import torch from model import QNet from config import small_epsilon, gamma, alpha, device Transition = namedtuple('Transition', ('state', 'next_state', 'action', 'reward', 'mask')) class Memory_With_TDError(object): def __init__(self, capacity): self.memory = deque(maxlen=capacity) self.memory_probabiliy = deque(maxlen=capacity) self.capacity = capacity def push(self, state, next_state, action, reward, mask): """Saves a transition.""" if len(self.memory) > 0: max_probability = max(self.memory_probabiliy) else: max_probability = small_epsilon self.memory.append(Transition(state, next_state, action, reward, mask)) self.memory_probabiliy.append(max_probability) def sample(self, batch_size, net, target_net, beta): probability_sum = sum(self.memory_probabiliy) p = [probability / probability_sum for probability in self.memory_probabiliy] # print(len(self.memory_probabiliy)) indexes = np.random.choice(np.arange(len(self.memory)), batch_size, p=p) transitions = [self.memory[idx] for idx in indexes] transitions_p = [p[idx] for idx in indexes] batch = Transition(*zip(*transitions)) weights = [pow(self.capacity * p_j, -beta) for p_j in transitions_p] weights = torch.Tensor(weights).to(device) # print(weights) weights = weights / weights.max() # print(weights) td_error = QNet.get_td_error(net, target_net, batch.state, batch.next_state, batch.action, batch.reward, batch.mask) td_error_idx = 0 for idx in indexes: self.memory_probabiliy[idx] = pow(abs(td_error[td_error_idx]) + small_epsilon, alpha).item() # print(pow(abs(td_error[td_error_idx]) + small_epsilon, alpha).item()) td_error_idx += 1 return batch, weights def __len__(self): return len(self.memory)