Python torch.tensor() Examples

The following are 30 code examples of torch.tensor(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source Project: mmdetection   Author: open-mmlab   File: gfl_head.py    License: Apache License 2.0 6 votes vote down vote up
def forward(self, feats):
        """Forward features from the upstream network.

        Args:
            feats (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.

        Returns:
            tuple: Usually a tuple of classification scores and bbox prediction
                cls_scores (list[Tensor]): Classification and quality (IoU)
                    joint scores for all scale levels, each is a 4D-tensor,
                    the channel number is num_classes.
                bbox_preds (list[Tensor]): Box distribution logits for all
                    scale levels, each is a 4D-tensor, the channel number is
                    4*(n+1), n is max value of integral set.
        """
        return multi_apply(self.forward_single, feats, self.scales) 
Example #2
Source Project: mmdetection   Author: open-mmlab   File: atss_head.py    License: Apache License 2.0 6 votes vote down vote up
def forward(self, feats):
        """Forward features from the upstream network.

        Args:
            feats (tuple[Tensor]): Features from the upstream network, each is
                a 4D-tensor.

        Returns:
            tuple: Usually a tuple of classification scores and bbox prediction
                cls_scores (list[Tensor]): Classification scores for all scale
                    levels, each is a 4D-tensor, the channels number is
                    num_anchors * num_classes.
                bbox_preds (list[Tensor]): Box energies / deltas for all scale
                    levels, each is a 4D-tensor, the channels number is
                    num_anchors * 4.
        """
        return multi_apply(self.forward_single, feats, self.scales) 
Example #3
Source Project: mmdetection   Author: open-mmlab   File: point_sample.py    License: Apache License 2.0 6 votes vote down vote up
def generate_grid(num_grid, size, device):
    """Generate regular square grid of points in [0, 1] x [0, 1] coordinate
    space.

    Args:
        num_grid (int): The number of grids to sample, one for each region.
        size (tuple(int, int)): The side size of the regular grid.
        device (torch.device): Desired device of returned tensor.

    Returns:
        (torch.Tensor): A tensor of shape (num_grid, size[0]*size[1], 2) that
            contains coordinates for the regular grids.
    """

    affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device)
    grid = F.affine_grid(
        affine_trans, torch.Size((1, 1, *size)), align_corners=False)
    grid = normalize(grid)
    return grid.view(1, -1, 2).expand(num_grid, -1, -1) 
Example #4
Source Project: mmdetection   Author: open-mmlab   File: score_hlr_sampler.py    License: Apache License 2.0 6 votes vote down vote up
def random_choice(gallery, num):
        """Randomly select some elements from the gallery.

        If `gallery` is a Tensor, the returned indices will be a Tensor;
        If `gallery` is a ndarray or list, the returned indices will be a
        ndarray.

        Args:
            gallery (Tensor | ndarray | list): indices pool.
            num (int): expected sample num.

        Returns:
            Tensor or ndarray: sampled indices.
        """
        assert len(gallery) >= num

        is_tensor = isinstance(gallery, torch.Tensor)
        if not is_tensor:
            gallery = torch.tensor(
                gallery, dtype=torch.long, device=torch.cuda.current_device())
        perm = torch.randperm(gallery.numel(), device=gallery.device)[:num]
        rand_inds = gallery[perm]
        if not is_tensor:
            rand_inds = rand_inds.cpu().numpy()
        return rand_inds 
Example #5
Source Project: mmdetection   Author: open-mmlab   File: test_anchor.py    License: Apache License 2.0 6 votes vote down vote up
def test_strides():
    from mmdet.core import AnchorGenerator
    # Square strides
    self = AnchorGenerator([10], [1.], [1.], [10])
    anchors = self.grid_anchors([(2, 2)], device='cpu')

    expected_anchors = torch.tensor([[-5., -5., 5., 5.], [5., -5., 15., 5.],
                                     [-5., 5., 5., 15.], [5., 5., 15., 15.]])

    assert torch.equal(anchors[0], expected_anchors)

    # Different strides in x and y direction
    self = AnchorGenerator([(10, 20)], [1.], [1.], [10])
    anchors = self.grid_anchors([(2, 2)], device='cpu')

    expected_anchors = torch.tensor([[-5., -5., 5., 5.], [5., -5., 15., 5.],
                                     [-5., 15., 5., 25.], [5., 15., 15., 25.]])

    assert torch.equal(anchors[0], expected_anchors) 
Example #6
Source Project: mmdetection   Author: open-mmlab   File: test_losses.py    License: Apache License 2.0 6 votes vote down vote up
def test_ce_loss():
    # use_mask and use_sigmoid cannot be true at the same time
    with pytest.raises(AssertionError):
        loss_cfg = dict(
            type='CrossEntropyLoss',
            use_mask=True,
            use_sigmoid=True,
            loss_weight=1.0)
        build_loss(loss_cfg)

    # test loss with class weights
    loss_cls_cfg = dict(
        type='CrossEntropyLoss',
        use_sigmoid=False,
        class_weight=[0.8, 0.2],
        loss_weight=1.0)
    loss_cls = build_loss(loss_cls_cfg)
    fake_pred = torch.Tensor([[100, -100]])
    fake_label = torch.Tensor([1]).long()
    assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(40.))

    loss_cls_cfg = dict(
        type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
    loss_cls = build_loss(loss_cls_cfg)
    assert torch.allclose(loss_cls(fake_pred, fake_label), torch.tensor(200.)) 
Example #7
Source Project: subword-qac   Author: clovaai   File: dataset.py    License: MIT License 6 votes vote down vote up
def collate_fn(queries, tokenizer, sample, max_seq_len=None):
    token_id_seqs = [[1] + tokenizer(x, **sample) + [2] for x in queries]

    length = [len(x) - 1 for x in token_id_seqs]
    if max_seq_len is None or max_seq_len > max(length) + 1:
        max_seq_len = max(length) + 1

    padded = []
    mask = []
    for x in token_id_seqs:
        x = x[:max_seq_len]
        pad_length = max_seq_len - len(x)
        padded.append(x + [0] * pad_length)
        mask.append([1] * (len(x) - 1) + [0] * pad_length)

    padded = torch.tensor(padded).t().contiguous()
    length = torch.tensor(length)
    mask = torch.tensor(mask).t().contiguous()
    return padded[:-1], padded[1:], length, mask 
Example #8
Source Project: models   Author: kipoi   File: model.py    License: MIT License 6 votes vote down vote up
def process_embedding(self, embedding,
                          residue_reduction=True, protein_reduction=False):
        '''
            Direct output of ELMo has shape (3,L,1024), with L being the protein's
            length, 3 being the number of layers used to train SeqVec (1 CharCNN, 2 LSTMs)
            and 1024 being a hyperparameter chosen to describe each amino acid.
            When a representation on residue level is required, you can sum
            over the first dimension, resulting in a tensor of size (L,1024).
            If you want to reduce each protein to a fixed-size vector, regardless of its
            length, you can average over dimension L.
        '''
        embedding = torch.tensor(embedding)
        if residue_reduction:
            embedding = embedding.sum(dim=0)
        elif protein_reduction:
            embedding = embedding.sum(dim=0).mean(dim=0)

        return embedding.cpu().detach().numpy() 
Example #9
Source Project: deep-learning-note   Author: wdxtub   File: 33_gru_raw.py    License: MIT License 6 votes vote down vote up
def get_params():
    def _one(shape):
        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)
        return torch.nn.Parameter(ts, requires_grad=True)

    def _three():
        return (_one((num_inputs, num_hiddens)),
                _one((num_hiddens, num_hiddens)),
                torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))

    W_xz, W_hz, b_z = _three() # 更新门参数
    W_xr, W_hr, b_r = _three() # 重置门参数
    W_xh, W_hh, b_h = _three() # 候选隐藏层参数

    # 输出层参数
    W_hq = _one((num_hiddens, num_outputs))
    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)
    return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q]) 
Example #10
Source Project: deep-learning-note   Author: wdxtub   File: utils.py    License: MIT License 6 votes vote down vote up
def data_iter_random(corpus_indices, batch_size, num_steps, device=None):
    # 减1是因为输出的索引x是相应输入的索引y加1
    num_examples = (len(corpus_indices) - 1) // num_steps
    epoch_size = num_examples // batch_size
    example_indices = list(range(num_examples))
    random.shuffle(example_indices)

    # 返回从pos开始的长为num_steps的序列
    def _data(pos):
        return corpus_indices[pos: pos + num_steps]

    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for i in range(epoch_size):
        # 每次读取batch_size个随机样本
        i = i * batch_size
        batch_indices = example_indices[i: i + batch_size]
        X = [_data(j * num_steps) for j in batch_indices]
        Y = [_data(j * num_steps + 1) for j in batch_indices]
        yield torch.tensor(X, dtype=torch.float32, device=device), torch.tensor(Y, dtype=torch.float32, device=device) 
Example #11
Source Project: deep-learning-note   Author: wdxtub   File: utils.py    License: MIT License 6 votes vote down vote up
def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char,
                        char_to_idx):
    state = None
    output = [char_to_idx[prefix[0]]]  # output会记录prefix加上输出
    for t in range(num_chars + len(prefix) - 1):
        X = torch.tensor([output[-1]], device=device).view(1, 1)
        if state is not None:
            if isinstance(state, tuple):  # LSTM, state:(h, c)
                state = (state[0].to(device), state[1].to(device))
            else:
                state = state.to(device)

        (Y, state) = model(X, state)  # 前向计算不需要传入模型参数
        if t < len(prefix) - 1:
            output.append(char_to_idx[prefix[t + 1]])
        else:
            output.append(int(Y.argmax(dim=1).item()))
    return ''.join([idx_to_char[i] for i in output]) 
Example #12
Source Project: deep-learning-note   Author: wdxtub   File: 3_linear_regression_raw.py    License: MIT License 6 votes vote down vote up
def generate_dataset(true_w, true_b):
    num_examples = 1000

    features = torch.tensor(np.random.normal(0, 1, (num_examples, num_inputs)), dtype=torch.float)
    # 真实 label
    labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
    # 添加噪声
    labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()), dtype=torch.float)
    # 展示下分布
    plt.scatter(features[:, 1].numpy(), labels.numpy(), 1)
    plt.show()
    
    return features, labels


# batch 读取数据集 
Example #13
Source Project: deep-learning-note   Author: wdxtub   File: 53_machine_translation.py    License: MIT License 6 votes vote down vote up
def batch_loss(encoder, decoder, X, Y, loss):
    batch_size = X.shape[0]
    enc_state = encoder.begin_state()
    enc_outputs, enc_state = encoder(X, enc_state)
    # 初始化解码器的隐藏状态
    dec_state = decoder.begin_state(enc_state)
    # 解码器在最初时间步的输入是BOS
    dec_input = torch.tensor([out_vocab.stoi[BOS]] * batch_size)
    # 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失
    mask, num_not_pad_tokens = torch.ones(batch_size,), 0
    l = torch.tensor([0.0])
    for y in Y.permute(1,0): # Y shape: (batch, seq_len)
        dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)
        l = l + (mask * loss(dec_output, y)).sum()
        dec_input = y  # 使用强制教学
        num_not_pad_tokens += mask.sum().item()
        # 将PAD对应位置的掩码设成0, 原文这里是 y != out_vocab.stoi[EOS], 感觉有误
        mask = mask * (y != out_vocab.stoi[PAD]).float()
    return l / num_not_pad_tokens 
Example #14
Source Project: deep-learning-note   Author: wdxtub   File: 53_machine_translation.py    License: MIT License 6 votes vote down vote up
def translate(encoder, decoder, input_seq, max_seq_len):
    in_tokens = input_seq.split(' ')
    in_tokens += [EOS] + [PAD] * (max_seq_len - len(in_tokens) - 1)
    enc_input = torch.tensor([[in_vocab.stoi[tk] for tk in in_tokens]]) # batch=1
    enc_state = encoder.begin_state()
    enc_output, enc_state = encoder(enc_input, enc_state)
    dec_input = torch.tensor([out_vocab.stoi[BOS]])
    dec_state = decoder.begin_state(enc_state)
    output_tokens = []
    for _ in range(max_seq_len):
        dec_output, dec_state = decoder(dec_input, dec_state, enc_output)
        pred = dec_output.argmax(dim=1)
        pred_token = out_vocab.itos[int(pred.item())]
        if pred_token == EOS:  # 当任一时间步搜索出EOS时,输出序列即完成
            break
        else:
            output_tokens.append(pred_token)
            dec_input = pred
    return output_tokens 
Example #15
Source Project: OpenNRE   Author: thunlp   File: data_loader.py    License: MIT License 6 votes vote down vote up
def collate_fn(data):
        data = list(zip(*data))
        label, bag_name, count = data[:3]
        seqs = data[3:]
        for i in range(len(seqs)):
            seqs[i] = torch.cat(seqs[i], 0) # (sumn, L)
            seqs[i] = seqs[i].expand((torch.cuda.device_count() if torch.cuda.device_count() > 0 else 1, ) + seqs[i].size())
        scope = [] # (B, 2)
        start = 0
        for c in count:
            scope.append((start, start + c))
            start += c
        assert(start == seqs[0].size(1))
        scope = torch.tensor(scope).long()
        label = torch.tensor(label).long() # (B)
        return [label, bag_name, scope] + seqs 
Example #16
Source Project: graph-neural-networks   Author: alelab-upenn   File: graphML.py    License: GNU General Public License v3.0 6 votes vote down vote up
def addGSO(self, S):
        # Every S has 3 dimensions.
        assert len(S.shape) == 3
        # S is of shape E x N x N
        self.N = S.shape[1]
        assert S.shape[2] == self.N
        self.S = S
        # Change tensor S to numpy now that we have saved it as tensor in self.S
        S = S.cpu().numpy()
        # The neighborhood matrix has to be a tensor of shape
        #   nOutputNodes x maxNeighborhoodSize
        neighborhood = []
        maxNeighborhoodSizes = []
        for k in range(1,self.K+1):
            # For each hop (0,1,...) in the range K
            thisNeighborhood = graphTools.computeNeighborhood(S, k,
                                                            outputType='matrix')
            # compute the k-hop neighborhood
            neighborhood.append(torch.tensor(thisNeighborhood).to(self.S.device))
            maxNeighborhoodSizes.append(thisNeighborhood.shape[1])
        self.maxNeighborhoodSizes = maxNeighborhoodSizes
        self.neighborhood = neighborhood 
Example #17
Source Project: graph-neural-networks   Author: alelab-upenn   File: graphML.py    License: GNU General Public License v3.0 6 votes vote down vote up
def addGSO(self, S):
        # Every S has 3 dimensions.
        assert len(S.shape) == 3
        # S is of shape E x N x N
        self.N = S.shape[1]
        assert S.shape[2] == self.N
        self.S = S
        # Change tensor S to numpy now that we have saved it as tensor in self.S
        S = S.cpu().numpy()
        # The neighborhood matrix has to be a tensor of shape
        #   nOutputNodes x maxNeighborhoodSize
        neighborhood = []
        for k in range(1,self.K+1):
            # For each hop (0,1,...) in the range K
            thisNeighborhood = graphTools.computeNeighborhood(S, k,
                                                              outputType='list')
            # compute the k-hop neighborhood
            neighborhood.append(thisNeighborhood)
        self.neighborhood = neighborhood 
Example #18
Source Project: graph-neural-networks   Author: alelab-upenn   File: graphML.py    License: GNU General Public License v3.0 6 votes vote down vote up
def forward(self, x):
        # x is of shape: batchSize x dimInFeatures x numberNodesIn
        B = x.shape[0]
        F = x.shape[1]
        Nin = x.shape[2]
        # If we have less filter coefficients than the required ones, we need
        # to use the copying scheme
        if self.M == self.N:
            self.h = self.weight
        else:
            self.h = torch.index_select(self.weight, 4, self.copyNodes)
        # And now we add the zero padding
        if Nin < self.N:
            zeroPad = torch.zeros(B, F, self.N-Nin).type(x.dtype).to(x.device)
            x = torch.cat((x, zeroPad), dim = 2)
        # Compute the filter output
        u = NVGF(self.h, self.S, x, self.bias)
        # So far, u is of shape batchSize x dimOutFeatures x numberNodes
        # And we want to return a tensor of shape
        # batchSize x dimOutFeatures x numberNodesIn
        # since the nodes between numberNodesIn and numberNodes are not required
        if Nin < self.N:
            u = torch.index_select(u, 2, torch.arange(Nin).to(u.device))
        return u 
Example #19
Source Project: graph-neural-networks   Author: alelab-upenn   File: graphML.py    License: GNU General Public License v3.0 6 votes vote down vote up
def forward(self, x):
        # x is of shape: batchSize x dimInFeatures x numberNodesIn
        B = x.shape[0]
        F = x.shape[1]
        Nin = x.shape[2]
        # And now we add the zero padding
        if Nin < self.N:
            x = torch.cat((x,
                           torch.zeros(B, F, self.N-Nin)\
                                   .type(x.dtype).to(x.device)
                          ), dim = 2)
        # Compute the filter output
        u = jARMA(self.inverseWeight, self.directWeight, self.filterWeight,
                  self.S, x, b = self.bias, tMax = self.tMax)
        # So far, u is of shape batchSize x dimOutFeatures x numberNodes
        # And we want to return a tensor of shape
        # batchSize x dimOutFeatures x numberNodesIn
        # since the nodes between numberNodesIn and numberNodes are not required
        if Nin < self.N:
            u = torch.index_select(u, 2, torch.arange(Nin).to(u.device))
        return u 
Example #20
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: data.py    License: Apache License 2.0 5 votes vote down vote up
def get_cuda(tensor):
    # if torch.cuda.is_available():
    #     tensor = tensor
    return tensor.cuda() 
Example #21
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: data.py    License: Apache License 2.0 5 votes vote down vote up
def get_cuda(tensor):
    # if torch.cuda.is_available():
    #     tensor = tensor
    return tensor.cuda() 
Example #22
Source Project: controllable-text-attribute-transfer   Author: Nrgeup   File: data.py    License: Apache License 2.0 5 votes vote down vote up
def get_cuda(tensor):
    # if torch.cuda.is_available():
    #     tensor = tensor
    return tensor.cuda() 
Example #23
Source Project: comet-commonsense   Author: atcbosselut   File: utils.py    License: Apache License 2.0 5 votes vote down vote up
def make_new_tensor_from_list(items, device_num, dtype=torch.float32):
    if device_num is not None:
        device = torch.device("cuda:{}".format(device_num))
    else:
        device = torch.device("cpu")
    return torch.tensor(items, dtype=dtype, device=device)


# is_dir look ast at whether the name we make
# should be a directory or a filename 
Example #24
Source Project: comet-commonsense   Author: atcbosselut   File: utils.py    License: Apache License 2.0 5 votes vote down vote up
def initialize_progress_bar(data_loader_list):
    num_examples = sum([len(tensor) for tensor in
                        data_loader_list.values()])
    return set_progress_bar(num_examples) 
Example #25
Source Project: comet-commonsense   Author: atcbosselut   File: sampler.py    License: Apache License 2.0 5 votes vote down vote up
def make_batch(self, X):
        X = np.array(X)
        assert X.ndim in [1, 2]
        if X.ndim == 1:
            X = np.expand_dims(X, axis=0)
        pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
        pos_enc = np.expand_dims(pos_enc, axis=0)
        batch = np.stack([X, pos_enc], axis=-1)
        batch = torch.tensor(batch, dtype=torch.long).to(device)
        return batch 
Example #26
Source Project: comet-commonsense   Author: atcbosselut   File: generate_conceptnet_beam_search.py    License: Apache License 2.0 5 votes vote down vote up
def make_batch(X):
    X = np.array(X)
    assert X.ndim in [1, 2]
    if X.ndim == 1:
        X = np.expand_dims(X, axis=0)
    pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
    pos_enc = np.expand_dims(pos_enc, axis=0)
    batch = np.stack([X, pos_enc], axis=-1)
    batch = torch.tensor(batch, dtype=torch.long).to(device)
    return batch 
Example #27
Source Project: comet-commonsense   Author: atcbosselut   File: generate_atomic_greedy.py    License: Apache License 2.0 5 votes vote down vote up
def make_batch(X):
    X = np.array(X)
    assert X.ndim in [1, 2]
    if X.ndim == 1:
        X = np.expand_dims(X, axis=0)
    pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
    pos_enc = np.expand_dims(pos_enc, axis=0)
    batch = np.stack([X, pos_enc], axis=-1)
    batch = torch.tensor(batch, dtype=torch.long).to(device)
    return batch 
Example #28
Source Project: comet-commonsense   Author: atcbosselut   File: generate_atomic_beam_search.py    License: Apache License 2.0 5 votes vote down vote up
def make_batch(X):
    X = np.array(X)
    assert X.ndim in [1, 2]
    if X.ndim == 1:
        X = np.expand_dims(X, axis=0)
    pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
    pos_enc = np.expand_dims(pos_enc, axis=0)
    batch = np.stack([X, pos_enc], axis=-1)
    batch = torch.tensor(batch, dtype=torch.long).to(device)
    return batch 
Example #29
Source Project: hgraph2graph   Author: wengong-jin   File: hgnn.py    License: MIT License 5 votes vote down vote up
def make_cuda(tensors):
    tree_tensors, graph_tensors = tensors
    make_tensor = lambda x: x if type(x) is torch.Tensor else torch.tensor(x)
    tree_tensors = [make_tensor(x).cuda().long() for x in tree_tensors[:-1]] + [tree_tensors[-1]]
    graph_tensors = [make_tensor(x).cuda().long() for x in graph_tensors[:-1]] + [graph_tensors[-1]]
    return tree_tensors, graph_tensors 
Example #30
Source Project: hgraph2graph   Author: wengong-jin   File: hgnn.py    License: MIT License 5 votes vote down vote up
def forward(self, x_graphs, x_tensors, y_graphs, y_tensors, y_orders, cond, beta):
        x_tensors = make_cuda(x_tensors)
        y_tensors = make_cuda(y_tensors)
        cond = torch.tensor(cond).float().cuda()

        x_root_vecs, x_tree_vecs, x_graph_vecs = self.encode(x_tensors)
        _, y_tree_vecs, y_graph_vecs = self.encode(y_tensors)

        diff_tree_vecs = y_tree_vecs.sum(dim=1) - x_tree_vecs.sum(dim=1)
        diff_graph_vecs = y_graph_vecs.sum(dim=1) - x_graph_vecs.sum(dim=1)
        diff_tree_vecs = self.U_tree( torch.cat([diff_tree_vecs, cond], dim=-1) ) #combine condition for posterior
        diff_graph_vecs = self.U_graph( torch.cat([diff_graph_vecs, cond], dim=-1) ) #combine condition for posterior

        diff_tree_vecs, tree_kl = self.rsample(diff_tree_vecs, self.T_mean, self.T_var)
        diff_graph_vecs, graph_kl = self.rsample(diff_graph_vecs, self.G_mean, self.G_var)
        kl_div = tree_kl + graph_kl

        diff_tree_vecs = torch.cat([diff_tree_vecs, cond], dim=-1) #combine condition for posterior
        diff_graph_vecs = torch.cat([diff_graph_vecs, cond], dim=-1) #combine condition for posterior

        diff_tree_vecs = diff_tree_vecs.unsqueeze(1).expand(-1, x_tree_vecs.size(1), -1)
        diff_graph_vecs = diff_graph_vecs.unsqueeze(1).expand(-1, x_graph_vecs.size(1), -1)
        x_tree_vecs = self.W_tree( torch.cat([x_tree_vecs, diff_tree_vecs], dim=-1) )
        x_graph_vecs = self.W_graph( torch.cat([x_graph_vecs, diff_graph_vecs], dim=-1) )

        loss, wacc, iacc, tacc, sacc = self.decoder((x_root_vecs, x_tree_vecs, x_graph_vecs), y_graphs, y_tensors, y_orders)
        return loss + beta * kl_div, kl_div.item(), wacc, iacc, tacc, sacc