# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import layers import utils.ops from custom_lstm import CustomBasicLSTMCell class Network(object): def __init__(self, conf): ''' Initialize hyper-parameters, set up optimizer and network layers common across Q and Policy/V nets ''' self.name = conf['name'] self.num_actions = conf['num_act'] self.arch = conf['args'].arch self.batch_size = conf['args'].batch_size self.optimizer_type = conf['args'].opt_type self.optimizer_mode = conf['args'].opt_mode self.clip_loss_delta = conf['args'].clip_loss_delta self.clip_norm = conf['args'].clip_norm self.clip_norm_type = conf['args'].clip_norm_type self.input_shape = conf['input_shape'] self.activation = conf['args'].activation self.max_local_steps = conf['args'].max_local_steps self.input_channels = 3 if conf['args'].use_rgb else conf['args'].history_length self.use_recurrent = 'lstm' in conf['args'].alg_type self.fc_layer_sizes = conf['args'].fc_layer_sizes self._init_placeholders() def _init_placeholders(self): with tf.variable_scope(self.name): if self.arch == 'FC': self.input_ph = tf.placeholder('float32', [self.batch_size]+self.input_shape+[self.input_channels], name='input') else: #assume image input self.input_ph = tf.placeholder('float32',[self.batch_size, 84, 84, self.input_channels], name='input') if self.use_recurrent: self.hidden_state_size = 256 self.step_size = tf.placeholder(tf.float32, [None], name='step_size') self.initial_lstm_state = tf.placeholder( tf.float32, [None, 2*self.hidden_state_size], name='initital_state') self.selected_action_ph = tf.placeholder( 'float32', [self.batch_size, self.num_actions], name='selected_action') def _build_encoder(self): with tf.variable_scope(self.name): if self.arch == 'FC': layer_i = layers.flatten(self.input_ph) for i, layer_size in enumerate(self.fc_layer_sizes): layer_i = layers.fc('fc{}'.format(i+1), layer_i, layer_size, activation=self.activation)[-1] self.ox = layer_i elif self.arch == 'ATARI-TRPO': self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 16, 4, self.input_channels, 2, activation=self.activation) self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 16, 4, 16, 2, activation=self.activation) self.w3, self.b3, self.o3 = layers.fc('fc3', layers.flatten(self.o2), 20, activation=self.activation) self.ox = self.o3 elif self.arch == 'NIPS': self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 16, 8, self.input_channels, 4, activation=self.activation) self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 32, 4, 16, 2, activation=self.activation) self.w3, self.b3, self.o3 = layers.fc('fc3', layers.flatten(self.o2), 256, activation=self.activation) self.ox = self.o3 elif self.arch == 'NATURE': self.w1, self.b1, self.o1 = layers.conv2d('conv1', self.input_ph, 32, 8, self.input_channels, 4, activation=self.activation) self.w2, self.b2, self.o2 = layers.conv2d('conv2', self.o1, 64, 4, 32, 2, activation=self.activation) self.w3, self.b3, self.o3 = layers.conv2d('conv3', self.o2, 64, 3, 64, 1, activation=self.activation) self.w4, self.b4, self.o4 = layers.fc('fc4', layers.flatten(self.o3), 512, activation=self.activation) self.ox = self.o4 else: raise Exception('Invalid architecture `{}`'.format(self.arch)) if self.use_recurrent: with tf.variable_scope('lstm_layer') as vs: self.lstm_cell = tf.contrib.rnn.BasicLSTMCell( self.hidden_state_size, state_is_tuple=True, forget_bias=1.0) batch_size = tf.shape(self.step_size)[0] self.ox_reshaped = tf.reshape(self.ox, [batch_size, -1, self.ox.get_shape().as_list()[-1]]) state_tuple = tf.contrib.rnn.LSTMStateTuple( *tf.split(self.initial_lstm_state, 2, 1)) self.lstm_outputs, self.lstm_state = tf.nn.dynamic_rnn( self.lstm_cell, self.ox_reshaped, initial_state=state_tuple, sequence_length=self.step_size, time_major=False) self.lstm_state = tf.concat(self.lstm_state, 1) self.ox = tf.reshape(self.lstm_outputs, [-1,self.hidden_state_size], name='reshaped_lstm_outputs') # Get all LSTM trainable params self.lstm_trainable_variables = [v for v in tf.trainable_variables() if v.name.startswith(vs.name)] return self.ox def _value_function_loss(self, diff): if self.clip_loss_delta > 0: # DEFINE HUBER LOSS return 0.5 * tf.reduce_sum(tf.where( tf.abs(diff) < self.clip_loss_delta, tf.square(diff), tf.square(self.clip_loss_delta)*tf.abs(diff))) else: return tf.nn.l2_loss(diff) def _clip_grads(self, grads): if self.clip_norm_type == 'ignore': return grads elif self.clip_norm_type == 'global': return tf.clip_by_global_norm(grads, self.clip_norm)[0] elif self.clip_norm_type == 'avg': return tf.clip_by_average_norm(grads, self.clip_norm)[0] elif self.clip_norm_type == 'local': return [tf.clip_by_norm(g, self.clip_norm) for g in grads] def _setup_shared_memory_ops(self): # Placeholders for shared memory vars self.params_ph = [] for p in self.params: self.params_ph.append(tf.placeholder(tf.float32, shape=p.get_shape(), name="shared_memory_for_{}".format( (p.name.split("/", 1)[1]).replace(":", "_")))) # Ops to sync net with shared memory vars self.sync_with_shared_memory = [] for i in xrange(len(self.params)): self.sync_with_shared_memory.append( self.params[i].assign(self.params_ph[i])) def _build_gradient_ops(self, loss): self.params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=self.name) self.flat_vars = utils.ops.flatten_vars(self.params) grads = tf.gradients(loss, self.params) self.get_gradients = self._clip_grads(grads) self._setup_shared_memory_ops() def get_input_shape(self): return self.input_ph.get_shape().as_list()[1:]