import torch_dct
import scipy.fftpack as fftpack
import numpy as np
import torch

np.random.seed(1)

EPS = 1e-3
# THIS IS NOT HOW THESE LAYERS SHOULD BE USED IN PRACTICE
# only written this way for testing convenience
dct1 = lambda x: torch_dct.LinearDCT(x.size(1), type='dct1')(x).data
idct1 = lambda x: torch_dct.LinearDCT(x.size(1), type='idct1')(x).data
def dct(x, norm=None):
    return torch_dct.LinearDCT(x.size(1), type='dct', norm=norm)(x).data
def idct(x, norm=None):
    return torch_dct.LinearDCT(x.size(1), type='idct', norm=norm)(x).data

dct_2d = lambda x: torch_dct.apply_linear_2d(x, torch_dct.LinearDCT(x.size(1), type='dct')).data
dct_3d = lambda x: torch_dct.apply_linear_3d(x, torch_dct.LinearDCT(x.size(1), type='dct')).data
idct_2d = lambda x: torch_dct.apply_linear_2d(x, torch_dct.LinearDCT(x.size(1), type='idct')).data
idct_3d = lambda x: torch_dct.apply_linear_3d(x, torch_dct.LinearDCT(x.size(1), type='idct')).data

def test_dct1():
    for N in [2, 5, 32, 111]:
        x = np.random.normal(size=(1, N,))
        ref = fftpack.dct(x, type=1)
        act = dct1(torch.tensor(x).float()).numpy()
        assert np.abs(ref - act).max() < EPS, ref

    for d in [2, 3, 4]:
        x = np.random.normal(size=(2,) * d)
        ref = fftpack.dct(x, type=1)
        act = dct1(torch.tensor(x).float()).numpy()
        assert np.abs(ref - act).max() < EPS, ref


def test_idct1():
    for N in [2, 5, 32, 111]:
        x = np.random.normal(size=(1, N))
        X = dct1(torch.tensor(x).float())
        y = idct1(X).numpy()
        assert np.abs(x - y).max() < EPS, x


def test_dct():
    for norm in [None, 'ortho']:
        for N in [2, 3, 5, 32, 111]:
            x = np.random.normal(size=(1, N,))
            ref = fftpack.dct(x, type=2, norm=norm)
            act = dct(torch.tensor(x).float(), norm=norm).numpy()
            assert np.abs(ref - act).max() < EPS, (norm, N)

        for d in [2, 3, 4, 11]:
            x = np.random.normal(size=(2,) * d)
            ref = fftpack.dct(x, type=2, norm=norm)
            act = dct(torch.tensor(x).float(), norm=norm).numpy()
            assert np.abs(ref - act).max() < EPS, (norm, d)


def test_idct():
    for norm in [None, 'ortho']:
        for N in [5, 2, 32, 111]:
            x = np.random.normal(size=(1, N))
            X = dct(torch.tensor(x).float(), norm=norm)
            y = idct(X, norm=norm).numpy()
            assert np.abs(x - y).max() < EPS, x

def test_dct_2d():
    for N1 in [2, 5, 32]:
        x = np.random.normal(size=(1, N1, N1))
        ref = fftpack.dct(x, axis=2, type=2)
        ref = fftpack.dct(ref, axis=1, type=2)
        act = dct_2d(torch.tensor(x).float()).numpy()
        assert np.abs(ref - act).max() < EPS, (ref, act)


def test_idct_2d():
    for N1 in [2, 5, 32]:
        x = np.random.normal(size=(1, N1, N1))
        X = dct_2d(torch.tensor(x).float())
        y = idct_2d(X).numpy()
        assert np.abs(x - y).max() < EPS, x


def test_dct_3d():
    for N1 in [2, 5, 32]:
        x = np.random.normal(size=(1, N1, N1, N1))
        ref = fftpack.dct(x, axis=3, type=2)
        ref = fftpack.dct(ref, axis=2, type=2)
        ref = fftpack.dct(ref, axis=1, type=2)
        act = dct_3d(torch.tensor(x).float()).numpy()
        assert np.abs(ref - act).max() < EPS, (ref, act)


def test_idct_3d():
    for N1 in [2, 5, 32]:
            x = np.random.normal(size=(1, N1, N1, N1))
            X = dct_3d(torch.tensor(x).float())
            y = idct_3d(X).numpy()
            assert np.abs(x - y).max() < EPS, x