from __future__ import print_function

import os
import sys
import time
import pickle
import itertools
import numpy as np

import theano
import lasagne
from lasagne.utils import floatX

import theano.tensor as T

from dcase_task2.lasagne_wrapper.utils import BColors, print_net_architecture
from dcase_task2.lasagne_wrapper.data_pool import DataPool
from dcase_task2.lasagne_wrapper.batch_iterators import threaded_generator_from_iterator


class Network(object):
    """
    Neural Network
    """

    def __init__(self, net, print_architecture=True):
        """
        Constructor
        """
        self.net = net
        self.compute_output = None
        self.compute_output_dict = dict()
        self.saliency_function = None
        self.iter_funcs = None

        # get input shape of network
        l_in = lasagne.layers.helper.get_all_layers(self.net)[0]
        self.input_shape = l_in.output_shape
        
        if print_architecture:
            print_net_architecture(net, detailed=True)

    def fit(self, data, training_strategy, dump_file=None, log_file=None):
        """ Train model """
        print("Training neural network...")
        col = BColors()

        # create data pool if raw data is given
        if "X_train" in data:
            data_pools = dict()
            data_pools['train'] = DataPool(data['X_train'], data['y_train'])
            data_pools['valid'] = DataPool(data['X_valid'], data['y_valid'])
        else:
            data_pools = data

        # check if out_path exists
        if dump_file is not None:
            out_path = os.path.dirname(dump_file)
            if out_path != '' and not os.path.exists(out_path):
                os.mkdir(out_path)

        # log model evolution
        if log_file is not None:
            out_path = os.path.dirname(log_file)
            if out_path != '' and not os.path.exists(out_path):
                os.mkdir(out_path)

        # adaptive learning rate
        learn_rate = training_strategy.ini_learning_rate
        learning_rate = theano.shared(floatX(learn_rate))
        learning_rate.set_value(training_strategy.adapt_learn_rate(training_strategy.ini_learning_rate, 0))

        # initialize evaluation output
        pred_tr_err, pred_val_err, overfitting = [], [], []
        tr_accs, va_accs = [], []

        print("Compiling theano train functions...")
        if self.iter_funcs is None:
            self.iter_funcs = self._create_iter_functions(y_tensor_type=training_strategy.y_tensor_type,
                                                          objective=training_strategy.objective, learning_rate=learning_rate,
                                                          l_2=training_strategy.L2,
                                                          compute_updates=training_strategy.update_parameters,
                                                          use_weights=training_strategy.use_weights,
                                                          debug_mode=training_strategy.debug_mode,
                                                          layer_update_filter=training_strategy.layer_update_filter)

        print("Starting training...")
        now = time.time()
        try:

            # initialize early stopping
            last_improvement = 0
            best_model = lasagne.layers.get_all_param_values(self.net)

            # iterate training epochs
            best_va_dice = 0.0
            prev_tr_loss, prev_va_loss = 1e7, 1e7
            prev_acc_tr, prev_acc_va = 0.0, 0.0
            prev_map_tr, prev_map_va = 0.0, 0.0
            for epoch in self._train(self.iter_funcs, data_pools, training_strategy.build_train_batch_iterator(),
                                     training_strategy.build_valid_batch_iterator(), training_strategy.report_dices,
                                     debug_mode=training_strategy.debug_mode, predict_k=training_strategy.report_map):

                print("Epoch {} of {} took {:.3f}s".format(epoch['number'], training_strategy.max_epochs, time.time() - now))
                now = time.time()

                # --- collect train output ---

                tr_loss, va_loss = epoch['train_loss'], epoch['valid_loss']
                train_acc, valid_acc = epoch['train_acc'], epoch['valid_acc']
                train_map, valid_map = epoch['train_map'], epoch['valid_map']
                train_dices, valid_dices = epoch['train_dices'], epoch['valid_dices']
                overfit = epoch['overfitting']

                # prepare early stopping
                if training_strategy.best_model_by_accurary:
                    improvement = valid_acc > prev_acc_va
                else:
                    improvement = va_loss < prev_va_loss

                if improvement:
                    last_improvement = 0
                    best_model = lasagne.layers.get_all_param_values(self.net)
                    best_epoch = epoch['number']
                    best_opt_state = [_u.get_value() for _u in self.iter_funcs['updates'].keys()]

                    # dump net parameters during training
                    if dump_file is not None:
                        with open(dump_file, 'wb') as fp:
                            pickle.dump(best_model, fp)

                last_improvement += 1

                # print train output
                txt_tr = 'costs_tr %.5f ' % tr_loss
                if tr_loss < prev_tr_loss:
                    txt_tr = col.print_colored(txt_tr, BColors.OKGREEN)
                    prev_tr_loss = tr_loss

                txt_tr_acc = '(%.3f)' % train_acc
                if train_acc > prev_acc_tr:
                    txt_tr_acc = col.print_colored(txt_tr_acc, BColors.OKGREEN)
                    prev_acc_tr = train_acc
                txt_tr += txt_tr_acc + ', '
                
                txt_val = 'costs_val %.5f ' % va_loss
                if va_loss < prev_va_loss:
                    txt_val = col.print_colored(txt_val, BColors.OKGREEN)
                    prev_va_loss = va_loss

                txt_va_acc = '(%.3f)' % valid_acc
                if valid_acc > prev_acc_va:
                    txt_va_acc = col.print_colored(txt_va_acc, BColors.OKGREEN)
                    prev_acc_va = valid_acc
                txt_val += txt_va_acc + ', '

                txt_tr_map = 'tr-map@%d %.3f' % (training_strategy.report_map, train_map)
                if train_map > prev_map_tr:
                    txt_tr_map = col.print_colored(txt_tr_map, BColors.OKGREEN)
                    prev_map_tr = train_map

                txt_va_map = 'va-map@%d %.3f' % (training_strategy.report_map, valid_map)
                if valid_map > prev_map_va:
                    txt_va_map = col.print_colored(txt_va_map, BColors.OKGREEN)
                    prev_map_va = valid_map

                txt_map = "%s, %s" % (txt_tr_map, txt_va_map)

                print('  lr: %.7f, patience: %d' % (learn_rate, training_strategy.patience - last_improvement + 1))
                print('  ' + txt_tr + txt_val + 'tr/val %.3f' % overfit)
                print('  ' + txt_map)

                # report dice coefficients
                if training_strategy.report_dices:

                    train_str = '  train  |'
                    for key in np.sort(train_dices.keys()):
                        train_str += ' %.2f: %.3f |' % (key, train_dices[key])
                    print(train_str)
                    train_acc = np.max(train_dices.values())

                    valid_str = '  valid  |'
                    for key in np.sort(valid_dices.keys()):
                        txt_va_dice = ' %.2f: %.3f |' % (key, valid_dices[key])
                        if valid_dices[key] > best_va_dice and valid_dices[key] == np.max(valid_dices.values()):
                            best_va_dice = valid_dices[key]
                            txt_va_dice = col.print_colored(txt_va_dice, BColors.OKGREEN)
                        valid_str += txt_va_dice
                    print(valid_str)
                    valid_acc = np.max(valid_dices.values())

                # report map@k
                if training_strategy.report_map:
                    pass

                # collect model evolution data
                tr_accs.append(train_acc)
                va_accs.append(valid_acc)
                pred_tr_err.append(tr_loss)
                pred_val_err.append(va_loss)
                overfitting.append(overfit)
                
                # save results
                exp_res = dict()
                exp_res['pred_tr_err'] = pred_tr_err
                exp_res['tr_accs'] = tr_accs
                exp_res['pred_val_err'] = pred_val_err
                exp_res['va_accs'] = va_accs
                exp_res['overfitting'] = overfitting
                
                if log_file is not None:
                    with open(log_file, 'w') as fp:
                        pickle.dump(exp_res, fp)                
                
                # --- early stopping: preserve best model ---
                if last_improvement > training_strategy.patience:
                    print(col.print_colored("Early Stopping!", BColors.WARNING))
                    status = "Epoch: %d, Best Validation Loss: %.5f: Acc: %.5f" % (
                    best_epoch, prev_va_loss, prev_acc_va)
                    print(col.print_colored(status, BColors.WARNING))

                    if training_strategy.refinement_strategy.n_refinement_steps <= 0:
                        break

                    else:

                        status = "Loading best parameters so far and refining (%d) with decreased learn rate ..." % \
                                 training_strategy.refinement_strategy.n_refinement_steps
                        print(col.print_colored(status, BColors.WARNING))

                        # reset net to best weights
                        lasagne.layers.set_all_param_values(self.net, best_model)

                        # reset optimizer
                        for _u, value in zip(self.iter_funcs['updates'].keys(), best_opt_state):
                            _u.set_value(value)

                        # update learn rate
                        learn_rate = training_strategy.refinement_strategy.adapt_learn_rate(learn_rate)
                        training_strategy.patience = training_strategy.refinement_strategy.refinement_patience
                        last_improvement = 0

                # maximum number of epochs reached
                if epoch['number'] >= training_strategy.max_epochs:
                    break

                # update learning rate
                learn_rate = training_strategy.adapt_learn_rate(learn_rate, epoch['number'])
                learning_rate.set_value(learn_rate)

        except KeyboardInterrupt:
            pass

        # set net to best weights
        lasagne.layers.set_all_param_values(self.net, best_model)

        # return best validation loss
        if training_strategy.best_model_by_accurary:
            return prev_acc_va
        else:
            return prev_va_loss

    def predict_proba(self, input):
        """
        Predict on test samples
        """

        # prepare input for prediction
        if not isinstance(input, list):
            input = [input]

        # reshape to network input
        if input[0].ndim < len(self.input_shape):
            input[0] = input[0].reshape([1] + list(input[0].shape))

        if self.compute_output is None:
            self.compute_output = self._compile_prediction_function()

        return self.compute_output(*input)

    def predict(self, input):
        """
        Predict class labels on test samples
        """
        return np.argmax(self.predict_proba(input), axis=1)

    def compute_layer_output(self, input, layer):
        """
        Compute output of given layer
        layer: either a string (name of layer) or a layer object
        """

        # prepare input for prediction
        if not isinstance(input, list):
            input = [input]

        # reshape to network input
        if input[0].ndim < len(self.input_shape):
            input[0] = input[0].reshape([1] + list(input[0].shape))

        # get layer by name
        if not isinstance(layer, lasagne.layers.Layer):
            for l in lasagne.layers.helper.get_all_layers(self.net):
                if l.name == layer:
                    layer = l
                    break

        # compile prediction function for target layer
        if layer not in self.compute_output_dict:
            self.compute_output_dict[layer] = self._compile_prediction_function(target_layer=layer)

        return self.compute_output_dict[layer](*input)

    def compute_saliency(self, input, nonlin=lasagne.nonlinearities.rectify):
        """
        Compute saliency maps using guided backprop
        """

        # prepare input for prediction
        if not isinstance(input, list):
            input = [input]

        # reshape to network input
        if input[0].ndim < len(self.input_shape):
            input[0] = input[0].reshape([1] + list(input[0].shape))

        if not self.saliency_function:
            self.saliency_function = self._compile_saliency_function(nonlin)

        return self.saliency_function(*input)

    def save(self, file_path):
        """
        Save model to disk
        """
        with open(file_path, 'w') as fp:
            params = lasagne.layers.get_all_param_values(self.net)
            pickle.dump(params, fp, -1)

    def load(self, file_path):
        """
        load model from disk
        """
        with open(file_path, 'r') as fp:
            params = pickle.load(fp)
        lasagne.layers.set_all_param_values(self.net, params)

    def _compile_prediction_function(self, target_layer=None):
        """
        Compile theano prediction function
        """

        # get network output nad compile function
        if target_layer is None:
            target_layer = self.net

        # collect input vars
        all_layers = lasagne.layers.helper.get_all_layers(target_layer)
        input_vars = []
        for l in all_layers:
            if isinstance(l, lasagne.layers.InputLayer):
                input_vars.append(l.input_var)

        net_output = lasagne.layers.get_output(target_layer, deterministic=True)
        return theano.function(inputs=input_vars, outputs=net_output)

    def _create_iter_functions(self, y_tensor_type, objective, learning_rate, l_2, compute_updates, use_weights,
                               debug_mode, layer_update_filter):
        """ Create functions for training, validation and testing to iterate one epoch. """

        # init target tensor
        targets = y_tensor_type('y')
        weights = y_tensor_type('w').astype("float32")

        # get input layer
        all_layers = lasagne.layers.helper.get_all_layers(self.net)

        # collect input vars
        input_vars = []
        for l in all_layers:
            if isinstance(l, lasagne.layers.InputLayer):
                input_vars.append(l.input_var)

        # compute train costs
        tr_output = lasagne.layers.get_output(self.net, deterministic=False)

        if use_weights:
            tr_cost = objective(tr_output, targets, weights)
            tr_input = input_vars + [targets, weights]
        else:
            tr_cost = objective(tr_output, targets)
            tr_input = input_vars + [targets]

        # regularization costs
        tr_reg_cost = 0

        # regularize RNNs
        for l in all_layers:

            # if l.name == "norm_reg_rnn":
            #
            #     H = lasagne.layers.get_output(l, deterministic=False)
            #     H_l2 = T.sqrt(T.sum(H ** 2, axis=-1))
            #     norm_diffs = (H_l2[:, 1:] - H_l2[:, :-1]) ** 2
            #     norm_preserving_loss = T.mean(norm_diffs)
            #
            #     beta = 1.0
            #     tr_cost += beta * norm_preserving_loss

            if l.name == "norm_reg_rnn":

                H = lasagne.layers.get_output(l, deterministic=False)
                steps = T.arange(1, l.output_shape[1])

                def compute_norm_diff(k, H):
                    n0 = ((H[:, k - 1, :]) ** 2).sum(1).sqrt()
                    n1 = ((H[:, k, :]) ** 2).sum(1).sqrt()
                    return (n1 - n0) ** 2

                norm_diffs, _ = theano.scan(fn=compute_norm_diff, outputs_info=None,
                                            non_sequences=[H], sequences=[steps])

                beta = 1.0
                norm_preserving_loss = T.mean(norm_diffs)
                tr_reg_cost += beta * norm_preserving_loss

        # compute validation costs
        va_output = lasagne.layers.get_output(self.net, deterministic=True)

        # estimate accuracy
        if y_tensor_type == T.ivector:
            va_acc = 100.0 * T.mean(T.eq(T.argmax(va_output, axis=1), targets), dtype=theano.config.floatX)
            tr_acc = 100.0 * T.mean(T.eq(T.argmax(tr_output, axis=1), targets), dtype=theano.config.floatX)

        elif y_tensor_type == T.vector:
            va_acc = 100.0 * T.mean(T.eq(T.ge(va_output.flatten(), 0.5), targets), dtype=theano.config.floatX)
            tr_acc = 100.0 * T.mean(T.eq(T.ge(tr_output.flatten(), 0.5), targets), dtype=theano.config.floatX)

        else:
            va_acc = 100.0 * T.mean(T.eq(T.argmax(va_output, axis=1), T.argmax(targets, axis=1)), dtype=theano.config.floatX)
            tr_acc = 100.0 * T.mean(T.eq(T.argmax(tr_output, axis=1), T.argmax(targets, axis=1)), dtype=theano.config.floatX)

        # collect all parameters of net and compute updates
        all_params = lasagne.layers.get_all_params(self.net, trainable=True)

        # filter parameters to update by layer name
        if layer_update_filter:
            all_params = [p for p in all_params if layer_update_filter in p.name]

        # add weight decay
        if l_2 is not None:
            all_layers = lasagne.layers.get_all_layers(self.net)
            tr_reg_cost += l_2 * lasagne.regularization.regularize_layer_params(all_layers, lasagne.regularization.l2)

        # compute updates
        all_grads = lasagne.updates.get_or_compute_grads(tr_cost + tr_reg_cost, all_params)
        updates = compute_updates(all_grads, all_params, learning_rate)

        # compile iter functions
        tr_outputs = [tr_cost, tr_output]
        if tr_acc is not None:
            tr_outputs.append(tr_acc)
        iter_train = theano.function(tr_input, tr_outputs, updates=updates)

        va_inputs = input_vars + [targets]
        va_cost = objective(va_output, targets)
        va_outputs = [va_cost, va_output]
        if va_acc is not None:
            va_outputs.append(va_acc)
        iter_valid = theano.function(va_inputs, va_outputs)

        # network debugging
        compute_grad_norms = None
        compute_layer_outputs = None
        if debug_mode:

            # compile gradient norm computation for weights
            grad_norms = []
            for i, p in enumerate(all_params):
                if "W" in p.name:
                    g = all_grads[i]
                    grad_norm = T.sqrt(T.sum(g**2))
                    grad_norms.append(grad_norm)
            compute_grad_norms = theano.function(tr_input, grad_norms)

            # compute output of each layer
            layer_outputs = lasagne.layers.get_output(all_layers)
            compute_layer_outputs = theano.function(input_vars, layer_outputs)

        return dict(train=iter_train, valid=iter_valid, test=iter_valid, updates=updates,
                    compute_grad_norms=compute_grad_norms,
                    compute_layer_outputs=compute_layer_outputs)

    def _compile_saliency_function(self, nonlin=lasagne.nonlinearities.rectify):
        """
        Compiles a function to compute the saliency maps and predicted classes
        for a given mini batch of input images.

        in_vars = lin.input_var
        """

        class ModifiedBackprop(object):

            def __init__(self, nonlinearity):
                self.nonlinearity = nonlinearity
                self.ops = {}  # memoizes an OpFromGraph instance per tensor type

            def __call__(self, x):
                # OpFromGraph is oblique to Theano optimizations, so we need to move
                # things to GPU ourselves if needed.
                if theano.sandbox.cuda.cuda_enabled:
                    maybe_to_gpu = theano.sandbox.cuda.as_cuda_ndarray_variable
                else:
                    maybe_to_gpu = lambda x: x
                # We move the input to GPU if needed.
                x = maybe_to_gpu(x)
                # We note the tensor type of the input variable to the nonlinearity
                # (mainly dimensionality and dtype); we need to create a fitting Op.
                tensor_type = x.type
                # If we did not create a suitable Op yet, this is the time to do so.
                if tensor_type not in self.ops:
                    # For the graph, we create an input variable of the correct type:
                    inp = tensor_type()
                    # We pass it through the nonlinearity (and move to GPU if needed).
                    outp = maybe_to_gpu(self.nonlinearity(inp))
                    # Then we fix the forward expression...
                    op = theano.OpFromGraph([inp], [outp])
                    # ...and replace the gradient with our own (defined in a subclass).
                    op.grad = self.grad
                    # Finally, we memoize the new Op
                    self.ops[tensor_type] = op
                # And apply the memorized Op to the input we got.
                return self.ops[tensor_type](x)

        class GuidedBackprop(ModifiedBackprop):
            def grad(self, inputs, out_grads):
                (inp,) = inputs
                (grd,) = out_grads
                dtype = inp.dtype
                return (grd * (inp > 0).astype(dtype) * (grd > 0).astype(dtype),)

        def fix_nonlins(l_out, nonlin):
            """ Replace relus with guided-back-prop """
            nonlin_layers = [layer for layer in lasagne.layers.get_all_layers(l_out)
                             if getattr(layer, 'nonlinearity', None) is nonlin]
            modded_nonlin = GuidedBackprop(nonlin)  # important: only instantiate this once!
            for layer in nonlin_layers:
                layer.nonlinearity = modded_nonlin

            return l_out

        # fix non-linearities
        l_out = fix_nonlins(self.net, nonlin=nonlin)

        # collect input vars
        all_layers = lasagne.layers.helper.get_all_layers(l_out)
        input_vars = []
        for l in all_layers:
            if isinstance(l, lasagne.layers.InputLayer):
                input_vars.append(l.input_var)

        outp = lasagne.layers.get_output(l_out.input_layer, deterministic=True)
        max_outp = T.max(outp, axis=1)
        saliency = theano.grad(max_outp.sum(), wrt=input_vars)

        return theano.function(input_vars, saliency)

    def _train(self, iter_funcs, data_pools, train_batch_iter, valid_batch_iter, estimate_dices, debug_mode,
               predict_k=3):
        """
        Train the model with `dataset` with mini-batch training.
        Each mini-batch has `batch_size` recordings.
        """
        col = BColors()
        from dcase_task2.lasagne_wrapper.segmentation_utils import dice
        from dcase_task2.lasagne_wrapper.evaluation import mapk

        for epoch in itertools.count(1):

            # evaluate various thresholds
            if estimate_dices:
                threshs = [0.3, 0.4, 0.5, 0.6, 0.7]

                tr_dices = dict()
                for thr in threshs:
                    tr_dices[thr] = []

                va_dices = dict()
                for thr in threshs:
                    va_dices[thr] = []

            else:
                tr_dices = None
                va_dices = None

            # iterate train batches
            batch_train_losses, batch_train_accs = [], []
            batch_train_maps = []
            iterator = train_batch_iter(data_pools['train'])
            generator = threaded_generator_from_iterator(iterator)

            batch_times = np.zeros(5, dtype=np.float32)
            start, after = time.time(), time.time()
            for i_batch, train_input in enumerate(generator):
                batch_res = iter_funcs['train'](*train_input)
                batch_train_losses.append(batch_res[0])

                # collect classification accuracies
                if len(batch_res) > 2:
                    batch_train_accs.append(batch_res[2])

                    # compute map
                    y_b = train_input[1].argmax(axis=1) if train_input[1].ndim == 2 else train_input[1]
                    pred = batch_res[1]
                    actual = [[y] for y in y_b]
                    predicted = []
                    for yp in pred:
                        predicted.append(list(np.argsort(yp)[::-1][0:predict_k]))
                    batch_train_maps.append(mapk(actual, predicted, predict_k))

                # estimate dices for various thresholds
                if estimate_dices:
                    y_b = train_input[1]
                    pred = batch_res[1]
                    for thr in threshs:
                        for i in xrange(pred.shape[0]):
                            seg = pred[i, 0] > thr
                            tr_dices[thr].append(100 * dice(seg, y_b[i, 0]))

                # train time
                batch_time = time.time() - after
                after = time.time()
                train_time = (after - start)

                # estimate updates per second (running avg)
                batch_times[0:4] = batch_times[1:5]
                batch_times[4] = batch_time
                ups = 1.0 / batch_times.mean()

                # report loss during training
                perc = 100 * (float(i_batch) / train_batch_iter.n_batches)
                dec = int(perc // 4)
                progbar = "|" + dec * "#" + (25 - dec) * "-" + "|"
                vals = (perc, progbar, train_time, ups, np.mean(batch_train_losses))
                loss_str = " (%d%%) %s time: %.2fs, ups: %.2f, loss: %.5f" % vals
                print(col.print_colored(loss_str, col.WARNING), end="\r")
                sys.stdout.flush()

            # some debug plots on gradients and hidden activations
            if debug_mode:
                import matplotlib.pyplot as plt

                # compute gradient norm for last batch
                grad_norms = iter_funcs['compute_grad_norms'](*train_input)

                plt.figure("Gradient Norms")
                plt.clf()
                plt.plot(grad_norms, "g-", linewidth=3, alpha=0.7)
                plt.grid('on')
                plt.title("Gradient Norms")
                plt.ylabel("Gradient Norm")
                plt.xlabel("Weight $W_l$")
                plt.draw()

                # compute layer output for last batch
                layer_outputs = iter_funcs['compute_layer_outputs'](*train_input[:-1])

                n_outputs = len(layer_outputs)
                sub_plot_dim = np.ceil(np.sqrt(n_outputs))

                plt.figure("Hidden Activation Distributions")
                plt.clf()
                plt.subplots_adjust(bottom=0.05, top=0.98)
                for i, l_out in enumerate(layer_outputs):
                    l_out = np.asarray(l_out).flatten()
                    h, b = np.histogram(l_out, bins='auto')

                    plt.subplot(sub_plot_dim, sub_plot_dim, i + 1)
                    plt.plot(b[:-1], h, "g-", linewidth=3, alpha=0.7,
                             label="%.2f $\pm$ %.5f" % (l_out.mean(), l_out.std()))
                    span = (b[-1] - b[0])
                    x_min = b[0] - 0.05 * span
                    x_max = b[-1] + 0.05 * span
                    plt.xlim([x_min, x_max])
                    plt.legend(fontsize=10)
                    plt.grid('on')
                    plt.yticks([])
                plt.draw()

                plt.pause(0.1)

            print("\x1b[K", end="\r")
            print(' ')
            avg_train_loss = np.mean(batch_train_losses)
            if len(batch_train_accs) > 0:
                avg_train_acc = np.mean(batch_train_accs)
                avg_train_maps = np.mean(batch_train_maps)
            else:
                avg_train_acc = avg_train_maps = 0.0
            if estimate_dices:
                for thr in threshs:
                    tr_dices[thr] = np.mean(tr_dices[thr])

            # evaluate classification power of model

            # iterate validation batches
            batch_valid_losses, batch_valid_accs = [], []
            batch_valid_maps = []
            iterator = valid_batch_iter(data_pools['valid'])
            generator = threaded_generator_from_iterator(iterator)

            batch_wights = []
            for va_input in generator:
                batch_res = iter_funcs['valid'](*va_input)
                batch_valid_losses.append(batch_res[0])
                batch_wights.append(np.float(va_input[0].shape[0]))

                # collect classification accuracies
                if len(batch_res) > 2:
                    batch_valid_accs.append(batch_res[2])

                    # compute map
                    y_b = va_input[1].argmax(axis=1) if va_input[1].ndim == 2 else va_input[1]
                    pred = batch_res[1]
                    actual = [[y] for y in y_b]
                    predicted = []
                    for yp in pred:
                        predicted.append(list(np.argsort(yp)[::-1][0:predict_k]))
                    batch_valid_maps.append(mapk(actual, predicted, predict_k))

                # estimate dices for various thresholds
                if estimate_dices:
                    y_b = va_input[1]
                    pred = batch_res[1]
                    for thr in threshs:
                        for i in xrange(pred.shape[0]):
                            seg = pred[i, :] > thr
                            va_dices[thr].append(100 * dice(seg, y_b[i, :]))

                    # # todo: remove this!
                    # if np.sum(y_b[0, :-1]) > 0:
                    #     print(np.sum(y_b[0, :-1]))
                    #     import matplotlib.pyplot as plt
                    #     plt.figure("Pred", figsize=(16, 8))
                    #     plt.clf()
                    #     c = pred.shape[1]
                    #     for i in range(c):
                    #         plt.subplot(2, c, i + 1)
                    #         plt.imshow(pred[0, i], vmin=0, vmax=1)
                    #         plt.subplot(2, c, i + c + 1)
                    #         plt.imshow(y_b[0, i], vmin=0, vmax=1)
                    #     plt.savefig("epoch_%d.png" % epoch)

            batch_wights = np.asarray(batch_wights) / np.sum(batch_wights)
            batch_valid_losses = np.asarray(batch_valid_losses)
            if len(batch_valid_accs) > 0:
                batch_valid_accs = np.asarray(batch_valid_accs)
                batch_valid_maps = np.asarray(batch_valid_maps)
            else:
                batch_valid_accs = 0.0
                batch_valid_maps = 0.0

            avg_valid_loss = np.average(batch_valid_losses, weights=batch_wights)
            avg_valid_accs = np.average(batch_valid_accs, weights=batch_wights) if len(batch_valid_accs) > 0 else 0.0
            avg_valid_maps = np.average(batch_valid_maps, weights=batch_wights) if len(batch_valid_maps) > 0 else 0.0
            if estimate_dices:
                for thr in threshs:
                    va_dices[thr] = np.average(np.asarray(va_dices[thr]), weights=batch_wights)

            # collect results
            yield {
                'number': epoch,
                'train_loss': avg_train_loss,
                'train_acc': avg_train_acc,
                'valid_loss': avg_valid_loss,
                'valid_acc': avg_valid_accs,
                'valid_dices': va_dices,
                'train_dices': tr_dices,
                'train_map': avg_train_maps,
                'valid_map': avg_valid_maps,
                'overfitting': avg_train_loss / avg_valid_loss,
            }


class SegmentationNetwork(Network):
    """
    Segmentation Neural Network
    """
    
    def predict_proba(self, input, squeeze=True, overlap=0.5):
        """
        Predict on test samples
        """
        if self.compute_output is None:
            self.compute_output = self._compile_prediction_function()
        
        # get network input shape
        l_in = lasagne.layers.helper.get_all_layers(self.net)[0]
        in_shape = l_in.output_shape[-2::]
        
        # standard prediction
        if input.shape[-2::] == in_shape:
            proba = self.compute_output(input)
        
        # sliding window prediction if images do not match
        else:
            proba = self._predict_proba_sliding_window(input, overlap=overlap)
        
        if squeeze:
            proba = proba.squeeze()
        
        return proba

    def predict(self, input, thresh=0.5):
        """
        Predict label map on test samples
        """
        P = self.predict_proba(input, squeeze=False)
        
        # binary segmentation
        if P.shape[1] == 1:
            return (P > thresh).squeeze()
        
        # categorical segmentation
        else:
            return np.argmax(P, axis=1).squeeze()
        
    
    def _predict_proba_sliding_window(self, images, overlap=0.5):
        """
        Sliding window prediction for images larger than the input layer
        """
        images = images.copy()
        n_images = images.shape[0]
        h, w = images.shape[2:4]
        _, Nc, sh, sw = self.net.output_shape

        # pad images for sliding window prediction
        missing_h = int(sh * np.ceil(float(h) / sh) - h)
        missing_w = int(sw * np.ceil(float(w) / sw) - w)

        pad_top = missing_h // 2
        pad_bottom = missing_h - pad_top

        pad_left = missing_w // 2
        pad_right = missing_w - pad_left

        images = np.pad(images, ((0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)), mode='constant')

        step_h = int(sh * (1.0 - overlap))
        row_0 = np.arange(0, images.shape[2] - sh + 1, step_h)
        row_1 = row_0 + sh

        step_w = int(sw * (1.0 - overlap))
        col_0 = np.arange(0, images.shape[3] - sw + 1, step_w)
        col_1 = col_0 + sw

        # import pdb
        # pdb.set_trace()

        # hamming window weighting
        window_h = np.hamming(sh)
        window_w = np.hamming(sw)
        ham2d = np.sqrt(np.outer(window_h, window_w))[np.newaxis, np.newaxis]

        # initialize result image
        R = np.zeros((n_images, Nc, images.shape[2], images.shape[3]))
        V = np.zeros((n_images, Nc, images.shape[2], images.shape[3]))

        for ir in xrange(len(row_0)):
            for ic in xrange(len(col_0)):
                I = images[:, :, row_0[ir]:row_1[ir], col_0[ic]:col_1[ic]]
                P = self.compute_output(I)
                R[:, :, row_0[ir]:row_1[ir], col_0[ic]:col_1[ic]] += P * ham2d
                V[:, :, row_0[ir]:row_1[ir], col_0[ic]:col_1[ic]] += ham2d

        # clip to original image size again
        R = R[:, :, pad_top:images.shape[2] - pad_bottom, pad_left:images.shape[3] - pad_right]
        V = V[:, :, pad_top:images.shape[2] - pad_bottom, pad_left:images.shape[3] - pad_right]

        # import matplotlib.pyplot as plt
        # plt.figure()
        # plt.imshow(V[0, 0])
        # plt.colorbar()
        # plt.show(block=True)

        # normalize predictions
        R /= V
        return R