import torch
import torch.nn as nn
import torch.nn.functional as F

from .modules import clones, LayerNorm, SublayerConnection
from .constants import print_dims

class Decoder(nn.Module):
    "Generic N layer decoder with masking"
    def __init__(self, layer, N):
        super(Decoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
        
    def forward(self, x, src_states, src_mask, tgt_mask):
        """
            x: (batch_size, tgt_seq_len, d_model)
            src_states: (batch_size, src_seq_len, d_model)
            src_mask: (batch_size, 1, src_seq_len)
            tgt_mask: (batch_size, tgt_seq_len, tgt_seq_len)
        """
        if print_dims:
            print("{0}: x: type: {1}, shape: {2}".format(self.__class__.__name__, x.type(), x.shape))
            print("{0}: src_states: type: {1}, shape: {2}".format(self.__class__.__name__, src_states.type(), src_states.shape))
            print("{0}: src_mask: type: {1}, shape: {2}".format(self.__class__.__name__, src_mask.type(), src_mask.shape))
            print("{0}: tgt_mask: type: {1}, shape: {2}".format(self.__class__.__name__, tgt_mask.type(), tgt_mask.shape))
        for layer in self.layers:
            x = layer(x, src_states, src_mask, tgt_mask)
        x = self.norm(x) # (batch_size, tgt_seq_len, d_model)
        if print_dims:
            print("{0}: x (output): type: {1}, shape: {2}".format(self.__class__.__name__, x.type(), x.shape))
        
        # add max pooling across sequences
        x = F.max_pool1d(x.permute(0,2,1), x.shape[1]).squeeze(-1) # (batch_size, d_model)
        return x

class DecoderLayer(nn.Module):
    "Decoder is made of self-attn, src-attn and feed forward"
    def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
        super(DecoderLayer, self).__init__()
        self.size = size
        self.self_attn = self_attn
        self.src_attn = src_attn
        self.feed_forward = feed_forward
        self.sublayer = clones(SublayerConnection(size, dropout), 3)
        
    def forward(self, x, src_states, src_mask, tgt_mask):
        """norm -> self_attn -> dropout -> add -> 
        norm -> src_attn -> dropout -> add ->
        norm -> feed_forward -> dropout -> add"""
        x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
        x = self.sublayer[1](x, lambda x: self.src_attn(x, src_states, src_states, src_mask))
        return self.sublayer[2](x, self.feed_forward)