#!/usr/bin/env python # -*- coding: latin-1 -*- # # Copyright 2016-2019 Blaise Frederick # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # $Author: frederic $ # $Date: 2016/07/12 13:50:29 $ # $Id: tide_funcs.py,v 1.4 2016/07/12 13:50:29 frederic Exp $ # from __future__ import print_function, division import matplotlib.pyplot as pl import numpy as np import scipy as sp import warnings import sys import rapidtide.util as tide_util import rapidtide.fit as tide_fit import rapidtide.miscmath as tide_math import rapidtide.correlate as tide_corr class fmridata: thedata = None theshape = None xsize = None ysize = None numslices = None realtimepoints = None timepoints = None slicesize = None numvox = None numskip = 0 def __init__(self, thedata, zerodata=False, copydata=False, numskip=0): if zerodata: self.thedata = thedata * 0.0 else: if copydata: self.thedata = thedata + 0.0 else: self.thedata = thedata self.getsizes() self.setnumskip(numskip) def getsizes(self): self.theshape = self.thedata.shape self.xsize = self.theshape[0] self.ysize = self.theshape[1] self.numslices = self.theshape[2] try: self.realtimepoints = self.theshape[3] except KeyError: self.realtimepoints = 1 self.slicesize = self.xsize * self.ysize self.numvox = self.slicesize * self.numslices def setnumskip(self, numskip): self.numskip = numskip self.timepoints = self.realtimepoints - self.numskip def byslice(self): return self.thedata[:, :, :, self.numskip:].reshape((self.slicesize, self.numslices, self.timepoints)) def byvol(self): return self.thedata[:, :, :, self.numskip:].reshape((self.numvox, self.timepoints)) def byvox(self): return self.thedata[:, :, :, self.numskip:] class proberegressor: inputtimeaxis = None inputvec = None inputfreq = None inputstart = 0.0 inputoffset = 0.0 targettimeaxis = None targetvec = None targetfreq = None targetstart = 0.0 targetoffset = 0.0 def __init__(self, inputvec, inputfreq, targetperiod, targetpoints, targetstartpoint, targetoversample=1, inputstart=0.0, inputoffset=0.0, targetstart=0.0, targetoffset=0.0, ): self.inputoffset = inputoffset self.setinputvec(inputvec, inputfreq, inputstart=inputstart) self.targetperiod = targetperiod self.makeinputtimeaxis(self) self.targetoversample = targetoversample self.targetpoints = targetpoints self.targetstartpoint = targetstartpoint def setinputvec(self, inputvec, inputfreq, inputstart=0.0): self.inputvec = inputvec self.inputfreq = inputfreq self.inputstart = inputstart def makeinputtimeaxis(self): self.inputtimeaxis = np.linspace(0.0, len(self.inputvec)) / self.inputfreq - (self.inputstarttime + self.inputoffset) def maketargettimeaxis(self): self.targettimeaxis = np.linspace(self.targetperiod * self.targetstartpoint, self.targetperiod * self.targetstartpoint + self.targetperiod * self.targetpoints, num=self.targetpoints, endpoint=True) os_fmri_x = np.arange(0.0, (validtimepoints - optiondict['addedskip']) * self.targetoversample - ( self.targetoversample - 1)) * self.targetoversample * self.targetperiod + skiptime class correlator: reftc = None prepreftc = None testtc = None preptesttc = None timeaxis = None corrlen = 0 datavalid = False timeaxisvalid = False corrorigin = 0 def __init__(self, Fs=0.0, corrorigin=0, lagmininpts=0, lagmaxinpts=0, ncprefilter=None, reftc=None, detrendorder=1, windowfunc='hamming', corrweighting='none'): self.Fs = Fs self.corrorigin = corrorigin self.lagmininpts = lagmininpts self.lagmaxinpts = lagmaxinpts self.ncprefilter = ncprefilter self.reftc = reftc self.detrendorder = detrendorder self.windowfunc = windowfunc if self.windowfunc is not None: self.usewindowfunc = True else: self.usewindowfunc = False self.corrweighting = corrweighting if self.reftc is not None: self.setreftc(self.reftc) def preptc(self, thetc): # prepare timecourse by filtering, normalizing, detrending, and applying a window function return tide_math.corrnormalize(self.ncprefilter.apply(self.Fs, thetc), prewindow=self.usewindowfunc, detrendorder=self.detrendorder, windowfunc=self.windowfunc) def setreftc(self, reftc): self.reftc = reftc + 0.0 self.prepreftc = self.preptc(self.reftc) self.corrlen = len(self.reftc) * 2 - 1 self.corrorigin = self.corrlen // 2 + 1 # make the time axis self.timeaxis = np.arange(0.0, self.corrlen) * (1.0 / self.Fs) \ - ((self.corrlen - 1) * (1.0 / self.Fs)) / 2.0 self.timeaxisvalid = True self.datavalid = False def setlimits(self, lagmininpts, lagmaxinpts): self.lagmininpts = lagmininpts self.lagmaxinpts = lagmaxinpts def trim(self, vector): return vector[self.corrorigin - self.lagmininpts:self.corrorigin + self.lagmaxinpts] def getcorrelation(self, trim=True): if self.datavalid: if trim: return self.trim(self.thexcorr), self.trim(self.timeaxis), self.theglobalmax else: return self.thexcorr, self.timeaxis, self.theglobalmax else: if self.timeaxisvalid: if trim: return None, self.trim(self.timeaxis), None else: return None, self.timeaxis, None else: print('must run correlation before fetching data') return None, None, None def run(self, thetc, trim=True): if len(thetc) != len(self.reftc): print('timecourses are of different sizes - exiting') sys.exit() self.testtc = thetc self.preptesttc = self.preptc(self.testtc) # now actually do the correlation self.thexcorr = tide_corr.fastcorrelate(self.preptesttc, self.prepreftc, usefft=True, weighting=self.corrweighting) self.corrlen = len(self.thexcorr) self.corrorigin = self.corrlen // 2 + 1 # find the global maximum value self.theglobalmax = np.argmax(self.thexcorr) self.datavalid = True if trim: return self.trim(self.thexcorr), self.trim(self.timeaxis), self.theglobalmax else: return self.thexcorr, self.timeaxis, self.theglobalmax class correlation_fitter: corrtimeaxis = None FML_BADAMPLOW = np.uint16(0x01) FML_BADAMPHIGH = np.uint16(0x02) FML_BADSEARCHWINDOW = np.uint16(0x04) FML_BADWIDTHLOW = np.uint16(0x08) FML_BADWIDTHHIGH = np.uint16(0x10) FML_BADLAG = np.uint16(0x20) FML_FITFAIL = np.uint16(0x40) FML_INITFAIL = np.uint16(0x80) def __init__(self, corrtimeaxis=None, lagmin=-30.0, lagmax=30.0, absmaxsigma=1000.0, absminsigma=0.25, hardlimit=True, bipolar=False, lthreshval=0.0, uthreshval=1.0, debug=False, findmaxtype='gauss', zerooutbadfit=True, refine=False, maxguess=0.0, useguess=False, searchfrac=0.5, fastgauss=False, lagmod=1000.0, enforcethresh=True, displayplots=False): r""" Parameters ---------- corrtimeaxis: 1D float array The time axis of the correlation function lagmin: float The minimum allowed lag time in seconds lagmax: float The maximum allowed lag time in seconds absmaxsigma: float The maximum allowed peak halfwidth in seconds hardlimit bipolar: boolean If true find the correlation peak with the maximum absolute value, regardless of sign threshval uthreshval debug zerooutbadfit refine maxguess useguess searchfrac fastgauss lagmod enforcethresh displayplots Returns ------- Methods ------- fit(corrfunc): Fit the correlation function given in corrfunc and return the location of the peak in seconds, the maximum correlation value, the peak width setrange(lagmin, lagmax): Specify the search range for lag peaks, in seconds """ self.setcorrtimeaxis(corrtimeaxis) self.lagmin = lagmin self.lagmax = lagmax self.absmaxsigma = absmaxsigma self.absminsigma = absminsigma self.hardlimit = hardlimit self.bipolar = bipolar self.lthreshval = lthreshval self.uthreshval = uthreshval self.debug=debug self.findmaxtype=findmaxtype self.zerooutbadfit = zerooutbadfit self.refine = refine self.maxguess = maxguess self.useguess = useguess self.searchfrac = searchfrac self.fastgauss = fastgauss self.lagmod = lagmod self.enforcethresh = enforcethresh self.displayplots = displayplots def _maxindex_noedge(self, corrfunc): """ Parameters ---------- corrfunc Returns ------- """ lowerlim = 0 upperlim = len(self.corrtimeaxis) - 1 done = False while not done: flipfac = 1.0 done = True maxindex = (np.argmax(corrfunc[lowerlim:upperlim]) + lowerlim).astype('int32') if self.bipolar: minindex = (np.argmax(np.fabs(corrfunc[lowerlim:upperlim])) + lowerlim).astype('int32') if np.fabs(corrfunc[minindex]) > np.fabs(corrfunc[maxindex]): maxindex = minindex flipfac = -1.0 else: maxindex = (np.argmax(corrfunc[lowerlim:upperlim]) + lowerlim).astype('int32') if upperlim == lowerlim: done = True if maxindex == 0: lowerlim += 1 done = False if maxindex == upperlim: upperlim -= 1 done = False return maxindex, flipfac def setrange(self, lagmin, lagmax): self.lagmin = lagmin self.lagmax = lagmax def setcorrtimeaxis(self, corrtimeaxis): if corrtimeaxis is not None: self.corrtimeaxis = corrtimeaxis + 0.0 else: self.corrtimeaxis = corrtimeaxis def setguess(self, useguess, maxguess=0.0): self.useguess = useguess self.maxguess = maxguess def setlthresh(self, lthreshval): self.lthreshval = lthreshval def setuthresh(self, uthreshval): self.uthreshval = uthreshval def diagnosefail(self, failreason): # define error values reasons = [] if failreason.astype(np.uint16) & self.FML_BADAMPLOW: reasons.append('Fit amplitude too low') if failreason.astype(np.uint16) & self.FML_BADAMPHIGH: reasons.append('Fit amplitude too high') if failreason.astype(np.uint16) & self.FML_BADSEARCHWINDOW: reasons.append('Bad search window') if failreason.astype(np.uint16) & self.FML_BADWIDTHLOW: reasons.append('Bad fit width - value too low') if failreason.astype(np.uint16) & self.FML_BADWIDTHHIGH: reasons.append('Bad fit width - value too high') if failreason.astype(np.uint16) & self.FML_BADLAG: reasons.append('Lag out of range') if failreason.astype(np.uint16) & self.FML_FITFAIL: reasons.append('Refinement failed') if failreason.astype(np.uint16) & self.FML_INITFAIL: reasons.append('Initialization failed') if len(reasons) > 0: return ', '.join(reasons) else: return 'No error' def fit(self, corrfunc): # check to make sure xcorr_x and xcorr_y match if self.corrtimeaxis is None: print("Correlation time axis is not defined - exiting") sys.exit() if len(self.corrtimeaxis) != len(corrfunc): print('Correlation time axis and values do not match in length (', len(self.corrtimeaxis), '!=', len(corrfunc), '- exiting') sys.exit() # set initial parameters # absmaxsigma is in seconds # maxsigma is in Hz # maxlag is in seconds warnings.filterwarnings("ignore", "Number*") failreason = np.uint(0) maskval = np.uint16(1) # start out assuming the fit will succeed binwidth = self.corrtimeaxis[1] - self.corrtimeaxis[0] # set the search range lowerlim = 0 upperlim = len(self.corrtimeaxis) - 1 if self.debug: print('initial search indices are', lowerlim, 'to', upperlim, '(', self.corrtimeaxis[lowerlim], self.corrtimeaxis[upperlim], ')') # make an initial guess at the fit parameters for the gaussian # start with finding the maximum value and its location flipfac = 1.0 if self.useguess: maxindex = tide_util.valtoindex(self.corrtimeaxis, self.maxguess) else: maxindex, flipfac = self._maxindex_noedge(corrfunc) corrfunc *= flipfac maxlag_init = (1.0 * self.corrtimeaxis[maxindex]).astype('float64') maxval_init = corrfunc[maxindex].astype('float64') if self.debug: print('maxindex, maxlag_init, maxval_init:', maxindex, maxlag_init, maxval_init) # then calculate the width of the peak thegrad = np.gradient(corrfunc).astype('float64') # the gradient of the correlation function peakpoints = np.where(corrfunc > self.searchfrac * maxval_init, 1, 0) # mask for places where correlaion exceeds serchfrac*maxval_init peakpoints[0] = 0 peakpoints[-1] = 0 peakstart = np.max([1, maxindex - 1]) peakend = np.min([len(self.corrtimeaxis) - 2, maxindex + 1]) while thegrad[peakend + 1] <= 0.0 and peakpoints[peakend + 1] == 1: peakend += 1 while thegrad[peakstart - 1] >= 0.0 and peakpoints[peakstart - 1] == 1: peakstart -= 1 # deal with flat peak top while peakend < (len(self.corrtimeaxis) - 3) and corrfunc[peakend] == corrfunc[peakend - 1]: peakend += 1 while peakstart > 2 and corrfunc[peakstart] == corrfunc[peakstart + 1]: peakstart -= 1 # This is calculated from first principles, but it's always big by a factor or ~1.4. # Which makes me think I dropped a factor if sqrt(2). So fix that with a final division maxsigma_init = np.float64( ((peakend - peakstart + 1) * binwidth / (2.0 * np.sqrt(-np.log(self.searchfrac)))) / np.sqrt(2.0)) if self.debug: print('maxsigma_init:', maxsigma_init) # now check the values for errors if self.hardlimit: rangeextension = 0.0 else: rangeextension = (self.lagmax - self.lagmin) * 0.75 if not ((self.lagmin - rangeextension - binwidth) <= maxlag_init <= (self.lagmax + rangeextension + binwidth)): failreason |= (self.FML_INITFAIL | self.FML_BADLAG) if maxlag_init <= (self.lagmin - rangeextension - binwidth): maxlag_init = self.lagmin - rangeextension - binwidth else: maxlag_init = self.lagmax + rangeextension + binwidth if self.debug: print('bad initial') if maxsigma_init > self.absmaxsigma: failreason |= (self.FML_INITFAIL | self.FML_BADWIDTHHIGH) maxsigma_init = self.absmaxsigma if self.debug: print('bad initial width - too high') if peakend - peakstart < 2: failreason |= (self.FML_INITFAIL | self.FML_BADSEARCHWINDOW) maxsigma_init = np.float64( ((2 + 1) * binwidth / (2.0 * np.sqrt(-np.log(self.searchfrac)))) / np.sqrt(2.0)) if self.debug: print('bad initial width - too low') if not (self.lthreshval <= maxval_init <= self.uthreshval) and self.enforcethresh: failreason |= (self.FML_INITFAIL | self.FML_BADAMPLOW) if self.debug: print('bad initial amp:', maxval_init, 'is less than', self.lthreshval) if (maxval_init < 0.0): failreason |= (self.FML_INITFAIL | self.FML_BADAMPLOW) maxval_init = 0.0 if self.debug: print('bad initial amp:', maxval_init, 'is less than 0.0') if (maxval_init > 1.0): failreason |= (self.FML_INITFAIL | self.FML_BADAMPHIGH) maxval_init = 1.0 if self.debug: print('bad initial amp:', maxval_init, 'is greater than 1.0') if failreason > 0 and self.zerooutbadfit: maxval = np.float64(0.0) maxlag = np.float64(0.0) maxsigma = np.float64(0.0) else: maxval = np.float64(maxval_init) maxlag = np.float64(maxlag_init) maxsigma = np.float64(maxsigma_init) # refine if necessary if self.refine: X = self.corrtimeaxis[peakstart:peakend + 1] data = corrfunc[peakstart:peakend + 1] '''if self.debug: print('peakstart, peakend', peakstart, peakend) #for i in range(len(data)): # print(X[i], data[i], thegrad[i], ) pl.figure() pl.plot(X, data, 'b') pl.plot(X,peakpoints[peakstart:peakend + 1], 'r') pl.plot(X, thegrad[peakstart:peakend + 1], 'g')''' if self.fastgauss: # do a non-iterative fit over the top of the peak # 6/12/2015 This is just broken. Gives quantized maxima maxlag = np.float64(1.0 * sum(X * data) / sum(data)) maxsigma = np.float64(np.sqrt(np.abs(np.sum((X - maxlag) ** 2 * data) / np.sum(data)))) maxval = np.float64(data.max()) else: # do a least squares fit over the top of the peak # p0 = np.array([maxval_init, np.fmod(maxlag_init, lagmod), maxsigma_init], dtype='float64') p0 = np.array([maxval_init, maxlag_init, maxsigma_init], dtype='float64') if self.debug: print('fit input array:', p0) try: plsq, dummy = sp.optimize.leastsq(tide_fit.gaussresiduals, p0, args=(data, X), maxfev=5000) maxval = plsq[0] maxlag = np.fmod((1.0 * plsq[1]), self.lagmod) maxsigma = plsq[2] except: maxval = np.float64(0.0) maxlag = np.float64(0.0) maxsigma = np.float64(0.0) if self.debug: print('fit output array:', [maxval, maxlag, maxsigma]) # check for errors in fit fitfail = False failreason = np.uint16(0) if self.bipolar: lowestcorrcoeff = -1.0 else: lowestcorrcoeff = 0.0 if maxval < lowestcorrcoeff: failreason |= (self.FML_FITFAIL + self.FML_BADAMPLOW) maxval = lowestcorrcoeff if self.debug: print('bad fit amp: maxval is lower than lower limit') fitfail = True if (np.abs(maxval) > 1.0): failreason |= (self.FML_FITFAIL | self.FML_BADAMPHIGH) maxval = 1.0 * np.sign(maxval) if self.debug: print('bad fit amp: magnitude of', maxval, 'is greater than 1.0') fitfail = True if (self.lagmin > maxlag) or (maxlag > self.lagmax): failreason |= (self.FML_FITFAIL + self.FML_BADLAG) if self.debug: print('bad lag after refinement') if self.lagmin > maxlag: maxlag = self.lagmin else: maxlag = self.lagmax fitfail = True if maxsigma > self.absmaxsigma: failreason |= (self.FML_FITFAIL + self.FML_BADWIDTHHIGH) if self.debug: print('bad width after refinement:', maxsigma, '>', self.absmaxsigma) maxsigma = self.absmaxsigma fitfail = True if maxsigma < self.absminsigma: failreason |= (self.FML_FITFAIL + self.FML_BADWIDTHLOW) if self.debug: print('bad width after refinement:', maxsigma, '<', self.absminsigma) maxsigma = self.absminsigma fitfail = True if fitfail: if self.debug: print('fit fail') if self.zerooutbadfit: maxval = np.float64(0.0) maxlag = np.float64(0.0) maxsigma = np.float64(0.0) maskval = np.int16(0) # print(maxlag_init, maxlag, maxval_init, maxval, maxsigma_init, maxsigma, maskval, failreason, fitfail) else: maxval = np.float64(maxval_init) maxlag = np.float64(np.fmod(maxlag_init, self.lagmod)) maxsigma = np.float64(maxsigma_init) if failreason > 0: maskval = np.uint16(0) if self.debug or self.displayplots: print("init to final: maxval", maxval_init, maxval, ", maxlag:", maxlag_init, maxlag, ", width:", maxsigma_init, maxsigma) if self.displayplots and self.refine and (maskval != 0.0): fig = pl.figure() ax = fig.add_subplot(111) ax.set_title('Data and fit') hiresx = np.arange(X[0], X[-1], (X[1] - X[0]) / 10.0) pl.plot(X, data, 'ro', hiresx, gauss_eval(hiresx, np.array([maxval, maxlag, maxsigma])), 'b-') pl.show() return maxindex, maxlag, flipfac * maxval, maxsigma, maskval, failreason, peakstart, peakend class freqtrack: freqs = None times = None def __init__(self, lowerlim=0.1, upperlim=0.6, nperseg=32, Q=10.0, debug=False): self.lowerlim = lowerlim self.upperlim = upperlim self.nperseg = nperseg self.Q = Q self.debug = debug self.nfft = self.nperseg def track(self, x, fs): self.freqs, self.times, thespectrogram = sp.signal.spectrogram(np.concatenate([np.zeros(int(self.nperseg // 2)), x, np.zeros(int(self.nperseg // 2))], axis=0), fs=fs, detrend='constant', scaling='spectrum', nfft=None, window=np.hamming(self.nfft), noverlap=(self.nperseg - 1)) lowerliminpts = tide_util.valtoindex(self.freqs, self.lowerlim) upperliminpts = tide_util.valtoindex(self.freqs, self.upperlim) if self.debug: print(self.times.shape, self.freqs.shape, thespectrogram.shape) print(self.times) # intitialize the peak fitter thefitter = correlation_fitter(corrtimeaxis=self.freqs, lagmin=self.lowerlim, lagmax=self.upperlim, absmaxsigma=10.0, absminsigma=0.1, debug=self.debug, findmaxtype='gauss', zerooutbadfit=False, refine=True, useguess=False, fastgauss=False ) peakfreqs = np.zeros((thespectrogram.shape[1] - 1), dtype=float) for i in range(0, thespectrogram.shape[1] - 1): maxindex, peakfreqs[i], maxval, maxsigma, maskval, failreason, peakstart, peakend = thefitter.fit(thespectrogram[:, i]) if not (lowerliminpts <= maxindex <= upperliminpts): peakfreqs[i] = -1.0 return self.times[:-1], peakfreqs def clean(self, x, fs, times, peakfreqs, numharmonics=2): nyquistfreq = 0.5 * fs y = x * 0.0 halfwidth = int(self.nperseg // 2) padx = np.concatenate([np.zeros(halfwidth), x, np.zeros(halfwidth)], axis=0) pady = np.concatenate([np.zeros(halfwidth), y, np.zeros(halfwidth)], axis=0) padweight = padx * 0.0 if self.debug: print(fs, len(times), len(peakfreqs)) for i in range(0, len(times)): centerindex = int(times[i] * fs) xstart = centerindex - halfwidth xend = centerindex + halfwidth if peakfreqs[i] > 0.0: filtsignal = padx[xstart:xend] numharmonics = np.min([numharmonics, int((nyquistfreq // peakfreqs[i]) - 1)]) if self.debug: print('numharmonics:', numharmonics, nyquistfreq // peakfreqs[i]) for j in range(numharmonics + 1): workingfreq = (j + 1) * peakfreqs[i] if self.debug: print('workingfreq:', workingfreq) ws = [workingfreq * 0.95, workingfreq * 1.05] wp = [workingfreq * 0.9, workingfreq * 1.1] gpass = 1.0 gstop = 40.0 b, a = sp.signal.iirdesign(wp, ws, gpass, gstop, ftype='cheby2', fs=fs) if self.debug: print(i, j, times[i], centerindex, halfwidth, xstart, xend, xend - xstart, wp, ws, len(a), len(b)) filtsignal = sp.signal.filtfilt(b, a, sp.signal.filtfilt(b, a, filtsignal)) pady[xstart:xend] += filtsignal else: pady[xstart:xend] += padx[xstart:xend] padweight[xstart:xend] += 1.0 return (pady / padweight)[halfwidth:-halfwidth]