# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import torch
import torch.nn as nn
import torch.nn.functional as F

# import torch.utils.checkpoint as cp

from fairseq.modules import (
    MaskedConvolution, MultiheadMaskedConvolution
)


class PAGateNet(nn.Module):
    """ A network of convolutional layers"""

    def __init__(self, num_init_features, args):
        super().__init__()
        num_layers = args.num_layers
        kernel_size = args.kernel_size
        num_features = num_init_features
        self.reduce_channels = Linear(num_features, num_features // args.divide_channels) if args.divide_channels > 1 else None
        num_features = num_features // args.divide_channels
        self.output_channels = num_features
        self.gate_channels = args.gate_channels
        self.gates_ffn = nn.ModuleList([])
        self.gates_attn = nn.ModuleList([])

        self.blocks = nn.ModuleList([])

        self.gate_embeddings = _GateLayer(num_features)
        self.depth_gate = _GateLayer(num_features)

        for _ in range(num_layers):
            self.blocks.append(_Layer(num_features, kernel_size, args))
            self.gates_attn.append(_GateLayer(num_features))
            self.gates_ffn.append(_GateLayer(num_features))

        
    def forward(self, x, 
                encoder_mask=None,
                decoder_mask=None,
                incremental_state=None):
        """
        Input : N, Tt, Ts, C
        Output : N, Tt, Ts, C
        """
        if self.reduce_channels is not None:
            x = self.reduce_channels(x)

        features = self.gate_embeddings(x)
        for layer, gate_ffn, gate_attn in zip(self.blocks, self.gates_ffn, self.gates_attn):
            xffn, xattn = layer(x,
                             encoder_mask=encoder_mask,
                             decoder_mask=decoder_mask,
                             incremental_state=incremental_state)
            features += gate_attn(xattn)
            features += gate_ffn(xffn)
            x = self.depth_gate(x + xattn + xffn)
        return features


class _GateLayer(nn.Module):
    def __init__(self, num_features):
        super().__init__()
        self.linear = Linear(num_features, 2*num_features)

    def forward(self, x):
        return F.glu(self.linear(x), dim=-1)


class _Layer(nn.Module):
    """ Single layer

    num_input_features - number of input channels to the layer
    kernel_size - size of masked convolution, k x (k // 2)
    drop_rate - dropout rate
    """

    def __init__(self, num_features, kernel_size, args):
        super().__init__()
        self.zero_out = args.zero_out_conv_input
        self.drop_rate = args.convolution_dropout
        ffn_dim = args.ffn_dim
        mid_features = args.reduce_dim
        stride = args.conv_stride  
        dilsrc = args.source_dilation
        diltrg = args.target_dilation
        resolution = args.maintain_resolution
        if resolution:
            if not stride == 1:
                raise ValueError('Could not maintain the resolution with stride=%d' % stride)

            # choose the padding accordingly:
            padding_trg = diltrg * (kernel_size - 1) // 2
            padding_src = dilsrc * (kernel_size - 1) // 2
            padding = (padding_trg, padding_src)
        else:
            # must maintain the target resolution:
            padding = (diltrg * (kernel_size - 1) // 2, 0)

        # Reduce dim should be dividible by groups
        self.conv1 = nn.Conv2d(num_features,
                               mid_features,
                               kernel_size=1,
                               stride=1,
                               bias=args.conv_bias)

        self.mconv2 = MaskedConvolution(
            mid_features, num_features,
            kernel_size, args,
            padding=padding,
        )
        self.fc1 = Linear(num_features, ffn_dim)
        self.fc2 = Linear(ffn_dim, num_features)

    def forward(self, x, 
                encoder_mask=None,
                decoder_mask=None,
                incremental_state=None):
        x = x.permute(0, 3, 1, 2)
        # Zero out the conv input
        if self.zero_out and self.training:
            if encoder_mask is not None:
                x = x.masked_fill(encoder_mask.unsqueeze(1).unsqueeze(1), 0)
            if decoder_mask is not None:
                x = x.masked_fill(decoder_mask.unsqueeze(1).unsqueeze(-1), 0)

        # Depthwise separable convolution
        x = self.conv1(x)
        x = self.mconv2(x, incremental_state)
        if self.drop_rate:
            x = F.dropout(x, p=self.drop_rate, training=self.training)

        x = x.permute(0, 2, 3, 1)
        xattn = x
        # FFN:
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        if self.drop_rate:
            x = F.dropout(x, p=self.drop_rate, training=self.training)
        return x, xattn


def Linear(in_features, out_features, bias=True):
    m = nn.Linear(in_features, out_features, bias)
    nn.init.xavier_uniform_(m.weight)
    if bias:
        nn.init.constant_(m.bias, 0.)
    return m