import numpy as np
from six import iteritems
from multiprocessing import cpu_count
from typing import Tuple, List
from logging import Logger
from tensorflow.python.client import device_lib
from keras.models import load_model as load_model_keras
from keras.callbacks import Callback
from keras.engine.training import Model
from pathlib import Path


def load_model(model_path: Path) -> Model:
    return load_model_keras(str(model_path))


def use_multiprocessing() -> Tuple[bool, int]:
    if _get_available_gpus():
        # if GPU is available, use all available CPUs for batch preprocessing
        use_multiprocessing = True
        num_workers = cpu_count()
    else:
        # device = 'CPU'
        use_multiprocessing = False
        num_workers = 1
    return use_multiprocessing, num_workers


def _get_available_gpus() -> List[str]:
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == 'GPU']


class LoggingMetrics(Callback):
    """Callback for logging metrics at the end of each epoch.

    Args:
        logger: Root logger.
    """

    def __init__(self, logger: Logger) -> None:
        Callback.__init__(self)
        self.logger = logger
        self.format_epoch = 'Epoch: {} - {}'
        self.format_keyvalue = '{}: {:0.4f}'
        self.format_separator = ' - '

    def on_epoch_end(self, epoch: int, logs: dict = {}):
        values = self.format_separator.join(
            self.format_keyvalue.format(k, v) for k, v in iteritems(logs)
        )
        msg = self.format_epoch.format(epoch + 1, values)
        self.logger.debug(msg)


class LoggingModels(Callback):
    def __init__(
        self,
        filepath: Path,
        logger: Logger,
        monitor: str = 'val_loss',
        verbose: int = 0,
        save_best_only: bool = False,
        save_weights_only: bool = False,
        mode: str = 'auto',
        period: int = 1,
    ) -> None:
        self.monitor = monitor
        self.verbose = verbose
        self.filepath = filepath
        self.save_best_only = save_best_only
        self.save_weights_only = save_weights_only
        self.period = period
        self.epochs_since_last_save = 0
        self.logger = logger

        if mode not in ['auto', 'min', 'max']:
            self.logger.warning(
                'ModelCheckpoint mode {} is unknown, fallback to auto mode.'.format(mode)
            )
            mode = 'auto'

        if mode == 'min':
            self.monitor_op = np.less
            self.best = np.Inf
        elif mode == 'max':
            self.monitor_op = np.greater
            self.best = -np.Inf
        else:
            if 'acc' in self.monitor or self.monitor.startswith('fmeasure'):
                self.monitor_op = np.greater
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

    def on_epoch_end(self, epoch: int, logs: dict = None):
        logs = logs or {}
        self.epochs_since_last_save += 1
        if self.epochs_since_last_save >= self.period:
            self.epochs_since_last_save = 0
            filepath = Path(str(self.filepath).format(epoch=epoch + 1, **logs))
            if self.save_best_only:
                current = logs.get(self.monitor)
                if current is None:
                    self.logger.warning(
                        'Can save best model only with {} available, skipping.'.format(self.monitor)
                    )
                else:
                    if self.monitor_op(current, self.best):
                        if self.verbose > 0:
                            self.logger.info(
                                '\nEpoch {:05d} {} improved from {:0.5f} to {:0.5f},\nsaving model to {}'.format(
                                    epoch + 1, self.monitor, self.best, current, filepath
                                )
                            )
                        self.best = current
                        if self.save_weights_only:
                            self.model.save_weights(str(filepath), overwrite=True)
                        else:
                            self.model.save(str(filepath), overwrite=True)
                    else:
                        if self.verbose > 0:
                            self.logger.info(
                                '\nEpoch {:05d} {} did not improve from {:0.5f}'.format(
                                    epoch + 1, self.monitor, self.best
                                )
                            )
            else:
                if self.verbose > 0:
                    self.logger.debug(
                        '\nEpoch {:05d} saving model to {}'.format(epoch + 1, filepath)
                    )
                if self.save_weights_only:
                    self.model.save_weights(str(filepath), overwrite=True)
                else:
                    self.model.save(str(filepath), overwrite=True)