Python torch.multinomial() Examples

The following are 30 code examples of torch.multinomial(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: crossentropyloss.py    From backpack with MIT License 6 votes vote down vote up
def _sqrt_hessian_sampled(self, module, g_inp, g_out, mc_samples=1):
        self._check_2nd_order_parameters(module)

        M = mc_samples
        C = module.input0.shape[1]

        probs = self._get_probs(module)
        V_dim = 0
        probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1)

        multi = multinomial(probs, M, replacement=True)
        classes = one_hot(multi, num_classes=C)
        classes = einsum("nvc->vnc", classes).float()

        sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M)

        if module.reduction == "mean":
            N = module.input0.shape[0]
            sqrt_mc_h /= sqrt(N)

        return sqrt_mc_h 
Example #2
Source File: decoder.py    From A-Hierarchical-Latent-Structure-for-Variational-Conversation-Modeling with MIT License 6 votes vote down vote up
def decode(self, out):
        """
        Args:
            out: unnormalized word distribution [batch_size, vocab_size]
        Return:
            x: word_index [batch_size]
        """

        # Sample next word from multinomial word distribution
        if self.sample:
            # x: [batch_size] - word index (next input)
            x = torch.multinomial(self.softmax(out / self.temperature), 1).view(-1)

        # Greedy sampling
        else:
            # x: [batch_size] - word index (next input)
            _, x = out.max(dim=1)
        return x 
Example #3
Source File: modeling_transfo_xl_utilities.py    From Bert-Chinese-Text-Classification-Pytorch with MIT License 6 votes vote down vote up
def sample(self, labels):
        """
            labels: [b1, b2]
        Return
            true_log_probs: [b1, b2]
            samp_log_probs: [n_sample]
            neg_samples: [n_sample]
        """

        # neg_samples = torch.empty(0).long()
        n_sample = self.n_sample
        n_tries = 2 * n_sample

        with torch.no_grad():
            neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
            device = labels.device
            neg_samples = neg_samples.to(device)
            true_log_probs = self.log_q[labels].to(device)
            samp_log_probs = self.log_q[neg_samples].to(device)
            return true_log_probs, samp_log_probs, neg_samples 
Example #4
Source File: torch_generator_agent.py    From ParlAI with MIT License 6 votes vote down vote up
def select_paths(self, logprobs, prior_scores, current_length):
        # Unlike the other treesearch methods, we have to switch to linspace
        # for the probabilities in order to compute the CDF.
        probs = torch.softmax(logprobs, dim=-1)
        sprobs, sinds = probs.sort(dim=-1, descending=True)
        # The subtraction here is to get the exclusive prefix sum,
        # to guarantee the first element is not masked
        mask = (sprobs.cumsum(dim=-1) - sprobs) >= self.p
        sprobs[mask] = 0
        sprobs.div_(sprobs.sum(dim=-1).unsqueeze(1))
        choices = torch.multinomial(sprobs, 1)[:, 0]
        hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device)
        tok_ids = sinds[hyp_ids, choices]
        # Convert back to logspace.
        scores = sprobs[hyp_ids, choices].log()
        best_scores = prior_scores.expand_as(scores) + scores
        return (hyp_ids, tok_ids, best_scores) 
Example #5
Source File: main.py    From char_rnn_lm_zh with MIT License 6 votes vote down vote up
def generate(model, idx2word, word_len=200, temperature=1.0):
    """生成一定数量的文本,temperature结合多项式分布可增添抽样的多样性。"""
    model.eval()
    hidden = model.init_hidden(1)  # batch_size为1
    inputs = Variable(torch.rand(1, 1).mul(len(idx2word)).long(), volatile=True)  # 随机选取一个字作为开始
    if use_cuda:
        inputs = inputs.cuda()

    word_list = []
    for i in range(word_len):  # 逐字生成
        output, hidden = model(inputs, hidden)
        word_weights = output.squeeze().data.div(temperature).exp().cpu()

        # 基于词的权重,对其再进行一次抽样,增添其多样性,如果不使用此法,会导致常用字的无限循环
        word_idx = torch.multinomial(word_weights, 1)[0]
        inputs.data.fill_(word_idx)  # 将新生成的字赋给inputs
        word = idx2word[word_idx]
        word_list.append(word)
    return word_list 
Example #6
Source File: run_gpt2.py    From squash-generation with MIT License 6 votes vote down vote up
def sample_sequence(model, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, device='cuda', sample=True):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
        context = torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0).repeat(batch_size, 1)
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = torch.full((batch_size, 1), start_token, device=device, dtype=torch.long)
    prev = context
    output = context
    past = None
    with torch.no_grad():
        for i in trange(length):
            logits, past = model(prev, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_k_logits(logits, k=top_k)
            log_probs = F.softmax(logits, dim=-1)
            if sample:
                prev = torch.multinomial(log_probs, num_samples=1)
            else:
                _, prev = torch.topk(log_probs, k=1, dim=-1)
            output = torch.cat((output, prev), dim=1)
    return output 
Example #7
Source File: heuristics.py    From baal with Apache License 2.0 6 votes vote down vote up
def _draw_choices(self, probs, n_choices):
        """
        Draw `n_choices` sample from `probs`.

        References:
            Code from https://github.com/BlackHC/BatchBALD/blob/master/src/torch_utils.py#L187

        Returns:
            choices: B... x `n_choices`

        """
        probs = probs.permute(0, 2, 1)
        probs_B_C = probs.reshape((-1, probs.shape[-1]))

        # samples: Ni... x draw_per_xx
        choices = torch.multinomial(probs_B_C,
                                    num_samples=n_choices, replacement=True)

        choices_b_M = choices.reshape(list(probs.shape[:-1]) + [n_choices])
        return choices_b_M.long() 
Example #8
Source File: modeling_transfo_xl_utilities.py    From squash-generation with MIT License 6 votes vote down vote up
def sample(self, labels):
        """
            labels: [b1, b2]
        Return
            true_log_probs: [b1, b2]
            samp_log_probs: [n_sample]
            neg_samples: [n_sample]
        """

        # neg_samples = torch.empty(0).long()
        n_sample = self.n_sample
        n_tries = 2 * n_sample

        with torch.no_grad():
            neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
            device = labels.device
            neg_samples = neg_samples.to(device)
            true_log_probs = self.log_q[labels].to(device)
            samp_log_probs = self.log_q[neg_samples].to(device)
            return true_log_probs, samp_log_probs, neg_samples 
Example #9
Source File: modeling_transfo_xl_utilities.py    From TextClassify with Apache License 2.0 6 votes vote down vote up
def sample(self, labels):
        """
            labels: [b1, b2]
        Return
            true_log_probs: [b1, b2]
            samp_log_probs: [n_sample]
            neg_samples: [n_sample]
        """

        # neg_samples = torch.empty(0).long()
        n_sample = self.n_sample
        n_tries = 2 * n_sample

        with torch.no_grad():
            neg_samples = torch.multinomial(self.dist, n_tries, replacement=True).unique()
            device = labels.device
            neg_samples = neg_samples.to(device)
            true_log_probs = self.log_q[labels].to(device)
            samp_log_probs = self.log_q[neg_samples].to(device)
            return true_log_probs, samp_log_probs, neg_samples 
Example #10
Source File: base_model.py    From KBGAN with MIT License 6 votes vote down vote up
def gen_step(self, src, rel, dst, n_sample=1, temperature=1.0, train=True):
        if not hasattr(self, 'opt'):
            self.opt = Adam(self.mdl.parameters(), weight_decay=self.weight_decay)
        n, m = dst.size()
        rel_var = Variable(rel.cuda())
        src_var = Variable(src.cuda())
        dst_var = Variable(dst.cuda())

        logits = self.mdl.prob_logit(src_var, rel_var, dst_var) / temperature
        probs = nnf.softmax(logits)
        row_idx = torch.arange(0, n).type(torch.LongTensor).unsqueeze(1).expand(n, n_sample)
        sample_idx = torch.multinomial(probs, n_sample, replacement=True)
        sample_srcs = src[row_idx, sample_idx.data.cpu()]
        sample_dsts = dst[row_idx, sample_idx.data.cpu()]
        rewards = yield sample_srcs, sample_dsts
        if train:
            self.mdl.zero_grad()
            log_probs = nnf.log_softmax(logits)
            reinforce_loss = -torch.sum(Variable(rewards) * log_probs[row_idx.cuda(), sample_idx.data])
            reinforce_loss.backward()
            self.opt.step()
            self.mdl.constraint()
        yield None 
Example #11
Source File: testing.py    From funsor with Apache License 2.0 6 votes vote down vote up
def random_tensor(inputs, output=reals()):
    """
    Creates a random :class:`funsor.tensor.Tensor` with given inputs and output.
    """
    backend = get_backend()
    assert isinstance(inputs, OrderedDict)
    assert isinstance(output, Domain)
    shape = tuple(d.dtype for d in inputs.values()) + output.shape
    if output.dtype == 'real':
        data = randn(shape)
    else:
        num_elements = reduce(operator.mul, shape, 1)
        if backend == "torch":
            import torch

            data = torch.multinomial(torch.ones(output.dtype), num_elements, replacement=True)
        else:
            data = np.random.choice(output.dtype, num_elements, replace=True)
        data = data.reshape(shape)
    return Tensor(data, inputs, output.dtype) 
Example #12
Source File: weighted_random_sampler.py    From PyTorch-NLP with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __iter__(self):
        if self.num_samples == 0:
            return iter([])

        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist()) 
Example #13
Source File: ShowTellModel.py    From AAT with MIT License 5 votes vote down vote up
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        outputs = []

        for i in range(seq.size(1)):
            if i == 0:
                xt = self.img_embed(fc_feats)
            else:
                if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
                    sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                    sample_mask = sample_prob < self.ss_prob
                    if sample_mask.sum() == 0:
                        it = seq[:, i-1].clone()
                    else:
                        sample_ind = sample_mask.nonzero().view(-1)
                        it = seq[:, i-1].data.clone()
                        #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                        #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                        prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
                        it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
                else:
                    it = seq[:, i-1].clone()                
                # break if all the sequences end
                if i >= 2 and seq[:, i-1].data.sum() == 0:
                    break
                xt = self.embed(it)

            output, state = self.core(xt.unsqueeze(0), state)
            output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 
Example #14
Source File: dataset.py    From pytorch-planet-amazon with Apache License 2.0 5 votes vote down vote up
def __iter__(self):
        base_samples = torch.arange(0, len(self.weights)).long()
        remaining = self.num_samples - len(self.weights)
        over_samples = torch.multinomial(self.weights, remaining, True)
        samples = torch.cat((base_samples, over_samples), dim=0)
        print('num samples', len(samples))
        return (samples[i] for i in torch.randperm(len(samples))) 
Example #15
Source File: torch_generator_agent.py    From ParlAI with MIT License 5 votes vote down vote up
def select_paths(self, logprobs, prior_scores, current_length):
        values, indices = logprobs.topk(self.k, dim=-1)
        probs = torch.softmax(values, dim=-1)
        choices = torch.multinomial(probs, 1)[:, 0]
        hyp_ids = torch.arange(logprobs.size(0)).to(logprobs.device)
        tok_ids = indices[hyp_ids, choices]
        scores = values[hyp_ids, choices]
        best_scores = prior_scores.expand_as(scores) + scores
        return (hyp_ids, tok_ids, best_scores) 
Example #16
Source File: pytorch_model.py    From char-rnn-text-generation with MIT License 5 votes vote down vote up
def sample_from_probs(probs, top_n=10):
    """
    truncated weighted random choice.
    """
    _, indices = torch.sort(probs)
    # set probabilities after top_n to 0
    probs[indices.data[:-top_n]] = 0
    sampled_index = torch.multinomial(probs, 1)
    return sampled_index 
Example #17
Source File: model.py    From reinvent-randomized with MIT License 5 votes vote down vote up
def sample_smiles(self, num):
        """
        Samples n SMILES from the model.
        :param num: Number of SMILES to sample.
        :return: An iterator with (smiles, likelihood) pairs
        """
        input_vector = torch.full((num, 1), self.vocabulary["^"], dtype=torch.long).cuda()  # (batch, 1)
        seq_lengths = torch.ones(num).cuda()  # (batch)
        sequences = []
        hidden_state = None
        nlls = torch.zeros(num).cuda()
        not_finished = torch.ones(num, 1, dtype=torch.long).cuda()
        for _ in range(self.max_sequence_length - 1):
            logits, hidden_state = self.network(input_vector, seq_lengths, hidden_state)  # (batch, 1, voc)
            probs = logits.softmax(dim=2).squeeze()  # (batch, voc)
            log_probs = logits.log_softmax(dim=2).squeeze()
            input_vector = torch.multinomial(probs, 1)*not_finished  # (batch, 1)
            sequences.append(input_vector)
            nlls += self.nll_loss(log_probs, input_vector.squeeze())
            not_finished = (input_vector > 1).type(torch.long)
            if not_finished.sum() == 0:
                break

        smiles = [self.tokenizer.untokenize(self.vocabulary.decode(seq))
                  for seq in torch.cat(sequences, 1).data.cpu().numpy()]
        nlls = nlls.data.cpu().numpy().tolist()
        return zip(smiles, nlls) 
Example #18
Source File: masked_language_pair_dataset.py    From NLP_Toolkit with Apache License 2.0 5 votes vote down vote up
def random_word(self, w, pred_probs):
        cands = [self.vocab.mask_index, np.random.randint(self.vocab.nspecial, len(self.vocab)), w]
        prob = torch.multinomial(self.pred_probs, 1, replacement=True)
        return cands[prob] 
Example #19
Source File: composition.py    From torchio with MIT License 5 votes vote down vote up
def apply_transform(self, sample: Subject):
        weights = torch.Tensor(list(self.transforms_dict.values()))
        index = torch.multinomial(weights, 1)
        transforms = list(self.transforms_dict.keys())
        transform = transforms[index]
        transformed = transform(sample)
        return transformed 
Example #20
Source File: imbalanced.py    From imbalanced-dataset-sampler with MIT License 5 votes vote down vote up
def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(
            self.weights, self.num_samples, replacement=True)) 
Example #21
Source File: sampler_unchange.py    From MetaFGNet with MIT License 5 votes vote down vote up
def __iter__(self):
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement)) 
Example #22
Source File: action.py    From bindsnet with GNU Affero General Public License v3.0 5 votes vote down vote up
def select_softmax(pipeline: EnvironmentPipeline, **kwargs) -> int:
    # language=rst
    """
    Selects an action using softmax function based on spiking from a network layer.

    :param pipeline: EnvironmentPipeline with environment that has an integer action
        space and :code:`spike_record` set.
    :return: Action sampled from softmax over activity of similarly-sized output layer.

    Keyword arguments:

    :param str output: Name of output layer whose activity to base action selection on.
    """
    try:
        output = kwargs["output"]
    except KeyError:
        raise KeyError('select_softmax() requires an "output" layer argument.')

    assert (
        pipeline.network.layers[output].n == pipeline.env.action_space.n
    ), "Output layer size is not equal to the size of the action space."

    assert hasattr(
        pipeline, "spike_record"
    ), "EnvironmentPipeline is missing the attribute: spike_record."

    spikes = torch.sum(pipeline.spike_record[output], dim=0)
    probabilities = torch.softmax(spikes, dim=0)
    return torch.multinomial(probabilities, num_samples=1).item() 
Example #23
Source File: module.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def _predict_next_word_sample(self, sentence_list: list):
        # 进行分布式采样,以获得随机结果
        test_item = torch.tensor([[self._word_vocab.stoi[x]] for x in sentence_list], device=DEVICE)
        pred_index = torch.multinomial(torch.softmax(self._model(test_item)[-1], dim=0).cpu().data, 1)
        pred_word = self._word_vocab.itos[pred_index]
        return pred_word 
Example #24
Source File: sampling.py    From lightNLP with Apache License 2.0 5 votes vote down vote up
def sampling(self, num):
        return torch.multinomial(torch.tensor(self.weighted_list), num).tolist() 
Example #25
Source File: controller.py    From torchsupport with MIT License 5 votes vote down vote up
def forward(self, input, hidden, prev_attention):
    trace = {
      "incoming": [],
      "operations": [],
      "incoming_logits": [],
      "operations_logits": []
    }
    
    for idx, indices in enumerate(self.incoming):
      history, hidden = self.lstm(input, hidden)
      embedding_out = self.link_embedding_out(hidden[-1])
      prev_attention.append(embedding_out)
      embedding_in = self.link_embedding_in(input)
      sum_attention = torch.tanh(
        embedding_in + torch.cat([self.prev_attention[index] for index in indices], dim=0)
      )
      logits = self.link_attention[idx](sum_attention)
      choice = torch.multinomial(logits)
      trace["incoming"].append(choice)
      trace["incoming_logits"].append(logits)
      input = history[choice]

    for idx, indices in enumerate(self.operations):
      history, hidden = self.lstm(input, hidden)
      logits = self.choice[idx](hidden[-1])
      choice = torch.multinomial(logits)
      trace["operations"].append(choice)
      trace["operations_logits"].append(logits)
      input = self.op_embedding[idx][choice]

    return history, hidden, trace 
Example #26
Source File: model.py    From pytorch-sgns with MIT License 5 votes vote down vote up
def forward(self, iword, owords):
        batch_size = iword.size()[0]
        context_size = owords.size()[1]
        if self.weights is not None:
            nwords = t.multinomial(self.weights, batch_size * context_size * self.n_negs, replacement=True).view(batch_size, -1)
        else:
            nwords = FT(batch_size, context_size * self.n_negs).uniform_(0, self.vocab_size - 1).long()
        ivectors = self.embedding.forward_i(iword).unsqueeze(2)
        ovectors = self.embedding.forward_o(owords)
        nvectors = self.embedding.forward_o(nwords).neg()
        oloss = t.bmm(ovectors, ivectors).squeeze().sigmoid().log().mean(1)
        nloss = t.bmm(nvectors, ivectors).squeeze().sigmoid().log().view(-1, context_size, self.n_negs).sum(2).mean(1)
        return -(oloss + nloss).mean() 
Example #27
Source File: net.py    From skorch with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def sample(self, input, temperature=1., hidden=None):
        hidden = self.module_.init_hidden(1) if hidden is None else hidden
        output, hidden = self.module_(input, hidden)
        probas = output.squeeze().data.div(temperature).exp()
        sample = torch.multinomial(probas, 1)[-1]
        if probas.dim() > 1:
            sample = sample[0]
        return sample, self.repackage_hidden(hidden) 
Example #28
Source File: FCModel.py    From AAT with MIT License 5 votes vote down vote up
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        outputs = []

        for i in range(seq.size(1)):
            if i == 0:
                xt = self.img_embed(fc_feats)
            else:
                if self.training and i >= 2 and self.ss_prob > 0.0: # otherwiste no need to sample
                    sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
                    sample_mask = sample_prob < self.ss_prob
                    if sample_mask.sum() == 0:
                        it = seq[:, i-1].clone()
                    else:
                        sample_ind = sample_mask.nonzero().view(-1)
                        it = seq[:, i-1].data.clone()
                        #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                        #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                        prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
                        it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
                else:
                    it = seq[:, i-1].clone()
                # break if all the sequences end
                if i >= 2 and seq[:, i-1].sum() == 0:
                    break
                xt = self.embed(it)

            output, state = self.core(xt, state)
            output = F.log_softmax(self.logit(output), dim=1)
            outputs.append(output)

        return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous() 
Example #29
Source File: model.py    From ML_CIA with MIT License 5 votes vote down vote up
def forward(self, iword, owords):
        batch_size = iword.size()[0]
        context_size = owords.size()[1]
        if self.weights is not None:
            nwords = t.multinomial(self.weights, batch_size * context_size * self.n_negs, replacement=True).view(batch_size, -1)
        else:
            nwords = FT(batch_size, context_size * self.n_negs).uniform_(0, self.vocab_size - 1).long()
        ivectors = self.embedding.forward_i(iword).unsqueeze(2)
        ovectors = self.embedding.forward_o(owords)
        nvectors = self.embedding.forward_o(nwords).neg()
        oloss = t.bmm(ovectors, ivectors).squeeze().sigmoid().log().mean(1)
        nloss = t.bmm(nvectors, ivectors).squeeze().sigmoid().log().view(-1, context_size, self.n_negs).sum(2).mean(1)
        return -(oloss + nloss).mean() 
Example #30
Source File: AttModel.py    From AAT with MIT License 5 votes vote down vote up
def _forward(self, fc_feats, att_feats, seq, att_masks=None):
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)

        outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size+1)

        # Prepare the features
        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
        # pp_att_feats is used for attention, we cache it in advance to reduce computation cost

        for i in range(seq.size(1) - 1):
            if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
                sample_prob = fc_feats.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1)
                    #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
                    # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1)
                    prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
            else:
                it = seq[:, i].clone()          
            # break if all the sequences end
            if i >= 1 and seq[:, i].sum() == 0:
                break

            output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
            outputs[:, i] = output

        return outputs