# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import tensorflow as tf import numpy as np import collections from ops import input_ops FLAGS = tf.flags.FLAGS def read_vocab_embs(vocabulary_file, embedding_matrix_file): tf.logging.info("Reading vocabulary from %s", vocabulary_file) with tf.gfile.GFile(vocabulary_file, mode="r") as f: lines = list(f.readlines()) vocab = [line.decode("utf-8").strip() for line in lines] with open(embedding_matrix_file, "r") as f: embedding_matrix = np.load(f) tf.logging.info("Loaded embedding matrix with shape %s", embedding_matrix.shape) word_embedding_dict = collections.OrderedDict( zip(vocab, embedding_matrix)) return word_embedding_dict def read_vocab(vocabulary_file): tf.logging.info("Reading vocabulary from %s", vocabulary_file) with tf.gfile.GFile(vocabulary_file, mode="r") as f: lines = list(f.readlines()) reverse_vocab = [line.decode("utf-8").strip() for line in lines] tf.logging.info("Loaded vocabulary with %d words.", len(reverse_vocab)) #tf.logging.info("Loading embedding matrix from %s", embedding_matrix_file) # Note: tf.gfile.GFile doesn't work here because np.load() calls f.seek() # with 3 arguments. word_embedding_dict = collections.OrderedDict( zip(reverse_vocab, range(len(reverse_vocab)))) return word_embedding_dict class s2v(object): """Skip-thoughts model.""" def __init__(self, config, mode="train", input_reader=None, input_queue=None): """Basic setup. The actual TensorFlow graph is constructed in build(). Args: config: Object containing configuration parameters. mode: "train", "eval" or "encode". input_reader: Subclass of tf.ReaderBase for reading the input serialized tf.Example protocol buffers. Defaults to TFRecordReader. Raises: ValueError: If mode is invalid. """ if mode not in ["train", "eval", "encode"]: raise ValueError("Unrecognized mode: %s" % mode) self.config = config self.mode = mode self.reader = input_reader if input_reader else tf.TFRecordReader() self.input_queue = input_queue # Initializer used for non-recurrent weights. self.uniform_initializer = tf.random_uniform_initializer( minval=-FLAGS.uniform_init_scale, maxval=FLAGS.uniform_init_scale) # Input sentences represented as sequences of word ids. "encode" is the # source sentence, "decode_pre" is the previous sentence and "decode_post" # is the next sentence. # Each is an int64 Tensor with shape [batch_size, padded_length]. self.encode_ids = None # Boolean masks distinguishing real words (1) from padded words (0). # Each is an int32 Tensor with shape [batch_size, padded_length]. self.encode_mask = None # Input sentences represented as sequences of word embeddings. # Each is a float32 Tensor with shape [batch_size, padded_length, emb_dim]. self.encode_emb = None # The output from the sentence encoder. # A float32 Tensor with shape [batch_size, num_gru_units]. self.thought_vectors = None # The total loss to optimize. self.total_loss = None def build_inputs(self): if self.mode == "encode": encode_ids = tf.placeholder(tf.int64, (None, None), name="encode_ids") encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask") else: # Prefetch serialized tf.Example protos. input_queue = input_ops.prefetch_input_data( self.reader, FLAGS.input_file_pattern, shuffle=FLAGS.shuffle_input_data, capacity=FLAGS.input_queue_capacity, num_reader_threads=FLAGS.num_input_reader_threads) # Deserialize a batch. serialized = input_queue.dequeue_many(FLAGS.batch_size) encode = input_ops.parse_example_batch(serialized) encode_ids = encode.ids encode_mask = encode.mask self.encode_ids = encode_ids self.encode_mask = encode_mask def build_word_embeddings(self): rand_init = self.uniform_initializer self.word_embeddings = [] self.encode_emb = [] self.init = None for v in self.config.vocab_configs: if v.mode == 'fixed': if self.mode == "train": word_emb = tf.get_variable( name=v.name, shape=[v.size, v.dim], trainable=False) embedding_placeholder = tf.placeholder( tf.float32, [v.size, v.dim]) embedding_init = word_emb.assign(embedding_placeholder) rand = np.random.rand(1, v.dim) word_vecs = np.load(v.embs_file) load_vocab_size = word_vecs.shape[0] assert(load_vocab_size == v.size - 1) word_init = np.concatenate((rand, word_vecs), axis=0) self.init = (embedding_init, embedding_placeholder, word_init) else: word_emb = tf.get_variable( name=v.name, shape=[v.size, v.dim]) encode_emb = tf.nn.embedding_lookup(word_emb, self.encode_ids) self.word_emb = word_emb self.encode_emb.extend([encode_emb, encode_emb]) if v.mode == 'trained': for inout in ["", "_out"]: word_emb = tf.get_variable( name=v.name + inout, shape=[v.size, v.dim], initializer=rand_init) if self.mode == 'train': self.word_embeddings.append(word_emb) encode_emb = tf.nn.embedding_lookup(word_emb, self.encode_ids) self.encode_emb.append(encode_emb) if v.mode == 'expand': for inout in ["", "_out"]: encode_emb = tf.placeholder(tf.float32, ( None, None, v.dim), v.name + inout) self.encode_emb.append(encode_emb) word_emb_dict = read_vocab_embs(v.vocab_file + inout + ".txt", v.embs_file + inout + ".npy") self.word_embeddings.append(word_emb_dict) if v.mode != 'expand' and self.mode == 'encode': word_emb_dict = read_vocab(v.vocab_file) self.word_embeddings.extend([word_emb_dict, word_emb_dict]) def _initialize_cell(self, num_units, cell_type="GRU"): if cell_type == "GRU": return tf.contrib.rnn.GRUCell(num_units=num_units) elif cell_type == "LSTM": return tf.contrib.rnn.LSTMCell(num_units=num_units) else: raise ValueError("Invalid cell type") def bow(self, word_embs, mask): mask_f = tf.expand_dims(tf.cast(mask, tf.float32), -1) word_embs_mask = word_embs * mask_f bow = tf.reduce_sum(word_embs_mask, axis=1) return bow def rnn(self, word_embs, mask, scope, encoder_dim, cell_type="GRU"): length = tf.to_int32(tf.reduce_sum(mask, 1), name="length") if self.config.bidir: if encoder_dim % 2: raise ValueError( "encoder_dim must be even when using a bidirectional encoder.") num_units = encoder_dim // 2 cell_fw = self._initialize_cell(num_units, cell_type=cell_type) cell_bw = self._initialize_cell(num_units, cell_type=cell_type) outputs, states = tf.nn.bidirectional_dynamic_rnn( cell_fw=cell_fw, cell_bw=cell_bw, inputs=word_embs, sequence_length=length, dtype=tf.float32, scope=scope) if cell_type == "LSTM": states = [states[0][1], states[1][1]] state = tf.concat(states, 1) else: cell = self._initialize_cell(encoder_dim, cell_type=cell_type) outputs, state = tf.nn.dynamic_rnn( cell=cell, inputs=word_embs, sequence_length=length, dtype=tf.float32, scope=scope) if cell_type == "LSTM": state = state[1] return state def build_encoder(self): """Builds the sentence encoder. Inputs: self.encode_emb self.encode_mask Outputs: self.thought_vectors Raises: ValueError: if config.bidirectional_encoder is True and config.encoder_dim is odd. """ names = ["","_out"] self.thought_vectors = [] print(self.config.encoder) for i in range(2): with tf.variable_scope("encoder" + names[i]) as scope: if self.config.encoder == "gru": sent_rep = self.rnn(self.encode_emb[i], self.encode_mask, scope, self.config.encoder_dim, cell_type="GRU") elif self.config.encoder == "lstm": sent_rep = self.rnn(self.encode_emb[i], self.encode_mask, scope, self.config.encoder_dim, cell_type="LSTM") elif self.config.encoder == 'bow': sent_rep = self.bow(self.encode_emb[i], self.encode_mask) else: raise ValueError("Invalid encoder") thought_vectors = tf.identity(sent_rep, name="thought_vectors") self.thought_vectors.append(thought_vectors) def build_loss(self): """Builds the loss Tensor. Outputs: self.total_loss """ all_sen_embs = self.thought_vectors if FLAGS.dropout: mask_shp = [1, self.config.encoder_dim] bin_mask = tf.random_uniform(mask_shp) > FLAGS.dropout_rate bin_mask = tf.where(bin_mask, tf.ones(mask_shp), tf.zeros(mask_shp)) src = all_sen_embs[0] * bin_mask dst = all_sen_embs[1] * bin_mask scores = tf.matmul(src, dst, transpose_b=True) else: scores = tf.matmul(all_sen_embs[0], all_sen_embs[1], transpose_b=True) # Ignore source sentence scores = tf.matrix_set_diag(scores, np.zeros(FLAGS.batch_size)) # Targets targets_np = np.zeros((FLAGS.batch_size, FLAGS.batch_size)) ctxt_sent_pos = range(-FLAGS.context_size, FLAGS.context_size + 1) ctxt_sent_pos.remove(0) for ctxt_pos in ctxt_sent_pos: targets_np += np.eye(FLAGS.batch_size, k=ctxt_pos) targets_np_sum = np.sum(targets_np, axis=1, keepdims=True) targets_np = targets_np/targets_np_sum targets = tf.constant(targets_np, dtype=tf.float32) # Forward and backward scores f_scores = scores[:-1] b_scores = scores[1:] losses = tf.nn.softmax_cross_entropy_with_logits( labels=targets, logits=scores) loss = tf.reduce_mean(losses) tf.summary.scalar("losses/ent_loss", loss) self.total_loss = loss if self.mode == "eval": f_max = tf.to_int64(tf.argmax(f_scores, axis=1)) b_max = tf.to_int64(tf.argmax(b_scores, axis=1)) targets = range(FLAGS.batch_size - 1) targets = tf.constant(targets, dtype=tf.int64) fwd_targets = targets + 1 names_to_values, names_to_updates = tf.contrib.slim.metrics.aggregate_metric_map({ "Acc/Fwd Acc": tf.contrib.slim.metrics.streaming_accuracy(f_max, fwd_targets), "Acc/Bwd Acc": tf.contrib.slim.metrics.streaming_accuracy(b_max, targets) }) for name, value in names_to_values.iteritems(): tf.summary.scalar(name, value) self.eval_op = names_to_updates.values() def build(self): """Creates all ops for training, evaluation or encoding.""" self.build_inputs() self.build_word_embeddings() self.build_encoder() self.build_loss() def build_enc(self): """Creates all ops for training, evaluation or encoding.""" self.build_inputs() self.build_word_embeddings() self.build_encoder()