import os import time import numpy as np import tensorflow as tf import pickle from multiprocessing.pool import ThreadPool train_stats = ( 'Training statistics: \n' '\tLearning rate : {}\n' '\tBatch size : {}\n' '\tEpoch number : {}\n' '\tBackup every : {}' ) pool = ThreadPool() def _save_ckpt(self, step, loss_profile): file = '{}-{}{}' model = self.meta['name'] profile = file.format(model, step, '.profile') profile = os.path.join(self.FLAGS.backup, profile) with open(profile, 'wb') as profile_ckpt: pickle.dump(loss_profile, profile_ckpt) ckpt = file.format(model, step, '') ckpt = os.path.join(self.FLAGS.backup, ckpt) self.say('Checkpoint at step {}'.format(step)) self.saver.save(self.sess, ckpt) def train(self): loss_ph = self.framework.placeholders loss_mva = None; profile = list() batches = self.framework.shuffle() loss_op = self.framework.loss for i, (x_batch, datum) in enumerate(batches): if not i: self.say(train_stats.format( self.FLAGS.lr, self.FLAGS.batch, self.FLAGS.epoch, self.FLAGS.save )) feed_dict = { loss_ph[key]: datum[key] for key in loss_ph } feed_dict[self.inp] = x_batch feed_dict.update(self.feed) fetches = [self.train_op, loss_op] if self.FLAGS.summary: fetches.append(self.summary_op) fetched = self.sess.run(fetches, feed_dict) loss = fetched[1] if loss_mva is None: loss_mva = loss loss_mva = .9 * loss_mva + .1 * loss step_now = self.FLAGS.load + i + 1 if self.FLAGS.summary: self.writer.add_summary(fetched[2], step_now) form = 'step {} - loss {} - moving ave loss {}' self.say(form.format(step_now, loss, loss_mva)) profile += [(loss, loss_mva)] ckpt = (i+1) % (self.FLAGS.save // self.FLAGS.batch) args = [step_now, profile] if not ckpt: _save_ckpt(self, *args) if ckpt: _save_ckpt(self, *args) def return_predict(self, im): assert isinstance(im, np.ndarray), \ 'Image is not a np.ndarray' h, w, _ = im.shape im = self.framework.resize_input(im) this_inp = np.expand_dims(im, 0) feed_dict = {self.inp : this_inp} out = self.sess.run(self.out, feed_dict)[0] boxes = self.framework.findboxes(out) threshold = self.FLAGS.threshold boxesInfo = list() for box in boxes: tmpBox = self.framework.process_box(box, h, w, threshold) if tmpBox is None: continue boxesInfo.append({ "label": tmpBox[4], "confidence": tmpBox[6], "topleft": { "x": tmpBox[0], "y": tmpBox[2]}, "bottomright": { "x": tmpBox[1], "y": tmpBox[3]} }) return boxesInfo import math def predict(self): inp_path = self.FLAGS.imgdir all_inps = os.listdir(inp_path) all_inps = [i for i in all_inps if self.framework.is_inp(i)] if not all_inps: msg = 'Failed to find any images in {} .' exit('Error: {}'.format(msg.format(inp_path))) batch = min(self.FLAGS.batch, len(all_inps)) # predict in batches n_batch = int(math.ceil(len(all_inps) / batch)) for j in range(n_batch): from_idx = j * batch to_idx = min(from_idx + batch, len(all_inps)) # collect images input in the batch this_batch = all_inps[from_idx:to_idx] inp_feed = pool.map(lambda inp: ( np.expand_dims(self.framework.preprocess( os.path.join(inp_path, inp)), 0)), this_batch) # Feed to the net feed_dict = {self.inp : np.concatenate(inp_feed, 0)} self.say('Forwarding {} inputs ...'.format(len(inp_feed))) start = time.time() out = self.sess.run(self.out, feed_dict) stop = time.time(); last = stop - start self.say('Total time = {}s / {} inps = {} ips'.format( last, len(inp_feed), len(inp_feed) / last)) # Post processing self.say('Post processing {} inputs ...'.format(len(inp_feed))) start = time.time() pool.map(lambda p: (lambda i, prediction: self.framework.postprocess( prediction, os.path.join(inp_path, this_batch[i])))(*p), enumerate(out)) stop = time.time(); last = stop - start # Timing self.say('Total time = {}s / {} inps = {} ips'.format( last, len(inp_feed), len(inp_feed) / last))