from torch import nn
from torch.nn import functional as F

from .base import GlobalConvolutionalNetwork, BoundaryRefinement, DeconvConv2dBnRelu
from .encoders import get_encoder_channel_nr


class LargeKernelMatters(nn.Module):
    """PyTorch LKM model using ResNet(18, 34, 50, 101 or 152) encoder.

        https://arxiv.org/pdf/1703.02719.pdf
    """

    def __init__(self, encoder, num_classes, kernel_size=9, internal_channels=21, use_relu=False, pool0=False,
                 use_channel_se=False, use_spatial_se=False, reduction_se=4, dropout_2d=0.0):
        super().__init__()

        self.dropout_2d = dropout_2d
        self.pool0 = pool0

        self.encoder = encoder
        encoder_channel_nr = get_encoder_channel_nr(self.encoder)

        self.gcn2 = GlobalConvolutionalNetwork(in_channels=encoder_channel_nr[0],
                                               out_channels=internal_channels,
                                               kernel_size=kernel_size,
                                               use_relu=use_relu)
        self.gcn3 = GlobalConvolutionalNetwork(in_channels=encoder_channel_nr[1],
                                               out_channels=internal_channels,
                                               kernel_size=kernel_size,
                                               use_relu=use_relu)
        self.gcn4 = GlobalConvolutionalNetwork(in_channels=encoder_channel_nr[2],
                                               out_channels=internal_channels,
                                               kernel_size=kernel_size,
                                               use_relu=use_relu)
        self.gcn5 = GlobalConvolutionalNetwork(in_channels=encoder_channel_nr[3],
                                               out_channels=internal_channels,
                                               kernel_size=kernel_size,
                                               use_relu=use_relu)
        self.enc_br2 = BoundaryRefinement(in_channels=internal_channels,
                                          out_channels=internal_channels,
                                          kernel_size=3)

        self.enc_br3 = BoundaryRefinement(in_channels=internal_channels,
                                          out_channels=internal_channels,
                                          kernel_size=3)

        self.enc_br4 = BoundaryRefinement(in_channels=internal_channels,
                                          out_channels=internal_channels,
                                          kernel_size=3)

        self.enc_br5 = BoundaryRefinement(in_channels=internal_channels,
                                          out_channels=internal_channels,
                                          kernel_size=3)

        self.dec_br1 = BoundaryRefinement(in_channels=internal_channels,
                                          out_channels=internal_channels,
                                          kernel_size=3)

        self.dec_br2 = BoundaryRefinement(in_channels=internal_channels,
                                          out_channels=internal_channels,
                                          kernel_size=3)

        self.dec_br3 = BoundaryRefinement(in_channels=internal_channels,
                                          out_channels=internal_channels,
                                          kernel_size=3)

        self.dec_br4 = BoundaryRefinement(in_channels=internal_channels,
                                          out_channels=internal_channels,
                                          kernel_size=3)

        self.deconv5 = DeconvConv2dBnRelu(in_channels=internal_channels, out_channels=internal_channels,
                                          use_channel_se=use_channel_se, use_spatial_se=use_spatial_se,
                                          reduction=reduction_se)
        self.deconv4 = DeconvConv2dBnRelu(in_channels=internal_channels, out_channels=internal_channels,
                                          use_channel_se=use_channel_se, use_spatial_se=use_spatial_se,
                                          reduction=reduction_se)
        self.deconv3 = DeconvConv2dBnRelu(in_channels=internal_channels, out_channels=internal_channels,
                                          use_channel_se=use_channel_se, use_spatial_se=use_spatial_se,
                                          reduction=reduction_se)
        self.deconv2 = DeconvConv2dBnRelu(in_channels=internal_channels, out_channels=internal_channels,
                                          use_channel_se=use_channel_se, use_spatial_se=use_spatial_se,
                                          reduction=reduction_se)
        self.deconv1 = DeconvConv2dBnRelu(in_channels=internal_channels, out_channels=internal_channels,
                                          use_channel_se=use_channel_se, use_spatial_se=use_spatial_se,
                                          reduction=reduction_se)
        self.dec_br0_1 = BoundaryRefinement(in_channels=internal_channels,
                                            out_channels=internal_channels,
                                            kernel_size=3)
        self.dec_br0_2 = BoundaryRefinement(in_channels=internal_channels,
                                            out_channels=internal_channels,
                                            kernel_size=3)

        self.final = nn.Conv2d(internal_channels, num_classes, kernel_size=1, padding=0)

    def forward(self, x):
        encoder2, encoder3, encoder4, encoder5 = self.encoder(x)
        encoder5 = F.dropout2d(encoder5, p=self.dropout_2d)

        gcn2 = self.enc_br2(self.gcn2(encoder2))
        gcn3 = self.enc_br3(self.gcn3(encoder3))
        gcn4 = self.enc_br4(self.gcn4(encoder4))
        gcn5 = self.enc_br5(self.gcn5(encoder5))

        decoder5 = self.deconv5(gcn5)
        decoder4 = self.deconv4(self.dec_br4(decoder5 + gcn4))
        decoder3 = self.deconv3(self.dec_br3(decoder4 + gcn3))
        decoder2 = self.dec_br1(self.deconv2(self.dec_br2(decoder3 + gcn2)))

        if self.pool0:
            decoder2 = self.dec_br0_2(self.deconv1(self.dec_br0_1(decoder2)))

        return self.final(decoder2)