"""
 Reference: https://github.com/iwyoo/ConvLSTMCell-tensorflow
"""
import tensorflow as tf

from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops.math_ops import sigmoid
from tensorflow.python.ops.math_ops import tanh
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import init_ops
from tensorflow.python.ops.rnn_cell import LSTMStateTuple
from tensorflow.python.util import nest

# Future : Replace it with tensorflow.python.util.nest
import collections
import six

def _is_sequence(seq):
  return (isinstance(seq, collections.Sequence)
          and not isinstance(seq, six.string_types))

def ln(input, s, b, epsilon = 1e-5, max = 1000):
    """ Layer normalizes a 4D tensor along its second axis, which corresponds to batch """
    m, v = tf.nn.moments(input, [1,2,3], keep_dims=True) # for conv case ?
    normalised_input = (input - m) / tf.sqrt(v + epsilon)
    return normalised_input * s + b

class ConvGRUCell(rnn_cell.RNNCell):
  """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""

  def __init__(self, num_units, k_size=3, height=23, width=30, input_size=None, activation=tanh, initializer=None):
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated.", self)
    self._num_units = num_units
    self._activation = activation
    self._initializer = initializer
    self._k_size = k_size
    self._height = height
    self._width = width

  @property
  def state_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  def zero_state(self, batch_size=3, dtype=None):
    return tf.zeros([batch_size, self._height, self._width, self._num_units])

  def __call__(self, inputs, state, scope=None):
    """Gated recurrent unit (GRU) with nunits cells."""
    with vs.variable_scope(scope or type(self).__name__):  # "GRUCell"
      with vs.variable_scope("Gates"):  # Reset gate and update gate.
        # We start with bias of 1.0 to not reset and not update.
        r, u = array_ops.split(3, 2, _conv([inputs, state],
                                             2 * self._num_units, self._k_size, True, initializer=self._initializer))
        r, u = sigmoid(r), sigmoid(u)
      with vs.variable_scope("Candidate"):
        c = self._activation(_conv([inputs, r * state],
                                     self._num_units, self._k_size, True, initializer=self._initializer))
      new_h = u * state + (1 - u) * c
    return new_h, new_h

class ConvLSTMCell(rnn_cell.RNNCell):
  """ Convolutional LSTM network cell (ConvLSTM).
  The implementation is based on http://arxiv.org/abs/1506.04214.
   and BasicLSTMCell in TensorFlow.
   https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py

   Future : Peephole connection will be added as the full LSTMCell
            implementation of TensorFlow.
  """
  def __init__(self, num_units, k_size=3, batch_size=4, height=23, width=30, input_size=None,
               use_peepholes=False, cell_clip=None,
               initializer=None, num_proj=None, proj_clip=None,
               num_unit_shards=1, num_proj_shards=1,
               forget_bias=1.0, state_is_tuple=False,
               activation=tanh):

    if not state_is_tuple:
      logging.warn(
          "%s: Using a concatenated state is slower and will soon be "
          "deprecated.  Use state_is_tuple=True." % self)
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated." % self)

    #self._use_peepholes = use_peepholes
    #self._cell_clip = cell_clip
    #self._initializer = initializer
    #self._num_proj = num_proj
    #self._num_unit_shards = num_unit_shards
    #self._num_proj_shards = num_proj_shards

    self._num_units = num_units
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation
    self._initializer = initializer
    self._k_size = k_size
    self._height = height
    self._width = width
    self._batch_size = batch_size

  @property
  def state_size(self):
    return (LSTMStateTuple(self._num_units, self._num_units)
            if self._state_is_tuple else 2 * self._num_units)

  @property
  def output_size(self):
    return self._num_units

  def zero_state(self, batch_size=4, dtype=None):
    return tf.zeros([batch_size, self._height, self._width, self._num_units*2])

  def __call__(self, inputs, state, scope=None):
    """Convolutional Long short-term memory cell (ConvLSTM)."""
    with vs.variable_scope(scope or type(self).__name__): # "ConvLSTMCell"
      if self._state_is_tuple:
        c, h = state
      else:
        c, h = array_ops.split(3, 2, state)

      # batch_size * height * width * channel
      concat = _conv([inputs, h], 4 * self._num_units, self._k_size, True, initializer=self._initializer)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i, j, f, o = array_ops.split(3, 4, concat)

      new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
               self._activation(j))
      new_h = self._activation(new_c) * sigmoid(o)

      if self._state_is_tuple:
        new_state = LSTMStateTuple(new_c, new_h)
      else:
        new_state = array_ops.concat(3, [new_c, new_h])
      return new_h, new_state

class LNConvLSTMCell(rnn_cell.RNNCell):
  """ Convolutional LSTM network cell (ConvLSTM).
  The implementation is based on http://arxiv.org/abs/1506.04214.
   and BasicLSTMCell in TensorFlow.
   https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py

   Future : Peephole connection will be added as the full LSTMCell
            implementation of TensorFlow.
  """
  def __init__(self, num_units, k_size=3, batch_size=4, height=23, width=30, input_size=None,
               use_peepholes=False, cell_clip=None,
               initializer=None, num_proj=None, proj_clip=None,
               num_unit_shards=1, num_proj_shards=1,
               forget_bias=1.0, state_is_tuple=False,
               activation=tanh):

    if not state_is_tuple:
      logging.warn(
          "%s: Using a concatenated state is slower and will soon be "
          "deprecated.  Use state_is_tuple=True." % self)
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated." % self)

    #self._use_peepholes = use_peepholes
    #self._cell_clip = cell_clip
    #self._initializer = initializer
    #self._num_proj = num_proj
    #self._num_unit_shards = num_unit_shards
    #self._num_proj_shards = num_proj_shards

    self._num_units = num_units
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation
    self._initializer = initializer
    self._k_size = k_size
    self._height = height
    self._width = width
    self._batch_size = batch_size

  @property
  def state_size(self):
    return (LSTMStateTuple(self._num_units, self._num_units)
            if self._state_is_tuple else 2 * self._num_units)

  @property
  def output_size(self):
    return self._num_units

  def zero_state(self, batch_size=4, dtype=None):
    return tf.zeros([batch_size, self._height, self._width, self._num_units*2])

  def __call__(self, inputs, state, scope=None):
    """Convolutional Long short-term memory cell (ConvLSTM)."""
    with vs.variable_scope(scope or type(self).__name__): # "ConvLSTMCell"
      if self._state_is_tuple:
        c, h = state
      else:
        c, h = array_ops.split(3, 2, state)
      s1 = vs.get_variable("s1", initializer=tf.ones([self._height, self._width, 4 * self._num_units]), dtype=tf.float32)
      s2 = vs.get_variable("s2", initializer=tf.ones([self._height, self._width, 4 * self._num_units]), dtype=tf.float32)
      # s3 = vs.get_variable("s3", initializer=tf.ones([self._batch_size, self._num_units]), dtype=tf.float32)

      b1 = vs.get_variable("b1", initializer=tf.zeros([self._height, self._width, 4 * self._num_units]), dtype=tf.float32)
      b2 = vs.get_variable("b2", initializer=tf.zeros([self._height, self._width, 4 * self._num_units]), dtype=tf.float32)
      # b3 = vs.get_variable("b3", initializer=tf.zeros([self._batch_size, self._num_units]), dtype=tf.float32)
      input_below_ = _conv([inputs], 4 * self._num_units, self._k_size, False, initializer=self._initializer, scope="out_1")
      input_below_ = ln(input_below_, s1, b1)
      state_below_ = _conv([h], 4 * self._num_units, self._k_size, False, initializer=self._initializer, scope="out_2")
      state_below_ = ln(state_below_, s2, b2)
      lstm_matrix = tf.add(input_below_, state_below_)

      i, j, f, o = array_ops.split(3, 4, lstm_matrix)

      # batch_size * height * width * channel
      # concat = _conv([inputs, h], 4 * self._num_units, self._k_size, True, initializer=self._initializer)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      # i, j, f, o = array_ops.split(3, 4, lstm_matrix)

      new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
               self._activation(j))
      new_h = self._activation(new_c) * sigmoid(o)

      if self._state_is_tuple:
        new_state = LSTMStateTuple(new_c, new_h)
      else:
        new_state = array_ops.concat(3, [new_c, new_h])
      return new_h, new_state

class MultiRNNCell(rnn_cell.RNNCell):

  def __init__(self, cells, state_is_tuple=False):
    """
      Stacked convLSTM , modified from ops.rnn_cell MultiRNNCell
    """
    if not cells:
      raise ValueError("Must specify at least one cell for MultiRNNCell.")
    self._cells = cells
    self._state_is_tuple = state_is_tuple
    self._num_units = cells[0].output_size
    if not state_is_tuple:
      if any(nest.is_sequence(c.state_size) for c in self._cells):
        raise ValueError("Some cells return tuples of states, but the flag "
                         "state_is_tuple is not set.  State sizes are: %s"
                         % str([c.state_size for c in self._cells]))

  @property
  def state_size(self):
    if self._state_is_tuple:
      return tuple(cell.state_size for cell in self._cells)
    else:
      return sum([cell.state_size for cell in self._cells])

  @property
  def output_size(self):
    return self._cells[-1].output_size

  def zero_state(self, batch_size=3, dtype=None, height=23, width=30):
    if self._state_is_tuple:
        return [tf.zeros(1, batch_size, height, width) for i in range(len(self._cells))]
    else:
        return tf.zeros([len(self._cells), batch_size, height, width, self._num_units*2])

  def __call__(self, inputs, state, scope=None):
    """Run this multi-layer cell on inputs, starting from state."""
    with vs.variable_scope(scope or type(self).__name__):  # "MultiRNNCell"
      cur_state_pos = 0
      cur_inp = inputs
      new_states = []
      for i, cell in enumerate(self._cells):
        with vs.variable_scope("Cell%d" % i):
          if self._state_is_tuple:
            if not nest.is_sequence(state):
              raise ValueError(
                  "Expected state to be a tuple of length %d, but received: %s"
                  % (len(self.state_size), state))
            cur_state = state[i]
          else:
            # print("STATE",state)
            """
            cur_state = array_ops.slice(
                state, [0, cur_state_pos], [-1, cell.state_size])
            """
            cur_state = array_ops.unpack(state)[i]
            # cur_state_pos += cell.state_size
          cur_inp, new_state = cell(cur_inp, cur_state)
          new_states.append(new_state)
    """
    new_states = (tuple(new_states) if self._state_is_tuple
                  else array_ops.concat(1, new_states))
    """
    new_states = array_ops.pack(new_states)
    return cur_inp, new_states

def _conv(args, output_size, k_size, bias=True, bias_start=0.0, initializer=None, scope=None):
  if args is None or (_is_sequence(args) and not args):
    raise ValueError("`args` must be specified")
  if not _is_sequence(args):
    args = [args]

  # Calculate the total size of arguments on dimension 3.
  # (batch_size x height x width x arg_size)
  total_arg_size = 0
  shapes = [a.get_shape().as_list() for a in args]
  height = shapes[0][1]
  width  = shapes[0][2]
  for shape in shapes:
    if len(shape) != 4:
      raise ValueError("Conv is expecting 3D arguments: %s" % str(shapes))
    if not shape[3]:
      raise ValueError("Conv expects shape[3] of arguments: %s" % str(shapes))
    if shape[1] == height and shape[2] == width:
      total_arg_size += shape[3]
    else :
      raise ValueError("Inconsistent height and width size in arguments: %s" % str(shapes))

  with vs.variable_scope(scope or "Conv"):
    kernel = vs.get_variable("Kernel", [k_size, k_size, total_arg_size, output_size], initializer=initializer)

    if len(args) == 1:
      res = tf.nn.conv2d(args[0], kernel, [1, 1, 1, 1], padding='SAME')
    else:
      res = tf.nn.conv2d(array_ops.concat(3, args), kernel, [1, 1, 1, 1], padding='SAME')

    if not bias: return res
    bias_term = vs.get_variable( "Bias", [output_size],
      initializer=init_ops.constant_initializer(bias_start))
  return res + bias_term