import argparse import os import cv2 import numpy as np import tensorflow as tf from model import Nivdia_Model import reader FLAGS = None def visualize(image, mask): # cast image from yuv to brg. image = cv2.cvtColor(image, cv2.COLOR_YUV2BGR) max_val = np.max(mask) min_val = np.min(mask) mask = (mask - min_val) / (max_val - min_val) mask = (mask * 255.0).astype(np.uint8) overlay = np.copy(image) overlay[:, :, 1] = cv2.add(image[:, :, 1], mask) return image, mask, overlay def main(): x_image = tf.placeholder(tf.float32, [None, 66, 200, 3]) keep_prob = tf.placeholder(tf.float32) y = tf.placeholder(tf.float32, [None, 1]) model = Nivdia_Model(x_image, y, keep_prob, FLAGS, False) # dataset reader dataset = reader.Reader(FLAGS.data_dir, FLAGS) saver = tf.train.Saver() with tf.Session() as sess: # initialize all varibales sess.run(tf.global_variables_initializer()) # restore model print(FLAGS.model_dir) path = tf.train.latest_checkpoint(FLAGS.model_dir) if path is None: print("Err: the model does NOT exist") exit(0) else: saver.restore(sess, path) print("Restore model from", path) batch_x, batch_y = dataset.train.next_batch(FLAGS.visualization_num, False) y_pred = sess.run( model.prediction, feed_dict={ x_image: batch_x, keep_prob: 1.0 }) masks = sess.run( model.visualization_mask, feed_dict={ x_image: batch_x, keep_prob: 1.0 }) if not os.path.exists(FLAGS.result_dir): os.makedirs(FLAGS.result_dir) for i in range(FLAGS.visualization_num): image, mask, overlay = visualize(batch_x[i], masks[i]) cv2.imwrite( os.path.join(FLAGS.result_dir, "image_" + str(i) + ".jpg"), image) cv2.imwrite( os.path.join(FLAGS.result_dir, "mask_" + str(i) + ".jpg"), mask) cv2.imwrite( os.path.join(FLAGS.result_dir, "overlay_" + str(i) + ".jpg"), overlay) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '--model_dir', type=str, default=os.path.join('.', 'saved_model'), help='Directory of saved model') parser.add_argument( '--data_dir', type=str, default=os.path.join('.', 'driving_dataset'), help='Directory of data') parser.add_argument( '--result_dir', type=str, default=os.path.join('.', 'visualization_mask'), help='Directory of visualization result') parser.add_argument( '--visualization_num', type=int, default=10, help='The image number of visualization') FLAGS, unparsed = parser.parse_known_args() main()