import torch

from torch import nn
from torch.nn import functional as F


class GatedConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation=1):
        super(GatedConv2d, self).__init__()

        self.conv = nn.Conv2d(in_channels, 2 * out_channels, kernel_size,
                              stride, padding, dilation)

    def forward(self, inputs):
        temps = self.conv(inputs)
        outputs = F.glu(temps, dim=1)
        return outputs


class GatedConvTranspose2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride, padding, output_padding=0, dilation=1):
        super(GatedConvTranspose2d, self).__init__()

        self.conv_transpose = nn.ConvTranspose2d(in_channels, 2 * out_channels,
                                                 kernel_size, stride, padding,
                                                 output_padding, dilation=dilation)

    def forward(self, inputs):
        temps = self.conv_transpose(inputs)
        outputs = F.glu(temps, dim=1)
        return outputs


class SylvesterFlowConvEncoderNet(nn.Module):
    def __init__(self, context_features, last_kernel_shape=(7, 7)):
        super().__init__()
        self.context_features = context_features
        self.last_kernel_shape = last_kernel_shape

        self.gated_conv_layers = nn.ModuleList([
            GatedConv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=5,
                padding=2,
                stride=1
            ),
            GatedConv2d(  # 2
                in_channels=32,
                out_channels=32,
                kernel_size=5,
                padding=2,
                stride=2
            ),
            GatedConv2d(  # 3
                in_channels=32,
                out_channels=64,
                kernel_size=5,
                padding=2,
                stride=1
            ),
            GatedConv2d(  # 4
                in_channels=64,
                out_channels=64,
                kernel_size=5,
                padding=2,
                stride=2
            ),
            GatedConv2d(  # 5
                in_channels=64,
                out_channels=64,
                kernel_size=5,
                padding=2,
                stride=1
            ),
            GatedConv2d(  # 6
                in_channels=64,
                out_channels=64,
                kernel_size=5,
                padding=2,
                stride=1
            ),
            GatedConv2d(  # 7
                in_channels=64,
                out_channels=256,
                kernel_size=self.last_kernel_shape,
                padding=0,
                stride=1
            )
        ])

        self.context_layer = nn.Linear(
            in_features=256,
            out_features=self.context_features
        )

    def forward(self, inputs):
        batch_size = inputs.shape[0]
        temps = inputs
        del inputs
        for gated_conv in self.gated_conv_layers:
            temps = gated_conv(temps)
        outputs = self.context_layer(temps.reshape(batch_size, -1))
        del temps
        return outputs


class SylvesterFlowConvDecoderNet(nn.Module):
    def __init__(self, latent_features, last_kernel_shape=(7, 7)):
        super().__init__()
        self.latent_features = latent_features
        self.last_kernel_shape = last_kernel_shape

        self.gated_conv_transpose_layers = nn.ModuleList([
            GatedConvTranspose2d(
                in_channels=self.latent_features,
                out_channels=64,
                kernel_size=self.last_kernel_shape,
                padding=0,
                stride=1
            ),
            GatedConvTranspose2d(  # 2
                in_channels=64,
                out_channels=64,
                kernel_size=5,
                padding=2,
                stride=1
            ),
            GatedConvTranspose2d(  # 3
                in_channels=64,
                out_channels=32,
                kernel_size=5,
                padding=2,
                stride=2,
                output_padding=1
            ),
            GatedConvTranspose2d(  # 4
                in_channels=32,
                out_channels=32,
                kernel_size=5,
                padding=2,
                stride=1
            ),
            GatedConvTranspose2d(  # 5
                in_channels=32,
                out_channels=32,
                kernel_size=5,
                padding=2,
                stride=2,
                output_padding=1
            ),
            GatedConv2d(  # 6
                in_channels=32,
                out_channels=32,
                kernel_size=5,
                padding=2,
                stride=1
            ),
            GatedConv2d(  # 7
                in_channels=32,
                out_channels=1,
                kernel_size=1,
                padding=0,
                stride=1
            )
        ])

    def forward(self, inputs):
        temps = inputs[..., None, None]
        del inputs
        for gated_conv_transpose in self.gated_conv_transpose_layers:
            temps = gated_conv_transpose(temps)
        outputs = temps
        del temps
        return outputs


class ResidualBlock(nn.Module):
    def __init__(self, in_channels, resample=None, activation=F.relu,
                 dropout_probability=0., first=False):
        super().__init__()
        self.in_channels = in_channels
        self.resample = resample
        self.activation = activation

        self.residual_layer_1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=3,
            padding=1
        )

        if resample is None:
            self.shortcut_layer = nn.Identity()
            self.residual_2_layer = nn.Conv2d(
                in_channels=in_channels,
                out_channels=in_channels,
                kernel_size=3,
                padding=1
            )
        elif resample == 'down':
            self.shortcut_layer = nn.Conv2d(
                in_channels=in_channels,
                out_channels=2 * in_channels,
                kernel_size=3,
                stride=2,
                padding=1
            )
            self.residual_2_layer = nn.Conv2d(
                in_channels=in_channels,
                out_channels=2 * in_channels,
                kernel_size=3,
                stride=2,
                padding=1
            )
        elif resample == 'up':
            self.shortcut_layer = nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=in_channels // 2,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=0 if first else 1
            )
            self.residual_2_layer = nn.ConvTranspose2d(
                in_channels=in_channels,
                out_channels=in_channels // 2,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=0 if first else 1
            )

        if dropout_probability > 0:
            self.dropout = nn.Dropout(dropout_probability)
        else:
            self.dropout = None

    def forward(self, inputs):

        shortcut = self.shortcut_layer(inputs)
        residual_1 = self.activation(inputs)
        residual_1 = self.residual_layer_1(residual_1)
        if self.dropout is not None:
            residual_1 = self.dropout(residual_1)
        residual_2 = self.activation(residual_1)
        residual_2 = self.residual_2_layer(residual_2)

        return shortcut + residual_2


class ConvEncoder(nn.Module):
    def __init__(self, context_features, channels_multiplier,
                 activation=F.relu, dropout_probability=0.):
        super().__init__()
        self.context_features = context_features
        self.channels_multiplier = channels_multiplier
        self.activation = activation

        self.initial_layer = nn.Conv2d(1, channels_multiplier, kernel_size=1)
        self.residual_blocks = nn.ModuleList([
            ResidualBlock(in_channels=channels_multiplier,
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier, resample='down',
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 2,
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 2, resample='down',
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 4,
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 4, resample='down',
                          dropout_probability=dropout_probability)
        ])
        self.final_layer = nn.Linear(
            in_features=(4 * 4 * channels_multiplier * 8),
            out_features=context_features
        )

    def forward(self, inputs):
        temps = self.initial_layer(inputs)
        for residual_block in self.residual_blocks:
            temps = residual_block(temps)
        temps = self.activation(temps)
        outputs = self.final_layer(temps.reshape(-1, 4 * 4 * self.channels_multiplier * 8))
        return outputs


class ConvDecoder(nn.Module):
    def __init__(self, latent_features, channels_multiplier,
                 activation=F.relu, dropout_probability=0.):
        super().__init__()
        self.latent_features = latent_features
        self.channels_multiplier = channels_multiplier
        self.activation = activation

        self.initial_layer = nn.Linear(
            in_features=latent_features,
            out_features=(4 * 4 * channels_multiplier * 8)
        )
        self.residual_blocks = nn.ModuleList([
            ResidualBlock(in_channels=channels_multiplier * 8,
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 8, resample='up', first=True,
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 4,
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 4, resample='up',
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 2,
                          dropout_probability=dropout_probability),
            ResidualBlock(in_channels=channels_multiplier * 2, resample='up',
                          dropout_probability=dropout_probability)
        ])
        self.final_layer = nn.Conv2d(
            in_channels=channels_multiplier,
            out_channels=1,
            kernel_size=1
        )

    def forward(self, inputs):
        temps = self.initial_layer(inputs).reshape(
            -1, self.channels_multiplier * 8, 4, 4
        )
        for residual_block in self.residual_blocks:
            temps = residual_block(temps)
        temps = self.activation(temps)
        outputs = self.final_layer(temps)
        return outputs


def main():
    batch_size, channels, width, height = 16, 1, 28, 28
    inputs = torch.rand(batch_size, channels, width, height)

    net = ConvEncoder(context_features=24, channels_multiplier=16)
    outputs = net(inputs)

    net = ConvDecoder(latent_features=24, channels_multiplier=16)
    outputs = net(outputs)


if __name__ == '__main__':
    main()