import tensorflow as tf def preprocess_batch(images_batch, preproc_func=None): """ Creates a preprocessing graph for a batch given a function that processes a single image. :param images_batch: A tensor for an image batch. :param preproc_func: (optional function) A function that takes in a tensor and returns a preprocessed input. """ if preproc_func is None: return images_batch with tf.variable_scope('preprocess'): images_list = tf.split(images_batch, int(images_batch.shape[0])) result_list = [] for img in images_list: reshaped_img = tf.reshape(img, img.shape[1:]) processed_img = preproc_func(reshaped_img) result_list.append(tf.expand_dims(processed_img, axis=0)) result_images = tf.concat(result_list, axis=0) return result_images