import numpy as np import cPickle as pickle from keras.applications import ResNet50 from keras.preprocessing import image from keras.applications import imagenet_utils from tqdm import tqdm from time import time counter = 0 DATA_PATH = "/data/vision/fisher/data1/Flickr8k/" def load_image(path): img = image.load_img(path, target_size=(224,224)) x = image.img_to_array(img) x = np.expand_dims(x, axis=0) x = imagenet_utils.preprocess_input(x) return np.asarray(x) def load_encoding_model(): model = ResNet50(weights='imagenet', include_top=False, input_shape = (224, 224, 3)) return model def get_encoding(model, img): global counter counter += 1 image = load_image(DATA_PATH + 'Flicker8k_Dataset/'+str(img)) pred = model.predict(image) pred = np.reshape(pred, pred.shape[-1]) return pred def prepare_dataset(no_imgs = -1, num_val=500): f_train_images = open(DATA_PATH + 'Flickr8k_text/Flickr_8k.trainImages.txt','rb') train_imgs = f_train_images.read().strip().split('\n') if no_imgs == -1 else f_train_images.read().strip().split('\n')[:no_imgs] f_train_images.close() f_test_images = open(DATA_PATH + 'Flickr8k_text/Flickr_8k.testImages.txt','rb') test_imgs = f_test_images.read().strip().split('\n') if no_imgs == -1 else f_test_images.read().strip().split('\n')[:no_imgs] f_test_images.close() f_train_dataset = open(DATA_PATH + 'Flickr8k_text/flickr_8k_train_dataset.txt','wb') f_train_dataset.write("image_id\tcaptions\n") f_val_dataset = open(DATA_PATH + 'Flickr8k_text/flickr_8k_val_dataset.txt','wb') f_val_dataset.write("image_id\tcaptions\n") f_test_dataset = open(DATA_PATH + 'Flickr8k_text/flickr_8k_test_dataset.txt','wb') f_test_dataset.write("image_id\tcaptions\n") f_captions = open(DATA_PATH + 'Flickr8k_text/Flickr8k.token.txt', 'rb') captions = f_captions.read().strip().split('\n') data = {} print "processing captions..." for row in captions: row = row.split("\t") row[0] = row[0][:len(row[0])-2] try: data[row[0]].append(row[1]) except: data[row[0]] = [row[1]] f_captions.close() encoded_images = {} encoding_model = load_encoding_model() c_train, c_val = 0, 0 print "processing training and validation images..." for idx, img in tqdm(enumerate(train_imgs)): encoded_images[img] = get_encoding(encoding_model, img) if (idx < len(train_imgs) - num_val): #training for capt in data[img]: caption = "<start> "+capt+" <end>" f_train_dataset.write(img+"\t"+caption+"\n") f_train_dataset.flush() c_train += 1 #end for else: #validation for capt in data[img]: caption = "<start> "+capt+" <end>" f_val_dataset.write(img+"\t"+caption+"\n") f_val_dataset.flush() c_val += 1 #end for #end if #end for f_train_dataset.close() f_val_dataset.close() c_test = 0 print "processing test images..." for img in tqdm(test_imgs): encoded_images[img] = get_encoding(encoding_model, img) for capt in data[img]: caption = "<start> "+capt+" <end>" f_test_dataset.write(img+"\t"+caption+"\n") f_test_dataset.flush() c_test += 1 f_test_dataset.close() with open(DATA_PATH + "encoded_images.dat", "wb" ) as pickle_f: pickle.dump( encoded_images, pickle_f ) return [c_train, c_val, c_test] if __name__ == '__main__': c_train, c_val, c_test = prepare_dataset() print "num training captions: ", c_train print "num validation captions: ", c_val print "num test captions: ", c_test