# Taken from https://github.com/ikostrikov/pytorch-a3c from __future__ import absolute_import, division, print_function import numpy as np import torch import torch.nn as nn import torch.nn.functional as F def normalized_columns_initializer(weights, std=1.0): out = torch.randn(weights.size()) out *= std / torch.sqrt(out.pow(2).sum(1).expand_as(out)) return out def weights_init(m): """ Not actually using this but let's keep it here in case that changes """ classname = m.__class__.__name__ if classname.find('Conv') != -1: weight_shape = list(m.weight.data.size()) fan_in = np.prod(weight_shape[1:4]) fan_out = np.prod(weight_shape[2:4]) * weight_shape[0] w_bound = np.sqrt(6. / (fan_in + fan_out)) m.weight.data.uniform_(-w_bound, w_bound) m.bias.data.fill_(0) elif classname.find('Linear') != -1: weight_shape = list(m.weight.data.size()) fan_in = weight_shape[1] fan_out = weight_shape[0] w_bound = np.sqrt(6. / (fan_in + fan_out)) m.weight.data.uniform_(-w_bound, w_bound) m.bias.data.fill_(0) def selu(x): alpha = 1.6732632423543772848170429916717 scale = 1.0507009873554804934193349852946 return scale * F.elu(x, alpha) class ES(torch.nn.Module): def __init__(self, num_inputs, action_space, small_net=False): """ Really I should be using inheritance for the small_net here """ super(ES, self).__init__() num_outputs = action_space.n self.small_net = small_net if self.small_net: self.linear1 = nn.Linear(num_inputs, 64) self.linear2 = nn.Linear(64, 64) self.actor_linear = nn.Linear(64, num_outputs) else: self.conv1 = nn.Conv2d(num_inputs, 32, 3, stride=2, padding=1) self.conv2 = nn.Conv2d(32, 32, 3, stride=2, padding=1) self.conv3 = nn.Conv2d(32, 32, 3, stride=2, padding=1) self.conv4 = nn.Conv2d(32, 32, 3, stride=2, padding=1) self.lstm = nn.LSTMCell(32*3*3, 256) self.actor_linear = nn.Linear(256, num_outputs) self.train() def forward(self, inputs): if self.small_net: x = selu(self.linear1(inputs)) x = selu(self.linear2(x)) return self.actor_linear(x) else: inputs, (hx, cx) = inputs x = selu(self.conv1(inputs)) x = selu(self.conv2(x)) x = selu(self.conv3(x)) x = selu(self.conv4(x)) x = x.view(-1, 32*3*3) hx, cx = self.lstm(x, (hx, cx)) x = hx return self.actor_linear(x), (hx, cx) def count_parameters(self): count = 0 for param in self.parameters(): count += param.data.numpy().flatten().shape[0] return count def es_params(self): """ The params that should be trained by ES (all of them) """ return [(k, v) for k, v in zip(self.state_dict().keys(), self.state_dict().values())]