Python torch.norm() Examples

The following are 30 code examples of torch.norm(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module torch , or try the search function .
Example #1
Source File: test_manifold_basic.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def sphere_subspace_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    subspace = torch.rand(shape[-1], 2, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(subspace)
    P = Q @ Q.t()

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(intersection=subspace)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(intersection=subspace)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case 
Example #2
Source File: utils.py    From pase with MIT License 6 votes vote down vote up
def get_grad_norms(model, keys=[]):
    grads = {}
    for i, (k, param) in enumerate(dict(model.named_parameters()).items()):
        accept = False
        for key in keys:
            # match substring in collection of model keys
            if key in k:
                accept = True
                break
        if not accept:
            continue
        if param.grad is None:
            print('WARNING getting grads: {} param grad is None'.format(k))
            continue
        grads[k] = torch.norm(param.grad).cpu().item()
    return grads 
Example #3
Source File: manager.py    From gnn-comparison with GNU General Public License v3.0 6 votes vote down vote up
def _power_iteration(self, A, num_simulations=30):
        # Ideally choose a random vector
        # To decrease the chance that our vector
        # Is orthogonal to the eigenvector
        b_k = torch.rand(A.shape[1]).unsqueeze(dim=1) * 0.5 - 1

        for _ in range(num_simulations):
            # calculate the matrix-by-vector product Ab
            b_k1 = torch.mm(A, b_k)

            # calculate the norm
            b_k1_norm = torch.norm(b_k1)

            # re normalize the vector
            b_k = b_k1 / b_k1_norm

        return b_k 
Example #4
Source File: adaptive_sampling.py    From dgl with Apache License 2.0 6 votes vote down vote up
def __init__(self,
                 in_feats,
                 out_feats,
                 aggregator_type='mean',
                 feat_drop=0.,
                 bias=True,
                 norm=None,
                 activation=None,
                 G=None):
        super(AdaptSAGEConv, self).__init__()
        self._in_feats = in_feats
        self._out_feats = out_feats
        self.norm = norm
        self.feat_drop = nn.Dropout(feat_drop)
        self.activation = activation
        # self.fc_self = nn.Linear(in_feats, out_feats, bias=bias).double()
        self.fc_neigh = nn.Linear(in_feats, out_feats, bias=bias)
        self.reset_parameters()
        self.G = G 
Example #5
Source File: lipschitz.py    From residual-flows with MIT License 6 votes vote down vote up
def operator_norm_settings(domain, codomain):
    if domain == 1 and codomain == 1:
        # maximum l1-norm of column
        max_across_input_dims = True
        norm_type = 1
    elif domain == 1 and codomain == 2:
        # maximum l2-norm of column
        max_across_input_dims = True
        norm_type = 2
    elif domain == 1 and codomain == float("inf"):
        # maximum l-inf norm of column
        max_across_input_dims = True
        norm_type = float("inf")
    elif domain == 2 and codomain == float("inf"):
        # maximum l2-norm of row
        max_across_input_dims = False
        norm_type = 2
    elif domain == float("inf") and codomain == float("inf"):
        # maximum l1-norm of row
        max_across_input_dims = False
        norm_type = 1
    else:
        raise ValueError('Unknown combination of domain "{}" and codomain "{}"'.format(domain, codomain))

    return max_across_input_dims, norm_type 
Example #6
Source File: label_model.py    From metal with Apache License 2.0 6 votes vote down vote up
def loss_l2(self, l2=0):
        """L2 loss centered around mu_init, scaled optionally per-source.

        In other words, diagonal Tikhonov regularization,
            ||D(\mu-\mu_{init})||_2^2
        where D is diagonal.

        Args:
            - l2: A float or np.array representing the per-source regularization
                strengths to use
        """
        if isinstance(l2, (int, float)):
            D = l2 * torch.eye(self.d)
        else:
            D = torch.diag(torch.from_numpy(l2))

        # Note that mu is a matrix and this is the *Frobenius norm*
        return torch.norm(D @ (self.mu - self.mu_init)) ** 2 
Example #7
Source File: score_fun.py    From dgl with Apache License 2.0 6 votes vote down vote up
def create_neg(self, neg_head):
        gamma = self.gamma
        if neg_head:
            def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
                relations = relations.reshape(num_chunks, -1, self.relation_dim)
                tails = tails - relations
                tails = tails.reshape(num_chunks, -1, 1, self.relation_dim)
                score = heads - tails
                return gamma - th.norm(score, p=1, dim=-1)
            return fn
        else:
            def fn(heads, relations, tails, num_chunks, chunk_size, neg_sample_size):
                relations = relations.reshape(num_chunks, -1, self.relation_dim)
                heads = heads - relations
                heads = heads.reshape(num_chunks, -1, 1, self.relation_dim)
                score = heads - tails
                return gamma - th.norm(score, p=1, dim=-1)
            return fn 
Example #8
Source File: test_manifold_basic.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def sphere_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = ex / torch.norm(ex)
    v = ev - (x @ ev) * x

    manifold = geoopt.Sphere()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact()
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case 
Example #9
Source File: test_manifold_basic.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def sphere_compliment_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.Sphere]
    complement = torch.rand(shape[-1], 1, dtype=torch.float64)

    Q, _ = geoopt.linalg.batch_linalg.qr(complement)
    P = -Q @ Q.transpose(-1, -2)
    P[..., torch.arange(P.shape[-2]), torch.arange(P.shape[-2])] += 1

    ex = torch.randn(*shape, dtype=torch.float64)
    ev = torch.randn(*shape, dtype=torch.float64)
    x = (ex @ P.t()) / torch.norm(ex @ P.t())
    v = (ev - (x @ ev) * x) @ P.t()

    manifold = geoopt.Sphere(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.SphereExact(complement=complement)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case 
Example #10
Source File: test_manifold_basic.py    From geoopt with Apache License 2.0 6 votes vote down vote up
def poincare_case():
    torch.manual_seed(42)
    shape = manifold_shapes[geoopt.manifolds.PoincareBall]
    ex = torch.randn(*shape, dtype=torch.float64) / 3
    ev = torch.randn(*shape, dtype=torch.float64) / 3
    x = torch.tanh(torch.norm(ex)) * ex / torch.norm(ex)
    ex = x.clone()
    v = ev.clone()
    manifold = geoopt.PoincareBall().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case
    manifold = geoopt.PoincareBallExact().to(dtype=torch.float64)
    x = geoopt.ManifoldTensor(x, manifold=manifold)
    case = UnaryCase(shape, x, ex, v, ev, manifold)
    yield case 
Example #11
Source File: Losses.py    From 3D-HourGlass-Network with MIT License 6 votes vote down vote up
def AccelerationMatchingError(input, target):
	global lossfunc
	"""
	Takes input as (N,C,D,3) 3D coordinates and similiar targets (Here C is number of channels equivalent to number of joints)
	"""
	assert input.shape == target.shape
	assert len(input.shape) == 4
	#print('\n')
	#print(input[0,:8,0,:])
	#print(target[0,:8,0,:])
	input = input.cuda()
	inputdistances = input[:,:,1:,:] - input[:,:,:-1,:]
	inputdistances = torch.norm(inputdistances, dim=3)
	inputaccn = inputdistances[:,:,2:] + inputdistances[:,:,:-2] - 2*inputdistances[:,:,1:-1]
	targetdistances = target[:,:,1:,:] - target[:,:,:-1,:]
	targetdistances = torch.norm(targetdistances, dim=3)
	targetaccn = targetdistances[:,:,2:] + targetdistances[:,:,:-2] - 2*targetdistances[:,:,1:-1]
	return lossfunc(inputaccn, targetaccn) 
Example #12
Source File: sgcn.py    From SGCN with GNU General Public License v3.0 6 votes vote down vote up
def calculate_positive_embedding_loss(self, z, positive_edges):
        """
        Calculating the loss on the positive edge embedding distances
        :param z: Hidden vertex representation.
        :param positive_edges: Positive training edges.
        :return loss_term: Loss value on positive edge embedding.
        """
        self.positive_surrogates = [random.choice(self.nodes) for node in range(positive_edges.shape[1])]
        self.positive_surrogates = torch.from_numpy(np.array(self.positive_surrogates, dtype=np.int64).T)
        self.positive_surrogates = self.positive_surrogates.type(torch.long).to(self.device)
        positive_edges = torch.t(positive_edges)
        self.positive_z_i = z[positive_edges[:, 0], :]
        self.positive_z_j = z[positive_edges[:, 1], :]
        self.positive_z_k = z[self.positive_surrogates, :]
        norm_i_j = torch.norm(self.positive_z_i-self.positive_z_j, 2, 1, True).pow(2)
        norm_i_k = torch.norm(self.positive_z_i-self.positive_z_k, 2, 1, True).pow(2)
        term = norm_i_j-norm_i_k
        term[term < 0] = 0
        loss_term = term.mean()
        return loss_term 
Example #13
Source File: polar.py    From pytorch_geometric with MIT License 6 votes vote down vote up
def __call__(self, data):
        (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
        assert pos.dim() == 2 and pos.size(1) == 2

        cart = pos[col] - pos[row]

        rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)

        theta = torch.atan2(cart[..., 1], cart[..., 0]).view(-1, 1)
        theta = theta + (theta < 0).type_as(theta) * (2 * PI)

        if self.norm:
            rho = rho / (rho.max() if self.max is None else self.max)
            theta = theta / (2 * PI)

        polar = torch.cat([rho, theta], dim=-1)

        if pseudo is not None and self.cat:
            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
            data.edge_attr = torch.cat([pseudo, polar.type_as(pos)], dim=-1)
        else:
            data.edge_attr = polar

        return data 
Example #14
Source File: DCCLoss.py    From DCC with MIT License 6 votes vote down vote up
def forward(self, enc_out, sampweights, pairweights, pairs, index, _sigma1, _sigma2, _lambda):
        centroids = self.U[index]

        # note that sigmas here are labelled mu in the paper
        # data loss
        # enc_out is Y, the original embedding without shift
        out1 = torch.norm((enc_out - centroids).view(len(enc_out), -1), p=2, dim=1) ** 2
        out11 = torch.sum(_sigma1 * sampweights * out1 / (_sigma1 + out1))

        # pairwise loss
        out2 = torch.norm((centroids[pairs[:, 0]] - centroids[pairs[:, 1]]).view(len(pairs), -1), p=2, dim=1) ** 2

        out21 = _lambda * torch.sum(_sigma2 * pairweights * out2 / (_sigma2 + out2))

        out = out11 + out21

        if self.size_average:
            out = out / enc_out.nelement()

        return out 
Example #15
Source File: sgcn.py    From SGCN with GNU General Public License v3.0 6 votes vote down vote up
def calculate_negative_embedding_loss(self, z, negative_edges):
        """
        Calculating the loss on the negative edge embedding distances
        :param z: Hidden vertex representation.
        :param negative_edges: Negative training edges.
        :return loss_term: Loss value on negative edge embedding.
        """
        self.negative_surrogates = [random.choice(self.nodes) for node in range(negative_edges.shape[1])]
        self.negative_surrogates = torch.from_numpy(np.array(self.negative_surrogates, dtype=np.int64).T)
        self.negative_surrogates = self.negative_surrogates.type(torch.long).to(self.device)
        negative_edges = torch.t(negative_edges)
        self.negative_z_i = z[negative_edges[:, 0], :]
        self.negative_z_j = z[negative_edges[:, 1], :]
        self.negative_z_k = z[self.negative_surrogates, :]
        norm_i_j = torch.norm(self.negative_z_i-self.negative_z_j, 2, 1, True).pow(2)
        norm_i_k = torch.norm(self.negative_z_i-self.negative_z_k, 2, 1, True).pow(2)
        term = norm_i_k-norm_i_j
        term[term < 0] = 0
        loss_term = term.mean()
        return loss_term 
Example #16
Source File: functional.py    From audio with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def magphase(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tuple[Tensor, Tensor]:
    r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.

    Args:
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
        power (float): Power of the norm. (Default: `1.0`)

    Returns:
        (Tensor, Tensor): The magnitude and phase of the complex tensor
    """
    mag = complex_norm(complex_tensor, power)
    phase = angle(complex_tensor)
    return mag, phase 
Example #17
Source File: functional.py    From audio with BSD 2-Clause "Simplified" License 6 votes vote down vote up
def complex_norm(
        complex_tensor: Tensor,
        power: float = 1.0
) -> Tensor:
    r"""Compute the norm of complex tensor input.

    Args:
        complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
        power (float): Power of the norm. (Default: `1.0`).

    Returns:
        Tensor: Power of the normed input tensor. Shape of `(..., )`
    """

    # Replace by torch.norm once issue is fixed
    # https://github.com/pytorch/pytorch/issues/34279
    return complex_tensor.pow(2.).sum(-1).pow(0.5 * power) 
Example #18
Source File: score_fun.py    From dgl with Apache License 2.0 5 votes vote down vote up
def edge_func(self, edges):
        head = edges.src['emb']
        tail = edges.dst['emb'].unsqueeze(-1)
        rel = edges.data['emb']
        rel = rel.view(-1, self.relation_dim, self.entity_dim)
        score = head * th.matmul(rel, tail).squeeze(-1)
        # TODO: check if use self.gamma
        return {'score': th.sum(score, dim=-1)}
        # return {'score': self.gamma - th.norm(score, p=1, dim=-1)} 
Example #19
Source File: nn_proc.py    From signaltrain with GNU General Public License v3.0 5 votes vote down vote up
def forward(self, x_cuda, knobs_cuda, return_acts=False):
        # trainable STFT, outputs spectrograms for real & imag parts
        x_real, x_imag = self.dft_analysis.forward(x_cuda/2)  # the /2 is cheap way to help us approach 'unit variaance' of -0.5 and .5
        # Magnitude-Phase computation
        mag = torch.norm(torch.cat((x_real.unsqueeze(0), x_imag.unsqueeze(0)), 0), 2, dim=0)
        phs = torch.atan2(x_imag.float(), x_real.float()+1e-7).to(x_cuda.dtype)
        if return_acts:
            layer_acts = [x_real, x_imag, mag, phs]

        # Processes Magnitude and phase individually
        mag_hat, m_acts = self.aenc.forward(mag, knobs_cuda, skip_connections='sf', return_acts=return_acts)
        phs_hat, p_acts = self.phs_aenc.forward(phs, knobs_cuda, skip_connections='', return_acts=return_acts)
        if return_acts:
            layer_acts.extend(m_acts)
            layer_acts.extend(p_acts)

        output_phs_dim = phs_hat.size()[1]
        phs_hat = phs_hat + phs[:,-output_phs_dim:,:] # <-- residual skip connection. Slightly smoother convergence

        # Back to Real and Imaginary
        an_real = mag_hat * torch.cos(phs_hat)
        an_imag = mag_hat * torch.sin(phs_hat)

        # Forward synthesis pass
        x_fwdsyn = self.dft_synthesis.forward(an_real, an_imag)

        # final skip residual
        y_hat = x_fwdsyn  + x_cuda[:,-x_fwdsyn.size()[-1]:]/2

        if return_acts:
            layer_acts.extend([mag_hat, phs_hat, an_real, an_imag, x_fwdsyn, y_hat])

        if return_acts:
            return 2*y_hat, mag, mag_hat, layer_acts   # undo the /2 at the beginning
        else:
            return 2*y_hat, mag, mag_hat 
Example #20
Source File: basemodel.py    From DeepCTR-Torch with Apache License 2.0 5 votes vote down vote up
def add_regularization_loss(self, weight_list, weight_decay, p=2):
        reg_loss = torch.zeros((1,), device=self.device)
        for w in weight_list:
            if isinstance(w, tuple):
                l2_reg = torch.norm(w[1], p=p, )
            else:
                l2_reg = torch.norm(w, p=p, )
            reg_loss = reg_loss + l2_reg
        reg_loss = weight_decay * reg_loss
        self.reg_loss = self.reg_loss + reg_loss 
Example #21
Source File: class_balance.py    From metal with Apache License 2.0 5 votes vote down vote up
def get_loss(O, Q, mask):
        # Main constraint: match empirical three-way overlaps matrix
        # (entries O_{ijk} for i != j != k)
        diffs = (O - torch.einsum("aby,cdy,efy->acebdf", [Q, Q, Q]))[mask]
        return torch.norm(diffs) ** 2 
Example #22
Source File: utils.py    From crosentgec with GNU General Public License v3.0 5 votes vote down vote up
def clip_grad_norm_(tensor, max_norm):
    grad_norm = item(torch.norm(tensor))
    if grad_norm > max_norm > 0:
        clip_coef = max_norm / (grad_norm + 1e-6)
        tensor.mul_(clip_coef)
    return grad_norm 
Example #23
Source File: label_model.py    From metal with Apache License 2.0 5 votes vote down vote up
def loss_mu(self, *args, l2=0):
        loss_1 = torch.norm((self.O - self.mu @ self.P @ self.mu.t())[self.mask]) ** 2
        loss_2 = torch.norm(torch.sum(self.mu @ self.P, 1) - torch.diag(self.O)) ** 2
        return loss_1 + loss_2 + self.loss_l2(l2=l2) 
Example #24
Source File: label_model.py    From metal with Apache License 2.0 5 votes vote down vote up
def loss_inv_mu(self, *args, l2=0):
        loss_1 = torch.norm(self.Q - self.mu @ self.P @ self.mu.t()) ** 2
        loss_2 = torch.norm(torch.sum(self.mu @ self.P, 1) - torch.diag(self.O)) ** 2
        return loss_1 + loss_2 + self.loss_l2(l2=l2) 
Example #25
Source File: label_model.py    From metal with Apache License 2.0 5 votes vote down vote up
def loss_inv_Z(self, *args):
        return torch.norm((self.O_inv + self.Z @ self.Z.t())[self.mask]) ** 2 
Example #26
Source File: embedding_utils.py    From Talking-Face-Generation-DAVS with MIT License 5 votes vote down vote up
def l2_sim(feature1, feature2):
    Feature = feature1.expand(feature1.size(0), feature1.size(0), feature1.size(1)).transpose(0, 1)
    return torch.norm(Feature - feature2, p=2, dim=2) 
Example #27
Source File: train_funcs.py    From BERT-Relation-Extraction with Apache License 2.0 5 votes vote down vote up
def p_(self, f1_vec, f2_vec):
        if self.normalize:
            factor = 1/(torch.norm(f1_vec)*torch.norm(f2_vec))
        else:
            factor = 1.0
        
        if not self.use_logits:
            p = 1/(1 + torch.exp(-factor*torch.dot(f1_vec, f2_vec)))
        else:
            p = factor*torch.dot(f1_vec, f2_vec)
        return p 
Example #28
Source File: dif_fms.py    From DenseMatchingBenchmark with MIT License 5 votes vote down vote up
def fast_dif_fms(reference_fm, target_fm, max_disp=192, start_disp=0, dilation=1, disp_sample=None,
                 normalize=False, p=1.0,):
    device = reference_fm.device
    B, C, H, W = reference_fm.shape

    if disp_sample is None:
        end_disp = start_disp + max_disp - 1

        disp_sample_number = (max_disp + dilation - 1) // dilation
        D = disp_sample_number

        # generate disparity samples, in [B,D, H, W] layout
        disp_sample = torch.linspace(start_disp, end_disp, D)
        disp_sample = disp_sample.view(1, D, 1, 1).expand(B, D, H, W).to(device).float()

    else:  # direct provide disparity samples
        # the number of disparity samples
        D = disp_sample.shape[1]

    # expand D dimension
    dif_reference_fm = reference_fm.unsqueeze(2).expand(B, C, D, H, W)
    dif_target_fm = target_fm.unsqueeze(2).expand(B, C, D, H, W)

    # shift reference feature map with disparity through grid sample
    # shift target feature according to disparity samples
    dif_target_fm = inverse_warp_3d(dif_target_fm, -disp_sample, padding_mode='zeros')

    # mask out features in reference
    dif_reference_fm = dif_reference_fm * (dif_target_fm > 0).type_as(dif_reference_fm)

    # [B, C, D, H, W)
    dif_fm = dif_reference_fm - dif_target_fm

    if normalize:
        # [B, D, H, W]
        dif_fm = torch.norm(dif_fm, p=p, dim=1, keepdim=False)

    return dif_fm 
Example #29
Source File: functional.py    From SlowFast-Network-pytorch with MIT License 5 votes vote down vote up
def normalize(input, p=2, dim=1, eps=1e-12, out=None):
    # type: (Tensor, float, int, float, Optional[Tensor]) -> Tensor
    r"""Performs :math:`L_p` normalization of inputs over specified dimension.

    For a tensor :attr:`input` of sizes :math:`(n_0, ..., n_{dim}, ..., n_k)`, each
    :math:`n_{dim}` -element vector :math:`v` along dimension :attr:`dim` is transformed as

    .. math::
        v = \frac{v}{\max(\lVert v \rVert_p, \epsilon)}.

    With the default arguments it uses the Euclidean norm over vectors along dimension :math:`1` for normalization.

    Args:
        input: input tensor of any shape
        p (float): the exponent value in the norm formulation. Default: 2
        dim (int): the dimension to reduce. Default: 1
        eps (float): small value to avoid division by zero. Default: 1e-12
        out (Tensor, optional): the output tensor. If :attr:`out` is used, this
                                operation won't be differentiable.
    """
    if out is None:
        denom = input.norm(p, dim, True).clamp(min=eps).expand_as(input)
        ret = input / denom
    else:
        denom = input.norm(p, dim, True).clamp_(min=eps).expand_as(input)
        ret = torch.div(input, denom, out=torch.jit._unwrap_optional(out))
    return ret 
Example #30
Source File: spherical.py    From pytorch_geometric with MIT License 5 votes vote down vote up
def __call__(self, data):
        (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
        assert pos.dim() == 2 and pos.size(1) == 3

        cart = pos[col] - pos[row]

        rho = torch.norm(cart, p=2, dim=-1).view(-1, 1)

        theta = torch.atan2(cart[..., 1], cart[..., 0]).view(-1, 1)
        theta = theta + (theta < 0).type_as(theta) * (2 * PI)

        phi = torch.acos(cart[..., 2] / rho.view(-1)).view(-1, 1)

        if self.norm:
            rho = rho / (rho.max() if self.max is None else self.max)
            theta = theta / (2 * PI)
            phi = phi / PI

        spher = torch.cat([rho, theta, phi], dim=-1)

        if pseudo is not None and self.cat:
            pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
            data.edge_attr = torch.cat([pseudo, spher.type_as(pos)], dim=-1)
        else:
            data.edge_attr = spher

        return data