# Copyright (C) 2016 The TensorFlow Authors. # Copyright (C) 2019 Alibaba Group Holding Limited. # All Rights Reserved. # ============================================================================== """A factory-pattern map which returns classification dataset iterator.""" from __future__ import division from __future__ import print_function import tensorflow as tf import cifar10 import flowers import mnist from flags import FLAGS from utils import dataset_utils datasets_map = { 'cifar10': cifar10, 'flowers': flowers, 'mnist': mnist, } def get_dataset_iterator(dataset_name, train_image_size, preprocessing_fn=None, data_sources=None, reader=None): with tf.device("/cpu:0"): if not dataset_name: raise ValueError('expect dataset_name not None.') if dataset_name == 'mock': return dataset_utils._create_mock_iterator(train_image_size) if dataset_name not in datasets_map: raise ValueError('Name of network unknown %s' % dataset_name) def parse_fn(example): with tf.device("/cpu:0"): image, label = datasets_map[dataset_name].parse_fn(example) if preprocessing_fn is not None: image = preprocessing_fn(image, train_image_size, train_image_size) if FLAGS.use_fp16: image = tf.cast(image, tf.float16) label -= FLAGS.labels_offset return image, label return dataset_utils._create_dataset_iterator(data_sources, parse_fn, reader)