import os
import sys
import tensorflow as tf
import contextlib
import json
import importlib
from tensorflow.python.eager.context import num_gpus


def get_current_epoch(output_dir):
    try:
        with open(os.path.join(output_dir, "stats.json"), "r") as f:
            return json.load(f)["epoch"]
    except:
        return 0


class ModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
    def on_epoch_end(self, epoch, logs=None):
        super().on_epoch_end(epoch, logs=logs)
        with open(os.path.join(os.path.dirname(self.filepath), "stats.json"), "w") as f:
            return json.dump({"epoch": epoch + 1}, f)


def get_distribution_scope(batch_size):
    if num_gpus() > 1:
        strategy = tf.distribute.MirroredStrategy()
        assert (
            batch_size % strategy.num_replicas_in_sync == 0
        ), f"Batch size {batch_size} cannot be divided onto {num_gpus()} GPUs"
        distribution_scope = strategy.scope
    else:
        if sys.version_info >= (3, 7):
            distribution_scope = contextlib.nullcontext
        else:
            distribution_scope = contextlib.suppress

    return distribution_scope()


def import_all_sub_modules(module):
    with os.scandir(os.path.join(os.path.dirname(__file__), module)) as it:
        for entry in it:
            name = entry.name
            if name.endswith(".py") and not name.startswith("_") and entry.is_file():
                importlib.import_module(f"bnn_optimization.{module}.{name[:-3]}")


def prepare_registry():
    for module in ("data", "models"):
        import_all_sub_modules(module)