# from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

from keras import backend as K

from WrappedGRU import WrappedGRU
from helpers import compute_mask, softmax

class QuestionAttnGRU(WrappedGRU):

    def build(self, input_shape):
        H = self.units
        assert(isinstance(input_shape, list))
        
        nb_inputs = len(input_shape)
        assert(nb_inputs >= 2)

        assert(len(input_shape[0]) == 3)
        B, P, H_ = input_shape[0]
        assert(H_ == 2 * H)

        assert(len(input_shape[1]) == 3)
        B, Q, H_ = input_shape[1]
        assert(H_ == 2 * H)

        self.input_spec = [None]
        super(QuestionAttnGRU, self).build(input_shape=(B, P, 4 * H))
        self.GRU_input_spec = self.input_spec
        self.input_spec = [None] * nb_inputs

    def step(self, inputs, states):
        uP_t = inputs
        vP_tm1 = states[0]
        _ = states[1:3] # ignore internal dropout/masks
        uQ, WQ_u, WP_v, WP_u, v, W_g1 = states[3:9]
        uQ_mask, = states[9:10]

        WQ_u_Dot = K.dot(uQ, WQ_u) #WQ_u
        WP_v_Dot = K.dot(K.expand_dims(vP_tm1, axis=1), WP_v) #WP_v
        WP_u_Dot = K.dot(K.expand_dims(uP_t, axis=1), WP_u) # WP_u

        s_t_hat = K.tanh(WQ_u_Dot + WP_v_Dot + WP_u_Dot)

        s_t = K.dot(s_t_hat, v) # v
        s_t = K.batch_flatten(s_t)
        a_t = softmax(s_t, mask=uQ_mask, axis=1)
        c_t = K.batch_dot(a_t, uQ, axes=[1, 1])

        GRU_inputs = K.concatenate([uP_t, c_t])
        g = K.sigmoid(K.dot(GRU_inputs, W_g1))  # W_g1
        GRU_inputs = g * GRU_inputs
        vP_t, s = super(QuestionAttnGRU, self).step(GRU_inputs, states)

        return vP_t, s