import numpy as np
from numpy.fft import fft2, ifft2, fftshift, ifftshift
from .FFTBase import FFTBase


class T2FFT(FFTBase):
    """
    The Fast Fourier Transform on the 2-Torus.

    REMOVE?

    The torus is parameterized by two cyclic variables (x, y).
    The standard domain is (x, y) in [0, 1) x [0, 1), in which case the Fourier basis functions are:
     exp( i 2 pi xi^T (x; y))
    where xi is the spectral variable, xi in Z^2.

    The Fourier transform is
    \hat{f}[p, q] = 1/2pi int_0^2pi f(x, y) exp(-i 2 pi xi^T (x; y)) dx dy



    but this class allows one to use arbitrarily scaled and shifted domains D = [l_x, u_x) x [l_y, u_y)
    Let the width of the domain be given by
      alpha_x = u_x - l_x
      alpha_y = u_y - l_y
    The basis functions on [l_x, u_x) x [l_y, u_y) are
     exp( i 2 pi xi^T ((x - l_x) / alpha_x; (y - l_y) / alpha_y))
    where xi is the spectral variable, xi in Z^2.
    The normalized Haar measure is dx dy / (alpha_x * alpha_y) (in terms of Lebesque measure dx dy)

    So the Fourier transform on this particular parameterization of the torus is:
    \hat{f}_pq = 1/alpha int_lx^ux int_ly^uy f(x) e^{-2 pi i (p, q)^T ((x - lx) / alpha_x; (y - ly)/alpha_y)} dx dy

    This is what the current class computes, given discrete samples in the domain D.
    The samples are assumed to come from the following sampling grid:
    (x_i, y_j), i = 0, ... N - 1; j = 0, ..., N - 1
    x_i = lx + alpha_x * (i / N_x)
    y_i = ly + alpha_y * (i / N_y)
    this is the ouput of
    x = np.linspace(lx, ux, N_x, endpoint=False)
    x = np.linspace(ly, uy, N_y, endpoint=False)
    X, Y = np.meshgrid(x, y)

    """
    def __init__(self, lower_bound=(0., 0.), upper_bound=(1., 1.)):
        self.lower_bound = np.array(lower_bound)
        self.upper_bound = np.array(upper_bound)


    @staticmethod
    def analyze(f, axes=(0, 1)):
        """
        Compute the Fourier Transform of the discretely sampled function f : T^2 -> C.

        Let f : T^2 -> C be a band-limited function on the torus.
        The samples f(theta_k, phi_l) correspond to points on a regular grid on the circle,
        as returned by spaces.T1.linspace:
        theta_k = phi_k = 2 pi k / N
        for k = 0, ..., N - 1 and l = 0, ..., N - 1

        This function computes
        \hat{f}_n = (1/N) \sum_{k=0}^{N-1} f(theta_k) e^{-i n theta_k}
        which, if f has band-limit less than N, is equal to:
        \hat{f}_n = \int_0^{2pi} f(theta) e^{-i n theta} dtheta / 2pi,
                  = <f(theta), e^{i n theta}>
        where dtheta / 2pi is the normalized Haar measure on T^1, and < , > denotes the inner product on Hilbert space,
        with respect to which this transform is unitary.

        The range of frequencies n is -floor(N/2) <= n <= ceil(N/2) - 1

        :param f:
        :param axis:
        :return:
        """
        # The numpy FFT returns coefficients in a different order than we want them,
        # and using a different normalization.
        f_hat = fft2(f, axes=axes)
        f_hat = fftshift(f_hat, axes=axes)
        size = np.prod([f.shape[ax] for ax in axes])
        return f_hat / size

    @staticmethod
    def synthesize(f_hat, axes=(0, 1)):
        """
        :param f_hat:
        :param axis:
        :return:
        """

        size = np.prod([f_hat.shape[ax] for ax in axes])
        f_hat = ifftshift(f_hat * size, axes=axes)
        f = ifft2(f_hat, axes=axes)
        return f