import tensorflow as tf import os, sys import numpy as np import time from dataprovider_supervise import dataprovider from model_supervise import ground_model from util.iou import calc_iou import argparse parser = argparse.ArgumentParser() parser.add_argument("-m", "--model_name", type=str, default='grounder') parser.add_argument("-g", "--gpu", type=str, default='0') parser.add_argument("--restore_id", type=int, default=0) args = parser.parse_args() os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu class Config(object): batch_size = 40 img_feat_dir = './feature' sen_dir = './annotation' train_file_list = 'flickr30k_train_val.lst' test_file_list = 'flickr30k_test.lst' log_file = './log/ground_supervised' save_path = './model/ground_supervised' vocab_size = 17150 num_epoch = 3 max_step = 12000 optim='adam' dropout = 0.5 lr = 0.001 weight_decay=0.0 def update_feed_dict(dataprovider, model, is_train): img_feat, sen_feat, bbx_label = dataprovider.get_next_batch() feed_dict = { model.sen_data: sen_feat, model.vis_data: img_feat, model.bbx_label: bbx_label, model.is_train: is_train} return feed_dict def eval_cur_batch(gt_label, cur_logits, is_train=True, num_sample=0): res_prob = cur_logits res_label = np.argmax(res_prob, axis=1) accu = 0.0 if is_train: accu = float(np.sum(res_label == gt_label)) / float(len(gt_label)) else: for gt_id, cur_gt in enumerate(gt_label): if res_label[gt_id] in cur_gt: accu += 1.0 accu /= float(num_sample) return accu def load_img_id_list(file_list): img_list = [] with open(file_list) as fin: for img_id in fin.readlines(): img_list.append(int(img_id.strip())) img_list = np.array(img_list).astype('int') return img_list def run_eval(sess, dataprovider, model, eval_op, feed_dict): num_cnt = 0.0 num_cor = 0.0 for img_ind, img_id in enumerate(dataprovider.test_list): img_feat_raw, sen_feat_batch, bbx_gt_batch, num_sample_all = dataprovider.get_test_feat(img_id) # bbx_gt_batch = set(bbx_gt_batch) if num_sample_all > 0: num_corr = 0 num_sample = len(bbx_gt_batch) img_feat = feed_dict[model.vis_data] for i in range(num_sample): img_feat[i] = img_feat_raw sen_feat = feed_dict[model.sen_data] sen_feat[:num_sample] = sen_feat_batch bbx_label = feed_dict[model.bbx_label] eval_feed_dict = { model.sen_data: sen_feat, model.vis_data: img_feat, model.bbx_label: bbx_label, model.is_train: False} cur_att_logits = sess.run(eval_op, feed_dict=eval_feed_dict) cur_att_logits = cur_att_logits[:num_sample] cur_accuracy = eval_cur_batch(bbx_gt_batch, cur_att_logits, False, num_sample_all) print '%d/%d: %d/%d, %.4f'%(img_ind, len(dataprovider.test_list), num_sample, num_sample_all, cur_accuracy) num_cor += float(num_sample_all)*cur_accuracy num_cnt += float(num_sample_all) else: print 'No gt for %d'%img_id accu = num_cor/num_cnt print 'Accuracy = %.4f'%accu return accu def run_evaluate(): train_list = [] test_list = [] config = Config() train_list = load_img_id_list(config.train_file_list) test_list = load_img_id_list(config.test_file_list) config.save_path = config.save_path + '_' + args.model_name assert(os.path.isdir(config.save_path)) restore_id = args.restore_id assert(restore_id > 0) cur_dataset = dataprovider(train_list, test_list, config.img_feat_dir, config.sen_dir, config.vocab_size, batch_size=config.batch_size) model = ground_model(config) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3) with tf.Graph().as_default(): loss, train_op, loss_vec, logits = model.build_model() # Create a session for running Ops on the Graph. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) # Run the Op to initialize the variables. init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver(max_to_keep=100) feed_dict = update_feed_dict(cur_dataset, model, False) print 'Restore model_%d'%restore_id saver.restore(sess, './model/%s/model_%d.ckpt'%(config.save_path, restore_id)) print "-----------------------------------------------" eval_accu = run_eval(sess, cur_dataset, model, logits, feed_dict) print "-----------------------------------------------" def main(_): run_evaluate() if __name__ == '__main__': tf.app.run()