# Author: Jean-Baptiste Schiratti <jean.baptiste.schiratti@gmail.com> # Alexandre Gramfort <alexandre.gramfort@inria.fr> # License: BSD 3 clause import numpy as np from numpy.testing import assert_almost_equal, assert_equal, assert_raises from scipy import signal from mne_features.utils import (_idxiter, power_spectrum, _embed, _filt, _psd_params_checker) rng = np.random.RandomState(42) sfreq = 172. data = rng.standard_normal((20, int(sfreq))) def test_psd(): n_channels, n_times = data.shape _data = data[None, ...] # Only test output shape when `method='welch'` or `method='multitaper'` # since it is actually just a wrapper for MNE functions: psd_welch, _ = power_spectrum(sfreq, _data, psd_method='welch') psd_multitaper, _ = power_spectrum(sfreq, _data, psd_method='multitaper') psd_fft, freqs_fft = power_spectrum(sfreq, _data, psd_method='fft') assert_equal(psd_welch.shape, (1, n_channels, n_times // 2 + 1)) assert_equal(psd_multitaper.shape, (1, n_channels, n_times // 2 + 1)) assert_equal(psd_fft.shape, (1, n_channels, n_times // 2 + 1)) # Compare result obtained with `method='fft'` to the Scipy's result # (implementation of Welch's method with rectangular window): expected_freqs, expected_psd = signal.welch(data, sfreq, window=signal.get_window( 'boxcar', data.shape[-1]), return_onesided=True, scaling='density') assert_almost_equal(expected_freqs, freqs_fft) assert_almost_equal(expected_psd, psd_fft[0, ...]) def test_idxiter(): n_channels = data.shape[0] # Upper-triangular part, including diag idx0, idx1 = np.triu_indices(n_channels) triu_indices = np.array([np.arange(idx0.size), idx0, idx1]) triu_indices2 = np.array(list(_idxiter(n_channels, include_diag=True))) # Upper-triangular part, without diag idx2, idx3 = np.triu_indices(n_channels, 1) triu_indices_nodiag = np.array([np.arange(idx2.size), idx2, idx3]) triu_indices2_nodiag = np.array(list(_idxiter(n_channels, include_diag=False))) assert_almost_equal(triu_indices, triu_indices2.transpose()) assert_almost_equal(triu_indices_nodiag, triu_indices2_nodiag.transpose()) # Upper and lower-triangular parts, without diag expected = [(i, j) for _, (i, j) in enumerate(np.ndindex((n_channels, n_channels))) if i != j] assert_equal(np.array([(i, j) for _, i, j in _idxiter(n_channels, triu=False)]), expected) def test_embed(): d, tau = 10, 10 emb_data = _embed(data, d=d, tau=tau) expected = np.concatenate([data[..., None, j + tau * np.arange(d)] for j in range(data.shape[-1] - (d - 1) * tau)], axis=data.ndim - 1) assert_almost_equal(emb_data, expected) def test_filt(): filt_low_pass = _filt(sfreq, data, [None, 50.]) filt_bandpass = _filt(sfreq, data, [1., 70.]) assert_equal(filt_low_pass.shape, data.shape) assert_equal(filt_bandpass.shape, data.shape) def test_psd_params_checker(): valid_params = {'welch_n_fft': 2048, 'welch_n_per_seg': 1024} assert_equal(valid_params, _psd_params_checker(valid_params)) assert_equal(dict(), _psd_params_checker(None)) with assert_raises(ValueError): invalid_params1 = {'n_fft': 1024, 'psd_method': 'fft'} _psd_params_checker(invalid_params1) with assert_raises(ValueError): invalid_params2 = [1024, 1024] _psd_params_checker(invalid_params2) if __name__ == '__main__': test_psd() test_idxiter() test_embed() test_filt() test_psd_params_checker()