import keras.callbacks as callbacks

import h5py
import numpy as np
import yaml


class MetaCheckpoint(callbacks.ModelCheckpoint):
    """
    Checkpoints some training information with the model. This should enable
    resuming training and having training information on every checkpoint.

    Thanks to Roberto Estevao @robertomest - robertomest@poli.ufrj.br
    """

    def __init__(self, filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', period=1, training_args=None, meta=None):

        super(MetaCheckpoint, self).__init__(filepath, monitor='val_loss',
                                             verbose=0, save_best_only=False,
                                             save_weights_only=False,
                                             mode='auto', period=1)

        self.filepath = filepath
        self.meta = meta or {'epochs': []}

        if training_args:
            training_args = vars(training_args)

            self.meta['training_args'] = training_args

    def on_train_begin(self, logs={}):
        super(MetaCheckpoint, self).on_train_begin(logs)

    def on_epoch_end(self, epoch, logs={}):
        super(MetaCheckpoint, self).on_epoch_end(epoch, logs)

        # Get statistics
        self.meta['epochs'].append(epoch)
        for k, v in logs.items():
            # Get default gets the value or sets (and gets) the default value
            self.meta.setdefault(k, []).append(v)

        # Save to file
        filepath = self.filepath.format(epoch=epoch, **logs)

        if self.epochs_since_last_save == 0:
            with h5py.File(filepath, 'r+') as f:
                meta_group = f.create_group('meta')
                meta_group.attrs['training_args'] = yaml.dump(
                    self.meta.get('training_args', '{}'))
                meta_group.create_dataset('epochs',
                                          data=np.array(self.meta['epochs']))
                for k in logs:
                    meta_group.create_dataset(k, data=np.array(self.meta[k]))


class ProgbarLogger(callbacks.ProgbarLogger):

    def __init__(self, show_metrics=None):
        super(ProgbarLogger, self).__init__()

        self.show_metrics = show_metrics

    def on_train_begin(self, logs=None):
        super(ProgbarLogger, self).on_train_begin(logs)

        if self.show_metrics:
            self.params['metrics'] = self.show_metrics