import numpy as np import gzip import pickle import os class MNIST: def __init__(self, batch_size): self.batch_size = batch_size train, valid, test = self._load_data() self.X_train, self.y_train = train[0], train[1] # encoding y_train using one-hot encoding self.y_train_one_hot = np.zeros((self.y_train.shape[0], 10)) self.y_train_one_hot[np.arange(self.y_train.shape[0]), self.y_train] = 1 self.X_valid, self.y_valid = valid[0], valid[1] self.X_test, self.y_test = test[0], test[1] def train_batch_generator(self): while True: rand_indices = np.random.choice(self.X_train.shape[0], self.batch_size, False) yield self.X_train[rand_indices], self.y_train_one_hot[rand_indices] def validation(self): return self.X_valid, self.y_valid def testing(self): return self.X_test, self.y_test def num_features(self): return self.X_train.shape[1] def _load_data(self): script_dir = os.path.dirname(__file__) mnist_file = os.path.join(os.path.join(script_dir, 'data'), 'mnist.pkl.gz') with gzip.open(mnist_file, 'rb') as mnist_file: u = pickle._Unpickler(mnist_file) u.encoding = 'latin1' train, val, test = u.load() return train, val, test