from layers.feed_forward import * from layers.attention_layer import * from layers.embedding_layer import * from layers.layer_norm import LayerNormalization from tensorflow.python.framework import tensor_shape from utils.tf_utils import * import os _ROOT = os.path.abspath(os.path.dirname(__file__)) LOG_DIR = _ROOT + "/log" train_step_signature = [ tf.TensorSpec(shape=(None, None), dtype=tf.int32, name="Inputs"), tf.TensorSpec(shape=(None, None), dtype=tf.int32, name="Targets"), tf.TensorSpec(shape=(None), dtype=tf.int32, name="Step") ] class Gpt2(tf.keras.Model): def __init__(self, num_layers, d_model, num_heads, dff, max_seq_len, vocab_size, optimizer="adam", learning_rate=1e-3, rev_embedding_projection=True): super(Gpt2, self).__init__() self.rev_embedding_projection = rev_embedding_projection self.num_layers = num_layers self.num_heads = num_heads self.dff = dff self.max_seq_len = max_seq_len self.vocab_size = vocab_size self.d_model = d_model self.learning_rate = learning_rate self.optimizer_t = optimizer self.dataset = None self.mirrored_strategy = None self.embedding = EmbeddingLayer( self.vocab_size, self.d_model) self.pos_embedding = PositionEmbeddingLayer( self.max_seq_len, self.d_model) self.decoder_layers = [DecoderLayer(self.d_model, self.num_heads, self.dff) for _ in range(self.num_layers)] self.layer_norm = LayerNormalization(self.d_model) if not self.rev_embedding_projection: self.output_layer = OutputLayer(self.vocab_size) self.loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction='none') self.accuracy_object = tf.keras.metrics.SparseCategoricalAccuracy( name='accuracy') self.train_step_signature = [ tf.TensorSpec(shape=(None, None), dtype=tf.int32)] def call(self, x, training=True, past=None): x = tf.cast(x, tf.int32) batch, sequence = tf.shape(x)[0], tf.shape(x)[1] if past is None: pasts = [None] * self.num_layers else: pasts = past assert len(pasts) == self.num_layers att_mask = create_masks(x) past_length = 1 if past is None else tf.shape(past)[-2] with tf.name_scope("embeddings"): embedded_x = self.embedding(x) hidden_states = embedded_x + self.pos_embedding(x, start=past_length) presents = [] for decoder_layer, past in zip(self.decoder_layers, pasts): hidden_states, present = decoder_layer(hidden_states, training, att_mask, past=past) presents.append(present) hidden_states = self.layer_norm(hidden_states) if self.rev_embedding_projection: logits = self.embedding(hidden_states, mode="projection") else: logits = self.output_layer(hidden_states) return logits, presents @staticmethod def get_padded_accuracy(labels, logits): with tf.name_scope("padded_accuracy"): weights = tf.cast(tf.not_equal(labels, 0), tf.float32) outputs = tf.cast(tf.argmax(logits, axis=-1), tf.int32) padded_labels = tf.cast(labels, tf.int32) nonpad_seq = tf.math.count_nonzero(weights, dtype=tf.dtypes.float32, ) acc = tf.cast(tf.equal(outputs, padded_labels), tf.float32) accuracy = tf.reduce_sum(tf.cast(acc * weights, tf.float32)) / nonpad_seq return tf.cast(accuracy, tf.float32) def creat_optimizer(self): optimizer = self.optimizer_t.lower() with tf.name_scope("optimizer"): if optimizer == "adam": self.optimizer = tf.keras.optimizers.Adam(self.learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) elif optimizer == "adadelta": self.optimizer = tf.keras.optimizers.Adadelta(self.learning_rate) elif optimizer == "rms": self.optimizer = tf.keras.optimizers.RMSprop(self.learning_rate) else: self.optimizer = tf.keras.optimizers.SGD(self.learning_rate) return self.optimizer def get_loss(self, real, pred): with tf.name_scope("loss_layer"): mask = tf.math.logical_not(tf.math.equal(real, 0)) loss_ = self.loss_object(real, pred) with tf.name_scope("loss_masking"): mask = tf.cast(mask, dtype=loss_.dtype) loss_ *= mask loss_ = tf.reduce_sum(loss_, axis=1) sequence_avg_loss = loss_ / tf.reduce_sum(mask, axis=1) return sequence_avg_loss def create_checkpoint_manager(self, checkpoint_path, max_to_keep=5, load_model=True): with tf.name_scope('checkpoint_manager'): ckpt = tf.train.Checkpoint(optimizer=self.optimizer, model=self) self.ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=max_to_keep) if load_model: # If want to load trained weights ckpt.restore(self.ckpt_manager.latest_checkpoint) print('Latest checkpoint restored...............') else: print("Initializing model from scratch..........") def load_model(self, filepath): ckpt = tf.train.Checkpoint(model=self) ckpt_manager = tf.train.CheckpointManager(ckpt, filepath) ckpt.restore(ckpt_manager.latest_checkpoint) print("Model Restored..........................") def create_summary_writer(self, summary_path): train_summary_path = summary_path + "/train" test_summary_path = summary_path + "/test" with tf.name_scope('summary'): self.train_writer = tf.summary.create_file_writer(train_summary_path) self.test_writer = tf.summary.create_file_writer(test_summary_path) return self.train_writer, self.test_writer @tf.function(input_signature=train_step_signature) def train_step(self, inputs, targets, step, grad_clip=True, clip_value=2.5): with tf.GradientTape() as tape: predictions, _ = self(inputs, training=True) loss = tf.reduce_mean(self.get_loss(targets, predictions)) with tf.name_scope("gradients"): gradients = tape.gradient(loss, self.trainable_variables) if grad_clip: gradients = [(tf.clip_by_value(grad, -clip_value, clip_value)) for grad in gradients] self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) accuracy = self.get_padded_accuracy(targets, predictions) with tf.name_scope("summary_writer"): with self.train_writer.as_default(): tf.summary.scalar("loss", loss, step=tf.cast(step, tf.int64)) tf.summary.scalar("accuracy", accuracy, step=tf.cast(step, tf.int64)) return loss, accuracy @tf.function def distributed_train_step(self, inputs, targets, step, grad_clip=True, clip_value=1.0): def step_fn(inp, tar): with tf.GradientTape() as tape: logits = self(inputs) cross_entropy = self.get_loss(targets, logits) loss = tf.reduce_mean(cross_entropy) with tf.name_scope("gradients"): gradients = tape.gradient(loss, self.trainable_variables) if grad_clip: gradients = [(tf.clip_by_value(grad, -clip_value, clip_value)) for grad in gradients] self.optimizer.apply_gradients(list(zip(gradients, self.trainable_variables))) return cross_entropy per_example_losses = self.mirrored_strategy.experimental_run_v2( step_fn, args=(inputs, targets)) mean_loss = self.mirrored_strategy.reduce( tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0) with tf.name_scope("summary_writer"): with self.train_writer.as_default(): tf.summary.scalar("loss", mean_loss, step=step) return mean_loss def fit(self, train_dataset): if self.mirrored_strategy is None: tf.summary.trace_on(graph=True, profiler=True) for (step, (inputs, targets)) in enumerate(train_dataset): train_loss, train_acc = self.train_step(inputs, targets, step) if step % 10 == 0: print('Step {} Train_Loss {:.4f} Train_Accuracy {:.4f}'.format( step, train_loss, train_acc)) if step == 0: with self.train_writer.as_default(): tf.summary.trace_export( name="gpt-2", step=0, profiler_outdir=LOG_DIR) if step % 1000 == 0: ckpt_save_path = self.ckpt_manager.save() print('Saving checkpoint for step {} at {}'.format(step, ckpt_save_path)) else: with self.mirrored_strategy.scope(): tf.summary.trace_on(graph=True, profiler=True) for (step, (inputs)) in enumerate(train_dataset): train_loss = self.distributed_train_step(inputs, step) if step == 0: with self.train_writer.as_default(): tf.summary.trace_export( name="gpt-2", step=0, profiler_outdir=LOG_DIR) if step % 100 == 0: print('Step {} Train_Loss {:.4f}'.format( step, train_loss)) if step % 1000 == 0: ckpt_save_path = self.ckpt_manager.save() print('Saving checkpoint for step {} at {}'.format(step, ckpt_save_path)) class OutputLayer(tf.keras.layers.Layer): def __init__(self, output_dim, proj_weights=None, kernel_initializer=None): super(OutputLayer, self).__init__() self.proj_weights = proj_weights self.output_dim = output_dim self.layer_weights = None self.kernel_initializer = kernel_initializer def build(self, input_shape): if self.proj_weights is None: input_dim = tensor_shape.dimension_value(input_shape[-1]) self.layer_weights = self.add_weight( 'output_layer_weights', shape=[input_dim, self.output_dim], initializer=self.kernel_initializer, trainable=True) super(OutputLayer, self).build(input_shape) def call(self, x): batch, sequence, d_model = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[-1] h_flat = tf.reshape(x, [-1, d_model]) if self.proj_weights is None: out = tf.matmul(h_flat, self.layer_weights) else: out = tf.matmul(h_flat, self.porj_weights, transpose_b=True) out = tf.reshape(out, [batch, sequence, self.output_dim]) return out class DecoderLayer(tf.keras.layers.Layer): def __init__(self, d_model, num_heads, dff, dr_rate=0.1): super(DecoderLayer, self).__init__() self.d_model = d_model self.num_heads = num_heads self.dff = dff self.dr_rate = dr_rate self.mha = MultiHeadAttention(self.d_model, self.num_heads) self.feed_forward = FeedForward(self.d_model, self.dff, self.dr_rate) self.layer_norm1 = LayerNormalization(self.d_model) self.layer_norm2 = LayerNormalization(self.d_model) def call(self, x, training, mask, past=None): out, present = self.mha(self.layer_norm1(x), mask=mask, past_layer=past, training=training) # (batch_size, input_seq_len, d_model) with tf.name_scope("residual_conn"): x = x + out out = self.feed_forward(self.layer_norm2(x), training=training) # (batch_size, input_seq_len, d_model) with tf.name_scope("residual_conn"): x = x + out return x, present