from __future__ import absolute_import, division, print_function, unicode_literals from builtins import open import os import operator import logging from timeit import default_timer import tensorflow as tf import tensorflow_fold as td from tensorflow.contrib.tensorboard.plugins import projector from . import data from . import apputil from .config import hyper, param logger = logging.getLogger(__name__) def linear_combine(clen, pclen, idx): Wl = param.get('Wl') Wr = param.get('Wr') dim = tf.unstack(tf.shape(Wl))[0] batch_shape = tf.shape(clen) f = (clen / pclen) l = (pclen - idx - 1) / (pclen - 1) r = (idx) / (pclen - 1) # when pclen == 1, replace nan items with 0.5 l = tf.where(tf.is_nan(l), tf.ones_like(l) * 0.5, l) r = tf.where(tf.is_nan(r), tf.ones_like(r) * 0.5, r) lb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * l) rb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * r) fb = tf.transpose(tf.transpose(tf.eye(dim, batch_shape=batch_shape)) * f) lb = tf.reshape(lb, [-1, hyper.word_dim]) rb = tf.reshape(rb, [-1, hyper.word_dim]) tmp = tf.matmul(lb, Wl) + tf.matmul(rb, Wr) tmp = tf.reshape(tmp, [-1, hyper.word_dim, hyper.word_dim]) return tf.matmul(fb, tmp) def batch_mul(batch, weight): batch = tf.expand_dims(batch, axis=1) mul = tf.matmul(batch, weight) return tf.squeeze(mul, axis=1) def expand_dim_blk(axis): return td.Function(lambda tensor: tf.expand_dims(tensor, axis=axis)) def linear_combine_blk(): blk = td.Function(linear_combine, infer_output_type=False) blk.set_output_type(td.TensorType([hyper.word_dim, hyper.word_dim])) return blk def continous_weighted_add_blk(): block = td.Composition(name='continous_weighted_add') with block.scope(): initial = td.GetItem(0).reads(block.input) cur = td.GetItem(1).reads(block.input) last = td.GetItem(0).reads(initial) idx = td.GetItem(1).reads(initial) cur_fea = td.GetItem(0).reads(cur) cur_clen = td.GetItem(1).reads(cur) pclen = td.GetItem(2).reads(cur) Wi = linear_combine_blk().reads(cur_clen, pclen, idx) weighted_fea = td.Function(batch_mul).reads(cur_fea, Wi) block.output.reads( td.Function(tf.add, name='add_last_weighted_fea').reads(last, weighted_fea), # XXX: rewrite using tf.range td.Function(tf.add, name='add_idx_1').reads(idx, td.FromTensor(tf.constant(1.))) ) return block def clip_by_norm_blk(norm=1.0): return td.Function(lambda x: tf.clip_by_norm(x, norm, axes=[1])) def direct_embed_blk(): return (td.GetItem('name') >> td.Scalar('int32') >> td.Function(lambda x: tf.nn.embedding_lookup(param.get('We'), x)) >> clip_by_norm_blk()) def composed_embed_blk(): leaf_case = direct_embed_blk() nonleaf_case = td.Composition(name='composed_embed_nonleaf') with nonleaf_case.scope(): children = td.GetItem('children').reads(nonleaf_case.input) clen = td.Scalar().reads(td.GetItem('clen').reads(nonleaf_case.input)) cclens = td.Map(td.GetItem('clen') >> td.Scalar()).reads(children) fchildren = td.Map(direct_embed_blk()).reads(children) initial_state = td.Composition() with initial_state.scope(): initial_state.output.reads( td.FromTensor(tf.zeros(hyper.word_dim)), td.FromTensor(tf.zeros([])), ) summed = td.Zip().reads(fchildren, cclens, td.Broadcast().reads(clen)) summed = td.Fold(continous_weighted_add_blk(), initial_state).reads(summed)[0] added = td.Function(tf.add, name='add_bias').reads(summed, td.FromTensor(param.get('B'))) normed = clip_by_norm_blk().reads(added) act_fn = tf.nn.relu if hyper.use_relu else tf.nn.tanh relu = td.Function(act_fn).reads(normed) nonleaf_case.output.reads(relu) return td.OneOf(lambda node: node['clen'] == 0, {True: leaf_case, False: nonleaf_case}) def batch_nn_l2loss(a, b): """L2 loss between a and b, similar to tf.nn.l2_loss, but treat dim 0 as batch dim""" diff = tf.subtract(a, b) diff = tf.multiply(diff, diff) s = tf.reduce_sum(diff, axis=1) s = s / 2 return s def l2loss_blk(): # rewrite using metric leaf_case = td.Composition() with leaf_case.scope(): leaf_case.output.reads(td.FromTensor(tf.constant(1.))) nonleaf_case = td.Composition() with nonleaf_case.scope(): direct = direct_embed_blk().reads(nonleaf_case.input) com = composed_embed_blk().reads(nonleaf_case.input) loss = td.Function(batch_nn_l2loss).reads(direct, com) nonleaf_case.output.reads(loss) return td.OneOf(lambda node: node['clen'] != 0, {False: leaf_case, True: nonleaf_case}) # generalize to tree_reduce, accepts one block that takes two node, returns a value def tree_sum_blk(loss_blk): # traverse the tree to sum up the loss tree_sum_fwd = td.ForwardDeclaration(td.PyObjectType(), td.TensorType([])) tree_sum = td.Composition() with tree_sum.scope(): myloss = loss_blk().reads(tree_sum.input) children = td.GetItem('children').reads(tree_sum.input) mapped = td.Map(tree_sum_fwd()).reads(children) summed = td.Reduce(td.Function(tf.add)).reads(mapped) summed = td.Function(tf.add).reads(summed, myloss) tree_sum.output.reads(summed) tree_sum_fwd.resolve_to(tree_sum) return tree_sum def write_embedding_metadata(writer, word2int): metadata_path = os.path.join(hyper.train_dir, 'embedding_meta.tsv') # dump embedding mapping items = sorted(word2int.items(), key=operator.itemgetter(1)) with open(metadata_path, 'w') as f: for item in items: print(item[0], file=f) config = projector.ProjectorConfig() config.model_checkpoint_dir = hyper.train_dir # the above line not work yet. TF doesn't support model_checkpoint_dir # thus create a symlink from train_dir to log_dir os.symlink(os.path.join(hyper.train_dir, 'checkpoint'), os.path.join(hyper.log_dir, 'checkpoint')) embedding = config.embeddings.add() embedding.tensor_name = param.get('We').name # Link this tensor to its metadata file (e.g. labels). embedding.metadata_path = metadata_path # Saves a configuration file that TensorBoard will read during startup. projector.visualize_embeddings(writer, config) def main(): apputil.initialize(variable_scope='embedding') # load data early so we can initialize hyper parameters accordingly ds = data.load_dataset('data/statements') hyper.node_type_num = len(ds.word2int) hyper.dump() # create model variables param.initialize_embedding_weights() # Compile the block tree_sum = td.GetItem(0) >> tree_sum_blk(l2loss_blk) compiler = td.Compiler.create(tree_sum) (batched_loss, ) = compiler.output_tensors loss = tf.reduce_mean(batched_loss) opt = tf.train.AdamOptimizer(learning_rate=hyper.learning_rate) global_step = tf.Variable(0, trainable=False, name='global_step') train_step = opt.minimize(loss, global_step=global_step) # Attach summaries tf.summary.histogram('Wl', param.get('Wl')) tf.summary.histogram('Wr', param.get('Wr')) tf.summary.histogram('B', param.get('B')) tf.summary.histogram('Embedding', param.get('We')) tf.summary.scalar('loss', loss) summary_op = tf.summary.merge_all() # create missing dir if not os.path.exists(hyper.train_dir): os.makedirs(hyper.train_dir) # train loop saver = tf.train.Saver() train_set = compiler.build_loom_inputs(ds.get_split('all')[1]) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) summary_writer = tf.summary.FileWriter(hyper.log_dir, graph=sess.graph) write_embedding_metadata(summary_writer, ds.word2int) for epoch, shuffled in enumerate(td.epochs(train_set, hyper.num_epochs), 1): for step, batch in enumerate(td.group_by_batches(shuffled, hyper.batch_size), 1): train_feed_dict = {compiler.loom_input_tensor: batch} start_time = default_timer() _, loss_value, summary, gstep = sess.run([train_step, loss, summary_op, global_step], train_feed_dict) duration = default_timer() - start_time logger.info('global %d epoch %d step %d loss = %.2f (%.1f samples/sec; %.3f sec/batch)', gstep, epoch, step, loss_value, hyper.batch_size / duration, duration) if gstep % 10 == 0: summary_writer.add_summary(summary, gstep) if gstep % 100 == 0: saver.save(sess, os.path.join(hyper.train_dir, "model.ckpt"), global_step=gstep) if __name__ == '__main__': main()