import torch import torch.nn as nn import torch.nn.functional as F from . import utils class DuelingDQN(nn.Module): def __init__(self, n_action, input_shape=(4, 84, 84)): super(DuelingDQN, self).__init__() self.n_action = n_action self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4) self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) r = int((int(input_shape[1] / 4) - 1) / 2) - 3 c = int((int(input_shape[2] / 4) - 1) / 2) - 3 self.adv1 = nn.Linear(r * c * 64, 512) self.adv2 = nn.Linear(512, self.n_action) self.val1 = nn.Linear(r * c * 64, 512) self.val2 = nn.Linear(512, 1) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) adv = F.relu(self.adv1(x)) adv = self.adv2(adv) val = F.relu(self.val1(x)) val = self.val2(val) return val + adv - adv.mean(1, keepdim=True) def calc_priorities(self, target_net, transitions, alpha=0.6, gamma=0.999, device=torch.device("cpu")): batch = utils.Transition(*zip(*transitions)) next_state_batch = torch.stack(batch.next_state).to(device) state_batch = torch.stack(batch.state).to(device) action_batch = torch.stack(batch.action).to(device) reward_batch = torch.stack(batch.reward).to(device) done_batch = torch.stack(batch.done).to(device) state_action_values = self.forward(state_batch).gather(1, action_batch) next_action = self.forward(next_state_batch).argmax(dim=1).unsqueeze(1) next_state_values = target_net(next_state_batch).gather(1, next_action).detach() expected_state_action_values = (next_state_values * gamma * (1.0 - done_batch)) \ + reward_batch delta = F.smooth_l1_loss(state_action_values, expected_state_action_values, reduce=False) prios = (delta.abs() + 1e-5).pow(alpha) return delta, prios.detach() class DuelingLSTMDQN(nn.Module): def __init__(self, n_action, batch_size, n_burn_in=40, nstep_return=5, input_shape=(4, 84, 84)): super(DuelingLSTMDQN, self).__init__() self.n_action = n_action self.batch_size = batch_size self.n_burn_in = n_burn_in self.nstep_return = nstep_return self.input_shape = input_shape self.conv1 = nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4) self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2) self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1) r = int((int(input_shape[1] / 4) - 1) / 2) - 3 c = int((int(input_shape[2] / 4) - 1) / 2) - 3 self.lstm = nn.LSTMCell(r * c * 64, 512) self.adv1 = nn.Linear(512, 512) self.adv2 = nn.Linear(512, self.n_action, bias=False) self.val1 = nn.Linear(512, 512) self.val2 = nn.Linear(512, 1) self.hx = torch.zeros(self.batch_size, 512) self.cx = torch.zeros(self.batch_size, 512) def to(self, device): super(DuelingLSTMDQN, self).to(device) self.hx = self.hx.to(device) self.cx = self.cx.to(device) return self def reset(self, done=False): self.hx.detach_() self.cx.detach_() if done: self.hx.zero_() self.cx.zero_() def get_state(self): return self.hx.detach().clone().cpu(), self.cx.detach().clone().cpu() def set_state(self, state, device): hx, cx = state self.hx = hx.clone().to(device) self.cx = cx.clone().to(device) def forward(self, x): x = F.relu(self.conv1(x)) x = F.relu(self.conv2(x)) x = F.relu(self.conv3(x)) x = x.view(x.size(0), -1) self.hx, self.cx = self.lstm(x, (self.hx, self.cx)) adv = F.relu(self.adv1(self.hx)) adv = self.adv2(adv) val = F.relu(self.val1(self.hx)) val = self.val2(val) return val + adv - adv.mean(1, keepdim=True) def calc_priorities(self, target_net, transitions, eta=0.9, gamma=0.997, require_grad=True, device=torch.device("cpu")): n_transitions = len(transitions) self_cp = DuelingLSTMDQN(self.n_action, self.batch_size, self.n_burn_in, self.nstep_return, self.input_shape).to(device) self_cp.load_state_dict(self.state_dict()) self_cp.eval() batch = utils.Sequence(*zip(*transitions)) batch = utils.Sequence(list(zip(*(batch.transitions))), list(zip(*(batch.recurrent_state)))) hx = torch.cat(batch.recurrent_state[0]) cx = torch.cat(batch.recurrent_state[1]) self.set_state((hx, cx), device) target_net.set_state((hx, cx), device) # burn-in with torch.no_grad(): for t in range(self.n_burn_in): trans = utils.Transition(*zip(*(batch.transitions[t]))) state_batch = torch.stack(trans.state).to(device) self.forward(state_batch) target_net.forward(state_batch) self_cp.set_state(self.get_state(), device) for t in range(self.n_burn_in, self.n_burn_in + self.nstep_return): trans = utils.Transition(*zip(*(batch.transitions[t]))) state_batch = torch.stack(trans.state).to(device) self_cp.forward(state_batch) target_net.forward(state_batch) self.reset() n_sequence = len(batch.transitions) delta = torch.zeros(n_sequence - self.n_burn_in - self.nstep_return, n_transitions, 1, device=device) with torch.set_grad_enabled(require_grad): for t in range(self.n_burn_in, n_sequence - self.nstep_return): trans0 = utils.Transition(*zip(*(batch.transitions[t]))) trans1 = utils.Transition(*zip(*(batch.transitions[t + self.nstep_return]))) state_batch = torch.stack(trans0.state).to(device) action_batch = torch.stack(trans0.action).to(device) reward_batch = torch.stack(trans0.reward).to(device) next_state_batch = torch.stack(trans1.state).to(device) done_batch = torch.stack(trans1.done).to(device) state_action_values = self.forward(state_batch).gather(1, action_batch) next_action = self_cp.forward(next_state_batch).argmax(dim=1).unsqueeze(1).detach() next_state_values = target_net(next_state_batch).gather(1, next_action).detach() expected_state_action_values = utils.rescale((utils.inv_rescale(next_state_values) \ * gamma * (1.0 - done_batch)) + reward_batch) delta[t - self.n_burn_in] = F.l1_loss(state_action_values, expected_state_action_values, reduce=False) prios = eta * delta.max(dim=0)[0] + (1.0 - eta) * delta.mean(dim=0) return delta.pow(2).sum(dim=0), prios.detach()