# License: MIT # Author: Karl Stelzner import os import numpy as np import scipy.misc import imageio from observations import mnist def add_noise(x): x = 0.8 * x + 0.1 x += np.random.normal(0.0, 0.20, size=x.shape) x = np.clip(x, 0.0, 1.0) return x def add_structured_noise(images): n, height, width = [int(dim) for dim in images.shape[:3]] x_offset = np.random.randint(0, 5, n) y_offset = np.random.randint(0, 5, n) for i in range(n): x, y = x_offset[i], y_offset[i] while y < height: images[i, y] = np.maximum(images[i, y], 0.4) y += 5 while x < width: images[i, :, x] = np.maximum(images[i, :, x], 0.4) x += 5 return images def preprocess(data): data = data.astype(np.float32) data /= data.max() # Squash to [0, 1] return data def load_multi_mnist(path, max_digits=2, canvas_size=50, seed=42): """ Code pulled from observations library and customized to collect bounding box information. Load the multiple MNIST data set [@eslami2016attend]. It modifies the original MNIST such that each image contains a number of non-overlapping random MNIST digits with equal probability. Args: path: str. Path to directory which either stores file or otherwise file will be downloaded and extracted there. Filename is `'multi_mnist_{}_{}_{}.npz'.format(max_digits, canvas_size, seed)`. max_digits: int, optional. Maximum number of non-overlapping MNIST digits per image to generate if not cached. canvas_size: list of two int, optional. Width x height pixel size of generated images if not cached. seed: int, optional. Random seed to generate the data set from MNIST if not cached. Returns: Tuple of (np.ndarray of dtype uint8, list) `(x_train, y_train), (x_test, y_test)`. Each element in the y's is a np.ndarray of labels, one label for each digit in the image. """ def _sample_one(canvas_size, x_data, y_data): i = np.random.randint(x_data.shape[0]) digit = x_data[i] label = y_data[i] scale = 0.1 * np.random.randn() + 1.3 resized = scipy.misc.imresize(digit, 1.0 / scale) width = resized.shape[0] padding = canvas_size - width pad_l = np.random.randint(0, padding) pad_r = np.random.randint(0, padding) pad_width = ((pad_l, padding - pad_l), (pad_r, padding - pad_r)) positioned = np.pad(resized, pad_width, 'constant', constant_values=0) bbox = (pad_l, pad_r, pad_l + width, pad_r + width) return positioned, label, bbox def _sample_multi(num_digits, canvas_size, x_data, y_data): canvas = np.zeros((canvas_size, canvas_size)) labels = [] bboxes = [] for _ in range(num_digits): positioned_digit, label, bbox = _sample_one(canvas_size, x_data, y_data) canvas += positioned_digit labels.append(label) bboxes.append(bbox) labels = np.array(labels, dtype=np.uint8) if np.max(canvas) > 255: # crude check for overlapping digits return _sample_multi(num_digits, canvas_size, x_data, y_data) else: return canvas, labels, bboxes def _build_dataset(x_data, y_data, max_digits, canvas_size): x = [] y = [] data_size = x_data.shape[0] data_num_digits = np.random.randint(max_digits + 1, size=data_size) x_data = np.reshape(x_data, [data_size, 28, 28]) bboxes_arr = np.zeros((data_size, max_digits, 4)) for i, num_digits in enumerate(data_num_digits): canvas, labels, bboxes = _sample_multi(num_digits, canvas_size, x_data, y_data) x.append(canvas) y.append(labels) for j, bbox in enumerate(bboxes): bboxes_arr[i, j] = bbox x = np.array(x, dtype=np.uint8) return x, y, bboxes_arr path = os.path.expanduser(path) cache_filename = 'multi_mnist_{}_{}_{}.npz'.format( max_digits, canvas_size, seed) if os.path.exists(os.path.join(path, cache_filename)): data = np.load(os.path.join(path, cache_filename), allow_pickle=True) return (data['x_train'], data['y_train'], data['x_bbox']),\ (data['x_test'], data['y_test'], data['y_bbox']) np.random.seed(seed) (x_train, y_train), (x_test, y_test) = mnist(path) x_train, y_train, x_bbox = _build_dataset(x_train, y_train, max_digits, canvas_size) x_test, y_test, y_bbox = _build_dataset(x_test, y_test, max_digits, canvas_size) with open(os.path.join(path, cache_filename), 'wb') as f: np.savez_compressed(f, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, x_bbox=x_bbox, y_bbox=y_bbox) return (x_train, y_train, x_bbox), (x_test, y_test, y_bbox) def load_mnist(canvas_size, max_digits=5, path='./data'): (x, y, bbox), (x_test, y_test, bbox_test) = \ load_multi_mnist(path, max_digits=max_digits, canvas_size=canvas_size, seed=42) x = preprocess(x) x_test = preprocess(x_test) # x = 1.0 - x # Using FloatTensor to allow comparison with values sampled from Bernoulli. counts = np.array([len(objs) for objs in y]) counts_test = np.array([len(objs) for objs in y_test]) x = np.expand_dims(x, -1) x_test = np.expand_dims(x_test, -1) return (x, counts, y, bbox), (x_test, counts_test, y_test, bbox_test) def load_svhn(path): npy_file = open(os.path.join(path, 'svhn_data.npy'), 'rb') count_file = open(os.path.join(path, 'svhn_counts.npy'), 'rb') obj_file = open(os.path.join(path, 'svhn_objects.npy'), 'rb') bg_file = open(os.path.join(path, 'svhn_bg.npy'), 'rb') data = np.load(npy_file, allow_pickle=True) counts = np.load(count_file, allow_pickle=True) objects = np.load(obj_file, allow_pickle=True) bgs = np.load(bg_file, allow_pickle=True) data = np.expand_dims(data, -1) return data, counts, objects, bgs def add_sprite(canvas): while True: pos_x = np.random.random_integers(0, 38) pos_y = np.random.random_integers(0, 38) scale = np.random.random_integers(12, min(min(16, 50 - pos_x), 50 - pos_y)) cat = np.random.random_integers(0, 2) sprite = np.zeros((50, 50, 3)) if cat == 0: # draw circle center_x = pos_x + scale // 2.0 center_y = pos_y + scale // 2.0 for x in range(50): for y in range(50): if (x - center_x)**2 + (y - center_y)**2 < (scale // 2.0)**2: sprite[x][y][cat] = 1.0 elif cat == 1: # draw square sprite[pos_x:pos_x + scale, pos_y:pos_y + scale, cat] = 1.0 else: # draw square turned by 45 degrees center_x = pos_x + scale // 2.0 center_y = pos_y + scale // 2.0 for x in range(50): for y in range(50): if abs(x - center_x) + abs(y - center_y) < (scale // 2.0): sprite[x][y][cat] = 1.0 mod_canvas = canvas + sprite if np.all(np.sum(mod_canvas, axis=2) <= 1.): return mod_canvas def make_sprites(n=50000, path='./data'): path = os.path.expanduser(path) cache_filename = 'sprites_{}_{}.npz'.format(n, 50) if os.path.exists(os.path.join(path, cache_filename)): data = np.load(os.path.join(path, cache_filename), allow_pickle=True) return (data['x_train'], data['count_train'], None),\ (data['x_test'], data['count_test'], None) images = np.zeros((n, 50, 50, 3)) counts = np.zeros((n,)) for i in range(n): if i < 100: num_sprites = i % 3 else: num_sprites = np.random.random_integers(0, 2) counts[i] = num_sprites for j in range(num_sprites): images[i] = add_sprite(images[i]) np.clip(images, 0.0, 1.0) x_train, count_train = images[:4 * n // 5], counts[:4 * n // 5] x_test, count_test = images[4 * n // 5:], counts[4 * n // 5:] with open(os.path.join(path, cache_filename), 'wb') as f: np.savez_compressed(f, x_train=x_train, count_train=count_train, x_test=x_test, count_test=count_test) return (x_train, count_train, None), (x_test, count_test, None) def load_omniglot(path): images = [] for dirname, dirnames, filenames in os.walk(path): for filename in filenames: if len(filename) > 4 and filename[-4:] == '.png': fullname = dirname + '/' + filename image = imageio.imread(fullname) image = scipy.misc.imresize(image, (50, 50)) images.append(image) images = np.stack(images, axis=0) images = np.expand_dims(images, -1) images = images.astype(np.float32) / 255.0 np.random.seed(42) np.random.shuffle(images) print(np.min(images), np.max(images)) print(images.shape) return images