Python torch.clamp_min() Examples

The following are 11 code examples of torch.clamp_min(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: loss.py    From torch-toolbox with BSD 3-Clause "New" or "Revised" License 7 votes vote down vote up
def forward(self, x, target):
        similarity_matrix = x @ x.T  # need gard here
        label_matrix = target.unsqueeze(1) == target.unsqueeze(0)
        negative_matrix = label_matrix.logical_not()
        positive_matrix = label_matrix.fill_diagonal_(False)

        sp = torch.where(positive_matrix, similarity_matrix,
                         torch.zeros_like(similarity_matrix))
        sn = torch.where(negative_matrix, similarity_matrix,
                         torch.zeros_like(similarity_matrix))

        ap = torch.clamp_min(1 + self.m - sp.detach(), min=0.)
        an = torch.clamp_min(sn.detach() + self.m, min=0.)

        logit_p = -self.gamma * ap * (sp - self.dp)
        logit_n = self.gamma * an * (sn - self.dn)

        logit_p = torch.where(positive_matrix, logit_p,
                              torch.zeros_like(logit_p))
        logit_n = torch.where(negative_matrix, logit_n,
                              torch.zeros_like(logit_n))

        loss = F.softplus(torch.logsumexp(logit_p, dim=1) +
                          torch.logsumexp(logit_n, dim=1)).mean()
        return loss 
Example #2
Source File: loss.py    From SegmenTron with Apache License 2.0 5 votes vote down vote up
def _aux_forward(self, *inputs, **kwargs):
        *preds, target = tuple(inputs)
        valid_mask = (target != self.ignore_index).long()
        target_one_hot = F.one_hot(torch.clamp_min(target, 0))
        loss = self._base_forward(preds[0], target_one_hot, valid_mask)
        for i in range(1, len(preds)):
            aux_loss = self._base_forward(preds[i], target_one_hot, valid_mask)
            loss += self.aux_weight * aux_loss
        return loss 
Example #3
Source File: agent.py    From trading-bitcoin-with-reinforcement-learning with MIT License 5 votes vote down vote up
def train(self):
        s_seq, a_seq, r_seq = zip(*self.history)

        # Concatenate s_seq and cast it into a 2-dim Torch Tensor
        state = torch.from_numpy(np.concatenate(s_seq).reshape(len(s_seq), -1)).float()

        # Concatenate a_seq and cast it into a 2-dim Long Tensor
        action = torch.LongTensor(a_seq).unsqueeze(1)

        # Reverse computation of the Q-value
        q_seq = []
        q = 0
        for r in reversed(r_seq):
            q = r + self.discount_factor * q  # discount by 0.9
            q_seq.append(q)
        q_seq.reverse()

        # Standardize Q-value to improve the efficiency of gradient descent
        q_seq = np.array(q_seq)
        q_seq -= q_seq.mean()
        q_seq /= (q_seq.std() + 1e-6)
        np.clip(q_seq, -10, 10, out=q_seq)

        # Cast to 2-dim Torch Tensor
        q_val = torch.from_numpy(q_seq).float().unsqueeze(1)

        # Compute loss function: negative of Q * log(Action_Prob)
        a_prob = self.net(state).gather(1, action)
        a_prob = torch.clamp_min(a_prob, 1e-6)  # prevent too small number since we are taking the log below
        loss = -(q_val * torch.log(a_prob)).mean()

        # Perform gradient ascent
        self.net.train()
        self.net.optim.zero_grad()
        loss.backward()
        self.net.optim.step()
        self.net.eval()

        # Clear buffer after training
        del self.history[:] 
Example #4
Source File: fastspeech.py    From NeMo with Apache License 2.0 5 votes vote down vote up
def forward(self, encoder_output, encoder_output_mask, target=None, alpha=1.0, mel_max_length=None):
        duration_predictor_output = self.duration_predictor(encoder_output, encoder_output_mask)

        if self.training:
            output, dec_pos = self.get_output(encoder_output, target, alpha, mel_max_length)
        else:
            duration_predictor_output = torch.clamp_min(torch.exp(duration_predictor_output) - 1, 0)

            output, dec_pos = self.get_output(encoder_output, duration_predictor_output, alpha)

        return output, dec_pos, duration_predictor_output 
Example #5
Source File: pmath.py    From hyperbolic-image-embeddings with MIT License 5 votes vote down vote up
def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x) 
Example #6
Source File: pmath.py    From hyperbolic-image-embeddings with MIT License 5 votes vote down vote up
def _expmap(x, u, c):  # pragma: no cover
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    second_term = (
        tanh(sqrt_c / 2 * _lambda_x(x, c, keepdim=True) * u_norm)
        * u
        / (sqrt_c * u_norm)
    )
    gamma_1 = _mobius_add(x, second_term, c)
    return gamma_1 
Example #7
Source File: pmath.py    From hyperbolic-image-embeddings with MIT License 5 votes vote down vote up
def _expmap0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1 
Example #8
Source File: pmath.py    From hyperbolic-image-embeddings with MIT License 5 votes vote down vote up
def _logmap0(y, c):
    sqrt_c = c ** 0.5
    y_norm = torch.clamp_min(y.norm(dim=-1, p=2, keepdim=True), 1e-5)
    return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm) 
Example #9
Source File: pmath.py    From hyperbolic-image-embeddings with MIT License 5 votes vote down vote up
def _mobius_matvec(m, x, c):
    x_norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    sqrt_c = c ** 0.5
    mx = x @ m.transpose(-1, -2)
    mx_norm = mx.norm(dim=-1, keepdim=True, p=2)
    res_c = tanh(mx_norm / x_norm * artanh(sqrt_c * x_norm)) * mx / (mx_norm * sqrt_c)
    cond = (mx == 0).prod(-1, keepdim=True, dtype=torch.uint8)
    res_0 = torch.zeros(1, dtype=res_c.dtype, device=res_c.device)
    res = torch.where(cond, res_0, res_c)
    return _project(res, c) 
Example #10
Source File: circle.py    From catalyst with Apache License 2.0 5 votes vote down vote up
def forward(self, normed_features: Tensor, labels: Tensor) -> Tensor:
        """

        Args:
            normed_features: batch with samples features of shape
                [bs; feature_len]
            labels: batch with samples correct labels of shape [bs; ]

        Returns:
            torch.Tensor: circle loss
        """
        sp, sn = _convert_label_to_similarity(normed_features, labels)

        ap = torch.clamp_min(-sp.detach() + 1 + self.margin, min=0.0)
        an = torch.clamp_min(sn.detach() + self.margin, min=0.0)

        delta_p = 1 - self.margin
        delta_n = self.margin

        logit_p = -ap * (sp - delta_p) * self.gamma
        logit_n = an * (sn - delta_n) * self.gamma

        loss = self.soft_plus(
            torch.logsumexp(logit_n, dim=0) + torch.logsumexp(logit_p, dim=0)
        )

        return loss 
Example #11
Source File: model.py    From forte with Apache License 2.0 4 votes vote down vote up
def _compute_soft_head_attention_brute(self, start_ids: torch.LongTensor,
                                           end_ids: torch.LongTensor,
                                           sent_lengths: torch.LongTensor,
                                           states: torch.Tensor,
                                           word_inputs: torch.Tensor) \
            -> Tuple[torch.Tensor, torch.LongTensor]:
        device = start_ids.device
        batch_size, max_len = states.size()[:2]
        num_spans = start_ids.size(1)
        max_span_width = self._hparams.max_span_width
        batch_offset = torch.arange(batch_size, device=device) * max_len
        span_indices = torch.arange(max_span_width, device=device)
        # span_indices: (batch_size, num_spans, max_span_width)
        span_indices = (span_indices.expand(batch_size, num_spans, -1) +
                        start_ids.unsqueeze(-1) + batch_offset.view(-1, 1, 1))
        # valid_spans: (batch_size, num_spans)
        valid_spans = end_ids < sent_lengths.unsqueeze(-1)
        # valid_spans_idx: (total_spans)
        valid_spans_idx = valid_spans.view(-1).nonzero().view(-1)
        # flat_span_indices: (total_spans, max_span_width)
        flat_span_indices = torch.index_select(
            span_indices.view(-1, max_span_width), dim=0, index=valid_spans_idx)

        # flat_sent_lengths: (total_spans)
        flat_sent_lengths = torch.index_select(
            (torch.min(end_ids + 1, sent_lengths.unsqueeze(-1)) +
             batch_offset.unsqueeze(-1)).view(-1),
            dim=0, index=valid_spans_idx)
        # flat_mask: (total_spans, max_span_width)
        flat_mask = flat_span_indices < flat_sent_lengths.unsqueeze(-1)
        flat_span_indices *= flat_mask.type_as(flat_span_indices)

        # span_word_inputs: (total_spans, max_span_width, word_input_dim)
        span_word_inputs = torch.index_select(
            word_inputs.view(-1, word_inputs.size(-1)),
            dim=0, index=flat_span_indices.view(-1)
        ).view(*flat_span_indices.size(), -1)

        # logits: (batch_size, max_len)
        logits = self.head_attention(states).squeeze(-1)
        # flat_span_logits: (total_spans, max_span_width)
        flat_span_logits = torch.index_select(
            logits.view(-1), dim=0, index=flat_span_indices.view(-1)
        ).view(flat_span_indices.size())
        masked_span_logits = (flat_span_logits -
                              1e10 * (~flat_mask).type_as(flat_span_logits))
        weights = torch.softmax(masked_span_logits, dim=-1)

        # weighted_inputs: (total_spans, max_span_width, word_input_dim)
        weighted_inputs = span_word_inputs * weights.unsqueeze(-1)
        # soft_head: (total_spans, word_input_dim)
        soft_head = torch.sum(weighted_inputs, dim=1)
        # indices: (batch_size, num_spans)
        indices = torch.cumsum(valid_spans.view(-1).type(torch.long), dim=0) - 1
        indices = torch.clamp_min(indices, 0).view_as(valid_spans)

        return soft_head, indices