import tensorflow as tf
import numpy as np

EPS = 1e-20
def forward_ngctc(pol_probs,T,inference = False):

    # pol probs are policy stop probs. which is similar to the stop actions but throgh a global controller.
    # in this case the stop action can be seen as the probability for the next policy in the sequence
    T  = tf.to_int32(T)
    # act_probs = Tx|l|, stop_probs = Tx |l|*2
    lab_len = tf.shape(pol_probs)[1]
    alpha = tf.zeros([lab_len],dtype=tf.float64,)
    alpha = tf.concat([[pol_probs[0,0]],alpha[1:]],0)
    eye = tf.eye(lab_len,dtype=tf.float64)
    eye2 = tf.eye(lab_len-1, dtype=tf.float64)
    C = tf.reduce_sum(alpha)
    alpha = alpha/C


    def scan_op(inp,i):
        curr_alpha, curr_C = inp
        diag = curr_alpha*eye
        diag2 = curr_alpha[:-1] * eye2
        non_stops = tf.matmul(diag,tf.expand_dims(pol_probs[i,:],1))[:,0]

        stops = tf.matmul(diag2,tf.expand_dims(pol_probs[i,1:],1))[:,0]
        cut_stop = non_stops[1:] + stops


        #print tf.shape(cut_stop)
        all_stops = tf.concat([[non_stops[0]],cut_stop],0)

        def time_mask_full(): return tf.ones([lab_len],dtype=tf.float64)

        def time_mask_partial(): return tf.concat((tf.zeros([lab_len-T+i],dtype=tf.float64),tf.ones([T-i],dtype=tf.float64)),0)


        time_mask = tf.cond(i<T-lab_len+1,time_mask_full,time_mask_partial)
        new_alpha = all_stops*time_mask
        return new_alpha/tf.reduce_sum(new_alpha),tf.reduce_sum(new_alpha)+EPS

    irange = tf.range(1,T,dtype=tf.int32) # this only goes until T. So it will ignore longer sequences.
    alphas,Cs = tf.scan(scan_op,irange,initializer=(alpha,C))
    if not inference:
        return tf.concat(([C],Cs),0) # this returs a vector
    else:
        return tf.concat(([alpha],alphas),0)


def ngctc_loss(term_probs, targets,seq_len,tar_len):
    bs = tf.to_int32(tf.shape(term_probs)[0])
    #loss = 0.
    cond = lambda j,loss: tf.less(j, bs)
    j = tf.constant(0,dtype=tf.int32)
    loss = tf.constant(0,dtype=tf.float64)
    def body(j,loss):
        idx = tf.expand_dims(targets[j,:tar_len[j]],1)
        st = tf.transpose(term_probs[j], (1, 0))
        st = tf.transpose(tf.gather_nd(st, idx), (1, 0))
        length = seq_len[j]
        loss += -tf.reduce_sum(tf.log(forward_ngctc(st, length))/tf.to_double(bs)) # negative log likelihood for whole batch
        return tf.add(j,1),loss # average loss over batches

    out = tf.while_loop(cond,body,loop_vars= [j,loss])

    return out[1]


def ngctc_decode(term_probs, targets,seq_len,tar_len):
    max_seq_len  = tf.to_int32(tf.reduce_max(seq_len))
    bs = tf.to_int32(tf.shape(term_probs)[0])
    #loss = 0.
    cond = lambda j,loss: tf.less(j, bs)
    j = tf.constant(0,dtype=tf.int32)
    decoded = tf.zeros([1,max_seq_len],dtype=tf.int32)
    def body(j,decoded):
        idx = tf.expand_dims(targets[j,:tar_len[j]],1)
        st = tf.transpose(term_probs[j], (1, 0))
        st = tf.transpose(tf.gather_nd(st, idx), (1, 0))
        length = tf.to_int32(seq_len[j])
        alphas = forward_ngctc(st, length,inference=True) # get essentially the probability of being at each node
        dec = tf.to_int32(tf.argmax(alphas,axis=1)) # decode that by taking the argmax for each column of alphas
        dec = tf.concat([dec,tf.zeros([max_seq_len-length],dtype=tf.int32)],axis=0)

        decoded = tf.concat([decoded,[dec]],axis=0)

        return tf.add(j,1),decoded

    out = tf.while_loop(cond,body,loop_vars= [j,decoded],shape_invariants=[tf.TensorShape(None),tf.TensorShape([None, None])])
    return out[1]



if __name__=="__main__":
    def softmax(x,axis = 0):
        """Compute softmax values for each sets of scores in x."""
        e_x = np.exp(x - np.expand_dims(np.max(x,axis=axis),axis=axis))
        return e_x / np.expand_dims(e_x.sum(axis=axis),axis=axis)

    def test_loss():
        # create some dummy inputs
        D = 20 # dictionary length
        bs = 100
        T = 500 # all lengths 20
        variable_lengths = np.random.randint(100,390,size = bs)
        targ_lengths = np.random.randint(5, 15, size=bs,dtype=np.int32)
        T_arr = tf.convert_to_tensor(variable_lengths)
        targ_len = tf.convert_to_tensor(targ_lengths)


        np.random.uniform(0, 1, size=[bs, 20, D])

        p_stop = np.array(np.random.uniform(0, 1, size = [bs,T, D]),dtype=np.float32)
        p_stop = softmax(p_stop,axis=2)
        prob = tf.convert_to_tensor(p_stop)
        action_probs = tf.convert_to_tensor(np.array(np.random.uniform(0, 0.00001, size= [bs,T, D]),dtype=np.float32))


        targets = tf.convert_to_tensor(np.random.randint(0,D-1,size = [bs,np.amax(targ_lengths)]))



        with tf.Session() as sess:
            out = sess.run(ngctc_loss(prob,targets,T_arr,targ_len))
        ce()
    test_loss()