# -*- coding: utf-8 -*- import os from PIL import Image import plotly from plotly.graph_objs import Scatter, Line import msgpack import msgpack_numpy import torch from torch import multiprocessing as mp from torchvision import transforms to_tensor = transforms.ToTensor() # Patch MessagePack to work with numpy arrays msgpack_numpy.patch() # Global counter class Counter(): def __init__(self): self.val = mp.Value('i', 0) self.lock = mp.Lock() def increment(self): with self.lock: self.val.value += 1 def value(self): with self.lock: return self.val.value # Sends Torch tensor over ØMQ (via numpy format as torch storage does not provide a buffer interface) def send_tensors(socket, tensors, flags=0, copy=True, track=False): # TODO: Investigate options return socket.send(msgpack.packb([tensor.numpy() for tensor in tensors]), flags, copy=copy, track=track) # Receives Torch tensor over ØMQ def receive_tensors(socket, flags=0, copy=True, track=False): msg = socket.recv(flags=flags, copy=copy, track=track) return [torch.from_numpy(tensor) for tensor in msgpack.unpackb(msg)] # Preprocesses ALE frames for A3C def _preprocess(img): return to_tensor(Image.fromarray(img, mode='RGB').resize([84, 84])) # Converts a state from the OpenAI Gym (a numpy array) to a batch tensor def state_to_tensor(state): return _preprocess(state).unsqueeze(0) # Plots min, max and mean + standard deviation bars of a population over time def plot_line(xs, ys_population, path=''): max_colour, mean_colour, std_colour = 'rgb(0, 132, 180)', 'rgb(0, 172, 237)', 'rgba(29, 202, 255, 0.2)' ys = torch.Tensor(ys_population) ys_min = ys.min(1)[0].squeeze() ys_max = ys.max(1)[0].squeeze() ys_mean = ys.mean(1).squeeze() ys_std = ys.std(1).squeeze() ys_upper, ys_lower = ys_mean + ys_std, ys_mean - ys_std trace_max = Scatter(x=xs, y=ys_max.numpy(), line=Line(color=max_colour, dash='dash'), name='Max') trace_upper = Scatter(x=xs, y=ys_upper.numpy(), line=Line(color='transparent'), name='+1 Std. Dev.', showlegend=False) trace_mean = Scatter(x=xs, y=ys_mean.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color=mean_colour), name='Mean') trace_lower = Scatter(x=xs, y=ys_lower.numpy(), fill='tonexty', fillcolor=std_colour, line=Line(color='transparent'), name='-1 Std. Dev.', showlegend=False) trace_min = Scatter(x=xs, y=ys_min.numpy(), line=Line(color=max_colour, dash='dash'), name='Min') plotly.offline.plot({ 'data': [trace_upper, trace_mean, trace_lower, trace_min, trace_max], 'layout': dict(title='Rewards', xaxis={'title': 'Step'}, yaxis={'title': 'Average Reward'}) }, filename=os.path.join(path, 'rewards.html'), auto_open=False)