'''
Created on Sep, 2017

@author: hugo

'''
import numpy as np

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
import torch.nn.functional as F

from .utils import to_cuda


INF = 1e20
VERY_SMALL_NUMBER = 1e-10
class BAMnet(nn.Module):
    def __init__(self, vocab_size, vocab_embed_size, o_embed_size, \
        hidden_size, num_ent_types, num_relations, num_query_words, \
        word_emb_dropout=None,\
        que_enc_dropout=None,\
        ans_enc_dropout=None, \
        pre_w2v=None, \
        num_hops=1, \
        att='add', \
        use_cuda=True):
        super(BAMnet, self).__init__()
        self.use_cuda = use_cuda
        self.word_emb_dropout = word_emb_dropout
        self.que_enc_dropout = que_enc_dropout
        self.ans_enc_dropout = ans_enc_dropout
        self.num_hops = num_hops
        self.hidden_size = hidden_size
        self.que_enc = SeqEncoder(vocab_size, vocab_embed_size, hidden_size, \
                        seq_enc_type='lstm', \
                        word_emb_dropout=word_emb_dropout, bidirectional=True, \
                        init_word_embed=pre_w2v, use_cuda=use_cuda).que_enc

        self.ans_enc = AnsEncoder(o_embed_size, hidden_size, \
                        num_ent_types, num_relations, \
                        vocab_size=vocab_size, \
                        vocab_embed_size=vocab_embed_size, \
                        shared_embed=self.que_enc.embed, \
                        word_emb_dropout=word_emb_dropout, \
                        ans_enc_dropout=ans_enc_dropout, \
                        use_cuda=use_cuda)

        self.qw_embed = nn.Embedding(num_query_words, o_embed_size // 8, padding_idx=0)
        self.batchnorm = nn.BatchNorm1d(hidden_size)

        self.init_atten = Attention(hidden_size, hidden_size, hidden_size, atten_type=att)
        self.self_atten = SelfAttention_CoAtt(hidden_size)
        print('[ Using self-attention on question encoder ]')

        self.memory_hop = RomHop(hidden_size, hidden_size, hidden_size, atten_type=att)
        print('[ Using {}-hop memory update ]'.format(self.num_hops))

    def kb_aware_query_enc(self, memories, queries, query_lengths, ans_mask, ctx_mask=None):
        # Question encoder
        Q_r = self.que_enc(queries, query_lengths)[0]
        if self.que_enc_dropout:
            Q_r = F.dropout(Q_r, p=self.que_enc_dropout, training=self.training)

        query_mask = create_mask(query_lengths, Q_r.size(1), self.use_cuda)
        q_r_init = self.self_atten(Q_r, query_lengths, query_mask)

        # Answer encoder
        _, _, _, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ent, x_ctx_ent_len, x_ctx_ent_num, _, _, _, _ = memories
        ans_comp_val, ans_comp_key = self.ans_enc(x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ent, x_ctx_ent_len, x_ctx_ent_num)
        if self.ans_enc_dropout:
            for _ in range(len(ans_comp_key)):
                ans_comp_key[_] = F.dropout(ans_comp_key[_], p=self.ans_enc_dropout, training=self.training)
        # KB memory summary
        ans_comp_atts = [self.init_atten(q_r_init, each, atten_mask=ans_mask) for each in ans_comp_key]
        if ctx_mask is not None:
            ans_comp_atts[-1] = ctx_mask * ans_comp_atts[-1] - (1 - ctx_mask) * INF
        ans_comp_probs = [torch.softmax(each, dim=-1) for each in ans_comp_atts]
        memory_summary = []
        for i, probs in enumerate(ans_comp_probs):
            memory_summary.append(torch.bmm(probs.unsqueeze(1), ans_comp_val[i]))
        memory_summary = torch.cat(memory_summary, 1)

        # Co-attention
        CoAtt = torch.bmm(Q_r, memory_summary.transpose(1, 2)) # co-attention matrix
        CoAtt = query_mask.unsqueeze(-1) * CoAtt - (1 - query_mask).unsqueeze(-1) * INF
        if ctx_mask is not None:
            # mask over empty ctx elements
            ctx_mask_global = (ctx_mask.sum(-1, keepdim=True) > 0).float()
            CoAtt[:, :, -1] = ctx_mask_global * CoAtt[:, :, -1].clone() - (1 - ctx_mask_global) * INF

        q_att = F.max_pool1d(CoAtt, kernel_size=CoAtt.size(-1)).squeeze(-1)
        q_att = torch.softmax(q_att, dim=-1)
        return (ans_comp_val, ans_comp_key), (q_att, Q_r), query_mask

    def forward(self, memories, queries, query_lengths, query_words, ctx_mask=None):
        ctx_mask = None
        mem_hop_scores = []
        ans_mask = create_mask(memories[0], memories[2].size(1), self.use_cuda)

        # Multi-task learning on answer type matching
        # question word vec
        self.qw_vec = torch.mean(self.qw_embed(query_words), 1)
        # answer type vec
        x_types = memories[4]
        ans_types = torch.mean(self.ans_enc.ent_type_embed(x_types.view(-1, x_types.size(-1))), 1).view(x_types.size(0), x_types.size(1), -1)
        qw_anstype_loss = torch.bmm(ans_types, self.qw_vec.unsqueeze(2)).squeeze(2)
        if ans_mask is not None:
            qw_anstype_loss = ans_mask * qw_anstype_loss - (1 - ans_mask) * INF # Make dummy candidates have large negative scores
        mem_hop_scores.append(qw_anstype_loss)


        # Kb-aware question attention module
        (ans_val, ans_key), (q_att, Q_r), query_mask = self.kb_aware_query_enc(memories, queries, query_lengths, ans_mask, ctx_mask=ctx_mask)
        ans_val = torch.cat([each.unsqueeze(2) for each in ans_val], 2)
        ans_key = torch.cat([each.unsqueeze(2) for each in ans_key], 2)

        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)
        mid_score = self.scoring(ans_key.sum(2), q_r, mask=ans_mask)
        mem_hop_scores.append(mid_score)

        Q_r, ans_key, ans_val = self.memory_hop(Q_r, ans_key, ans_val, q_att, atten_mask=ans_mask, ctx_mask=ctx_mask, query_mask=query_mask)
        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)
        mid_score = self.scoring(ans_key, q_r, mask=ans_mask)
        mem_hop_scores.append(mid_score)

        # Generalization module
        for _ in range(self.num_hops):
            q_r_tmp = self.memory_hop.gru_step(q_r, ans_key, ans_val, atten_mask=ans_mask)
            q_r = self.batchnorm(q_r + q_r_tmp)
            mid_score = self.scoring(ans_key, q_r, mask=ans_mask)
            mem_hop_scores.append(mid_score)
        return mem_hop_scores

    def premature_score(self, memories, queries, query_lengths, ctx_mask=None):
        ctx_mask = None
        ans_mask = create_mask(memories[0], memories[2].size(1), self.use_cuda)

        # Kb-aware question attention module
        (ans_val, ans_key), (q_att, Q_r), query_mask = self.kb_aware_query_enc(memories, queries, query_lengths, ans_mask, ctx_mask=ctx_mask)
        ans_key = torch.cat([each.unsqueeze(2) for each in ans_key], 2)

        mem_hop_scores = []
        q_r = torch.bmm(q_att.unsqueeze(1), Q_r).squeeze(1)
        score = self.scoring(ans_key.sum(2), q_r, mask=ans_mask)
        return score

    def scoring(self, ans_r, q_r, mask=None):
        score = torch.bmm(ans_r, q_r.unsqueeze(2)).squeeze(2)
        if mask is not None:
            score = mask * score - (1 - mask) * INF # Make dummy candidates have large negative scores
        return score

class RomHop(nn.Module):
    def __init__(self, query_embed_size, in_memory_embed_size, hidden_size, atten_type='add'):
        super(RomHop, self).__init__()
        self.hidden_size = hidden_size
        self.gru_linear_z = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.gru_linear_r = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.gru_linear_t = nn.Linear(2 * hidden_size, hidden_size, bias=False)
        self.gru_atten = Attention(hidden_size, query_embed_size, in_memory_embed_size, atten_type=atten_type)

    def forward(self, query_embed, in_memory_embed, out_memory_embed, query_att, \
                atten_mask=None, ctx_mask=None, query_mask=None):
        output = self.update_coatt_cat_maxpool(query_embed, in_memory_embed, out_memory_embed, query_att, \
                    atten_mask=atten_mask, ctx_mask=ctx_mask, query_mask=query_mask)
        return output

    def gru_step(self, h_state, in_memory_embed, out_memory_embed, atten_mask=None):
        attention = self.gru_atten(h_state, in_memory_embed, atten_mask=atten_mask)
        probs = torch.softmax(attention, dim=-1)

        memory_output = torch.bmm(probs.unsqueeze(1), out_memory_embed).squeeze(1)
        # GRU-like memory update
        z = torch.sigmoid(self.gru_linear_z(torch.cat([h_state, memory_output], -1)))
        r = torch.sigmoid(self.gru_linear_r(torch.cat([h_state, memory_output], -1)))
        t = torch.tanh(self.gru_linear_t(torch.cat([r * h_state, memory_output], -1)))
        output = (1 - z) * h_state + z * t
        return output

    def update_coatt_cat_maxpool(self, query_embed, in_memory_embed, out_memory_embed, query_att, atten_mask=None, ctx_mask=None, query_mask=None):
        attention = torch.bmm(query_embed, in_memory_embed.view(in_memory_embed.size(0), -1, in_memory_embed.size(-1))\
            .transpose(1, 2)).view(query_embed.size(0), query_embed.size(1), in_memory_embed.size(1), -1) # bs * N * M * k
        if ctx_mask is not None:
            attention[:, :, :, -1] = ctx_mask.unsqueeze(1) * attention[:, :, :, -1].clone() - (1 - ctx_mask).unsqueeze(1) * INF
        if atten_mask is not None:
            attention = atten_mask.unsqueeze(1).unsqueeze(-1) * attention - (1 - atten_mask).unsqueeze(1).unsqueeze(-1) * INF
        if query_mask is not None:
            attention = query_mask.unsqueeze(2).unsqueeze(-1) * attention - (1 - query_mask).unsqueeze(2).unsqueeze(-1) * INF

        # Importance module
        kb_feature_att = F.max_pool1d(attention.view(attention.size(0), attention.size(1), -1).transpose(1, 2), kernel_size=attention.size(1)).squeeze(-1).view(attention.size(0), -1, attention.size(-1))
        kb_feature_att = torch.softmax(kb_feature_att, dim=-1).view(-1, kb_feature_att.size(-1)).unsqueeze(1)
        in_memory_embed = torch.bmm(kb_feature_att, in_memory_embed.view(-1, in_memory_embed.size(2), in_memory_embed.size(-1))).squeeze(1).view(in_memory_embed.size(0), in_memory_embed.size(1), -1)
        out_memory_embed = out_memory_embed.sum(2)

        # Enhanced module
        attention = F.max_pool1d(attention.view(attention.size(0), -1, attention.size(-1)), kernel_size=attention.size(-1)).squeeze(-1).view(attention.size(0), attention.size(1), attention.size(2))
        probs = torch.softmax(attention, dim=-1)
        new_query_embed = query_embed + query_att.unsqueeze(2) * torch.bmm(probs, out_memory_embed)

        probs2 = torch.softmax(attention, dim=1)
        kb_att = torch.bmm(query_att.unsqueeze(1), probs).squeeze(1)
        in_memory_embed = in_memory_embed + kb_att.unsqueeze(2) * torch.bmm(probs2.transpose(1, 2), new_query_embed)
        return new_query_embed, in_memory_embed, out_memory_embed

class AnsEncoder(nn.Module):
    """Answer Encoder"""
    def __init__(self, o_embed_size, hidden_size, num_ent_types, num_relations, vocab_size=None, \
                    vocab_embed_size=None, shared_embed=None, word_emb_dropout=None, \
                    ans_enc_dropout=None, use_cuda=True):
        super(AnsEncoder, self).__init__()
        # Cannot have embed and vocab_size set as None at the same time.
        self.use_cuda = use_cuda
        self.ans_enc_dropout = ans_enc_dropout
        self.hidden_size = hidden_size
        self.ent_type_embed = nn.Embedding(num_ent_types, o_embed_size // 8, padding_idx=0)
        self.relation_embed = nn.Embedding(num_relations, o_embed_size, padding_idx=0)
        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, vocab_embed_size, padding_idx=0)
        self.vocab_embed_size = self.embed.weight.data.size(1)

        self.linear_type_bow_key = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_paths_key = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)
        self.linear_ctx_key = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_type_bow_val = nn.Linear(hidden_size, hidden_size, bias=False)
        self.linear_paths_val = nn.Linear(hidden_size + o_embed_size, hidden_size, bias=False)
        self.linear_ctx_val = nn.Linear(hidden_size, hidden_size, bias=False)

        # lstm for ans encoder
        self.lstm_enc_type = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \
                        dropout=word_emb_dropout, \
                        bidirectional=True, \
                        shared_embed=shared_embed, \
                        rnn_type='lstm', \
                        use_cuda=use_cuda)
        self.lstm_enc_path = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \
                        dropout=word_emb_dropout, \
                        bidirectional=True, \
                        shared_embed=shared_embed, \
                        rnn_type='lstm', \
                        use_cuda=use_cuda)
        self.lstm_enc_ctx = EncoderRNN(vocab_size, self.vocab_embed_size, hidden_size, \
                        dropout=word_emb_dropout, \
                        bidirectional=True, \
                        shared_embed=shared_embed, \
                        rnn_type='lstm', \
                        use_cuda=use_cuda)

    def forward(self, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num):
        ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent = self.enc_ans_features(x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num)
        ans_val, ans_key = self.enc_comp_kv(ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent)
        return ans_val, ans_key

    def enc_comp_kv(self, ans_type_bow, ans_types, ans_path_bow, ans_paths, ans_ctx_ent):
        ans_type_bow_val = self.linear_type_bow_val(ans_type_bow)
        ans_paths_val = self.linear_paths_val(torch.cat([ans_path_bow, ans_paths], -1))
        ans_ctx_val = self.linear_ctx_val(ans_ctx_ent)

        ans_type_bow_key = self.linear_type_bow_key(ans_type_bow)
        ans_paths_key = self.linear_paths_key(torch.cat([ans_path_bow, ans_paths], -1))
        ans_ctx_key = self.linear_ctx_key(ans_ctx_ent)

        ans_comp_val = [ans_type_bow_val, ans_paths_val, ans_ctx_val]
        ans_comp_key = [ans_type_bow_key, ans_paths_key, ans_ctx_key]
        return ans_comp_val, ans_comp_key

    def enc_ans_features(self, x_type_bow, x_types, x_type_bow_len, x_path_bow, x_paths, x_path_bow_len, x_ctx_ents, x_ctx_ent_len, x_ctx_ent_num):
        '''
        x_types: answer type
        x_paths: answer path, i.e., bow of relation
        x_ctx_ents: answer context, i.e., bow of entity words, (batch_size, num_cands, num_ctx, L)
        '''
        # ans_types = torch.mean(self.ent_type_embed(x_types.view(-1, x_types.size(-1))), 1).view(x_types.size(0), x_types.size(1), -1)
        ans_type_bow = (self.lstm_enc_type(x_type_bow.view(-1, x_type_bow.size(-1)), x_type_bow_len.view(-1))[1]).view(x_type_bow.size(0), x_type_bow.size(1), -1)
        ans_path_bow = (self.lstm_enc_path(x_path_bow.view(-1, x_path_bow.size(-1)), x_path_bow_len.view(-1))[1]).view(x_path_bow.size(0), x_path_bow.size(1), -1)
        ans_paths = torch.mean(self.relation_embed(x_paths.view(-1, x_paths.size(-1))), 1).view(x_paths.size(0), x_paths.size(1), -1)

        # Avg over ctx
        ctx_num_mask = create_mask(x_ctx_ent_num.view(-1), x_ctx_ents.size(2), self.use_cuda).view(x_ctx_ent_num.shape + (-1,))
        ans_ctx_ent = (self.lstm_enc_ctx(x_ctx_ents.view(-1, x_ctx_ents.size(-1)), x_ctx_ent_len.view(-1))[1]).view(x_ctx_ents.size(0), x_ctx_ents.size(1), x_ctx_ents.size(2), -1)
        ans_ctx_ent = ctx_num_mask.unsqueeze(-1) * ans_ctx_ent
        ans_ctx_ent = torch.sum(ans_ctx_ent, dim=2) / torch.clamp(x_ctx_ent_num.float().unsqueeze(-1), min=VERY_SMALL_NUMBER)

        if self.ans_enc_dropout:
            # ans_types = F.dropout(ans_types, p=self.ans_enc_dropout, training=self.training)
            ans_type_bow = F.dropout(ans_type_bow, p=self.ans_enc_dropout, training=self.training)
            ans_path_bow = F.dropout(ans_path_bow, p=self.ans_enc_dropout, training=self.training)
            ans_paths = F.dropout(ans_paths, p=self.ans_enc_dropout, training=self.training)
            ans_ctx_ent = F.dropout(ans_ctx_ent, p=self.ans_enc_dropout, training=self.training)
        return ans_type_bow, None, ans_path_bow, ans_paths, ans_ctx_ent

class SeqEncoder(object):
    """Question Encoder"""
    def __init__(self, vocab_size, embed_size, hidden_size, \
                seq_enc_type='lstm', word_emb_dropout=None,
                cnn_kernel_size=[3], bidirectional=False, \
                shared_embed=None, init_word_embed=None, use_cuda=True):
        if seq_enc_type in ('lstm', 'gru'):
            self.que_enc = EncoderRNN(vocab_size, embed_size, hidden_size, \
                        dropout=word_emb_dropout, \
                        bidirectional=bidirectional, \
                        shared_embed=shared_embed, \
                        init_word_embed=init_word_embed, \
                        rnn_type=seq_enc_type, \
                        use_cuda=use_cuda)

        elif seq_enc_type == 'cnn':
            self.que_enc = EncoderCNN(vocab_size, embed_size, hidden_size, \
                        kernel_size=cnn_kernel_size, dropout=word_emb_dropout, \
                        shared_embed=shared_embed, \
                        init_word_embed=init_word_embed, \
                        use_cuda=use_cuda)
        else:
            raise RuntimeError('Unknown SeqEncoder type: {}'.format(seq_enc_type))

class EncoderRNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, dropout=None, \
        bidirectional=False, shared_embed=None, init_word_embed=None, rnn_type='lstm', use_cuda=True):
        super(EncoderRNN, self).__init__()
        if not rnn_type in ('lstm', 'gru'):
            raise RuntimeError('rnn_type is expected to be lstm or gru, got {}'.format(rnn_type))
        if bidirectional:
            print('[ Using bidirectional {} encoder ]'.format(rnn_type))
        else:
            print('[ Using {} encoder ]'.format(rnn_type))
        if bidirectional and hidden_size % 2 != 0:
            raise RuntimeError('hidden_size is expected to be even in the bidirectional mode!')
        self.dropout = dropout
        self.rnn_type = rnn_type
        self.use_cuda = use_cuda
        self.hidden_size = hidden_size // 2 if bidirectional else hidden_size
        self.num_directions = 2 if bidirectional else 1
        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, embed_size, padding_idx=0)
        model = nn.LSTM if rnn_type == 'lstm' else nn.GRU
        self.model = model(embed_size, self.hidden_size, 1, batch_first=True, bidirectional=bidirectional)
        if shared_embed is None:
            self.init_weights(init_word_embed)

    def init_weights(self, init_word_embed):
        if init_word_embed is not None:
            print('[ Using pretrained word embeddings ]')
            self.embed.weight.data.copy_(torch.from_numpy(init_word_embed))
        else:
            self.embed.weight.data.uniform_(-0.08, 0.08)

    def forward(self, x, x_len):
        """x: [batch_size * max_length]
           x_len: [batch_size]
        """
        x = self.embed(x)
        if self.dropout:
            x = F.dropout(x, p=self.dropout, training=self.training)

        sorted_x_len, indx = torch.sort(x_len, 0, descending=True)
        x = pack_padded_sequence(x[indx], sorted_x_len.data.tolist(), batch_first=True)

        h0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)
        if self.rnn_type == 'lstm':
            c0 = to_cuda(torch.zeros(self.num_directions, x_len.size(0), self.hidden_size), self.use_cuda)
            packed_h, (packed_h_t, _) = self.model(x, (h0, c0))
            if self.num_directions == 2:
                packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)
        else:
            packed_h, packed_h_t = self.model(x, h0)
            if self.num_directions == 2:
                packed_h_t = packed_h_t.transpose(0, 1).contiguous().view(query_lengths.size(0), -1)

        hh, _ = pad_packed_sequence(packed_h, batch_first=True)

        # restore the sorting
        _, inverse_indx = torch.sort(indx, 0)
        restore_hh = hh[inverse_indx]
        restore_packed_h_t = packed_h_t[inverse_indx]
        return restore_hh, restore_packed_h_t


class EncoderCNN(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, kernel_size=[2, 3], \
            dropout=None, shared_embed=None, init_word_embed=None, use_cuda=True):
        super(EncoderCNN, self).__init__()
        print('[ Using CNN encoder with kernel size: {} ]'.format(kernel_size))
        self.use_cuda = use_cuda
        self.dropout = dropout
        self.embed = shared_embed if shared_embed is not None else nn.Embedding(vocab_size, embed_size, padding_idx=0)
        self.cnns = nn.ModuleList([nn.Conv1d(embed_size, hidden_size, kernel_size=k, padding=k-1) for k in kernel_size])

        if len(kernel_size) > 1:
            self.fc = nn.Linear(len(kernel_size) * hidden_size, hidden_size)
        if shared_embed is None:
            self.init_weights(init_word_embed)

    def init_weights(self, init_word_embed):
        if init_word_embed is not None:
            print('[ Using pretrained word embeddings ]')
            self.embed.weight.data.copy_(torch.from_numpy(init_word_embed))
        else:
            self.embed.weight.data.uniform_(-0.08, 0.08)

    def forward(self, x, x_len=None):
        """x: [batch_size * max_length]
           x_len: reserved
        """
        x = self.embed(x)
        if self.dropout:
            x = F.dropout(x, p=self.dropout, training=self.training)
        # Trun(batch_size, seq_len, embed_size) to (batch_size, embed_size, seq_len) for cnn1d
        x = x.transpose(1, 2)
        z = [conv(x) for conv in self.cnns]
        output = [F.max_pool1d(i, kernel_size=i.size(-1)).squeeze(-1) for i in z]

        if len(output) > 1:
            output = self.fc(torch.cat(output, -1))
        else:
            output = output[0]
        return None, output


class Attention(nn.Module):
    def __init__(self, hidden_size, h_state_embed_size=None, in_memory_embed_size=None, atten_type='simple'):
        super(Attention, self).__init__()
        self.atten_type = atten_type
        if not h_state_embed_size:
            h_state_embed_size = hidden_size
        if not in_memory_embed_size:
            in_memory_embed_size = hidden_size
        if atten_type in ('mul', 'add'):
            self.W = torch.Tensor(h_state_embed_size, hidden_size)
            self.W = nn.Parameter(nn.init.xavier_uniform_(self.W))
            if atten_type == 'add':
                self.W2 = torch.Tensor(in_memory_embed_size, hidden_size)
                self.W2 = nn.Parameter(nn.init.xavier_uniform_(self.W2))
                self.W3 = torch.Tensor(hidden_size, 1)
                self.W3 = nn.Parameter(nn.init.xavier_uniform_(self.W3))
        elif atten_type == 'simple':
            pass
        else:
            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))

    def forward(self, query_embed, in_memory_embed, atten_mask=None):
        if self.atten_type == 'simple': # simple attention
            attention = torch.bmm(in_memory_embed, query_embed.unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'mul': # multiplicative attention
            attention = torch.bmm(in_memory_embed, torch.mm(query_embed, self.W).unsqueeze(2)).squeeze(2)
        elif self.atten_type == 'add': # additive attention
            attention = torch.tanh(torch.mm(in_memory_embed.view(-1, in_memory_embed.size(-1)), self.W2)\
                .view(in_memory_embed.size(0), -1, self.W2.size(-1)) \
                + torch.mm(query_embed, self.W).unsqueeze(1))
            attention = torch.mm(attention.view(-1, attention.size(-1)), self.W3).view(attention.size(0), -1)
        else:
            raise RuntimeError('Unknown atten_type: {}'.format(self.atten_type))

        if atten_mask is not None:
            # Exclude masked elements from the softmax
            attention = atten_mask * attention - (1 - atten_mask) * INF
        return attention

class SelfAttention_CoAtt(nn.Module):
    def __init__(self, hidden_size, use_cuda=True):
        super(SelfAttention_CoAtt, self).__init__()
        self.use_cuda = use_cuda
        self.hidden_size = hidden_size
        self.model = nn.LSTM(2 * hidden_size, hidden_size // 2, batch_first=True, bidirectional=True)

    def forward(self, x, x_len, atten_mask):
        CoAtt = torch.bmm(x, x.transpose(1, 2))
        CoAtt = atten_mask.unsqueeze(1) * CoAtt - (1 - atten_mask).unsqueeze(1) * INF
        CoAtt = torch.softmax(CoAtt, dim=-1)
        new_x = torch.cat([torch.bmm(CoAtt, x), x], -1)

        sorted_x_len, indx = torch.sort(x_len, 0, descending=True)
        new_x = pack_padded_sequence(new_x[indx], sorted_x_len.data.tolist(), batch_first=True)

        h0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)
        c0 = to_cuda(torch.zeros(2, x_len.size(0), self.hidden_size // 2), self.use_cuda)
        packed_h, (packed_h_t, _) = self.model(new_x, (h0, c0))

        # restore the sorting
        _, inverse_indx = torch.sort(indx, 0)
        packed_h_t = torch.cat([packed_h_t[i] for i in range(packed_h_t.size(0))], -1)
        restore_packed_h_t = packed_h_t[inverse_indx]
        output = restore_packed_h_t
        return output

def create_mask(x, N, use_cuda=True):
    x = x.data
    mask = np.zeros((x.size(0), N))
    for i in range(x.size(0)):
        mask[i, :x[i]] = 1
    return to_cuda(torch.Tensor(mask), use_cuda)