import tensorflow as tf from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops.math_ops import sigmoid from tensorflow.python.ops.math_ops import tanh from tensorflow.python.platform import tf_logging as logging from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import init_ops from tensorflow.python.ops.rnn_cell import LSTMStateTuple # Future : Replace it with tensorflow.python.util.nest import collections import six def _is_sequence(seq): return (isinstance(seq, collections.Sequence) and not isinstance(seq, six.string_types)) class ConvLSTMCell(rnn_cell.RNNCell): """ Convolutional LSTM network cell (ConvLSTM). The implementation is based on http://arxiv.org/abs/1506.04214. and BasicLSTMCell in TensorFlow. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py Future : Peephole connection will be added as the full LSTMCell implementation of TensorFlow. """ def __init__(self, num_units, input_size=None, use_peepholes=False, cell_clip=None, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=False, activation=tanh): # if not state_is_tuple: # logging.warn( # "%s: Using a concatenated state is slower and will soon be " # "deprecated. Use state_is_tuple=True." % self) if input_size is not None: logging.warn("%s: The input_size parameter is deprecated." % self) #self._use_peepholes = use_peepholes #self._cell_clip = cell_clip #self._initializer = initializer #self._num_proj = num_proj #self._num_unit_shards = num_unit_shards #self._num_proj_shards = num_proj_shards self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation @property def state_size(self): return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units) @property def output_size(self): return self._num_units def zero_state(self, batch_size, height, width): return tf.zeros([batch_size, height, width, self._num_units*2]) def __call__(self, inputs, state, k_size=3, scope=None): """Convolutional Long short-term memory cell (ConvLSTM).""" with vs.variable_scope(scope or type(self).__name__): # "ConvLSTMCell" if self._state_is_tuple: c, h = state else: c, h = array_ops.split(3, 2, state) # batch_size * height * width * channel concat = _conv([inputs, h], 4 * self._num_units, k_size, True) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split(3, 4, concat) new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j)) new_h = self._activation(new_c) * sigmoid(o) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat(3, [new_c, new_h]) return new_h, new_state def _conv(args, output_size, k_size, bias=True, bias_start=0.0, scope=None): if args is None or (_is_sequence(args) and not args): raise ValueError("`args` must be specified") if not _is_sequence(args): args = [args] # Calculate the total size of arguments on dimension 3. # (batch_size x height x width x arg_size) total_arg_size = 0 shapes = [a.get_shape().as_list() for a in args] height = shapes[0][1] width = shapes[0][2] for shape in shapes: if len(shape) != 4: raise ValueError("Conv is expecting 3D arguments: %s" % str(shapes)) if not shape[3]: raise ValueError("Conv expects shape[3] of arguments: %s" % str(shapes)) if shape[1] == height and shape[2] == width: total_arg_size += shape[3] else : raise ValueError("Inconsistent height and width size in arguments: %s" % str(shapes)) with vs.variable_scope(scope or "Conv"): kernel = vs.get_variable("Kernel", [k_size, k_size, total_arg_size, output_size]) if len(args) == 1: res = tf.nn.conv2d(args[0], kernel, [1, 1, 1, 1], padding='SAME') else: res = tf.nn.conv2d(array_ops.concat(3, args), kernel, [1, 1, 1, 1], padding='SAME') if not bias: return res bias_term = vs.get_variable( "Bias", [output_size], initializer=init_ops.constant_initializer(bias_start)) return res + bias_term