from random import shuffle from random import seed seed(123) from threading import Thread import numpy as np import time import tensorflow as tf try: import Queue as Q # ver. < 3.0 except ImportError: import queue as Q from sklearn.preprocessing import normalize PriorityQueue = Q.PriorityQueue # Custom prioriy queue that is able to clear the queue once it is full # and cut it to half. Therefore, everytime size of buffer is dqn_replay_buffer_size, # we will keep half of the most valuable states and remove the rest to provide # space for new experiences class CustomQueue(PriorityQueue): ''' A custom queue subclass that provides a :meth:`clear` method. ''' def __init__(self, size): PriorityQueue.__init__(self, size) def clear(self): ''' Clears all items from the queue. ''' with self.mutex: unfinished = self.unfinished_tasks - len(self.queue) if unfinished <= 0: if unfinished < 0: raise ValueError('task_done() called too many times') self.all_tasks_done.notify_all() self.queue = self.queue[0:len(self.queue)/2] self.unfinished_tasks = unfinished + len(self.queue) self.not_full.notify_all() def isempty(self): with self.mutex: return len(self.queue) == 0 def isfull(self): with self.mutex: return len(self.queue) == self.maxsize class Transition(object): """ A class for holding the experiences collected from seq2seq model """ def __init__(self, state, action, state_prime, action_prime, reward, q_value, done): """ Args: state: current decoder output state action: the greedy action selected from current decoder output state_prime: next decoder output state reward: reward of the greedy action selected q_value: Q-value of the greedy action selected done: whether we reached End-Of-Sequence or not """ self.state = state # size: dqn_input_feature_len self.action = action # size: 1 self.state_prime = state_prime # size: dqn_input_feature_len self.action_prime = action_prime self.reward = reward # size: vocab_size self.q_value = q_value # size: vocab_size self.done = done # true/false def __cmp__(self, item): """ PriorityQueue uses this functino to sort the rewards Args: We sort the queue such that items with higher rewards are in the head of max-heap """ return cmp(item.reward, self.reward) # bigger numbers have more priority class ReplayBatch(object): """ A class for creating batches required for training DDQN. """ def __init__(self, hps, example_list, dqn_batch_size, use_state_prime = False, max_art_oovs = 0): """ Args: hps: seq2seq model parameters example_list: list of experiences dqn_batch_size: DDQN batch size use_state_prime: whether to use the next decoder state to make the batch or the current one max_art_oovs: number of OOV tokens in current batch Properties: _x: The input to DDQN model for training, this is basically the decoder output (dqn_batch_size, dqn_input_feature_len) _y: The Q-estimation (dqn_batch_size, vocab_size) _y_extended: The Q-estimation (dqn_batch_size, vocab_size + max_art_oovs) """ self._x = np.zeros((dqn_batch_size, hps.dqn_input_feature_len)) self._y = np.zeros((dqn_batch_size, hps.vocab_size)) self._y_extended = np.zeros((dqn_batch_size, hps.vocab_size + max_art_oovs)) for i,e in enumerate(example_list): if use_state_prime: self._x[i,:]=e.state_prime else: self._x[i,:]=e.state self._y[i,:]=normalize(e.q_value[0:hps.vocab_size], axis=1, norm='l1') if max_art_oovs == 0: self._y_extended[i,:] = normalize(e.q_value[0:hps.vocab_size], axis=1, norm='l1') else: self._y_extended[i,:] = e.q_value class ReplayBuffer(object): """ A class for implementing the priority experience buffer. """ BATCH_QUEUE_MAX = 100 # max number of batches the batch_queue can hold def __init__(self, hps): self._hps = hps self._buffer = CustomQueue(self._hps.dqn_replay_buffer_size) self._batch_queue = Q.Queue(self.BATCH_QUEUE_MAX) self._example_queue = Q.Queue(self.BATCH_QUEUE_MAX * self._hps.dqn_batch_size) self._num_example_q_threads = 1 # num threads to fill example queue self._num_batch_q_threads = 1 # num threads to fill batch queue self._bucketing_cache_size = 100 # how many batches-worth of examples to load into cache before bucketing # Start the threads that load the queues self._example_q_threads = [] for _ in range(self._num_example_q_threads): self._example_q_threads.append(Thread(target=self.fill_example_queue)) self._example_q_threads[-1].daemon = True self._example_q_threads[-1].start() self._batch_q_threads = [] for _ in range(self._num_batch_q_threads): self._batch_q_threads.append(Thread(target=self.fill_batch_queue)) self._batch_q_threads[-1].daemon = True self._batch_q_threads[-1].start() # Start a thread that watches the other threads and restarts them if they're dead self._watch_thread = Thread(target=self.watch_threads) self._watch_thread.daemon = True self._watch_thread.start() def next_batch(self): """Return a Batch from the batch queue. If mode='decode' then each batch contains a single example repeated beam_size-many times; this is necessary for beam search. Returns: batch: a Batch object, or None if we're in single_pass mode and we've exhausted the dataset. """ # If the batch queue is empty, print a warning if self._batch_queue.qsize() == 0: tf.logging.warning('Bucket input queue is empty when calling next_batch. Bucket queue size: %i, Input queue size: %i', self._batch_queue.qsize(), self._example_queue.qsize()) return None batch = self._batch_queue.get() # get the next Batch return batch @staticmethod def create_batch(_hps, batch, batch_size, use_state_prime=False, max_art_oovs=0): """ Create a DDQN-compatible batch from the input transitions Args: _hps: seq2seq model parameters batch: a list of Transitions dqn_batch_size: DDQN batch size use_state_prime: whether to use the next decoder state to make the batch or the current one max_art_oovs: number of OOV tokens in current batch Returns: An object of ReplayBatch class """ return ReplayBatch(_hps, batch, batch_size, use_state_prime, max_art_oovs) def fill_example_queue(self): """Reads data from file and processes into Examples which are then placed into the example queue.""" while True: try: input_gen = self._example_generator().next() except StopIteration: # if there are no more examples: tf.logging.info("The example generator for this example queue filling thread has exhausted data.") raise Exception("single_pass mode is off but the example generator is out of data; error.") self._example_queue.put(input_gen) # place the pair in the example queue. def fill_batch_queue(self): """Takes Examples out of example queue, sorts them by encoder sequence length, processes into Batches and places them in the batch queue.""" while True: # Get bucketing_cache_size-many batches of Examples into a list, then sort inputs = [] for _ in range(self._hps.dqn_batch_size * self._bucketing_cache_size): inputs.append(self._example_queue.get()) # feed back all the samples to the buffer self.add(inputs) # Group the sorted Examples into batches, optionally shuffle the batches, and place in the batch queue. batches = [] for i in range(0, len(inputs), self._hps.dqn_batch_size): batches.append(inputs[i:i + self._hps.dqn_batch_size]) shuffle(batches) for b in batches: # each b is a list of Example objects self._batch_queue.put(ReplayBatch(self._hps, b, self._hps.dqn_batch_size)) def watch_threads(self): """Watch example queue and batch queue threads and restart if dead.""" while True: time.sleep(60) for idx,t in enumerate(self._example_q_threads): if not t.is_alive(): # if the thread is dead tf.logging.error('Found example queue thread dead. Restarting.') new_t = Thread(target=self.fill_example_queue) self._example_q_threads[idx] = new_t new_t.daemon = True new_t.start() for idx,t in enumerate(self._batch_q_threads): if not t.is_alive(): # if the thread is dead tf.logging.error('Found batch queue thread dead. Restarting.') new_t = Thread(target=self.fill_batch_queue) self._batch_q_threads[idx] = new_t new_t.daemon = True new_t.start() def add(self, items): """ Adding a list of experiences to the buffer. When buffer is full, we get rid of half of the least important experiences and keep the rest. Args: items: A list of experiences of size (batch_size, k, max_dec_steps, hidden_dim) """ for item in items: if not self._buffer.isfull(): self._buffer.put_nowait(item) else: print('Replay Buffer is full, getting rid of unimportant transitions...') self._buffer.clear() self._buffer.put_nowait(item) print('ReplayBatch size: {}'.format(self._buffer.qsize())) print('ReplayBatch example queue size: {}'.format(self._example_queue.qsize())) print('ReplayBatch batch queue size: {}'.format(self._batch_queue.qsize())) def _buffer_len(self): return self._buffer.qsize() def _example_generator(self): while True: if not self._buffer.isempty(): item = self._buffer.get_nowait() self._buffer.task_done() yield item