# @Author : bamtercelboo # @Datetime : 2018/9/14 9:51 # @File : CRF.py # @Last Modify Time : 2018/9/14 9:51 # @Contact : bamtercelboo@{gmail.com, 163.com} """ FILE : CRF.py FUNCTION : None REFERENCE : https://github.com/jiesutd/NCRFpp/blob/master/model/crf.py """ import torch from torch.autograd.variable import Variable import torch.nn as nn def log_sum_exp(vec, m_size): """ Args: vec: size=(batch_size, vanishing_dim, hidden_dim) m_size: hidden_dim Returns: size=(batch_size, hidden_dim) """ _, idx = torch.max(vec, 1) # B * 1 * M max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M return max_score.view(-1, m_size) + torch.log(torch.sum( torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) class CRF(nn.Module): """ CRF """ def __init__(self, **kwargs): """ kwargs: target_size: int, target size device: str, device """ super(CRF, self).__init__() for k in kwargs: self.__setattr__(k, kwargs[k]) device = self.device # init transitions self.START_TAG, self.STOP_TAG = -2, -1 init_transitions = torch.zeros(self.target_size + 2, self.target_size + 2, device=device) init_transitions[:, self.START_TAG] = -10000.0 init_transitions[self.STOP_TAG, :] = -10000.0 self.transitions = nn.Parameter(init_transitions) def _forward_alg(self, feats, mask): """ Do the forward algorithm to compute the partition function (batched). Args: feats: size=(batch_size, seq_len, self.target_size+2) mask: size=(batch_size, seq_len) Returns: xxx """ batch_size = feats.size(0) seq_len = feats.size(1) tag_size = feats.size(2) mask = mask.transpose(1, 0).contiguous() ins_num = seq_len * batch_size """ be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) """ feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) """ need to consider start """ scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) scores = scores.view(seq_len, batch_size, tag_size, tag_size) # build iter seq_iter = enumerate(scores) _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size """ only need start from start_tag """ partition = inivalues[:, self.START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size """ add start score (from start to all tag, duplicate to batch_size) partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) iter over last scores """ for idx, cur_values in seq_iter: """ previous to_target is current from_target partition: previous results log(exp(from_target)), #(batch_size * from_target) cur_values: bat_size * from_target * to_target """ cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) cur_partition = log_sum_exp(cur_values, tag_size) mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) """ effective updated partition part, only keep the partition value of mask value = 1 """ masked_cur_partition = cur_partition.masked_select(mask_idx) """ let mask_idx broadcastable, to disable warning """ mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) """ replace the partition where the maskvalue=1, other partition value keeps the same """ partition.masked_scatter_(mask_idx, masked_cur_partition) """ until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG """ cur_values = self.transitions.view(1, tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) cur_partition = log_sum_exp(cur_values, tag_size) final_partition = cur_partition[:, self.STOP_TAG] return final_partition.sum(), scores def _viterbi_decode(self, feats, mask): """ input: feats: (batch, seq_len, self.tag_size+2) mask: (batch, seq_len) output: decode_idx: (batch, seq_len) decoded sequence path_score: (batch, 1) corresponding score for each sequence (to be implementated) """ # print(feats.size()) batch_size = feats.size(0) seq_len = feats.size(1) tag_size = feats.size(2) # assert(tag_size == self.tagset_size+2) """ calculate sentence length for each sentence """ length_mask = torch.sum(mask.long(), dim=1).view(batch_size, 1).long() """ mask to (seq_len, batch_size) """ mask = mask.transpose(1, 0).contiguous() ins_num = seq_len * batch_size """ be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) """ feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) """ need to consider start """ scores = feats + self.transitions.view(1, tag_size, tag_size).expand(ins_num, tag_size, tag_size) scores = scores.view(seq_len, batch_size, tag_size, tag_size) # build iter seq_iter = enumerate(scores) # record the position of best score back_points = list() partition_history = list() ## reverse mask (bug for mask = 1- mask, use this as alternative choice) # mask = 1 + (-1)*mask mask = (1 - mask.long()).byte() _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size """ only need start from start_tag """ partition = inivalues[:, self.START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size partition_history.append(partition) # iter over last scores for idx, cur_values in seq_iter: """ previous to_target is current from_target partition: previous results log(exp(from_target)), #(batch_size * from_target) cur_values: batch_size * from_target * to_target """ cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) """ forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG """ partition, cur_bp = torch.max(cur_values, 1) partition_history.append(partition) """ cur_bp: (batch_size, tag_size) max source score position in current tag set padded label as 0, which will be filtered in post processing """ cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) back_points.append(cur_bp) """ add score to final STOP_TAG """ partition_history = torch.cat(partition_history, 0).view(seq_len, batch_size, -1).transpose(1, 0).contiguous() ## (batch_size, seq_len. tag_size) """ get the last position for each setences, and select the last partitions using gather() """ last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1 last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1) """ calculate the score from last partition to end state (and then select the STOP_TAG from it) """ last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) _, last_bp = torch.max(last_values, 1) pad_zero = torch.zeros(batch_size, tag_size, device=self.device, requires_grad=True).long() back_points.append(pad_zero) back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) """ elect end ids in STOP_TAG """ pointer = last_bp[:, self.STOP_TAG] insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size) back_points = back_points.transpose(1,0).contiguous() """move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values """ back_points.scatter_(1, last_position, insert_last) back_points = back_points.transpose(1,0).contiguous() """ decode from the end, padded position ids are 0, which will be filtered if following evaluation """ # decode_idx = Variable(torch.LongTensor(seq_len, batch_size)) decode_idx = torch.empty(seq_len, batch_size, device=self.device, requires_grad=True).long() decode_idx[-1] = pointer.detach() for idx in range(len(back_points)-2, -1, -1): pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) decode_idx[idx] = pointer.detach().view(batch_size) path_score = None decode_idx = decode_idx.transpose(1, 0) return path_score, decode_idx def forward(self, feats, mask): """ :param feats: :param mask: :return: """ path_score, best_path = self._viterbi_decode(feats, mask) return path_score, best_path def _score_sentence(self, scores, mask, tags): """ Args: scores: size=(seq_len, batch_size, tag_size, tag_size) mask: size=(batch_size, seq_len) tags: size=(batch_size, seq_len) Returns: score: """ # print(scores.size()) batch_size = scores.size(1) seq_len = scores.size(0) tag_size = scores.size(-1) tags = tags.view(batch_size, seq_len) """ convert tag value into a new format, recorded label bigram information to index """ # new_tags = Variable(torch.LongTensor(batch_size, seq_len)) new_tags = torch.empty(batch_size, seq_len, device=self.device, requires_grad=True).long() for idx in range(seq_len): if idx == 0: new_tags[:, 0] = (tag_size - 2) * tag_size + tags[:, 0] else: new_tags[:, idx] = tags[:, idx-1] * tag_size + tags[:, idx] """ transition for label to STOP_TAG """ end_transition = self.transitions[:, self.STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) """ length for batch, last word position = length - 1 """ length_mask = torch.sum(mask, dim=1).view(batch_size, 1).long() """ index the label id of last word """ end_ids = torch.gather(tags, 1, length_mask-1) """ index the transition score for end_id to STOP_TAG """ end_energy = torch.gather(end_transition, 1, end_ids) """ convert tag as (seq_len, batch_size, 1) """ new_tags = new_tags.transpose(1, 0).contiguous().view(seq_len, batch_size, 1) """ need convert tags id to search from 400 positions of scores """ tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) tg_energy = tg_energy.masked_select(mask.transpose(1, 0)) """ add all score together gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() """ gold_score = tg_energy.sum() + end_energy.sum() return gold_score def neg_log_likelihood_loss(self, feats, mask, tags): """ Args: feats: size=(batch_size, seq_len, tag_size) mask: size=(batch_size, seq_len) tags: size=(batch_size, seq_len) """ batch_size = feats.size(0) forward_score, scores = self._forward_alg(feats, mask) gold_score = self._score_sentence(scores, mask, tags) return forward_score - gold_score