from typing import Dict, Iterator, List,Tuple
from collections import OrderedDict

import torch
import torch.nn as nn

from allennlp.nn.util import get_text_field_mask
                              
import torch.nn.functional as F

from allennlp.modules.text_field_embedders import TextFieldEmbedder
from allennlp.modules.matrix_attention.cosine_matrix_attention import CosineMatrixAttention                          
from allennlp.modules.matrix_attention.dot_product_matrix_attention import DotProductMatrixAttention                          
from matchmaker.modules.masked_softmax import MaskedSoftmax

class PACRR(nn.Module):
    '''
    Paper: PACRR: A Position-Aware Neural IR Model for Relevance Matching, Hui et al., EMNLP'17

    Reference code (but in tensorflow):
    
    * first-hand: https://github.com/khui/copacrr/blob/master/models/pacrr.py
    
    '''

    @staticmethod
    def from_config(config,word_embeddings_out_dim):
        return PACRR(unified_query_length=config["pacrr_unified_query_length"], 
                     unified_document_length=config["pacrr_unified_document_length"],
                     max_conv_kernel_size=config["pacrr_max_conv_kernel_size"],
                     conv_output_size=config["pacrr_conv_output_size"],
                     kmax_pooling_size=config["pacrr_kmax_pooling_size"])

    def __init__(self,

                 unified_query_length:int,
                 unified_document_length:int,

                 max_conv_kernel_size: int, # 2 to n
                 conv_output_size: int, # conv output channels

                 kmax_pooling_size: int): # per query k-max pooling
                 
        super(PACRR,self).__init__()

        self.cosine_module = CosineMatrixAttention()

        self.unified_query_length = unified_query_length
        self.unified_document_length = unified_document_length

        self.convolutions = []
        for i in range(2, max_conv_kernel_size + 1):
            self.convolutions.append(
                nn.Sequential(
                    nn.ConstantPad2d((0,i - 1,0, i - 1), 0), # this outputs [batch,1,unified_query_length + i - 1 ,unified_document_length + i - 1]
                    nn.Conv2d(kernel_size=i, in_channels=1, out_channels=conv_output_size), # this outputs [batch,32,unified_query_length,unified_document_length]
                    nn.MaxPool3d(kernel_size=(conv_output_size,1,1)) # this outputs [batch,1,unified_query_length,unified_document_length]
            ))
        self.convolutions = nn.ModuleList(self.convolutions) # register conv as part of the model

        self.masked_softmax = MaskedSoftmax()
        self.kmax_pooling_size = kmax_pooling_size

        self.dense = nn.Linear(kmax_pooling_size * unified_query_length * max_conv_kernel_size, out_features=100, bias=True)
        self.dense2 = nn.Linear(100, out_features=10, bias=True)
        self.dense3 = nn.Linear(10, out_features=1, bias=False)

    def forward(self, query_embeddings: torch.Tensor, document_embeddings: torch.Tensor,
                query_pad_oov_mask: torch.Tensor, document_pad_oov_mask: torch.Tensor,
                query_idfs: torch.Tensor, document_idfs: torch.Tensor, 
                output_secondary_output: bool = False) -> torch.Tensor:

        #
        # similarity matrix
        # -------------------------------------------------------

        # create sim matrix
        cosine_matrix = self.cosine_module.forward(query_embeddings, document_embeddings)
        # shape: (batch, 1, query_max, doc_max) for the input of conv_2d
        cosine_matrix = cosine_matrix[:,None,:,:]

        #
        # duplicate cosine_matrix -> n-gram convolutions, then top-k pooling
        # ----------------------------------------------
        conv_results = []
        conv_results.append(torch.topk(cosine_matrix.squeeze(),k=self.kmax_pooling_size,sorted=True)[0])

        for conv in self.convolutions:
            cr = conv(cosine_matrix)
            cr_kmax_result = torch.topk(cr.squeeze(),k=self.kmax_pooling_size,sorted=True)[0]
            conv_results.append(cr_kmax_result)

        #
        # flatten all paths together & weight by query idf
        # -------------------------------------------------------
        
        per_query_results = torch.cat(conv_results,dim=-1)

        weigthed_per_query = per_query_results * self.masked_softmax(query_idfs, query_pad_oov_mask.unsqueeze(-1))

        all_flat = per_query_results.view(weigthed_per_query.shape[0],-1)


        #
        # dense layer
        # -------------------------------------------------------

        dense_out = F.relu(self.dense(all_flat))
        dense_out = F.relu(self.dense2(dense_out))
        dense_out = self.dense3(dense_out)

        output = torch.squeeze(dense_out, 1)
        return output

    def get_param_stats(self):
        return "PACRR: / "
        #return "PACRR: dense3 weight: "+str(self.dense3.weight.data)+\
        #" dense2 weight: "+str(self.dense2.weight.data)+" b: "+str(self.dense2.bias.data) +\
        #" dense weight: "+str(self.dense.weight.data)+" b: "+str(self.dense.bias.data)