# -*- coding: utf-8 -*- import tensorflow as tf import numpy as np import scipy.misc as misc import cv2 import matplotlib.pyplot as plt from flowlib import read_flo, read_pfm from data_augmentation import * from utils import imshow class BasicDataset(object): def __init__(self, crop_h=320, crop_w=896, batch_size=4, data_list_file='path_to_your_data_list_file', img_dir='path_to_your_image_directory', fake_flow_occ_dir='path_to_your_fake_flow_occlusion_directory'): self.crop_h = crop_h self.crop_w = crop_w self.batch_size = batch_size self.img_dir = img_dir self.data_list = np.loadtxt(data_list_file, dtype=np.str) self.data_num = self.data_list.shape[0] self.fake_flow_occ_dir = fake_flow_occ_dir # KITTI's data format for storing flow and mask # The first two channels are flow, the third channel is mask def extract_flow_and_mask(self, flow): optical_flow = flow[:, :, :2] optical_flow = (optical_flow - 32768) / 64.0 mask = tf.cast(tf.greater(flow[:, :, 2], 0), tf.float32) #mask = tf.cast(flow[:, :, 2], tf.float32) mask = tf.expand_dims(mask, -1) return optical_flow, mask # The default image type is PNG. def read_and_decode(self, filename_queue): img1_name = tf.string_join([self.img_dir, '/', filename_queue[0]]) img2_name = tf.string_join([self.img_dir, '/', filename_queue[1]]) img1 = tf.image.decode_png(tf.read_file(img1_name), channels=3) img1 = tf.cast(img1, tf.float32) img2 = tf.image.decode_png(tf.read_file(img2_name), channels=3) img2 = tf.cast(img2, tf.float32) return img1, img2 # For Flying Chairs, the image type is ppm, please use "read_and_decode_ppm" instead of "read_and_decode". # Similarily, for other image types, please write their decode functions by yourself. def read_and_decode_ppm(self, filename_queue): def read_ppm(self, filename): img = misc.imread(filename).astype('float32') return img flying_h = 384 flying_w = 512 img1_name = tf.string_join([self.img_dir, '/', filename_queue[0]]) img2_name = tf.string_join([self.img_dir, '/', filename_queue[1]]) img1 = tf.py_func(read_ppm, [img1_name], tf.float32) img2 = tf.py_func(read_ppm, [img2_name], tf.float32) img1 = tf.reshape(img1, [flying_h, flying_w, 3]) img2 = tf.reshape(img2, [flying_h, flying_w, 3]) return img1, img2 def read_and_decode_distillation(self, filename_queue): img1_name = tf.string_join([self.img_dir, '/', filename_queue[0]]) img2_name = tf.string_join([self.img_dir, '/', filename_queue[1]]) img1 = tf.image.decode_png(tf.read_file(img1_name), channels=3) img1 = tf.cast(img1, tf.float32) img2 = tf.image.decode_png(tf.read_file(img2_name), channels=3) img2 = tf.cast(img2, tf.float32) flow_occ_fw_name = tf.string_join([self.fake_flow_occ_dir, '/flow_occ_fw_', filename_queue[2], '.png']) flow_occ_bw_name = tf.string_join([self.fake_flow_occ_dir, '/flow_occ_bw_', filename_queue[2], '.png']) flow_occ_fw = tf.image.decode_png(tf.read_file(flow_occ_fw_name), dtype=tf.uint16, channels=3) flow_occ_fw = tf.cast(flow_occ_fw, tf.float32) flow_occ_bw = tf.image.decode_png(tf.read_file(flow_occ_bw_name), dtype=tf.uint16, channels=3) flow_occ_bw = tf.cast(flow_occ_bw, tf.float32) flow_fw, occ_fw = self.extract_flow_and_mask(flow_occ_fw) flow_bw, occ_bw = self.extract_flow_and_mask(flow_occ_bw) return img1, img2, flow_fw, flow_bw, occ_fw, occ_bw def augmentation(self, img1, img2): img1, img2 = random_crop([img1, img2], self.crop_h, self.crop_w) img1, img2 = random_flip([img1, img2]) img1, img2 = random_channel_swap([img1, img2]) return img1, img2 def augmentation_distillation(self, img1, img2, flow_fw, flow_bw, occ_fw, occ_bw): [img1, img2, flow_fw, flow_bw, occ_fw, occ_bw] = random_crop([img1, img2, flow_fw, flow_bw, occ_fw, occ_bw], self.crop_h, self.crop_w) [img1, img2, occ_fw, occ_bw], [flow_fw, flow_bw] = random_flip_with_flow([img1, img2, occ_fw, occ_bw], [flow_fw, flow_bw]) img1, img2 = random_channel_swap([img1, img2]) return img1, img2, flow_fw, flow_bw, occ_fw, occ_bw def preprocess_augmentation(self, filename_queue): img1, img2 = self.read_and_decode(filename_queue) img1 = img1 / 255. img2 = img2 / 255. img1, img2 = self.augmentation(img1, img2) return img1, img2 def preprocess_augmentation_distillation(self, filename_queue): img1, img2, flow_fw, flow_bw, occ_fw, occ_bw = self.read_and_decode_distillation(filename_queue) img1 = img1 / 255. img2 = img2 / 255. img1, img2, flow_fw, flow_bw, occ_fw, occ_bw = self.augmentation_distillation(img1, img2, flow_fw, flow_bw, occ_fw, occ_bw) return img1, img2, flow_fw, flow_bw, occ_fw, occ_bw def preprocess_one_shot(self, filename_queue): img1, img2 = self.read_and_decode(filename_queue) img1 = img1 / 255. img2 = img2 / 255. return img1, img2 def create_batch_iterator(self, data_list, batch_size, shuffle=True, buffer_size=5000, num_parallel_calls=4): data_list = tf.convert_to_tensor(data_list, dtype=tf.string) dataset = tf.data.Dataset.from_tensor_slices(data_list) dataset = dataset.map(self.preprocess_augmentation, num_parallel_calls=num_parallel_calls) if shuffle: dataset = dataset.shuffle(buffer_size=buffer_size) dataset = dataset.batch(batch_size) dataset = dataset.repeat() iterator = dataset.make_initializable_iterator() return iterator def create_batch_distillation_iterator(self, data_list, batch_size, shuffle=True, buffer_size=5000, num_parallel_calls=4): data_list = tf.convert_to_tensor(data_list, dtype=tf.string) dataset = tf.data.Dataset.from_tensor_slices(data_list) dataset = dataset.map(self.preprocess_augmentation_distillation, num_parallel_calls=num_parallel_calls) if shuffle: dataset = dataset.shuffle(buffer_size=buffer_size) dataset = dataset.batch(batch_size) dataset = dataset.repeat() iterator = dataset.make_initializable_iterator() return iterator def create_one_shot_iterator(self, data_list, num_parallel_calls=4): """ For Validation or Testing Generate image and flow one_by_one without cropping, image and flow size may change every iteration """ data_list = tf.convert_to_tensor(data_list, dtype=tf.string) dataset = tf.data.Dataset.from_tensor_slices(data_list) dataset = dataset.map(self.preprocess_one_shot, num_parallel_calls=num_parallel_calls) dataset = dataset.batch(1) dataset = dataset.repeat() iterator = dataset.make_initializable_iterator() return iterator