import tensorflow as tf from babi.base_model import BaseTower, BaseRunner from my.tensorflow.nn import linear import numpy as np class Embedder(object): def __call__(self, content): raise Exception() class VariableEmbedder(Embedder): def __init__(self, params, wd=0.0, initializer=None, name="variable_embedder"): V, d = params.vocab_size, params.hidden_size with tf.variable_scope(name): self.emb_mat = tf.get_variable("emb_mat", dtype='float', shape=[V, d], initializer=initializer) # TODO : not sure wd is appropriate for embedding matrix if wd: weight_decay = tf.mul(tf.nn.l2_loss(self.emb_mat), wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) def __call__(self, word, name="embedded_content"): out = tf.nn.embedding_lookup(self.emb_mat, word, name=name) return out class PositionEncoder(object): def __init__(self, max_sent_size, hidden_size): self.max_sent_size, self.hidden_size = max_sent_size, hidden_size J, d = max_sent_size, hidden_size with tf.name_scope("pe_constants"): b = [1 - k/d for k in range(1, d+1)] w = [[j*(2*k/d - 1) for k in range(1, d+1)] for j in range(1, J+1)] self.b = tf.constant(b, shape=[d]) self.w = tf.constant(w, shape=[J, d]) def __call__(self, Ax, mask, scope=None): with tf.name_scope(scope or "position_encoder"): shape = Ax.get_shape().as_list() length_dim_index = len(shape) - 2 length = tf.reduce_sum(tf.cast(mask, 'float'), length_dim_index) length = tf.maximum(length, 1.0) # masked sentences will have length 0 length_aug = tf.expand_dims(tf.expand_dims(length, -1), -1) # l = self.b + self.w/length_aug l = self.b + self.w/self.max_sent_size mask_aug = tf.expand_dims(mask, -1) f = tf.reduce_sum(Ax * l * tf.cast(mask_aug, 'float'), length_dim_index, name='f') # [N, S, d] return f class VariablePositionEncoder(object): def __init__(self, max_sent_size, hidden_size, scope=None): self.max_sent_size, self.hidden_size = max_sent_size, hidden_size J, d = max_sent_size, hidden_size with tf.variable_scope(scope or self.__class__.__name__): self.w = tf.get_variable('w', shape=[J, d], dtype='float') def __call__(self, Ax, mask, scope=None): with tf.name_scope(scope or self.__class__.__name__): shape = Ax.get_shape().as_list() length_dim_index = len(shape) - 2 mask_aug = tf.expand_dims(mask, -1) f = tf.reduce_sum(Ax * self.w * tf.cast(mask_aug, 'float'), length_dim_index, name='f') # [N, S, d] return f class ReductionLayer(object): def __init__(self, batch_size, mem_size, hidden_size): self.hidden_size = hidden_size self.mem_size = mem_size self.batch_size = batch_size N, M, d = batch_size, mem_size, hidden_size self.L = np.tril(np.ones([M, M], dtype='float32')) self.sL = np.tril(np.ones([M, M], dtype='float32'), k=-1) def __call__(self, u_t, a, b, scope=None): """ :param u_t: [N, M, d] :param a: [N, M. 1] :param b: [N, M. 1] :param mask: [N, M] :return: """ N, M, d = self.batch_size, self.mem_size, self.hidden_size L, sL = self.L, self.sL with tf.name_scope(scope or self.__class__.__name__): L = tf.tile(tf.expand_dims(L, 0), [N, 1, 1]) sL = tf.tile(tf.expand_dims(sL, 0), [N, 1, 1]) logb = tf.log(b + 1e-9) logb = tf.concat(1, [tf.zeros([N, 1, 1]), tf.slice(logb, [0, 1, 0], [-1, -1, -1])]) left = L * tf.exp(tf.batch_matmul(L, logb * sL)) # [N, M, M] right = a * u_t # [N, M, d] u = tf.batch_matmul(left, right) # [N, M, d] return u class VectorReductionLayer(object): def __init__(self, batch_size, mem_size, hidden_size): self.hidden_size = hidden_size self.mem_size = mem_size self.batch_size = batch_size N, M, d = batch_size, mem_size, hidden_size self.L = np.tril(np.ones([M, M], dtype='float32')) self.sL = np.tril(np.ones([M, M], dtype='float32'), k=-1) def __call__(self, u_t, a, b, scope=None): """ :param u_t: [N, M, d] :param a: [N, M. d] :param b: [N, M. d] :param mask: [N, M] :return: """ N, M, d = self.batch_size, self.mem_size, self.hidden_size L, sL = self.L, self.sL with tf.name_scope(scope or self.__class__.__name__): L = tf.tile(tf.expand_dims(tf.expand_dims(L, 0), 0), [N, d, 1, 1]) sL = tf.tile(tf.expand_dims(tf.expand_dims(sL, 0), 0), [N, d, 1, 1]) logb = tf.log(b + 1e-9) # [N, M, d] logb = tf.concat(1, [tf.zeros([N, 1, d]), tf.slice(logb, [0, 1, 0], [-1, -1, -1])]) # [N, M, d] logb = tf.expand_dims(tf.transpose(logb, [0, 2, 1]), -1) # [N, d, M, 1] left = L * tf.exp(tf.batch_matmul(L, logb * sL)) # [N, d, M, M] right = a * u_t # [N, M, d] right = tf.expand_dims(tf.transpose(right, [0, 2, 1]), -1) # [N, d, M, 1] u = tf.batch_matmul(left, right) # [N, d, M, 1] u = tf.transpose(tf.squeeze(u, [3]), [0, 2, 1]) # [N, M, d] return u class Tower(BaseTower): def initialize(self): params = self.params placeholders = self.placeholders tensors = self.tensors variables_dict = self.variables_dict N, J, V, Q, M = params.batch_size, params.max_sent_size, params.vocab_size, params.max_ques_size, params.mem_size d = params.hidden_size L = params.mem_num_layers att_forget_bias = params.att_forget_bias use_vector_gate = params.use_vector_gate wd = params.wd initializer = tf.random_uniform_initializer(-np.sqrt(3), np.sqrt(3)) with tf.name_scope("placeholders"): x = tf.placeholder('int32', shape=[N, M, J], name='x') x_mask = tf.placeholder('bool', shape=[N, M, J], name='x_mask') q = tf.placeholder('int32', shape=[N, J], name='q') q_mask = tf.placeholder('bool', shape=[N, J], name='q_mask') y = tf.placeholder('int32', shape=[N], name='y') is_train = tf.placeholder('bool', shape=[], name='is_train') placeholders['x'] = x placeholders['x_mask'] = x_mask placeholders['q'] = q placeholders['q_mask'] = q_mask placeholders['y'] = y placeholders['is_train'] = is_train with tf.variable_scope("embedding"): A = VariableEmbedder(params, wd=wd, initializer=initializer, name='A') Aq = A(q, name='Aq') # [N, S, J, d] Ax = A(x, name='Ax') # [N, S, J, d] with tf.name_scope("encoding"): encoder = PositionEncoder(J, d) u = encoder(Aq, q_mask) # [N, d] m = encoder(Ax, x_mask) # [N, M, d] with tf.variable_scope("networks"): m_mask = tf.reduce_max(tf.cast(x_mask, 'int64'), 2, name='m_mask') # [N, M] gate_mask = tf.expand_dims(m_mask, -1) m_length = tf.reduce_sum(m_mask, 1, name='m_length') # [N] prev_u = tf.tile(tf.expand_dims(u, 1), [1, M, 1]) # [N, M, d] reg_layer = VectorReductionLayer(N, M, d) if use_vector_gate else ReductionLayer(N, M, d) gate_size = d if use_vector_gate else 1 h = None # [N, M, d] as_, rfs, rbs = [], [], [] hs = [] for layer_idx in range(L): with tf.name_scope("layer_{}".format(layer_idx)): u_t = tf.tanh(linear([prev_u, m], d, True, wd=wd, scope='u_t')) a = tf.cast(gate_mask, 'float') * tf.sigmoid(linear([prev_u * m], gate_size, True, initializer=initializer, wd=wd, scope='a') - att_forget_bias) h = reg_layer(u_t, a, 1.0-a, scope='h') if layer_idx + 1 < L: if params.use_reset: rf, rb = tf.split(2, 2, tf.cast(gate_mask, 'float') * tf.sigmoid(linear([prev_u * m], 2 * gate_size, True, initializer=initializer, wd=wd, scope='r'))) else: rf = rb = tf.ones(a.get_shape().as_list()) u_t_rev = tf.reverse_sequence(u_t, m_length, 1) a_rev, rb_rev = tf.reverse_sequence(a, m_length, 1), tf.reverse_sequence(rb, m_length, 1) uf = reg_layer(u_t, a*rf, 1.0-a, scope='uf') ub_rev = reg_layer(u_t_rev, a_rev*rb_rev, 1.0-a_rev, scope='ub_rev') ub = tf.reverse_sequence(ub_rev, m_length, 1) prev_u = uf + ub else: rf = rb = tf.zeros(a.get_shape().as_list()) rfs.append(rf) rbs.append(rb) as_.append(a) hs.append(h) tf.get_variable_scope().reuse_variables() h_last = tf.squeeze(tf.slice(h, [0, M-1, 0], [-1, -1, -1]), [1]) # [N, d] hs_last = [tf.squeeze(tf.slice(each, [0, M-1, 0], [-1, -1, -1]), [1]) for each in hs] a = tf.transpose(tf.pack(as_, name='a'), [1, 0, 2, 3]) rf = tf.transpose(tf.pack(rfs, name='rf'), [1, 0, 2, 3]) rb = tf.transpose(tf.pack(rbs, name='rb'), [1, 0, 2, 3]) tensors['a'] = a tensors['rf'] = rf tensors['rb'] = rb with tf.variable_scope("class"): class_mode = params.class_mode use_class_bias = params.use_class_bias if class_mode == 'h': # W = tf.transpose(A.emb_mat, name='W') logits = linear([h_last], V, use_class_bias, wd=wd) elif class_mode == 'uh': logits = linear([h_last, u], V, use_class_bias, wd=wd) elif class_mode == 'hs': logits = linear(hs_last, V, use_class_bias, wd=wd) elif class_mode == 'hss': logits = linear(sum(hs_last), V, use_class_bias, wd=wd) else: raise Exception("Invalid class mode: {}".format(class_mode)) yp = tf.cast(tf.argmax(logits, 1), 'int32') correct = tf.equal(yp, y) tensors['yp'] = yp tensors['correct'] = correct with tf.name_scope("loss"): with tf.name_scope("ans_loss"): ce = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, y, name='ce') avg_ce = tf.reduce_mean(ce, name='avg_ce') tf.add_to_collection('losses', avg_ce) losses = tf.get_collection('losses') loss = tf.add_n(losses, name='loss') tensors['loss'] = loss variables_dict['all'] = tf.trainable_variables() def get_feed_dict(self, batch, mode, **kwargs): params = self.params N, J, V, M = params.batch_size, params.max_sent_size, params.vocab_size, params.mem_size x = np.zeros([N, M, J], dtype='int32') x_mask = np.zeros([N, M, J], dtype='bool') q = np.zeros([N, J], dtype='int32') q_mask = np.zeros([N, J], dtype='bool') y = np.zeros([N], dtype='int32') ph = self.placeholders feed_dict = {ph['x']: x, ph['x_mask']: x_mask, ph['q']: q, ph['q_mask']: q_mask, ph['y']: y, ph['is_train']: mode == 'train' } if batch is None: return feed_dict X, Q, S, Y, H, T = batch for i, para in enumerate(X): if len(para) > M: para = para[-M:] for jj, sent in enumerate(para): # j = len(para) - jj - 1 # reverting story sequence, last to first j = jj for k, word in enumerate(sent): x[i, j, k] = word x_mask[i, j, k] = True for i, ques in enumerate(Q): for j, word in enumerate(ques): q[i, j] = word q_mask[i, j] = True for i, ans in enumerate(Y): y[i] = ans return feed_dict class Runner(BaseRunner): def _get_train_op(self, **kwargs): return self.train_ops['all']