"""
The Trainer class prepares and launches training of a model.
Most importantly, it compiles the tf.keras Model object with according to the
specified optimizer, loss and metrics and implements the .fit method for
training the model given a set of parameters and (non-initialized) callbacks.
"""

from multiprocessing import cpu_count
from tensorflow.keras import optimizers, losses
from tensorflow.python.framework.errors_impl import (ResourceExhaustedError,
                                                     InternalError)
from MultiPlanarUNet.callbacks import (init_callback_objects,
                                       remove_validation_callbacks)
from MultiPlanarUNet.logging import ScreenLogger
from MultiPlanarUNet.callbacks import (DividerLine, LearningCurve)
from MultiPlanarUNet.utils import ensure_list_or_tuple
from MultiPlanarUNet.train.utils import (ensure_sparse, init_losses,
                                         init_metrics)
from utime.callbacks import Validation
from utime.train.utils import get_steps


class Trainer(object):
    """
    Handles initialization and logging of model fitting sessions.
    """
    def __init__(self, model, org_model=None, logger=None):
        """
        Init. simply accepts a model and stores it.
        Optionally, an 'org_model' (original model) may be passed and stored
        as well. This is for training multi-GPU models prepared by the
        tf.keras.utils.multi_gpu_model utility, which returns a new, split
        model for training (passed as 'model' parameter here). For properly
        saving the model parameter, however, it is recommended to use the
        original, non-split model (here passed as 'org_model').

        Args:
            model:      (tf.keras Model) Initialized model to train
            org_model:  (tf.keras Model) Optional single-GPU version for the
                                         passed 'model' parameter.
            logger:     (Logger)         Optional Logger instance
        """
        self.model = model
        self.logger = logger if logger is not None else ScreenLogger()

        # Extra reference to original (non multiple-GPU) model
        # May also be set from a script at a later time (before self.fit call)
        self.org_model = org_model

    def compile_model(self, optimizer, optimizer_kwargs, loss, metrics, **kwargs):
        """
        Compile the stored tf.keras Model instance stored in self.model
        Sets the loss function, optimizer and metrics

        Args:
            optimizer:        (string) The name of a tf.keras.optimizers Optimizer
            optimizer_kwargs: (dict)   Key-word arguments passed to the Optimizer
            loss:             (string) The name of a tf.keras.losses or
                                       MultiPlanarUnet loss function
            metrics:          (list)   List of tf.keras.metrics or
                                       MultiPlanarUNet metrics.
            **kwargs:         (dict)   Key-word arguments passed to losses
                                       and/or metrics that accept such.
        """
        # Make sure sparse metrics and loss are specified as sparse
        metrics = ensure_list_or_tuple(metrics)
        losses = ensure_list_or_tuple(loss)
        ensure_sparse(metrics+losses)

        # Initialize optimizer
        optimizer = optimizers.__dict__[optimizer]
        optimizer = optimizer(**optimizer_kwargs)

        # Initialize loss(es) and metrics from tf.keras or MultiPlanarUNet
        losses = init_losses(losses, self.logger, **kwargs)
        metrics = init_metrics(metrics, self.logger, **kwargs)

        # Compile the model
        self.model.compile(optimizer=optimizer, loss=losses, metrics=metrics)
        self.logger("Optimizer:   %s" % optimizer)
        self.logger("Loss funcs:  %s" % losses)
        self.logger("Metrics:     %s" % init_metrics)
        return self

    def fit(self, batch_size, **fit_kwargs):
        """
        Fit the stored tf.keras Model (self.model) on a set of data.

        The 'fit' method is a wrapper around the hidden '_fit' method. It
        handles KeyboardInterrupts (--> stopping training prematurely), TF
        GPU memory errors (--> batch_size is reduced by 2 and training
        restarted), and other exceptions (--> error logged and training
        terminated).

        Please refer to the self._fit method for 'fit_kwargs' argument details.

        Args:
            batch_size: (int)  The initial batch size to run training with
            fit_kwargs: (dict) Keyword arguments passed to self._fit
        """
        fitting = True
        while fitting:
            try:
                self._fit(batch_size=batch_size, **fit_kwargs)
                fitting = False
            except (ResourceExhaustedError, InternalError):
                # Reduce batch size
                batch_size -= 2
                self.logger("\n\n[MEMORY ERROR] Reducing batch size "
                            "by 2 (now %i)" % batch_size)
                if batch_size < 1:
                    self.logger("[ERROR] Batch size negative or zero!")
                    fitting = False
            except KeyboardInterrupt:
                fitting = False
            except Exception as e:
                self.logger(e)
                raise e
        self.logger.print_calling_method = True
        self.logger("Training stopped.")
        return self.model

    def _fit(self,
             train,
             val,
             batch_size,
             n_epochs,
             callbacks,
             train_samples_per_epoch,
             val_samples_per_epoch,
             verbose=1,
             init_epoch=0,
             use_multiprocessing=False,
             **unused):
        """
        Args:
            train: (Sequence)       The training Sequence object
            val    (Sequence, None) The validation Sequence object or None if no
                                    validation is to be performed
            batch_size: (int)       The batch size to use for training
            n_epochs: (int)         Number of epochs to train for
            callbacks: (list)       List of uninitialized callback kwargs.
            train_samples_per_epoch: (int) Number of training samples to sample
                                           before an epoch is determined over.
            val_samples_per_epoch:   (int) Same as 'train_samples_per_epoch'
            verbose: (int/bool)     Verbosity level passed to keras.fit_generator
            init_epoch: (int)       The initial epoch
            use_multiprocessing: (bool) Whether to use multiprocessing instead
                                        of multithreading.
        """
        train.batch_size = batch_size
        train_steps = get_steps(train_samples_per_epoch, train)
        self.logger("Using %i steps per train epoch (total batches=%i)" %
                    (train_steps, len(train)))

        if val is None or len(val) == 0:
            # No validation to be performed, remove callbacks that might need
            # validation data to function properly
            remove_validation_callbacks(callbacks, self.logger)
        else:
            val.batch_size = batch_size
            val_steps = get_steps(val_samples_per_epoch, val)
            self.logger("Using %i steps per validation epoch "
                        "(total batches=%i)" % (val_steps, len(val)))
            # Add validation callback
            # Important: Should be first in callbacks list as other CBs may
            # depend on the validation metrics/loss
            validation = Validation(val,
                                    steps=val_steps,
                                    logger=self.logger,
                                    verbose=verbose)
            callbacks = [validation] + callbacks

        # Callback for plotting learning curves
        callbacks.append(LearningCurve(logger=self.logger))
        callbacks = callbacks + [DividerLine(self.logger)]

        # Get initialized callback objects
        callbacks, cb_dict = init_callback_objects(callbacks, self.logger)

        # If ModelCheckPointClean is used, set the original model to store
        # the correct weights when using multi-GPU models
        cb = cb_dict.get("ModelCheckPointClean")
        if cb:
            cb.org_model = self.org_model

        # Fit the model
        self.logger.active_log_file = "training"
        self.logger.print_calling_method = False
        self.model.fit_generator(
            generator=train,
            steps_per_epoch=train_steps,
            epochs=n_epochs,
            callbacks=callbacks,
            initial_epoch=init_epoch,
            use_multiprocessing=use_multiprocessing,  # Normally False
            workers=min(7, cpu_count()-1),
            max_queue_size=25,
            shuffle=False,  # Determined by the chosen Sequence class
            verbose=verbose
        )