import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" import numpy as np import tensorflow as tf import cv2 import cfg from shufflenetv2_centernet_V2 import ShuffleNetV2_centernet # from shufflenetv2_centernet_V2_SEB import Shufflenetv2_Centernet_SEB # from yolov3_centernet_V2 import yolov3_centernet from create_label import CreatGroundTruth def parse_color_data(example_proto): features = {"img_raw": tf.FixedLenFeature([], tf.string), "label": tf.FixedLenFeature([], tf.string), "width": tf.FixedLenFeature([], tf.int64), "height": tf.FixedLenFeature([], tf.int64)} parsed_features = tf.parse_single_example(example_proto, features) img = parsed_features["img_raw"] img = tf.decode_raw(img, tf.uint8) width = parsed_features["width"] height = parsed_features["height"] img = tf.reshape(img, [height, width, 3]) img = tf.cast(img, tf.float32) * (1. / 255.) - 0.5 label = parsed_features["label"] label = tf.decode_raw(label, tf.float32) return img, label def erase_invalid_val(sequence): label = [] h, w = sequence.shape mask = (sequence != -1.0) for i in range(h): seq_new = sequence[i][mask[i]] label.append(list(seq_new)) return label filenames = [cfg.tfrecords_path] dataset = tf.data.TFRecordDataset(filenames) dataset = dataset.shuffle(buffer_size=1000) dataset = dataset.map(parse_color_data) val1=tf.constant(-0.5,tf.float32) val2 = tf.constant(-1, tf.float32) dataset = dataset.padded_batch(cfg.batch_size, padded_shapes=([None, None, 3], [None]), padding_values=(val1, val2)) dataset = dataset.repeat(cfg.epochs) iterator = dataset.make_one_shot_iterator() next_element = iterator.get_next() train_start_time = cv2.getTickCount() model=ShuffleNetV2_centernet() sess = tf.Session() sess.run(tf.global_variables_initializer()) train_writer = tf.summary.FileWriter("shufflenetv2_voc_summary", sess.graph) saver=tf.train.Saver(max_to_keep=20) if 0:#reload model model_file = tf.train.latest_checkpoint('shufflenetv2_voc/') saver.restore(sess, model_file) print("reload ckpt from "+model_file) try: while True: batch_start_time=cv2.getTickCount() img_batch, label_batch = sess.run(next_element) label_batch = erase_invalid_val(label_batch) cls_gt_batch, size_gt_batch = CreatGroundTruth(label_batch) feed = {model.inputs: img_batch, model.is_training:True, model.size_gt:size_gt_batch, model.cls_gt:cls_gt_batch } fetches = [ model.cls_loss, model.size_loss, model.total_loss, model.global_step, model.lr, model.merged_summay, model.train_op, ] cls_loss,size_loss,total_loss,global_step, lr, summary, _ = sess.run(fetches, feed) train_writer.add_summary(summary, global_step) time_elapsed = (cv2.getTickCount()-batch_start_time)/cv2.getTickFrequency() if global_step%200==0: saver.save(sess,"shufflenetv2_seb_voc/shufflenetv2_seb_voc.ckpt",global_step=global_step) # saver.save(sess,"shufflenetv2_face_SEB_summary/shufflenetv2_face_SEB.ckpt",global_step=global_step) # saver.save(sess,"shufflenetv2_voc/shufflenetv2_voc.ckpt",global_step=global_step) # saver.save(sess,"yolov3_voc/yolov3_voc.ckpt",global_step=global_step) # saver.save(sess,"shufflenev2_face_ori/shufflenev2_face.ckpt",global_step=global_step) if global_step % 10 == 0: print("-------Training {0}th batch-------".format(global_step)) print("global_step:{0} total_loss:{1:0.3f} cls_loss:{2:0.3f} size_loss:{3:0.3f}".format(global_step,total_loss,cls_loss,size_loss)) print("learning_rate:{0:0.6f}".format(lr)) # print("predicts:", predicts) print('The batch run total {0:0.5f}s'.format(time_elapsed)) except tf.errors.OutOfRangeError: print('Training has completed...') train_total_time=(cv2.getTickCount()-train_start_time)/cv2.getTickFrequency() print('Training has stopped...') hour=train_total_time // 3600 minute=(train_total_time-hour*3600)//60 print('Training runs {:.0f}h {:.0f}m...'.format(hour,minute)) sess.close()