# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for polybeast learn function implementation."""

import copy
import unittest
from unittest import mock

import numpy as np
import torch
from torchbeast import polybeast


def _state_dict_to_numpy(state_dict):
    return {key: value.numpy() for key, value in state_dict.items()}


class LearnTest(unittest.TestCase):
    def setUp(self):
        unroll_length = 2  # Arbitrary.
        batch_size = 4  # Arbitrary.
        frame_dimension = 84  # Has to match what expected by the model.
        num_actions = 6  # Specific to each environment.
        num_channels = 4  # Has to match with the first conv layer of the net.

        # The following hyperparamaters are arbitrary.
        self.lr = 0.1
        total_steps = 100000

        # Set the random seed manually to get reproducible results.
        torch.manual_seed(0)

        self.model = polybeast.Net(num_actions=num_actions, use_lstm=False)
        self.actor_model = polybeast.Net(num_actions=num_actions, use_lstm=False)
        self.initial_model_dict = copy.deepcopy(self.model.state_dict())
        self.initial_actor_model_dict = copy.deepcopy(self.actor_model.state_dict())

        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr)

        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=total_steps // 10
        )

        self.stats = {}

        # The call to plogger.log will not perform any action.
        plogger = mock.Mock()
        plogger.log = mock.Mock()

        # Mock flags.
        mock_flags = mock.Mock()
        mock_flags.learner_device = torch.device("cpu")
        mock_flags.reward_clipping = "abs_one"  # Default value from cmd.
        mock_flags.discounting = 0.99  # Default value from cmd.
        mock_flags.baseline_cost = 0.5  # Default value from cmd.
        mock_flags.entropy_cost = 0.0006  # Default value from cmd.
        mock_flags.unroll_length = unroll_length
        mock_flags.batch_size = batch_size
        mock_flags.grad_norm_clipping = 40

        # Prepare content for mock_learner_queue.
        frame = torch.ones(
            unroll_length, batch_size, num_channels, frame_dimension, frame_dimension
        )
        rewards = torch.ones(unroll_length, batch_size)
        done = torch.zeros(unroll_length, batch_size, dtype=torch.uint8)
        episode_step = torch.ones(unroll_length, batch_size)
        episode_return = torch.ones(unroll_length, batch_size)

        env_outputs = (frame, rewards, done, episode_step, episode_return)
        actor_outputs = (
            # Actions taken.
            torch.randint(low=0, high=num_actions, size=(unroll_length, batch_size)),
            # Logits.
            torch.randn(unroll_length, batch_size, num_actions),
            # Baseline.
            torch.rand(unroll_length, batch_size),
        )
        initial_agent_state = ()  # No lstm.
        tensors = ((env_outputs, actor_outputs), initial_agent_state)

        # Mock learner_queue.
        mock_learner_queue = mock.MagicMock()
        mock_learner_queue.__iter__.return_value = iter([tensors])

        self.learn_args = (
            mock_flags,
            mock_learner_queue,
            self.model,
            self.actor_model,
            optimizer,
            scheduler,
            self.stats,
            plogger,
        )

    def test_parameters_copied_to_actor_model(self):
        """Check that the learner model copies the parameters to the actor model."""
        # Reset models.
        self.model.load_state_dict(self.initial_model_dict)
        self.actor_model.load_state_dict(self.initial_actor_model_dict)

        polybeast.learn(*self.learn_args)

        np.testing.assert_equal(
            _state_dict_to_numpy(self.actor_model.state_dict()),
            _state_dict_to_numpy(self.model.state_dict()),
        )

    def test_weights_update(self):
        """Check that trainable parameters get updated after one iteration."""
        # Reset models.
        self.model.load_state_dict(self.initial_model_dict)
        self.actor_model.load_state_dict(self.initial_actor_model_dict)

        polybeast.learn(*self.learn_args)

        model_state_dict = self.model.state_dict(keep_vars=True)
        actor_model_state_dict = self.actor_model.state_dict(keep_vars=True)
        for key, initial_tensor in self.initial_model_dict.items():
            model_tensor = model_state_dict[key]
            actor_model_tensor = actor_model_state_dict[key]
            # Assert that the gradient is not zero for the learner.
            self.assertGreater(torch.norm(model_tensor.grad), 0.0)
            # Assert actor has no gradient.
            # Note that even though actor model tensors have no gradient,
            # they have requires_grad == True. No gradients are ever calculated
            # for these tensors because the inference function in polybeast.py
            # (that performs forward passes with the actor_model) uses torch.no_grad
            # context manager.
            self.assertIsNone(actor_model_tensor.grad)
            # Assert that the weights are updated in the expected way.
            # We manually perform a gradient descent step,
            # and check that they are the same as the calculated ones
            # (ignoring floating point errors).
            expected_tensor = (
                initial_tensor.detach().numpy() - self.lr * model_tensor.grad.numpy()
            )
            np.testing.assert_almost_equal(
                model_tensor.detach().numpy(), expected_tensor
            )
            np.testing.assert_almost_equal(
                actor_model_tensor.detach().numpy(), expected_tensor
            )

    def test_gradients_update(self):
        """Check that gradients get updated after one iteration."""
        # Reset models.
        self.model.load_state_dict(self.initial_model_dict)
        self.actor_model.load_state_dict(self.initial_actor_model_dict)

        # There should be no calculated gradient yet.
        for p in self.model.parameters():
            self.assertIsNone(p.grad)
        for p in self.actor_model.parameters():
            self.assertIsNone(p.grad)

        polybeast.learn(*self.learn_args)

        # Check that every parameter for the learner model has a gradient, and that
        # there is at least some non-zero gradient for each set of paramaters.
        for p in self.model.parameters():
            self.assertIsNotNone(p.grad)
            self.assertFalse(torch.equal(p.grad, torch.zeros_like(p.grad)))

        # Check that the actor model has no gradients associated with it.
        for p in self.actor_model.parameters():
            self.assertIsNone(p.grad)

    def test_non_zero_loss(self):
        """Check that the loss is not zero after one iteration."""
        # Reset models.
        self.model.load_state_dict(self.initial_model_dict)
        self.actor_model.load_state_dict(self.initial_actor_model_dict)

        polybeast.learn(*self.learn_args)

        self.assertNotEqual(self.stats["total_loss"], 0.0)
        self.assertNotEqual(self.stats["pg_loss"], 0.0)
        self.assertNotEqual(self.stats["baseline_loss"], 0.0)
        self.assertNotEqual(self.stats["entropy_loss"], 0.0)


if __name__ == "__main__":
    unittest.main()