import tensorflow as tf import numpy as np EPS = 0.000001 def forward_tac_log(act_probs,stop_probs,T,inference= False): T = tf.to_int32(T) # act_probs = Tx|l|, stop_probs = Tx |l|*2 lab_len = tf.shape(act_probs)[1] alpha = tf.zeros_like(act_probs[0,:],dtype=tf.float64) alpha = tf.concat([[tf.log(act_probs[0,0])],alpha[1:]],0) #eye = tf.eye(lab_len,dtype=tf.float64) stop_probs = tf.log(stop_probs) # so here no normalisation is needed in theory. def scan_op(curr_alpha,i): non_stops = tf.where(tf.equal(tf.zeros_like(curr_alpha),curr_alpha),-np.inf*tf.ones_like(curr_alpha),curr_alpha + stop_probs[i,:,0]) stops = tf.where(tf.equal(tf.zeros_like(curr_alpha), curr_alpha), -np.inf*tf.ones_like(curr_alpha), curr_alpha + stop_probs[i, :, 1]) cut_stop = tf.reduce_logsumexp([non_stops[1:], stops[:-1]],reduction_indices=0) #print tf.shape(cut_stop) all_stops = tf.concat([[non_stops[0]],cut_stop],0) def time_mask_full(): return tf.ones([lab_len],dtype=tf.float64) def time_mask_partial(): return tf.concat((tf.zeros([lab_len-T+i],dtype=tf.float64),tf.ones([T-i],dtype=tf.float64)),0) time_mask = tf.cond(i<T-lab_len+1,time_mask_full,time_mask_partial) new_alpha = tf.where(tf.is_inf(all_stops), tf.zeros_like(all_stops),(all_stops + tf.log(act_probs[i, :])) * time_mask) return new_alpha irange = tf.range(1,T,dtype=tf.int32) # this only goes until T. So it will ignore longer sequences. alphas = tf.scan(scan_op,irange,initializer=(alpha)) if not inference: return alphas[-1,-1] # this returns a vector else: return tf.concat(([alpha],alphas),0) def forward_tac_tf(act_probs,stop_probs,T,inference= False): T = tf.to_int32(T) # act_probs = Tx|l|, stop_probs = Tx |l|*2 lab_len = tf.shape(act_probs)[1] alpha = tf.zeros_like(act_probs[0,:],dtype=tf.float64) alpha = tf.concat([[act_probs[0,0]],alpha[1:]],0) eye = tf.eye(lab_len,dtype=tf.float64) C = tf.reduce_sum(alpha) alpha = alpha/C def scan_op(inp,i): curr_alpha, curr_C = inp diag = curr_alpha*eye non_stops = tf.matmul(diag,tf.expand_dims(stop_probs[i,:,0],1))[:,0] stops = tf.matmul(diag,tf.expand_dims(stop_probs[i,:,1],1))[:,0] cut_stop = non_stops[1:] + stops[:-1] #print tf.shape(cut_stop) all_stops = tf.concat([[non_stops[0]],cut_stop],0) def time_mask_full(): return tf.ones([lab_len],dtype=tf.float64) def time_mask_partial(): return tf.concat((tf.zeros([lab_len-T+i],dtype=tf.float64),tf.ones([T-i],dtype=tf.float64)),0) time_mask = tf.cond(i<T-lab_len+1,time_mask_full,time_mask_partial) new_alpha = all_stops * act_probs[i, :]*time_mask #if not inference else all_stops*time_mask return new_alpha/tf.reduce_sum(new_alpha),tf.reduce_sum(new_alpha) irange = tf.range(1,T,dtype=tf.int32) # this only goes until T. So it will ignore longer sequences. alphas,Cs = tf.scan(scan_op,irange,initializer=(alpha,C)) if not inference: return tf.concat(([C],Cs),0) # this returs a vector else: return tf.concat(([alpha],alphas),0) def tac_decode(action_probs, term_probs, targets,seq_len,tar_len): # For now a non batch version. # T length of trajectory. D size of dictionary. l length of label. B batch_size # actions_prob_tensors.shape [B,max(seq_len),D] # stop_tensors.shape [B,max(seq_len),D,2] # # targets.shape [B,max(tar_len)] # zero padded label sequences. # seq_len the actual length of each sequence. # tar_len the actual length of each target sequence # because the loss was only implemented per example, the batch version is simply in a loop rather than a matrix. max_seq_len = tf.to_int32(tf.reduce_max(seq_len)) bs = tf.to_int32(tf.shape(action_probs)[0]) #loss = 0. cond = lambda j,loss: tf.less(j, bs) j = tf.constant(0,dtype=tf.int32) decoded = tf.zeros([1,max_seq_len],dtype=tf.int32) def body(j,decoded): idx = tf.expand_dims(targets[j,:tar_len[j]],1) ac = tf.transpose(tf.gather_nd(tf.transpose(action_probs[j]), idx)) st = tf.transpose(term_probs[j], (1, 0, 2)) st = tf.transpose(tf.gather_nd(st, idx), (1, 0, 2)) length = tf.to_int32(seq_len[j]) alphas = forward_tac_tf(ac, st, length,inference=True) # get essentially the probability of being at each node dec = tf.to_int32(tf.argmax(alphas,axis=1)) # decode that by taking the argmax for each column of alphas dec = tf.concat([dec,tf.zeros([max_seq_len-length],dtype=tf.int32)],axis=0) decoded = tf.concat([decoded,[dec]],axis=0) return tf.add(j,1),decoded out = tf.while_loop(cond,body,loop_vars= [j,decoded],shape_invariants=[tf.TensorShape(None),tf.TensorShape([None, None])]) return out[1] def tac_loss(action_probs, term_probs, targets,seq_len,tar_len,safe = False): # For now a non batch version. # T length of trajectory. D size of dictionary. l length of label. B batch_size # actions_prob_tensors.shape [B,max(seq_len),D] # stop_tensors.shape [B,max(seq_len),D,2] # # targets.shape [B,max(tar_len)] # zero padded label sequences. # seq_len the actual length of each sequence. # tar_len the actual length of each target sequence # because the loss was only implemented per example, the batch version is simply in a loop rather than a matrix. bs = tf.to_int32(tf.shape(action_probs)[0]) #loss = 0. cond = lambda j,loss: tf.less(j, bs) j = tf.constant(0,dtype=tf.int32) loss = tf.constant(0,dtype=tf.float64) def body(j,loss): idx = tf.expand_dims(targets[j,:tar_len[j]],1) ac = tf.transpose(tf.gather_nd(tf.transpose(action_probs[j]), idx)) st = tf.transpose(term_probs[j], (1, 0, 2)) st = tf.transpose(tf.gather_nd(st, idx), (1, 0, 2)) length = seq_len[j] if safe: loss += -forward_tac_log(ac, st, length) / tf.to_double(bs) # negative log likelihood else: loss += -tf.reduce_sum(tf.log(forward_tac_tf(ac, st, length))/tf.to_double(bs)) # negative log likelihood for whole batch return tf.add(j,1),loss # average loss over batches out = tf.while_loop(cond,body,loop_vars= [j,loss]) return out[1] if __name__=="__main__": def test_sum_log(): x = tf.convert_to_tensor(-np.inf*(np.ones(2))) x_mask = tf.where(tf.is_inf(x), -np.inf*tf.ones_like(x), x) sum_log =tf.reduce_logsumexp(x) with tf.Session() as sess: mas = sess.run(x_mask) out = sess.run(sum_log) # print mas # print out def test_loss(): # create some dummy inputs D = 20 # dictionary length bs = 100 T = 1000 # all lengths 20 variable_lengths = np.random.randint(T-10,T-1,size = bs,dtype = np.int32) targ_lengths = np.random.randint(3, 6, size=bs,dtype=np.int32) T_arr = tf.convert_to_tensor(variable_lengths) targ_len = tf.convert_to_tensor(targ_lengths) np.random.uniform(0, 1, size=[bs, 20, D]) p_stop = np.array(np.random.uniform(0, 1, size = [bs,T, D]),dtype=np.float64) p_n_stop = 1 - p_stop prob = tf.convert_to_tensor(np.array(np.stack((p_stop, p_n_stop), axis=3),dtype=np.float64)) action_probs = tf.convert_to_tensor(np.array(np.random.uniform(0, 1e-70, size= [bs,T, D]),dtype=np.float64)) targets = tf.convert_to_tensor(np.random.randint(0,D-1,size = [bs,np.amax(targ_lengths)])) import time with tf.Session() as sess: tic = time.time() out1 = sess.run(tac_loss(action_probs,prob,targets,T_arr,targ_len,safe=True)) #print "SAFE",time.time()-tic tic = time.time() out2 = sess.run(tac_loss(action_probs, prob, targets, T_arr, targ_len, safe=False)) #print "Classic", time.time() - tic def test_decode(): # create some dummy inputs D = 20 # dictionary length bs = 100 T = 40 # all lengths 20 variable_lengths = np.random.randint(10,39,size = bs) targ_lengths = np.random.randint(10, 20, size=bs,dtype=np.int32) T_arr = tf.convert_to_tensor(variable_lengths) targ_len = tf.convert_to_tensor(targ_lengths) np.random.uniform(0, 1, size=[bs, 20, D]) p_stop = np.array(np.random.uniform(0, 1, size = [bs,T, D]),dtype=np.float64) p_n_stop = 1 - p_stop prob = tf.convert_to_tensor(np.array(np.stack((p_stop, p_n_stop), axis=3),dtype=np.float64)) action_probs = tf.convert_to_tensor(np.array(np.random.uniform(0, 1, size= [bs,T, D]),dtype=np.float64)) targets = tf.convert_to_tensor(np.random.randint(0,D-1,size = [bs,np.amax(targ_lengths)])) with tf.Session() as sess: out = sess.run(tac_decode(action_probs,prob,targets,T_arr,targ_len)) test_loss()