import tensorflow as tf import numpy as np from models import BaseModel, AttentionCell, multi_head_attention from tensorflow.python.ops.rnn_cell import MultiRNNCell from tensorflow.python.ops.rnn import bidirectional_dynamic_rnn, dynamic_rnn from tensorflow.contrib.rnn.python.ops.rnn import stack_bidirectional_dynamic_rnn from models.nns import multi_conv1d, highway_network, layer_normalize from utils import Progbar class SequenceLabelModel(BaseModel): def __init__(self, config): super(SequenceLabelModel, self).__init__(config) def _add_placeholders(self): self.words = tf.placeholder(tf.int32, shape=[None, None], name="words") self.tags = tf.placeholder(tf.int32, shape=[None, None], name="tags") self.seq_len = tf.placeholder(tf.int32, shape=[None], name="seq_len") if self.cfg["use_chars"]: self.chars = tf.placeholder(tf.int32, shape=[None, None, None], name="chars") self.char_seq_len = tf.placeholder(tf.int32, shape=[None, None], name="char_seq_len") # hyper-parameters self.is_train = tf.placeholder(tf.bool, shape=[], name="is_train") self.batch_size = tf.placeholder(tf.int32, name="batch_size") self.keep_prob = tf.placeholder(tf.float32, name="rnn_keep_probability") self.drop_rate = tf.placeholder(tf.float32, name="dropout_rate") self.lr = tf.placeholder(tf.float32, name="learning_rate") def _get_feed_dict(self, batch, keep_prob=1.0, is_train=False, lr=None): feed_dict = {self.words: batch["words"], self.seq_len: batch["seq_len"], self.batch_size: batch["batch_size"]} if "tags" in batch: feed_dict[self.tags] = batch["tags"] if self.cfg["use_chars"]: feed_dict[self.chars] = batch["chars"] feed_dict[self.char_seq_len] = batch["char_seq_len"] feed_dict[self.keep_prob] = keep_prob feed_dict[self.drop_rate] = 1.0 - keep_prob feed_dict[self.is_train] = is_train if lr is not None: feed_dict[self.lr] = lr return feed_dict def _create_rnn_cell(self): if self.cfg["num_layers"] is None or self.cfg["num_layers"] <= 1: return self._create_single_rnn_cell(self.cfg["num_units"]) else: if self.cfg["use_stack_rnn"]: return [self._create_single_rnn_cell(self.cfg["num_units"]) for _ in range(self.cfg["num_layers"])] else: return MultiRNNCell([self._create_single_rnn_cell(self.cfg["num_units"]) for _ in range(self.cfg["num_layers"])]) def _build_embedding_op(self): with tf.variable_scope("embeddings"): if not self.cfg["use_pretrained"]: self.word_embeddings = tf.get_variable(name="emb", dtype=tf.float32, trainable=True, shape=[self.word_vocab_size, self.cfg["emb_dim"]]) else: self.word_embeddings = tf.Variable(np.load(self.cfg["pretrained_emb"])["embeddings"], name="emb", dtype=tf.float32, trainable=self.cfg["tuning_emb"]) word_emb = tf.nn.embedding_lookup(self.word_embeddings, self.words, name="word_emb") print("word embedding shape: {}".format(word_emb.get_shape().as_list())) if self.cfg["use_chars"]: self.char_embeddings = tf.get_variable(name="c_emb", dtype=tf.float32, trainable=True, shape=[self.char_vocab_size, self.cfg["char_emb_dim"]]) char_emb = tf.nn.embedding_lookup(self.char_embeddings, self.chars, name="chars_emb") char_represent = multi_conv1d(char_emb, self.cfg["filter_sizes"], self.cfg["channel_sizes"], drop_rate=self.drop_rate, is_train=self.is_train) print("chars representation shape: {}".format(char_represent.get_shape().as_list())) word_emb = tf.concat([word_emb, char_represent], axis=-1) if self.cfg["use_highway"]: self.word_emb = highway_network(word_emb, self.cfg["highway_layers"], use_bias=True, bias_init=0.0, keep_prob=self.keep_prob, is_train=self.is_train) else: self.word_emb = tf.layers.dropout(word_emb, rate=self.drop_rate, training=self.is_train) print("word and chars concatenation shape: {}".format(self.word_emb.get_shape().as_list())) def _build_model_op(self): with tf.variable_scope("bi_directional_rnn"): cell_fw = self._create_rnn_cell() cell_bw = self._create_rnn_cell() if self.cfg["use_stack_rnn"]: rnn_outs, *_ = stack_bidirectional_dynamic_rnn(cell_fw, cell_bw, self.word_emb, dtype=tf.float32, sequence_length=self.seq_len) else: rnn_outs, *_ = bidirectional_dynamic_rnn(cell_fw, cell_bw, self.word_emb, sequence_length=self.seq_len, dtype=tf.float32) rnn_outs = tf.concat(rnn_outs, axis=-1) rnn_outs = tf.layers.dropout(rnn_outs, rate=self.drop_rate, training=self.is_train) if self.cfg["use_residual"]: word_project = tf.layers.dense(self.word_emb, units=2 * self.cfg["num_units"], use_bias=False) rnn_outs = rnn_outs + word_project outputs = layer_normalize(rnn_outs) if self.cfg["use_layer_norm"] else rnn_outs print("rnn output shape: {}".format(outputs.get_shape().as_list())) if self.cfg["use_attention"] == "self_attention": with tf.variable_scope("self_attention"): attn_outs = multi_head_attention(outputs, outputs, self.cfg["num_heads"], self.cfg["attention_size"], drop_rate=self.drop_rate, is_train=self.is_train) if self.cfg["use_residual"]: attn_outs = attn_outs + outputs outputs = layer_normalize(attn_outs) if self.cfg["use_layer_norm"] else attn_outs print("self-attention output shape: {}".format(outputs.get_shape().as_list())) elif self.cfg["use_attention"] == "normal_attention": with tf.variable_scope("normal_attention"): context = tf.transpose(outputs, [1, 0, 2]) p_context = tf.layers.dense(outputs, units=2 * self.cfg["num_units"], use_bias=False) p_context = tf.transpose(p_context, [1, 0, 2]) attn_cell = AttentionCell(self.cfg["num_units"], context, p_context) # time major based attn_outs, _ = dynamic_rnn(attn_cell, context, sequence_length=self.seq_len, time_major=True, dtype=tf.float32) outputs = tf.transpose(attn_outs, [1, 0, 2]) print("attention output shape: {}".format(outputs.get_shape().as_list())) with tf.variable_scope("project"): self.logits = tf.layers.dense(outputs, units=self.tag_vocab_size, use_bias=True) print("logits shape: {}".format(self.logits.get_shape().as_list())) def train_epoch(self, train_set, valid_data, epoch): num_batches = len(train_set) prog = Progbar(target=num_batches) for i, batch_data in enumerate(train_set): feed_dict = self._get_feed_dict(batch_data, is_train=True, keep_prob=self.cfg["keep_prob"], lr=self.cfg["lr"]) _, train_loss, summary = self.sess.run([self.train_op, self.loss, self.summary], feed_dict=feed_dict) cur_step = (epoch - 1) * num_batches + (i + 1) prog.update(i + 1, [("Global Step", int(cur_step)), ("Train Loss", train_loss)]) self.train_writer.add_summary(summary, cur_step) if i % 100 == 0: valid_feed_dict = self._get_feed_dict(valid_data) valid_summary = self.sess.run(self.summary, feed_dict=valid_feed_dict) self.test_writer.add_summary(valid_summary, cur_step) def train(self, train_set, valid_data, valid_set, test_set): self.logger.info("Start training...") best_f1, no_imprv_epoch, init_lr = -np.inf, 0, self.cfg["lr"] self._add_summary() for epoch in range(1, self.cfg["epochs"] + 1): self.logger.info('Epoch {}/{}:'.format(epoch, self.cfg["epochs"])) self.train_epoch(train_set, valid_data, epoch) # train epochs if self.cfg["use_lr_decay"]: # learning rate decay self.cfg["lr"] = max(init_lr / (1.0 + self.cfg["lr_decay"] * epoch), self.cfg["minimal_lr"]) if self.cfg["task_name"] == "pos": self.eval_accuracy(valid_set, "dev") acc = self.eval_accuracy(test_set, "test") cur_test_score = acc else: self.evaluate(valid_set, "dev") score = self.evaluate(test_set, "test") cur_test_score = score["FB1"] if cur_test_score > best_f1: best_f1 = cur_test_score no_imprv_epoch = 0 self.save_session(epoch) self.logger.info(' -- new BEST score on test dataset: {:04.2f}'.format(best_f1)) else: no_imprv_epoch += 1 if no_imprv_epoch >= self.cfg["no_imprv_tolerance"]: self.logger.info('early stop at {}th epoch without improvement, BEST score on testset: {:04.2f}' .format(epoch, best_f1)) break self.train_writer.close() self.test_writer.close() def eval_accuracy(self, dataset, name): # Used for POS task accuracy = [] for data in dataset: predicts = self._predict_op(data) for preds, tags, seq_len in zip(predicts, data["tags"], data["seq_len"]): preds = preds[:seq_len] tags = tags[:seq_len] accuracy += [p == t for p, t in zip(preds, tags)] acc = np.mean(accuracy) * 100.0 self.logger.info("{} dataset -- accuracy: {:04.2f}".format(name, acc)) return acc