# -*- coding: utf-8 -*- import os import time import numpy as np from itertools import count import redis import torch from ..libs import utils from . import replay # if gpu is to be used device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class Learner(object): """Learner of Ape-X Args: policy_net (torch.nn.Module): Q-function network target_net (torch.nn.Module): target network optimizer (torch.optim.Optimizer): optimizer vis (visdom.Visdom): visdom object replay_size (int, optional): size of replay memory hostname (str, optional): host name of redis server beta_decay (int, optional): Decay of annealing bias use_memory_compress (bool, optional): use the compressed replay memory for saved memory """ def __init__(self, policy_net, target_net, optimizer, vis, replay_size=30000, hostname='localhost', beta_decay=1000000, use_memory_compress=False): self._vis = vis self._policy_net = policy_net self._target_net = target_net self._target_net.load_state_dict(self._policy_net.state_dict()) self._target_net.eval() self._beta_decay = beta_decay self._connect = redis.StrictRedis(host=hostname) self._connect.delete('params') self._optimizer = optimizer self._win = self._vis.line(X=np.array([0]), Y=np.array([0]), opts=dict(title='Memory size')) self._win2 = self._vis.line(X=np.array([0]), Y=np.array([0]), opts=dict(title='Q loss')) self._memory = replay.Replay(replay_size, self._connect, use_compress=use_memory_compress) self._memory.start() def _sleep(self): mlen = self._connect.llen('experience') time.sleep(0.01 * mlen) def _wait_memory(self, memory_size): while True: if len(self._memory) > memory_size: break time.sleep(0.1) def optimize_loop(self, batch_size=512, gamma=0.999**3, beta0=0.4, max_grad_norm=40, start_memory_size=10000, fit_timing=100, target_update=1000, actor_device=device, save_timing=10000, save_model_dir='./models'): self._wait_memory(max(batch_size, start_memory_size)) for t in count(): transitions, prios, indices = self._memory.sample(batch_size) total = len(self._memory) beta = min(1.0, beta0 + (1.0 - beta0) / self._beta_decay * t) weights = (total * np.array(prios) / self._memory.total_prios) ** (-beta) weights /= weights.max() delta, prio = self._policy_net.calc_priorities(self._target_net, transitions, gamma=gamma, device=device) loss = (delta * torch.from_numpy(np.expand_dims(weights, 1).astype(np.float32)).to(device)).mean() # Optimize the model self._optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self._policy_net.parameters(), max_grad_norm) self._optimizer.step() self._memory.update_priorities(indices, prio.squeeze(1).cpu().numpy().tolist()) self._connect.set('params', utils.dumps(self._policy_net.to(actor_device).state_dict())) self._policy_net.to(device) self._vis.line(X=np.array([t]), Y=np.array([loss.detach().cpu().numpy()]), win=self._win2, update='append') if t % fit_timing == 0: print('[Learner] Remove to fit.') self._memory.remove_to_fit() self._vis.line(X=np.array([t]), Y=np.array([len(self._memory)]), win=self._win, update='append') if t % target_update == 0: print('[Learner] Update target.') self._target_net.load_state_dict(self._policy_net.state_dict()) if t % save_timing == 0: print('[Learner] Save model.') torch.save(self._policy_net.state_dict(), os.path.join(save_model_dir, 'model_%d.pth' % t)) self._sleep()