# coding=utf-8 # Copyright 2018 Google LLC & Hwalsuk Lee. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Data-related utility functions. Includes: - A helper class to hold images and Inception features for evaluation. - A method to load a dataset as NumPy array. - Sample from the generator and return the data as a NumPy array. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os from absl import logging import numpy as np from six.moves import range import tensorflow as tf import tensorflow_gan as tfgan # Special value returned when fake image generated by GAN has nans. NAN_DETECTED = 31337.0 INCEPTION_URL = "http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz" INCEPTION_FROZEN_GRAPH = "inceptionv1_for_inception_score.pb" def get_inception_graph_def(): return tfgan.eval.get_graph_def_from_url_tarball( # pylint: disable=unreachable url=INCEPTION_URL, filename=INCEPTION_FROZEN_GRAPH, tar_filename=os.path.basename(INCEPTION_URL)) class NanFoundError(Exception): """Exception thrown, when the Nans are present in the output.""" class EvalDataSample(object): """Helper class to hold images and Inception features for evaluation. All properties are tensors. Images are in [0, 255]. """ def __init__(self, images): self.images = images self.activations = None self.logits = None def discard_images(self): logging.info("Deleting references to images: %s", self.images.shape) del self.images def set_inception_features(self, activations, logits): self.activations = activations self.logits = logits def set_num_examples(self, num_examples): if self.images is not None: assert self.images.shape[0] >= num_examples self.images = self.images[:num_examples] if self.activations is not None: assert self.activations.shape[0] >= num_examples self.activations = self.activations[:num_examples] if self.logits is not None: assert self.logits.shape[0] >= num_examples self.logits = self.logits[:num_examples] def get_real_images(dataset, num_examples, split=None, failure_on_insufficient_examples=True): """Get num_examples images from the given dataset/split. Args: dataset: `ImageDataset` object. num_examples: Number of images to read. split: Split of the dataset to use. If None will use the default split for eval defined by the dataset. failure_on_insufficient_examples: If True raise an exception if the dataset/split does not images. Otherwise will log to error and return fewer images. Returns: 4-D NumPy array with images with values in [0, 256]. Raises: ValueError: If the dataset/split does not of the number of requested number requested images and `failure_on_insufficient_examples` is True. """ logging.info("Start loading real data.") with tf.Graph().as_default(): ds = dataset.eval_input_fn(split=split) # Get real images from the dataset. In the case of a 1-channel # dataset (like MNIST) convert it to 3 channels. next_batch = ds.make_one_shot_iterator().get_next()[0] shape = [num_examples] + next_batch.shape.as_list() is_single_channel = shape[-1] == 1 if is_single_channel: shape[-1] = 3 real_images = np.empty(shape, dtype=np.float32) with tf.Session() as sess: for i in range(num_examples): try: b = sess.run(next_batch) b *= 255.0 if is_single_channel: b = np.tile(b, [1, 1, 3]) real_images[i] = b except tf.errors.OutOfRangeError: logging.error("Reached the end of dataset. Read: %d samples.", i) break if real_images.shape[0] != num_examples: if failure_on_insufficient_examples: raise ValueError("Not enough examples in the dataset %s: %d / %d" % (dataset, real_images.shape[0], num_examples)) else: logging.error("Not enough examples in the dataset %s: %d / %d", dataset, real_images.shape[0], num_examples) logging.info("Done loading real data.") return real_images def sample_fake_dataset(sess, generator, num_batches): """Returns a generated data set as a NumPy array.""" logging.info("Generating a fake data set.") samples = [] for _ in range(num_batches): x = sess.run(generator) # If NaNs were generated, ignore this checkpoint and assign a very high # FID score which we handle specially later. if np.isnan(x).any(): logging.error("Detected NaN in fake_images! Returning NaN.") raise NanFoundError("Detected NaN in fake images.") samples.append(x) fake_images = np.concatenate(samples, axis=0) fake_images *= 255.0 # Convert 1-channel datasets (like MNIST) to 3 channels. if fake_images.shape[3] == 1: fake_images = np.tile(fake_images, [1, 1, 1, 3]) logging.info("Done sampling a generated data set.") return fake_images def inception_transform(inputs): with tf.control_dependencies([ tf.assert_greater_equal(inputs, 0.0), tf.assert_less_equal(inputs, 255.0)]): inputs = tf.identity(inputs) preprocessed_inputs = tf.map_fn( fn=tfgan.eval.preprocess_image, elems=inputs, back_prop=False) return tfgan.eval.run_inception( preprocessed_inputs, graph_def=get_inception_graph_def(), output_tensor=["pool_3:0", "logits:0"]) def inception_transform_np(inputs, batch_size): """Computes the inception features and logits for a given NumPy array. The inputs are first preprocessed to match the input shape required for Inception. Args: inputs: NumPy array of shape [-1, H, W, 3]. batch_size: Batch size. Returns: A tuple of NumPy arrays with Inception features and logits for each input. """ with tf.Session(graph=tf.Graph()) as sess: inputs_placeholder = tf.placeholder( dtype=tf.float32, shape=[None] + list(inputs[0].shape)) features_and_logits = inception_transform(inputs_placeholder) features = [] logits = [] num_batches = int(np.ceil(inputs.shape[0] / batch_size)) for i in range(num_batches): input_batch = inputs[i * batch_size:(i + 1) * batch_size] x = sess.run( features_and_logits, feed_dict={inputs_placeholder: input_batch}) features.append(x[0]) logits.append(x[1]) features = np.vstack(features) logits = np.vstack(logits) return features, logits