"""Cifar100 dataset preprocessing and specifications.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import matplotlib.pyplot as plt import numpy as np import os from six.moves import cPickle from six.moves import urllib import sys import tarfile import tensorflow as tf from common import dataset from common import misc_utils REMOTE_URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" LOCAL_DIR = "data/cifar100/" ARCHIVE_NAME = "cifar-100-python.tar.gz" DATA_DIR = "cifar-100-python/" TRAIN_BATCHES = ["train"] TEST_BATCHES = ["test"] IMAGE_SIZE = 32 NUM_CLASSES = 100 class Cifar100(dataset.AbstractDataset): def get_params(self): return { "image_size": IMAGE_SIZE, "num_classes": NUM_CLASSES, } def prepare(self, params): """Download the cifar 100 dataset.""" if not os.path.exists(LOCAL_DIR): os.makedirs(LOCAL_DIR) if not os.path.exists(LOCAL_DIR + ARCHIVE_NAME): print("Downloading...") urllib.request.urlretrieve(REMOTE_URL, LOCAL_DIR + ARCHIVE_NAME) if not os.path.exists(LOCAL_DIR + DATA_DIR): print("Extracting files...") tar = tarfile.open(LOCAL_DIR + ARCHIVE_NAME) tar.extractall(LOCAL_DIR) tar.close() def read(self, mode, params): """Create an instance of the dataset object.""" batches = { tf.estimator.ModeKeys.TRAIN: TRAIN_BATCHES, tf.estimator.ModeKeys.EVAL: TEST_BATCHES }[mode] all_images = [] all_labels = [] for batch in batches: with open("%s%s%s" % (LOCAL_DIR, DATA_DIR, batch), "rb") as fo: dict = cPickle.load(fo) images = np.array(dict["data"]) labels = np.array(dict["fine_labels"]) num = images.shape[0] images = np.reshape(images, [num, 3, IMAGE_SIZE, IMAGE_SIZE]) images = np.transpose(images, [0, 2, 3, 1]) print("Loaded %d examples." % num) all_images.append(images) all_labels.append(labels) all_images = np.concatenate(all_images) all_labels = np.concatenate(all_labels) return tf.data.Dataset.from_tensor_slices((all_images, all_labels)) def parse(self, mode, params, image, label): """Parse input record to features and labels.""" image = tf.cast(image, tf.float32) image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3]) if mode == tf.estimator.ModeKeys.TRAIN: image = tf.image.resize_image_with_crop_or_pad( image, IMAGE_SIZE + 4, IMAGE_SIZE + 4) image = tf.random_crop(image, [IMAGE_SIZE, IMAGE_SIZE, 3]) image = tf.image.random_flip_left_right(image) image = tf.image.per_image_standardization(image) return {"image": image}, {"label": label} dataset.DatasetFactory.register("cifar100", Cifar100) if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: python dataset.cifar100 download") sys.exit(1) d = Cifar100() if sys.argv[1] == "download": d.prepare(misc_utils.Tuple(d.get_params())) else: print("Unknown command", sys.argv[1])