# -*- coding: utf-8 -*-

import threading

import tensorflow as tf

from auto_pose.ae.utils import lazy_property
import time

class Queue(object):

    def __init__(self, dataset, num_threads, queue_size, batch_size):
        self._dataset = dataset
        self._num_threads = num_threads
        self._queue_size = queue_size
        self._batch_size = batch_size

        datatypes = 2*['float32']
        shapes = 2*[self._dataset.shape]

        batch_shape = [None]+list(self._dataset.shape)
        
        self._placeholders = 2*[
            tf.placeholder(dtype=tf.float32, shape=batch_shape),
            tf.placeholder(dtype=tf.float32, shape=batch_shape) 
        ]

        self._queue = tf.FIFOQueue(self._queue_size, datatypes, shapes=shapes)
        self.x, self.y = self._queue.dequeue_up_to(self._batch_size)
        self.enqueue_op = self._queue.enqueue_many(self._placeholders)

        self._coordinator = tf.train.Coordinator()

        self._threads = []


    def start(self, session):
        assert len(self._threads) == 0
        tf.train.start_queue_runners(session, self._coordinator)
        for _ in range(self._num_threads):
            thread = threading.Thread(
                        target=Queue.__run__, 
                        args=(self, session)
                        )
            thread.deamon = True
            thread.start()
            self._threads.append(thread)


    def stop(self, session):
        self._coordinator.request_stop()
        session.run(self._queue.close(cancel_pending_enqueues=True))
        self._coordinator.join(self._threads)
        self._threads[:] = []


    def __run__(self, session):
        while not self._coordinator.should_stop():        
            # a= time.time()
            # print 'batching...'
            batch = self._dataset.batch(self._batch_size)
            # print 'batch creation time ', time.time()-a
            
            feed_dict = { k:v for k,v in zip( self._placeholders, batch ) }
            try:
                session.run(self.enqueue_op, feed_dict)
                # print 'enqueued something'
            except tf.errors.CancelledError as e:
                print('worker was cancelled')
                pass