import numpy as np def softmax(x): """Compute softmax values for each sets of scores in x.""" e_x = np.exp(x - np.max(x)) return e_x / e_x.sum() def reduce_mul(l): out = 1.0 for x in l: out *= x return out def decode_step(step, encoder_logits): words_prob = encoder_logits[step] words_prob = softmax(words_prob) ouput_step = [(idx, prob) for idx, prob in enumerate(words_prob)] ouput_step = sorted(ouput_step, key=lambda x: x[1], reverse=True) return ouput_step def check_exceed(seq, mask, max_count): count = 0 for i, word in enumerate(seq): sign_index = word[0] if sign_index > 0 and mask[i]: count += 1 if count > max_count: return False else: return True def beam_search_step(step, encoder_logits, top_seqs, mask, k, max_count): all_seqs = [] for seq in top_seqs: seq_score = reduce_mul([_score for _, _score in seq]) # get current step using encoder_context & seq current_step = decode_step(step, encoder_logits) for i, word in enumerate(current_step): if i >= k: break word_index, word_score = word score = seq_score * word_score rs_seq = seq + [word] all_seqs.append((rs_seq, score)) all_seqs = sorted(all_seqs, key=lambda seq: seq[1], reverse=True) # Expression constraint filtered_seqs = [seq for seq, _ in all_seqs if check_exceed(seq, mask, max_count)] # topk_seqs = [seq for seq, _ in all_seqs[:k]] topk_seqs = [seq for seq in filtered_seqs[:k]] return topk_seqs def beam_search(encoder_logits, mask, beam_size, max_count): max_len = sum(mask) # START top_seqs = [[(0, 1.0)]] # loop for i in range(1, max_len + 1): top_seqs = beam_search_step(i, encoder_logits, top_seqs, mask, beam_size, max_count) number_indices_list, sign_indices_list, scores_list = [], [], [] for seq in top_seqs: number_indices, sign_indices = [], [] for i, word in enumerate(seq): sign_index, score = word if sign_index > 0 and mask[i]: number_indices.append(i) sign_indices.append(sign_index) if number_indices == [] and sign_indices == []: continue number_indices_list.append(number_indices) sign_indices_list.append(sign_indices) seq_score = reduce_mul([_score for _, _score in seq]) scores_list.append(seq_score) if scores_list != []: scores_list = softmax(np.array(scores_list)) return number_indices_list, sign_indices_list, scores_list.tolist()