"""Module for constructing RNN Cells.

## RNN Cell wrappers (RNNCells that wrap other RNNCells)

@@ZoneoutWrapper
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops.rnn_cell import RNNCell, LSTMStateTuple

__author__ = "Mohammed AlQuraishi"
__copyright__ = "Copyright 2018, Harvard Medical School"
__license__ = "MIT"

class ZoneoutWrapper(RNNCell):
  """Operator adding zoneout to hidden state and memory of the given cell."""

  def __init__(self, cell, memory_cell_keep_prob=1.0, hidden_state_keep_prob=1.0,
               seed=None, is_training=True):
    """Create a cell with hidden state and memory zoneout.

    If this class is used to wrap a Dropout cell, then it will override the output 
    Dropout but maintain input Dropout. If a Dropout cell wraps a Zoneout cell,
    then both Dropout and Zoneout will be applied to the outputs.

    This function assumes that LSTM Cells are using the new tuple-based state.

    Args:
      cell: an BasicLSTMCell or LSTMCell
      memory_cell_keep_prob: unit Tensor or float between 0 and 1, memory cell
        keep probability; if it is float and 1, no zoneout will be added.
      hidden_state_keep_prob: unit Tensor or float between 0 and 1, hidden state
        keep probability; if it is float and 1, no zoneout will be added.
      seed: (optional) integer, the randomness seed.
      is_training: boolean, determines which mode of the zoneout is used.

    Raises:
      TypeError: if cell is not a BasicLSTMCell or LSTMCell.
      ValueError: if memory_cell_keep_prob or hidden_state_keep_prob is not between 0 and 1.
    """
    # if not (isinstance(cell, BasicLSTMCell) or isinstance(cell, LSTMCell)):
    #   raise TypeError("The parameter cell is not a BasicLSTMCell or LSTMCell.")
    if (isinstance(memory_cell_keep_prob, float) and
        not (memory_cell_keep_prob >= 0.0 and memory_cell_keep_prob <= 1.0)):
      raise ValueError("Parameter memory_cell_keep_prob must be between 0 and 1: %d"
                       % memory_cell_keep_prob)
    if (isinstance(hidden_state_keep_prob, float) and
        not (hidden_state_keep_prob >= 0.0 and hidden_state_keep_prob <= 1.0)):
      raise ValueError("Parameter hidden_state_keep_prob must be between 0 and 1: %d"
                       % hidden_state_keep_prob)
    self._cell = cell
    self._memory_cell_keep_prob = memory_cell_keep_prob    
    self._hidden_state_keep_prob = hidden_state_keep_prob
    self._seed = seed
    self._is_training = is_training

    self._has_memory_cell_zoneout  = (not isinstance(self._memory_cell_keep_prob, float) or
                                      self._memory_cell_keep_prob < 1)
    self._has_hidden_state_zoneout = (not isinstance(self._hidden_state_keep_prob, float) or
                                      self._hidden_state_keep_prob < 1)

  @property
  def input_size(self):
    return self._cell.input_size

  @property
  def output_size(self):
    return self._cell.output_size

  @property
  def state_size(self):
    return self._cell.state_size

  def __call__(self, inputs, state, scope=None):
    """Run the cell with the declared zoneouts."""

    # compute output and new state as before
    output, new_state = self._cell(inputs, state, scope)

    # if either hidden state or memory cell zoneout is applied, then split state and process
    if self._has_hidden_state_zoneout or self._has_memory_cell_zoneout:
      # split state
      c_old, m_old = state
      c_new, m_new = new_state

      # apply zoneout to memory cell and hidden state
      c_and_m = []
      for s_old, s_new, p, has_zoneout in [(c_old, c_new, self._memory_cell_keep_prob,  self._has_memory_cell_zoneout), 
                                           (m_old, m_new, self._hidden_state_keep_prob, self._has_hidden_state_zoneout)]:
        if has_zoneout:
          if self._is_training:
            mask = nn_ops.dropout(array_ops.ones_like(s_new), p, seed=self._seed) * p # this should just random ops instead. See dropout code for how.
            s = ((1. - mask) * s_old) + (mask * s_new)
          else:
            s = ((1. - p) * s_old) + (p * s_new)
        else:
          s = s_new

        c_and_m.append(s)

      # package final results
      new_state = LSTMStateTuple(*c_and_m)
      output = new_state.h

    return output, new_state