import math import tensorflow as tf def variable_summaries(var, groupname, name): """Attach a lot of summaries to a Tensor. This is also quite expensive. """ with tf.device("/cpu:0"), tf.name_scope(None): s_var = tf.cast(var, tf.float32) amean = tf.reduce_mean(tf.abs(s_var)) tf.summary.scalar(groupname + '/amean/' + name, amean) mean = tf.reduce_mean(s_var) tf.summary.scalar(groupname + '/mean/' + name, mean) stddev = tf.sqrt(tf.reduce_sum(tf.square(s_var - mean))) tf.summary.scalar(groupname + '/sttdev/' + name, stddev) tf.summary.scalar(groupname + '/max/' + name, tf.reduce_max(s_var)) tf.summary.scalar(groupname + '/min/' + name, tf.reduce_min(s_var)) tf.summary.histogram(groupname + "/" + name, s_var) def getdtype(hps, is_rnn=False): if is_rnn: return tf.float16 if hps.float16_rnn else tf.float32 else: return tf.float16 if hps.float16_non_rnn else tf.float32 def linear(x, size, name): w = tf.get_variable(name + "/W", [x.get_shape()[-1], size]) b = tf.get_variable(name + "/b", [1, size], initializer=tf.zeros_initializer) return tf.matmul(x, w) + b def sharded_variable(name, shape, num_shards, dtype=tf.float32, transposed=False): # The final size of the sharded variable may be larger than requested. # This should be fine for embeddings. shard_size = int((shape[0] + num_shards - 1) / num_shards) if transposed: initializer = tf.uniform_unit_scaling_initializer(dtype=dtype) else: initializer = tf.uniform_unit_scaling_initializer(dtype=dtype) return [tf.get_variable(name + "_" + str(i), [shard_size, shape[1]], initializer=initializer, dtype=dtype) for i in range(num_shards)] # XXX(rafal): Code below copied from rnn_cell.py def _get_sharded_variable(name, shape, dtype, num_shards): """Get a list of sharded variables with the given dtype.""" if num_shards > shape[0]: raise ValueError("Too many shards: shape=%s, num_shards=%d" % (shape, num_shards)) unit_shard_size = int(math.floor(shape[0] / num_shards)) remaining_rows = shape[0] - unit_shard_size * num_shards shards = [] for i in range(num_shards): current_size = unit_shard_size if i < remaining_rows: current_size += 1 shards.append(tf.get_variable(name + "_%d" % i, [current_size] + shape[1:], dtype=dtype)) return shards def _get_concat_variable(name, shape, dtype, num_shards): """Get a sharded variable concatenated into one tensor.""" _sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards) if len(_sharded_variable) == 1: return _sharded_variable[0] return tf.concat(_sharded_variable, 0) class FLSTMCell(tf.contrib.rnn.RNNCell): """LSTMCell with factorized matrix""" def __init__(self, num_units, input_size, initializer=None, num_proj=None, num_shards=1, factor_size=None, fnon_linearity=None, dtype=tf.float32): self._num_units = num_units self._initializer = initializer self._num_proj = num_proj self._num_unit_shards = num_shards self._num_proj_shards = num_shards self._forget_bias = 1.0 if factor_size: self._factor_size = int(factor_size) else: self._factor_size = None self._fnon_linearity = fnon_linearity if num_proj: self._state_size = num_units + num_proj self._output_size = num_proj else: self._state_size = 2 * num_units self._output_size = num_units with tf.variable_scope("LSTMCell"): if self._factor_size: self._concat_w1 = _get_concat_variable( "W1", [input_size + num_proj, self._factor_size], dtype, self._num_unit_shards) self._concat_w2 = _get_concat_variable( "W2", [self._factor_size, 4 * self._num_units], dtype, self._num_unit_shards) if self._fnon_linearity: self._b1 = tf.get_variable(name="b1", shape=[self._factor_size]) else: self._concat_w = _get_concat_variable( "W", [input_size + num_proj, 4 * self._num_units], dtype, self._num_unit_shards) self._b = tf.get_variable( "B", shape=[4 * self._num_units]) self._concat_w_proj = _get_concat_variable( "W_P", [self._num_units, self._num_proj], dtype, self._num_proj_shards) @property def state_size(self): return self._state_size @property def output_size(self): return self._output_size def __call__(self, inputs, state, scope=None): num_proj = self._num_units if self._num_proj is None else self._num_proj c_prev = tf.slice(state, [0, 0], [-1, self._num_units]) m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj]) input_size = inputs.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from inputs.get_shape()[-1]") with tf.variable_scope(type(self).__name__, initializer=self._initializer): # "LSTMCell" # i = input_gate, j = new_input, f = forget_gate, o = output_gate cell_inputs = tf.concat([inputs, m_prev], 1) if self._factor_size: if self._fnon_linearity: lstm_matrix = tf.nn.bias_add(tf.matmul( self._fnon_linearity(tf.nn.bias_add(tf.matmul(cell_inputs, self._concat_w1), self._b1)), self._concat_w2), self._b) else: lstm_matrix = tf.nn.bias_add(tf.matmul(tf.matmul(cell_inputs, self._concat_w1), self._concat_w2), self._b) else: lstm_matrix = tf.matmul(cell_inputs, self._concat_w) + self._b i, j, f, o = tf.split(lstm_matrix, 4, 1) c = tf.sigmoid(f + 1.0) * c_prev + tf.sigmoid(i) * tf.tanh(j) m = tf.sigmoid(o) * tf.tanh(c) if self._num_proj is not None: m = tf.matmul(m, self._concat_w_proj) new_state = tf.concat([c, m], 1) return m, new_state