################################################################################ # Example : perform conversion from tflearn checkpoint format to TensorFlow # protocol buffer (.pb) binary format and also .tflite files (for import into other tools) # Copyright (c) 2019 Toby Breckon, Durham University, UK # License : https://github.com/tobybreckon/fire-detection-cnn/blob/master/LICENSE # Acknowledgements: some portions - tensorflow tutorial examples and URL below ################################################################################ import glob,os ################################################################################ # import tensorflow api import tensorflow as tf from tensorflow.python.framework import graph_util from tensorflow.compat.v1.graph_util import convert_variables_to_constants from tensorflow.tools.graph_transforms import TransformGraph from tensorflow.python.tools import optimize_for_inference_lib from tensorflow.compat.v1.lite import TFLiteConverter ################################################################################ # import tflearn api import tflearn from tflearn.layers.core import * from tflearn.layers.conv import * from tflearn.layers.normalization import * from tflearn.layers.estimator import regression ################################################################################ # convert a loaded model definition by loading a checkpoint from a given path # retaining the network between the specified input and output layers # outputs to pbfilename as a binary .pb protocol buffer format file # e.g. for FireNet # model = construct_firenet (224, 224, False) # path = "models/FireNet/firenet"; # path to tflearn checkpoint including filestem # input_layer_name = 'InputData/X' # input layer of network # output_layer_name= 'FullyConnected_2/Softmax' # output layer of network # filename = "firenet.pb" # output filename def convert_to_pb(model, path, input_layer_name, output_layer_name, pbfilename, verbose=False): model.load(path,weights_only=True) print("[INFO] Loaded CNN network weights from " + path + " ...") print("[INFO] Re-export model ...") del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:] model.save("model-tmp.tfl") # taken from: https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow print("[INFO] Re-import model ...") input_checkpoint = "model-tmp.tfl" saver = tf.train.import_meta_graph(input_checkpoint + '.meta', True) sess = tf.Session(); saver.restore(sess, input_checkpoint) # print out all layers to find name of output if (verbose): op = sess.graph.get_operations() [print(m.values()) for m in op][1] print("[INFO] Freeze model to " + pbfilename + " ...") # freeze and removes nodes which are not related to feedforward prediction minimal_graph = convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_layer_name]) graph_def = optimize_for_inference_lib.optimize_for_inference(minimal_graph, [input_layer_name], [output_layer_name], tf.float32.as_datatype_enum) graph_def = TransformGraph(graph_def, [input_layer_name], [output_layer_name], ["sort_by_execution_order"]) with tf.gfile.GFile(pbfilename, 'wb') as f: f.write(graph_def.SerializeToString()) # write model to logs dir so we can visualize it as: # tensorboard --logdir="logs" if (verbose): writer = tf.summary.FileWriter('logs', graph_def) writer.close() # tidy up tmp files for f in glob.glob("model-tmp.tfl*"): os.remove(f) os.remove('checkpoint') ################################################################################ # convert a binary .pb protocol buffer format model to tflite format # e.g. for FireNet # pbfilename = "firenet.pb" # input_layer_name = 'InputData/X' # input layer of network # output_layer_name= 'FullyConnected_2/Softmax' # output layer of network def convert_to_tflite(pbfilename, input_layer_name, output_layer_name, input_tensor_dim_x=224, input_tensor_dim_y=224, input_tensor_channels=3): input_tensor={input_layer_name:[1,input_tensor_dim_x,input_tensor_dim_y,input_tensor_channels]} print("[INFO] tflite model to " + pbfilename.replace(".pb",".tflite") + " ...") converter = tf.lite.TFLiteConverter.from_frozen_graph(pbfilename, [input_layer_name], [output_layer_name], input_tensor) tflite_model = converter.convert() open(pbfilename.replace(".pb",".tflite"), "wb").write(tflite_model) ################################################################################