#!/usr/bin/env python """STFT feature extractors""" import numpy as np from librosa import stft, magphase from librosa import amplitude_to_db, get_duration from librosa.util import fix_length from .base import FeatureExtractor from ._utils import phase_diff, to_dtype __all__ = ['STFT', 'STFTMag', 'STFTPhaseDiff'] class STFT(FeatureExtractor): '''Short-time Fourier Transform (STFT) with both magnitude and phase. Attributes ---------- name : str The name of this transformer sr : number > 0 The sampling rate of audio hop_length : int > 0 The hop length of STFT frames n_fft : int > 0 The number of FFT bins per frame log : bool If `True`, scale magnitude in decibels. Otherwise use linear magnitude. conv : str Convolution mode dtype : np.dtype The data type for the output features. Default is `float32`. Setting to `uint8` will produce quantized features. See Also -------- STFTMag STFTPhaseDiff ''' def __init__(self, name, sr, hop_length, n_fft, log=False, conv=None, dtype='float32'): super(STFT, self).__init__(name, sr, hop_length, conv=conv, dtype=dtype) self.n_fft = n_fft self.log = log self.register('mag', 1 + n_fft // 2, self.dtype) self.register('phase', 1 + n_fft // 2, self.dtype) def transform_audio(self, y): '''Compute the STFT magnitude and phase. Parameters ---------- y : np.ndarray The audio buffer Returns ------- data : dict data['mag'] : np.ndarray, shape=(n_frames, 1 + n_fft//2) STFT magnitude data['phase'] : np.ndarray, shape=(n_frames, 1 + n_fft//2) STFT phase ''' n_frames = self.n_frames(get_duration(y=y, sr=self.sr)) D = stft(y, hop_length=self.hop_length, n_fft=self.n_fft) D = fix_length(D, n_frames) mag, phase = magphase(D) if self.log: mag = amplitude_to_db(mag, ref=np.max) return {'mag': to_dtype(mag.T[self.idx], self.dtype), 'phase': to_dtype(np.angle(phase.T)[self.idx], self.dtype)} class STFTPhaseDiff(STFT): '''STFT with phase differentials See Also -------- STFT ''' def __init__(self, *args, **kwargs): super(STFTPhaseDiff, self).__init__(*args, **kwargs) phase_field = self.pop('phase') self.register('dphase', 1 + self.n_fft // 2, phase_field.dtype) def transform_audio(self, y): '''Compute the STFT magnitude and phase differential. Parameters ---------- y : np.ndarray The audio buffer Returns ------- data : dict data['mag'] : np.ndarray, shape=(n_frames, 1 + n_fft//2) STFT magnitude data['dphase'] : np.ndarray, shape=(n_frames, 1 + n_fft//2) STFT phase ''' n_frames = self.n_frames(get_duration(y=y, sr=self.sr)) D = stft(y, hop_length=self.hop_length, n_fft=self.n_fft) D = fix_length(D, n_frames) mag, phase = magphase(D) if self.log: mag = amplitude_to_db(mag, ref=np.max) phase = phase_diff(np.angle(phase.T)[self.idx], self.conv) return {'mag': to_dtype(mag.T[self.idx], self.dtype), 'dphase': to_dtype(phase, self.dtype)} class STFTMag(STFT): '''STFT with only magnitude. See Also -------- STFT ''' def __init__(self, *args, **kwargs): super(STFTMag, self).__init__(*args, **kwargs) self.pop('phase') def transform_audio(self, y): '''Compute the STFT Parameters ---------- y : np.ndarray The audio buffer Returns ------- data : dict data['mag'] : np.ndarray, shape=(n_frames, 1 + n_fft//2) The STFT magnitude ''' data = super(STFTMag, self).transform_audio(y) data.pop('phase') return data