from __future__ import print_function

from lasagne.init import floatX
import theano.tensor as T

from dcase_task2.lasagne_wrapper.optimization_objectives import mean_categorical_crossentropy, mean_pixel_binary_crossentropy, \
    mean_pixel_categorical_crossentropy
from dcase_task2.lasagne_wrapper.learn_rate_shedules import get_constant
from dcase_task2.lasagne_wrapper.parameter_updates import get_update_adam
from dcase_task2.lasagne_wrapper.batch_iterators import get_batch_iterator


def get_classification_TrainingStrategy(**kwargs):
    """
    Defines training strategy for neural network
    """
    return TrainingStrategy(**kwargs)


def get_binary_segmentation_TrainingStrategy(**kwargs):
    """
    Defines training strategy for neural network
    """    
    return TrainingStrategy(y_tensor_type=T.tensor4, objective=mean_pixel_binary_crossentropy, report_dices=True, **kwargs)


def get_categorical_segmentation_TrainingStrategy(**kwargs):
    """
    Defines training strategy for neural network
    """    
    return TrainingStrategy(y_tensor_type=T.itensor4, objective=mean_pixel_categorical_crossentropy, **kwargs)


def get_next_step_one_hot_TrainingStrategy(**kwargs):
    """
    Defines training strategy for next step prediction of classes (e.g. one hot vector character rnns)
    """
    return TrainingStrategy(**kwargs)


class RefinementStrategy(object):
    """
    Defines refinement strategy for neural network
    Once a model does not improve anymore during training this will be applied
    """

    def __init__(self, n_refinement_steps=2, refinement_patience=5, learn_rate_multiplier=0.1):
        """
        Constructor
        """
        self.n_refinement_steps = n_refinement_steps
        self.refinement_patience = refinement_patience
        self.learn_rate_multiplier = learn_rate_multiplier

    def adapt_learn_rate(self, lr):
        """ Update learning rate """
        self.n_refinement_steps -= 1
        return floatX(self.learn_rate_multiplier * lr)


class TrainingStrategy(object):
    """
    Defines training strategy for neural network
    """

    def __init__(self, batch_size=100, ini_learning_rate=0.001, max_epochs=100, patience=10, y_tensor_type=T.ivector,
                 L2=1e-4, objective=mean_categorical_crossentropy, adapt_learn_rate=get_constant(),
                 update_function=get_update_adam(), valid_batch_iter=get_batch_iterator(),
                 train_batch_iter=get_batch_iterator(), use_weights=False, samples_per_epoch=None,
                 shuffle_train=True, report_dices=False, refinement_strategy=RefinementStrategy(),
                 best_model_by_accurary=False, debug_mode=False, layer_update_filter=None, report_map=3):
        """
        Constructor
        """
        self.batch_size = batch_size
        self.ini_learning_rate = ini_learning_rate
        self.max_epochs = max_epochs
        self.patience = patience
        self.y_tensor_type = y_tensor_type
        self.L2 = L2
        self.objective = objective
        self.adapt_learn_rate = adapt_learn_rate
        self.update_function = update_function
        self.valid_batch_iter = valid_batch_iter
        self.train_batch_iter = train_batch_iter
        self.use_weights = use_weights
        self.samples_per_epoch = samples_per_epoch
        self.shuffle_train = shuffle_train
        self.report_dices = report_dices
        self.refinement_strategy = refinement_strategy
        self.best_model_by_accurary = best_model_by_accurary
        self.debug_mode = debug_mode
        self.layer_update_filter = layer_update_filter
        self.report_map = report_map

    def update_learning_rate(self, lr, epoch):
        """ Update learning rate """
        return self.adapt_learn_rate(lr, epoch)

    def update_parameters(self, all_grads, all_params, learning_rate):
        """ Compute updates from gradients """
        return self.update_function(all_grads, all_params, learning_rate)

    def build_valid_batch_iterator(self):
        """ Compile batch iterator """
        return self.valid_batch_iter(self.batch_size, k_samples=None, shuffle=False)

    def build_train_batch_iterator(self):
        """ Compile batch iterator """
        return self.train_batch_iter(self.batch_size, k_samples=self.samples_per_epoch, shuffle=self.shuffle_train)