# Based on
# https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
import os

import numpy as np
import scipy.misc
import scipy.ndimage as ndi
from skimage.color import rgb2gray, gray2rgb
from skimage import img_as_float


def optical_flow(seq, rows_idx, cols_idx, chan_idx, return_rgb=False):
    '''Optical flow

    Takes a 4D array of sequences and returns a 4D array with
    an RGB optical flow image for each frame in the input'''
    import cv2
    if seq.ndim != 4:
        raise RuntimeError('Optical flow expected 4 dimensions, got %d' %
                           seq.ndim)
    seq = seq.copy()
    seq = (seq * 255).astype('uint8')
    # Reshape to channel last: (b*seq, 0, 1, ch) if seq
    pattern = [el for el in range(seq.ndim)
               if el not in (rows_idx, cols_idx, chan_idx)]
    pattern += [rows_idx, cols_idx, chan_idx]
    inv_pattern = [pattern.index(el) for el in range(seq.ndim)]
    seq = seq.transpose(pattern)
    if seq.shape[0] == 1:
        raise RuntimeError('Optical flow needs a sequence longer than 1 '
                           'to work')
    seq = seq[..., ::-1]  # Go BGR for OpenCV

    frame1 = seq[0]
    if return_rgb:
        flow_seq = np.zeros_like(seq)
        hsv = np.zeros_like(frame1)
    else:
        sh = list(seq.shape)
        sh[-1] = 2
        flow_seq = np.zeros(sh)

    frame1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)  # Go to gray

    flow = None
    for i, frame2 in enumerate(seq[1:]):
        frame2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)  # Go to gray
        flow = cv2.calcOpticalFlowFarneback(prev=frame1,
                                            next=frame2,
                                            pyr_scale=0.5,
                                            levels=3,
                                            winsize=10,
                                            iterations=3,
                                            poly_n=5,
                                            poly_sigma=1.1,
                                            flags=0,
                                            flow=flow)
        mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1],
                                   angleInDegrees=True)
        # normalize between 0 and 255
        ang = ang / 360 * 255
        if return_rgb:
            hsv[..., 0] = ang
            hsv[..., 1] = 255
            hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
            rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)
            flow_seq[i+1] = rgb
            # Image.fromarray(rgb).show()
            # cv2.imwrite('opticalfb.png', frame2)
            # cv2.imwrite('opticalhsv.png', bgr)
        else:
            flow_seq[i+1] = np.stack((ang, mag), 2)
        frame1 = frame2
    flow_seq = flow_seq.transpose(inv_pattern)
    return flow_seq / 255.  # return in [0, 1]


def my_label2rgb(labels, cmap, bglabel=None, bg_color=(0., 0., 0.)):
    '''Convert a label mask to RGB applying a color map'''
    output = np.zeros(labels.shape + (3,), dtype=np.float64)
    for i in range(len(cmap)):
        if i != bglabel:
            output[(labels == i).nonzero()] = cmap[i]
    if bglabel is not None:
        output[(labels == bglabel).nonzero()] = bg_color
    return output


def my_label2rgboverlay(labels, cmap, image, bglabel=None,
                        bg_color=(0., 0., 0.), alpha=0.2):
    '''Superimpose a mask over an image

    Convert a label mask to RGB applying a color map and superimposing it
    over an image as a transparent overlay'''
    image_float = gray2rgb(img_as_float(rgb2gray(image)))
    label_image = my_label2rgb(labels, cmap, bglabel=bglabel,
                               bg_color=bg_color)
    output = image_float * alpha + label_image * (1 - alpha)
    return output


def save_img2(x, y, fname, cmap, void_label, rows_idx, cols_idx,
              chan_idx):
    '''Save a mask and an image side to side

    Convert a label mask to RGB applying a color map and superimposing it
    over an image as a transparent overlay. Saves the original image and
    the image with the mask overlay in a file'''
    pattern = [el for el in range(x.ndim) if el not in [rows_idx, cols_idx,
                                                        chan_idx]]
    pattern += [rows_idx, cols_idx, chan_idx]

    x_copy = x.transpose(pattern)
    if y is not None and len(y) > 0:
        y_copy = y.transpose(pattern)

    # Take only the first batch
    x_copy = x_copy[0]
    if y is not None and len(y) > 0:
        # Take only the first batch and drop extra dim
        y_copy = y_copy[0, ..., 0]
        label_mask = my_label2rgboverlay(y_copy,
                                         cmap=cmap,
                                         image=x_copy,
                                         bglabel=void_label,
                                         alpha=0.2)
        combined_image = np.concatenate((x_copy, label_mask), axis=1)
    else:
        combined_image = x_copy
    scipy.misc.toimage(combined_image).save(fname)


def transform_matrix_offset_center(matrix, x, y):
    '''Shift the transformation matrix to be in the center of the image

    Apply an offset to the transformation matrix so that the origin of
    the axis is in the center of the image.'''
    o_x = float(x) / 2 + 0.5
    o_y = float(y) / 2 + 0.5
    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])
    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])
    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)
    return transform_matrix


def apply_transform(x, transform_matrix, fill_mode='nearest', cval=0.,
                    order=0, rows_idx=1, cols_idx=2):
    '''Apply an affine transformation on each channel separately.'''
    final_affine_matrix = transform_matrix[:2, :2]
    final_offset = transform_matrix[:2, 2]

    # Reshape to (*, 0, 1)
    pattern = [el for el in range(x.ndim) if el != rows_idx and el != cols_idx]
    pattern += [rows_idx, cols_idx]
    inv_pattern = [pattern.index(el) for el in range(x.ndim)]
    x = x.transpose(pattern)
    x_shape = list(x.shape)
    x = x.reshape([-1] + x_shape[-2:])  # squash everything on the first axis

    # Apply the transformation on each channel, sequence, batch, ..
    for i in range(x.shape[0]):
        x[i] = ndi.interpolation.affine_transform(x[i], final_affine_matrix,
                                                  final_offset, order=order,
                                                  mode=fill_mode, cval=cval)
    x = x.reshape(x_shape)  # unsquash
    x = x.transpose(inv_pattern)
    return x


def random_channel_shift(x, shift_range, rows_idx, cols_idx, chan_idx):
    '''Shift the intensity values of each channel uniformly.

    Channel by channel, shift all the intensity values by a random value in
    [-shift_range, shift_range]'''
    pattern = [chan_idx]
    pattern += [el for el in range(x.ndim) if el not in [rows_idx, cols_idx,
                                                         chan_idx]]
    pattern += [rows_idx, cols_idx]
    inv_pattern = [pattern.index(el) for el in range(x.ndim)]
    x = x.transpose(pattern)  # channel first
    x_shape = list(x.shape)
    # squash rows and cols together and everything else on the 1st
    x = x.reshape((-1, x_shape[-2] * x_shape[-1]))
    # Loop on the channels/batches/etc
    for i in range(x.shape[0]):
        min_x, max_x = np.min(x), np.max(x)
        x[i] = np.clip(x[i] + np.random.uniform(-shift_range, shift_range),
                       min_x, max_x)
    x = x.reshape(x_shape)  # unsquash
    x = x.transpose(inv_pattern)
    return x


def flip_axis(x, flipping_axis):
    '''Flip an axis by inverting the position of its elements'''
    pattern = [flipping_axis]
    pattern += [el for el in range(x.ndim) if el != flipping_axis]
    inv_pattern = [pattern.index(el) for el in range(x.ndim)]
    x = x.transpose(pattern)  # "flipping_axis" first
    x = x[::-1, ...]
    x = x.transpose(inv_pattern)
    return x


def pad_image(x, pad_amount, mode='reflect', constant=0.):
    '''Pad an image

    Pad an image by pad_amount on each side.

    Parameters
    ----------
    x: numpy ndarray
        The array to be padded.
    pad_amount: int
        The number of pixels of the padding.
    mode: string
        The padding mode. If "constant" a constant value will be used to
        fill the padding; if "reflect" the border pixels will be used in
        inverse order to fill the padding; if "nearest" the border pixel
        closer to the padded area will be used to fill the padding; if
        "zero" the padding will be filled with zeros.
    constant: int
        The value used to fill the padding when "constant" mode is
        selected.
        '''
    e = pad_amount
    shape = list(x.shape)
    shape[:2] += 2*e
    if mode == 'constant':
        x_padded = np.ones(shape, dtype=np.float32)*constant
        x_padded[e:-e, e:-e] = x.copy()
    else:
        x_padded = np.zeros(shape, dtype=np.float32)
        x_padded[e:-e, e:-e] = x.copy()

    if mode == 'reflect':
        # Edges
        x_padded[:e, e:-e] = np.flipud(x[:e, :])  # left
        x_padded[-e:, e:-e] = np.flipud(x[-e:, :])  # right
        x_padded[e:-e, :e] = np.fliplr(x[:, :e])  # top
        x_padded[e:-e, -e:] = np.fliplr(x[:, -e:])  # bottom
        # Corners
        x_padded[:e, :e] = np.fliplr(np.flipud(x[:e, :e]))  # top-left
        x_padded[-e:, :e] = np.fliplr(np.flipud(x[-e:, :e]))  # top-right
        x_padded[:e, -e:] = np.fliplr(np.flipud(x[:e, -e:]))  # bottom-left
        x_padded[-e:, -e:] = np.fliplr(np.flipud(x[-e:, -e:]))  # bottom-right
    elif mode == 'zero' or mode == 'constant':
        pass
    elif mode == 'nearest':
        # Edges
        x_padded[:e, e:-e] = x[[0], :]  # left
        x_padded[-e:, e:-e] = x[[-1], :]  # right
        x_padded[e:-e, :e] = x[:, [0]]  # top
        x_padded[e:-e, -e:] = x[:, [-1]]  # bottom
        # Corners
        x_padded[:e, :e] = x[[0], [0]]  # top-left
        x_padded[-e:, :e] = x[[-1], [0]]  # top-right
        x_padded[:e, -e:] = x[[0], [-1]]  # bottom-left
        x_padded[-e:, -e:] = x[[-1], [-1]]  # bottom-right
    else:
        raise ValueError("Unsupported padding mode \"{}\"".format(mode))
    return x_padded


def gen_warp_field(shape, sigma=0.1, grid_size=3):
    '''Generate an spline warp field'''
    import SimpleITK as sitk
    # Initialize bspline transform
    args = shape+(sitk.sitkFloat32,)
    ref_image = sitk.Image(*args)
    tx = sitk.BSplineTransformInitializer(ref_image, [grid_size, grid_size])

    # Initialize shift in control points:
    # mesh size = number of control points - spline order
    p = sigma * np.random.randn(grid_size+3, grid_size+3, 2)

    # Anchor the edges of the image
    p[:, 0, :] = 0
    p[:, -1:, :] = 0
    p[0, :, :] = 0
    p[-1:, :, :] = 0

    # Set bspline transform parameters to the above shifts
    tx.SetParameters(p.flatten())

    # Compute deformation field
    displacement_filter = sitk.TransformToDisplacementFieldFilter()
    displacement_filter.SetReferenceImage(ref_image)
    displacement_field = displacement_filter.Execute(tx)

    return displacement_field


def apply_warp(x, warp_field, fill_mode='reflect',
               interpolator=None,
               fill_constant=0, rows_idx=1, cols_idx=2):
    '''Apply an spling warp field on an image'''
    import SimpleITK as sitk
    if interpolator is None:
        interpolator = sitk.sitkLinear
    # Expand deformation field (and later the image), padding for the largest
    # deformation
    warp_field_arr = sitk.GetArrayFromImage(warp_field)
    max_deformation = np.max(np.abs(warp_field_arr))
    pad = np.ceil(max_deformation).astype(np.int32)
    warp_field_padded_arr = pad_image(warp_field_arr, pad_amount=pad,
                                      mode='nearest')
    warp_field_padded = sitk.GetImageFromArray(warp_field_padded_arr,
                                               isVector=True)

    # Warp x, one filter slice at a time
    pattern = [el for el in range(0, x.ndim) if el not in [rows_idx, cols_idx]]
    pattern += [rows_idx, cols_idx]
    inv_pattern = [pattern.index(el) for el in range(x.ndim)]
    x = x.transpose(pattern)  # batch, channel, ...
    x_shape = list(x.shape)
    x = x.reshape([-1] + x_shape[2:])  # *, r, c
    warp_filter = sitk.WarpImageFilter()
    warp_filter.SetInterpolator(interpolator)
    warp_filter.SetEdgePaddingValue(np.min(x).astype(np.double))
    for i in range(x.shape[0]):
        bc_pad = pad_image(x[i], pad_amount=pad, mode=fill_mode,
                           constant=fill_constant).T
        bc_f = sitk.GetImageFromArray(bc_pad)
        bc_f_warped = warp_filter.Execute(bc_f, warp_field_padded)
        bc_warped = sitk.GetArrayFromImage(bc_f_warped)
        x[i] = bc_warped[pad:-pad, pad:-pad].T
    x = x.reshape(x_shape)  # unsquash
    x = x.transpose(inv_pattern)
    return x


def random_transform(x, y=None,
                     rotation_range=0.,
                     width_shift_range=0.,
                     height_shift_range=0.,
                     shear_range=0.,
                     zoom_range=0.,
                     channel_shift_range=0.,
                     fill_mode='nearest',
                     cval=0.,
                     cval_mask=0.,
                     horizontal_flip=0.,  # probability
                     vertical_flip=0.,  # probability
                     rescale=None,
                     spline_warp=False,
                     warp_sigma=0.1,
                     warp_grid_size=3,
                     crop_size=None,
                     crop_mode='random',
                     return_optical_flow=False,
                     nclasses=None,
                     gamma=0.,
                     gain=1.,
                     chan_idx=3,  # No batch yet: (s, 0, 1, c)
                     rows_idx=1,  # No batch yet: (s, 0, 1, c)
                     cols_idx=2,  # No batch yet: (s, 0, 1, c)
                     void_label=None,
                     mask_labels=[],
                     prescale=1.0):
    '''Random Transform.

    A function to perform data augmentation of images and masks during
    the training  (on-the-fly). Based on [RandomTransform1]_.

    Parameters
    ----------
    x: array of floats
        An image.
    y: array of int
        An array with labels.
    rotation_range: int
        Degrees of rotation (0 to 180).
    width_shift_range: float
        The maximum amount the image can be shifted horizontally (in
        percentage).
    height_shift_range: float
        The maximum amount the image can be shifted vertically (in
        percentage).
    shear_range: float
        The shear intensity (shear angle in radians).
    zoom_range: float or list of floats
        The amout of zoom. If set to a scalar z, the zoom range will be
        randomly picked in the range [1-z, 1+z].
    channel_shift_range: float
        The shift range for each channel.
    fill_mode: string
        Some transformations can return pixels that are outside of the
        boundaries of the original image. The points outside the
        boundaries are filled according to the given mode (`constant`,
        `nearest`, `reflect` or `wrap`). Default: `nearest`.
    cval: int
        Value used to fill the points of the image outside the boundaries when
        fill_mode is `constant`. Default: 0.
    cval_mask: int
        Value used to fill the points of the mask outside the boundaries when
        fill_mode is `constant`. Default: 0.
    horizontal_flip: float
        The probability to randomly flip the images (and masks)
        horizontally. Default: 0.
    vertical_flip: bool
        The probability to randomly flip the images (and masks)
        vertically. Default: 0.
    rescale: float
        The rescaling factor. If None or 0, no rescaling is applied, otherwise
        the data is multiplied by the value provided (before applying
        any other transformation).
    spline_warp: bool
        Whether to apply spline warping.
    warp_sigma: float
        The sigma of the gaussians used for spline warping.
    warp_grid_size: int
        The grid size of the spline warping.
    crop_size: tuple
        The size of crop to be applied to images and masks (after any
        other transformation).
    crop_mode: string
        The crop strategy. Can be either 'random' or 'smart'.
        The 'random' mode randomly places the crop in the image.
        The 'smart' mode centers the crop in one of the locations where
        non-background masks are present more often (in a static image, or in
        all the frames over time in the case of sequences). To do so it looks
        for a label called 'background' or 'void' to retrieve the mask id or
        assumes the id of the background mask to be 0.

        When the crop is performed in 'smart' mode the image or the frames in a
        video sequence are cropped trying to satisfy the costraint on the
        percentage of background pixels in the crop. This allows to crop
        in the area of the image where the foreground is more concentrated
        (or not) and to extrapolate parts of a video sequence where the most of
        the motion happens.
        In the following heuristic the crop is performed starting by the
        computation of the foreground mask that contains the foreground
        pixels in the case of an image and the sum of the foregorund pixels
        over all the sequence in the case of the video. The crop is first
        centred in one of the point that has the maximum 'concentration of
        foreground', then if foreground/background constraint is not
        satisfied the heuristic searches for another crop center by moving
        'smart_crop_search_step' in the direction chosen randomly between
        the possible direction in the remaining quadrants of the image (with
        respect to the quadrant where the current center is placed). The
        heuristic terminates when the threshold constraint is satisfied or
        when the border of the image is reached. If it is not possible to
        satisfy the fg/bg constraint for tthe current image or video sequence,
        the heuristic return the best crop found before.
    return_optical_flow: bool
        If not False a dense optical flow will be concatenated to the
        end of the channel axis of the image. If True, angle and
        magnitude will be returned, if set to 'rbg' an RGB representation
        will be returned instead. Default: False.
    nclasses: int
        The number of classes of the dataset.
    gamma: float
        Controls gamma in Gamma correction.
    gain: float
        Controls gain in Gamma correction.
    chan_idx: int
        The index of the channel axis.
    rows_idx: int
        The index of the rows of the image.
    cols_idx: int
        The index of the cols of the image.
    void_label: int
        The index of the void label, if any. Used for padding.
    mask_labels: list of strings
        The list of the mask labels. Used in smart cropping to look for
        the background label.

    References
    ----------
    .. [RandomTransform1]
       https://github.com/fchollet/keras/blob/master/keras/preprocessing/image.py
    '''
    # Set this to a dir, if you want to save augmented images samples
    save_to_dir = None  # "./"

    if x.ndim != 4:
        raise RuntimeError('The x input of random transform should have '
                           '4 dimensions. Received %d instead.' % x.ndim)
    if y.ndim != 3:
        raise RuntimeError('The y input of random transform should have '
                           '3 dimensions. Received %d instead.' % x.ndim)
    if rescale:
        raise NotImplementedError()

    # Do not modify the original images
    x = x.copy()
    if y is not None and len(y) > 0:
        y = y[..., None]  # Add extra dim to y to simplify computation
        y = y.copy()

    # Prescale each image/mask in the batch
    if prescale != 1.0:
        import skimage.transform
        x = [skimage.transform.rescale(x_image, prescale,
                                       order=1,  # bilinear
                                       preserve_range=True) for x_image in x]
        x = np.stack(x, 0)
        if y is not None and len(y) > 0:
            y = [skimage.transform.rescale(y_image, prescale,
                                           order=0,  # Nearest-neighbor
                                           preserve_range=True)
                 for y_image in y]
            y = np.stack(y, 0)

    # listify zoom range
    if np.isscalar(zoom_range):
        if zoom_range > 1.:
            raise RuntimeError('Zoom range should be between 0 and 1. '
                               'Received: ', zoom_range)
        zoom_range = [1 - zoom_range, 1 - zoom_range]
    elif len(zoom_range) == 2:
        if any(el > 1. for el in zoom_range):
            raise RuntimeError('Zoom range should be between 0 and 1. '
                               'Received: ', zoom_range)
        zoom_range = [1-el for el in zoom_range]
    else:
        raise Exception('zoom_range should be a float or '
                        'a tuple or list of two floats. '
                        'Received arg: ', zoom_range)

    # Channel shift
    if channel_shift_range != 0:
        x = random_channel_shift(x, channel_shift_range, rows_idx, cols_idx,
                                 chan_idx)

    # Gamma correction
    if gamma > 0:
        scale = float(1)
        x = ((x / scale) ** gamma) * scale * gain

    # Affine transformations (zoom, rotation, shift, ..)
    if (rotation_range or height_shift_range or width_shift_range or
            shear_range or zoom_range != [1, 1]):

        # --> Rotation
        if rotation_range:
            theta = np.pi / 180 * np.random.uniform(-rotation_range,
                                                    rotation_range)
        else:
            theta = 0
        rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                    [np.sin(theta), np.cos(theta), 0],
                                    [0, 0, 1]])
        # --> Shift/Translation
        if height_shift_range:
            tx = (np.random.uniform(-height_shift_range, height_shift_range) *
                  x.shape[rows_idx])
        else:
            tx = 0
        if width_shift_range:
            ty = (np.random.uniform(-width_shift_range, width_shift_range) *
                  x.shape[cols_idx])
        else:
            ty = 0
        translation_matrix = np.array([[1, 0, tx],
                                       [0, 1, ty],
                                       [0, 0, 1]])
        # --> Shear
        if shear_range:
            shear = np.random.uniform(-shear_range, shear_range)
        else:
            shear = 0
        shear_matrix = np.array([[1, -np.sin(shear), 0],
                                 [0, np.cos(shear), 0],
                                 [0, 0, 1]])
        # --> Zoom
        if zoom_range == [1, 1]:
            zx, zy = 1, 1
        else:
            zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)
        zoom_matrix = np.array([[zx, 0, 0],
                                [0, zy, 0],
                                [0, 0, 1]])
        # Use a composition of homographies to generate the final transform
        # that has to be applied
        transform_matrix = np.dot(np.dot(np.dot(rotation_matrix,
                                                translation_matrix),
                                         shear_matrix), zoom_matrix)
        h, w = x.shape[rows_idx], x.shape[cols_idx]
        transform_matrix = transform_matrix_offset_center(transform_matrix,
                                                          h, w)
        # Apply all the transformations together
        x = apply_transform(x, transform_matrix, fill_mode=fill_mode,
                            cval=cval, order=1, rows_idx=rows_idx,
                            cols_idx=cols_idx)
        if y is not None and len(y) > 0:
            y = apply_transform(y, transform_matrix, fill_mode=fill_mode,
                                cval=cval_mask, order=0, rows_idx=rows_idx,
                                cols_idx=cols_idx)

    # Horizontal flip
    if np.random.random() < horizontal_flip:  # 0 = disabled
        x = flip_axis(x, cols_idx)
        if y is not None and len(y) > 0:
            y = flip_axis(y, cols_idx)

    # Vertical flip
    if np.random.random() < vertical_flip:  # 0 = disabled
        x = flip_axis(x, rows_idx)
        if y is not None and len(y) > 0:
            y = flip_axis(y, rows_idx)

    # Spline warp
    if spline_warp:
        import SimpleITK as sitk
        warp_field = gen_warp_field(shape=(x.shape[rows_idx],
                                           x.shape[cols_idx]),
                                    sigma=warp_sigma,
                                    grid_size=warp_grid_size)
        x = apply_warp(x, warp_field,
                       interpolator=sitk.sitkLinear,
                       fill_mode=fill_mode,
                       fill_constant=cval,
                       rows_idx=rows_idx, cols_idx=cols_idx)
        if y is not None and len(y) > 0:
            y = np.round(apply_warp(y, warp_field,
                                    interpolator=sitk.sitkNearestNeighbor,
                                    fill_mode=fill_mode,
                                    fill_constant=cval_mask,
                                    rows_idx=rows_idx, cols_idx=cols_idx))

    # Crop
    # Expects axes with shape (..., 0, 1)
    # TODO: Add center crop
    if crop_size:
        # Reshape to (..., 0, 1)
        pattern = [el for el in range(x.ndim) if el != rows_idx and
                   el != cols_idx] + [rows_idx, cols_idx]
        inv_pattern = [pattern.index(el) for el in range(x.ndim)]
        x = x.transpose(pattern)

        crop = list(crop_size)
        pad = [0, 0]
        h, w = x.shape[-2:]

        # Compute crop and padding amounts
        if crop[0] < h:
            if crop_mode == 'random':
                top = np.random.randint(h - crop[0])
        else:
            # Set pad and disable crop
            pad[0] = crop[0] - h
            top, crop[0] = 0, h
        if crop[1] < w:
            if crop_mode == 'random':
                left = np.random.randint(w - crop[1])
        else:
            # Set pad and disable crop
            pad[1] = crop[1] - w
            left, crop[1] = 0, w

        if crop_mode == 'smart':
            if y is None or len(y) < 1:
                raise RuntimeError('Cannot use smart cropping without labels')

            if pad[0] == 0 or pad[1] == 0:  # We crop in at least one dimension
                # Look for the background label, or assume it to be 0
                bg_label = np.where([m.lower() == 'background' for m
                                     in mask_labels])[0]
                if len(bg_label) == 0:
                    bg_label = np.where([m.lower() == 'void' for m
                                         in mask_labels])[0]
                bg_label = bg_label[0] if len(bg_label) else 0
                # Sum the number of fg pixels in time in each location
                fg_mask = y[..., 0] != bg_label  # 3D: seq, 0, 1
                t_fg = fg_mask.sum(axis=0)  # accumulate over time --> 2D

                # Compute the sum of the cumulated masks (i.e., the number of
                # fg pixels over time) of each candidate crop. The result is a
                # matrix of the cumulated fg values of the crop whose top-left
                # corner is positioned in each location
                from scipy.signal import fftconvolve
                effective_crop_size = [cr if cr < sz else sz for cr, sz in
                                       zip(crop_size, (h, w))]
                crop_filter = np.ones(effective_crop_size)
                cum_t_fg = fftconvolve(t_fg, crop_filter, 'valid')
                # Account for fft numerical instability
                cum_t_fg = np.clip(cum_t_fg, 0, np.inf)

                # Convert the comulated mask to a probability
                tot_t_fg = cum_t_fg.sum(dtype=float)
                p = (cum_t_fg / tot_t_fg).flatten()

                # Select some coordinates stochastically, with probability
                # of each location proportional to the cumulative amount of
                # foreground pixels in time
                n_locations = np.prod(cum_t_fg.shape)
                idx = np.random.choice(n_locations, p=p)  # 1D coord
                top, left = np.unravel_index(idx, cum_t_fg.shape)  # 2D

        # Cropping
        x = x[..., top:top+crop[0], left:left+crop[1]]
        if y is not None and len(y) > 0:
            y = y.transpose(pattern)
            y = y[..., top:top+crop[0], left:left+crop[1]]
        # Padding
        if pad != [0, 0]:
            pad_pattern = ((0, 0),) * (x.ndim - 2) + (
                (pad[0]//2, pad[0] - pad[0]//2),
                (pad[1]//2, pad[1] - pad[1]//2))
            x = np.pad(x, pad_pattern, 'constant')
            try:
                y = np.pad(y, pad_pattern, 'constant',
                           constant_values=void_label)
            except ValueError as e:
                raise type(e)(e.message + '\nCannot pad the image: the '
                              'dataset has no void class')

        x = x.transpose(inv_pattern)
        if y is not None and len(y) > 0:
            y = y.transpose(inv_pattern)

    if return_optical_flow:
        flow = optical_flow(x, rows_idx, cols_idx, chan_idx,
                            return_rgb=return_optical_flow == 'rgb')
        x = np.concatenate((x, flow), axis=chan_idx)

    # Save augmented images
    if save_to_dir:
        import seaborn as sns
        fname = 'data_augm_{}.png'.format(np.random.randint(1e4))
        cmap = sns.hls_palette(nclasses)
        save_img2(x, y, os.path.join(save_to_dir, fname),
                  cmap, void_label, rows_idx, cols_idx, chan_idx)

    # Undo extra dim
    if y is not None and len(y) > 0:
        y = y[..., 0]

    return x, y