Python torch.sum() Examples
The following are 30
code examples of torch.sum().
These examples are extracted from open source projects.
You can vote up the ones you like or vote down the ones you don't like,
and go to the original project or source file by following the links above each example.
You may also want to check out all available functions/classes of the module
torch
, or try the search function
.

Example #1
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model.py License: Apache License 2.0 | 6 votes |
def forward(self, src, tgt, src_mask, tgt_mask): """ Take in and process masked src and target sequences. """ latent = self.encode(src, src_mask) # (batch_size, max_src_seq, d_model) latent = self.sigmoid(latent) # memory = self.position_layer(memory) latent = torch.sum(latent, dim=1) # (batch_size, d_model) # latent = self.memory2latent(memory) # (batch_size, max_src_seq, latent_size) # latent = self.memory2latent(memory) # memory = self.latent2memory(latent) # (batch_size, max_src_seq, d_model) logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask) # (batch_size, max_tgt_seq, d_model) prob = self.generator(logit) # (batch_size, max_seq, vocab_size) return latent, prob
Example #2
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model2.py License: Apache License 2.0 | 6 votes |
def forward(self, src, tgt, src_mask, tgt_mask): """ Take in and process masked src and target sequences. """ memory = self.encode(src, src_mask) # (batch_size, max_src_seq, d_model) # attented_mem=self.attention(memory,memory,memory,src_mask) # memory=attented_mem score = self.attention(memory, memory, src_mask) attent_memory = score.bmm(memory) # memory=self.linear(torch.cat([memory,attent_memory],dim=-1)) memory, _ = self.gru(attented_mem) ''' score=torch.sigmoid(self.linear(memory)) memory=memory*score ''' latent = torch.sum(memory, dim=1) # (batch_size, d_model) logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask) # (batch_size, max_tgt_seq, d_model) # logit,_=self.gru_decoder(logit) prob = self.generator(logit) # (batch_size, max_seq, vocab_size) return latent, prob
Example #3
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model.py License: Apache License 2.0 | 6 votes |
def forward(self, src, tgt, src_mask, tgt_mask): """ Take in and process masked src and target sequences. """ latent = self.encode(src, src_mask) # (batch_size, max_src_seq, d_model) latent = self.sigmoid(latent) # memory = self.position_layer(memory) latent = torch.sum(latent, dim=1) # (batch_size, d_model) # latent = self.memory2latent(memory) # (batch_size, max_src_seq, latent_size) # latent = self.memory2latent(memory) # memory = self.latent2memory(latent) # (batch_size, max_src_seq, d_model) logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask) # (batch_size, max_tgt_seq, d_model) prob = self.generator(logit) # (batch_size, max_seq, vocab_size) return latent, prob
Example #4
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model2.py License: Apache License 2.0 | 6 votes |
def forward(self, src, tgt, src_mask, tgt_mask): """ Take in and process masked src and target sequences. """ memory = self.encode(src, src_mask) # (batch_size, max_src_seq, d_model) # attented_mem=self.attention(memory,memory,memory,src_mask) # memory=attented_mem score = self.attention(memory, memory, src_mask) attent_memory = score.bmm(memory) # memory=self.linear(torch.cat([memory,attent_memory],dim=-1)) memory, _ = self.gru(attented_mem) ''' score=torch.sigmoid(self.linear(memory)) memory=memory*score ''' latent = torch.sum(memory, dim=1) # (batch_size, d_model) logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask) # (batch_size, max_tgt_seq, d_model) # logit,_=self.gru_decoder(logit) prob = self.generator(logit) # (batch_size, max_seq, vocab_size) return latent, prob
Example #5
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model.py License: Apache License 2.0 | 6 votes |
def forward(self, src, tgt, src_mask, tgt_mask): """ Take in and process masked src and target sequences. """ latent = self.encode(src, src_mask) # (batch_size, max_src_seq, d_model) latent = self.sigmoid(latent) # memory = self.position_layer(memory) latent = torch.sum(latent, dim=1) # (batch_size, d_model) # latent = self.memory2latent(memory) # (batch_size, max_src_seq, latent_size) # latent = self.memory2latent(memory) # memory = self.latent2memory(latent) # (batch_size, max_src_seq, d_model) logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask) # (batch_size, max_tgt_seq, d_model) prob = self.generator(logit) # (batch_size, max_seq, vocab_size) return latent, prob
Example #6
Source Project: hgraph2graph Author: wengong-jin File: hgnn.py License: MIT License | 6 votes |
def forward(self, x_graphs, x_tensors, y_graphs, y_tensors, y_orders, beta): x_tensors = make_cuda(x_tensors) y_tensors = make_cuda(y_tensors) x_root_vecs, x_tree_vecs, x_graph_vecs = self.encode(x_tensors) _, y_tree_vecs, y_graph_vecs = self.encode(y_tensors) diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1) diff_graph_vecs = y_graph_vecs.sum(dim=1) - x_graph_vecs.sum(dim=1) diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var) diff_graph_vecs, graph_kl = self.rsample(diff_graph_vecs, self.G_mean, self.G_var) kl_div = tree_kl + graph_kl diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1) diff_graph_vecs = diff_graph_vecs.unsqueeze(1).expand(-1, x_graph_vecs.size(1), -1) x_tree_vecs = self.W_tree( torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1) ) x_graph_vecs = self.W_graph( torch.cat([x_graph_vecs, diff_graph_vecs], dim=-1) ) loss, wacc, iacc, tacc, sacc = self.decoder((x_root_vecs, x_tree_vecs, x_graph_vecs), y_graphs, y_tensors, y_orders) return loss + beta * kl_div, kl_div.item(), wacc, iacc, tacc, sacc
Example #7
Source Project: hgraph2graph Author: wengong-jin File: hgnn.py License: MIT License | 6 votes |
def forward(self, x_graphs, x_tensors, y_graphs, y_tensors, y_orders, beta): x_tensors = make_cuda(x_tensors) y_tensors = make_cuda(y_tensors) x_root_vecs, x_tree_vecs, x_graph_vecs = self.encode(x_tensors) _, y_tree_vecs, y_graph_vecs = self.encode(y_tensors) diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1) diff_graph_vecs = y_graph_vecs.sum(dim=1) - x_graph_vecs.sum(dim=1) diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var) diff_graph_vecs, graph_kl = self.rsample(diff_graph_vecs, self.G_mean, self.G_var) kl_div = tree_kl + graph_kl diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1) diff_graph_vecs = diff_graph_vecs.unsqueeze(1).expand(-1, x_graph_vecs.size(1), -1) x_tree_vecs = self.W_tree( torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1) ) x_graph_vecs = self.W_graph( torch.cat([x_graph_vecs, diff_graph_vecs], dim=-1) ) loss, wacc, iacc, tacc, sacc = self.decoder((x_root_vecs, x_tree_vecs, x_graph_vecs), y_graphs, y_tensors, y_orders) return loss + beta * kl_div, kl_div.item(), wacc, iacc, tacc, sacc
Example #8
Source Project: nmp_qc Author: priba File: demo_letter_duvenaud.py License: MIT License | 6 votes |
def plot_examples(data_loader, model, epoch, plotter, ind = [0, 10, 20]): # switch to evaluate mode model.eval() for i, (g, h, e, target) in enumerate(data_loader): if i in ind: subfolder_path = 'batch_' + str(i) + '_t_' + str(int(target[0][0])) + '/epoch_' + str(epoch) + '/' if not os.path.isdir(args.plotPath + subfolder_path): os.makedirs(args.plotPath + subfolder_path) num_nodes = torch.sum(torch.sum(torch.abs(h[0, :, :]), 1) > 0) am = g[0, 0:num_nodes, 0:num_nodes].numpy() pos = h[0, 0:num_nodes, :].numpy() plotter.plot_graph(am, position=pos, fig_name=subfolder_path+str(i) + '_input.png') # Prepare input data if args.cuda: g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda() g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target) # Compute output model(g, h, e, lambda cls, id: plotter.plot_graph(am, position=pos, cls=cls, fig_name=subfolder_path+ id))
Example #9
Source Project: deep-learning-note Author: wdxtub File: 49_word2vec.py License: MIT License | 6 votes |
def train(net, lr, num_epochs): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print("train on", device) net = net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) for epoch in range(num_epochs): start, l_sum, n = time.time(), 0.0, 0 for batch in data_iter: center, context_negative, mask, label = [d.to(device) for d in batch] pred = skip_gram(center, context_negative, net[0], net[1]) # 使用掩码变量mask来避免填充项对损失函数计算的影响 l = (loss(pred.view(label.shape), label, mask) * mask.shape[1] / mask.float().sum(dim=1)).mean() # 一个batch的平均loss optimizer.zero_grad() l.backward() optimizer.step() l_sum += l.cpu().item() n += 1 print('epoch %d, loss %.2f, time %.2fs' % (epoch + 1, l_sum / n, time.time() - start))
Example #10
Source Project: PolarSeg Author: edwardzhou130 File: lovasz_losses.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): """ IoU for foreground class binary: 1 foreground, 0 background """ if not per_image: preds, labels = (preds,), (labels,) ious = [] for pred, label in zip(preds, labels): intersection = ((label == 1) & (pred == 1)).sum() union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() if not union: iou = EMPTY else: iou = float(intersection) / float(union) ious.append(iou) iou = mean(ious) # mean accross images if per_image return 100 * iou
Example #11
Source Project: PolarSeg Author: edwardzhou130 File: lovasz_losses.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): """ Array of IoU for each (non ignored) class """ if not per_image: preds, labels = (preds,), (labels,) ious = [] for pred, label in zip(preds, labels): iou = [] for i in range(C): if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) intersection = ((label == i) & (pred == i)).sum() union = ((label == i) | ((pred == i) & (label != ignore))).sum() if not union: iou.append(EMPTY) else: iou.append(float(intersection) / float(union)) ious.append(iou) ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image return 100 * np.array(ious) # --------------------------- BINARY LOSSES ---------------------------
Example #12
Source Project: PolarSeg Author: edwardzhou130 File: lovasz_losses.py License: BSD 3-Clause "New" or "Revised" License | 6 votes |
def lovasz_hinge_flat(logits, labels): """ Binary Lovasz hinge loss logits: [P] Variable, logits at each prediction (between -\infty and +\infty) labels: [P] Tensor, binary ground truth labels (0 or 1) ignore: label to ignore """ if len(labels) == 0: # only void pixels, the gradients should be 0 return logits.sum() * 0. signs = 2. * labels.float() - 1. errors = (1. - logits * Variable(signs)) errors_sorted, perm = torch.sort(errors, dim=0, descending=True) perm = perm.data gt_sorted = labels[perm] grad = lovasz_grad(gt_sorted) loss = torch.dot(F.relu(errors_sorted), Variable(grad)) return loss
Example #13
Source Project: dogTorch Author: ehsanik File: metrics.py License: MIT License | 6 votes |
def final_report(self): correct_preds = self.confusion[:, :, range(self.args.num_classes), range(self.args.num_classes)] correct_percentage = correct_preds / (self.confusion.sum(3) + 1e-6) * 100 balance_accuracy = correct_percentage.mean() per_sequence_element_accuracy = correct_percentage.view( correct_percentage.size(0), -1).mean(1) per_sequence_report = ', '.join( '{:.2f}'.format(acc) for acc in per_sequence_element_accuracy) report = ('Accuracy {meter.avg[0]:.2f} Balanced {balanced:.2f} ' 'PerSeq [{per_seq}]').format(meter=self.meter, balanced=balance_accuracy, per_seq=per_sequence_report) report += ' Accuracy Matrix (seq x imu x label): {}'.format( correct_percentage) return report
Example #14
Source Project: dogTorch Author: ehsanik File: metrics.py License: MIT License | 6 votes |
def record_output(self, output, output_indices, target, prev_absolutes, next_absolutes, batch_size=1): assert output.dim() == 4 assert target.dim() == 3 _, predictions = output.max(3) # Compute per class accuracy for unbalanced data. sequence_length = output.size(1) num_label = output.size(2) num_class = output.size(3) correct_alljoint = (target == predictions).float().sum(2) sum_of_corrects = correct_alljoint.sum(1) max_value = num_label * sequence_length count_correct = (sum_of_corrects == max_value).float().mean() correct_per_seq = ((correct_alljoint == num_label - 1).sum(1).float() / sequence_length).mean() self.meter.update( torch.Tensor([count_correct * 100, correct_per_seq * 100]), batch_size)
Example #15
Source Project: ACAN Author: miraiaroha File: losses.py License: MIT License | 6 votes |
def forward(self, Q, P): """ Parameters ---------- P: ground truth probability distribution [batch_size, n, n] Q: predicted probability distribution [batch_size, n, n] Description ----------- compute the KL divergence of attention maps. Here P and Q denote the pixel-level attention map with n spatial positions. """ kl_loss = P * safe_log(P / Q) pixel_loss = torch.sum(kl_loss, dim=-1) total_loss = torch.mean(pixel_loss) return total_loss
Example #16
Source Project: ACAN Author: miraiaroha File: losses.py License: MIT License | 6 votes |
def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None): """ Parameters ---------- ignore_index : Specifies a target value that is ignored and does not contribute to the input gradient reduction : Specifies the reduction to apply to the output: 'mean' | 'sum'. 'mean': elemenwise mean, 'sum': class dim will be summed and batch dim will be averaged. use_weight : whether to use weights of classes. weight : Tensor, optional a manual rescaling weight given to each class. If given, has to be a Tensor of size "nclasses" """ super(_BaseEntropyLoss2d, self).__init__() self.ignore_index = ignore_index self.reduction = reduction self.use_weights = use_weights if use_weights: print("w/ class balance") print(weight) self.weight = torch.FloatTensor(weight).cuda() else: print("w/o class balance") self.weight = None
Example #17
Source Project: cascade-rcnn_Pytorch Author: guoruoqian File: gridgen.py License: MIT License | 6 votes |
def forward(self, input1): self.batchgrid3d = torch.zeros(torch.Size([input1.size(0)]) + self.grid3d.size()) for i in range(input1.size(0)): self.batchgrid3d[i] = self.grid3d self.batchgrid3d = Variable(self.batchgrid3d) #print(self.batchgrid3d) x = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,0:4]), 3) y = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,4:8]), 3) z = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,8:]), 3) #print(x) r = torch.sqrt(x**2 + y**2 + z**2) + 1e-5 #print(r) theta = torch.acos(z/r)/(np.pi/2) - 1 #phi = torch.atan(y/x) phi = torch.atan(y/(x + 1e-5)) + np.pi * x.lt(0).type(torch.FloatTensor) * (y.ge(0).type(torch.FloatTensor) - y.lt(0).type(torch.FloatTensor)) phi = phi/np.pi output = torch.cat([theta,phi], 3) return output
Example #18
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model2.py License: Apache License 2.0 | 5 votes |
def forward(self, src, tgt, src_mask, tgt_mask): """ Take in and process masked src and target sequences. """ memory = self.encode(src, src_mask) # (batch_size, max_src_seq, d_model) # attented_mem=self.attention(memory,memory,memory,src_mask) # memory=attented_mem score = self.attention(memory, memory, src_mask) attent_memory = score.bmm(memory) # memory=self.linear(torch.cat([memory,attent_memory],dim=-1)) memory, _ = self.gru(attented_mem) ''' score=torch.sigmoid(self.linear(memory)) memory=memory*score ''' latent = torch.sum(memory, dim=1) # (batch_size, d_model) logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask) # (batch_size, max_tgt_seq, d_model) # logit,_=self.gru_decoder(logit) prob = self.generator(logit) # (batch_size, max_seq, vocab_size) return latent, prob
Example #19
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model2.py License: Apache License 2.0 | 5 votes |
def __init__(self, src, trg=None, pad=0): self.src = src self.src_mask = (src != pad).unsqueeze(-2) if trg is not None: self.trg = trg[:, :-1] self.trg_y = trg[:, 1:] self.trg_mask = self.make_std_mask(self.trg, pad) self.ntokens = (self.trg_y != pad).data.sum()
Example #20
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model.py License: Apache License 2.0 | 5 votes |
def __init__(self, src, trg=None, pad=0): self.src = src self.src_mask = (src != pad).unsqueeze(-2) if trg is not None: self.trg = trg[:, :-1] self.trg_y = trg[:, 1:] self.trg_mask = self.make_std_mask(self.trg, pad) self.ntokens = (self.trg_y != pad).data.sum()
Example #21
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model.py License: Apache License 2.0 | 5 votes |
def __init__(self, src, trg=None, pad=0): self.src = src self.src_mask = (src != pad).unsqueeze(-2) if trg is not None: self.trg = trg[:, :-1] self.trg_y = trg[:, 1:] self.trg_mask = self.make_std_mask(self.trg, pad) self.ntokens = (self.trg_y != pad).data.sum()
Example #22
Source Project: controllable-text-attribute-transfer Author: Nrgeup File: model2.py License: Apache License 2.0 | 5 votes |
def __init__(self, src, trg=None, pad=0): self.src = src self.src_mask = (src != pad).unsqueeze(-2) if trg is not None: self.trg = trg[:, :-1] self.trg_y = trg[:, 1:] self.trg_mask = self.make_std_mask(self.trg, pad) self.ntokens = (self.trg_y != pad).data.sum()
Example #23
Source Project: DDPAE-video-prediction Author: jthsieh File: DDPAE.py License: MIT License | 5 votes |
def get_output(self, components, latent): ''' Take the sum of the components. ''' # components: batch_size x n_frames_total x total_components x C x H x W batch_size = components.size(0) # Sum the components output = torch.sum(components, dim=2) output = torch.clamp(output, max=1) return output
Example #24
Source Project: hgraph2graph Author: wengong-jin File: nnutils.py License: MIT License | 5 votes |
def avg_pool(all_vecs, scope, dim): size = create_var(torch.Tensor([le for _,le in scope])) return all_vecs.sum(dim=dim) / size.unsqueeze(-1)
Example #25
Source Project: hgraph2graph Author: wengong-jin File: nnutils.py License: MIT License | 5 votes |
def get_accuracy_bin(scores, labels): preds = torch.ge(scores, 0).long() acc = torch.eq(preds, labels).float() return torch.sum(acc) / labels.nelement()
Example #26
Source Project: hgraph2graph Author: wengong-jin File: nnutils.py License: MIT License | 5 votes |
def get_accuracy(scores, labels): _,preds = torch.max(scores, dim=-1) acc = torch.eq(preds, labels).float() return torch.sum(acc) / labels.nelement()
Example #27
Source Project: hgraph2graph Author: wengong-jin File: nnutils.py License: MIT License | 5 votes |
def get_accuracy_sym(scores, labels): max_scores,max_idx = torch.max(scores, dim=-1) lab_scores = scores[torch.arange(len(scores)), labels] acc = torch.eq(lab_scores, max_scores).float() return torch.sum(acc) / labels.nelement()
Example #28
Source Project: hgraph2graph Author: wengong-jin File: hgnn.py License: MIT License | 5 votes |
def rsample(self, z_vecs, W_mean, W_var, perturb=True): batch_size = z_vecs.size(0) z_mean = W_mean(z_vecs) z_log_var = -torch.abs( W_var(z_vecs) ) kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size epsilon = torch.randn_like(z_mean).cuda() z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon if perturb else z_mean return z_vecs, kl_loss
Example #29
Source Project: hgraph2graph Author: wengong-jin File: hgnn.py License: MIT License | 5 votes |
def rsample(self, z_vecs, W_mean, W_var): batch_size = z_vecs.size(0) z_mean = W_mean(z_vecs) z_log_var = -torch.abs( W_var(z_vecs) ) kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size epsilon = torch.randn_like(z_mean).cuda() z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon return z_vecs, kl_loss
Example #30
Source Project: hgraph2graph Author: wengong-jin File: hgnn.py License: MIT License | 5 votes |
def forward(self, x_graphs, x_tensors, y_graphs, y_tensors, y_orders, cond, beta): x_tensors = make_cuda(x_tensors) y_tensors = make_cuda(y_tensors) cond = torch.tensor(cond).float().cuda() x_root_vecs, x_tree_vecs, x_graph_vecs = self.encode(x_tensors) _, y_tree_vecs, y_graph_vecs = self.encode(y_tensors) diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1) diff_graph_vecs = y_graph_vecs.sum(dim=1) - x_graph_vecs.sum(dim=1) diff_tree_vecs = self.U_tree( torch.cat([diff_tree_vecs, cond], dim=-1) ) #combine condition for posterior diff_graph_vecs = self.U_graph( torch.cat([diff_graph_vecs, cond], dim=-1) ) #combine condition for posterior diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var) diff_graph_vecs, graph_kl = self.rsample(diff_graph_vecs, self.G_mean, self.G_var) kl_div = tree_kl + graph_kl diff_tree_vecs = torch.cat([diff_tree_vecs, cond], dim=-1) #combine condition for posterior diff_graph_vecs = torch.cat([diff_graph_vecs, cond], dim=-1) #combine condition for posterior diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1) diff_graph_vecs = diff_graph_vecs.unsqueeze(1).expand(-1, x_graph_vecs.size(1), -1) x_tree_vecs = self.W_tree( torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1) ) x_graph_vecs = self.W_graph( torch.cat([x_graph_vecs, diff_graph_vecs], dim=-1) ) loss, wacc, iacc, tacc, sacc = self.decoder((x_root_vecs, x_tree_vecs, x_graph_vecs), y_graphs, y_tensors, y_orders) return loss + beta * kl_div, kl_div.item(), wacc, iacc, tacc, sacc