from __future__ import absolute_import from __future__ import division from __future__ import print_function import multiprocessing import numpy as np import tensorflow as tf from tflib.data.dataset import batch_dataset, Dataset _N_CPU = multiprocessing.cpu_count() def memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, shuffle_buffer_size=None, repeat=-1): """Memory data batch dataset. `memory_data_dict` example: {'img': img_ndarray, 'label': label_ndarray} or {'img': img_tftensor, 'label': label_tftensor} * The value of each item of `memory_data_dict` is in shape of (N, ...). """ dataset = tf.data.Dataset.from_tensor_slices(memory_data_dict) dataset = batch_dataset(dataset, batch_size, prefetch_batch, drop_remainder, filter, map_func, num_threads, shuffle, shuffle_buffer_size, repeat) return dataset class MemoryData(Dataset): """MemoryData. `memory_data_dict` example: {'img': img_ndarray, 'label': label_ndarray} or {'img': img_tftensor, 'label': label_tftensor} * The value of each item of `memory_data_dict` is in shape of (N, ...). """ def __init__(self, memory_data_dict, batch_size, prefetch_batch=_N_CPU + 1, drop_remainder=True, filter=None, map_func=None, num_threads=_N_CPU, shuffle=True, shuffle_buffer_size=None, repeat=-1, sess=None): super(MemoryData, self).__init__() dataset = memory_data_batch_dataset(memory_data_dict, batch_size, prefetch_batch, drop_remainder, filter, map_func, num_threads, shuffle, shuffle_buffer_size, repeat) self._bulid(dataset, sess) if isinstance(list(memory_data_dict.values())[0], np.ndarray): self._n_data = len(list(memory_data_dict.values())[0]) else: self._n_data = list(memory_data_dict.values())[0].get_shape().as_list()[0] def __len__(self): return self._n_data if __name__ == '__main__': data = {'a': np.array([1.0, 2, 3, 4, 5]), 'b': np.array([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6]])} def filter(x): return tf.cond(x['a'] > 2, lambda: tf.constant(True), lambda: tf.constant(False)) def map_func(x): x['a'] = x['a'] * 10 return x # tf.enable_eager_execution() s = tf.Session() dataset = MemoryData(data, 2, filter=None, map_func=map_func, shuffle=True, shuffle_buffer_size=None, drop_remainder=True, repeat=4, sess=s) for i in range(5): print(map(dataset.get_next().__getitem__, ['b', 'a'])) print([n.name for n in tf.get_default_graph().as_graph_def().node])