from __future__ import absolute_import
import os

import h5py
from keras.models import Sequential
from keras.optimizers import Adadelta
from keras.layers.core import Dense, Activation

from . import kerashack


__all__ = [
    'TrainingRun',
]


class TrainingRun(object):

    def __init__(self, filename, model, epochs_completed, chunks_completed, num_chunks):
        self.filename = filename
        self.model = model
        self.epochs_completed = epochs_completed
        self.chunks_completed = chunks_completed
        self.num_chunks = num_chunks

    def save(self):
        # Backup the original file in case something goes wrong while
        # saving the new checkpoint.
        backup = None
        if os.path.exists(self.filename):
            backup = self.filename + '.bak'
            os.rename(self.filename, backup)

        output = h5py.File(self.filename, 'w')
        model_out = output.create_group('model')
        kerashack.save_model_to_hdf5_group(self.model, model_out)
        metadata = output.create_group('metadata')
        metadata.attrs['epochs_completed'] = self.epochs_completed
        metadata.attrs['chunks_completed'] = self.chunks_completed
        metadata.attrs['num_chunks'] = self.num_chunks
        output.close()

        # If we got here, we no longer need the backup.
        if backup is not None:
            os.unlink(backup)

    def complete_chunk(self):
        self.chunks_completed += 1
        if self.chunks_completed == self.num_chunks:
            self.epochs_completed += 1
            self.chunks_completed = 0
        self.save()

    @classmethod
    def load(cls, filename):
        inp = h5py.File(filename, 'r')
        model = kerashack.load_model_from_hdf5_group(inp['model'])
        training_run = cls(filename,
                           model,
                           inp['metadata'].attrs['epochs_completed'],
                           inp['metadata'].attrs['chunks_completed'],
                           inp['metadata'].attrs['num_chunks'])
        inp.close()
        return training_run

    @classmethod
    def create(cls, filename, index, layer_fn):
        model = Sequential()
        for layer in layer_fn((7, 19, 19)):
            model.add(layer)
        model.add(Dense(19 * 19))
        model.add(Activation('softmax'))
        opt = Adadelta(clipnorm=0.25)
        model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])
        training_run = cls(filename, model, 0, 0, index.num_chunks)
        training_run.save()
        return training_run