from keras.models import Model from keras.layers import Input, Dense, Embedding, Lambda, TimeDistributed, Add, Conv1D, Layer from kulc.layer_normalization import LayerNormalization from kulc.attention import MultiHeadAttention import numpy as np import keras.backend as K class PositionWiseFeedForward(object): # def __init__(self, d_model=512, d_ff=2048, **kwargs): def __init__(self, d_model=512, d_ff=512, **kwargs): self._d_model = d_model self._d_ff = d_ff self._conv1 = Conv1D(self._d_ff, kernel_size=1, activation="relu") self._conv2 = Conv1D(self._d_model, kernel_size=1) def __call__(self, x): intermediate_x = self._conv1(x) return self._conv2(intermediate_x) class EncoderLayer(object): def __init__(self, h=8, d_k=64, d_v=64, d_model=512, d_inner_hid=2048): self._mha = MultiHeadAttention(h=h, d_k=d_k, d_v=d_v, d_model=d_model) self._ln_a = LayerNormalization() self._psfw = PositionWiseFeedForward(d_model=d_model, d_ff=d_inner_hid) self._ln_b = LayerNormalization() self._add_a = Add() self._add_b = Add() def __call__(self, x): y = self._mha([x, x, x]) y = self._add_a([x, y]) x = self._ln_a(y) y = self._psfw(x) y = self._add_b([x, y]) x = self._ln_b(y) return x class DecoderLayer(object): def __init__(self, h=8, d_k=64, d_v=64, d_model=512, d_inner_hid=2048, return_attention=True): self._mha_a = MultiHeadAttention(h=h, d_k=d_k, d_v=d_v, d_model=d_model, return_attention=return_attention) self._mha_b = MultiHeadAttention(h=h, d_k=d_k, d_v=d_v, d_model=d_model, return_attention=return_attention) self._psfw = PositionWiseFeedForward(d_model=d_model, d_ff=d_inner_hid) self._ln_a = LayerNormalization() self._ln_b = LayerNormalization() self._ln_c = LayerNormalization() self._add_a = Add() self._add_b = Add() self._add_c = Add() self._return_attention = return_attention def __call__(self, x, encoder_output): y, self_atn = self._mha_a([x, x, x]) y = self._add_a([x, y]) x = self._ln_a(y) y, enc_atn = self._mha_b([x, encoder_output, encoder_output]) y = self._add_b([x, y]) x = self._ln_b(y) y = self._psfw(x) y = self._add_c([x, y]) x = self._ln_c(y) if self._return_attention: return [x, self_atn, enc_atn] else: return x class Encoder(object): def __init__(self, embedding, position_embedding, n=6, h=8, d_k=64, d_v=64, d_model=512, d_inner_hid=2048, null_token_value=0): self._embedding = embedding self._position_embedding = position_embedding self._n = n self._position_encoding = Lambda(_get_pos_seq, arguments={"null_token_value": null_token_value}) self._layers = [EncoderLayer(h=h, d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=d_inner_hid) for _ in range(n)] def __call__(self, x): x_embedded = self._embedding(x) pos_encoding = self._position_encoding(x) pos_encoding_embedded = self._position_embedding(pos_encoding) x = Add()([x_embedded, pos_encoding_embedded]) for layer in self._layers: x = layer(x) return x class Decoder(object): def __init__(self, embedding, position_embedding, n=6, h=8, d_k=64, d_v=64, d_model=512, d_inner_hid=2048, null_token_value=0): self._embedding = embedding self._position_embedding = position_embedding self._n = n self._position_encoding = Lambda(_get_pos_seq, arguments={"null_token_value": null_token_value}) self._layers = [DecoderLayer(h=h, d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=d_inner_hid) for _ in range(n)] def __call__(self, x, encoder_output, return_attention=False): x_embedded = self._embedding(x) pos_encoding = self._position_encoding(x) pos_encoding_embedded = self._position_embedding(pos_encoding) x = Add()([x_embedded, pos_encoding_embedded]) self_atts = [] enc_atts = [] for layer in self._layers: x, self_att, enc_att = layer(x, encoder_output) if return_attention: self_atts.append(self_att) enc_atts.append(enc_att) if return_attention: return [x, self_atts, enc_atts] else: return x def build_transformer(source_vocabulary_size, target_vocabulary_size, max_length, share_word_embedding=False, n=6, h=8, d_k=64, d_v=64, d_model=512, optimizer="adam", null_token_value=0): source_input = Input(shape=(None,), name="source_input") target_input = Input(shape=(None,), name="target_input") enc_input = Lambda(lambda x:x[:,1:])(source_input) dec_input = Lambda(lambda x:x[:,:-1])(target_input) dec_target_output = Lambda(lambda x:x[:,1:])(target_input) # create embedding source_word_embedding = Embedding(source_vocabulary_size, d_model, name="source_embedding" if share_word_embedding else "source_embedding") # weights=[_get_positional_encoding_matrix(max_length, d_model)] if share_word_embedding: target_word_embedding = source_word_embedding else: target_word_embedding = Embedding(target_vocabulary_size, d_model, name="target_embedding") # embedding for the position encoding position_encoding = Embedding(max_length, d_model, trainable=False, weights=[_get_positional_encoding_matrix(max_length, d_model)], name="position_embedding") enc = Encoder(source_word_embedding, position_encoding, n=n, h=h, d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=512) dec = Decoder(target_word_embedding, position_encoding, n=n, h=h, d_k=d_k, d_v=d_v, d_model=d_model, d_inner_hid=512) enc_output = enc(enc_input) dec_output = dec(dec_input, enc_output) # lin_dense = TimeDistributed(Dense(d_model)) fin_output = TimeDistributed(Dense(target_vocabulary_size, activation=None, use_bias=False), name="output") # "softmax" # lin_dense_out = lin_dense(dec_output) fin_output_out = fin_output(dec_output) # lin_dense_out) accuracy = Lambda(_get_accuracy, arguments={"null_token_value": null_token_value})([fin_output_out, dec_target_output]) loss = Lambda(_get_loss, arguments={"null_token_value": null_token_value})([fin_output_out, dec_target_output]) train_model = Model(inputs=[source_input, target_input], outputs=loss) train_model.add_loss([loss]) train_model.compile(optimizer, None) train_model.metrics_names.append('accuracy') train_model.metrics_tensors.append(accuracy) inference_model = Model([source_input, target_input], fin_output_out) return train_model, inference_model def create_model(source_vocabulary_size, target_vocabulary_size, max_length, share_word_embedding=False, n=6, h=8, d_k=64, d_v=64, d_model=512, optimizer="adam", null_token_value=0): return build_transformer( source_vocabulary_size=source_vocabulary_size, target_vocabulary_size=target_vocabulary_size, max_length=max_length, share_word_embedding=share_word_embedding, n=n, h=h, d_k=d_k, d_v=d_v,d_model=d_model, optimizer=optimizer, null_token_value=null_token_value) def _get_loss(args, null_token_value): y_pred, y_true = args y_true_id = K.cast(y_true, "int32") mask = K.cast(K.equal(y_true_id, null_token_value), K.floatx()) mask = 1.0 - mask loss = K.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True) * mask # take average w.r.t. the number of unmasked entries return K.sum(loss) / K.sum(mask) def _get_accuracy(args, null_token_value): y_pred, y_true = args y_true = K.cast(y_true, "int32") mask = 1.0 - K.cast(K.equal(y_true, null_token_value), K.floatx()) y_pred = K.cast(K.argmax(y_pred, axis=-1), "int32") correct = K.cast( K.equal(y_pred, y_true), K.floatx() ) correct = K.sum(correct * mask, -1) / K.sum(mask, -1) return K.mean(correct) def _get_pos_seq(x, null_token_value=0): mask = K.cast(K.not_equal(x, null_token_value), 'float32') pos = K.cumsum(K.ones_like(x, 'float32'), 1) return pos * mask def _get_positional_encoding_matrix(max_len, d_emb): pos_enc = np.array([ [pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)] if pos != 0 else np.zeros(d_emb) for pos in range(max_len) ]) pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2]) # dim 2i pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2]) # dim 2i+1 return pos_enc