import h5py import numpy as np from random import randint import pylab import datetime import scipy from scipy.misc import toimage np.random.seed(np.random.randint(1 << 30)) def create_reverse_dictionary(dictionary): dictionary_reverse = {} for word in dictionary: index = dictionary[word] dictionary_reverse[index] = word return dictionary_reverse dictionary = {'0':0, '1':1, '2':2, '3':3, '4':4, '5':5, '6':6, '7':7, '8':8, '9':9, 'the': 10, 'digit': 11, 'is': 12, 'on': 13, 'at': 14, 'left': 15, 'right': 16, 'bottom': 17, 'top': 18, 'of': 19, 'image': 20, '.': 21} reverse_dictionary = create_reverse_dictionary(dictionary) def sent2matrix(sentence, dictionary): words = sentence.split() m = np.int32(np.zeros((1, len(words)))) for i in xrange(len(words)): m[0,i] = dictionary[words[i]] return m def matrix2sent(matrix, reverse_dictionary): text = "" for i in xrange(matrix.shape[0]): text = text + " " + reverse_dictionary[matrix[i]] return text def create_2digit_mnist_image_leftright(digit1, digit2): """ Digits is list of numpy arrays, where each array is a digit""" image = np.zeros((60,60)) digit1 = digit1.reshape(28,28) digit2 = digit2.reshape(28,28) w = randint(16,18) h = randint(0,4) image[w:w+28,h:h+28] = digit1 h = randint(28,32) image[w:w+28,h:h+28] = digit2 image = image.reshape(-1) return image def create_2digit_mnist_image_topbottom(digit1, digit2): """ Digits is list of numpy arrays, where each array is a digit""" image = np.zeros((60,60)) digit1 = digit1.reshape(28,28) digit2 = digit2.reshape(28,28) h = randint(16,18) w = randint(0,2) image[w:w+28,h:h+28] = digit1 w = randint(30,32) image[w:w+28,h:h+28] = digit2 image = image.reshape(-1) return image def create_1digit_mnist_image_topleft(digit1): """ Digits is list of numpy arrays, where each array is a digit""" image = np.zeros((60,60)) digit1 = digit1.reshape(28,28) w = randint(0,2) h = randint(0,4) image[w:w+28,h:h+28] = digit1 image = image.reshape(-1) return image def create_1digit_mnist_image_topright(digit1): """ Digits is list of numpy arrays, where each array is a digit""" image = np.zeros((60,60)) digit1 = digit1.reshape(28,28) w = randint(0,2) h = randint(28,32) image[w:w+28,h:h+28] = digit1 image = image.reshape(-1) return image def create_1digit_mnist_image_bottomright(digit1): """ Digits is list of numpy arrays, where each array is a digit""" image = np.zeros((60,60)) digit1 = digit1.reshape(28,28) w = randint(30,32) h = randint(28,32) image[w:w+28,h:h+28] = digit1 image = image.reshape(-1) return image def create_1digit_mnist_image_bottomleft(digit1): """ Digits is list of numpy arrays, where each array is a digit""" image = np.zeros((60,60)) digit1 = digit1.reshape(28,28) w = randint(30,32) h = randint(0,4) image[w:w+28,h:h+28] = digit1 image = image.reshape(-1) return image def create_mnist_captions_dataset(data, labels, banned, num=10000): images = np.zeros((num,60*60)) captions = np.zeros((num,12)) counts = [0, 0, 0, 0, 0, 0, 0, 0] curr_num = 0 while True: # only left/right case for now k = randint(0,7) # Select 2 random digits i = randint(0,data.shape[0]-1) j = randint(0,data.shape[0]-1) # some cases are hidden from training set if k <= 3: if labels[i] == banned[k*2] or labels[j] == banned[k*2+1]: continue else: if labels[i] == banned[k+4]: continue if k == 0: sentence = 'the digit %d is on the left of the digit %d .' % (labels[i], labels[j]) elif k == 1: sentence = 'the digit %d is on the right of the digit %d .' % (labels[j], labels[i]) elif k == 2: sentence = 'the digit %d is at the top of the digit %d .' % (labels[i], labels[j]) elif k == 3: sentence = 'the digit %d is at the bottom of the digit %d .' % (labels[j], labels[i]) elif k == 4: sentence = 'the digit %d is at the top left of the image .' % (labels[i]) elif k == 5: sentence = 'the digit %d is at the bottom right of the image .' % (labels[i]) elif k == 6: sentence = 'the digit %d is at the top right of the image .' % (labels[i]) elif k == 7: sentence = 'the digit %d is at the bottom left of the image .' % (labels[i]) counts[k] = counts[k] + 1 sentence_matrix = sent2matrix(sentence, dictionary) captions[curr_num,:] = sentence_matrix if k == 0 or k == 1: images[curr_num,:] = create_2digit_mnist_image_leftright(data[i,:], data[j,:]) if k == 2 or k == 3: images[curr_num,:] = create_2digit_mnist_image_topbottom(data[i,:], data[j,:]) if k == 4: images[curr_num,:] = create_1digit_mnist_image_topleft(data[i,:]) if k == 5: images[curr_num,:] = create_1digit_mnist_image_bottomright(data[i,:]) if k == 6: images[curr_num,:] = create_1digit_mnist_image_topright(data[i,:]) if k == 7: images[curr_num,:] = create_1digit_mnist_image_bottomleft(data[i,:]) curr_num += 1 #print curr_num if curr_num == num: break return np.float32(images), np.int32(captions), counts if __name__ == '__main__': data = np.copy(h5py.File('/ais/gobi3/u/nitish/mnist/mnist.h5', 'r')["train"]) labels = np.copy(h5py.File('/ais/gobi3/u/nitish/mnist/mnist.h5', 'r')["train_labels"]) image = create_1digit_mnist_image_topright(data[327,:]) pylab.figure() pylab.gray() pylab.imshow(image.reshape((60,60)), interpolation='nearest') pylab.show(block=True)