Python tensorflow.python.tools.optimize_for_inference_lib.optimize_for_inference() Examples

The following are 20 code examples of tensorflow.python.tools.optimize_for_inference_lib.optimize_for_inference(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow.python.tools.optimize_for_inference_lib , or try the search function .
Example #1
Source File: freeze_saved_model.py    From open_model_zoo with Apache License 2.0 6 votes vote down vote up
def freeze(saved_model_dir, input_nodes, output_nodes, save_file):
    graph_def = tf.Graph()
    with tf.Session(graph=graph_def) as sess:
        tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], saved_model_dir)
        frozen_graph_def = tf.graph_util.convert_variables_to_constants(
            sess,
            sess.graph_def,
            output_nodes
        )
        frozen_graph_def = optimize_for_inference_lib.optimize_for_inference(
            frozen_graph_def,
            input_nodes,
            output_nodes,
            tf.float32.as_datatype_enum
        )
        with open(save_file, 'wb') as f:
            f.write(frozen_graph_def.SerializeToString()) 
Example #2
Source File: convert_to_tfmobile.py    From fritz-models with MIT License 6 votes vote down vote up
def _optimize_graph(basename, output_dir):
    name, _ = os.path.splitext(basename)
    frozen_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
    graph_def = load_graph_def(frozen_graph_filename)

    optimized_graph = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def=graph_def,
        input_node_names=['input_1'],
        placeholder_type_enum=dtypes.float32.as_datatype_enum,
        output_node_names=['conv6_interp/ResizeBilinear'],
        toco_compatible=True
    )

    optimized_graph_filename = os.path.basename(
        frozen_graph_filename).replace('frozen', 'optimized')
    optimized_graph_filename = optimized_graph_filename
    tf.train.write_graph(
        optimized_graph, output_dir, optimized_graph_filename, as_text=False
    )
    logger.info('Saved optimized graph to: %s' %
                os.path.join(output_dir, optimized_graph_filename)) 
Example #3
Source File: convert_to_tfmobile.py    From fritz-models with MIT License 6 votes vote down vote up
def _optimize_graph(basename, output_dir):
    name, _ = os.path.splitext(basename)
    frozen_graph_filename = os.path.join(output_dir, '%s_frozen.pb' % name)
    graph_def = load_graph_def(frozen_graph_filename)

    optimized_graph = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def=graph_def,
        input_node_names=['input_1'],
        placeholder_type_enum=dtypes.float32.as_datatype_enum,
        output_node_names=['deprocess_stylized_image_1/mul'],
        toco_compatible=True
    )

    optimized_graph_filename = os.path.basename(
        frozen_graph_filename).replace('frozen', 'optimized')
    optimized_graph_filename = optimized_graph_filename
    tf.train.write_graph(
        optimized_graph, output_dir, optimized_graph_filename, as_text=False
    )
    logger.info('Saved optimized graph to: %s' %
                os.path.join(output_dir, optimized_graph_filename)) 
Example #4
Source File: optimize_for_inference.py    From lambda-packs with MIT License 5 votes vote down vote up
def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "rb") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

  if FLAGS.frozen_graph:
    f = gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    graph_io.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0 
Example #5
Source File: optimize_for_inference.py    From auto-alt-text-lambda-api with MIT License 5 votes vote down vote up
def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "r") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

  if FLAGS.frozen_graph:
    f = gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    graph_io.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0 
Example #6
Source File: tf_tensor.py    From spark-deep-learning with Apache License 2.0 5 votes vote down vote up
def _optimize_for_inference(self):
        graph_def = self.getTFInputGraph().graph_def
        # Get data types of input placeholders
        placeholder_types = self._get_placeholder_types(graph_def)
        # Strip away graph nodes not used in computing the tensors with the specified output names
        input_names = [tfx.op_name(tnsr_name) for _, tnsr_name in self.getInputMapping()]
        output_names = [tfx.op_name(tnsr_name) for tnsr_name, _ in self.getOutputMapping()]
        return infr_opt.optimize_for_inference(graph_def,
                                               input_names,
                                               output_names,
                                               placeholder_types) 
Example #7
Source File: infer_detections.py    From models with Apache License 2.0 5 votes vote down vote up
def load_graph(self):
    print('load graph from: ' + self.args.input_graph)

    self.infer_graph = tf.Graph()
    with self.infer_graph.as_default():
      graph_def = tf.compat.v1.GraphDef()
      with tf.compat.v1.gfile.FastGFile(self.args.input_graph, 'rb') as input_file:
        input_graph_content = input_file.read()
        graph_def.ParseFromString(input_graph_content)
      output_graph = optimize_for_inference(graph_def, [self.input_layer],
                              self.output_layers, dtypes.uint8.as_datatype_enum, False)
      tf.import_graph_def(output_graph, name='') 
Example #8
Source File: freeze_code.py    From Intelligent-Projects-Using-Python with MIT License 5 votes vote down vote up
def model_freeze(path,MODEL_NAME='model'):

    # Freeze the graph

    input_graph_path = path + MODEL_NAME+'.pbtxt'
    checkpoint_path = path + 'model_ckpt'
    input_saver_def_path = ""
    input_binary = False
    output_node_names = 'positive_sentiment_probability'
    restore_op_name = "save/restore_all"
    filename_tensor_name = "save/Const:0"
    output_frozen_graph_name = path + 'frozen_'+MODEL_NAME+'.pb'
    output_optimized_graph_name = path + 'optimized_'+MODEL_NAME+'.pb'
    clear_devices = True


    freeze_graph.freeze_graph(input_graph_path, input_saver_def_path,
                            input_binary, checkpoint_path, output_node_names,
                            restore_op_name, filename_tensor_name,
    output_frozen_graph_name, clear_devices, "")

    input_graph_def = tf.GraphDef()

    with tf.gfile.Open(output_frozen_graph_name, "rb") as f:
        data = f.read()
        input_graph_def.ParseFromString(data)

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def,
            ["inputs/X" ],#an array of the input node(s)
            ["positive_sentiment_probability"],
            tf.int32.as_datatype_enum # an array of output nodes
            )

    # Save the optimized graph

    f = tf.gfile.FastGFile(output_optimized_graph_name, "w")
    f.write(output_graph_def.SerializeToString()) 
Example #9
Source File: optimize_for_inference.py    From deep_image_model with Apache License 2.0 5 votes vote down vote up
def main(unused_args):
  if not tf.gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = tf.GraphDef()
  with tf.gfile.Open(FLAGS.input, "r") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

  if FLAGS.frozen_graph:
    f = tf.gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    tf.train.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0 
Example #10
Source File: optimize_for_inference.py    From keras-lambda with MIT License 5 votes vote down vote up
def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "r") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

  if FLAGS.frozen_graph:
    f = gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    graph_io.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0 
Example #11
Source File: mnist_cnn1.py    From Unity-MNIST with Apache License 2.0 5 votes vote down vote up
def export_model(saver, model, input_node_names, output_node_name):
    if not path.exists('out'):
        os.mkdir('out')

    tf.train.write_graph(K.get_session().graph_def, 'out', model_name + '_graph.pbtxt')

    saver.save(K.get_session(), 'out/' + model_name + '.chkp')

    freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, False,
                              'out/' + model_name + '.chkp', output_node_name,
                              "save/restore_all", "save/Const:0",
                              'out/frozen_' + model_name + '.bytes', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + model_name + '.bytes', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/opt_' + model_name + '.bytes', "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("graph saved!")

########################################################################################################################
# Main program 
Example #12
Source File: hangul_model.py    From tensorflow-hangul-recognition with Apache License 2.0 5 votes vote down vote up
def export_model(model_output_dir, input_node_names, output_node_name):
    """Export the model so we can use it later.

    This will create two Protocol Buffer files in the model output directory.
    These files represent a serialized version of our model with all the
    learned weights and biases. One of the ProtoBuf files is a version
    optimized for inference-only usage.
    """

    name_base = os.path.join(model_output_dir, MODEL_NAME)
    frozen_graph_file = os.path.join(model_output_dir,
                                     'frozen_' + MODEL_NAME + '.pb')
    freeze_graph.freeze_graph(
        name_base + '.pbtxt', None, False, name_base + '.chkp',
        output_node_name, "save/restore_all", "save/Const:0",
        frozen_graph_file, True, ""
    )

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open(frozen_graph_file, "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    optimized_graph_file = os.path.join(model_output_dir,
                                        'optimized_' + MODEL_NAME + '.pb')
    with tf.gfile.GFile(optimized_graph_file, "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("Inference optimized graph saved at: " + optimized_graph_file) 
Example #13
Source File: optimize_for_inference.py    From dcscn-super-resolution with MIT License 5 votes vote down vote up
def main(unused_args):
    if not gfile.Exists(FLAGS.input):
        print("Input graph file '" + FLAGS.input + "' does not exist!")
        return -1

    input_graph_def = graph_pb2.GraphDef()
    with gfile.Open(FLAGS.input, "rb") as f:
        data = f.read()
        if FLAGS.frozen_graph:
            input_graph_def.ParseFromString(data)
        else:
            text_format.Merge(data.decode("utf-8"), input_graph_def)

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        FLAGS.input_names.split(","),
        FLAGS.output_names.split(","),
        FLAGS.placeholder_type_enum,
        FLAGS.toco_compatible)

    if FLAGS.frozen_graph:
        f = gfile.FastGFile(FLAGS.output, "w")
        f.write(output_graph_def.SerializeToString())
    else:
        graph_io.write_graph(output_graph_def,
                             os.path.dirname(FLAGS.output),
                             os.path.basename(FLAGS.output))
    return 0 
Example #14
Source File: tf_utils.py    From fritz-models with MIT License 5 votes vote down vote up
def optimize_graph(frozen_graph_filename, suffix='optimized'):
    """Optimize a TensorFlow graph for inference.

    Optimized graphs are saved to the same directory as the input frozen graph.

    Args:
        frozen_graph_filename (str): the filename of a frozen graph.
        suffix (optional, str): a suffix to append to the optimized graph file.
    
    Returns:
        optimized_graph_filename (str): a path to the saved optimized graph.
    """
    output_dir, basename = os.path.split(frozen_graph_filename)
    graph_def = load_graph_def(frozen_graph_filename)

    optimized_graph = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def=graph_def,
        input_node_names=['input_1'],
        placeholder_type_enum=dtypes.float32.as_datatype_enum,
        output_node_names=['deprocess_stylized_image_1/mul'],
        toco_compatible=True
    )

    optimized_graph_filename = os.path.basename(
        frozen_graph_filename).replace('frozen', suffix)
    optimized_graph_filename = optimized_graph_filename
    tf.train.write_graph(
        optimized_graph, output_dir, optimized_graph_filename, as_text=False
    )
    logger.info('Saved optimized graph to: %s' %
                os.path.join(output_dir, optimized_graph_filename))
    return optimized_graph_filename 
Example #15
Source File: optimize_for_inference.py    From Serverless-Deep-Learning-with-TensorFlow-and-AWS-Lambda with MIT License 5 votes vote down vote up
def main(unused_args):
  if not gfile.Exists(FLAGS.input):
    print("Input graph file '" + FLAGS.input + "' does not exist!")
    return -1

  input_graph_def = graph_pb2.GraphDef()
  with gfile.Open(FLAGS.input, "rb") as f:
    data = f.read()
    if FLAGS.frozen_graph:
      input_graph_def.ParseFromString(data)
    else:
      text_format.Merge(data.decode("utf-8"), input_graph_def)

  output_graph_def = optimize_for_inference_lib.optimize_for_inference(
      input_graph_def,
      FLAGS.input_names.split(","),
      FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)

  if FLAGS.frozen_graph:
    f = gfile.FastGFile(FLAGS.output, "w")
    f.write(output_graph_def.SerializeToString())
  else:
    graph_io.write_graph(output_graph_def,
                         os.path.dirname(FLAGS.output),
                         os.path.basename(FLAGS.output))
  return 0 
Example #16
Source File: export.py    From tensorpack with Apache License 2.0 4 votes vote down vote up
def export_compact(self, filename, optimize=True, toco_compatible=False):
        """Create a self-contained inference-only graph and write final graph (in pb format) to disk.

        Args:
            filename (str): path to the output graph
            optimize (bool): whether to use TensorFlow's `optimize_for_inference`
                to prune and optimize the graph. This does not work on all types of graphs.
            toco_compatible (bool): See TensorFlow's
                `optimize_for_inference
                <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_
                for details. Only available after TF 1.8.
        """
        if toco_compatible:
            assert optimize, "toco_compatible is only effective when optimize=True!"
        self.graph = self.config._maybe_create_graph()
        with self.graph.as_default():
            input = PlaceholderInput()
            input.setup(self.config.input_signature)
            with PredictTowerContext(''):
                self.config.tower_func(*input.get_input_tensors())

            input_tensors = get_tensors_by_names(self.config.input_names)
            output_tensors = get_tensors_by_names(self.config.output_names)

            self.config.session_init._setup_graph()
            # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
            sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
            self.config.session_init._run_init(sess)

            dtypes = [n.dtype for n in input_tensors]

            # freeze variables to constants
            frozen_graph_def = graph_util.convert_variables_to_constants(
                sess,
                self.graph.as_graph_def(),
                [n.name[:-2] for n in output_tensors],
                variable_names_whitelist=None,
                variable_names_blacklist=None)

            # prune unused nodes from graph
            if optimize:
                toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, )
                frozen_graph_def = optimize_for_inference_lib.optimize_for_inference(
                    frozen_graph_def,
                    [n.name[:-2] for n in input_tensors],
                    [n.name[:-2] for n in output_tensors],
                    [dtype.as_datatype_enum for dtype in dtypes],
                    *toco_args)

            with gfile.FastGFile(filename, "wb") as f:
                f.write(frozen_graph_def.SerializeToString())
                logger.info("Output graph written to {}.".format(filename)) 
Example #17
Source File: training_pipeline.py    From paraphraser with MIT License 4 votes vote down vote up
def compress_graph(sess, args, model):
    """After training has completed, this function can be called to compress
    the model.  The computation graph is frozen turning the checkpoint
    variables into constants.  Finally, optimization is done by stripping
    away all unnecessary nodes from the graph if they are not used at
    inference time.

    Args:
        sess: Tensorflow session
        args: ArgumentParser config object
        model: model dictionary containing tensors of interest

    """
    from tensorflow.python.tools import freeze_graph 
    from tensorflow.python.tools import optimize_for_inference_lib

    tf.train.write_graph(sess.graph_def, '/media/sdb/models/paraphraser', 'model.pb', as_text=False)

    freeze_graph.freeze_graph(
        #input_graph='/tmp/model.pbtxt', 
        input_graph='/media/sdb/models/paraphraser/model.pb',
        input_saver='',
        input_binary=True, 
        input_checkpoint=args.checkpoint,
        output_node_names='predictions',
        restore_op_name='save/restore_all', 
        filename_tensor_name='save/Const:0',
        output_graph='/media/sdb/models/paraphraser/frozen_model.pb', 
        clear_devices=True, 
        initializer_nodes='')

    '''
    input_graph_def = tf.GraphDef()
    #with tf.gfile.Open('/media/sdb/models/paraphraser/frozen_model.pb', 'rb') as f:
    with tf.gfile.Open('/tmp/frozen_model.pb', 'rb') as f:
        data = f.read()
        input_graph_def.ParseFromString(data)
        with tf.Graph().as_default() as graph:
            tf.import_graph_def(input_graph_def)
            print(dir(graph))
            print(graph.find_tensor_by_name('placeholders/sampling_temperature'))

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        ['placeholders/source_ids', 'placeholders/sequence_source_lengths'],
        ['predictions'],
        tf.float32.as_datatype_enum)
    
    f = tf.gfile.FastGFile('/tmp/optimized_model.pb', "w")
    f.write(output_graph_def.SerializeToString())
    ''' 
Example #18
Source File: export.py    From ADL with MIT License 4 votes vote down vote up
def export_compact(self, filename, optimize=True, toco_compatible=False):
        """Create a self-contained inference-only graph and write final graph (in pb format) to disk.

        Args:
            filename (str): path to the output graph
            optimize (bool): whether to use TensorFlow's `optimize_for_inference`
                to prune and optimize the graph. This does not work on all types of graphs.
            toco_compatible (bool): See TensorFlow's
                `optimize_for_inference
                <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_
                for details. Only available after TF 1.8.
        """
        if toco_compatible:
            assert optimize, "toco_compatible is only effective when optimize=True!"
        self.graph = self.config._maybe_create_graph()
        with self.graph.as_default():
            input = PlaceholderInput()
            input.setup(self.config.input_signature)
            with PredictTowerContext(''):
                self.config.tower_func(*input.get_input_tensors())

            input_tensors = get_tensors_by_names(self.config.input_names)
            output_tensors = get_tensors_by_names(self.config.output_names)

            self.config.session_init._setup_graph()
            # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
            sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
            self.config.session_init._run_init(sess)

            dtypes = [n.dtype for n in input_tensors]

            # freeze variables to constants
            frozen_graph_def = graph_util.convert_variables_to_constants(
                sess,
                self.graph.as_graph_def(),
                [n.name[:-2] for n in output_tensors],
                variable_names_whitelist=None,
                variable_names_blacklist=None)

            # prune unused nodes from graph
            if optimize:
                toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, )
                frozen_graph_def = optimize_for_inference_lib.optimize_for_inference(
                    frozen_graph_def,
                    [n.name[:-2] for n in input_tensors],
                    [n.name[:-2] for n in output_tensors],
                    [dtype.as_datatype_enum for dtype in dtypes],
                    *toco_args)

            with gfile.FastGFile(filename, "wb") as f:
                f.write(frozen_graph_def.SerializeToString())
                logger.info("Output graph written to {}.".format(filename)) 
Example #19
Source File: export.py    From petridishnn with MIT License 4 votes vote down vote up
def export_compact(self, filename, optimize=True, toco_compatible=False):
        """Create a self-contained inference-only graph and write final graph (in pb format) to disk.

        Args:
            filename (str): path to the output graph
            optimize (bool): whether to use TensorFlow's `optimize_for_inference`
                to prune and optimize the graph. This does not work on all types of graphs.
            toco_compatible (bool): See TensorFlow's
                `optimize_for_inference
                <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference.py>`_
                for details. Only available after TF 1.8.
        """
        if toco_compatible:
            assert optimize, "toco_compatible is only effective when optimize=True!"
        self.graph = self.config._maybe_create_graph()
        with self.graph.as_default():
            input = PlaceholderInput()
            input.setup(self.config.input_signature)
            with PredictTowerContext(''):
                self.config.tower_func(*input.get_input_tensors())

            input_tensors = get_tensors_by_names(self.config.input_names)
            output_tensors = get_tensors_by_names(self.config.output_names)

            self.config.session_init._setup_graph()
            # we cannot use "self.config.session_creator.create_session()" here since it finalizes the graph
            sess = tfv1.Session(config=tfv1.ConfigProto(allow_soft_placement=True))
            self.config.session_init._run_init(sess)

            dtypes = [n.dtype for n in input_tensors]

            # freeze variables to constants
            frozen_graph_def = graph_util.convert_variables_to_constants(
                sess,
                self.graph.as_graph_def(),
                [n.name[:-2] for n in output_tensors],
                variable_names_whitelist=None,
                variable_names_blacklist=None)

            # prune unused nodes from graph
            if optimize:
                toco_args = () if get_tf_version_tuple() < (1, 8) else (toco_compatible, )
                frozen_graph_def = optimize_for_inference_lib.optimize_for_inference(
                    frozen_graph_def,
                    [n.name[:-2] for n in input_tensors],
                    [n.name[:-2] for n in output_tensors],
                    [dtype.as_datatype_enum for dtype in dtypes],
                    *toco_args)

            with gfile.FastGFile(filename, "wb") as f:
                f.write(frozen_graph_def.SerializeToString())
                logger.info("Output graph written to {}.".format(filename)) 
Example #20
Source File: converter.py    From fire-detection-cnn with MIT License 4 votes vote down vote up
def convert_to_pb(model, path, input_layer_name,  output_layer_name, pbfilename, verbose=False):

  model.load(path,weights_only=True)
  print("[INFO] Loaded CNN network weights from " + path + " ...")

  print("[INFO] Re-export model ...")
  del tf.get_collection_ref(tf.GraphKeys.TRAIN_OPS)[:]
  model.save("model-tmp.tfl")

  # taken from: https://stackoverflow.com/questions/34343259/is-there-an-example-on-how-to-generate-protobuf-files-holding-trained-tensorflow

  print("[INFO] Re-import model ...")

  input_checkpoint = "model-tmp.tfl"
  saver = tf.train.import_meta_graph(input_checkpoint + '.meta', True)
  sess = tf.Session();
  saver.restore(sess, input_checkpoint)

  # print out all layers to find name of output

  if (verbose):
      op = sess.graph.get_operations()
      [print(m.values()) for m in op][1]

  print("[INFO] Freeze model to " +  pbfilename + " ...")

  # freeze and removes nodes which are not related to feedforward prediction

  minimal_graph = convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_layer_name])

  graph_def = optimize_for_inference_lib.optimize_for_inference(minimal_graph, [input_layer_name], [output_layer_name], tf.float32.as_datatype_enum)
  graph_def = TransformGraph(graph_def, [input_layer_name], [output_layer_name], ["sort_by_execution_order"])

  with tf.gfile.GFile(pbfilename, 'wb') as f:
      f.write(graph_def.SerializeToString())

  # write model to logs dir so we can visualize it as:
  # tensorboard --logdir="logs"

  if (verbose):
      writer = tf.summary.FileWriter('logs', graph_def)
      writer.close()

  # tidy up tmp files

  for f in glob.glob("model-tmp.tfl*"):
      os.remove(f)

  os.remove('checkpoint')

################################################################################
# convert a  binary .pb protocol buffer format model to tflite format

# e.g. for FireNet
#    pbfilename = "firenet.pb"
#    input_layer_name = 'InputData/X'                  # input layer of network
#    output_layer_name= 'FullyConnected_2/Softmax'     # output layer of network