######################################################################################## # # Forge # Copyright (C) 2018 Adam R. Kosiorek, Oxford Robotics Institute and # Department of Statistics, University of Oxford # # email: adamk@robots.ox.ac.uk # webpage: http://akosiorek.github.io/ # github: https://github.com/akosiorek/forge/ # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # You should have received a copy of the GNU General Public License # along with this program. If not, see <http://www.gnu.org/licenses/>. # ######################################################################################## """MNIST data config.""" from attrdict import AttrDict import os from tensorflow.examples.tutorials.mnist import input_data from forge import flags from forge.data import tensors_from_data flags.DEFINE_string('data_folder', '../data/MNIST_data', 'Path to a data folder.') def load(config, **unused_kwargs): del unused_kwargs if not os.path.exists(config.data_folder): os.makedirs(config.data_folder) dataset = input_data.read_data_sets(config.data_folder) train_data = {'imgs': dataset.train.images, 'labels': dataset.train.labels} valid_data = {'imgs': dataset.validation.images, 'labels': dataset.validation.labels} train_tensors = tensors_from_data(train_data, config.batch_size, shuffle=True) valid_tensors = tensors_from_data(valid_data, config.batch_size, shuffle=False) data_dict = AttrDict( train_img=train_tensors['imgs'], valid_img=valid_tensors['imgs'], train_label=train_tensors['labels'], valid_label=valid_tensors['labels'], ) return data_dict