from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import tensorflow as tf
import tensorflow.contrib.rnn as rnn

from .rnn_cell_util import (
    GeneratorCellBuilder, GeneratorRNNCellBuilder, mean_field_cell,
    OutputRangeWrapper
)
from functools import partial
from itertools import product
from tensorflow.python.ops import variable_scope


class RNNCellGenerator(object):
    """Train a model to emit action sequences"""
    def __init__(self, m, seq_len, cell_builder, name='gen', reuse=False,
        **kwargs):
        self.m          = m
        self.seq_len    = seq_len
        self.cell_build = cell_builder
        self.name       = name
        self.action_seq = None
        self.logits     = None
        with tf.variable_scope(name, reuse=reuse):
            self.batch_size = tf.placeholder(tf.int32, [])
            self._build_generator(**kwargs)

    def _build_generator(self, feed_actions=True, init_type='zeros'):
        """Build the RNN TF sequence generator

        @feed_actions: Feed one-hot actions taken in previous step, rather than
            logits (from which action was sampled)
        @init_type: How to initialize the sequence generation:
            * train: Train a variable to init with
            * zeros: Send in an all-zeros vector
        """
        # Build the cell
        cell, state = self.cell_build.build_cell_and_init_state(
            self.batch_size, feed_actions
        )
        # If train input, train a variable input. Otherwise, all zeros
        if init_type.lower().startswith('train'):
            input_var = tf.Variable(tf.zeros((1, self.m), dtype=tf.float32))
            feed      = tf.tile(input_var, (self.batch_size, 1))
        elif init_type.lower().startswith('zero'):
            feed = tf.zeros((self.batch_size, self.m), dtype=tf.float32)
        else:
            raise ValueError(
                "RNN cell generator init_type %s not recognized" % init_type)

        # Placeholders to recover policy for updates
        self.rerun         = tf.placeholder_with_default(False, [])
        self.input_actions = tf.placeholder_with_default(
            tf.zeros((1, self.seq_len), dtype=tf.int32), (None, self.seq_len)
        )
        self.coo_actions   = tf.placeholder(tf.int32, (None, 3))
        # Run loopy feed forward
        actions_arr = tf.TensorArray(tf.int32, self.seq_len)
        logits_arr  = tf.TensorArray(tf.float32, self.seq_len)
        for t in range(self.seq_len):
            if t > 0:
                variable_scope.get_variable_scope().reuse_variables()
            # Compute logits for next action using RNN cell
            logits, state = cell(feed, state)
            # Samplers to draw actions
            def sample():
                return tf.to_int32(tf.multinomial(logits, 1))
            def rerun_sample():
                return self.input_actions[:, t]
            # If rerunning to apply policy gradients, draw is the input
            draw = tf.reshape(tf.cond(self.rerun, rerun_sample, sample), (-1,))
            # Write to arrays
            logits_arr  = logits_arr.write(t, logits)
            actions_arr = actions_arr.write(t, draw)
            # Update feed- either with the action taken (default), or with
            # the logits output at the previous timestep
            if feed_actions:
                feed = tf.one_hot(draw, self.m)
            else:
                feed = logits
        # Reshape logits to [batch_size, seq_len, n_actions]
        self.logits = tf.transpose(logits_arr.stack(), (1, 0, 2))
        # Reshape action_seq to [batch_size, seq_len]
        self.action_seq = tf.transpose(actions_arr.stack())

    def _get_generated_probabilities(self):
        """Returns a [batch_size, seq_len] Tensor with probabilities for each
           action that was drawn
        """
        input_batch_size = tf.shape(self.input_actions)[0]
        dists            = tf.nn.softmax(self.logits)
        r_dists          = tf.gather_nd(dists, self.coo_actions)
        return tf.reshape(r_dists, (input_batch_size, self.seq_len))

    def _build_discounts_matrix(self, T, gamma):
        """Build lower-triangular matrix of discounts.
        For example for T = 3: D = [[1,       0,     0]
                                   [gamma,   1,     0]
                                   [gamma^2, gamma, 1]]
        Then with R, our N x T incremental rewards matrix, the discounted sum is
            R * D
        """
        power_ltri  = tf.cumsum(
            tf.sequence_mask(tf.range(T)+1, T, dtype=tf.float32), exclusive=True
        )
        gamma_ltri  = tf.pow(gamma, power_ltri) 
        gamma_ltri *= tf.sequence_mask(tf.range(T)+1, T, dtype=tf.float32)
        return gamma_ltri

    def get_policy_loss_op(self, incremental_rewards, gamma):
        """Input is a [batch_size, seq_len] Tensor where each entry represents
           the incremental reward for an action on a data point
        """
        T = tf.shape(incremental_rewards)[1]
        # Form matrix of discounts to apply
        gamma_ltri = self._build_discounts_matrix(T, gamma)
        # Compute future discounted rewards as [batch_size x seq_len] matrix
        future_rewards = tf.matmul(incremental_rewards, gamma_ltri)
        # Compute baseline and advantage
        baseline   = tf.reduce_mean(future_rewards, axis=0)
        advantages = future_rewards - baseline
        # Apply advantage to policy
        policy = self._get_generated_probabilities()
        return tf.reduce_sum(tf.log(policy) * tf.stop_gradient(advantages))

    def get_action_sequence(self, session, batch_size):
        """Sample action sequences"""
        return session.run(self.action_seq, {self.batch_size: batch_size})

    def get_feed(self, actions, **kwargs):
        """Get the feed_dict for the training step.
        @action_seqs: The sequence of actions taken to generate the transformed
            data in this training step.

        Note that we feed `action_seqs` back in and set rerun=True to indicate
        that the exact same sequence of actions should be used in all other
        operations in this step!
        """
        coord   = product(range(actions.shape[0]), range(actions.shape[1]))
        feed    = {
            self.batch_size   : actions.shape[0],
            self.input_actions: actions,
            self.coo_actions:   [[i, j, actions[i, j]] for i, j in coord],
            self.rerun:         True,
        }
        kwargs.update(feed)
        return kwargs


class GRUGenerator(RNNCellGenerator):
    def __init__(self, m, seq_len, name='gen', reuse=False, n_stack=1,
        logit_range=4.0, **kwargs):
        # Get GRU cell builder
        range_wrapper = partial(OutputRangeWrapper, output_range=logit_range)
        cb = GeneratorRNNCellBuilder(
            rnn.GRUCell, m=m, n_stack=n_stack, wrappers=[range_wrapper]
        )
        # Super constructor
        super(GRUGenerator, self).__init__(
            m, seq_len, name=name, cell_builder=cb, reuse=reuse, **kwargs
        )


class LSTMGenerator(RNNCellGenerator):
    def __init__(self, m, seq_len, name='gen', reuse=False, n_stack=1,
        logit_range=4.0, **kwargs):
        # Get LSTM cell builder
        def norm(x):
            return 0.5 * (x + 1.)
        range_wrapper = partial(
            OutputRangeWrapper, output_range=logit_range, norm_op=norm
        )
        cb = GeneratorRNNCellBuilder(
            rnn.BasicLSTMCell, m=m, n_stack=n_stack, wrappers=[range_wrapper]
        )
        # Super constructor
        super(LSTMGenerator, self).__init__(
            m, seq_len, name=name, cell_builder=cb, reuse=reuse, **kwargs
        )


class MeanFieldGenerator(RNNCellGenerator):
    def __init__(self, m, seq_len, name='gen', reuse=False, **kwargs):
        # Get mean field cell builder
        cb = GeneratorCellBuilder(mean_field_cell)
        # Super constructor
        super(MeanFieldGenerator, self).__init__(
            m, seq_len, name=name, cell_builder=cb, reuse=reuse, **kwargs
        )