from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import copy import json import math import logging import torch from torch import nn from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss # from torch.nn import LayerNorm import torch.nn.functional as F from config import BertConfig from graph_models import FuseEmbeddings logger = logging.getLogger(__name__) CONFIG_NAME = 'bert_config.json' WEIGHTS_NAME = 'pytorch_model.bin' def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) class LayerNorm(nn.Module): def __init__(self, hidden_size, eps=1e-12): """Construct a layernorm module in the TF style (epsilon inside the square root). """ super(LayerNorm, self).__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.bias = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = eps def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.weight * x + self.bias class MultiHeadedAttention(nn.Module): """ Take in model size and number of heads. """ def __init__(self, config: BertConfig): super().__init__() assert config.hidden_size % config.num_attention_heads == 0 # We assume d_v always equals d_k self.d_k = config.hidden_size // config.num_attention_heads self.h = config.num_attention_heads self.linear_layers = nn.ModuleList( [nn.Linear(config.hidden_size, config.hidden_size, bias=False) for _ in range(3)]) self.output_linear = nn.Linear(config.hidden_size, config.hidden_size) self.attention = Attention() self.dropout = nn.Dropout(p=config.attention_probs_dropout_prob) def forward(self, query, key, value, mask=None): batch_size = query.size(0) # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = [l(x).view(batch_size, -1, self.h, self.d_k).transpose(1, 2) for l, x in zip(self.linear_layers, (query, key, value))] # 2) Apply attention on all the projected vectors in batch. x, attn = self.attention( query, key, value, mask=mask, dropout=self.dropout) # 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous().view( batch_size, -1, self.h * self.d_k) return self.output_linear(x) class Attention(nn.Module): """ Compute 'Scaled Dot Product Attention """ def forward(self, query, key, value, mask=None, dropout=None): scores = torch.matmul(query, key.transpose(-2, -1)) \ / math.sqrt(query.size(-1)) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim=-1) if dropout is not None: p_attn = dropout(p_attn) return torch.matmul(p_attn, value), p_attn class SublayerConnection(nn.Module): """ A residual connection followed by a layer norm. Note for code simplicity the norm is first as opposed to last. """ def __init__(self, config: BertConfig): super(SublayerConnection, self).__init__() self.norm = LayerNorm(config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, x, sublayer): "Apply residual connection to any sublayer with the same size." return x + self.dropout(sublayer(self.norm(x))) class PositionwiseFeedForward(nn.Module): "Implements FFN equation." def __init__(self, config: BertConfig): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Linear(config.hidden_size, config.intermediate_size) self.w_2 = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, x): return self.w_2(self.dropout(gelu(self.w_1(x)))) class TransformerBlock(nn.Module): """ Bidirectional Encoder = Transformer (self-attention) Transformer = MultiHead_Attention + Feed_Forward with sublayer connection """ def __init__(self, config: BertConfig): """ :param hidden: hidden size of transformer :param attn_heads: head sizes of multi-head attention :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size :param dropout: dropout rate """ super().__init__() self.attention = MultiHeadedAttention(config) self.feed_forward = PositionwiseFeedForward(config) self.input_sublayer = SublayerConnection(config) self.output_sublayer = SublayerConnection(config) self.dropout = nn.Dropout(p=config.hidden_dropout_prob) def forward(self, x, mask): x = self.input_sublayer( x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) x = self.output_sublayer(x, self.feed_forward) return self.dropout(x) class BertEmbeddings(nn.Module): """Construct the embeddings from word, visit and token_type embeddings. """ def __init__(self, config): super(BertEmbeddings, self).__init__() self.word_embeddings = nn.Embedding( config.vocab_size, config.hidden_size) self.token_type_embeddings = nn.Embedding(2, config.hidden_size) # self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, input_ids, token_type_ids=None): if token_type_ids is None: token_type_ids = torch.zeros_like(input_ids) words_embeddings = self.word_embeddings(input_ids) embeddings = words_embeddings + \ self.token_type_embeddings(token_type_ids) embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class PreTrainedBertModel(nn.Module): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ def __init__(self, config: BertConfig, *inputs, **kwargs): super(PreTrainedBertModel, self).__init__() if not isinstance(config, BertConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `BertConfig`. " "To create a model from a Google pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ )) self.config = config def init_bert_weights(self, module): """ Initialize the weights. """ if isinstance(module, (nn.Linear, nn.Embedding)): # Slightly different from the TF version which uses truncated_normal for initialization # cf https://github.com/pytorch/pytorch/pull/5617 module.weight.data.normal_( mean=0.0, std=self.config.initializer_range) elif isinstance(module, LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() @classmethod def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir='', *inputs, **kwargs): serialization_dir = os.path.join(cache_dir, pretrained_model_name) # Load config config_file = os.path.join(serialization_dir, CONFIG_NAME) config = BertConfig.from_json_file(config_file) logger.info("Model config {}".format(config)) # Instantiate model. model = cls(config, *inputs, **kwargs) if state_dict is None: weights_path = os.path.join(serialization_dir, WEIGHTS_NAME) state_dict = torch.load(weights_path) old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None if 'gamma' in key: new_key = key.replace('gamma', 'weight') if 'beta' in key: new_key = key.replace('beta', 'bias') if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) missing_keys = [] unexpected_keys = [] error_msgs = [] # copy state_dict so _load_from_state_dict can modify it metadata = getattr(state_dict, '_metadata', None) state_dict = state_dict.copy() if metadata is not None: state_dict._metadata = metadata def load(module, prefix=''): local_metadata = {} if metadata is None else metadata.get( prefix[:-1], {}) module._load_from_state_dict( state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) for name, child in module._modules.items(): if child is not None: load(child, prefix + name + '.') load(model, prefix='' if hasattr(model, 'bert') else 'bert.') if len(missing_keys) > 0: logger.info("Weights of {} not initialized from pretrained model: {}".format( model.__class__.__name__, missing_keys)) if len(unexpected_keys) > 0: logger.info("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) return model class BERT(PreTrainedBertModel): """ BERT model : Bidirectional Encoder Representations from Transformers. """ def __init__(self, config: BertConfig, dx_voc=None, rx_voc=None): """ :param vocab_size: vocab_size of total words :param hidden: BERT model hidden size :param n_layers: numbers of Transformer blocks(layers) :param attn_heads: number of attention heads :param dropout: dropout rate """ super().__init__(config) if config.graph: assert dx_voc is not None assert rx_voc is not None # embedding for BERT, sum of positional, segment, token embeddings self.embedding = FuseEmbeddings( config, dx_voc, rx_voc) if config.graph else BertEmbeddings(config) # multi-layers transformer blocks, deep network self.transformer_blocks = nn.ModuleList( [TransformerBlock(config) for _ in range(config.num_hidden_layers)]) # pool first output # self.pooler = BertPooler(config) self.apply(self.init_bert_weights) def forward(self, x, token_type_ids=None, input_positions=None, input_sides=None): # attention masking for padded token # torch.ByteTensor([batch_size, 1, seq_len, seq_len) mask = (x > 1).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1) # embedding the indexed sequence to sequence of vectors x = self.embedding(x, token_type_ids) # running over multiple transformer blocks for transformer in self.transformer_blocks: x = transformer.forward(x, mask) return x, x[:, 0] class BertPooler(nn.Module): def __init__(self, config): super(BertPooler, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.activation = nn.Tanh() def forward(self, hidden_states): # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output # pretaining class BertPredictionHeadTransform(nn.Module): def __init__(self, config): super(BertPredictionHeadTransform, self).__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.transform_act_fn = gelu self.LayerNorm = LayerNorm(config.hidden_size, eps=1e-12) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, config, voc_size=None): super(BertLMPredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size if voc_size is None else voc_size) def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states