from vergeml.img import INPUT_PATTERNS, open_image, fixext, ImageType from vergeml.io import source, SourcePlugin, Sample from vergeml.data import Labels from vergeml.utils import VergeMLError from vergeml.sources.labeled_image import LabeledImageSource import random import numpy as np from PIL import Image import os.path import json from operator import methodcaller import io from typing import List import gzip import hashlib _FILES = ("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz") _MNIST_LABELS = ("0", "1", "2", "3", "4", "5", "6", "7", "8", "9") _FASHION_MNIST_LABELS = ("tshirt_top", "trouser", "pullover", "dress", "coat", "sandal", "shirt", "sneaker", "sag", "ankle_boot") # we use the md5 to check for fashion mnist, so we can provide the labels # automatically _MD5_FASHION = "8d4fb7e6c68d591d4c3dfef9ec88bf0d" def _md5(fname): hash_md5 = hashlib.md5() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_md5.update(chunk) return hash_md5.hexdigest() @source('image', descr="Load images in MNIST format.") class InputMnist(SourcePlugin): data = None def num_samples(self, split: str) -> int: return len(self.data[split]) def read_sample(self, split: str, index: int): return self.data[split][index] def _check_files(self): self.data = dict(train=[], val=[], test=[]) samples_dir = self.config["samples_dir"] files = [os.path.join(samples_dir, file) for file in _FILES] for path in files: if not os.path.exists(path): raise VergeMLError("File not found in samples_dir: {}".format( os.path.basename(path))) if _md5(files[0]) == _MD5_FASHION: self.meta['labels'] = _FASHION_MNIST_LABELS else: self.meta['labels'] = _MNIST_LABELS # preload for split, images, labels in (('train', files[0], files[1]), ('test', files[2], files[3])): with gzip.open(images) as f: # First 16 bytes are magic_number, n_imgs, n_rows, n_cols pixels = np.frombuffer(f.read(), 'B', offset=16) pixels = pixels.reshape(-1, 28, 28) with gzip.open(labels) as f: # First 8 bytes are magic_number, n_labels integer_labels = np.frombuffer(f.read(), 'B', offset=8) n_cols = integer_labels.max() + 1 for ix, imagearr in enumerate(pixels): label = integer_labels[ix] onehot = np.zeros((n_cols), dtype='float32') onehot[label] = 1.0 self.data[split].append((Image.fromarray(imagearr), onehot, dict(labels=self.meta['labels'], filename=images, split=split, types=('pil', 'labels')))) if split == 'train': n = self.config['val_num'] if self.config['val_perc'] is not None: n = int(len(self.data['train']) * self.config['val_perc'] // 100) if n is not None: if n > len(self.data['train']): raise VergeMLError("number of test samples is greater than number of available samples.") rng = random.Random(self.config['random_seed']) count = len(self.data[split]) indices = rng.sample(range(count), count) self.data['val'] = [self.data['train'][i] for i in indices[:n]] self.data['train'] = [self.data['train'][i] for i in indices[n:]] else: if self.config['test_num']: if self.config['test_num'] > len(self.data['test']): raise VergeMLError("number of test samples is greater than number of available samples.") rng = random.Random(self.config['random_seed']) indices = rng.sample(range(len(self.data[split])), len(pixels)) self.data['test'] = [self.data['test'][i] for i in indices[:n]] plugin = InputMnist