#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Patrick Wieschollek <mail@patwie.com>

from tensorpack import *
import tensorpack as tp
import argparse
import cv2
import numpy as np
import psf
from scipy import ndimage

"""
Usage:

python data_provider.py --lmdb mydb2.lmdb
"""


class PSF(tp.dataflow.DataFlow):
    """TensorPack dataflow proxy for PSF-sampler

    Attributes:
        kernel_shape (int): size of PSF kernel
        multiple (int): number of psf for one step (could be re-written by tp.BatchData)
        psf_gen (python-generator): generator producing PSF samples
    """
    def __init__(self, kernel_shape=7, multiple=5):
        self.kernel_shape = kernel_shape
        self.multiple = multiple
        self.psf_gen = psf.PSF(kernel_size=kernel_shape)

    def reset_state(self):
        pass

    def size(self):
        return 100000000

    def get_data(self):
        sampler = self.psf_gen.sample()
        while True:
            k = []
            for _ in range(self.multiple):
                k.append(next(sampler))
            yield k


class Blur(tp.dataflow.DataFlow):
    """Apply blur from SPF kernels to incoming images.

    This yields [blurry1, blurry2, ... blurry5, sharp1, sharp2, ..., sharp5]

    Attributes:
        ds_images: dataflow producing image-bursts (should already contain motion blur).
        ds_psf: dataflow producing psf kernels
    """
    def __init__(self, ds_images, ds_psf):
        self.ds_images = ds_images
        self.ds_psf = ds_psf

    def reset_state(self):
        self.ds_images.reset_state()
        self.ds_psf.reset_state()

    def size(self):
        return self.ds_images.size()

    def get_data(self):

        image_iter = self.ds_images.get_data()
        psf_iter = self.ds_psf.get_data()

        for dp_image in image_iter:

            # sample camera shake kernel
            dp_psf = next(psf_iter)

            # synthesize ego-motion
            for t, k in enumerate(dp_psf):
                blurry = dp_image[t]
                for c in range(3):
                    blurry[:, :, c] = ndimage.convolve(blurry[:, :, c], k, mode='constant', cval=0.0)
                dp_image[t] = blurry

            yield dp_image


def get_lmdb_data(lmdb_file):

    class Decoder(MapData):
        """compress images into JPEG format"""
        def __init__(self, df):
            def func(dp):
                return [cv2.imdecode(np.asarray(bytearray(i), dtype=np.uint8), cv2.IMREAD_COLOR) for i in dp]
            super(Decoder, self).__init__(df, func)

    ds = LMDBDataPoint(lmdb_file, shuffle=True)
    ds = Decoder(ds)
    return ds


def get_data(lmdb_file, shape=(256, 256), ego_motion_size=[17, 25, 35, 71]):

    # s = (shape[0] + 2 * max(ego_motion_size), shape[1] + 2 * max(ego_motion_size))
    s = (306, 306)

    ds_img = get_lmdb_data(lmdb_file)
    # to remove hints from border-handling we crop a slightly larger regions ...
    ds_img = AugmentImageComponents(ds_img, [imgaug.RandomCrop(s)], index=range(10), copy=True)
    # .. and then apply the PSF kernel ....

    ds_psf = [PSF(kernel_shape=m) for m in ego_motion_size]
    ds_psf = RandomChooseData(ds_psf)

    ds = Blur(ds_img, ds_psf)
    # ... before the final crop
    ds = AugmentImageComponents(ds, [imgaug.CenterCrop(shape)], index=range(10), copy=True)

    def combine(x):
        nr = len(x)
        blurry = np.array(x[:nr // 2])
        sharp = np.array(x[nr // 2:])
        return [blurry, sharp]
    ds = MapData(ds, combine)
    return ds


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--lmdb', type=str, help='path to lmdb', required='True')
    parser.add_argument('--num', type=int, help='display window', default=5)
    parser.add_argument('--show', help='display window instead of writing', action='store_true')

    args = parser.parse_args()

    ds = get_data(args.lmdb, shape=(256, 256), ego_motion_size=[17, 25, 35, 71])
    ds.reset_state()

    for counter, dp in enumerate(ds.get_data()):
        # from IPython import embed
        # embed()

        blurry = dp[0]
        blurry = [blurry[i, ...] for i in range(5)]
        blurry = np.concatenate(blurry, axis=1)

        sharp = dp[1]
        sharp = [sharp[i, ...] for i in range(5)]
        sharp = np.concatenate(sharp, axis=1)

        out = np.concatenate([blurry, sharp], axis=1)[:, :, ::-1]

        if args.show:
            cv2.imshow('stacked_blurry', out)
            cv2.waitKey(0)
        else:
            cv2.imwrite('/tmp/stacked_data_%i.jpg' % counter, out)

        if counter > args.num:
            break