Python torch.ones() Examples

The following are 30 code examples of torch.ones(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File:    From cvpr2018-hnd with MIT License 7 votes vote down vote up
def __init__(self, T, opts):
        super(LOOLoss, self).__init__()
        self.gpu = opts.gpu
        self.loo = opts.loo if 'LOO' in opts.method else 0.
        self.label_smooth = opts.label_smooth
        self.kld_u_const = math.log(len(T['wnids']))
        self.relevant = [torch.from_numpy(rel) for rel in T['relevant']]
        self.labels_relevant = torch.from_numpy(T['labels_relevant'].astype(np.uint8))
        ch_slice = T['ch_slice']
        if opts.class_wise:
            num_children = T['num_children']
            num_supers = len(num_children)
            self.class_weight = torch.zeros(ch_slice[-1])
            for m, num_ch in enumerate(num_children):
                self.class_weight[ch_slice[m]:ch_slice[m+1]] = 1. / (num_ch * num_supers)
            self.class_weight = torch.ones(ch_slice[-1]) / ch_slice[-1] 
Example #2
Source File:    From graph-neural-networks with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, x):
        # x is of shape: batchSize x dimInFeatures x numberNodesIn
        B = x.shape[0]
        F = x.shape[1]
        Nin = x.shape[2]
        # If we have less filter coefficients than the required ones, we need
        # to use the copying scheme
        if self.M == self.N:
            self.h = self.weight
            self.h = torch.index_select(self.weight, 4, self.copyNodes)
        # And now we add the zero padding
        if Nin < self.N:
            zeroPad = torch.zeros(B, F, self.N-Nin).type(x.dtype).to(x.device)
            x =, zeroPad), dim = 2)
        # Compute the filter output
        u = NVGF(self.h, self.S, x, self.bias)
        # So far, u is of shape batchSize x dimOutFeatures x numberNodes
        # And we want to return a tensor of shape
        # batchSize x dimOutFeatures x numberNodesIn
        # since the nodes between numberNodesIn and numberNodes are not required
        if Nin < self.N:
            u = torch.index_select(u, 2, torch.arange(Nin).to(u.device))
        return u 
Example #3
Source File:    From controllable-text-attribute-transfer with Apache License 2.0 6 votes vote down vote up
def greedy_decode(self, latent, max_len, start_id):
        latent: (batch_size, max_src_seq, d_model)
        src_mask: (batch_size, 1, max_src_len)
        batch_size = latent.size(0)
        ys = get_cuda(torch.ones(batch_size, 1).fill_(start_id).long())  # (batch_size, 1)
        for i in range(max_len - 1):
            # input("==========")
            # print("="*10, i)
            # print("ys", ys.size())  # (batch_size, i)
            # print("tgt_mask", subsequent_mask(ys.size(1)).size())  # (1, i, i)
            out = self.decode(latent.unsqueeze(1), to_var(ys), to_var(subsequent_mask(ys.size(1)).long()))
            prob = self.generator(out[:, -1])
            # print("prob", prob.size())  # (batch_size, vocab_size)
            _, next_word = torch.max(prob, dim=1)
            # print("next_word", next_word.size())  # (batch_size)

            # print("next_word.unsqueeze(1)", next_word.unsqueeze(1).size())

            ys =[ys, next_word.unsqueeze(1)], dim=1)
            # print("ys", ys.size())
        return ys[:, 1:] 
Example #4
Source File:    From controllable-text-attribute-transfer with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_size, hidden_size, correlation_func=1, do_similarity=False):
        super(AttentionScore, self).__init__()
        self.correlation_func = correlation_func
        self.hidden_size = hidden_size

        if correlation_func == 2 or correlation_func == 3:
            self.linear = nn.Linear(input_size, hidden_size, bias=False)
            if do_similarity:
                self.diagonal = Parameter(torch.ones(1, 1, 1) / (hidden_size ** 0.5), requires_grad=False)
                self.diagonal = Parameter(torch.ones(1, 1, hidden_size), requires_grad=True)

        if correlation_func == 4:
            self.linear = nn.Linear(input_size, input_size, bias=False)

        if correlation_func == 5:
            self.linear = nn.Linear(input_size, hidden_size, bias=False) 
Example #5
Source File:    From mmdetection with Apache License 2.0 6 votes vote down vote up
def __init__(self,
                             torch.ones(self.out_channels, 1, 1, 1))
                             torch.zeros(self.out_channels, 1, 1, 1)) 
Example #6
Source File:    From transferlearning with MIT License 6 votes vote down vote up
def CORAL(source, target):
    d = source.size(1)
    ns, nt = source.size(0), target.size(0)

    # source covariance
    tmp_s = torch.ones((1, ns)).to(DEVICE) @ source
    cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)

    # target covariance
    tmp_t = torch.ones((1, nt)).to(DEVICE) @ target
    ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)

    # frobenius norm
    loss = (cs - ct).pow(2).sum().sqrt()
    loss = loss / (4 * d * d)

    return loss 
Example #7
Source File:    From Pytorch-Project-Template with MIT License 6 votes vote down vote up
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, condense_factor=None, dropout_rate=0.):

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.condense_factor = condense_factor
        self.groups = groups
        self.dropout_rate = dropout_rate

        # Check if given configs are valid
        assert self.in_channels % self.groups == 0, "group value is not divisible by input channels"
        assert self.in_channels % self.condense_factor == 0, "condensation factor is not divisible by input channels"
        assert self.out_channels % self.groups == 0, "group value is not divisible by output channels"

        self.batch_norm = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        if self.dropout_rate > 0:
            self.dropout = nn.Dropout(self.dropout_rate, inplace=False)
        self.conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=kernel_size,
                              stride=stride, padding=padding, dilation=dilation, groups=1, bias=False)
        # register conv buffers
        self.register_buffer('_count', torch.zeros(1))
        self.register_buffer('_stage', torch.zeros(1))
        self.register_buffer('_mask', torch.ones(self.conv.weight.size())) 
Example #8
Source File:    From torch-toolbox with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, num_features, eps=1e-5, momentum=0.9, affine=True):
        super(_SwitchNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(num_features))
            self.bias = nn.Parameter(torch.Tensor(num_features))
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        self.mean_weight = nn.Parameter(torch.ones(3))
        self.var_weight = nn.Parameter(torch.ones(3))

        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features)) 
Example #9
Source File:    From torch-toolbox with BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def __init__(self, prefix, num_features, eps=1e-5, momentum=0.9, groups=32,
        super(_EvoNorm, self).__init__()
        assert prefix in ('s0', 'b0')
        self.prefix = prefix
        self.groups = groups
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.bias = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.v = nn.Parameter(torch.Tensor(1, num_features, 1, 1))
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
            self.register_parameter('v', None)
        self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
Example #10
Source File:    From controllable-text-attribute-transfer with Apache License 2.0 6 votes vote down vote up
def greedy_decode(self, latent, max_len, start_id):
        latent: (batch_size, max_src_seq, d_model)
        src_mask: (batch_size, 1, max_src_len)
        batch_size = latent.size(0)
        ys = get_cuda(torch.ones(batch_size, 1).fill_(start_id).long())  # (batch_size, 1)
        for i in range(max_len - 1):
            # input("==========")
            # print("="*10, i)
            # print("ys", ys.size())  # (batch_size, i)
            # print("tgt_mask", subsequent_mask(ys.size(1)).size())  # (1, i, i)
            out = self.decode(latent.unsqueeze(1), to_var(ys), to_var(subsequent_mask(ys.size(1)).long()))
            prob = self.generator(out[:, -1])
            # print("prob", prob.size())  # (batch_size, vocab_size)
            _, next_word = torch.max(prob, dim=1)
            # print("next_word", next_word.size())  # (batch_size)

            # print("next_word.unsqueeze(1)", next_word.unsqueeze(1).size())

            ys =[ys, next_word.unsqueeze(1)], dim=1)
            # print("ys", ys.size())
        return ys[:, 1:] 
Example #11
Source File:    From controllable-text-attribute-transfer with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_size, hidden_size, correlation_func=1, do_similarity=False):
        super(AttentionScore, self).__init__()
        self.correlation_func = correlation_func
        self.hidden_size = hidden_size

        if correlation_func == 2 or correlation_func == 3:
            self.linear = nn.Linear(input_size, hidden_size, bias=False)
            if do_similarity:
                self.diagonal = Parameter(torch.ones(1, 1, 1) / (hidden_size ** 0.5), requires_grad=False)
                self.diagonal = Parameter(torch.ones(1, 1, hidden_size), requires_grad=True)

        if correlation_func == 4:
            self.linear = nn.Linear(input_size, input_size, bias=False)

        if correlation_func == 5:
            self.linear = nn.Linear(input_size, hidden_size, bias=False) 
Example #12
Source File:    From deep-learning-note with MIT License 6 votes vote down vote up
def batch_loss(encoder, decoder, X, Y, loss):
    batch_size = X.shape[0]
    enc_state = encoder.begin_state()
    enc_outputs, enc_state = encoder(X, enc_state)
    # 初始化解码器的隐藏状态
    dec_state = decoder.begin_state(enc_state)
    # 解码器在最初时间步的输入是BOS
    dec_input = torch.tensor([out_vocab.stoi[BOS]] * batch_size)
    # 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失
    mask, num_not_pad_tokens = torch.ones(batch_size,), 0
    l = torch.tensor([0.0])
    for y in Y.permute(1,0): # Y shape: (batch, seq_len)
        dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)
        l = l + (mask * loss(dec_output, y)).sum()
        dec_input = y  # 使用强制教学
        num_not_pad_tokens += mask.sum().item()
        # 将PAD对应位置的掩码设成0, 原文这里是 y != out_vocab.stoi[EOS], 感觉有误
        mask = mask * (y != out_vocab.stoi[PAD]).float()
    return l / num_not_pad_tokens 
Example #13
Source File:    From controllable-text-attribute-transfer with Apache License 2.0 6 votes vote down vote up
def greedy_decode(self, latent, max_len, start_id):
        latent: (batch_size, max_src_seq, d_model)
        src_mask: (batch_size, 1, max_src_len)
        batch_size = latent.size(0)
        ys = get_cuda(torch.ones(batch_size, 1).fill_(start_id).long())  # (batch_size, 1)
        for i in range(max_len - 1):
            # input("==========")
            # print("="*10, i)
            # print("ys", ys.size())  # (batch_size, i)
            # print("tgt_mask", subsequent_mask(ys.size(1)).size())  # (1, i, i)
            out = self.decode(latent.unsqueeze(1), to_var(ys), to_var(subsequent_mask(ys.size(1)).long()))
            prob = self.generator(out[:, -1])
            # print("prob", prob.size())  # (batch_size, vocab_size)
            _, next_word = torch.max(prob, dim=1)
            # print("next_word", next_word.size())  # (batch_size)

            # print("next_word.unsqueeze(1)", next_word.unsqueeze(1).size())

            ys =[ys, next_word.unsqueeze(1)], dim=1)
            # print("ys", ys.size())
        return ys[:, 1:] 
Example #14
Source File:    From audio with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _feature_window_function(window_type: str,
                             window_size: int,
                             blackman_coeff: float,
                             device: torch.device,
                             dtype: int,
                             ) -> Tensor:
    r"""Returns a window function with the given type and size
    if window_type == HANNING:
        return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
    elif window_type == HAMMING:
        return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
    elif window_type == POVEY:
        # like hanning but goes to zero at edges
        return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
    elif window_type == RECTANGULAR:
        return torch.ones(window_size, device=device, dtype=dtype)
    elif window_type == BLACKMAN:
        a = 2 * math.pi / (window_size - 1)
        window_function = torch.arange(window_size, device=device, dtype=dtype)
        # can't use torch.blackman_window as they use different coefficients
        return (blackman_coeff - 0.5 * torch.cos(a * window_function) +
                (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)).to(device=device, dtype=dtype)
        raise Exception('Invalid window type ' + window_type) 
Example #15
Source File:    From controllable-text-attribute-transfer with Apache License 2.0 6 votes vote down vote up
def __init__(self, input_size, hidden_size, correlation_func=1, do_similarity=False):
        super(AttentionScore, self).__init__()
        self.correlation_func = correlation_func
        self.hidden_size = hidden_size

        if correlation_func == 2 or correlation_func == 3:
            self.linear = nn.Linear(input_size, hidden_size, bias=False)
            if do_similarity:
                self.diagonal = Parameter(torch.ones(1, 1, 1) / (hidden_size ** 0.5), requires_grad=False)
                self.diagonal = Parameter(torch.ones(1, 1, hidden_size), requires_grad=True)

        if correlation_func == 4:
            self.linear = nn.Linear(input_size, input_size, bias=False)

        if correlation_func == 5:
            self.linear = nn.Linear(input_size, hidden_size, bias=False) 
Example #16
Source File:    From audio with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _fade_in(self, waveform_length: int) -> Tensor:
        fade = torch.linspace(0, 1, self.fade_in_len)
        ones = torch.ones(waveform_length - self.fade_in_len)

        if self.fade_shape == "linear":
            fade = fade

        if self.fade_shape == "exponential":
            fade = torch.pow(2, (fade - 1)) * fade

        if self.fade_shape == "logarithmic":
            fade = torch.log10(.1 + fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi - math.pi / 2) / 2 + 0.5

        return, ones)).clamp_(0, 1) 
Example #17
Source File:    From audio with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def _fade_out(self, waveform_length: int) -> Tensor:
        fade = torch.linspace(0, 1, self.fade_out_len)
        ones = torch.ones(waveform_length - self.fade_out_len)

        if self.fade_shape == "linear":
            fade = - fade + 1

        if self.fade_shape == "exponential":
            fade = torch.pow(2, - fade) * (1 - fade)

        if self.fade_shape == "logarithmic":
            fade = torch.log10(1.1 - fade) + 1

        if self.fade_shape == "quarter_sine":
            fade = torch.sin(fade * math.pi / 2 + math.pi / 2)

        if self.fade_shape == "half_sine":
            fade = torch.sin(fade * math.pi + math.pi / 2) / 2 + 0.5

        return, fade)).clamp_(0, 1) 
Example #18
Source File:    From fast-MPN-COV with MIT License 5 votes vote down vote up
def forward(ctx, input):
         x = input
         batchSize =[0]
         dim =[1]
         h =[2]
         w =[3]
         M = h*w
         x = x.reshape(batchSize,dim,M)
         I_hat = (-1./M/M)*torch.ones(M,M,device = x.device) + (1./M)*torch.eye(M,M,device = x.device)
         I_hat = I_hat.view(1,M,M).repeat(batchSize,1,1).type(x.dtype)
         y = x.bmm(I_hat).bmm(x.transpose(1,2))
         return y 
Example #19
Source File:    From PolarSeg with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def grp_range_torch(a,dev):
    idx = torch.cumsum(a,0)
    id_arr = torch.ones(idx[-1],dtype = torch.int64,device=dev)
    id_arr[0] = 0
    id_arr[idx[:-1]] = -a[:-1]+1
    return torch.cumsum(id_arr,0) 
Example #20
Source File:    From fast-MPN-COV with MIT License 5 votes vote down vote up
def forward(ctx, input):
         x = input
         batchSize =[0]
         dim =[1]
         dtype = x.dtype
         x = x.reshape(batchSize, dim*dim)
         I = torch.ones(dim,dim).triu().reshape(dim*dim)
         index = I.nonzero()
         y = torch.zeros(batchSize,int(dim*(dim+1)/2),device = x.device).type(dtype)
         y = x[:,index]
         return y 
Example #21
Source File:    From deep-learning-note with MIT License 5 votes vote down vote up
def __init__(self, num_features, num_dims):
        super(BatchNorm, self).__init__()
        if num_dims == 2:
            # 全连接
            shape = (1, num_features)
            # 卷积
            shape = (1, num_features, 1, 1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape) 
Example #22
Source File:    From ACAN with MIT License 5 votes vote down vote up
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
        """Creates an Activated Batch Normalization module

        num_features : int
            Number of feature channels in the input and output.
        eps : float
            Small constant to prevent numerical issues.
        momentum : float
            Momentum factor applied to compute running statistics as.
        affine : bool
            If `True` apply learned scale and shift transformation after normalization.
        activation : str
            Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
        slope : float
            Negative slope for the `leaky_relu` activation.
        super(ABN, self).__init__()
        self.num_features = num_features
        self.affine = affine
        self.eps = eps
        self.momentum = momentum
        self.activation = activation
        self.slope = slope
        if self.affine:
            self.weight = nn.Parameter(torch.ones(num_features))
            self.bias = nn.Parameter(torch.zeros(num_features))
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
Example #23
Source File:    From mmdetection with Apache License 2.0 5 votes vote down vote up
def convert_bn(blobs, state_dict, caffe_name, torch_name, converted_names):
    # detectron replace bn with affine channel layer
    state_dict[torch_name + '.bias'] = torch.from_numpy(blobs[caffe_name +
    state_dict[torch_name + '.weight'] = torch.from_numpy(blobs[caffe_name +
    bn_size = state_dict[torch_name + '.weight'].size()
    state_dict[torch_name + '.running_mean'] = torch.zeros(bn_size)
    state_dict[torch_name + '.running_var'] = torch.ones(bn_size)
    converted_names.add(caffe_name + '_b')
    converted_names.add(caffe_name + '_s') 
Example #24
Source File:    From mmdetection with Apache License 2.0 5 votes vote down vote up
def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
        """Run forward function and calculate loss for mask head in
        if not self.share_roi_extractor:
            pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
            if pos_rois.shape[0] == 0:
                return dict(loss_mask=None)
            mask_results = self._mask_forward(x, pos_rois)
            pos_inds = []
            device = bbox_feats.device
            for res in sampling_results:
            pos_inds =
            if pos_inds.shape[0] == 0:
                return dict(loss_mask=None)
            mask_results = self._mask_forward(
                x, pos_inds=pos_inds, bbox_feats=bbox_feats)

        mask_targets = self.mask_head.get_targets(sampling_results, gt_masks,
        pos_labels =[res.pos_gt_labels for res in sampling_results])
        loss_mask = self.mask_head.loss(mask_results['mask_pred'],
                                        mask_targets, pos_labels)

        mask_results.update(loss_mask=loss_mask, mask_targets=mask_targets)
        return mask_results 
Example #25
Source File:    From hgraph2graph with MIT License 5 votes vote down vote up
def forward(self, fnode, fmess, agraph, bgraph):
        h = self.rnn(fmess, bgraph)
        h = self.rnn.get_hidden_state(h)
        nei_message = index_select_ND(h, 0, agraph)
        nei_message = nei_message.sum(dim=1)
        node_hiddens =[fnode, nei_message], dim=1)
        node_hiddens = self.W_o(node_hiddens)

        mask = torch.ones(node_hiddens.size(0), 1, device=fnode.device)
        mask[0, 0] = 0 #first node is padding
        return node_hiddens * mask, h #return only the hidden state (different from IncMPNEncoder in LSTM case) 
Example #26
Source File:    From hgraph2graph with MIT License 5 votes vote down vote up
def forward(self, fmess, bgraph):
        h = torch.zeros(fmess.size(0), self.hidden_size, device=fmess.device)
        c = torch.zeros(fmess.size(0), self.hidden_size, device=fmess.device)
        mask = torch.ones(h.size(0), 1, device=h.device)
        mask[0, 0] = 0 #first message is padding

        for i in range(self.depth):
            h_nei = index_select_ND(h, 0, bgraph)
            c_nei = index_select_ND(c, 0, bgraph)
            h,c = self.LSTM(fmess, h_nei, c_nei)
            h = h * mask
            c = c * mask
        return h,c 
Example #27
Source File:    From hgraph2graph with MIT License 5 votes vote down vote up
def index_scatter(sub_data, all_data, index):
    d0, d1 = all_data.size()
    buf = torch.zeros_like(all_data).scatter_(0, index.repeat(d1, 1).t(), sub_data)
    mask = torch.ones(d0, device=all_data.device).scatter_(0, index, 0)
    return all_data * mask.unsqueeze(-1) + buf 
Example #28
Source File:    From hgraph2graph with MIT License 5 votes vote down vote up
def forward(self, fnode, fmess, agraph, bgraph):
        h = self.rnn(fmess, bgraph)
        h = self.rnn.get_hidden_state(h)
        nei_message = index_select_ND(h, 0, agraph)
        nei_message = nei_message.sum(dim=1)
        node_hiddens =[fnode, nei_message], dim=1)
        node_hiddens = self.W_o(node_hiddens)

        mask = torch.ones(node_hiddens.size(0), 1, device=fnode.device)
        mask[0, 0] = 0 #first node is padding
        return node_hiddens * mask, h #return only the hidden state (different from IncMPNEncoder in LSTM case) 
Example #29
Source File:    From hgraph2graph with MIT License 5 votes vote down vote up
def forward(self, fmess, bgraph):
        h = torch.zeros(fmess.size(0), self.hidden_size, device=fmess.device)
        c = torch.zeros(fmess.size(0), self.hidden_size, device=fmess.device)
        mask = torch.ones(h.size(0), 1, device=h.device)
        mask[0, 0] = 0 #first message is padding

        for i in range(self.depth):
            h_nei = index_select_ND(h, 0, bgraph)
            c_nei = index_select_ND(c, 0, bgraph)
            h,c = self.LSTM(fmess, h_nei, c_nei)
            h = h * mask
            c = c * mask
        return h,c 
Example #30
Source File:    From hgraph2graph with MIT License 5 votes vote down vote up
def forward(self, fmess, bgraph):
        h = torch.zeros(fmess.size(0), self.hidden_size, device=fmess.device)
        mask = torch.ones(h.size(0), 1, device=h.device)
        mask[0, 0] = 0 #first message is padding

        for i in range(self.depth):
            h_nei = index_select_ND(h, 0, bgraph)
            h = self.GRU(fmess, h_nei)
            h = h * mask
        return h