import torch from torch_scatter import scatter_log_softmax, scatter_softmax def test_softmax(): src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) src.requires_grad_() index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) out = scatter_softmax(src, index) out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1) out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1) out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1) out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1) expected = torch.stack([ out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] ], dim=0) assert torch.allclose(out, expected) out.backward(torch.randn_like(out)) def test_log_softmax(): src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) src.requires_grad_() index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) out = scatter_log_softmax(src, index) out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1) out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1) out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1) out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1) expected = torch.stack([ out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] ], dim=0) assert torch.allclose(out, expected) out.backward(torch.randn_like(out))