# Copyright 2020 Zuru Tech HK Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Test Restorers.

GIVEN a correctly instantiated trainer
GIVEN some training has been done
    WHEN calling the Restorer "Restored ... from checkpoint: .../ckpts/ckpt-..."
        should be logged.
    WHEN restoring models the first layer of the restored model and the trained one should
        have the same weights.
"""
from pathlib import Path
from typing import Union

import pytest
import tensorflow as tf
from ashpy.callbacks import CounterCallback
from ashpy.restorers import (
    AdversarialRestorer,
    ClassifierRestorer,
    ModelNotConstructedError,
    Restorer,
)
from ashpy.trainers import AdversarialTrainer, ClassifierTrainer

from tests.utils.fake_training_loop import (
    FakeAdversarialTraining,
    FakeClassifierTraining,
)

DEFAULT_CKPT_DIR = "ckpts"


def _check_models_weights(trained: tf.keras.Model, restored: tf.keras.Model, i=0):
    """Test that the first layers of the restored and trained model have the same weights."""

    try:
        for i, element in enumerate(trained.weights):
            assert tf.reduce_all(tf.equal(element, restored.weights[i]))
    except AssertionError:
        raise ModelNotConstructedError


def _test_restore_object(restorer, placeholder, ckpt_id, capsys):
    """Test that the object is restored correctly."""
    restorer.restore_object(placeholder, ckpt_id)
    _check_log(restorer, ckpt_id, capsys)


def _check_log(restorer, ckpt_id, capsys):
    """Test that the object is restored correctly by looking at the logs."""
    out, _ = capsys.readouterr()
    # Assert that the log is correct
    assert restorer._restored_log_msg.format(ckpt_id, restorer._ckpts_dir) in out.split(
        "\n"
    )


def test_restore_model(fake_training_fn, capsys, tmpdir):
    """
    Test that models are correctly restored.

    The test is performed by checking the logs and the first layer of each model.
    """
    logdir = Path(tmpdir).joinpath("training")
    _tmp_logdir = Path(tmpdir).joinpath("banana")

    fake_training = fake_training_fn(logdir=logdir)
    assert fake_training()
    trainer = fake_training.trainer
    restorer = Restorer(logdir=logdir)

    if isinstance(trainer, ClassifierTrainer):

        new_loop = fake_training_fn(logdir=_tmp_logdir)
        placeholder = new_loop.model

        # Ensure model have been built correctly
        x, _ = next(iter(new_loop.dataset))
        placeholder(x)

        ckpt_id = trainer.ckpt_id_model
        _test_restore_object(restorer, placeholder, ckpt_id, capsys)
        _check_models_weights(trainer._model, placeholder)

    elif isinstance(trainer, AdversarialTrainer):

        new_loop: FakeAdversarialTraining = fake_training_fn(logdir=_tmp_logdir)
        placeholder_g, placeholder_d = (new_loop.generator, new_loop.discriminator)

        # Ensure that the ModelNotConstructedError is correctly triggered
        with pytest.raises(ModelNotConstructedError):
            _test_restore_object(
                restorer, placeholder_g, trainer.ckpt_id_generator, capsys
            )
            _test_restore_object(
                restorer, placeholder_d, trainer.ckpt_id_discriminator, capsys
            )

        # Ensure model have been built correctly
        (x, _), z = next(iter(new_loop.dataset))
        fake = placeholder_g(z)
        assert tf.reduce_all(tf.equal(fake.shape, x.shape))
        placeholder_d(x)

        _test_restore_object(restorer, placeholder_g, trainer.ckpt_id_generator, capsys)
        _check_models_weights(trainer._generator, placeholder_g)
        _test_restore_object(
            restorer, placeholder_d, trainer.ckpt_id_discriminator, capsys
        )
        _check_models_weights(trainer._discriminator, placeholder_d)


def test_restore_common_variables(fake_training_fn, capsys, tmpdir):
    """
    Test that the convenience methods exposed by :class:`Restorer` work correctly.

    The common :class:`tf.Variable`s that can be restored from :class:`Restorer` are:
        - global step
        - steps per epoch

    """
    logdir = Path(tmpdir).joinpath("training")

    fake_training = fake_training_fn(logdir=logdir)
    assert fake_training()
    trainer = fake_training.trainer
    restorer = Restorer(logdir=logdir)

    # Restore variables and check their values using the convenience method
    assert tf.equal(trainer._global_step, restorer.get_global_step())
    assert tf.equal(trainer._steps_per_epoch, restorer.get_steps_per_epoch())

    out, _ = capsys.readouterr()

    # Check the log
    for id_to_check in [
        trainer.ckpt_id_global_step,
        trainer.ckpt_id_steps_per_epoch,
    ]:
        # Assert that the log is correct
        assert restorer._restored_log_msg.format(
            id_to_check, restorer._ckpts_dir
        ) in out.split("\n")


def test_restore_callbacks(fake_training_fn, capsys, tmpdir):
    """Test that callbacks are succesfully restored."""
    logdir = Path(tmpdir).joinpath("training")

    fake_training = fake_training_fn(logdir=logdir)
    assert fake_training()
    trainer = fake_training.trainer
    restorer = Restorer(logdir=logdir)

    if isinstance(trainer, AdversarialTrainer):
        placeholder_callbacks = fake_training.callbacks
        for i, placeholder_callback in enumerate(placeholder_callbacks):
            # Restore the callbacks
            restorer.restore_callback(placeholder_callback, placeholder_callback.name)
            # Check the log
            _check_log(restorer, placeholder_callback.name, capsys)
            # Check that the trained values and the restored are equal
            if isinstance(placeholder_callback, CounterCallback):
                assert tf.equal(
                    placeholder_callbacks[i]._event_counter,
                    trainer._callbacks[i]._event_counter,
                )


def test_read_checkpoint_map(fake_training_fn, tmpdir):
    """Test that checkpoint map is read correctly."""
    logdir = Path(tmpdir).joinpath("training")

    fake_training = fake_training_fn(logdir=logdir)
    assert fake_training()
    trainer = fake_training.trainer
    restorer = Restorer(logdir=logdir)
    assert restorer.checkpoint_map == trainer._generate_checkpoint_map()

    # Test that Restorer.checkpoint_map without the checkpoint_map.json correctly returns None
    # Remove checkpoint_map.json
    ckpt_map: Path = restorer._ckpts_dir / "checkpoint_map.json"
    ckpt_map.unlink()
    assert not ckpt_map.exists()
    assert not Restorer(logdir).checkpoint_map


# ###################################################
# Test Convenience Methods


def _test_convenience_model_restorer(
    restorer: AdversarialRestorer,
    convenience_method,
    placeholder_model,
    trained_model,
    ckpt_id,
    capsys,
):
    convenience_method(placeholder_model)
    _check_log(restorer, ckpt_id, capsys)
    _check_models_weights(trained_model, placeholder_model)


def _test_convenience_optimizer_restorer(
    restorer, convenience_method, placeholder_optimizer, ckpt_id, capsys
):
    """
    Test that the various optimizers are correctly restored using convenience classes.

    TODO: Add a more thorough check like :meth:`_check_first_layer()`
    """
    convenience_method(placeholder_optimizer)
    _check_log(restorer, ckpt_id, capsys)


def test_convenience_restorer(fake_training_fn, capsys, tmpdir):
    """
    Test that models and optimizers are correctly restored using the convenience classes.

    TODO: Add test for AdversarialEncoderRestorer
    """
    logdir = Path(tmpdir).joinpath("training")
    _tmp_logdir = Path(tmpdir).joinpath("banana")

    fake_training = fake_training_fn(logdir=logdir)
    assert fake_training()
    trainer = fake_training.trainer
    restorer = Restorer(logdir=logdir)

    if isinstance(trainer, ClassifierTrainer):
        restorer: ClassifierRestorer = ClassifierRestorer(logdir=logdir)

        new_training: FakeClassifierTraining = fake_training_fn(_tmp_logdir)

        placeholder_model = new_training.model
        placeholder_opt = tf.keras.optimizers.Adam()

        # Ensure model have been built correctly
        x, _ = next(iter(new_training.dataset))
        placeholder_model(x)

        _test_convenience_model_restorer(
            restorer,
            restorer.restore_model,
            placeholder_model,
            trainer._model,
            trainer.ckpt_id_model,
            capsys,
        )
        _test_convenience_optimizer_restorer(
            restorer,
            restorer.restore_optimizer,
            placeholder_opt,
            trainer.ckpt_id_optimizer,
            capsys,
        )

    elif isinstance(trainer, AdversarialTrainer):
        restorer: AdversarialRestorer = AdversarialRestorer(logdir=logdir)

        new_training: FakeAdversarialTraining = fake_training_fn(_tmp_logdir)

        placeholder_g, placeholder_d = (
            new_training.generator,
            new_training.discriminator,
        )
        placeholder_optimizer_g, placeholder_optimizer_d = (
            tf.keras.optimizers.Adam(),
            tf.keras.optimizers.Adam(),
        )

        # Ensure that the ModelNotConstructedError is correctly triggered
        with pytest.raises(ModelNotConstructedError):
            _test_convenience_model_restorer(
                restorer,
                restorer.restore_generator,
                placeholder_g,
                trainer._generator,
                trainer.ckpt_id_generator,
                capsys,
            )

        with pytest.raises(ModelNotConstructedError):
            _test_convenience_model_restorer(
                restorer,
                restorer.restore_discriminator,
                placeholder_d,
                trainer._discriminator,
                trainer.ckpt_id_discriminator,
                capsys,
            )

        # Ensure models have been built correctly
        (x, _), z = next(iter(new_training.dataset))
        fake = placeholder_g(z)
        assert tf.reduce_all(tf.equal(fake.shape, x.shape))
        placeholder_d(x)

        _test_convenience_model_restorer(
            restorer,
            restorer.restore_generator,
            placeholder_g,
            trainer._generator,
            trainer.ckpt_id_generator,
            capsys,
        )
        _test_convenience_optimizer_restorer(
            restorer,
            restorer.restore_generator_optimizer,
            placeholder_optimizer_g,
            trainer.ckpt_id_optimizer_generator,
            capsys,
        )
        _test_convenience_model_restorer(
            restorer,
            restorer.restore_discriminator,
            placeholder_d,
            trainer._discriminator,
            trainer.ckpt_id_discriminator,
            capsys,
        )
        _test_convenience_optimizer_restorer(
            restorer,
            restorer.restore_discriminator_optimizer,
            placeholder_optimizer_d,
            trainer.ckpt_id_optimizer_discriminator,
            capsys,
        )


def test_failings(tmpdir):
    """Test the failing cases for the Restorers."""
    # Test Restorer fails on empty logdir
    with pytest.raises(FileNotFoundError):
        Restorer(Path(tmpdir + ("fuffa")))

    # Test Restorer fails on empty checkpoint dir
    with pytest.raises(FileNotFoundError):
        restorer = Restorer(tmpdir)
        restorer._restore_checkpoint(tf.train.Checkpoint())

    # Test failed placeholders validation
    with pytest.raises(TypeError):
        Restorer._validate_placeholder(
            placeholders=tf.keras.Model(), placeholder_type=tf.Variable,
        )