Python tensorflow.fft() Examples

The following are 14 code examples of tensorflow.fft(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow , or try the search function .
Example #1
Source File: audio.py    From vqvae-speech with MIT License 7 votes vote down vote up
def spectrogram(x, frame_length, nfft=1024):
  ''' Spectrogram of non-overlapping window '''
  with tf.name_scope('Spectrogram'):
    shape = tf.shape(x)
    b = shape[0]
    D = frame_length
    t = shape[1] // D
    x = tf.reshape(x, [b, t, D])

    window = tf.contrib.signal.hann_window(frame_length)
    window = tf.expand_dims(window, 0)
    window = tf.expand_dims(window, 0) # [1, 1, L]
    x = x * window

    pad = tf.zeros([b, t, nfft - D])
    x = tf.concat([x, pad], -1)
    x = tf.cast(x, tf.complex64)
    X = tf.fft(x)  # TF's API doesn't do padding automatically yet

    X = tf.log(tf.abs(X) + 1e-2)

    X = X[:, :, :nfft//2 + 1]
    X = tf.transpose(X, [0, 2, 1])
    X = tf.reverse(X, [1])
    X = tf.expand_dims(X, -1)

    X = (X - tf.reduce_min(X)) / (tf.reduce_max(X) - tf.reduce_min(X))
    X = gray2jet(X)

    tf.summary.image('spectrogram', X)
    return X 
Example #2
Source File: layers.py    From neuron with GNU General Public License v3.0 6 votes vote down vote up
def call(self, inputx):
        
        if not inputx.dtype in [tf.complex64, tf.complex128]:
            print('Warning: inputx is not complex. Converting.', file=sys.stderr)
        
            # if inputx is float, this will assume 0 imag channel
            inputx = tf.cast(inputx, tf.complex64)

        # get the right fft
        if self.ndims == 1:
            fft = tf.fft
        elif self.ndims == 2:
            fft = tf.fft2d
        else:
            fft = tf.fft3d

        perm_dims = [0, self.ndims + 1] + list(range(1, self.ndims + 1))
        invert_perm_ndims = [0] + list(range(2, self.ndims + 2)) + [1]
        
        perm_inputx = K.permute_dimensions(inputx, perm_dims)  # [batch_size, nb_features, *vol_size]
        fft_inputx = fft(perm_inputx)
        return K.permute_dimensions(fft_inputx, invert_perm_ndims) 
Example #3
Source File: layers.py    From neuron with GNU General Public License v3.0 6 votes vote down vote up
def call(self, inputx):
        
        if not inputx.dtype in [tf.complex64, tf.complex128]:
            print('Warning: inputx is not complex. Converting.', file=sys.stderr)
        
            # if inputx is float, this will assume 0 imag channel
            inputx = tf.cast(inputx, tf.complex64)
        
        # get the right fft
        if self.ndims == 1:
            ifft = tf.ifft
        elif self.ndims == 2:
            ifft = tf.ifft2d
        else:
            ifft = tf.ifft3d

        perm_dims = [0, self.ndims + 1] + list(range(1, self.ndims + 1))
        invert_perm_ndims = [0] + list(range(2, self.ndims + 2)) + [1]
        
        perm_inputx = K.permute_dimensions(inputx, perm_dims)  # [batch_size, nb_features, *vol_size]
        ifft_inputx = ifft(perm_inputx)
        return K.permute_dimensions(ifft_inputx, invert_perm_ndims) 
Example #4
Source File: compact_bilinear_pooling.py    From RGB-N with MIT License 6 votes vote down vote up
def _fft(bottom, sequential, compute_size):
    if sequential:
        return sequential_batch_fft(bottom, compute_size)
    else:
        return tf.fft(bottom) 
Example #5
Source File: tfmri.py    From dl-cs with MIT License 6 votes vote down vote up
def fftshift(im, axis=-1, name='fftshift'):
    """Perform fft shift.

    This function assumes that the axis to perform fftshift is divisible by 2.

    Args:
        axis: Integer or array of integers for axes to perform shift operation.
        name: TensorFlow name scope.

    Returns:
        Tensor with the contents fft shifted.
    """
    with tf.name_scope(name):
        if not hasattr(axis, '__iter__'):
            axis = [axis]
        output = im
        for a in axis:
            split0, split1 = tf.split(output, 2, axis=a)
            output = tf.concat((split1, split0), axis=a)

    return output 
Example #6
Source File: fft_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _npFFT(self, x, rank):
    if rank == 1:
      return np.fft.fft2(x, axes=(-1,))
    elif rank == 2:
      return np.fft.fft2(x, axes=(-2, -1))
    elif rank == 3:
      return np.fft.fft2(x, axes=(-3, -2, -1))
    else:
      raise ValueError("invalid rank") 
Example #7
Source File: fft_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _npIFFT(self, x, rank):
    if rank == 1:
      return np.fft.ifft2(x, axes=(-1,))
    elif rank == 2:
      return np.fft.ifft2(x, axes=(-2, -1))
    elif rank == 3:
      return np.fft.ifft2(x, axes=(-3, -2, -1))
    else:
      raise ValueError("invalid rank") 
Example #8
Source File: fft_ops_test.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def _tfFFTForRank(self, rank):
    if rank == 1:
      return tf.fft
    elif rank == 2:
      return tf.fft2d
    elif rank == 3:
      return tf.fft3d
    else:
      raise ValueError("invalid rank") 
Example #9
Source File: tfmri.py    From dl-cs with MIT License 5 votes vote down vote up
def fftc(im,
         data_format='channels_last',
         orthonorm=True,
         transpose=False,
         name='fftc'):
    """Centered FFT on last non-channel dimension."""
    with tf.name_scope(name):
        im_out = im
        if data_format == 'channels_last':
            permute_orig = np.arange(len(im.shape))
            permute = permute_orig.copy()
            permute[-2] = permute_orig[-1]
            permute[-1] = permute_orig[-2]
            im_out = tf.transpose(im_out, permute)

        if orthonorm:
            fftscale = tf.sqrt(tf.cast(im_out.shape[-1], tf.float32))
        else:
            fftscale = 1.0
        fftscale = tf.cast(fftscale, dtype=tf.complex64)

        im_out = fftshift(im_out, axis=-1)
        if transpose:
            im_out = tf.ifft(im_out) * fftscale
        else:
            im_out = tf.fft(im_out) / fftscale
        im_out = fftshift(im_out, axis=-1)

        if data_format == 'channels_last':
            im_out = tf.transpose(im_out, permute)

    return im_out 
Example #10
Source File: HolE.py    From KagNet with MIT License 5 votes vote down vote up
def _cconv(self, a, b):
		return tf.ifft(tf.fft(a) * tf.fft(b)).real 
Example #11
Source File: HolE.py    From KagNet with MIT License 5 votes vote down vote up
def _ccorr(self, a, b):
		a = tf.cast(a, tf.complex64)
		b = tf.cast(b, tf.complex64)
		return tf.real(tf.ifft(tf.conj(tf.fft(a)) * tf.fft(b))) 
Example #12
Source File: tfnp_compatibility.py    From spherical-cnn with MIT License 5 votes vote down vote up
def fft(x, *args, **kwargs):
    """ Return np.fft.fft or tf.fft according to input. """
    return (tf.fft(x, *args, **kwargs) if istf(x)
            else np.fft.fft(x, *args, **kwargs)) 
Example #13
Source File: HolE.py    From CPL with MIT License 5 votes vote down vote up
def _cconv(self, a, b):
		return tf.ifft(tf.fft(a) * tf.fft(b)).real 
Example #14
Source File: HolE.py    From CPL with MIT License 5 votes vote down vote up
def _ccorr(self, a, b):
		a = tf.cast(a, tf.complex64)
		b = tf.cast(b, tf.complex64)
		return tf.real(tf.ifft(tf.conj(tf.fft(a)) * tf.fft(b)))