import torch import torch.nn as nn import torch.nn.init as init import numpy as np from .._functions import GELU from efficiency.log import show_var import pdb PAD_ID_WORD = 1 max_doc_n_words = 1534 + 2 class AttEncoderLayer(nn.Module): ''' Compose with two layers ''' def __init__(self, n_head, d_graph, d_inner_hid, d_k, d_v, p_gcn): super(AttEncoderLayer, self).__init__() self.slf_attn = MultiHeadAttention( n_head, d_graph, d_k, d_v, dropout=p_gcn) self.pos_ffn = PositionwiseFeedForward(d_graph, d_inner_hid, dropout=p_gcn) def forward(self, enc_input, slf_attn_mask=None): enc_output, enc_slf_attn = self.slf_attn( enc_input, enc_input, enc_input, attn_mask=slf_attn_mask) enc_output = self.pos_ffn(enc_output) return enc_output, enc_slf_attn ''' Define the sublayers in encoder/decoder layer ''' def get_attn_padding_mask(seq_q, seq_k): ''' Indicate the padding-related part to mask ''' assert seq_q.dim() == 2 and seq_k.dim() == 2 mb_size, len_q = seq_q.size() mb_size, len_k = seq_k.size() pad_attn_mask = seq_k.eq(PAD_ID_WORD).unsqueeze(1) # bx1xsk pad_attn_mask = pad_attn_mask.expand(mb_size, len_q, len_k) # bxsqxsk return pad_attn_mask def get_attn_adj_mask(adjs): adjs_mask = adjs.ne(0) # batch*n_node*n_node # torch.set_printoptions(precision=None, threshold=float('inf')) # pdb.set_trace() n_neig = adjs_mask.sum(dim=2) adjs_mask[:, :, 0] += n_neig.eq(0) # this is for making PAD not all zeros return adjs_mask.eq(0) def adj_normalization(adj): ''' symmetric normalization of adjacency matrix, i.e. :param adj: :return: D^{-0.5} \dot (A+I) \dot D^{-0.5} ''' rowsum = torch.clamp(adj.sum(-1), min=1) d_inv_sqrt = torch.pow(rowsum, -0.5) diag = torch.zeros_like(adj) diag.as_strided(d_inv_sqrt.size(), [diag.stride(0), diag.size(2) + 1]).copy_(d_inv_sqrt) normed = torch.bmm(torch.bmm(diag, adj), diag) return normed class PositionEncoder(nn.Module): def __init__(self, d_graph, mode="lookup"): super(PositionEncoder, self).__init__() self.mode = mode max_n_node = max_doc_n_words d_pos = 1 if self.mode == "lookup": self.position_enc = nn.Embedding(max_n_node, d_graph, padding_idx=0) self.position_enc.weight.data = self._position_encoding_init(max_n_node, d_graph) elif self.mode == "linear": self.position_enc = nn.Linear(d_pos, d_graph) def _position_encoding_init(self, n_position, d_pos_vec): ''' Init the sinusoid position encoding table ''' # keep dim 0 for padding token position encoding zero vector position_enc = np.array([ [pos / np.power(10000, 2 * (j // 2) / d_pos_vec) for j in range(d_pos_vec)] if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)]) position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1 return torch.from_numpy(position_enc).type(torch.FloatTensor) def forward(self, pos, h_gcn): # x: (N, input_dim) if self.mode == "lookup": pos_enc = self.position_enc(pos) pos_enc = pos_enc.squeeze(2) elif self.mode == "linear": pos_enc = F.tanh(self.position_enc(pos)) elif self.mode == "none": pos_enc = torch.zeros_like(h_gcn) return pos_enc class Linear(nn.Module): ''' Simple Linear layer with xavier init ''' def __init__(self, d_in, d_out, bias=True): super(Linear, self).__init__() self.linear = nn.Linear(d_in, d_out, bias=bias) init.xavier_normal_(self.linear.weight) def forward(self, x): return self.linear(x) class Bottle(nn.Module): ''' Perform the reshape routine before and after an operation ''' def forward(self, input): if len(input.size()) <= 2: return super(Bottle, self).forward(input) size = input.size()[:2] out = super(Bottle, self).forward(input.view(size[0] * size[1], -1)) return out.view(size[0], size[1], -1) class BottleLinear(Bottle, Linear): ''' Perform the reshape routine before and after a linear projection ''' pass class BottleSoftmax(Bottle, nn.Softmax): ''' Perform the reshape routine before and after a softmax operation''' pass class MultiHeadAttention(nn.Module): ''' Multi-Head Attention module ''' def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, use_residual=True): super(MultiHeadAttention, self).__init__() self.n_head = n_head self.d_k = d_k self.d_v = d_v self.w_qs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k)) self.w_ks = nn.Parameter(torch.FloatTensor(n_head, d_model, d_k)) self.w_vs = nn.Parameter(torch.FloatTensor(n_head, d_model, d_v)) self.attention = ScaledDotProductAttention(d_model) self.layer_norm = LayerNormalization(d_model) self.proj = BottleLinear(n_head * d_v, d_model) self.dropout = nn.Dropout(dropout) self.use_residual = use_residual init.xavier_normal_(self.w_qs) init.xavier_normal_(self.w_ks) init.xavier_normal_(self.w_vs) def forward(self, q, k, v, attn_mask=None): d_k, d_v = self.d_k, self.d_v n_head = self.n_head residual = q if self.use_residual else 0 mb_size, len_q, d_model = q.size() mb_size, len_k, d_model = k.size() mb_size, len_v, d_value = v.size() # get v_s v_s = v.repeat(n_head, 1, 1) # (n_head*mb_size) x len_v x d_value if self.use_residual: v_s = v.repeat(n_head, 1, 1).view(n_head, -1, d_value) # n_head x (mb_size*len_v) x d_model v_s = torch.bmm(v_s, self.w_vs).view(-1, len_v, d_v) # (n_head*mb_size) x len_v x d_v # treat as a (n_head) size batch q_s = q.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_q) x d_model k_s = k.repeat(n_head, 1, 1).view(n_head, -1, d_model) # n_head x (mb_size*len_k) x d_model # treat the result as a (n_head * mb_size) size batch q_s = torch.bmm(q_s, self.w_qs).view(-1, len_q, d_k) # (n_head*mb_size) x len_q x d_k k_s = torch.bmm(k_s, self.w_ks).view(-1, len_k, d_k) # (n_head*mb_size) x len_k x d_k # perform attention, result size = (n_head * mb_size) x len_q x d_v attn_mask = attn_mask.repeat(n_head, 1, 1) if attn_mask is not None else attn_mask outputs, attns = self.attention(q_s, k_s, v_s, attn_mask=attn_mask) if self.use_residual: # back to original mb_size batch, result size = mb_size x len_q x (n_head*d_v) outputs = torch.cat(torch.split(outputs, mb_size, dim=0), dim=-1) # project back to residual size outputs = self.proj(outputs) else: outputs = outputs.mean(0, True) attns = attns.mean(0, True) outputs = self.dropout(outputs) if self.use_residual: outputs = self.layer_norm(outputs + residual) return outputs, attns class PositionwiseFeedForward(nn.Module): ''' A two-feed-forward-layer module ''' def __init__(self, d_hid, d_inner_hid, dropout=0.1): super(PositionwiseFeedForward, self).__init__() self.w_1 = nn.Conv1d(d_hid, d_inner_hid, 1) # position-wise self.w_2 = nn.Conv1d(d_inner_hid, d_hid, 1) # position-wise self.layer_norm = LayerNormalization(d_hid) self.dropout = nn.Dropout(dropout) self.elu = nn.ReLU() def forward(self, x): residual = x output = self.elu(self.w_1(x.transpose(1, 2))) output = self.w_2(output).transpose(2, 1) output = self.dropout(output) return self.layer_norm(output + residual) class LayerNormalization(nn.Module): ''' Layer normalization module ''' def __init__(self, d_hid, eps=1e-3): super(LayerNormalization, self).__init__() self.eps = eps self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True) self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True) def forward(self, z): if z.size(1) == 1: return z mu = torch.mean(z, keepdim=True, dim=-1) sigma = torch.std(z, keepdim=True, dim=-1) ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps) ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out) return ln_out class ScaledDotProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, d_model, attn_dropout=0.1): super(ScaledDotProductAttention, self).__init__() self.temper = np.power(d_model, 0.5) self.dropout = nn.Dropout(attn_dropout) self.softmax = BottleSoftmax(dim=-1) def forward(self, q, k, v, attn_mask=None, show_net=False): attn = torch.bmm(q, k.transpose(1, 2)) / self.temper # cos sim needs normalization if attn_mask is not None: assert attn_mask.size() == attn.size(), \ 'Attention mask shape {} mismatch ' \ 'with Attention logit tensor shape ' \ '{}.'.format(attn_mask.size(), attn.size()) attn.masked_fill_(attn_mask, -float('inf')) attn = self.softmax(attn) # attn: [32, 27, 27] if attn_mask is not None: attn.data.masked_fill_(attn_mask, 0) # convert NaN to 0 attn = self.dropout(attn) output = torch.bmm(attn, v) return output, attn class WeightedScaledDotProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, d_model, attn_dropout=0.1, ff_layers=1, ff_drop_p=0.2, comb_mode=2): super(WeightedScaledDotProductAttention, self).__init__() self.temper = np.power(d_model, 0.5) self.dropout = nn.Dropout(attn_dropout) self.softmax = BottleSoftmax(dim=-1) self.ff_layers = ff_layers self.linear1 = nn.Linear(d_model, d_model) self.ff_dropout = nn.Dropout(ff_drop_p) if ff_layers == 2: self.linear2 = nn.Linear(d_model, d_model) self.elu = GELU() if comb_mode == 1: self.comb_att_n_init_adj = add_n_norm elif comb_mode == 2: self.comb_att_n_init_adj = learn_n_norm elif comb_mode == 0: self.comb_att_n_init_adj = use_init_adj self.comb_mode = comb_mode def _fc(self, lin, q, k, use_elu=True): # assume q == k q_out = lin(q) q_out = self.ff_dropout(q_out) if use_elu: q_out = self.elu(q_out) return q_out, q_out def forward(self, q, k, v, attn_mask=None, show_net=False): assert len(q.size()) == 3 if self.ff_layers == 1: q_out, k_out = self._fc(self.linear1, q, k, use_elu=False) if show_net: show_var(["self.linear1"]) elif self.ff_layers == 2: q_out, k_out = self._fc(self.linear1, q, k) q_out, k_out = self._fc(self.linear2, q_out, k_out, use_elu=False) if show_net: show_var(["self.linear1", "self.linear2"]) if show_net: print("bmm --> self.dropout") show_var(["self.comb_att_n_init_adj"]) attn = torch.bmm(q_out, k_out.transpose(1, 2)) # / self.temper # cos sim needs normalization if attn_mask is not None: assert attn_mask.size() == attn.size(), \ 'Attention mask shape {} mismatch ' \ 'with Attention logit tensor shape ' \ '{}.'.format(attn_mask.size(), attn.size()) attn.masked_fill_(attn_mask, -float('inf')) if attn_mask is not None: attn.data.masked_fill_(attn_mask, 0) # convert NaN to 0 attn = self.dropout(attn) import pdb; pdb.set_trace() output = self.comb_att_n_init_adj(attn, v) return output, attn def add_n_norm(attn, v): output = attn + v output = adj_normalization(output) return output def learn_n_norm(attn, v): output = adj_normalization(attn) return output def use_init_adj(attn, v): return v class ConcatProductAttention(nn.Module): ''' Scaled Dot-Product Attention ''' def __init__(self, d_model, attn_dropout=0.1, ff_drop_p=0.2, use_elu=False): super(ConcatProductAttention, self).__init__() self.dropout = nn.Dropout(attn_dropout) self.softmax = BottleSoftmax(dim=-1) self.linear1 = nn.Linear(d_model, 1) self.linear2 = nn.Linear(d_model, 1) self.ff_dropout = nn.Dropout(ff_drop_p) if use_elu: self.elu = GELU() def _fc(self, lin, q, use_elu=False): q_out = lin(q) q_out = self.ff_dropout(q_out) if use_elu: q_out = self.elu(q_out) return q_out def forward(self, q, k, v, attn_mask=None, show_net=False): batch, sent_len, dim = q.size() q_out = self._fc(self.linear1, q) k_out = self._fc(self.linear2, k) k_out = k_out.permute(0, 2, 1) q_out = q_out.expand(batch, sent_len, sent_len) k_out = k_out.expand(batch, sent_len, sent_len) attn = q_out + k_out # cos sim needs normalization if attn_mask is not None: assert attn_mask.size() == attn.size(), \ 'Attention mask shape {} mismatch ' \ 'with Attention logit tensor shape ' \ '{}.'.format(attn_mask.size(), attn.size()) attn.masked_fill_(attn_mask, -float('inf')) attn = self.softmax(attn) # attn: [32, 27, 27] if attn_mask is not None: attn.data.masked_fill_(attn_mask, 0) # convert NaN to 0 attn = self.dropout(attn) output = torch.bmm(attn, v) return output, attn