# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from torch.optim import Adam, AdamW

from qhoptim.pyt import QHAdam, QHAdamW

from .util import assert_optimizers_equal


def test_adam_equiv():
    lr = 3e-4
    betas = (0.9, 0.999)
    weight_decay = 0.5e-4
    eps = 1e-8

    def adam_ctor(params):
        return Adam(params, lr=lr, betas=betas, weight_decay=weight_decay, eps=eps)

    def qhadam_ctor(params):
        return QHAdam(params, lr=lr, betas=betas, weight_decay=weight_decay, nus=(1.0, 1.0), eps=eps)

    def adamw_ctor(params):
        return AdamW(params, lr=lr, betas=betas, weight_decay=weight_decay, eps=eps)

    def qhadamw_ctor(params):
        return QHAdamW(params, lr=lr, betas=betas, weight_decay=weight_decay, nus=(1.0, 1.0), eps=eps)

    assert_optimizers_equal(adam_ctor, qhadam_ctor)