from __future__ import print_function import numpy as np import importlib from sklearn.metrics import roc_auc_score from numpy import fft import pandas import pickle from scipy.signal import lfilter, butter from copy import copy import warnings import json with open("SETTINGS.json") as file: config = json.load(file) warnings.filterwarnings('ignore', module='.*nolearn.lasagne.*') # Constants that are fixed for the competition N_EVENTS = 6 SAMPLE_RATE = 500 SUBJECTS = list(range(1,13)) TRAIN_SERIES = list(range(1,9)) TEST_SERIES = [9,10] # By default we train on DEFAULT_TRAIN_SIZE randomly selected location each "epoch" # (yes, that's not really an epoch). Similarly, we validate each "epoch" on # VALID_SIZE randomly chosen points. The validation points are chosen with the # random seed set to VALID_SEED. Note that the training points are different # each epoch while the validation points are the same. DEFAULT_TRAIN_SIZE = 16*1024 DEFAULT_VALID_SIZE = 8*1024 VALID_SEED = 199 # Utility Functions def path(subject, series, kind): prefix = config["TRAIN_DATA_PATH"] if (series in TRAIN_SERIES) else config["TEST_DATA_PATH"] return "{0}/subj{1}_series{2}_{3}.csv".format(prefix, subject, series, kind) def read_csv(path): return pandas.read_csv(path, index_col=0).values def splice_at(x, n, size=32): # Make a smooth splice at a junction created by earlier concatenation. # n-1 is one side of junction, n is the other mean = 0.5 * (x[n-1] + x[n]) region = x[n-size:n+size] - mean region[:size] *= np.linspace(1,0,size)[:,None] region[size:] *= np.linspace(0,1,size)[:,None] x[n-size:n+size] = region + mean FILTER_N = 4 # Order of the filters to use def butter_lowpass(highcut, fs, order): nyq = 0.5 * fs high = highcut / nyq b, a = butter(order, high, btype="lowpass") return b, a def butter_bandpass(lowcut, highcut, fs, order): nyq = 0.5 * fs cutoff = [lowcut / nyq, highcut / nyq] b, a = butter(order, cutoff, btype="bandpass") return b, a def butter_highpass(highcut, fs, order): nyq = 0.5 * fs high = highcut / nyq b, a = butter(order, high, btype="highpass") return b, a # Sources for Batch Iterators # # These classes load training and test data and perform some basic preprocessing on it. # They are then passed to factory functions that create the net. There they are used # as data sources for the batch iterators that feed data to the net. # All classes band pass or low pass filter their data based on min / max freq using # a causal filter (lfilter) when the data is first loaded. # * TrainSource loads a several series of EEG data and events, splices them together into # one long stream, then normalizes the EEG data to zero mean and unit standard deviation. # * TestSource is like TrainSource except that it uses the mean and standard deviation # computed for the associated training source to normalize the EEG data. # * SubmitSource is like TestSource except that it does not load and event data. class Source: mean = None std = None _series_cache = {} # Big enough to cache first to subjects for interactive trial and error MAX_CACHE_SIZE = 20 def load_series(self, subject, series): min_freq = self.min_freq max_freq = self.max_freq key = (subject, series, min_freq, max_freq) if key not in self._series_cache: while len(self._series_cache) > self.MAX_CACHE_SIZE: # Randomly throw away an item self._series_cache.popitem() print("Loading", subject, series) data = read_csv(path(subject, series, "data")) # Filter here since it's slow and we don't want to filter multiple # times. `lfilter` is CAUSAL and thus doesn't violate the ban on future data. if (self.min_freq is None) or (self.min_freq == 0): print("Low pass filtering, f_h =", max_freq) b, a = butter_lowpass(max_freq, SAMPLE_RATE, FILTER_N) else: print("Band pass filtering, f_l =", min_freq, "f_h =", max_freq) b, a = butter_bandpass(min_freq, max_freq, SAMPLE_RATE, FILTER_N) self._series_cache[key] = lfilter(b, a, data, axis=0) return self._series_cache[key] def load_raw_data(self, subject, series_list): self.raw_data = [self.load_series(subject, i) for i in series_list] def assemble_data(self): if self.raw_data : self.data = np.concatenate(self.raw_data, axis=0) else: self.data = np.zeros([0,32]) n = 0 for x in self.raw_data[:-1]: n += len(x) splice_at(self.data, n) def load_events(self, subject, series): self.raw_events = [read_csv(path(subject, i, "events")) for i in series] def assemble_events(self): if self.raw_events: self.events = np.concatenate(self.raw_events, axis=0) else: self.events = np.zeros([0]) def normalize(self): self.data -= self.mean self.data /= self.std def __len__(self): return len(self.data) class TrainSource(Source): def __init__(self, subject, series_list, min_freq, max_freq): self.subject = subject self.series_list = series_list self.min_freq = min_freq self.max_freq = max_freq self.load_raw_data(subject, series_list) self.load_events(subject, series_list) self._init() def _init(self): self.assemble_data() self.assemble_events() self.mean = self.data.mean(axis=0) self.std = self.data.std(axis=0) self.normalize() class TestSource(Source): def __init__(self, train_source): self.min_freq = train_source.min_freq self.max_freq = train_source.max_freq vseries = sorted(set(TRAIN_SERIES) - set(train_source.series_list)) self.series_list = vseries self.load_raw_data(train_source.subject, vseries) self.load_events(train_source.subject, vseries) self._init(train_source) def _init(self, train_source): self.assemble_data() self.assemble_events() self.mean = train_source.mean self.std = train_source.std self.normalize() class SubmitSource(Source): def __init__(self, subj, series, train_source): self.series_list = series self.min_freq = train_source.min_freq self.max_freq = train_source.max_freq self.load_raw_data(subj, series) self._init(train_source) def _init(self, train_source): self.assemble_data() self.mean = train_source.mean self.std = train_source.std self.normalize() # These two function support validation using the last several trials in each series. # This is specified in train/train_all with validation=<integer> def find_split_index(events, count): in_event = False events_seen = 0 ndx = len(events) - 1 while ndx > 0: if in_event: if events[ndx].sum() == 0: in_event = False events_seen += 1 if events_seen >= count: start = ndx else: if events_seen >= count and events[ndx].sum(): return int((ndx + start) / 2) if events[ndx,0] == 1: in_event = True ndx -= 1 else: raise ValueError("couldn't find a good split point") def split_source(train_source, validation_count=1): """Split a training source into a training and a test Source This takes `validation_count` trials from each series and uses them in returned test series. The remaining trials are used for the train source. Arguments: validation_count -- number of trial from each series to place into the test source """ raw_events = train_source.raw_events raw_data = train_source.raw_data # new_train = copy(train_source) new_train.raw_data = [] new_train.raw_events = [] new_test = TestSource(train_source) new_test.raw_data = [] new_test.raw_events = [] for events, data in zip(raw_events, raw_data): split_index = find_split_index(events, validation_count) new_train.raw_events.append(events[:split_index]) new_test.raw_events.append( events[split_index:]) new_train.raw_data.append(data[:split_index]) new_test.raw_data.append( data[split_index:]) new_train._init() new_test._init(new_train) # return new_train, new_test # These are utility functions associated with computin the ROC AUC score def make_valid_indices(source, count): """Make a set of `count` indices to use for validaton""" test_indices = np.arange(len(source.data)) np.random.seed(VALID_SEED) np.random.shuffle(test_indices) return test_indices[:count] def score(net, samples=4096): """Compute the area under the curve, ROC score from a trained net We take `samples` random samples and compute the ROC AUC score on those samples. """ source = net.batch_iterator_test.source test_indices = make_valid_indices(source, samples) predicted = net.predict_proba(test_indices) if predicted.shape[-1] != N_EVENTS: predicted = decode(predicted) actual = source.events[test_indices] try: return roc_auc_score(actual.reshape(-1), predicted.reshape(-1)) except: return 0 def score_for(train_info, subj, series, samples=1024, **kwargs): """Compute the roc_auc score from train_info""" factory, info = train_info weights, train_source = info[subj] if isinstance(series, int): min_freq = train_source.min_freq max_freq = train_source.max_freq base_source = TrainSource(subj, TRAIN_SERIES, min_freq, max_freq) _, test_source = split_source(base_source, series) else: test_source = TestSource(train_source) indices = np.arange(len(test_source.data)) net = factory(train_source=None, test_source=test_source, **kwargs) net.load_weights_from(weights) # return score(net, samples) # ******************************************************************* # These are the core functions meant to be used from outside of the # module: train, train_all, submit, load, dump def train(factory, subject, max_epochs=100, validation=[3,6], min_freq=0.2, max_freq=50, params=None, train_size=DEFAULT_TRAIN_SIZE, valid_size=DEFAULT_VALID_SIZE, **kwargs): """Train a net created by `factory` for `subject` Arguments: factory -- function that returns a net. Arguments vary and can be passed in `kwargs` max_epochs -- maximum number of epochs to train for validation -- type of validation to use. This can be either: - a list: series specified in the list are used for validation - an integer: last `validation` trials of each series are used for validation min_freq -- lower frequency to band pass filter at. If None low pass filter instead max_freq -- upper frequency to band pass or low pass filter at train_size -- the number of points to train with each epoch valid_size -- the number of points to validate with **kwargs -- extra arguments to be passed to `factory` """ # by passing in -1s to the train source, we get a random set of points # to train at each time. train_indices = np.zeros([train_size], dtype=int) - 1 # if isinstance(validation, int): base_source = TrainSource(subject, TRAIN_SERIES, min_freq, max_freq) train_source, test_source = split_source(base_source, validation) valid_indices = make_valid_indices(test_source, valid_size) elif validation: tseries = sorted(set([1,2,3,4,5,6,7,8]) - set(validation)) train_source = TrainSource(subject, tseries, min_freq, max_freq) test_source = TestSource(train_source) valid_indices = make_valid_indices(test_source, valid_size) else: test_source = None valid_indices = [] kwargs['patience'] = 0 for k, v in list(kwargs.items()): if isinstance(v, dict): if subject in v: kwargs[k] = v[subject] else: del kwargs[k] net = factory(train_source, test_source, max_epochs=max_epochs, **kwargs) if params is not None: net.load_params_from(params) while True: try: net.fit(train_indices, valid_indices) except MemoryError: input("Memory Error press any key to retry (^C) to stop") else: break params = net.get_all_params_values() if validation: score = score_for( (factory, {subject : (params, train_source)}), subject, validation, **kwargs) print("Score:", score) else: score = None return (params, train_source, score) def train_all(factory, max_epochs=20, epoch_boost=20, **kwargs): """Train a net created by `factory` for all subjects We train the net for the first subject for a maximum of `max_epochs`+`epoch_boost` epochs. Subsequent subjects are trained for only `max_epochs`, but we use a warm start, initializing their weights based on the weights computed for the previous subject. This greatly speeds up the fit. This is the primary function for training nets. Typical usage would be: >>> import grasp >>> import net_stf7 >>> info = grasp.train_all(net_stf7.create_net, max_epochs=50) wait for several hours ..... >>> grasp.make_submission(info, "path_to_write_output_to.csv") Arguments: factory -- a function that returns a net. Arguments vary and can be passed in `kwargs` max_epochs -- the maximum number of epochs to train all but the first subject for epoch_boost -- extra epochs to train for on first subject **kwargs -- args to forward on to `train` """ info = {} net = None params = None scores = [] for subj in SUBJECTS: print("Subject:", subj) epochs = max_epochs + epoch_boost params, source, score = train(factory, subj, epochs, params=params, **kwargs) scores.append(score) info[subj] = (params, source, score) epoch_boost = 0 print("Overall score:", np.mean(scores)) kwargs.update({'max_epochs' : max_epochs, 'epoch_boost' : epoch_boost}) return (factory, kwargs, info) def submit_only_kwargs(kwargs): """Strip out kwargs that are not used in submit""" kwargs = kwargs.copy() for key in ['patience', 'min_freq', 'max_freq', 'validation', "max_epochs", "epoch_boost", "train_size", "valid_size"]: _ = kwargs.pop(key, None) return kwargs def make_submission(train_info, path): """create a submission file based on `train_info` at `path`""" factory, kwargs, info = train_info all_probs = [] for subj in SUBJECTS: weights, train_source, score = info[subj] for series in [9,10]: print("Subject:", subj, ", series:", series) submit_source = SubmitSource(subj, [series], train_source) indices = np.arange(len(submit_source.data)) net = factory(train_source=None, test_source=submit_source, **submit_only_kwargs(kwargs)) net.load_weights_from(weights) all_probs.append((subj, series, net.predict_proba(indices))) # with open(path, 'w') as file: file.write("id,HandStart,FirstDigitTouch,BothStartLoadPhase,LiftOff,Replace,BothReleased\n") for subj, series, probs in all_probs: for i, p in enumerate(probs): id = "subj{0}_series{1}_{2},".format(subj, series, i) file.write(id + ",".join(str(x) for x in p)+"\n") def dump(train_info, path): factory, kwargs, info = train_info factory_tuple = (factory.__module__, factory.__name__) series = set([1,2,4,5,7,8]) stripped_info = {k : {'params':params, 'series_list':source.series_list, 'score' : score} for (k, (params, source, score)) in info.items()} dumpable_info = {'factory' : factory_tuple, 'kwargs' : kwargs, 'subject_info' : stripped_info} with open(path, 'wb') as f: pickle.dump(dumpable_info, f) def load(path): with open(path, 'rb') as f: try: # This works if data was saved under py2 and this is py3 loaded = pickle.load(f, encoding="latin-1") except: loaded = pickle.load(f) mod_name, func_name = loaded["factory"] try: mod = importlib.import_module(mod_name) except: # Some of the dumps were created before nets moved to `nets` directory mod = importlib.import_module("nets."+mod_name) factory = getattr(mod, func_name) info = {} # Some of the dumps were computed before kwargs were dumped. kwargs = loaded.get('kwargs', {}) min_freq = kwargs.get("min_freq", 0.2) max_freq = kwargs.get("max_freq", 50) for subj, value in loaded['subject_info'].items(): source = TrainSource(subj, value['series_list'], min_freq, max_freq) score = value.get("score", 0) info[subj] = (value['params'], source, score) return factory, kwargs, info