import copy import csv import gc import glob import itertools import logging import os import random import time import traceback import warnings from collections import defaultdict from datetime import datetime from functools import partial from itertools import product, groupby import matplotlib.pyplot as plt import mne import numpy as np # import pandas from mne.beamformer import lcmv from mne.time_frequency import compute_epochs_csd from scipy.ndimage import gaussian_filter1d from scipy.optimize import leastsq, minimize from sklearn import mixture from sklearn.datasets.base import Bunch from sklearn.linear_model import Ridge, RidgeCV, Lasso, LassoCV, ElasticNetCV # from src import dtw from src.utils import utils from src.misc.beamformers import tf_dics as tf from src.preproc import meg from src.preproc.meg import (calc_cov, calc_csd, get_cond_fname, get_file_name, make_forward_solution_to_specific_points, TASKS) warnings.filterwarnings("ignore") try: from tqdm import tqdm except: print('no tqdm!') LINKS_DIR = utils.get_links_dir() SUBJECTS_MEG_DIR = os.path.join(LINKS_DIR, 'meg') SUBJECTS_MRI_DIR = utils.get_link_dir(LINKS_DIR, 'subjects', 'SUBJECTS_DIR') FREESURFER_HOME = utils.get_link_dir(LINKS_DIR, 'freesurfer', 'FREESURFER_HOME') BLENDER_ROOT_DIR = os.path.join(LINKS_DIR, 'mmvt_addon') # Setting frequency bins as in Dalal et al. 2008 # BAD_ELECS = {'noninterference':['RAF4-RAF3', 'RAF5-RAF4', 'RAF8-RAF7', 'RAT8-RAT7'], 'interference': ['RAF8-RAF7', 'RAF4-RAF3']} # BAD_ELECS = ['RAF4-RAF3', 'RAF5-RAF4', 'RAF8-RAF7', 'RAT8-RAT7'] BAD_ELECS = [] OUTLIERS = {'neutral': ['RPT8-RPT7', 'RMT3-RMT2', 'RPT2-RPT1', 'LPT2-LPT1'], 'interference': []} ERROR_RECONSTRUCT_METHODS = ['RMS', 'RMSN', 'rol_corr'] # 'maxabs', 'dtw', 'diff_rms', 'diff' def load_all_subcorticals(subject_meg_fol, sub_corticals_codes_file, cond, from_t, to_t, normalize=True, all_vertices=False, inverse_method='lcmv'): sub_corticals = utils.read_sub_corticals_code_file(sub_corticals_codes_file) meg_data = {} for sub_cortical_index in sub_corticals: sub_cortical, _ = utils.get_numeric_index_to_label(sub_cortical_index, None, FREESURFER_HOME) meg_data_file_name = '{}-{}-{}{}.npy'.format(cond, sub_cortical, inverse_method, '-all-vertices' if all_vertices else '') data = np.load(os.path.join(subject_meg_fol, 'subcorticals', meg_data_file_name)) data = data.T data = data[from_t: to_t] if normalize: data = data * 1/max(data) meg_data[sub_cortical] = data return meg_data def read_vars(events_id, region, read_csd=True, read_cov=True): for event in events_id.keys(): forward, data_cov, noise_cov, noise_csd, data_csd = None, None, None, None, None if not region is None: forward = mne.read_forward_solution(get_cond_fname(FWD_X, event, region=region)) #, surf_ori=True) epochs = mne.read_epochs(get_cond_fname(EPO, event)) evoked = mne.read_evokeds(get_cond_fname(EVO, event), baseline=(None, 0))[0] if read_cov: noise_cov = calc_cov(get_cond_fname(DATA_COV, event), event, epochs, None, 0) data_cov = calc_cov(get_cond_fname(NOISE_COV, event), event, epochs, 0.0, 1.0) if read_csd: noise_csd = calc_csd(NOISE_CSD, event, epochs, -0.5, 0., mode='multitaper', fmin=6, fmax=10, overwrite=False) data_csd = calc_csd(DATA_CSD, event, epochs, 0.0, 1.0, mode='multitaper', fmin=6, fmax=10, overwrite=False) yield event, forward, evoked, epochs, data_cov, noise_cov, data_csd, noise_csd def get_fif_name(subject, fname_format, raw_cleaning_method, constrast, task): root_dir = os.path.join(SUBJECTS_MEG_DIR, task, subject) return partial(get_file_name, root_dir=root_dir, fname_format=fname_format, subject=subject, file_type='fif', raw_cleaning_method=raw_cleaning_method, constrast=constrast) def init_msit(subject, constrast, raw_cleaning_method, region): events_id = dict(interference=1) # dict(interference=1, neutral=2) fname_format = '{subject}_msit_{raw_cleaning_method}_{constrast}_{cond}_1-15-{ana_type}.{file_type}' _get_fif_name = get_fif_name(subject, fname_format, raw_cleaning_method, constrast, 'MSIT') fwd_fname = _get_fif_name('{region}-fwd')#.format(region)) evo_fname = _get_fif_name('ave') epo_fname = _get_fif_name('epo') data_cov_fname = _get_fif_name('data-cov') noise_cov_fnmae = _get_fif_name('noise-cov') return events_id, fwd_fname, evo_fname, epo_fname, data_cov_fname, noise_cov_fnmae def calc_electrode_fwd(subject, electrode, events_id, bipolar=False, overwrite_fwd=False, read_if_exist=False, n_jobs=4): fwd_elec_fnames = [get_cond_fname(FWD_X, cond, region=electrode) for cond in events_id.keys()] if not np.all([os.path.isfile(fname) for fname in fwd_elec_fnames]) or overwrite_fwd: names, pos, org_pos = get_electrodes_positions(subject, bipolar) if bipolar and org_pos is None: raise Exception('bipolar and org_pos is None!') index = np.where(names==electrode)[0][0] elec_pos = org_pos[index] if bipolar else np.array([pos[index]]) fwds = make_forward_solution_to_specific_points(subject, events_id, elec_pos, electrode, EPO, EVO, FWD_X, n_jobs=n_jobs, usingEEG=True) else: if read_if_exist: fwds = [] for fwd_fname in fwd_elec_fnames: fwds.append(mne.read_forward_solution(fwd_fname)) else: fwds = [None] * len(events_id) return fwds[0] if len(events_id) == 1 else fwds def calc_electrodes_fwd(subject, electrodes, events_id, bipolar=False, overwrite_fwd=False, read_if_exist=False, n_jobs=4): region = 'bipolar_electrodes' if bipolar else 'regular_electrodes' fwd_elec_fnames = [get_cond_fname(FWD_X, cond, region=region) for cond in events_id.keys()] if not np.all([os.path.isfile(fname) for fname in fwd_elec_fnames]) or overwrite_fwd: names, pos, org_pos = get_electrodes_positions(subject, bipolar) if bipolar and org_pos is None: raise Exception('bipolar and org_pos is None!') elec_pos = org_pos if bipolar else np.array(pos) fwds = make_forward_solution_to_specific_points(subject, events_id, elec_pos, region, EPO, EVO, FWD_X, n_jobs=n_jobs, usingEEG=True) else: if read_if_exist: fwds = [] for fwd_fname in fwd_elec_fnames: fwds.append(mne.read_forward_solution(fwd_fname)) else: fwds = [None] * len(events_id) return fwds[0] if len(events_id) == 1 else fwds # def calc_bipolar_electrode_fwd(subject, electrode, n_jobs=4): # electrodes_biplor, elecs_pos_biplor, elecs_pos_org_biplor = get_electrodes_positions(subject, bipolar=True) # for elec, elec_pos, elec_pos_org in zip(electrodes_biplor, elecs_pos_biplor, elecs_pos_org_biplor): # if elec==electrode: # fwd = make_forward_solution_to_specific_points(events_id, elec_pos_org, elec, EPO, FWD_X, # n_jobs=n_jobs, usingEEG=True) # return fwd def calc_all_electrodes_fwd(subject, events_id, overwrite_fwd=False, n_jobs=6): electrodes, elecs_pos, _ = get_electrodes_positions(subject, bipolar=False) electrodes_biplor, elecs_pos_biplor, elecs_pos_org_biplor = get_electrodes_positions(subject, bipolar=True) for elec, elec_pos in zip(electrodes, elecs_pos): if not check_if_fwd_exist(elec, events_id) or overwrite_fwd: print('make forward solution for {}'.format(elec)) try: make_forward_solution_to_specific_points( subject, events_id, [np.array(elec_pos)], elec, EPO, EVO, FWD_X, n_jobs=n_jobs, usingEEG=True) except: print(traceback.format_exc()) for elec, elec_pos, elec_pos_org in zip(electrodes_biplor, elecs_pos_biplor, elecs_pos_org_biplor): if not check_if_fwd_exist(elec, events_id) or overwrite_fwd: print('make forward solution for {}'.format(elec)) try: make_forward_solution_to_specific_points(subject, events_id, elec_pos_org, elec, EPO, EVO, FWD_X, n_jobs=n_jobs, usingEEG=True) except: print(traceback.format_exc()) def check_if_fwd_exist(elec, events_id): fwd_elec_fnames = [get_cond_fname(FWD_X, cond, region=elec) for cond in events_id.keys()] return np.all([os.path.isfile(fname) for fname in fwd_elec_fnames]) def get_electrodes_positions(subject, bipolar): subject_mri_dir = os.path.join(SUBJECTS_MRI_DIR, subject) positions_file_name = 'electrodes{}_positions.npz'.format('_bipolar' if bipolar else '') positions_file_name = os.path.join(subject_mri_dir, 'electrodes', positions_file_name) d = np.load(positions_file_name) return d['names'], d['pos'], d['pos_org'] if 'pos_org' in d else None def call_lcmv(electrode, forward, data_cov, noise_cov, evoked, epochs, cond, data_fname='', all_verts=False, pick_ori=None, whitened_data_cov_reg=0.01): if data_fname == '': data_fname = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'lcmv_{}-{}-{}.npy'.format(cond, electrode, pick_ori)) if not os.path.isfile(data_fname): stc = lcmv(evoked, forward, noise_cov, data_cov, reg=whitened_data_cov_reg, pick_ori=pick_ori, rank=None) data = stc.data[:len(stc.vertices)] data = data.T if all_verts else data.mean(0).T np.save(data_fname[:-4], data) else: data = np.load(data_fname) return data def call_dics(forward, evoked, bipolar, noise_csd, data_csd, cond='', data_fname='', all_verts=False, overwrite=False, electrode=''): from mne.beamformer import dics if data_fname == '': fmin, fmax = map(int, np.round(data_csd.frequencies)[[0, -1]]) data_fol = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular') data_fname = os.path.join(data_fol, 'dics_{}-{}-{}-{}.npy'.format(cond, electrode, fmin, fmax)) if not os.path.isfile(data_fname) or overwrite: stc = dics(evoked, forward, noise_csd, data_csd) data = stc.data[:len(stc.vertices)] data = data.T if all_verts else data.mean(0).T np.save(data_fname[:-4], data) else: data = np.load(data_fname) return data def plot_activation(events_id, meg_data, elec_data, electrode, from_t, to_t, method): xaxis = range(from_t, to_t) T = len(xaxis) f, axs = plt.subplots(2, sharex=True)#, sharey=True) for cond, ax in zip(events_id.keys(), axs): ax.plot(xaxis, meg_data[cond][:T], label='MEG', color='r') if not elec_data is None: ax.plot(xaxis, elec_data[cond][:T], label='electrode', color='b') ax.axvline(x=0, linestyle='--', color='k') ax.legend() ax.set_title('{}-{}'.format(cond, method)) plt.xlabel('Time(ms)') plt.show() def plot_activation_cond(cond, meg_data, elec_data, electrode, T, method='', do_plot=True, plt_fname=''): xaxis = range(T) plt.figure() if type(meg_data) is dict: colors = utils.get_spaced_colors(len(meg_data.keys()) + 1) colors = set(colors) - set(['b']) for (key, data), color in zip(meg_data.items(), colors): plt.plot(xaxis, meg_data[key], label=key, color=color) else: plt.plot(xaxis, meg_data, label='MEG', color='r') if not elec_data is None: plt.plot(xaxis, elec_data, label=electrode, color='b') plt.axvline(x=0, linestyle='--', color='k') plt.legend() plt.title('{} {}'.format(cond, method)) plt.xlabel('Time(ms)') if do_plot: plt.show() if plt_fname != '': plt.savefig(plt_fname) def plot_activation_options(meg_data_all, elec_data, electrode, T, elec_opts=False): xaxis = range(T) f, axs = plt.subplots(len(meg_data_all.keys()), sharex=True)#, sharey=True) if len(meg_data_all.keys())==1: axs = [axs] for ind, ((params_option), ax) in enumerate(zip(meg_data_all.keys(), axs)): ax.plot(xaxis, meg_data_all[params_option], label='MEG', color='r') elec = elec_data if not elec_opts else elec_data[params_option] label = electrode if not elec_opts else params_option ax.plot(xaxis, elec, label=label, color='b') ax.axvline(x=0, linestyle='--', color='k') plt.setp(ax.get_yticklabels(), visible=False) ax.legend() plt.setp(ax.get_legend().get_texts(), fontsize='20') plt.xlabel('Time(ms)', fontsize=20) plt.show() def plot_activation_one_fig(cond, meg_data_all, elec_data, electrode, T): xaxis = range(T) colors = utils.get_spaced_colors(len(meg_data_all.keys())) plt.figure() for (label, meg_data), color in zip(meg_data_all.items(), colors): plt.plot(xaxis, meg_data, label=label, color=color) plt.plot(xaxis, elec_data, label=electrode, color='k') plt.axvline(x=0, linestyle='--', color='k') plt.legend() # plt.set_title('{}-{}'.format(cond, params_option), fontsize=20) plt.xlabel('Time(ms)', fontsize=20) plt.show() def plot_all_vertices(cond, electrode, meg_data, elec_data, from_t, to_t, from_i, to_i, params_option): xaxis = np.arange(from_i, to_i) - 500 plt.plot(xaxis, meg_data[from_i: to_i], label=params_option, color='r') plt.plot(xaxis, elec_data[cond][from_i: to_i], label=electrode, color='b') plt.axvline(x=0, linestyle='--', color='k') plt.legend() plt.title('{}-{}'.format(cond, params_option), fontsize=20) plt.xlabel('Time(ms)', fontsize=20) plt.xlim((from_t, to_t)) plt.show() def test_pick_ori(forward, data_cov, noise_cov, evoked, epochs): meg_data = {} for pick_ori, key in zip([None, 'max-power'], ['None', 'max-power']): meg_data[key] = call_lcmv(forward, data_cov, noise_cov, evoked, epochs, pick_ori=pick_ori) return meg_data def test_whitened_data_cov_reg(forward, data_cov, noise_cov, evoked, epochs): meg_data = {} for reg in [0.001, 0.01, 0.1]: meg_data[reg] = call_lcmv(forward, data_cov, noise_cov, evoked, epochs, whitened_data_cov_reg=reg) return meg_data def test_all_verts(forward, data_cov, noise_cov, evoked, epochs): return dict(all=call_lcmv(forward, data_cov, noise_cov, evoked, epochs, all_verts=True)) def normalize_meg_data(meg_data, elec_data, from_t, to_t, sigma=0, norm_max=True): if sigma != 0: meg_data = gaussian_filter1d(meg_data, sigma) meg_data = meg_data[from_t:to_t] if norm_max: meg_data *= 1/max(meg_data) if not elec_data is None: meg_data -= meg_data[0] - elec_data[0] return meg_data def normalize_elec_data(elec_data, from_t, to_t): elec_data = elec_data[from_t:to_t] elec_data = elec_data - min(elec_data) elec_data *= 1/max(elec_data) return elec_data # def smooth_meg_data(meg_data): # meg_data_all = {} # for sigma in [8, 10, 12]: # meg_data_all[sigma] = gaussian_filter1d(meg_data, sigma) # return meg_data_all # def check_electrodes(): # meg_data_all, elec_data_all = {}, {} # electrodes = ['LAT1', 'LAT2', 'LAT3', 'LAT4'] # vars = read_vars(events_id, None) # for cond, forward, evoked, epochs, data_cov, noise_cov, data_csd, noise_csd in vars: # for electrode in electrodes: # calc_electrode_fwd(MRI_SUBJECT, electrode, events_id, bipolar, overwrite_fwd=False) # forward = mne.read_forward_solution(get_cond_fname(FWD_X, cond, region=electrode)) #, surf_ori=True) # elec_data = load_electrode_msit_data(bipolar, electrode, BLENDER_SUB_FOL, positive=True, normalize_data=True) # meg_data = call_dics(forward, evoked, bipolar, noise_csd, data_csd, cond) # elec_data_norm, meg_data_norm = normalize_data(elec_data[cond], meg_data, from_t, to_t) # meg_data_norm = gaussian_filter1d(meg_data_norm, 10) # meg_data_all[electrode] = meg_data_norm # elec_data_all[electrode] = elec_data_norm # plot_activation_options(meg_data_all, elec_data_all, electrodes, 500, elec_opts=True) def get_dics_fname(cond, bipolar, electrode, fmin, fmax): dics_fol = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular') return os.path.join(dics_fol, 'dics_{}-{}-{}-{}.npy'.format(cond, electrode, fmin, fmax)) def calc_dics_freqs_csd(events_id, electrodes, bipolar, from_t, to_t, time_split, freqs_bands, overwrite_csds=False, overwrite_dics=False, gk_sigma=3, njobs=6): vars = list(read_vars(events_id, None, read_csd=False, read_cov=False)) # electrodes = get_all_electrodes_names(bipolar) dics_fol = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular') utils.make_dir(dics_fol) for cond, _, evoked, epochs, _, _, _, _ in vars: event_id = {cond: events_id[cond]} all_electrodes_freqs = list(product(electrodes, freqs_bands)) electrodes_freqs = [(el, (fmin, fmax)) for el, (fmin, fmax) in all_electrodes_freqs \ if not os.path.isfile(get_dics_fname(cond, bipolar, el, fmin, fmax)) \ or overwrite_csds or overwrite_dics] np.random.shuffle(electrodes_freqs) chunks = utils.chunks(electrodes_freqs, len(electrodes_freqs) / njobs) params = [(event_id, chunk, evoked, epochs, bipolar, overwrite_csds, overwrite_dics, gk_sigma) for chunk in chunks] utils.run_parallel(_par_calc_dics_chunk_electrodes, params, njobs) def calc_all_fwds(events_id, electrodes, bipolar, from_t, to_t, time_split, overwrite_fwd=False, njobs=6): vars = list(read_vars(events_id, None, read_csd=False)) electrodes = get_all_electrodes_names(bipolar) for cond, _, evoked, epochs, data_cov, noise_cov, _, _ in vars: event_id = {cond: events_id[cond]} for electrode in electrodes: calc_electrode_fwd(MRI_SUBJECT, electrode, event_id, bipolar, overwrite_fwd=overwrite_fwd, read_if_exist=(not overwrite_fwd), n_jobs=njobs) def check_bipolar_fwd(forward, event_id, electrode, bipolar): vertno = utils.fwd_vertno(forward) fwd_was_wrong = True if bipolar and vertno != 2: fwd_was_wrong = True forward = calc_electrode_fwd(MRI_SUBJECT, electrode, event_id, bipolar, overwrite_fwd=True) vertno = utils.fwd_vertno(forward) if vertno != 2: raise Exception('vertno != 2') return forward, fwd_was_wrong def load_all_dics(freqs_bins, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, dont_calc_new_csd=False, njobs=2): meg_data_dic = {} cond = utils.first_key(event_id) freqs_bins_sorted = sorted(freqs_bins) for electrode in electrodes: dics_files = glob.glob(os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular', 'dics_{}-{}-*.npy'.format(cond, electrode))) if len(dics_files) < len(freqs_bins_sorted) and dont_calc_new_csd: print('{} does not have all the csd files'.format(electrode)) continue params = [(event_id, None, None, None, None, None, electrode, bipolar, from_t, to_t, False, False, gk_sigma, True, dont_calc_new_csd, ifreq, fmin, fmax) for ifreq, (fmin, fmax) in enumerate(freqs_bins_sorted)] results = utils.run_parallel(_par_calc_dics_frqs, params, njobs) meg_data_arr = np.zeros((len(results), to_t-from_t)) data_is_none = False for data, ifreq, fmin, fmax in results: if data is None: print('{}, {}-{}: data is None!'.format(electrode, fmin, fmax)) data_is_none = True break meg_data_arr[ifreq, :] = data if not data_is_none: meg_data_dic[electrode] = meg_data_arr return meg_data_dic def reconstruct_meg(events_id, freqs_bins, electrodes, from_t, to_t, time_split, gk_sigma=3, bipolar=True, plot_elecs=False, title=None, predicted_electrodes=[], plot_results=False, dont_calc_new_csd=True, vars=None, all_meg_data=None, res_ind=0, elec_data=None, optimization_method='leastsq', error_calc_method='RMS', optimization_params={}, save_plots_in_pred_electrode_fol=False, do_save_plots=False, uuid='', njobs=6): root_fol = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular') opt_ps, errors, opt_cv_params = {}, {}, {} time_diff = np.diff(time_split)[0] if len(time_split) > 1 else 500 if not all_dics_are_already_computed(events_id, freqs_bins, electrodes, predicted_electrodes, dont_calc_new_csd, root_fol): return None, None, None if elec_data is None: elec_data = load_electrodes_data(events_id, bipolar, electrodes, from_t, to_t, subtract_min=True, normalize_data=True) if plot_elecs: plot_electrodes(events_id, electrodes, elec_data) vars = prepare_vars(events_id, vars, dont_calc_new_csd) for cond, _, evoked, epochs, data_cov, noise_cov, _, _ in vars: event_id = {cond: events_id[cond]} meg_data_dic = calc_meg_data_dic(event_id, evoked, epochs, data_cov, noise_cov, all_meg_data, freqs_bins, electrodes, predicted_electrodes, bipolar, from_t, to_t, gk_sigma, dont_calc_new_csd, root_fol, njobs) if not meg_data_dic: continue ps, cvs_parameters = [], [] for from_t, to_t in zip(time_split, time_split + time_diff): p, cv_parameters = calc_optimization_features(optimization_method, freqs_bins, cond, meg_data_dic, elec_data, electrodes, from_t, to_t, optimization_params) ps = p if len(ps) == 0 else np.vstack((ps, p)) cvs_parameters.append(cv_parameters) opt_ps[cond] = ps opt_cv_params[cond] = cvs_parameters errors[cond] = calc_reconstruction_errors(ps, cond, electrodes + predicted_electrodes, elec_data, meg_data_dic, time_split, time_diff, error_calc_method) if plot_results: plot_leastsq_results(meg_data_dic, cond, elec_data, electrodes, opt_ps[cond], time_split, optimization_method, predicted_electrodes, same_ps=True, do_plot=True, title=title, res_ind=res_ind, save_in_pred_electrode_fol=save_plots_in_pred_electrode_fol, do_save=do_save_plots, uuid=uuid) return errors, opt_ps, opt_cv_params def prepare_vars(events_id, vars, dont_calc_new_csd): if dont_calc_new_csd: vars = [(event, None, None, None, None, None, None, None) for event in events_id.keys()] else: if vars is None: vars = list(read_vars(events_id, None, read_csd=False)) return vars def all_dics_are_already_computed(events_id, freqs_bins, electrodes, predicted_electrodes, dont_calc_new_csd, root_fol): # Check if all the dics files are already computed for electrode in electrodes + predicted_electrodes: dics_files_num = np.array([len(list(glob.glob(os.path.join(root_fol, 'dics_{}-{}-*.npy'.format(cond, electrode))))) for cond in events_id]) if np.any(dics_files_num < len(freqs_bins)) and dont_calc_new_csd: mes = '{}: not all the csds are calculated {}/{} for {}'.format(electrode, dics_files_num, len(freqs_bins), events_id.keys()) print(mes) logging.error(mes) return False return True def plot_electrodes(events_id, electrodes, elec_data): plt.figure() for electrode in electrodes: for cond in events_id: plt.plot(elec_data[electrode][cond], label='{} {}'.format(cond, electrode)) plt.legend() plt.show() def calc_meg_data_dic(event_id, evoked, epochs, data_cov, noise_cov, all_meg_data, freqs_bins, electrodes, predicted_electrodes, bipolar, from_t, to_t, gk_sigma, dont_calc_new_csd, root_fol, njobs): cond = utils.first_key(event_id) meg_data_dic = {} for electrode in electrodes + predicted_electrodes: if dont_calc_new_csd: if not all_meg_data is None: dics_files_num = len(list(glob.glob(os.path.join(root_fol, 'dics_{}-{}-*.npy'.format(cond, electrode))))) if dics_files_num < len(freqs_bins): logging.error('dics_files_num ({}) < len(CSD_FREQS) ({})!'.format(dics_files_num, len(freqs_bins))) continue else: params = [(event_id, None, None, None, None, None, electrode, bipolar, from_t, to_t, False, False, gk_sigma, True, dont_calc_new_csd, ifreq, fmin, fmax) for ifreq, (fmin, fmax) in enumerate(freqs_bins)] else: params = [(event_id, None, evoked, epochs, data_cov, noise_cov, electrode, bipolar, from_t, to_t, False, False, gk_sigma, True, dont_calc_new_csd, ifreq, fmin, fmax) for ifreq, (fmin, fmax) in enumerate(freqs_bins)] try: if all_meg_data is None: results = utils.run_parallel(_par_calc_dics_frqs, params, njobs) meg_data_arr = [] for data, fmin, fmax in results: meg_data_arr = data if len(meg_data_arr)==0 else np.vstack((meg_data_arr, data)) meg_data_dic[electrode] = meg_data_arr else: meg_data_dic[electrode] = all_meg_data[electrode] except: print('check_freqs: Error in gathering the csd files!') print(traceback.format_exc()) logging.error(traceback.format_exc()) if electrode in meg_data_dic: meg_data_dic.pop(electrode) return meg_data_dic def calc_optimization_features(optimization_method, freqs_bins, cond, meg_data_dic, elec_data, electrodes, from_t, to_t, optimization_params={}): # scorer = make_scorer(rol_corr, False) cv_parameters = [] if optimization_method in ['Ridge', 'RidgeCV', 'Lasso', 'LassoCV', 'ElasticNet', 'ElasticNetCV']: # vstack all meg data, such that X.shape = T*n X F, where n is the electrodes num # Y is T*n * 1 X = np.hstack((meg_data_dic[electrode][:, from_t:to_t] for electrode in electrodes)) Y = np.hstack((elec_data[electrode][cond][from_t:to_t] for electrode in electrodes)) funcs_dic = {'Ridge': Ridge(alpha=0.1), 'RidgeCV':RidgeCV(np.logspace(0, -10, 11)), # scoring=scorer 'Lasso': Lasso(alpha=1.0/X.shape[0]), 'LassoCV':LassoCV(alphas=np.logspace(0, -10, 11), max_iter=1000), 'ElasticNetCV': ElasticNetCV(alphas= np.logspace(0, -10, 11), l1_ratio=np.linspace(0, 1, 11))} clf = funcs_dic[optimization_method] clf.fit(X.T, Y) p = clf.coef_ if len(p) != len(freqs_bins): raise Exception('{} (len(clf.coef)) != {} (len(freqs_bin))!!!'.format(len(p), len(freqs_bins))) if optimization_method in ['RidgeCV', 'LassoCV']: cv_parameters = clf.alpha_ elif optimization_method == 'ElasticNetCV': cv_parameters = [clf.alpha_, clf.l1_ratio_] args = [(meg_pred(p, meg_data_dic[electrode][:, from_t:to_t]), elec_data[electrode][cond][from_t:to_t]) for electrode in electrodes] p0 = leastsq(post_ridge_err_func, [1], args=args, maxfev=0)[0] p = np.hstack((p0, p)) elif optimization_method in ['leastsq', 'dtw', 'minmax', 'diff_rms', 'rol_corr']: args = ([(meg_data_dic[electrode][:, from_t:to_t], elec_data[electrode][cond][from_t:to_t]) for electrode in electrodes], optimization_params) p0 = np.ones((1, len(freqs_bins)+1)) funcs_dic = {'leastsq': partial(leastsq, func=err_func, x0=p0, args=args), 'dtw': partial(minimize, fun=dtw_err_func, x0=p0, args=args), 'minmax': partial(minimize, fun=minmax_err_func, x0=p0, args=args), 'diff_rms': partial(minimize, fun=min_diff_rms_err_func, x0=p0, args=args), 'rol_corr': partial(minimize, fun=max_rol_corr, x0=p0, args=args)} res = funcs_dic[optimization_method]() p = res[0] if optimization_method=='leastsq' else res.x cv_parameters = optimization_params else: raise Exception('Unknown optimization_method! {}'.format(optimization_method)) return p, cv_parameters def calc_reconstruction_errors(electrode_ps, cond, electrodes, elec_data, meg_data_dic, time_split, time_diff, error_calc_method='RMS', dtw_window=10): errors = {} for electrode in electrodes: err = electrode_reconstruction_error(electrode, elec_data[electrode][cond], electrode_ps, meg_data_dic, error_calc_method, time_split, time_diff, dtw_window=dtw_window) errors[electrode] = err return errors def electrode_reconstruction_error(electrode, electrode_data, electrode_ps, meg_data_dic, error_calc_method, time_split, time_diff, dtw_window=10, rol_corr_window=30, meg=None): if meg is None: meg = combine_meg_chunks(meg_data_dic[electrode], electrode_ps, time_split, time_diff) if error_calc_method == 'RMS': err = sum((electrode_data - meg)**2) elif error_calc_method == 'RMSN': err = sum((electrode_data - meg)**2) * 1.0/utils.max_min_diff(electrode_data) elif error_calc_method == 'maxabs': err = maxabs(electrode_data, meg) elif error_calc_method == 'dtw': err = dtw.distance_w(electrode_data, meg, dtw_window) elif error_calc_method == 'diff': err = sum(abs(np.diff(electrode_data) - np.diff(meg))) elif error_calc_method == 'diff_rms': err = diff_rms(electrode_data, meg) elif error_calc_method == 'rol_corr': err = rol_corr(electrode_data, meg, window=rol_corr_window) else: raise Exception('Unreconize error_calc_method! {}'.format(error_calc_method)) return err def meg_pred(p, X): if len(p) == X.shape[0]: return np.dot(p, X) else: return p[0] + np.dot(p[1:], X) def electrode_err_func(p, X, y): return y - meg_pred(p, X) def err_func(p, XY, params): return sum([electrode_err_func(p, X, y) for X, y in XY]) def post_ridge_err_func(p, XY): return sum([y - (p + X) for X, y in XY]) def dtw_err_func(p, XY, params): dists = sum([dtw.distance_w(y, meg_pred(p, X), 10) for X, y in XY]) return dists + np.mean(p**2) def minmax_err_func(p, XY, params): return sum([max(abs(y - meg_pred(p, X))) for X, y in XY]) def min_diff_rms_err_func(p, XY, params): err = 0 for X, y in XY: meg = meg_pred(p, X) err += diff_rms(y, meg) return err def max_rol_corr(p, XY, params): err = 0 window = params.get('window', 30) alpha = params.get('alpha', 1) for X, y in XY: meg = meg_pred(p, X) err += rol_corr(y, meg, window=window, alpha=alpha) #+ 0.0001 * np.sum((p)**2) return err def maxabs(y, meg): return max(abs(y - meg)) * 1/utils.max_min_diff(y) def diff_rms(y, meg): # diffs_sum = sum(abs(np.diff(y) - np.diff(meg))) # diffs_sum = sum(abs(utils.diff_4pc(y) - utils.diff_4pc(meg))) # y = gaussian_filter1d(y, 3) diffs_sum = sum(abs(np.gradient(y) - np.gradient(meg))) rms = np.sum((y-meg)**2) max_abs = max(abs(y-meg)) if max_abs > 0.3: max_abs = np.inf if rms * 1/utils.max_min_diff(y) > 10: rms = np.inf return (diffs_sum + rms + max_abs) * 1/utils.max_min_diff(y) def rol_corr(y, meg, window=30, alpha=5): rol_corr = pandas.rolling_corr(y, meg, window) return (1 - np.nanmean(rol_corr)) * alpha + np.sum((y-meg)**2) * 1/utils.max_min_diff(y) # def rol_corr(y, meg, window=30): # rol_corr = pandas.rolling_corr(y, meg, window) # max_cor = len(y) - window + 1 # sum(~np.isnan(rol_corr)) # alpha = 30 * len(y) / 500 # corr_term = (max_cor - np.nansum(rol_corr)) * alpha / max_cor # rms_term = np.sum((y-meg)**2) * 1/utils.max_min_diff(y) # err = corr_term + rms_term # return err def _par_calc_dics_chunk_electrodes(params_chunck): (event_id, elecs_freqs_chunck, evoked, epochs, bipolar, overwrite_csd, overwrite_dics, gk_sigma) = params_chunck cond = utils.first_key(event_id) data_fol = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular') forwards = {} for electrode, (fmin, fmax) in elecs_freqs_chunck: data_fname = os.path.join(data_fol, 'dics_{}-{}-{}-{}.npy'.format(cond, electrode, fmin, fmax)) csd_fname = os.path.join(data_fol, 'csd_{}-{}-{}-{}.npy'.format(cond, electrode, fmin, fmax)) if not os.path.isfile(data_fname) or overwrite_csd or overwrite_dics: print('compute csd and dics for {} {}-{}'.format(electrode, fmin, fmax)) if electrode not in forwards: fwd_fname = get_cond_fname(FWD_X, cond, region=electrode) forwards[electrode] = mne.read_forward_solution(fwd_fname) if bipolar: forwards[electrode], fwd_was_wrong = check_bipolar_fwd(forwards[electrode], event_id, electrode, bipolar) if not os.path.isfile(csd_fname) or overwrite_csd: noise_csd = compute_epochs_csd(epochs, 'multitaper', tmin=-0.5, tmax=0.0, fmin=fmin, fmax=fmax) data_csd = compute_epochs_csd(epochs, 'multitaper', tmin=0.0, tmax=1.0, fmin=fmin, fmax=fmax) utils.save((data_csd, noise_csd), csd_fname) else: data_csd, noise_csd = utils.load(csd_fname) data = call_dics(forwards[electrode], evoked, bipolar, noise_csd, data_csd, data_fname=data_fname, all_verts=bipolar, overwrite=overwrite_dics, electrode=electrode) if bipolar: if data.shape[1] != 2: raise Exception('Should be 2 sources in the bipolar fwd!') data = np.diff(data).squeeze() np.save(data_fname, data) del noise_csd, data_csd, data gc.collect() else: print('{} already exists'.format(utils.namebase(data_fname))) def _par_calc_dics_frqs(p): event_id, forward, evoked, epochs, data_cov, noise_cov, electrode, bipolar, from_t, to_t,\ overwrite_csd, overwrite_dics, gk_sigma, load_data, dont_calc_new_csd, ifreq, fmin, fmax = p cond = utils.first_key(event_id) data_fname = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular', 'dics_{}-{}-{}-{}.npy'.format(cond, electrode, fmin, fmax)) data = None if not os.path.isfile(data_fname) or overwrite_csd: if dont_calc_new_csd: raise Exception('dont_calc_new_csd flag and not all csds are computed! ({} {} {}-{})'.format(cond, electrode, fmin, fmax)) print('compute csd and dics for {}-{}'.format(fmin, fmax)) noise_csd = compute_epochs_csd(epochs, 'multitaper', tmin=-0.5, tmax=0.0, fmin=fmin, fmax=fmax) data_csd = compute_epochs_csd(epochs, 'multitaper', tmin=0.0, tmax=1.0, fmin=fmin, fmax=fmax) data = call_dics(forward, evoked, bipolar, noise_csd, data_csd, data_fname=data_fname, all_verts=bipolar, overwrite=overwrite_dics, electrode=electrode) if bipolar: if data.shape[1] != 2: raise Exception('Should be 2 sources in the bipolar fwd!') data = np.diff(data).squeeze() np.save(data_fname, data) else: if load_data: data = np.load(data_fname) if not data is None and not from_t is None: data = normalize_meg_data(data, None, from_t, to_t, gk_sigma, norm_max=False) if not data is None: data = data.squeeze() return data, ifreq, fmin, fmax # def comp_lcmv_dics_electrode(events_id, electrode, bipolar): # elec_data = load_electrode_msit_data(bipolar, electrode, BLENDER_SUB_FOL, positive=True, normalize_data=True) # calc_electrode_fwd(MRI_SUBJECT, electrode, events_id, bipolar, overwrite_fwd=False) # vars = read_vars(events_id, electrode) # for cond, forward, evoked, epochs, data_cov, noise_cov, data_csd, noise_csd in vars: # meg_data_lcmv = call_lcmv(forward, data_cov, noise_cov, evoked, epochs, cond, all_verts=True) # meg_data_dics = call_dics(forward, evoked, bipolar, noise_csd, data_csd, cond, all_verts=True) # if bipolar: # meg_data_lcmv = np.diff(meg_data_lcmv).squeeze() # meg_data_dics = np.diff(meg_data_dics).squeeze() # elec_data_norm = normalize_elec_data(elec_data[cond], from_t, to_t) # meg_data_lcmv_norm = normalize_meg_data(meg_data_lcmv, elec_data_norm, from_t, to_t, 3) # meg_data_dics_norm = normalize_meg_data(meg_data_dics, elec_data_norm, from_t, to_t, 3) # plot_activation_cond(cond, {'lcmv': meg_data_lcmv_norm, 'dics': meg_data_dics_norm}, elec_data_norm, electrode, 500) # def check_bipolar_meg(events_id, electrode, bipolar, from_t, to_t): elec_name2, elec_name1 = electrode.split('-') # Warning: load_electrodes_data parameters have been changed! elec_bip_data = load_electrodes_data (True, electrode, BLENDER_SUB_FOL, positive=True, normalize_data=False) elec2_data = load_electrodes_data(False, elec_name2, BLENDER_SUB_FOL, positive=True, normalize_data=False) elec1_data = load_electrodes_data(False, elec_name1, BLENDER_SUB_FOL, positive=True, normalize_data=False) vars = read_vars(events_id, electrode) for cond, forward, evoked, epochs, data_cov, noise_cov, data_csd, noise_csd in vars: max_electrode = max([max(data) for data in [elec_bip_data[cond][from_t:to_t], elec2_data[cond][from_t:to_t], elec1_data[cond][from_t:to_t]]]) elec_bip_data_norm = (elec_bip_data[cond] * 1.0/max_electrode)[from_t:to_t] elec2_data_norm = (elec2_data[cond] * 1.0/max_electrode)[from_t:to_t] elec1_data_norm = (elec1_data[cond] * 1.0/max_electrode)[from_t:to_t] meg_data_lcmv = call_lcmv(forward, data_cov, noise_cov, evoked, epochs, cond, all_verts=True) meg_data_dics = call_dics(forward, evoked, bipolar, noise_csd, data_csd, cond, all_verts=True) for data, method in zip([meg_data_lcmv, meg_data_dics], ['lcmv', 'dics']): data_1 = normalize_meg_data(data[:, 0], elec_bip_data_norm, from_t, to_t, 3) data_2 = normalize_meg_data(data[:, 1], elec_bip_data_norm, from_t, to_t, 3) plt.figure() plt.plot(data_1-data_2, label='{} diff'.format(method)) plt.plot(data_1, label='{} 1'.format(method)) plt.plot(data_2, label='{} 2'.format(method)) plt.plot(elec_bip_data_norm, label=electrode) plt.plot(elec1_data_norm, label=elec_name1) plt.plot(elec2_data_norm, label=elec_name2) plt.plot() plt.legend() plt.title(cond) plt.show() def get_electrodes_parcellation(electrodes, bipolar, include_white_matter=True): parc = defaultdict(dict) parc_fname = os.path.join(electrode_parc_fol(), '{}_laus250_electrodes_all_rois_cigar_r_3_l_4{}.csv'.format( MRI_SUBJECT, '_bipolar_stretch' if bipolar else '')) if os.path.isfile(parc_fname): electrodes_probs = np.genfromtxt(parc_fname, dtype=np.str, delimiter=',') rois = electrodes_probs[0, 1:] for electrode_probs in electrodes_probs[1:, :]: elec_name = electrode_probs[0] for elec_prob, roi in zip(electrode_probs[1:], rois): if float(elec_prob) > 0: if include_white_matter or not 'White-Matter' in roi: parc[elec_name][roi] = elec_prob return parc def get_figs_fol(): return os.path.join(utils.get_figs_fol(), 'meg_electrodes') def electrode_parc_fol(): return os.path.join(utils.get_parent_fol(), 'electrodes_parcellation') # def load_all_electrodes_data(root_fol, bipolar): # d = np.load(os.path.join(root_fol, 'electrodes{}_data.npz'.format('_bipolar' if bipolar else ''))) # data, names, elecs_conditions = d['data'], d['names'], d['conditions'] # return data, names, elecs_conditions def find_significant_electrodes(events_id, bipolar, from_t, to_t, do_plot=False, do_save=False, plot_only_sig=False): elec_data = load_electrodes_data(events_id, bipolar, subtract_min=False, normalize_data=False) T = to_t - from_t sig_electrodes = defaultdict(list) if do_save and do_plot: fol = os.path.join(SUBJECT_MRI_FOL, 'electrodes', 'figs', 'bipolar' if bipolar else 'regular') utils.delete_folder_files(fol) for cond in events_id.keys(): # cond_id = MEG_ELEC_CONDS_TRANS[cond] # cond = EVENTS_TRANS_INV[cond] for electrode in elec_data.keys(): if electrode in BAD_ELECS: continue org_data = elec_data[electrode][cond] data_std = np.std(org_data[:from_t]) data_mean = np.mean(org_data[:from_t]) data = org_data[from_t:to_t] - data_mean sig = False for stds_num, sig_len in [(3, 30), (4, 20), (5, 10)]: sig_indices = np.where((data > data_mean + stds_num * data_std) | (data < data_mean - stds_num * data_std))[0] diff = np.diff(sig_indices) sig = sig or max(map(len, ''.join([str(x==y)[0] for (x,y) in zip(diff[:-1], diff[1:])]).split('F'))) > sig_len if sig: dics_files = glob.glob(os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular', 'dics_{}-{}-*.npy'.format(cond, electrode))) if len(dics_files) > 0: sig_electrodes[cond].append(electrode) # print(cond, name) # print('dic file exist for {}-{} sig'.format(cond, name)) if do_plot or (plot_only_sig and sig): plt.figure() plt.plot(data - data_mean, 'b') plt.plot((0, T), (3 * data_std, 3 * data_std), 'r--') plt.plot((0, T), (-3 * data_std, -3 * data_std), 'r--') plt.plot((0, T), (2 * data_std, 2 * data_std), 'y--') plt.plot((0, T), (-2 * data_std, -2 * data_std), 'y--') plt.plot((0, T), (2.5 * data_std, 2.5 * data_std), 'c--') plt.plot((0, T), (-2.5 * data_std, -2.5 * data_std), 'c--') title = '{}-{}{}'.format(cond, electrode, '-sig' if sig else '') plt.title(title) if do_save: plt.savefig(os.path.join(fol, '{}.png'.format(title))) plt.close() else: plt.show() print(title) # fname = 'sig_{}electrodes.pkl'.format('bipolar_' if bipolar else '') # utils.save(sig_electrodes, os.path.join(SUBJECT_MRI_FOL, 'electrodes', fname)) return sig_electrodes def check_freqs_for_all_electrodes(events_id, from_t, to_t, time_split, njobs=4): for bipolar in [True, False]: # electrodes, _, _ = get_electrodes_positions(MRI_SUBJECT, bipolar) electrodes = get_all_electrodes_names(bipolar) for electrode in electrodes: try: reconstruct_meg(events_id, [electrode], from_t, to_t, time_split, gk_sigma=3, bipolar=bipolar, plot_elecs=False, plot_results=False, predicted_electrodes=[], njobs=njobs) except: pass def learn_and_pred(events_id, bipolar, from_t, to_t, time_split): sig = find_significant_electrodes(events_id, bipolar, from_t, to_t) electrodes, predicted_electrodes = [],[] for k, (cond, name) in enumerate(sig.items()): if EVENTS_TRANS[cond] in events_id.keys(): if k%3 == 0: predicted_electrodes.append(name) else: electrodes.append(name) reconstruct_meg(events_id, electrodes, from_t, to_t, time_split, predicted_electrodes=predicted_electrodes, gk_sigma=3, bipolar=bipolar, njobs=1) def find_fit(events_id, bipolar, from_t, to_t, time_split, gk_sigma=3, err_threshold=np.inf, plot_results=False, njobs=3): sig_electrodes = find_significant_electrodes(events_id, bipolar, from_t, to_t, do_plot=False, do_save=False, plot_only_sig=False) bad_channels = {} for cond in events_id: event_id = {cond: events_id[cond]} if not cond in bad_channels: bad_channels[cond] = [] print('calc leastsq for {}'.format(cond)) electrodes = list(set(sig_electrodes[cond]) - set(bad_channels[cond])) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=False, normalize_data=False) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs=njobs) errors, ps, cv_params = {}, {}, {} for electrode in electrodes: elec_errs, elec_ps, opt_cv_params = reconstruct_meg(event_id, [electrode], from_t, to_t, time_split, plot_results=plot_results, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=njobs) if elec_errs: print(electrode, elec_errs[cond][electrode]) errors[electrode] = elec_errs[cond][electrode] ps[electrode] = elec_ps[cond] cv_params[electrode] = opt_cv_params[cond] utils.save((errors, ps, cv_params), get_pkl_file('{}_leastsq_time_split_{}.pkl'.format(cond, len(time_split)))) def analyze_leastsq_results(events_id, time_split): for cond in events_id: print(cond) errors, ps = np.load(get_pkl_file('{}_leastsq_time_split_{}.pkl'.format(cond, len(time_split)))) electrodes = errors.keys() print(sorted(ps.keys())) print(cond, len(errors)) X = [] for elec, p in ps.items(): X = p if len(X)==0 else np.vstack((X, p)) utils.plot_3d_PCA(X, electrodes) x_pca = utils.calc_PCA(X, n_components=3) res, best_gmm, bic = utils.calc_clusters_bic(X, 10) # means = res['spherical'][6].means_ best_gmm = res['spherical'][6] utils.plot_3d_scatter(x_pca, classifier=best_gmm) print('sdf') def find_best_groups(event_id, bipolar, from_t, to_t, time_split, err_threshold=7, groups_panelty=5, only_sig_electrodes=False, electrodes_positive=True, electrodes_normalize=True, gk_sigma=3, njobs=4): cond = utils.first_key(event_id) if only_sig_electrodes: sig_electrodes = find_significant_electrodes(event_id, bipolar, from_t, to_t, do_plot=False, do_save=False, plot_only_sig=False) all_electrodes = sig_electrodes[cond] else: all_electrodes = get_all_electrodes_names(bipolar) elec_data = load_electrodes_data(event_id, bipolar, all_electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, all_electrodes, from_t, to_t, gk_sigma, njobs=njobs) run_min_score = np.inf run_num = 1 uuid = utils.rand_letters(5) print('find_best_groups', cond, err_threshold) while True: electrodes = set(list(all_electrodes)) new_electrode, elec_ps, elec_err = pick_new_electrode(event_id, bipolar, from_t, to_t, time_split, electrodes, meg_data_dic, elec_data, njobs) electrodes.remove(new_electrode) groups_errs = [[elec_err]] groups_ps = [elec_ps] groups = [[new_electrode]] while len(electrodes) > 0: new_electrode, new_err, new_ps = find_best_new_electrode(groups[-1], electrodes, event_id, meg_data_dic, elec_data, from_t, to_t, time_split, bipolar, njobs) if new_err < err_threshold: groups[-1].append(new_electrode) groups_errs[-1].append(new_err) groups_ps[-1] = new_ps else: print('new group!') # for debug: plot_leastsq_results(meg_data_dic, cond, elec_data, groups[-1], groups_ps[-1], time_split, same_ps=True, do_plot=True, do_save=True, uuid=uuid) new_electrode, elec_ps, elec_err = pick_new_electrode(event_id, bipolar, from_t, to_t, time_split, electrodes, meg_data_dic, elec_data, njobs) groups.append([new_electrode]) groups_errs.append([elec_err]) groups_ps.append(elec_ps) print(groups[-1], groups_errs[-1]) electrodes.remove(new_electrode) run_score = sum(map(np.mean, groups_errs)) + len(groups) * groups_panelty if run_score < run_min_score: print('new min err!') run_min_score = run_score utils.save((groups, groups_ps, groups_errs), get_pkl_file( '{}_find_best_groups_{}_{}_{}_gp{}.pkl'.format(cond, len(time_split), err_threshold, 'only_sig' if only_sig_electrodes else 'all', groups_panelty))) print('{} run: {}, run score: {}, run min score: {}, groups err: {}, threshold: {}'.format( cond, run_num, run_score, run_min_score, map(np.mean, groups_errs), err_threshold)) utils.save((groups, groups_ps, groups_errs), get_pkl_file( '{}_find_best_groups_run_{}_{}_{}_{}_gp{}_{}.pkl'.format(cond, run_num, len(time_split), err_threshold, 'only_sig' if only_sig_electrodes else 'all', groups_panelty, uuid))) run_num += 1 def pick_new_electrode(event_id, bipolar, from_t, to_t, time_split, electrodes, meg_data_dic, elec_data, njobs): cond = utils.first_key(event_id) new_electrode = random.sample(electrodes, 1)[0] elec_errors, elec_ps, opt_cv_params = reconstruct_meg(event_id, [new_electrode], from_t, to_t, time_split, plot_results=False, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=njobs) print(new_electrode, elec_errors[cond][new_electrode]) return new_electrode, elec_ps[cond], elec_errors[cond][new_electrode], opt_cv_params[cond] def find_best_new_electrode(group_electrodes, other_electrodes, event_id, freqs_bins, meg_data_dic, elec_data, from_t, to_t, time_split, bipolar, njobs): cond = utils.first_key(event_id) errors, ps = {}, {} for electrode in other_electrodes: elec_errors, elec_ps, opt_cv_params = reconstruct_meg(event_id, freqs_bins, [electrode] + group_electrodes, from_t, to_t, time_split, plot_results=False, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=njobs) # for debug # plot_leastsq_results(meg_data_dic, cond, elec_data, [electrode] + group_electrodes, elec_ps[cond], time_split, same_ps=True, do_plot=True) errors[electrode] = max([err for elc, err in elec_errors[cond].items()]) ps[electrode] = elec_ps[cond] best_electrode = min(errors, key=errors.get) min_err = errors[best_electrode] best_ps = ps[best_electrode] return best_electrode, min_err, best_ps def find_best_predictive_subset(event_id, bipolar, freqs_bins, from_t, to_t, time_split, k=4, only_sig_electrodes=False, check_only_pred_score=True, only_from_same_lead=False, electrodes_positive=False, electrodes_normalize=False, electrodes_subtract_mean=False, gk_sigma=3, error_threshold=20, uuid_len=5, optimization_method='Ridge', error_calc_method='RMS', do_plot_results=False, do_plot_all_results=False, do_save_partial_results=True, combs=None, optimization_params={}, vebrose=True, meg_data_dic=None, elec_data=None, all_electrodes=None, save_results=True, njobs=4): if vebrose: print('find_best_predictive_subset:\nk={}, optimization_method={} '.format(k, optimization_method) + 'error_threshold={}, electrodes_positive={} '.format(error_threshold, str(electrodes_positive)[0]) + 'electrodes_normalize={}'.format(str(electrodes_normalize)[0])) if only_from_same_lead and only_sig_electrodes: raise Exception("Can't handle only_from_same_lead and only_sig_electrodes!") cond = utils.first_key(event_id) uuid = results_fol = '' if save_results: results_fol = get_results_fol(optimization_method, electrodes_normalize, electrodes_positive) utils.make_dir(results_fol) uuid, output_file = bps_find_unique_results_name(cond, k, results_fol, optimization_params, uuid_len) if do_save_partial_results: utils.make_dir(os.path.join(results_fol, uuid)) if all_electrodes is None: if only_sig_electrodes: sig_electrodes = find_significant_electrodes(event_id, bipolar, from_t, to_t, do_plot=False, do_save=False, plot_only_sig=False) all_electrodes = sig_electrodes[cond] else: all_electrodes = get_all_electrodes_names(bipolar) if elec_data is None: elec_data = load_electrodes_data(event_id, bipolar, all_electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize, subtract_mean=electrodes_subtract_mean) if len(set(all_electrodes) - set(elec_data.keys())) > 0: print('data_electrodes_set - all_electrodes_set:') print(set(all_electrodes) - set(elec_data.keys())) raise Exception('Not the same electrodes in all_electrodes and electodes_data!') if meg_data_dic is None: meg_data_dic = load_all_dics(freqs_bins, event_id, bipolar, all_electrodes, from_t, to_t, gk_sigma, dont_calc_new_csd=True, njobs=njobs) take_only_first_prediction = not combs is None if combs is None: if only_from_same_lead: combs = get_lead_groups(k, bipolar) else: electrodes = set(meg_data_dic.keys()) combs = list(itertools.combinations(electrodes, k)) N = len(combs) # np.random.shuffle(combs) combs_chuncked = utils.chunks(combs, int(N / njobs)) params = [(comb_chuncked, event_id, elec_data, meg_data_dic, freqs_bins, from_t, to_t, time_split, bipolar, k, check_only_pred_score, vebrose, optimization_method, error_calc_method, optimization_params, error_threshold, run, int(N / njobs), uuid, do_plot_results, do_plot_all_results, do_save_partial_results, results_fol, take_only_first_prediction) for run, comb_chuncked in enumerate(combs_chuncked)] runs_results = utils.run_parallel(_find_best_predictive_subset_parallel, params, njobs) all_results = [] for run_results in runs_results: all_results.extend(run_results) if save_results: if vebrose: print('saving results in {}'.format(output_file)) utils.save(all_results, output_file) return all_results def bps_find_unique_results_name(cond, k, results_fol, optimization_params, uuid_len=10): uuid = utils.rand_letters(uuid_len) # find unique file name params_suffix = utils.params_suffix(optimization_params) output_file = os.path.join(results_fol, 'bps_{}_{}_{}{}.pkl'.format(cond, k, uuid, params_suffix)) while os.path.isfile(output_file): uuid = utils.rand_letters(uuid_len) output_file = os.path.join(results_fol, 'bps_{}_{}_{}{}.pkl'.format(cond, k, uuid, params_suffix)) return uuid, output_file def bps_find_unique_thread_name(thread_fol, uuid_len=5): thread_uuid = utils.rand_letters(uuid_len) thred_output = os.path.join(thread_fol, '{}.pkl'.format(thread_uuid)) while os.path.isfile(thred_output): thread_uuid = utils.rand_letters(uuid_len) thred_output = os.path.join(thread_fol, '{}.pkl'.format(thread_uuid)) return thred_output def _find_best_predictive_subset_parallel(params_chunks): (comb_chuncked, event_id, elec_data, meg_data_dic, freqs_bins, from_t, to_t, time_split, bipolar, k, check_only_pred_score, vebrose, optimization_method, error_calc_method, optimization_params, error_threshold, run, N, uuid, do_plot_results, do_plot_all_results, do_save_partial_results, results_fol, take_only_first_prediction) = params_chunks results = [] if do_save_partial_results: thread_fol = os.path.join(results_fol, uuid) thred_output = bps_find_unique_thread_name(uuid, thread_fol) # todo for run, comb in enumerate(comb_chuncked): cond = utils.first_key(event_id) if run % 1000 == 0 and vebrose: print('{}/{}'.format(run, N)) for predicted_electrode in comb: train_electrodes = [e for e in comb if e != predicted_electrode] elec_errors, elec_ps, opt_cv_params = reconstruct_meg(event_id, freqs_bins, train_electrodes, from_t, to_t, time_split, optimization_method = optimization_method, error_calc_method=error_calc_method, optimization_params=optimization_params, predicted_electrodes=[predicted_electrode], plot_results=do_plot_all_results, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=1) if elec_ps is None or elec_errors is None: print('No meg reconstruction for {}!'.format(train_electrodes + [predicted_electrode])) continue if check_only_pred_score: good_result = elec_errors[cond][predicted_electrode] < error_threshold else: good_result = np.all(np.array(list(elec_errors[cond].values())) < error_threshold) if good_result: if vebrose: errors_str = ','.join(['{:.2f}'.format(elec_errors[cond][elec]) for elec in train_electrodes + [predicted_electrode]]) print('{}->{}: {}'.format(train_electrodes, predicted_electrode, errors_str))#, opt_cv_params[cond])) if do_plot_results: plot_leastsq_results(meg_data_dic, cond, elec_data, train_electrodes, elec_ps[cond], time_split, optimization_method, [predicted_electrode], do_plot=True, do_save=False, uuid=uuid, res_ind=utils.rand_letters(3)) # errors = copy.deepcopy(elec_errors[cond]) # ps = copy.deepcopy(elec_ps[cond]) # train = copy.deepcopy(train_electrodes) results.append((predicted_electrode, train_electrodes, elec_errors[cond], elec_ps[cond], opt_cv_params[cond])) if do_save_partial_results: utils.save(results, thred_output) if take_only_first_prediction: break return results def best_predictive_subset_collect_results(event_id, freqs_bin, bipolar, from_t, to_t, time_split, uuid, k=3, gk_sigma=3, electrodes_positive=False, electrodes_normalize=False, electrodes_subtract_mean=False, error_threshold=10, optimization_method='', elec_data=None, error_calc_method='RMS', sort_only_accoring_to_pred=True, calc_all_errors=False, dtw_window=10, do_save=False, do_plot=True, save_in_pred_electrode_fol=False, write_errors_csv=False, optimization_params={}, do_plot_electrodes=False, error_functions=(), check_only_pred_score=True, do_plot_together=False, njobs=4): print('best_predictive_subset_collect_results:\nk={}, optimization_method={} '.format(k, optimization_method) + 'error_calc_method={} error_threshold={} '.format(error_calc_method, error_threshold) + 'electrodes_positive={} electrodes_normalize={}'.format(str(electrodes_positive)[0], str(electrodes_normalize)[0])) cond = utils.first_key(event_id) electrodes = get_all_electrodes_names(bipolar) if elec_data is None: elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize, subtract_mean=electrodes_subtract_mean) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs=njobs) if len(error_functions) == 0: error_functions = ERROR_RECONSTRUCT_METHODS results, results_fol = bps_load_results(cond, uuid, k, optimization_method, electrodes_normalize, electrodes_positive, optimization_params) results, errors, results_errors = sort_results(event_id, results, meg_data_dic, elec_data, time_split, error_calc_method, sort_only_accoring_to_pred, calc_all_errors, dtw_window) results_num = sum([np.all(result_errors < error_threshold) for result_errors in results_errors]) print('{} results < error_threshold'.format(results_num)) utils.save((errors, results_errors), os.path.join(results_fol, 'bps_errors_{}_{}_{}.pkl'.format(k, uuid, error_calc_method))) if write_errors_csv: csv_file = open(os.path.join(results_fol, 'bps_errors_{}_{}_{}.csv'.format(k, uuid, error_calc_method)), 'w') csv_writer = csv.writer(csv_file, delimiter=',') csv_writer.writerows([['index', 'predicted', 'train'] + error_functions]) parc = get_electrodes_parcellation(electrodes, bipolar) for res_ind, (result, result_errors) in enumerate(zip(results, results_errors)): good_result = result_errors[-1] < error_threshold if check_only_pred_score else \ np.all(result_errors < error_threshold) if good_result: predicted_electrode, train_electrodes, elecs_errors, ps, cv_params = open_bps_result(result) if res_ind == 0: print('#freqs: {}'.format(ps.shape[1])) print('{}) {}->{}: {}'.format(res_ind, train_electrodes, predicted_electrode, result_errors)) if write_errors_csv: write_error_line(csv_writer, res_ind, cond, elec_data, predicted_electrode, train_electrodes, ps, meg_data_dic, time_split, dtw_window) print_parc_info(parc, predicted_electrode, train_electrodes, k=-1) plot_leastsq_results(meg_data_dic, cond, elec_data, train_electrodes, ps, time_split, predicted_electrodes=[predicted_electrode], same_ps=True, res_ind=res_ind, do_save=do_save, uuid=uuid, do_plot=do_plot, optimization_method=optimization_method, save_in_pred_electrode_fol=save_in_pred_electrode_fol, error_functions=error_functions) if do_plot_together: plt.figure() plt.plot(elec_data[predicted_electrode][cond], label=predicted_electrode) plt.plot(elec_data[train_electrodes[0]][cond], label=train_electrodes[0]) plt.plot(elec_data[train_electrodes[1]][cond], label=train_electrodes[1]) plt.legend() plt.show() if do_plot_electrodes: plot_electrodes(bipolar, [predicted_electrode] + train_electrodes) if write_errors_csv: csv_file.close() def print_parc_info(parc, predicted_electrode, train_electrodes, k=3): if parc: for electrode, elec_type in zip([predicted_electrode] + train_electrodes, ['pred'] + ['train'] * len(train_electrodes)): elec_parc = sorted([(float(prob), region) for region, prob in parc[electrode].items()], reverse=True) elec_parc_str = ['{}: {:.4f}'.format(region, prob) for prob, region in elec_parc] print('{}: {} in {}'.format(elec_type, electrode, elec_parc_str[:k])) def write_error_line(csv_writer, res_ind, cond, elec_data, predicted_electrode, train_electrodes, ps, meg_data_dic, time_split, dtw_window): time_diff = np.diff(time_split)[0] if len(time_split) > 1 else 500 meg = combine_meg_chunks(meg_data_dic[predicted_electrode], ps, time_split, time_diff) err_func = partial(electrode_reconstruction_error, electrode=predicted_electrode, meg=meg, electrode_data=elec_data[predicted_electrode][cond], electrode_ps=ps, meg_data_dic=meg_data_dic, time_split=time_split, time_diff=time_diff, dtw_window=dtw_window) errors_strs = ['{:.2f}'.format(err_func(error_calc_method=em)) for em in ERROR_RECONSTRUCT_METHODS] csv_writer.writerows([[res_ind, predicted_electrode, train_electrodes] + errors_strs]) def bps_load_results(cond, uuid, k, optimization_method, electrodes_normalize, electrodes_positive, optimization_params): results_fol = get_results_fol(optimization_method, electrodes_normalize, electrodes_positive) params_suffix = '' if len(optimization_params) == 0 else \ ''.join(sorted(['_{}_{}'.format(param_key, param_val) for param_key, param_val in sorted(optimization_params.items())]) ) results_file = os.path.join(results_fol, 'bps_{}_{}_{}{}.pkl'.format(cond, k, uuid, params_suffix)) if os.path.isfile(results_file): results = utils.load(results_file) else: # Try to collect partial results results = [] partial_files = glob.glob(os.path.join(results_fol, uuid, '*.pkl')) for partial_file in partial_files: results.extend(utils.load(partial_file)) if len(results) == 0: raise Exception('No results found!') return results, results_fol def sort_results(event_id, results, meg_data_dic, electrodes_data, time_split, error_calc_method='RMS', only_accoring_to_pred=True, re_calc_all_errors=False, dtw_window=10, max_error_threshold=np.inf): if not re_calc_all_errors: results_errors = [[res_errors[elec] for elec in train + [pred]] for pred, train, res_errors, _, _ in results if res_errors[pred] < max_error_threshold] if only_accoring_to_pred: errors = [res_errors[pred] for pred, _, res_errors, _, _ in results if res_errors[pred] < max_error_threshold] else: errors = copy.copy(results_errors) sorted_results = [(pred, train, res_errors, ps, params) for pred, train, res_errors, ps, params in results if res_errors[pred] < max_error_threshold] else: sorted_results, errors, results_errors = [], [], [] cond = utils.first_key(event_id) time_diff = np.diff(time_split)[0] if len(time_split) > 1 else 500 now = time.time() for ind, result in enumerate(results): predicted_electrode, train_electrodes, elecs_errors, electrode_ps, cv_params = open_bps_result(result) if ind % 100 == 0: print('sorting ({}) {}/{} {}'.format(error_calc_method, ind, len(results), time.time() - now)) err = 0 electrodes_errors = [] for electrode in train_electrodes + [predicted_electrode]: elec_err = electrode_reconstruction_error(electrode, electrodes_data[electrode][cond], electrode_ps, meg_data_dic, error_calc_method, time_split, time_diff, dtw_window) electrodes_errors.append(elec_err) if (only_accoring_to_pred and electrode == predicted_electrode) or not only_accoring_to_pred: err += elec_err # Check if the predicted electrode's error is below the threshld if electrodes_errors[-1] < max_error_threshold: sorted_results.append(result) errors.append(err) results_errors.append(electrodes_errors) sorted_results = [res for (err, res) in sorted(zip(errors, sorted_results))] results_errors = [res_err for (err, res_err) in sorted(zip(errors, results_errors))] errors = sorted(errors) return sorted_results, np.array(errors), np.array(results_errors) def open_bps_result(result): if len(result) == 5: predicted_electrode, train_electrodes, elecs_errors, electrode_ps, cv_params = result else: predicted_electrode, train_electrodes, elecs_errors, electrode_ps = result cv_params = [] return predicted_electrode, train_electrodes, elecs_errors, electrode_ps, cv_params def bps_collect_results_gs(uuid, event_id, optimization_method, error_calc_method, from_t, to_t, time_split, gk_sigma=3, bipolar=True, electrodes_positive=False, electrodes_normalize=False, dtw_window=10, do_plot=True, error_threshold=10, write_csv=False, njobs=4): cond = utils.first_key(event_id) time_diff = np.diff(time_split)[0] if len(time_split) > 1 else 500 electrodes = get_all_electrodes_names(bipolar) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs=njobs) results_fol = os.path.join(get_results_fol(optimization_method, electrodes_normalize, electrodes_positive), 'params_grid_search_{}'.format(uuid)) gs_results = defaultdict(lambda : defaultdict(list)) gs_files = glob.glob(os.path.join(results_fol, '*.pkl')) gs_files = sorted(gs_files, key=lambda x:int(utils.namebase(x).split('_')[-3])) alphas = [int(utils.namebase(x).split('_')[-3]) for x in gs_files] if write_csv: csv_file = open(os.path.join(results_fol, 'grid_search.csv'), 'w') csv_writer = csv.writer(csv_file, delimiter=',') csv_writer.writerows([['predicted', 'train'] + alphas]) for result_fname in gs_files: alpha = int(utils.namebase(result_fname).split('_')[-3]) results = utils.load(result_fname) # results, _, prediction_errors = sort_results(event_id, results, meg_data_dic, elec_data, time_split, # error_calc_method=error_calc_method, only_accoring_to_pred=False, calc_all_errors=True, dtw_window=dtw_window) for result in results: predicted_electrode, train_electrodes, _, ps, _ = open_bps_result(result) errors = [] for electrode in [predicted_electrode] + train_electrodes: errors.append(electrode_reconstruction_error(electrode, elec_data[electrode][cond], ps, meg_data_dic, error_calc_method, time_split, time_diff, dtw_window)) gs_results[(predicted_electrode, tuple(train_electrodes))][alpha] = (errors, ps) best_alpha = {} for (predicted_electrode, train_electrodes), pred_results in gs_results.items(): pred_results = utils.sort_dict_by_values(pred_results) pred_errors = [result_info[0][0] for result_info in pred_results.values()] train1_errors = [result_info[0][1] for result_info in pred_results.values()] train2_errors = [result_info[0][2] for result_info in pred_results.values()] # min_train_index = np.argmin(train_errors) # best_alpha[predicted_electrode] = alphas[min_train_index] if write_csv: result_errors_str = ['{:.2f}'.format(err) for err in pred_errors] csv_writer.writerows([[predicted_electrode, train_electrodes] + result_errors_str]) if do_plot and min(pred_errors) < error_threshold: plt.figure() plt.plot(pred_errors, label='pred') plt.plot(train1_errors, label='train1') plt.plot(train2_errors, label='train2') plt.title('{} {}'.format(predicted_electrode, train_electrodes)) plt.legend() plt.show() def plot_reconstruction_for_different_freqs(event_id, electrode, two_electrodes, from_t, to_t, time_split, gk_sigma=3, bipolar=True, electrodes_positive=False, electrodes_normalize=False, njobs=4): cond = utils.first_key(event_id) electrodes = get_all_electrodes_names(bipolar) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs=njobs) reconstruct_meg(event_id, [electrode], from_t, to_t, time_split, plot_results=True, all_meg_data=meg_data_dic, elec_data=elec_data, title='{}: {}'.format(cond, electrode)) reconstruct_meg(event_id, two_electrodes, from_t, to_t, time_split, optimization_method='RidgeCV', plot_results=True, all_meg_data=meg_data_dic,elec_data=elec_data, title='{}: {} and {}'.format(cond, two_electrodes[0], two_electrodes[1])) freqs_inds = np.array([2, 6, 9, 10, 11, 15, 16]) plt.plot(elec_data[electrode][cond]) plt.plot(meg_data_dic[electrode][freqs_inds, :].T, '--') plt.legend([electrode] + np.array(CSD_FREQS)[freqs_inds].tolist()) # plt.title('{}: {}'.format(cond, electrode)) plt.show() def plot_predictive_subset(electrodes, event_id, bipolar, from_t, to_t, time_split, elec_data, meg_data_dic, gk_sigma=3, electrodes_positive=True, electrodes_normalize=True, njobs=4): for predicted_electrode in electrodes: train_electrodes = [e for e in electrodes if e != predicted_electrode] reconstruct_meg(event_id, train_electrodes, from_t, to_t, time_split, predicted_electrodes=[predicted_electrode], plot_results=True, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=1) def calc_lead_predictiveness(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, electrodes_positive=True, electrodes_normalize=True, k=3, error_threshold=10, njobs=4): cond = utils.first_key(event_id) electrodes = get_all_electrodes_names(True) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs=njobs) electrodes = set(meg_data_dic.keys()) elecs_groups = get_lead_groups(k) good_groups = [] for elecs in elecs_groups: group_errors, group_pss, group_predicted = [], [], [] if not np.all([e in electrodes for e in elecs]): continue for predicted_electrode in elecs: train_electrodes = [e for e in elecs if e != predicted_electrode] elec_errors, elec_ps = reconstruct_meg(event_id, train_electrodes, from_t, to_t, time_split, predicted_electrodes=[predicted_electrode], plot_results=False, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=1) group_errors.append(elec_errors[cond][predicted_electrode]) group_pss.append(elec_ps[cond]) group_predicted.append(predicted_electrode) good_groups.append((elecs, group_errors, group_pss, group_predicted)) if min(group_errors) < error_threshold: print('low error!') print(group_errors, group_predicted) utils.save(good_groups, get_pkl_file('{}_calc_lead_predictiveness_{}.pkl'.format(cond, k))) def plot_lead_predictiveness(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, electrodes_positive=True, electrodes_normalize=True, k=3, error_threshold=10, njobs=4): cond = utils.first_key(event_id) good_groups = utils.load(get_pkl_file('{}_calc_lead_predictiveness_{}.pkl'.format(cond, k))) electrodes = get_all_electrodes_names(True) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs=njobs) for elecs, group_errors, group_pss, group_predicted in good_groups: if min(group_errors) < error_threshold: best_index = np.argmin(group_errors) best_predictive_elc = group_predicted[best_index] print(best_predictive_elc, group_errors[best_index]) best_ps = group_pss[best_index] plot_leastsq_results(meg_data_dic, cond, elec_data, elecs, best_ps, time_split, same_ps=True, do_plot=True, do_save=False) def get_lead_groups(k, bipolar): electrodes = get_all_electrodes_names(False) groups = defaultdict(list) elecs_groups = [] for electrode in electrodes: elc_group, elc_num = utils.elec_group_number(electrode) groups[elc_group].append(elc_num) for group_name, nums in groups.items(): max_num = max(nums) - k if bipolar else max(nums) - k + 1 for num in range(max_num): if bipolar: elecs_groups.append(['{}{}-{}{}'.format(group_name, num+l+1, group_name, num+l) for l in range(1, k+1)]) else: elecs_groups.append(['{}{}'.format(group_name, num+l) for l in range(1, k+1)]) return elecs_groups def get_inter_lead_combs(inter_leads_groups, bipolar, calc_groups_product): electrodes = get_all_electrodes_names(False) groups = defaultdict(list) elecs_groups = [] for electrode in electrodes: elc_group, elc_num = utils.elec_group_number(electrode) groups[elc_group].append(elc_num) for group in groups.keys(): groups[group] = sorted(groups[group]) for inter_leads_group in inter_leads_groups: if calc_groups_product: electrodes_groups = [['{}{}'.format(group, k) for k in groups[group]] for group in inter_leads_group] elecs_groups.extend(product(*electrodes_groups)) else: max_k = min(map(max, [groups[inter_leads_group[i]] for i in range(len(inter_leads_group))])) for k in range(1, max_k + 1): if not bipolar: elecs_groups.append(['{}{}'.format(group, k) for group in inter_leads_group]) return elecs_groups def analyze_best_predictive_subset(k, event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, electrodes_positive=True, electrodes_normalize=True, njobs=1): cond = utils.first_key(event_id) elec_data = load_electrodes_data(event_id, bipolar, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) electrodes = elec_data.keys() meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs=njobs) min_err = min([float(utils.namebase(fname).split('_')[-1]) for fname in glob.glob(get_pkl_file('best_predictive_subset_{}_*.pkl'.format(k)))]) min_comb, min_err, min_ps = utils.load(get_pkl_file('best_predictive_subset_{}_{}.pkl'.format(k, min_err))) plot_leastsq_results(meg_data_dic, cond, elec_data, min_comb, min_ps, time_split, same_ps=True, do_plot=True, do_save=False) def load_electrodes_data(events_id, bipolar, electrodes=None, from_t=None, to_t=None, subtract_min=False, normalize_data=False, subtract_mean=False): meg_elec_conditions_translate = {'interference':1, 'neutral':0} d = np.load(os.path.join(BLENDER_SUB_FOL, 'electrodes{}_data.npz'.format('_bipolar' if bipolar else ''))) data, names, elecs_conditions = d['data'], d['names'], d['conditions'] elec_data = defaultdict(dict) if from_t is None: from_t = 0 if to_t is None: to_t = -1 if electrodes is None: electrodes = names if subtract_mean: data -= np.mean(data, axis=0) for electrode, elec_data_mat in zip(electrodes, data): for cond in events_id: ind = meg_elec_conditions_translate[cond] data = elec_data_mat[:, ind] if not from_t is None and not to_t is None: data = data[from_t: to_t] if subtract_min: data = data - min(data) if normalize_data: data = data * 1.0/max(data) elec_data[electrode][cond] = data return elec_data def analyze_best_groups(event_id, bipolar, from_t, to_t, time_split, err_threshold=10, groups_panelty=5, only_sig_electrodes=False, electrodes_positive=True, electrodes_normalize=True, gk_sigma=3, njobs=4): cond = utils.first_key(event_id) # electrodes_parc = get_electrodes_parcellation(electrodes, bipolar) elec_data = load_electrodes_data(event_id, bipolar, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, elec_data.keys(), from_t, to_t, gk_sigma, njobs=njobs) print('analyze_best_groups', cond, err_threshold) res_fname = get_pkl_file( '{}_find_best_groups_{}_{}_{}_gp{}.pkl'.format(cond, len(time_split), err_threshold, 'only_sig' if only_sig_electrodes else 'all', groups_panelty)) if os.path.isfile(res_fname): groups, groups_ps, groups_err = utils.load(res_fname) print(len(groups)) print(map(len, groups)) print(map(np.mean, groups_err)) for electrodes_group, electrodes_ps, group_err in zip(groups, groups_ps, groups_err): # for elec in electrodes_group: # print(elec, electrodes_parc[elec]) # print(electrodes_group, group_err) # plot_leastsq_results(meg_data_dic, cond, elec_data, electrodes_group, electrodes_ps, time_split, # same_ps=True, do_plot=True) for elec in electrodes_group: electrodes_train = [e for e in electrodes_group if e!=elec] reconstruct_meg(event_id, electrodes_train, from_t, to_t, time_split, predicted_electrodes=[elec], plot_results=True, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, calc_ridge=False, njobs=njobs) reconstruct_meg(event_id, electrodes_train, from_t, to_t, time_split, predicted_electrodes=[elec], plot_results=True, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, calc_ridge=True, njobs=njobs) # plot_electrodes(bipolar, electrodes_group) else: print('No such file {}'.format(res_fname)) # elecctrodes = get_all_electrodes_names(bipolar) # elec_data = load_all_electrodes(electrodes, positive=True, normalize_data=True) # for cond in events_id: # cum_errors, cum_electrodes, ps_electrodes = utils.load(get_pkl_file('{}_find_best_fit_electrodes_{}.pkl'.format(cond, len(time_split)))) # plt.figure() # plt.plot(cum_errors) # plt.title(cond) # plt.show() # meg_data_dic = load_all_dics(freqs_bin, cond, bipolar, electrodes, from_t, to_t, gk_sigma, njobs) # # plot_leastsq_results(meg_data_dic, cond, elec_data, cum_electrodes, ps_electrodes, time_split) def plot_leastsq_results(meg_data_dic, cond, elec_data, electrodes, electrodes_ps, time_split, optimization_method, predicted_electrodes=[], same_ps=True, do_plot=True, do_save=False, uuid='', title=None, full_title=True, error_functions=(), res_ind=0, save_in_pred_electrode_fol=False, dtw_window=10, tight_layout=True): electrodes = list(electrodes) if uuid=='': uuid = utils.rand_letters(5) time_diff = np.diff(time_split)[0] if len(time_split) > 1 else 500 # ps_range = range(0, len(time_split) * (len(CSD_FREQS) + 1), len(CSD_FREQS) + 1) if same_ps: electrodes_ps = [electrodes_ps] * (len(electrodes) + len(predicted_electrodes)) pics_num_x, pics_num_y = utils.how_many_subplots(len(electrodes) + len(predicted_electrodes)) f, axs = plt.subplots(pics_num_x, pics_num_y, sharex='col', sharey='row', figsize=(12, 8)) if pics_num_x==1 and pics_num_y==1: axs = [axs] elif pics_num_x>1 and pics_num_y>1: axs = list(itertools.chain(*axs)) if do_save: fig_fol = os.path.join(get_figs_fol(), optimization_method, uuid) utils.make_dir(fig_fol) if len(error_functions)==0: error_functions = ERROR_RECONSTRUCT_METHODS electrode_types = ['training'] * len(electrodes) + ['prediction'] * len(predicted_electrodes) for electrode, electrode_ps, ax, electrode_type in zip(electrodes + predicted_electrodes, electrodes_ps, axs, electrode_types): meg = combine_meg_chunks(meg_data_dic[electrode], electrode_ps, time_split, time_diff) err_func = partial(electrode_reconstruction_error, electrode=electrode, meg=meg, electrode_data=elec_data[electrode][cond], electrode_ps=electrode_ps, meg_data_dic=meg_data_dic, time_split=time_split, time_diff=time_diff, dtw_window=dtw_window) errors = {em:err_func(error_calc_method=em) for em in error_functions} errors_str = ''.join(['{}:{:.2f} '.format(em, errors[em]) for em in error_functions]) ax.plot(elec_data[electrode][cond], label=electrode) ax.plot(meg, label='plsq') # ax.set_ylim([0, 1]) # ax.legend() if title is None: if full_title: ax.set_title('{}) {}: {} '.format(res_ind, electrode_type, electrode) + errors_str) else: ax.set_title('{}: {} ({:.2f})'.format(electrode_type, electrode, err_func(error_calc_method=error_functions[0]))) if not title is None: axs[0].set_title(title) if tight_layout: plt.tight_layout() if do_save: if save_in_pred_electrode_fol == True and len(predicted_electrodes)==1: elec_fol = os.path.join(fig_fol, predicted_electrodes[0]) utils.make_dir(elec_fol) fig_fname = os.path.join(elec_fol, '{}_{}.png'.format(uuid, res_ind)) else: fig_fname = os.path.join(fig_fol, '{}_{}.png'.format(uuid, res_ind)) f.savefig(fig_fname) print('saved to {}'.format(fig_fname)) if do_plot and not do_save: plt.show() else: plt.close() def combine_meg_chunks(electrod_meg_data, electrode_ps, time_split, time_diff): meg = [] for ps, from_t, to_t in zip(electrode_ps, time_split, time_split + time_diff): meg_chunk = meg_pred(ps, electrod_meg_data[:, from_t:to_t]) meg = meg_chunk if len(meg) == 0 else np.hstack((meg, meg_chunk)) return meg def get_all_electrodes_names(bipolar): subject_mri_dir = os.path.join(SUBJECTS_MRI_DIR, MRI_SUBJECT) positions_file_name = 'electrodes{}_positions.npz'.format('_bipolar' if bipolar else '') positions_file_name = os.path.join(subject_mri_dir, 'electrodes', positions_file_name) d = np.load(positions_file_name) names = d['names'] names = [elc.astype(str) for elc in names if not elc in BAD_ELECS] return names def calc_p_for_each_electrode(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, electrodes_positive=True, electrodes_normalize=True, outliers=[], only_sig_electrodes=False, plot_results=False, njobs=4): cond = utils.first_key(event_id) if only_sig_electrodes: sig_electrodes = find_significant_electrodes(event_id, bipolar, from_t, to_t, do_plot=False, do_save=False, plot_only_sig=False) electrodes = sig_electrodes[cond] else: electrodes = get_all_electrodes_names(bipolar) errors, pss = [], [] electrodes = [elc for elc in electrodes if elc not in outliers] elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, dont_calc_new_csd=True, njobs=njobs) electrodes = meg_data_dic.keys() chunked_electrodes = utils.chunks(electrodes, len(electrodes) / njobs) params = [(chunk_electrodes, bipolar, event_id, from_t, to_t, time_split, gk_sigma, meg_data_dic, elec_data) for chunk_electrodes in chunked_electrodes] results = utils.run_parallel(_calc_p_for_each_electrode_parallel, params, njobs) for chunk_errors, chunk_pss in results: errors.extend(chunk_errors) pss.extend(chunk_pss) utils.save((electrodes, errors, pss), get_pkl_file('{}_calc_p_for_each_electrode.pkl'.format(cond))) def _calc_p_for_each_electrode_parallel(chunked_params): errors, pss = [], [] electrodes, bipolar, event_id, from_t, to_t, time_split, gk_sigma, meg_data_dic, elec_data = chunked_params cond = utils.first_key(event_id) for electrode in electrodes: elec_errors, elec_ps = reconstruct_meg(event_id, [electrode], from_t, to_t, time_split, gk_sigma=gk_sigma, plot_results=False, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=1) errors.append(elec_errors[cond][electrode]) pss.append(elec_ps[cond]) return errors, pss def analyze_p_for_each_electrode(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, electrodes_positive=True, electrodes_normalize=True, njobs=4): cond = utils.first_key(event_id) electrodes, errors, pss = utils.load(get_pkl_file('{}_calc_p_for_each_electrode.pkl'.format(cond))) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, dont_calc_new_csd=True, njobs=njobs) # plot_leastsq_results(meg_data_dic, cond, elec_data, electrodes, pss, time_split, # same_ps=False, do_plot=True, do_save=False) X = utils.stack(pss) find_best_n_componenets(X, electrodes, event_id, bipolar, from_t, to_t, time_split, meg_data_dic, elec_data, gk_sigma) # utils.plot_3d_PCA(X) # res, best_gmm, bic = utils.calc_clusters_bic(X) def analyze_best_n_componenets(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, electrodes_positive=True, electrodes_normalize=True, njobs=4): cond = utils.first_key(event_id) all_clusters, all_errors = utils.load(get_pkl_file('{}_best_n_componenets'.format(cond))) electrodes, errors, pss = utils.load(get_pkl_file('{}_calc_p_for_each_electrode.pkl'.format(cond))) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, dont_calc_new_csd=True, njobs=njobs) X = utils.stack(pss) for k, cluster_error in enumerate(all_errors): print(k, cluster_error) plt.plot(all_errors) plt.show() gmm = mixture.GMM(n_components=22, covariance_type='spherical') gmm.fit(X) clusters = gmm.predict(X) unique_clusters = np.unique(clusters) cluster_errors = [] for cluster in unique_clusters: cluster_electrodes = np.array(electrodes)[np.where(clusters == cluster)].tolist() elec_errors, elec_ps = reconstruct_meg(event_id, cluster_electrodes, from_t, to_t, time_split, gk_sigma=gk_sigma, plot_results=True, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=1) cluster_errors.append(max(elec_errors[cond].values())) def find_best_n_componenets(X, electrodes, event_id, bipolar, from_t, to_t, time_split, meg_data_dic, elec_data, gk_sigma): cond = utils.first_key(event_id) all_errors = [] all_clusters = [] for n_components in range(1, X.shape[0]): gmm = mixture.GMM(n_components=n_components, covariance_type='spherical') gmm.fit(X) clusters = gmm.predict(X) unique_clusters = np.unique(clusters) cluster_errors = [] for cluster in unique_clusters: cluster_electrodes = np.array(electrodes)[np.where(clusters == cluster)].tolist() elec_errors, elec_ps = reconstruct_meg(event_id, cluster_electrodes, from_t, to_t, time_split, gk_sigma=gk_sigma, plot_results=False, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=1) cluster_errors.append(max(elec_errors[cond].values())) print(n_components, max(cluster_errors)) all_clusters.append(clusters) all_errors.append(max(cluster_errors)) utils.save((all_clusters, all_errors), get_pkl_file('{}_best_n_componenets'.format(cond))) plt.plot(all_errors) plt.show() def find_best_subset(event_id, k, bipolar, from_t, to_t, time_split, gk_sigma=3, only_first_subset=False, electrodes_positive=True, electrodes_normalize=True, outliers=[], split_to_learn_and_pred=False, only_sig_electrodes=False, plot_results=False, njobs=4): cond = utils.first_key(event_id) if only_sig_electrodes: sig_electrodes = find_significant_electrodes(event_id, bipolar, from_t, to_t, do_plot=False, do_save=False, plot_only_sig=False) electrodes = sig_electrodes[cond] else: electrodes = get_all_electrodes_names(bipolar) electrodes = [elc for elc in electrodes if elc not in outliers] if k==0: k = len(electrodes) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs) run, min_error = 0, np.inf cutof = 1 # int(len(used_electrodes)/2) while(True): if only_first_subset: used_electrodes_sets = [np.random.choice(electrodes, k, replace=False).tolist()] else: used_electrodes_sets = utils.find_subsets(electrodes, k) errors = 0 elec_pss = [] for used_electrodes in used_electrodes_sets: if split_to_learn_and_pred: predicted_electrodes = np.random.choice(used_electrodes, cutof, replace=False).tolist() learned_electrodes = [elc for elc in used_electrodes if elc not in predicted_electrodes] elec_errors, elec_ps = reconstruct_meg(event_id, learned_electrodes, from_t, to_t, time_split, gk_sigma=gk_sigma, plot_results=False, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, predicted_electrodes=predicted_electrodes, njobs=njobs) else: elec_errors, elec_ps = reconstruct_meg(event_id, used_electrodes, from_t, to_t, time_split, gk_sigma=gk_sigma, plot_results=plot_results, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=njobs) errors += sum([elec_errors[cond][electrode] for electrode in used_electrodes]) elec_pss.append(elec_ps[cond]) if errors < min_error: min_error = errors print('new min was found! {}'.format(errors)) utils.save((used_electrodes_sets, errors, elec_pss), best_subset_fname( cond, k, time_split, split_to_learn_and_pred, cutof, only_sig_electrodes)) run += 1 print(cond, run, k, '' if only_first_subset else 's', min_error) def analyze_best_subset(event_id, k, bipolar, from_t, to_t, time_split, gk_sigma=3, only_sig_electrodes=True, split_to_learn_and_pred=True, only_first_subset=False, electrodes_positive=True, electrodes_normalize=True, plot_locations=False, do_plot=True, njobs=3): cond = utils.first_key(event_id) electrodes_sets, electrodes_errors, electrodes_pss = utils.load(best_subset_fname( cond, k, time_split, split_to_learn_and_pred, only_sig_electrodes, 1)) electrodes = utils.flat_list_of_lists(electrodes_sets) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=electrodes_positive, normalize_data=electrodes_normalize) meg_data_dic = load_all_dics(freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs) for electrodes_set, electrodes_ps in zip(electrodes_sets, electrodes_pss): if plot_locations: plot_electrodes(bipolar, electrodes_set) # plot_leastsq_results(meg_data_dic, cond, elec_data, electrodes_set, electrodes_ps, time_split, same_ps=True, do_plot=do_plot) def plot_electrodes(bipolar, electrodes=None): names, pos, _ = get_electrodes_positions(MRI_SUBJECT, bipolar) utils.plot_3d_scatter(pos, names.tolist(), electrodes) def best_subset_fname(cond, k, time_split, split_to_learn_and_pred=False, only_sig_electrodes=False, cutof=0, only_first_subset=False, electrodes_normalize=True, electrodes_positive=True): if electrodes_normalize and electrodes_positive: fname = get_pkl_file('{}_{}{}_find_best_subset_{}{}{}{}.pkl'.format( cond, k, '' if only_first_subset else 's', len(time_split), '_split' if split_to_learn_and_pred else '', '_{}'.format(cutof) if cutof > 0 else '', '' if only_sig_electrodes else '_all_elecs')) else: fname = get_pkl_file('{}_{}{}_find_best_subset_{}{}{}{}_norm_{}_positive_{}.pkl'.format( cond, k, '' if only_first_subset else 's', len(time_split), '_split_' if split_to_learn_and_pred else '', '_{}'.format(cutof) if cutof > 0 else '', '' if only_sig_electrodes else '_all_elecs', int(electrodes_normalize), int(electrodes_positive))) return fname def get_pkl_file(fname): return os.path.join(utils.get_files_fol(), fname) def find_best_groups_parralel(params): event_id, bipolar, from_t, to_t, time_split, err_threshold, groups_panelty = params random.seed(utils.rand_letters(5)) find_best_groups(event_id, bipolar, from_t, to_t, time_split, err_threshold=err_threshold, groups_panelty=groups_panelty, only_sig_electrodes=False, electrodes_positive=True, electrodes_normalize=True, gk_sigma=3, njobs=1) def calc_noise_epoches_from_empty_room(events_id, data_raw_fname, empty_room_raw_fname, from_t, to_t, overwrite_epochs=False): from mne.event import make_fixed_length_events from mne.io import Raw epochs_noise_dic = {} epochs_noise_fnames = [get_cond_fname(EPO_NOISE, event) for event in events_id.keys()] if np.all([os.path.isfile(fname) for fname in epochs_noise_fnames]) and not overwrite_epochs: for event in events_id.keys(): epochs_noise_dic[event] = mne.read_epochs(get_cond_fname(EPO_NOISE, event)) else: raw = Raw(data_raw_fname) raw_noise = Raw(empty_room_raw_fname) # raw_noise.info['bads'] = ['MEG0321'] # 1 bad MEG channel picks = mne.pick_types(raw.info, meg=True)#, exclude='bads') events_noise = make_fixed_length_events(raw_noise, 1) epochs_noise = mne.Epochs(raw_noise, events_noise, 1, from_t, to_t, proj=True, picks=picks, baseline=None, preload=True) for event, event_id in events_id.items(): # then make sure the number of epochs is the same epochs = mne.read_epochs(get_cond_fname(EPO, event)) epochs_noise_dic[event] = epochs_noise[:len(epochs.events)] epochs_noise_dic[event].save(get_cond_fname(EPO_NOISE, event)) return epochs_noise_dic def calc_empty_room_noise_csd(events_id, epochs_from_t, epochs_to_t, freq_bins, win_lengths, overwrite_csds=False, overwrite_epochs=False): noise_csds = defaultdict(list) epochs_noise = calc_noise_epoches_from_empty_room(events_id, RAW, RAW_NOISE, epochs_from_t, epochs_to_t, overwrite_epochs=overwrite_epochs) for event in events_id.keys(): if not os.path.isfile(get_cond_fname(NOISE_CSD_EMPTY_ROOM, event)) or overwrite_csds: for freq_bin, win_length in zip(freq_bins, win_lengths): noise_csd = compute_epochs_csd(epochs_noise[event], mode='multitaper', #mode='fourier', fmin=freq_bin[0], fmax=freq_bin[1], fsum=True, tmin=-win_length, tmax=0.0, n_fft=None) noise_csds[event].append(noise_csd) print('saving csd to {}'.format(get_cond_fname(NOISE_CSD_EMPTY_ROOM, event))) utils.save(noise_csds[event], get_cond_fname(NOISE_CSD_EMPTY_ROOM, event)) else: noise_csds[event] = utils.load(get_cond_fname(NOISE_CSD_EMPTY_ROOM, event)) print(tuple(c.data.shape for c in noise_csds[event])) return noise_csds def calc_td_dics(events_id, bipolar, epochs_from_t, epochs_to_t, csd_from_t, csd_to_t, tstep, freq_bins, win_lengths, subtract_evoked=False, overwrite_epochs=False, overwrite_csds=False): region = 'bipolar_electrodes' if bipolar else 'regular_electrodes' noise_csds = calc_empty_room_noise_csd(events_id, epochs_from_t, epochs_to_t, freq_bins=freq_bins, win_lengths=win_lengths, overwrite_csds=overwrite_csds, overwrite_epochs=overwrite_epochs) stcs = {} data_fol = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'all_{}'.format(region)) utils.make_dir(data_fol) for event in events_id.keys(): forward = mne.read_forward_solution(get_cond_fname(FWD_X, event, region=region), surf_ori=True) epochs = mne.read_epochs(get_cond_fname(EPO, event)) stcs[event] = tf.tf_dics(event, epochs, forward, noise_csds[event], csd_from_t, csd_to_t, tstep, win_lengths, freq_bins=freq_bins, subtract_evoked=subtract_evoked, mode='multitaper', reg=0.001, subject=MRI_SUBJECT, data_fol=data_fol, overwrite_csds=False, overwrite_dics_sp=False, overwrite_stc=True) def plot_td_dics(events_id, bipolar, tmin_plot, tmax_plot, freq_bins): from mne.viz import plot_source_spectrogram from mne.source_estimate import _make_stc # Plotting source spectrogram for source with maximum activity # Note that tmin and tmax are set to display a time range that is smaller than # the one for which beamforming estimates were calculated. This ensures that # all time bins shown are a result of smoothing across an identical number of # time windows. region = 'bipolar_electrodes' if bipolar else 'regular_electrodes' data_fol = os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'all_{}'.format(region)) stcs = [] for event in events_id.keys(): stcs.append(tf.load_stcs(event, freq_bins, data_fol, MRI_SUBJECT)) # plot_source_spectrogram(stcs[-1], freq_bins, tmin=tmin_plot, tmax=tmax_plot, # source_index=None, colorbar=True) stcs_diff = [] for stc1, stc2 in zip(stcs[0], stcs[1]): stc_diff = _make_stc(stc1.data - stc2.data, vertices=stc1.vertices, tmin=stc1.tmin, tstep=stc1.tstep, subject=stc1.subject) stcs_diff.append(stc_diff) plot_source_spectrogram(stcs_diff, freq_bins, tmin=tmin_plot, tmax=tmax_plot, source_index=None, colorbar=True) def find_best_freqs_subset(event_id, bipolar, freqs_bins, from_t, to_t, time_split, combs, optimization_method='RidgeCV', optimization_params={}, k=3, gk_sigma=3, njobs=6): freqs_bins = sorted(freqs_bins) all_electrodes = get_all_electrodes_names(bipolar) elec_data = load_electrodes_data(event_id, bipolar, all_electrodes, from_t, to_t, subtract_min=False, normalize_data=False) meg_data_dic = load_all_dics(freqs_bins, event_id, bipolar, all_electrodes, from_t, to_t, gk_sigma, dont_calc_new_csd=True, njobs=njobs) uuid = utils.rand_letters(5) results_fol = get_results_fol(optimization_method) partial_results_fol = os.path.join(results_fol, 'best_freqs_subset_{}'.format(uuid)) utils.make_dir(results_fol) utils.make_dir(partial_results_fol) cond = utils.first_key(event_id) all_freqs_bins_subsets = list(utils.superset(freqs_bins)) random.shuffle(all_freqs_bins_subsets) N = len(all_freqs_bins_subsets) print('There are {} freqs subsets'.format(N)) all_freqs_bins_subsets_chunks = utils.chunks(all_freqs_bins_subsets, int(len(all_freqs_bins_subsets) / njobs)) params = [Bunch(event_id=event_id, bipolar=bipolar, freqs_bins_chunks=freqs_bins_subsets_chunk, cond=cond, from_t=from_t, to_t=to_t, freqs_bins=freqs_bins, partial_results_fol=partial_results_fol, time_split=time_split, only_sig_electrodes=False, only_from_same_lead=True, electrodes_positive=False, electrodes_normalize=False, gk_sigma=gk_sigma, k=k, do_plot_results=False, do_save_partial_results=False, optimization_params=optimization_params, check_only_pred_score=True, njobs=1, N=int(N / njobs), elec_data=elec_data, meg_data_dic=meg_data_dic, all_electrodes=all_electrodes, optimization_method=optimization_method, error_calc_method='rol_corr', error_threshold=30, combs=combs) for freqs_bins_subsets_chunk in all_freqs_bins_subsets_chunks] results = utils.run_parallel(_find_best_freqs_subset_parallel, params, njobs) all_results = [] for chunk_results in results: all_results.extend(chunk_results) params_suffix = utils.params_suffix(optimization_params) output_file = os.path.join(results_fol, 'best_freqs_subset_{}_{}_{}{}.pkl'.format(cond, uuid, k, params_suffix)) print('saving results to {}'.format(output_file)) utils.save((chunk_results, freqs_bins), output_file) def _find_best_freqs_subset_parallel(p): chunk_results = [] uuid = utils.rand_letters(5) output_file = os.path.join(p.partial_results_fol, 'best_freqs_subset_{}_{}.pkl'.format(p.cond, uuid)) now = time.time() for run, freqs_bin in enumerate(p.freqs_bins_chunks): freqs_indices = [p.freqs_bins.index(fb) for fb in freqs_bin] meg_data_dic = get_sub_meg_data_dic(p.meg_data_dic, freqs_indices) if run % 10 == 0 and len(chunk_results) > 0: utils.time_to_go(now, run, p.N) utils.save((chunk_results, p.freqs_bins), output_file) results = find_best_predictive_subset(event_id=p.event_id, bipolar=p.bipolar, freqs_bins=freqs_bin, from_t=p.from_t, to_t=p.to_t, time_split=p.time_split, only_sig_electrodes=False, only_from_same_lead=True, electrodes_positive=False, electrodes_normalize=False, gk_sigma=p.gk_sigma, k=p.k, do_plot_results=False, do_save_partial_results=False, optimization_params=p.optimization_params, check_only_pred_score=True, njobs=1, vebrose=False, uuid_len=5, optimization_method=p.optimization_method, error_calc_method='rol_corr', error_threshold=100, combs=p.combs, save_results=False, elec_data=p.elec_data, meg_data_dic=meg_data_dic, all_electrodes=p.all_electrodes) results = [result + (freqs_bin, ) for result in results] chunk_results.extend(results) utils.save((chunk_results, p.freqs_bins), output_file) return chunk_results def get_sub_meg_data_dic(orig_meg_data_dic, freqs_indices): meg_data_dic = copy.deepcopy(orig_meg_data_dic) for elec in meg_data_dic.keys(): meg_data_dic[elec] = meg_data_dic[elec][freqs_indices, :] return meg_data_dic def pickup_freqs_subsets(event_id, uuid, optimization_method, optimization_params, k=3): results_fol = get_results_fol(optimization_method) partial_results_fol = os.path.join(results_fol, 'best_freqs_subset_{}'.format(uuid)) cond = utils.first_key(event_id) all_results = [] freqs_bins = set() for results_fname in glob.glob(os.path.join(partial_results_fol, 'best_freqs_subset*.pkl')): try: results, freqs_bins = utils.load(results_fname) except: results = utils.load(results_fname) for result in results: freqs_bins |= set([f for f in result[-1]]) all_results.extend(results) params_suffix = utils.params_suffix(optimization_params) output_file = os.path.join(results_fol, 'best_freqs_subset_{}_{}_{}{}.pkl'.format(cond, uuid, k, params_suffix)) freqs_bins = sorted([(fmin, fmax) for fmin, fmax in freqs_bins]) print('saving {} results to {}'.format(len(all_results), output_file)) print('freqs_bins = {}'.format(freqs_bins)) utils.save((all_results, freqs_bins), output_file) def get_results_fol(optimization_method, electrodes_normalize=False, electrodes_positive=False): return os.path.join(utils.get_files_fol(), '{}_norm_{}_positive_{}'.format( optimization_method, str(electrodes_normalize)[0], str(electrodes_positive)[0])) def load_best_freqs_subset(event_id, uuid, optimization_method, optimization_params, from_t, to_t, time_split, k=3, bipolar=False, gk_sigma=3, do_plot=False, verbose=False, recalculate=False, write_errors_csv=False, new_optimization_method='', new_error_calc_method='', resort=False, top_k=np.inf, top_err=np.inf, group_by_predicted=True, best_k_in_group=1, save_plots_in_pred_electrode_fol=False, do_save_plots=False, njobs=4): results_fol = get_results_fol(optimization_method) cond = utils.first_key(event_id) params_suffix = utils.params_suffix(optimization_params) results_fname = os.path.join(results_fol, 'best_freqs_subset_{}_{}_{}{}.pkl'.format(cond, uuid, k, params_suffix)) print('loading {}'.format(results_fname)) results, all_freqs_bins = utils.load(results_fname) print('{} results for {}'.format(len(results), all_freqs_bins)) all_best_freqs_bins = set() if do_plot or recalculate or resort: electrodes = get_all_electrodes_names(bipolar) elec_data = load_electrodes_data(event_id, bipolar, electrodes, from_t, to_t, subtract_min=False, normalize_data=False) if top_err < np.inf: results = sorted(results, key=lambda x:x[2][x[0]]) errors = np.array([x[2][x[0]] for x in results]) top_ind = np.where(errors > top_err)[0][0] print('top ind for max_err {} is {}'.format(top_err, top_ind)) results = results[:top_ind] if group_by_predicted: only_best_results = [] results = sorted(results, key=lambda x:x[0]) for predicted, pred_results in groupby(results, lambda x:x[0]): pred_results = sorted(pred_results, key=lambda x:x[2][x[0]]) # pred_results = sorted(pred_results, key=lambda x:len(x[-1])) only_best_results.extend(pred_results[:best_k_in_group]) results = only_best_results print('new results len after group by: {}'.format(len(results))) if resort: results = freqs_subsets_sort_results(event_id, results, electrodes, elec_data, all_freqs_bins, gk_sigma, from_t, to_t, time_split, new_optimization_method, new_error_calc_method, optimization_params, bipolar, njobs) else: results = sorted(results, key=lambda x:x[2][x[0]]) if write_errors_csv: csv_file = open(os.path.join(results_fol, 'freqs_subsets_{}_{}.csv'.format(k, uuid)), 'w') csv_writer = csv.writer(csv_file, delimiter=',') csv_writer.writerows([['index', 'predicted', 'train', 'predicted RMS', '', '', 'freqs bins']]) for res_ind, best_result in enumerate(results): predicted, train, errors, best_ps, params, best_freqs_bin = best_result for fb in best_freqs_bin: all_best_freqs_bins.add(fb) if do_plot or recalculate: meg_data_dic = load_all_dics(best_freqs_bin, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs) if verbose: print('best result:') print(best_freqs_bin) errors_str = ','.join(['{:.2f}'.format(errors[elec] if recalculate else errors[elec]) for elec in train + [predicted]]) print('{}->{}: {}'.format(train, predicted, errors_str)) if recalculate: errors, ps, params = reconstruct_meg(event_id, best_freqs_bin, train, from_t, to_t, time_split, optimization_method=new_optimization_method, error_calc_method=new_error_calc_method, optimization_params=optimization_params, predicted_electrodes=[predicted], plot_results=do_plot, bipolar=bipolar, dont_calc_new_csd=True, res_ind=res_ind, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=1, uuid=uuid, save_plots_in_pred_electrode_fol=save_plots_in_pred_electrode_fol, do_save_plots=do_save_plots) print(ps[cond]) elif do_plot and not do_save_plots: plot_leastsq_results(meg_data_dic, cond, elec_data, train, best_ps, time_split, optimization_method, predicted_electrodes=[predicted],save_in_pred_electrode_fol=save_plots_in_pred_electrode_fol, do_save=do_save_plots) if write_errors_csv: errors_strs = ['{:.2f}'.format(errors[cond][elec]) for elec in [predicted] + train] csv_writer.writerows([[res_ind, predicted, train] + errors_strs + list(best_freqs_bin)]) if res_ind == top_k: break print('freqs_bins in all {} results:'.format(top_k)) print(sorted(all_best_freqs_bins)) print('not used freqs: {}'.format(set(all_freqs_bins) - all_best_freqs_bins)) if write_errors_csv: csv_file.close() def print_result(errors, train, predicted, cond=None): err = errors[list(errors)[0]] if predicted not in errors else errors errors_str = ','.join(['{:.2f}'.format(err[elec]) for elec in train + [predicted]]) print('{}{}->{}: {}'.format('' if cond is None else '{} '.format(cond), train, predicted, errors_str)) def cut_best_freqs_subset(events_id, uuid, optimization_method, optimization_params, k, max_err): results_fol = get_results_fol(optimization_method) params_suffix = utils.params_suffix(optimization_params) for cond in events_id.keys(): results_fname = os.path.join(results_fol, 'best_freqs_subset_{}_{}_{}{}.pkl'.format(cond, uuid, k, params_suffix)) if os.path.isfile(results_fname): print('loading {}'.format(results_fname)) results, all_freqs_bins = utils.load(results_fname) print('{} results for {}'.format(len(results), all_freqs_bins)) results = sorted(results, key=lambda x:x[2][x[0]]) errors = np.array([x[2][x[0]] for x in results]) top_ind = np.where(errors > max_err)[0][0] print('top ind for max_err {} is {}'.format(max_err, top_ind)) results = results[:top_ind] utils.save((results, all_freqs_bins), results_fname) else: print("{} does't exist!".format(results_fname)) def best_freqs_subset_cv(events_id, uuid, all_freqs_bins, optimization_method, error_calc_method, optimization_params, from_t, to_t, time_split, k=3, bipolar=False, gk_sigma=3, max_err=30, verbose=False, njobs=4): results_fol = get_results_fol(optimization_method) params_suffix = utils.params_suffix(optimization_params) electrodes = get_all_electrodes_names(bipolar) elec_data = load_electrodes_data(events_id, bipolar, electrodes, from_t, to_t) all_meg_data_dic = {} for cond in events_id.keys(): event_id = {cond: events_id[cond]} all_meg_data_dic[cond] = load_all_dics(all_freqs_bins, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs) for cond in events_id.keys(): event_id = {cond: events_id[cond]} other_cond = [c for c in events_id.keys() if c !=cond][0] other_event_id = {other_cond: events_id[other_cond]} results_fname = os.path.join(results_fol, 'best_freqs_subset_{}_{}_{}{}.pkl'.format(cond, uuid, k, params_suffix)) if not os.path.isfile(results_fname): print("{} doesn't exist!".format(results_fname)) continue print('loading {}'.format(results_fname)) results, all_freqs_bins = utils.load(results_fname) print('{} results for {}'.format(len(results), all_freqs_bins)) results = sorted(results, key=lambda x:x[2][x[0]]) errors = np.array([x[2][x[0]] for x in results]) top_inds = np.where(errors > max_err)[0] if len(top_inds) > 0: results = results[:top_inds[0]] print('top ind for max_err {} is {}'.format(max_err, len(results))) all_pred_results = [] results = sorted(results, key=lambda x:x[0]) for predicted, pred_results in groupby(results, lambda x:x[0]): pred_results = sorted(pred_results, key=lambda x:x[2][x[0]]) all_pred_results.append(pred_results) all_pred_results_chunks = utils.chunks(all_pred_results, int(len(all_pred_results) / njobs)) params = [(pred_results_chunk, all_meg_data_dic, elec_data, all_freqs_bins, cond, other_cond, other_event_id, from_t, to_t, time_split, optimization_method, error_calc_method, optimization_params, bipolar, max_err, uuid, verbose) for pred_results_chunk in all_pred_results_chunks] all_results = utils.run_parallel(_best_freqs_subset_cv_parallel, params, njobs) all_good_results = defaultdict(list) for good_results in all_results: all_good_results[cond].extend(good_results[cond]) results_fol = get_results_fol(optimization_method) params_suffix = utils.params_suffix(optimization_params) good_results_fname = os.path.join(results_fol, 'best_freqs_subset_cv_{}_{}_{}{}.pkl'.format(cond, uuid, k, params_suffix)) utils.save(good_results, good_results_fname) print('{} good result for max_err {}'.format(len(all_good_results[cond]), max_err)) def _best_freqs_subset_cv_parallel(params): pred_results_chunk, all_meg_data_dic, elec_data, all_freqs_bins, cond, other_cond, other_event_id, from_t, to_t, time_split,\ optimization_method, error_calc_method, optimization_params, bipolar, max_err, uuid, verbose = params good_results = defaultdict(list) for pred_results in pred_results_chunk: for result in pred_results: predicted, train, errors, ps, params, freqs_bins = result freqs_indices = [all_freqs_bins.index(fb) for fb in freqs_bins] other_meg_data_dic = get_sub_meg_data_dic(all_meg_data_dic[other_cond], freqs_indices) other_errors, other_ps, other_params = reconstruct_meg(other_event_id, freqs_bins, train, from_t, to_t, time_split, optimization_method=optimization_method, error_calc_method=error_calc_method, optimization_params=optimization_params, predicted_electrodes=[predicted], plot_results=False, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=other_meg_data_dic, elec_data=elec_data, njobs=1, uuid=uuid) if other_errors[other_cond][predicted] < max_err and errors[predicted] < max_err: if verbose: print('results for {}'.format(sorted(freqs_bins))) print_result(errors, train, predicted, cond) print_result(other_errors, train, predicted, other_cond) good_results[cond].append((result, other_errors[other_cond], other_ps[other_cond], other_params)) return good_results def load_best_freqs_subset_cv(events_id, uuid, all_freqs_bins, optimization_method, optimization_params, from_t, to_t, time_split, gk_sigma, max_err, k=3, bipolar=False, njobs=4): results_fol = get_results_fol(optimization_method) params_suffix = utils.params_suffix(optimization_params) electrodes = get_all_electrodes_names(bipolar) elec_data = load_electrodes_data(events_id, bipolar, electrodes, from_t, to_t) all_meg_data_dic = {} for cond in events_id.keys(): event_id = {cond: events_id[cond]} all_meg_data_dic[cond] = load_all_dics(all_freqs_bins, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs) for cond in events_id.keys(): other_cond = [c for c in events_id.keys() if c !=cond][0] results_fname = os.path.join(results_fol, 'best_freqs_subset_cv_{}_{}_{}{}.pkl'.format(cond, uuid, k, params_suffix)) if not os.path.isfile(results_fname): continue results = utils.load(results_fname)[cond] print('{} good result for max_err {}'.format(len(results), max_err)) results = sorted(results, key=lambda x:x[0][0]) for predicted, pred_results in groupby(results, lambda x:x[0][0]): pred_results = list(pred_results) print('{}: {} good results'.format(predicted, len(pred_results))) pred_results = sort_both_conds(pred_results, other_cond) (predicted, train, best_errors, best_ps, best_params, best_freqs_bins),\ best_other_errors, other_best_ps, _ = pred_results[0] print(best_freqs_bins) print_result(best_errors, train, predicted, cond) print_result(best_other_errors, train, predicted, other_cond) freqs_indices = [all_freqs_bins.index(fb) for fb in best_freqs_bins] cond_meg_data_dic = get_sub_meg_data_dic(all_meg_data_dic[cond], freqs_indices) other_meg_data_dic = get_sub_meg_data_dic(all_meg_data_dic[other_cond], freqs_indices) plot_both_cond_prediction(predicted, cond_meg_data_dic, other_meg_data_dic, cond, other_cond, elec_data, best_ps, other_best_ps[other_cond], time_split, optimization_method) def sort_both_conds(pred_results, other_cond): both_errors = [] for result, other_errors, other_ps, other_params in pred_results: predicted, train, errors, ps, params, freqs_bins = result cond_error = errors[predicted] other_cond_error = other_errors[predicted] both_errors.append(cond_error * other_cond_error) return [res for err, res in sorted(zip(both_errors, pred_results))] def plot_both_cond_prediction(predicted_electrode, cond_meg_data_dic, other_meg_data_dic, cond, other_cond, elec_data, cond_ps, other_cond_ps, time_split, optimization_method, do_plot=True, do_save=False, uuid='', title=None, full_title=True, error_functions=(), res_ind=0, save_in_pred_electrode_fol=False, dtw_window=10, tight_layout=True): if uuid=='': uuid = utils.rand_letters(5) time_diff = np.diff(time_split)[0] if len(time_split) > 1 else 500 pics_num_x, pics_num_y = utils.how_many_subplots(2) f, axs = plt.subplots(pics_num_x, pics_num_y, sharex='col', sharey='row', figsize=(12, 8)) if do_save: fig_fol = os.path.join(get_figs_fol(), optimization_method, uuid) utils.make_dir(fig_fol) if len(error_functions) == 0: error_functions = ERROR_RECONSTRUCT_METHODS electrodes_data = [elec_data[predicted_electrode][cond], elec_data[predicted_electrode][other_cond]] meg_data_dics = [cond_meg_data_dic, other_meg_data_dic] conds = [cond, other_cond] pss = [cond_ps, other_cond_ps] for electrode_data, meg_data_dic, current_cond, ps, ax in zip(electrodes_data, meg_data_dics, conds, pss, axs): meg = combine_meg_chunks(meg_data_dic[predicted_electrode], ps, time_split, time_diff) err_func = partial(electrode_reconstruction_error, electrode=predicted_electrode, meg=meg, electrode_data=electrode_data, electrode_ps=ps, meg_data_dic=meg_data_dic, time_split=time_split, time_diff=time_diff, dtw_window=dtw_window) errors = {em:err_func(error_calc_method=em) for em in error_functions} errors_str = ''.join(['{}:{:.2f} '.format(em, errors[em]) for em in error_functions]) ax.plot(elec_data[predicted_electrode][current_cond], label=predicted_electrode) ax.plot(meg, label='plsq') if title is None: ax.set_title('{}: {} '.format(current_cond, predicted_electrode) + errors_str) if not title is None: axs[0].set_title(title) if tight_layout: plt.tight_layout() if do_save: if save_in_pred_electrode_fol: elec_fol = os.path.join(fig_fol, predicted_electrode) utils.make_dir(elec_fol) fig_fname = os.path.join(elec_fol, '{}_{}.png'.format(uuid, res_ind)) else: fig_fname = os.path.join(fig_fol, '{}_{}.png'.format(uuid, res_ind)) f.savefig(fig_fname) print('saved to {}'.format(fig_fname)) if do_plot and not do_save: plt.show() else: plt.close() def freqs_subsets_sort_results(event_id, results, electrodes, elec_data, all_freqs_bins, gk_sigma, from_t, to_t, time_split, new_optimization_method, new_error_calc_method, optimization_params, bipolar, njobs=4): cond = utils.first_key(event_id) orig_meg_data_dic = load_all_dics(all_freqs_bins, event_id, bipolar, electrodes, from_t, to_t, gk_sigma, njobs) results_errors = [] for result in tqdm(results): predicted, train, errors, ps, params, freqs_bin = result freqs_indices = [all_freqs_bins.index(fb) for fb in freqs_bin] meg_data_dic = get_sub_meg_data_dic(orig_meg_data_dic, freqs_indices) errors, ps, params = reconstruct_meg(event_id, freqs_bin, train, from_t, to_t, time_split, optimization_method=new_optimization_method, error_calc_method=new_error_calc_method, optimization_params=optimization_params, predicted_electrodes=[predicted], plot_results=False, bipolar=bipolar, dont_calc_new_csd=True, all_meg_data=meg_data_dic, elec_data=elec_data, njobs=1) results_errors.append(errors[cond][predicted]) return sorted(res for (err, res) in zip(results, results_errors)) def fix_names(bipolar): #laf = lof #lmt = lpt dics = glob.glob(os.path.join(SUBJECT_MEG_FOL, 'subcorticals', 'dics', 'bipolar' if bipolar else 'regular', 'dics_*.npy')) for dic_fname in dics: if 'LAF' in dic_fname: new_dic_fname = dic_fname.replace('LAF', 'LOF') os.rename(dic_fname, new_dic_fname) elif 'LMT' in dic_fname: new_dic_fname = dic_fname.replace('MTF', 'LPT') os.rename(dic_fname, new_dic_fname) def main(): from_t, to_t = 500, 1000# -500, 2000 bipolar = False gk_sigma = 3 njobs = int(utils.how_many_cores() / 2) logging.basicConfig(filename='errors.log',level=logging.ERROR) random.seed(datetime.now()) events_id = dict(neutral=2, interference=1) # electrode = 'LAT3-LAT2' if bipolar else 'LAT3' # electrodes = ['LAT3-LAT2', 'LPT2-LPT1', 'LAT2-LAT1', 'LAT4-LAT3'] # predicted_electrodes = []#['LAT2-LAT1', 'LAT4-LAT3'] use_fwd_for_region = False sub_corticals_codes_file = os.path.join(BLENDER_ROOT_DIR, 'sub_cortical_codes.txt') time_split = np.arange(0, 500, 100) # np.arange(0, 500, 500) CSD_FREQS_NO_INF = [(0, 4), (1, 3), (2, 6), (3, 5), (4,8), (6,10), (8,12), (10, 14), (12, 16), (12, 25), (25, 40), (40, 100), (60, 100), (80, 120), (100, 140)]#, (80, np.inf), (0, np.inf)] CSD_FREQS_INF = [(80, np.inf), (0, np.inf)] CSD_FREQS = CSD_FREQS_NO_INF + CSD_FREQS_INF CSD_FREQS_DALAL = [(4, 8), (8, 12), (12, 30), (30, 55), (65, 300)] # Hz # find_fit(events_id, bipolar, from_t, to_t, time_split, err_threshold=10, plot_results=False) # analyze_leastsq_results(events_id, time_split) cond = utils.first_key(events_id)# list(events_id)[0] # neutral = 'neutral' event_id = {cond: events_id[cond]} # nice_combs = [['RMF5-RMF4', 'RMF4-RMF3', 'RMF6-RMF5'],['RMT4-RMT3','RMT3-RMT2','RMT5-RMT4'],['RAT6-RAT5','RAT5-RAT4', 'RAT7-RAT6'], ['RMF6-RMF5','RMF5-RMF4', 'RMF7-RMF6']] # groups_panelty = 1 only_sig_electrodes = False # params = [({cond: events_id[cond]}, bipolar, from_t, to_t, time_split, err_threshold, groups_panelty) for # (cond, err_threshold) in product(events_id.keys(), [3,5,7,10])] # utils.run_parallel(find_best_groups_parralel, params, 8) # find_best_groups(event_id, bipolar, from_t, to_t, time_split, err_threshold=10, groups_panelty=1, only_sig_electrodes=False, njobs=njobs) # analyze_best_groups(event_id, bipolar, from_t, to_t, time_split, err_threshold=5, groups_panelty=1) optimization_params={'window':30, 'alpha':5} freqs_bin = CSD_FREQS_DALAL find_bps = partial(find_best_predictive_subset, event_id=event_id, bipolar=bipolar, freqs_bins=freqs_bin, from_t=from_t, to_t=to_t, time_split=time_split, only_sig_electrodes=False, only_from_same_lead=True, electrodes_positive=False, electrodes_normalize=False, electrodes_subtract_mean=False, gk_sigma=3, k=3, do_plot_results=False, do_save_partial_results=False, optimization_params=optimization_params, check_only_pred_score=True, njobs=4) bps_collect_results = partial(best_predictive_subset_collect_results, event_id=event_id, bipolar=bipolar, freqs_bin=freqs_bin, from_t=from_t, to_t=to_t, time_split=time_split, sort_only_accoring_to_pred=True, calc_all_errors=False, dtw_window=10, electrodes_positive=False, electrodes_normalize=False, njobs=1, electrodes_subtract_mean=True, optimization_params=optimization_params, do_save=False, do_plot=True, save_in_pred_electrode_fol=True, write_errors_csv=False, do_plot_electrodes=False, error_functions=ERROR_RECONSTRUCT_METHODS) mg78_inter_leads = [['RAT', 'RMT', 'RPT'], ['RAF', 'RMF', 'RPF'], ['ROF', 'RAF', 'RMF']] mg78_inter_leads_combs = get_inter_lead_combs(mg78_inter_leads, bipolar, False) # find_bps(optimization_method='rol_corr', error_calc_method='rol_corr', error_threshold=100, combs=mg78_inter_leads_combs)#, elec_data=elec_data) # find_bps(optimization_method='RidgeCV', error_calc_method='rol_corr', error_threshold=30, combs=mg78_inter_leads_combs) # bps_collect_results(uuid='3f70f', k=3, optimization_method='rol_corr', error_calc_method='rol_corr', error_threshold=100)#, elec_data=elec_data) # bps_collect_results(uuid='7df21', k=3, optimization_method='RidgeCV', error_calc_method='rol_corr', error_threshold=30) # bps_collect_results(uuid='6d6a1', optimization_method='rol_corr', error_calc_method='rol_corr', error_threshold=10) # find_best_freqs_subset(event_id, bipolar, CSD_FREQS_NO_INF, from_t, to_t, time_split, # mg78_inter_leads_combs, 'RidgeCV', optimization_params, k=3, gk_sigma=3, njobs=7) # pickup_freqs_subsets(event_id, '20793', 'RidgeCV', optimization_params) # cut_best_freqs_subset(events_id, '20793', 'RidgeCV', optimization_params, k=3, max_err=50) # load_best_freqs_subset(event_id, '20793', 'RidgeCV', optimization_params, from_t, to_t, time_split, k=3, # load_best_freqs_subset(event_id, '9222b', 'RidgeCV', optimization_params, from_t, to_t, time_split, k=3, # bipolar=False, gk_sigma=3, do_plot=True, recalculate=False, write_errors_csv=False, top_err=25, # group_by_predicted=True, new_optimization_method='rol_corr', new_error_calc_method='rol_corr', # verbose=True, save_plots_in_pred_electrode_fol=False, do_save_plots=False, best_k_in_group=1, njobs=4) # best_freqs_subset_cv(events_id, '20793', CSD_FREQS_NO_INF, 'RidgeCV', 'rol_corr', optimization_params, # from_t, to_t, time_split, k=3, bipolar=False, gk_sigma=3, max_err=50, njobs=4) load_best_freqs_subset_cv(events_id, '20793', CSD_FREQS_NO_INF, 'RidgeCV', optimization_params, from_t, to_t, time_split, gk_sigma, max_err=50) # all_electrodes = get_all_electrodes_names(bipolar) # elec_data = load_electrodes_data(event_id, bipolar, all_electrodes, from_t, to_t, # subtract_min=False, normalize_data=False, subtract_mean=False) # new_combs = [] # for comb in get_inter_lead_combs(mg78_inter_leads, bipolar, False): # # avg = np.mean([elec_data[comb[l]][cond] for l in range(3)], axis=0) # # for l in range(3): # # elec_data[comb[l]][cond] -= avg # elec_data[comb[0]][cond] -= elec_data[comb[2]][cond] # elec_data[comb[1]][cond] -= elec_data[comb[2]][cond] # elec_data[comb[2]][cond] = elec_data[comb[0]][cond] # new_combs.append([comb[0], comb[1]]) # plt.figure() # for l in range(3): # plt.plot(elec_data[comb[l]][cond], label=comb[l]) # plt.legend() # plt.show() # for alpha in range(21): # optimization_params={'window':30, 'alpha':alpha} # print(optimization_params) # find_bps(optimization_method='rol_corr', error_calc_method='rol_corr', error_threshold=100, optimization_params=optimization_params) # plot_reconstruction_for_different_freqs(event_id, 'RMF5-RMF4', ['RMT3-RMT2', 'RMT6-RMT5'], from_t, to_t, time_split) # 3df71 # analyze_best_predictive_subset(4, event_id, bipolar, from_t, to_t, time_split) # plot_predictive_subset(['ROF5-ROF4', 'LAT3-LAT2', 'LAT2-LAT1', 'LAT4-LAT3'], 4, event_id, bipolar, from_t, to_t, time_split, njobs=4) # calc_lead_predictiveness(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, # electrodes_positive=True, electrodes_normalize=True, k=7, njobs=4) # plot_lead_predictiveness(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, # electrodes_positive=True, electrodes_normalize=True, k=5, error_threshold=10, njobs=4) # calc_p_for_each_electrode(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, njobs=4) # analyze_p_for_each_electrode(event_id, bipolar, from_t, to_t, time_split, gk_sigma=3, njobs=4) # analyze_best_n_componenets(event_id, bipolar, from_t, to_t, time_split, njobs=4) # find_best_subset(event_id, 1, bipolar, from_t, to_t, time_split, gk_sigma=3, split_to_learn_and_pred=False, only_sig_electrodes=False, plot_results=True, njobs=njobs) # analyze_best_subset(event_id, 4, bipolar, from_t, to_t, time_split, split_to_learn_and_pred=False, only_sig_electrodes=False, plot_locations=True, do_plot=False) # sig_elecs = find_significant_electrodes(bipolar, from_t, to_t, do_plot=False, do_save=False, plot_only_sig=False) # plot_electrodes(bipolar, sig_elecs[cond]) # learn_and_pred(events_id, bipolar, from_t, to_t, time_split) # sig = find_significant_electrodes(bipolar, from_t, to_t) # check_freqs(events_id, electrodes, from_t, to_t, time_split, predicted_electrodes=predicted_electrodes, gk_sigma=3, bipolar=bipolar, njobs=1) # names, pos, pos_org = get_electrodes_positions(MRI_SUBJECT, bipolar) bipolar = False # electrodes = get_all_electrodes_names(bipolar) # calc_all_fwds(events_id, electrodes, bipolar, from_t, to_t, time_split, overwrite_fwd=True, njobs=7) # calc_electrodes_fwd(MRI_SUBJECT, electrodes, events_id, bipolar=False, overwrite_fwd=True, read_if_exist=False, n_jobs=4) # calc_dics_freqs_csd(events_id, electrodes, bipolar, from_t, to_t, time_split, freqs_bands=CSD_FREQS_DALAL, # overwrite_csds=False, overwrite_dics=False, gk_sigma=3, njobs=7) # calc_dics_freqs_csd(events_id, electrodes, bipolar, from_t, to_t, time_split, freqs_bands=CSD_FREQS_INF, # overwrite_csds=False, overwrite_dics=False, gk_sigma=3, njobs=1) # region = 'bipolar_electrodes' if bipolar else 'regular_electrodes' # calc_dics_freqs_csd(events_id, [region], bipolar, from_t, to_t, time_split, freqs_bands=CSD_FREQS_DALAL, # overwrite_csds=True, overwrite_dics=True, gk_sigma=3, njobs=1) win_lengths = [0.3, 0.3, 0.2, 0.15, 0.1] freqs = CSD_FREQS_DALAL epochs_from_t, epochs_to_t, tstep = -0.5, 2.0, 0.05 csd_from_t, csd_to_t = -0.5, 2.0 tmin_plot, tmax_plot = -0.25, 1.75 # calc_td_dics(events_id, bipolar, epochs_from_t, epochs_to_t, csd_from_t, csd_to_t, tstep, # win_lengths=win_lengths, freq_bins=freqs, overwrite_csds=True, overwrite_epochs=False) # plot_td_dics(events_id, bipolar, tmin_plot, tmax_plot, freq_bins=freqs) # calc_all_electrodes_fwd(MRI_SUBJECT, events_id, overwrite_fwd=False, n_jobs=6) # calc_electrode_fwd(MRI_SUBJECT, electrode, events_id, bipolar, overwrite_fwd=False) # check_electrodes() # check_bipolar_meg(events_id, electrode) # comp_lcmv_dics_electrode(events_id, electrode, bipolar) # plot_activation_one_fig(cond, meg_data_norm, elec_data_norm, electrode, 500) # cond = 'interference' # meg_data = load_all_subcorticals(subject_meg_fol, sub_corticals_codes_file, cond, from_t, to_t, normalize=True) # plot_activation_one_fig(cond, meg_data, elec_data, 'elec', from_t, to_t) # meg_data = call_lcmv(forward, data_cov, noise_cov, evoked, epochs) # plot_activation(events_id, meg_data, elec_data, 'elec', from_t, to_t, 'lcmv') # meg_data = test_all_verts(forward, data_cov, noise_cov, evoked, epochs) if __name__ == '__main__': MEG_SUBJECT = 'ep001' MRI_SUBJECT = 'mg78' constrast='interference' raw_cleaning_method='nTSSS' task = meg.TASK_MSIT fname_format = '{subject}_msit_{raw_cleaning_method}_{constrast}_{cond}_1-15-{ana_type}.{file_type}' SUBJECT_MEG_FOL = os.path.join(SUBJECTS_MEG_DIR, TASKS[task], MEG_SUBJECT) SUBJECT_MRI_FOL = os.path.join(SUBJECTS_MRI_DIR, MRI_SUBJECT) BLENDER_SUB_FOL = os.path.join(BLENDER_ROOT_DIR, MRI_SUBJECT) MEG_ELEC_CONDS_TRANS = {'noninterference':0, 'interference':1} EVENTS_TRANS = {'noninterference':'neutral', 'interference':'interference'} EVENTS_TRANS_INV = {v:k for k, v in EVENTS_TRANS.items()} meg.init_globals(MEG_SUBJECT, MRI_SUBJECT, fname_format, True, raw_cleaning_method, constrast, SUBJECTS_MEG_DIR, TASKS, task, SUBJECTS_MRI_DIR, BLENDER_ROOT_DIR) from src.preproc.meg import RAW, RAW_NOISE, FWD_X, EVO, EPO, EPO_NOISE, DATA_COV, NOISE_COV, \ DATA_CSD, NOISE_CSD, NOISE_CSD_EMPTY_ROOM now = time.time() main() print('Finish! {}'.format(time.time() - now))