# -*- coding: utf-8 -*- # AUTHOR: Shun Zheng # DATE: 19-9-19 import torch from torch import nn import torch.nn.functional as F import math from pytorch_pretrained_bert.modeling import PreTrainedBertModel, BertModel from . import transformer class BertForBasicNER(PreTrainedBertModel): """BERT model for basic NER functionality. This module is composed of the BERT model with a linear layer on top of the output sequences. Params: `config`: a BertConfig class instance with the configuration to build a new model. `num_entity_labels`: the number of entity classes for the classifier. Inputs: `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with the word token indices in the vocabulary. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to a `sentence B` token (see BERT paper for more details). `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max input sequence length in the current batch. It's the mask that we typically use for attention when a batch has varying length sentences. `label_ids`: a torch.LongTensor of shape [batch_size, sequence_length] with label indices selected in [0, ..., num_labels-1]. Outputs: if `labels` is not `None`: Outputs the CrossEntropy classification loss of the output with the labels. if `labels` is `None`: Outputs the classification logits sequence. """ def __init__(self, config, num_entity_labels): super(BertForBasicNER, self).__init__(config) self.bert = BertModel(config) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_entity_labels) self.apply(self.init_bert_weights) self.num_entity_labels = num_entity_labels def old_forward(self, input_ids, input_masks, token_type_ids=None, label_ids=None, eval_flag=False, eval_for_metric=True): """Assume input size [batch_size, seq_len]""" if input_masks.dtype != torch.uint8: input_masks = input_masks == 1 enc_seq_out, _ = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=input_masks, output_all_encoded_layers=False) # [batch_size, seq_len, hidden_size] enc_seq_out = self.dropout(enc_seq_out) # [batch_size, seq_len, num_entity_labels] seq_logits = self.classifier(enc_seq_out) if eval_flag: # if for evaluation purpose if label_ids is None: raise Exception('Cannot do evaluation without label info') else: if eval_for_metric: batch_metrics = produce_ner_batch_metrics(seq_logits, label_ids, input_masks) return batch_metrics else: seq_logp = F.log_softmax(seq_logits, dim=-1) seq_pred = seq_logp.argmax(dim=-1, keepdim=True) # [batch_size, seq_len, 1] seq_gold = label_ids.unsqueeze(-1) # [batch_size, seq_len, 1] seq_mask = input_masks.unsqueeze(-1).long() # [batch_size, seq_len, 1] seq_pred_gold_mask = torch.cat([seq_pred, seq_gold, seq_mask], dim=-1) # [batch_size, seq_len, 3] return seq_pred_gold_mask elif label_ids is not None: # if has label_ids, calculate the loss # [num_valid_token, num_entity_labels] batch_logits = seq_logits[input_masks, :] # [num_valid_token], lid \in {0,..., num_entity_labels-1} batch_labels = label_ids[input_masks] loss = F.cross_entropy(batch_logits, batch_labels) return loss, enc_seq_out else: # just reture seq_pred_logps return F.log_softmax(seq_logits, dim=-1), enc_seq_out def forward(self, input_ids, input_masks, label_ids=None, train_flag=True, decode_flag=True): """Assume input size [batch_size, seq_len]""" if input_masks.dtype != torch.uint8: input_masks = input_masks == 1 batch_seq_enc, _ = self.bert(input_ids, attention_mask=input_masks, output_all_encoded_layers=False) # [batch_size, seq_len, hidden_size] batch_seq_enc = self.dropout(batch_seq_enc) # [batch_size, seq_len, num_entity_labels] batch_seq_logits = self.classifier(batch_seq_enc) batch_seq_logp = F.log_softmax(batch_seq_logits, dim=-1) if train_flag: batch_logp = batch_seq_logp.view(-1, batch_seq_logp.size(-1)) batch_label = label_ids.view(-1) # ner_loss = F.nll_loss(batch_logp, batch_label, reduction='sum') ner_loss = F.nll_loss(batch_logp, batch_label, reduction='none') ner_loss = ner_loss.view(label_ids.size()).sum(dim=-1) # [batch_size] else: ner_loss = None if decode_flag: batch_seq_preds = batch_seq_logp.argmax(dim=-1) else: batch_seq_preds = None return batch_seq_enc, ner_loss, batch_seq_preds class NERModel(nn.Module): def __init__(self, config): super(NERModel, self).__init__() self.config = config # Word Embedding, Word Local Position Embedding self.token_embedding = NERTokenEmbedding( config.vocab_size, config.hidden_size, max_sent_len=config.max_sent_len, dropout=config.dropout ) # Multi-layer Transformer Layers to Incorporate Contextual Information self.token_encoder = transformer.make_transformer_encoder( config.num_tf_layers, config.hidden_size, ff_size=config.ff_size, dropout=config.dropout ) if self.config.use_crf_layer: self.crf_layer = CRFLayer(config.hidden_size, self.config.num_entity_labels) else: # Token Label Classification self.classifier = nn.Linear(config.hidden_size, self.config.num_entity_labels) def forward(self, input_ids, input_masks, label_ids=None, train_flag=True, decode_flag=True): """Assume input size [batch_size, seq_len]""" if input_masks.dtype != torch.uint8: input_masks = input_masks == 1 if train_flag: assert label_ids is not None # get contextual info input_emb = self.token_embedding(input_ids) input_masks = input_masks.unsqueeze(-2) # to fit for the transformer code batch_seq_enc = self.token_encoder(input_emb, input_masks) if self.config.use_crf_layer: ner_loss, batch_seq_preds = self.crf_layer( batch_seq_enc, seq_token_label=label_ids, batch_first=True, train_flag=train_flag, decode_flag=decode_flag ) else: # [batch_size, seq_len, num_entity_labels] batch_seq_logits = self.classifier(batch_seq_enc) batch_seq_logp = F.log_softmax(batch_seq_logits, dim=-1) if train_flag: batch_logp = batch_seq_logp.view(-1, batch_seq_logp.size(-1)) batch_label = label_ids.view(-1) # ner_loss = F.nll_loss(batch_logp, batch_label, reduction='sum') ner_loss = F.nll_loss(batch_logp, batch_label, reduction='none') ner_loss = ner_loss.view(label_ids.size()).sum(dim=-1) # [batch_size] else: ner_loss = None if decode_flag: batch_seq_preds = batch_seq_logp.argmax(dim=-1) else: batch_seq_preds = None return batch_seq_enc, ner_loss, batch_seq_preds class NERTokenEmbedding(nn.Module): """Add token position information""" def __init__(self, vocab_size, hidden_size, max_sent_len=256, dropout=0.1): super(NERTokenEmbedding, self).__init__() self.token_embedding = nn.Embedding(vocab_size, hidden_size) self.pos_embedding = nn.Embedding(max_sent_len, hidden_size) self.layer_norm = transformer.LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout) def forward(self, batch_token_ids): batch_size, sent_len = batch_token_ids.size() device = batch_token_ids.device batch_pos_ids = torch.arange( sent_len, dtype=torch.long, device=device, requires_grad=False ) batch_pos_ids = batch_pos_ids.unsqueeze(0).expand_as(batch_token_ids) batch_token_emb = self.token_embedding(batch_token_ids) batch_pos_emb = self.pos_embedding(batch_pos_ids) batch_token_emb = batch_token_emb + batch_pos_emb batch_token_out = self.layer_norm(batch_token_emb) batch_token_out = self.dropout(batch_token_out) return batch_token_out class CRFLayer(nn.Module): NEG_LOGIT = -100000. """ Conditional Random Field Layer Reference: https://pytorch.org/tutorials/beginner/nlp/advanced_tutorial.html#sphx-glr-beginner-nlp-advanced-tutorial-py The original example codes operate on one sequence, while this version operates on one batch """ def __init__(self, hidden_size, num_entity_labels): super(CRFLayer, self).__init__() self.tag_size = num_entity_labels + 2 # add start tag and end tag self.start_tag = self.tag_size - 2 self.end_tag = self.tag_size - 1 # Map token-level hidden state into tag scores self.hidden2tag = nn.Linear(hidden_size, self.tag_size) # Transition Matrix # [i, j] denotes transitioning from j to i self.trans_mat = nn.Parameter(torch.randn(self.tag_size, self.tag_size)) self.reset_trans_mat() def reset_trans_mat(self): nn.init.kaiming_uniform_(self.trans_mat, a=math.sqrt(5)) # copy from Linear init # set parameters that will not be updated during training, but is important self.trans_mat.data[self.start_tag, :] = self.NEG_LOGIT self.trans_mat.data[:, self.end_tag] = self.NEG_LOGIT def get_log_parition(self, seq_emit_score): """ Calculate the log of the partition function :param seq_emit_score: [seq_len, batch_size, tag_size] :return: Tensor with Size([batch_size]) """ seq_len, batch_size, tag_size = seq_emit_score.size() # dynamic programming table to store previously summarized tag logits dp_table = seq_emit_score.new_full( (batch_size, tag_size), self.NEG_LOGIT, requires_grad=False ) dp_table[:, self.start_tag] = 0. batch_trans_mat = self.trans_mat.unsqueeze(0).expand(batch_size, tag_size, tag_size) for token_idx in range(seq_len): prev_logit = dp_table.unsqueeze(1) # [batch_size, 1, tag_size] batch_emit_score = seq_emit_score[token_idx].unsqueeze(-1) # [batch_size, tag_size, 1] cur_logit = batch_trans_mat + batch_emit_score + prev_logit # [batch_size, tag_size, tag_size] dp_table = log_sum_exp(cur_logit) # [batch_size, tag_size] batch_logit = dp_table + self.trans_mat[self.end_tag, :].unsqueeze(0) log_partition = log_sum_exp(batch_logit) # [batch_size] return log_partition def get_gold_score(self, seq_emit_score, seq_token_label): """ Calculate the score of the given sequence label :param seq_emit_score: [seq_len, batch_size, tag_size] :param seq_token_label: [seq_len, batch_size] :return: Tensor with Size([batch_size]) """ seq_len, batch_size, tag_size = seq_emit_score.size() end_token_label = seq_token_label.new_full( (1, batch_size), self.end_tag, requires_grad=False ) seq_cur_label = torch.cat( [seq_token_label, end_token_label], dim=0 ).unsqueeze(-1).unsqueeze(-1).expand(seq_len+1, batch_size, 1, tag_size) start_token_label = seq_token_label.new_full( (1, batch_size), self.start_tag, requires_grad=False ) seq_prev_label = torch.cat( [start_token_label, seq_token_label], dim=0 ).unsqueeze(-1).unsqueeze(-1) # [seq_len+1, batch_size, 1, 1] seq_trans_score = self.trans_mat.unsqueeze(0).unsqueeze(0).expand(seq_len+1, batch_size, tag_size, tag_size) # gather according to token label at the current token gold_trans_score = torch.gather(seq_trans_score, 2, seq_cur_label) # [seq_len+1, batch_size, 1, tag_size] # gather according to token label at the previous token gold_trans_score = torch.gather(gold_trans_score, 3, seq_prev_label) # [seq_len+1, batch_size, 1, 1] batch_trans_score = gold_trans_score.sum(dim=0).squeeze(-1).squeeze(-1) # [batch_size] gold_emit_score = torch.gather(seq_emit_score, 2, seq_token_label.unsqueeze(-1)) # [seq_len, batch_size, 1] batch_emit_score = gold_emit_score.sum(dim=0).squeeze(-1) # [batch_size] gold_score = batch_trans_score + batch_emit_score # [batch_size] return gold_score def viterbi_decode(self, seq_emit_score): """ Use viterbi decoding to get prediction :param seq_emit_score: [seq_len, batch_size, tag_size] :return: batch_best_path: [batch_size, seq_len], the best tag for each token batch_best_score: [batch_size], the corresponding score for each path """ seq_len, batch_size, tag_size = seq_emit_score.size() dp_table = seq_emit_score.new_full((batch_size, tag_size), self.NEG_LOGIT, requires_grad=False) dp_table[:, self.start_tag] = 0 backpointers = [] for token_idx in range(seq_len): last_tag_score = dp_table.unsqueeze(-2) # [batch_size, 1, tag_size] batch_trans_mat = self.trans_mat.unsqueeze(0).expand(batch_size, tag_size, tag_size) cur_emit_score = seq_emit_score[token_idx].unsqueeze(-1) # [batch_size, tag_size, 1] cur_trans_score = batch_trans_mat + last_tag_score + cur_emit_score # [batch_size, tag_size, tag_size] dp_table, cur_tag_bp = cur_trans_score.max(dim=-1) # [batch_size, tag_size] backpointers.append(cur_tag_bp) # transition to the end tag last_trans_arr = self.trans_mat[self.end_tag].unsqueeze(0).expand(batch_size, tag_size) dp_table = dp_table + last_trans_arr # get the best path score and the best tag of the last token batch_best_score, best_tag = dp_table.max(dim=-1) # [batch_size] best_tag = best_tag.unsqueeze(-1) # [batch_size, 1] best_tag_list = [best_tag] # reversely traverse back pointers to recover the best path for last_tag_bp in reversed(backpointers): # best_tag Size([batch_size, 1]) records the current tag that can own the highest score # last_tag_bp Size([batch_size, tag_size]) records the last best tag that the current tag is based on best_tag = torch.gather(last_tag_bp, 1, best_tag) # [batch_size, 1] best_tag_list.append(best_tag) batch_start = best_tag_list.pop() assert (batch_start == self.start_tag).sum().item() == batch_size best_tag_list.reverse() batch_best_path = torch.cat(best_tag_list, dim=-1) # [batch_size, seq_len] return batch_best_path, batch_best_score def forward(self, seq_token_emb, seq_token_label=None, batch_first=False, train_flag=True, decode_flag=True): """ Get loss and prediction with CRF support. :param seq_token_emb: assume size [seq_len, batch_size, hidden_size] if not batch_first :param seq_token_label: assume size [seq_len, batch_size] if not batch_first :param batch_first: Flag to denote the meaning of the first dimension :param train_flag: whether to calculate the loss :param decode_flag: whether to decode the path based on current parameters :return: nll_loss: negative log-likelihood loss seq_token_pred: seqeunce predictions """ if batch_first: # CRF assumes the input size of [seq_len, batch_size, hidden_size] seq_token_emb = seq_token_emb.transpose(0, 1).contiguous() if seq_token_label is not None: seq_token_label = seq_token_label.transpose(0, 1).contiguous() seq_emit_score = self.hidden2tag(seq_token_emb) # [seq_len, batch_size, tag_size] if train_flag: gold_score = self.get_gold_score(seq_emit_score, seq_token_label) # [batch_size] log_partition = self.get_log_parition(seq_emit_score) # [batch_size] nll_loss = log_partition - gold_score else: nll_loss = None if decode_flag: # Use viterbi decoding to get the current prediction # no matter what batch_first is, return size is [batch_size, seq_len] batch_best_path, batch_best_score = self.viterbi_decode(seq_emit_score) else: batch_best_path = None return nll_loss, batch_best_path # Compute log sum exp in a numerically stable way def log_sum_exp(batch_logit): """ Caculate the log-sum-exp operation for the last dimension. :param batch_logit: Size([*, logit_size]), * should at least be 1 :return: Size([*]) """ batch_max, _ = batch_logit.max(dim=-1) batch_broadcast = batch_max.unsqueeze(-1) return batch_max + \ torch.log(torch.sum(torch.exp(batch_logit - batch_broadcast), dim=-1)) def produce_ner_batch_metrics(seq_logits, gold_labels, masks): # seq_logits: [batch_size, seq_len, num_entity_labels] # gold_labels: [batch_size, seq_len] # masks: [batch_size, seq_len] batch_size, seq_len, num_entities = seq_logits.size() # [batch_size, seq_len, num_entity_labels] seq_logp = F.log_softmax(seq_logits, dim=-1) # [batch_size, seq_len] pred_labels = seq_logp.argmax(dim=-1) # [batch_size*seq_len, num_entity_labels] token_logp = seq_logp.view(-1, num_entities) # [batch_size*seq_len] token_labels = gold_labels.view(-1) # [batch_size, seq_len] seq_token_loss = F.nll_loss(token_logp, token_labels, reduction='none').view(batch_size, seq_len) batch_metrics = [] for bid in range(batch_size): ex_loss = seq_token_loss[bid, masks[bid]].mean().item() ex_acc = (pred_labels[bid, masks[bid]] == gold_labels[bid, masks[bid]]).float().mean().item() ex_pred_lids = pred_labels[bid, masks[bid]].tolist() ex_gold_lids = gold_labels[bid, masks[bid]].tolist() ner_tp_set, ner_fp_set, ner_fn_set = judge_ner_prediction(ex_pred_lids, ex_gold_lids) batch_metrics.append([ex_loss, ex_acc, len(ner_tp_set), len(ner_fp_set), len(ner_fn_set)]) return torch.tensor(batch_metrics, dtype=torch.float, device=seq_logits.device) def judge_ner_prediction(pred_label_ids, gold_label_ids): """Very strong assumption on label_id, 0: others, odd: ner_start, even: ner_mid""" if isinstance(pred_label_ids, torch.Tensor): pred_label_ids = pred_label_ids.tolist() if isinstance(gold_label_ids, torch.Tensor): gold_label_ids = gold_label_ids.tolist() # element: (ner_start_index, ner_end_index, ner_type_id) pred_ner_set = set() gold_ner_set = set() pred_ner_sid = None for idx, ner in enumerate(pred_label_ids): if pred_ner_sid is None: if ner % 2 == 1: pred_ner_sid = idx continue else: prev_ner = pred_label_ids[pred_ner_sid] if ner == 0: pred_ner_set.add((pred_ner_sid, idx, prev_ner)) pred_ner_sid = None continue elif ner == prev_ner + 1: # same entity continue elif ner % 2 == 1: pred_ner_set.add((pred_ner_sid, idx, prev_ner)) pred_ner_sid = idx continue else: # ignore invalid subsequence ners pred_ner_set.add((pred_ner_sid, idx, prev_ner)) pred_ner_sid = None pass if pred_ner_sid is not None: prev_ner = pred_label_ids[pred_ner_sid] pred_ner_set.add((pred_ner_sid, len(pred_label_ids), prev_ner)) gold_ner_sid = None for idx, ner in enumerate(gold_label_ids): if gold_ner_sid is None: if ner % 2 == 1: gold_ner_sid = idx continue else: prev_ner = gold_label_ids[gold_ner_sid] if ner == 0: gold_ner_set.add((gold_ner_sid, idx, prev_ner)) gold_ner_sid = None continue elif ner == prev_ner + 1: # same entity continue elif ner % 2 == 1: gold_ner_set.add((gold_ner_sid, idx, prev_ner)) gold_ner_sid = idx continue else: # ignore invalid subsequence ners gold_ner_set.add((gold_ner_sid, idx, prev_ner)) gold_ner_sid = None pass if gold_ner_sid is not None: prev_ner = gold_label_ids[gold_ner_sid] gold_ner_set.add((gold_ner_sid, len(gold_label_ids), prev_ner)) ner_tp_set = pred_ner_set.intersection(gold_ner_set) ner_fp_set = pred_ner_set - gold_ner_set ner_fn_set = gold_ner_set - pred_ner_set return ner_tp_set, ner_fp_set, ner_fn_set