import torch
torch.manual_seed(0)
# from _ext import th_fft
import pytorch_fft.fft as cfft
import pytorch_fft.fft.autograd as afft
import numpy as np
import numpy.fft as nfft

def run_c2c(x, z, _f1, _f2, _if1, _if2, atol):
    y1, y2 = _f1(x, z)
    x_np = x.cpu().numpy().squeeze()
    y_np = _f2(x_np)
    assert np.allclose(y1.cpu().numpy(), y_np.real, atol=atol)
    assert np.allclose(y2.cpu().numpy(), y_np.imag, atol=atol)

    x0, z0 = _if1(y1, y2)
    x0_np = _if2(y_np)
    assert np.allclose(x0.cpu().numpy(), x0_np.real, atol=atol)
    assert np.allclose(z0.cpu().numpy(), x0_np.imag, atol=atol)


def test_c2c(_f1, _f2, _if1, _if2): 
    batch = 3
    nch = 4
    n = 5
    m = 7
    x = torch.randn(batch*nch*n*m).view(batch, nch, n, m).cuda()
    z = torch.zeros(batch, nch, n, m).cuda()
    run_c2c(x, z, _f1, _f2, _if1, _if2, 1e-6)
    run_c2c(x.double(), z.double(), _f1, _f2, _if1, _if2, 1e-14)



def run_r2c(x, _f1, _f2, _if1, _if2, atol):
    y1, y2 = _f1(x)
    x_np = x.cpu().numpy().squeeze()
    y_np = _f2(x_np)
    assert np.allclose(y1.cpu().numpy(), y_np.real, atol=atol)
    assert np.allclose(y2.cpu().numpy(), y_np.imag, atol=atol)

    x0 = _if1(y1, y2)
    x0_np = _if2(y_np)
    assert np.allclose(x0.cpu().numpy(), x0_np.real, atol=atol)


def test_r2c(_f1, _f2, _if1, _if2): 
    batch = 3
    nch = 2
    n = 2
    m = 4
    x = torch.randn(batch*nch*n*m).view(batch, nch, n, m).cuda()
    run_r2c(x, _f1, _f2, _if1, _if2, 1e-6)
    run_r2c(x.double(), _f1, _f2, _if1, _if2, 1e-14)

def test_expand(): 
    X = torch.randn(2,2,4,4).cuda().double()
    zeros = torch.zeros(2,2,4,4).cuda().double()
    r1, r2 = cfft.rfft2(X)
    c1, c2 = cfft.fft2(X, zeros)
    assert np.allclose(cfft.expand(r1).cpu().numpy(), c1.cpu().numpy())
    assert np.allclose(cfft.expand(r2, imag=True).cpu().numpy(), c2.cpu().numpy())
    r1, r2 = cfft.rfft3(X)
    c1, c2 = cfft.fft3(X, zeros)
    assert np.allclose(cfft.expand(r1).cpu().numpy(), c1.cpu().numpy())
    assert np.allclose(cfft.expand(r2, imag=True).cpu().numpy(), c2.cpu().numpy())

    X = torch.randn(2,2,5,5).cuda().double()
    zeros = torch.zeros(2,2,5,5).cuda().double()
    r1, r2 = cfft.rfft3(X)
    c1, c2 = cfft.fft3(X, zeros)
    assert np.allclose(cfft.expand(r1, odd=True).cpu().numpy(), c1.cpu().numpy())
    assert np.allclose(cfft.expand(r2, imag=True, odd=True).cpu().numpy(), c2.cpu().numpy())

def create_real_var(*args):
    return (torch.autograd.Variable(torch.randn(*args).double().cuda(), requires_grad=True),)

def create_complex_var(*args):
    return (torch.autograd.Variable(torch.randn(*args).double().cuda(), requires_grad=True),
            torch.autograd.Variable(torch.randn(*args).double().cuda(), requires_grad=True))

def test_fft_gradcheck():
    invar = create_complex_var(5,10)
    assert torch.autograd.gradcheck(afft.Fft(), invar)

def test_ifft_gradcheck():
    invar = create_complex_var(5,10)
    assert torch.autograd.gradcheck(afft.Ifft(), invar)

def test_fft2d_gradcheck():
    invar = create_complex_var(5,5,5)
    assert torch.autograd.gradcheck(afft.Fft2d(), invar)

def test_ifft2d_gradcheck():
    invar = create_complex_var(5,5,5)
    assert torch.autograd.gradcheck(afft.Ifft2d(), invar)

def test_fft3d_gradcheck():
    invar = create_complex_var(5,3,3,3)
    assert torch.autograd.gradcheck(afft.Fft3d(), invar)

def test_ifft3d_gradcheck():
    invar = create_complex_var(5,3,3,3)
    assert torch.autograd.gradcheck(afft.Ifft3d(), invar)

def test_rfft_gradcheck():
    invar = create_real_var(5,10)
    assert torch.autograd.gradcheck(afft.Rfft(), invar)

    invar = create_real_var(5,11)
    assert torch.autograd.gradcheck(afft.Rfft(), invar)

def test_rfft2d_gradcheck():
    invar = create_real_var(5,6,6)
    assert torch.autograd.gradcheck(afft.Rfft2d(), invar)

    invar = create_real_var(5,5,5)
    assert torch.autograd.gradcheck(afft.Rfft2d(), invar)

def test_rfft3d_gradcheck():
    invar = create_real_var(5,4,4,4)
    assert torch.autograd.gradcheck(afft.Rfft3d(), invar)

    invar = create_real_var(5,3,3,3)
    assert torch.autograd.gradcheck(afft.Rfft3d(), invar)

def test_irfft_gradcheck():
    invar = create_complex_var(5,11)
    assert torch.autograd.gradcheck(afft.Irfft(), invar)

def test_irfft2d_gradcheck():
    invar = create_complex_var(5,5,5)
    assert torch.autograd.gradcheck(afft.Irfft2d(), invar)

def test_irfft3d_gradcheck():
    invar = create_complex_var(5,3,3,3)
    assert torch.autograd.gradcheck(afft.Irfft3d(), invar)

if __name__ == "__main__": 
    if torch.cuda.is_available():
        nfft3 = lambda x: nfft.fftn(x,axes=(1,2,3))
        nifft3 = lambda x: nfft.ifftn(x,axes=(1,2,3))

        cfs = [cfft.fft, cfft.fft2, cfft.fft3]
        nfs = [nfft.fft, nfft.fft2, nfft3]
        cifs = [cfft.ifft, cfft.ifft2, cfft.ifft3]
        nifs = [nfft.ifft, nfft.ifft2, nifft3]
        
        for args in zip(cfs, nfs, cifs, nifs):
            test_c2c(*args)

        nrfft3 = lambda x: nfft.rfftn(x,axes=(1,2,3))
        nirfft3 = lambda x: nfft.irfftn(x,axes=(1,2,3))

        cfs = [cfft.rfft, cfft.rfft2, cfft.rfft3]
        nfs = [nfft.rfft, nfft.rfft2, nrfft3]
        cifs = [cfft.irfft, cfft.irfft2, cfft.irfft3]
        nifs = [nfft.irfft, nfft.irfft2, nirfft3]
        
        for args in zip(cfs, nfs, cifs, nifs):
            test_r2c(*args)

        test_expand()
        test_fft_gradcheck()
        test_ifft_gradcheck()
        test_fft2d_gradcheck()
        test_ifft2d_gradcheck()
        test_fft3d_gradcheck()
        test_ifft3d_gradcheck()

        test_rfft_gradcheck()
        test_irfft_gradcheck()
        test_rfft2d_gradcheck()
        test_irfft2d_gradcheck()
        test_rfft3d_gradcheck()
        test_irfft3d_gradcheck()
    else:
        print("Cuda not available, cannot test.")