import pytest
from itertools import product
from functools import partial

import torch
from torch.autograd import gradcheck

from entmax.root_finding import sparsemax_bisect, entmax_bisect
from entmax.activations import sparsemax, entmax15


# @pytest.mark.parametrize("dim", (0, 1, 2))
# def test_dim(dim, Map):
# for _ in range(10):
# x = torch.randn(5, 6, 7, requires_grad=True, dtype=torch.float64)
# # gradcheck(f, (x,))


def test_sparsemax():
    x = 0.5 * torch.randn(4, 6, dtype=torch.float32)
    p1 = sparsemax(x, 1)
    p2 = sparsemax_bisect(x)
    assert torch.sum((p1 - p2) ** 2) < 1e-7


def test_entmax15():
    x = 0.5 * torch.randn(4, 6, dtype=torch.float32)
    p1 = entmax15(x, 1)
    p2 = entmax_bisect(x, alpha=1.5)
    assert torch.sum((p1 - p2) ** 2) < 1e-7


def test_sparsemax_grad():
    x = torch.randn(4, 6, dtype=torch.float64, requires_grad=True)
    gradcheck(sparsemax_bisect, (x,), eps=1e-5)


@pytest.mark.parametrize("alpha", (1.2, 1.5, 1.75, 2.25))
def test_entmax_grad(alpha):
    alpha = torch.tensor(alpha, dtype=torch.float64, requires_grad=True)
    x = torch.randn(4, 6, dtype=torch.float64, requires_grad=True)
    gradcheck(entmax_bisect, (x, alpha), eps=1e-5)


def test_entmax_correct_multiple_alphas():
    n = 4
    x = torch.randn(n, 6, dtype=torch.float64, requires_grad=True)
    alpha = 1.05 + torch.rand((n, 1), dtype=torch.float64, requires_grad=True)

    p1 = entmax_bisect(x, alpha)
    p2_ = [
        entmax_bisect(x[i].unsqueeze(0), alpha[i].item()).squeeze()
        for i in range(n)
    ]
    p2 = torch.stack(p2_)

    assert torch.allclose(p1, p2)


def test_entmax_grad_multiple_alphas():

    n = 4
    x = torch.randn(n, 6, dtype=torch.float64, requires_grad=True)
    alpha = 1.05 + torch.rand((n, 1), dtype=torch.float64, requires_grad=True)
    gradcheck(entmax_bisect, (x, alpha), eps=1e-5)


@pytest.mark.parametrize("dim", (0, 1, 2, 3))
def test_arbitrary_dimension(dim):
    shape = [3, 4, 2, 5]
    X = torch.randn(*shape, dtype=torch.float64)

    alpha_shape = shape
    alpha_shape[dim] = 1

    alphas = 1.05 + torch.rand(alpha_shape, dtype=torch.float64)

    P = entmax_bisect(X, alpha=alphas, dim=dim)

    ranges = [
        list(range(k)) if i != dim else [slice(None)]
        for i, k in enumerate(shape)
    ]

    for ix in product(*ranges):
        x = X[ix].unsqueeze(0)
        alpha = alphas[ix].item()
        p_true = entmax_bisect(x, alpha=alpha, dim=-1)
        assert torch.allclose(P[ix], p_true)


@pytest.mark.parametrize("dim", (0, 1, 2, 3))
def test_arbitrary_dimension_grad(dim):
    shape = [3, 4, 2, 5]

    alpha_shape = shape
    alpha_shape[dim] = 1

    f = partial(entmax_bisect, dim=dim)

    X = torch.randn(*shape, dtype=torch.float64, requires_grad=True)
    alphas = 1.05 + torch.rand(
        alpha_shape, dtype=torch.float64, requires_grad=True
    )
    gradcheck(f, (X, alphas), eps=1e-5)