import logging import os import keras import tensorflow as tf from tensorflow.python.platform import gfile from tensorflow.python.tools import freeze_graph from tensorflow.python.tools import optimize_for_inference_lib from tensorflow.python.framework import dtypes from style_transfer import models logger = logging.getLogger(__name__) def freeze_keras_model_graph(model, basename, output_dir): """Extract and freeze the tensorflow graph from a Keras model. Args: model (keras.models.Model): A Keras model. basename (str): the basename of the Keras model. E.g. starry_night.h5 output_dir (str): a directory to output the frozen graph Returns: output_graph_filename (str): a path to the saved frozen graph. """ name, _ = os.path.splitext(basename) saver = tf.train.Saver() with keras.backend.get_session() as sess: checkpoint_filename = os.path.join(output_dir, '%s.ckpt' % name) output_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name) saver.save(sess, checkpoint_filename) tf.train.write_graph( sess.graph_def, output_dir, '%s_graph_def.pbtext' % name ) freeze_graph.freeze_graph( input_graph=os.path.join(output_dir, '%s_graph_def.pbtext' % name), input_saver='', input_binary=False, input_checkpoint=checkpoint_filename, output_graph=output_graph_filename, output_node_names='deprocess_stylized_image_1/mul', restore_op_name="save/restore_all", filename_tensor_name="save/Const:0", clear_devices=True, initializer_nodes=None ) logger.info('Saved frozen graph to: %s' % output_graph_filename) return output_graph_filename def optimize_graph(frozen_graph_filename, suffix='optimized'): """Optimize a TensorFlow graph for inference. Optimized graphs are saved to the same directory as the input frozen graph. Args: frozen_graph_filename (str): the filename of a frozen graph. suffix (optional, str): a suffix to append to the optimized graph file. Returns: optimized_graph_filename (str): a path to the saved optimized graph. """ output_dir, basename = os.path.split(frozen_graph_filename) graph_def = load_graph_def(frozen_graph_filename) optimized_graph = optimize_for_inference_lib.optimize_for_inference( input_graph_def=graph_def, input_node_names=['input_1'], placeholder_type_enum=dtypes.float32.as_datatype_enum, output_node_names=['deprocess_stylized_image_1/mul'], toco_compatible=True ) optimized_graph_filename = os.path.basename( frozen_graph_filename).replace('frozen', suffix) optimized_graph_filename = optimized_graph_filename tf.train.write_graph( optimized_graph, output_dir, optimized_graph_filename, as_text=False ) logger.info('Saved optimized graph to: %s' % os.path.join(output_dir, optimized_graph_filename)) return optimized_graph_filename def load_graph_def(filename): """Load a graph_def file. Args: filename (str): a filename to load Returns: graph_def """ input_graph_def = tf.GraphDef() with gfile.FastGFile(filename, 'rb') as file: data = file.read() input_graph_def.ParseFromString(data) return input_graph_def