""" Tools that are not necessary for the Transformer by itself, but might be useful in building models with it. """ import math from keras import activations, regularizers # noinspection PyPep8Naming from keras import backend as K from keras.engine import Layer from keras.layers import Embedding from keras.utils import get_custom_objects class ReusableEmbedding(Embedding): """ A "reusable" form of the Embedding layer, which returns its full embedding matrix as one of the outputs. This is necessary to guarantee correct work of Keras when the matrix is being re-used again in TiedOutputEmbedding layer. """ def call(self, inputs, **kwargs): result = super().call(inputs, **kwargs) return [result, self.embeddings] def compute_output_shape(self, input_shape): return [super().compute_output_shape(input_shape), K.int_shape(self.embeddings)] def compute_mask(self, inputs, mask=None): return [super().compute_mask(inputs, mask), None] class TiedOutputEmbedding(Layer): """ Allows to reuse the same word embedding matrix both for the input and the output layers of the network. This is called Weight Tying and is proven to improve performance of neural network language models, as well as decrease their number of parameters (eliminating the need for a separate huge matrix of output weights). The layers is supposed to be called with two inputs, like TiedOutputEmbedding()([main_input, embedding_matrix]) where the `main_input` is the output of the previous layer (like LSTM) and the `embedding_matrix` coming from the `ReusableEmbedding` layer. https://arxiv.org/abs/1608.05859 https://arxiv.org/abs/1611.01462 https://blog.openai.com/language-unsupervised/ """ def __init__(self, activation=None, add_biases=False, projection_regularizer=None, projection_dropout: float = 0.0, scaled_attention=False, **kwargs): self.activation = activations.get(activation) self.add_biases = add_biases self.projection_regularizer = regularizers.get(projection_regularizer) self.projection_dropout = projection_dropout self.scaled_attention = scaled_attention super().__init__(**kwargs) def get_config(self): config = super().get_config() return dict( config, activation=activations.serialize(self.activation), add_biases=self.add_biases, projection_regularizer=regularizers.serialize( self.projection_regularizer), projection_dropout=self.projection_dropout, scaled_attention=self.scaled_attention) # noinspection PyAttributeOutsideInit def build(self, input_shape): main_input_shape, embedding_matrix_shape = input_shape emb_input_dim, emb_output_dim = embedding_matrix_shape assert len(main_input_shape) == 3 self.projection = self.add_weight( name='kernel', shape=(main_input_shape[-1], emb_output_dim), initializer='glorot_uniform', regularizer=self.projection_regularizer, trainable=True) if self.add_biases: self.biases = self.add_weight( name='biases', shape=(emb_output_dim,), initializer='zeros', trainable=True) return super().build(input_shape) def call(self, inputs, **kwargs): main_input, embedding_matrix = inputs input_shape_tensor = K.shape(main_input) last_input_dim = K.int_shape(main_input)[-1] emb_input_dim, emb_output_dim = K.int_shape(embedding_matrix) projected = K.dot(K.reshape(main_input, (-1, last_input_dim)), self.projection) if self.add_biases: projected = K.bias_add(projected, self.biases, data_format='channels_last') if 0 < self.projection_dropout < 1: projected = K.in_train_phase( lambda: K.dropout(projected, self.projection_dropout), projected, training=kwargs.get('training')) attention = K.dot(projected, K.transpose(embedding_matrix)) if self.scaled_attention: # scaled dot-product attention, described in # "Attention is all you need" (https://arxiv.org/abs/1706.03762) sqrt_d = K.constant(math.sqrt(emb_output_dim), dtype=K.floatx()) attention = attention / sqrt_d result = K.reshape( self.activation(attention), (input_shape_tensor[0], input_shape_tensor[1], emb_input_dim)) return result def compute_output_shape(self, input_shape): main_input_shape, embedding_matrix_shape = input_shape emb_input_dim, emb_output_dim = embedding_matrix_shape return main_input_shape[0], main_input_shape[1], emb_input_dim get_custom_objects().update({ 'ReusableEmbedding': ReusableEmbedding, 'TiedOutputEmbedding': TiedOutputEmbedding, })