from __future__ import print_function
import time
ct = time.time()


from lasagne.layers import get_output, InputLayer, DenseLayer, ReshapeLayer, NonlinearityLayer
from lasagne.nonlinearities import rectify, leaky_rectify
from lasagne.updates import nesterov_momentum, rmsprop, adamax

import sys, os, time
scriptpath = os.path.dirname(__file__)

import nibabel
import numpy as np
import theano
import theano.tensor as T

import lasagne

# Note that Conv3DLayer and dnn.Conv3DDNNLayer have opposite filter-fliping defaults
from lasagne.layers import Conv3DLayer, MaxPool3DLayer
from lasagne.layers import Upscale3DLayer
from lasagne.layers import *


from lasagne.layers import Layer
from lasagne.utils import as_tuple

import pickle
import theano.misc.pkl_utils

cachefile = os.path.dirname(os.path.realpath(__file__)) + "/model_hippo.pkl"

if not os.path.exists(cachefile):


    # This broadcast-enabled layer is required to apply a 3d-mask along the feature-dimension
    # From GH PR #633
    class ElemwiseMergeLayerBroadcast(MergeLayer):
        """
        This layer performs an elementwise merge of its input layers.
        It requires all input layers to have the same output shape.
        Parameters
        ----------
        incomings : Unless `cropping` is given, all shapes must be equal, except
            for dimensions that are undefined (``None``) or broadcastable (``1``).
        merge_function : callable
            the merge function to use. Should take two arguments and return the
            updated value. Some possible merge functions are ``theano.tensor``:
            ``mul``, ``add``, ``maximum`` and ``minimum``.
        cropping : None or [crop]
            Cropping for each input axis. Cropping is described in the docstring
            for :func:`autocrop`
        See Also
        --------
        ElemwiseSumLayer : Shortcut for sum layer.
        """

        def __init__(self, incomings, merge_function, cropping=None, **kwargs):
            super(ElemwiseMergeLayerBroadcast, self).__init__(incomings, **kwargs)
            self.merge_function = merge_function
            self.cropping = cropping
            self.broadcastable = None

        def get_output_shape_for(self, input_shapes):
            input_shapes = autocrop_array_shapes(input_shapes, self.cropping)

            input_dims = [len(shp) for shp in input_shapes]
            if not all(input_dim == input_dims[0] for input_dim in input_dims):
                raise ValueError('Input dimensions must be the same but were %s' %
                                 ", ".join(map(str, input_shapes)))

            def broadcasting(input_dim):
                # Identify dimensions that will be broadcasted.
                sorted_dim = sorted(input_dim,
                                    key=lambda x: x if x is not None else -1)
                if isinstance(sorted_dim[-1], int) and sorted_dim[-1] != 1 \
                        and all([d == 1 for d in sorted_dim[:-1]]):
                    size_after_broadcast = sorted_dim[-1]
                    broadcast = [True if d == 1 else None for d in input_dim]
                    return ((size_after_broadcast,)*len(input_dim), broadcast)
                else:
                    return (input_dim, [None]*len(input_dim))

            # if the dimension is broadcastable we replace 1's with the size
            # after broadcasting.
            input_dims, broadcastable = list(zip(
                *[broadcasting(input_dim)for input_dim in zip(*input_shapes)]))

            self.broadcastable = list(zip(*broadcastable))
            input_shapes = list(zip(*input_dims))

            # Infer the output shape by grabbing, for each axis, the first
            # input size that is not `None` (if there is any)
            output_shape = tuple(next((s for s in sizes if s is not None), None)
                                 for sizes in zip(*input_shapes))

            def match(shape1, shape2):
                return (len(shape1) == len(shape2) and
                        all(s1 is None or s2 is None or s1 == s2
                            for s1, s2 in zip(shape1, shape2)))

            # Check for compatibility with inferred output shape
            if not all(match(shape, output_shape) for shape in input_shapes):
                raise ValueError("Mismatch: not all input shapes are the same")
            return output_shape

        def get_output_for(self, inputs, **kwargs):
            inputs = autocrop(inputs, self.cropping)
            # modify broadcasting pattern.
            if self.broadcastable is not None:
                for n, broadcasting_dim in enumerate(self.broadcastable):
                    for dim, broadcasting in enumerate(broadcasting_dim):
                        if broadcasting:
                            inputs[n] = T.addbroadcast(inputs[n], dim)

            output = None
            for input in inputs:
                if output is not None:
                    output = self.merge_function(output, input)
                else:
                    output = input
            return output


    # Definition of the network
    conv_num_filters = 48
    l = InputLayer(shape = (None, 1, 48, 72, 64), name="input")
    l_input = l

    # # # #
    # encoding
    # # # #
    l = Conv3DLayer(l, flip_filters=False, num_filters = 16, filter_size = (1,1,3), pad = 'valid', name="conv")
    l = Conv3DLayer(l, flip_filters=False, num_filters = 16, filter_size = (1,3,1), pad = 'valid', name="conv")
    l_conv_0 = l = Conv3DLayer(l, flip_filters=False, num_filters = 16, filter_size = (3,1,1), pad = 'valid', name="conv")

    l = l_conv_f1 = Conv3DLayer(l, flip_filters=False, num_filters = conv_num_filters, filter_size = 3, pad = 'valid', name="conv_f1")

    l = l_maxpool1 = MaxPool3DLayer(l, pool_size = 2, name ='maxpool1')
    l = BatchNormLayer(l, name="batchnorm")

    l = Conv3DLayer(l, flip_filters=False, num_filters = conv_num_filters, filter_size = (3,3,3), pad = "same", name="conv")
    l = l_convout1 = Conv3DLayer(l, flip_filters=False, num_filters = conv_num_filters, filter_size = (3, 3, 3), pad = 'same', name ='convout1', nonlinearity = None)
    l = ElemwiseSumLayer(incomings = [l_maxpool1, l_convout1], name="sum_1s")
    l = NonlinearityLayer(l, nonlinearity = rectify, name="relu")


    conv_num_filters2 = 48
    l = l_maxpool2 = MaxPool3DLayer(l, pool_size = 2, name = 'maxpool2')
    l_maxpool2_conv = l
    l = BatchNormLayer(l, name="batchnorm")
    l = Conv3DLayer(l, flip_filters=False, num_filters = conv_num_filters2, filter_size = (3,3,3), pad = "same", name="conv")
    l = l_convout2 = Conv3DLayer(l, flip_filters=False, num_filters = conv_num_filters2, filter_size = (3, 3, 3), pad = 'same', name ='convout2', nonlinearity = None)
    l = ElemwiseSumLayer(incomings = [l_maxpool2_conv, l_convout2], name="sum_2s")
    l = NonlinearityLayer(l, nonlinearity = rectify, name="relu")

    # # # #
    # segmentation
    # # # #
    l_middle = l
    l = Upscale3DLayer(l, scale_factor = 2, name="upscale")
    l = Conv3DLayer(l_middle, flip_filters=False, num_filters = conv_num_filters, filter_size = 3, pad = "same", name="conv")
    l = Upscale3DLayer(l, scale_factor = 2, name="upscale")
    l = l_convout1 = Conv3DLayer(l, flip_filters=False, num_filters = conv_num_filters, filter_size = 3, pad = 1, name="conv")
    l = Upscale3DLayer(l, scale_factor = 2, name="upscale")
    l_upscale = l
    l_convout2 = Conv3DLayer(l_upscale, flip_filters=False, num_filters = 16, filter_size = 3, pad = 1, name="conv")

    # Original (before refinement) output
    l_output1 = Conv3DLayer(l_convout2, flip_filters=False, num_filters = 1, filter_size = 1, pad = 'same', name="conv_1x", nonlinearity =lasagne.nonlinearities.sigmoid )

    # # #
    # refinement
    # # #
    ## The next output is reusing masked original filters to temptatively improve the network
    l_blur = Conv3DLayer(l_output1, flip_filters=False, num_filters=1, filter_size=7, stride=1, pad='same', W=lasagne.init.Constant(1.), b=lasagne.init.Constant(-7*7*7.*.10), nonlinearity=lasagne.nonlinearities.sigmoid)
    # in the above, *10 is : threshold at 10% of the smoothed mask (the higher, the smaller the mask)
    for x in l_blur.params.values():
        x.remove("trainable") # never train this, this is for downsampling

    l_masked_f1 = ElemwiseMergeLayerBroadcast([l_blur, l_conv_f1], merge_function=T.mul)
    l_extract = Conv3DLayer(l_masked_f1, flip_filters=False, num_filters = 47, filter_size = 3, pad = 1, name="extractconv", nonlinearity=leaky_rectify)
    l_concat = l = ConcatLayer([l_output1, l_extract], axis=1)
    l_mix = Conv3DLayer(l_concat, flip_filters=False, num_filters = 16, filter_size = 3, pad = 1, name="mixconv", nonlinearity=rectify)
    l_output2 = Conv3DLayer(l_mix, flip_filters=False, num_filters = 1, filter_size = 1, pad = 'same', name="conv_1x", nonlinearity =lasagne.nonlinearities.sigmoid )

    # Final output
    network = l_output2

    l_out = ConcatLayer([l_output2, l_output1])

    with np.load(os.path.join(scriptpath, "modelparams.npz")) as f:
        param_values = [f['arr_%d' % i] for i in range(len(f.files))]
        lasagne.layers.set_all_param_values(network, param_values)
    #print ("weights loaded on %s (init took %4.2f sec)" % (time.ctime(), time.time() - ct))

    print("Compiling")

    fn_get_output = theano.function([l_input.input_var], get_output(l_out, deterministic=True))

    try:
        print("Pickling")
        pickle.dump(fn_get_output, open(cachefile,"wb"))
    except:
        print("Pickling failed")
        pass
else:
    print("Loading from cache")
    fn_get_output = pickle.load(open(cachefile,"rb"))


if __name__ == "__main__":
    for fn in sys.argv[1:]:
        print ("Running %s" % (fn))
        img = nibabel.load(fn)
        d = img.get_data().astype(np.float32)
        d -= d.mean()
        d /= d.std()
        # split Left and Right (flipping Right)
        d_in = np.vstack([d[None, None, 6: 54:+1,: ,2:-2 ], d[None, None,-7:-55:-1,: ,2:-2 ]])
        out= fn_get_output(d_in)

        if 1:
            output = np.zeros((107, 72, 68, 2), np.uint8)
            output[-7:-55:-1,: ,2:-2, 0 ][2:-2,2:-2,2:-2] = np.clip(out[1,0] * 256, 0, 255)#* maskL
            output[6: 54:+1,: ,2:-2, 1 ][2:-2,2:-2,2:-2] = np.clip(out[0,0] * 256, 0, 255) # * maskR
            outputfn = fn.replace(".nii.gz", "_outseg_L.nii.gz")
            nibabel.Nifti1Image(output[...,0], img.get_affine()).to_filename(outputfn)
            outputfn = fn.replace(".nii.gz", "_outseg_R.nii.gz")
            nibabel.Nifti1Image(output[...,1], img.get_affine()).to_filename(outputfn)

        if 0: # l_output1 (for debugging)
            output = np.zeros((107, 72, 68, 2), np.uint8)
            output[-7:-55:-1,: ,2:-2, 0 ][2:-2,2:-2,2:-2] = np.clip(out[1,1] * 256, 0, 255)#* maskL
            output[6: 54:+1,: ,2:-2, 1 ][2:-2,2:-2,2:-2] = np.clip(out[0,1] * 256, 0, 255) # * maskR
            outputfn = fn.replace(".nii.gz", "_outseg_output1_L.nii.gz")
            nibabel.Nifti1Image(output[...,0], img.get_affine()).to_filename(outputfn)
            outputfn = fn.replace(".nii.gz", "_outseg_output1_R.nii.gz")
            nibabel.Nifti1Image(output[...,1], img.get_affine()).to_filename(outputfn)

    print("Elapsed: %4.2fs" %  (time.time() - ct))