"""
tensorflow/keras utilities for the neuron project

If you use this code, please cite 
Dalca AV, Guttag J, Sabuncu MR
Anatomical Priors in Convolutional Networks for Unsupervised Biomedical Segmentation, 
CVPR 2018

or for the transformation/interpolation related functions:

Unsupervised Learning for Fast Probabilistic Diffeomorphic Registration
Adrian V. Dalca, Guha Balakrishnan, John Guttag, Mert R. Sabuncu
MICCAI 2018.

Contact: adalca [at] csail [dot] mit [dot] edu
License: GPLv3
"""

# python imports
import itertools

# third party imports
import numpy as np
from tqdm import tqdm_notebook as tqdm
from pprint import pformat

import pytools.patchlib as pl
import pytools.timer as timer

# local imports
import pynd.ndutils as nd

# often changed file
from imp import reload
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
reload(pl)

def interpn(vol, loc, interp_method='linear', fill_value=None):
    """
    N-D gridded interpolation in tensorflow

    vol can have more dimensions than loc[i], in which case loc[i] acts as a slice 
    for the first dimensions

    Parameters:
        vol: volume with size vol_shape or [*vol_shape, nb_features]
        loc: a N-long list of N-D Tensors (the interpolation locations) for the new grid
            each tensor has to have the same size (but not nec. same size as vol)
            or a tensor of size [*new_vol_shape, D]
        interp_method: interpolation type 'linear' (default) or 'nearest'
        fill_value: value to use for points outside the domain. If None, the nearest
            neighbors will be used (default).

    Returns:
        new interpolated volume of the same size as the entries in loc

    TODO:
        enable optional orig_grid - the original grid points.
        check out tf.contrib.resampler, only seems to work for 2D data
    """
    
    if isinstance(loc, (list, tuple)):
        loc = tf.stack(loc, -1)
    nb_dims = loc.shape[-1]

    if len(vol.shape) not in [nb_dims, nb_dims+1]:
        raise Exception("Number of loc Tensors %d does not match volume dimension %d"
                        % (nb_dims, len(vol.shape[:-1])))

    if nb_dims > len(vol.shape):
        raise Exception("Loc dimension %d does not match volume dimension %d"
                        % (nb_dims, len(vol.shape)))

    if len(vol.shape) == nb_dims:
        vol = K.expand_dims(vol, -1)

    # flatten and float location Tensors
    loc = tf.cast(loc, 'float32')
    

    if isinstance(vol.shape, (tf.compat.v1.Dimension, tf.TensorShape)):
        volshape = vol.shape.as_list()
    else:
        volshape = vol.shape

    max_loc = [d - 1 for d in vol.get_shape().as_list()]
    if fill_value is not None:
        below = [tf.less(loc[...,d], 0) for d in range(nb_dims)]
        above = [tf.greater(loc[...,d], max_loc[d]) for d in range(nb_dims)]
        out_of_bounds = tf.reduce_any(tf.stack(below + above, axis=-1), axis=-1, keepdims=True)

    # interpolate
    if interp_method == 'linear':
        loc0 = tf.floor(loc)

        # clip values
        clipped_loc = [tf.clip_by_value(loc[...,d], 0, max_loc[d]) for d in range(nb_dims)]
        loc0lst = [tf.clip_by_value(loc0[...,d], 0, max_loc[d]) for d in range(nb_dims)]

        # get other end of point cube
        loc1 = [tf.clip_by_value(loc0lst[d] + 1, 0, max_loc[d]) for d in range(nb_dims)]
        locs = [[tf.cast(f, 'int32') for f in loc0lst], [tf.cast(f, 'int32') for f in loc1]]

        # compute the difference between the upper value and the original value
        # differences are basically 1 - (pt - floor(pt))
        #   because: floor(pt) + 1 - pt = 1 + (floor(pt) - pt) = 1 - (pt - floor(pt))
        diff_loc1 = [loc1[d] - clipped_loc[d] for d in range(nb_dims)]
        diff_loc0 = [1 - d for d in diff_loc1]
        weights_loc = [diff_loc1, diff_loc0] # note reverse ordering since weights are inverse of diff.

        # go through all the cube corners, indexed by a ND binary vector 
        # e.g. [0, 0] means this "first" corner in a 2-D "cube"
        cube_pts = list(itertools.product([0, 1], repeat=nb_dims))
        interp_vol = 0
        
        for c in cube_pts:
            
            # get nd values
            # note re: indices above volumes via https://github.com/tensorflow/tensorflow/issues/15091
            #   It works on GPU because we do not perform index validation checking on GPU -- it's too
            #   expensive. Instead we fill the output with zero for the corresponding value. The CPU
            #   version caught the bad index and returned the appropriate error.
            subs = [locs[c[d]][d] for d in range(nb_dims)]

            # tf stacking is slow for large volumes, so we will use sub2ind and use single indexing.
            # indices = tf.stack(subs, axis=-1)
            # vol_val = tf.gather_nd(vol, indices)
            # faster way to gather than gather_nd, because the latter needs tf.stack which is slow :(
            idx = sub2ind(vol.shape[:-1], subs)
            vol_val = tf.gather(tf.reshape(vol, [-1, volshape[-1]]), idx)

            # get the weight of this cube_pt based on the distance
            # if c[d] is 0 --> want weight = 1 - (pt - floor[pt]) = diff_loc1
            # if c[d] is 1 --> want weight = pt - floor[pt] = diff_loc0
            wts_lst = [weights_loc[c[d]][d] for d in range(nb_dims)]
            # tf stacking is slow, we will use prod_n()
            # wlm = tf.stack(wts_lst, axis=0)
            # wt = tf.reduce_prod(wlm, axis=0)
            wt = prod_n(wts_lst)
            wt = K.expand_dims(wt, -1)
            
            # compute final weighted value for each cube corner
            interp_vol += wt * vol_val
        
    else:
        assert interp_method == 'nearest'
        roundloc = tf.cast(tf.round(loc), 'int32')
        roundloc = [tf.clip_by_value(roundloc[...,d], 0, max_loc[d]) for d in range(nb_dims)]

        # get values
        # tf stacking is slow. replace with gather
        # roundloc = tf.stack(roundloc, axis=-1)
        # interp_vol = tf.gather_nd(vol, roundloc)
        idx = sub2ind(vol.shape[:-1], roundloc)
        interp_vol = tf.gather(tf.reshape(vol, [-1, vol.shape[-1]]), idx) 

    if fill_value is not None:
        out_type = interp_vol.dtype
        fill_value = tf.constant(fill_value, dtype=out_type)
        interp_vol *= tf.cast(tf.logical_not(out_of_bounds), dtype=out_type)
        interp_vol += tf.cast(out_of_bounds, dtype=out_type) * fill_value

    return interp_vol


def resize(vol, zoom_factor, interp_method='linear'):
    """
    if zoom_factor is a list, it will determine the ndims, in which case vol has to be of length ndims of ndims + 1

    if zoom_factor is an integer, then vol must be of length ndims + 1

    """

    if isinstance(zoom_factor, (list, tuple)):
        ndims = len(zoom_factor)
        vol_shape = vol.shape[:ndims]
        
        assert len(vol_shape) in (ndims, ndims+1), \
            "zoom_factor length %d does not match ndims %d" % (len(vol_shape), ndims)

    else:
        vol_shape = vol.shape[:-1]
        ndims = len(vol_shape)
        zoom_factor = [zoom_factor] * ndims
    if not isinstance(vol_shape[0], int):
        vol_shape = vol_shape.as_list()

    new_shape = [vol_shape[f] * zoom_factor[f] for f in range(ndims)]
    new_shape = [int(f) for f in new_shape]

    lin = [tf.linspace(0., vol_shape[d]-1., new_shape[d]) for d in range(ndims)]
    grid = ndgrid(*lin)

    return interpn(vol, grid, interp_method=interp_method)


zoom = resize


def affine_to_shift(affine_matrix, volshape, shift_center=True, indexing='ij'):
    """
    transform an affine matrix to a dense location shift tensor in tensorflow

    Algorithm:
        - get grid and shift grid to be centered at the center of the image (optionally)
        - apply affine matrix to each index.
        - subtract grid

    Parameters:
        affine_matrix: ND+1 x ND+1 or ND x ND+1 matrix (Tensor)
        volshape: 1xN Nd Tensor of the size of the volume.
        shift_center (optional)

    Returns:
        shift field (Tensor) of size *volshape x N

    TODO: 
        allow affine_matrix to be a vector of size nb_dims * (nb_dims + 1)
    """


    if isinstance(volshape, (tf.compat.v1.Dimension, tf.TensorShape)):
        volshape = volshape.as_list()
    
    if affine_matrix.dtype != 'float32':
        affine_matrix = tf.cast(affine_matrix, 'float32')

    nb_dims = len(volshape)

    if len(affine_matrix.shape) == 1:
        if len(affine_matrix) != (nb_dims * (nb_dims + 1)) :
            raise ValueError('transform is supposed a vector of len ndims * (ndims + 1).'
                             'Got len %d' % len(affine_matrix))

        affine_matrix = tf.reshape(affine_matrix, [nb_dims, nb_dims + 1])

    if not (affine_matrix.shape[0] in [nb_dims, nb_dims + 1] and affine_matrix.shape[1] == (nb_dims + 1)):
        raise Exception('Affine matrix shape should match'
                        '%d+1 x %d+1 or ' % (nb_dims, nb_dims) + \
                        '%d x %d+1.' % (nb_dims, nb_dims) + \
                        'Got: ' + str(affine_matrix.shape))

    # list of volume ndgrid
    # N-long list, each entry of shape volshape
    mesh = volshape_to_meshgrid(volshape, indexing=indexing)  
    mesh = [tf.cast(f, 'float32') for f in mesh]
    
    if shift_center:
        mesh = [mesh[f] - (volshape[f]-1)/2 for f in range(len(volshape))]

    # add an all-ones entry and transform into a large matrix
    flat_mesh = [flatten(f) for f in mesh]
    flat_mesh.append(tf.ones(flat_mesh[0].shape, dtype='float32'))
    mesh_matrix = tf.transpose(tf.stack(flat_mesh, axis=1))  # 4 x nb_voxels

    # compute locations
    loc_matrix = tf.matmul(affine_matrix, mesh_matrix)  # N+1 x nb_voxels
    loc_matrix = tf.transpose(loc_matrix[:nb_dims, :])  # nb_voxels x N
    loc = tf.reshape(loc_matrix, list(volshape) + [nb_dims])  # *volshape x N
    # loc = [loc[..., f] for f in range(nb_dims)]  # N-long list, each entry of shape volshape

    # get shifts and return
    return loc - tf.stack(mesh, axis=nb_dims)


def transform(vol, loc_shift, interp_method='linear', indexing='ij', fill_value=None):
    """
    transform (interpolation N-D volumes (features) given shifts at each location in tensorflow

    Essentially interpolates volume vol at locations determined by loc_shift. 
    This is a spatial transform in the sense that at location [x] we now have the data from, 
    [x + shift] so we've moved data.

    Parameters:
        vol: volume with size vol_shape or [*vol_shape, nb_features]
        loc_shift: shift volume [*new_vol_shape, N]
        interp_method (default:'linear'): 'linear', 'nearest'
        indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
            In general, prefer to leave this 'ij'
        fill_value (default: None): value to use for points outside the domain.
            If None, the nearest neighbors will be used.
    
    Return:
        new interpolated volumes in the same size as loc_shift[0]
    
    Keyworks:
        interpolation, sampler, resampler, linear, bilinear
    """

    # parse shapes

    
    
    
    
    
    

    if isinstance(loc_shift.shape, (tf.compat.v1.Dimension, tf.TensorShape)):
        volshape = loc_shift.shape[:-1].as_list()
    else:
        volshape = loc_shift.shape[:-1]
    nb_dims = len(volshape)

    # location should be mesh and delta
    mesh = volshape_to_meshgrid(volshape, indexing=indexing)  # volume mesh
    loc = [tf.cast(mesh[d], 'float32') + loc_shift[..., d] for d in range(nb_dims)]

    # test single
    return interpn(vol, loc, interp_method=interp_method, fill_value=fill_value)


def compose(disp_1, disp_2, indexing='ij'):
    """
    compose two dense deformations specified by their displacements

    We have two fields
        A --> B (so field is in space of B)
        and
        B --> C (so field is in the space of C)
    this function gives a new warp field
        A --> C (so field is in the sapce of C)

    Parameters:
        disp_1: first displacement (A-->B) with size [*vol_shape, ndims]
        disp_2: second  displacement (B-->C) with size [*vol_shape, ndims]
        indexing (default: 'ij'): 'ij' (matrix) or 'xy' (cartesian).
            In general, prefer to leave this 'ij'

    Returns:
        composed field disp_3 which takes data from A to C
    """

    assert indexing == 'ij', "currently only ij indexing is implemented in compose"

    return disp_2 + transform(disp_1, disp_2, interp_method='linear', indexing=indexing)


def integrate_vec(vec, time_dep=False, method='ss', **kwargs):
    """
    Integrate (stationary of time-dependent) vector field (N-D Tensor) in tensorflow
    
    Aside from directly using tensorflow's numerical integration odeint(), also implements 
    "scaling and squaring", and quadrature. Note that the diff. equation given to odeint
    is the one used in quadrature.   

    Parameters:
        vec: the Tensor field to integrate. 
            If vol_size is the size of the intrinsic volume, and vol_ndim = len(vol_size),
            then vector shape (vec_shape) should be 
            [vol_size, vol_ndim] (if stationary)
            [vol_size, vol_ndim, nb_time_steps] (if time dependent)
        time_dep: bool whether vector is time dependent
        method: 'scaling_and_squaring' or 'ss' or 'ode' or 'quadrature'
        
        if using 'scaling_and_squaring': currently only supports integrating to time point 1.
            nb_steps: int number of steps. Note that this means the vec field gets broken
            down to 2**nb_steps. so nb_steps of 0 means integral = vec.

        if using 'ode':
            out_time_pt (optional): a time point or list of time points at which to evaluate
                Default: 1
            init (optional): if using 'ode', the initialization method.
                Currently only supporting 'zero'. Default: 'zero'
            ode_args (optional): dictionary of all other parameters for 
                tf.contrib.integrate.odeint()

    Returns:
        int_vec: integral of vector field.
        Same shape as the input if method is 'scaling_and_squaring', 'ss', 'quadrature', 
        or 'ode' with out_time_pt not a list. Will have shape [*vec_shape, len(out_time_pt)]
        if method is 'ode' with out_time_pt being a list.

    Todo:
        quadrature for more than just intrinsically out_time_pt = 1
    """

    if method not in ['ss', 'scaling_and_squaring', 'ode', 'quadrature']:
        raise ValueError("method has to be 'scaling_and_squaring' or 'ode'. found: %s" % method)

    if method in ['ss', 'scaling_and_squaring']:
        nb_steps = kwargs['nb_steps']
        assert nb_steps >= 0, 'nb_steps should be >= 0, found: %d' % nb_steps

        if time_dep:
            svec = K.permute_dimensions(vec, [-1, *range(0, vec.shape[-1] - 1)])
            assert 2**nb_steps == svec.shape[0], "2**nb_steps and vector shape don't match"

            svec = svec/(2**nb_steps)
            for _ in range(nb_steps):
                svec = svec[0::2] + tf.map_fn(transform, svec[1::2,:], svec[0::2,:])

            disp = svec[0, :]

        else:
            vec = vec/(2**nb_steps)
            for _ in range(nb_steps):
                vec += transform(vec, vec)
            disp = vec

    elif method == 'quadrature':
        # TODO: could output more than a single timepoint!
        nb_steps = kwargs['nb_steps']
        assert nb_steps >= 1, 'nb_steps should be >= 1, found: %d' % nb_steps

        vec = vec/nb_steps

        if time_dep:
            disp = vec[...,0]
            for si in range(nb_steps-1):
                disp += transform(vec[...,si+1], disp)
        else:
            disp = vec
            for _ in range(nb_steps-1):
                disp += transform(vec, disp)

    else:
        assert not time_dep, "odeint not implemented with time-dependent vector field"
        fn = lambda disp, _: transform(vec, disp)  

        # process time point.
        out_time_pt = kwargs['out_time_pt'] if 'out_time_pt' in kwargs.keys() else 1
        out_time_pt = tf.cast(K.flatten(out_time_pt), tf.float32)
        len_out_time_pt = out_time_pt.get_shape().as_list()[0]
        assert len_out_time_pt is not None, 'len_out_time_pt is None :('
        z = out_time_pt[0:1]*0.0  # initializing with something like tf.zeros(1) gives a control flow issue.
        K_out_time_pt = K.concatenate([z, out_time_pt], 0)

        # enable a new integration function than tf.contrib.integrate.odeint
        odeint_fn = tf.contrib.integrate.odeint
        if 'odeint_fn' in kwargs.keys() and kwargs['odeint_fn'] is not None:
            odeint_fn = kwargs['odeint_fn']

        # process initialization
        if 'init' not in kwargs.keys() or kwargs['init'] == 'zero':
            disp0 = vec*0  # initial displacement is 0
        else:
            raise ValueError('non-zero init for ode method not implemented')

        # compute integration with odeint
        if 'ode_args' not in kwargs.keys(): kwargs['ode_args'] = {}
        disp = odeint_fn(fn, disp0, K_out_time_pt, **kwargs['ode_args'])
        disp = K.permute_dimensions(disp[1:len_out_time_pt+1, :], [*range(1,len(disp.shape)), 0])

        # return
        if len_out_time_pt == 1: 
            disp = disp[...,0]

    return disp


def tf_map_fn_axis(fn, elems, axis, **kwargs):
    """
    apply map_fn along a specific axis
    
    if elems is a Tensor, axis is an int
    if elems is a list, axis is a list of same length
    """
    
    # determine lists
    islist = isinstance(elems, (tuple, list))
    if not islist:
        elems = [elems]
        assert not isinstance(axis, (tuple, list)), 'axis cannot be list if elements are not list'
        axis = [axis]
        
        
    elems_perm = []
    for xi, x in enumerate(elems):
        a = axis[xi]
        s = len(x.get_shape().as_list())
        if a == -1: a = s - 1

        # move channels to front, so x will be [axis, ...]
        perm = [a] + list(range(0, a)) + list(range(a + 1, s))
        elems_perm.append(K.permute_dimensions(x, perm))

    # compute sptial deformation regularization for this channel
    if not islist:
        elems_perm = elems_perm[0]
        
    x_perm_trf = tf.map_fn(fn, elems_perm, **kwargs)
    if not islist:
        x_perm_trf = [x_perm_trf]
        

    # move in_channels back to end
    elems_trf = []
    for xi, x in enumerate(x_perm_trf):
        a = axis[xi]
        s = len(x.get_shape().as_list())
        if a == -1: a = s - 1
            
        perm = list(range(1, a + 1)) + [0] + list(range(a + 1, s))
        elems_trf.append(K.permute_dimensions(x, perm))
        
    if not islist:
        elems_trf = elems_trf[0]
    
    return elems_trf






def volshape_to_ndgrid(volshape, **kwargs):
    """
    compute Tensor ndgrid from a volume size

    Parameters:
        volshape: the volume size
        **args: "name" (optional)

    Returns:
        A list of Tensors

    See Also:
        ndgrid
    """
    
    isint = [float(d).is_integer() for d in volshape]
    if not all(isint):
        raise ValueError("volshape needs to be a list of integers")

    linvec = [tf.range(0, d) for d in volshape]
    return ndgrid(*linvec, **kwargs)


def volshape_to_meshgrid(volshape, **kwargs):
    """
    compute Tensor meshgrid from a volume size

    Parameters:
        volshape: the volume size
        **args: "name" (optional)

    Returns:
        A list of Tensors

    See Also:
        tf.meshgrid, meshgrid, ndgrid, volshape_to_ndgrid
    """
    
    isint = [float(d).is_integer() for d in volshape]
    if not all(isint):
        raise ValueError("volshape needs to be a list of integers")

    linvec = [tf.range(0, d) for d in volshape]
    return meshgrid(*linvec, **kwargs)


def ndgrid(*args, **kwargs):
    """
    broadcast Tensors on an N-D grid with ij indexing
    uses meshgrid with ij indexing

    Parameters:
        *args: Tensors with rank 1
        **args: "name" (optional)

    Returns:
        A list of Tensors
    
    """
    return meshgrid(*args, indexing='ij', **kwargs)


def meshgrid(*args, **kwargs):
    """
    
    meshgrid code that builds on (copies) tensorflow's meshgrid but dramatically
    improves runtime by changing the last step to tiling instead of multiplication.
    https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/python/ops/array_ops.py#L1921
    
    Broadcasts parameters for evaluation on an N-D grid.
    Given N one-dimensional coordinate arrays `*args`, returns a list `outputs`
    of N-D coordinate arrays for evaluating expressions on an N-D grid.
    Notes:
    `meshgrid` supports cartesian ('xy') and matrix ('ij') indexing conventions.
    When the `indexing` argument is set to 'xy' (the default), the broadcasting
    instructions for the first two dimensions are swapped.
    Examples:
    Calling `X, Y = meshgrid(x, y)` with the tensors
    ```python
    x = [1, 2, 3]
    y = [4, 5, 6]
    X, Y = meshgrid(x, y)
    # X = [[1, 2, 3],
    #      [1, 2, 3],
    #      [1, 2, 3]]
    # Y = [[4, 4, 4],
    #      [5, 5, 5],
    #      [6, 6, 6]]
    ```
    Args:
    *args: `Tensor`s with rank 1.
    **kwargs:
      - indexing: Either 'xy' or 'ij' (optional, default: 'xy').
      - name: A name for the operation (optional).
    Returns:
    outputs: A list of N `Tensor`s with rank N.
    Raises:
    TypeError: When no keyword arguments (kwargs) are passed.
    ValueError: When indexing keyword argument is not one of `xy` or `ij`.
    """

    indexing = kwargs.pop("indexing", "xy")
    name = kwargs.pop("name", "meshgrid")
    if kwargs:
        key = list(kwargs.keys())[0]
        raise TypeError("'{}' is an invalid keyword argument "
                    "for this function".format(key))

    if indexing not in ("xy", "ij"):
        raise ValueError("indexing parameter must be either 'xy' or 'ij'")

    # with ops.name_scope(name, "meshgrid", args) as name:
    ndim = len(args)
    s0 = (1,) * ndim

    # Prepare reshape by inserting dimensions with size 1 where needed
    output = []
    for i, x in enumerate(args):
        output.append(tf.reshape(tf.stack(x), (s0[:i] + (-1,) + s0[i + 1::])))
    # Create parameters for broadcasting each tensor to the full size
    shapes = [tf.size(x) for x in args]
    sz = [x.get_shape().as_list()[0] for x in args]

    # output_dtype = tf.convert_to_tensor(args[0]).dtype.base_dtype
    if indexing == "xy" and ndim > 1:
        output[0] = tf.reshape(output[0], (1, -1) + (1,) * (ndim - 2))
        output[1] = tf.reshape(output[1], (-1, 1) + (1,) * (ndim - 2))
        shapes[0], shapes[1] = shapes[1], shapes[0]
        sz[0], sz[1] = sz[1], sz[0]

    # This is the part of the implementation from tf that is slow. 
    # We replace it below to get a ~6x speedup (essentially using tile instead of * tf.ones())
    # TODO(nolivia): improve performance with a broadcast  
    # mult_fact = tf.ones(shapes, output_dtype)
    # return [x * mult_fact for x in output]
    for i in range(len(output)):       
        stack_sz = [*sz[:i], 1, *sz[(i+1):]]
        if indexing == 'xy' and ndim > 1 and i < 2:
            stack_sz[0], stack_sz[1] = stack_sz[1], stack_sz[0]
        output[i] = tf.tile(output[i], tf.stack(stack_sz))
    return output


def flatten(v):
    """
    flatten Tensor v
    
    Parameters:
        v: Tensor to be flattened
    
    Returns:
        flat Tensor
    """

    return tf.reshape(v, [-1])

    

    
def prod_n(lst):
    """
    Alternative to tf.stacking and prod, since tf.stacking can be slow
    """
    prod = lst[0]
    for p in lst[1:]:
        prod *= p
    return prod



def gaussian_kernel(sigma, windowsize=None, indexing='ij'):
    """
    sigma will be a number of a list of numbers.

    # some guidance from my MATLAB file 
    https://github.com/adalca/mivt/blob/master/src/gaussFilt.m

    Parameters:
        sigma: scalar or list of scalars
        windowsize (optional): scalar or list of scalars indicating the shape of the kernel
    
    Returns:
        ND kernel the same dimensiosn as the number of sigmas.

    Todo: could use MultivariateNormalDiag
    """

    if not isinstance(sigma, (list, tuple)):
        sigma = [sigma]
    sigma = [np.maximum(f, np.finfo(float).eps) for f in sigma]

    nb_dims = len(sigma)

    # compute windowsize
    if windowsize is None:
        windowsize = [np.round(f * 3) * 2 + 1 for f in sigma]

    if len(sigma) != len(windowsize):
        raise ValueError('sigma and windowsize should have the same length.'
                         'Got vectors: ' + str(sigma) + 'and' + str(windowsize))

    # ok, let's get to work.
    mid = [(w - 1)/2 for w in windowsize]

    # list of volume ndgrid
    # N-long list, each entry of shape volshape
    mesh = volshape_to_meshgrid(windowsize, indexing=indexing)  
    mesh = [tf.cast(f, 'float32') for f in mesh]

    # compute independent gaussians
    diff = [mesh[f] - mid[f] for f in range(len(windowsize))]
    exp_term = [- K.square(diff[f])/(2 * (sigma[f]**2)) for f in range(nb_dims)]
    norms = [exp_term[f] - np.log(sigma[f] * np.sqrt(2 * np.pi)) for f in range(nb_dims)]

    # add an all-ones entry and transform into a large matrix
    norms_matrix = tf.stack(norms, axis=-1)  # *volshape x N
    g = K.sum(norms_matrix, -1)  # volshape
    g = tf.exp(g)
    g /= tf.reduce_sum(g)

    return g




def _softmax(x, axis=-1, alpha=1):
    """
    building on keras implementation, with additional alpha parameter

    Softmax activation function.
    # Arguments
        x : Tensor.
        axis: Integer, axis along which the softmax normalization is applied.
        alpha: a value to multiply all x
    # Returns
        Tensor, output of softmax transformation.
    # Raises
        ValueError: In case `dim(x) == 1`.
    """
    x = alpha * x
    ndim = K.ndim(x)
    if ndim == 2:
        return K.softmax(x)
    elif ndim > 2:
        e = K.exp(x - K.max(x, axis=axis, keepdims=True))
        s = K.sum(e, axis=axis, keepdims=True)
        return e / s
    else:
        raise ValueError('Cannot apply softmax to a tensor that is 1D')



def stack_models(models, connecting_node_ids=None):
    """
    stacks keras models sequentially without nesting the models into layers
        (the nominal behaviour in keras as of 1/13/2018 is to nest models)
    This preserves the layers (i.e. does not copy layers). This means that if you modify the
    original layer weights, you are automatically affecting the new stacked model.

    Parameters:
        models: a list of models, in order of: [input_model, second_model, ..., final_output_model]
        connecting_node_ids (optional): a list of connecting node pointers from Nth model to N+1th model

    Returns:
        new stacked model pointer
    """

    output_tensors = models[0].outputs
    stacked_inputs = [*models[0].inputs]

    # go through models 1 onwards and stack with current graph
    for mi in range(1, len(models)):
        
        # prepare input nodes - a combination of 
        new_input_nodes = list(models[mi].inputs)
        stacked_inputs_contrib = list(models[mi].inputs)

        if connecting_node_ids is None: 
            conn_id = list(range(len(new_input_nodes)))
            assert len(new_input_nodes) == len(models[mi-1].outputs), \
                'argument count does not match'
        else:
            conn_id = connecting_node_ids[mi-1]

        for out_idx, ii in enumerate(conn_id):
            new_input_nodes[ii] = output_tensors[out_idx]
            stacked_inputs_contrib[ii] = None
        
        output_tensors = mod_submodel(models[mi], new_input_nodes=new_input_nodes)
        stacked_inputs = stacked_inputs + stacked_inputs_contrib

    stacked_inputs_ = [i for i in stacked_inputs if i is not None]
    # check for unique, but keep order:
    stacked_inputs = []
    for inp in stacked_inputs_:
        if inp not in stacked_inputs:
            stacked_inputs.append(inp)
    new_model = keras.models.Model(stacked_inputs, output_tensors)
    return new_model


def mod_submodel(orig_model,
                 new_input_nodes=None,
                 input_layers=None):
    """
    modify (cut and/or stitch) keras submodel

    layer objects themselved will be untouched - the new model, even if it includes, 
    say, a subset of the previous layers, those layer objects will be shared with
    the original model

    given an original model:
        model stitching: given new input node(s), get output tensors of having pushed these 
        nodes through the model
        
        model cutting: given input layer (pointers) inside the model, the new input nodes
        will match the new input layers, hence allowing cutting the model

    Parameters:
        orig_model: original keras model pointer
        new_input_nodes: a pointer to a new input node replacement
        input_layers: the name of the layer in the original model to replace input nodes
    
    Returns:
        pointer to modified model
    """

    def _layer_dependency_dict(orig_model):
        """
        output: a dictionary of all layers in the orig_model
        for each layer:
            dct[layer] is a list of lists of layers.
        """

        if hasattr(orig_model, 'output_layers'):
            out_layers = orig_model.output_layers
            out_node_idx = orig_model.output_layers_node_indices
            node_list = [ol._inbound_nodes[out_node_idx[i]] for i, ol in enumerate(out_layers)]

        else:
            out_layers = orig_model._output_layers
            
            node_list = []
            for i, ol in enumerate(orig_model._output_layers):
                node_list += ol._inbound_nodes
            node_list  = list(set(node_list ))
            
        dct = {}
        dct_node_idx = {}
        while len(node_list) > 0:
            node = node_list.pop(0)
            node_input_layers = node.inbound_layers
            node_indices = node.node_indices
            if not isinstance(node_input_layers, (list, tuple)):
                node_input_layers = [node_input_layers]
                node_indices = [node_indices]
                
            add = True
            # if not empty. we need to check that we're not adding the same layers through the same node.
            if len(dct.setdefault(node.outbound_layer, [])) > 0:
                for li, layers in enumerate(dct[node.outbound_layer]):
                    if layers == node.inbound_layers and \
                        dct_node_idx[node.outbound_layer][li] == node_indices:
                        add = False
                        break
            if add:
                dct[node.outbound_layer].append(node_input_layers)
                dct_node_idx.setdefault(node.outbound_layer, []).append(node_indices)
            # append is in place

            # add new node
            
            for li, layer in enumerate(node_input_layers):
                if hasattr(layer, '_inbound_nodes'):
                    node_list.append(layer._inbound_nodes[node_indices[li]])
            
        return dct

    def _get_new_layer_output(layer, new_layer_outputs, inp_layers):
        """
        (recursive) given a layer, get new outbound_nodes based on new inbound_nodes

        new_layer_outputs is a (reference) dictionary that we will be adding
        to within the recursion stack.
        """

        if layer not in new_layer_outputs:

            if layer not in inp_layers:
                raise Exception('layer %s is not in inp_layers' % layer.name)

            # for all input layers to this layer, gather their output (our input)
            for group in inp_layers[layer]:
                input_nodes = [None] * len(group)
                for li, inp_layer in enumerate(group):
                    if inp_layer in new_layer_outputs:
                        input_nodes[li] = new_layer_outputs[inp_layer]
                    else: # recursive call
                        input_nodes[li] = _get_new_layer_output(inp_layer, new_layer_outputs, inp_layers)

                # layer call
                if len(input_nodes) == 1:
                    new_layer_outputs[layer] = layer(*input_nodes)
                else:
                    new_layer_outputs[layer] = layer(input_nodes)

        return new_layer_outputs[layer]



    # for each layer create list of input layers
    inp_layers = _layer_dependency_dict(orig_model)

    # get input layers
    #   These layers will be 'ignored' in that they will not be called!
    #   instead, the outbound nodes of the layers will be the input nodes
    #   computed below or passed in
    if input_layers is None: # if none provided, search for them
        # InputLayerClass = keras.engine.topology.InputLayer
        InputLayerClass = type(tf.keras.layers.InputLayer())
        input_layers = [l for l in orig_model.layers if isinstance(l, InputLayerClass)]

    else:
        if not isinstance(input_layers, (tuple, list)):
            input_layers = [input_layers]
        for idx, input_layer in enumerate(input_layers):
            # if it's a string, assume it's layer name, and get the layer pointer
            if isinstance(input_layer, str):
                input_layers[idx] = orig_model.get_layer(input_layer)

    # process new input nodes
    if new_input_nodes is None:
        input_nodes = list(orig_model.inputs)
    else:
        input_nodes = new_input_nodes
    assert len(input_nodes) == len(input_layers), 'input_nodes (%d) and input_layers (%d) have to match' % (len(input_nodes), len(input_layers))

    # initialize dictionary of layer:new_output_node
    #   note: the input layers are not called, instead their outbound nodes
    #   are assumed to be the given input nodes. If we call the nodes, we can run
    #   into multiple-inbound-nodes problems, or if we completely skip the layers altogether
    #   we have problems with multiple inbound input layers into subsequent layers
    new_layer_outputs = {}
    for i, input_layer in enumerate(input_layers):
        new_layer_outputs[input_layer] = input_nodes[i]

    # recursively go back from output layers and request new input nodes
    output_layers = []
    for layer in orig_model.layers:
        if hasattr(layer, '_inbound_nodes'):
            for i in range(len(layer._inbound_nodes)):
                if layer.get_output_at(i) in orig_model.outputs:
                    output_layers.append(layer)
                    break
    assert len(output_layers) == len(orig_model.outputs), "Number of output layers don't match"

    outputs = [None] * len(output_layers)
    for li, output_layer in enumerate(output_layers):
        outputs[li] = _get_new_layer_output(output_layer, new_layer_outputs, inp_layers)

    return outputs


def reset_weights(model, session=None):
    """
    reset weights of model with the appropriate initializer.
    Note: only uses "kernel_initializer" and "bias_initializer"
    does not close session.

    Reference:
    https://www.codementor.io/nitinsurya/how-to-re-initialize-keras-model-weights-et41zre2g

    Parameters:
        model: keras model to reset
        session (optional): the current session
    """

    if session is None:
        session = K.get_session()

    for layer in model.layers: 
        reset = False
        if hasattr(layer, 'kernel_initializer'):
            layer.kernel.initializer.run(session=session)
            reset = True
        
        if hasattr(layer, 'bias_initializer'):
            layer.bias.initializer.run(session=session)
            reset = True
        
        if not reset:
            print('Could not find initializer for layer %s. skipping', layer.name)


def copy_model_weights(src_model, dst_model):
    """
    copy weights from the src keras model to the dst keras model via layer names

    Parameters:
        src_model: source keras model to copy from
        dst_model: destination keras model to copy to
    """

    for layer in tqdm(dst_model.layers):
        try:
            wts = src_model.get_layer(layer.name).get_weights()
            layer.set_weights(wts)
        except:
            print('Could not copy weights of %s' % layer.name)
            continue


def robust_multi_gpu_model(model, gpus, verbose=True):
    """
    re-work keras model for multi-gpus if number of gpus is > 1

    Parameters:
        model: keras Model
        gpus: list of gpus to split to (e.g. [1, 4, 6]), or count of gpus available (e.g. 3)
            Note: if given int, assume that is the count of gpus, 
            so if you want a single specific gpu, this function will not do that.
        verbose: whether to display what happened (default: True)
    
    Returns:
        keras model
    """

    islist = isinstance(gpus, (list, tuple))
    if (islist and len(gpus) > 1) or (not islist and gpus > 1):
        count = gpus if not islist else len(gpus)
        print("Returning multi-gpu (%d) model" % count)
        return keras.utils.multi_gpu_model(model, count)

    else:
        print("Returning keras model back (single gpu found)")
        return model







def logtanh(x, a=1):
    """
    log * tanh

    See Also: arcsinh
    """
    return K.tanh(x) *  K.log(2 + a * abs(x))


def arcsinh(x, alpha=1):
    """
    asignh

    See Also: logtanh
    """
    return tf.asinh(x * alpha) / alpha


def logistic(x, x0=0., alpha=1., L=1.):
    """
    returns L/(1+exp(-alpha * (x-x0)))
    """
    assert L > 0, 'L (height of logistic) should be > 0'
    assert alpha > 0, 'alpha (slope) of logistic should be > 0'
    
    return L / (1 + tf.exp(-alpha * (x-x0)))

def sigmoid(x):
    return logistic(x, x0=0., alpha=1., L=1.)

def logistic_fixed_ends(x, start=-1., end=1., L=1., **kwargs):
    """
    f is logistic with fixed ends, so that f(start) = 0, and f(end) = L.
    this is currently done a bit heuristically: it's a sigmoid, with a linear function added to correct the ends.
    """
    assert end > start, 'End of fixed points should be greater than start'
    # tf.assert_greater(end, start, message='assert')
    
    # clip to start and end
    x = tf.clip_by_value(x, start, end)
    
    # logistic function
    xv = logistic(x, L=L, **kwargs)
    
    # ends of linear corrective function
    sv = logistic(start, L=L, **kwargs)
    ev = logistic(end, L=L, **kwargs)
    
    # corrective function
    df = end - start
    linear_corr = (end-x)/df * (- sv) + (x-start)/df * (-ev + L)
    
    # return fixed logistic
    return xv + linear_corr

def sigmoid_fixed_ends(x, start=-1., end=1., L=1., **kwargs):
    return logistic_fixed_ends(x, start=-1., end=1., L=1., x0=0., alpha=1.)

def soft_round(x, alpha=25):
    fx = tf.floor(x)
    xd = x - fx
    return fx + logistic_fixed_ends(xd, start=0., end=1., x0=0.5, alpha=alpha)

def soft_delta(x, x0=0., alpha=100, reg='l1'):
    """
    recommended defaults:
    alpha = 100 for l1
    alpha = 1000 for l2
    """
    if reg == 'l1':
        xa = tf.abs(x - x0)
    else:
        assert reg == 'l2'
        xa = tf.square(x - x0)
    return (1 - logistic(xa, alpha=alpha)) * 2


def odd_shifted_relu(x, shift=-0.5, scale=2.0):
    """
    Odd shifted ReLu
    Essentially in x > 0, it is a shifted ReLu, and in x < 0 it's a negative mirror. 
    """

    shift = float(shift)
    scale = float(scale)
    return scale * K.relu(x - shift)  - scale * K.relu(- x - shift)






def predict_volumes(models,
                    data_generator,
                    batch_size,
                    patch_size,
                    patch_stride,
                    grid_size,
                    nan_func=np.nanmedian,
                    do_extra_vol=False,  # should compute vols beyond label
                    do_prob_of_true=False,  # should compute prob_of_true vols
                    verbose=False):
    """
    Note: we allow models to be a list or a single model.
    Normally, if you'd like to run a function over a list for some param,
    you can simply loop outside of the function. here, however, we are dealing with a generator,
    and want the output of that generator to be consistent for each model.

    Returns:
    if models isa list of more than one model:
        a tuple of model entried, each entry is a tuple of:
        true_label, pred_label, <vol>, <prior_label>, <pred_prob_of_true>, <prior_prob_of_true>
    if models is just one model:
        a tuple of
        (true_label, pred_label, <vol>, <prior_label>, <pred_prob_of_true>, <prior_prob_of_true>)

    TODO: could add prior
    """

    if not isinstance(models, (list, tuple)):
        models = (models,)

    # get the input and prediction stack
    with timer.Timer('predict_volume_stack', verbose):
        vol_stack = predict_volume_stack(models,
                                         data_generator,
                                         batch_size,
                                         grid_size,
                                         verbose)
    if len(models) == 1:
        do_prior = len(vol_stack) == 4
    else:
        do_prior = len(vol_stack[0]) == 4

    # go through models and volumes
    ret = ()
    for midx, _ in enumerate(models):

        stack = vol_stack if len(models) == 1 else vol_stack[midx]

        if do_prior:
            all_true, all_pred, all_vol, all_prior = stack
        else:
            all_true, all_pred, all_vol = stack

        # get max labels
        all_true_label, all_pred_label = pred_to_label(all_true, all_pred)

        # quilt volumes and aggregate overlapping patches, if any
        args = [patch_size, grid_size, patch_stride]
        label_kwargs = {'nan_func_layers':nan_func, 'nan_func_K':nan_func, 'verbose':verbose}
        vol_true_label = _quilt(all_true_label, *args, **label_kwargs).astype('int')
        vol_pred_label = _quilt(all_pred_label, *args, **label_kwargs).astype('int')

        ret_set = (vol_true_label, vol_pred_label)

        if do_extra_vol:
            vol_input = _quilt(all_vol, *args)
            ret_set += (vol_input, )

            if do_prior:
                all_prior_label, = pred_to_label(all_prior)
                vol_prior_label = _quilt(all_prior_label, *args, **label_kwargs).astype('int')
                ret_set += (vol_prior_label, )

        # compute the probability of prediction and prior
        # instead of quilting the probabilistic volumes and then computing the probability
        # of true label, which takes a long time, we'll first compute the probability of label,
        # and then quilt. This is faster, but we'll need to take median votes
        if do_extra_vol and do_prob_of_true:
            all_pp = prob_of_label(all_pred, all_true_label)
            pred_prob_of_true = _quilt(all_pp, *args, **label_kwargs)
            ret_set += (pred_prob_of_true, )

            if do_prior:
                all_pp = prob_of_label(all_prior, all_true_label)
                prior_prob_of_true = _quilt(all_pp, *args, **label_kwargs)

                ret_set += (prior_prob_of_true, )

        ret += (ret_set, )

    if len(models) == 1:
        ret = ret[0]

    # return
    return ret


def predict_volume_stack(models,
                         data_generator,
                         batch_size,
                         grid_size,
                         verbose=False):
    """
    predict all the patches in a volume

    requires batch_size to be a divisor of the number of patches (prod(grid_size))

    Note: we allow models to be a list or a single model.
    Normally, if you'd like to run a function over a list for some param,
    you can simply loop outside of the function. here, however, we are dealing with a generator,
    and want the output of that generator to be consistent for each model.

    Returns:
    if models isa list of more than one model:
        a tuple of model entried, each entry is a tuple of:
        all_true, all_pred, all_vol, <all_prior>
    if models is just one model:
        a tuple of
        all_true, all_pred, all_vol, <all_prior>
    """

    if not isinstance(models, (list, tuple)):
        models = (models,)

    # compute the number of batches we need for one volume
    # we need the batch_size to be a divisor of nb_patches,
    # in order to loop through batches and form full volumes
    nb_patches = np.prod(grid_size)
    # assert np.mod(nb_patches, batch_size) == 0, \
        # "batch_size %d should be a divisor of nb_patches %d" %(batch_size, nb_patches)
    nb_batches = ((nb_patches - 1) // batch_size) + 1

    # go through the patches
    batch_gen = tqdm(range(nb_batches)) if verbose else range(nb_batches)
    for batch_idx in batch_gen:
        sample = next(data_generator)
        nb_vox = np.prod(sample[1].shape[1:-1])
        do_prior = isinstance(sample[0], (list, tuple))

        # pre-allocate all the data
        if batch_idx == 0:
            nb_labels = sample[1].shape[-1]
            all_vol = [np.zeros((nb_patches, nb_vox)) for f in models]
            all_true = [np.zeros((nb_patches, nb_vox * nb_labels)) for f in models]
            all_pred = [np.zeros((nb_patches, nb_vox * nb_labels)) for f in models]
            all_prior = [np.zeros((nb_patches, nb_vox * nb_labels)) for f in models]

        # get in_vol, y_true, y_pred
        for idx, model in enumerate(models):
            # with timer.Timer('prediction', verbose):
            pred = model.predict(sample[0])
            assert pred.shape[0] == batch_size, \
                "batch size mismatch. sample has batch size %d, given batch size is %d" %(pred.shape[0], batch_size)
            input_batch = sample[0] if not do_prior else sample[0][0]

            # compute batch range
            batch_start = batch_idx * batch_size
            batch_end = np.minimum(batch_start + batch_size, nb_patches)
            batch_range = np.arange(batch_start, batch_end)
            batch_vox_idx = batch_end-batch_start

            # update stacks
            all_vol[idx][batch_range, :] = K.batch_flatten(input_batch)[0:batch_vox_idx, :]
            all_true[idx][batch_range, :] = K.batch_flatten(sample[1])[0:batch_vox_idx, :]
            all_pred[idx][batch_range, :] = K._batch_flatten(pred)[0:batch_vox_idx, :]
            if do_prior:
                all_prior[idx][batch_range, :] = K.batch_flatten(sample[0][1])[0:batch_vox_idx, :]

    # reshape probabilistic answers
    for idx, _ in enumerate(models):
        all_true[idx] = np.reshape(all_true[idx], [nb_patches, nb_vox, nb_labels])
        all_pred[idx] = np.reshape(all_pred[idx], [nb_patches, nb_vox, nb_labels])
        if do_prior:
            all_prior[idx] = np.reshape(all_prior[idx], [nb_patches, nb_vox, nb_labels])

    # prepare output tuple
    ret = ()
    for midx, _ in enumerate(models):
        if do_prior:
            ret += ((all_true[midx], all_pred[midx], all_vol[midx], all_prior[midx]), )
        else:
            ret += ((all_true[midx], all_pred[midx], all_vol[midx]), )

    if len(models) == 1:
        ret = ret[0]
    return ret


def prob_of_label(vol, labelvol):
    """
    compute the probability of the labels in labelvol in each of the volumes in vols

    Parameters:
        vol (float numpy array of dim (nd + 1): volume with a prob dist at each voxel in a nd vols
        labelvol (int numpy array of dim nd): nd volume of labels

    Returns:
        nd volume of probabilities
    """

    # check dimensions
    nb_dims = np.ndim(labelvol)
    assert np.ndim(vol) == nb_dims + 1, "vol dimensions do not match [%d] vs [%d]" % (np.ndim(vol)-1, nb_dims)
    shp = vol.shape
    nb_voxels = np.prod(shp[0:nb_dims])
    nb_labels = shp[-1]

    # reshape volume to be [nb_voxels, nb_labels]
    flat_vol = np.reshape(vol, (nb_voxels, nb_labels))

    # normalize accross second dimension
    rows_sums = flat_vol.sum(axis=1)
    flat_vol_norm = flat_vol / rows_sums[:, np.newaxis]

    # index into the flattened volume
    idx = list(range(nb_voxels))
    v = flat_vol_norm[idx, labelvol.flat]
    return np.reshape(v, labelvol.shape)


def next_pred_label(model, data_generator, verbose=False):
    """
    predict the next sample batch from the generator, and compute max labels
    return sample, prediction, max_labels
    """
    sample = next(data_generator)
    with timer.Timer('prediction', verbose):
        pred = model.predict(sample[0])
    sample_input = sample[0] if not isinstance(sample[0], (list, tuple)) else sample[0][0]
    max_labels = pred_to_label(sample_input, pred)
    return (sample, pred) + max_labels

def next_label(model, data_generator):
    """
    predict the next sample batch from the generator, and compute max labels
    return max_labels
    """
    batch_proc = next_pred_label(model, data_generator)
    return (batch_proc[2], batch_proc[3])

def sample_to_label(model, sample):
    """
    redict a sample batch and compute max labels
    return max_labels
    """
    # predict output for a new sample
    res = model.predict(sample[0])
    # return
    return pred_to_label(sample[1], res)

def pred_to_label(*y):
    """
    return the true and predicted labels given true and predicted nD+1 volumes
    """
    # compute resulting volume(s)
    return tuple(np.argmax(f, -1).astype(int) for f in y)

def next_vol_pred(model, data_generator, verbose=False):
    """
    get the next batch, predict model output

    returns (input_vol, y_true, y_pred, <prior>)
    """

    # batch to input, output and prediction
    sample = next(data_generator)
    with timer.Timer('prediction', verbose):
        pred = model.predict(sample[0])
    data = (sample[0], sample[1], pred)
    if isinstance(sample[0], (list, tuple)):  # if given prior, might be a list
        data = (sample[0][0], sample[1], pred, sample[0][1])

    return data







def sub2ind(siz, subs, **kwargs):
    """
    assumes column-order major
    """
    # subs is a list
    assert len(siz) == len(subs), \
        'found inconsistent siz and subs: %d %d' % (len(siz), len(subs))

    k = np.cumprod(siz[::-1])

    ndx = subs[-1]
    for i, v in enumerate(subs[:-1][::-1]):
        ndx = ndx + v * k[i]

    return ndx





###############################################################################
# functions from some external source
###############################################################################

def batch_gather(reference, indices):
    """
    C+P From Keras pull request https://github.com/keras-team/keras/pull/6377/files
    
    Batchwise gathering of row indices.

    The numpy equivalent is `reference[np.arange(batch_size), indices]`, where
    `batch_size` is the first dimension of the reference tensor.

    # Arguments
        reference: A tensor with ndim >= 2 of shape.
          (batch_size, dim1, dim2, ..., dimN)
        indices: A 1d integer tensor of shape (batch_size) satisfying
          0 <= i < dim2 for each element i.

    # Returns
        The selected tensor with shape (batch_size, dim2, ..., dimN).

    # Examples
        1. If reference is `[[3, 5, 7], [11, 13, 17]]` and indices is `[2, 1]`
        then the result is `[7, 13]`.

        2. If reference is
        ```
          [[[2, 3], [4, 5], [6, 7]],
           [[10, 11], [12, 13], [16, 17]]]
        ```
        and indices is `[2, 1]` then the result is `[[6, 7], [12, 13]]`.
    """
    batch_size = K.shape(reference)[0]
    indices = tf.stack([tf.range(batch_size), indices], axis=1)
    return tf.gather_nd(reference, indices)


def model_diagram(model):
    outfile = NamedTemporaryFile().name + '.png'
    plot_model(model, to_file=outfile, show_shapes=True)
    Image(outfile, width=100)






def perlin_vol(vol_shape, min_scale=0, max_scale=None, interp_method='linear', wt_type='monotonic'):
    """
    generate perlin noise ND volume 

    rough algorithm:
    
    vol = zeros
    for scale in scales:
        rand = generate random uniform noise at given scale
        vol += wt * upsampled rand to vol_shape 
        

    Parameters
    ----------
    vol_shape: list indicating input shape.
    min_scale: higher min_scale = less high frequency noise
      the minimum rescale vol_shape/(2**min_scale), min_scale of 0 (default) 
      means start by not rescaling, and go down.
    max_scale: maximum scale, if None computes such that smallest volume shape is [1]
    interp_order: interpolation (upscale) order, as used in ne.utils.zoom
    wt_type: the weight type between volumes. default: monotonically decreasing with image size.
      options: 'monotonic', 'random'
    
    https://github.com/adalca/matlib/blob/master/matlib/visual/perlin.m
    loosely inspired from http://nullprogram.com/blog/2007/11/20
    """

    # input handling
    assert wt_type in ['monotonic', 'random'], \
        "wt_type should be in 'monotonic', 'random', got: %s"  % wt_type

    if max_scale is None:
        max_width = np.max(vol_shape)
        max_scale = np.ceil(np.log2(max_width)).astype('int')

    # decide on scales:
    scale_shapes = []
    wts = []
    for i in range(min_scale, max_scale + 1):
        scale_shapes.append(np.ceil([f / (2**i) for f in vol_shape]).astype('int'))
    
        # determine weight
        if wt_type == 'monotonic':
            wts.append(i + 1)  # larger images (so more high frequencies) get lower weight
        else:
            wts.append(K.random_uniform([1])[0])

    wts = K.stack(wts)/K.sum(wts)
    wts = tf.cast(wts, tf.float32)


    # get perlin volume
    vol = K.zeros(vol_shape)
    for sci, sc in enumerate(scale_shapes):

        # get a small random volume
        rand_vol = K.random_uniform(sc)
        
        # interpolated rand volume to upper side
        reshape_factor = [vol_shape[d]/sc[d] for d in range(len(vol_shape))]
        interp_vol = zoom(rand_vol, reshape_factor, interp_method=interp_method)[..., 0]

        # add to existing volume
        vol = vol + wts[sci] * interp_vol
        
    return vol


###############################################################################
# helper functions
###############################################################################

def _concat(lists, dim):
    if lists[0].size == 0:
        lists = lists[1:]

    return np.concatenate(lists, dim)

def _quilt(patches, patch_size, grid_size, patch_stride, verbose=False, **kwargs):
    assert len(patches.shape) >= 2, "patches has bad shape %s" % pformat(patches.shape)

    # reshape to be [nb_patches x nb_vox]
    patches = np.reshape(patches, (patches.shape[0], -1, 1))

    # quilt
    quilted_vol = pl.quilt(patches, patch_size, grid_size, patch_stride=patch_stride, **kwargs)
    assert quilted_vol.ndim == len(patch_size), "problem with dimensions after quilt"

    # return
    return quilted_vol


# TO MOVE (numpy softmax)
def softmax(x, axis):
    """
    softmax of a numpy array along a given dimension
    """

    return np.exp(x) / np.sum(np.exp(x), axis=axis, keepdims=True)