import numpy as np import scipy.signal as ss import spikeextractors as se from pathlib import Path def check_signal_power_signal1_below_signal2(signals1, signals2, freq_range, fs): ''' Check that spectrum power of signal1 is below the one of signal 2 in the range freq_range ''' f1, pow1 = ss.welch(signals1, fs, nfft=1024) f2, pow2 = ss.welch(signals2, fs, nfft=1024) below = True for (p1, p2) in zip(pow1, pow2): r1_idxs = np.where((f1 > freq_range[0]) & (f1 <= freq_range[1])) r2_idxs = np.where((f2 > freq_range[0]) & (f2 <= freq_range[1])) sump1 = np.sum(p1[r1_idxs]) sump2 = np.sum(p2[r2_idxs]) if sump1 >= sump2: below = False break return below def create_wf(min_val=-100, max_val=50, n_samples=100): ''' Creates stereotyped waveform ''' wf = np.zeros(n_samples) inter = n_samples // 4 wf[:inter] = np.linspace(0, min_val, inter) wf[inter:3 * inter] = np.linspace(min_val, max_val, 2 * inter) wf[3 * inter:] = np.linspace(max_val, 0, n_samples - 3 * inter) return wf def generate_template_with_random_amps(n_ch, wf): ''' Creates stereotyped templates from waveform ''' amps = [] i = 1 found = False while len(amps) < n_ch - 1 and i < 1000: a = np.random.rand() i = i + 1 if a < 0.2 or a > 0.5: continue if sum(amps) + a < 0.9: amps.append(a) if len(amps) == n_ch - 1: amps.append(1 - sum(amps)) found = True template = np.zeros((n_ch, len(wf))) for i, a in enumerate(amps): template[i] = a * wf else: template = [] return template, amps, found def create_signal_with_known_waveforms(n_channels=4, n_waveforms=2, n_wf_samples=100, duration=5, fs=30000): ''' Creates stereotyped recording, sorting, with waveforms, templates, and max_chans ''' a_min = [-200, -50] a_max = [10, 50] wfs = [] # gen waveforms for w in range(n_waveforms): amp_min = np.random.randint(a_min[0], a_min[1]) amp_max = np.random.randint(a_max[0], a_max[1]) wf = create_wf(amp_min, amp_max, n_wf_samples) wfs.append(wf) # gen templates templates = [] max_chans = [] for wf in wfs: found = False while not found: template, amps, found = generate_template_with_random_amps(n_channels, wf) templates.append(template) max_chans.append(np.argmax(amps)) templates = np.array(templates) n_samples = int(fs * duration) # gen spiketrains interval = 10 * n_wf_samples times = np.arange(interval, duration * fs - interval, interval).astype(int) labels = np.zeros(len(times)).astype(int) for i, wf in enumerate(wfs): labels[i::len(wfs)] = i timeseries = np.zeros((n_channels, n_samples)) waveforms = [] amplitudes = [] for i, tem in enumerate(templates): idxs = np.where(labels == i) wav = [] amps = [] for t in times[idxs]: rand_val = np.random.randn() * 0.01 + 1 timeseries[:, t - n_wf_samples // 2:t + n_wf_samples // 2] = rand_val * tem wav.append(rand_val * tem) amps.append(np.min(rand_val * tem)) wav = np.array(wav) amps = np.array(amps) waveforms.append(wav) amplitudes.append(amps) rec = se.NumpyRecordingExtractor(timeseries=timeseries, sampling_frequency=fs) sort = se.NumpySortingExtractor() sort.set_times_labels(times=times, labels=labels) sort.set_sampling_frequency(fs) return rec, sort, waveforms, templates, max_chans, amplitudes def create_fake_waveforms_with_known_pc(): # HINT: start from Guassians in PC space and stereotyped waveforms and build dataset. pass def create_dumpable_extractors_from_existing(folder, RX, SX): folder = Path(folder) if 'location' not in RX.get_shared_channel_property_names(): RX.set_channel_locations(np.random.randn(RX.get_num_channels(), 2)) se.MdaRecordingExtractor.write_recording(RX, folder) RX_mda = se.MdaRecordingExtractor(folder) se.NpzSortingExtractor.write_sorting(SX, folder / 'sorting.npz') SX_npz = se.NpzSortingExtractor(folder / 'sorting.npz') return RX_mda, SX_npz