from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import time from pprint import pprint import h5py import numpy as np from six.moves import xrange import tensorflow as tf from input_ops import create_input_ops, check_data_id from util import log from config import argparser class EvalManager(object): def __init__(self): # collection of batches (not flattened) self._ids = [] self._predictions = [] self._groundtruths = [] def add_batch(self, id, prediction, groundtruth): # for now, store them all (as a list of minibatch chunks) self._ids.append(id) self._predictions.append(prediction) self._groundtruths.append(groundtruth) def compute_accuracy(self, pred, gt): correct_prediction = np.sum(np.argmax(pred[:, :-1], axis=1) == np.argmax(gt, axis=1)) return float(correct_prediction)/pred.shape[0] def report(self): # report L2 loss log.info("Computing scores...") score = {} score = [] for id, pred, gt in zip(self._ids, self._predictions, self._groundtruths): score.append(self.compute_accuracy(pred, gt)) avg = np.average(score) log.infov("Average accuracy : %.4f", avg*100) class Evaler(object): def __init__(self, config, model, dataset): self.config = config self.model = model self.train_dir = config.train_dir log.info("self.train_dir = %s", self.train_dir) # --- input ops --- self.batch_size = config.batch_size self.dataset = dataset check_data_id(dataset, config.data_id) _, self.batch = create_input_ops(dataset, self.batch_size, data_id=config.data_id, is_training=False, shuffle=False) self.global_step = tf.contrib.framework.get_or_create_global_step(graph=None) self.step_op = tf.no_op(name='step_no_op') tf.set_random_seed(1234) session_config = tf.ConfigProto( allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True), device_count={'GPU': 1}, ) self.session = tf.Session(config=session_config) # --- checkpoint and monitoring --- self.saver = tf.train.Saver(max_to_keep=100) self.checkpoint = config.checkpoint if self.checkpoint is None and self.train_dir: self.checkpoint = tf.train.latest_checkpoint(self.train_dir) if self.checkpoint is None: log.warn("No checkpoint is given. Just random initialization :-)") self.session.run(tf.global_variables_initializer()) else: log.info("Checkpoint path : %s", self.checkpoint) def eval_run(self): # load checkpoint if self.checkpoint: self.saver.restore(self.session, self.checkpoint) log.info("Loaded from checkpoint!") log.infov("Start 1-epoch Inference and Evaluation") log.info("# of examples = %d", len(self.dataset)) length_dataset = len(self.dataset) max_steps = int(length_dataset / self.batch_size) + 1 log.info("max_steps = %d", max_steps) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(self.session, coord=coord, start=True) evaler = EvalManager() try: for s in xrange(max_steps): step, loss, step_time, batch_chunk, prediction_pred, prediction_gt = \ self.run_single_step(self.batch) self.log_step_message(s, loss, step_time) evaler.add_batch(batch_chunk['id'], prediction_pred, prediction_gt) except Exception as e: coord.request_stop(e) coord.request_stop() try: coord.join(threads, stop_grace_period_secs=3) except RuntimeError as e: log.warn(str(e)) evaler.report() log.infov("Evaluation complete.") def run_single_step(self, batch, step=None, is_train=True): _start_time = time.time() batch_chunk = self.session.run(batch) [step, accuracy, all_preds, all_targets, _] = self.session.run( [self.global_step, self.model.accuracy, self.model.all_preds, self.model.all_targets, self.step_op], feed_dict=self.model.get_feed_dict(batch_chunk) ) _end_time = time.time() return step, accuracy, (_end_time - _start_time), batch_chunk, all_preds, all_targets def log_step_message(self, step, accuracy, step_time, is_train=False): if step_time == 0: step_time = 0.001 log_fn = (is_train and log.info or log.infov) log_fn((" [{split_mode:5s} step {step:4d}] " + "batch total-accuracy (test): {test_accuracy:.2f}% " + "({sec_per_batch:.3f} sec/batch, {instance_per_sec:.3f} instances/sec) " ).format(split_mode=(is_train and 'train' or 'val'), step=step, test_accuracy=accuracy*100, sec_per_batch=step_time, instance_per_sec=self.batch_size / step_time, ) ) def main(): config, model, dataset_train, dataset_test = argparser(is_train=False) evaler = Evaler(config, model, dataset_test) log.warning("dataset: %s", config.dataset) evaler.eval_run() if __name__ == '__main__': main()