""" Train """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import random import tqdm import tensorflow as tf from utils import start_threads, set_logging_verbosity, MovingAverage, count_number_of_parameters from data_loader import DATA_PATH, queue_context, tokenize, vectorize from layers import _phase_train, _phase_infer from model import Model from search import reverse_decode, greedy_argmax from losses import gan_loss flags = tf.flags # -- saver and logging options flags.DEFINE_string("model_dir", "./tmp", ( "Model directory.")) flags.DEFINE_string("logging_verbosity", "INFO", ( "Set verbosity to INFO, WARN, DEBUG or ERROR")) # -- train options flags.DEFINE_string("corpus_name", "ptb", ( "Corpus name.")) flags.DEFINE_integer("batch_size", 32, ( "Batch size for dequeue.")) flags.DEFINE_integer("epoch_size", 10, ( "Quit after max number of epochs.")) flags.DEFINE_string("seed_text", "how are", ( "Seed the sampling from the generator with this text.")) flags.DEFINE_string("gan_strategy", "pretrain", ( "GAN training strategy (pretrain, generator, discriminator, simultaneous, alternating).")) flags.DEFINE_string("gan_type", "jsd", ( "GAN type (jsd, emd, ls).")) flags.DEFINE_float("gan_gd_ratio", 0.5, ( "Ratio > 0.5 will run generator more, < 0.5 will run discriminator more.")) # -- optimizer options flags.DEFINE_string("optimizer_type", "adam", ( "Optimizer type (adam, sgd, rmsprop).")) flags.DEFINE_float("learning_rate", 1e-4, ( "Learning rate for optimizer.")) flags.DEFINE_float("learning_rate_decay", 0.8, ( "Decay the learning rate once criterion is passed.")) flags.DEFINE_float("minimum_learning_rate", 1e-6, ( "Early stop when lowering than minimum.")) flags.DEFINE_float("max_grads", 5.0, ( "Max clipping of gradients.")) # -- model options flags.DEFINE_integer("embedding_dim", 128, ( "Hidden dimensions for embedding.")) flags.DEFINE_integer("rnn_hidden_dim", 128, ( "Hidden dimensions for RNN hidden vectors.")) flags.DEFINE_integer("output_hidden_dim", 128, ( "Hidden dimensions for output hidden vectors before softmax layer.")) flags.DEFINE_float("word_dropout_keep_prob", 0.9, ( "Dropout keep rate for word embeddings.")) flags.DEFINE_float("recurrent_dropout_keep_prob", 0.6, ( "Dropout keep rate for recurrent input and output vectors.")) flags.DEFINE_float("output_dropout_keep_prob", 0.5, ( "Dropout keep rate for output vectors.")) FLAGS = flags.FLAGS opts = FLAGS.__flags # dict TODO: make class? set_logging_verbosity(FLAGS.logging_verbosity) def _get_n_batches(batch_size, corpus_size): return int(corpus_size // batch_size) def set_initial_ops(): local_init_op = tf.local_variables_initializer() global_init_op = tf.global_variables_initializer() init_op = tf.group(local_init_op, global_init_op) return init_op def set_train_op(loss, tvars): if FLAGS.optimizer_type == "sgd": optimizer = tf.train.GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) elif FLAGS.optimizer_type == "rmsprop": optimizer = tf.train.RMSPropOptimizer(learning_rate=FLAGS.learning_rate) elif FLAGS.optimizer_type == "adam": optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) else: raise ValueError("Wrong optimizer_type.") gradients = optimizer.compute_gradients(loss, var_list=tvars) clipped_gradients = [(grad if grad is None else tf.clip_by_norm(grad, FLAGS.max_grads), var) for grad, var in gradients] train_op = optimizer.apply_gradients(clipped_gradients) return train_op def get_supervisor(model): saver = tf.train.Saver() summary_writer = tf.summary.FileWriter(FLAGS.model_dir) supervisor = tf.train.Supervisor( logdir=FLAGS.model_dir, is_chief=True, saver=saver, init_op=set_initial_ops(), summary_op=tf.summary.merge_all(), summary_writer=summary_writer, save_summaries_secs=100, # TODO: add as flags save_model_secs=1000, global_step=model.global_step, ) return supervisor def get_sess_config(): # gpu_options = tf.GPUOptions( # per_process_gpu_memory_fraction=self.gpu_memory_fraction, # allow_growth=True) # seems to be not working sess_config = tf.ConfigProto( # log_device_placement=True, inter_op_parallelism_threads=8, # TODO: add as flags # allow_soft_placement=True, # gpu_options=gpu_options) ) return sess_config def print_loss(sess, loss, moving_average=None): l = sess.run(loss) if moving_average is None: tf.logging.info(" loss: %.4f", l) else: l_ma = moving_average.next(l) tf.logging.info(" loss: %.4f", l_ma) # _g, _d = sess.run([g_loss, d_loss]) # tf.logging.info("g_loss: %.4f, d_loss: %.4f", _g, _d) # TODO: add to TensorBoard def print_valid_loss(sess, loss): sess.run(_phase_infer) total_loss = 0.0 for _ in range(100): # TODO: change, use all test data l = sess.run(loss) total_loss += l valid_loss = total_loss / 100. tf.logging.info(" valid_loss: %.4f", valid_loss) sess.run(_phase_train) # TODO: configurable seed_text def print_sample(sess, seed_text, probs, input_ph, word2idx, idx2word): # seed_text = "how are you" vector = vectorize(seed_text, word2idx) out = greedy_argmax(vector[:-1], lambda x: sess.run(probs, {input_ph: [x]})) text = reverse_decode(out, idx2word) tf.logging.info(" generated text:\n%s", text) # TODO: learning rate decay def main(): corpus = DATA_PATH[FLAGS.corpus_name] model = Model(corpus, **opts) n_batches = _get_n_batches(FLAGS.batch_size, model.corpus_size) # TODO: rename to pretrain g_loss = model.g_tensors_pretrain.loss g_train_op = set_train_op(g_loss, model.g_tvars) g_loss_valid = model.g_tensors_pretrain_valid.loss d_logits_real = model.d_tensors_real.prediction_logits d_logits_fake = model.d_tensors_fake.prediction_logits gan_d_loss, gan_g_loss = gan_loss( d_logits_real, d_logits_fake, gan_type=FLAGS.gan_type) gan_d_train_op = set_train_op(gan_d_loss, model.d_tvars) gan_g_train_op = set_train_op(gan_g_loss, model.g_tvars) g_loss_ma = MovingAverage(10) sv = get_supervisor(model) sess_config = get_sess_config() tf.logging.info(" number of parameters %i", count_number_of_parameters()) with sv.managed_session(config=sess_config) as sess: sess.run(_phase_train) start_threads(model.enqueue_data, (sess,)) start_threads(model.enqueue_data_valid, (sess,)) # TODO: add learning rate decay -> early_stop if FLAGS.gan_strategy == "pretrain": sv.loop(60, print_loss, (sess, g_loss, g_loss_ma)) sv.loop(600, print_valid_loss, (sess, g_loss_valid)) sv.loop(100, print_sample, (sess, FLAGS.seed_text, model.g_tensors_pretrain_valid.flat_logits, model.input_ph, model.word2idx, model.idx2word)) # TODO: cleanup elif FLAGS.gan_strategy in ["generator", "simultaneous", "alternating"]: # sv.loop(60, print_loss, (sess, g_loss, g_loss_ma)) # sv.loop(600, print_valid_loss, (sess, g_loss_valid)) sv.loop(100, print_sample, (sess, FLAGS.seed_text, model.g_tensors_fake_valid.flat_logits, model.input_ph, model.word2idx, model.idx2word)) # make graph read only sess.graph.finalize() for epoch in range(FLAGS.epoch_size): tf.logging.info(" epoch: %i", epoch) for _ in tqdm.tqdm(range(n_batches)): if sv.should_stop(): break if FLAGS.gan_strategy == "pretrain": sess.run([g_train_op, model.increment_global_step_op]) elif FLAGS.gan_strategy == "generator": sess.run([gan_g_train_op, model.increment_global_step_op]) elif FLAGS.gan_strategy == "discriminator": sess.run([gan_d_train_op, model.increment_global_step_op]) elif FLAGS.gan_strategy == "simultaneous": sess.run([gan_g_train_op, gan_d_train_op, model.increment_global_step_op]) elif FLAGS.gan_strategy == "alternating": assert 0. < FLAGS.gan_gd_ratio < 1.0 u = random.random() if FLAGS.gan_gd_ratio < u: sess.run([gan_g_train_op, model.increment_global_step_op]) elif FLAGS.gan_gd_ratio > u: sess.run([gan_d_train_op, model.increment_global_step_op]) else: raise ValueError("Wrong gan_strategy.") if False: # some criterion sv.stop() sv.saver.save(sess, sv.save_path, global_step=sv.global_step) tf.logging.info(" training finished") if __name__ == "__main__": main()