import torch.nn as nn import torch.nn.functional as F from pytorch_pretrained_bert.modeling import BertConfig, BertLayer class Masker(nn.Module): def __init__(self, vocab_size, original_hidden_size, num_layers, tau=1): super().__init__() self.bert_layer = BertLayer(BertConfig( vocab_size_or_config_json_file=vocab_size, hidden_size=original_hidden_size * num_layers, )) self.linear_layer = nn.Linear(original_hidden_size * num_layers, 1) self.log_sigmoid = nn.LogSigmoid() self.tau = tau def forward(self, x, attention_mask, gumbel_softmax=True, tau=None): extended_attention_mask = self.convert_mask(attention_mask) h = self.bert_layer(x, extended_attention_mask) h = self.linear_layer(h) log_probs = self.log_sigmoid(h).squeeze(dim=2) if gumbel_softmax: tau = self.tau if tau is None else tau return F.gumbel_softmax(log_probs, tau=tau) else: return log_probs def convert_mask(self, attention_mask): extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask