import torch import torch.nn as nn from transformers import BertModel, BertTokenizer # from .base_encoder import BaseEncoder import pickle as pk class BERTEncoder(nn.Module): def __init__(self, max_length, pretrain_path, blank_padding=True): """ Args: max_length: max length of sentence pretrain_path: path of pretrain model """ super().__init__() self.max_length = max_length self.blank_padding = blank_padding self.bert = BertModel.from_pretrained(pretrain_path) self.hidden_size = self.bert.config.hidden_size self.tokenizer = BertTokenizer.from_pretrained(pretrain_path) def forward(self, token, att_mask): """ Args: token: (B, L), index of tokens att_mask: (B, L), attention mask (1 for contents and 0 for padding) Return: (B, H), representations for sentences """ _, x = self.bert(token, attention_mask=att_mask) return x def tokenize(self, item): """ Args: item: data instance containing 'text' / 'token', 'h' and 't' Return: Name of the relation of the sentence """ # Sentence -> token if 'text' in item: sentence = item['text'] is_token = False else: sentence = item['token'] is_token = True pos_head = item['h']['pos'] pos_tail = item['t']['pos'] if not is_token: pos_min = pos_head pos_max = pos_tail if pos_head[0] > pos_tail[0]: pos_min = pos_tail pos_max = pos_head rev = True else: rev = False sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) pos_head = [len(sent0), len(sent0) + len(ent0)] pos_tail = [ len(sent0) + len(ent0) + len(sent1), len(sent0) + len(ent0) + len(sent1) + len(ent1) ] if rev: pos_tail = [len(sent0), len(sent0) + len(ent0)] pos_head = [ len(sent0) + len(ent0) + len(sent1), len(sent0) + len(ent0) + len(sent1) + len(ent1) ] tokens = sent0 + ent0 + sent1 + ent1 + sent2 else: tokens = sentence # Token -> index re_tokens = ['[CLS]'] cur_pos = 0 for token in tokens: token = token.lower() if cur_pos == pos_head[0]: re_tokens.append('[unused0]') if cur_pos == pos_tail[0]: re_tokens.append('[unused1]') re_tokens += self.tokenizer.tokenize(token) if cur_pos == pos_head[1] - 1: re_tokens.append('[unused2]') if cur_pos == pos_tail[1] - 1: re_tokens.append('[unused3]') cur_pos += 1 re_tokens.append('[SEP]') indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) avai_len = len(indexed_tokens) # Padding if self.blank_padding: while len(indexed_tokens) < self.max_length: indexed_tokens.append(0) # 0 is id for [PAD] indexed_tokens = indexed_tokens[:self.max_length] indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze( 0) # (1, L) # Attention mask att_mask = torch.zeros(indexed_tokens.size()).long() # (1, L) att_mask[0, :avai_len] = 1 return indexed_tokens, att_mask class BERTHiddenStateEncoder(BERTEncoder): def __init__(self, pretrain_path, blank_padding=True): super().__init__(80, pretrain_path, blank_padding) self.bert = BertModel.from_pretrained(pretrain_path, output_hidden_states=True,output_attentions=True) def forward(self, token, att_mask): """ Args: token: (B, L), index of tokens att_mask: (B, L), attention mask (1 for contents and 0 for padding) Return: (B, H), representations for sentences """ _, x, hs ,atts= self.bert(token, attention_mask=att_mask) return x, hs, atts def tokenize(self, item): """ Args: item: data instance containing 'text' / 'token', 'h' and 't' Return: Name of the relation of the sentence """ # Sentence -> token # if 'text' in item: # sentence = item['text'] # is_token = False # else: sentence = item['token'] is_token = True tokens = sentence re_tokens = ['[CLS]'] cur_pos = 0 new_index = [-1] for idx, token in enumerate(tokens): # token = token.lower() re_tokens += self.tokenizer.tokenize(token) new_index.extend([idx] * (len(re_tokens) - len(new_index))) cur_pos += 1 re_tokens.append('[SEP]') new_index.append(max(new_index) + 1) indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze(0) # (1, L) # Attention mask att_mask = torch.ones(indexed_tokens.size()).long() # (1, L) return indexed_tokens, att_mask, new_index class BERTEntityEncoder(nn.Module): def __init__(self, max_length, pretrain_path, blank_padding=True): """ Args: max_length: max length of sentence pretrain_path: path of pretrain model """ super().__init__() self.max_length = max_length self.blank_padding = blank_padding self.hidden_size = 768 * 2 self.bert = BertModel.from_pretrained(pretrain_path) self.tokenizer = BertTokenizer.from_pretrained(pretrain_path) self.linear = nn.Linear(self.hidden_size, self.hidden_size) def forward(self, token, att_mask, pos1, pos2): """ Args: token: (B, L), index of tokens att_mask: (B, L), attention mask (1 for contents and 0 for padding) pos1: (B, 1), position of the head entity starter pos2: (B, 1), position of the tail entity starter Return: (B, 2H), representations for sentences """ hidden, _ = self.bert(token, attention_mask=att_mask) hidden = hidden[-1] # the last hidden layer, (B, L, H) # Get entity start hidden state onehot = torch.zeros(hidden.size()[:2]).float() # (B, L) if torch.cuda.is_available(): onehot = onehot.cuda() onehot_head = onehot.scatter_(1, pos1, 1) onehot_tail = onehot.scatter_(1, pos2, 1) head_hidden = (onehot_head.unsqueeze(2) * hidden).sum(1) # (B, H) tail_hidden = (onehot_tail.unsqueeze(2) * hidden).sum(1) # (B, H) x = torch.cat([head_hidden, tail_hidden], 1) # (B, 2H) x = self.linear(x) return x def tokenize(self, item): """ Args: item: data instance containing 'text' / 'token', 'h' and 't' Return: Name of the relation of the sentence """ # Sentence -> token if 'text' in item: sentence = item['text'] is_token = False else: sentence = item['token'] is_token = True pos_head = item['h']['pos'] pos_tail = item['t']['pos'] if not is_token: pos_min = pos_head pos_max = pos_tail if pos_head[0] > pos_tail[0]: pos_min = pos_tail pos_max = pos_head rev = True else: rev = False sent0 = self.tokenizer.tokenize(sentence[:pos_min[0]]) ent0 = self.tokenizer.tokenize(sentence[pos_min[0]:pos_min[1]]) sent1 = self.tokenizer.tokenize(sentence[pos_min[1]:pos_max[0]]) ent1 = self.tokenizer.tokenize(sentence[pos_max[0]:pos_max[1]]) sent2 = self.tokenizer.tokenize(sentence[pos_max[1]:]) pos_head = [len(sent0), len(sent0) + len(ent0)] pos_tail = [ len(sent0) + len(ent0) + len(sent1), len(sent0) + len(ent0) + len(sent1) + len(ent1) ] if rev: pos_tail = [len(sent0), len(sent0) + len(ent0)] pos_head = [ len(sent0) + len(ent0) + len(sent1), len(sent0) + len(ent0) + len(sent1) + len(ent1) ] tokens = sent0 + ent0 + sent1 + ent1 + sent2 else: tokens = sentence # Token -> index re_tokens = ['[CLS]'] cur_pos = 0 pos1 = 0 pos2 = 0 for token in tokens: token = token.lower() if cur_pos == pos_head[0]: pos1 = len(re_tokens) re_tokens.append('[unused0]') if cur_pos == pos_tail[0]: pos2 = len(re_tokens) re_tokens.append('[unused1]') re_tokens += self.tokenizer.tokenize(token) if cur_pos == pos_head[1] - 1: re_tokens.append('[unused2]') if cur_pos == pos_tail[1] - 1: re_tokens.append('[unused3]') cur_pos += 1 re_tokens.append('[SEP]') pos1 = min(self.max_length - 1, pos1) pos2 = min(self.max_length - 1, pos2) indexed_tokens = self.tokenizer.convert_tokens_to_ids(re_tokens) avai_len = len(indexed_tokens) # Position pos1 = torch.tensor([[pos1]]).long() pos2 = torch.tensor([[pos2]]).long() # Padding # if self.blank_padding: # while len(indexed_tokens) < self.max_length: # indexed_tokens.append(0) # 0 is id for [PAD] # indexed_tokens = indexed_tokens[:self.max_length] indexed_tokens = torch.tensor(indexed_tokens).long().unsqueeze( 0) # (1, L) # Attention mask att_mask = torch.ones(indexed_tokens.size()).long() # (1, L) return indexed_tokens, att_mask, pos1, pos2