import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class MPCNN(nn.Module): def __init__(self, embedding, n_word_dim, n_holistic_filters, n_per_dim_filters, filter_widths, hidden_layer_units, num_classes, dropout): super(MPCNN, self).__init__() self.embedding = embedding self.n_word_dim = n_word_dim self.n_holistic_filters = n_holistic_filters self.n_per_dim_filters = n_per_dim_filters self.filter_widths = filter_widths holistic_conv_layers = [] per_dim_conv_layers = [] self.in_channels = n_word_dim for ws in filter_widths: if np.isinf(ws): continue holistic_conv_layers.append(nn.Sequential( nn.Conv1d(self.in_channels, n_holistic_filters, ws), nn.Tanh() )) per_dim_conv_layers.append(nn.Sequential( nn.Conv1d(self.in_channels, self.in_channels * n_per_dim_filters, ws, groups=self.in_channels), nn.Tanh() )) self.holistic_conv_layers = nn.ModuleList(holistic_conv_layers) self.per_dim_conv_layers = nn.ModuleList(per_dim_conv_layers) # compute number of inputs to first hidden layer COMP_1_COMPONENTS_HOLISTIC, COMP_1_COMPONENTS_PER_DIM, COMP_2_COMPONENTS = 2 + n_holistic_filters, 2 + self.in_channels, 2 n_feat_h = 3 * len(self.filter_widths) * COMP_2_COMPONENTS n_feat_v = ( # comparison units from holistic conv for min, max, mean pooling for non-infinite widths 3 * ((len(self.filter_widths) - 1) ** 2) * COMP_1_COMPONENTS_HOLISTIC + # comparison units from holistic conv for min, max, mean pooling for infinite widths 3 * 3 + # comparison units from per-dim conv 2 * (len(self.filter_widths) - 1) * n_per_dim_filters * COMP_1_COMPONENTS_PER_DIM ) n_feat = n_feat_h + n_feat_v self.final_layers = nn.Sequential( nn.Linear(n_feat, hidden_layer_units), nn.Tanh(), nn.Dropout(dropout), nn.Linear(hidden_layer_units, num_classes), nn.LogSoftmax(1) ) def _get_blocks_for_sentence(self, sent): block_a = {} block_b = {} for ws in self.filter_widths: if np.isinf(ws): sent_flattened, sent_flattened_size = sent.contiguous().view(sent.size(0), 1, -1), sent.size(1) * sent.size(2) block_a[ws] = { 'max': F.max_pool1d(sent_flattened, sent_flattened_size).view(sent.size(0), -1), 'min': F.max_pool1d(-1 * sent_flattened, sent_flattened_size).view(sent.size(0), -1), 'mean': F.avg_pool1d(sent_flattened, sent_flattened_size).view(sent.size(0), -1) } continue holistic_conv_out = self.holistic_conv_layers[ws - 1](sent) block_a[ws] = { 'max': F.max_pool1d(holistic_conv_out, holistic_conv_out.size(2)).contiguous().view(-1, self.n_holistic_filters), 'min': F.max_pool1d(-1 * holistic_conv_out, holistic_conv_out.size(2)).contiguous().view(-1, self.n_holistic_filters), 'mean': F.avg_pool1d(holistic_conv_out, holistic_conv_out.size(2)).contiguous().view(-1, self.n_holistic_filters) } per_dim_conv_out = self.per_dim_conv_layers[ws - 1](sent) block_b[ws] = { 'max': F.max_pool1d(per_dim_conv_out, per_dim_conv_out.size(2)).contiguous().view(-1, self.in_channels, self.n_per_dim_filters), 'min': F.max_pool1d(-1 * per_dim_conv_out, per_dim_conv_out.size(2)).contiguous().view(-1, self.in_channels, self.n_per_dim_filters) } return block_a, block_b def _algo_1_horiz_comp(self, sent1_block_a, sent2_block_a): comparison_feats = [] for pool in ('max', 'min', 'mean'): for ws in self.filter_widths: x1 = sent1_block_a[ws][pool] x2 = sent2_block_a[ws][pool] batch_size = x1.size()[0] comparison_feats.append(F.cosine_similarity(x1, x2).contiguous().view(batch_size, 1)) comparison_feats.append(F.pairwise_distance(x1, x2)) return torch.cat(comparison_feats, dim=1) def _algo_2_vert_comp(self, sent1_block_a, sent2_block_a, sent1_block_b, sent2_block_b): comparison_feats = [] ws_no_inf = [w for w in self.filter_widths if not np.isinf(w)] for pool in ('max', 'min', 'mean'): for ws1 in self.filter_widths: x1 = sent1_block_a[ws1][pool] batch_size = x1.size()[0] for ws2 in self.filter_widths: x2 = sent2_block_a[ws2][pool] if (not np.isinf(ws1) and not np.isinf(ws2)) or (np.isinf(ws1) and np.isinf(ws2)): comparison_feats.append(F.cosine_similarity(x1, x2).contiguous().view(batch_size, 1)) comparison_feats.append(F.pairwise_distance(x1, x2)) comparison_feats.append(torch.abs(x1 - x2)) for pool in ('max', 'min'): for ws in ws_no_inf: oG_1B = sent1_block_b[ws][pool] oG_2B = sent2_block_b[ws][pool] for i in range(0, self.n_per_dim_filters): x1 = oG_1B[:, :, i] x2 = oG_2B[:, :, i] batch_size = x1.size()[0] comparison_feats.append(F.cosine_similarity(x1, x2).contiguous().view(batch_size, 1)) comparison_feats.append(F.pairwise_distance(x1, x2)) comparison_feats.append(torch.abs(x1 - x2)) return torch.cat(comparison_feats, dim=1) def forward(self, batch): sent1 = self.embedding(batch.sentence_a).transpose(1, 2) sent2 = self.embedding(batch.sentence_b).transpose(1, 2) # Sentence modeling module sent1_block_a, sent1_block_b = self._get_blocks_for_sentence(sent1) sent2_block_a, sent2_block_b = self._get_blocks_for_sentence(sent2) # Similarity measurement layer feat_h = self._algo_1_horiz_comp(sent1_block_a, sent2_block_a) feat_v = self._algo_2_vert_comp(sent1_block_a, sent2_block_a, sent1_block_b, sent2_block_b) combined_feats = [feat_h, feat_v] feat_all = torch.cat(combined_feats, dim=1) preds = self.final_layers(feat_all) return preds