Python numpy.fft.fftn() Examples

The following are 26 code examples for showing how to use numpy.fft.fftn(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module numpy.fft , or try the search function .

Example 1
Project: AcousticNLOS   Author: computational-imaging   File: AcousticNLOSReconstruction.py    License: MIT License 6 votes vote down vote up
def run_lct(self, meas, X, Y, Z, S, x_max, y_max):
        self.max_dist = int(self.T * self.v/2 / self.channels[0])
        slope_x = self.dx * x_max / (self.fend/self.B * self.max_dist) * (1 + ((self.fstart/self.B)/(self.fend/self.B))**2)
        slope_y = self.dy * y_max / (self.fend/self.B * self.max_dist) * (1 + ((self.fstart/self.B)/(self.fend/self.B))**2)
        
        # params
        psf, fpsf = lct.getPSF(X, Y, Z, S, slope_x, slope_y)
        mtx, mtxi = lct.interpMtx(Z, S, self.fstart/self.B * self.max_dist, self.fend/self.B * self.max_dist)

        def pad_array(x, S, Z, X, Y):
            return np.pad(x, ((S*Z//2, S*Z//2), (Y//2, Y//2), (X//2, X//2)), 'constant')

        def trim_array(x, S, Z, X, Y):
            return x[S*int(np.floor(Z/2))+1:-S*int(np.ceil(Z/2))+1, Y//2+1:-Y//2+1, X//2+1:-X//2+1]

        invpsf = np.conj(fpsf) / (abs(fpsf)**2 + 1 / self.snr)
        tdata = np.matmul(mtx, meas.reshape((Z, -1))).reshape((-1, Y, X))

        fdata = fftn(pad_array(tdata, S, Z, X, Y))
        tvol = abs(trim_array(ifftn(fdata * invpsf), S, Z, X, Y))
        out = np.matmul(mtxi, tvol.reshape((S*Z, -1))).reshape((-1, Y, X))

        return out 
Example 2
Project: AcousticNLOS   Author: computational-imaging   File: lct.py    License: MIT License 6 votes vote down vote up
def getPSF(X, Y, Z, S, slope_x, slope_y):
    x = np.linspace(-1, 1, 2*X)
    y = np.linspace(-1, 1, 2*Y)
    z = np.linspace(0,2,2*S*Z)
    grid_z, grid_y, grid_x = np.meshgrid(z, y, x, indexing='ij')
    psf = np.abs(slope_x**2 * grid_x**2 + slope_y**2 * grid_y**2 - grid_z)

    psf = psf == np.tile(np.min(psf, axis=0, keepdims=True), (2*S*Z, 1, 1)) 
    psf = psf.astype(np.float32)
    psf = psf / np.sum(psf)

    psf = np.roll(psf, X, axis=2)
    psf = np.roll(psf, Y, axis=1)
    fpsf = fftn(psf)

    return psf, fpsf 
Example 3
Project: AcousticNLOS   Author: computational-imaging   File: lct.py    License: MIT License 6 votes vote down vote up
def lct(x1, y1, t1, v, vol, snr):
    X = len(x1)
    Y = len(y1)
    Z = len(t1)
    S = 2
    slope = np.max(x1) / (np.max(t1) * v/2)
    slope = np.max(y1) / (np.max(t1) * v/2)
    psf, fpsf = getPSF(X, Y, Z, S, slope)
    mtx, mtxi = interpMtx(Z, S, 0, np.max(t1)*v)

    def pad_array(x, S, Z, X):
        return np.pad(x, ((S*Z//2, S*Z//2), (X//2, X//2), (Y//2, Y//2)), 'constant')

    def trim_array(x, S, Z, X):
        return x[S*(Z//2)+1:-S*(Z//2)+1, X//2+1:-X//2+1, Y//2+1:-Y//2+1]

    invpsf = np.conj(fpsf) / (abs(fpsf)**2 + 1 / snr)
    tdata = np.matmul(mtx, vol)
    fdata = fftn(pad_array(tdata, S, Z, X))
    tvol = abs(trim_array(ifftn(fdata * invpsf), S, Z, X))
    vol = np.matmul(mtxi, tvol)
    return vol 
Example 4
Project: sporco   Author: bwohlberg   File: fft.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def fl2norm2(xf, axis=(0, 1)):
    r"""Compute the squared :math:`\ell_2` norm in the DFT domain.

    Compute the squared :math:`\ell_2` norm in the DFT domain, taking
    into account the unnormalised DFT scaling, i.e. given the DFT of a
    multi-dimensional array computed via :func:`fftn`, return the
    squared :math:`\ell_2` norm of the original array.

    Parameters
    ----------
    xf : array_like
      Input array
    axis : sequence of ints, optional (default (0,1))
      Axes on which the input is in the frequency domain

    Returns
    -------
    x : float
      :math:`\|\mathbf{x}\|_2^2` where the input array is the result of
      applying :func:`fftn` to the specified axes of multi-dimensional
      array :math:`\mathbf{x}`
    """

    xfs = xf.shape
    return (np.linalg.norm(xf)**2) / np.prod(np.array([xfs[k] for k in axis])) 
Example 5
Project: bifrost   Author: ledatelescope   File: test_fft.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def run_test_c2c_impl(self, shape, axes, inverse=False, fftshift=False):
        shape = list(shape)
        shape[-1] *= 2 # For complex
        known_data = np.random.normal(size=shape).astype(np.float32).view(np.complex64)
        idata = bf.ndarray(known_data, space='cuda')
        odata = bf.empty_like(idata)
        fft = Fft()
        fft.init(idata, odata, axes=axes, apply_fftshift=fftshift)
        fft.execute(idata, odata, inverse)
        if inverse:
            if fftshift:
                known_data = np.fft.ifftshift(known_data, axes=axes)
            # Note: Numpy applies normalization while CUFFT does not
            norm = reduce(lambda a, b: a * b, [known_data.shape[d]
                                               for d in axes])
            known_result = gold_ifftn(known_data, axes=axes) * norm
        else:
            known_result = gold_fftn(known_data, axes=axes)
            if fftshift:
                known_result = np.fft.fftshift(known_result, axes=axes)
        x = (np.abs(odata.copy('system') - known_result) / known_result > RTOL).astype(np.int32)
        a = odata.copy('system')
        b = known_result
        compare(odata.copy('system'), known_result) 
Example 6
Project: pysteps   Author: pySTEPS   File: fft.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_numpy(shape, fftn_shape=None, **kwargs):
    import numpy.fft as numpy_fft

    f = {
        "fft2": numpy_fft.fft2,
        "ifft2": numpy_fft.ifft2,
        "rfft2": numpy_fft.rfft2,
        "irfft2": lambda X: numpy_fft.irfft2(X, s=shape),
        "fftshift": numpy_fft.fftshift,
        "ifftshift": numpy_fft.ifftshift,
        "fftfreq": numpy_fft.fftfreq,
    }
    if fftn_shape is not None:
        f["fftn"] = numpy_fft.fftn
    fft = SimpleNamespace(**f)

    return fft 
Example 7
Project: pysteps   Author: pySTEPS   File: fft.py    License: BSD 3-Clause "New" or "Revised" License 6 votes vote down vote up
def get_scipy(shape, fftn_shape=None, **kwargs):
    import numpy.fft as numpy_fft
    import scipy.fftpack as scipy_fft

    # use numpy implementation of rfft2/irfft2 because they have not been
    # implemented in scipy.fftpack
    f = {
        "fft2": scipy_fft.fft2,
        "ifft2": scipy_fft.ifft2,
        "rfft2": numpy_fft.rfft2,
        "irfft2": lambda X: numpy_fft.irfft2(X, s=shape),
        "fftshift": scipy_fft.fftshift,
        "ifftshift": scipy_fft.ifftshift,
        "fftfreq": scipy_fft.fftfreq,
    }
    if fftn_shape is not None:
        f["fftn"] = scipy_fft.fftn
    fft = SimpleNamespace(**f)

    return fft 
Example 8
Project: ocelot   Author: ocelot-collab   File: sc.py    License: GNU General Public License v3.0 5 votes vote down vote up
def potential(self, q, steps):
        hx = steps[0]
        hy = steps[1]
        hz = steps[2]
        Nx = q.shape[0]
        Ny = q.shape[1]
        Nz = q.shape[2]
        out = np.zeros((2*Nx-1, 2*Ny-1, 2*Nz-1))
        out[:Nx, :Ny, :Nz] = q
        K1 = self.sym_kernel(q.shape, steps)
        K2 = np.zeros((2*Nx-1, 2*Ny-1, 2*Nz-1))
        K2[0:Nx, 0:Ny, 0:Nz] = K1
        K2[0:Nx, 0:Ny, Nz:2*Nz-1] = K2[0:Nx, 0:Ny, Nz-1:0:-1] #z-mirror
        K2[0:Nx, Ny:2*Ny-1,:] = K2[0:Nx, Ny-1:0:-1, :]        #y-mirror
        K2[Nx:2*Nx-1, :, :] = K2[Nx-1:0:-1, :, :]             #x-mirror
        t0 = time.time()
        if pyfftw_flag:
            nthreads = int(conf.OCELOT_NUM_THREADS)
            if nthreads < 1:
                nthreads = 1
            K2_fft = pyfftw.builders.fftn(K2, axes=None, overwrite_input=False, planner_effort='FFTW_ESTIMATE',
                                       threads=nthreads, auto_align_input=False, auto_contiguous=False, avoid_copy=True)
            out_fft = pyfftw.builders.fftn(out, axes=None, overwrite_input=False, planner_effort='FFTW_ESTIMATE',
                                          threads=nthreads, auto_align_input=False, auto_contiguous=False, avoid_copy=True)
            out_ifft = pyfftw.builders.ifftn(out_fft()*K2_fft(), axes=None, overwrite_input=False, planner_effort='FFTW_ESTIMATE',
                                          threads=nthreads, auto_align_input=False, auto_contiguous=False, avoid_copy=True)
            out = np.real(out_ifft())

        else:
            out = np.real(ifftn(fftn(out)*fftn(K2)))
        t1 = time.time()
        logger.debug('fft time:' + str(t1-t0) + ' sec')
        out[:Nx, :Ny, :Nz] = out[:Nx,:Ny,:Nz]/(4*pi*epsilon_0*hx*hy*hz)
        return out[:Nx, :Ny, :Nz] 
Example 9
Project: Computable   Author: ktraunmueller   File: bench_basic.py    License: MIT License 5 votes vote down vote up
def bench_random(self):
        from numpy.fft import fftn as numpy_fftn
        print()
        print('    Multi-dimensional Fast Fourier Transform')
        print('===================================================')
        print('          |    real input     |   complex input    ')
        print('---------------------------------------------------')
        print('   size   |  scipy  |  numpy  |  scipy  |  numpy ')
        print('---------------------------------------------------')
        for size,repeat in [((100,100),100),((1000,100),7),
                            ((256,256),10),
                            ((512,512),3),
                            ]:
            print('%9s' % ('%sx%s' % size), end=' ')
            sys.stdout.flush()

            for x in [random(size).astype(double),
                      random(size).astype(cdouble)+random(size).astype(cdouble)*1j
                      ]:
                y = fftn(x)
                #if size > 500: y = fftn(x)
                #else: y = direct_dft(x)
                assert_array_almost_equal(fftn(x),y)
                print('|%8.2f' % measure('fftn(x)',repeat), end=' ')
                sys.stdout.flush()

                assert_array_almost_equal(numpy_fftn(x),y)
                print('|%8.2f' % measure('numpy_fftn(x)',repeat), end=' ')
                sys.stdout.flush()

            print(' (secs for %s calls)' % (repeat))

        sys.stdout.flush() 
Example 10
Project: ProxImaL   Author: comp-imaging   File: utils.py    License: MIT License 5 votes vote down vote up
def fftd(I, dims=None):

    # Compute fft
    if dims is None:
        X = fftn(I)
    elif dims == 2:
        X = fft2(I, axes=(0, 1))
    else:
        X = fftn(I, axes=tuple(range(dims)))

    return X 
Example 11
Project: dl-cs   Author: MRSRL   File: fftc.py    License: MIT License 5 votes vote down vote up
def fftnc(x, axes, ortho=True):
    tmp = fft.fftshift(x, axes=axes)
    tmp = fft.fftn(tmp, axes=axes, norm="ortho" if ortho else None)
    return fft.ifftshift(tmp, axes=axes) 
Example 12
Project: aitom   Author: xulabs   File: band_pass.py    License: GNU General Public License v3.0 5 votes vote down vote up
def filter_given_curve(v, curve):
    grid = GV.grid_displacement_to_center(v.shape, GV.fft_mid_co(v.shape))
    rad = GV.grid_distance_to_center(grid)
    rad = N.round(rad).astype(N.int)
    b = N.zeros(rad.shape)
    for (i, a) in enumerate(curve):
        b[(rad == i)] = a
    vf = ifftn(ifftshift((fftshift(fftn(v)) * b)))
    vf = N.real(vf)
    return vf 
Example 13
Project: aitom   Author: xulabs   File: ssnr2d.py    License: GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, images, band_width_radius=1.0):

        im_f = {}
        for k in range(len(images)):
            im = images[k]
            im = fftshift(fftn(im))
            im_f[k] = im

        self.im_f = im_f        # fft transformed images
        self.ks = set()
        self.img_siz = im_f[k].shape
        self.set_fft_mid_co()
        self.set_rad()
        self.band_width_radius = band_width_radius 
Example 14
Project: aitom   Author: xulabs   File: ssnr3d.py    License: GNU General Public License v3.0 5 votes vote down vote up
def __init__(self, images, masks, band_width_radius=1.0):

        im_f = {}
        for k in images:
            im = images[k]
            im = fftshift(fftn(im))
            im_f[k] = im

        self.im_f = im_f        # fft transformed images
        self.ms = masks         # masks
        self.ks = set()
        self.img_siz = im_f[k].shape
        self.set_fft_mid_co()
        self.set_rad()
        self.band_width_radius = band_width_radius 
Example 15
Project: aitom   Author: xulabs   File: faml.py    License: GNU General Public License v3.0 5 votes vote down vote up
def fourier_transform(v):
    return fftshift(fftn(v)) 
Example 16
Project: aitom   Author: xulabs   File: vol.py    License: GNU General Public License v3.0 5 votes vote down vote up
def fsc(v1, v2, band_width_radius=1.0):
    
    siz = v1.shape
    assert(siz == v2.shape)

    origin_co = GV.fft_mid_co(siz)
    
    x = N.mgrid[0:siz[0], 0:siz[1], 0:siz[2]]
    x = x.astype(N.float)

    for dim_i in range(3):      x[dim_i] -= origin_co[dim_i]

    rad = N.sqrt( N.square(x).sum(axis=0) )

    vol_rad = int( N.floor( N.min(siz) / 2.0 ) + 1)

    v1f = NF.fftshift( NF.fftn(v1) )
    v2f = NF.fftshift( NF.fftn(v2) )

    fsc_cors = N.zeros(vol_rad)

    # the interpolation can also be performed using scipy.ndimage.interpolation.map_coordinates()
    for r in range(vol_rad):

        ind = ( abs(rad - r) <= band_width_radius )

        c1 = v1f[ind]
        c2 = v2f[ind]

        fsc_cor_t = N.sum( c1 * N.conj(c2) ) / N.sqrt( N.sum( N.abs(c1)**2 ) * N.sum( N.abs(c2)**2) )
        fsc_cors[r] = N.real( fsc_cor_t )

    return fsc_cors 
Example 17
Project: aitom   Author: xulabs   File: gaussian.py    License: GNU General Public License v3.0 5 votes vote down vote up
def dog_smooth__large_map(v, s1, s2=None):

    if s2 is None:      s2 = s1 * 1.1       # the 1.1 is according to a DoG particle picking paper
    assert      s1 < s2

    size = v.shape


    pad_width = int(N.round(s2*2))
    vp = N.pad(array=v, pad_width=pad_width, mode='reflect')

    v_fft = fftn(vp).astype(N.complex64)
    del v;      GC.collect()


    g_small = difference_of_gauss_function(size=N.array([int(N.round(s2 * 4))]*3), sigma1=s1, sigma2=s2)
    assert      N.all(N.array(g_small.shape) <= N.array(vp.shape))       # make sure we can use CV.paste_to_whole_map()

    g = N.zeros(vp.shape)
    paste_to_whole_map(whole_map=g, vol=g_small, c=None)

    g_fft_conj = N.conj(   fftn(ifftshift(g)).astype(N.complex64)   )    # use ifftshift(g) to move center of gaussian to origin
    del g;      GC.collect()

    prod_t = (v_fft * g_fft_conj).astype(N.complex64)
    del v_fft;      GC.collect()
    del g_fft_conj;      GC.collect()

    prod_t_ifft = ifftn( prod_t ).astype(N.complex64)
    del prod_t;      GC.collect()

    v_conv = N.real( prod_t_ifft )
    del prod_t_ifft;      GC.collect()
    v_conv = v_conv.astype(N.float32)

    v_conv = v_conv[(pad_width+1):(pad_width+size[0]+1), (pad_width+1):(pad_width+size[1]+1), (pad_width+1):(pad_width+size[2]+1)]
    assert      size == v_conv.shape

    return v_conv 
Example 18
Project: aitom   Author: xulabs   File: util.py    License: GNU General Public License v3.0 5 votes vote down vote up
def fast_rotation_align(v1, m1, v2, m2, max_l=36):

    radius = int( max(v1.shape) / 2 )

    radii = list(range(1,radius+1))     # radii must start from 1, not 0!
    radii = N.asarray(radii, dtype=N.float64)		# convert to nparray

    # fftshift breaks order='F'
    v1fa = abs(fftshift(fftn(v1)))
    v2fa = abs(fftshift(fftn(v2)))

    v1fa = v1fa.copy(order='F')
    v2fa = v2fa.copy(order='F')

    m1sq = N.square(m1)
    m2sq = N.square(m2)

    a1t = v1fa * m1sq
    a2t = v2fa * m2sq

    cor12 = core.rot_search_cor(a1t, a2t, radii, max_l)

    sqt_cor11 = N.sqrt( N.real( core.rot_search_cor( N.square(v1fa) * m1sq, m2sq, radii, max_l ) ) )
    sqt_cor22 = N.sqrt( N.real( core.rot_search_cor( m1sq, N.square(v2fa) * m2sq, radii, max_l ) ) )

    cors = cor12 / (sqt_cor11 * sqt_cor22)

    # N.real breaks order='F' by not making explicit copy.
    cors = N.real(cors)
    cors = cors.copy(order='F')

    (cor, angs) = core.local_max_angles(cors, 8)

    return angs 
Example 19
Project: aitom   Author: xulabs   File: util.py    License: GNU General Public License v3.0 5 votes vote down vote up
def translation_align_given_rotation_angles(v1, m1, v2, m2, angs):
    v1f = fftn(v1)
    v1f[0,0,0] = 0.0
    v1f = fftshift(v1f)

    a = [None] * len(angs)
    for i, ang in enumerate(angs):
        v2r = GR.rotate_pad_mean(v2, angle=ang)
        v2rf = fftn(v2r)
        v2rf[0,0,0] = 0.0
        v2rf = fftshift(v2rf)

        m2r = GR.rotate_pad_zero(m2, angle=ang)
        m1_m2r = m1 * m2

        # masked images
        v1fm = v1f * m1_m2r
        v2rfm = v2rf * m1_m2r

        # normalize values
        v1fmn = v1fm / N.sqrt(N.square(N.abs(v1fm)).sum())
        v2rfmn = v2rfm / N.sqrt(N.square(N.abs(v2rfm)).sum())

        lc = translation_align__given_unshifted_fft(ifftshift(v1fmn), ifftshift(v2rfmn))

        a[i] = {'ang':ang, 'loc':lc['loc'], 'score':lc['cor']}

    return a 
Example 20
Project: AcousticNLOS   Author: computational-imaging   File: ADMMReconstruction.py    License: MIT License 5 votes vote down vote up
def fconv(x, otf):
    return np.real(ifftn(fftn(x) * otf)) 
Example 21
Project: sporco   Author: bwohlberg   File: fft.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def fftn(a, s=None, axes=None):
    """Multi-dimensional discrete Fourier transform.

    Compute the multi-dimensional discrete Fourier transform. This function
    is a wrapper for :func:`pyfftw.interfaces.numpy_fft.fftn`,
    with an interface similar to that of :func:`numpy.fft.fftn`.

    Parameters
    ----------
    a : array_like
      Input array (can be complex)
    s : sequence of ints, optional (default None)
      Shape of the output along each transformed axis (input is cropped or
      zero-padded to match).
    axes : sequence of ints, optional (default None)
      Axes over which to compute the DFT.

    Returns
    -------
    af : complex ndarray
      DFT of input array
    """

    return pyfftw.interfaces.numpy_fft.fftn(
        a, s=s, axes=axes, overwrite_input=False,
        planner_effort=pyfftw_planner_effort, threads=pyfftw_threads) 
Example 22
Project: sporco   Author: bwohlberg   File: fft.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def fftconv(a, b, axes=(0, 1), origin=None):
    """Multi-dimensional convolution via the Discrete Fourier Transform.

    Compute a multi-dimensional convolution via the Discrete Fourier
    Transform. Note that the output has a phase shift relative to the
    output of :func:`scipy.ndimage.convolve` with the default `origin`
    parameter.

    Parameters
    ----------
    a : array_like
      Input array
    b : array_like
      Input array
    axes : sequence of ints, optional (default (0, 1))
      Axes on which to perform convolution
    origin : sequence of ints or None optional (default None)
      Indices of centre of `a` filter. The default of None corresponds
      to a centre at 0 on all axes of `a`

    Returns
    -------
    ab : ndarray
      Convolution of input arrays, `a` and `b`, along specified `axes`
    """

    if np.isrealobj(a) and np.isrealobj(b):
        fft = rfftn
        ifft = irfftn
    else:
        fft = fftn
        ifft = ifftn
    dims = np.maximum([a.shape[i] for i in axes], [b.shape[i] for i in axes])
    af = fft(a, dims, axes)
    bf = fft(b, dims, axes)
    ab = ifft(af * bf, dims, axes)
    if origin is not None:
        ab = np.roll(ab, -np.array(origin), axis=axes)
    return ab 
Example 23
Project: sporco   Author: bwohlberg   File: fft.py    License: BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def _fftn(a, s=None, axes=None):
        return  npfft.fftn(a, s, axes).astype(complex_dtype(a.dtype)) 
Example 24
Project: diffsims   Author: pyxem   File: fourier_transform.py    License: GNU General Public License v3.0 5 votes vote down vote up
def fftn(a, s=None, axes=None, norm=None, **_):
        return _fftn(a, s, axes, norm) 
Example 25
Project: diffsims   Author: pyxem   File: fourier_transform.py    License: GNU General Public License v3.0 5 votes vote down vote up
def convolve(arr1, arr2, dx=None, axes=None):
    """
    Performs a centred convolution of input arrays

    Parameters
    ----------
    arr1, arr2 : `numpy.ndarray`
        Arrays to be convolved. If dimensions are not equal then 1s are appended
        to the lower dimensional array. Otherwise, arrays must be broadcastable.
    dx : float > 0, list of float, or `None` , optional
        Grid spacing of input arrays. Output is scaled by
        `dx**max(arr1.ndim, arr2.ndim)`. default=`None` applies no scaling
    axes : tuple of ints or `None`, optional
        Choice of axes to convolve. default=`None` convolves all axes

    """
    if arr2.ndim > arr1.ndim:
        arr1, arr2 = arr2, arr1
        if axes is None:
            axes = range(arr2.ndim)
    arr2 = arr2.reshape(arr2.shape + (1,) * (arr1.ndim - arr2.ndim))

    if dx is None:
        dx = 1
    elif isscalar(dx):
        dx = dx ** (len(axes) if axes is not None else arr1.ndim)
    else:
        dx = prod(dx)

    arr1 = fftn(arr1, axes=axes)
    arr2 = fftn(ifftshift(arr2), axes=axes)
    out = ifftn(arr1 * arr2, axes=axes) * dx
    return require(out, requirements="CA") 
Example 26
Project: diffsims   Author: pyxem   File: fourier_transform.py    License: GNU General Public License v3.0 4 votes vote down vote up
def plan_fft(A, n=None, axis=None, norm=None, **_):
        """
        Plans an fft for repeated use. Parameters are the same as for `pyfftw`'s `fftn`
        which are, where possible, the same as the `numpy` equivalents.
        Note that some functionality is only possible when using the `pyfftw` backend.

        Parameters
        ----------
        A : `numpy.ndarray`, of dimension `d`
            Array of same shape to be input for the fft
        n : iterable or `None`, `len(n) == d`, optional
            The output shape of fft (default=`None` is same as `A.shape`)
        axis : `int`, iterable length `d`, or `None`, optional
            The axis (or axes) to transform (default=`None` is all axes)
        overwrite : `bool`, optional
            Whether the input array can be overwritten during computation
            (default=False)
        planner : {0, 1, 2, 3}, optional
            Amount of effort put into optimising Fourier transform where 0 is low
            and 3 is high (default=`1`).
        threads : `int`, `None`
            Number of threads to use (default=`None` is all threads)
        auto_align_input : `bool`, optional
            If `True` then may re-align input (default=`True`)
        auto_contiguous : `bool`, optional
            If `True` then may re-order input (default=`True`)
        avoid_copy : `bool`, optional
            If `True` then may over-write initial input (default=`False`)
        norm : {None, 'ortho'}, optional
            Indicate whether fft is normalised (default=`None`)

        Returns
        -------
        plan : function
            Returns the Fourier transform of `B`, `plan() == fftn(B)`
        B : `numpy.ndarray`, `A.shape`
            Array which should be modified inplace for fft to be computed. If
            possible, `B is A`.


        Example
        -------
        A = numpy.zeros((8,16))
        plan, B = plan_fft(A)

        B[:,:] = numpy.random.rand(8,16)
        numpy.fft.fftn(B) == plan()

        B = numpy.random.rand(8,16)
        numpy.fft.fftn(B) != plan()

        """
        return lambda: fftn(A, n, axis, norm), A