# -*- coding: utf-8 -*- # Created by li huayong on 2019/9/28 import os import re import torch import torch.nn as nn from torch.utils.data import TensorDataset from tqdm import tqdm, trange from abc import ABCMeta, abstractmethod from typing import Dict, List, Tuple from utils.data.conll_file import CoNLLFile from utils.data.graph_vocab import GraphVocab from utils.model.get_optimizer import get_optimizer from utils.model.parser_funs import sdp_decoder, parse_semgraph import utils.model.sdp_simple_scorer as sdp_scorer from utils.best_result import BestResult from utils.model.label_smoothing import label_smoothed_kl_div_loss from PyToolkit.PyToolkit import get_logger from PyToolkit.PyToolkit.seed import set_seed try: from torch.utils.tensorboard import SummaryWriter except ImportError: from tensorboardX import SummaryWriter class BaseDependencyTrainer(metaclass=ABCMeta): def __init__(self, args, model): self.model = model self.optimizer = self.optim_scheduler = None self.graph_vocab = GraphVocab(args.graph_vocab_file) self.configs = args self.logger = get_logger(args.log_name) @abstractmethod def _unpack_batch(self, batch: TensorDataset) -> Dict: """ 拆分batch,得到encoder的输入和word mask,sentence length,以及dep ids,以及其他输入信息 eg: dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_pos, all_end_pos, all_dep_ids, all_pos_ids) Args: batch: 输入的单个batch,类型为TensorDataset(或者torchtext.dataset),可用索引分别取值 Returns: 返回一个字典,[1]是inputs,类型为字典;[2]是word mask;[3]是sentence length,python 列表;[4]是dep ids, 根据实际情况可能还包含其他输入信息 """ raise NotImplementedError('must implement in sub class') def _custom_train_operations(self, epoch: int): """ 某些模型在训练时可能需要一些定制化的操作, 比如BERT类型的模型可能会在Training的时候动态freeze某些层 为了支持这些操作同时不破坏BiaffineDependencyTrainer的普适性,我们加入这个方法 BiaffineDependencyTrainer的子类可以选择重写该方法以支持定制化操作 注意这个方法会在训练的每个epoch的开始调用一次 本方法默认不会做任何事情 :return: """ pass def _update_and_predict(self, unlabeled_scores, labeled_scores, unlabeled_target, labeled_target, word_pad_mask, label_loss_ratio=0.5, sentence_lengths=None, calc_loss=True, update=True, calc_prediction=False, pos_logits=None, pos_target=None, pos_loss_ratio=1.0, summary_writer=None, global_step=None): """ 针对一个batch输入:计算loss,反向传播,计算预测结果 :param word_pad_mask: 以word为单位,1为PAD,0为真实输入 :return: """ weights = torch.ones(word_pad_mask.size(0), self.configs.max_seq_len, self.configs.max_seq_len, dtype=unlabeled_scores.dtype, device=unlabeled_scores.device) # 将PAD的位置权重设为0,其余位置为1 weights = weights.masked_fill(word_pad_mask.unsqueeze(1), 0) weights = weights.masked_fill(word_pad_mask.unsqueeze(2), 0) # words_num 记录batch中的单词数量 # torch.eq(word_pad_mask, False) 得到word_mask words_num = torch.sum(torch.eq(word_pad_mask, False)).item() if calc_loss: assert label_loss_ratio assert unlabeled_target is not None and labeled_target is not None dep_arc_loss_func = nn.BCEWithLogitsLoss(weight=weights, reduction='sum') dep_arc_loss = dep_arc_loss_func(unlabeled_scores, unlabeled_target) dep_label_loss_func = nn.CrossEntropyLoss(ignore_index=-1, reduction='sum') dependency_mask = labeled_target.eq(0) labeled_target = labeled_target.masked_fill(dependency_mask, -1) labeled_scores = labeled_scores.contiguous().view(-1, len(self.graph_vocab.get_labels())) dep_label_loss = dep_label_loss_func(labeled_scores, labeled_target.view(-1)) if self.configs.use_pos: assert pos_logits is not None assert pos_target is not None pos_loss_func = nn.CrossEntropyLoss(ignore_index=self.configs.pos_label_pad_idx) pos_loss = pos_loss_func(pos_logits.view(-1, self.configs.pos_label_num), pos_target.view(-1)) loss = 2 * ((1 - label_loss_ratio) * dep_arc_loss + label_loss_ratio * dep_label_loss) if self.configs.use_pos: loss = loss + pos_loss_ratio * pos_loss if self.configs.average_loss_by_words_num: loss = loss / words_num if self.configs.scale_loss: loss = loss * self.configs.loss_scaling_ratio if self.configs.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training if summary_writer and global_step: summary_writer.add_scalar('train_loss/dep_arc_loss', dep_arc_loss, global_step) summary_writer.add_scalar('train_loss/dep_label_loss', dep_label_loss, global_step) if self.configs.use_pos: summary_writer.add_scalar('train_loss/pos_loss', pos_loss, global_step) if update: # Ref:https://discuss.pytorch.org/t/model-zero-grad-or-optimizer-zero-grad/28426/5 # Dmitry A. Konovalov: # model.zero_grad() and optimizer.zero_grad() are the same IF all your model parameters # are in that optimizer. # I found it is safer to call model.zero_grad() to make sure all grads are zero, # e.g. if you have two or more optimizers for one model. # Ref:https://discuss.pytorch.org/t/zero-grad-optimizer-or-net/1887/9 # ptrblck: # if you pass all parameters of your model to the optimizer, both calls will be equal. # model.zero_grad() would clear all parameters of the model, # while the optimizerX.zero_grad() call will just clean # the gradients of the parameters that were passed to it. # loss.backward() **之前** 清空模型梯度 self.model.zero_grad() # 参看上述注释,这里只需要model.zero_grad() # self.optimizer.zero_grad() loss.backward() if self.configs.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.configs.max_grad_norm) self.optimizer.step() if self.optim_scheduler: self.optim_scheduler.step() # Update learning rate schedule loss = loss.detach().cpu().item() else: loss = None if calc_prediction: assert sentence_lengths weights = weights.unsqueeze(3) head_probs = torch.sigmoid(unlabeled_scores).unsqueeze(3) label_probs = torch.softmax(labeled_scores, dim=3) batch_probs = head_probs * label_probs * weights batch_probs = batch_probs.detach().cpu().numpy() # debug_print(batch_probs) sem_graph = sdp_decoder(batch_probs, sentence_lengths) sem_sents = parse_semgraph(sem_graph, sentence_lengths) batch_prediction = self.graph_vocab.parse_to_sent_batch(sem_sents) else: batch_prediction = None return loss, batch_prediction def train(self, train_data_loader, dev_data_loader=None, dev_CoNLLU_file=None): self.optimizer, self.optim_scheduler = get_optimizer(self.configs, self.model) global_step = 0 best_result = BestResult() self.model.zero_grad() set_seed(self.configs) # Added here for reproductibility (even between python 2 and 3) train_stop = False if self.configs.local_rank in [-1, 0] and not self.configs.no_output: summary_writer = SummaryWriter(log_dir=self.configs.summary_dir) for epoch in range(1, self.configs.max_train_epochs + 1): epoch_ave_loss = 0 train_data_loader = tqdm(train_data_loader, desc=f'Training epoch {epoch}', disable=self.configs.local_rank not in [-1, 0]) # 某些模型在训练时可能需要一些定制化的操作,默认什么都不做 # 具体参考子类中_custom_train_operations的实现 self._custom_train_operations(epoch) for step, batch in enumerate(train_data_loader): batch = tuple(t.to(self.configs.device) for t in batch) self.model.train() # debug_print(batch) # word_mask:以word为单位,1为真实输入,0为PAD unpacked_batch = self._unpack_batch(batch) # word_pad_mask:以word为单位,1为PAD,0为真实输入 word_pad_mask = torch.eq(unpacked_batch['word_mask'], 0) model_output = self.model(unpacked_batch['inputs']) unlabeled_scores, labeled_scores = model_output['unlabeled_scores'], model_output['labeled_scores'] labeled_target = unpacked_batch['dep_ids'] unlabeled_target = labeled_target.ge(1).to(unlabeled_scores.dtype) if self.configs.use_pos: pos_logits = model_output['pos_logits'] pos_target = unpacked_batch['pos_ids'] else: pos_target = pos_logits = None # Calc loss and update: loss, _ = self._update_and_predict(unlabeled_scores, labeled_scores, unlabeled_target, labeled_target, word_pad_mask, # label_loss_ratio=self.model.module.label_loss_ratio if hasattr(self.model,'module') else self.model.label_loss_ratio, calc_loss=True, update=True, calc_prediction=False, pos_logits=pos_logits, pos_target=pos_target, summary_writer=summary_writer if self.configs.local_rank in [-1, 0] else None, global_step=global_step) global_step += 1 if loss is not None: epoch_ave_loss += loss if global_step % self.configs.eval_interval == 0 and self.configs.local_rank in [-1, 0]: if not self.configs.no_output: summary_writer.add_scalar('train_loss/loss', loss, global_step) # 记录学习率 for i, param_group in enumerate(self.optimizer.param_groups): summary_writer.add_scalar(f'lr/group_{i}', param_group['lr'], global_step) if dev_data_loader and self.configs.local_rank in [-1, 0]: UAS, LAS = self.dev(dev_data_loader, dev_CoNLLU_file) if not self.configs.no_output: summary_writer.add_scalar('metrics/uas', UAS, global_step) summary_writer.add_scalar('metrics/las', LAS, global_step) if best_result.is_new_record(LAS=LAS, UAS=UAS, epoch=epoch) and self.configs.local_rank in [-1, 0]: self.logger.info(f"\n## NEW BEST RESULT in epoch {epoch} ##") self.logger.info('\n' + str(best_result)) # 保存最优模型: if not self.configs.no_output: if hasattr(self.model, 'module'): # 多卡,torch.nn.DataParallel封装model self.model.module.save_pretrained(self.configs.output_model_dir) else: self.model.save_pretrained(self.configs.output_model_dir) if self.configs.early_stop and epoch - best_result.best_LAS_epoch > self.configs.early_stop_epochs \ and self.configs.local_rank == -1: # 当使用 torch.distributed 训练时无法使用 early stop # todo fix bug [bug when use torch.distributed.launch !!] self.logger.info(f'\n## Early stop in step:{global_step} ##') train_stop = True break if train_stop: break # print(f'\n- Epoch {epoch} average loss : {epoch_ave_loss / len(train_data_loader)}') if self.configs.local_rank in [-1, 0] and not self.configs.no_output: summary_writer.add_scalar('epoch_loss', epoch_ave_loss / len(train_data_loader), epoch) if self.configs.local_rank in [-1, 0] and not self.configs.no_output: with open(self.configs.dev_result_path, 'w', encoding='utf-8')as f: f.write(str(best_result) + '\n') self.logger.info("\n## BEST RESULT in Training ##") self.logger.info('\n' + str(best_result)) summary_writer.close() def dev(self, dev_data_loader, dev_CoNLLU_file, input_conllu_path=None, output_conllu_path=None): if not isinstance(dev_CoNLLU_file, CoNLLFile): raise RuntimeError(f'dev_conllu_file type:{type(dev_CoNLLU_file)}') if input_conllu_path is None: input_conllu_path = os.path.join(self.configs.data_dir, self.configs.dev_file) if output_conllu_path is None: output_conllu_path = self.configs.dev_output_path if not self.configs.no_output else None dev_data_loader = tqdm(dev_data_loader, desc='Evaluation') predictions = [] for step, batch in enumerate(dev_data_loader): self.model.eval() batch = tuple(t.to(self.configs.device) for t in batch) unpacked_batch = self._unpack_batch(batch) """ unpacked_batch = { 'inputs': inputs, 'word_mask': word_mask, 'sent_len': sent_len, 'dep_ids': dep_ids, 'pos_ids': pos_ids, } """ inputs, word_mask, sent_lens, dep_ids = unpacked_batch['inputs'], unpacked_batch['word_mask'], \ unpacked_batch['sent_len'], unpacked_batch['dep_ids'] word_mask = torch.eq(word_mask, 0) model_output = self.model(inputs) unlabeled_scores, labeled_scores = model_output['unlabeled_scores'], model_output['labeled_scores'] try: with torch.no_grad(): _, batch_prediction = self._update_and_predict(unlabeled_scores, labeled_scores, None, None, word_mask, # label_loss_ratio=self.model.module.label_loss_ratio if hasattr(self.model,'module') else self.model.label_loss_ratio, sentence_lengths=sent_lens, calc_loss=False, update=False, calc_prediction=True) except Exception as e: for b in batch: print(b.shape) raise e predictions += batch_prediction # batch_sent_lens += sent_lens dev_CoNLLU_file.set(['deps'], [dep for sent in predictions for dep in sent]) if output_conllu_path: dev_CoNLLU_file.write_conll(output_conllu_path) UAS, LAS = sdp_scorer.score(output_conllu_path, input_conllu_path) return UAS, LAS def inference(self, inference_data_loader, inference_CoNLLU_file, output_conllu_path): inference_data_loader = tqdm(inference_data_loader, desc='Inference') predictions = [] for step, batch in enumerate(inference_data_loader): self.model.eval() unpacked_batch = self._unpack_batch(batch) """ unpacked_batch = { 'inputs': inputs, 'word_mask': word_mask, 'sent_len': sent_len, 'dep_ids': dep_ids, 'pos_ids': pos_ids, } """ inputs, word_mask, sent_lens, _ = unpacked_batch['inputs'], unpacked_batch['word_mask'], \ unpacked_batch['sent_len'], unpacked_batch['dep_ids'] word_mask = torch.eq(word_mask, 0) model_output = self.model(inputs) unlabeled_scores, labeled_scores = model_output['unlabeled_scores'], model_output['labeled_scores'] with torch.no_grad(): _, batch_prediction = self._update_and_predict(unlabeled_scores, labeled_scores, None, None, word_mask, # label_loss_ratio=self.model.label_loss_ratio if not self.args.data_paralle else self.model.module.label_loss_ratio, sentence_lengths=sent_lens, calc_loss=False, update=False, calc_prediction=True) predictions += batch_prediction inference_CoNLLU_file.set(['deps'], [dep for sent in predictions for dep in sent]) inference_CoNLLU_file.write_conll(output_conllu_path) return predictions class TransformerBaseTrainer(BaseDependencyTrainer): def _unpack_batch(self, batch): pass # class CharRNNBaseTrainer(BaseDependencyTrainer): # pass if __name__ == '__main__': pass