#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import unittest
from typing import Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchdp import PrivacyEngine, utils
from torchvision import models, transforms
from torchvision.datasets import FakeData


class utils_replace_all_modules_test(unittest.TestCase):
    def checkModulePresent(self, root: nn.Module, targetclass):
        result = False
        for module in root.modules():
            result |= isinstance(module, targetclass)
        self.assertTrue(result)

    def checkModuleNotPresent(self, root: nn.Module, targetclass):
        for module in root.modules():
            self.assertFalse(
                isinstance(module, targetclass),
                msg=f"{module} has the given targetclass type",
            )

    def test_replace_basic_case(self):
        model = nn.BatchNorm1d(10)
        model = utils.replace_all_modules(
            model, nn.BatchNorm1d, lambda _: nn.BatchNorm2d(10)
        )
        self.checkModulePresent(model, nn.BatchNorm2d)
        self.checkModuleNotPresent(model, nn.BatchNorm1d)

    def test_replace_sequential_case(self):
        model = nn.Sequential(nn.Conv1d(1, 2, 3), nn.Sequential(nn.Conv2d(4, 5, 6)))

        def conv(m: nn.Conv2d):
            return nn.Linear(4, 5)

        model = utils.replace_all_modules(model, nn.Conv2d, conv)
        self.checkModulePresent(model, nn.Linear)
        self.checkModuleNotPresent(model, nn.Conv2d)

    def test_nullify_resnet18(self):
        model = models.resnet18()
        # check module BatchNorms is there
        self.checkModulePresent(model, nn.BatchNorm2d)
        # nullify the module (replace with Idetity)
        model = utils.nullify_batchnorm_modules(model, nn.BatchNorm2d)
        # check module is not present
        self.checkModuleNotPresent(model, nn.BatchNorm2d)

    def test_convert_batchnorm_modules_resnet50(self):
        model = models.resnet50()
        # check module BatchNorms is there
        self.checkModulePresent(model, nn.BatchNorm2d)
        # replace the module with instancenorm
        model = utils.convert_batchnorm_modules(model)
        # check module is not present
        self.checkModuleNotPresent(model, nn.BatchNorm2d)
        self.checkModulePresent(model, nn.GroupNorm)


class BasicModel(nn.Module):
    def __init__(self, imgSize):
        super().__init__()
        self.size = imgSize[0] * imgSize[1] * imgSize[2]
        self.bn = nn.BatchNorm2d(imgSize[0])
        self.fc = nn.Linear(self.size, 2)

    def forward(self, input):
        x = self.bn(input)
        x = x.view(-1, self.size)
        x = self.fc(x)
        return x


class utils_convert_batchnorm_modules_test(unittest.TestCase):
    def setUp(self):
        self.criterion = nn.CrossEntropyLoss()

    def setUpOptimizer(
        self, model: nn.Module, data_loader: DataLoader, privacy_engine: bool = False
    ):
        # sample parameter values
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-2, momentum=0.9)
        optimizer.zero_grad()
        if privacy_engine:
            pe = PrivacyEngine(
                model,
                batch_size=data_loader.batch_size,
                sample_size=len(data_loader.dataset),
                alphas=[1 + x / 10.0 for x in range(1, 100)] + list(range(12, 64)),
                noise_multiplier=1.3,
                max_grad_norm=1,
            )
            pe.attach(optimizer)
        return optimizer

    def genFakeData(
        self, imgSize: Tuple[int, int, int], batch_size: int = 1, num_batches: int = 1
    ) -> DataLoader:
        self.ds = FakeData(
            size=num_batches,
            image_size=imgSize,
            num_classes=2,
            transform=transforms.Compose([transforms.ToTensor()]),
        )
        return DataLoader(self.ds, batch_size=batch_size)

    def runOneBatch(
        self,
        model: nn.Module,
        imgsize: Tuple[int, int, int],
        privacy_engine: bool = True,
    ):
        dl = self.genFakeData(imgsize, 1, 1)
        optimizer = self.setUpOptimizer(model, dl, privacy_engine)
        for x, y in dl:
            # forward
            try:
                logits = model(x)
            except Exception as err:
                self.fail(f"Failed forward step with exception: {err}")
            loss = self.criterion(logits, y)
            # backward
            try:
                loss.backward()
            except Exception as err:
                self.fail(f"Failed backward step with exception: {err}")
            # optimizer
            try:
                optimizer.step()
            except Exception as err:
                self.fail(f"Failed optimizer step with exception: {err}")
        optimizer.zero_grad()

    def test_run_basic_case(self):
        imgSize = (3, 4, 5)
        # should throw because privacy engine does not work with batch norm
        # remove the next two lines when we support batch norm
        with self.assertRaises(Exception):
            self.runOneBatch(BasicModel(imgSize), imgSize)
        self.runOneBatch(
            utils.convert_batchnorm_modules(BasicModel(imgSize)), imgSize)

    def test_run_resnet18(self):
        imgSize = (3, 224, 224)
        # should throw because privacy engine does not work with batch norm
        # remove the next two lines when we support batch norm
        with self.assertRaises(Exception):
            self.runOneBatch(models.resnet18(), imgSize)
        self.runOneBatch(
            utils.convert_batchnorm_modules(models.resnet18()), imgSize)

    def test_run_resnet34(self):
        imgSize = (3, 224, 224)
        # should throw because privacy engine does not work with batch norm
        # remove the next two lines when we support batch norm
        with self.assertRaises(Exception):
            self.runOneBatch(models.resnet34(), imgSize)
        self.runOneBatch(
            utils.convert_batchnorm_modules(models.resnet34()), imgSize)

    def test_run_resnet50(self):
        imgSize = (3, 224, 224)
        # should throw because privacy engine does not work with batch norm
        # remove the next two lines when we support batch norm
        with self.assertRaises(Exception):
            self.runOneBatch(models.resnet50(), imgSize)
        self.runOneBatch(utils.convert_batchnorm_modules(models.resnet50()), imgSize)

    def test_run_resnet101(self):
        imgSize = (3, 224, 224)
        # should throw because privacy engine does not work with batch norm
        # remove the next two lines when we support batch norm
        with self.assertRaises(Exception):
            self.runOneBatch(models.resnet101(), imgSize)
        self.runOneBatch(
            utils.convert_batchnorm_modules(models.resnet101()), imgSize)


class utils_ModelInspector_test(unittest.TestCase):
    def setUp(self):
        def pred_supported(module):
            return isinstance(module, (nn.Conv2d, nn.Linear))

        def pred_not_unsupported(module):
            return not isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d))

        def pred_requires_grad(module):
            requires_grad = True
            for p in module.parameters(recurse=False):
                requires_grad &= p.requires_grad
            return requires_grad

        self.pred_supported = pred_supported
        self.pred_not_unsupported = pred_not_unsupported
        self.pred_mix = lambda m: (not pred_requires_grad(m)) | pred_not_unsupported(m)

    def test_validate_basic(self):
        inspector = utils.ModelInspector(
            'pred', lambda model: isinstance(model, nn.Linear)
        )
        model = nn.Conv1d(1, 1, 1)
        valid = inspector.validate(model)
        self.assertFalse(valid, inspector.violators)

    def test_validate_positive_predicate_valid(self):
        # test when a positive predicate (e.g. supported) returns true
        inspector = utils.ModelInspector('pred', self.pred_supported)
        model = nn.Conv2d(1, 1, 1)
        valid = inspector.validate(model)
        self.assertTrue(valid)
        list_len = len(inspector.violators)
        self.assertEqual(list_len, 0, f'violators = {inspector.violators}')

    def test_validate_positive_predicate_invalid(self):
        # test when a positive predicate (e.g. supported) returns false
        inspector = utils.ModelInspector('pred', self.pred_supported)
        model = nn.Conv1d(1, 1, 1)
        valid = inspector.validate(model)
        self.assertFalse(valid)
        list_len = len(inspector.violators)
        self.assertEqual(list_len, 1, f'violators = {inspector.violators}')

    def test_validate_negative_predicate_ture(self):
        # test when a negative predicate (e.g. not unsupported) returns true
        inspector = utils.ModelInspector('pred1', self.pred_not_unsupported)
        model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.Linear(1, 1))
        valid = inspector.validate(model)
        self.assertTrue(valid)
        list_len = len(inspector.violators)
        self.assertEqual(list_len, 0)

    def test_validate_negative_predicate_False(self):
        # test when a negative predicate (e.g. not unsupported) returns false
        inspector = utils.ModelInspector('pred', self.pred_not_unsupported)
        model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1))
        valid = inspector.validate(model)
        self.assertFalse(valid)
        list_len = len(inspector.violators)
        self.assertEqual(list_len, 1, f'violators = {inspector.violators}')

    def test_validate_mix_predicate(self):
        # check with a mix predicate not requires grad or is not unsupported
        inspector = utils.ModelInspector('pred1', self.pred_mix)
        model = nn.Sequential(nn.Conv2d(1, 1, 1), nn.BatchNorm2d(1))
        for p in model[1].parameters():
            p.requires_grad = False
        valid = inspector.validate(model)
        self.assertTrue(valid)

    def test_check_everything_flag(self):
        # check to see if a model does not containt nn.sequential
        inspector = utils.ModelInspector(
            'pred', lambda model: not isinstance(model, nn.Sequential),
            check_leaf_nodes_only=False
        )
        model = nn.Sequential(nn.Conv1d(1, 1, 1))
        valid = inspector.validate(model)
        self.assertFalse(
            valid, f'violators = {inspector.violators}')

    def test_complicated_case(self):
        def good(x):
            return isinstance(x, (nn.Conv2d, nn.Linear))

        def bad(x):
            return isinstance(x, nn.modules.batchnorm._BatchNorm)

        inspector1 = utils.ModelInspector(
            'good_or_bad', lambda x: good(x) | bad(x))
        inspector2 = utils.ModelInspector(
            'not_bad', lambda x: not bad(x))
        model = models.resnet50()
        valid = inspector1.validate(model)
        self.assertTrue(valid, f'violators = {inspector1.violators}')
        self.assertEqual(
            len(inspector1.violators),
            0,
            f'violators = {inspector1.violators}')
        valid = inspector2.validate(model)
        self.assertFalse(valid, f'violators = {inspector2.violators}')
        self.assertEqual(
            len(inspector2.violators),
            53,
            f'violators = {inspector2.violators}')