""" Input pipeline (tf.dataset and input_fn) for GQN datasets. Adapted from the implementation provided here: https://github.com/deepmind/gqn-datasets/blob/acca9db6d9aa7cfa4c41ded45ccb96fecc9b272e/data_reader.py """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import os import tensorflow as tf # ---------- ad-hoc data structures ---------- DatasetInfo = collections.namedtuple( 'DatasetInfo', ['basepath', 'train_size', 'test_size', 'frame_size', 'sequence_size'] ) Context = collections.namedtuple('Context', ['frames', 'cameras']) Query = collections.namedtuple('Query', ['context', 'query_camera']) TaskData = collections.namedtuple('TaskData', ['query', 'target']) # ---------- dataset constants ---------- _DATASETS = dict( jaco=DatasetInfo( basepath='jaco', train_size=3600, test_size=400, frame_size=64, sequence_size=11), mazes=DatasetInfo( basepath='mazes', train_size=1080, test_size=120, frame_size=84, sequence_size=300), rooms_free_camera_with_object_rotations=DatasetInfo( basepath='rooms_free_camera_with_object_rotations', train_size=2034, test_size=226, frame_size=128, sequence_size=10), rooms_ring_camera=DatasetInfo( basepath='rooms_ring_camera', train_size=2160, test_size=240, frame_size=64, sequence_size=10), # super-small subset of rooms_ring for debugging purposes rooms_ring_camera_debug=DatasetInfo( basepath='rooms_ring_camera_debug', train_size=1, test_size=1, frame_size=64, sequence_size=10), rooms_free_camera_no_object_rotations=DatasetInfo( basepath='rooms_free_camera_no_object_rotations', train_size=2160, test_size=240, frame_size=64, sequence_size=10), shepard_metzler_5_parts=DatasetInfo( basepath='shepard_metzler_5_parts', train_size=900, test_size=100, frame_size=64, sequence_size=15), shepard_metzler_7_parts=DatasetInfo( basepath='shepard_metzler_7_parts', train_size=900, test_size=100, frame_size=64, sequence_size=15) ) _NUM_CHANNELS = 3 _NUM_RAW_CAMERA_PARAMS = 5 _MODES = ('train', 'test') # ---------- helper functions ---------- def _convert_frame_data(jpeg_data): decoded_frames = tf.image.decode_jpeg(jpeg_data) return tf.image.convert_image_dtype(decoded_frames, dtype=tf.float32) def _get_dataset_files(dataset_info, mode, root): """Generates lists of files for a given dataset version.""" basepath = dataset_info.basepath base = os.path.join(root, basepath, mode) if mode == 'train': num_files = dataset_info.train_size else: num_files = dataset_info.test_size length = len(str(num_files)) template = '{:0%d}-of-{:0%d}.tfrecord' % (length, length) record_paths = [ # indexing runs from 1 to n os.path.join(base, template.format(i, num_files)) for i in range(1, num_files + 1)] return record_paths def _get_randomized_indices(context_size, dataset_info): """Generates randomized indices into a sequence of a specific length.""" example_size = context_size + 1 indices = tf.range(0, dataset_info.sequence_size) indices = tf.random_shuffle(indices) indices = tf.slice(indices, begin=[0], size=[example_size]) return indices def _parse(raw_data, dataset_info): """Parses raw data from the tfrecord.""" feature_map = { 'frames': tf.FixedLenFeature( shape=dataset_info.sequence_size, dtype=tf.string), 'cameras': tf.FixedLenFeature( shape=[dataset_info.sequence_size * _NUM_RAW_CAMERA_PARAMS], dtype=tf.float32) } # example = tf.parse_example(raw_data, feature_map) example = tf.parse_single_example(raw_data, feature_map) return example def _preprocess(example, indices, context_size, custom_frame_size, dataset_info): """Preprocesses the parsed data.""" # frames example_size = context_size + 1 frames = tf.concat(example['frames'], axis=0) frames = tf.gather(frames, indices, axis=0) frames = tf.map_fn( _convert_frame_data, tf.reshape(frames, [-1]), dtype=tf.float32, back_prop=False) dataset_image_dimensions = tuple( [dataset_info.frame_size] * 2 + [_NUM_CHANNELS]) frames = tf.reshape( frames, (example_size, ) + dataset_image_dimensions) if (custom_frame_size and custom_frame_size != dataset_info.frame_size): frames = tf.reshape(frames, dataset_image_dimensions) new_frame_dimensions = (custom_frame_size,) * 2 + (_NUM_CHANNELS,) frames = tf.image.resize_bilinear( frames, new_frame_dimensions[:2], align_corners=True) frames = tf.reshape( frames, (-1, example_size) + new_frame_dimensions) # cameras raw_pose_params = example['cameras'] raw_pose_params = tf.reshape( raw_pose_params, [dataset_info.sequence_size, _NUM_RAW_CAMERA_PARAMS]) raw_pose_params = tf.gather(raw_pose_params, indices, axis=0) pos = raw_pose_params[:, 0:3] yaw = raw_pose_params[:, 3:4] pitch = raw_pose_params[:, 4:5] cameras = tf.concat( [pos, tf.sin(yaw), tf.cos(yaw), tf.sin(pitch), tf.cos(pitch)], axis=-1) # return preprocessed tuple preprocessed_example = {} preprocessed_example['frames'] = frames preprocessed_example['cameras'] = cameras return preprocessed_example def _prepare(preprocessed_example): """Prepares the preprocessed data into (feature, label) tuples.""" # decompose frames = preprocessed_example['frames'] cameras = preprocessed_example['cameras'] # split data context_frames = frames[:-1] context_cameras = cameras[:-1] target = frames[-1] query_camera = cameras[-1] context = Context(cameras=context_cameras, frames=context_frames) query = Query(context=context, query_camera=query_camera) data = TaskData(query=query, target=target) return data, data.target # ---------- input_fn ---------- def gqn_input_fn( dataset_name, root, mode, context_size, batch_size=1, num_epochs=1, # optionally reshape frames custom_frame_size=None, # queue params num_threads=4, buffer_size=256, seed=None): """ Creates a tf.data.Dataset based op that returns data. Args: dataset_name: string, one of ['jaco', 'mazes', 'rooms_ring_camera', 'rooms_free_camera_no_object_rotations', 'rooms_free_camera_with_object_rotations', 'shepard_metzler_5_parts', 'shepard_metzler_7_parts']. root: string, path to the root folder of the data. mode: one of tf.estimator.ModeKeys. context_size: integer, number of views to be used to assemble the context. batch_size: (optional) batch size, defaults to 1. num_epochs: (optional) number of times to go through the dataset, defaults to 1. custom_frame_size: (optional) integer, required size of the returned frames, defaults to None. num_threads: (optional) integer, number of threads used to read and parse the record files, defaults to 4. buffer_size: (optional) integer, capacity of the underlying prefetch or shuffle buffer, defaults to 256. seed: (optional) integer, seed for the random number generators used in the dataset. Returns: tf.data.dataset yielding tuples of the form (features, labels) shapes: features.query.context.cameras: [N, K, 7] features.query.context.frames: [N, K, H, W, 3] features.query.query_camera: [N, 7] features.target (same as labels): [N, H, W, 3] Raises: ValueError: if the required version does not exist; if the required mode is not supported; if the requested context_size is bigger than the maximum supported for the given dataset version. """ # map estimator mode key to dataset internal mode strings if mode == tf.estimator.ModeKeys.TRAIN: str_mode = 'train' else: str_mode = 'test' # check validity of requested dataset and split if dataset_name not in _DATASETS: raise ValueError('Unrecognized dataset {} requested. Available datasets ' 'are {}'.format(dataset_name, _DATASETS.keys())) if str_mode not in _MODES: raise ValueError('Unsupported mode {} requested. Supported modes ' 'are {}'.format(str_mode, _MODES)) # retrieve dataset parameters dataset_info = _DATASETS[dataset_name] if context_size >= dataset_info.sequence_size: raise ValueError( 'Maximum support context size for dataset {} is {}, but ' 'was {}.'.format( dataset_name, dataset_info.sequence_size-1, context_size)) # collect the paths to all tfrecord files record_paths = _get_dataset_files(dataset_info, str_mode, root) # create TFRecordDataset dataset = tf.data.TFRecordDataset( filenames=record_paths, num_parallel_reads=num_threads) # parse the data from tfrecords dataset = dataset.map( lambda raw_data: _parse(raw_data, dataset_info), num_parallel_calls=num_threads) # preprocess into context and target indices = _get_randomized_indices(context_size, dataset_info) dataset = dataset.map( lambda example: _preprocess(example, indices, context_size, custom_frame_size, dataset_info), num_parallel_calls=num_threads) # parse into tuple expected by tf.estimator input_fn dataset = dataset.map(_prepare, num_parallel_calls=num_threads) # shuffle data if mode == tf.estimator.ModeKeys.TRAIN: dataset = dataset.shuffle(buffer_size=(buffer_size * batch_size), seed=seed) # set up batching dataset = dataset.repeat(num_epochs) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(buffer_size) return dataset