from LSPGlobals import FLAGS import GetLSPData import LSPModels import os import tensorflow as tf from tensorflow.python.platform import gfile import time from datetime import datetime import LSPGlobals from LSPDrawLines import draw_pose_on_image as draw # Constants used for dealing with the files, matches convert_to_records. train_set_file = os.path.join(FLAGS.data_dir, 'train.tfrecords') validation_set_file = os.path.join(FLAGS.data_dir, 'validation.tfrecords') def main(): if not (os.path.exists(train_set_file) & os.path.exists(validation_set_file)): GetLSPData.main() if not gfile.Exists(FLAGS.train_dir): gfile.MakeDirs(FLAGS.train_dir) train() def read_and_decode(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) # The serialized example is converted back to actual values. # One needs to describe the format of the objects to be returned features = tf.parse_single_example( serialized_example, features={ # We know the length of both fields. If not the # tf.VarLenFeature could be used 'label': tf.FixedLenFeature([LSPGlobals.TotalLabels], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string) }) # now return the converted data image_as_vector = tf.decode_raw(features['image_raw'], tf.uint8) image_as_vector.set_shape([LSPGlobals.TotalImageBytes]) image = tf.reshape(image_as_vector, [FLAGS.input_size, FLAGS.input_size, FLAGS.input_depth]) # Convert from [0, 255] -> [-0.5, 0.5] floats. image_float = tf.cast(image, tf.float32) * (1. / 255) - 0.5 # Convert label from a scalar uint8 tensor to an int32 scalar. label = tf.cast(features['label'], tf.int32) return label, image_float def inputs(is_train): """Reads input data num_epochs times.""" filename = train_set_file if is_train else validation_set_file with tf.name_scope('input'): filename_queue = tf.train.string_input_producer( [filename], num_epochs=None) # get single examples label, image = read_and_decode(filename_queue) # groups examples into batches randomly images_batch, labels_batch = tf.train.shuffle_batch( [image, label], batch_size=FLAGS.batch_size, capacity=3000, min_after_dequeue=1000) return images_batch, labels_batch def train(): with tf.Graph().as_default(): # Global step variable for tracking processes. global_step = tf.Variable(0, trainable=False) # Prepare data batches train_set_batch, train_label_batch = inputs(is_train=True) validation_set_batch, validation_label_batch = inputs(is_train=False) # Placeholder to switch between train and test sets. image_batch = tf.placeholder(tf.float32, shape=[FLAGS.batch_size, FLAGS.input_size, FLAGS.input_size, FLAGS.input_depth]) label_batch = tf.placeholder(tf.int32, shape=[FLAGS.batch_size, LSPGlobals.TotalLabels]) keep_probability = tf.placeholder(tf.float32) # Build a Graph that computes the logits predictions from the inference model. logits = LSPModels.inference(image_batch, keep_prob=keep_probability) # Calculate loss. loss, mean_pixel_error = LSPModels.loss(logits, label_batch) # Build a Graph that trains the model with one batch of examples and updates the model parameters. train_op = LSPModels.train(loss, global_step) # Create a saver. saver = tf.train.Saver(tf.global_variables()) # Build the summary operation based on the TF collection of Summaries. summary_op = tf.summary.merge_all() # Build an initialization operation to run below. init = tf.global_variables_initializer() with tf.Session() as sess: # Start populating the filename queue. coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) sess.run(init) step_init = 0 checkpoint = tf.train.get_checkpoint_state(FLAGS.train_dir) if checkpoint and checkpoint.model_checkpoint_path: saver.restore(sess, checkpoint.model_checkpoint_path) step_init = sess.run(global_step) else: print("No checkpoint found...") summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph=sess.graph) for step in range(step_init, FLAGS.max_steps): start_time = time.time() images, labels = sess.run([train_set_batch, train_label_batch]) feed_dict = {image_batch: images, label_batch: labels, keep_probability: 0.6} _, pixel_error_value = sess.run([train_op, mean_pixel_error], feed_dict=feed_dict) duration = time.time() - start_time if not step == 0: # Print current results. if step % 50 == 0: num_examples_per_step = FLAGS.batch_size examples_per_sec = num_examples_per_step / duration sec_per_batch = float(duration) format_str = '%s: step %d, MeanPixelError = %.1f pixels (%.1f examples/sec; %.3f sec/batch)' print(format_str % (datetime.now(), step, pixel_error_value, examples_per_sec, sec_per_batch)) # Check results for validation set if step % 500 == 0: images, labels = sess.run([validation_set_batch, validation_label_batch]) feed_dict = {image_batch: images, label_batch: labels, keep_probability: 1} produced_labels, pixel_error_value = sess.run([logits, mean_pixel_error], feed_dict=feed_dict) draw(images[0, ...], produced_labels[0, ...], FLAGS.drawing_dir, step/500) print('Test Set MeanPixelError: %.1f pixels' % pixel_error_value) # Add summary to summary writer if step % 1000 == 0: summary_str = sess.run(summary_op, feed_dict=feed_dict) summary_writer.add_summary(summary_str, step) # Save the model checkpoint periodically. if step % 5000 == 0 or (step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) print('Model checkpoint saved for step %d' % step) coord.request_stop() coord.join(threads) if __name__ == '__main__': main()