# -*- coding: utf-8 -*- """ Created on Fri Mar 22 2019 @author: Alexandre Sauve Content : Helper for wavelet related functions """ import pywt import numpy as np import matplotlib.pyplot as plt import six DEFAULT_WAVELET = 'cmor1-1.5' def set_default_wavelet(wavelet): """Sets the default wavelet to be used for scaleograms""" global DEFAULT_WAVELET DEFAULT_WAVELET = wavelet def get_default_wavelet(): """Sets the default wavelet to be used for scaleograms""" global DEFAULT_WAVELET return DEFAULT_WAVELET def periods2scales(periods, wavelet=None, dt=1.0): """Helper function to convert periods values (in the pseudo period wavelet sense) to scales values Arguments --------- - periods : np.ndarray() of positive strictly increasing values The ``periods`` array should be consistent with the ``time`` array passed to ``cws()``. If no ``time`` values are provided the period on the scaleogram will be in sample units. Note: you should check that periods minimum value is larger than the duration of two data sample because the sectrum has no physical meaning bellow these values. - wavelet : pywt.ContinuousWavelet instance or string name dt=[1.0] : specify the time interval between two samples of data When no ``time`` array is passed to ``cws()``, there is no need to set this parameter and the default value of 1 is used. Note: for a scale value of ``s`` and a wavelet Central frequency ``C``, the period ``p`` is:: p = s / C Example : Build a spectrum with constant period bins in log space ------- import numpy as np import scaleogram as scg periods = np.logspace(np.log10(2), np.log10(100), 100) wavelet = 'cgau5' scales = periods2scales(periods, wavelet) data = np.random.randn(512) # gaussian noise scg.cws( data, scales=scales, wavelet=wavelet, yscale='log', title="CWT of gaussian noise with constant binning in Y logscale") """ if wavelet is None: wavelet = get_default_wavelet() if isinstance(wavelet, six.string_types): wavelet = pywt.ContinuousWavelet(wavelet) else: assert(isinstance(wavelet, pywt.ContinuousWavelet)) return (periods/dt) * pywt.central_frequency(wavelet) def get_wavlist(): """Returns the list of continuous wavelet functions available in the PyWavelets library. """ l = [] for name in pywt.wavelist(kind='continuous'): # supress warnings when the wavelet name is missing parameters completion = { 'cmor': 'cmor1.5-1.0', 'fbsp': 'fbsp1-1.5-1.0', 'shan': 'shan1.5-1.0' } if name in completion: name = completion[name]# supress warning l.append( name+" :\t"+pywt.ContinuousWavelet(name).family_name ) return l WAVLIST = get_wavlist() def child_wav(wavelet, scale): """Returns an array of complex values with the child wavelet used at the given ``scale``. The ``wavelet`` argument can be either a string like 'cmor1-1.5' or a ``pywt.ContinuousWavelet`` instance """ wavelet = _wavelet_instance(wavelet) # the following code has been extracted from pywt.cwt() 1.0.2 precision = 10 int_psi, x = pywt.integrate_wavelet(wavelet, precision=precision) step = x[1] - x[0] j = np.floor( np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)) if np.max(j) >= np.size(int_psi): j = np.delete(j, np.where((j >= np.size(int_psi)))[0]) return int_psi[j.astype(np.int)] def _wavelet_instance(wavelet): """Function responsible for returning the correct pywt.ContinuousWavelet """ if isinstance(wavelet, pywt.ContinuousWavelet): return wavelet if isinstance(wavelet, six.string_types): return pywt.ContinuousWavelet(wavelet) else: raise ValueError("Expecting a string name for the wavelet,"+ " or pywt.ContinuousWavelet. Got: "+str(wavelet)) def fastcwt(data, scales, wavelet, sampling_period=1.0, method='auto'): """ Compute the continuous wavelet transform (CWT) and has the same signature as ``pywt.cwt()`` but is faster for large signals length and scales. Parameters ---------- signal : array to compute the CWT on scales: dilatation factors for the CWT wavelet: wavelet name or pywt.ContinuousWavelet method=['auto'] | 'conv' | 'fft' for selecting the convolution method the `'auto'` keyword switch automatically to the best complexity at each scales. While the `'fft'` and `'conv'` uses `numpy.fft` and `numpy.conv` respectively. In practice the `'fft'` method is implemented by using the convolution theorem which states:: convolve(wav,sig) == ifft(fft(wav)*fft(sig)) Zero padding is adjusted to keep at bay circular convolution side effects. Example:: %time (coef1, freq1) = fastcwt(np.arange(140000), np.arange(2,200), 'cmorl1-1') => CPU times: user 12.6 s, sys: 2.2 s, total: 14.8 s => Wall time: 14.9 s %time (coef1, freq1) = pywt.cwt(np.arange(140000), np.arange(2,200), 'cmorl1-1') => CPU times: user 1min 50s, sys: 401 ms, total: 1min 51s => Wall time: 1min 51s """ # accept array_like input; make a copy to ensure a contiguous array data = np.array(data) if not isinstance(wavelet, (pywt.ContinuousWavelet, pywt.Wavelet)): wavelet = pywt.DiscreteContinuousWavelet(wavelet) if np.isscalar(scales): scales = np.array([scales]) dt_out = None # currently keep the 1.0.2 behaviour: TODO fix in/out dtype consistency if data.ndim == 1: if wavelet.complex_cwt: dt_out = complex out = np.zeros((np.size(scales), data.size), dtype=dt_out) precision = 10 int_psi, x = pywt.integrate_wavelet(wavelet, precision=precision) if method in ('auto', 'fft'): # - to be as large as the sum of data length and and maximum wavelet # support to avoid circular convolution effects # - additional padding to reach a power of 2 for CPU-optimal FFT size_pad = lambda s: 2**np.int(np.ceil(np.log2(s[0] + s[1]))) size_scale0 = size_pad( (len(data), np.take(scales, 0) * ((x[-1] - x[0]) + 1)) ) fft_data = None elif not method == 'conv': raise ValueError("method must be in: 'conv', 'fft' or 'auto'") for i in np.arange(np.size(scales)): step = x[1] - x[0] j = np.floor( np.arange(scales[i] * (x[-1] - x[0]) + 1) / (scales[i] * step)) if np.max(j) >= np.size(int_psi): j = np.delete(j, np.where((j >= np.size(int_psi)))[0]) int_psi_scale = int_psi[j.astype(np.int)][::-1] if method == 'conv': conv = np.convolve(data, int_psi_scale) else: size_scale = size_pad( (len(data), len(int_psi_scale)) ) if size_scale != size_scale0: # the fft of data changes when padding size changes thus # it has to be recomputed fft_data = None size_scale0 = size_scale nops_conv = len(data) * len(int_psi_scale) nops_fft = (2+(fft_data is None)) * size_scale * np.log2(size_scale) if (method == 'fft') or ((method == 'auto') and (nops_fft < nops_conv)): if fft_data is None: fft_data = np.fft.fft(data, size_scale) fft_wav = np.fft.fft(int_psi_scale, size_scale) conv = np.fft.ifft(fft_wav*fft_data) conv = conv[0:len(data)+len(int_psi_scale)-1] else: conv = np.convolve(data, int_psi_scale) coef = - np.sqrt(scales[i]) * np.diff(conv) if not np.iscomplexobj(out): coef = np.real(coef) d = (coef.size - data.size) / 2. if d > 0: out[i, :] = coef[int(np.floor(d)):int(-np.ceil(d))] elif d == 0.: out[i, :] = coef else: raise ValueError( "Selected scale of {} too small.".format(scales[i])) frequencies = pywt.scale2frequency(wavelet, scales, precision) if np.isscalar(frequencies): frequencies = np.array([frequencies]) for i in np.arange(len(frequencies)): frequencies[i] /= sampling_period return out, frequencies else: raise ValueError("Only dim == 1 supported") def plot_wav_time(wav=None, real=True, imag=True, figsize=None, ax=None, legend=True, clearx=False): """Plot wavelet representation in **time domain** see ``plot_wav()`` for parameters. """ if wav is None: wav = get_default_wavelet() wav = _wavelet_instance(wav) fun_wav, time = wav.wavefun() if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) # tme domain plot if real: ax.plot(time, fun_wav.real, label="real") if imag: ax.plot(time, fun_wav.imag, "r-", label="imag") if legend: ax.legend() ax.set_title(wav.name) if clearx: ax.set_xticks([]) else: ax.set_xlabel('Time (s)') #ax.set_ylabel("Amplitude") ax.set_ylim(-1, 1) return ax def plot_wav_freq(wav=None, figsize=None, ax=None, yscale='linear', annotate=True, clearx=False): """Plot wavelet representation in **frequency domain** see ``plot_wav()`` for parameters. """ if wav is None: wav = get_default_wavelet() wav = _wavelet_instance(wav) fun_wav, time = wav.wavefun() if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) # frequency domain plot df = 1 / (time[-1]-time[0]) freq = np.arange(len(time)) * df fft = np.abs(np.fft.fft(fun_wav)/np.sqrt(len(fun_wav))) ax.plot(freq, fft) #ax2.set_yscale('log') ax.set_xlim(0, df*len(freq)/2) ax.set_title("Frequency support") if clearx: ax.set_xticks([]) else: ax.set_xlabel("Frequency [Hz]") ax.set_yscale(yscale) ax.set_ylim(-0.1, 1.1) central_frequency = wav.center_frequency if not central_frequency: central_frequency = pywt.central_frequency(wav) bandwidth_frequency = wav.bandwidth_frequency if wav.bandwidth_frequency else 0 ax.plot(central_frequency*np.ones(2), ax.get_ylim()) if annotate: ax.annotate("central_freq=%0.1fHz\nbandwidth_param=%0.1f" % ( central_frequency, bandwidth_frequency), xy=(central_frequency, 0.5), xytext=(central_frequency+2, 0.6), arrowprops=dict(facecolor='black', shrink=0.01)) return ax def plot_wav(wav=None, figsize=None, axes=None, real=True, imag=True, yscale='linear', legend=True, annotate=True, clearx=False): if wav is None: wav = get_default_wavelet() wav = _wavelet_instance(wav) fun_wav, time = wav.wavefun() if axes is None: fig, (ax1, ax2)= plt.subplots(1, 2, figsize=figsize) else: ax1, ax2 = axes plot_wav_time(wav, real=real, imag=imag, ax=ax1, legend=legend, clearx=clearx) plot_wav_freq(wav, yscale=yscale, ax=ax2, annotate=annotate, clearx=clearx) return ax1, ax2 plot_wav.__doc__ = """ Quick helper function to check visually the properties of a wavelet in time domain and the filter view in frequency domain. Parameters ---------- - wav : continuous wavelate name or pywt.ContinuousWavelet If not provided, then the default wavelet for ``cws()`` is used. - axes= (ax1, ax2) : allow to customize the plot destinations - figsize=(width,eight) : forward the size (inches) to matplotlib If this parameter is provided, a new figure is created under the hood for display. It is only used if axes is absent - real= [True]/False : plot real part in time domain - imag= [True]/False : plot imaginary part in time domain - yscale=['linear']|'log' allow to select Y axis scale in frequency domain Returns ------- - ax1, ax2 : matplotlib graphics elements Continuous Wavelet list ----------------------- - """+("\n- ".join(WAVLIST)) def plot_wavelets(wavlist=None, figsize=None): """Plot the matrix of all available continuous wavelets """ wavlist = WAVLIST if wavlist is None else wavlist names = [ desc.split()[0] for desc in wavlist ] ncol = 4 nrow = int((len(names)+1)/2) figsize = (12, 1.5*nrow) if figsize is None else figsize fig, axes = plt.subplots(nrow, ncol, figsize=figsize) plt.subplots_adjust( hspace=0.5, wspace=0.25 ) axes = [ item for sublist in axes for item in sublist ] # flatten list for i in range(int(len(names))): plot_wav(names[i], axes=(axes[i*2], axes[i*2+1]), legend=False, annotate=False, clearx=True) #if __name__ == '__main__': # plot_wav() # plot_wavelets() # plt.draw() # plt.show()