# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Module for constructing fused RNN cells."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

from tensorflow.python.ops import array_ops
from tensorflow.python.ops import rnn


class FusedRNNCell(object):
  """Abstract object representing a fused RNN cell.

  A fused RNN cell represents the entire RNN expanded over the time
  dimension. In effect, this represents an entire recurrent network.

  Unlike RNN cells which are subclasses of `rnn_cell.RNNCell`, a `FusedRNNCell`
  operates on the entire time sequence at once, by putting the loop over time
  inside the cell. This usually leads to much more efficient, but more complex
  and less flexible implementations.

  Every `FusedRNNCell` must implement `__call__` with the following signature.
  """

  __metaclass__ = abc.ABCMeta

  @abc.abstractmethod
  def __call__(self,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
    """Run this fused RNN on inputs, starting from the given state.

    Args:
      inputs: `3-D` tensor with shape `[time_len x batch_size x input_size]`
        or a list of `time_len` tensors of shape `[batch_size x input_size]`.
      initial_state: either a tensor with shape `[batch_size x state_size]`
        or a tuple with shapes `[batch_size x s] for s in state_size`, if the
        cell takes tuples. If this is not provided, the cell is expected to
        create a zero initial state of type `dtype`.
      dtype: The data type for the initial state and expected output. Required
        if `initial_state` is not provided or RNN state has a heterogeneous
          dtype.
      sequence_length: Specifies the length of each sequence in inputs. An
        `int32` or `int64` vector (tensor) size `[batch_size]`, values in `[0,
        time_len)`.
        Defaults to `time_len` for each element.
      scope: `VariableScope` or `string` for the created subgraph; defaults to
        class name.

    Returns:
      A pair containing:

      - Output: A `3-D` tensor of shape `[time_len x batch_size x output_size]`
        or a list of `time_len` tensors of shape `[batch_size x output_size]`,
        to match the type of the `inputs`.
      - Final state: Either a single `2-D` tensor, or a tuple of tensors
        matching the arity and shapes of `initial_state`.
    """
    pass


class FusedRNNCellAdaptor(FusedRNNCell):
  """This is an adaptor for RNNCell classes to be used with `FusedRNNCell`."""

  def __init__(self, cell, use_dynamic_rnn=False):
    """Initialize the adaptor.

    Args:
      cell: an instance of a subclass of a `rnn_cell.RNNCell`.
      use_dynamic_rnn: whether to use dynamic (or static) RNN.
    """
    self._cell = cell
    self._use_dynamic_rnn = use_dynamic_rnn

  def __call__(self,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
    is_list = isinstance(inputs, list)
    if self._use_dynamic_rnn:
      if is_list:
        inputs = array_ops.pack(inputs)
      outputs, state = rnn.dynamic_rnn(
          self._cell,
          inputs,
          sequence_length=sequence_length,
          initial_state=initial_state,
          dtype=dtype,
          time_major=True,
          scope=scope)
      if is_list:
        # Convert outputs back to list
        outputs = array_ops.unpack(outputs)
    else:  # non-dynamic rnn
      if not is_list:
        inputs = array_ops.unpack(inputs)
      outputs, state = rnn.rnn(self._cell,
                               inputs,
                               initial_state=initial_state,
                               dtype=dtype,
                               sequence_length=sequence_length,
                               scope=scope)
      if not is_list:
        # Convert outputs back to tensor
        outputs = array_ops.pack(outputs)

    return outputs, state


class TimeReversedFusedRNN(FusedRNNCell):
  """This is an adaptor to time-reverse a FusedRNNCell.

  For example,

  ```python
  cell = tf.nn.rnn_cell.BasicRNNCell(10)
  fw_lstm = tf.contrib.rnn.FusedRNNCellAdaptor(cell, use_dynamic_rnn=True)
  bw_lstm = tf.contrib.rnn.TimeReversedFusedRNN(fw_lstm)
  fw_out, fw_state = fw_lstm(inputs)
  bw_out, bw_state = bw_lstm(inputs)
  ```
  """

  def __init__(self, cell):
    self._cell = cell

  def _reverse(self, t, lengths):
    """Time reverse the provided tensor or list of tensors.

    Assumes the top dimension is the time dimension.

    Args:
      t: 3D tensor or list of 2D tensors to be reversed
      lengths: 1D tensor of lengths, or `None`

    Returns:
      A reversed tensor or list of tensors
    """
    if isinstance(t, list):
      return list(reversed(t))
    else:
      if lengths is None:
        return array_ops.reverse(t, [True, False, False])
      else:
        return array_ops.reverse_sequence(t, lengths, 0, 1)

  def __call__(self,
               inputs,
               initial_state=None,
               dtype=None,
               sequence_length=None,
               scope=None):
    inputs = self._reverse(inputs, sequence_length)
    outputs, state = self._cell(
        inputs,
        initial_state=initial_state,
        dtype=dtype,
        sequence_length=sequence_length,
        scope=scope)
    outputs = self._reverse(outputs, sequence_length)
    return outputs, state