# vim: tabstop=8 expandtab shiftwidth=4 softtabstop=4 import numpy as np import cv2 import time class optical_flow_track(object): def __init__(self): # Parameters for lucas kanade optical flow self.lk_params = dict( winSize = (15,15), maxLevel = 1, criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.08)) # Create some random colors self.color = np.random.randint(0,255,(2000,3)) self.old_gray=None self.old_frame=None self.p0=self.p1=None self.initial_state=None def feed(self,frame): if self.old_gray is None: self.old_frame=frame.copy() #self.old_gray=cv2.cvtColor(self.old_frame, cv2.COLOR_BGR2GRAY) self.old_gray=self.old_frame[:,:,2] margx=120 margy=30 self.p0 = np.array([(i,j) for i in range(margx,frame.shape[1]-margx,30) for j in range(margy,frame.shape[0]-margy,30)],dtype='float32').reshape(-1,1,2) self.color=self.color[:len(self.p0)] self.initial_state=[self.p0,self.color] #import ipdb;ipdb.set_trace() frame_gray = frame[:,:,2].copy() #frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) tic=time.time() p1, st, err = cv2.calcOpticalFlowPyrLK(self.old_gray, frame_gray, self.p0, None, **self.lk_params) #print('-------------------->',time.time()-tic) # Select good points good_new = p1[st==1] good_old = self.p0[st==1] self.color=self.color[(st==1).flatten()] # draw the tracks self.old_frame = frame.copy() for i,new in enumerate(good_new): a,b = new.ravel() frame = cv2.circle(frame,(a,b),2,self.color[i].tolist(),-1) # Now update the previous frame and previous points self.old_gray = frame_gray self.p0 = good_new.reshape(-1,1,2) return frame def save_final_state(self,save_path): #create color_dict colors_dict_final={} for pos,color in zip(self.p0,self.color): colors_dict_final[tuple(color)]=pos colors_dict_inital={} for pos,color in zip(*self.initial_state): colors_dict_inital[tuple(color)]=pos #import ipdb;ipdb.set_trace() last_frame=self.old_frame.copy() for k in colors_dict_inital: a,b=colors_dict_inital[k].ravel() if k in colors_dict_final: c,d=colors_dict_final[k].ravel() last_frame = cv2.line(last_frame,(a,b),(c,d),(255,0,0),2) else: last_frame = cv2.circle(last_frame,(a,b),2,(0,0,255),-1) font = cv2.FONT_HERSHEY_COMPLEX_SMALL text=' '.join(save_path.split('/')[-2:]).replace('_',' ') textsize=cv2.getTextSize(text,font,1,1)[0] cv2.putText(last_frame,text,(last_frame.shape[1]//2-textsize[0]//2,30), font,1, (0,0,255),1,cv2.LINE_AA) cv2.imwrite(save_path+'/tracks.png',last_frame) ret={} ret['init_ftr_cnt']=len(self.initial_state[0]) if __name__=="__main__": frame=cv2.imread('/tmp/screenshot.png') of=optical_flow_track() out=of.feed(frame) cv2.imshow('cv window',out) cv2.waitKey(0)