# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for higher.optim."""
import unittest
import copy
from parameterized import parameterized

import torch
from torch import nn, optim

import higher

_test_param_sweep = [
    (
        "simple_model_sgd",
        lambda self: self._model,
        optim.SGD,
    ),
    (
        "simple_model_sgd_momentum",
        lambda self: self._model,
        optim.SGD,
        {
            'momentum': .1
        },
    ),
    (
        "simple_model_sgd_nesterov",
        lambda self: self._model,
        optim.SGD,
        {
            'momentum': .1,
            'nesterov': True
        },
    ),
    (
        "simple_model_sgd_weight_decay",
        lambda self: self._model,
        optim.SGD,
        {
            'weight_decay': .1
        },
    ),
    (
        "share_weight_model_sgd",
        lambda self: self._shared_param_model,
        optim.SGD,
    ),
    (
        "share_weight_seq_model_sgd",
        lambda self: self._shared_param_seq_model,
        optim.SGD,
    ),
    (
        "partially_used_model_sgd",
        lambda self: self._partially_used_model,
        optim.SGD,
    ),
    (
        "simple_model_adam",
        lambda self: self._model,
        optim.Adam,
    ),
    (
        "simple_model_adam_weight_decay",
        lambda self: self._model,
        optim.Adam,
        {
            "weight_decay": 0.1
        },
    ),
    (
        "share_weight_model_adam",
        lambda self: self._shared_param_model,
        optim.Adam,
    ),
    (
        "share_weight_seq_model_adam",
        lambda self: self._shared_param_seq_model,
        optim.Adam,
    ),
    (
        "partially_used_model_adam",
        lambda self: self._partially_used_model,
        optim.Adam,
    ),
    (
        "simple_model_adadelta",
        lambda self: self._model,
        optim.Adadelta,
    ),
    (
        "simple_model_adadelta_weight_decay",
        lambda self: self._model,
        optim.Adadelta,
        {
            "weight_decay": 0.1
        },
    ),
    (
        "share_weight_model_adadelta",
        lambda self: self._shared_param_model,
        optim.Adadelta,
    ),
    (
        "share_weight_seq_model_adadelta",
        lambda self: self._shared_param_seq_model,
        optim.Adadelta,
    ),
    (
        "partially_used_model_adadelta",
        lambda self: self._partially_used_model,
        optim.Adadelta,
    ),
    (
        "simple_model_adagrad",
        lambda self: self._model,
        optim.Adagrad,
    ),
    (
        "simple_model_adagrad_weight_decay",
        lambda self: self._model,
        optim.Adagrad,
        {
            "weight_decay": 0.1
        },
    ),
    (
        "simple_model_adagrad_lr_decay",
        lambda self: self._model,
        optim.Adagrad,
        {
            "lr_decay": 0.1
        },
    ),
    (
        "share_weight_model_adagrad",
        lambda self: self._shared_param_model,
        optim.Adagrad,
    ),
    (
        "share_weight_seq_model_adagrad",
        lambda self: self._shared_param_seq_model,
        optim.Adagrad,
    ),
    (
        "partially_used_model_adagrad",
        lambda self: self._partially_used_model,
        optim.Adagrad,
    ),
    (
        "simple_model_adamax",
        lambda self: self._model,
        optim.Adamax,
    ),
    (
        "simple_model_adamax_weight_decay",
        lambda self: self._model,
        optim.Adamax,
        {
            "weight_decay": 0.1
        },
    ),
    (
        "share_weight_model_adamax",
        lambda self: self._shared_param_model,
        optim.Adamax,
    ),
    (
        "share_weight_seq_model_adamax",
        lambda self: self._shared_param_seq_model,
        optim.Adamax,
    ),
    (
        "partially_used_model_adamax",
        lambda self: self._partially_used_model,
        optim.Adamax,
    ),
    (
        "simple_model_asgd",
        lambda self: self._model,
        optim.ASGD,
    ),
    (
        "simple_model_asgd_weight_decay",
        lambda self: self._model,
        optim.ASGD,
        {
            "weight_decay": 0.1
        },
    ),
    (
        "share_weight_model_asgd",
        lambda self: self._shared_param_model,
        optim.ASGD,
    ),
    (
        "share_weight_seq_model_asgd",
        lambda self: self._shared_param_seq_model,
        optim.ASGD,
    ),
    (
        "partially_used_model_asgd",
        lambda self: self._partially_used_model,
        optim.ASGD,
    ),
    # (
    #     "simple_model_rmsprop",
    #     lambda self: self._model,
    #     optim.RMSprop,
    # ),
    # (
    #     "simple_model_rmsprop_momentum",
    #     lambda self: self._model,
    #     optim.RMSprop,
    #     {
    #         "momentum": 0.1
    #     },
    # ),
    # (
    #     "simple_model_rmsprop_weight_decay",
    #     lambda self: self._model,
    #     optim.RMSprop,
    #     {
    #         "weight_decay": 0.1
    #     },
    # ),
    # (
    #     "simple_model_rmsprop_centered",
    #     lambda self: self._model,
    #     optim.RMSprop,
    #     {
    #         "centered": True
    #     },
    # ),
    # (
    #     "share_weight_model_rmsprop",
    #     lambda self: self._shared_param_model,
    #     optim.RMSprop,
    # ),
    # (
    #     "share_weight_seq_model_rmsprop",
    #     lambda self: self._shared_param_seq_model,
    #     optim.RMSprop,
    # ),
    # (
    #     "partially_used_model_rmsprop",
    #     lambda self: self._partially_used_model,
    #     optim.RMSprop,
    # ),
]


class _NestedEnc(torch.nn.Module):
    def __init__(self, f):
        super().__init__()
        self.f = f

    def forward(self, x):
        return self.f(x)


class _Enc(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = _NestedEnc(torch.nn.Linear(4, 2))
        self.e2 = _NestedEnc(self.e1.f)

    def forward(self, x):
        return self.e1(x) + self.e2(x)


class _PartiallyUsed(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.a = torch.nn.Parameter(torch.rand(4, 3, requires_grad=True))
        self.b = torch.nn.Parameter(torch.rand(4, 3, requires_grad=True))

    def forward(self, x):
        return x @ self.a


def finite_difference(model, closure, eps):
    fd_params = []
    ground = closure()
    for param in model.parameters():
        fd_param = torch.zeros_like(param)
        for p, fdp in zip(param.flatten(), fd_param.flatten()):
            p.data.add_(eps)
            fdp.fill_((closure() - ground) / eps)
            p.data.sub_(eps)
        fd_params.append(fd_param)
    return fd_params


class TestOptim(unittest.TestCase):
    """Test case for the optim module."""

    def setUp(self):
        self._model = nn.Sequential(
            nn.Linear(4, 3), nn.ReLU(), nn.Linear(3, 4, bias=None),
            nn.Sigmoid(), nn.Linear(4, 2)
        )

        self._shared_param_model = _Enc()

        proj = nn.Linear(4, 3)
        self._shared_param_seq_model = nn.Sequential(
            proj, nn.ReLU(), nn.Linear(3, 4), nn.Sigmoid(), proj, nn.Tanh(),
            nn.Linear(3, 2)
        )

        self._partially_used_model = _PartiallyUsed()

    @parameterized.expand(_test_param_sweep)
    def testDiffOptCorrectness(
        self, _, model_builder, opt_builder, kwargs=None
    ):
        kwargs = {} if kwargs is None else kwargs
        lr = .1
        model = model_builder(self)
        opt = opt_builder(model.parameters(), lr=lr, **kwargs)

        self._run_correctness_test(opt, model)

    @parameterized.expand(_test_param_sweep)
    def testDiffOptGroupedParam(
        self, _, model_builder, opt_builder, kwargs=None
    ):
        kwargs = {} if kwargs is None else kwargs
        lr = .1
        left_lr = .2
        model = model_builder(self)
        full_parameters = list(model.parameters())
        half = len(full_parameters) // 2
        left_parameters = full_parameters[:half]
        right_parameters = full_parameters[half:]
        param_groups = [
            {
                'params': left_parameters,
                'lr': left_lr
            },
            {
                'params': right_parameters
            },
        ]
        opt = opt_builder(param_groups, lr=lr, **kwargs)

        self._run_correctness_test(opt, model)

    def _run_correctness_test(self, opt, model, override=None):
        for i in range(10):
            fmodel = higher.patch.monkeypatch(model)
            diffopt = higher.optim.get_diff_optim(
                opt, model.parameters(), fmodel, override=override
            )
            for j in range(3):
                opt.zero_grad()
                x = torch.rand(10, 4)

                y_model = model(x)
                y_fmodel = fmodel(x)

                loss_model = y_model.pow(2).sum()
                loss_fmodel = y_fmodel.pow(2).sum()

                loss_model.backward()
                diffopt.step(loss_fmodel)
                opt.step()
                self.assertEqual(
                    len(list(model.parameters())),
                    len(list(fmodel.parameters()))
                )

            for p, fp in zip(model.parameters(), fmodel.parameters()):
                torch.testing.assert_allclose(p, fp, atol=1e-5, rtol=1e-1)

    @parameterized.expand(_test_param_sweep)
    def testGradientCorrectness(
        self, _, model_builder, opt_builder, kwargs=None
    ):
        kwargs = {} if kwargs is None else kwargs
        lr = .1
        model = model_builder(self)
        eps = 1e-3

        tests = 10
        count = 0
        threshold = .6  # proportion of tests that should pass

        for i in range(tests):
            xs = [torch.rand(10, 4) for _ in range(2)]

            def closure():
                cmodel = copy.deepcopy(model)
                opt = opt_builder(cmodel.parameters(), lr=lr, **kwargs)
                for x in xs[:-1]:
                    opt.zero_grad()
                    cmodel(x).pow(2).sum().backward()
                    opt.step()
                loss = cmodel(xs[-1]).pow(2).sum()
                return loss

            fd_grads = finite_difference(model, closure, eps)

            opt = opt_builder(model.parameters(), lr=lr, **kwargs)
            with higher.innerloop_ctx(model, opt) as (fmodel, diffopt):
                for x in xs[:-1]:
                    loss = fmodel(x).pow(2).sum()
                    diffopt.step(loss)
                loss = fmodel(xs[-1]).pow(2).sum()
                grads = torch.autograd.grad(
                    loss, fmodel.parameters(time=0), allow_unused=True
                )
                close = []
                for g, fg in zip(grads, fd_grads):
                    if g is None:
                        # trusting that the tensor shouldn't have been used...
                        close.append(True)
                    else:
                        self.assertFalse(
                            torch.any(torch.isnan(g)), "NaNs found in gradient."
                        )
                        close.append(torch.allclose(g, fg, 1e-1, 1e-1))
                if all(close):
                    count += 1
        self.assertTrue(
            count / tests >= threshold,
            msg="Proportion of successful finite gradient checks below {:.0f}% "
            "threshold ({:.0f}%).".format(threshold * 100, 100 * count / tests)
        )

    @parameterized.expand([(
        "simple_model_adam",
        lambda self: self._model,
        optim.Adam,
    )])
    def testDiffOptGroupedParamLearn(
        self, _, model_builder, opt_builder, kwargs=None
    ):
        kwargs = {} if kwargs is None else kwargs
        lr = .1
        left_lr = .2
        model = model_builder(self)
        full_parameters = list(model.parameters())
        half = len(full_parameters) // 2
        left_parameters = full_parameters[:half]
        right_parameters = full_parameters[half:]
        param_groups = [
            {
                'params': left_parameters,
                'lr': left_lr
            },
            {
                'params': right_parameters
            },
        ]
        opt = opt_builder(param_groups, lr=lr, **kwargs)

        override = {
            'lr':
                [
                    torch.tensor(.3, requires_grad=True)
                ],
            'betas':
                [
                    (
                        torch.tensor(0.9, requires_grad=True),
                        torch.tensor(0.999, requires_grad=True)
                    ),
                    (
                        torch.tensor(0.8, requires_grad=True),
                        torch.tensor(0.888, requires_grad=True)
                    )
                ]
        }
        meta_params = higher.utils.flatten(override)

        for i in range(1):
            fmodel = higher.patch.monkeypatch(model)
            diffopt = higher.optim.get_diff_optim(
                opt, model.parameters(), fmodel, override=override
            )
            for j in range(3):
                x = torch.rand(10, 4)
                y_fmodel = fmodel(x)
                loss_fmodel = y_fmodel.pow(2).sum()
                diffopt.step(loss_fmodel)
            param_sum = sum(p.sum() for p in fmodel.parameters())
            for g in torch.autograd.grad(param_sum, meta_params):
                self.assertTrue(
                    torch.isfinite(g).all().item(),
                    "Nan or Inf found in hyperparameter gradients."
                )

    @staticmethod
    def _approx_equal_params(params_1, params_2):
        params_1 = list(params_1)
        params_2 = list(params_2)
        if len(params_1) != len(params_2):
            return False
        for p1, p2 in zip(params_1, params_2):
            if not torch.allclose(p1, p2):
                return False
        return True

    @parameterized.expand([(
        "simple_model_adam",
        lambda self: self._model,
        optim.Adam,
    )])
    def testDiffOptCallback(
        self, _, model_builder, opt_builder, kwargs=None
    ):
        kwargs = {} if kwargs is None else kwargs
        lr = .1
        left_lr = .2
        model = model_builder(self)

        full_parameters = list(model.parameters())
        half = len(full_parameters) // 2
        left_parameters = full_parameters[:half]
        right_parameters = full_parameters[half:]
        param_groups = [
            {
                'params': left_parameters,
                'lr': left_lr
            },
            {
                'params': right_parameters
            },
        ]
        opt = opt_builder(param_groups, lr=lr, **kwargs)

        # We should have the following equalities/inequalities for the patched
        # models defined below at the end of training:
        # fmodel_0 != fmodel_1 != fmodel_2
        # fmodel_1 != fmodel_2
        # fmodel_2 == fmodel_3

        callback_1 = lambda all_grad: [g * .1 for g in all_grad]
        callback_2 = lambda all_grad: [g * .2 for g in all_grad]
        callback_3 = callback_2

        for i in range(1):
            fmodel_0 = higher.patch.monkeypatch(model)
            diffopt_0 = higher.optim.get_diff_optim(
                opt, model.parameters(), fmodel_0, grad_callback=None
            )

            fmodel_1 = higher.patch.monkeypatch(model)
            diffopt_1 = higher.optim.get_diff_optim(
                opt, model.parameters(), fmodel_1, grad_callback=callback_1
            )

            fmodel_2 = higher.patch.monkeypatch(model)
            diffopt_2 = higher.optim.get_diff_optim(
                opt, model.parameters(), fmodel_2, grad_callback=callback_2
            )

            fmodel_3 = higher.patch.monkeypatch(model)
            diffopt_3 = higher.optim.get_diff_optim(
                opt, model.parameters(), fmodel_3, grad_callback=None
            )

            for j in range(3):
                x = torch.rand(10, 4)
                diffopt_0.step(fmodel_0(x).pow(2).sum())
                diffopt_1.step(fmodel_1(x).pow(2).sum())
                diffopt_2.step(fmodel_2(x).pow(2).sum())
                diffopt_3.step(
                    fmodel_3(x).pow(2).sum(), grad_callback=callback_3
                )

            # Check that the conditions described at top of loop are satisfied
            self.assertFalse(
                self._approx_equal_params(
                    fmodel_0.parameters(), fmodel_1.parameters()
                )
            )
            self.assertFalse(
                self._approx_equal_params(
                    fmodel_0.parameters(), fmodel_2.parameters()
                )
            )
            self.assertFalse(
                self._approx_equal_params(
                    fmodel_1.parameters(), fmodel_2.parameters()
                )
            )
            self.assertTrue(
                self._approx_equal_params(
                    fmodel_2.parameters(), fmodel_3.parameters()
                )
            )

    @parameterized.expand([(
        "simple_model_adam",
        lambda self: self._model,
        optim.Adam,
    )])
    def testDiffOptGroupedParamLearnStepwise(
        self, _, model_builder, opt_builder, kwargs=None
    ):
        kwargs = {} if kwargs is None else kwargs
        lr = .1
        left_lr = .2
        model = model_builder(self)
        full_parameters = list(model.parameters())
        half = len(full_parameters) // 2
        left_parameters = full_parameters[:half]
        right_parameters = full_parameters[half:]
        param_groups = [
            {
                'params': left_parameters,
                'lr': left_lr
            },
            {
                'params': right_parameters
            },
        ]
        opt = opt_builder(param_groups, lr=lr, **kwargs)

        override = {
            'lr':
                [
                    torch.tensor(.3, requires_grad=True)
                ],
            'betas':
                [
                    (
                        torch.tensor(0.9, requires_grad=True),
                        torch.tensor(0.999, requires_grad=True)
                    ),
                    (
                        torch.tensor(0.8, requires_grad=True),
                        torch.tensor(0.888, requires_grad=True)
                    )
                ]
        }
        meta_params = higher.utils.flatten(override)

        for i in range(1):
            fmodel = higher.patch.monkeypatch(model)
            diffopt = higher.optim.get_diff_optim(
                opt, model.parameters(), fmodel, override=None
            )
            for j in range(3):
                x = torch.rand(10, 4)
                y_fmodel = fmodel(x)
                loss_fmodel = y_fmodel.pow(2).sum()
                diffopt.step(loss_fmodel, override=override)
            param_sum = sum(p.sum() for p in fmodel.parameters())
            for g in torch.autograd.grad(param_sum, meta_params):
                self.assertTrue(
                    torch.isfinite(g).all().item(),
                    "Nan or Inf found in hyperparameter gradients."
                )

    def testFrozenParameters(self):
        """Check if diffopts robuts to frozen parameters.

        Thanks to github user @seanie12 for providing the minimum working
        example for this unit test.
        """
        class Net(nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.fc1 = nn.Linear(30, 50)
                self.fc2 = nn.Linear(50, 1)
                # freeze first FC layer
                for param in self.fc1.parameters():
                    param.requires_grad = False

            def forward(self, x):
                hidden = self.fc1(x)
                logits = self.fc2(hidden).squeeze(1)
                return logits

        # random input and labels for debugging
        inputs = torch.randn(16, 30)
        ones = torch.ones(8)
        zeros = torch.zeros(8)
        labels = torch.cat([ones, zeros], dim=0)

        net = Net()

        param = filter(lambda x: x.requires_grad, net.parameters())
        inner_opt = torch.optim.SGD(param, lr=1e-1)
        loss_func = nn.BCEWithLogitsLoss()

        with higher.innerloop_ctx(net, inner_opt) as (fnet, diffopt):
            logits = fnet(inputs)
            loss = loss_func(logits, labels)
            diffopt.step(loss)
            zipped = list(zip(net.parameters(), fnet.parameters()))
            self.assertTrue(torch.equal(*zipped[0]))
            self.assertTrue(torch.equal(*zipped[1]))
            self.assertFalse(torch.equal(*zipped[2]))
            self.assertFalse(torch.equal(*zipped[3]))

    def testGetApplyRoundTrip(self):
        kwargs = {}
        lr = .1
        left_lr = .2
        model = self._model
        full_parameters = list(model.parameters())
        half = len(full_parameters) // 2
        left_parameters = full_parameters[:half]
        right_parameters = full_parameters[half:]
        param_groups = [
            {
                'params': left_parameters,
                'lr': left_lr
            },
            {
                'params': right_parameters
            },
        ]
        opt = optim.Adam(param_groups, lr=lr, **kwargs)

        override = higher.optim.get_trainable_opt_params(opt)

        def assert_closure(target: torch.Tensor):
            self.assertTrue(torch.is_tensor(target) and target.requires_grad)

        # Check that all items in override are structures containing
        # differentiable tensors requiring gradient
        for hp in override:
            higher.utils._recursive_map(override[hp], assert_closure)

        param_groups = [
            {
                'params': left_parameters,
                'lr': left_lr + 1
            },
            {
                'params': right_parameters
            },
        ]

        # Create new opt with slightly different parameter group hyperparameter
        # values, to simulate divergence between original hyperparameters
        new_opt = optim.Adam(param_groups, lr=lr+5, **kwargs)

        # Overwrite with the "learned" hyperparameters
        higher.optim.apply_trainable_opt_params(new_opt, override)

        # Check that values match
        # TODO(egrefen): would be good to do a structure matching test, or an
        # extrinsic eval whereby we check that both opts are functionally
        # equivalent.
        old_flattened = higher.utils.flatten(opt.param_groups)
        new_flattened = higher.utils.flatten(new_opt.param_groups)
        self.assertEqual(len(old_flattened), len(new_flattened))
        zipped = zip(old_flattened, new_flattened)
        for old, new in zipped:
            self.assertEqual(type(old), type(new))
            if torch.is_tensor(old):
                torch.testing.assert_allclose(old, new)
            elif isinstance(old, float) or isinstance(old, int):
                self.assertAlmostEqual(old, new)
            else:
                self.assertEqual(old, new)



if __name__ == '__main__':
    unittest.main()