from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import sys import numpy as np import tensorflow as tf import tensorflow.contrib.slim as slim from config import TF_MODELS_PATH sys.path.append(TF_MODELS_PATH + '/research/im2txt/im2txt') sys.path.append(TF_MODELS_PATH + '/research/slim') from inference_utils import vocabulary from inference_utils.caption_generator import Caption from inference_utils.caption_generator import TopN from nets import inception_v4 FLAGS = tf.flags.FLAGS tf.flags.DEFINE_string('job_dir', 'saving', 'job dir') tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') tf.flags.DEFINE_integer('batch_size', 1, 'batch size') tf.flags.DEFINE_string("vocab_file", "data/word_counts.txt", "Text file containing the vocabulary.") tf.flags.DEFINE_integer('beam_size', 3, 'beam size') tf.flags.DEFINE_integer('max_caption_length', 20, 'beam size') tf.flags.DEFINE_float('length_normalization_factor', 0.0, 'l n f') tf.flags.DEFINE_string('data_dir', None, 'path to all images') tf.flags.DEFINE_string('inc_ckpt', None, 'InceptionV4 checkpoint path') def _tower_fn(im, is_training=False): with slim.arg_scope(inception_v4.inception_v4_arg_scope()): net, _ = inception_v4.inception_v4(im, None, is_training=False) net = tf.squeeze(net, [1, 2]) with tf.variable_scope('Generator'): feat = slim.fully_connected(net, FLAGS.mem_dim, activation_fn=None) feat = tf.nn.l2_normalize(feat, axis=1) embedding = tf.get_variable( name='embedding', shape=[FLAGS.vocab_size, FLAGS.emb_dim], initializer=tf.random_uniform_initializer(-0.08, 0.08)) softmax_w = tf.matrix_transpose(embedding) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) cell = tf.nn.rnn_cell.BasicLSTMCell(FLAGS.mem_dim) if is_training: cell = tf.nn.rnn_cell.DropoutWrapper(cell, FLAGS.keep_prob, FLAGS.keep_prob) zero_state = cell.zero_state(FLAGS.batch_size, tf.float32) _, state = cell(feat, zero_state) init_state = state tf.get_variable_scope().reuse_variables() state_feed = tf.placeholder(dtype=tf.float32, shape=[None, sum(cell.state_size)], name="state_feed") state_tuple = tf.split(value=state_feed, num_or_size_splits=2, axis=1) input_feed = tf.placeholder(dtype=tf.int64, shape=[None], # batch_size name="input_feed") inputs = tf.nn.embedding_lookup(embedding, input_feed) out, state_tuple = cell(inputs, state_tuple) tf.concat(axis=1, values=state_tuple, name="state") logits = tf.nn.bias_add(tf.matmul(out, softmax_w), softmax_b) tower_pred = tf.nn.softmax(logits, name="softmax") return tf.concat(init_state, axis=1, name='initial_state') def read_image(im): """Reads an image.""" filename = tf.string_join([FLAGS.data_dir, im]) image = tf.read_file(filename) image = tf.image.decode_jpeg(image, 3) image = tf.image.convert_image_dtype(image, tf.float32) image = tf.image.resize_images(image, [346, 346]) image = image[23:-24, 23:-24] image = image * 2 - 1 return image class Infer: def __init__(self, job_dir=FLAGS.job_dir): im_inp = tf.placeholder(tf.string, []) im = read_image(im_inp) im = tf.expand_dims(im, 0) initial_state_op = _tower_fn(im) vocab = vocabulary.Vocabulary(FLAGS.vocab_file) self.saver = tf.train.Saver(tf.trainable_variables('Generator')) self.im_inp = im_inp self.init_state = initial_state_op self.vocab = vocab config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True)) self.sess = tf.Session(config=config) inc_saver = tf.train.Saver(tf.global_variables('InceptionV4')) self.restore_fn(job_dir) inc_saver.restore(self.sess, FLAGS.inc_ckpt) def restore_fn(self, checkpoint_path): if tf.gfile.IsDirectory(checkpoint_path): checkpoint_path = tf.train.latest_checkpoint(checkpoint_path) if checkpoint_path: self.saver.restore(self.sess, checkpoint_path) else: self.sess.run(tf.global_variables_initializer()) def infer(self, im): vocab = self.vocab sess = self.sess im_inp = self.im_inp initial_state_op = self.init_state initial_state = sess.run(initial_state_op, feed_dict={im_inp: im}) initial_beam = Caption( sentence=[vocab.start_id], state=initial_state[0], logprob=0.0, score=0.0, metadata=[""]) partial_captions = TopN(FLAGS.beam_size) partial_captions.push(initial_beam) complete_captions = TopN(FLAGS.beam_size) # Run beam search. for _ in range(FLAGS.max_caption_length - 1): partial_captions_list = partial_captions.extract() partial_captions.reset() input_feed = np.array([c.sentence[-1] for c in partial_captions_list]) state_feed = np.array([c.state for c in partial_captions_list]) softmax, new_states = sess.run( fetches=["Generator/softmax:0", "Generator/state:0"], feed_dict={ "Generator/input_feed:0": input_feed, "Generator/state_feed:0": state_feed, }) metadata = None for i, partial_caption in enumerate(partial_captions_list): word_probabilities = softmax[i] word_probabilities[-1] = 0 state = new_states[i] # For this partial caption, get the beam_size most probable next words. words_and_probs = list(enumerate(word_probabilities)) words_and_probs.sort(key=lambda x: -x[1]) words_and_probs = words_and_probs[0:FLAGS.beam_size] # Each next word gives a new partial caption. for w, p in words_and_probs: if p < 1e-12: continue # Avoid log(0). sentence = partial_caption.sentence + [w] logprob = partial_caption.logprob + math.log(p) score = logprob if metadata: metadata_list = partial_caption.metadata + [metadata[i]] else: metadata_list = None if w == vocab.end_id: if FLAGS.length_normalization_factor > 0: score /= len(sentence) ** FLAGS.length_normalization_factor beam = Caption(sentence, state, logprob, score, metadata_list) complete_captions.push(beam) else: beam = Caption(sentence, state, logprob, score, metadata_list) partial_captions.push(beam) if partial_captions.size() == 0: # We have run out of partial candidates; happens when beam_size = 1. break # If we have no complete captions then fall back to the partial captions. # But never output a mixture of complete and partial captions because a # partial caption could have a higher score than all the complete captions. if not complete_captions.size(): complete_captions = partial_captions captions = complete_captions.extract(sort=True) ret = [] for i, caption in enumerate(captions): # Ignore begin and end words. sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]] sentence = " ".join(sentence) # print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob))) ret.append((sentence, math.exp(caption.logprob))) return ret