Python torch.triu() Examples

The following are 30 code examples of torch.triu(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: basics.py    From heat with MIT License 6 votes vote down vote up
def triu(m, k=0):
    """
    Returns the upper triangular part of the tensor, the other elements of the result tensor are set to 0.

    The upper triangular part of the tensor is defined as the elements on and below the diagonal.

    The argument k controls which diagonal to consider. If k=0, all elements on and below the main diagonal are
    retained. A positive value includes just as many diagonals above the main diagonal, and similarly a negative
    value excludes just as many diagonals below the main diagonal.

    Parameters
    ----------
    m : ht.DNDarray
        Input tensor for which to compute the upper triangle.
    k : int, optional
        Diagonal above which to zero elements. k=0 (default) is the main diagonal, k<0 is below and k>0 is above.

    Returns
    -------
    upper_triangle : ht.DNDarray
        Upper triangle of the input tensor.
    """
    return __tri_op(m, k, torch.triu) 
Example #2
Source File: list_probability.py    From pt-ranking.github.io with MIT License 6 votes vote down vote up
def log_ranking_prob_Bradley_Terry(batch_preds):
    '''
    :param batch_preds: [batch_size, list_size]
    :return:
    '''
    assert 2 == len(batch_preds.size())

    max_v = torch.max(batch_preds)
    new_batch_preds = torch.exp(batch_preds - max_v)

    batch_numerators = torch.unsqueeze(new_batch_preds, dim=2).repeat(1, 1, batch_preds.size(1))

    batch_denominaotrs = torch.unsqueeze(new_batch_preds, dim=2) + torch.unsqueeze(new_batch_preds, dim=1)

    batch_BT_probs = batch_numerators / batch_denominaotrs

    batch_log_ranking_prob = torch.sum(torch.sum(torch.triu(torch.log(batch_BT_probs), diagonal=1), dim=2), dim=1)

    return batch_log_ranking_prob 
Example #3
Source File: modeling_bart.py    From exbert with Apache License 2.0 6 votes vote down vote up
def _prepare_bart_decoder_inputs(
    config, input_ids, decoder_input_ids=None, decoder_attn_mask=None,
):
    """Prepare masks that ignore padding tokens  decoder and a causal lm mask for the decoder if
    none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
    """
    pad_token_id = config.pad_token_id
    need_causal_mask = not config.output_past
    if decoder_input_ids is None:
        decoder_input_ids = shift_tokens_right(input_ids, pad_token_id)
    bsz, tgt_len = decoder_input_ids.size()[:2]
    if decoder_attn_mask is None:
        decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
        if need_causal_mask:
            causal_lm_mask = torch.triu(fill_with_neg_inf(torch.zeros(tgt_len, tgt_len)), 1)
        else:
            causal_lm_mask = None
        new_shape = (bsz, tgt_len, tgt_len)
        # make it broadcastable so can just be added to the attention coefficients
        decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
    assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
    return decoder_input_ids, decoder_attn_mask 
Example #4
Source File: textual_heads.py    From virtex with MIT License 6 votes vote down vote up
def _generate_future_mask(
        self, size: int, dtype: torch.dtype, device: torch.device
    ) -> torch.Tensor:
        r"""
        Generate a mask for "future" positions, useful when using this module
        for language modeling.

        Parameters
        ----------
        size: int
        """
        # Default mask is for forward direction. Flip for backward direction.
        mask = torch.triu(
            torch.ones(size, size, device=device, dtype=dtype), diagonal=1
        )
        mask = mask.masked_fill(mask == 1, float("-inf"))
        return mask 
Example #5
Source File: cloze_transformer_model.py    From translate with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def buffered_future_mask(self, tensor):
        """attend all surounding words except itself
           [[0, -inf, 0]
            [0,  0, -inf]
            [0,  0,   0]]
        The attention map is not ture diagonal since we predict y_{t+1} at time-step t
        """
        dim = tensor.size(0)
        if (
            not hasattr(self, "_future_mask")
            or self._future_mask is None
            or self._future_mask.device != tensor.device
        ):
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(tensor.new(dim, dim)), 1
            )
            self._future_mask = torch.tril(self._future_mask, 1)
        if self._future_mask.size(0) < dim:
            self._future_mask = torch.triu(
                utils.fill_with_neg_inf(self._future_mask.resize_(dim, dim)), 1
            )
            self._future_mask = torch.tril(self._future_mask, 1)
        return self._future_mask[:dim, :dim] 
Example #6
Source File: model.py    From sodeep with BSD 3-Clause Clear License 6 votes vote down vote up
def comp(self, inpu):
        in_mat1 = torch.triu(inpu.repeat(inpu.size(0), 1), diagonal=1)
        in_mat2 = torch.triu(inpu.repeat(inpu.size(0), 1).t(), diagonal=1)

        comp_first = (in_mat1 - in_mat2)
        comp_second = (in_mat2 - in_mat1)

        std1 = torch.std(comp_first).item()
        std2 = torch.std(comp_second).item()

        comp_first = torch.sigmoid(comp_first * (6.8 / std1))
        comp_second = torch.sigmoid(comp_second * (6.8 / std2))

        comp_first = torch.triu(comp_first, diagonal=1)
        comp_second = torch.triu(comp_second, diagonal=1)

        return (torch.sum(comp_first, 1) + torch.sum(comp_second, 0) + 1) / inpu.size(0) 
Example #7
Source File: decoders.py    From meshed-memory-transformer with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def forward(self, input, encoder_output, mask_encoder):
        # input (b_s, seq_len)
        b_s, seq_len = input.shape[:2]
        mask_queries = (input != self.padding_idx).unsqueeze(-1).float()  # (b_s, seq_len, 1)
        mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8, device=input.device),
                                         diagonal=1)
        mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)
        mask_self_attention = mask_self_attention + (input == self.padding_idx).unsqueeze(1).unsqueeze(1).byte()
        mask_self_attention = mask_self_attention.gt(0)  # (b_s, 1, seq_len, seq_len)
        if self._is_stateful:
            self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention], -1)
            mask_self_attention = self.running_mask_self_attention

        seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device)  # (b_s, seq_len)
        seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
        if self._is_stateful:
            self.running_seq.add_(1)
            seq = self.running_seq

        out = self.word_emb(input) + self.pos_emb(seq)
        for i, l in enumerate(self.layers):
            out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)

        out = self.fc(out)
        return F.log_softmax(out, dim=-1) 
Example #8
Source File: model_utils.py    From TVQAplus with MIT License 6 votes vote down vote up
def find_max_triples(p1, p2, topN=5, prob_thd=None):
    """ Find a list of (k1, k2) where k1 >= k2 with the maximum values of p1[k1] * p2[k2]
    Args:
        p1 (torch.CudaTensor): (N, L) batched start_idx probabilities
        p2 (torch.CudaTensor): (N, L) batched end_idx probabilities
        topN (int): return topN pairs with highest values
        prob_thd (float):
    Returns:
        batched_sorted_triple: N * [(st_idx, ed_idx, confidence), ...]
    """
    product = torch.bmm(p1.unsqueeze(2), p2.unsqueeze(1))  # (N, L, L), end_idx >= start_idx
    upper_product = torch.stack([torch.triu(p) for p in product]
                                ).data.cpu().numpy()  # (N, L, L) the lower part becomes zeros
    batched_sorted_triple = []
    for idx, e in enumerate(upper_product):
        sorted_triple = topN_array_2d(e, topN=topN)
        if prob_thd is not None:
            sorted_triple = [t for t in sorted_triple if t[2] >= prob_thd]
        batched_sorted_triple.append(sorted_triple)
    return batched_sorted_triple 
Example #9
Source File: made_test.py    From nsf with MIT License 6 votes vote down vote up
def test_total_mask_random(self):
        features = 10
        hidden_features = 50
        num_blocks = 5
        output_multiplier = 1

        model = made.MADE(
            features=features,
            hidden_features=hidden_features,
            num_blocks=num_blocks,
            output_multiplier=output_multiplier,
            use_residual_blocks=False,
            random_mask=True,
        )
        total_mask = model.initial_layer.mask
        for block in model.blocks:
            self.assertIsInstance(block, made.MaskedFeedforwardBlock)
            total_mask = block.linear.mask @ total_mask
        total_mask = model.final_layer.mask @ total_mask
        total_mask = (total_mask > 0).float()
        self.assertEqual(torch.triu(total_mask), torch.zeros([features, features])) 
Example #10
Source File: xlnet_encoder.py    From texar-pytorch with Apache License 2.0 6 votes vote down vote up
def _create_causal_attn_mask(self,
                                 seq_len: int,
                                 mem_len: int,
                                 same_length: bool = False) -> torch.Tensor:
        r"""Create causal attention mask of shape
        `(seq_len, mem_len + seq_len)`.
        """
        assert self.r_w_bias is not None
        device = self.r_w_bias.device
        attn_mask = torch.ones(seq_len, seq_len, device=device)
        mask_u = torch.triu(attn_mask, diagonal=1)
        attn_mask_pad = torch.zeros(seq_len, mem_len, device=device)
        ret = torch.cat([attn_mask_pad, mask_u], dim=1)
        if same_length:
            mask_l = torch.tril(attn_mask, diagonal=-1)
            ret = torch.cat([ret[:, :seq_len] + mask_l, ret[:, seq_len:]], 1)
        return ret 
Example #11
Source File: Transformer.py    From ConvLab with MIT License 5 votes vote down vote up
def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''

    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls

    return subsequent_mask 
Example #12
Source File: utils.py    From Speech-Transformer with MIT License 5 votes vote down vote up
def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''

    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls

    return subsequent_mask 
Example #13
Source File: adhoc_metric.py    From pt-ranking.github.io with MIT License 5 votes vote down vote up
def torch_kendall_tau(sys_ranking, natural_ascending_as_reference = True):
	'''
	$\tau = 1.0 - \frac{2S(\pi, \delta)}{N(N-1)/2}$, cf. 2006-Automatic Evaluation of Information Ordering: Kendall’s Tau
	The tie issue is not considered within this version.
	The current implementation is just counting the inversion number, then normalized by n(n-1)/2. The underlying assumption is that the reference ltr_adhoc is the ideal ltr_adhoc, say labels are ordered in a descending order.
	:param sys_ranking: system's ltr_adhoc, whose entries can be predicted values, labels, etc.
	:return:
	'''
	assert 1 == len(sys_ranking.size()) # one-dimension vector

	ranking_size = sys_ranking.size(0)
	pair_diffs = sys_ranking.view(-1, 1) - sys_ranking.view(1, -1)

	if natural_ascending_as_reference:
		bi_pair_diffs = torch.clamp(pair_diffs, min=0, max=1)
		bi_pair_diffs_triu1 = torch.triu(bi_pair_diffs, diagonal=1)
		#print('bi_pair_diffs_triu1\n', bi_pair_diffs_triu1)

		tau = 1.0 - 4 * torch.sum(bi_pair_diffs_triu1) / (ranking_size*(ranking_size-1))

	else: # i.e., natural descending as the reference
		bi_pair_diffs = torch.clamp(pair_diffs, min=-1, max=0)
		bi_pair_diffs_triu1 = torch.triu(bi_pair_diffs, diagonal=1)
		#print('bi_pair_diffs_triu1\n', bi_pair_diffs_triu1)
		print('total discordant: ', 2*torch.sum(bi_pair_diffs_triu1))

		tau = 1.0 + 4 * torch.sum(bi_pair_diffs_triu1) / (ranking_size*(ranking_size-1))

	return tau 
Example #14
Source File: flows.py    From Maximally_Interfered_Retrieval with MIT License 5 votes vote down vote up
def __init__(self, num_ortho_vecs):

        super(Sylvester, self).__init__()

        self.num_ortho_vecs = num_ortho_vecs

        self.h = nn.Tanh()

        triu_mask = torch.triu(torch.ones(num_ortho_vecs, num_ortho_vecs), diagonal=1).unsqueeze(0)
        diag_idx = torch.arange(0, num_ortho_vecs).long()

        self.register_buffer('triu_mask', Variable(triu_mask))
        self.triu_mask.requires_grad = False
        self.register_buffer('diag_idx', diag_idx) 
Example #15
Source File: multiheadattention.py    From text with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def generate_square_subsequent_mask(nbatch, sz):
    r"""Generate a square mask for the sequence. The masked positions are filled with True.
        Unmasked positions are filled with False.

    Args:
        nbatch: the number of batch size
        sz: the size of square mask
    """
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1).repeat(nbatch, 1, 1)
    return mask 
Example #16
Source File: seq2slate.py    From ReAgent with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def subsequent_mask(size, device):
    """
    Mask out subsequent positions. Mainly used in the decoding process,
    in which an item should not attend subsequent items.
    """
    attn_shape = (1, size, size)
    subsequent_mask = (
        1 - torch.triu(torch.ones(*attn_shape, device=device), diagonal=1)
    ).type(torch.int8)
    return subsequent_mask 
Example #17
Source File: modeling_drop.py    From MTMSN with Apache License 2.0 5 votes vote down vote up
def get_best_span(span_start_logits: torch.Tensor, span_end_logits: torch.Tensor) -> torch.Tensor:
    """
    This acts the same as the static method ``BidirectionalAttentionFlow.get_best_span()``
    in ``allennlp/models/reading_comprehension/bidaf.py``. We keep it here so that users can
    directly import this function without the class.

    We call the inputs "logits" - they could either be unnormalized logits or normalized log
    probabilities.  A log_softmax operation is a constant shifting of the entire logit
    vector, so taking an argmax over either one gives the same result.
    """
    if span_start_logits.dim() != 2 or span_end_logits.dim() != 2:
        raise ValueError("Input shapes must be (batch_size, passage_length)")
    batch_size, passage_length = span_start_logits.size()
    device = span_start_logits.device
    # (batch_size, passage_length, passage_length)
    span_log_probs = span_start_logits.unsqueeze(2) + span_end_logits.unsqueeze(1)
    # Only the upper triangle of the span matrix is valid; the lower triangle has entries where
    # the span ends before it starts.
    span_log_mask = torch.triu(torch.ones((passage_length, passage_length),
                                          device=device)).log()
    valid_span_log_probs = span_log_probs + span_log_mask

    # Here we take the span matrix and flatten it, then find the best span using argmax.  We
    # can recover the start and end indices from this flattened list using simple modular
    # arithmetic.
    # (batch_size, passage_length * passage_length)
    best_spans = valid_span_log_probs.view(batch_size, -1).argmax(-1)
    span_start_indices = best_spans // passage_length
    span_end_indices = best_spans % passage_length
    return torch.stack([span_start_indices, span_end_indices], dim=-1) 
Example #18
Source File: layers.py    From MultiTurnDialogZoo with MIT License 5 votes vote down vote up
def gen_nopeek_mask(length):
    # for transformer masking
    mask = torch.triu(torch.ones(length, length)) == 1
    mask = mask.transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    
    if torch.cuda.is_available():
        mask = mask.cuda()

    return mask


# ========= BOS Loss ========== # 
Example #19
Source File: flows.py    From UMNN with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def __init__(self, num_ortho_vecs):

        super(Sylvester, self).__init__()

        self.num_ortho_vecs = num_ortho_vecs

        self.h = nn.Tanh()

        triu_mask = torch.triu(torch.ones(num_ortho_vecs, num_ortho_vecs), diagonal=1).unsqueeze(0)
        diag_idx = torch.arange(0, num_ortho_vecs).long()

        self.register_buffer('triu_mask', Variable(triu_mask))
        self.triu_mask.requires_grad = False
        self.register_buffer('diag_idx', diag_idx) 
Example #20
Source File: model.py    From examples with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask 
Example #21
Source File: transformer.py    From Guyu with MIT License 5 votes vote down vote up
def get_mask(size):
        weights = torch.triu(torch.ones((size, size), dtype = torch.bool), 1)
        return weights 
Example #22
Source File: transformer_utils.py    From Count-Sketch-Optimizers with Apache License 2.0 5 votes vote down vote up
def buffered_mask(self, tensor):
        dim = tensor.size(-1)
        if self._mask is None:
            self._mask = torch.triu(fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._mask.size(0) < dim:
            self._mask = torch.triu(fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
        return self._mask[:dim, :dim] 
Example #23
Source File: layers_GGT.py    From generative-graph-transformer with MIT License 5 votes vote down vote up
def generate_mask_sequence(size, device="cuda:0"):
    """
    :param size: seq_len
    :param device: cuda or cpu
    :return: mask with future timesteps zero-valued. shape 1 x size x size
    """
    x = torch.ones((size, size), device=device)
    x = torch.triu(x, diagonal=1)
    return x.unsqueeze(0) == 0  # invert and convert to byte 
Example #24
Source File: modeling_transfo_xl.py    From bert_on_stilts with Apache License 2.0 5 votes vote down vote up
def _parallelogram_mask(self, h, w, left=False):
        mask = torch.ones((h, w)).byte()
        m = min(h, w)
        mask[:m,:m] = torch.triu(mask[:m,:m])
        mask[-m:,-m:] = torch.tril(mask[-m:,-m:])

        if left:
            return mask
        else:
            return mask.flip(0) 
Example #25
Source File: sylvester.py    From ddsp_pytorch with GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, dim, num_ortho_vecs=16, steps=50, amortized='none'):
        """
        :param zk: shape: (batch_size, z_size)
        :param r1: shape: (batch_size, z_size, z_size)
        :param r2: shape: (batch_size, z_size, z_size)
        :param b: shape: (batch_size, 1, z_size)
        """
        super(TriangularSylvesterFlow, self).__init__(dim, num_ortho_vecs, steps, amortized)
        self.num_ortho_vecs = num_ortho_vecs
        self.diag_idx = torch.arange(0, self.dim).long()
        self.mask = torch.triu(torch.ones(self.dim, self.dim), diagonal=1).unsqueeze(0)
        # Register buffers
        self.register_buffer('idx_d', self.diag_idx) 
Example #26
Source File: common.py    From torch-light with MIT License 5 votes vote down vote up
def get_subsequent_mask(seq):
    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)

    return subsequent_mask 
Example #27
Source File: mem_transformer.py    From transformer-xl with Apache License 2.0 5 votes vote down vote up
def _parallelogram_mask(self, h, w, left=False):
        mask = torch.ones((h, w)).byte()
        m = min(h, w)
        mask[:m,:m] = torch.triu(mask[:m,:m])
        mask[-m:,-m:] = torch.tril(mask[-m:,-m:])

        if left:
            return mask
        else:
            return mask.flip(0) 
Example #28
Source File: multihead_attention.py    From inversecooking with MIT License 5 votes vote down vote up
def buffered_mask(self, tensor):
        dim = tensor.size(-1)
        if self._mask is None:
            self._mask = torch.triu(fill_with_neg_inf(tensor.new(dim, dim)), 1)
        if self._mask.size(0) < dim:
            self._mask = torch.triu(fill_with_neg_inf(self._mask.resize_(dim, dim)), 1)
        return self._mask[:dim, :dim] 
Example #29
Source File: modeling_xlnet.py    From CCF-BDCI-Sentiment-Analysis-Baseline with Apache License 2.0 5 votes vote down vote up
def create_mask(self, qlen, mlen):
        """
        Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.

        Args:
            qlen: TODO Lysandre didn't fill
            mlen: TODO Lysandre didn't fill

        ::

                  same_length=False:      same_length=True:
                  <mlen > <  qlen >       <mlen > <  qlen >
               ^ [0 0 0 0 0 1 1 1 1]     [0 0 0 0 0 1 1 1 1]
                 [0 0 0 0 0 0 1 1 1]     [1 0 0 0 0 0 1 1 1]
            qlen [0 0 0 0 0 0 0 1 1]     [1 1 0 0 0 0 0 1 1]
                 [0 0 0 0 0 0 0 0 1]     [1 1 1 0 0 0 0 0 1]
               v [0 0 0 0 0 0 0 0 0]     [1 1 1 1 0 0 0 0 0]

        """
        attn_mask = torch.ones([qlen, qlen])
        mask_up = torch.triu(attn_mask, diagonal=1)
        attn_mask_pad = torch.zeros([qlen, mlen])
        ret = torch.cat([attn_mask_pad, mask_up], dim=1)
        if self.same_length:
            mask_lo = torch.tril(attn_mask, diagonal=-1)
            ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)

        ret = ret.to(next(self.parameters()))
        return ret 
Example #30
Source File: modeling_xlnet.py    From exbert with Apache License 2.0 5 votes vote down vote up
def create_mask(self, qlen, mlen):
        """
        Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.

        Args:
            qlen: Sequence length
            mlen: Mask length

        ::

                  same_length=False:      same_length=True:
                  <mlen > <  qlen >       <mlen > <  qlen >
               ^ [0 0 0 0 0 1 1 1 1]     [0 0 0 0 0 1 1 1 1]
                 [0 0 0 0 0 0 1 1 1]     [1 0 0 0 0 0 1 1 1]
            qlen [0 0 0 0 0 0 0 1 1]     [1 1 0 0 0 0 0 1 1]
                 [0 0 0 0 0 0 0 0 1]     [1 1 1 0 0 0 0 0 1]
               v [0 0 0 0 0 0 0 0 0]     [1 1 1 1 0 0 0 0 0]

        """
        attn_mask = torch.ones([qlen, qlen])
        mask_up = torch.triu(attn_mask, diagonal=1)
        attn_mask_pad = torch.zeros([qlen, mlen])
        ret = torch.cat([attn_mask_pad, mask_up], dim=1)
        if self.same_length:
            mask_lo = torch.tril(attn_mask, diagonal=-1)
            ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)

        ret = ret.to(next(self.parameters()))
        return ret