""" tfnet secondary (helper) methods """ from ..utils.loader import create_loader from time import time as timer import tensorflow as tf import numpy as np import sys import cv2 import os old_graph_msg = 'Resolving old graph def {} (no guarantee)' def build_train_op(self): self.framework.loss(self.out) self.say('Building {} train op'.format(self.meta['model'])) optimizer = self._TRAINER[self.FLAGS.trainer](self.FLAGS.lr) gradients = optimizer.compute_gradients(self.framework.loss) self.train_op = optimizer.apply_gradients(gradients) def load_from_ckpt(self): if self.FLAGS.load < 0: # load lastest ckpt with open(os.path.join(self.FLAGS.backup, 'checkpoint'), 'r') as f: last = f.readlines()[-1].strip() load_point = last.split(' ')[1] load_point = load_point.split('"')[1] load_point = load_point.split('-')[-1] self.FLAGS.load = int(load_point) load_point = os.path.join(self.FLAGS.backup, self.meta['name']) load_point = '{}-{}'.format(load_point, self.FLAGS.load) self.say('Loading from {}'.format(load_point)) try: self.saver.restore(self.sess, load_point) except: load_old_graph(self, load_point) def say(self, *msgs): if not self.FLAGS.verbalise: return msgs = list(msgs) for msg in msgs: if msg is None: continue print(msg) def load_old_graph(self, ckpt): ckpt_loader = create_loader(ckpt) self.say(old_graph_msg.format(ckpt)) for var in tf.global_variables(): name = var.name.split(':')[0] args = [name, var.get_shape()] val = ckpt_loader(args) assert val is not None, \ 'Cannot find and load {}'.format(var.name) shp = val.shape plh = tf.placeholder(tf.float32, shp) op = tf.assign(var, plh) self.sess.run(op, {plh: val}) def _get_fps(self, frame): elapsed = int() start = timer() preprocessed = self.framework.preprocess(frame) feed_dict = {self.inp: [preprocessed]} net_out = self.sess.run(self.out, feed_dict)[0] processed = self.framework.postprocess(net_out, frame, False) return timer() - start def camera(self): file = self.FLAGS.demo SaveVideo = self.FLAGS.saveVideo if file == 'camera': file = 0 else: assert os.path.isfile(file), \ 'file {} does not exist'.format(file) camera = cv2.VideoCapture(file) if file == 0: self.say('Press [ESC] to quit demo') assert camera.isOpened(), \ 'Cannot capture source' if file == 0:#camera window cv2.namedWindow('', 0) _, frame = camera.read() height, width, _ = frame.shape cv2.resizeWindow('', width, height) else: _, frame = camera.read() height, width, _ = frame.shape if SaveVideo: fourcc = cv2.VideoWriter_fourcc(*'XVID') if file == 0:#camera window fps = 1 / self._get_fps(frame) if fps < 1: fps = 1 else: fps = round(camera.get(cv2.CAP_PROP_FPS)) videoWriter = cv2.VideoWriter( 'video.avi', fourcc, fps, (width, height)) # buffers for demo in batch buffer_inp = list() buffer_pre = list() elapsed = int() start = timer() self.say('Press [ESC] to quit demo') # Loop through frames while camera.isOpened(): elapsed += 1 _, frame = camera.read() if frame is None: print ('\nEnd of Video') break preprocessed = self.framework.preprocess(frame) buffer_inp.append(frame) buffer_pre.append(preprocessed) # Only process and imshow when queue is full if elapsed % self.FLAGS.queue == 0: feed_dict = {self.inp: buffer_pre} net_out = self.sess.run(self.out, feed_dict) for img, single_out in zip(buffer_inp, net_out): postprocessed = self.framework.postprocess( single_out, img, False) if SaveVideo: videoWriter.write(postprocessed) if file == 0: #camera window cv2.imshow('', postprocessed) # Clear Buffers buffer_inp = list() buffer_pre = list() if elapsed % 5 == 0: sys.stdout.write('\r') sys.stdout.write('{0:3.3f} FPS'.format( elapsed / (timer() - start))) sys.stdout.flush() if file == 0: #camera window choice = cv2.waitKey(1) if choice == 27: break sys.stdout.write('\n') if SaveVideo: videoWriter.release() camera.release() if file == 0: #camera window cv2.destroyAllWindows() def to_darknet(self): darknet_ckpt = self.darknet with self.graph.as_default() as g: for var in tf.global_variables(): name = var.name.split(':')[0] var_name = name.split('-') l_idx = int(var_name[0]) w_sig = var_name[1].split('/')[-1] l = darknet_ckpt.layers[l_idx] l.w[w_sig] = var.eval(self.sess) for layer in darknet_ckpt.layers: for ph in layer.h: layer.h[ph] = None return darknet_ckpt