Python torch.exp() Examples
The following are 30
code examples of torch.exp().
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.
Example #1
Source File: utils.py From pruning_yolov3 with GNU General Public License v3.0 | 8 votes |
def plot_wh_methods(): # from utils.utils import *; plot_wh_methods() # Compares the two methods for width-height anchor multiplication # https://github.com/ultralytics/yolov3/issues/168 x = np.arange(-4.0, 4.0, .1) ya = np.exp(x) yb = torch.sigmoid(torch.from_numpy(x)).numpy() * 2 fig = plt.figure(figsize=(6, 3), dpi=150) plt.plot(x, ya, '.-', label='yolo method') plt.plot(x, yb ** 2, '.-', label='^2 power method') plt.plot(x, yb ** 2.5, '.-', label='^2.5 power method') plt.xlim(left=-4, right=4) plt.ylim(bottom=0, top=6) plt.xlabel('input') plt.ylabel('output') plt.legend() fig.tight_layout() fig.savefig('comparison.png', dpi=200)
Example #2
Source File: functional.py From audio with BSD 2-Clause "Simplified" License | 6 votes |
def mu_law_decoding( x_mu: Tensor, quantization_channels: int ) -> Tensor: r"""Decode mu-law encoded signal. For more info see the `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ This expects an input with values between 0 and quantization_channels - 1 and returns a signal scaled between -1 and 1. Args: x_mu (Tensor): Input tensor quantization_channels (int): Number of channels Returns: Tensor: Input after mu-law decoding """ mu = quantization_channels - 1.0 if not x_mu.is_floating_point(): x_mu = x_mu.to(torch.float) mu = torch.tensor(mu, dtype=x_mu.dtype) x = ((x_mu) / mu) * 2 - 1.0 x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu return x
Example #3
Source File: box_utils.py From CSD-SSD with MIT License | 6 votes |
def decode(loc, priors, variances): """Decode locations from predictions using priors to undo the encoding we did for offset regression at train time. Args: loc (tensor): location predictions for loc layers, Shape: [num_priors,4] priors (tensor): Prior boxes in center-offset form. Shape: [num_priors,4]. variances: (list[float]) Variances of priorboxes Return: decoded bounding box predictions """ boxes = torch.cat(( priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) boxes[:, :2] -= boxes[:, 2:] / 2 boxes[:, 2:] += boxes[:, :2] return boxes
Example #4
Source File: adaptive_inference.py From MSDNet-PyTorch with MIT License | 6 votes |
def dynamic_eval_with_threshold(self, logits, targets, flops, T): n_stage, n_sample, _ = logits.size() max_preds, argmax_preds = logits.max(dim=2, keepdim=False) # take the max logits as confidence acc_rec, exp = torch.zeros(n_stage), torch.zeros(n_stage) acc, expected_flops = 0, 0 for i in range(n_sample): gold_label = targets[i] for k in range(n_stage): if max_preds[k][i].item() >= T[k]: # force to exit at k _g = int(gold_label.item()) _pred = int(argmax_preds[k][i].item()) if _g == _pred: acc += 1 acc_rec[k] += 1 exp[k] += 1 break acc_all, sample_all = 0, 0 for k in range(n_stage): _t = exp[k] * 1.0 / n_sample sample_all += exp[k] expected_flops += _t * flops[k] acc_all += acc_rec[k] return acc * 100.0 / n_sample, expected_flops
Example #5
Source File: adaptive_inference.py From MSDNet-PyTorch with MIT License | 6 votes |
def dynamic_evaluate(model, test_loader, val_loader, args): tester = Tester(model, args) if os.path.exists(os.path.join(args.save, 'logits_single.pth')): val_pred, val_target, test_pred, test_target = \ torch.load(os.path.join(args.save, 'logits_single.pth')) else: val_pred, val_target = tester.calc_logit(val_loader) test_pred, test_target = tester.calc_logit(test_loader) torch.save((val_pred, val_target, test_pred, test_target), os.path.join(args.save, 'logits_single.pth')) flops = torch.load(os.path.join(args.save, 'flops.pth')) with open(os.path.join(args.save, 'dynamic.txt'), 'w') as fout: for p in range(1, 40): print("*********************") _p = torch.FloatTensor(1).fill_(p * 1.0 / 20) probs = torch.exp(torch.log(_p) * torch.range(1, args.nBlocks)) probs /= probs.sum() acc_val, _, T = tester.dynamic_eval_find_threshold( val_pred, val_target, probs, flops) acc_test, exp_flops = tester.dynamic_eval_with_threshold( test_pred, test_target, flops, T) print('valid acc: {:.3f}, test acc: {:.3f}, test flops: {:.2f}M'.format(acc_val, acc_test, exp_flops / 1e6)) fout.write('{}\t{}\n'.format(acc_test, exp_flops.item()))
Example #6
Source File: loss.py From overhaul-distillation with MIT License | 6 votes |
def FocalLoss(self, logit, target, gamma=2, alpha=0.5): n, c, h, w = logit.size() criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, size_average=self.size_average) if self.cuda: criterion = criterion.cuda() logpt = -criterion(logit, target.long()) pt = torch.exp(logpt) if alpha is not None: logpt *= alpha loss = -((1 - pt) ** gamma) * logpt if self.batch_average: loss /= n return loss
Example #7
Source File: trainer.py From treelstm.pytorch with MIT License | 6 votes |
def test(self, dataset): self.model.eval() with torch.no_grad(): total_loss = 0.0 predictions = torch.zeros(len(dataset), dtype=torch.float, device='cpu') indices = torch.arange(1, dataset.num_classes + 1, dtype=torch.float, device='cpu') for idx in tqdm(range(len(dataset)), desc='Testing epoch ' + str(self.epoch) + ''): ltree, linput, rtree, rinput, label = dataset[idx] target = utils.map_label_to_target(label, dataset.num_classes) linput, rinput = linput.to(self.device), rinput.to(self.device) target = target.to(self.device) output = self.model(ltree, linput, rtree, rinput) loss = self.criterion(output, target) total_loss += loss.item() output = output.squeeze().to('cpu') predictions[idx] = torch.dot(indices, torch.exp(output)) return total_loss / len(dataset), predictions
Example #8
Source File: mmd.py From transferlearning with MIT License | 6 votes |
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): n_samples = int(source.size()[0]) + int(target.size()[0]) total = torch.cat([source, target], dim=0) total0 = total.unsqueeze(0).expand( int(total.size(0)), int(total.size(0)), int(total.size(1))) total1 = total.unsqueeze(1).expand( int(total.size(0)), int(total.size(0)), int(total.size(1))) L2_distance = ((total0-total1)**2).sum(2) if fix_sigma: bandwidth = fix_sigma else: bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) bandwidth /= kernel_mul ** (kernel_num // 2) bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] return sum(kernel_val)
Example #9
Source File: mmd_pytorch.py From transferlearning with MIT License | 6 votes |
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): n_samples = int(source.size()[0]) + int(target.size()[0]) total = torch.cat([source, target], dim=0) total0 = total.unsqueeze(0).expand( int(total.size(0)), int(total.size(0)), int(total.size(1))) total1 = total.unsqueeze(1).expand( int(total.size(0)), int(total.size(0)), int(total.size(1))) L2_distance = ((total0-total1)**2).sum(2) if fix_sigma: bandwidth = fix_sigma else: bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples) bandwidth /= kernel_mul ** (kernel_num // 2) bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)] kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list] return sum(kernel_val)
Example #10
Source File: tsd_net.py From ConvLab with MIT License | 6 votes |
def forward(self, z_enc_out, u_enc_out, u_input_np, m_t_input, degree_input, last_hidden, z_input_np): sparse_z_input = Variable(self.get_sparse_selective_input(z_input_np), requires_grad=False) m_embed = self.emb(m_t_input) z_context = self.attn_z(last_hidden, z_enc_out) u_context = self.attn_u(last_hidden, u_enc_out) gru_in = torch.cat([m_embed, u_context, z_context, degree_input.unsqueeze(0)], dim=2) gru_out, last_hidden = self.gru(gru_in, last_hidden) gen_score = self.proj(torch.cat([z_context, u_context, gru_out], 2)).squeeze(0) z_copy_score = F.tanh(self.proj_copy2(z_enc_out.transpose(0, 1))) z_copy_score = torch.matmul(z_copy_score, gru_out.squeeze(0).unsqueeze(2)).squeeze(2) z_copy_score = z_copy_score.cpu() z_copy_score_max = torch.max(z_copy_score, dim=1, keepdim=True)[0] z_copy_score = torch.exp(z_copy_score - z_copy_score_max) # [B,T] z_copy_score = torch.log(torch.bmm(z_copy_score.unsqueeze(1), sparse_z_input)).squeeze( 1) + z_copy_score_max # [B,V] z_copy_score = cuda_(z_copy_score) scores = F.softmax(torch.cat([gen_score, z_copy_score], dim=1), dim=1) gen_score, z_copy_score = scores[:, :cfg.vocab_size], \ scores[:, cfg.vocab_size:] proba = gen_score + z_copy_score[:, :cfg.vocab_size] # [B,V] proba = torch.cat([proba, z_copy_score[:, cfg.vocab_size:]], 1) return proba, last_hidden, gru_out
Example #11
Source File: model_utils.py From medicaldetectiontoolkit with Apache License 2.0 | 6 votes |
def apply_box_deltas_2D(boxes, deltas): """Applies the given deltas to the given boxes. boxes: [N, 4] where each row is y1, x1, y2, x2 deltas: [N, 4] where each row is [dy, dx, log(dh), log(dw)] """ # Convert to y, x, h, w height = boxes[:, 2] - boxes[:, 0] width = boxes[:, 3] - boxes[:, 1] center_y = boxes[:, 0] + 0.5 * height center_x = boxes[:, 1] + 0.5 * width # Apply deltas center_y += deltas[:, 0] * height center_x += deltas[:, 1] * width height *= torch.exp(deltas[:, 2]) width *= torch.exp(deltas[:, 3]) # Convert back to y1, x1, y2, x2 y1 = center_y - 0.5 * height x1 = center_x - 0.5 * width y2 = y1 + height x2 = x1 + width result = torch.stack([y1, x1, y2, x2], dim=1) return result
Example #12
Source File: bbox_transform.py From Collaborative-Learning-for-Weakly-Supervised-Object-Detection with MIT License | 5 votes |
def bbox_transform_inv(boxes, deltas): # Input should be both tensor or both Variable and on the same device if len(boxes) == 0: return deltas.detach() * 0 widths = boxes[:, 2] - boxes[:, 0] + 1.0 heights = boxes[:, 3] - boxes[:, 1] + 1.0 ctr_x = boxes[:, 0] + 0.5 * widths ctr_y = boxes[:, 1] + 0.5 * heights dx = deltas[:, 0::4] dy = deltas[:, 1::4] dw = deltas[:, 2::4] dh = deltas[:, 3::4] pred_ctr_x = dx * widths.unsqueeze(1) + ctr_x.unsqueeze(1) pred_ctr_y = dy * heights.unsqueeze(1) + ctr_y.unsqueeze(1) pred_w = torch.exp(dw) * widths.unsqueeze(1) pred_h = torch.exp(dh) * heights.unsqueeze(1) pred_boxes = torch.cat(\ [_.unsqueeze(2) for _ in [pred_ctr_x - 0.5 * pred_w,\ pred_ctr_y - 0.5 * pred_h,\ pred_ctr_x + 0.5 * pred_w,\ pred_ctr_y + 0.5 * pred_h]], 2).view(len(boxes), -1) return pred_boxes
Example #13
Source File: models_task.py From ConvLab with MIT License | 5 votes |
def gaussian_logprob(self, mu, logvar, sample_z): var = th.exp(logvar) constant = float(-0.5 * np.log(2*np.pi)) logprob = constant - 0.5 * logvar - th.pow((mu-sample_z), 2) / (2.0*var) return logprob
Example #14
Source File: utils.py From conv-social-pooling with MIT License | 5 votes |
def outputActivation(x): muX = x[:,:,0:1] muY = x[:,:,1:2] sigX = x[:,:,2:3] sigY = x[:,:,3:4] rho = x[:,:,4:5] sigX = torch.exp(sigX) sigY = torch.exp(sigY) rho = torch.tanh(rho) out = torch.cat([muX, muY, sigX, sigY, rho],dim=2) return out ## Batchwise NLL loss, uses mask for variable output lengths
Example #15
Source File: encoder.py From pytorch_sac_ae with MIT License | 5 votes |
def reparameterize(self, mu, logstd): std = torch.exp(logstd) eps = torch.randn_like(std) return mu + eps * std
Example #16
Source File: models_task.py From ConvLab with MIT License | 5 votes |
def forward_rl(self, data_feed, max_words, temp=0.1): ctx_lens = data_feed['context_lens'] # (batch_size, ) short_ctx_utts = self.np2var(self.extract_short_ctx(data_feed['contexts'], ctx_lens), LONG) bs_label = self.np2var(data_feed['bs'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) db_label = self.np2var(data_feed['db'], FLOAT) # (batch_size, max_ctx_len, max_utt_len) batch_size = len(ctx_lens) utt_summary, _, enc_outs = self.utt_encoder(short_ctx_utts.unsqueeze(1)) # create decoder initial states enc_last = th.cat([bs_label, db_label, utt_summary.squeeze(1)], dim=1) # create decoder initial states p_mu, p_logvar = self.c2z(enc_last) sample_z = th.normal(p_mu, th.sqrt(th.exp(p_logvar))).detach() logprob_sample_z = self.gaussian_logprob(p_mu, self.zero, sample_z) joint_logpz = th.sum(logprob_sample_z, dim=1) # pack attention context dec_init_state = self.z_embedding(sample_z.unsqueeze(0)) attn_context = None # decode if self.config.dec_rnn_cell == 'lstm': dec_init_state = tuple([dec_init_state, dec_init_state]) # decode logprobs, outs = self.decoder.forward_rl(batch_size=batch_size, dec_init_state=dec_init_state, attn_context=attn_context, vocab=self.vocab, max_words=max_words, temp=0.1) return logprobs, outs, joint_logpz, sample_z
Example #17
Source File: criterions.py From ConvLab with MIT License | 5 votes |
def forward(self, log_qy, batch_size=None, unit_average=False): """ -qy log(qy) """ if log_qy.dim() > 2: log_qy = log_qy.squeeze() qy = th.exp(log_qy) h_q = th.sum(-1 * log_qy * qy, dim=1) if unit_average: return th.mean(h_q) else: return th.sum(h_q) / batch_size
Example #18
Source File: criterions.py From ConvLab with MIT License | 5 votes |
def forward(self, log_qy, log_py, batch_size=None, unit_average=False): """ qy * log(q(y)/p(y)) """ qy = th.exp(log_qy) y_kl = th.sum(qy * (log_qy - log_py), dim=1) if unit_average: return th.mean(y_kl) else: return th.sum(y_kl)/batch_size
Example #19
Source File: utils.py From conv-social-pooling with MIT License | 5 votes |
def logsumexp(inputs, dim=None, keepdim=False): if dim is None: inputs = inputs.view(-1) dim = 0 s, _ = torch.max(inputs, dim=dim, keepdim=True) outputs = s + (inputs - s).exp().sum(dim=dim, keepdim=True).log() if not keepdim: outputs = outputs.squeeze(dim) return outputs
Example #20
Source File: utils.py From conv-social-pooling with MIT License | 5 votes |
def maskedMSETest(y_pred, y_gt, mask): acc = torch.zeros_like(mask) muX = y_pred[:, :, 0] muY = y_pred[:, :, 1] x = y_gt[:, :, 0] y = y_gt[:, :, 1] out = torch.pow(x - muX, 2) + torch.pow(y - muY, 2) acc[:, :, 0] = out acc[:, :, 1] = out acc = acc * mask lossVal = torch.sum(acc[:,:,0],dim=1) counts = torch.sum(mask[:,:,0],dim=1) return lossVal, counts ## Helper function for log sum exp calculation:
Example #21
Source File: criterions.py From ConvLab with MIT License | 5 votes |
def forward(self, recog_mu, recog_logvar, prior_mu, prior_logvar): # find the KL divergence between two Gaussian distribution loss = 1.0 + (recog_logvar - prior_logvar) loss -= th.div(th.pow(prior_mu - recog_mu, 2), th.exp(prior_logvar)) loss -= th.div(th.exp(recog_logvar), th.exp(prior_logvar)) if self.unit_average: kl_loss = -0.5 * th.mean(loss, dim=1) else: kl_loss = -0.5 * th.sum(loss, dim=1) avg_kl_loss = th.mean(kl_loss) return avg_kl_loss
Example #22
Source File: utils.py From pixel-cnn-pp with MIT License | 5 votes |
def log_sum_exp(x): """ numerically stable log_sum_exp implementation that prevents overflow """ # TF ordering axis = len(x.size()) - 1 m, _ = torch.max(x, dim=axis) m2, _ = torch.max(x, dim=axis, keepdim=True) return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
Example #23
Source File: logistic_mixture.py From L3C-PyTorch with GNU General Public License v3.0 | 5 votes |
def log_softmax(logit_probs, dim): """ numerically stable log_softmax implementation that prevents overflow """ m, _ = torch.max(logit_probs, dim=dim, keepdim=True) return logit_probs - m - torch.log(torch.sum(torch.exp(logit_probs - m), dim=dim, keepdim=True))
Example #24
Source File: functional.py From audio with BSD 2-Clause "Simplified" License | 5 votes |
def bass_biquad( waveform: Tensor, sample_rate: int, gain: float, central_freq: float = 100, Q: float = 0.707 ) -> Tensor: r"""Design a bass tone-control effect. Similar to SoX implementation. Args: waveform (Tensor): audio waveform of dimension of `(..., time)` sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) gain (float): desired gain at the boost (or attenuation) in dB. central_freq (float, optional): central frequency (in Hz). (Default: ``100``) Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``). Returns: Tensor: Waveform of dimension of `(..., time)` References: http://sox.sourceforge.net/sox.html https://www.w3.org/2011/audio/audio-eq-cookbook.html#APF """ w0 = 2 * math.pi * central_freq / sample_rate alpha = math.sin(w0) / 2 / Q A = math.exp(gain / 40 * math.log(10)) temp1 = 2 * math.sqrt(A) * alpha temp2 = (A - 1) * math.cos(w0) temp3 = (A + 1) * math.cos(w0) b0 = A * ((A + 1) - temp2 + temp1) b1 = 2 * A * ((A - 1) - temp3) b2 = A * ((A + 1) - temp2 - temp1) a0 = (A + 1) + temp2 + temp1 a1 = -2 * ((A - 1) + temp3) a2 = (A + 1) + temp2 - temp1 return biquad(waveform, b0 / a0, b1 / a0, b2 / a0, a0 / a0, a1 / a0, a2 / a0)
Example #25
Source File: ada_lanczos_net.py From LanczosNetwork with MIT License | 5 votes |
def _get_graph_laplacian(self, node_feat, adj_mask): """ Compute graph Laplacian Args: node_feat: float tensor, shape B X N X D adj_mask: float tensor, shape B X N X N, binary mask, should contain self-loop Returns: L: float tensor, shape B X N X N """ batch_size = node_feat.shape[0] num_node = node_feat.shape[1] dim_feat = node_feat.shape[2] # compute pairwise distance idx_row, idx_col = np.meshgrid(range(num_node), range(num_node)) idx_row, idx_col = torch.Tensor(idx_row.reshape(-1)).long().to(node_feat.device), torch.Tensor( idx_col.reshape(-1)).long().to(node_feat.device) diff = node_feat[:, idx_row, :] - node_feat[:, idx_col, :] # shape B X N^2 X D dist2 = (diff * diff).sum(dim=2) # shape B X N^2 # sigma2, _ = torch.median(dist2, dim=1, keepdim=True) # median is sometimes 0 # sigma2 = sigma2 + 1.0e-7 sigma2 = torch.mean(dist2, dim=1, keepdim=True) A = torch.exp(-dist2 / sigma2) # shape B X N^2 A = A.reshape(batch_size, num_node, num_node) * adj_mask # shape B X N X N row_sum = torch.sum(A, dim=2, keepdim=True) pad_row_sum = torch.zeros_like(row_sum) pad_row_sum[row_sum == 0.0] = 1.0 alpha = 0.5 D = 1.0 / (row_sum + pad_row_sum).pow(alpha) # shape B X N X 1 L = D * A * D.transpose(1, 2) # shape B X N X N return L
Example #26
Source File: position_encoding.py From TVQAplus with MIT License | 5 votes |
def __init__(self, n_filters=128, max_len=500): """ :param n_filters: same with input hidden size :param max_len: maximum sequence length """ super(PositionEncoding, self).__init__() # Compute the positional encodings once in log space. pe = torch.zeros(max_len, n_filters) # (L, D) position = torch.arange(0, max_len).float().unsqueeze(1) div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) # buffer is a tensor, not a variable, (L, D)
Example #27
Source File: logistic_mixture.py From L3C-PyTorch with GNU General Public License v3.0 | 5 votes |
def log_prob_from_logits(logit_probs): """ numerically stable log_softmax implementation that prevents overflow """ # logit_probs is NKHW m, _ = torch.max(logit_probs, dim=1, keepdim=True) return logit_probs - m - torch.log(torch.sum(torch.exp(logit_probs - m), dim=1, keepdim=True)) # TODO(pytorch): replace with pytorch internal in 1.0, there is a bug in 0.4.1
Example #28
Source File: utils.py From ICDAR-2019-SROIE with MIT License | 5 votes |
def gcxgcy_to_cxcy(gcxgcy, priors_cxcy): """ Decode bounding box coordinates predicted by the model, since they are encoded in the form mentioned above. They are decoded into center-size coordinates. This is the inverse of the function above. :param gcxgcy: encoded bounding boxes, i.e. output of the model, a tensor of size (n_priors, 4) :param priors_cxcy: prior boxes with respect to which the encoding is defined, a tensor of size (n_priors, 4) :return: decoded bounding boxes in center-size form, a tensor of size (n_priors, 4) """ return torch.cat([gcxgcy[:, :2] * priors_cxcy[:, 2:] / 10 + priors_cxcy[:, :2], # c_x, c_y torch.exp(gcxgcy[:, 2:] / 5) * priors_cxcy[:, 2:]], 1) # w, h
Example #29
Source File: torchac.py From L3C-PyTorch with GNU General Public License v3.0 | 5 votes |
def _get_C_cur(targets, means_c, log_scales_c): # NKHWL """ :param targets: Lp floats :param means_c: NKHW :param log_scales_c: NKHW :return: """ # NKHW1 inv_stdv = torch.exp(-log_scales_c).unsqueeze(-1) # NKHWL' centered_targets = (targets - means_c.unsqueeze(-1)) # NKHWL' cdf = centered_targets.mul(inv_stdv).sigmoid() # sigma' * (x - mu) return cdf
Example #30
Source File: logistic_mixture.py From L3C-PyTorch with GNU General Public License v3.0 | 5 votes |
def log_sum_exp(log_probs, dim): """ numerically stable log_sum_exp implementation that prevents overflow """ m, _ = torch.max(log_probs, dim=dim) m_keep, _ = torch.max(log_probs, dim=dim, keepdim=True) # == m + torch.log(torch.sum(torch.exp(log_probs - m_keep), dim=dim)) return log_probs.sub_(m_keep).exp_().sum(dim=dim).log_().add(m)