__author__ = 'Charlie' """Image coloring by fully convolutional networks - incomplete """ import numpy as np import tensorflow as tf import os, sys, inspect from datetime import datetime import scipy.misc as misc lib_path = os.path.realpath( os.path.abspath(os.path.join(os.path.split(inspect.getfile(inspect.currentframe()))[0], ".."))) if lib_path not in sys.path: sys.path.insert(0, lib_path) import TensorflowUtils as utils FLAGS = tf.flags.FLAGS tf.flags.DEFINE_string("data_dir", "Data_zoo/CIFAR10_data/", """Path to the CIFAR10 data""") tf.flags.DEFINE_string("mode", "train", "Network mode train/ test") tf.flags.DEFINE_string("test_image_path", "", "Path to test image - read only if mode is test") tf.flags.DEFINE_integer("batch_size", "128", "train batch size") tf.flags.DEFINE_string("logs_dir", "logs/ImageColoring_logs/", """Path to save logs and checkpoint if needed""") DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' LEARNING_RATE = 1e-3 MAX_ITERATIONS = 100001 NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 20000 IMAGE_SIZE = 32 def read_cifar10(filename_queue): class CIFAR10Record(object): pass result = CIFAR10Record() label_bytes = 1 result.height = IMAGE_SIZE result.width = IMAGE_SIZE result.depth = 3 image_bytes = result.height * result.width * result.depth record_bytes = label_bytes + image_bytes reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) result.key, value = reader.read(filename_queue) record_bytes = tf.decode_raw(value, tf.uint8) depth_major = tf.cast(tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]), [result.depth, result.height, result.width]), tf.float32) image = tf.transpose(depth_major, [1, 2, 0]) # extended_image = tf.reshape(image, (result.height, result.width, result.depth)) result.color_image = image print result.color_image.get_shape() print "Converting image to gray scale" result.gray_image = 0.21 * result.color_image[ :, :, 2] + 0.72 * result.color_image[ :, :, 1] + 0.07 * result.color_image[ :, :, 0] result.gray_image = tf.expand_dims(result.gray_image, 2) print result.gray_image.get_shape() return result def get_image(image_dir): image = misc.imread(image_dir) image = np.ndarray.reshape(image.astype(np.float32), ((1,) + image.shape)) return image def inputs(): data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin') filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)] for f in filenames: if not tf.gfile.Exists(f): raise ValueError('Failed to find file: ' + f) filename_queue = tf.train.string_input_producer(filenames) read_input = read_cifar10(filename_queue) num_preprocess_threads = 8 min_queue_examples = int(0.4 * NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN) print "Shuffling" input_gray, input_colored = tf.train.shuffle_batch([read_input.gray_image, read_input.color_image], batch_size=FLAGS.batch_size, num_threads=num_preprocess_threads, capacity=min_queue_examples + 3 * FLAGS.batch_size, min_after_dequeue=min_queue_examples) input_gray = (input_gray - 128) / 255.0 input_colored = (input_colored - 128) / 255.0 return input_gray, input_colored def inference(image): W1 = utils.weight_variable_xavier_initialized([9, 9, 1, 32]) b1 = utils.bias_variable([32]) tf.histogram_summary("W1", W1) tf.histogram_summary("b1", b1) h_conv1 = tf.nn.relu(utils.conv2d_basic(image, W1, b1)) W2 = utils.weight_variable_xavier_initialized([3, 3, 32, 64]) b2 = utils.bias_variable([64]) tf.histogram_summary("W2", W2) tf.histogram_summary("b2", b2) h_conv2 = tf.nn.relu(utils.conv2d_strided(h_conv1, W2, b2)) W3 = utils.weight_variable_xavier_initialized([3, 3, 64, 128]) b3 = utils.bias_variable([128]) tf.histogram_summary("W3", W3) tf.histogram_summary("b3", b3) h_conv3 = tf.nn.relu(utils.conv2d_strided(h_conv2, W3, b3)) # upstrides W4 = utils.weight_variable_xavier_initialized([3, 3, 64, 128]) b4 = utils.bias_variable([64]) tf.histogram_summary("W4", W4) tf.histogram_summary("b4", b4) h_conv4 = tf.nn.relu(utils.conv2d_transpose_strided(h_conv3, W4, b4)) W5 = utils.weight_variable_xavier_initialized([3, 3, 32, 64]) b5 = utils.bias_variable([32]) tf.histogram_summary("W5", W5) tf.histogram_summary("b5", b5) h_conv5 = tf.nn.relu(utils.conv2d_transpose_strided(h_conv4, W5, b5)) W6 = utils.weight_variable_xavier_initialized([9, 9, 32, 3]) b6 = utils.bias_variable([3]) tf.histogram_summary("W6", W6) tf.histogram_summary("b6", b6) pred_image = tf.nn.tanh(utils.conv2d_basic(h_conv5, W6, b6)) return pred_image def loss(pred, colored): rmse = tf.sqrt(2 * tf.nn.l2_loss(tf.sub(colored, pred))) / FLAGS.batch_size tf.scalar_summary("RMSE", rmse) return rmse def train(loss_val, step): learning_rate = tf.train.exponential_decay(LEARNING_RATE, step, 0.4 * MAX_ITERATIONS, 0.99) train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss_val, global_step=step) return train_op def main(argv=None): utils.maybe_download_and_extract(FLAGS.data_dir, DATA_URL, is_tarfile=True) print "Setting up model..." global_step = tf.Variable(0,trainable=False) gray, color = inputs() pred = 255 * inference(gray) + 128 tf.image_summary("Gray", gray, max_images=1) tf.image_summary("Ground_truth", color, max_images=1) tf.image_summary("Prediction", pred, max_images=1) image_loss = loss(pred, color) train_op = train(image_loss, global_step) summary_op = tf.merge_all_summaries() with tf.Session() as sess: print "Setting up summary writer, queue, saver..." sess.run(tf.initialize_all_variables()) summary_writer = tf.train.SummaryWriter(FLAGS.logs_dir, sess.graph) saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state(FLAGS.logs_dir) if ckpt and ckpt.model_checkpoint_path: print "Restoring model from checkpoint..." saver.restore(sess, ckpt.model_checkpoint_path) tf.train.start_queue_runners(sess) for step in xrange(MAX_ITERATIONS): if step % 400 == 0: loss_val, summary_str = sess.run([image_loss, summary_op]) print "Step %d, Loss: %g" % (step, loss_val) summary_writer.add_summary(summary_str, global_step=step) if step % 1000 == 0: saver.save(sess, FLAGS.logs_dir + "model.ckpt", global_step=step) print "%s" % datetime.now() sess.run(train_op) if __name__ == "__main__": tf.app.run()