Python torch.sum() Examples

The following are 30 code examples of torch.sum(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model.py    License: Apache License 2.0 6 votes vote down vote up
def forward(self, src, tgt, src_mask, tgt_mask):
        """
        Take in and process masked src and target sequences.
        """
        latent = self.encode(src, src_mask)  # (batch_size, max_src_seq, d_model)
        latent = self.sigmoid(latent)
        # memory = self.position_layer(memory)

        latent = torch.sum(latent, dim=1)  # (batch_size, d_model)

        # latent = self.memory2latent(memory)  # (batch_size, max_src_seq, latent_size)

        # latent = self.memory2latent(memory)
        # memory = self.latent2memory(latent)  # (batch_size, max_src_seq, d_model)

        logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask)  # (batch_size, max_tgt_seq, d_model)
        prob = self.generator(logit)  # (batch_size, max_seq, vocab_size)
        return latent, prob 
Example #2
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model2.py    License: Apache License 2.0 6 votes vote down vote up
def forward(self, src, tgt, src_mask, tgt_mask):
        """
        Take in and process masked src and target sequences.
        """
        memory = self.encode(src, src_mask)  # (batch_size, max_src_seq, d_model)
        # attented_mem=self.attention(memory,memory,memory,src_mask)
        # memory=attented_mem
        score = self.attention(memory, memory, src_mask)
        attent_memory = score.bmm(memory)
        # memory=self.linear(torch.cat([memory,attent_memory],dim=-1))

        memory, _ = self.gru(attented_mem)
        '''
        score=torch.sigmoid(self.linear(memory))
        memory=memory*score
        '''
        latent = torch.sum(memory, dim=1)  # (batch_size, d_model)
        logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask)  # (batch_size, max_tgt_seq, d_model)
        # logit,_=self.gru_decoder(logit)
        prob = self.generator(logit)  # (batch_size, max_seq, vocab_size)
        return latent, prob 
Example #3
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model.py    License: Apache License 2.0 6 votes vote down vote up
def forward(self, src, tgt, src_mask, tgt_mask):
        """
        Take in and process masked src and target sequences.
        """
        latent = self.encode(src, src_mask)  # (batch_size, max_src_seq, d_model)
        latent = self.sigmoid(latent)
        # memory = self.position_layer(memory)

        latent = torch.sum(latent, dim=1)  # (batch_size, d_model)

        # latent = self.memory2latent(memory)  # (batch_size, max_src_seq, latent_size)

        # latent = self.memory2latent(memory)
        # memory = self.latent2memory(latent)  # (batch_size, max_src_seq, d_model)

        logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask)  # (batch_size, max_tgt_seq, d_model)
        prob = self.generator(logit)  # (batch_size, max_seq, vocab_size)
        return latent, prob 
Example #4
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model2.py    License: Apache License 2.0 6 votes vote down vote up
def forward(self, src, tgt, src_mask, tgt_mask):
        """
        Take in and process masked src and target sequences.
        """
        memory = self.encode(src, src_mask)  # (batch_size, max_src_seq, d_model)
        # attented_mem=self.attention(memory,memory,memory,src_mask)
        # memory=attented_mem
        score = self.attention(memory, memory, src_mask)
        attent_memory = score.bmm(memory)
        # memory=self.linear(torch.cat([memory,attent_memory],dim=-1))

        memory, _ = self.gru(attented_mem)
        '''
        score=torch.sigmoid(self.linear(memory))
        memory=memory*score
        '''
        latent = torch.sum(memory, dim=1)  # (batch_size, d_model)
        logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask)  # (batch_size, max_tgt_seq, d_model)
        # logit,_=self.gru_decoder(logit)
        prob = self.generator(logit)  # (batch_size, max_seq, vocab_size)
        return latent, prob 
Example #5
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model.py    License: Apache License 2.0 6 votes vote down vote up
def forward(self, src, tgt, src_mask, tgt_mask):
        """
        Take in and process masked src and target sequences.
        """
        latent = self.encode(src, src_mask)  # (batch_size, max_src_seq, d_model)
        latent = self.sigmoid(latent)
        # memory = self.position_layer(memory)

        latent = torch.sum(latent, dim=1)  # (batch_size, d_model)

        # latent = self.memory2latent(memory)  # (batch_size, max_src_seq, latent_size)

        # latent = self.memory2latent(memory)
        # memory = self.latent2memory(latent)  # (batch_size, max_src_seq, d_model)

        logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask)  # (batch_size, max_tgt_seq, d_model)
        prob = self.generator(logit)  # (batch_size, max_seq, vocab_size)
        return latent, prob 
Example #6
Source Project: hgraph2graph   Author: wengong-jin   File: hgnn.py    License: MIT License 6 votes vote down vote up
def forward(self, x_graphs, x_tensors, y_graphs, y_tensors, y_orders, beta):
        x_tensors = make_cuda(x_tensors)
        y_tensors = make_cuda(y_tensors)
        x_root_vecs, x_tree_vecs, x_graph_vecs = self.encode(x_tensors)
        _, y_tree_vecs, y_graph_vecs = self.encode(y_tensors)

        diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1)
        diff_graph_vecs = y_graph_vecs.sum(dim=1) - x_graph_vecs.sum(dim=1)
        diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var)
        diff_graph_vecs, graph_kl = self.rsample(diff_graph_vecs, self.G_mean, self.G_var)
        kl_div = tree_kl + graph_kl

        diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1)
        diff_graph_vecs = diff_graph_vecs.unsqueeze(1).expand(-1, x_graph_vecs.size(1), -1)
        x_tree_vecs = self.W_tree( torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1) )
        x_graph_vecs = self.W_graph( torch.cat([x_graph_vecs, diff_graph_vecs], dim=-1) )

        loss, wacc, iacc, tacc, sacc = self.decoder((x_root_vecs, x_tree_vecs, x_graph_vecs), y_graphs, y_tensors, y_orders)
        return loss + beta * kl_div, kl_div.item(), wacc, iacc, tacc, sacc 
Example #7
Source Project: hgraph2graph   Author: wengong-jin   File: hgnn.py    License: MIT License 6 votes vote down vote up
def forward(self, x_graphs, x_tensors, y_graphs, y_tensors, y_orders, beta):
        x_tensors = make_cuda(x_tensors)
        y_tensors = make_cuda(y_tensors)
        x_root_vecs, x_tree_vecs, x_graph_vecs = self.encode(x_tensors)
        _, y_tree_vecs, y_graph_vecs = self.encode(y_tensors)

        diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1)
        diff_graph_vecs = y_graph_vecs.sum(dim=1) - x_graph_vecs.sum(dim=1)
        diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var)
        diff_graph_vecs, graph_kl = self.rsample(diff_graph_vecs, self.G_mean, self.G_var)
        kl_div = tree_kl + graph_kl

        diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1)
        diff_graph_vecs = diff_graph_vecs.unsqueeze(1).expand(-1, x_graph_vecs.size(1), -1)
        x_tree_vecs = self.W_tree( torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1) )
        x_graph_vecs = self.W_graph( torch.cat([x_graph_vecs, diff_graph_vecs], dim=-1) )

        loss, wacc, iacc, tacc, sacc = self.decoder((x_root_vecs, x_tree_vecs, x_graph_vecs), y_graphs, y_tensors, y_orders)
        return loss + beta * kl_div, kl_div.item(), wacc, iacc, tacc, sacc 
Example #8
Source Project: nmp_qc   Author: priba   File: demo_letter_duvenaud.py    License: MIT License 6 votes vote down vote up
def plot_examples(data_loader, model, epoch, plotter, ind = [0, 10, 20]):

    # switch to evaluate mode
    model.eval()

    for i, (g, h, e, target) in enumerate(data_loader):
        if i in ind:
            subfolder_path = 'batch_' + str(i) + '_t_' + str(int(target[0][0])) + '/epoch_' + str(epoch) + '/'
            if not os.path.isdir(args.plotPath + subfolder_path):
                os.makedirs(args.plotPath + subfolder_path)

            num_nodes = torch.sum(torch.sum(torch.abs(h[0, :, :]), 1) > 0)
            am = g[0, 0:num_nodes, 0:num_nodes].numpy()
            pos = h[0, 0:num_nodes, :].numpy()

            plotter.plot_graph(am, position=pos, fig_name=subfolder_path+str(i) + '_input.png')

            # Prepare input data
            if args.cuda:
                g, h, e, target = g.cuda(), h.cuda(), e.cuda(), target.cuda()
            g, h, e, target = Variable(g), Variable(h), Variable(e), Variable(target)

            # Compute output
            model(g, h, e, lambda cls, id: plotter.plot_graph(am, position=pos, cls=cls,
                                                          fig_name=subfolder_path+ id)) 
Example #9
Source Project: deep-learning-note   Author: wdxtub   File: 49_word2vec.py    License: MIT License 6 votes vote down vote up
def train(net, lr, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("train on", device)
    net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    for epoch in range(num_epochs):
        start, l_sum, n = time.time(), 0.0, 0
        for batch in data_iter:
            center, context_negative, mask, label = [d.to(device) for d in batch]

            pred = skip_gram(center, context_negative, net[0], net[1])

            # 使用掩码变量mask来避免填充项对损失函数计算的影响
            l = (loss(pred.view(label.shape), label, mask) *
                 mask.shape[1] / mask.float().sum(dim=1)).mean()  # 一个batch的平均loss
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            l_sum += l.cpu().item()
            n += 1
        print('epoch %d, loss %.2f, time %.2fs'
              % (epoch + 1, l_sum / n, time.time() - start)) 
Example #10
Source Project: PolarSeg   Author: edwardzhou130   File: lovasz_losses.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True):
    """
    IoU for foreground class
    binary: 1 foreground, 0 background
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        intersection = ((label == 1) & (pred == 1)).sum()
        union = ((label == 1) | ((pred == 1) & (label != ignore))).sum()
        if not union:
            iou = EMPTY
        else:
            iou = float(intersection) / float(union)
        ious.append(iou)
    iou = mean(ious)    # mean accross images if per_image
    return 100 * iou 
Example #11
Source Project: PolarSeg   Author: edwardzhou130   File: lovasz_losses.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False):
    """
    Array of IoU for each (non ignored) class
    """
    if not per_image:
        preds, labels = (preds,), (labels,)
    ious = []
    for pred, label in zip(preds, labels):
        iou = []    
        for i in range(C):
            if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes)
                intersection = ((label == i) & (pred == i)).sum()
                union = ((label == i) | ((pred == i) & (label != ignore))).sum()
                if not union:
                    iou.append(EMPTY)
                else:
                    iou.append(float(intersection) / float(union))
        ious.append(iou)
    ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image
    return 100 * np.array(ious)


# --------------------------- BINARY LOSSES --------------------------- 
Example #12
Source Project: PolarSeg   Author: edwardzhou130   File: lovasz_losses.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]
    grad = lovasz_grad(gt_sorted)
    loss = torch.dot(F.relu(errors_sorted), Variable(grad))
    return loss 
Example #13
Source Project: dogTorch   Author: ehsanik   File: metrics.py    License: MIT License 6 votes vote down vote up
def final_report(self):
        correct_preds = self.confusion[:, :,
                                       range(self.args.num_classes),
                                       range(self.args.num_classes)]
        correct_percentage = correct_preds / (self.confusion.sum(3) + 1e-6) * 100
        balance_accuracy = correct_percentage.mean()
        per_sequence_element_accuracy = correct_percentage.view(
            correct_percentage.size(0), -1).mean(1)
        per_sequence_report = ', '.join(
            '{:.2f}'.format(acc) for acc in per_sequence_element_accuracy)
        report = ('Accuracy {meter.avg[0]:.2f}   Balanced {balanced:.2f}   '
                  'PerSeq [{per_seq}]').format(meter=self.meter,
                                               balanced=balance_accuracy,
                                               per_seq=per_sequence_report)
        report += '   Accuracy Matrix (seq x imu x label): {}'.format(
            correct_percentage)
        return report 
Example #14
Source Project: dogTorch   Author: ehsanik   File: metrics.py    License: MIT License 6 votes vote down vote up
def record_output(self, output, output_indices, target, prev_absolutes,
                      next_absolutes, batch_size=1):
        assert output.dim() == 4
        assert target.dim() == 3

        _, predictions = output.max(3)

        # Compute per class accuracy for unbalanced data.
        sequence_length = output.size(1)
        num_label = output.size(2)
        num_class = output.size(3)
        correct_alljoint = (target == predictions).float().sum(2)
        sum_of_corrects = correct_alljoint.sum(1)
        max_value = num_label * sequence_length
        count_correct = (sum_of_corrects == max_value).float().mean()
        correct_per_seq = ((correct_alljoint == num_label - 1).sum(1).float() /
                           sequence_length).mean()
        self.meter.update(
            torch.Tensor([count_correct * 100, correct_per_seq * 100]),
            batch_size) 
Example #15
Source Project: ACAN   Author: miraiaroha   File: losses.py    License: MIT License 6 votes vote down vote up
def forward(self, Q, P):
        """
        Parameters
        ----------
        P: ground truth probability distribution [batch_size, n, n]
        Q: predicted probability distribution [batch_size, n, n]

        Description
        -----------
        compute the KL divergence of attention maps. Here P and Q denote 
        the pixel-level attention map with n spatial positions.
        """
        kl_loss = P * safe_log(P / Q)
        pixel_loss = torch.sum(kl_loss, dim=-1)
        total_loss = torch.mean(pixel_loss)
        return total_loss 
Example #16
Source Project: ACAN   Author: miraiaroha   File: losses.py    License: MIT License 6 votes vote down vote up
def __init__(self, ignore_index=None, reduction='sum', use_weights=False, weight=None):
        """
        Parameters
        ----------
        ignore_index : Specifies a target value that is ignored
                       and does not contribute to the input gradient
        reduction : Specifies the reduction to apply to the output: 
                    'mean' | 'sum'. 'mean': elemenwise mean, 
                    'sum': class dim will be summed and batch dim will be averaged.
        use_weight : whether to use weights of classes.
        weight : Tensor, optional
                a manual rescaling weight given to each class.
                If given, has to be a Tensor of size "nclasses"
        """
        super(_BaseEntropyLoss2d, self).__init__()
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.use_weights = use_weights
        if use_weights:
            print("w/ class balance")
            print(weight)
            self.weight = torch.FloatTensor(weight).cuda()
        else:
            print("w/o class balance")
            self.weight = None 
Example #17
Source Project: cascade-rcnn_Pytorch   Author: guoruoqian   File: gridgen.py    License: MIT License 6 votes vote down vote up
def forward(self, input1):
        self.batchgrid3d = torch.zeros(torch.Size([input1.size(0)]) + self.grid3d.size())

        for i in range(input1.size(0)):
            self.batchgrid3d[i] = self.grid3d

        self.batchgrid3d = Variable(self.batchgrid3d)
        #print(self.batchgrid3d)

        x = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,0:4]), 3)
        y = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,4:8]), 3)
        z = torch.sum(torch.mul(self.batchgrid3d, input1[:,:,:,8:]), 3)
        #print(x)
        r = torch.sqrt(x**2 + y**2 + z**2) + 1e-5

        #print(r)
        theta = torch.acos(z/r)/(np.pi/2)  - 1
        #phi = torch.atan(y/x)
        phi = torch.atan(y/(x + 1e-5))  + np.pi * x.lt(0).type(torch.FloatTensor) * (y.ge(0).type(torch.FloatTensor) - y.lt(0).type(torch.FloatTensor))
        phi = phi/np.pi


        output = torch.cat([theta,phi], 3)

        return output 
Example #18
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model2.py    License: Apache License 2.0 5 votes vote down vote up
def forward(self, src, tgt, src_mask, tgt_mask):
        """
        Take in and process masked src and target sequences.
        """
        memory = self.encode(src, src_mask)  # (batch_size, max_src_seq, d_model)
        # attented_mem=self.attention(memory,memory,memory,src_mask)
        # memory=attented_mem
        score = self.attention(memory, memory, src_mask)
        attent_memory = score.bmm(memory)
        # memory=self.linear(torch.cat([memory,attent_memory],dim=-1))

        memory, _ = self.gru(attented_mem)
        '''
        score=torch.sigmoid(self.linear(memory))
        memory=memory*score
        '''
        latent = torch.sum(memory, dim=1)  # (batch_size, d_model)
        logit = self.decode(latent.unsqueeze(1), tgt, tgt_mask)  # (batch_size, max_tgt_seq, d_model)
        # logit,_=self.gru_decoder(logit)
        prob = self.generator(logit)  # (batch_size, max_seq, vocab_size)
        return latent, prob 
Example #19
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model2.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() 
Example #20
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() 
Example #21
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() 
Example #22
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: model2.py    License: Apache License 2.0 5 votes vote down vote up
def __init__(self, src, trg=None, pad=0):
        self.src = src
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:, :-1]
            self.trg_y = trg[:, 1:]
            self.trg_mask = self.make_std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() 
Example #23
Source Project: DDPAE-video-prediction   Author: jthsieh   File: DDPAE.py    License: MIT License 5 votes vote down vote up
def get_output(self, components, latent):
    '''
    Take the sum of the components.
    '''
    # components: batch_size x n_frames_total x total_components x C x H x W
    batch_size = components.size(0)
    # Sum the components
    output = torch.sum(components, dim=2)
    output = torch.clamp(output, max=1)
    return output 
Example #24
Source Project: hgraph2graph   Author: wengong-jin   File: nnutils.py    License: MIT License 5 votes vote down vote up
def avg_pool(all_vecs, scope, dim):
    size = create_var(torch.Tensor([le for _,le in scope]))
    return all_vecs.sum(dim=dim) / size.unsqueeze(-1) 
Example #25
Source Project: hgraph2graph   Author: wengong-jin   File: nnutils.py    License: MIT License 5 votes vote down vote up
def get_accuracy_bin(scores, labels):
    preds = torch.ge(scores, 0).long()
    acc = torch.eq(preds, labels).float()
    return torch.sum(acc) / labels.nelement() 
Example #26
Source Project: hgraph2graph   Author: wengong-jin   File: nnutils.py    License: MIT License 5 votes vote down vote up
def get_accuracy(scores, labels):
    _,preds = torch.max(scores, dim=-1)
    acc = torch.eq(preds, labels).float()
    return torch.sum(acc) / labels.nelement() 
Example #27
Source Project: hgraph2graph   Author: wengong-jin   File: nnutils.py    License: MIT License 5 votes vote down vote up
def get_accuracy_sym(scores, labels):
    max_scores,max_idx = torch.max(scores, dim=-1)
    lab_scores = scores[torch.arange(len(scores)), labels]
    acc = torch.eq(lab_scores, max_scores).float()
    return torch.sum(acc) / labels.nelement() 
Example #28
Source Project: hgraph2graph   Author: wengong-jin   File: hgnn.py    License: MIT License 5 votes vote down vote up
def rsample(self, z_vecs, W_mean, W_var, perturb=True):
        batch_size = z_vecs.size(0)
        z_mean = W_mean(z_vecs)
        z_log_var = -torch.abs( W_var(z_vecs) )
        kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
        epsilon = torch.randn_like(z_mean).cuda()
        z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon if perturb else z_mean
        return z_vecs, kl_loss 
Example #29
Source Project: hgraph2graph   Author: wengong-jin   File: hgnn.py    License: MIT License 5 votes vote down vote up
def rsample(self, z_vecs, W_mean, W_var):
        batch_size = z_vecs.size(0)
        z_mean = W_mean(z_vecs)
        z_log_var = -torch.abs( W_var(z_vecs) )
        kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
        epsilon = torch.randn_like(z_mean).cuda()
        z_vecs = z_mean + torch.exp(z_log_var / 2) * epsilon
        return z_vecs, kl_loss 
Example #30
Source Project: hgraph2graph   Author: wengong-jin   File: hgnn.py    License: MIT License 5 votes vote down vote up
def forward(self, x_graphs, x_tensors, y_graphs, y_tensors, y_orders, cond, beta):
        x_tensors = make_cuda(x_tensors)
        y_tensors = make_cuda(y_tensors)
        cond = torch.tensor(cond).float().cuda()

        x_root_vecs, x_tree_vecs, x_graph_vecs = self.encode(x_tensors)
        _, y_tree_vecs, y_graph_vecs = self.encode(y_tensors)

        diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1)
        diff_graph_vecs = y_graph_vecs.sum(dim=1) - x_graph_vecs.sum(dim=1)
        diff_tree_vecs = self.U_tree( torch.cat([diff_tree_vecs, cond], dim=-1) ) #combine condition for posterior
        diff_graph_vecs = self.U_graph( torch.cat([diff_graph_vecs, cond], dim=-1) ) #combine condition for posterior

        diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var)
        diff_graph_vecs, graph_kl = self.rsample(diff_graph_vecs, self.G_mean, self.G_var)
        kl_div = tree_kl + graph_kl

        diff_tree_vecs = torch.cat([diff_tree_vecs, cond], dim=-1) #combine condition for posterior
        diff_graph_vecs = torch.cat([diff_graph_vecs, cond], dim=-1) #combine condition for posterior

        diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1)
        diff_graph_vecs = diff_graph_vecs.unsqueeze(1).expand(-1, x_graph_vecs.size(1), -1)
        x_tree_vecs = self.W_tree( torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1) )
        x_graph_vecs = self.W_graph( torch.cat([x_graph_vecs, diff_graph_vecs], dim=-1) )

        loss, wacc, iacc, tacc, sacc = self.decoder((x_root_vecs, x_tree_vecs, x_graph_vecs), y_graphs, y_tensors, y_orders)
        return loss + beta * kl_div, kl_div.item(), wacc, iacc, tacc, sacc