__author__ = 'charlie' import numpy as np import os, sys, inspect import random from six.moves import cPickle as pickle from tensorflow.python.platform import gfile import glob utils_path = os.path.abspath( os.path.realpath(os.path.join(os.path.split(inspect.getfile(inspect.currentframe()))[0], ".."))) if utils_path not in sys.path: sys.path.insert(0, utils_path) import TensorflowUtils as utils DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz' def read_dataset(data_dir): pickle_filename = "flowers_data.pickle" pickle_filepath = os.path.join(data_dir, pickle_filename) if not os.path.exists(pickle_filepath): utils.maybe_download_and_extract(data_dir, DATA_URL, is_tarfile=True) flower_folder = os.path.splitext(DATA_URL.split("/")[-1])[0] result = create_image_lists(os.path.join(data_dir, flower_folder)) print "Training set: %d" % len(result['train']) print "Test set: %d" % len(result['test']) print "Validation set: %d" % len(result['validation']) print "Pickling ..." with open(pickle_filepath, 'wb') as f: pickle.dump(result, f, pickle.HIGHEST_PROTOCOL) else: print "Found pickle file!" with open(pickle_filepath, 'rb') as f: result = pickle.load(f) training_images = result['train'] testing_images = result['test'] validation_images = result['validation'] del result print ("Training: %d, Validation: %d, Test: %d" % ( len(training_images), len(validation_images), len(testing_images))) return training_images, testing_images, validation_images def create_image_lists(image_dir, testing_percentage=0.0, validation_percentage=0.0): """ Code modified from tensorflow/tensorflow/examples/image_retraining """ if not gfile.Exists(image_dir): print("Image directory '" + image_dir + "' not found.") return None training_images = [] sub_dirs = [x[0] for x in os.walk(image_dir)] # The root directory comes first, so skip it. is_root_dir = True for sub_dir in sub_dirs: if is_root_dir: is_root_dir = False continue extensions = ['jpg', 'jpeg', 'JPG', 'JPEG'] file_list = [] dir_name = os.path.basename(sub_dir) if dir_name == image_dir: continue print("Looking for images in '" + dir_name + "'") for extension in extensions: file_glob = os.path.join(image_dir, dir_name, '*.' + extension) file_list.extend(glob.glob(file_glob)) if not file_list: print('No files found') continue print "No. of files found: %d" % len(file_list) training_images.extend([f for f in file_list]) random.shuffle(training_images) no_of_images = len(training_images) validation_offset = int(validation_percentage * no_of_images) validation_images = training_images[:validation_offset] test_offset = int(testing_percentage * no_of_images) testing_images = training_images[validation_offset:validation_offset + test_offset] training_images = training_images[validation_offset + test_offset:] result = { 'train': training_images, 'test': testing_images, 'validation': validation_images, } return result