from .loader import Loader
import tensorflow as tf
import threading
import numpy as np
import time
import glob
import os
import imageio
import cv2
import deepdish as dd


def pad(x, min_side):
    if np.min(x.shape[:2]) >= min_side:
        return x
        sh = (max(min_side, x.shape[0]), max(min_side, x.shape[1])) + x.shape[2:]
        new_x = np.zeros(sh, dtype=x.dtype)
        new_x[:x.shape[0], :x.shape[1]] = x
        return new_x

def extract_optical_flow(fn, n_frames=34):
    img = dd.image.load(fn)
    if img.shape != (128*34, 128, 3):
        return []
    frames = np.array_split(img, 34, axis=0)
    grayscale_frames = [fr.mean(-1) for fr in frames]
    mags = []
    skip_frames = np.random.randint(34 - n_frames + 1)
    middle_frame = frames[np.random.randint(skip_frames, skip_frames+n_frames)]
    im0 = grayscale_frames[skip_frames]
    for f in range(1+skip_frames, 1+skip_frames+n_frames-1):
        im1 = grayscale_frames[f]
        flow = cv2.calcOpticalFlowFarneback(im0, im1,
                    None, # flow
                    0.5, # pyr_scale
                    3, # levels
                    np.random.randint(3, 20), # winsize
                    3, #iterations
                    5, #poly_n 
                    1.2, #poly_sigma
                    0 # flags
        mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1])
        im0 = im1
    mag = np.sum(mags, 0)
    mag = mag.clip(min=0)
    #norm_mag = np.tanh(mag * 10000)
    norm_mag = (mag - mag.min()) / (mag.max() - mag.min() + 1e-5)
    outputs = []
    outputs.append((middle_frame, norm_mag))
    return outputs

class VideoJPEGRollsFlowSaliency(Loader):
    def __init__(self, path, root_path='', batch_size=16, input_size=227, num_threads=10):
        self._path = path
        self._root_path = root_path
        with open(path) as f:
            self._list_files = [x.rstrip('\n') for x in f.readlines()]
        print('list_files', len(self._list_files))

        self._batch_size = batch_size
        self._input_size = input_size
        self._num_threads = num_threads
        self._coord = tf.train.Coordinator()
        self._base_shape = [batch_size, input_size, input_size]
        self._image_shape = self._base_shape + [3]
        self._label_shape = self._base_shape + [1]
        p_x = tf.placeholder(tf.float32, self._image_shape, name='x')
        p_y = tf.placeholder(tf.float32, self._label_shape, name='y')
        inputs = [p_x, p_y]
        self._queue = tf.FIFOQueue(400,
                [i.dtype for i in inputs], [i.get_shape() for i in inputs])
        self._inputs = inputs
        self._enqueue_op = self._queue.enqueue(inputs)
        self._queue_close_op = self._queue.close(cancel_pending_enqueues=True)
        self._threads = []

    def __feed(self, rank):
        time.sleep(np.random.uniform(0, 3))
        batch_x = np.zeros(self._image_shape, dtype=np.float32)
        batch_y = np.zeros(self._label_shape, dtype=np.float32)
        pool = []
        N = len(self._list_files)
        input_size = self._input_size
        while True:
            while len(pool) < self._batch_size * 30:
                i = np.random.randint(N)

                fn = os.path.join(self._root_path, self._list_files[i])
                outputs = extract_optical_flow(fn, n_frames=FRAMES)
                for img, mag in outputs:
                    img0 = dd.image.resize(img, min_side=input_size)
                    mag0 = dd.image.resize(mag, min_side=input_size)

                    # Now find a random window
                    h = np.random.randint(img0.shape[0] - input_size + 1)
                    w = np.random.randint(img0.shape[1] - input_size + 1)
                    if np.random.randint(2) == 0:
                        ss = np.s_[:]
                        # flipped
                        ss = np.s_[:, ::-1]

                    pool.append((img0[ss], mag0[ss]))

                if len(pool) >= self._batch_size:

            for b in range(self._batch_size):
                i = np.random.randint(len(pool))
                img, mag = pool.pop(i)
                batch_x[b] = img
                batch_y[b, ..., 0] = mag

            yield {self._inputs[0]: batch_x, self._inputs[1]: batch_y}

    def __thread(self, session, rank):
        with self._coord.stop_on_exception():
            for feed_dict in self.__feed(rank):
      , feed_dict)

    def batch(self):
        x, y = self._queue.dequeue()
        return x, {'saliency': y}

    def batch_size(self):
        return self._batch_size

    def start(self, session):
        for i in range(self._num_threads):
            t = threading.Thread(target=VideoJPEGRollsFlowSaliency.__thread,
                                 args=(self, session, i))
            t.daemon = True

    def check_status(self):
        ret = False
        for i, t in enumerate(self._threads):
            if not t.is_alive():
                print(f'Thread #{i} has died')
                ret = True
        return ret

    def stop(self, session):