#-*- coding:utf-8 -*- from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys import pickle import numpy as np from keras.models import Model from keras.preprocessing import image from keras.layers import Flatten, Input from keras.applications.resnet50 import ResNet50 from keras.applications.resnet50 import preprocess_input, decode_predictions from keras.utils.vis_utils import plot_model import gc res50_base_model = ResNet50(weights='imagenet', pooling=max, include_top=False) image_input = Input(shape=(224, 224, 3), name="image_input") x = res50_base_model(image_input) x = Flatten()(x) res50_model = Model(inputs=image_input, outputs=x) res50_model.summary() # plot_model(res50_base_model, to_file='model.png', show_shapes=True) batch_size = 1000 # image file abs_path, label image_dir = "/home/ai-i-liuguiyang/ImageRetireval/dataset/OxBuild/src/" image_list_file_path = "/home/ai-i-liuguiyang/ImageRetireval/dataset/OxBuild/src/index_file.csv" nn_feature_save_dir = "/home/ai-i-liuguiyang/ImageRetireval/dataset/OxBuild/src/" with open(image_list_file_path, "r") as fl_reader: image_nn_feature_dict = dict() def __fetch_nn_feature(batch_image, batch_file_name): batch_image = np.concatenate(batch_image, axis=0) x = preprocess_input(batch_image) features = res50_model.predict(x) features_reduce = features.squeeze() for idx in range(len(batch_file_name)): image_nn_feature_dict[batch_file_name[idx]] = features_reduce[idx] # print(features_reduce) print(features_reduce.shape) batch_image, batch_file_name = [], [] for line in fl_reader.readlines(): image_name, image_label = line.strip().split(",") image_path = image_dir + image_name if not os.path.exists(image_path): print("{} not found !".format(image_path)) continue image_data = image.load_img(image_path, target_size=(224, 224)) x = image.img_to_array(image_data) x = np.expand_dims(x, axis=0) batch_image.append(x) batch_file_name.append(image_name) if len(batch_image) == batch_size: __fetch_nn_feature(batch_image, batch_file_name) batch_image = list() batch_file_name = list() if len(batch_image): __fetch_nn_feature(batch_image, batch_file_name) print("Before dump the data !") pickle.dump(image_nn_feature_dict, open(nn_feature_save_dir+"nn_features.pkl", "wb"), True) print("After dump the data !") gc.collect()