Python torch.chunk() Examples

The following are 30 code examples of torch.chunk(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: nnet.py    From pase with MIT License 6 votes vote down vote up
def forward(self, x):
        h = self.frontend(x)
        if not self.ft_fe:
            h = h.detach()
        if hasattr(self, 'z_bnorm'):
            h = self.z_bnorm(h)
        ht, state = self.rnn(h.transpose(1, 2))
        if self.return_sequence:
            ht = ht.transpose(1, 2)
        else:
            if not self.uni:
                # pick last time-step for each dir
                # first chunk feat dim
                bsz, slen, feats = ht.size()
                ht = torch.chunk(ht.view(bsz, slen, 2, feats // 2), 2, dim=2)
                # now select fwd
                ht_fwd = ht[0][:, -1, 0, :].unsqueeze(2)
                ht_bwd = ht[1][:, 0, 0, :].unsqueeze(2)
                ht = torch.cat((ht_fwd, ht_bwd), dim=1)
            else:
                # just last time-step works
                ht = ht[:, -1, :].unsqueeze(2)
        y = self.model(ht)
        return y 
Example #2
Source File: modules.py    From pase with MIT License 6 votes vote down vote up
def format_frontend_chunk(batch, device='cpu'):
    if type(batch) == dict:
        if 'chunk_ctxt' and 'chunk_rand' in batch:
            keys = ['chunk', 'chunk_ctxt', 'chunk_rand', 'cchunk']
            # cluster all 'chunk's, including possible 'cchunk'
            batches = [batch[k] for k in keys if k in batch]
            x = torch.cat(batches, dim=0).to(device)
            # store the number of batches condensed as format
            data_fmt = len(batches)
        else:
            x = batch['chunk'].to(device)
            data_fmt = 1
    else:
        x = batch
        data_fmt = 0
    return x, data_fmt 
Example #3
Source File: revnet.py    From imgclsmob with MIT License 6 votes vote down vote up
def forward(ctx, x, fm, gm, *params):

        with torch.no_grad():
            x1, x2 = torch.chunk(x, chunks=2, dim=1)
            x1 = x1.contiguous()
            x2 = x2.contiguous()

            y1 = x1 + fm(x2)
            y2 = x2 + gm(y1)

            y = torch.cat((y1, y2), dim=1)

            x1.set_()
            x2.set_()
            y1.set_()
            y2.set_()
            del x1, x2, y1, y2

        ctx.save_for_backward(x, y)
        ctx.fm = fm
        ctx.gm = gm

        return y 
Example #4
Source File: shufflenetv2b.py    From imgclsmob with MIT License 6 votes vote down vote up
def forward(self, x):
        if self.downsample:
            y1 = self.shortcut_dconv(x)
            y1 = self.shortcut_conv(y1)
            x2 = x
        else:
            y1, x2 = torch.chunk(x, chunks=2, dim=1)
        y2 = self.conv1(x2)
        y2 = self.dconv(y2)
        y2 = self.conv2(y2)
        if self.use_se:
            y2 = self.se(y2)
        if self.use_residual and not self.downsample:
            y2 = y2 + x2
        x = torch.cat((y1, y2), dim=1)
        x = self.c_shuffle(x)
        return x 
Example #5
Source File: esim.py    From video_captioning_rl with MIT License 6 votes vote down vote up
def seq2seq_cross_entropy(logits, label, l, chuck=None, sos_truncate=True):
    """
    :param logits: [exB, V] : exB = sum(l)
    :param label: [B] : a batch of Label
    :param l: [B] : a batch of LongTensor indicating the lengths of each inputs
    :param chuck: Number of chuck to process
    :return: A loss value
    """
    packed_label = pack_sequence_for_linear(label, l)
    cross_entropy_loss = functools.partial(F.cross_entropy, size_average=False)
    total = sum(l)

    assert total == logits.size(0) or packed_label.size(0) == logits.size(0),\
        "logits length mismatch with label length."

    if chuck:
        logits_losses = 0
        for x, y in zip(torch.chunk(logits, chuck, dim=0), torch.chunk(packed_label, chuck, dim=0)):
            logits_losses += cross_entropy_loss(x, y)
        return logits_losses * (1 / total)
    else:
        return cross_entropy_loss(logits, packed_label) * (1 / total) 
Example #6
Source File: perceptual_loss.py    From PerceptualGAN with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, input):

        input = input.clone()
        input = self.preprocess_range(input)

        if self.preprocessing_type == 'caffe':

            r, g, b = torch.chunk(input, 3, dim=1)
            bgr = torch.cat([b, g, r], 1)
            out = bgr * 255 - self.vgg_mean

        elif self.preprocessing_type == 'pytorch':

            input = input - self.vgg_mean
            input = input / self.vgg_std

        output = input
        outputs = []
        
        for block in self.blocks:
            output = block(output)
            outputs.append(output)

        return outputs 
Example #7
Source File: esim.py    From video_captioning_rl with MIT License 6 votes vote down vote up
def pack_sequence_for_linear(inputs, lengths, batch_first=True):
    """
    :param inputs: [B, T, D] if batch_first 
    :param lengths:  [B]
    :param batch_first:  
    :return: 
    """
    batch_list = []
    if batch_first:
        for i, l in enumerate(lengths):
            # print(inputs[i, :l].size())
            batch_list.append(inputs[i, :l])
        packed_sequence = torch.cat(batch_list, 0)
        # if chuck:
        #     return list(torch.chunk(packed_sequence, chuck, dim=0))
        # else:
        return packed_sequence
    else:
        raise NotImplemented() 
Example #8
Source File: grassdata.py    From grass_pytorch with Apache License 2.0 6 votes vote down vote up
def __init__(self, dir, transform=None):
        self.dir = dir
        box_data = torch.from_numpy(loadmat(self.dir+u'/box_data.mat')[u'boxes']).float()
        op_data = torch.from_numpy(loadmat(self.dir+u'/op_data.mat')[u'ops']).int()
        sym_data = torch.from_numpy(loadmat(self.dir+u'/sym_data.mat')[u'syms']).float()
        #weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float()
        num_examples = op_data.size()[1]
        box_data = torch.chunk(box_data, num_examples, 1)
        op_data = torch.chunk(op_data, num_examples, 1)
        sym_data = torch.chunk(sym_data, num_examples, 1)
        #weight_list = torch.chunk(weight_list, num_examples, 1)
        self.transform = transform
        self.trees = []
        for i in xrange(len(op_data)) :
            boxes = torch.t(box_data[i])
            ops = torch.t(op_data[i])
            syms = torch.t(sym_data[i])
            tree = Tree(boxes, ops, syms)
            self.trees.append(tree) 
Example #9
Source File: classifiers.py    From pase with MIT License 6 votes vote down vote up
def forward(self, x):
        # input x with shape [B, F, T]
        # FORWARD THROUGH DRN
        # ----------------------------
        if self.frontend is not None:
            x = self.frontend(x)
            if not self.ft_fe:
                x = x.detach()
        x = F.pad(x, (4, 5))
        x = self.drn(x)
        # FORWARD THROUGH RNN
        # ----------------------------
        x = x.transpose(1, 2)
        x, _ = self.rnn(x)
        xt = torch.chunk(x, x.shape[1], dim=1)
        x = xt[-1].transpose(1, 2)
        # FORWARD THROUGH DNn
        # ----------------------------
        x = self.mlp(x)
        return x 
Example #10
Source File: grassdata.py    From grass_pytorch with Apache License 2.0 6 votes vote down vote up
def __init__(self, dir, transform=None):
        self.dir = dir
        box_data = torch.from_numpy(loadmat(self.dir+'/box_data.mat')['boxes']).float()
        op_data = torch.from_numpy(loadmat(self.dir+'/op_data.mat')['ops']).int()
        sym_data = torch.from_numpy(loadmat(self.dir+'/sym_data.mat')['syms']).float()
        #weight_list = torch.from_numpy(loadmat(self.dir+'/weights.mat')['weights']).float()
        num_examples = op_data.size()[1]
        box_data = torch.chunk(box_data, num_examples, 1)
        op_data = torch.chunk(op_data, num_examples, 1)
        sym_data = torch.chunk(sym_data, num_examples, 1)
        #weight_list = torch.chunk(weight_list, num_examples, 1)
        self.transform = transform
        self.trees = []
        for i in range(len(op_data)) :
            boxes = torch.t(box_data[i])
            ops = torch.t(op_data[i])
            syms = torch.t(sym_data[i])
            tree = Tree(boxes, ops, syms)
            self.trees.append(tree) 
Example #11
Source File: frontend.py    From pase with MIT License 6 votes vote down vote up
def forward(self, batch, device=None, mode=None):

        # batch possible chunk and contexts, or just forward non-dict tensor
        x, data_fmt = format_frontend_chunk(batch, device)

        sinc_out = self.sinc(x).unsqueeze(1)

        # print(sinc_out.shape)

        conv_out = self.conv1(sinc_out)

        # print(conv_out.shape)

        res_out = self.resnet(conv_out)

        # print(res_out.shape)

        h =self.conv2(res_out).squeeze(2)

        # print(h.shape)

        return format_frontend_output(h, data_fmt, mode) 
Example #12
Source File: treelstm.py    From treenet with GNU General Public License v3.0 6 votes vote down vote up
def forward(self, inputs, children, arities):

        i = self.wi_net(inputs)
        o = self.wo_net(inputs)
        u = self.wu_net(inputs)

        f_base = self.wf_net(inputs)
        fc_sum = inputs.new_zeros(self.memory_size)
        for k, child in enumerate(children):
            child_h, child_c = torch.chunk(child, 2, dim=1)
            i.add_(self.ui_nets[k](child_h))
            o.add_(self.uo_nets[k](child_h))
            u.add_(self.uu_nets[k](child_h))

            f = f_base
            for l, other_child in enumerate(children):
                other_child_h, _ = torch.chunk(other_child, 2, dim=1)
                f = f.add(self.uf_nets[k][l](other_child_h))
            fc_sum.add(torch.sigmoid(f) * child_c)

        c = torch.sigmoid(i) * torch.tanh(u) + fc_sum
        h = torch.sigmoid(o) * torch.tanh(c)
        return torch.cat([h, c], dim=1) 
Example #13
Source File: train.py    From dgl with Apache License 2.0 6 votes vote down vote up
def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
    """ One step of training. """
    deg_g = deg_g.to(dev)
    deg_lg = deg_lg.to(dev)
    pm_pd = pm_pd.to(dev)
    t0 = time.time()
    z = model(g, lg, deg_g, deg_lg, pm_pd)
    t_forward = time.time() - t0

    z_list = th.chunk(z, args.batch_size, 0)
    loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size
    overlap = compute_overlap(z_list)

    optimizer.zero_grad()
    t0 = time.time()
    loss.backward()
    t_backward = time.time() - t0
    optimizer.step()

    return loss, overlap, t_forward, t_backward 
Example #14
Source File: blow.py    From blow with Apache License 2.0 6 votes vote down vote up
def forward(self,h,emb):
        sbatch,nsq,lchunk=h.size()
        h=h.contiguous()
        """
        # Slower version
        ws=list(self.adapt_w(emb).view(sbatch,self.ncha,1,self.kw))
        bs=list(self.adapt_b(emb))
        hs=list(torch.chunk(h,sbatch,dim=0))
        out=[]
        for hi,wi,bi in zip(hs,ws,bs):
            out.append(torch.nn.functional.conv1d(hi,wi,bias=bi,padding=self.kw//2,groups=nsq))
        h=torch.cat(out,dim=0)
        """
        # Faster version fully using group convolution
        w=self.adapt_w(emb).view(-1,1,self.kw)
        b=self.adapt_b(emb).view(-1)
        h=torch.nn.functional.conv1d(h.view(1,-1,lchunk),w,bias=b,padding=self.kw//2,groups=sbatch*nsq).view(sbatch,self.ncha,lchunk)
        #"""
        h=self.net.forward(h)
        s,m=torch.chunk(h,2,dim=1)
        s=torch.sigmoid(s+2)+1e-7
        return s,m

########################################################################################################################
######################################################################################################################## 
Example #15
Source File: simple.py    From kge with MIT License 6 votes vote down vote up
def score_emb(self, s_emb, p_emb, o_emb, combine: str):
        n = p_emb.size(0)

        # split left/right
        s_emb_h, s_emb_t = torch.chunk(s_emb, 2, dim=1)
        p_emb_forward, p_emb_backward = torch.chunk(p_emb, 2, dim=1)
        o_emb_h, o_emb_t = torch.chunk(o_emb, 2, dim=1)

        if combine == "spo":
            out1 = (s_emb_h * p_emb_forward * o_emb_t).sum(dim=1)
            out2 = (s_emb_t * p_emb_backward * o_emb_h).sum(dim=1)
        elif combine == "sp_":
            out1 = (s_emb_h * p_emb_forward).mm(o_emb_t.transpose(0, 1))
            out2 = (s_emb_t * p_emb_backward).mm(o_emb_h.transpose(0, 1))
        elif combine == "_po":
            out1 = (o_emb_t * p_emb_forward).mm(s_emb_h.transpose(0, 1))
            out2 = (o_emb_h * p_emb_backward).mm(s_emb_t.transpose(0, 1))
        else:
            return super().score_emb(s_emb, p_emb, o_emb, combine)

        return (out1 + out2).view(n, -1) / 2.0 
Example #16
Source File: transformer.py    From naru with Apache License 2.0 6 votes vote down vote up
def forward(self, x, query_input=None):
        """x: [bs, num cols, d_model].  Output has the same shape."""
        assert x.dim() == 3, x.size()
        bs, ncols, _ = x.size()

        # [bs, num cols, d_state * 3 * num_heads]
        qkv = self.qkv_linear(x)
        # [bs, num heads, num cols, d_state] each
        qs, ks, vs = map(self._split_heads, torch.chunk(qkv, 3, dim=-1))

        if query_input is not None:
            # TODO: obviously can avoid redundant calc.
            qkv = self.qkv_linear(query_input)
            qs, _, _ = map(self._split_heads, torch.chunk(qkv, 3, dim=-1))

        # [bs, num heads, num cols, d_state]
        x = self._do_attention(qs, ks, vs, mask=self.attn_mask.to(x.device))

        # [bs, num cols, num heads, d_state]
        x = x.transpose(1, 2)
        # Concat all heads' outputs: [bs, num cols, num heads * d_state]
        x = x.contiguous().view(bs, ncols, -1)
        # Then do a transform: [bs, num cols, d_model].
        x = self.linear(x)
        return x 
Example #17
Source File: modulated_deform_conv.py    From mmcv with Apache License 2.0 5 votes vote down vote up
def forward(self, x):
        out = self.conv_offset(x)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
                                       self.stride, self.padding,
                                       self.dilation, self.groups,
                                       self.deform_groups) 
Example #18
Source File: segan.py    From tfm-franroldan-wav2pix with GNU General Public License v3.0 5 votes vote down vote up
def forward(self, x):
        #print("Input: {}".format(x.data.shape))
        h = x
        # store intermediate activations
        int_act = {}
        for ii, layer in enumerate(self.disc):
            #print(ii)
            h, _ = layer(h)
            #print("After layer: {}".format(h.data.shape))
            int_act['h_{}'.format(ii)] = h
        if self.pool_type == 'rnn':
            if hasattr(self, 'ln'):
                h = self.ln(h)
                int_act['ln_conv'] = h
            ht, state = self.rnn(h.transpose(1, 2))
            h = state[0]
            # concat both states (fwd, bwd)
            hfwd, hbwd = torch.chunk(h, 2, 0)
            h = torch.cat((hfwd, hbwd), dim=2)
            h = h.squeeze(0)
            int_act['rnn_h'] = h
        elif self.pool_type == 'conv':
            h = self.pool_conv(h)
            h = h.view(h.size(0), -1)
            int_act['avg_conv_h'] = h
        elif self.pool_type == 'none':
            h = h.view(h.size(0), -1)
        #print("Final h: {}".format(h.data.shape))
        #print(type(h.data))
        y = self.fc(h)
        #print(type(y.data))
        int_act['logit'] = y
        # return F.sigmoid(y), int_act
        return y, int_act 
Example #19
Source File: deform_conv_module.py    From DetNAS with MIT License 5 votes vote down vote up
def forward(self, input):
        out = self.conv_offset_mask(input)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return modulated_deform_conv(
            input, offset, mask, self.weight, self.bias, self.stride,
            self.padding, self.dilation, self.groups, self.deformable_groups) 
Example #20
Source File: treelstm.py    From treenet with GNU General Public License v3.0 5 votes vote down vote up
def forward(self, *args, **kwargs):
        hc = super(TreeLSTM, self).forward(*args, **kwargs)
        h, _ = torch.chunk(hc, 2, dim=1)
        return h 
Example #21
Source File: irevnet.py    From imgclsmob with MIT License 5 votes vote down vote up
def forward(self, x, _):
        x1, x2 = torch.chunk(x, chunks=2, dim=1)
        return x1, x2 
Example #22
Source File: toy.py    From vae-lagging-encoder with MIT License 5 votes vote down vote up
def plot_multiple(model, plot_data, grid_z,
                  iter_, args):

    plot_data, sents_len = plot_data
    plot_data_list = torch.chunk(plot_data, round(args.num_plot / args.batch_size))

    infer_posterior_mean = []
    report_loss_kl = report_mi = report_num_sample = 0
    for data in plot_data_list:
        report_loss_kl += model.KL(data).sum().item()
        report_num_sample += data.size(0)
        report_mi += model.calc_mi_q(data) * data.size(0)

        # [batch, 1]
        posterior_mean = model.calc_model_posterior_mean(data, grid_z)

        infer_mean = model.calc_infer_mean(data)

        infer_posterior_mean.append(torch.cat([posterior_mean, infer_mean], 1))

    # [*, 2]
    infer_posterior_mean = torch.cat(infer_posterior_mean, 0)

    save_path = os.path.join(args.plot_dir, 'aggr%d_iter%d_multiple.pickle' % (args.aggressive, iter_))

    save_data = {'posterior': infer_posterior_mean[:,0].cpu().numpy(),
                 'inference': infer_posterior_mean[:,1].cpu().numpy(),
                 'kl': report_loss_kl / report_num_sample,
                 'mi': report_mi / report_num_sample
                 }
    pickle.dump(save_data, open(save_path, 'wb')) 
Example #23
Source File: revnet.py    From imgclsmob with MIT License 5 votes vote down vote up
def backward(ctx, grad_y):
        fm = ctx.fm
        gm = ctx.gm

        x, y = ctx.saved_variables
        y1, y2 = torch.chunk(y, chunks=2, dim=1)
        y1 = y1.contiguous()
        y2 = y2.contiguous()

        with torch.no_grad():
            y1_z = Variable(y1.data, requires_grad=True)
            x2 = y2 - gm(y1_z)
            x1 = y1 - fm(x2)

        with set_grad_enabled(True):
            x1_ = Variable(x1.data, requires_grad=True)
            x2_ = Variable(x2.data, requires_grad=True)
            y1_ = x1_ + fm.forward(x2_)
            y2_ = x2_ + gm(y1_)
            y = torch.cat((y1_, y2_), dim=1)

            dd = torch.autograd.grad(y, (x1_, x2_) + tuple(gm.parameters()) + tuple(fm.parameters()), grad_y)

            gm_params_len = len([p for p in gm.parameters()])
            gm_params_grads = dd[2:2 + gm_params_len]
            fm_params_grads = dd[2 + gm_params_len:]
            grad_x = torch.cat((dd[0], dd[1]), dim=1)

            y1_.detach_()
            y2_.detach_()
            del y1_, y2_

        x.data.set_(torch.cat((x1, x2), dim=1).data.contiguous())

        return (grad_x, None, None) + fm_params_grads + gm_params_grads 
Example #24
Source File: modulated_dcn.py    From openseg.pytorch with MIT License 5 votes vote down vote up
def forward(self, input):
        out = self.conv_offset_mask(input)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        func = ModulatedDeformConvFunction(self.stride, self.padding, self.dilation, self.deformable_groups)
        return func(input, offset, mask, self.weight, self.bias) 
Example #25
Source File: deform_conv.py    From GCNet with Apache License 2.0 5 votes vote down vote up
def forward(self, x):
        out = self.conv_offset_mask(x)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
                                     self.stride, self.padding, self.dilation,
                                     self.groups, self.deformable_groups) 
Example #26
Source File: deform_conv_module.py    From Clothing-Detection with GNU General Public License v3.0 5 votes vote down vote up
def forward(self, input):
        out = self.conv_offset_mask(input)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return modulated_deform_conv(
            input, offset, mask, self.weight, self.bias, self.stride,
            self.padding, self.dilation, self.groups, self.deformable_groups) 
Example #27
Source File: test_additive_shared.py    From PySyft with Apache License 2.0 5 votes vote down vote up
def test_chunk(workers):
    bob, alice, james = (workers["bob"], workers["alice"], workers["james"])
    t = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
    x = t.share(bob, alice, crypto_provider=james)

    res0 = torch.chunk(x, 2, dim=0)
    res1 = torch.chunk(x, 2, dim=1)

    expected0 = [torch.tensor([[1, 2, 3, 4]]), torch.tensor([[5, 6, 7, 8]])]
    expected1 = [torch.tensor([[1, 2], [5, 6]]), torch.tensor([[3, 4], [7, 8]])]

    assert all(((res0[i].get() == expected0[i]).all() for i in range(2)))
    assert all(((res1[i].get() == expected1[i]).all() for i in range(2))) 
Example #28
Source File: average_attn.py    From ITDD with MIT License 5 votes vote down vote up
def forward(self, inputs, mask=None, layer_cache=None, step=None):
        """
        Args:
            inputs (`FloatTensor`): `[batch_size x input_len x model_dim]`

        Returns:
            (`FloatTensor`, `FloatTensor`):

            * gating_outputs `[batch_size x 1 x model_dim]`
            * average_outputs average attention `[batch_size x 1 x model_dim]`
        """
        batch_size = inputs.size(0)
        inputs_len = inputs.size(1)

        device = inputs.device
        average_outputs = self.cumulative_average(
          inputs, self.cumulative_average_mask(batch_size,
                                               inputs_len).to(device).float()
          if layer_cache is None else step, layer_cache=layer_cache)
        average_outputs = self.average_layer(average_outputs)
        gating_outputs = self.gating_layer(torch.cat((inputs,
                                                      average_outputs), -1))
        input_gate, forget_gate = torch.chunk(gating_outputs, 2, dim=2)
        gating_outputs = torch.sigmoid(input_gate) * inputs + \
            torch.sigmoid(forget_gate) * average_outputs

        return gating_outputs, average_outputs 
Example #29
Source File: dcn_v2.py    From centerpose with MIT License 5 votes vote down vote up
def forward(self, input):
        out = self.conv_offset_mask(input)
        o1, o2, mask = torch.chunk(out, 3, dim=1)
        offset = torch.cat((o1, o2), dim=1)
        mask = torch.sigmoid(mask)
        return dcn_v2_conv(input, offset, mask,
                           self.weight, self.bias,
                           self.stride,
                           self.padding,
                           self.dilation,
                           self.deformable_groups) 
Example #30
Source File: score_fun.py    From dgl with Apache License 2.0 5 votes vote down vote up
def edge_func(self, edges):
        re_head, im_head = th.chunk(edges.src['emb'], 2, dim=-1)
        re_tail, im_tail = th.chunk(edges.dst['emb'], 2, dim=-1)

        phase_rel = edges.data['emb'] / (self.emb_init / np.pi)
        re_rel, im_rel = th.cos(phase_rel), th.sin(phase_rel)
        re_score = re_head * re_rel - im_head * im_rel
        im_score = re_head * im_rel + im_head * re_rel
        re_score = re_score - re_tail
        im_score = im_score - im_tail
        score = th.stack([re_score, im_score], dim=0)
        score = score.norm(dim=0)
        return {'score': self.gamma - score.sum(-1)}