from .loader import Loader import tensorflow as tf import threading import numpy as np import time import glob import os import imageio import cv2 import deepdish as dd SAMPLES_PER_VIDEO = 1 SAMPLES_PER_FRAME = 1 FRAMES = 6 def pad(x, min_side): if np.min(x.shape[:2]) >= min_side: return x else: sh = (max(min_side, x.shape[0]), max(min_side, x.shape[1])) + x.shape[2:] new_x = np.zeros(sh, dtype=x.dtype) new_x[:x.shape[0], :x.shape[1]] = x return new_x def extract_optical_flow(fn, n_frames=34): img = dd.image.load(fn) if img.shape != (128*34, 128, 3): return [] frames = np.array_split(img, 34, axis=0) grayscale_frames = [fr.mean(-1) for fr in frames] mags = [] skip_frames = np.random.randint(34 - n_frames + 1) middle_frame = frames[np.random.randint(skip_frames, skip_frames+n_frames)] im0 = grayscale_frames[skip_frames] for f in range(1+skip_frames, 1+skip_frames+n_frames-1): im1 = grayscale_frames[f] flow = cv2.calcOpticalFlowFarneback(im0, im1, None, # flow 0.5, # pyr_scale 3, # levels np.random.randint(3, 20), # winsize 3, #iterations 5, #poly_n 1.2, #poly_sigma 0 # flags ) mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1]) mags.append(mag) im0 = im1 mag = np.sum(mags, 0) mag = mag.clip(min=0) #norm_mag = np.tanh(mag * 10000) norm_mag = (mag - mag.min()) / (mag.max() - mag.min() + 1e-5) outputs = [] outputs.append((middle_frame, norm_mag)) return outputs class VideoJPEGRollsFlowSaliency(Loader): def __init__(self, path, root_path='', batch_size=16, input_size=227, num_threads=10): self._path = path self._root_path = root_path with open(path) as f: self._list_files = [x.rstrip('\n') for x in f.readlines()] print('list_files', len(self._list_files)) self._batch_size = batch_size self._input_size = input_size self._num_threads = num_threads self._coord = tf.train.Coordinator() self._base_shape = [batch_size, input_size, input_size] self._image_shape = self._base_shape + [3] self._label_shape = self._base_shape + [1] p_x = tf.placeholder(tf.float32, self._image_shape, name='x') p_y = tf.placeholder(tf.float32, self._label_shape, name='y') inputs = [p_x, p_y] self._queue = tf.FIFOQueue(400, [i.dtype for i in inputs], [i.get_shape() for i in inputs]) self._inputs = inputs self._enqueue_op = self._queue.enqueue(inputs) self._queue_close_op = self._queue.close(cancel_pending_enqueues=True) self._threads = [] def __feed(self, rank): time.sleep(np.random.uniform(0, 3)) batch_x = np.zeros(self._image_shape, dtype=np.float32) batch_y = np.zeros(self._label_shape, dtype=np.float32) pool = [] N = len(self._list_files) input_size = self._input_size while True: while len(pool) < self._batch_size * 30: i = np.random.randint(N) fn = os.path.join(self._root_path, self._list_files[i]) #print(fn) outputs = extract_optical_flow(fn, n_frames=FRAMES) for img, mag in outputs: img0 = dd.image.resize(img, min_side=input_size) mag0 = dd.image.resize(mag, min_side=input_size) # Now find a random window h = np.random.randint(img0.shape[0] - input_size + 1) w = np.random.randint(img0.shape[1] - input_size + 1) if np.random.randint(2) == 0: ss = np.s_[:] else: # flipped ss = np.s_[:, ::-1] pool.append((img0[ss], mag0[ss])) if len(pool) >= self._batch_size: break for b in range(self._batch_size): i = np.random.randint(len(pool)) img, mag = pool.pop(i) batch_x[b] = img batch_y[b, ..., 0] = mag yield {self._inputs[0]: batch_x, self._inputs[1]: batch_y} def __thread(self, session, rank): with self._coord.stop_on_exception(): for feed_dict in self.__feed(rank): session.run(self._enqueue_op, feed_dict) def batch(self): x, y = self._queue.dequeue() return x, {'saliency': y} @property def batch_size(self): return self._batch_size def start(self, session): for i in range(self._num_threads): t = threading.Thread(target=VideoJPEGRollsFlowSaliency.__thread, args=(self, session, i)) t.daemon = True t.start() self._threads.append(t) def check_status(self): ret = False for i, t in enumerate(self._threads): if not t.is_alive(): print(f'Thread #{i} has died') ret = True return ret def stop(self, session): self._coord.request_stop() session.run(self._queue_close_op) self._coord.join(self._threads)