import os
import tensorflow as tf
from tensorflow.contrib.rnn import DropoutWrapper, ResidualWrapper
from tensorflow.contrib.rnn import BasicRNNCell
from tensorflow.contrib.rnn import GRUCell
from tensorflow.contrib.rnn import LSTMCell
from tensorflow.contrib.rnn import MultiRNNCell
import tensorflow.contrib.seq2seq as seq2seq
from tensorflow.python.layers.core import Dense

from icecaps.estimators.abstract_icecaps_estimator import AbstractIcecapsEstimator
from icecaps.util.vocabulary import Vocabulary


class AbstractRecurrentEstimator(AbstractIcecapsEstimator):

    @classmethod
    def construct_expected_params(cls):
        expected_params = super().construct_expected_params()
        expected_params["max_length"] = cls.make_param(50)
        expected_params["cell_type"] = cls.make_param('gru')
        expected_params["hidden_units"] = cls.make_param(32)
        expected_params["depth"] = cls.make_param(1)
        expected_params["token_embed_dim"] = cls.make_param(16)
        expected_params["tie_token_embeddings"] = cls.make_param(True)
        expected_params["beam_width"] = cls.make_param(8)
        expected_params["vocab_file"] = cls.make_param("./dummy_data/vocab.dic")
        expected_params["vocab_size"] = cls.make_param(0)
        expected_params["skip_tokens"] = cls.make_param('')
        expected_params["skip_tokens_start"] = cls.make_param('')
        return expected_params

    def extract_args(self, features, mode, params):
        super().extract_args(features, mode, params)
        if self.hparams.vocab_size > 0:
            self.vocab = Vocabulary(size=self.hparams.vocab_size)
        else:
            self.vocab = Vocabulary(fname=self.hparams.vocab_file, skip_tokens=self.hparams.skip_tokens, skip_tokens_start=self.hparams.skip_tokens_start)

    def build_cell(self, name=None):
        if self.hparams.cell_type == 'linear':
            cell = BasicRNNCell(self.hparams.hidden_units,
                                activation=tf.identity, name=name)
        elif self.hparams.cell_type == 'tanh':
            cell = BasicRNNCell(self.hparams.hidden_units,
                                activation=tf.tanh, name=name)
        elif self.hparams.cell_type == 'relu':
            cell = BasicRNNCell(self.hparams.hidden_units,
                                activation=tf.nn.relu, name=name)
        elif self.hparams.cell_type == 'gru':
            cell = GRUCell(self.hparams.hidden_units, name=name)
        elif self.hparams.cell_type == 'lstm':
            cell = LSTMCell(self.hparams.hidden_units, name=name, state_is_tuple=False)
        else:
            raise ValueError('Provided cell type not supported.')
        return cell

    def build_deep_cell(self, cell_list=None, name=None, return_raw_list=False):
        if name is None:
            name = "cell"
        if cell_list is None:
            cell_list = []
            for i in range(self.hparams.depth):
                cell = self.build_cell(name=name+"_"+str(i))
                cell = DropoutWrapper(cell, output_keep_prob=self.keep_prob)
                cell_list.append(cell)
        if return_raw_list:
            return cell_list
        if len(cell_list) == 1:
            return cell_list[0]
        return MultiRNNCell(cell_list, state_is_tuple=False)

    def build_rnn(self, input_key="inputs"):
        with tf.variable_scope('rnn'):
            self.cell = self.build_deep_cell()
            self.build_inputs(input_key)
            self.outputs, self.last_state = tf.nn.dynamic_rnn(
                cell=self.cell, inputs=self.inputs_dense,
                sequence_length=self.inputs_length,
                time_major=False, dtype=tf.float32)  # [batch_size, max_time_step, cell_output_size], [batch_size, cell_output_size]

    def build_embeddings(self):
        if "token_embeddings" in self.features and self.hparams.tie_token_embeddings:
            self.token_embeddings = self.features["token_embeddings"]
        else:
            self.token_embeddings = tf.get_variable(
                name='token_embeddings', shape=[self.vocab.size(), self.hparams.token_embed_dim])
            if self.hparams.token_embed_dim != self.hparams.hidden_units:
                projection = tf.get_variable(
                    name='token_embed_proj', shape=[self.hparams.token_embed_dim, self.hparams.hidden_units])
                self.token_embeddings = self.token_embeddings @ projection

    def embed_sparse_to_dense(self, sparse):
        with tf.variable_scope('embed_sparse_to_dense', reuse=tf.AUTO_REUSE):
            dense = tf.nn.embedding_lookup(self.token_embeddings, sparse)
        return dense

    def build_inputs(self, input_key):
        self.build_embeddings()
        self.inputs_sparse_untrimmed = tf.cast(self.features[input_key], tf.int32)
        self.inputs_length = tf.cast(tf.count_nonzero(
            self.inputs_sparse_untrimmed - self.vocab.end_token_id, -1), tf.int32)
        self.inputs_max_length = tf.reduce_max(self.inputs_length)
        self.inputs_sparse = tf.slice(self.inputs_sparse_untrimmed, [0, 0], [-1, self.inputs_max_length])
        self.inputs_dense = self.embed_sparse_to_dense(self.inputs_sparse)
        self.batch_size = tf.shape(self.inputs_sparse)[0]

    def build_loss(self):
        with tf.name_scope('build_loss'):
            self.loss = seq2seq.sequence_loss(
                logits=self.logits, targets=self.targets_sparse, weights=self.target_mask,
                average_across_timesteps=True, average_across_batch=True,)
        self.reported_loss = tf.identity(self.loss, 'reported_loss')