from torch import nn
from torch.nn import functional as F
import torch

from .base import Conv2dBnRelu, DecoderBlock
from .encoders import ResNetEncoders, SeResNetEncoders, SeResNetXtEncoders, DenseNetEncoders

"""
This script has been taken (and modified) from :
https://github.com/ternaus/TernausNet

@ARTICLE{arXiv:1801.05746,
         author = {V. Iglovikov and A. Shvets},
          title = {TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentation},
        journal = {ArXiv e-prints},
         eprint = {1801.05746}, 
           year = 2018
        }
"""


class UNetResNet(nn.Module):
    """PyTorch U-Net model using ResNet(34, 101 or 152) encoder.

    UNet: https://arxiv.org/abs/1505.04597
    ResNet: https://arxiv.org/abs/1512.03385
    Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/

    Args:
            encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
            num_classes (int): Number of output classes.
            num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
            dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
            pretrained (bool, optional):
                False - no pre-trained weights are being used.
                True  - ResNet encoder is pre-trained on ImageNet.
                Defaults to False.
            is_deconv (bool, optional):
                False: bilinear interpolation is used in decoder.
                True: deconvolution is used in decoder.
                Defaults to False.
    """

    def __init__(self, encoder_depth, num_classes, dropout_2d=0.0, pretrained=False, use_hypercolumn=False,
                 pool0=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d
        self.use_hypercolumn = use_hypercolumn

        self.encoders = ResNetEncoders(encoder_depth, pretrained=pretrained, pool0=pool0)

        if encoder_depth in [18, 34]:
            bottom_channel_nr = 512
        elif encoder_depth in [50, 101, 152]:
            bottom_channel_nr = 2048
        else:
            raise NotImplementedError('only 18, 34, 50, 101, 152 version of Resnet are implemented')

        self.center = nn.Sequential(Conv2dBnRelu(bottom_channel_nr, bottom_channel_nr),
                                    Conv2dBnRelu(bottom_channel_nr, bottom_channel_nr // 2),
                                    nn.AvgPool2d(kernel_size=2, stride=2)
                                    )

        self.dec5 = DecoderBlock(bottom_channel_nr + bottom_channel_nr // 2,
                                 bottom_channel_nr,
                                 bottom_channel_nr // 8)

        self.dec4 = DecoderBlock(bottom_channel_nr // 2 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 2,
                                 bottom_channel_nr // 8)
        self.dec3 = DecoderBlock(bottom_channel_nr // 4 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 4,
                                 bottom_channel_nr // 8)
        self.dec2 = DecoderBlock(bottom_channel_nr // 8 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 8,
                                 bottom_channel_nr // 8)
        self.dec1 = DecoderBlock(bottom_channel_nr // 8,
                                 bottom_channel_nr // 16,
                                 bottom_channel_nr // 8)

        if self.use_hypercolumn:
            self.final = nn.Sequential(Conv2dBnRelu(5 * bottom_channel_nr // 8, bottom_channel_nr // 8),
                                       nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0))
        else:
            self.final = nn.Sequential(Conv2dBnRelu(bottom_channel_nr // 8, bottom_channel_nr // 8),
                                       nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0))

    def forward(self, x):
        encoder2, encoder3, encoder4, encoder5 = self.encoders(x)
        encoder5 = F.dropout2d(encoder5, p=self.dropout_2d)

        center = self.center(encoder5)

        dec5 = self.dec5(center, encoder5)
        dec4 = self.dec4(dec5, encoder4)
        dec3 = self.dec3(dec4, encoder3)
        dec2 = self.dec2(dec3, encoder2)
        dec1 = self.dec1(dec2)

        if self.use_hypercolumn:
            dec1 = torch.cat([dec1,
                              F.upsample(dec2, scale_factor=2, mode='bilinear'),
                              F.upsample(dec3, scale_factor=4, mode='bilinear'),
                              F.upsample(dec4, scale_factor=8, mode='bilinear'),
                              F.upsample(dec5, scale_factor=16, mode='bilinear'),
                              ], 1)

        return self.final(dec1)


class UNetSeResNet(nn.Module):
    def __init__(self, encoder_depth, num_classes, dropout_2d=0.0, pretrained=False, use_hypercolumn=False,
                 pool0=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d
        self.use_hypercolumn = use_hypercolumn

        self.encoders = SeResNetEncoders(encoder_depth, pretrained=pretrained, pool0=pool0)
        bottom_channel_nr = 2048

        self.center = nn.Sequential(Conv2dBnRelu(bottom_channel_nr, bottom_channel_nr),
                                    Conv2dBnRelu(bottom_channel_nr, bottom_channel_nr // 2),
                                    nn.AvgPool2d(kernel_size=2, stride=2)
                                    )

        self.dec5 = DecoderBlock(bottom_channel_nr + bottom_channel_nr // 2,
                                 bottom_channel_nr,
                                 bottom_channel_nr // 8)

        self.dec4 = DecoderBlock(bottom_channel_nr // 2 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 2,
                                 bottom_channel_nr // 8)
        self.dec3 = DecoderBlock(bottom_channel_nr // 4 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 4,
                                 bottom_channel_nr // 8)
        self.dec2 = DecoderBlock(bottom_channel_nr // 8 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 8,
                                 bottom_channel_nr // 8)
        self.dec1 = DecoderBlock(bottom_channel_nr // 8,
                                 bottom_channel_nr // 16,
                                 bottom_channel_nr // 8)

        if self.use_hypercolumn:
            self.final = nn.Sequential(Conv2dBnRelu(5 * bottom_channel_nr // 8, bottom_channel_nr // 8),
                                       nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0))
        else:
            self.final = nn.Sequential(Conv2dBnRelu(bottom_channel_nr // 8, bottom_channel_nr // 8),
                                       nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0))

    def forward(self, x):
        encoder2, encoder3, encoder4, encoder5 = self.encoders(x)
        encoder5 = F.dropout2d(encoder5, p=self.dropout_2d)

        center = self.center(encoder5)

        dec5 = self.dec5(center, encoder5)
        dec4 = self.dec4(dec5, encoder4)
        dec3 = self.dec3(dec4, encoder3)
        dec2 = self.dec2(dec3, encoder2)
        dec1 = self.dec1(dec2)

        if self.use_hypercolumn:
            dec1 = torch.cat([dec1,
                              F.upsample(dec2, scale_factor=2, mode='bilinear'),
                              F.upsample(dec3, scale_factor=4, mode='bilinear'),
                              F.upsample(dec4, scale_factor=8, mode='bilinear'),
                              F.upsample(dec5, scale_factor=16, mode='bilinear'),
                              ], 1)

        return self.final(dec1)


class UNetSeResNetXt(nn.Module):
    def __init__(self, encoder_depth, num_classes, dropout_2d=0.0, pretrained=False, use_hypercolumn=False,
                 pool0=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d
        self.use_hypercolumn = use_hypercolumn

        self.encoders = SeResNetXtEncoders(encoder_depth, pretrained=pretrained, pool0=pool0)
        bottom_channel_nr = 2048

        self.center = nn.Sequential(Conv2dBnRelu(bottom_channel_nr, bottom_channel_nr),
                                    Conv2dBnRelu(bottom_channel_nr, bottom_channel_nr // 2),
                                    nn.AvgPool2d(kernel_size=2, stride=2)
                                    )

        self.dec5 = DecoderBlock(bottom_channel_nr + bottom_channel_nr // 2,
                                 bottom_channel_nr,
                                 bottom_channel_nr // 8)

        self.dec4 = DecoderBlock(bottom_channel_nr // 2 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 2,
                                 bottom_channel_nr // 8)
        self.dec3 = DecoderBlock(bottom_channel_nr // 4 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 4,
                                 bottom_channel_nr // 8)
        self.dec2 = DecoderBlock(bottom_channel_nr // 8 + bottom_channel_nr // 8,
                                 bottom_channel_nr // 8,
                                 bottom_channel_nr // 8)
        self.dec1 = DecoderBlock(bottom_channel_nr // 8,
                                 bottom_channel_nr // 16,
                                 bottom_channel_nr // 8)

        if self.use_hypercolumn:
            self.final = nn.Sequential(Conv2dBnRelu(5 * bottom_channel_nr // 8, bottom_channel_nr // 8),
                                       nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0))
        else:
            self.final = nn.Sequential(Conv2dBnRelu(bottom_channel_nr // 8, bottom_channel_nr // 8),
                                       nn.Conv2d(bottom_channel_nr // 8, num_classes, kernel_size=1, padding=0))

    def forward(self, x):
        encoder2, encoder3, encoder4, encoder5 = self.encoders(x)
        encoder5 = F.dropout2d(encoder5, p=self.dropout_2d)

        center = self.center(encoder5)

        dec5 = self.dec5(center, encoder5)
        dec4 = self.dec4(dec5, encoder4)
        dec3 = self.dec3(dec4, encoder3)
        dec2 = self.dec2(dec3, encoder2)
        dec1 = self.dec1(dec2)

        if self.use_hypercolumn:
            dec1 = torch.cat([dec1,
                              F.upsample(dec2, scale_factor=2, mode='bilinear'),
                              F.upsample(dec3, scale_factor=4, mode='bilinear'),
                              F.upsample(dec4, scale_factor=8, mode='bilinear'),
                              F.upsample(dec5, scale_factor=16, mode='bilinear'),
                              ], 1)

        return self.final(dec1)


class UNetDenseNet(nn.Module):
    def __init__(self, encoder_depth, num_classes, dropout_2d=0.0, pretrained=False, use_hypercolumn=False,
                 pool0=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d
        self.use_hypercolumn = use_hypercolumn

        self.encoders = DenseNetEncoders(encoder_depth, pretrained=pretrained, pool0=pool0)
        if encoder_depth == 121:
            encoder_channel_nr = [256, 512, 1024, 1024]
        elif encoder_depth == 161:
            encoder_channel_nr = [384, 768, 2112, 2208]
        elif encoder_depth == 169:
            encoder_channel_nr = [256, 512, 1280, 1664]
        elif encoder_depth == 201:
            encoder_channel_nr = [256, 512, 1792, 1920]
        else:
            raise NotImplementedError('only 121, 161, 169, 201 version of Densenet are implemented')

        self.center = nn.Sequential(Conv2dBnRelu(encoder_channel_nr[3], encoder_channel_nr[3]),
                                    Conv2dBnRelu(encoder_channel_nr[3], encoder_channel_nr[2]),
                                    nn.AvgPool2d(kernel_size=2, stride=2)
                                    )

        self.dec5 = DecoderBlock(encoder_channel_nr[3] + encoder_channel_nr[2],
                                 encoder_channel_nr[3],
                                 encoder_channel_nr[3] // 8)

        self.dec4 = DecoderBlock(encoder_channel_nr[2] + encoder_channel_nr[3] // 8,
                                 encoder_channel_nr[3] // 2,
                                 encoder_channel_nr[3] // 8)
        self.dec3 = DecoderBlock(encoder_channel_nr[1] + encoder_channel_nr[3] // 8,
                                 encoder_channel_nr[3] // 4,
                                 encoder_channel_nr[3] // 8)
        self.dec2 = DecoderBlock(encoder_channel_nr[0] + encoder_channel_nr[3] // 8,
                                 encoder_channel_nr[3] // 8,
                                 encoder_channel_nr[3] // 8)
        self.dec1 = DecoderBlock(encoder_channel_nr[3] // 8,
                                 encoder_channel_nr[3] // 16,
                                 encoder_channel_nr[3] // 8)

        if self.use_hypercolumn:
            self.final = nn.Sequential(Conv2dBnRelu(5 * encoder_channel_nr[3] // 8, encoder_channel_nr[3] // 8),
                                       nn.Conv2d(encoder_channel_nr[3] // 8, num_classes, kernel_size=1, padding=0))
        else:
            self.final = nn.Sequential(Conv2dBnRelu(encoder_channel_nr[3] // 8, encoder_channel_nr[3] // 8),
                                       nn.Conv2d(encoder_channel_nr[3] // 8, num_classes, kernel_size=1, padding=0))

    def forward(self, x):
        encoder2, encoder3, encoder4, encoder5 = self.encoders(x)
        encoder5 = F.dropout2d(encoder5, p=self.dropout_2d)

        center = self.center(encoder5)

        dec5 = self.dec5(center, encoder5)
        dec4 = self.dec4(dec5, encoder4)
        dec3 = self.dec3(dec4, encoder3)
        dec2 = self.dec2(dec3, encoder2)
        dec1 = self.dec1(dec2)

        if self.use_hypercolumn:
            dec1 = torch.cat([dec1,
                              F.upsample(dec2, scale_factor=2, mode='bilinear'),
                              F.upsample(dec3, scale_factor=4, mode='bilinear'),
                              F.upsample(dec4, scale_factor=8, mode='bilinear'),
                              F.upsample(dec5, scale_factor=16, mode='bilinear'),
                              ], 1)

        return self.final(dec1)