# -*- coding: UTF-8 -*- # !/usr/bin/python # @time :2019/7/22 9:36 # @author :Mo # @function : from keras.layers import Layer import keras.backend as K class TriglePositiomEmbedding(Layer): """Position embedding use sine and cosine functions. See: https://arxiv.org/pdf/1706.03762 Expand mode: # Input shape 2D tensor with shape: `(batch_size, sequence_length)`. # Output shape 3D tensor with shape: `(batch_size, sequence_length, output_dim)`. Add mode: # Input shape 3D tensor with shape: `(batch_size, sequence_length, feature_dim)`. # Output shape 3D tensor with shape: `(batch_size, sequence_length, feature_dim)`. Concat mode: # Input shape 3D tensor with shape: `(batch_size, sequence_length, feature_dim)`. # Output shape 3D tensor with shape: `(batch_size, sequence_length, feature_dim + output_dim)`. """ MODE_EXPAND = 'expand' MODE_ADD = 'add' MODE_CONCAT = 'concat' def __init__(self, mode=MODE_ADD, output_dim=None, **kwargs): """ :param output_dim: The embedding dimension. :param kwargs: """ if mode in [self.MODE_EXPAND, self.MODE_CONCAT]: if output_dim is None: raise NotImplementedError('`output_dim` is required in `%s` mode' % mode) if output_dim % 2 != 0: raise NotImplementedError('It does not make sense to use an odd output dimension: %d' % output_dim) self.mode = mode self.output_dim = output_dim self.supports_masking = True super(TriglePositiomEmbedding, self).__init__(**kwargs) def get_config(self): config = { 'mode': self.mode, 'output_dim': self.output_dim, } base_config = super(TriglePositiomEmbedding, self).get_config() return dict(list(base_config.items()) + list(config.items())) def compute_mask(self, inputs, mask=None): return mask def compute_output_shape(self, input_shape): if self.mode == self.MODE_EXPAND: return input_shape + (self.output_dim,) if self.mode == self.MODE_CONCAT: return input_shape[:-1] + (input_shape[-1] + self.output_dim,) return input_shape def call(self, inputs, mask=None): input_shape = K.shape(inputs) if self.mode == self.MODE_ADD: batch_size, seq_len, output_dim = input_shape[0], input_shape[1], input_shape[2] pos_input = K.tile(K.expand_dims(K.arange(seq_len), axis=0), [batch_size, 1]) elif self.mode == self.MODE_CONCAT: batch_size, seq_len, output_dim = input_shape[0], input_shape[1], self.output_dim pos_input = K.tile(K.expand_dims(K.arange(seq_len), axis=0), [batch_size, 1]) else: output_dim = self.output_dim pos_input = inputs if K.dtype(pos_input) != K.floatx(): pos_input = K.cast(pos_input, K.floatx()) evens = K.arange(output_dim // 2) * 2 odds = K.arange(output_dim // 2) * 2 + 1 even_embd = K.sin( K.dot( K.expand_dims(pos_input, -1), K.expand_dims(1.0 / K.pow( 10000.0, K.cast(evens, K.floatx()) / K.cast(output_dim, K.floatx()) ), 0) ) ) odd_embd = K.cos( K.dot( K.expand_dims(pos_input, -1), K.expand_dims(1.0 / K.pow( 10000.0, K.cast((odds - 1), K.floatx()) / K.cast(output_dim, K.floatx()) ), 0) ) ) embd = K.stack([even_embd, odd_embd], axis=-1) output = K.reshape(embd, [-1, K.shape(inputs)[1], output_dim]) if self.mode == self.MODE_CONCAT: output = K.concatenate([inputs, output], axis=-1) if self.mode == self.MODE_ADD: output += inputs return output