# from __future__ import absolute_import from __future__ import print_function from __future__ import division from keras import backend as K from keras.layers import Layer from keras.layers.wrappers import TimeDistributed from helpers import compute_mask, softmax class QuestionPooling(Layer): def __init__(self, **kwargs): super(QuestionPooling, self).__init__(**kwargs) self.supports_masking = True def compute_output_shape(self, input_shape): assert(isinstance(input_shape, list) and len(input_shape) == 5) input_shape = input_shape[0] B, Q, H = input_shape return (B, H) def build(self, input_shape): assert(isinstance(input_shape, list) and len(input_shape) == 5) input_shape = input_shape[0] B, Q, H_ = input_shape H = H_ // 2 def call(self, inputs, mask=None): assert(isinstance(inputs, list) and len(inputs) == 5) uQ, WQ_u, WQ_v, v, VQ_r = inputs uQ_mask = mask[0] if mask is not None else None ones = K.ones_like(K.sum(uQ, axis=1, keepdims=True)) # (B, 1, 2H) s_hat = K.dot(uQ, WQ_u) s_hat += K.dot(ones, K.dot(WQ_v, VQ_r)) s_hat = K.tanh(s_hat) s = K.dot(s_hat, v) s = K.batch_flatten(s) a = softmax(s, mask=uQ_mask, axis=1) rQ = K.batch_dot(uQ, a, axes=[1, 1]) return rQ def compute_mask(self, input, mask=None): return None