import torch

from unittest import TestCase
from vdb import Gan_networks as gns

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class TestResnetBlock(TestCase):

    def setUp(self):
        # in this block shortcut is learnable
        self.resBlock_1 = gns.ResnetBlock(fin=21, fout=79).to(device)

        # in this block shortcut is not learnable
        self.resBlock_2 = gns.ResnetBlock(fin=69, fout=69).to(device)

        # print the Resblocks
        print("\nResblock 1:\n%s" % str(self.resBlock_1))
        print("\nResblock 2:\n%s" % str(self.resBlock_2))

    def test_forward(self):
        # test the forward for the first res block
        mock_in = torch.randn(32, 21, 16, 16).to(device)
        mock_out = self.resBlock_1(mock_in)
        self.assertEqual(mock_out.shape, (32, 79, 16, 16))
        self.assertEqual(torch.isnan(mock_out).sum().item(), 0)
        self.assertEqual(torch.isinf(mock_out).sum().item(), 0)

        # test the forward for the second res block
        mock_in = torch.randn(32, 69, 16, 16).to(device)
        mock_out = self.resBlock_2(mock_in)
        self.assertEqual(mock_out.shape, (32, 69, 16, 16))
        self.assertEqual(torch.isnan(mock_out).sum().item(), 0)
        self.assertEqual(torch.isinf(mock_out).sum().item(), 0)

    def tearDown(self):
        # delete all the computational blocks
        del self.resBlock_1, self.resBlock_2


class TestDiscriminator(TestCase):

    def setUp(self):
        # edge case discriminator:
        self.dis_edge = gns.Discriminator(size=4).to(device)

        # normal case discriminator:
        self.dis = gns.Discriminator(size=256,
                                     num_filters=64,
                                     max_filters=512).to(device)

        # print some information:
        print("\nDiscriminator 1:\n%s" % str(self.dis_edge))
        print("\nDiscriminator 2:\n%s" % str(self.dis))

    def test_forward(self):
        # test the edge discriminator:
        mock_in = torch.randn(3, 3, 4, 4).to(device)
        for mean_mode in (True, False):
            mock_out1, mock_out2, mock_out3 = self.dis_edge(mock_in, mean_mode)

            # check the shapes of all the three:
            self.assertEqual(mock_out1.shape, (3, 1))
            self.assertEqual(mock_out2.shape, (3, 32))
            self.assertEqual(mock_out3.shape, (3, 32))
            self.assertGreaterEqual(mock_out3.min().item(), 0)
            self.assertEqual(torch.isnan(mock_out1).sum().item(), 0)
            self.assertEqual(torch.isinf(mock_out1).sum().item(), 0)

        # test the normal discriminator:
        mock_in = torch.randn(16, 3, 256, 256).to(device)
        for mean_mode in (True, False):
            mock_out1, mock_out2, mock_out3 = self.dis(mock_in, mean_mode)

            # check the shapes of all the three:
            self.assertEqual(mock_out1.shape, (16, 1))
            self.assertEqual(mock_out2.shape, (16, 256))
            self.assertEqual(mock_out3.shape, (16, 256))
            self.assertGreaterEqual(mock_out3.min().item(), 0)
            self.assertEqual(torch.isnan(mock_out1).sum().item(), 0)
            self.assertEqual(torch.isinf(mock_out1).sum().item(), 0)

    def tearDown(self):
        # delete all the computational blocks
        del self.dis_edge, self.dis


class TestGenerator(TestCase):

    def setUp(self):
        # edge case generator:
        self.gen_edge = gns.Generator(z_dim=128, size=4).to(device)

        # normal case generator:
        self.gen = gns.Generator(z_dim=8, size=256,
                                 final_channels=64,
                                 max_channels=512).to(device)

        # print some information:
        print("\nGenerator 1:\n%s" % str(self.gen_edge))
        print("\nGenerator 2:\n%s" % str(self.gen))

    def test_forward(self):
        # test the edge discriminator:
        mock_in = torch.randn(3, 128).to(device)

        mock_out = self.gen_edge(mock_in)

        # check the shapes of all the three:
        self.assertEqual(mock_out.shape, (3, 3, 4, 4))
        self.assertEqual(torch.isnan(mock_out).sum().item(), 0)
        self.assertEqual(torch.isinf(mock_out).sum().item(), 0)

        # test the normal discriminator:
        mock_in = torch.randn(16, 8).to(device)

        mock_out = self.gen(mock_in)

        # check the shapes of all the three:
        self.assertEqual(mock_out.shape, (16, 3, 256, 256))
        self.assertEqual(torch.isnan(mock_out).sum().item(), 0)
        self.assertEqual(torch.isinf(mock_out).sum().item(), 0)

    def tearDown(self):
        # delete all the computational blocks
        del self.gen_edge, self.gen