"""
DCGAN discriminator model
based on the paper: https://arxiv.org/pdf/1511.06434.pdf
date: 30 April 2018
"""
import torch
import torch.nn as nn

import json
from easydict import EasyDict as edict
from graphs.weights_initializer import weights_init


class Discriminator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.relu = nn.LeakyReLU(self.config.relu_slope, inplace=True)

        self.conv1 = nn.Conv2d(in_channels=self.config.input_channels, out_channels=self.config.num_filt_d, kernel_size=4, stride=2, padding=1, bias=False)

        self.conv2 = nn.Conv2d(in_channels=self.config.num_filt_d, out_channels=self.config.num_filt_d * 2, kernel_size=4, stride=2, padding=1, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(self.config.num_filt_d*2)

        self.conv3 = nn.Conv2d(in_channels=self.config.num_filt_d*2, out_channels=self.config.num_filt_d * 4, kernel_size=4, stride=2, padding=1, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(self.config.num_filt_d*4)

        self.conv4 = nn.Conv2d(in_channels=self.config.num_filt_d*4, out_channels=self.config.num_filt_d*8, kernel_size=4, stride=2, padding=1, bias=False)
        self.batch_norm3 = nn.BatchNorm2d(self.config.num_filt_d*8)

        self.conv5 = nn.Conv2d(in_channels=self.config.num_filt_d*8, out_channels=1, kernel_size=4, stride=1, padding=0, bias=False)

        self.out = nn.Sigmoid()

        self.apply(weights_init)

    def forward(self, x):
        out = self.conv1(x)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.batch_norm1(out)
        out =  self.relu(out)

        out = self.conv3(out)
        out = self.batch_norm2(out)
        out =  self.relu(out)

        out = self.conv4(out)
        out = self.batch_norm3(out)
        out =  self.relu(out)

        out = self.conv5(out)
        out = self.out(out)

        return out.view(-1, 1).squeeze(1)


"""
netD testing
"""
def main():
    config = json.load(open('../../configs/dcgan_exp_0.json'))
    config = edict(config)
    inp  = torch.autograd.Variable(torch.randn(config.batch_size, config.input_channels, config.image_size, config.image_size))
    print (inp.shape)
    netD = Discriminator(config)
    out = netD(inp)
    print (out)

if __name__ == '__main__':
    main()