#!/usr/bin/env python """Preparing the data.""" # pylint: disable=invalid-name, no-member from __future__ import print_function import os as _os import logging as _logging import cv2 as _cv2 import numpy as _np import click as _click import progressbar as _progressbar from sklearn.datasets import fetch_mldata as _fetch_mldata _LOGGER = _logging.getLogger(__name__) _DATA_FOLDER = _os.path.join(_os.path.dirname(__file__), 'data') if not _os.path.exists(_DATA_FOLDER): _LOGGER.info("Data folder not found. Creating...") _os.mkdir(_DATA_FOLDER) def training_data(): """Get the `MNIST original` training data.""" _np.random.seed(1) permutation = _np.random.permutation(range(60000)) mnist = _fetch_mldata('MNIST original', data_home=_os.path.join(_DATA_FOLDER, 'MNIST_original')) return (mnist.data[:60000, :][permutation, :].reshape((60000, 1, 28, 28)).astype('float32'), mnist.target[:60000][permutation].reshape((60000, 1)).astype('float32')) def test_data(): """Get the `MNIST original` test data.""" mnist = _fetch_mldata('MNIST original', data_home=_os.path.join(_DATA_FOLDER, 'MNIST_original')) return (mnist.data[60000:, :].reshape((10000, 1, 28, 28)).astype('float32'), mnist.target[60000:].reshape((10000, 1)).astype('float32')) @_click.group() def _cli(): """Handle the experiment data.""" pass @_cli.command() def validate_storage(): """Validate the data.""" _LOGGER.info("Validating storage...") val_folder = _os.path.join(_DATA_FOLDER, 'images') _LOGGER.info("Writing images to %s.", val_folder) if not _os.path.exists(val_folder): _os.mkdir(val_folder) _LOGGER.info("Train...") tr_folder = _os.path.join(val_folder, 'train') if not _os.path.exists(tr_folder): _os.mkdir(tr_folder) tr_data, tr_labels = training_data() pbar = _progressbar.ProgressBar(maxval=60000 - 1, widgets=[_progressbar.Percentage(), _progressbar.Bar(), _progressbar.ETA()]) pbar.start() for idx in range(60000): _cv2.imwrite(_os.path.join(tr_folder, '%05d_%d.jpg' % (idx, int(tr_labels[idx, 0]))), tr_data[idx, 0]) pbar.update(idx) pbar.finish() _LOGGER.info("Test...") te_folder = _os.path.join(val_folder, 'test') if not _os.path.exists(te_folder): _os.mkdir(te_folder) te_data, te_labels = test_data() pbar = _progressbar.ProgressBar(maxval=10000 - 1, widgets=[_progressbar.Percentage(), _progressbar.Bar(), _progressbar.ETA()]) pbar.start() for idx in range(10000): _cv2.imwrite(_os.path.join(te_folder, '%05d_%d.jpg' % (idx, int(te_labels[idx, 0]))), te_data[idx, 0]) pbar.update(idx) pbar.finish() if __name__ == '__main__': _logging.basicConfig(level=_logging.INFO) _cli()