import torch_dct as dct
import scipy.fftpack as fftpack
import numpy as np
import torch

np.random.seed(1)

EPS = 1e-10


def test_dct1():
    for N in [2, 5, 32, 111]:
        x = np.random.normal(size=(1, N,))
        ref = fftpack.dct(x, type=1)
        act = dct.dct1(torch.tensor(x)).numpy()
        assert np.abs(ref - act).max() < EPS, ref

    for d in [2, 3, 4]:
        x = np.random.normal(size=(2,) * d)
        ref = fftpack.dct(x, type=1)
        act = dct.dct1(torch.tensor(x)).numpy()
        assert np.abs(ref - act).max() < EPS, ref


def test_idct1():
    for N in [2, 5, 32, 111]:
        x = np.random.normal(size=(1, N))
        X = dct.dct1(torch.tensor(x))
        y = dct.idct1(X).numpy()
        assert np.abs(x - y).max() < EPS, x


def test_dct():
    for norm in [None, 'ortho']:
        for N in [2, 3, 5, 32, 111]:
            x = np.random.normal(size=(1, N,))
            ref = fftpack.dct(x, type=2, norm=norm)
            act = dct.dct(torch.tensor(x), norm=norm).numpy()
            assert np.abs(ref - act).max() < EPS, (norm, N)

        for d in [2, 3, 4, 11]:
            x = np.random.normal(size=(2,) * d)
            ref = fftpack.dct(x, type=2, norm=norm)
            act = dct.dct(torch.tensor(x), norm=norm).numpy()
            assert np.abs(ref - act).max() < EPS, (norm, d)


def test_idct():
    for norm in [None, 'ortho']:
        for N in [5, 2, 32, 111]:
            x = np.random.normal(size=(1, N))
            X = dct.dct(torch.tensor(x), norm=norm)
            y = dct.idct(X, norm=norm).numpy()
            assert np.abs(x - y).max() < EPS, x


def test_cuda():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')

        for N in [2, 5, 32, 111]:
            x = np.random.normal(size=(1, N,))
            ref = fftpack.dct(x, type=1)
            act = dct.dct1(torch.tensor(x, device=device)).cpu().numpy()
            assert np.abs(ref - act).max() < EPS, ref

        for d in [2, 3, 4]:
            x = np.random.normal(size=(2,) * d)
            ref = fftpack.dct(x, type=1)
            act = dct.dct1(torch.tensor(x, device=device)).cpu().numpy()
            assert np.abs(ref - act).max() < EPS, ref

        for norm in [None, 'ortho']:
            for N in [2, 3, 5, 32, 111]:
                x = np.random.normal(size=(1, N,))
                ref = fftpack.dct(x, type=2, norm=norm)
                act = dct.dct(torch.tensor(x, device=device), norm=norm).cpu().numpy()
                assert np.abs(ref - act).max() < EPS, (norm, N)

            for d in [2, 3, 4, 11]:
                x = np.random.normal(size=(2,) * d)
                ref = fftpack.dct(x, type=2, norm=norm)
                act = dct.dct(torch.tensor(x, device=device), norm=norm).cpu().numpy()
                assert np.abs(ref - act).max() < EPS, (norm, d)

            for N in [5, 2, 32, 111]:
                x = np.random.normal(size=(1, N))
                X = dct.dct(torch.tensor(x, device=device), norm=norm)
                y = dct.idct(X, norm=norm).cpu().numpy()
                assert np.abs(x - y).max() < EPS, x

def test_dct_2d():
    for N1 in [2, 5, 32]:
        for N2 in [2, 5, 32]:
            x = np.random.normal(size=(1, N1, N2))
            ref = fftpack.dct(x, axis=2, type=2)
            ref = fftpack.dct(ref, axis=1, type=2)
            act = dct.dct_2d(torch.tensor(x)).numpy()
            assert np.abs(ref - act).max() < EPS, (ref, act)


def test_idct_2d():
    for N1 in [2, 5, 32]:
        for N2 in [2, 5, 32]:
            x = np.random.normal(size=(1, N1, N2))
            X = dct.dct_2d(torch.tensor(x))
            y = dct.idct_2d(X).numpy()
            assert np.abs(x - y).max() < EPS, x


def test_dct_3d():
    for N1 in [2, 5, 32]:
        for N2 in [2, 5, 32]:
            for N3 in [2, 5, 32]:
                x = np.random.normal(size=(1, N1, N2, N3))
                ref = fftpack.dct(x, axis=3, type=2)
                ref = fftpack.dct(ref, axis=2, type=2)
                ref = fftpack.dct(ref, axis=1, type=2)
                act = dct.dct_3d(torch.tensor(x)).numpy()
                assert np.abs(ref - act).max() < EPS, (ref, act)


def test_idct_3d():
    for N1 in [2, 5, 32]:
        for N2 in [2, 5, 32]:
            for N3 in [2, 5, 32]:
                x = np.random.normal(size=(1, N1, N2, N3))
                X = dct.dct_3d(torch.tensor(x))
                y = dct.idct_3d(X).numpy()
                assert np.abs(x - y).max() < EPS, x