"""
This file contains the STFT function and related helper functions.
"""
import numpy as np
from math import ceil
import scipy

from scipy import signal
from numpy.fft import rfft, irfft

import string

from nara_wpe.wpe import segment_axis as segment_axis_v2


# http://stackoverflow.com/a/3153267
def roll_zeropad(a, shift, axis=None):
    """
    Roll array elements along a given axis.

    Elements off the end of the array are treated as zeros.

   Args:
        a: array_like
            Input array.
        shift: int
            The number of places by which elements are shifted.
        axis (int): optional,
            The axis along which elements are shifted.  By default, the array
            is flattened before shifting, after which the original
            shape is restored.

    Returns:
        ndarray: Output array, with the same shape as `a`.

    Note:
        roll     : Elements that roll off one end come back on the other.
        rollaxis : Roll the specified axis backwards, until it lies in a
                   given position.

    Examples:
        >>> x = np.arange(10)
        >>> roll_zeropad(x, 2)
        array([0, 0, 0, 1, 2, 3, 4, 5, 6, 7])
        >>> roll_zeropad(x, -2)
        array([2, 3, 4, 5, 6, 7, 8, 9, 0, 0])

        >>> x2 = np.reshape(x, (2,5))
        >>> x2
        array([[0, 1, 2, 3, 4],
               [5, 6, 7, 8, 9]])
        >>> roll_zeropad(x2, 1)
        array([[0, 0, 1, 2, 3],
               [4, 5, 6, 7, 8]])
        >>> roll_zeropad(x2, -2)
        array([[2, 3, 4, 5, 6],
               [7, 8, 9, 0, 0]])
        >>> roll_zeropad(x2, 1, axis=0)
        array([[0, 0, 0, 0, 0],
               [0, 1, 2, 3, 4]])
        >>> roll_zeropad(x2, -1, axis=0)
        array([[5, 6, 7, 8, 9],
               [0, 0, 0, 0, 0]])
        >>> roll_zeropad(x2, 1, axis=1)
        array([[0, 0, 1, 2, 3],
               [0, 5, 6, 7, 8]])
        >>> roll_zeropad(x2, -2, axis=1)
        array([[2, 3, 4, 0, 0],
               [7, 8, 9, 0, 0]])

        >>> roll_zeropad(x2, 50)
        array([[0, 0, 0, 0, 0],
               [0, 0, 0, 0, 0]])
        >>> roll_zeropad(x2, -50)
        array([[0, 0, 0, 0, 0],
               [0, 0, 0, 0, 0]])
        >>> roll_zeropad(x2, 0)
        array([[0, 1, 2, 3, 4],
               [5, 6, 7, 8, 9]])

    """
    a = np.asanyarray(a)
    if shift == 0:
        return a
    if axis is None:
        n = a.size
        reshape = True
    else:
        n = a.shape[axis]
        reshape = False
    if np.abs(shift) > n:
        res = np.zeros_like(a)
    elif shift < 0:
        shift += n
        zeros = np.zeros_like(a.take(np.arange(n - shift), axis))
        res = np.concatenate((a.take(np.arange(n - shift, n), axis), zeros),
                             axis)
    else:
        zeros = np.zeros_like(a.take(np.arange(n - shift, n), axis))
        res = np.concatenate((zeros, a.take(np.arange(n - shift), axis)), axis)
    if reshape:
        return res.reshape(a.shape)
    else:
        return res


def stft(
        time_signal,
        size,
        shift,
        axis=-1,
        window=signal.blackman,
        window_length=None,
        fading=True,
        pad=True,
        symmetric_window=False,
):
    """
    ToDo: Open points:
     - sym_window need literature
     - fading why it is better?
     - should pad have more degrees of freedom?

    Calculates the short time Fourier transformation of a multi channel multi
    speaker time signal. It is able to add additional zeros for fade-in and
    fade out and should yield an STFT signal which allows perfect
    reconstruction.

    Args:
        time_signal: Multi channel time signal with dimensions
            AA x ... x AZ x T x BA x ... x BZ.
        size: Scalar FFT-size.
        shift: Scalar FFT-shift, the step between successive frames in
            samples. Typically shift is a fraction of size.
        axis: Scalar axis of time.
            Default: None means the biggest dimension.
        window: Window function handle. Default is blackman window.
        fading: Pads the signal with zeros for better reconstruction.
        window_length: Sometimes one desires to use a shorter window than
            the fft size. In that case, the window is padded with zeros.
            The default is to use the fft-size as a window size.
        pad: If true zero pad the signal to match the shape, else cut
        symmetric_window: symmetric or periodic window. Assume window is
            periodic. Since the implementation of the windows in scipy.signal have a
            curious behaviour for odd window_length. Use window(len+1)[:-1]. Since
            is equal to the behaviour of MATLAB.

    Returns:
        Single channel complex STFT signal with dimensions
            AA x ... x AZ x T' times size/2+1 times BA x ... x BZ.
    """
    time_signal = np.array(time_signal)

    axis = axis % time_signal.ndim

    if window_length is None:
        window_length = size

    # Pad with zeros to have enough samples for the window function to fade.
    if fading:
        pad_width = np.zeros((time_signal.ndim, 2), dtype=np.int)
        pad_width[axis, :] = window_length - shift
        time_signal = np.pad(time_signal, pad_width, mode='constant')

    if isinstance(window, str):
        window = getattr(signal.windows, window)

    if symmetric_window:
        window = window(window_length)
    else:
        # https://github.com/scipy/scipy/issues/4551
        window = window(window_length + 1)[:-1]

    time_signal_seg = segment_axis_v2(
        time_signal,
        window_length,
        shift=shift,
        axis=axis,
        end='pad' if pad else 'cut'
    )

    letters = string.ascii_lowercase[:time_signal_seg.ndim]
    mapping = letters + ',' + letters[axis + 1] + '->' + letters

    try:
        # ToDo: Implement this more memory efficient
        return rfft(
            np.einsum(mapping, time_signal_seg, window),
            n=size,
            axis=axis + 1
        )
    except ValueError as e:
        raise ValueError(
            'Could not calculate the stft, something does not match.\n' +
            'mapping: {}, '.format(mapping) +
            'time_signal_seg.shape: {}, '.format(time_signal_seg.shape) +
            'window.shape: {}, '.format(window.shape) +
            'size: {}'.format(size) +
            'axis+1: {axis+1}'
        )


def _samples_to_stft_frames(
        samples,
        size,
        shift,
        *,
        pad=True,
        fading=False,
):
    """
    Calculates number of STFT frames from number of samples in time domain.

    Args:
        samples: Number of samples in time domain.
        size: FFT size.
            window_length often equal to FFT size. The name size should be
            marked as deprecated and replaced with window_length.
        shift: Hop in samples.
        pad: See stft.
        fading: See stft. Note to keep old behavior, default value is False.

    Returns:
        Number of STFT frames.

    >>> _samples_to_stft_frames(19, 16, 4)
    2
    >>> _samples_to_stft_frames(20, 16, 4)
    2
    >>> _samples_to_stft_frames(21, 16, 4)
    3

    >>> stft(np.zeros(19), 16, 4, fading=False).shape
    (2, 9)
    >>> stft(np.zeros(20), 16, 4, fading=False).shape
    (2, 9)
    >>> stft(np.zeros(21), 16, 4, fading=False).shape
    (3, 9)

    >>> _samples_to_stft_frames(19, 16, 4, fading=True)
    8
    >>> _samples_to_stft_frames(20, 16, 4, fading=True)
    8
    >>> _samples_to_stft_frames(21, 16, 4, fading=True)
    9

    >>> stft(np.zeros(19), 16, 4).shape
    (8, 9)
    >>> stft(np.zeros(20), 16, 4).shape
    (8, 9)
    >>> stft(np.zeros(21), 16, 4).shape
    (9, 9)

    >>> _samples_to_stft_frames(21, 16, 3, fading=True)
    12
    >>> stft(np.zeros(21), 16, 3).shape
    (12, 9)
    >>> _samples_to_stft_frames(21, 16, 3)
    3
    >>> stft(np.zeros(21), 16, 3, fading=False).shape
    (3, 9)
    """
    if fading:
        samples = samples + 2 * (size - shift)

    # I changed this from np.ceil to math.ceil, to yield an integer result.
    frames = (samples - size + shift) / shift
    if pad:
        return ceil(frames)
    return int(frames)


def _stft_frames_to_samples(frames, size, shift):
    """
    Calculates samples in time domain from STFT frames

    Args:
        frames: Number of STFT frames.
        size: FFT size.
        shift: Hop in samples.

    Returns:
        Number of samples in time domain.
    """
    return frames * shift + size - shift


def _biorthogonal_window_brute_force(analysis_window, shift,
                                     use_amplitude=False):
    """
    The biorthogonal window (synthesis_window) must verify the criterion:
        synthesis_window * analysis_window plus it's shifts must be one.
        1 == sum m from -inf to inf over (synthesis_window(n - mB) * analysis_window(n - mB))
        B ... shift
        n ... time index
        m ... shift index

    Args:
        analysis_window:
        shift:

    """
    size = len(analysis_window)

    influence_width = (size - 1) // shift

    denominator = np.zeros_like(analysis_window)

    if use_amplitude:
        analysis_window_square = analysis_window
    else:
        analysis_window_square = analysis_window ** 2
    for i in range(-influence_width, influence_width + 1):
        denominator += roll_zeropad(analysis_window_square, shift * i)

    if use_amplitude:
        synthesis_window = 1 / denominator
    else:
        synthesis_window = analysis_window / denominator
    return synthesis_window


_biorthogonal_window_fastest = _biorthogonal_window_brute_force


def istft(
        stft_signal,
        size=1024,
        shift=256,
        window=signal.blackman,
        fading=True,
        window_length=None,
        symmetric_window=False,
):
    """
    Calculated the inverse short time Fourier transform to exactly reconstruct
    the time signal.

    Notes:
        Be careful if you make modifications in the frequency domain (e.g.
        beamforming) because the synthesis window is calculated according to
        the unmodified! analysis window.

    Args:
        stft_signal: Single channel complex STFT signal
            with dimensions (..., frames, size/2+1).
        size: Scalar FFT-size.
        shift: Scalar FFT-shift. Typically shift is a fraction of size.
        window: Window function handle.
        fading: Removes the additional padding, if done during STFT.
        window_length: Sometimes one desires to use a shorter window than
            the fft size. In that case, the window is padded with zeros.
            The default is to use the fft-size as a window size.
        symmetric_window: symmetric or periodic window. Assume window is
            periodic. Since the implementation of the windows in scipy.signal have a
            curious behaviour for odd window_length. Use window(len+1)[:-1]. Since
            is equal to the behaviour of MATLAB.

    Returns:
        Single channel complex STFT signal
        Single channel time signal.
    """
    # Note: frame_axis and frequency_axis would make this function much more
    #       complicated
    stft_signal = np.array(stft_signal)

    assert stft_signal.shape[-1] == size // 2 + 1, str(stft_signal.shape)

    if window_length is None:
        window_length = size

    if isinstance(window, str):
        window = getattr(signal.windows, window)

    if symmetric_window:
        window = window(window_length)
    else:
        window = window(window_length + 1)[:-1]

    window = _biorthogonal_window_fastest(window, shift)

    # window = _biorthogonal_window_fastest(
    #     window, shift, use_amplitude_for_biorthogonal_window)
    # if disable_sythesis_window:
    #     window = np.ones_like(window)

    time_signal = np.zeros(
        list(stft_signal.shape[:-2]) +
        [stft_signal.shape[-2] * shift + window_length - shift]
    )

    # Get the correct view to time_signal
    time_signal_seg = segment_axis_v2(
        time_signal, window_length, shift, end=None
    )

    # Unbuffered inplace add
    np.add.at(
        time_signal_seg,
        Ellipsis,
        window * np.real(irfft(stft_signal))[..., :window_length]
    )
    # The [..., :window_length] is the inverse of the window padding in rfft.

    # Compensate fade-in and fade-out
    if fading:
        time_signal = time_signal[
            ..., window_length - shift:time_signal.shape[-1] - (window_length - shift)]

    return time_signal


def istft_single_channel(stft_signal, size=1024, shift=256,
          window=signal.blackman, fading=True, window_length=None,
          use_amplitude_for_biorthogonal_window=False,
          disable_sythesis_window=False):
    """
    Calculated the inverse short time Fourier transform to exactly reconstruct
    the time signal.

    Notes:
        Be careful if you make modifications in the frequency domain (e.g.
        beamforming) because the synthesis window is calculated according to the
        unmodified! analysis window.

    Args:
        stft_signal: Single channel complex STFT signal
            with dimensions frames times size/2+1.
        size: Scalar FFT-size.
        shift: Scalar FFT-shift. Typically shift is a fraction of size.
        window: Window function handle.
        fading: Removes the additional padding, if done during STFT.
        window_length: Sometimes one desires to use a shorter window than
            the fft size. In that case, the window is padded with zeros.
            The default is to use the fft-size as a window size.

    Returns:
        Single channel complex STFT signal
        Single channel time signal.
    """
    assert stft_signal.shape[1] == size // 2 + 1, str(stft_signal.shape)

    if window_length is None:
        window = window(size)
    else:
        window = window(window_length)
        window = np.pad(window, (0, size-window_length), mode='constant')
    window = _biorthogonal_window_fastest(window, shift,
                                          use_amplitude_for_biorthogonal_window)
    if disable_sythesis_window:
        window = np.ones_like(window)

    time_signal = scipy.zeros(stft_signal.shape[0] * shift + size - shift)

    for j, i in enumerate(range(0, len(time_signal) - size + shift, shift)):
        time_signal[i:i + size] += window * np.real(irfft(stft_signal[j]))

    # Compensate fade-in and fade-out
    if fading:
        time_signal = time_signal[size-shift:len(time_signal)-(size-shift)]

    return time_signal


def stft_to_spectrogram(stft_signal):
    """
    Calculates the power spectrum (spectrogram) of an stft signal.
    The output is guaranteed to be real.

    Args:
        stft_signal: Complex STFT signal with dimensions
            #time_frames times #frequency_bins.

    Returns:
        Real spectrogram with same dimensions as input.
    """
    spectrogram = stft_signal.real**2 + stft_signal.imag**2
    return spectrogram


def spectrogram(time_signal, *args, **kwargs):
    """
    Thin wrapper of stft with power spectrum calculation.

    Args:
        time_signal:
        *args:
        **kwargs:

    Returns:

    """
    return stft_to_spectrogram(stft(time_signal, *args, **kwargs))


def spectrogram_to_energy_per_frame(spectrogram):
    """
    The energy per frame is sometimes used as an additional feature to the MFCC
    features. Here, it is calculated from the power spectrum.

    Args:
        spectrogram: Real valued power spectrum.

    Returns:
        Real valued energy per frame.
    """
    energy = np.sum(spectrogram, 1)

    # If energy is zero, we get problems with log
    energy = np.where(energy == 0, np.finfo(float).eps, energy)
    return energy


def get_stft_center_frequencies(size=1024, sample_rate=16000):
    """
    It is often necessary to know, which center frequency is
    represented by each frequency bin index.

    Args:
        size: Scalar FFT-size.
        sample_rate: Scalar sample frequency in Hertz.

    Returns:
        Array of all relevant center frequencies
    """
    frequency_index = np.arange(0, size/2 + 1)
    return frequency_index * sample_rate / size