import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class StructuredAttention(nn.Module):
    """Use each word in context to attend to words in query.
    In my case, context is question-answer, query is the object-level
    features in an image.

    Note the values in S are cosine similarity scores, and are in [-1, 1]
    They are scaled before softmax to make sure the maximum value could
    get very high probability.
    S_ = F.softmax(S * self.scale, dim=-1)

    Consider softmax function f(m) = exp(m) / [24 * exp(-m) + exp(m)]
    If not scaled, S * scale \in [-100, 100], the weight the maximum value could only get is
    exp(1) / [24 * exp(-1) + exp(1)] = 0.04 .
    When set the scale = 100, S * scale \in [-100, 100]
    exp(100) / [24 * exp(-100) + exp(100)] = 0.9976
    """
    def __init__(self, dropout=0.1, scale=100, add_void=False):
        """
        Args:
            dropout:
            scale:
            add_void:
        """
        super(StructuredAttention, self).__init__()
        self.dropout = dropout
        self.scale = scale
        self.add_void = add_void

    def forward(self, C, Q, c_mask, q_mask, noun_mask=None, void_vector=None):
        """
        match the dim of '*', singlton is allowed
        Args:
            C: (N, 5, Li, Lqa, D)
            Q: (N, 1, Li, Lr, D)
            c_mask: (N, 5, Li, Lqa)
            q_mask: (N, 1, Li, Lr)
            noun_mask: (N, 5, Lqa) , where 1 indicate the current position is a noun
               or (N, 5, Li, Lqa), where each entry is the probability of the current
               image being a positive bag for the word
            void_vector: (D, )
        Returns:
            (N, *, Lc, D)
        """
        bsz, _, num_img, num_region, hsz = Q.shape
        if void_vector is not None:
            num_void = len(void_vector)
            Q_void = void_vector.view(1, 1, 1, num_void, hsz).repeat(bsz, 1, num_img, 1, 1)
            Q = torch.cat([Q, Q_void], dim=-2)  # (N, 1, Li, Lr+num_void, D)
            q_mask_void = q_mask.new_ones(bsz, 1, num_img, num_void)  # ones
            q_mask = torch.cat([q_mask, q_mask_void], dim=-1)  # (N, 1, Li, Lr+num_void)

        S, S_mask = self.similarity(C, Q, c_mask, q_mask)  # (N, 5, Li, Lqa, Lr+num_void)
        S_ = F.softmax(S * self.scale, dim=-1)
        # (N, 5, Li, Lqa, Lr+1) # the weight of each query word to a given context word
        S_ = S_ * S_mask  # for columns that are all padded elements

        if noun_mask is not None:
            if len(noun_mask.shape) == 3:
                bsz, num_qa, lqa = noun_mask.shape
                S_ = S_ * noun_mask.view(bsz, num_qa, 1, lqa, 1)
            elif len(noun_mask.shape) == 4:
                S_ = S_ * noun_mask.unsqueeze(-1)
            else:
                raise NotImplementedError

        if void_vector is not None:
            if self.add_void:
                A = torch.matmul(S_, Q)  # (N, 5, Li, Lqa, D)
                S, S_mask, S_ = S[:, :, :, :, :-num_void], S_mask[:, :, :, :, :-num_void], S_[:, :, :, :, :-num_void]
            else:
                S, S_mask, S_ = S[:, :, :, :, :-num_void], S_mask[:, :, :, :, :-num_void], S_[:, :, :, :, :-num_void]
                Q = Q[:, :, :, :-num_void, :]  # (N, 1, Li, Lr, D)
                A = torch.matmul(S_, Q)  # (N, 5, Li, Lqa, D)
        else:
            A = torch.matmul(S_, Q)  # (N, 5, Li, Lqa, D)
        return A, S, S_mask, S_

    def similarity(self, C, Q, c_mask, q_mask):
        """
        word2word dot-product similarity
        Args:
            C: (N, 5, Li, Lqa, D)
            Q: (N, 1, Li, Lr, D)
            c_mask: (N, 5, Li, Lqa)
            q_mask: (N, 1, Li, Lr)
        Returns:
            (N, *, Lc, Lq)
        """
        C = F.dropout(F.normalize(C, p=2, dim=-1), p=self.dropout, training=self.training)
        Q = F.dropout(F.normalize(Q, p=2, dim=-1), p=self.dropout, training=self.training)

        S_mask = torch.matmul(c_mask.unsqueeze(-1), q_mask.unsqueeze(-2))  # (N, 5, Li, Lqa, Lr)
        S = torch.matmul(C, Q.transpose(-2, -1))  # (N, 5, Li, Lqa, Lr)
        masked_S = S - 1e10*(1 - S_mask)  # (N, 5, Li, Lqa, Lr)
        return masked_S, S_mask


class ContextQueryAttention(nn.Module):
    """
    sub-a attention
    """
    def __init__(self):
        super(ContextQueryAttention, self).__init__()

    def forward(self, C, Q, c_mask, q_mask):
        """
        match the dim of '*', singlton is allowed
        :param C: (N, *, Lc, D)
        :param Q: (N, *, Lq, D)
        :param c_mask: (N, *, Lc)
        :param q_mask: (N, *, Lq)
        :return: (N, Lc, D) and (N, Lq, D)
        """

        S = self.similarity(C, Q, c_mask, q_mask)  # (N, *, Lc, Lq)
        S_ = F.softmax(S, dim=-1)  # (N, *, Lc, Lq)
        A = torch.matmul(S_, Q)  # (N, *, Lc, D)
        return A

    def similarity(self, C, Q, c_mask, q_mask):
        """
        word2word dot-product similarity
        :param C: (N, *, Lc, D)
        :param Q: (N, *, Lq, D)
        :param c_mask: (N, *, Lc)
        :param q_mask: (N, *, Lq)
        :return: (N, *, Lc, Lq)
        """
        C = F.dropout(C, p=0.1, training=self.training)
        Q = F.dropout(Q, p=0.1, training=self.training)
        hsz_root = math.sqrt(C.shape[-1])

        S_mask = torch.matmul(c_mask.unsqueeze(-1), q_mask.unsqueeze(-2))  # (N, *, Lc, Lq)
        S = torch.matmul(C, Q.transpose(-2, -1)) / hsz_root  # (N, *, Lc, Lq)
        masked_S = S - 1e10*(1 - S_mask)  # (N, *, Lc, Lq)
        return masked_S


def test():
    # (N, *, D, Lc)
    c2q = ContextQueryAttention()
    hsz = 128
    bsz = 10
    lc = 20
    lq = 10
    context = torch.randn(bsz, hsz, lc).float()
    context_mask = torch.ones(bsz, lc).float()
    query = torch.randn(bsz, hsz, lq).float()
    query_mask = torch.ones(bsz, lq).float()
    a, b = c2q(context, query, context_mask, query_mask)
    print("input size", context.shape, context_mask.shape, query.shape, query_mask.shape)
    print("output size", a.shape, b.shape)


if __name__ == '__main__':
    test()