import torch import torch.nn as nn from src import text from pytorch_pretrained_bert import BertForMaskedLM from pytorch_pretrained_bert.modeling import BertOnlyMLMHead class BertLikeSentencePieceTextEncoder(object): def __init__(self, text_encoder): if not isinstance(text_encoder, text.SubwordTextEncoder): raise TypeError( "`text_encoder` must be an instance of `src.text.SubwordTextEncoder`.") self.text_encoder = text_encoder @property def vocab_size(self): # +3 accounts for [CLS], [SEP] and [MASK] return self.text_encoder.vocab_size + 3 @property def cls_idx(self): return self.vocab_size - 3 @property def sep_idx(self): return self.vocab_size - 2 @property def mask_idx(self): return self.vocab_size - 1 @property def eos_idx(self): return self.text_encoder.eos_idx def generate_embedding(bert_model, labels): """Generate bert's embedding from fine-tuned model.""" batch_size, time = labels.shape cls_ids = torch.full( (batch_size, 1), bert_model.bert_text_encoder.cls_idx, dtype=labels.dtype, device=labels.device) bert_labels = torch.cat([cls_ids, labels], 1) # replace eos with sep eos_idx = bert_model.bert_text_encoder.eos_idx sep_idx = bert_model.bert_text_encoder.sep_idx bert_labels[bert_labels == eos_idx] = sep_idx embedding, _ = bert_model.bert(bert_labels, output_all_encoded_layers=True) # sum over all layers embedding embedding = torch.stack(embedding).sum(0) # get rid of cls embedding = embedding[:, 1:] assert labels.shape == embedding.shape[:-1] return embedding def load_fine_tuned_model(bert_model, text_encoder, path): """Load fine-tuned bert model given text encoder and checkpoint path.""" bert_text_encoder = BertLikeSentencePieceTextEncoder(text_encoder) model = BertForMaskedLM.from_pretrained(bert_model) model.bert_text_encoder = bert_text_encoder model.bert.embeddings.word_embeddings = nn.Embedding( bert_text_encoder.vocab_size, model.bert.embeddings.word_embeddings.weight.shape[1]) model.config.vocab_size = bert_text_encoder.vocab_size model.cls = BertOnlyMLMHead( model.config, model.bert.embeddings.word_embeddings.weight) model.load_state_dict(torch.load(path)) return model class BertEmbeddingPredictor(nn.Module): def __init__(self, bert_model, text_encoder, path): super(BertEmbeddingPredictor, self).__init__() self.model = load_fine_tuned_model(bert_model, text_encoder, path) def forward(self, labels): # do not modify this self.eval() return generate_embedding(self.model, labels)