import torch as th def sample(model, src_tokens, temperature=1.0, max_len=200, device=None): # Either decode on the model's device or specified device # (in which case move the model accordingly) if device is None: device = list(model.parameters())[0].device else: model = model.to(device) # Go into eval mode (e.g. disable dropout) model.eval() # Encode source sentece src_tensor = th.LongTensor(src_tokens).to(device).view(-1, 1) encodings = model.encode(src_tensor) # Initialize decoder state state = model.initial_state() # Start decoding out_tokens = [model.vocab["<sos>"]] eos_token = model.vocab["<eos>"] while out_tokens[-1] != eos_token and len(out_tokens) <= max_len: current_token = th.LongTensor([out_tokens[-1]]).view(1, 1).to(device) # One step of the decoder log_p, state = model.decode_step(current_token, encodings, state) # Probabilities probs = th.exp(log_p / temperature).view(-1) # Sample next_token = th.multinomial(probs.view(-1), 1).item() # Add to the generated sentence out_tokens.append(next_token) # Return generated token (idxs) without <sos> and <eos> out_tokens = out_tokens[1:] if out_tokens[-1] == eos_token: out_tokens = out_tokens[:-1] return out_tokens def greedy(model, src_tokens, max_len=200, device=None): # Either decode on the model's device or specified device # (in which case move the model accordingly) if device is None: device = list(model.parameters())[0].device else: model = model.to(device) # TODO 3: implement greedy decoding # # (hint: the implementation is very similar to sampling) raise NotImplementedError("TODO 3") def beam_search( model, src_tokens, beam_size=1, max_len=200, device=None ): # Either decode on the model's device or specified device # (in which case move the model accordingly) if device is None: device = list(model.parameters())[0].device else: model = model.to(device) # TODO 4: implement beam search # Hints: # - For each beam you need to keep track of at least: # 1. The previously generated tokens # 2. The decoder state # 3. The score (log probability of the generated tokens) # - Be careful of how many decoding step you need to perform at each step # - Think carefuly of the stopping criterion (there are 2) # - As a sanity check you can check that setting beam_szie to 1 returns # the same result as greedy decoding raise NotImplementedError("TODO 4")