from __future__ import division
from __future__ import absolute_import
from __future__ import print_function

import simplejson
from threading import Lock
import os


def is_generator(obj):
    import inspect

    return obj is not None and (inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj) or hasattr(obj, 'next') or hasattr(obj, '__next__'))


class Trainer():
    """
    :type job_backend : aetros.backend.JobBackend
    :type settings : dict
    """

    def __init__(self, job_backend):
        self.job_backend = job_backend
        self.input_shape = []

        # training sample count per epoch for generator. same name as in keras fit_generator
        self.samples_per_epoch = None #used when simple code uses a generator
        # validation sample count per epoch for generator. same name as in keras fit_generator
        self.nb_val_samples = None #used when simple code uses a generator
        self.nb_val_steps = None

        self.callbacks = [] #used by simple models
        self.classes = None #set by auto_dataset
        self.output_size = None #for code generator output layer, set by auto_dataset
        self.model = None

        self.job_model = job_backend.get_job_model()
        self.settings = {}

        classes_file = self.job_backend.git.work_tree + '/aetros/job/info/classes.json'
        if os.path.exists(classes_file):
            with open(self.job_backend.git.work_tree + '/aetros/job/info/classes.json') as f:
                self.classes = simplejson.loads(f.read())
                self.output_size = len(self.classes)

        if 'batchSize' in self.job_backend.job['config']:
            self.settings['batchSize'] = self.job_backend.job['config']['batchSize']

        if 'epochs' in self.job_backend.job['config']:
            self.settings['epochs'] = self.job_backend.job['config']['epochs']

        self.lock = Lock()

    @property
    def logger(self):
        return self.job_backend.logger

    def set_model(self, model):
        self.model = model

    def get_batch_size(self):
        return self.job_backend.get_job_model().get_batch_size()

    def set_generator_validation_nb(self, number):
        """
        sets self.nb_val_samples which is used in model.fit if input is a generator
        :param number:
        :return:
        """

        self.nb_val_samples = number
        diff_to_batch = number % self.get_batch_size()
        if diff_to_batch > 0:
            self.nb_val_samples += self.get_batch_size() - diff_to_batch

        import keras
        if '1' != keras.__version__[0]:
            self.nb_val_samples = self.nb_val_samples // self.get_batch_size()

    def set_generator_training_nb(self, number):
        """
        sets self.samples_per_epoch which is used in model.fit if input is a generator
        :param number:
        :return:
        """

        self.samples_per_epoch = number
        diff_to_batch = number % self.get_batch_size()
        if diff_to_batch > 0:
            self.samples_per_epoch += self.get_batch_size() - diff_to_batch

    def set_status(self, status):
        self.job_backend.set_status(status)

    def set_info(self, name, value):
        self.job_backend.set_info(name, value)

    def has_generator(self, dict):
        for v in dict.values():
            if is_generator(v):
                return True

        return False

    def get_first_generator(self, dict):
        for v in dict.values():
            if is_generator(v):
                return v

        return None

    def set_job_system_info(self, key, value):
        self.job_backend.set_system_info(key, value)

    def set_job_info(self, key, value):
        self.job_backend.set_info(key, value)