import math def log_sum_exp(a, b): """ Stable log sum exp. """ return max(a, b) + math.log1p(math.exp(-abs(a-b))) def decode_static(log_probs, beam_size=1, blank=0): """ Decode best prefix in the RNN Transducer. This decoder is static, it does not update the next step distribution based on the previous prediction. As such it looks for hypotheses which are length U. """ T, U, V = log_probs.shape beam = [((), 0)]; for i in range(T + U - 2): new_beam = {} for hyp, score in beam: u = len(hyp) t = i - u for v in range(V): if v == blank: if t < T - 1: new_hyp = hyp new_score = score + log_probs[t, u, v] elif u < U - 1: new_hyp = hyp + (v,) new_score = score + log_probs[t, u, v] else: continue old_score = new_beam.get(new_hyp, None) if old_score is not None: new_beam[new_hyp] = log_sum_exp(old_score, new_score) else: new_beam[new_hyp] = new_score new_beam = sorted(new_beam.items(), key=lambda x: x[1], reverse=True) beam = new_beam[:beam_size] hyp, score = beam[0] return hyp, score + log_probs[-1, -1, blank] if __name__ == "__main__": import transducer.ref_transduce as rt import numpy as np np.random.seed(10) T = 10 U = 5 V = 5 blank = V - 1 beam_size = 500 log_probs = np.random.randn(T, U, V) log_probs = rt.log_softmax(log_probs, axis=2) labels, beam_ll = decode_static(log_probs, beam_size, blank) _, ll = rt.forward_pass(log_probs, labels, blank) assert np.allclose(ll, beam_ll, rtol=1e-9, atol=1e-9), \ "Bad result from beam search."