import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from embedding.wordebd import WORDEBD
from embedding.auxiliary.factory import get_embedding
from collections import OrderedDict


class CNN(nn.Module):
    '''
        An aggregation method that encodes every document through different
        convolution filters (followed by max-over-time pooling).
    '''
    def __init__(self, ebd, args):
        super(CNN, self).__init__()
        self.args = args

        self.ebd = ebd
        self.aux = get_embedding(args)

        self.input_dim = self.ebd.embedding_dim + self.aux.embedding_dim

        # Convolution
        self.convs = nn.ModuleList([nn.Conv1d(
                    in_channels=self.input_dim,
                    out_channels=args.cnn_num_filters,
                    kernel_size=K) for K in args.cnn_filter_sizes])

        # used for visualization
        if args.mode == 'visualize':
            self.scores = [[] for _ in args.cnn_filter_sizes]

        self.ebd_dim = args.cnn_num_filters * len(args.cnn_filter_sizes)

    def _conv_max_pool(self, x, conv_filter=None, weights=None):
        '''
        Compute sentence level convolution
        Input:
            x:      batch_size, max_doc_len, embedding_dim
        Output:     batch_size, num_filters_total
        '''
        assert(len(x.size()) == 3)

        x = x.permute(0, 2, 1)  # batch_size, embedding_dim, doc_len
        x = x.contiguous()

        # Apply the 1d conv. Resulting dimension is
        # [batch_size, num_filters, doc_len-filter_size+1] * len(filter_size)
        assert(not ((conv_filter is None) and (weights is None)))
        if conv_filter is not None:
            x = [conv(x) for conv in conv_filter]

        elif weights is not None:
            x = [F.conv1d(x, weight=weights['convs.{}.weight'.format(i)],
                          bias=weights['convs.{}.bias'.format(i)])
                 for i in range(len(self.args.cnn_filter_sizes))]

        # max pool over time. Resulting dimension is
        # [batch_size, num_filters] * len(filter_size)
        x = [F.max_pool1d(sub_x, sub_x.size(2)).squeeze(2) for sub_x in x]

        # concatenate along all filters. Resulting dimension is
        # [batch_size, num_filters_total]
        x = torch.cat(x, 1)
        x = F.relu(x)

        return x

    def forward(self, data, weights=None):
        '''
            @param data dictionary
                @key text: batch_size * max_text_len
            @param weights placeholder used for maml

            @return output: batch_size * embedding_dim
        '''
        # Apply the word embedding, result:  batch_size, doc_len, embedding_dim
        ebd = self.ebd(data, weights)

        # add augmented embedding if applicable
        aux = self.aux(data, weights)

        ebd = torch.cat([ebd, aux], dim=2)

        # apply 1d conv + max pool, result:  batch_size, num_filters_total
        if weights is None:
            ebd = self._conv_max_pool(ebd, conv_filter=self.convs)
        else:
            ebd = self._conv_max_pool(ebd, weights=weights)

        # update max scores
        if self.args.mode == 'visualize':
            for i, s in enumerate(self.compute_score(data)):
                self.scores[i].append(torch.max(s).item())

        return ebd

    def compute_score(self, data, normalize=False):
        # preparing the input
        ebd = self.ebd(data)
        aux = self.aux(data)
        # (batch_size, doc_len, input_dim)
        x = torch.cat([ebd, aux], dim=-1).detach()

        # (out_channels, in_channels, kernel_size)
        w = [c.weight.data for c in self.convs]
        # (kernel_size, out_channels, in_channels)
        w = [c.permute(2,0,1) for c in w]
        # (out_channels * kernel_size, in_channels)
        w = [c.reshape(-1, self.input_dim) for c in w]

        # (batch_size, doc_len, out_channels * kernel_size)
        x = [x @ c.t() for c in w]
        # (batch_size, doc_len)
        x = [F.max_pool1d(z, z.shape[-1]).squeeze(-1) for z in x]

        if normalize:
            x = [x / np.mean(s) for x, s in zip(x, self.scores)]

        return x