# coding=utf-8
# Copyright (c) 2019 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import collections
import logging

import tensorflow as tf
from tensorflow.contrib.rnn import MultiRNNCell, LSTMStateTuple
from tensorflow.python.framework import dtypes, tensor_shape
from tensorflow.python.framework import ops
from tensorflow.python.util import nest

from ludwig.models.modules.fully_connected_modules import fc_layer
from ludwig.models.modules.initializer_modules import get_initializer
from ludwig.models.modules.reduction_modules import reduce_sequence
from ludwig.utils.tf_utils import sequence_length_3D, sequence_length_2D


logger = logging.getLogger(__name__)


def get_cell_fun(cell_type):
    if cell_type == 'rnn':
        cell_fn = tf.nn.rnn_cell.BasicRNNCell
    elif cell_type == 'lstm':
        # allows for optional peephole connections and cell clipping
        cell_fn = tf.nn.rnn_cell.LSTMCell
    elif cell_type == 'lstm_block':
        # Faster version of basic LSTM
        cell_fn = tf.contrib.rnn.LSTMBlockCell
    elif cell_type == 'lstm_ln':
        cell_fn = tf.contrib.rnn.LayerNormBasicLSTMCell
    elif cell_type == 'lstm_cudnn':
        cell_fn = tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell
    elif cell_type == 'gru':
        cell_fn = tf.nn.rnn_cell.GRUCell
    elif cell_type == 'gru_block':
        # Faster version of GRU (25% faster in my tests)
        cell_fn = tf.contrib.rnn.GRUBlockCell
    elif cell_type == 'gru_cudnn':
        # Faster version of GRU (25% faster in my tests)
        cell_fn = tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell
    else:
        cell_fn = tf.nn.rnn_cell.BasicRNNCell
    return cell_fn


class Projection(tf.compat.v1.layers.Layer):
    def __init__(self, projection_weights, projection_biases, name=None,
                 **kwargs):
        super(Projection, self).__init__(name=name, **kwargs)
        self.projection_weights = projection_weights
        self.projection_biases = projection_biases

    def call(self, inputs, **kwargs):
        inputs_shape = inputs.shape.as_list()
        weights_shape = self.projection_weights.shape.as_list()
        assert inputs_shape[-1] == weights_shape[0]
        inputs = tf.reshape(inputs, [-1, inputs_shape[-1]])

        outputs = tf.matmul(inputs, self.projection_weights)
        if self.projection_biases is not None:
            outputs = tf.nn.bias_add(outputs, self.projection_biases)

        outputs_shape = inputs_shape
        outputs_shape[0] = -1  # batch_size
        outputs_shape[-1] = weights_shape[1]
        outputs = tf.reshape(outputs, outputs_shape)
        return outputs

    def compute_output_shape(self, input_shape):
        input_shape = tensor_shape.TensorShape(input_shape).as_list()
        output_shape = input_shape
        output_shape[-1] = self.projection_biases.shape.as_list()[0]
        # output_shape = [input_shape[0], self.projection_biases.shape.as_list()[0]]
        return tensor_shape.TensorShape(output_shape)


class BasicDecoderOutput(
    collections.namedtuple('BasicDecoderOutput',
                           ('rnn_output', 'sample_id', 'projection_input'))):
    pass


class BasicDecoder(tf.contrib.seq2seq.BasicDecoder):
    def _projection_input_size(self):
        return self._cell.output_size

    @property
    def output_size(self):
        return BasicDecoderOutput(
            rnn_output=self._rnn_output_size(),
            sample_id=self._helper.sample_ids_shape,
            projection_input=self._projection_input_size())

    @property
    def output_dtype(self):
        dtype = nest.flatten(self._initial_state)[0].dtype
        return BasicDecoderOutput(
            nest.map_structure(lambda _: dtype, self._rnn_output_size()),
            self._helper.sample_ids_dtype,
            nest.map_structure(lambda _: dtype, self._projection_input_size()))

    def step(self, time, inputs, state, name=None):
        with ops.name_scope(name, 'BasicDecoderStep', (time, inputs, state)):
            cell_outputs, cell_state = self._cell(inputs, state)
            projection_inputs = cell_outputs  # get projection_inputs to compute sampled_softmax_cross_entropy_loss
            if self._output_layer is not None:
                cell_outputs = self._output_layer(cell_outputs)
            sample_ids = self._helper.sample(
                time=time, outputs=cell_outputs, state=cell_state)
            (finished, next_inputs, next_state) = self._helper.next_inputs(
                time=time,
                outputs=cell_outputs,
                state=cell_state,
                sample_ids=sample_ids)
        outputs = BasicDecoderOutput(cell_outputs, sample_ids,
                                     projection_inputs)
        return (outputs, next_state, next_inputs, finished)


class TimeseriesTrainingHelper(tf.contrib.seq2seq.TrainingHelper):
    def sample(self, time, outputs, name=None, **unused_kwargs):
        with ops.name_scope(name, 'TrainingHelperSample', [time, outputs]):
            return tf.zeros(tf.shape(outputs)[:-1], dtype=dtypes.int32)


class RecurrentStack:
    def __init__(
            self,
            state_size=256,
            cell_type='rnn',
            num_layers=1,
            bidirectional=False,
            dropout=False,
            regularize=True,
            reduce_output='last',
            **kwargs
    ):
        self.state_size = state_size
        self.cell_type = cell_type
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.dropout = dropout
        self.regularize = regularize
        self.reduce_output = reduce_output

    def __call__(
            self,
            input_sequence,
            regularizer,
            dropout_rate,
            is_training=True
    ):
        if not self.regularize:
            regularizer = None

        # Calculate the length of input_sequence and the batch size
        sequence_length = sequence_length_3D(input_sequence)

        # RNN cell
        cell_fn = get_cell_fun(self.cell_type)

        # initial state
        # init_state = tf.compat.v1.get_variable(
        #   'init_state',
        #   [1, state_size],
        #   initializer=tf.constant_initializer(0.0),
        # )
        # init_state = tf.tile(init_state, [batch_size, 1])

        # main RNN operation
        with tf.compat.v1.variable_scope('rnn_stack', reuse=tf.compat.v1.AUTO_REUSE,
                               regularizer=regularizer) as vs:
            if self.bidirectional:
                # forward direction cell
                fw_cell = lambda state_size: cell_fn(state_size)
                bw_cell = lambda state_size: cell_fn(state_size)
                fw_cells = [fw_cell(self.state_size) for _ in
                            range(self.num_layers)]
                bw_cells = [bw_cell(self.state_size) for _ in
                            range(self.num_layers)]
                rnn_outputs, final_state_fw, final_state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(
                    cells_fw=fw_cells,
                    cells_bw=bw_cells,
                    dtype=tf.float32,
                    sequence_length=sequence_length,
                    inputs=input_sequence
                )

            else:
                cell = lambda state_size: cell_fn(state_size)
                cells = MultiRNNCell(
                    [cell(self.state_size) for _ in range(self.num_layers)],
                    state_is_tuple=True)
                rnn_outputs, final_state = tf.nn.dynamic_rnn(
                    cells,
                    input_sequence,
                    sequence_length=sequence_length,
                    dtype=tf.float32)
                # initial_state=init_state)

            for v in tf.global_variables():
                if v.name.startswith(vs.name):
                    logger.debug('  {}: {}'.format(v.name, v))
            logger.debug('  rnn_outputs: {0}'.format(rnn_outputs))

            rnn_output = reduce_sequence(rnn_outputs, self.reduce_output)
            logger.debug('  reduced_rnn_output: {0}'.format(rnn_output))

        # dropout
        if self.dropout and dropout_rate is not None:
            rnn_output = tf.layers.dropout(
                rnn_output,
                rate=dropout_rate,
                training=is_training
            )
            logger.debug('  dropout_rnn: {0}'.format(rnn_output))

        return rnn_output, rnn_output.shape.as_list()[-1]


def recurrent_decoder(encoder_outputs, targets, max_sequence_length, vocab_size,
                      cell_type='rnn', state_size=256, embedding_size=50,
                      num_layers=1,
                      attention_mechanism=None, beam_width=1, projection=True,
                      tied_target_embeddings=True, embeddings=None,
                      initializer=None, regularizer=None,
                      is_timeseries=False):
    with tf.compat.v1.variable_scope('rnn_decoder', reuse=tf.compat.v1.AUTO_REUSE,
                           regularizer=regularizer):

        # ================ Setup ================
        if beam_width > 1 and is_timeseries:
            raise ValueError('Invalid beam_width: {}'.format(beam_width))

        GO_SYMBOL = vocab_size
        END_SYMBOL = 0
        batch_size = tf.shape(encoder_outputs)[0]

        # ================ Projection ================
        # Project the encoder outputs to the size of the decoder state
        encoder_outputs_size = encoder_outputs.shape[-1]
        if projection and encoder_outputs_size != state_size:
            with tf.compat.v1.variable_scope('projection'):
                encoder_output_rank = len(encoder_outputs.shape)
                if encoder_output_rank > 2:
                    sequence_length = tf.shape(encoder_outputs)[1]
                    encoder_outputs = tf.reshape(encoder_outputs,
                                                 [-1, encoder_outputs_size])
                    encoder_outputs = fc_layer(encoder_outputs,
                                               encoder_outputs.shape[-1],
                                               state_size,
                                               activation=None,
                                               initializer=initializer)
                    encoder_outputs = tf.reshape(encoder_outputs,
                                                 [-1, sequence_length,
                                                  state_size])
                else:
                    encoder_outputs = fc_layer(encoder_outputs,
                                               encoder_outputs.shape[-1],
                                               state_size,
                                               activation=None,
                                               initializer=initializer)

        # ================ Targets sequence ================
        # Calculate the length of inputs and the batch size
        with tf.compat.v1.variable_scope('sequence'):
            targets_sequence_length = sequence_length_2D(targets)
            start_tokens = tf.tile([GO_SYMBOL], [batch_size])
            end_tokens = tf.tile([END_SYMBOL], [batch_size])
            if is_timeseries:
                start_tokens = tf.cast(start_tokens, tf.float32)
                end_tokens = tf.cast(end_tokens, tf.float32)
            targets_with_go_and_eos = tf.concat([
                tf.expand_dims(start_tokens, 1),
                targets,
                tf.expand_dims(end_tokens, 1)], 1)
            logger.debug('  targets_with_go: {0}'.format(targets_with_go_and_eos))
            targets_sequence_length_with_eos = targets_sequence_length + 1  # the EOS symbol is 0 so it's not increasing the real length of the sequence

        # ================ Embeddings ================
        if is_timeseries:
            targets_embedded = tf.expand_dims(targets_with_go_and_eos, -1)
            targets_embeddings = None
        else:
            with tf.compat.v1.variable_scope('embedding'):
                if embeddings is not None:
                    embedding_size = embeddings.shape.as_list()[-1]
                    if tied_target_embeddings:
                        state_size = embedding_size
                elif tied_target_embeddings:
                    embedding_size = state_size

                if embeddings is not None:
                    embedding_go = tf.compat.v1.get_variable('embedding_GO',
                                                   initializer=tf.random_uniform(
                                                       [1, embedding_size],
                                                       -1.0, 1.0))
                    targets_embeddings = tf.concat([embeddings, embedding_go],
                                                   axis=0)
                else:
                    initializer_obj = get_initializer(initializer)
                    targets_embeddings = tf.compat.v1.get_variable(
                        'embeddings',
                        initializer=initializer_obj(
                            [vocab_size + 1, embedding_size]),
                        regularizer=regularizer
                    )
                logger.debug(
                    '  targets_embeddings: {0}'.format(targets_embeddings))

                targets_embedded = tf.nn.embedding_lookup(targets_embeddings,
                                                          targets_with_go_and_eos,
                                                          name='decoder_input_embeddings')
        logger.debug('  targets_embedded: {0}'.format(targets_embedded))

        # ================ Class prediction ================
        if tied_target_embeddings:
            class_weights = tf.transpose(targets_embeddings)
        else:
            initializer_obj = get_initializer(initializer)
            class_weights = tf.compat.v1.get_variable(
                'class_weights',
                initializer=initializer_obj([state_size, vocab_size + 1]),
                regularizer=regularizer
            )
        logger.debug('  class_weights: {0}'.format(class_weights))
        class_biases = tf.compat.v1.get_variable('class_biases', [vocab_size + 1])
        logger.debug('  class_biases: {0}'.format(class_biases))
        projection_layer = Projection(class_weights, class_biases)

        # ================ RNN ================
        initial_state = encoder_outputs
        with tf.compat.v1.variable_scope('rnn_cells') as vs:
            # Cell
            cell_fun = get_cell_fun(cell_type)

            if num_layers == 1:
                cell = cell_fun(state_size)
                if cell_type.startswith('lstm'):
                    initial_state = LSTMStateTuple(c=initial_state,
                                                   h=initial_state)
            elif num_layers > 1:
                cell = MultiRNNCell(
                    [cell_fun(state_size) for _ in range(num_layers)],
                    state_is_tuple=True)
                if cell_type.startswith('lstm'):
                    initial_state = LSTMStateTuple(c=initial_state,
                                                   h=initial_state)
                initial_state = tuple([initial_state] * num_layers)
            else:
                raise ValueError('num_layers in recurrent decoser: {}. '
                                 'Number of layers in a recurrenct decoder cannot be <= 0'.format(
                    num_layers))

            # Attention
            if attention_mechanism is not None:
                if attention_mechanism == 'bahdanau':
                    attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(
                        num_units=state_size, memory=encoder_outputs,
                        memory_sequence_length=sequence_length_3D(
                            encoder_outputs))
                elif attention_mechanism == 'luong':
                    attention_mechanism = tf.contrib.seq2seq.LuongAttention(
                        num_units=state_size, memory=encoder_outputs,
                        memory_sequence_length=sequence_length_3D(
                            encoder_outputs))
                else:
                    raise ValueError(
                        'Attention mechanism {} not supported'.format(
                            attention_mechanism))
                cell = tf.contrib.seq2seq.AttentionWrapper(
                    cell, attention_mechanism, attention_layer_size=state_size)
                initial_state = cell.zero_state(
                    dtype=tf.float32,
                    batch_size=batch_size)
                initial_state = initial_state.clone(
                    cell_state=reduce_sequence(encoder_outputs, 'last'))

            for v in tf.global_variables():
                if v.name.startswith(vs.name):
                    logger.debug('  {}: {}'.format(v.name, v))

        # ================ Decoding ================
        def decode(initial_state, cell, helper, beam_width=1,
                   projection_layer=None):
            # The decoder itself
            if beam_width > 1:
                # Tile inputs for beam search decoder
                beam_initial_state = tf.contrib.seq2seq.tile_batch(
                    initial_state, beam_width)
                decoder = tf.contrib.seq2seq.BeamSearchDecoder(
                    cell=cell,
                    embedding=targets_embeddings,
                    start_tokens=start_tokens,
                    end_token=END_SYMBOL,
                    initial_state=beam_initial_state,
                    beam_width=beam_width,
                    output_layer=projection_layer)
            else:
                decoder = BasicDecoder(
                    cell=cell, helper=helper,
                    initial_state=initial_state,
                    output_layer=projection_layer)

            # The decoding operation
            outputs = tf.contrib.seq2seq.dynamic_decode(
                decoder=decoder,
                output_time_major=False,
                impute_finished=False if beam_width > 1 else True,
                maximum_iterations=max_sequence_length
            )

            return outputs

        # ================ Decoding helpers ================
        if is_timeseries:
            train_helper = TimeseriesTrainingHelper(
                inputs=targets_embedded,
                sequence_length=targets_sequence_length_with_eos)
            final_outputs_pred, final_state_pred, final_sequence_lengths_pred = decode(
                initial_state,
                cell,
                train_helper,
                projection_layer=projection_layer)
            eval_logits = final_outputs_pred.rnn_output
            train_logits = final_outputs_pred.projection_input
            predictions_sequence = tf.reshape(eval_logits, [batch_size, -1])
            predictions_sequence_length_with_eos = final_sequence_lengths_pred

        else:
            train_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs=targets_embedded,
                sequence_length=targets_sequence_length_with_eos)
            final_outputs_train, final_state_train, final_sequence_lengths_train = decode(
                initial_state,
                cell,
                train_helper,
                projection_layer=projection_layer)
            eval_logits = final_outputs_train.rnn_output
            train_logits = final_outputs_train.projection_input
            # train_predictions = final_outputs_train.sample_id

            pred_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(
                embedding=targets_embeddings,
                start_tokens=start_tokens,
                end_token=END_SYMBOL)
            final_outputs_pred, final_state_pred, final_sequence_lengths_pred = decode(
                initial_state,
                cell,
                pred_helper,
                beam_width,
                projection_layer=projection_layer)

            if beam_width > 1:
                predictions_sequence = final_outputs_pred.beam_search_decoder_output.predicted_ids[
                                       :, :, 0]
                # final_outputs_pred..predicted_ids[:,:,0] would work too, but it contains -1s for padding
                predictions_sequence_scores = final_outputs_pred.beam_search_decoder_output.scores[
                                              :, :, 0]
                predictions_sequence_length_with_eos = final_sequence_lengths_pred[
                                                       :, 0]
            else:
                predictions_sequence = final_outputs_pred.sample_id
                predictions_sequence_scores = final_outputs_pred.rnn_output
                predictions_sequence_length_with_eos = final_sequence_lengths_pred

    logger.debug('  train_logits: {0}'.format(train_logits))
    logger.debug('  eval_logits: {0}'.format(eval_logits))
    logger.debug('  predictions_sequence: {0}'.format(predictions_sequence))
    logger.debug('  predictions_sequence_scores: {0}'.format(
        predictions_sequence_scores))

    return (
        predictions_sequence,
        predictions_sequence_scores,
        predictions_sequence_length_with_eos,
        targets_sequence_length_with_eos,
        eval_logits,
        train_logits,
        class_weights,
        class_biases
    )