"""Train object-to-sentence model. python initialization/obj2sen.py --batch_size 512 --save_checkpoint_steps 5000 """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import functools import os import tensorflow as tf from config import NUM_DESCRIPTIONS from misc_fn import crop_sentence from misc_fn import transform_grads_fn from misc_fn import validate_batch_size_for_multi_gpu from input_pipeline import AUTOTUNE tf.logging.set_verbosity(tf.logging.INFO) tf.flags.DEFINE_integer('intra_op_parallelism_threads', 0, 'Number of threads') tf.flags.DEFINE_integer('inter_op_parallelism_threads', 0, 'Number of threads') tf.flags.DEFINE_bool('multi_gpu', False, 'use multi gpus') tf.flags.DEFINE_integer('emb_dim', 512, 'emb dim') tf.flags.DEFINE_integer('mem_dim', 512, 'mem dim') tf.flags.DEFINE_float('keep_prob', 0.8, 'keep prob') tf.flags.DEFINE_string('job_dir', 'obj2sen', 'job dir') tf.flags.DEFINE_integer('batch_size', 512, 'batch size') tf.flags.DEFINE_integer('max_steps', 1000000, 'training steps') tf.flags.DEFINE_float('weight_decay', 0, 'weight decay') tf.flags.DEFINE_float('lr', 0.001, 'learning rate') tf.flags.DEFINE_integer('save_summary_steps', 100, 'save summary steps') tf.flags.DEFINE_integer('save_checkpoint_steps', 5000, 'save ckpt') FLAGS = tf.flags.FLAGS def model_fn(features, labels, mode, params): is_training = mode == tf.estimator.ModeKeys.TRAIN with tf.variable_scope('Discriminator'): embedding = tf.get_variable( name='embedding', shape=[FLAGS.vocab_size, FLAGS.emb_dim], initializer=tf.random_uniform_initializer(-0.08, 0.08)) key, lk = features['key'], features['len'] key = tf.nn.embedding_lookup(embedding, key) sentence, ls = labels['sentence'], labels['len'] targets = sentence[:, 1:] sentence = sentence[:, :-1] ls -= 1 sentence = tf.nn.embedding_lookup(embedding, sentence) cell = tf.nn.rnn_cell.BasicLSTMCell(params.mem_dim) if is_training: cell = tf.nn.rnn_cell.DropoutWrapper(cell, params.keep_prob, params.keep_prob) out, initial_state = tf.nn.dynamic_rnn(cell, key, lk, dtype=tf.float32) feat = tf.nn.l2_normalize(initial_state[1], axis=1) batch_size = tf.shape(feat)[0] with tf.variable_scope('Generator'): embedding = tf.get_variable( name='embedding', shape=[FLAGS.vocab_size, FLAGS.emb_dim], initializer=tf.random_uniform_initializer(-0.08, 0.08)) softmax_w = tf.matrix_transpose(embedding) softmax_b = tf.get_variable('softmax_b', [FLAGS.vocab_size]) cell = tf.nn.rnn_cell.BasicLSTMCell(params.mem_dim) if is_training: cell = tf.nn.rnn_cell.DropoutWrapper(cell, params.keep_prob, params.keep_prob) zero_state = cell.zero_state(batch_size, tf.float32) _, state = cell(feat, zero_state) tf.get_variable_scope().reuse_variables() out, state = tf.nn.dynamic_rnn(cell, sentence, ls, state) out = tf.reshape(out, [-1, FLAGS.mem_dim]) logits = tf.nn.bias_add(tf.matmul(out, softmax_w), softmax_b) logits = tf.reshape(logits, [batch_size, -1, FLAGS.vocab_size]) predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) mask = tf.sequence_mask(ls, tf.shape(sentence)[1]) targets = tf.boolean_mask(targets, mask) logits = tf.boolean_mask(logits, mask) loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=targets, logits=logits) loss = tf.reduce_mean(loss) opt = tf.train.AdamOptimizer(params.lr) if params.multi_gpu: opt = tf.contrib.estimator.TowerOptimizer(opt) grads = opt.compute_gradients(loss) grads = transform_grads_fn(grads) train_op = opt.apply_gradients(grads, global_step=tf.train.get_global_step()) train_hooks = None if not FLAGS.multi_gpu or opt._graph_state().is_the_last_tower: with open('data/word_counts.txt', 'r') as f: dic = list(f) dic = [i.split()[0] for i in dic] end_id = dic.index('</S>') dic.append('<unk>') dic = tf.convert_to_tensor(dic) noise = features['key'][0] m = tf.sequence_mask(features['len'][0], tf.shape(noise)[0]) noise = tf.boolean_mask(noise, m) noise = tf.gather(dic, noise) sentence = crop_sentence(labels['sentence'][0], end_id) sentence = tf.gather(dic, sentence) pred = crop_sentence(predictions[0], end_id) pred = tf.gather(dic, pred) train_hooks = [tf.train.LoggingTensorHook( {'sentence': sentence, 'noise': noise, 'pred': pred}, every_n_iter=100)] for variable in tf.trainable_variables(): tf.summary.histogram(variable.op.name, variable) predictions = tf.boolean_mask(predictions, mask) metrics = { 'acc': tf.metrics.accuracy(targets, predictions) } return tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=train_hooks, eval_metric_ops=metrics) def batching_func(x, batch_size): return x.padded_batch( batch_size, padded_shapes=( tf.TensorShape([None]), tf.TensorShape([]), tf.TensorShape([None]), tf.TensorShape([]))) def parse_sentence(serialized): """Parses a tensorflow.SequenceExample into an caption. Args: serialized: A scalar string Tensor; a single serialized SequenceExample. Returns: key: The keywords in a sentence. num_key: The number of keywords. sentence: A description. sentence_length: The length of the description. """ context, sequence = tf.parse_single_sequence_example( serialized, context_features={}, sequence_features={ 'key': tf.FixedLenSequenceFeature([], dtype=tf.int64), 'sentence': tf.FixedLenSequenceFeature([], dtype=tf.int64), }) key = tf.to_int32(sequence['key']) key = tf.random_shuffle(key) sentence = tf.to_int32(sequence['sentence']) return key, tf.shape(key)[0], sentence, tf.shape(sentence)[0] def input_fn(batch_size, subset='train'): sentence_ds = tf.data.TFRecordDataset('data/sentence.tfrec') num_val = NUM_DESCRIPTIONS // 50 if subset == 'train': sentence_ds = sentence_ds.skip(num_val) else: sentence_ds = sentence_ds.take(num_val) sentence_ds = sentence_ds.map(parse_sentence, num_parallel_calls=AUTOTUNE) sentence_ds = sentence_ds.filter(lambda k, lk, s, ls: tf.not_equal(lk, 0)) if subset == 'train': sentence_ds = sentence_ds.apply(tf.contrib.data.shuffle_and_repeat(65536)) sentence_ds = batching_func(sentence_ds, batch_size) sentence_ds = sentence_ds.prefetch(AUTOTUNE) iterator = sentence_ds.make_one_shot_iterator() key, lk, sentence, ls = iterator.get_next() return {'key': key, 'len': lk}, {'sentence': sentence, 'len': ls} def main(_): os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' if FLAGS.multi_gpu: validate_batch_size_for_multi_gpu(FLAGS.batch_size) model_function = tf.contrib.estimator.replicate_model_fn( model_fn, loss_reduction=tf.losses.Reduction.MEAN) else: model_function = model_fn sess_config = tf.ConfigProto( allow_soft_placement=True, intra_op_parallelism_threads=FLAGS.intra_op_parallelism_threads, inter_op_parallelism_threads=FLAGS.inter_op_parallelism_threads, gpu_options=tf.GPUOptions(allow_growth=True)) run_config = tf.estimator.RunConfig( session_config=sess_config, save_checkpoints_steps=FLAGS.save_checkpoint_steps, save_summary_steps=FLAGS.save_summary_steps, keep_checkpoint_max=100) train_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size) eval_input_fn = functools.partial(input_fn, batch_size=FLAGS.batch_size, subset='val') estimator = tf.estimator.Estimator( model_fn=model_function, model_dir=FLAGS.job_dir, config=run_config, params=FLAGS) train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=FLAGS.max_steps) eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=None) tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) if __name__ == '__main__': tf.app.run()