from lasagne.layers import Layer, InputLayer, DropoutLayer, DenseLayer, NonlinearityLayer, ElemwiseMergeLayer, SliceLayer, ConcatLayer, get_output, get_all_params
from lasagne.objectives import categorical_crossentropy, squared_error, aggregate
from lasagne.nonlinearities import softmax, rectify, sigmoid
from lasagne.updates import sgd, adagrad, adam, adadelta, apply_nesterov_momentum
from lasagne.init import Constant
import theano.tensor as T
import theano
import numpy as np

from utils import BatchIterator
from neuralforestlayer import NeuralForestLayer

__DEBUG_NO_FOREST__ = False

class ShallowNeuralForest:
    def __init__(self, n_inputs, n_outputs, regression, multiclass=False, depth=5, n_estimators=20, n_hidden=128, learning_rate=0.01, num_epochs=500, pi_iters=20, sgd_iters=10, batch_size=1000, momentum=0.0, dropout=0.0, loss=None, update=adagrad):
        """
        Parameters
        ----------
        n_inputs : number of input features
        n_outputs : number of classes to predict (1 for regression)
            for 2 class classification n_outputs should be 2, not 1
        regression : True for regression, False for classification
        multiclass : not used
        depth : depth of each tree in the ensemble
        n_estimators : number of trees in the ensemble
        n_hidden : number of neurons in the hidden layer
        pi_iters : number of iterations for the iterative algorithm that updates pi
        sgd_iters : number of full iterations of sgd between two consequtive updates of pi
        loss : theano loss function. If None, squared error will be used for regression and
            cross entropy will be used for classification
        update : theano update function
        """
        self._depth = depth
        self._n_estimators = n_estimators
        self._n_hidden = n_hidden
        self._n_outputs = n_outputs
        self._loss = loss
        self._regression = regression
        self._multiclass = multiclass
        self._learning_rate = learning_rate
        self._num_epochs = num_epochs
        self._pi_iters = pi_iters
        self._sgd_iters = sgd_iters
        self._batch_size = batch_size
        self._momentum = momentum
        self._update = update

        self.t_input = T.matrix('input')
        self.t_label = T.matrix('output')

        self._cached_trainable_params = None
        self._cached_params = None

        self._n_net_out = n_estimators * ((1 << depth) - 1)

        self.l_input = InputLayer((None, n_inputs))
        self.l_dense1 = DenseLayer(self.l_input, self._n_hidden, nonlinearity=rectify)
        if dropout != 0:
            self.l_dense1 = DropoutLayer(self.l_dense1, p=dropout)
        if not __DEBUG_NO_FOREST__:
            self.l_dense2 = DenseLayer(self.l_dense1, self._n_net_out, nonlinearity=sigmoid)
            self.l_forest = NeuralForestLayer(self.l_dense2, self._depth, self._n_estimators, self._n_outputs, self._pi_iters)
        else:
            self.l_forest = DenseLayer(self.l_dense1, self._n_outputs, nonlinearity=softmax)

    def _create_functions(self):
        self._update_func = self._update(self._get_loss_function(), self._get_all_trainable_params(), self._learning_rate)
        if self._momentum != 0:
            self._update_func = apply_nesterov_momentum(self._update_func, self._get_all_trainable_params(), self._momentum)
        self._loss_func = self._get_loss_function()
        self._train_function = theano.function([self.t_input, self.t_label], self._get_loss_function(), updates=self._update_func)

    def fit(self, X, y, X_val = None, y_val = None, on_epoch = None, verbose = False):
        """ Train the model

        Parameters
        ----------
        X : input vector for the training set
        y : output vector for the training set. Onehot is required for classification
        X_val : if not None, input vector for the validation set
        y_val : it not None, input vector for the validation set
        on_epoch : a callback that is called after each epoch
            if X_val is None, the signature is (epoch, training_error, accuracy)
            if X_val is not None, the signature is (epoch, training_error, validation_error, accuracy)
            on iterations that update pi the training error is reported for the previous iteration
        verbose : if True, spams current step on each epoch
        """
        self._create_functions()

        X = X.astype(np.float32)
        y = y.astype(np.float32)
        self._x_mean = np.mean(X, axis=0)
        self._x_std = np.std(X, axis=0)
        self._x_std[self._x_std == 0] = 1
        X = (X - self._x_mean) / self._x_std
        if y_val is not None:
            assert X_val is not None
            X_val = X_val.astype(np.float32)
            y_val = y_val.astype(np.float32)
            X_val = (X_val - self._x_mean) / self._x_std

        if X_val is not None:
            assert y_val is not None

            predictions = self._predict_internal(self._get_output())
            accuracy = T.mean(T.eq(predictions, self._predict_internal(self.t_label)))

            test_function = theano.function([self.t_input, self.t_label], [self._get_loss_function(), accuracy])

        iterator = BatchIterator(self._batch_size)

        loss = 0
        for epoch in range(self._num_epochs):

            # update the values of pi
            if not __DEBUG_NO_FOREST__ and epoch % self._sgd_iters == 0:
                if verbose: print "updating pi"
                self.l_forest.update_pi(X, y)
                if verbose: print "recreating update funcs"
                self._create_functions()

            else:
                if verbose: print "updating theta"
                loss = 0
                deno = 0
                # update the network parameters
                for Xb, yb in iterator(X, y):
                    loss += self._train_function(Xb, yb)
                    deno += 1

                loss /= deno

            if X_val is not None:
                tloss = 0
                accur = 0
                deno = 0
                iterator = BatchIterator(self._batch_size)
                for Xb, yb in iterator(X_val, y_val):
                    tl, ac = test_function(Xb, yb)
                    tloss += tl
                    accur += ac
                    deno += 1
                tloss /= deno
                accur /= deno

            if on_epoch is not None:
                if X_val is None:
                    on_epoch(epoch, loss)
                else:
                    on_epoch(epoch, loss, tloss, accur)

        return self

    def _predict_internal(self, y):
        if not self._regression and not self._multiclass:
            return y.argmax(axis=1)
        else:
            return y >= 0.5

    def predict(self, X):
        ret = self.predict_proba(X)
        return self._predict_internal(ret)

    def predict_proba(self, X):
        X = X.astype(np.float32)
        X = (X - self._x_mean) / self._x_std
        predict_function = theano.function([self.t_input], self._get_output())
        return predict_function(X)

    def _get_loss_function(self):
        # TODO: remove `or True`
        if self._loss is None:
            if self._regression:
                self._loss = squared_error
            else:
                self._loss = categorical_crossentropy
        return aggregate(self._loss(self._get_output(), self.t_label), mode='mean')

    def _get_output(self):
        return get_output(self.l_forest, self.t_input)

    def _get_all_trainable_params(self):
        if self._cached_trainable_params is None:
            self._cached_trainable_params = get_all_params(self.l_forest, trainable=True)
        return self._cached_trainable_params
    
    def _get_all_params(self):
        if self._cached_params is None:
            self._cached_params = get_all_params(self.l_forest)
        return self._cached_params