#!/usr/bin/python3 #-*-coding:utf-8-*- #$File: model_eval.py #$Date: Sat Jun 4 09:59:32 2016 #$Author: Like Ma <milkpku[at]gmail[dot]com> import tensorflow as tf import numpy as np import sys import argparse class Evaluer(object): def __init__(self, model_folder, checkpoint_file): sys.path.append(model_folder) from model import get_model from dataset import load_data self.dataset = load_data('validation') self.sess = tf.InteractiveSession() self.model = get_model('policy') saver = tf.train.Saver() saver.restore(self.sess, checkpoint_file) def evalue_topN(self, batch, N=5): data, label = self.dataset.next_batch(batch) input_dict = {} for var, subdata in zip(self.model.inputs, data): input_dict[var] = subdata pred = self.model.pred.eval(feed_dict=input_dict) # accuracy pred_flat = pred.reshape([-1, 9*10*16]) label_flat = label.reshape([-1, 9*10*16]) accuracy = np.mean(pred_flat.argmax(axis=1)==label_flat.argmax(axis=1)) # topN topN = 0 for prob, ind in zip(pred_flat, label_flat.argmax(axis=1)): v = prob[ind] tmp = np.copy(prob) tmp.sort() if v >= tmp[-N]: topN += 1 topN /= len(pred) print('accuracy %.2f, top%d %.2f' % (accuracy, N, topN)) return data, label, pred if __name__=='__main__': parser = argparse.ArgumentParser() parser.add_argument('model_folder', help='Foler for model') parser.add_argument('checkpoint_file', help='File path of checkpoint') args = parser.parse_args() evaluer = Evaluer(args.model_folder, args.checkpoint_file) #for i in range(10): # evaluer.evalue_topN(100, 3) data, label, pred = evaluer.evalue_topN(100,3) import pickle fh = open('pred.tensor', 'wb') pickle.dump([data, label, pred], fh) fh.close()