import operator

import autograd.numpy as np
from autograd.extend import primitive, defvjp
from scipy import fftpack


def _centered(arr, newshape):
    """Return the center newshape portion of the array.

    This function is used by `fft_convolve` to remove
    the zero padded region of the convolution.

    Note: If the array shape is odd and the target is even,
    the center of `arr` is shifted to the center-right
    pixel position.
    This is slightly different than the scipy implementation,
    which uses the center-left pixel for the array center.
    The reason for the difference is that we have
    adopted the convention of `np.fft.fftshift` in order
    to make sure that changing back and forth from
    fft standard order (0 frequency and position is
    in the bottom left) to 0 position in the center.
    """
    newshape = np.asarray(newshape)
    currshape = np.array(arr.shape)

    if not np.all(newshape <= currshape):
        msg = (
            "arr must be larger than newshape in both dimensions, received {0}, and {1}"
        )
        raise ValueError(msg.format(arr.shape, newshape))

    startind = (currshape - newshape + 1) // 2
    endind = startind + newshape
    myslice = [slice(startind[k], endind[k]) for k in range(len(endind))]

    return arr[tuple(myslice)]


@primitive
def fast_zero_pad(arr, pad_width):
    """Fast version of numpy.pad when `mode="constant"`

    Executing `numpy.pad` with zeros is ~1000 times slower
    because it doesn't make use of the `zeros` method for padding.

    Paramters
    ---------
    arr: array
        The array to pad
    pad_width: tuple
        Number of values padded to the edges of each axis.
        See numpy docs for more.

    Returns
    -------
    result: array
        The array padded with `constant_values`
    """
    newshape = tuple([a+ps[0]+ps[1] for a, ps in zip(arr.shape, pad_width)])

    result = np.zeros(newshape, dtype=arr.dtype)
    slices = tuple([slice(start, s-end) for s, (start, end) in zip(result.shape, pad_width)])
    result[slices] = arr
    return result


def _fast_zero_pad_grad(result, arr, pad_width):
    """Gradient for fast_zero_pad
    """
    slices = tuple([slice(start, s-end) for s, (start, end) in zip(result.shape, pad_width)])
    return lambda grad_chain: grad_chain[slices]


# Register this function in autograd
defvjp(fast_zero_pad, _fast_zero_pad_grad)


def _pad(arr, newshape, axes=None, mode="constant", constant_values=0):
    """Pad an array to fit into newshape

    Pad `arr` with zeros to fit into newshape,
    which uses the `np.fft.fftshift` convention of moving
    the center pixel of `arr` (if `arr.shape` is odd) to
    the center-right pixel in an even shaped `newshape`.
    """
    if axes is None:
        newshape = np.asarray(newshape)
        currshape = np.array(arr.shape)
        dS = newshape - currshape
        startind = (dS + 1) // 2
        endind = dS - startind
        pad_width = list(zip(startind, endind))
    else:
        # only pad the axes that will be transformed
        pad_width = [(0, 0) for axis in arr.shape]
        try:
            len(axes)
        except TypeError:
            axes = [axes]
        for a, axis in enumerate(axes):
            dS = newshape[a] - arr.shape[axis]
            startind = (dS + 1) // 2
            endind = dS - startind
            pad_width[axis] = (startind, endind)
    if mode == "constant" and constant_values == 0:
        result = fast_zero_pad(arr, pad_width)
    else:
        result = np.pad(arr, pad_width, mode=mode)
    return result


def _get_fft_shape(im_or_shape1, im_or_shape2, padding=3, axes=None, max=False):
    """Return the fast fft shapes for each spatial axis

    Calculate the fast fft shape for each dimension in
    axes.
    """
    if hasattr(im_or_shape1, "shape"):
        shape1 = np.asarray(im_or_shape1.shape)
    else:
        shape1 = np.asarray(im_or_shape1)
    if hasattr(im_or_shape2, "shape"):
        shape2 = np.asarray(im_or_shape2.shape)
    else:
        shape2 = np.asarray(im_or_shape2)
    # Make sure the shapes are the same size
    if len(shape1) != len(shape2):
        msg = (
            "img1 and img2 must have the same number of dimensions, but got {0} and {1}"
        )
        raise ValueError(msg.format(len(shape1), len(shape2)))
    # Set the combined shape based on the total dimensions
    if axes is None:
        if max:
            shape = np.max([shape1, shape2], axis=1)
        else:
            shape = shape1 + shape2
    else:
        shape = np.zeros(len(axes), dtype='int')
        try:
            len(axes)
        except TypeError:
            axes = [axes]
        for n, ax in enumerate(axes):
            shape[n] = shape1[ax] + shape2[ax]
            if max == True:
                shape[n] = np.max([shape1[ax], shape2[ax]])

    shape += padding
    # Use the next fastest shape in each dimension
    shape = [fftpack.helper.next_fast_len(s) for s in shape]
    # autograd.numpy.fft does not currently work
    # if the last dimension is odd
    while shape[-1] % 2 != 0:
        shape[-1] += 1
        shape[-1] = fftpack.helper.next_fast_len(shape[-1])
    if shape2[-2] % 2 == 0:
        while shape[-2] % 2 != 0:
            shape[-2] += 1
            shape[-2] = fftpack.helper.next_fast_len(shape[-2])

    return shape


class Fourier(object):
    """An array that stores its Fourier Transform

    The `Fourier` class is used for images that will make
    use of their Fourier Transform multiple times.
    In order to prevent numerical artifacts the same image
    convolved with different images might require different
    padding, so the FFT for each different shape is stored
    in a dictionary.
    """

    def __init__(self, image, image_fft=None):
        """Initialize the object

        Parameters
        ----------
        image: array
            The real space image.
        image_fft: dict
            A dictionary of {shape: fft_value} for which each different
            shape has a precalculated FFT.
        axes: int or tuple
            The dimension(s) of the array that will be transformed.
        """
        if image_fft is None:
            self._fft = {}
        else:
            self._fft = image_fft
        self._image = image

    @staticmethod
    def from_fft(image_fft, fft_shape, image_shape, axes=None):
        """Generate a new Fourier object from an FFT dictionary

        If the fft of an image has been generated but not its
        real space image (for example when creating a convolution kernel),
        this method can be called to create a new `Fourier` instance
        from the k-space representation.

        Parameters
        ----------
        image_fft: array
            The FFT of the image.
        fft_shape: tuple
            "Fast" shape of the image used to generate the FFT.
            This will be different than `image_fft.shape` if
            any of the dimensions are odd, since `np.fft.rfft`
            requires an even number of dimensions (for symmetry),
            so this tells `np.fft.irfft` how to go from
            complex k-space to real space.
        image_shape: tuple
            The shape of the image *before padding*.
            This will regenerate the image with the extra
            padding stripped.
        axes: int or tuple
            The dimension(s) of the array that will be transformed.

        Returns
        -------
        result: `Fourier`
            A `Fourier` object generated from the FFT.
        """
        if axes is None:
            axes = range(len(image_fft))
        all_axes = range(len(image_shape))
        image = np.fft.irfftn(image_fft, fft_shape, axes=axes)
        # Shift the center of the image from the bottom left to the center
        image = np.fft.fftshift(image, axes=axes)
        # Trim the image to remove the padding added
        # to reduce fft artifacts
        image = _centered(image, image_shape)
        key = (tuple(fft_shape), tuple(axes), tuple(all_axes))

        return Fourier(image, {key: image_fft})

    @property
    def image(self):
        """The real space image"""
        return self._image

    @property
    def shape(self):
        """The shape of the real space image"""
        return self._image.shape

    def fft(self, fft_shape, axes):
        """The FFT of an image for a given `fft_shape` along desired `axes`
        """
        try:
            iter(axes)
        except TypeError:
            axes = (axes,)
        all_axes = range(len(self.image.shape))
        fft_key = (tuple(fft_shape), tuple(axes), tuple(all_axes))

        # If this is the first time calling `fft` for this shape,
        # generate the FFT.
        if fft_key not in self._fft:
            if len(fft_shape) != len(axes):
                msg = "fft_shape self.axes must have the same number of dimensions, got {0}, {1}"
                raise ValueError(msg.format(fft_shape, axes))
            image = _pad(self.image, fft_shape, axes)
            self._fft[fft_key] = np.fft.rfftn(np.fft.ifftshift(image, axes), axes=axes)
        return self._fft[fft_key]

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):
        # Make the index a tuple
        if not hasattr(index, "__getitem__"):
            index = tuple([index])

        # Axes that are removed from the shape of the new object
        removed = np.array(
            [
                n
                for n, idx in enumerate(index)
                if not isinstance(idx, slice) and idx is not None
            ]
        )

        # Create views into the fft transformed values, appropriately adjusting
        # the shapes for the new axes

        fft_kernels = {
            (
                tuple(
                    [s for idx, s in enumerate(key[0]) if key[1][idx] not in removed]
                ),
                tuple(
                    [a for ida, a in enumerate(key[1]) if key[1][ida] not in removed]
                ),
                tuple(
                    [
                        aa
                        for idaa, aa in enumerate(key[2])
                        if key[2][idaa] not in removed
                    ]
                ),
            ): kernel[index]
            for key, kernel in self._fft.items()
        }
        return Fourier(self.image[index], fft_kernels)


def _kspace_operation(image1, image2, padding, op, shape, axes):
    """Combine two images in k-space using a given `operator`

    `image1` and `image2` are required to be `Fourier` objects and
    `op` should be an operator (either `operator.mul` for a convolution
    or `operator.truediv` for deconvolution). `shape` is the shape of the
    output image (`Fourier` instance).
    """
    if len(image1.shape) != len(image2.shape):
        msg = "Both images must have the same number of axes, got {0} and {1}"
        raise Exception(msg.format(len(image1.shape), len(image2.shape)))
    fft_shape = _get_fft_shape(image1.image, image2.image, padding, axes)
    convolved_fft = op(image1.fft(fft_shape, axes), image2.fft(fft_shape, axes))
    # why is shape not image1.shape? images are never padded
    convolved = Fourier.from_fft(convolved_fft, fft_shape, shape, axes)
    return convolved


def match_psfs(psf1, psf2, padding=3, axes=(-2, -1)):
    """Calculate the difference kernel between two psfs

    Parameters
    ----------
    psf1: `Fourier`
        `Fourier` object representing the psf and it's FFT.
    psf2: `Fourier`
        `Fourier` object representing the psf and it's FFT.
    padding: int
        Additional padding to use when generating the FFT
        to supress artifacts.
    axes: tuple or None
        Axes that contain the spatial information for the PSFs.
    """
    if psf1.shape[0] < psf2.shape[0]:
        shape = psf2.shape
    else:
        shape = psf1.shape
    return _kspace_operation(psf1, psf2, padding, operator.truediv, shape, axes=axes)


def convolve(image1, image2, padding=3, axes=(-2, -1)):
    """Convolve two images

    Parameters
    ----------
    image1: `Fourier`
        `Fourier` object represeting the image and it's FFT.
    image2: `Fourier`
        `Fourier` object represeting the image and it's FFT.
    padding: int
        Additional padding to use when generating the FFT
        to supress artifacts.
    """
    return _kspace_operation(
        image1, image2, padding, operator.mul, image1.shape, axes=axes
    )