import numpy as np
from numpy.fft import fft, ifft, fftshift, ifftshift
from .FFTBase import FFTBase


class T1FFT(FFTBase):
    """
    The Fast Fourier Transform on the Circle / 1-Torus / 1-Sphere.
    """

    @staticmethod
    def analyze(f, axis=0):
        """
        Compute the Fourier Transform of the discretely sampled function f : T^1 -> C.

        Let f : T^1 -> C be a band-limited function on the circle.
        The samples f(theta_k) correspond to points on a regular grid on the circle, as returned by spaces.T1.linspace:
        theta_k = 2 pi k / N
        for k = 0, ..., N - 1

        This function computes
        \hat{f}_n = (1/N) \sum_{k=0}^{N-1} f(theta_k) e^{-i n theta_k}
        which, if f has band-limit less than N, is equal to:
        \hat{f}_n = \int_0^{2pi} f(theta) e^{-i n theta} dtheta / 2pi,
                  = <f(theta), e^{i n theta}>
        where dtheta / 2pi is the normalized Haar measure on T^1, and < , > denotes the inner product on Hilbert space,
        with respect to which this transform is unitary.

        The range of frequencies n is -floor(N/2) <= n <= ceil(N/2) - 1

        :param f:
        :param axis:
        :return:
        """
        # The numpy FFT returns coefficients in a different order than we want them,
        # and using a different normalization.
        fhat = fft(f, axis=axis)
        fhat = fftshift(fhat, axes=axis)
        return fhat / f.shape[axis]

    @staticmethod
    def synthesize(f_hat, axis=0):
        """
        Compute the inverse / synthesis Fourier transform of the function f_hat : Z -> C.
        The function f_hat(n) is sampled at points in a limited range -floor(N/2) <= n <= ceil(N/2) - 1

        This function returns
        f[k] = f(theta_k) = sum_{n=-floor(N/2)}^{ceil(N/2)-1} f_hat(n) exp(i n theta_k)
        where theta_k = 2 pi k / N
        for k = 0, ..., N - 1

        :param f_hat:
        :param axis:
        :return:
        """

        f_hat = ifftshift(f_hat * f_hat.shape[axis], axes=axis)
        f = ifft(f_hat, axis=axis)
        return f

    @staticmethod
    def analyze_naive(f):
        f_hat = np.zeros_like(f)
        for n in range(f.size):
            for k in range(f.size):
                theta_k = k * 2 * np.pi / f.size
                f_hat[n] += f[k] * np.exp(-1j * n * theta_k)
        return fftshift(f_hat / f.size, axes=0)