import numpy as np class CTCPrefixScore(): ''' CTC Prefix score calculator An implementation of Algo. 2 in https://www.merl.com/publications/docs/TR2017-190.pdf (Watanabe et. al.) Reference (official implementation): https://github.com/espnet/espnet/tree/master/espnet/nets ''' def __init__(self, x): self.logzero = -100000000.0 self.blank = 0 self.eos = 1 self.x = x.cpu().numpy()[0] self.odim = x.shape[-1] self.input_length = len(self.x) def init_state(self): # 0 = non-blank, 1 = blank r = np.full((self.input_length, 2), self.logzero, dtype=np.float32) # Accumalate blank at each step r[0, 1] = self.x[0, self.blank] for i in range(1, self.input_length): r[i, 1] = r[i-1, 1] + self.x[i, self.blank] return r def full_compute(self, g, r_prev): '''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c)) This function computes all possible tokens for c (memory inefficient)''' prefix_length = len(g) last_char = g[-1] if prefix_length > 0 else 0 # init. r r = np.full((self.input_length, 2, self.odim), self.logzero, dtype=np.float32) # start from len(g) because is impossible for CTC to generate |y|>|X| start = max(1, prefix_length) if prefix_length == 0: r[0, 0, :] = self.x[0, :] # if g = <sos> psi = r[start-1, 0, :] phi = np.logaddexp(r_prev[:, 0], r_prev[:, 1]) for t in range(start, self.input_length): # prev_blank prev_blank = np.full((self.odim), r_prev[t-1, 1], dtype=np.float32) # prev_nonblank prev_nonblank = np.full( (self.odim), r_prev[t-1, 0], dtype=np.float32) prev_nonblank[last_char] = self.logzero phi = np.logaddexp(prev_nonblank, prev_blank) # P(h|current step is non-blank) = [ P(prev. step = y) + P()]*P(c) r[t, 0, :] = np.logaddexp(r[t-1, 0, :], phi) + self.x[t, :] # P(h|current step is blank) = [P(prev. step is blank) + P(prev. step is non-blank)]*P(now=blank) r[t, 1, :] = np.logaddexp( r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank] psi = np.logaddexp(psi, phi+self.x[t, :]) #psi[self.eos] = np.logaddexp(r_prev[-1,0], r_prev[-1,1]) return psi, np.rollaxis(r, 2) def cheap_compute(self, g, r_prev, candidates): '''Given prefix g, return the probability of all possible sequence y (where y = concat(g,c)) This function considers only those tokens in candidates for c (memory efficient)''' prefix_length = len(g) odim = len(candidates) last_char = g[-1] if prefix_length > 0 else 0 # init. r r = np.full((self.input_length, 2, len(candidates)), self.logzero, dtype=np.float32) # start from len(g) because is impossible for CTC to generate |y|>|X| start = max(1, prefix_length) if prefix_length == 0: r[0, 0, :] = self.x[0, candidates] # if g = <sos> psi = r[start-1, 0, :] # Phi = (prev_nonblank,prev_blank) sum_prev = np.logaddexp(r_prev[:, 0], r_prev[:, 1]) phi = np.repeat(sum_prev[..., None],odim,axis=-1) # Handle edge case : last tok of prefix in candidates if prefix_length>0 and last_char in candidates: phi[:,candidates.index(last_char)] = r_prev[:,1] for t in range(start, self.input_length): # prev_blank # prev_blank = np.full((odim), r_prev[t-1, 1], dtype=np.float32) # prev_nonblank # prev_nonblank = np.full((odim), r_prev[t-1, 0], dtype=np.float32) # phi = np.logaddexp(prev_nonblank, prev_blank) # P(h|current step is non-blank) = P(prev. step = y)*P(c) r[t, 0, :] = np.logaddexp( r[t-1, 0, :], phi[t-1]) + self.x[t, candidates] # P(h|current step is blank) = [P(prev. step is blank) + P(prev. step is non-blank)]*P(now=blank) r[t, 1, :] = np.logaddexp( r[t-1, 1, :], r[t-1, 0, :]) + self.x[t, self.blank] psi = np.logaddexp(psi, phi[t-1,]+self.x[t, candidates]) # P(end of sentence) = P(g) if self.eos in candidates: psi[candidates.index(self.eos)] = sum_prev[-1] return psi, np.rollaxis(r, 2)