from math import pi as PI

import torch
from torch_geometric.transforms import Polar
from torch_geometric.data import Data


def test_polar():
    assert Polar().__repr__() == 'Polar(norm=True, max_value=None)'

    pos = torch.Tensor([[0, 0], [1, 0]])
    edge_index = torch.tensor([[0, 1], [1, 0]])
    edge_attr = torch.Tensor([1, 1])

    data = Data(edge_index=edge_index, pos=pos)
    data = Polar(norm=False)(data)
    assert len(data) == 3
    assert data.pos.tolist() == pos.tolist()
    assert data.edge_index.tolist() == edge_index.tolist()
    assert torch.allclose(
        data.edge_attr, torch.Tensor([[1, 0], [1, PI]]), atol=1e-04)

    data = Data(edge_index=edge_index, pos=pos, edge_attr=edge_attr)
    data = Polar(norm=True)(data)
    assert len(data) == 3
    assert data.pos.tolist() == pos.tolist()
    assert data.edge_index.tolist() == edge_index.tolist()
    assert torch.allclose(
        data.edge_attr, torch.Tensor([[1, 1, 0], [1, 1, 0.5]]), atol=1e-04)