from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf

from tensorflow.contrib.layers.python.layers import layers
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest

def attention_decoder_fn_train(encoder_state,
                                                             attention_keys,
                                                             attention_values,
                                                             attention_score_fn,
                                                             attention_construct_fn,
                                                             output_alignments=False,
                                                             max_length=None,
                                                             name=None):
    """Attentional decoder function for `dynamic_rnn_decoder` during training.

    The `attention_decoder_fn_train` is a training function for an
    attention-based sequence-to-sequence model. It should be used when
    `dynamic_rnn_decoder` is in the training mode.

    The `attention_decoder_fn_train` is called with a set of the user arguments
    and returns the `decoder_fn`, which can be passed to the
    `dynamic_rnn_decoder`, such that

    ```
    dynamic_fn_train = attention_decoder_fn_train(encoder_state)
    outputs_train, state_train = dynamic_rnn_decoder(
            decoder_fn=dynamic_fn_train, ...)
    ```

    Further usage can be found in the `kernel_tests/seq2seq_test.py`.

    Args:
        encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
        attention_keys: to be compared with target states.
        attention_values: to be used to construct context vectors.
        attention_score_fn: to compute similarity between key and target states.
        attention_construct_fn: to build attention states.
        name: (default: `None`) NameScope for the decoder function;
            defaults to "simple_decoder_fn_train"

    Returns:
        A decoder function with the required interface of `dynamic_rnn_decoder`
        intended for training.
    """
    with ops.name_scope(name, "attention_decoder_fn_train", [
            encoder_state, attention_keys, attention_values, attention_score_fn,
            attention_construct_fn
    ]):
        pass

    def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
        """Decoder function used in the `dynamic_rnn_decoder` for training.

        Args:
            time: positive integer constant reflecting the current timestep.
            cell_state: state of RNNCell.
            cell_input: input provided by `dynamic_rnn_decoder`.
            cell_output: output of RNNCell.
            context_state: context state provided by `dynamic_rnn_decoder`.

        Returns:
            A tuple (done, next state, next input, emit output, next context state)
            where:

            done: `None`, which is used by the `dynamic_rnn_decoder` to indicate
            that `sequence_lengths` in `dynamic_rnn_decoder` should be used.

            next state: `cell_state`, this decoder function does not modify the
            given state.

            next input: `cell_input`, this decoder function does not modify the
            given input. The input could be modified when applying e.g. attention.

            emit output: `cell_output`, this decoder function does not modify the
            given output.

            next context state: `context_state`, this decoder function does not
            modify the given context state. The context state could be modified when
            applying e.g. beam search.
        """
        with ops.name_scope(
                name, "attention_decoder_fn_train",
                [time, cell_state, cell_input, cell_output, context_state]):
            if cell_state is None:    # first call, return encoder_state
                cell_state = encoder_state

                # init attention
                attention = _init_attention(encoder_state)
                if output_alignments:
                    context_state = tensor_array_ops.TensorArray(dtype=dtypes.float32, tensor_array_name="alignments_ta", size=max_length, dynamic_size=True, infer_shape=False)
            else:
                # construct attention
                #cell_output = tf.Print(cell_output, [context_state.stack()], summarize=1e8)
                attention = attention_construct_fn(cell_output, attention_keys, attention_values)
                if output_alignments:
                    attention, alignments = attention
                    context_state = context_state.write(time-1, alignments)

                cell_output = attention

            # combine cell_input and attention
            next_input = array_ops.concat([cell_input, attention], 1)

            return (None, cell_state, next_input, cell_output, context_state)

    return decoder_fn


def attention_decoder_fn_inference(output_fn,
                                                                     encoder_state,
                                                                     attention_keys,
                                                                     attention_values,
                                                                     attention_score_fn,
                                                                     attention_construct_fn,
                                                                     embeddings,
                                                                     start_of_sequence_id,
                                                                     end_of_sequence_id,
                                                                     maximum_length,
                                                                     num_decoder_symbols,
                                                                     dtype=dtypes.int32,
                                                                     selector_fn=None,
                                                                     imem=None,
                                                                     name=None):
    """Attentional decoder function for `dynamic_rnn_decoder` during inference.

    The `attention_decoder_fn_inference` is a simple inference function for a
    sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is
    in the inference mode.

    The `attention_decoder_fn_inference` is called with user arguments
    and returns the `decoder_fn`, which can be passed to the
    `dynamic_rnn_decoder`, such that

    ```
    dynamic_fn_inference = attention_decoder_fn_inference(...)
    outputs_inference, state_inference = dynamic_rnn_decoder(
            decoder_fn=dynamic_fn_inference, ...)
    ```

    Further usage can be found in the `kernel_tests/seq2seq_test.py`.

    Args:
        output_fn: An output function to project your `cell_output` onto class
        logits.

        An example of an output function;

        ```
            tf.variable_scope("decoder") as varscope
                output_fn = lambda x: layers.linear(x, num_decoder_symbols,
                                                                                        scope=varscope)

                outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...)
                logits_train = output_fn(outputs_train)

                varscope.reuse_variables()
                logits_inference, state_inference = seq2seq.dynamic_rnn_decoder(
                        output_fn=output_fn, ...)
        ```

        If `None` is supplied it will act as an identity function, which
        might be wanted when using the RNNCell `OutputProjectionWrapper`.

        encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
        attention_keys: to be compared with target states.
        attention_values: to be used to construct context vectors.
        attention_score_fn: to compute similarity between key and target states.
        attention_construct_fn: to build attention states.
        embeddings: The embeddings matrix used for the decoder sized
        `[num_decoder_symbols, embedding_size]`.
        start_of_sequence_id: The start of sequence ID in the decoder embeddings.
        end_of_sequence_id: The end of sequence ID in the decoder embeddings.
        maximum_length: The maximum allowed of time steps to decode.
        num_decoder_symbols: The number of classes to decode at each time step.
        dtype: (default: `dtypes.int32`) The default data type to use when
        handling integer objects.
        name: (default: `None`) NameScope for the decoder function;
            defaults to "attention_decoder_fn_inference"

    Returns:
        A decoder function with the required interface of `dynamic_rnn_decoder`
        intended for inference.
    """
    with ops.name_scope(name, "attention_decoder_fn_inference", [
            output_fn, encoder_state, attention_keys, attention_values,
            attention_score_fn, attention_construct_fn, embeddings, imem,
            start_of_sequence_id, end_of_sequence_id, maximum_length,
            num_decoder_symbols, dtype
    ]):
        start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype)
        end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype)
        maximum_length = ops.convert_to_tensor(maximum_length, dtype)
        num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype)
        encoder_info = nest.flatten(encoder_state)[0]
        batch_size = encoder_info.get_shape()[0].value
        if output_fn is None:
            output_fn = lambda x: x
        if batch_size is None:
            batch_size = array_ops.shape(encoder_info)[0]

    def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
        """Decoder function used in the `dynamic_rnn_decoder` for inference.

        The main difference between this decoder function and the `decoder_fn` in
        `attention_decoder_fn_train` is how `next_cell_input` is calculated. In
        decoder function we calculate the next input by applying an argmax across
        the feature dimension of the output from the decoder. This is a
        greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014)
        use beam-search instead.

        Args:
            time: positive integer constant reflecting the current timestep.
            cell_state: state of RNNCell.
            cell_input: input provided by `dynamic_rnn_decoder`.
            cell_output: output of RNNCell.
            context_state: context state provided by `dynamic_rnn_decoder`.

        Returns:
            A tuple (done, next state, next input, emit output, next context state)
            where:

            done: A boolean vector to indicate which sentences has reached a
            `end_of_sequence_id`. This is used for early stopping by the
            `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with
            all elements as `true` is returned.

            next state: `cell_state`, this decoder function does not modify the
            given state.

            next input: The embedding from argmax of the `cell_output` is used as
            `next_input`.

            emit output: If `output_fn is None` the supplied `cell_output` is
            returned, else the `output_fn` is used to update the `cell_output`
            before calculating `next_input` and returning `cell_output`.

            next context state: `context_state`, this decoder function does not
            modify the given context state. The context state could be modified when
            applying e.g. beam search.

        Raises:
            ValueError: if cell_input is not None.

        """
        with ops.name_scope(
                name, "attention_decoder_fn_inference",
                [time, cell_state, cell_input, cell_output, context_state]):
            if cell_input is not None:
                raise ValueError("Expected cell_input to be None, but saw: %s" %
                                                 cell_input)
            if cell_output is None:
                # invariant that this is time == 0
                next_input_id = array_ops.ones(
                        [batch_size,], dtype=dtype) * (start_of_sequence_id)
                done = array_ops.zeros([batch_size,], dtype=dtypes.bool)
                cell_state = encoder_state
                cell_output = array_ops.zeros(
                        [num_decoder_symbols], dtype=dtypes.float32)
                cell_input = array_ops.gather(embeddings, next_input_id)

                # init attention
                attention = _init_attention(encoder_state)
                if imem is not None:
                    context_state = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="output_ids_ta", size=maximum_length, dynamic_size=True, infer_shape=False)
            else:
                # construct attention
                attention = attention_construct_fn(cell_output, attention_keys,
                                                                                     attention_values)
                if type(attention) is tuple:
                    attention, alignment = attention
                    cell_output = attention
                    alignment = tf.reshape(alignment, [batch_size, -1])
                    #cell_output = output_fn(cell_output)    # logits
                    #next_input_id = math_ops.cast(
                    #        math_ops.argmax(cell_output, 1), dtype=dtype)
                    #done = math_ops.equal(next_input_id, end_of_sequence_id)
                    #cell_input = array_ops.gather(embeddings, next_input_id)
                    selector = selector_fn(cell_output)
                    logit = output_fn(cell_output)
                    word_prob = nn_ops.softmax(logit) * (1 - selector)
                    entity_prob = alignment * selector
                    mask = array_ops.reshape(math_ops.cast(math_ops.greater(tf.reduce_max(word_prob, 1), tf.reduce_max(entity_prob, 1)), dtype=dtypes.float32), [-1,1])
                    cell_input = mask * array_ops.gather(embeddings, math_ops.cast(math_ops.argmax(word_prob, 1), dtype=dtype)) + (1 - mask) * array_ops.gather_nd(imem, array_ops.concat([array_ops.reshape(math_ops.range(batch_size, dtype=dtype), [-1,1]), array_ops.reshape(math_ops.cast(math_ops.argmax(entity_prob, 1), dtype=dtype), [-1,1])], axis=1))

                    mask = array_ops.reshape(math_ops.cast(mask, dtype=dtype), [-1])
                    input_id = mask * math_ops.cast(math_ops.argmax(word_prob, 1), dtype=dtype) + (mask - 1) * math_ops.cast(math_ops.argmax(entity_prob, 1), dtype=dtype)
                    context_state = context_state.write(time-1, input_id)
                    done = array_ops.reshape(math_ops.equal(input_id, end_of_sequence_id), [-1])
                    #done = tf.Print(done, ['selector', selector, 'mask', mask], summarize=1e6)
                    cell_output = logit

                else:
                    cell_output = attention

                    # argmax decoder
                    cell_output = output_fn(cell_output)    # logits
                    next_input_id = math_ops.cast(
                            math_ops.argmax(cell_output, 1), dtype=dtype)
                    done = math_ops.equal(next_input_id, end_of_sequence_id)
                    cell_input = array_ops.gather(embeddings, next_input_id)

            # combine cell_input and attention
            next_input = array_ops.concat([cell_input, attention], 1)

            # if time > maxlen, return all true vector
            done = control_flow_ops.cond(
                    math_ops.greater(time, maximum_length),
                    lambda: array_ops.ones([batch_size,], dtype=dtypes.bool),
                    lambda: done)
            return (done, cell_state, next_input, cell_output, context_state)

    return decoder_fn

def attention_decoder_fn_beam_inference(output_fn,
                                       encoder_state,
                                       attention_keys,
                                       attention_values,
                                       attention_score_fn,
                                       attention_construct_fn,
                                       embeddings,
                                       start_of_sequence_id,
                                       end_of_sequence_id,
                                       maximum_length,
                                       num_decoder_symbols,
                                       beam_size,
                                       remove_unk=False,
                                       d_rate=0.0,
                                       dtype=dtypes.int32,
                                       name=None):
    """Attentional decoder function for `dynamic_rnn_decoder` during inference.
    The `attention_decoder_fn_inference` is a simple inference function for a
    sequence-to-sequence model. It should be used when `dynamic_rnn_decoder` is
    in the inference mode.
    The `attention_decoder_fn_inference` is called with user arguments
    and returns the `decoder_fn`, which can be passed to the
    `dynamic_rnn_decoder`, such that
    ```
    dynamic_fn_inference = attention_decoder_fn_inference(...)
    outputs_inference, state_inference = dynamic_rnn_decoder(
            decoder_fn=dynamic_fn_inference, ...)
    ```
    Further usage can be found in the `kernel_tests/seq2seq_test.py`.
    Args:
        output_fn: An output function to project your `cell_output` onto class
        logits.
        An example of an output function;
        ```
            tf.variable_scope("decoder") as varscope
                output_fn = lambda x: layers.linear(x, num_decoder_symbols,
                                                                                        scope=varscope)
                outputs_train, state_train = seq2seq.dynamic_rnn_decoder(...)
                logits_train = output_fn(outputs_train)
                varscope.reuse_variables()
                logits_inference, state_inference = seq2seq.dynamic_rnn_decoder(
                        output_fn=output_fn, ...)
        ```
        If `None` is supplied it will act as an identity function, which
        might be wanted when using the RNNCell `OutputProjectionWrapper`.
        encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
        attention_keys: to be compared with target states.
        attention_values: to be used to construct context vectors.
        attention_score_fn: to compute similarity between key and target states.
        attention_construct_fn: to build attention states.
        embeddings: The embeddings matrix used for the decoder sized
        `[num_decoder_symbols, embedding_size]`.
        start_of_sequence_id: The start of sequence ID in the decoder embeddings.
        end_of_sequence_id: The end of sequence ID in the decoder embeddings.
        maximum_length: The maximum allowed of time steps to decode.
        num_decoder_symbols: The number of classes to decode at each time step.
        dtype: (default: `dtypes.int32`) The default data type to use when
        handling integer objects.
        name: (default: `None`) NameScope for the decoder function;
            defaults to "attention_decoder_fn_inference"
    Returns:
        A decoder function with the required interface of `dynamic_rnn_decoder`
        intended for inference.
    """
    with ops.name_scope(name, "attention_decoder_fn_inference", [
            output_fn, encoder_state, attention_keys, attention_values,
            attention_score_fn, attention_construct_fn, embeddings,
            start_of_sequence_id, end_of_sequence_id, maximum_length,
            num_decoder_symbols, dtype
    ]):
        state_size = int(encoder_state[0].get_shape().with_rank(2)[1])
        state = []
        for s in encoder_state:
            state.append(array_ops.reshape(array_ops.concat([array_ops.reshape(s, [-1, 1, state_size])]*beam_size, 1), [-1, state_size]))
        encoder_state = tuple(state)
        origin_batch = array_ops.shape(attention_values)[0]
        attn_length = array_ops.shape(attention_values)[1]
        attention_values = array_ops.reshape(array_ops.concat([array_ops.reshape(attention_values, [-1, 1, attn_length, state_size])]*beam_size, 1), [-1, attn_length, state_size])
        attn_size = array_ops.shape(attention_keys)[2]
        attention_keys = array_ops.reshape(array_ops.concat([array_ops.reshape(attention_keys, [-1, 1, attn_length, attn_size])]*beam_size, 1), [-1, attn_length, attn_size])
        start_of_sequence_id = ops.convert_to_tensor(start_of_sequence_id, dtype)
        end_of_sequence_id = ops.convert_to_tensor(end_of_sequence_id, dtype)
        maximum_length = ops.convert_to_tensor(maximum_length, dtype)
        num_decoder_symbols = ops.convert_to_tensor(num_decoder_symbols, dtype)
        encoder_info = nest.flatten(encoder_state)[0]
        batch_size = encoder_info.get_shape()[0].value
        if output_fn is None:
            output_fn = lambda x: x
        if batch_size is None:
            batch_size = array_ops.shape(encoder_info)[0]
        #beam_size = ops.convert_to_tensor(beam_size, dtype)

    def decoder_fn(time, cell_state, cell_input, cell_output, context_state):
        """Decoder function used in the `dynamic_rnn_decoder` for inference.
        The main difference between this decoder function and the `decoder_fn` in
        `attention_decoder_fn_train` is how `next_cell_input` is calculated. In
        decoder function we calculate the next input by applying an argmax across
        the feature dimension of the output from the decoder. This is a
        greedy-search approach. (Bahdanau et al., 2014) & (Sutskever et al., 2014)
        use beam-search instead.
        Args:
            time: positive integer constant reflecting the current timestep.
            cell_state: state of RNNCell.
            cell_input: input provided by `dynamic_rnn_decoder`.
            cell_output: output of RNNCell.
            context_state: context state provided by `dynamic_rnn_decoder`.
        Returns:
            A tuple (done, next state, next input, emit output, next context state)
            where:
            done: A boolean vector to indicate which sentences has reached a
            `end_of_sequence_id`. This is used for early stopping by the
            `dynamic_rnn_decoder`. When `time>=maximum_length` a boolean vector with
            all elements as `true` is returned.
            next state: `cell_state`, this decoder function does not modify the
            given state.
            next input: The embedding from argmax of the `cell_output` is used as
            `next_input`.
            emit output: If `output_fn is None` the supplied `cell_output` is
            returned, else the `output_fn` is used to update the `cell_output`
            before calculating `next_input` and returning `cell_output`.
            next context state: `context_state`, this decoder function does not
            modify the given context state. The context state could be modified when
            applying e.g. beam search.
        Raises:
            ValueError: if cell_input is not None.
        """
        with ops.name_scope(
                name, "attention_decoder_fn_inference",
                [time, cell_state, cell_input, cell_output, context_state]):
            if cell_input is not None:
                raise ValueError("Expected cell_input to be None, but saw: %s" %
                                                 cell_input)
            if cell_output is None:
                # invariant that this is time == 0
                next_input_id = array_ops.ones(
                        [batch_size,], dtype=dtype) * (start_of_sequence_id)
                done = array_ops.zeros([batch_size,], dtype=dtypes.bool)
                cell_state = encoder_state
                cell_output = array_ops.zeros(
                        [num_decoder_symbols], dtype=dtypes.float32)
                cell_input = array_ops.gather(embeddings, next_input_id)

                # init attention
                attention = _init_attention(encoder_state)
                # init context state
                log_beam_probs = tensor_array_ops.TensorArray(dtype=dtypes.float32, tensor_array_name="log_beam_probs", size=maximum_length, dynamic_size=True, infer_shape=False)
                beam_parents = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="beam_parents", size=maximum_length, dynamic_size=True, infer_shape=False)
                beam_symbols = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="beam_symbols", size=maximum_length, dynamic_size=True, infer_shape=False)
                result_probs = tensor_array_ops.TensorArray(dtype=dtypes.float32, tensor_array_name="result_probs", size=maximum_length, dynamic_size=True, infer_shape=False)
                result_parents = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="result_parents", size=maximum_length, dynamic_size=True, infer_shape=False)
                result_symbols = tensor_array_ops.TensorArray(dtype=dtypes.int32, tensor_array_name="result_symbols", size=maximum_length, dynamic_size=True, infer_shape=False)
                context_state = (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols)
            else:
                # construct attention
                attention = attention_construct_fn(cell_output, attention_keys,
                        attention_values)
                cell_output = attention

                # beam search decoder
                (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols) = context_state
                
                cell_output = output_fn(cell_output)    # logits
                cell_output = nn_ops.softmax(cell_output)
                

                cell_output = array_ops.split(cell_output, [2, num_decoder_symbols-2], 1)[1]

                tmp_output = array_ops.gather(cell_output, math_ops.range(origin_batch)*beam_size)

                probs = control_flow_ops.cond(
                        math_ops.equal(time, ops.convert_to_tensor(1, dtype)),
                        lambda: math_ops.log(tmp_output+ops.convert_to_tensor(1e-20, dtypes.float32)),
                        lambda: math_ops.log(cell_output+ops.convert_to_tensor(1e-20, dtypes.float32)) + array_ops.reshape(log_beam_probs.read(time-2), [-1, 1]))

                probs = array_ops.reshape(probs, [origin_batch, -1])
                best_probs, indices = nn_ops.top_k(probs, beam_size * 2)
                #indices = array_ops.reshape(indices, [-1])
                indices_flatten = array_ops.reshape(indices, [-1]) + array_ops.reshape(array_ops.concat([array_ops.reshape(math_ops.range(origin_batch)*((num_decoder_symbols-2)*beam_size), [-1, 1])]*(beam_size*2), 1), [origin_batch*beam_size*2])
                best_probs_flatten = array_ops.reshape(best_probs, [-1])

                symbols = indices_flatten % (num_decoder_symbols - 2)
                symbols = symbols + 2
                parents = indices_flatten // (num_decoder_symbols - 2)

                probs_wo_eos = best_probs + 1e5*math_ops.cast(math_ops.cast((indices%(num_decoder_symbols-2)+2)-end_of_sequence_id, dtypes.bool), dtypes.float32)
                
                best_probs_wo_eos, indices_wo_eos = nn_ops.top_k(probs_wo_eos, beam_size)

                indices_wo_eos = array_ops.reshape(indices_wo_eos, [-1]) + array_ops.reshape(array_ops.concat([array_ops.reshape(math_ops.range(origin_batch)*(beam_size*2), [-1, 1])]*beam_size, 1), [origin_batch*beam_size])

                _probs = array_ops.gather(best_probs_flatten, indices_wo_eos)
                _symbols = array_ops.gather(symbols, indices_wo_eos)
                _parents = array_ops.gather(parents, indices_wo_eos)


                log_beam_probs = log_beam_probs.write(time-1, _probs)
                beam_symbols = beam_symbols.write(time-1, _symbols)
                beam_parents = beam_parents.write(time-1, _parents)
                result_probs = result_probs.write(time-1, best_probs_flatten)
                result_symbols = result_symbols.write(time-1, symbols)
                result_parents = result_parents.write(time-1, parents)


                next_input_id = array_ops.reshape(_symbols, [batch_size])

                state_size = int(cell_state[0].get_shape().with_rank(2)[1])
                attn_size = int(attention.get_shape().with_rank(2)[1])
                state = []
                for j in cell_state:
                    state.append(array_ops.reshape(array_ops.gather(j, _parents), [-1, state_size]))
                cell_state = tuple(state)
                attention = array_ops.reshape(array_ops.gather(attention, _parents), [-1, attn_size])

                done = math_ops.equal(next_input_id, end_of_sequence_id)
                cell_input = array_ops.gather(embeddings, next_input_id)

            # combine cell_input and attention
            next_input = array_ops.concat([cell_input, attention], 1)

            # if time > maxlen, return all true vector
            done = control_flow_ops.cond(
                    math_ops.greater(time, maximum_length),
                    lambda: array_ops.ones([batch_size,], dtype=dtypes.bool),
                    lambda: array_ops.zeros([batch_size,], dtype=dtypes.bool))
            return (done, cell_state, next_input, cell_output, (log_beam_probs, beam_parents, beam_symbols, result_probs, result_parents, result_symbols))#context_state)

    return decoder_fn

## Helper functions ##
def prepare_attention(attention_states,
                          attention_option,
                          num_units,
                          imem=None,
                          output_alignments=False,
                          reuse=False):
    """Prepare keys/values/functions for attention.
    Args:
        attention_states: hidden states to attend over.
        attention_option: how to compute attention, either "luong" or "bahdanau".
        num_units: hidden state dimension.
        reuse: whether to reuse variable scope.
    Returns:
        attention_keys: to be compared with target states.
        attention_values: to be used to construct context vectors.
        attention_score_fn: to compute similarity between key and target states.
        attention_construct_fn: to build attention states.
    """

    # Prepare attention keys / values from attention_states
    with variable_scope.variable_scope("attention_keys", reuse=reuse) as scope:
        attention_keys = layers.linear(
            attention_states, num_units, biases_initializer=None, scope=scope)
        attention_values = attention_states

    if imem is not None:
        if type(imem) is tuple:
            with variable_scope.variable_scope("imem_graph", reuse=reuse) as scope:
                attention_keys2, attention_states2 = array_ops.split(layers.linear(
                    imem[0], num_units*2, biases_initializer=None, scope=scope), [num_units, num_units], axis=2)
            with variable_scope.variable_scope("imem_triple", reuse=reuse) as scope:
                attention_keys3, attention_states3 = array_ops.split(layers.linear(
                    imem[1], num_units*2, biases_initializer=None, scope=scope), [num_units, num_units], axis=3)
            attention_keys = (attention_keys, attention_keys2, attention_keys3)
            attention_values = (attention_states, attention_states2, attention_states3)
        else:
            with variable_scope.variable_scope("imem", reuse=reuse) as scope:
                attention_keys2, attention_states2 = array_ops.split(layers.linear(
                    imem, num_units*2, biases_initializer=None, scope=scope), [num_units, num_units], axis=2)
                attention_keys = (attention_keys, attention_keys2)
                attention_values = (attention_states, attention_states2)

        

    # Attention score function
    if imem is None:
        attention_score_fn = _create_attention_score_fn("attention_score", num_units,
                                                            attention_option, reuse)
    else:
        attention_score_fn = (_create_attention_score_fn("attention_score", num_units,
                                                            attention_option, reuse),
                            _create_attention_score_fn("imem_score", num_units,
                                                            "luong", reuse, output_alignments=output_alignments))

    # Attention construction function
    attention_construct_fn = _create_attention_construct_fn("attention_construct",
                                  num_units,
                                  attention_score_fn,
                                  reuse)

    return (attention_keys, attention_values, attention_score_fn,
                    attention_construct_fn)


def _init_attention(encoder_state):
    """Initialize attention. Handling both LSTM and GRU.
    Args:
        encoder_state: The encoded state to initialize the `dynamic_rnn_decoder`.
    Returns:
        attn: initial zero attention vector.
    """

    # Multi- vs single-layer
    # TODO(thangluong): is this the best way to check?
    if isinstance(encoder_state, tuple):
        top_state = encoder_state[-1]
    else:
        top_state = encoder_state

    # LSTM vs GRU
    if isinstance(top_state, rnn_cell_impl.LSTMStateTuple):
        attn = array_ops.zeros_like(top_state.h)
    else:
        attn = array_ops.zeros_like(top_state)

    return attn


def _create_attention_construct_fn(name, num_units, attention_score_fn, reuse):
    """Function to compute attention vectors.
    Args:
        name: to label variables.
        num_units: hidden state dimension.
        attention_score_fn: to compute similarity between key and target states.
        reuse: whether to reuse variable scope.
    Returns:
        attention_construct_fn: to build attention states.
    """
    with variable_scope.variable_scope(name, reuse=reuse) as scope:

        def construct_fn(attention_query, attention_keys, attention_values):
            alignments = None
            if type(attention_score_fn) is tuple:
                context0 = attention_score_fn[0](attention_query, attention_keys[0],
                                                                         attention_values[0])
                if len(attention_keys) == 2:
                    context1 = attention_score_fn[1](attention_query, attention_keys[1],
                                                                             attention_values[1])
                elif len(attention_keys) == 3:
                    context1 = attention_score_fn[1](attention_query, attention_keys[1:],
                            attention_values[1:])
                if type(context1) is tuple:
                    if len(context1) == 2:
                        context1, alignments = context1
                        concat_input = array_ops.concat([attention_query, context0, context1], 1)
                    elif len(context1) == 3:
                        context1, context2, alignments = context1
                        concat_input = array_ops.concat([attention_query, context0, context1, context2], 1)
                else:
                    concat_input = array_ops.concat([attention_query, context0, context1], 1)
            else:
                context = attention_score_fn(attention_query, attention_keys,
                                                                         attention_values)
                concat_input = array_ops.concat([attention_query, context], 1)
            attention = layers.linear(
                    concat_input, num_units, biases_initializer=None, scope=scope)
            if alignments is None:
                return attention
            else:
                return attention, alignments

        return construct_fn


# keys: [batch_size, attention_length, attn_size]
# query: [batch_size, 1, attn_size]
# return weights [batch_size, attention_length]
@function.Defun(func_name="attn_add_fun", noinline=True)
def _attn_add_fun(v, keys, query):
    return math_ops.reduce_sum(v * math_ops.tanh(keys + query), [2])


@function.Defun(func_name="attn_mul_fun", noinline=True)
def _attn_mul_fun(keys, query):
    return math_ops.reduce_sum(keys * query, [2])


def _create_attention_score_fn(name,
                                   num_units,
                                   attention_option,
                                   reuse,
                                   output_alignments=False,
                                   dtype=dtypes.float32):
    """Different ways to compute attention scores.
    Args:
        name: to label variables.
        num_units: hidden state dimension.
        attention_option: how to compute attention, either "luong" or "bahdanau".
            "bahdanau": additive (Bahdanau et al., ICLR'2015)
            "luong": multiplicative (Luong et al., EMNLP'2015)
        reuse: whether to reuse variable scope.
        dtype: (default: `dtypes.float32`) data type to use.
    Returns:
        attention_score_fn: to compute similarity between key and target states.
    """
    with variable_scope.variable_scope(name, reuse=reuse):
        if attention_option == "bahdanau":
            query_w = variable_scope.get_variable(
                    "attnW", [num_units, num_units], dtype=dtype)
            score_v = variable_scope.get_variable("attnV", [num_units], dtype=dtype)

        def attention_score_fn(query, keys, values):
            """Put attention masks on attention_values using attention_keys and query.
            Args:
                query: A Tensor of shape [batch_size, num_units].
                keys: A Tensor of shape [batch_size, attention_length, num_units].
                values: A Tensor of shape [batch_size, attention_length, num_units].
            Returns:
                context_vector: A Tensor of shape [batch_size, num_units].
            Raises:
                ValueError: if attention_option is neither "luong" or "bahdanau".
            """
            triple_keys, triple_values = None, None

            if type(keys) is tuple:
                keys, triple_keys = keys
                values, triple_values = values

            if attention_option == "bahdanau":
                # transform query
                query = math_ops.matmul(query, query_w)

                # reshape query: [batch_size, 1, num_units]
                query = array_ops.reshape(query, [-1, 1, num_units])
                    

                # attn_fun
                scores = _attn_add_fun(score_v, keys, query)
            elif attention_option == "luong":
                # reshape query: [batch_size, 1, num_units]
                query = array_ops.reshape(query, [-1, 1, num_units])

                # attn_fun
                scores = _attn_mul_fun(keys, query)
            else:
                raise ValueError("Unknown attention option %s!" % attention_option)

            # Compute alignment weights
            #     scores: [batch_size, length]
            #     alignments: [batch_size, length]
            # TODO(thangluong): not normalize over padding positions.
            alignments = nn_ops.softmax(scores)
            #alignments = tf.Print(alignments, [alignments], summarize=1000)


            # Now calculate the attention-weighted vector.
            new_alignments = array_ops.expand_dims(alignments, 2)
            context_vector = math_ops.reduce_sum(new_alignments * values, [1])
            
            context_vector.set_shape([None, num_units])
            
            if triple_values is not None:
                triple_scores = math_ops.reduce_sum(triple_keys * array_ops.reshape(query, [-1, 1, 1, num_units]), [3])
                triple_alignments = nn_ops.softmax(triple_scores)
                context_triples = math_ops.reduce_sum(array_ops.expand_dims(triple_alignments, 3) * triple_values, [2])
                context_graph_triples = math_ops.reduce_sum(new_alignments * context_triples, [1])
                context_graph_triples.set_shape([None, num_units])
                return context_vector, context_graph_triples, new_alignments * triple_alignments 
            else:
                if output_alignments:
                    return context_vector, alignments
                else:
                    return context_vector

        return attention_score_fn