#!/usr/bin/env python3
""" Base class for Models. ALL Models should at least inherit from this class

    When inheriting model_data should be a list of NNMeta objects.
    See the class for details.
"""
import logging
import os
import sys
import time

from concurrent import futures

import keras
from keras import losses
from keras import backend as K
from keras.layers import Input
from keras.models import load_model, Model
from keras.utils import get_custom_objects, multi_gpu_model

from lib.serializer import get_serializer
from lib.model.backup_restore import Backup
from lib.model.losses import (DSSIMObjective, PenalizedLoss, gradient_loss, mask_loss_wrapper,
                              generalized_loss, l_inf_norm, gmsd_loss, gaussian_blur)
from lib.model.nn_blocks import NNBlocks
from lib.model.optimizers import Adam
from lib.utils import deprecation_warning, FaceswapError
from plugins.train._config import Config

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name
_CONFIG = None


class ModelBase():
    """ Base class that all models should inherit from """
    def __init__(self,
                 model_dir,
                 gpus=1,
                 configfile=None,
                 snapshot_interval=0,
                 no_logs=False,
                 warp_to_landmarks=False,
                 augment_color=True,
                 no_flip=False,
                 training_image_size=256,
                 alignments_paths=None,
                 preview_scale=100,
                 input_shape=None,
                 encoder_dim=None,
                 trainer="original",
                 pingpong=False,
                 memory_saving_gradients=False,
                 optimizer_savings=False,
                 predict=False):
        logger.debug("Initializing ModelBase (%s): (model_dir: '%s', gpus: %s, configfile: %s, "
                     "snapshot_interval: %s, no_logs: %s, warp_to_landmarks: %s, augment_color: "
                     "%s, no_flip: %s, training_image_size, %s, alignments_paths: %s, "
                     "preview_scale: %s, input_shape: %s, encoder_dim: %s, trainer: %s, "
                     "pingpong: %s, memory_saving_gradients: %s, optimizer_savings: %s, "
                     "predict: %s)",
                     self.__class__.__name__, model_dir, gpus, configfile, snapshot_interval,
                     no_logs, warp_to_landmarks, augment_color, no_flip, training_image_size,
                     alignments_paths, preview_scale, input_shape, encoder_dim, trainer, pingpong,
                     memory_saving_gradients, optimizer_savings, predict)

        self.predict = predict
        self.model_dir = model_dir
        self.vram_savings = VRAMSavings(pingpong, optimizer_savings, memory_saving_gradients)

        self.backup = Backup(self.model_dir, self.name)
        self.gpus = gpus
        self.configfile = configfile
        self.input_shape = input_shape
        self.encoder_dim = encoder_dim
        self.trainer = trainer

        self.load_config()  # Load config if plugin has not already referenced it

        self.state = State(self.model_dir,
                           self.name,
                           self.config_changeable_items,
                           no_logs,
                           self.vram_savings.pingpong,
                           training_image_size)

        self.blocks = NNBlocks(use_icnr_init=self.config["icnr_init"],
                               use_convaware_init=self.config["conv_aware_init"],
                               use_reflect_padding=self.config["reflect_padding"],
                               first_run=self.state.first_run)

        self.is_legacy = False
        self.rename_legacy()
        self.load_state_info()

        self.networks = dict()  # Networks for the model
        self.predictors = dict()  # Predictors for model
        self.history = dict()  # Loss history per save iteration)

        # Training information specific to the model should be placed in this
        # dict for reference by the trainer.
        self.training_opts = {"alignments": alignments_paths,
                              "preview_scaling": preview_scale / 100,
                              "warp_to_landmarks": warp_to_landmarks,
                              "augment_color": augment_color,
                              "no_flip": no_flip,
                              "pingpong": self.vram_savings.pingpong,
                              "snapshot_interval": snapshot_interval,
                              "training_size": self.state.training_size,
                              "no_logs": self.state.current_session["no_logs"],
                              "coverage_ratio": self.calculate_coverage_ratio(),
                              "mask_type": self.config["mask_type"],
                              "mask_blur_kernel": self.config["mask_blur_kernel"],
                              "mask_threshold": self.config["mask_threshold"],
                              "learn_mask": (self.config["learn_mask"] and
                                             self.config["mask_type"] is not None),
                              "penalized_mask_loss": (self.config["penalized_mask_loss"] and
                                                      self.config["mask_type"] is not None)}
        logger.debug("training_opts: %s", self.training_opts)

        if self.multiple_models_in_folder:
            deprecation_warning("Support for multiple model types within the same folder",
                                additional_info="Please split each model into separate folders to "
                                                "avoid issues in future.")

        self.build()
        logger.debug("Initialized ModelBase (%s)", self.__class__.__name__)

    @property
    def config_section(self):
        """ The section name for loading config """
        retval = ".".join(self.__module__.split(".")[-2:])
        logger.debug(retval)
        return retval

    @property
    def config(self):
        """ Return config dict for current plugin """
        global _CONFIG  # pylint: disable=global-statement
        if not _CONFIG:
            model_name = self.config_section
            logger.debug("Loading config for: %s", model_name)
            _CONFIG = Config(model_name, configfile=self.configfile).config_dict
        return _CONFIG

    @property
    def config_changeable_items(self):
        """ Return the dict of config items that can be updated after the model
            has been created """
        return Config(self.config_section, configfile=self.configfile).changeable_items

    @property
    def name(self):
        """ Set the model name based on the subclass """
        basename = os.path.basename(sys.modules[self.__module__].__file__)
        retval = os.path.splitext(basename)[0].lower()
        logger.debug("model name: '%s'", retval)
        return retval

    @property
    def models_exist(self):
        """ Return if all files exist and clear session """
        retval = all([os.path.isfile(model.filename) for model in self.networks.values()])
        logger.debug("Pre-existing models exist: %s", retval)
        return retval

    @property
    def multiple_models_in_folder(self):
        """ Return true if there are multiple model types in the same folder, else false """
        model_files = [fname for fname in os.listdir(str(self.model_dir)) if fname.endswith(".h5")]
        retval = False if not model_files else os.path.commonprefix(model_files) == ""
        logger.debug("model_files: %s, retval: %s", model_files, retval)
        return retval

    @property
    def output_shapes(self):
        """ Return the output shapes from the main AutoEncoder """
        out = list()
        for predictor in self.predictors.values():
            out.extend([K.int_shape(output)[-3:] for output in predictor.outputs])
            break  # Only get output from one autoencoder. Shapes are the same
        return [tuple(shape) for shape in out]

    @property
    def output_shape(self):
        """ The output shape of the model (shape of largest face output) """
        return self.output_shapes[self.largest_face_index]

    @property
    def largest_face_index(self):
        """ Return the index from model.outputs of the largest face
            Required for multi-output model prediction. The largest face
            is assumed to be the final output
        """
        sizes = [shape[1] for shape in self.output_shapes if shape[2] == 3]
        if not sizes:
            return None
        max_face = max(sizes)
        retval = [idx for idx, shape in enumerate(self.output_shapes)
                  if shape[1] == max_face and shape[2] == 3][0]
        logger.debug(retval)
        return retval

    @property
    def largest_mask_index(self):
        """ Return the index from model.outputs of the largest mask
            Required for multi-output model prediction. The largest face
            is assumed to be the final output
        """
        sizes = [shape[1] for shape in self.output_shapes if shape[2] == 1]
        if not sizes:
            return None
        max_mask = max(sizes)
        retval = [idx for idx, shape in enumerate(self.output_shapes)
                  if shape[1] == max_mask and shape[2] == 1][0]
        logger.debug(retval)
        return retval

    @property
    def feed_mask(self):
        """ bool: ``True`` if the model expects a mask to be fed into input otherwise ``False`` """
        return self.config["mask_type"] is not None and (self.config["learn_mask"] or
                                                         self.config["penalized_mask_loss"])

    def load_config(self):
        """ Load the global config for reference in self.config """
        global _CONFIG  # pylint: disable=global-statement
        if not _CONFIG:
            model_name = self.config_section
            logger.debug("Loading config for: %s", model_name)
            _CONFIG = Config(model_name, configfile=self.configfile).config_dict

    def calculate_coverage_ratio(self):
        """ Coverage must be a ratio, leading to a cropped shape divisible by 2 """
        coverage_ratio = self.config.get("coverage", 62.5) / 100
        logger.debug("Requested coverage_ratio: %s", coverage_ratio)
        cropped_size = (self.state.training_size * coverage_ratio) // 2 * 2
        coverage_ratio = cropped_size / self.state.training_size
        logger.debug("Final coverage_ratio: %s", coverage_ratio)
        return coverage_ratio

    def build(self):
        """ Build the model. Override for custom build methods """
        self.add_networks()
        self.load_models(swapped=False)
        inputs = self.get_inputs()
        try:
            self.build_autoencoders(inputs)
        except ValueError as err:
            if "must be from the same graph" in str(err).lower():
                msg = ("There was an error loading saved weights. This is most likely due to "
                       "model corruption during a previous save."
                       "\nYou should restore weights from a snapshot or from backup files. "
                       "You can use the 'Restore' Tool to restore from backup.")
                raise FaceswapError(msg) from err
            if "multi_gpu_model" in str(err).lower():
                raise FaceswapError(str(err)) from err
            raise err
        self.log_summary()
        self.compile_predictors(initialize=True)

    def get_inputs(self):
        """ Return the inputs for the model """
        logger.debug("Getting inputs")
        inputs = [Input(shape=self.input_shape, name="face_in")]
        output_network = [network for network in self.networks.values() if network.is_output][0]
        if self.feed_mask:
            # TODO penalized mask doesn't have a mask output, so we can't use output shapes
            # mask should always be last output..this needs to be a rule
            mask_shape = output_network.output_shapes[-1]
            inputs.append(Input(shape=(mask_shape[1:-1] + (1,)), name="mask_in"))
        logger.debug("Got inputs: %s", inputs)
        return inputs

    def build_autoencoders(self, inputs):
        """ Override for Model Specific autoencoder builds

            Inputs is defined in self.get_inputs() and is standardized for all models
                if will generally be in the order:
                [face (the input for image),
                 mask (the input for mask if it is used)]
        """
        raise NotImplementedError

    def add_networks(self):
        """ Override to add neural networks """
        raise NotImplementedError

    def load_state_info(self):
        """ Load the input shape from state file if it exists """
        logger.debug("Loading Input Shape from State file")
        if not self.state.inputs:
            logger.debug("No input shapes saved. Using model config")
            return
        if not self.state.face_shapes:
            logger.warning("Input shapes stored in State file, but no matches for 'face'."
                           "Using model config")
            return
        input_shape = self.state.face_shapes[0]
        logger.debug("Setting input shape from state file: %s", input_shape)
        self.input_shape = input_shape

    def add_network(self, network_type, side, network, is_output=False):
        """ Add a NNMeta object """
        logger.debug("network_type: '%s', side: '%s', network: '%s', is_output: %s",
                     network_type, side, network, is_output)
        filename = "{}_{}".format(self.name, network_type.lower())
        name = network_type.lower()
        if side:
            side = side.lower()
            filename += "_{}".format(side.upper())
            name += "_{}".format(side)
        filename += ".h5"
        logger.debug("name: '%s', filename: '%s'", name, filename)
        self.networks[name] = NNMeta(str(self.model_dir / filename),
                                     network_type,
                                     side,
                                     network,
                                     is_output)

    def add_predictor(self, side, model):
        """ Add a predictor to the predictors dictionary """
        logger.debug("Adding predictor: (side: '%s', model: %s)", side, model)
        if self.gpus > 1:
            logger.debug("Converting to multi-gpu: side %s", side)
            model = multi_gpu_model(model, self.gpus)
        self.predictors[side] = model
        if not self.state.inputs:
            self.store_input_shapes(model)

    def store_input_shapes(self, model):
        """ Store the input and output shapes to state """
        logger.debug("Adding input shapes to state for model")
        inputs = {tensor.name: K.int_shape(tensor)[-3:] for tensor in model.inputs}
        if not any(inp for inp in inputs.keys() if inp.startswith("face")):
            raise ValueError("No input named 'face' was found. Check your input naming. "
                             "Current input names: {}".format(inputs))
        # Make sure they are all ints so that it can be json serialized
        inputs = {key: tuple(int(i) for i in val) for key, val in inputs.items()}
        self.state.inputs = inputs
        logger.debug("Added input shapes: %s", self.state.inputs)

    def reset_pingpong(self):
        """ Reset the models for pingpong training """
        logger.debug("Resetting models")

        # Clear models and graph
        self.predictors = dict()
        K.clear_session()

        # Load Models for current training run
        for model in self.networks.values():
            model.network = Model.from_config(model.config)
            model.network.set_weights(model.weights)

        inputs = self.get_inputs()
        self.build_autoencoders(inputs)
        self.compile_predictors(initialize=False)
        logger.debug("Reset models")

    def compile_predictors(self, initialize=True):
        """ Compile the predictors """
        logger.debug("Compiling Predictors")
        learning_rate = self.config.get("learning_rate", 5e-5)
        optimizer = self.get_optimizer(lr=learning_rate, beta_1=0.5, beta_2=0.999)

        for side, model in self.predictors.items():
            loss = Loss(model.inputs, model.outputs)
            model.compile(optimizer=optimizer, loss=loss.funcs)
            if initialize:
                self.state.add_session_loss_names(side, loss.names)
                self.history[side] = list()
        logger.debug("Compiled Predictors. Losses: %s", loss.names)

    def get_optimizer(self, lr=5e-5, beta_1=0.5, beta_2=0.999):  # pylint: disable=invalid-name
        """ Build and return Optimizer """
        opt_kwargs = dict(lr=lr, beta_1=beta_1, beta_2=beta_2)
        if (self.config.get("clipnorm", False) and
                keras.backend.backend() != "plaidml.keras.backend"):
            # NB: Clip-norm is ballooning VRAM usage, which is not expected behavior
            # and may be a bug in Keras/Tensorflow.
            # PlaidML has a bug regarding the clip-norm parameter
            # See: https://github.com/plaidml/plaidml/issues/228
            # Workaround by simply removing it.
            # TODO: Remove this as soon it is fixed in PlaidML.
            opt_kwargs["clipnorm"] = 1.0
        logger.debug("Optimizer kwargs: %s", opt_kwargs)
        return Adam(**opt_kwargs, cpu_mode=self.vram_savings.optimizer_savings)

    def converter(self, swap):
        """ Converter for autoencoder models """
        logger.debug("Getting Converter: (swap: %s)", swap)
        side = "a" if swap else "b"
        model = self.predictors[side]
        if self.predict:
            # Must compile the model to be thread safe
            model._make_predict_function()  # pylint: disable=protected-access
        retval = model.predict
        logger.debug("Got Converter: %s", retval)
        return retval

    @property
    def iterations(self):
        "Get current training iteration number"
        return self.state.iterations

    def map_models(self, swapped):
        """ Map the models for A/B side for swapping """
        logger.debug("Map models: (swapped: %s)", swapped)
        models_map = {"a": dict(), "b": dict()}
        sides = ("a", "b") if not swapped else ("b", "a")
        for network in self.networks.values():
            if network.side == sides[0]:
                models_map["a"][network.type] = network.filename
            if network.side == sides[1]:
                models_map["b"][network.type] = network.filename
        logger.debug("Mapped models: (models_map: %s)", models_map)
        return models_map

    def log_summary(self):
        """ Verbose log the model summaries """
        if self.predict:
            return
        for side in sorted(list(self.predictors.keys())):
            logger.verbose("[%s %s Summary]:", self.name.title(), side.upper())
            self.predictors[side].summary(print_fn=lambda x: logger.verbose("%s", x))
            for name, nnmeta in self.networks.items():
                if nnmeta.side is not None and nnmeta.side != side:
                    continue
                logger.verbose("%s:", name.title())
                nnmeta.network.summary(print_fn=lambda x: logger.verbose("%s", x))

    def do_snapshot(self):
        """ Perform a model snapshot """
        logger.debug("Performing snapshot")
        self.backup.snapshot_models(self.iterations)
        logger.debug("Performed snapshot")

    def load_models(self, swapped):
        """ Load models from file """
        logger.debug("Load model: (swapped: %s)", swapped)

        if not self.models_exist and not self.predict:
            logger.info("Creating new '%s' model in folder: '%s'", self.name, self.model_dir)
            return None
        if not self.models_exist and self.predict:
            logger.error("Model could not be found in folder '%s'. Exiting", self.model_dir)
            exit(0)

        if not self.is_legacy or not self.predict:
            K.clear_session()
        model_mapping = self.map_models(swapped)
        for network in self.networks.values():
            if not network.side:
                is_loaded = network.load()
            else:
                is_loaded = network.load(fullpath=model_mapping[network.side][network.type])
            if not is_loaded:
                break
        if is_loaded:
            logger.info("Loaded model from disk: '%s'", self.model_dir)
        return is_loaded

    def save_models(self):
        """ Backup and save the models """
        logger.debug("Backing up and saving models")
        # Insert a new line to avoid spamming the same row as loss output
        print("")
        save_averages = self.get_save_averages()
        backup_func = self.backup.backup_model if self.should_backup(save_averages) else None
        if backup_func:
            logger.info("Backing up models...")
        executor = futures.ThreadPoolExecutor()
        save_threads = [executor.submit(network.save, backup_func=backup_func)
                        for network in self.networks.values()]
        save_threads.append(executor.submit(self.state.save, backup_func=backup_func))
        futures.wait(save_threads)
        # call result() to capture errors
        _ = [thread.result() for thread in save_threads]
        msg = "[Saved models]"
        if save_averages:
            lossmsg = ["{}_{}: {:.5f}".format(self.state.loss_names[side][0],
                                              side.capitalize(),
                                              save_averages[side])
                       for side in sorted(list(save_averages.keys()))]
            msg += " - Average since last save: {}".format(", ".join(lossmsg))
        logger.info(msg)

    def get_save_averages(self):
        """ Return the average loss since the last save iteration and reset historical loss """
        logger.debug("Getting save averages")
        avgs = dict()
        for side, loss in self.history.items():
            if not loss:
                logger.debug("No loss in self.history: %s", side)
                break
            avgs[side] = sum(loss) / len(loss)
            self.history[side] = list()  # Reset historical loss
        logger.debug("Average losses since last save: %s", avgs)
        return avgs

    def should_backup(self, save_averages):
        """ Check whether the loss averages for all losses is the lowest that has been seen.

            This protects against model corruption by only backing up the model
            if any of the loss values have fallen.
            TODO This is not a perfect system. If the model corrupts on save_iteration - 1
            then model may still backup
        """
        backup = True

        if not save_averages:
            logger.debug("No save averages. Not backing up")
            return False

        for side, loss in save_averages.items():
            if not self.state.lowest_avg_loss.get(side, None):
                logger.debug("Setting initial save iteration loss average for '%s': %s",
                             side, loss)
                self.state.lowest_avg_loss[side] = loss
                continue
            if backup:
                # Only run this if backup is true. All losses must have dropped for a valid backup
                backup = self.check_loss_drop(side, loss)

        logger.debug("Lowest historical save iteration loss average: %s",
                     self.state.lowest_avg_loss)

        if backup:  # Update lowest loss values to the state
            for side, avg_loss in save_averages.items():
                logger.debug("Updating lowest save iteration average for '%s': %s", side, avg_loss)
                self.state.lowest_avg_loss[side] = avg_loss

        logger.debug("Backing up: %s", backup)
        return backup

    def check_loss_drop(self, side, avg):
        """ Check whether total loss has dropped since lowest loss """
        if avg < self.state.lowest_avg_loss[side]:
            logger.debug("Loss for '%s' has dropped", side)
            return True
        logger.debug("Loss for '%s' has not dropped", side)
        return False

    def rename_legacy(self):
        """ Legacy Original, LowMem and IAE models had inconsistent naming conventions
            Rename them if they are found and update """
        legacy_mapping = {"iae": [("IAE_decoder.h5", "iae_decoder.h5"),
                                  ("IAE_encoder.h5", "iae_encoder.h5"),
                                  ("IAE_inter_A.h5", "iae_intermediate_A.h5"),
                                  ("IAE_inter_B.h5", "iae_intermediate_B.h5"),
                                  ("IAE_inter_both.h5", "iae_inter.h5")],
                          "original": [("encoder.h5", "original_encoder.h5"),
                                       ("decoder_A.h5", "original_decoder_A.h5"),
                                       ("decoder_B.h5", "original_decoder_B.h5"),
                                       ("lowmem_encoder.h5", "original_encoder.h5"),
                                       ("lowmem_decoder_A.h5", "original_decoder_A.h5"),
                                       ("lowmem_decoder_B.h5", "original_decoder_B.h5")]}
        if self.name not in legacy_mapping.keys():
            return
        logger.debug("Renaming legacy files")

        set_lowmem = False
        updated = False
        for old_name, new_name in legacy_mapping[self.name]:
            old_path = os.path.join(str(self.model_dir), old_name)
            new_path = os.path.join(str(self.model_dir), new_name)
            if os.path.exists(old_path) and not os.path.exists(new_path):
                logger.info("Updating legacy model name from: '%s' to '%s'", old_name, new_name)
                os.rename(old_path, new_path)
                if old_name.startswith("lowmem"):
                    set_lowmem = True
                updated = True

        if not updated:
            logger.debug("No legacy files to rename")
            return

        self.is_legacy = True
        logger.debug("Creating state file for legacy model")
        self.state.inputs = {"face:0": [64, 64, 3]}
        self.state.training_size = 256
        self.state.config["coverage"] = 62.5
        self.state.config["reflect_padding"] = False
        self.state.config["mask_type"] = None
        self.state.config["mask_blur_kernel"] = 3
        self.state.config["mask_threshold"] = 4
        self.state.config["learn_mask"] = False
        self.state.config["lowmem"] = False
        self.encoder_dim = 1024

        if set_lowmem:
            logger.debug("Setting encoder_dim and lowmem flag for legacy lowmem model")
            self.encoder_dim = 512
            self.state.config["lowmem"] = True

        self.state.replace_config(self.config_changeable_items)
        self.state.save()


class VRAMSavings():
    """ VRAM Saving training methods """
    def __init__(self, pingpong, optimizer_savings, memory_saving_gradients):
        logger.debug("Initializing %s: (pingpong: %s, optimizer_savings: %s, "
                     "memory_saving_gradients: %s)", self.__class__.__name__,
                     pingpong, optimizer_savings, memory_saving_gradients)
        self.is_plaidml = keras.backend.backend() == "plaidml.keras.backend"
        self.pingpong = self.set_pingpong(pingpong)
        self.optimizer_savings = self.set_optimizer_savings(optimizer_savings)
        self.memory_saving_gradients = self.set_gradient_type(memory_saving_gradients)
        logger.debug("Initialized: %s", self.__class__.__name__)

    def set_pingpong(self, pingpong):
        """ Disable pingpong for plaidML users """
        if pingpong and self.is_plaidml:
            logger.warning("Pingpong training not supported on plaidML. Disabling")
            pingpong = False
        logger.debug("pingpong: %s", pingpong)
        if pingpong:
            logger.info("Using Pingpong Training")
        return pingpong

    def set_optimizer_savings(self, optimizer_savings):
        """ Disable optimizer savings for plaidML users """
        if optimizer_savings and self.is_plaidml == "plaidml.keras.backend":
            logger.warning("Optimizer Savings not supported on plaidML. Disabling")
            optimizer_savings = False
        logger.debug("optimizer_savings: %s", optimizer_savings)
        if optimizer_savings:
            logger.info("Using Optimizer Savings")
        return optimizer_savings

    def set_gradient_type(self, memory_saving_gradients):
        """ Monkey-patch Memory Saving Gradients if requested """
        if memory_saving_gradients and self.is_plaidml:
            logger.warning("Memory Saving Gradients not supported on plaidML. Disabling")
            memory_saving_gradients = False
        logger.debug("memory_saving_gradients: %s", memory_saving_gradients)
        if memory_saving_gradients:
            logger.info("Using Memory Saving Gradients")
            from lib.model import memory_saving_gradients
            K.__dict__["gradients"] = memory_saving_gradients.gradients_memory
        return memory_saving_gradients


class Loss():
    """ Holds loss names and functions for an Autoencoder """
    def __init__(self, inputs, outputs):
        logger.debug("Initializing %s: (inputs: %s, outputs: %s)",
                     self.__class__.__name__, inputs, outputs)
        self.inputs = inputs
        self.outputs = outputs
        self.names = self.get_loss_names()
        self.funcs = self.get_loss_functions()
        if len(self.names) > 1:
            self.names.insert(0, "total_loss")
        logger.debug("Initialized: %s", self.__class__.__name__)

    @property
    def loss_dict(self):
        """ Return the loss dict """
        loss_dict = dict(mae=losses.mean_absolute_error,
                         mse=losses.mean_squared_error,
                         logcosh=losses.logcosh,
                         smooth_loss=generalized_loss,
                         l_inf_norm=l_inf_norm,
                         ssim=DSSIMObjective(),
                         gmsd=gmsd_loss,
                         pixel_gradient_diff=gradient_loss)
        return loss_dict

    @property
    def config(self):
        """ Return the global _CONFIG variable """
        return _CONFIG

    @property
    def mask_preprocessing_func(self):
        """ The selected pre-processing function for the mask """
        retval = None
        if self.config.get("mask_blur", False):
            retval = gaussian_blur(max(1, self.mask_shape[1] // 32))
        logger.debug(retval)
        return retval

    @property
    def selected_loss(self):
        """ Return the selected loss function """
        retval = self.loss_dict[self.config.get("loss_function", "mae")]
        logger.debug(retval)
        return retval

    @property
    def selected_mask_loss(self):
        """ Return the selected mask loss function. Currently returns mse
            If a processing function has been requested wrap the loss function
            in loss wrapper """
        loss_func = self.loss_dict["mse"]
        func = self.mask_preprocessing_func
        logger.debug("loss_func: %s, func: %s", loss_func, func)
        retval = mask_loss_wrapper(loss_func, preprocessing_func=func)
        return retval

    @property
    def output_shapes(self):
        """ The shapes of the output nodes """
        return [K.int_shape(output)[1:] for output in self.outputs]

    @property
    def mask_input(self):
        """ Return the mask input or None """
        mask_inputs = [inp for inp in self.inputs if inp.name.startswith("mask")]
        if not mask_inputs:
            return None
        return mask_inputs[0]

    @property
    def mask_shape(self):
        """ Return the mask shape """
        if self.mask_input is None:
            return None
        return K.int_shape(self.mask_input)[1:]

    def get_loss_names(self):
        """ Return the loss names based on model output """
        output_names = [output.name for output in self.outputs]
        logger.debug("Model output names: %s", output_names)
        loss_names = [name[name.find("/") + 1:name.rfind("/")].replace("_out", "")
                      for name in output_names]
        if not all(name.startswith("face") or name.startswith("mask") for name in loss_names):
            # Handle incorrectly named/legacy outputs
            logger.debug("Renaming loss names from: %s", loss_names)
            loss_names = self.update_loss_names()
        loss_names = ["{}_loss".format(name) for name in loss_names]
        logger.debug(loss_names)
        return loss_names

    def update_loss_names(self):
        """ Update loss names if named incorrectly or legacy model """
        output_types = ["mask" if shape[-1] == 1 else "face" for shape in self.output_shapes]
        loss_names = ["{}{}".format(name,
                                    "" if output_types.count(name) == 1 else "_{}".format(idx))
                      for idx, name in enumerate(output_types)]
        logger.debug("Renamed loss names to: %s", loss_names)
        return loss_names

    def get_loss_functions(self):
        """ Set the loss function """
        loss_funcs = []
        for idx, loss_name in enumerate(self.names):
            if loss_name.startswith("mask"):
                loss_funcs.append(self.selected_mask_loss)
            elif self.config["penalized_mask_loss"] and self.config["mask_type"] is not None:
                face_size = self.output_shapes[idx][1]
                mask_size = self.mask_shape[1]
                scaling = face_size / mask_size
                logger.debug("face_size: %s mask_size: %s, mask_scaling: %s",
                             face_size, mask_size, scaling)
                loss_funcs.append(PenalizedLoss(self.mask_input, self.selected_loss,
                                                mask_scaling=scaling,
                                                preprocessing_func=self.mask_preprocessing_func))
            else:
                loss_funcs.append(self.selected_loss)
            logger.debug("%s: %s", loss_name, loss_funcs[-1])
        logger.debug(loss_funcs)
        return loss_funcs


class NNMeta():
    """ Class to hold a neural network and it's meta data

    filename:   The full path and filename of the model file for this network.
    type:       The type of network. For networks that can be swapped
                The type should be identical for the corresponding
                A and B networks, and should be unique for every A/B pair.
                Otherwise the type should be completely unique.
    side:       A, B or None. Used to identify which networks can
                be swapped.
    network:    Define network to this.
    is_output:  Set to True to indicate that this network is an output to the Autoencoder
    """

    def __init__(self, filename, network_type, side, network, is_output):
        logger.debug("Initializing %s: (filename: '%s', network_type: '%s', side: '%s', "
                     "network: %s, is_output: %s", self.__class__.__name__, filename,
                     network_type, side, network, is_output)
        self.filename = filename
        self.type = network_type.lower()
        self.side = side
        self.name = self.set_name()
        self.network = network
        self.is_output = is_output
        self.network.name = self.name
        self.config = network.get_config()  # For pingpong restore
        self.weights = network.get_weights()  # For pingpong restore
        logger.debug("Initialized %s", self.__class__.__name__)

    @property
    def output_shapes(self):
        """ Return the output shapes from the stored network """
        return [K.int_shape(output) for output in self.network.outputs]

    def set_name(self):
        """ Set the network name """
        name = self.type
        if self.side:
            name += "_{}".format(self.side)
        return name

    @property
    def output_names(self):
        """ Return output node names """
        output_names = [output.name for output in self.network.outputs]
        if self.is_output and not any(name.startswith("face_out") for name in output_names):
            # Saved models break if their layer names are changed, so dummy
            # in correct output names for legacy models
            output_names = self.get_output_names()
        return output_names

    def get_output_names(self):
        """ Return the output names based on number of channels and instances """
        output_types = ["mask_out" if K.int_shape(output)[-1] == 1 else "face_out"
                        for output in self.network.outputs]
        output_names = ["{}{}".format(name,
                                      "" if output_types.count(name) == 1 else "_{}".format(idx))
                        for idx, name in enumerate(output_types)]
        logger.debug("Overridden output_names: %s", output_names)
        return output_names

    def load(self, fullpath=None):
        """ Load model """
        fullpath = fullpath if fullpath else self.filename
        logger.debug("Loading model: '%s'", fullpath)
        try:
            network = load_model(self.filename, custom_objects=get_custom_objects())
        except ValueError as err:
            if str(err).lower().startswith("cannot create group in read only mode"):
                self.convert_legacy_weights()
                return True
            logger.warning("Failed loading existing training data. Generating new models")
            logger.debug("Exception: %s", str(err))
            return False
        except OSError as err:  # pylint: disable=broad-except
            logger.warning("Failed loading existing training data. Generating new models")
            logger.debug("Exception: %s", str(err))
            return False
        self.config = network.get_config()
        self.network = network  # Update network with saved model
        self.network.name = self.name
        return True

    def save(self, fullpath=None, backup_func=None):
        """ Save model """
        fullpath = fullpath if fullpath else self.filename
        if backup_func:
            backup_func(fullpath)
        logger.debug("Saving model: '%s'", fullpath)
        self.weights = self.network.get_weights()
        self.network.save(fullpath)

    def convert_legacy_weights(self):
        """ Convert legacy weights files to hold the model topology """
        logger.info("Adding model topology to legacy weights file: '%s'", self.filename)
        self.network.load_weights(self.filename)
        self.save(backup_func=None)
        self.network.name = self.type


class State():
    """ Class to hold the model's current state and autoencoder structure """
    def __init__(self, model_dir, model_name, config_changeable_items,
                 no_logs, pingpong, training_image_size):
        logger.debug("Initializing %s: (model_dir: '%s', model_name: '%s', "
                     "config_changeable_items: '%s', no_logs: %s, pingpong: %s, "
                     "training_image_size: '%s'", self.__class__.__name__, model_dir, model_name,
                     config_changeable_items, no_logs, pingpong, training_image_size)
        self.serializer = get_serializer("json")
        filename = "{}_state.{}".format(model_name, self.serializer.file_extension)
        self.filename = str(model_dir / filename)
        self.name = model_name
        self.iterations = 0
        self.session_iterations = 0
        self.training_size = training_image_size
        self.sessions = dict()
        self.lowest_avg_loss = dict()
        self.inputs = dict()
        self.config = dict()
        self.load(config_changeable_items)
        self.session_id = self.new_session_id()
        self.create_new_session(no_logs, pingpong, config_changeable_items)
        logger.debug("Initialized %s:", self.__class__.__name__)

    @property
    def face_shapes(self):
        """ Return a list of stored face shape inputs """
        return [tuple(val) for key, val in self.inputs.items() if key.startswith("face")]

    @property
    def mask_shapes(self):
        """ Return a list of stored mask shape inputs """
        return [tuple(val) for key, val in self.inputs.items() if key.startswith("mask")]

    @property
    def loss_names(self):
        """ Return the loss names for this session """
        return self.sessions[self.session_id]["loss_names"]

    @property
    def current_session(self):
        """ Return the current session dict """
        return self.sessions[self.session_id]

    @property
    def first_run(self):
        """ Return True if this is the first run else False """
        return self.session_id == 1

    def new_session_id(self):
        """ Return new session_id """
        if not self.sessions:
            session_id = 1
        else:
            session_id = max(int(key) for key in self.sessions.keys()) + 1
        logger.debug(session_id)
        return session_id

    def create_new_session(self, no_logs, pingpong, config_changeable_items):
        """ Create a new session """
        logger.debug("Creating new session. id: %s", self.session_id)
        self.sessions[self.session_id] = {"timestamp": time.time(),
                                          "no_logs": no_logs,
                                          "pingpong": pingpong,
                                          "loss_names": dict(),
                                          "batchsize": 0,
                                          "iterations": 0,
                                          "config": config_changeable_items}

    def add_session_loss_names(self, side, loss_names):
        """ Add the session loss names to the sessions dictionary """
        logger.debug("Adding session loss_names. (side: '%s', loss_names: %s", side, loss_names)
        self.sessions[self.session_id]["loss_names"][side] = loss_names

    def add_session_batchsize(self, batchsize):
        """ Add the session batchsize to the sessions dictionary """
        logger.debug("Adding session batchsize: %s", batchsize)
        self.sessions[self.session_id]["batchsize"] = batchsize

    def increment_iterations(self):
        """ Increment total and session iterations """
        self.iterations += 1
        self.sessions[self.session_id]["iterations"] += 1

    def load(self, config_changeable_items):
        """ Load state file """
        logger.debug("Loading State")
        if not os.path.exists(self.filename):
            logger.info("No existing state file found. Generating.")
            return
        state = self.serializer.load(self.filename)
        self.name = state.get("name", self.name)
        self.sessions = state.get("sessions", dict())
        self.lowest_avg_loss = state.get("lowest_avg_loss", dict())
        self.iterations = state.get("iterations", 0)
        self.training_size = state.get("training_size", 256)
        self.inputs = state.get("inputs", dict())
        self.config = state.get("config", dict())
        logger.debug("Loaded state: %s", state)
        self.replace_config(config_changeable_items)

    def save(self, backup_func=None):
        """ Save iteration number to state file """
        logger.debug("Saving State")
        if backup_func:
            backup_func(self.filename)
        state = {"name": self.name,
                 "sessions": self.sessions,
                 "lowest_avg_loss": self.lowest_avg_loss,
                 "iterations": self.iterations,
                 "inputs": self.inputs,
                 "training_size": self.training_size,
                 "config": _CONFIG}
        self.serializer.save(self.filename, state)
        logger.debug("Saved State")

    def replace_config(self, config_changeable_items):
        """ Replace the loaded config with the one contained within the state file
            Check for any fixed=False parameters changes and log info changes
        """
        global _CONFIG  # pylint: disable=global-statement
        legacy_update = self._update_legacy_config()
        # Add any new items to state config for legacy purposes
        for key, val in _CONFIG.items():
            if key not in self.config.keys():
                logger.info("Adding new config item to state file: '%s': '%s'", key, val)
                self.config[key] = val
        self.update_changed_config_items(config_changeable_items)
        logger.debug("Replacing config. Old config: %s", _CONFIG)
        _CONFIG = self.config
        if legacy_update:
            self.save()
        logger.debug("Replaced config. New config: %s", _CONFIG)
        logger.info("Using configuration saved in state file")

    def _update_legacy_config(self):
        """ Legacy updates for new config additions.

        When new config items are added to the Faceswap code, existing model state files need to be
        updated to handle these new items.

        Current existing legacy update items:

            * loss - If old `dssim_loss` is ``true`` set new `loss_function` to `ssim` otherwise
            set it to `mae`. Remove old `dssim_loss` item

            * masks - If `learn_mask` does not exist then it is set to ``True`` if `mask_type` is
            not ``None`` otherwise it is set to ``False``.

            * masks type - Replace removed masks 'dfl_full' and 'facehull' with `components` mask

        Returns
        -------
        bool
            ``True`` if legacy items exist and state file has been updated, otherwise ``False``
        """
        logger.debug("Checking for legacy state file update")
        priors = ["dssim_loss", "mask_type", "mask_type"]
        new_items = ["loss_function", "learn_mask", "mask_type"]
        updated = False
        for old, new in zip(priors, new_items):
            if old not in self.config:
                logger.debug("Legacy item '%s' not in config. Skipping update", old)
                continue

            # dssim_loss > loss_function
            if old == "dssim_loss":
                self.config[new] = "ssim" if self.config[old] else "mae"
                del self.config[old]
                updated = True
                logger.info("Updated config from legacy dssim format. New config loss "
                            "function: '%s'", self.config[new])
                continue

            # Add learn mask option and set to True if model has "penalized_mask_loss" specified
            if old == "mask_type" and new == "learn_mask" and new not in self.config:
                self.config[new] = self.config["mask_type"] is not None
                updated = True
                logger.info("Added new 'learn_mask' config item for this model. Value set to: %s",
                            self.config[new])
                continue

            # Replace removed masks with most similar equivalent
            if old == "mask_type" and new == "mask_type" and self.config[old] in ("facehull",
                                                                                  "dfl_full"):
                old_mask = self.config[old]
                self.config[new] = "components"
                updated = True
                logger.info("Updated 'mask_type' from '%s' to '%s' for this model",
                            old_mask, self.config[new])

        logger.debug("State file updated for legacy config: %s", updated)
        return updated

    def update_changed_config_items(self, config_changeable_items):
        """ Update any parameters which are not fixed and have been changed """
        if not config_changeable_items:
            logger.debug("No changeable parameters have been updated")
            return
        for key, val in config_changeable_items.items():
            old_val = self.config[key]
            if old_val == val:
                continue
            self.config[key] = val
            logger.info("Config item: '%s' has been updated from '%s' to '%s'", key, old_val, val)